mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-11 17:47:20 +10:00
Refactor cloudflared to depend on sing-cloudflared
Replace the inline cloudflared implementation with a thin adapter wrapping github.com/sagernet/sing-cloudflared. The protocol/cloudflare package is reduced to a single inbound.go that bridges the external Service to the sing-box router.
This commit is contained in:
13
go.mod
13
go.mod
@@ -14,7 +14,6 @@ require (
|
||||
github.com/go-chi/render v1.0.3
|
||||
github.com/godbus/dbus/v5 v5.2.2
|
||||
github.com/gofrs/uuid/v5 v5.4.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/insomniacslk/dhcp v0.0.0-20260220084031-5adc3eb26f91
|
||||
github.com/jsimonetti/rtnetlink v1.4.0
|
||||
github.com/keybase/go-keychain v0.0.1
|
||||
@@ -38,13 +37,14 @@ require (
|
||||
github.com/sagernet/gomobile v0.1.12
|
||||
github.com/sagernet/gvisor v0.0.0-20250811.0-sing-box-mod.1
|
||||
github.com/sagernet/quic-go v0.59.0-sing-box-mod.4
|
||||
github.com/sagernet/sing v0.8.3
|
||||
github.com/sagernet/sing v0.8.4
|
||||
github.com/sagernet/sing-cloudflared v0.0.0-20260407120610-7715dc2523fa
|
||||
github.com/sagernet/sing-mux v0.3.4
|
||||
github.com/sagernet/sing-quic v0.6.2-0.20260330152607-bf674c163212
|
||||
github.com/sagernet/sing-shadowsocks v0.2.8
|
||||
github.com/sagernet/sing-shadowsocks2 v0.2.1
|
||||
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11
|
||||
github.com/sagernet/sing-tun v0.8.7-0.20260323120017-8eb4e8acfc2d
|
||||
github.com/sagernet/sing-tun v0.8.7-0.20260402180740-11f6e77ec6c6
|
||||
github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1
|
||||
github.com/sagernet/smux v1.5.50-sing-box-mod.1
|
||||
github.com/sagernet/tailscale v1.92.4-sing-box-1.13-mod.7
|
||||
@@ -64,7 +64,6 @@ require (
|
||||
google.golang.org/grpc v1.79.1
|
||||
google.golang.org/protobuf v1.36.11
|
||||
howett.net/plist v1.0.1
|
||||
zombiezen.com/go/capnproto2 v2.18.2+incompatible
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -75,7 +74,7 @@ require (
|
||||
github.com/andybalholm/brotli v1.1.0 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 // indirect
|
||||
github.com/coreos/go-oidc/v3 v3.12.0 // indirect
|
||||
github.com/coreos/go-oidc/v3 v3.17.0 // indirect
|
||||
github.com/database64128/netx-go v0.1.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa // indirect
|
||||
@@ -95,6 +94,7 @@ require (
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/hashicorp/yamux v0.1.2 // indirect
|
||||
github.com/hdevalence/ed25519consensus v0.2.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
@@ -102,6 +102,7 @@ require (
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/mdlayher/socket v0.5.1 // indirect
|
||||
github.com/mitchellh/go-ps v1.0.0 // indirect
|
||||
github.com/philhofer/fwd v1.2.0 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
||||
github.com/pires/go-proxyproto v0.8.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
@@ -151,7 +152,6 @@ require (
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/tinylib/msgp v1.6.3 // indirect
|
||||
github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
github.com/zeebo/blake3 v0.2.4 // indirect
|
||||
@@ -169,4 +169,5 @@ require (
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
lukechampine.com/blake3 v1.3.0 // indirect
|
||||
zombiezen.com/go/capnproto2 v2.18.2+incompatible // indirect
|
||||
)
|
||||
|
||||
16
go.sum
16
go.sum
@@ -28,8 +28,8 @@ github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9
|
||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||
github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 h1:8h5+bWd7R6AYUslN6c6iuZWTKsKxUFDlpnmilO6R2n0=
|
||||
github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
|
||||
github.com/coreos/go-oidc/v3 v3.12.0 h1:sJk+8G2qq94rDI6ehZ71Bol3oUHy63qNYmkiSjrc/Jo=
|
||||
github.com/coreos/go-oidc/v3 v3.12.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0=
|
||||
github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc=
|
||||
github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/cretz/bine v0.2.0 h1:8GiDRGlTgz+o8H9DSnsl+5MeBK4HsExxgl6WgzOCuZo=
|
||||
github.com/cretz/bine v0.2.0/go.mod h1:WU4o9QR9wWp8AVKtTM1XD5vUHkEqnf2vVSo6dBqbetI=
|
||||
@@ -112,6 +112,8 @@ github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zt
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/letsencrypt/challtestsrv v1.4.2 h1:0ON3ldMhZyWlfVNYYpFuWRTmZNnyfiL9Hh5YzC3JVwU=
|
||||
github.com/letsencrypt/challtestsrv v1.4.2/go.mod h1:GhqMqcSoeGpYd5zX5TgwA6er/1MbWzx/o7yuuVya+Wk=
|
||||
github.com/letsencrypt/pebble/v2 v2.10.0 h1:Wq6gYXlsY6ubqI3hhxsTzdyotvfdjFBxuwYqCLCnj/U=
|
||||
@@ -240,8 +242,10 @@ github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNen
|
||||
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
|
||||
github.com/sagernet/quic-go v0.59.0-sing-box-mod.4 h1:6qvrUW79S+CrPwWz6cMePXohgjHoKxLo3c+MDhNwc3o=
|
||||
github.com/sagernet/quic-go v0.59.0-sing-box-mod.4/go.mod h1:OqILvS182CyOol5zNNo6bguvOGgXzV459+chpRaUC+4=
|
||||
github.com/sagernet/sing v0.8.3 h1:zGMy9M1deBPEew9pCYIUHKeE+/lDQ5A2CBqjBjjzqkA=
|
||||
github.com/sagernet/sing v0.8.3/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
|
||||
github.com/sagernet/sing v0.8.4 h1:Fj+jlY3F8vhcRfz/G/P3Dwcs5wqnmyNPT7u1RVVmjFI=
|
||||
github.com/sagernet/sing v0.8.4/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
|
||||
github.com/sagernet/sing-cloudflared v0.0.0-20260407120610-7715dc2523fa h1:165HiOfgfofJIirEp1NGSmsoJAi+++WhR29IhtAu4A4=
|
||||
github.com/sagernet/sing-cloudflared v0.0.0-20260407120610-7715dc2523fa/go.mod h1:bH2NKX+NpDTY1Zkxfboxw6MXB/ZywaNLmrDJYgKMJ2Y=
|
||||
github.com/sagernet/sing-mux v0.3.4 h1:ZQplKl8MNXutjzbMVtWvWG31fohhgOfCuUZR4dVQ8+s=
|
||||
github.com/sagernet/sing-mux v0.3.4/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk=
|
||||
github.com/sagernet/sing-quic v0.6.2-0.20260330152607-bf674c163212 h1:7mFOUqy+DyOj7qKGd1X54UMXbnbJiiMileK/tn17xYc=
|
||||
@@ -252,8 +256,8 @@ github.com/sagernet/sing-shadowsocks2 v0.2.1 h1:dWV9OXCeFPuYGHb6IRqlSptVnSzOelnq
|
||||
github.com/sagernet/sing-shadowsocks2 v0.2.1/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ=
|
||||
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 h1:tK+75l64tm9WvEFrYRE1t0YxoFdWQqw/h7Uhzj0vJ+w=
|
||||
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11/go.mod h1:sWqKnGlMipCHaGsw1sTTlimyUpgzP4WP3pjhCsYt9oA=
|
||||
github.com/sagernet/sing-tun v0.8.7-0.20260323120017-8eb4e8acfc2d h1:vi0j6301f6H8t2GYgAC2PA2AdnGdMwkP34B4+N03Qt4=
|
||||
github.com/sagernet/sing-tun v0.8.7-0.20260323120017-8eb4e8acfc2d/go.mod h1:pLCo4o+LacXEzz0bhwhJkKBjLlKOGPBNOAZ97ZVZWzs=
|
||||
github.com/sagernet/sing-tun v0.8.7-0.20260402180740-11f6e77ec6c6 h1:HV2I7DicF5Ar8v6F55f03W5FviBB7jgvLhJSDwbFvbk=
|
||||
github.com/sagernet/sing-tun v0.8.7-0.20260402180740-11f6e77ec6c6/go.mod h1:pLCo4o+LacXEzz0bhwhJkKBjLlKOGPBNOAZ97ZVZWzs=
|
||||
github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1 h1:aSwUNYUkVyVvdmBSufR8/nRFonwJeKSIROxHcm5br9o=
|
||||
github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1/go.mod h1:P11scgTxMxVVQ8dlM27yNm3Cro40mD0+gHbnqrNGDuY=
|
||||
github.com/sagernet/smux v1.5.50-sing-box-mod.1 h1:XkJcivBC9V4wBjiGXIXZ229aZCU1hzcbp6kSkkyQ478=
|
||||
|
||||
@@ -6,6 +6,7 @@ type CloudflaredInboundOptions struct {
|
||||
Token string `json:"token,omitempty"`
|
||||
HAConnections int `json:"ha_connections,omitempty"`
|
||||
Protocol string `json:"protocol,omitempty"`
|
||||
PostQuantum bool `json:"post_quantum,omitempty"`
|
||||
ControlDialer DialerOptions `json:"control_dialer,omitempty"`
|
||||
TunnelDialer DialerOptions `json:"tunnel_dialer,omitempty"`
|
||||
EdgeIPVersion int `json:"edge_ip_version,omitempty"`
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
)
|
||||
|
||||
const accessJWTAssertionHeader = "Cf-Access-Jwt-Assertion"
|
||||
|
||||
var newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
|
||||
issuerURL := accessIssuerURL(access.TeamName, access.Environment)
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, network, M.ParseSocksaddr(address))
|
||||
},
|
||||
},
|
||||
}
|
||||
keySet := oidc.NewRemoteKeySet(oidc.ClientContext(context.Background(), client), issuerURL+"/cdn-cgi/access/certs")
|
||||
verifier := oidc.NewVerifier(issuerURL, keySet, &oidc.Config{
|
||||
SkipClientIDCheck: true,
|
||||
})
|
||||
return &oidcAccessValidator{
|
||||
verifier: verifier,
|
||||
audTags: append([]string(nil), access.AudTag...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type accessValidator interface {
|
||||
Validate(ctx context.Context, request *http.Request) error
|
||||
}
|
||||
|
||||
type oidcAccessValidator struct {
|
||||
verifier *oidc.IDTokenVerifier
|
||||
audTags []string
|
||||
}
|
||||
|
||||
func (v *oidcAccessValidator) Validate(ctx context.Context, request *http.Request) error {
|
||||
accessJWT := request.Header.Get(accessJWTAssertionHeader)
|
||||
if accessJWT == "" {
|
||||
return E.New("missing access jwt assertion")
|
||||
}
|
||||
token, err := v.verifier.Verify(ctx, accessJWT)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if accessTokenAudienceAllowed(token.Audience, v.audTags) {
|
||||
return nil
|
||||
}
|
||||
return E.New("access token audience does not match configured aud_tag")
|
||||
}
|
||||
|
||||
func accessTokenAudienceAllowed(tokenAudience []string, configuredAudTags []string) bool {
|
||||
for _, tokenAudTag := range tokenAudience {
|
||||
for _, configuredAudTag := range configuredAudTags {
|
||||
if configuredAudTag == tokenAudTag {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func accessIssuerURL(teamName string, environment string) string {
|
||||
if strings.EqualFold(environment, "fed") || strings.EqualFold(environment, "fips") {
|
||||
return fmt.Sprintf("https://%s.fed.cloudflareaccess.com", teamName)
|
||||
}
|
||||
return fmt.Sprintf("https://%s.cloudflareaccess.com", teamName)
|
||||
}
|
||||
|
||||
func validateAccessConfiguration(access AccessConfig) error {
|
||||
if !access.Required {
|
||||
return nil
|
||||
}
|
||||
if access.TeamName == "" && len(access.AudTag) > 0 {
|
||||
return E.New("access.team_name cannot be blank when access.aud_tag is present")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func accessValidatorKey(access AccessConfig) string {
|
||||
return access.TeamName + "|" + access.Environment + "|" + strings.Join(access.AudTag, ",")
|
||||
}
|
||||
|
||||
type accessValidatorCache struct {
|
||||
access sync.RWMutex
|
||||
values map[string]accessValidator
|
||||
dialer N.Dialer
|
||||
}
|
||||
|
||||
func (c *accessValidatorCache) Get(accessConfig AccessConfig) (accessValidator, error) {
|
||||
key := accessValidatorKey(accessConfig)
|
||||
c.access.RLock()
|
||||
validator, loaded := c.values[key]
|
||||
c.access.RUnlock()
|
||||
if loaded {
|
||||
return validator, nil
|
||||
}
|
||||
|
||||
validator, err := newAccessValidator(accessConfig, c.dialer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.access.Lock()
|
||||
c.values[key] = validator
|
||||
c.access.Unlock()
|
||||
return validator, nil
|
||||
}
|
||||
@@ -1,191 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type fakeAccessValidator struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (v *fakeAccessValidator) Validate(ctx context.Context, request *http.Request) error {
|
||||
return v.err
|
||||
}
|
||||
|
||||
func newAccessTestInbound(t *testing.T) *Inbound {
|
||||
t.Helper()
|
||||
logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
logger: logFactory.NewLogger("test"),
|
||||
accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer},
|
||||
router: &testRouter{},
|
||||
controlDialer: N.SystemDialer,
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAccessConfiguration(t *testing.T) {
|
||||
err := validateAccessConfiguration(AccessConfig{
|
||||
Required: true,
|
||||
AudTag: []string{"aud"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected access config validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessTokenAudienceAllowed(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
tokenAudience []string
|
||||
configuredTags []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "matching audience",
|
||||
tokenAudience: []string{"aud-1", "aud-2"},
|
||||
configuredTags: []string{"aud-2"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "empty configured tags rejected",
|
||||
tokenAudience: []string{"aud-1"},
|
||||
configuredTags: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "non matching audience rejected",
|
||||
tokenAudience: []string{"aud-1"},
|
||||
configuredTags: []string{"aud-2"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
allowed := accessTokenAudienceAllowed(testCase.tokenAudience, testCase.configuredTags)
|
||||
if allowed != testCase.expected {
|
||||
t.Fatalf("accessTokenAudienceAllowed(%v, %v) = %v, want %v", testCase.tokenAudience, testCase.configuredTags, allowed, testCase.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundTripHTTPAccessDenied(t *testing.T) {
|
||||
originalFactory := newAccessValidator
|
||||
defer func() {
|
||||
newAccessValidator = originalFactory
|
||||
}()
|
||||
newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
|
||||
return &fakeAccessValidator{err: E.New("forbidden")}, nil
|
||||
}
|
||||
|
||||
inboundInstance := newAccessTestInbound(t)
|
||||
respWriter := &fakeConnectResponseWriter{}
|
||||
request := &ConnectRequest{
|
||||
Type: ConnectionTypeHTTP,
|
||||
Dest: "http://127.0.0.1:8083/test",
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPMethod, Val: http.MethodGet},
|
||||
{Key: metadataHTTPHost, Val: "example.com"},
|
||||
},
|
||||
}
|
||||
inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{
|
||||
Kind: ResolvedServiceHTTP,
|
||||
Destination: M.ParseSocksaddr("127.0.0.1:8083"),
|
||||
OriginRequest: OriginRequestConfig{
|
||||
Access: AccessConfig{
|
||||
Required: true,
|
||||
TeamName: "team",
|
||||
},
|
||||
},
|
||||
})
|
||||
if respWriter.status != http.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d", respWriter.status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleHTTPServiceStatusAccessDenied(t *testing.T) {
|
||||
originalFactory := newAccessValidator
|
||||
defer func() {
|
||||
newAccessValidator = originalFactory
|
||||
}()
|
||||
newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
|
||||
return &fakeAccessValidator{err: E.New("forbidden")}, nil
|
||||
}
|
||||
|
||||
inboundInstance := newAccessTestInbound(t)
|
||||
respWriter := &fakeConnectResponseWriter{}
|
||||
request := &ConnectRequest{
|
||||
Type: ConnectionTypeHTTP,
|
||||
Dest: "https://example.com/status",
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPMethod, Val: http.MethodGet},
|
||||
{Key: metadataHTTPHost, Val: "example.com"},
|
||||
},
|
||||
}
|
||||
inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{
|
||||
Kind: ResolvedServiceStatus,
|
||||
OriginRequest: OriginRequestConfig{
|
||||
Access: AccessConfig{
|
||||
Required: true,
|
||||
TeamName: "team",
|
||||
},
|
||||
},
|
||||
StatusCode: 404,
|
||||
})
|
||||
if respWriter.status != http.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d", respWriter.status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleHTTPServiceStreamAccessDenied(t *testing.T) {
|
||||
originalFactory := newAccessValidator
|
||||
defer func() {
|
||||
newAccessValidator = originalFactory
|
||||
}()
|
||||
newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
|
||||
return &fakeAccessValidator{err: E.New("forbidden")}, nil
|
||||
}
|
||||
|
||||
inboundInstance := newAccessTestInbound(t)
|
||||
respWriter := &fakeConnectResponseWriter{}
|
||||
request := &ConnectRequest{
|
||||
Type: ConnectionTypeWebsocket,
|
||||
Dest: "https://example.com/ws",
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPMethod, Val: http.MethodGet},
|
||||
{Key: metadataHTTPHost, Val: "example.com"},
|
||||
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
},
|
||||
}
|
||||
inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{
|
||||
Kind: ResolvedServiceStream,
|
||||
Destination: M.ParseSocksaddr("127.0.0.1:8080"),
|
||||
OriginRequest: OriginRequestConfig{
|
||||
Access: AccessConfig{
|
||||
Required: true,
|
||||
TeamName: "team",
|
||||
},
|
||||
},
|
||||
})
|
||||
if respWriter.status != http.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d", respWriter.status)
|
||||
}
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIICiTCCAi6gAwIBAgIUXZP3MWb8MKwBE1Qbawsp1sfA/Y4wCgYIKoZIzj0EAwIw
|
||||
gY8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1T
|
||||
YW4gRnJhbmNpc2NvMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMTgwNgYDVQQL
|
||||
Ey9DbG91ZEZsYXJlIE9yaWdpbiBTU0wgRUNDIENlcnRpZmljYXRlIEF1dGhvcml0
|
||||
eTAeFw0xOTA4MjMyMTA4MDBaFw0yOTA4MTUxNzAwMDBaMIGPMQswCQYDVQQGEwJV
|
||||
UzETMBEGA1UECBMKQ2FsaWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEZ
|
||||
MBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE4MDYGA1UECxMvQ2xvdWRGbGFyZSBP
|
||||
cmlnaW4gU1NMIEVDQyBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkwWTATBgcqhkjOPQIB
|
||||
BggqhkjOPQMBBwNCAASR+sGALuaGshnUbcxKry+0LEXZ4NY6JUAtSeA6g87K3jaA
|
||||
xpIg9G50PokpfWkhbarLfpcZu0UAoYy2su0EhN7wo2YwZDAOBgNVHQ8BAf8EBAMC
|
||||
AQYwEgYDVR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4EFgQUhTBdOypw1O3VkmcH/es5
|
||||
tBoOOKcwHwYDVR0jBBgwFoAUhTBdOypw1O3VkmcH/es5tBoOOKcwCgYIKoZIzj0E
|
||||
AwIDSQAwRgIhAKilfntP2ILGZjwajktkBtXE1pB4Y/fjAfLkIRUzrI15AiEA5UCL
|
||||
XYZZ9m2c3fKwIenMMojL1eqydsgqj/wK4p5kagQ=
|
||||
-----END CERTIFICATE-----
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIEADCCAuigAwIBAgIID+rOSdTGfGcwDQYJKoZIhvcNAQELBQAwgYsxCzAJBgNV
|
||||
BAYTAlVTMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMTQwMgYDVQQLEytDbG91
|
||||
ZEZsYXJlIE9yaWdpbiBTU0wgQ2VydGlmaWNhdGUgQXV0aG9yaXR5MRYwFAYDVQQH
|
||||
Ew1TYW4gRnJhbmNpc2NvMRMwEQYDVQQIEwpDYWxpZm9ybmlhMB4XDTE5MDgyMzIx
|
||||
MDgwMFoXDTI5MDgxNTE3MDAwMFowgYsxCzAJBgNVBAYTAlVTMRkwFwYDVQQKExBD
|
||||
bG91ZEZsYXJlLCBJbmMuMTQwMgYDVQQLEytDbG91ZEZsYXJlIE9yaWdpbiBTU0wg
|
||||
Q2VydGlmaWNhdGUgQXV0aG9yaXR5MRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMRMw
|
||||
EQYDVQQIEwpDYWxpZm9ybmlhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC
|
||||
AQEAwEiVZ/UoQpHmFsHvk5isBxRehukP8DG9JhFev3WZtG76WoTthvLJFRKFCHXm
|
||||
V6Z5/66Z4S09mgsUuFwvJzMnE6Ej6yIsYNCb9r9QORa8BdhrkNn6kdTly3mdnykb
|
||||
OomnwbUfLlExVgNdlP0XoRoeMwbQ4598foiHblO2B/LKuNfJzAMfS7oZe34b+vLB
|
||||
yrP/1bgCSLdc1AxQc1AC0EsQQhgcyTJNgnG4va1c7ogPlwKyhbDyZ4e59N5lbYPJ
|
||||
SmXI/cAe3jXj1FBLJZkwnoDKe0v13xeF+nF32smSH0qB7aJX2tBMW4TWtFPmzs5I
|
||||
lwrFSySWAdwYdgxw180yKU0dvwIDAQABo2YwZDAOBgNVHQ8BAf8EBAMCAQYwEgYD
|
||||
VR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4EFgQUJOhTV118NECHqeuU27rhFnj8KaQw
|
||||
HwYDVR0jBBgwFoAUJOhTV118NECHqeuU27rhFnj8KaQwDQYJKoZIhvcNAQELBQAD
|
||||
ggEBAHwOf9Ur1l0Ar5vFE6PNrZWrDfQIMyEfdgSKofCdTckbqXNTiXdgbHs+TWoQ
|
||||
wAB0pfJDAHJDXOTCWRyTeXOseeOi5Btj5CnEuw3P0oXqdqevM1/+uWp0CM35zgZ8
|
||||
VD4aITxity0djzE6Qnx3Syzz+ZkoBgTnNum7d9A66/V636x4vTeqbZFBr9erJzgz
|
||||
hhurjcoacvRNhnjtDRM0dPeiCJ50CP3wEYuvUzDHUaowOsnLCjQIkWbR7Ni6KEIk
|
||||
MOz2U0OBSif3FTkhCgZWQKOOLo1P42jHC3ssUZAtVNXrCk3fw9/E15k8NPkBazZ6
|
||||
0iykLhH1trywrKRMVw67F44IE8Y=
|
||||
-----END CERTIFICATE-----
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIGCjCCA/KgAwIBAgIIV5G6lVbCLmEwDQYJKoZIhvcNAQENBQAwgZAxCzAJBgNV
|
||||
BAYTAlVTMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMRQwEgYDVQQLEwtPcmln
|
||||
aW4gUHVsbDEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZv
|
||||
cm5pYTEjMCEGA1UEAxMab3JpZ2luLXB1bGwuY2xvdWRmbGFyZS5uZXQwHhcNMTkx
|
||||
MDEwMTg0NTAwWhcNMjkxMTAxMTcwMDAwWjCBkDELMAkGA1UEBhMCVVMxGTAXBgNV
|
||||
BAoTEENsb3VkRmxhcmUsIEluYy4xFDASBgNVBAsTC09yaWdpbiBQdWxsMRYwFAYD
|
||||
VQQHEw1TYW4gRnJhbmNpc2NvMRMwEQYDVQQIEwpDYWxpZm9ybmlhMSMwIQYDVQQD
|
||||
ExpvcmlnaW4tcHVsbC5jbG91ZGZsYXJlLm5ldDCCAiIwDQYJKoZIhvcNAQEBBQAD
|
||||
ggIPADCCAgoCggIBAN2y2zojYfl0bKfhp0AJBFeV+jQqbCw3sHmvEPwLmqDLqynI
|
||||
42tZXR5y914ZB9ZrwbL/K5O46exd/LujJnV2b3dzcx5rtiQzso0xzljqbnbQT20e
|
||||
ihx/WrF4OkZKydZzsdaJsWAPuplDH5P7J82q3re88jQdgE5hqjqFZ3clCG7lxoBw
|
||||
hLaazm3NJJlUfzdk97ouRvnFGAuXd5cQVx8jYOOeU60sWqmMe4QHdOvpqB91bJoY
|
||||
QSKVFjUgHeTpN8tNpKJfb9LIn3pun3bC9NKNHtRKMNX3Kl/sAPq7q/AlndvA2Kw3
|
||||
Dkum2mHQUGdzVHqcOgea9BGjLK2h7SuX93zTWL02u799dr6Xkrad/WShHchfjjRn
|
||||
aL35niJUDr02YJtPgxWObsrfOU63B8juLUphW/4BOjjJyAG5l9j1//aUGEi/sEe5
|
||||
lqVv0P78QrxoxR+MMXiJwQab5FB8TG/ac6mRHgF9CmkX90uaRh+OC07XjTdfSKGR
|
||||
PpM9hB2ZhLol/nf8qmoLdoD5HvODZuKu2+muKeVHXgw2/A6wM7OwrinxZiyBk5Hh
|
||||
CvaADH7PZpU6z/zv5NU5HSvXiKtCzFuDu4/Zfi34RfHXeCUfHAb4KfNRXJwMsxUa
|
||||
+4ZpSAX2G6RnGU5meuXpU5/V+DQJp/e69XyyY6RXDoMywaEFlIlXBqjRRA2pAgMB
|
||||
AAGjZjBkMA4GA1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAGAQH/AgECMB0GA1Ud
|
||||
DgQWBBRDWUsraYuA4REzalfNVzjann3F6zAfBgNVHSMEGDAWgBRDWUsraYuA4REz
|
||||
alfNVzjann3F6zANBgkqhkiG9w0BAQ0FAAOCAgEAkQ+T9nqcSlAuW/90DeYmQOW1
|
||||
QhqOor5psBEGvxbNGV2hdLJY8h6QUq48BCevcMChg/L1CkznBNI40i3/6heDn3IS
|
||||
zVEwXKf34pPFCACWVMZxbQjkNRTiH8iRur9EsaNQ5oXCPJkhwg2+IFyoPAAYURoX
|
||||
VcI9SCDUa45clmYHJ/XYwV1icGVI8/9b2JUqklnOTa5tugwIUi5sTfipNcJXHhgz
|
||||
6BKYDl0/UP0lLKbsUETXeTGDiDpxZYIgbcFrRDDkHC6BSvdWVEiH5b9mH2BON60z
|
||||
0O0j8EEKTwi9jnafVtZQXP/D8yoVowdFDjXcKkOPF/1gIh9qrFR6GdoPVgB3SkLc
|
||||
5ulBqZaCHm563jsvWb/kXJnlFxW+1bsO9BDD6DweBcGdNurgmH625wBXksSdD7y/
|
||||
fakk8DagjbjKShYlPEFOAqEcliwjF45eabL0t27MJV61O/jHzHL3dknXeE4BDa2j
|
||||
bA+JbyJeUMtU7KMsxvx82RmhqBEJJDBCJ3scVptvhDMRrtqDBW5JShxoAOcpFQGm
|
||||
iYWicn46nPDjgTU0bX1ZPpTpryXbvciVL5RkVBuyX2ntcOLDPlZWgxZCBp96x07F
|
||||
AnOzKgZk4RzZPNAxCXERVxajn/FLcOhglVAKo5H0ac+AitlQ0ip55D2/mf8o72tM
|
||||
fVQ6VpyjEXdiIXWUq/o=
|
||||
-----END CERTIFICATE-----
|
||||
@@ -1,72 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
)
|
||||
|
||||
func TestNewInboundRequiresToken(t *testing.T) {
|
||||
_, err := NewInbound(context.Background(), nil, log.NewNOPFactory().NewLogger("test"), "test", option.CloudflaredInboundOptions{})
|
||||
if err == nil {
|
||||
t.Fatal("expected missing token error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRegistrationResultRejectsNonRemoteManaged(t *testing.T) {
|
||||
err := validateRegistrationResult(&RegistrationResult{TunnelIsRemotelyManaged: false})
|
||||
if err == nil {
|
||||
t.Fatal("expected unsupported tunnel error")
|
||||
}
|
||||
if err != ErrNonRemoteManagedTunnelUnsupported {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeProtocolAutoUsesTokenStyleSentinel(t *testing.T) {
|
||||
protocol, err := normalizeProtocol("auto")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if protocol != "" {
|
||||
t.Fatalf("expected auto protocol to normalize to token-style empty sentinel, got %q", protocol)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveGracePeriodDefaultsToThirtySeconds(t *testing.T) {
|
||||
if got := resolveGracePeriod(nil); got != 30*time.Second {
|
||||
t.Fatalf("expected default grace period 30s, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveGracePeriodPreservesExplicitZero(t *testing.T) {
|
||||
var options option.CloudflaredInboundOptions
|
||||
if err := json.Unmarshal([]byte(`{"grace_period":"0s"}`), &options); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if options.GracePeriod == nil {
|
||||
t.Fatal("expected explicit grace period to be set")
|
||||
}
|
||||
if got := resolveGracePeriod(options.GracePeriod); got != 0 {
|
||||
t.Fatalf("expected explicit zero grace period, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveGracePeriodPreservesNonZeroValue(t *testing.T) {
|
||||
var options option.CloudflaredInboundOptions
|
||||
if err := json.Unmarshal([]byte(`{"grace_period":"45s"}`), &options); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if options.GracePeriod == nil {
|
||||
t.Fatal("expected explicit grace period to be set")
|
||||
}
|
||||
if got := resolveGracePeriod(options.GracePeriod); got != 45*time.Second {
|
||||
t.Fatalf("expected grace period 45s, got %s", got)
|
||||
}
|
||||
}
|
||||
@@ -1,268 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type stubNetConn struct {
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func newStubNetConn() *stubNetConn {
|
||||
return &stubNetConn{closed: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (c *stubNetConn) Read(_ []byte) (int, error) { <-c.closed; return 0, io.EOF }
|
||||
func (c *stubNetConn) Write(b []byte) (int, error) { return len(b), nil }
|
||||
func (c *stubNetConn) Close() error { closeOnce(c.closed); return nil }
|
||||
func (c *stubNetConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
|
||||
func (c *stubNetConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
|
||||
func (c *stubNetConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *stubNetConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *stubNetConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
type stubQUICConn struct {
|
||||
closed chan string
|
||||
}
|
||||
|
||||
func newStubQUICConn() *stubQUICConn {
|
||||
return &stubQUICConn{closed: make(chan string, 1)}
|
||||
}
|
||||
|
||||
func (c *stubQUICConn) OpenStream() (*quic.Stream, error) { return nil, errors.New("unused") }
|
||||
func (c *stubQUICConn) AcceptStream(context.Context) (*quic.Stream, error) {
|
||||
return nil, errors.New("unused")
|
||||
}
|
||||
|
||||
func (c *stubQUICConn) ReceiveDatagram(context.Context) ([]byte, error) {
|
||||
return nil, errors.New("unused")
|
||||
}
|
||||
func (c *stubQUICConn) SendDatagram([]byte) error { return nil }
|
||||
func (c *stubQUICConn) LocalAddr() net.Addr { return &net.UDPAddr{} }
|
||||
func (c *stubQUICConn) CloseWithError(_ quic.ApplicationErrorCode, reason string) error {
|
||||
select {
|
||||
case c.closed <- reason:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockRegistrationClient struct {
|
||||
unregisterCalled chan struct{}
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func newMockRegistrationClient() *mockRegistrationClient {
|
||||
return &mockRegistrationClient{
|
||||
unregisterCalled: make(chan struct{}, 1),
|
||||
closed: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mockRegistrationClient) RegisterConnection(context.Context, TunnelAuth, uuid.UUID, uint8, *RegistrationConnectionOptions) (*RegistrationResult, error) {
|
||||
return &RegistrationResult{}, nil
|
||||
}
|
||||
|
||||
func (c *mockRegistrationClient) Unregister(context.Context) error {
|
||||
select {
|
||||
case c.unregisterCalled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mockRegistrationClient) Close() error {
|
||||
select {
|
||||
case c.closed <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func closeOnce(ch chan struct{}) {
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP2GracefulShutdownWaitsForActiveRequests(t *testing.T) {
|
||||
conn := newStubNetConn()
|
||||
registrationClient := newMockRegistrationClient()
|
||||
connection := &HTTP2Connection{
|
||||
conn: conn,
|
||||
gracePeriod: 200 * time.Millisecond,
|
||||
registrationClient: registrationClient,
|
||||
registrationResult: &RegistrationResult{},
|
||||
serveCancel: func() {},
|
||||
}
|
||||
connection.activeRequests.Add(1)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
connection.gracefulShutdown()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-registrationClient.unregisterCalled:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected unregister call")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-conn.closed:
|
||||
t.Fatal("connection closed before active requests completed")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
connection.activeRequests.Done()
|
||||
|
||||
select {
|
||||
case <-conn.closed:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected connection close after active requests finished")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected graceful shutdown to finish")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP2GracefulShutdownTimesOut(t *testing.T) {
|
||||
conn := newStubNetConn()
|
||||
registrationClient := newMockRegistrationClient()
|
||||
connection := &HTTP2Connection{
|
||||
conn: conn,
|
||||
gracePeriod: 50 * time.Millisecond,
|
||||
registrationClient: registrationClient,
|
||||
registrationResult: &RegistrationResult{},
|
||||
serveCancel: func() {},
|
||||
}
|
||||
connection.activeRequests.Add(1)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
connection.gracefulShutdown()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-conn.closed:
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("expected connection close after grace timeout")
|
||||
}
|
||||
|
||||
connection.activeRequests.Done()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected graceful shutdown to finish after request completion")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQUICGracefulShutdownWaitsForDrainWindow(t *testing.T) {
|
||||
conn := newStubQUICConn()
|
||||
registrationClient := newMockRegistrationClient()
|
||||
serveCancelCalled := make(chan struct{}, 1)
|
||||
connection := &QUICConnection{
|
||||
conn: conn,
|
||||
gracePeriod: 80 * time.Millisecond,
|
||||
registrationClient: registrationClient,
|
||||
registrationResult: &RegistrationResult{},
|
||||
serveCancel: func() {
|
||||
select {
|
||||
case serveCancelCalled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
connection.gracefulShutdown()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-registrationClient.unregisterCalled:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected unregister call")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-conn.closed:
|
||||
t.Fatal("connection closed before grace window elapsed")
|
||||
case <-time.After(20 * time.Millisecond):
|
||||
}
|
||||
|
||||
select {
|
||||
case reason := <-conn.closed:
|
||||
if reason != "graceful shutdown" {
|
||||
t.Fatalf("unexpected close reason: %q", reason)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected graceful close")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-serveCancelCalled:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected serve cancel to be called")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected graceful shutdown to finish")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQUICGracefulShutdownStopsWaitingWhenServeContextEnds(t *testing.T) {
|
||||
conn := newStubQUICConn()
|
||||
registrationClient := newMockRegistrationClient()
|
||||
serveCtx, cancelServe := context.WithCancel(context.Background())
|
||||
connection := &QUICConnection{
|
||||
conn: conn,
|
||||
gracePeriod: time.Second,
|
||||
registrationClient: registrationClient,
|
||||
registrationResult: &RegistrationResult{},
|
||||
serveCtx: serveCtx,
|
||||
serveCancel: func() {},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
connection.gracefulShutdown()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-registrationClient.unregisterCalled:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected unregister call")
|
||||
}
|
||||
|
||||
cancelServe()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("expected graceful shutdown to stop waiting once serve context ends")
|
||||
}
|
||||
}
|
||||
@@ -1,522 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
const (
|
||||
h2EdgeSNI = "h2.cftunnel.com"
|
||||
h2ResponseMetaCloudflared = `{"src":"cloudflared"}`
|
||||
h2ResponseMetaCloudflaredLimited = `{"src":"cloudflared","flow_rate_limited":true}`
|
||||
contentTypeHeader = "content-type"
|
||||
contentLengthHeader = "content-length"
|
||||
transferEncodingHeader = "transfer-encoding"
|
||||
chunkTransferEncoding = "chunked"
|
||||
sseContentType = "text/event-stream"
|
||||
grpcContentType = "application/grpc"
|
||||
ndjsonContentType = "application/x-ndjson"
|
||||
)
|
||||
|
||||
var flushableContentTypes = []string{sseContentType, grpcContentType, ndjsonContentType}
|
||||
|
||||
// HTTP2Connection manages a single HTTP/2 connection to the Cloudflare edge.
|
||||
// Uses role reversal: we dial the edge as a TLS client but serve HTTP/2 as server.
|
||||
type HTTP2Connection struct {
|
||||
conn net.Conn
|
||||
server *http2.Server
|
||||
logger log.ContextLogger
|
||||
edgeAddr *EdgeAddr
|
||||
connIndex uint8
|
||||
credentials Credentials
|
||||
connectorID uuid.UUID
|
||||
features []string
|
||||
gracePeriod time.Duration
|
||||
inbound *Inbound
|
||||
|
||||
numPreviousAttempts uint8
|
||||
registrationClient registrationRPCClient
|
||||
registrationResult *RegistrationResult
|
||||
controlStreamErr error
|
||||
|
||||
activeRequests sync.WaitGroup
|
||||
serveCancel context.CancelFunc
|
||||
registrationClose sync.Once
|
||||
shutdownOnce sync.Once
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// NewHTTP2Connection dials the edge and establishes an HTTP/2 connection with role reversal.
|
||||
func NewHTTP2Connection(
|
||||
ctx context.Context,
|
||||
edgeAddr *EdgeAddr,
|
||||
connIndex uint8,
|
||||
credentials Credentials,
|
||||
connectorID uuid.UUID,
|
||||
features []string,
|
||||
numPreviousAttempts uint8,
|
||||
gracePeriod time.Duration,
|
||||
inbound *Inbound,
|
||||
logger log.ContextLogger,
|
||||
) (*HTTP2Connection, error) {
|
||||
rootCAs, err := cloudflareRootCertPool()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "load Cloudflare root CAs")
|
||||
}
|
||||
|
||||
tlsConfig := newEdgeTLSConfig(rootCAs, h2EdgeSNI, nil)
|
||||
|
||||
tcpConn, err := inbound.tunnelDialer.DialContext(ctx, "tcp", M.SocksaddrFrom(edgeAddr.TCP.AddrPort().Addr(), edgeAddr.TCP.AddrPort().Port()))
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "dial edge TCP")
|
||||
}
|
||||
|
||||
tlsConn := tls.Client(tcpConn, tlsConfig)
|
||||
err = tlsConn.HandshakeContext(ctx)
|
||||
if err != nil {
|
||||
tcpConn.Close()
|
||||
return nil, E.Cause(err, "TLS handshake")
|
||||
}
|
||||
|
||||
return &HTTP2Connection{
|
||||
conn: tlsConn,
|
||||
server: &http2.Server{
|
||||
MaxConcurrentStreams: math.MaxUint32,
|
||||
},
|
||||
logger: logger,
|
||||
edgeAddr: edgeAddr,
|
||||
connIndex: connIndex,
|
||||
credentials: credentials,
|
||||
connectorID: connectorID,
|
||||
features: features,
|
||||
numPreviousAttempts: numPreviousAttempts,
|
||||
gracePeriod: gracePeriod,
|
||||
inbound: inbound,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Serve runs the HTTP/2 server. Blocks until the context is cancelled or the connection ends.
|
||||
func (c *HTTP2Connection) Serve(ctx context.Context) error {
|
||||
serveCtx, serveCancel := context.WithCancel(context.WithoutCancel(ctx))
|
||||
c.serveCancel = serveCancel
|
||||
|
||||
shutdownDone := make(chan struct{})
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
c.gracefulShutdown()
|
||||
close(shutdownDone)
|
||||
}()
|
||||
|
||||
c.server.ServeConn(c.conn, &http2.ServeConnOpts{
|
||||
Context: serveCtx,
|
||||
Handler: c,
|
||||
})
|
||||
|
||||
if ctx.Err() != nil {
|
||||
<-shutdownDone
|
||||
return ctx.Err()
|
||||
}
|
||||
if c.controlStreamErr != nil {
|
||||
return c.controlStreamErr
|
||||
}
|
||||
if c.registrationResult == nil {
|
||||
return E.New("edge connection closed before registration")
|
||||
}
|
||||
return E.New("edge connection closed")
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get(h2HeaderUpgrade) == h2UpgradeControlStream {
|
||||
c.handleControlStream(r.Context(), r, w)
|
||||
return
|
||||
}
|
||||
|
||||
c.activeRequests.Add(1)
|
||||
defer c.activeRequests.Done()
|
||||
|
||||
switch {
|
||||
case r.Header.Get(h2HeaderUpgrade) == h2UpgradeWebsocket:
|
||||
c.handleH2DataStream(r.Context(), r, w, ConnectionTypeWebsocket)
|
||||
case r.Header.Get(h2HeaderTCPSrc) != "":
|
||||
c.handleH2DataStream(r.Context(), r, w, ConnectionTypeTCP)
|
||||
case r.Header.Get(h2HeaderUpgrade) == h2UpgradeConfiguration:
|
||||
c.handleConfigurationUpdate(r, w)
|
||||
default:
|
||||
c.handleH2DataStream(r.Context(), r, w, ConnectionTypeHTTP)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Request, w http.ResponseWriter) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
c.logger.Error("response writer does not support flushing")
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher.Flush()
|
||||
|
||||
stream := newHTTP2Stream(r.Body, &http2FlushWriter{w: w, flusher: flusher})
|
||||
|
||||
c.registrationClient = NewRegistrationClient(ctx, stream)
|
||||
|
||||
host, _, _ := net.SplitHostPort(c.conn.LocalAddr().String())
|
||||
originLocalIP := net.ParseIP(host)
|
||||
options := BuildConnectionOptions(c.connectorID, c.features, c.numPreviousAttempts, originLocalIP)
|
||||
result, err := c.registrationClient.RegisterConnection(
|
||||
ctx, c.credentials.Auth(), c.credentials.TunnelID, c.connIndex, options,
|
||||
)
|
||||
if err != nil {
|
||||
c.controlStreamErr = err
|
||||
c.logger.Error("register connection: ", err)
|
||||
go c.forceClose()
|
||||
return
|
||||
}
|
||||
if err := validateRegistrationResult(result); err != nil {
|
||||
c.controlStreamErr = err
|
||||
c.logger.Error("register connection: ", err)
|
||||
go c.forceClose()
|
||||
return
|
||||
}
|
||||
c.registrationResult = result
|
||||
c.inbound.notifyConnected(c.connIndex, "http2")
|
||||
|
||||
c.logger.Info("connected to ", result.Location,
|
||||
" (connection ", result.ConnectionID, ")")
|
||||
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) handleH2DataStream(ctx context.Context, r *http.Request, w http.ResponseWriter, connectionType ConnectionType) {
|
||||
r.Header.Del(h2HeaderUpgrade)
|
||||
r.Header.Del(h2HeaderTCPSrc)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
c.logger.Error("response writer does not support flushing")
|
||||
return
|
||||
}
|
||||
|
||||
var destination string
|
||||
if connectionType == ConnectionTypeTCP {
|
||||
destination = r.Host
|
||||
if destination == "" && r.URL != nil {
|
||||
destination = r.URL.Host
|
||||
}
|
||||
} else {
|
||||
if r.URL.Scheme == "" {
|
||||
r.URL.Scheme = "http"
|
||||
}
|
||||
if r.URL.Host == "" {
|
||||
r.URL.Host = r.Host
|
||||
}
|
||||
destination = r.URL.String()
|
||||
}
|
||||
|
||||
request := &ConnectRequest{
|
||||
Dest: destination,
|
||||
Type: connectionType,
|
||||
}
|
||||
request.Metadata = append(request.Metadata, Metadata{
|
||||
Key: metadataHTTPMethod,
|
||||
Val: r.Method,
|
||||
})
|
||||
request.Metadata = append(request.Metadata, Metadata{
|
||||
Key: metadataHTTPHost,
|
||||
Val: r.Host,
|
||||
})
|
||||
for name, values := range r.Header {
|
||||
for _, value := range values {
|
||||
request.Metadata = append(request.Metadata, Metadata{
|
||||
Key: metadataHTTPHeader + ":" + name,
|
||||
Val: value,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
flushState := &http2FlushState{shouldFlush: connectionType != ConnectionTypeHTTP}
|
||||
stream := &http2DataStream{
|
||||
reader: r.Body,
|
||||
writer: w,
|
||||
flusher: flusher,
|
||||
state: flushState,
|
||||
logger: c.logger,
|
||||
}
|
||||
respWriter := &http2ResponseWriter{
|
||||
writer: w,
|
||||
flusher: flusher,
|
||||
flushState: flushState,
|
||||
}
|
||||
|
||||
c.inbound.dispatchRequest(ctx, stream, respWriter, request)
|
||||
}
|
||||
|
||||
type h2ConfigurationUpdateBody struct {
|
||||
Version int32 `json:"version"`
|
||||
Config json.RawMessage `json:"config"`
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) handleConfigurationUpdate(r *http.Request, w http.ResponseWriter) {
|
||||
var body h2ConfigurationUpdateBody
|
||||
err := json.NewDecoder(r.Body).Decode(&body)
|
||||
if err != nil {
|
||||
c.logger.Error("decode configuration update: ", err)
|
||||
w.Header().Set(h2HeaderResponseMeta, h2ResponseMetaCloudflared)
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
result := c.inbound.ApplyConfig(body.Version, body.Config)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if result.Err != nil {
|
||||
w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":` + strconv.Quote(result.Err.Error()) + `}`))
|
||||
return
|
||||
}
|
||||
w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":null}`))
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) gracefulShutdown() {
|
||||
c.shutdownOnce.Do(func() {
|
||||
if c.registrationClient == nil || c.registrationResult == nil {
|
||||
c.closeNow()
|
||||
return
|
||||
}
|
||||
|
||||
unregisterCtx, cancel := context.WithTimeout(context.Background(), c.gracePeriod)
|
||||
err := c.registrationClient.Unregister(unregisterCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
c.logger.Debug("failed to unregister: ", err)
|
||||
}
|
||||
c.closeRegistrationClient()
|
||||
c.waitForActiveRequests(c.gracePeriod)
|
||||
c.closeNow()
|
||||
})
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) forceClose() {
|
||||
c.shutdownOnce.Do(func() {
|
||||
c.closeNow()
|
||||
})
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) waitForActiveRequests(timeout time.Duration) {
|
||||
if timeout <= 0 {
|
||||
c.activeRequests.Wait()
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.activeRequests.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) closeRegistrationClient() {
|
||||
c.registrationClose.Do(func() {
|
||||
if c.registrationClient != nil {
|
||||
_ = c.registrationClient.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) closeNow() {
|
||||
c.closeOnce.Do(func() {
|
||||
_ = c.conn.Close()
|
||||
if c.serveCancel != nil {
|
||||
c.serveCancel()
|
||||
}
|
||||
c.closeRegistrationClient()
|
||||
c.activeRequests.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
// Close closes the HTTP/2 connection.
|
||||
func (c *HTTP2Connection) Close() error {
|
||||
c.forceClose()
|
||||
return nil
|
||||
}
|
||||
|
||||
// http2Stream wraps an HTTP/2 request body (reader) and a flush-writer (writer) as an io.ReadWriteCloser.
|
||||
// Used for the control stream.
|
||||
type http2Stream struct {
|
||||
reader io.ReadCloser
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
func newHTTP2Stream(reader io.ReadCloser, writer io.Writer) *http2Stream {
|
||||
return &http2Stream{reader: reader, writer: writer}
|
||||
}
|
||||
|
||||
func (s *http2Stream) Read(p []byte) (int, error) { return s.reader.Read(p) }
|
||||
func (s *http2Stream) Write(p []byte) (int, error) { return s.writer.Write(p) }
|
||||
func (s *http2Stream) Close() error { return s.reader.Close() }
|
||||
|
||||
// http2FlushWriter wraps an http.ResponseWriter and flushes after every write.
|
||||
type http2FlushWriter struct {
|
||||
w http.ResponseWriter
|
||||
flusher http.Flusher
|
||||
}
|
||||
|
||||
func (w *http2FlushWriter) Write(p []byte) (int, error) {
|
||||
n, err := w.w.Write(p)
|
||||
if err == nil {
|
||||
w.flusher.Flush()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// http2DataStream wraps an HTTP/2 request/response pair as io.ReadWriteCloser for data streams.
|
||||
type http2DataStream struct {
|
||||
reader io.ReadCloser
|
||||
writer http.ResponseWriter
|
||||
flusher http.Flusher
|
||||
state *http2FlushState
|
||||
logger log.ContextLogger
|
||||
}
|
||||
|
||||
func (s *http2DataStream) Read(p []byte) (int, error) {
|
||||
return s.reader.Read(p)
|
||||
}
|
||||
|
||||
func (s *http2DataStream) Write(p []byte) (n int, err error) {
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
if s.logger != nil {
|
||||
s.logger.Debug("recovered from HTTP/2 data stream panic: ", recovered, "\n", string(debug.Stack()))
|
||||
}
|
||||
n = 0
|
||||
err = io.ErrClosedPipe
|
||||
}
|
||||
}()
|
||||
n, err = s.writer.Write(p)
|
||||
if err == nil && s.state != nil && s.state.shouldFlush {
|
||||
s.flusher.Flush()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (s *http2DataStream) Close() error {
|
||||
return s.reader.Close()
|
||||
}
|
||||
|
||||
// http2ResponseWriter translates ConnectResponse metadata to HTTP/2 response headers.
|
||||
type http2ResponseWriter struct {
|
||||
writer http.ResponseWriter
|
||||
flusher http.Flusher
|
||||
headersSent bool
|
||||
flushState *http2FlushState
|
||||
}
|
||||
|
||||
func (w *http2ResponseWriter) AddTrailer(name, value string) {
|
||||
if !w.headersSent {
|
||||
return
|
||||
}
|
||||
w.writer.Header().Add(http2.TrailerPrefix+name, value)
|
||||
}
|
||||
|
||||
func (w *http2ResponseWriter) WriteResponse(responseError error, metadata []Metadata) error {
|
||||
if w.headersSent {
|
||||
return nil
|
||||
}
|
||||
w.headersSent = true
|
||||
|
||||
if responseError != nil {
|
||||
if hasFlowConnectRateLimited(metadata) {
|
||||
w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaCloudflaredLimited)
|
||||
} else {
|
||||
w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaCloudflared)
|
||||
}
|
||||
w.writer.WriteHeader(http.StatusBadGateway)
|
||||
w.flusher.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
statusCode := http.StatusOK
|
||||
userHeaders := make(http.Header)
|
||||
|
||||
for _, entry := range metadata {
|
||||
if entry.Key == metadataHTTPStatus {
|
||||
code, err := strconv.Atoi(entry.Val)
|
||||
if err == nil {
|
||||
statusCode = code
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(entry.Key, metadataHTTPHeader+":") {
|
||||
headerName := strings.TrimPrefix(entry.Key, metadataHTTPHeader+":")
|
||||
lower := strings.ToLower(headerName)
|
||||
|
||||
if lower == "content-length" {
|
||||
w.writer.Header().Set(headerName, entry.Val)
|
||||
}
|
||||
|
||||
if !isControlResponseHeader(lower) || isWebsocketClientHeader(lower) {
|
||||
userHeaders.Add(headerName, entry.Val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
w.writer.Header().Set(h2HeaderResponseUser, SerializeHeaders(userHeaders))
|
||||
w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaOrigin)
|
||||
if w.flushState != nil && shouldFlushHTTPHeaders(userHeaders) {
|
||||
w.flushState.shouldFlush = true
|
||||
}
|
||||
|
||||
if statusCode == http.StatusSwitchingProtocols {
|
||||
statusCode = http.StatusOK
|
||||
}
|
||||
|
||||
w.writer.WriteHeader(statusCode)
|
||||
if w.flushState != nil && w.flushState.shouldFlush {
|
||||
w.flusher.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type http2FlushState struct {
|
||||
shouldFlush bool
|
||||
}
|
||||
|
||||
func shouldFlushHTTPHeaders(headers http.Header) bool {
|
||||
if headers.Get(contentLengthHeader) == "" {
|
||||
return true
|
||||
}
|
||||
if transferEncoding := strings.ToLower(headers.Get(transferEncodingHeader)); transferEncoding != "" && strings.Contains(transferEncoding, chunkTransferEncoding) {
|
||||
return true
|
||||
}
|
||||
contentType := strings.ToLower(headers.Get(contentTypeHeader))
|
||||
for _, flushable := range flushableContentTypes {
|
||||
if strings.HasPrefix(contentType, flushable) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,191 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
)
|
||||
|
||||
type captureHTTP2Writer struct {
|
||||
header http.Header
|
||||
flushCount int
|
||||
statusCode int
|
||||
body []byte
|
||||
panicWrite bool
|
||||
}
|
||||
|
||||
func (w *captureHTTP2Writer) Header() http.Header {
|
||||
if w.header == nil {
|
||||
w.header = make(http.Header)
|
||||
}
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *captureHTTP2Writer) WriteHeader(statusCode int) {
|
||||
w.statusCode = statusCode
|
||||
}
|
||||
|
||||
func (w *captureHTTP2Writer) Write(p []byte) (int, error) {
|
||||
if w.panicWrite {
|
||||
panic("write after close")
|
||||
}
|
||||
w.body = append(w.body, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *captureHTTP2Writer) Flush() {
|
||||
w.flushCount++
|
||||
}
|
||||
|
||||
func TestHTTP2NonStreamingResponseDoesNotFlush(t *testing.T) {
|
||||
writer := &captureHTTP2Writer{}
|
||||
flushState := &http2FlushState{}
|
||||
respWriter := &http2ResponseWriter{
|
||||
writer: writer,
|
||||
flusher: writer,
|
||||
flushState: flushState,
|
||||
}
|
||||
|
||||
err := respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusOK, http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"Content-Length": []string{"2"},
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if writer.flushCount != 0 {
|
||||
t.Fatalf("expected no header flush for non-streaming response, got %d", writer.flushCount)
|
||||
}
|
||||
|
||||
stream := &http2DataStream{
|
||||
writer: writer,
|
||||
flusher: writer,
|
||||
state: flushState,
|
||||
logger: log.NewNOPFactory().NewLogger("test"),
|
||||
}
|
||||
if _, err := stream.Write([]byte("ok")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if writer.flushCount != 0 {
|
||||
t.Fatalf("expected no body flush for non-streaming response, got %d", writer.flushCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP2StreamingResponsesFlush(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
header http.Header
|
||||
}{
|
||||
{
|
||||
name: "sse",
|
||||
header: http.Header{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
"Content-Length": []string{"1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "grpc",
|
||||
header: http.Header{
|
||||
"Content-Type": []string{"application/grpc"},
|
||||
"Content-Length": []string{"1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ndjson",
|
||||
header: http.Header{
|
||||
"Content-Type": []string{"application/x-ndjson"},
|
||||
"Content-Length": []string{"1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chunked",
|
||||
header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"Content-Length": []string{"-1"},
|
||||
"Transfer-Encoding": []string{"chunked"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no-content-length",
|
||||
header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
writer := &captureHTTP2Writer{}
|
||||
flushState := &http2FlushState{}
|
||||
respWriter := &http2ResponseWriter{
|
||||
writer: writer,
|
||||
flusher: writer,
|
||||
flushState: flushState,
|
||||
}
|
||||
|
||||
err := respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusOK, testCase.header))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if writer.flushCount == 0 {
|
||||
t.Fatal("expected header flush for streaming response")
|
||||
}
|
||||
|
||||
stream := &http2DataStream{
|
||||
writer: writer,
|
||||
flusher: writer,
|
||||
state: flushState,
|
||||
logger: log.NewNOPFactory().NewLogger("test"),
|
||||
}
|
||||
if _, err := stream.Write([]byte("chunk")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if writer.flushCount < 2 {
|
||||
t.Fatalf("expected body flush for streaming response, got %d flushes", writer.flushCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP2DataStreamWriteRecoversPanic(t *testing.T) {
|
||||
writer := &captureHTTP2Writer{panicWrite: true}
|
||||
stream := &http2DataStream{
|
||||
writer: writer,
|
||||
flusher: writer,
|
||||
state: &http2FlushState{shouldFlush: true},
|
||||
logger: log.NewNOPFactory().NewLogger("test"),
|
||||
}
|
||||
|
||||
_, err := stream.Write([]byte("panic"))
|
||||
if err != io.ErrClosedPipe {
|
||||
t.Fatalf("expected io.ErrClosedPipe, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleConfigurationUpdateDecodeFailureReturnsBadGateway(t *testing.T) {
|
||||
writer := &captureHTTP2Writer{}
|
||||
connection := &HTTP2Connection{
|
||||
logger: log.NewNOPFactory().NewLogger("test"),
|
||||
}
|
||||
request, err := http.NewRequest(http.MethodPost, "https://example.com", bytes.NewBufferString("{"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
connection.handleConfigurationUpdate(request, writer)
|
||||
|
||||
if writer.statusCode != http.StatusBadGateway {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, writer.statusCode)
|
||||
}
|
||||
if meta := writer.Header().Get(h2HeaderResponseMeta); meta != h2ResponseMetaCloudflared {
|
||||
t.Fatalf("unexpected response meta: %q", meta)
|
||||
}
|
||||
if len(writer.body) != 0 {
|
||||
t.Fatalf("expected empty response body, got %q", string(writer.body))
|
||||
}
|
||||
}
|
||||
@@ -1,436 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
quicEdgeSNI = "quic.cftunnel.com"
|
||||
quicEdgeALPN = "argotunnel"
|
||||
|
||||
quicHandshakeIdleTimeout = 5 * time.Second
|
||||
quicMaxIdleTimeout = 5 * time.Second
|
||||
quicKeepAlivePeriod = 1 * time.Second
|
||||
)
|
||||
|
||||
func quicInitialPacketSize(ipVersion int) uint16 {
|
||||
initialPacketSize := uint16(1252)
|
||||
if ipVersion == 4 {
|
||||
initialPacketSize = 1232
|
||||
}
|
||||
return initialPacketSize
|
||||
}
|
||||
|
||||
// QUICConnection manages a single QUIC connection to the Cloudflare edge.
|
||||
type QUICConnection struct {
|
||||
conn quicConnection
|
||||
logger log.ContextLogger
|
||||
edgeAddr *EdgeAddr
|
||||
connIndex uint8
|
||||
credentials Credentials
|
||||
connectorID uuid.UUID
|
||||
datagramVersion string
|
||||
features []string
|
||||
numPreviousAttempts uint8
|
||||
gracePeriod time.Duration
|
||||
registrationClient registrationRPCClient
|
||||
registrationResult *RegistrationResult
|
||||
onConnected func()
|
||||
|
||||
serveCtx context.Context
|
||||
serveCancel context.CancelFunc
|
||||
registrationClose sync.Once
|
||||
shutdownOnce sync.Once
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
type quicStreamHandle interface {
|
||||
io.Reader
|
||||
io.Writer
|
||||
io.Closer
|
||||
CancelRead(code quic.StreamErrorCode)
|
||||
CancelWrite(code quic.StreamErrorCode)
|
||||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
type quicConnection interface {
|
||||
OpenStream() (*quic.Stream, error)
|
||||
AcceptStream(ctx context.Context) (*quic.Stream, error)
|
||||
ReceiveDatagram(ctx context.Context) ([]byte, error)
|
||||
SendDatagram(data []byte) error
|
||||
LocalAddr() net.Addr
|
||||
CloseWithError(code quic.ApplicationErrorCode, reason string) error
|
||||
}
|
||||
|
||||
type closeableQUICConn struct {
|
||||
*quic.Conn
|
||||
udpConn *net.UDPConn
|
||||
}
|
||||
|
||||
func (c *closeableQUICConn) CloseWithError(code quic.ApplicationErrorCode, reason string) error {
|
||||
err := c.Conn.CloseWithError(code, reason)
|
||||
_ = c.udpConn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// NewQUICConnection dials the edge and establishes a QUIC connection.
|
||||
func NewQUICConnection(
|
||||
ctx context.Context,
|
||||
edgeAddr *EdgeAddr,
|
||||
connIndex uint8,
|
||||
credentials Credentials,
|
||||
connectorID uuid.UUID,
|
||||
datagramVersion string,
|
||||
features []string,
|
||||
numPreviousAttempts uint8,
|
||||
gracePeriod time.Duration,
|
||||
tunnelDialer N.Dialer,
|
||||
onConnected func(),
|
||||
logger log.ContextLogger,
|
||||
) (*QUICConnection, error) {
|
||||
rootCAs, err := cloudflareRootCertPool()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "load Cloudflare root CAs")
|
||||
}
|
||||
|
||||
tlsConfig := newEdgeTLSConfig(rootCAs, quicEdgeSNI, []string{quicEdgeALPN})
|
||||
|
||||
quicConfig := &quic.Config{
|
||||
HandshakeIdleTimeout: quicHandshakeIdleTimeout,
|
||||
MaxIdleTimeout: quicMaxIdleTimeout,
|
||||
KeepAlivePeriod: quicKeepAlivePeriod,
|
||||
MaxIncomingStreams: 1 << 60,
|
||||
MaxIncomingUniStreams: 1 << 60,
|
||||
EnableDatagrams: true,
|
||||
InitialPacketSize: quicInitialPacketSize(edgeAddr.IPVersion),
|
||||
}
|
||||
|
||||
udpConn, err := createUDPConnForConnIndex(ctx, edgeAddr, tunnelDialer)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "listen UDP for QUIC edge")
|
||||
}
|
||||
|
||||
conn, err := quic.Dial(ctx, udpConn, edgeAddr.UDP, tlsConfig, quicConfig)
|
||||
if err != nil {
|
||||
udpConn.Close()
|
||||
return nil, E.Cause(err, "dial QUIC edge")
|
||||
}
|
||||
|
||||
return &QUICConnection{
|
||||
conn: &closeableQUICConn{Conn: conn, udpConn: udpConn},
|
||||
logger: logger,
|
||||
edgeAddr: edgeAddr,
|
||||
connIndex: connIndex,
|
||||
credentials: credentials,
|
||||
connectorID: connectorID,
|
||||
datagramVersion: datagramVersion,
|
||||
features: features,
|
||||
numPreviousAttempts: numPreviousAttempts,
|
||||
gracePeriod: gracePeriod,
|
||||
onConnected: onConnected,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createUDPConnForConnIndex creates a UDP socket for QUIC via the tunnel dialer.
|
||||
// Unlike cloudflared, we do not attempt to reuse previously-bound ports across
|
||||
// reconnects — the dialer interface does not support specifying local ports,
|
||||
// and fixed port binding is not important for our use case.
|
||||
// We also do not apply Darwin-specific udp4/udp6 network selection to work around
|
||||
// quic-go#3793 (DF bit on macOS dual-stack); the dialer controls network selection
|
||||
// and this is a non-critical platform-specific limitation.
|
||||
func createUDPConnForConnIndex(ctx context.Context, edgeAddr *EdgeAddr, tunnelDialer N.Dialer) (*net.UDPConn, error) {
|
||||
packetConn, err := tunnelDialer.ListenPacket(ctx, M.SocksaddrFrom(edgeAddr.UDP.AddrPort().Addr(), edgeAddr.UDP.AddrPort().Port()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpConn, ok := packetConn.(*net.UDPConn)
|
||||
if !ok {
|
||||
packetConn.Close()
|
||||
return nil, fmt.Errorf("unexpected packet conn type %T", packetConn)
|
||||
}
|
||||
return udpConn, nil
|
||||
}
|
||||
|
||||
// Serve runs the QUIC connection: registers, accepts streams, handles datagrams.
|
||||
// Blocks until the context is cancelled or a fatal error occurs.
|
||||
func (q *QUICConnection) Serve(ctx context.Context, handler StreamHandler) error {
|
||||
controlStream, err := q.conn.OpenStream()
|
||||
if err != nil {
|
||||
return E.Cause(err, "open control stream")
|
||||
}
|
||||
|
||||
err = q.register(ctx, controlStream)
|
||||
if err != nil {
|
||||
controlStream.Close()
|
||||
q.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
q.logger.Info("connected to ", q.registrationResult.Location,
|
||||
" (connection ", q.registrationResult.ConnectionID, ")")
|
||||
|
||||
serveCtx, serveCancel := context.WithCancel(context.WithoutCancel(ctx))
|
||||
q.serveCtx = serveCtx
|
||||
q.serveCancel = serveCancel
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
errChan <- q.acceptStreams(serveCtx, handler)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
errChan <- q.handleDatagrams(serveCtx, handler)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
q.gracefulShutdown()
|
||||
<-errChan
|
||||
return ctx.Err()
|
||||
case err = <-errChan:
|
||||
q.forceClose()
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (q *QUICConnection) register(ctx context.Context, stream *quic.Stream) error {
|
||||
q.registrationClient = NewRegistrationClient(ctx, newStreamReadWriteCloser(stream))
|
||||
|
||||
host, _, _ := net.SplitHostPort(q.conn.LocalAddr().String())
|
||||
originLocalIP := net.ParseIP(host)
|
||||
options := BuildConnectionOptions(q.connectorID, q.features, q.numPreviousAttempts, originLocalIP)
|
||||
result, err := q.registrationClient.RegisterConnection(
|
||||
ctx, q.credentials.Auth(), q.credentials.TunnelID, q.connIndex, options,
|
||||
)
|
||||
if err != nil {
|
||||
return E.Cause(err, "register connection")
|
||||
}
|
||||
if err := validateRegistrationResult(result); err != nil {
|
||||
return err
|
||||
}
|
||||
q.registrationResult = result
|
||||
if q.onConnected != nil {
|
||||
q.onConnected()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *QUICConnection) acceptStreams(ctx context.Context, handler StreamHandler) error {
|
||||
for {
|
||||
stream, err := q.conn.AcceptStream(ctx)
|
||||
if err != nil {
|
||||
return E.Cause(err, "accept stream")
|
||||
}
|
||||
go q.handleStream(ctx, stream, handler)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *QUICConnection) handleStream(ctx context.Context, stream quicStreamHandle, handler StreamHandler) {
|
||||
rwc := newStreamReadWriteCloser(stream)
|
||||
defer rwc.Close()
|
||||
|
||||
streamType, err := ReadStreamSignature(rwc)
|
||||
if err != nil {
|
||||
q.logger.Debug("failed to read stream signature: ", err)
|
||||
stream.CancelWrite(0)
|
||||
return
|
||||
}
|
||||
|
||||
switch streamType {
|
||||
case StreamTypeData:
|
||||
request, err := ReadConnectRequest(rwc)
|
||||
if err != nil {
|
||||
q.logger.Debug("failed to read connect request: ", err)
|
||||
stream.CancelWrite(0)
|
||||
return
|
||||
}
|
||||
handler.HandleDataStream(ctx, &nopCloserReadWriter{ReadWriteCloser: rwc}, request, q.connIndex)
|
||||
|
||||
case StreamTypeRPC:
|
||||
handler.HandleRPCStreamWithSender(ctx, rwc, q.connIndex, q)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *QUICConnection) handleDatagrams(ctx context.Context, handler StreamHandler) error {
|
||||
for {
|
||||
datagram, err := q.conn.ReceiveDatagram(ctx)
|
||||
if err != nil {
|
||||
return E.Cause(err, "receive datagram")
|
||||
}
|
||||
handler.HandleDatagram(ctx, datagram, q)
|
||||
}
|
||||
}
|
||||
|
||||
// SendDatagram sends a QUIC datagram to the edge.
|
||||
func (q *QUICConnection) SendDatagram(data []byte) error {
|
||||
return q.conn.SendDatagram(data)
|
||||
}
|
||||
|
||||
func (q *QUICConnection) DatagramVersion() string {
|
||||
return q.datagramVersion
|
||||
}
|
||||
|
||||
func (q *QUICConnection) OpenRPCStream(ctx context.Context) (io.ReadWriteCloser, error) {
|
||||
stream, err := q.conn.OpenStream()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "open rpc stream")
|
||||
}
|
||||
rwc := newStreamReadWriteCloser(stream)
|
||||
if err := WriteRPCStreamSignature(rwc); err != nil {
|
||||
rwc.Close()
|
||||
return nil, E.Cause(err, "write rpc stream signature")
|
||||
}
|
||||
return rwc, nil
|
||||
}
|
||||
|
||||
func (q *QUICConnection) gracefulShutdown() {
|
||||
q.shutdownOnce.Do(func() {
|
||||
if q.registrationClient == nil || q.registrationResult == nil {
|
||||
q.closeNow("connection closed")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), q.gracePeriod)
|
||||
err := q.registrationClient.Unregister(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
q.logger.Debug("failed to unregister: ", err)
|
||||
}
|
||||
q.closeRegistrationClient()
|
||||
if q.gracePeriod > 0 {
|
||||
waitCtx := q.serveCtx
|
||||
if waitCtx == nil {
|
||||
waitCtx = context.Background()
|
||||
}
|
||||
timer := time.NewTimer(q.gracePeriod)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-waitCtx.Done():
|
||||
}
|
||||
}
|
||||
q.closeNow("graceful shutdown")
|
||||
})
|
||||
}
|
||||
|
||||
func (q *QUICConnection) forceClose() {
|
||||
q.shutdownOnce.Do(func() {
|
||||
q.closeNow("connection closed")
|
||||
})
|
||||
}
|
||||
|
||||
func (q *QUICConnection) closeRegistrationClient() {
|
||||
q.registrationClose.Do(func() {
|
||||
if q.registrationClient != nil {
|
||||
_ = q.registrationClient.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (q *QUICConnection) closeNow(reason string) {
|
||||
q.closeOnce.Do(func() {
|
||||
if q.serveCancel != nil {
|
||||
q.serveCancel()
|
||||
}
|
||||
q.closeRegistrationClient()
|
||||
_ = q.conn.CloseWithError(0, reason)
|
||||
})
|
||||
}
|
||||
|
||||
// Close closes the QUIC connection immediately.
|
||||
func (q *QUICConnection) Close() error {
|
||||
q.forceClose()
|
||||
return nil
|
||||
}
|
||||
|
||||
// StreamHandler handles incoming edge streams and datagrams.
|
||||
type StreamHandler interface {
|
||||
HandleDataStream(ctx context.Context, stream io.ReadWriteCloser, request *ConnectRequest, connIndex uint8)
|
||||
HandleRPCStream(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8)
|
||||
HandleRPCStreamWithSender(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8, sender DatagramSender)
|
||||
HandleDatagram(ctx context.Context, datagram []byte, sender DatagramSender)
|
||||
}
|
||||
|
||||
// DatagramSender can send QUIC datagrams back to the edge.
|
||||
type DatagramSender interface {
|
||||
SendDatagram(data []byte) error
|
||||
}
|
||||
|
||||
// streamReadWriteCloser adapts a *quic.Stream to io.ReadWriteCloser
|
||||
// with mutex-protected writes and safe close semantics.
|
||||
type streamReadWriteCloser struct {
|
||||
stream quicStreamHandle
|
||||
writeAccess sync.Mutex
|
||||
}
|
||||
|
||||
func newStreamReadWriteCloser(stream quicStreamHandle) *streamReadWriteCloser {
|
||||
return &streamReadWriteCloser{stream: stream}
|
||||
}
|
||||
|
||||
func (s *streamReadWriteCloser) Read(p []byte) (int, error) {
|
||||
return s.stream.Read(p)
|
||||
}
|
||||
|
||||
func (s *streamReadWriteCloser) Write(p []byte) (int, error) {
|
||||
s.writeAccess.Lock()
|
||||
defer s.writeAccess.Unlock()
|
||||
return s.stream.Write(p)
|
||||
}
|
||||
|
||||
func (s *streamReadWriteCloser) Close() error {
|
||||
_ = s.stream.SetWriteDeadline(time.Now())
|
||||
s.writeAccess.Lock()
|
||||
defer s.writeAccess.Unlock()
|
||||
s.stream.CancelRead(0)
|
||||
return s.stream.Close()
|
||||
}
|
||||
|
||||
// nopCloserReadWriter lets handlers stop consuming the read side without closing
|
||||
// the underlying stream write side. This matches cloudflared's QUIC HTTP behavior,
|
||||
// where the request body can be closed before the response is fully written.
|
||||
type nopCloserReadWriter struct {
|
||||
io.ReadWriteCloser
|
||||
|
||||
sawEOF bool
|
||||
closed uint32
|
||||
}
|
||||
|
||||
func (n *nopCloserReadWriter) Read(p []byte) (int, error) {
|
||||
if n.sawEOF {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if atomic.LoadUint32(&n.closed) > 0 {
|
||||
return 0, fmt.Errorf("closed by handler")
|
||||
}
|
||||
|
||||
readLen, err := n.ReadWriteCloser.Read(p)
|
||||
if err == io.EOF {
|
||||
n.sawEOF = true
|
||||
}
|
||||
return readLen, err
|
||||
}
|
||||
|
||||
func (n *nopCloserReadWriter) Close() error {
|
||||
atomic.StoreUint32(&n.closed, 1)
|
||||
return nil
|
||||
}
|
||||
@@ -1,129 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
)
|
||||
|
||||
func TestQUICInitialPacketSize(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
ipVersion int
|
||||
expected uint16
|
||||
}{
|
||||
{name: "ipv4", ipVersion: 4, expected: 1232},
|
||||
{name: "ipv6", ipVersion: 6, expected: 1252},
|
||||
{name: "default", ipVersion: 0, expected: 1252},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
if actual := quicInitialPacketSize(testCase.ipVersion); actual != testCase.expected {
|
||||
t.Fatalf("quicInitialPacketSize(%d) = %d, want %d", testCase.ipVersion, actual, testCase.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockReadWriteCloser struct {
|
||||
reader strings.Reader
|
||||
writes []byte
|
||||
}
|
||||
|
||||
func (m *mockReadWriteCloser) Read(p []byte) (int, error) {
|
||||
return m.reader.Read(p)
|
||||
}
|
||||
|
||||
func (m *mockReadWriteCloser) Write(p []byte) (int, error) {
|
||||
m.writes = append(m.writes, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (m *mockReadWriteCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNOPCloserReadWriterCloseOnlyStopsReads(t *testing.T) {
|
||||
inner := &mockReadWriteCloser{reader: *strings.NewReader("payload")}
|
||||
wrapper := &nopCloserReadWriter{ReadWriteCloser: inner}
|
||||
|
||||
if err := wrapper.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := wrapper.Read(make([]byte, 1)); err == nil {
|
||||
t.Fatal("expected read to fail after close")
|
||||
}
|
||||
|
||||
if _, err := wrapper.Write([]byte("response")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(inner.writes) != "response" {
|
||||
t.Fatalf("unexpected writes %q", inner.writes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNOPCloserReadWriterTracksEOF(t *testing.T) {
|
||||
inner := &mockReadWriteCloser{reader: *strings.NewReader("")}
|
||||
wrapper := &nopCloserReadWriter{ReadWriteCloser: inner}
|
||||
|
||||
if _, err := wrapper.Read(make([]byte, 1)); err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %v", err)
|
||||
}
|
||||
if _, err := wrapper.Read(make([]byte, 1)); err != io.EOF {
|
||||
t.Fatalf("expected cached EOF, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type fakeQUICStream struct {
|
||||
reader strings.Reader
|
||||
cancelWriteCount int
|
||||
}
|
||||
|
||||
func (s *fakeQUICStream) Read(p []byte) (int, error) { return s.reader.Read(p) }
|
||||
func (s *fakeQUICStream) Write(p []byte) (int, error) { return len(p), nil }
|
||||
func (s *fakeQUICStream) Close() error { return nil }
|
||||
func (s *fakeQUICStream) CancelRead(quic.StreamErrorCode) {}
|
||||
func (s *fakeQUICStream) CancelWrite(quic.StreamErrorCode) {
|
||||
s.cancelWriteCount++
|
||||
}
|
||||
func (s *fakeQUICStream) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
func TestHandleStreamCancelsWriteOnSignatureError(t *testing.T) {
|
||||
stream := &fakeQUICStream{reader: *strings.NewReader("broken")}
|
||||
connection := &QUICConnection{logger: log.NewNOPFactory().NewLogger("test")}
|
||||
|
||||
connection.handleStream(context.Background(), stream, nil)
|
||||
if stream.cancelWriteCount != 1 {
|
||||
t.Fatalf("expected CancelWrite on signature error, got %d", stream.cancelWriteCount)
|
||||
}
|
||||
}
|
||||
|
||||
type nopStreamHandler struct{}
|
||||
|
||||
func (nopStreamHandler) HandleDataStream(context.Context, io.ReadWriteCloser, *ConnectRequest, uint8) {
|
||||
}
|
||||
func (nopStreamHandler) HandleRPCStream(context.Context, io.ReadWriteCloser, uint8) {}
|
||||
func (nopStreamHandler) HandleRPCStreamWithSender(context.Context, io.ReadWriteCloser, uint8, DatagramSender) {
|
||||
}
|
||||
func (nopStreamHandler) HandleDatagram(context.Context, []byte, DatagramSender) {}
|
||||
|
||||
func TestHandleStreamCancelsWriteOnConnectRequestError(t *testing.T) {
|
||||
stream := &fakeQUICStream{
|
||||
reader: *strings.NewReader(string(dataStreamSignature[:])),
|
||||
}
|
||||
connection := &QUICConnection{logger: log.NewNOPFactory().NewLogger("test")}
|
||||
|
||||
connection.handleStream(context.Background(), stream, nopStreamHandler{})
|
||||
if stream.cancelWriteCount != 1 {
|
||||
t.Fatalf("expected CancelWrite on connect request error, got %d", stream.cancelWriteCount)
|
||||
}
|
||||
}
|
||||
@@ -1,222 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"zombiezen.com/go/capnproto2/pogs"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
)
|
||||
|
||||
const (
|
||||
rpcTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
var clientVersion = "sing-box " + C.Version
|
||||
|
||||
// RegistrationClient handles the Cap'n Proto RPC for tunnel registration.
|
||||
type RegistrationClient struct {
|
||||
client tunnelrpc.TunnelServer
|
||||
rpcConn *rpc.Conn
|
||||
transport rpc.Transport
|
||||
}
|
||||
|
||||
type registrationRPCClient interface {
|
||||
RegisterConnection(
|
||||
ctx context.Context,
|
||||
auth TunnelAuth,
|
||||
tunnelID uuid.UUID,
|
||||
connIndex uint8,
|
||||
options *RegistrationConnectionOptions,
|
||||
) (*RegistrationResult, error)
|
||||
Unregister(ctx context.Context) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type permanentRegistrationError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *permanentRegistrationError) Error() string {
|
||||
if e == nil || e.Err == nil {
|
||||
return "permanent registration error"
|
||||
}
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
func (e *permanentRegistrationError) Unwrap() error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return e.Err
|
||||
}
|
||||
|
||||
func isPermanentRegistrationError(err error) bool {
|
||||
var permanentErr *permanentRegistrationError
|
||||
return errors.As(err, &permanentErr)
|
||||
}
|
||||
|
||||
// NewRegistrationClient creates a Cap'n Proto RPC client over the given stream.
|
||||
// The stream should be the first QUIC stream (control stream).
|
||||
func NewRegistrationClient(ctx context.Context, stream io.ReadWriteCloser) *RegistrationClient {
|
||||
transport := safeTransport(stream)
|
||||
conn := newRPCClientConn(transport, ctx)
|
||||
return &RegistrationClient{
|
||||
client: tunnelrpc.TunnelServer{Client: conn.Bootstrap(ctx)},
|
||||
rpcConn: conn,
|
||||
transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterConnection registers this tunnel connection with the edge.
|
||||
func (c *RegistrationClient) RegisterConnection(
|
||||
ctx context.Context,
|
||||
auth TunnelAuth,
|
||||
tunnelID uuid.UUID,
|
||||
connIndex uint8,
|
||||
options *RegistrationConnectionOptions,
|
||||
) (*RegistrationResult, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, rpcTimeout)
|
||||
defer cancel()
|
||||
|
||||
promise := c.client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error {
|
||||
// Marshal TunnelAuth
|
||||
tunnelAuth, err := p.NewAuth()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
authPogs := &RegistrationTunnelAuth{
|
||||
AccountTag: auth.AccountTag,
|
||||
TunnelSecret: auth.TunnelSecret,
|
||||
}
|
||||
err = pogs.Insert(tunnelrpc.TunnelAuth_TypeID, tunnelAuth.Struct, authPogs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set tunnel ID
|
||||
err = p.SetTunnelId(tunnelID[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set connection index
|
||||
p.SetConnIndex(connIndex)
|
||||
|
||||
// Marshal ConnectionOptions
|
||||
connOptions, err := p.NewOptions()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return pogs.Insert(tunnelrpc.ConnectionOptions_TypeID, connOptions.Struct, options)
|
||||
})
|
||||
|
||||
response, err := promise.Result().Struct()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "registration RPC")
|
||||
}
|
||||
|
||||
result := response.Result()
|
||||
switch result.Which() {
|
||||
case tunnelrpc.ConnectionResponse_result_Which_error:
|
||||
resultError, err := result.Error()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read registration error")
|
||||
}
|
||||
cause, _ := resultError.Cause()
|
||||
registrationError := E.New(cause)
|
||||
if resultError.ShouldRetry() {
|
||||
return nil, &RetryableError{
|
||||
Err: registrationError,
|
||||
Delay: time.Duration(resultError.RetryAfter()),
|
||||
}
|
||||
}
|
||||
return nil, &permanentRegistrationError{Err: registrationError}
|
||||
|
||||
case tunnelrpc.ConnectionResponse_result_Which_connectionDetails:
|
||||
connDetails, err := result.ConnectionDetails()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read connection details")
|
||||
}
|
||||
uuidBytes, err := connDetails.Uuid()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read connection UUID")
|
||||
}
|
||||
connectionID, err := uuid.FromBytes(uuidBytes)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "parse connection UUID")
|
||||
}
|
||||
location, _ := connDetails.LocationName()
|
||||
return &RegistrationResult{
|
||||
ConnectionID: connectionID,
|
||||
Location: location,
|
||||
TunnelIsRemotelyManaged: connDetails.TunnelIsRemotelyManaged(),
|
||||
}, nil
|
||||
|
||||
default:
|
||||
return nil, E.New("unexpected registration response type")
|
||||
}
|
||||
}
|
||||
|
||||
// Unregister sends the UnregisterConnection RPC.
|
||||
func (c *RegistrationClient) Unregister(ctx context.Context) error {
|
||||
promise := c.client.UnregisterConnection(ctx, nil)
|
||||
_, err := promise.Struct()
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the RPC connection and transport.
|
||||
func (c *RegistrationClient) Close() error {
|
||||
return E.Errors(
|
||||
c.rpcConn.Close(),
|
||||
c.transport.Close(),
|
||||
)
|
||||
}
|
||||
|
||||
func validateRegistrationResult(result *RegistrationResult) error {
|
||||
if result == nil || result.TunnelIsRemotelyManaged {
|
||||
return nil
|
||||
}
|
||||
return ErrNonRemoteManagedTunnelUnsupported
|
||||
}
|
||||
|
||||
// BuildConnectionOptions creates the ConnectionOptions to send during registration.
|
||||
func BuildConnectionOptions(connectorID uuid.UUID, features []string, numPreviousAttempts uint8, originLocalIP net.IP) *RegistrationConnectionOptions {
|
||||
return &RegistrationConnectionOptions{
|
||||
Client: RegistrationClientInfo{
|
||||
ClientID: connectorID[:],
|
||||
Features: features,
|
||||
Version: clientVersion,
|
||||
Arch: runtime.GOOS + "_" + runtime.GOARCH,
|
||||
},
|
||||
ReplaceExisting: false,
|
||||
CompressionQuality: 0,
|
||||
OriginLocalIP: originLocalIP,
|
||||
NumPreviousAttempts: numPreviousAttempts,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultFeatures returns the feature strings to advertise.
|
||||
func DefaultFeatures(datagramVersion string) []string {
|
||||
features := []string{
|
||||
"serialized_headers",
|
||||
"support_datagram_v2",
|
||||
"support_quic_eof",
|
||||
"allow_remote_config",
|
||||
}
|
||||
if datagramVersion == "v3" {
|
||||
features = append(features, "support_datagram_v3_2")
|
||||
}
|
||||
return features
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
// Credentials contains all info needed to run a tunnel.
|
||||
type Credentials struct {
|
||||
AccountTag string `json:"AccountTag"`
|
||||
TunnelSecret []byte `json:"TunnelSecret"`
|
||||
TunnelID uuid.UUID `json:"TunnelID"`
|
||||
Endpoint string `json:"Endpoint,omitempty"`
|
||||
}
|
||||
|
||||
// TunnelToken is the compact token format used in the --token flag.
|
||||
// Field names match cloudflared's JSON encoding.
|
||||
type TunnelToken struct {
|
||||
AccountTag string `json:"a"`
|
||||
TunnelSecret []byte `json:"s"`
|
||||
TunnelID uuid.UUID `json:"t"`
|
||||
Endpoint string `json:"e,omitempty"`
|
||||
}
|
||||
|
||||
func (t TunnelToken) ToCredentials() Credentials {
|
||||
return Credentials{
|
||||
AccountTag: t.AccountTag,
|
||||
TunnelSecret: t.TunnelSecret,
|
||||
TunnelID: t.TunnelID,
|
||||
Endpoint: t.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
// TunnelAuth is the authentication data sent during tunnel registration.
|
||||
type TunnelAuth struct {
|
||||
AccountTag string
|
||||
TunnelSecret []byte
|
||||
}
|
||||
|
||||
func (c *Credentials) Auth() TunnelAuth {
|
||||
return TunnelAuth{
|
||||
AccountTag: c.AccountTag,
|
||||
TunnelSecret: c.TunnelSecret,
|
||||
}
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func TestParseToken(t *testing.T) {
|
||||
tunnelID := uuid.New()
|
||||
secret := []byte("test-secret-32-bytes-long-xxxxx")
|
||||
tokenJSON := `{"a":"account123","t":"` + tunnelID.String() + `","s":"` + base64.StdEncoding.EncodeToString(secret) + `"}`
|
||||
token := base64.StdEncoding.EncodeToString([]byte(tokenJSON))
|
||||
|
||||
credentials, err := parseToken(token)
|
||||
if err != nil {
|
||||
t.Fatal("parseToken: ", err)
|
||||
}
|
||||
if credentials.AccountTag != "account123" {
|
||||
t.Error("expected AccountTag account123, got ", credentials.AccountTag)
|
||||
}
|
||||
if credentials.TunnelID != tunnelID {
|
||||
t.Error("expected TunnelID ", tunnelID, ", got ", credentials.TunnelID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenInvalidBase64(t *testing.T) {
|
||||
_, err := parseToken("not-valid-base64!!!")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid base64")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenInvalidJSON(t *testing.T) {
|
||||
token := base64.StdEncoding.EncodeToString([]byte("{bad json"))
|
||||
_, err := parseToken(token)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
@@ -1,204 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type v2UnregisterCall struct {
|
||||
sessionID uuid.UUID
|
||||
message string
|
||||
}
|
||||
|
||||
type captureRPCDatagramSender struct {
|
||||
captureDatagramSender
|
||||
}
|
||||
|
||||
type captureV2SessionRPCClient struct {
|
||||
unregisterCh chan v2UnregisterCall
|
||||
}
|
||||
|
||||
type blockingPacketConn struct {
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func newBlockingPacketConn() *blockingPacketConn {
|
||||
return &blockingPacketConn{closed: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (c *blockingPacketConn) ReadPacket(_ *buf.Buffer) (M.Socksaddr, error) {
|
||||
<-c.closed
|
||||
return M.Socksaddr{}, io.EOF
|
||||
}
|
||||
|
||||
func (c *blockingPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
|
||||
buffer.Release()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *blockingPacketConn) Close() error {
|
||||
closeOnce(c.closed)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *blockingPacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} }
|
||||
func (c *blockingPacketConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *blockingPacketConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *blockingPacketConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
type packetDialingRouter struct {
|
||||
testRouter
|
||||
packetConn N.PacketConn
|
||||
}
|
||||
|
||||
func (r *packetDialingRouter) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) {
|
||||
return r.packetConn, nil
|
||||
}
|
||||
|
||||
func (c *captureV2SessionRPCClient) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string) error {
|
||||
c.unregisterCh <- v2UnregisterCall{sessionID: sessionID, message: message}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *captureV2SessionRPCClient) Close() error { return nil }
|
||||
|
||||
func TestDatagramV2LocalCloseUnregistersRemote(t *testing.T) {
|
||||
inboundInstance := newLimitedInbound(t, 0)
|
||||
sender := &captureRPCDatagramSender{}
|
||||
muxer := NewDatagramV2Muxer(inboundInstance, sender, inboundInstance.logger)
|
||||
unregisterCh := make(chan v2UnregisterCall, 1)
|
||||
originalClientFactory := newV2SessionRPCClient
|
||||
newV2SessionRPCClient = func(ctx context.Context, sender DatagramSender) (v2SessionRPCClient, error) {
|
||||
return &captureV2SessionRPCClient{unregisterCh: unregisterCh}, nil
|
||||
}
|
||||
defer func() {
|
||||
newV2SessionRPCClient = originalClientFactory
|
||||
}()
|
||||
|
||||
sessionID := uuidTest(7)
|
||||
if err := muxer.RegisterSession(context.Background(), sessionID, net.IPv4(127, 0, 0, 1), 53, time.Second); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
muxer.sessionAccess.RLock()
|
||||
session := muxer.sessions[sessionID]
|
||||
muxer.sessionAccess.RUnlock()
|
||||
if session == nil {
|
||||
t.Fatal("expected registered session")
|
||||
}
|
||||
|
||||
session.closeWithReason("local close")
|
||||
|
||||
select {
|
||||
case call := <-unregisterCh:
|
||||
if call.sessionID != sessionID {
|
||||
t.Fatalf("unexpected session id: %s", call.sessionID)
|
||||
}
|
||||
if call.message != "local close" {
|
||||
t.Fatalf("unexpected message: %q", call.message)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected unregister rpc")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatagramV3RegistrationMigratesSender(t *testing.T) {
|
||||
inboundInstance := newLimitedInbound(t, 0)
|
||||
sender1 := &captureDatagramSender{}
|
||||
sender2 := &captureDatagramSender{}
|
||||
muxer1 := NewDatagramV3Muxer(inboundInstance, sender1, inboundInstance.logger)
|
||||
muxer2 := NewDatagramV3Muxer(inboundInstance, sender2, inboundInstance.logger)
|
||||
|
||||
requestID := RequestID{}
|
||||
requestID[15] = 9
|
||||
payload := make([]byte, 1+2+2+16+4)
|
||||
payload[0] = 0
|
||||
binary.BigEndian.PutUint16(payload[1:3], 53)
|
||||
binary.BigEndian.PutUint16(payload[3:5], 30)
|
||||
copy(payload[5:21], requestID[:])
|
||||
copy(payload[21:25], []byte{127, 0, 0, 1})
|
||||
|
||||
muxer1.handleRegistration(context.Background(), payload)
|
||||
session, exists := inboundInstance.datagramV3Manager.Get(requestID)
|
||||
if !exists {
|
||||
t.Fatal("expected v3 session after first registration")
|
||||
}
|
||||
|
||||
muxer2.handleRegistration(context.Background(), payload)
|
||||
|
||||
session.senderAccess.RLock()
|
||||
currentSender := session.sender
|
||||
session.senderAccess.RUnlock()
|
||||
if currentSender != sender2 {
|
||||
t.Fatal("expected v3 session sender migration to second sender")
|
||||
}
|
||||
|
||||
session.close()
|
||||
}
|
||||
|
||||
func TestDatagramV3MigrationUpdatesSessionContext(t *testing.T) {
|
||||
packetConn := newBlockingPacketConn()
|
||||
inboundInstance := newLimitedInbound(t, 0)
|
||||
inboundInstance.router = &packetDialingRouter{packetConn: packetConn}
|
||||
sender1 := &captureDatagramSender{}
|
||||
sender2 := &captureDatagramSender{}
|
||||
muxer1 := NewDatagramV3Muxer(inboundInstance, sender1, inboundInstance.logger)
|
||||
muxer2 := NewDatagramV3Muxer(inboundInstance, sender2, inboundInstance.logger)
|
||||
|
||||
requestID := RequestID{}
|
||||
requestID[15] = 10
|
||||
payload := make([]byte, 1+2+2+16+4)
|
||||
payload[0] = 0
|
||||
binary.BigEndian.PutUint16(payload[1:3], 53)
|
||||
binary.BigEndian.PutUint16(payload[3:5], 30)
|
||||
copy(payload[5:21], requestID[:])
|
||||
copy(payload[21:25], []byte{127, 0, 0, 1})
|
||||
|
||||
ctx1, cancel1 := context.WithCancel(context.Background())
|
||||
muxer1.handleRegistration(ctx1, payload)
|
||||
|
||||
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||
muxer2.handleRegistration(ctx2, payload)
|
||||
|
||||
cancel1()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
session, exists := inboundInstance.datagramV3Manager.Get(requestID)
|
||||
if !exists {
|
||||
t.Fatal("expected session to survive old connection context cancellation")
|
||||
}
|
||||
|
||||
session.senderAccess.RLock()
|
||||
currentSender := session.sender
|
||||
session.senderAccess.RUnlock()
|
||||
if currentSender != sender2 {
|
||||
t.Fatal("expected migrated sender to stay active")
|
||||
}
|
||||
|
||||
cancel2()
|
||||
|
||||
deadline := time.After(time.Second)
|
||||
for {
|
||||
if _, exists := inboundInstance.datagramV3Manager.Get(requestID); !exists {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-deadline:
|
||||
t.Fatal("expected session to be removed after new context cancellation")
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,231 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
|
||||
|
||||
"github.com/google/uuid"
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
)
|
||||
|
||||
func newRegisterUDPSessionCall(t *testing.T, traceContext string) (tunnelrpc.SessionManager_registerUdpSession, func() (tunnelrpc.RegisterUdpSessionResponse, error)) {
|
||||
return newRegisterUDPSessionCallWithDstIP(t, []byte{127, 0, 0, 1}, traceContext)
|
||||
}
|
||||
|
||||
func newRegisterUDPSessionCallWithDstIP(t *testing.T, dstIP []byte, traceContext string) (tunnelrpc.SessionManager_registerUdpSession, func() (tunnelrpc.RegisterUdpSessionResponse, error)) {
|
||||
t.Helper()
|
||||
|
||||
_, paramsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
params, err := tunnelrpc.NewSessionManager_registerUdpSession_Params(paramsSeg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sessionID := uuid.New()
|
||||
if err := params.SetSessionId(sessionID[:]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := params.SetDstIp(dstIP); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
params.SetDstPort(53)
|
||||
params.SetCloseAfterIdleHint(int64(30))
|
||||
if err := params.SetTraceContext(traceContext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, resultsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
results, err := tunnelrpc.NewSessionManager_registerUdpSession_Results(resultsSeg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
call := tunnelrpc.SessionManager_registerUdpSession{
|
||||
Ctx: context.Background(),
|
||||
Params: params,
|
||||
Results: results,
|
||||
}
|
||||
return call, results.Result
|
||||
}
|
||||
|
||||
func newUnregisterUDPSessionCall(t *testing.T) tunnelrpc.SessionManager_unregisterUdpSession {
|
||||
t.Helper()
|
||||
|
||||
_, paramsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
params, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Params(paramsSeg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sessionID := uuid.New()
|
||||
if err := params.SetSessionId(sessionID[:]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := params.SetMessage("close"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, resultsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
results, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Results(resultsSeg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return tunnelrpc.SessionManager_unregisterUdpSession{
|
||||
Ctx: context.Background(),
|
||||
Params: params,
|
||||
Results: results,
|
||||
}
|
||||
}
|
||||
|
||||
func newUnregisterUDPSessionCallForSession(t *testing.T, sessionID uuid.UUID, message string) tunnelrpc.SessionManager_unregisterUdpSession {
|
||||
t.Helper()
|
||||
|
||||
_, paramsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
params, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Params(paramsSeg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := params.SetSessionId(sessionID[:]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := params.SetMessage(message); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, resultsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
results, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Results(resultsSeg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return tunnelrpc.SessionManager_unregisterUdpSession{
|
||||
Ctx: context.Background(),
|
||||
Params: params,
|
||||
Results: results,
|
||||
}
|
||||
}
|
||||
|
||||
func TestV3RPCRegisterUDPSessionReturnsUnsupportedResult(t *testing.T) {
|
||||
server := &cloudflaredV3Server{
|
||||
inbound: &Inbound{Adapter: inbound.NewAdapter(C.TypeCloudflared, "test")},
|
||||
}
|
||||
call, readResult := newRegisterUDPSessionCall(t, "trace-context")
|
||||
if err := server.RegisterUdpSession(call); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result, err := readResult()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resultErr, err := result.Err()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resultErr != errUnsupportedDatagramV3UDPRegistration.Error() {
|
||||
t.Fatalf("unexpected registration error %q", resultErr)
|
||||
}
|
||||
spans, err := result.Spans()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(spans) != 0 {
|
||||
t.Fatalf("expected empty spans, got %x", spans)
|
||||
}
|
||||
}
|
||||
|
||||
func TestV3RPCUnregisterUDPSessionReturnsUnsupportedError(t *testing.T) {
|
||||
server := &cloudflaredV3Server{
|
||||
inbound: &Inbound{Adapter: inbound.NewAdapter(C.TypeCloudflared, "test")},
|
||||
}
|
||||
err := server.UnregisterUdpSession(newUnregisterUDPSessionCall(t))
|
||||
if err == nil {
|
||||
t.Fatal("expected unsupported unregister error")
|
||||
}
|
||||
if err.Error() != errUnsupportedDatagramV3UDPUnregistration.Error() {
|
||||
t.Fatalf("unexpected unregister error %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestV2RPCUnregisterUDPSessionPropagatesMessage(t *testing.T) {
|
||||
inboundInstance := newLimitedInbound(t, 0)
|
||||
inboundInstance.router = &packetDialingRouter{packetConn: newBlockingPacketConn()}
|
||||
muxer := NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger)
|
||||
|
||||
sessionID := uuid.New()
|
||||
if err := muxer.RegisterSession(context.Background(), sessionID, net.IPv4(127, 0, 0, 1), 53, time.Second); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
muxer.sessionAccess.RLock()
|
||||
session := muxer.sessions[sessionID]
|
||||
muxer.sessionAccess.RUnlock()
|
||||
if session == nil {
|
||||
t.Fatal("expected registered session")
|
||||
}
|
||||
|
||||
server := &cloudflaredServer{
|
||||
inbound: inboundInstance,
|
||||
muxer: muxer,
|
||||
ctx: context.Background(),
|
||||
logger: inboundInstance.logger,
|
||||
}
|
||||
if err := server.UnregisterUdpSession(newUnregisterUDPSessionCallForSession(t, sessionID, "edge close")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if reason := session.closeReason(); reason != "edge close" {
|
||||
t.Fatalf("expected close reason propagated from edge, got %q", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestV2RPCRegisterUDPSessionRejectsMissingDestinationIP(t *testing.T) {
|
||||
inboundInstance := newLimitedInbound(t, 0)
|
||||
inboundInstance.router = &packetDialingRouter{packetConn: newBlockingPacketConn()}
|
||||
server := &cloudflaredServer{
|
||||
inbound: inboundInstance,
|
||||
muxer: NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger),
|
||||
ctx: context.Background(),
|
||||
logger: inboundInstance.logger,
|
||||
}
|
||||
call, readResult := newRegisterUDPSessionCallWithDstIP(t, nil, "")
|
||||
|
||||
if err := server.RegisterUdpSession(call); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result, err := readResult()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resultErr, err := result.Err()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resultErr != "missing destination IP" {
|
||||
t.Fatalf("unexpected result error %q", resultErr)
|
||||
}
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"zombiezen.com/go/capnproto2/server"
|
||||
)
|
||||
|
||||
var (
|
||||
errUnsupportedDatagramV3UDPRegistration = errors.New("datagram v3 does not support RegisterUdpSession RPC")
|
||||
errUnsupportedDatagramV3UDPUnregistration = errors.New("datagram v3 does not support UnregisterUdpSession RPC")
|
||||
)
|
||||
|
||||
type cloudflaredV3Server struct {
|
||||
inbound *Inbound
|
||||
logger log.ContextLogger
|
||||
}
|
||||
|
||||
func (s *cloudflaredV3Server) RegisterUdpSession(call tunnelrpc.SessionManager_registerUdpSession) error {
|
||||
result, err := call.Results.NewResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := result.SetErr(errUnsupportedDatagramV3UDPRegistration.Error()); err != nil {
|
||||
return err
|
||||
}
|
||||
return result.SetSpans([]byte{})
|
||||
}
|
||||
|
||||
func (s *cloudflaredV3Server) UnregisterUdpSession(call tunnelrpc.SessionManager_unregisterUdpSession) error {
|
||||
return errUnsupportedDatagramV3UDPUnregistration
|
||||
}
|
||||
|
||||
func (s *cloudflaredV3Server) UpdateConfiguration(call tunnelrpc.ConfigurationManager_updateConfiguration) error {
|
||||
server.Ack(call.Options)
|
||||
version := call.Params.Version()
|
||||
configData, _ := call.Params.Config()
|
||||
updateResult := s.inbound.ApplyConfig(version, configData)
|
||||
result, err := call.Results.NewResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result.SetLatestAppliedVersion(updateResult.LastAppliedVersion)
|
||||
if updateResult.Err != nil {
|
||||
result.SetErr(updateResult.Err.Error())
|
||||
} else {
|
||||
result.SetErr("")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServeV3RPCStream serves configuration updates on v3 and rejects legacy UDP RPCs.
|
||||
func ServeV3RPCStream(ctx context.Context, stream io.ReadWriteCloser, inbound *Inbound, logger log.ContextLogger) {
|
||||
srv := &cloudflaredV3Server{
|
||||
inbound: inbound,
|
||||
logger: logger,
|
||||
}
|
||||
client := tunnelrpc.CloudflaredServer_ServerToClient(srv)
|
||||
transport := safeTransport(stream)
|
||||
rpcConn := newRPCServerConn(transport, client.Client)
|
||||
rpcCtx, cancel := context.WithTimeout(ctx, rpcTimeout)
|
||||
defer cancel()
|
||||
select {
|
||||
case <-rpcConn.Done():
|
||||
case <-rpcCtx.Done():
|
||||
}
|
||||
E.Errors(
|
||||
rpcConn.Close(),
|
||||
transport.Close(),
|
||||
)
|
||||
}
|
||||
@@ -1,568 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
"zombiezen.com/go/capnproto2/server"
|
||||
)
|
||||
|
||||
// V2 wire format: [payload | 16B sessionID | 1B type] (suffix-based)
|
||||
|
||||
// DatagramV2Type identifies the type of a V2 datagram.
|
||||
type DatagramV2Type byte
|
||||
|
||||
const (
|
||||
DatagramV2TypeUDP DatagramV2Type = 0
|
||||
DatagramV2TypeIP DatagramV2Type = 1
|
||||
DatagramV2TypeIPWithTrace DatagramV2Type = 2
|
||||
DatagramV2TypeTracingSpan DatagramV2Type = 3
|
||||
|
||||
sessionIDLength = 16
|
||||
typeIDLength = 1
|
||||
)
|
||||
|
||||
// DatagramV2Muxer handles V2 datagram demuxing and session management.
|
||||
type DatagramV2Muxer struct {
|
||||
inbound *Inbound
|
||||
logger log.ContextLogger
|
||||
sender DatagramSender
|
||||
icmp *ICMPBridge
|
||||
|
||||
sessionAccess sync.RWMutex
|
||||
sessions map[uuid.UUID]*udpSession
|
||||
}
|
||||
|
||||
// NewDatagramV2Muxer creates a new V2 datagram muxer.
|
||||
func NewDatagramV2Muxer(inbound *Inbound, sender DatagramSender, logger log.ContextLogger) *DatagramV2Muxer {
|
||||
return &DatagramV2Muxer{
|
||||
inbound: inbound,
|
||||
logger: logger,
|
||||
sender: sender,
|
||||
icmp: NewICMPBridge(inbound, sender, icmpWireV2),
|
||||
sessions: make(map[uuid.UUID]*udpSession),
|
||||
}
|
||||
}
|
||||
|
||||
type rpcStreamOpener interface {
|
||||
OpenRPCStream(ctx context.Context) (io.ReadWriteCloser, error)
|
||||
}
|
||||
|
||||
type v2SessionRPCClient interface {
|
||||
UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
var newV2SessionRPCClient = func(ctx context.Context, sender DatagramSender) (v2SessionRPCClient, error) {
|
||||
opener, ok := sender.(rpcStreamOpener)
|
||||
if !ok {
|
||||
return nil, E.New("sender does not support rpc streams")
|
||||
}
|
||||
stream, err := opener.OpenRPCStream(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
transport := safeTransport(stream)
|
||||
conn := newRPCClientConn(transport, ctx)
|
||||
return &capnpV2SessionRPCClient{
|
||||
client: tunnelrpc.SessionManager{Client: conn.Bootstrap(ctx)},
|
||||
rpcConn: conn,
|
||||
transport: transport,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type capnpV2SessionRPCClient struct {
|
||||
client tunnelrpc.SessionManager
|
||||
rpcConn *rpc.Conn
|
||||
transport rpc.Transport
|
||||
}
|
||||
|
||||
func (c *capnpV2SessionRPCClient) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string) error {
|
||||
promise := c.client.UnregisterUdpSession(ctx, func(p tunnelrpc.SessionManager_unregisterUdpSession_Params) error {
|
||||
if err := p.SetSessionId(sessionID[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
return p.SetMessage(message)
|
||||
})
|
||||
_, err := promise.Struct()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *capnpV2SessionRPCClient) Close() error {
|
||||
return E.Errors(c.rpcConn.Close(), c.transport.Close())
|
||||
}
|
||||
|
||||
// HandleDatagram demuxes an incoming V2 datagram.
|
||||
func (m *DatagramV2Muxer) HandleDatagram(ctx context.Context, data []byte) {
|
||||
if len(data) < typeIDLength {
|
||||
return
|
||||
}
|
||||
|
||||
datagramType := DatagramV2Type(data[len(data)-typeIDLength])
|
||||
payload := data[:len(data)-typeIDLength]
|
||||
|
||||
switch datagramType {
|
||||
case DatagramV2TypeUDP:
|
||||
m.handleUDPDatagram(ctx, payload)
|
||||
case DatagramV2TypeIP:
|
||||
if err := m.icmp.HandleV2(ctx, datagramType, payload); err != nil {
|
||||
m.logger.Debug("drop V2 ICMP datagram: ", err)
|
||||
}
|
||||
case DatagramV2TypeIPWithTrace:
|
||||
if err := m.icmp.HandleV2(ctx, datagramType, payload); err != nil {
|
||||
m.logger.Debug("drop V2 traced ICMP datagram: ", err)
|
||||
}
|
||||
case DatagramV2TypeTracingSpan:
|
||||
// Tracing spans, ignore
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DatagramV2Muxer) handleUDPDatagram(ctx context.Context, data []byte) {
|
||||
if len(data) < sessionIDLength {
|
||||
return
|
||||
}
|
||||
|
||||
payload := data[:len(data)-sessionIDLength]
|
||||
sessionID, err := uuid.FromBytes(data[len(data)-sessionIDLength:])
|
||||
if err != nil {
|
||||
m.logger.Debug("invalid session ID in V2 datagram: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.sessionAccess.RLock()
|
||||
session, exists := m.sessions[sessionID]
|
||||
m.sessionAccess.RUnlock()
|
||||
|
||||
if !exists {
|
||||
m.logger.Debug("unknown V2 UDP session: ", sessionID)
|
||||
return
|
||||
}
|
||||
|
||||
session.writeToOrigin(payload)
|
||||
}
|
||||
|
||||
// RegisterSession registers a new UDP session from an RPC call.
|
||||
func (m *DatagramV2Muxer) RegisterSession(
|
||||
ctx context.Context,
|
||||
sessionID uuid.UUID,
|
||||
destinationIP net.IP,
|
||||
destinationPort uint16,
|
||||
closeAfterIdle time.Duration,
|
||||
) error {
|
||||
if destinationIP == nil {
|
||||
return E.New("missing destination IP")
|
||||
}
|
||||
var destinationAddr netip.Addr
|
||||
if ip4 := destinationIP.To4(); ip4 != nil {
|
||||
destinationAddr = netip.AddrFrom4([4]byte(ip4))
|
||||
} else if ip16 := destinationIP.To16(); ip16 != nil {
|
||||
destinationAddr = netip.AddrFrom16([16]byte(ip16))
|
||||
} else {
|
||||
return E.New("invalid destination IP")
|
||||
}
|
||||
destination := netip.AddrPortFrom(destinationAddr, destinationPort)
|
||||
|
||||
if closeAfterIdle == 0 {
|
||||
closeAfterIdle = 210 * time.Second
|
||||
}
|
||||
|
||||
m.sessionAccess.Lock()
|
||||
if _, exists := m.sessions[sessionID]; exists {
|
||||
m.sessionAccess.Unlock()
|
||||
return nil
|
||||
}
|
||||
limit := m.inbound.maxActiveFlows()
|
||||
if !m.inbound.flowLimiter.Acquire(limit) {
|
||||
m.sessionAccess.Unlock()
|
||||
return E.New("too many active flows")
|
||||
}
|
||||
|
||||
origin, err := m.inbound.dialWarpPacketConnection(ctx, destination)
|
||||
if err != nil {
|
||||
m.inbound.flowLimiter.Release(limit)
|
||||
m.sessionAccess.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
session := newUDPSession(sessionID, destination, closeAfterIdle, origin, m)
|
||||
m.sessions[sessionID] = session
|
||||
m.sessionAccess.Unlock()
|
||||
|
||||
m.logger.Info("registered V2 UDP session ", sessionID, " to ", destination)
|
||||
|
||||
go m.serveSession(ctx, session, limit)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnregisterSession removes a UDP session.
|
||||
func (m *DatagramV2Muxer) UnregisterSession(sessionID uuid.UUID, message string) {
|
||||
m.sessionAccess.Lock()
|
||||
session, exists := m.sessions[sessionID]
|
||||
if exists {
|
||||
delete(m.sessions, sessionID)
|
||||
}
|
||||
m.sessionAccess.Unlock()
|
||||
|
||||
if exists {
|
||||
session.markRemoteClosed(message)
|
||||
session.close()
|
||||
m.logger.Info("unregistered V2 UDP session ", sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DatagramV2Muxer) serveSession(ctx context.Context, session *udpSession, limit uint64) {
|
||||
defer m.inbound.flowLimiter.Release(limit)
|
||||
|
||||
session.serve(ctx)
|
||||
|
||||
m.sessionAccess.Lock()
|
||||
if current, exists := m.sessions[session.id]; exists && current == session {
|
||||
delete(m.sessions, session.id)
|
||||
}
|
||||
m.sessionAccess.Unlock()
|
||||
|
||||
if !session.remoteClosed() {
|
||||
unregisterCtx, cancel := context.WithTimeout(context.Background(), rpcTimeout)
|
||||
defer cancel()
|
||||
if err := m.unregisterRemoteSession(unregisterCtx, session.id, session.closeReason()); err != nil {
|
||||
m.logger.Debug("failed to unregister V2 UDP session ", session.id, ": ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendToEdge sends a V2 UDP datagram back to the edge.
|
||||
func (m *DatagramV2Muxer) sendToEdge(sessionID uuid.UUID, payload []byte) {
|
||||
data := make([]byte, len(payload)+sessionIDLength+typeIDLength)
|
||||
copy(data, payload)
|
||||
copy(data[len(payload):], sessionID[:])
|
||||
data[len(data)-1] = byte(DatagramV2TypeUDP)
|
||||
m.sender.SendDatagram(data)
|
||||
}
|
||||
|
||||
// Close closes all sessions.
|
||||
func (m *DatagramV2Muxer) Close() {
|
||||
m.sessionAccess.Lock()
|
||||
sessions := m.sessions
|
||||
m.sessions = make(map[uuid.UUID]*udpSession)
|
||||
m.sessionAccess.Unlock()
|
||||
|
||||
for _, session := range sessions {
|
||||
session.close()
|
||||
}
|
||||
}
|
||||
|
||||
// udpSession represents a V2 UDP session.
|
||||
type udpSession struct {
|
||||
id uuid.UUID
|
||||
destination netip.AddrPort
|
||||
closeAfterIdle time.Duration
|
||||
origin N.PacketConn
|
||||
muxer *DatagramV2Muxer
|
||||
|
||||
writeChan chan []byte
|
||||
closeOnce sync.Once
|
||||
closeChan chan struct{}
|
||||
|
||||
activeAccess sync.RWMutex
|
||||
activeAt time.Time
|
||||
|
||||
stateAccess sync.RWMutex
|
||||
closedByRemote bool
|
||||
closeReasonString string
|
||||
}
|
||||
|
||||
func newUDPSession(id uuid.UUID, destination netip.AddrPort, closeAfterIdle time.Duration, origin N.PacketConn, muxer *DatagramV2Muxer) *udpSession {
|
||||
return &udpSession{
|
||||
id: id,
|
||||
destination: destination,
|
||||
closeAfterIdle: closeAfterIdle,
|
||||
origin: origin,
|
||||
muxer: muxer,
|
||||
writeChan: make(chan []byte, 256),
|
||||
closeChan: make(chan struct{}),
|
||||
activeAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *udpSession) writeToOrigin(payload []byte) {
|
||||
data := make([]byte, len(payload))
|
||||
copy(data, payload)
|
||||
select {
|
||||
case s.writeChan <- data:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *udpSession) close() {
|
||||
s.closeOnce.Do(func() {
|
||||
if s.origin != nil {
|
||||
_ = s.origin.Close()
|
||||
}
|
||||
close(s.closeChan)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *udpSession) serve(ctx context.Context) {
|
||||
go s.readLoop()
|
||||
go s.writeLoop()
|
||||
|
||||
tickInterval := s.closeAfterIdle / 2
|
||||
if tickInterval <= 0 || tickInterval > 10*time.Second {
|
||||
tickInterval = time.Second
|
||||
}
|
||||
ticker := time.NewTicker(tickInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.closeWithReason("connection closed")
|
||||
case <-ticker.C:
|
||||
if time.Since(s.lastActive()) >= s.closeAfterIdle {
|
||||
s.closeWithReason("idle timeout")
|
||||
}
|
||||
case <-s.closeChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *udpSession) readLoop() {
|
||||
for {
|
||||
buffer := buf.NewPacket()
|
||||
_, err := s.origin.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
s.closeWithReason(err.Error())
|
||||
return
|
||||
}
|
||||
s.markActive()
|
||||
s.muxer.sendToEdge(s.id, append([]byte(nil), buffer.Bytes()...))
|
||||
buffer.Release()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *udpSession) writeLoop() {
|
||||
for {
|
||||
select {
|
||||
case payload := <-s.writeChan:
|
||||
err := s.origin.WritePacket(buf.As(payload), M.SocksaddrFromNetIP(s.destination))
|
||||
if err != nil {
|
||||
s.closeWithReason(err.Error())
|
||||
return
|
||||
}
|
||||
s.markActive()
|
||||
case <-s.closeChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *udpSession) markActive() {
|
||||
s.activeAccess.Lock()
|
||||
s.activeAt = time.Now()
|
||||
s.activeAccess.Unlock()
|
||||
}
|
||||
|
||||
func (s *udpSession) lastActive() time.Time {
|
||||
s.activeAccess.RLock()
|
||||
defer s.activeAccess.RUnlock()
|
||||
return s.activeAt
|
||||
}
|
||||
|
||||
func (s *udpSession) closeWithReason(reason string) {
|
||||
s.stateAccess.Lock()
|
||||
if s.closeReasonString == "" {
|
||||
s.closeReasonString = reason
|
||||
}
|
||||
s.stateAccess.Unlock()
|
||||
s.close()
|
||||
}
|
||||
|
||||
func (s *udpSession) markRemoteClosed(message string) {
|
||||
s.stateAccess.Lock()
|
||||
s.closedByRemote = true
|
||||
if message != "" {
|
||||
s.closeReasonString = message
|
||||
} else if s.closeReasonString == "" {
|
||||
s.closeReasonString = "unregistered by edge"
|
||||
}
|
||||
s.stateAccess.Unlock()
|
||||
}
|
||||
|
||||
func (s *udpSession) remoteClosed() bool {
|
||||
s.stateAccess.RLock()
|
||||
defer s.stateAccess.RUnlock()
|
||||
return s.closedByRemote
|
||||
}
|
||||
|
||||
func (s *udpSession) closeReason() string {
|
||||
s.stateAccess.RLock()
|
||||
defer s.stateAccess.RUnlock()
|
||||
if s.closeReasonString == "" {
|
||||
return "session closed"
|
||||
}
|
||||
return s.closeReasonString
|
||||
}
|
||||
|
||||
// ReadPacket implements N.PacketConn - reads packets from the edge to forward to origin.
|
||||
func (s *udpSession) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||
select {
|
||||
case data := <-s.writeChan:
|
||||
_, err := buffer.Write(data)
|
||||
return M.SocksaddrFromNetIP(s.destination), err
|
||||
case <-s.closeChan:
|
||||
return M.Socksaddr{}, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
// WritePacket implements N.PacketConn - receives packets from origin to forward to edge.
|
||||
func (s *udpSession) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
s.muxer.sendToEdge(s.id, buffer.Bytes())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *udpSession) Close() error {
|
||||
s.close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *udpSession) LocalAddr() net.Addr { return nil }
|
||||
func (s *udpSession) SetDeadline(_ time.Time) error { return nil }
|
||||
func (s *udpSession) SetReadDeadline(_ time.Time) error { return nil }
|
||||
func (s *udpSession) SetWriteDeadline(_ time.Time) error { return nil }
|
||||
|
||||
func (m *DatagramV2Muxer) unregisterRemoteSession(ctx context.Context, sessionID uuid.UUID, message string) error {
|
||||
client, err := newV2SessionRPCClient(ctx, m.sender)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
return client.UnregisterSession(ctx, sessionID, message)
|
||||
}
|
||||
|
||||
// V2 RPC server implementation for HandleRPCStream.
|
||||
|
||||
type cloudflaredServer struct {
|
||||
inbound *Inbound
|
||||
muxer *DatagramV2Muxer
|
||||
ctx context.Context
|
||||
logger log.ContextLogger
|
||||
}
|
||||
|
||||
func (s *cloudflaredServer) RegisterUdpSession(call tunnelrpc.SessionManager_registerUdpSession) error {
|
||||
server.Ack(call.Options)
|
||||
sessionIDBytes, err := call.Params.SessionId()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sessionID, err := uuid.FromBytes(sessionIDBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
destinationIP, err := call.Params.DstIp()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
destinationPort := call.Params.DstPort()
|
||||
closeAfterIdle := time.Duration(call.Params.CloseAfterIdleHint())
|
||||
if _, traceErr := call.Params.TraceContext(); traceErr != nil {
|
||||
return traceErr
|
||||
}
|
||||
|
||||
if len(destinationIP) == 0 {
|
||||
err = E.New("missing destination IP")
|
||||
} else {
|
||||
err = s.muxer.RegisterSession(s.ctx, sessionID, net.IP(destinationIP), destinationPort, closeAfterIdle)
|
||||
}
|
||||
|
||||
result, allocErr := call.Results.NewResult()
|
||||
if allocErr != nil {
|
||||
return allocErr
|
||||
}
|
||||
if spansErr := result.SetSpans([]byte{}); spansErr != nil {
|
||||
return spansErr
|
||||
}
|
||||
if err != nil {
|
||||
result.SetErr(err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cloudflaredServer) UnregisterUdpSession(call tunnelrpc.SessionManager_unregisterUdpSession) error {
|
||||
server.Ack(call.Options)
|
||||
sessionIDBytes, err := call.Params.SessionId()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sessionID, err := uuid.FromBytes(sessionIDBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
message, err := call.Params.Message()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.muxer.UnregisterSession(sessionID, message)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cloudflaredServer) UpdateConfiguration(call tunnelrpc.ConfigurationManager_updateConfiguration) error {
|
||||
server.Ack(call.Options)
|
||||
version := call.Params.Version()
|
||||
configData, _ := call.Params.Config()
|
||||
updateResult := s.inbound.ApplyConfig(version, configData)
|
||||
result, err := call.Results.NewResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result.SetLatestAppliedVersion(updateResult.LastAppliedVersion)
|
||||
if updateResult.Err != nil {
|
||||
result.SetErr(updateResult.Err.Error())
|
||||
} else {
|
||||
result.SetErr("")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServeRPCStream handles an incoming V2 RPC stream (session management + configuration).
|
||||
func ServeRPCStream(ctx context.Context, stream io.ReadWriteCloser, inbound *Inbound, muxer *DatagramV2Muxer, logger log.ContextLogger) {
|
||||
srv := &cloudflaredServer{
|
||||
inbound: inbound,
|
||||
muxer: muxer,
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
}
|
||||
client := tunnelrpc.CloudflaredServer_ServerToClient(srv)
|
||||
transport := safeTransport(stream)
|
||||
rpcConn := newRPCServerConn(transport, client.Client)
|
||||
rpcCtx, cancel := context.WithTimeout(ctx, rpcTimeout)
|
||||
defer cancel()
|
||||
select {
|
||||
case <-rpcConn.Done():
|
||||
case <-rpcCtx.Done():
|
||||
}
|
||||
E.Errors(
|
||||
rpcConn.Close(),
|
||||
transport.Close(),
|
||||
)
|
||||
}
|
||||
@@ -1,483 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
// V3 wire format: [1B type | payload] (prefix-based)
|
||||
|
||||
// DatagramV3Type identifies the type of a V3 datagram.
|
||||
type DatagramV3Type byte
|
||||
|
||||
const (
|
||||
DatagramV3TypeRegistration DatagramV3Type = 0
|
||||
DatagramV3TypePayload DatagramV3Type = 1
|
||||
DatagramV3TypeICMP DatagramV3Type = 2
|
||||
DatagramV3TypeRegistrationResponse DatagramV3Type = 3
|
||||
|
||||
// V3 registration header sizes
|
||||
v3RegistrationFlagLen = 1
|
||||
v3RegistrationPortLen = 2
|
||||
v3RegistrationIdleLen = 2
|
||||
v3RequestIDLength = 16
|
||||
v3IPv4AddrLen = 4
|
||||
v3IPv6AddrLen = 16
|
||||
v3RegistrationBaseLen = 1 + v3RegistrationFlagLen + v3RegistrationPortLen + v3RegistrationIdleLen + v3RequestIDLength // 22
|
||||
v3PayloadHeaderLen = 1 + v3RequestIDLength // 17
|
||||
v3RegistrationRespLen = 1 + 1 + v3RequestIDLength + 2 // 20
|
||||
maxV3UDPPayloadLen = 1280
|
||||
|
||||
// V3 registration flags
|
||||
v3FlagIPv6 byte = 0x01
|
||||
v3FlagTraced byte = 0x02
|
||||
v3FlagBundle byte = 0x04
|
||||
|
||||
// V3 registration response types
|
||||
v3ResponseOK byte = 0x00
|
||||
v3ResponseDestinationUnreachable byte = 0x01
|
||||
v3ResponseUnableToBindSocket byte = 0x02
|
||||
v3ResponseTooManyActiveFlows byte = 0x03
|
||||
v3ResponseErrorWithMsg byte = 0xFF
|
||||
)
|
||||
|
||||
// RequestID is a 128-bit session identifier for V3.
|
||||
type RequestID [v3RequestIDLength]byte
|
||||
|
||||
type v3RegistrationState uint8
|
||||
|
||||
const (
|
||||
v3RegistrationNew v3RegistrationState = iota
|
||||
v3RegistrationExisting
|
||||
v3RegistrationMigrated
|
||||
)
|
||||
|
||||
type DatagramV3SessionManager struct {
|
||||
sessionAccess sync.RWMutex
|
||||
sessions map[RequestID]*v3Session
|
||||
}
|
||||
|
||||
func NewDatagramV3SessionManager() *DatagramV3SessionManager {
|
||||
return &DatagramV3SessionManager{
|
||||
sessions: make(map[RequestID]*v3Session),
|
||||
}
|
||||
}
|
||||
|
||||
// DatagramV3Muxer handles V3 datagram demuxing and session management.
|
||||
type DatagramV3Muxer struct {
|
||||
inbound *Inbound
|
||||
logger log.ContextLogger
|
||||
sender DatagramSender
|
||||
icmp *ICMPBridge
|
||||
}
|
||||
|
||||
// NewDatagramV3Muxer creates a new V3 datagram muxer.
|
||||
func NewDatagramV3Muxer(inbound *Inbound, sender DatagramSender, logger log.ContextLogger) *DatagramV3Muxer {
|
||||
return &DatagramV3Muxer{
|
||||
inbound: inbound,
|
||||
logger: logger,
|
||||
sender: sender,
|
||||
icmp: NewICMPBridge(inbound, sender, icmpWireV3),
|
||||
}
|
||||
}
|
||||
|
||||
// HandleDatagram demuxes an incoming V3 datagram.
|
||||
func (m *DatagramV3Muxer) HandleDatagram(ctx context.Context, data []byte) {
|
||||
if len(data) < 1 {
|
||||
return
|
||||
}
|
||||
|
||||
datagramType := DatagramV3Type(data[0])
|
||||
payload := data[1:]
|
||||
|
||||
switch datagramType {
|
||||
case DatagramV3TypeRegistration:
|
||||
m.handleRegistration(ctx, payload)
|
||||
case DatagramV3TypePayload:
|
||||
m.handlePayload(payload)
|
||||
case DatagramV3TypeICMP:
|
||||
if err := m.icmp.HandleV3(ctx, payload); err != nil {
|
||||
m.logger.Debug("drop V3 ICMP datagram: ", err)
|
||||
}
|
||||
case DatagramV3TypeRegistrationResponse:
|
||||
// Unexpected - we never send registrations
|
||||
m.logger.Debug("received unexpected V3 registration response")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DatagramV3Muxer) handleRegistration(ctx context.Context, data []byte) {
|
||||
if len(data) < v3RegistrationFlagLen+v3RegistrationPortLen+v3RegistrationIdleLen+v3RequestIDLength {
|
||||
m.logger.Debug("V3 registration too short")
|
||||
return
|
||||
}
|
||||
|
||||
flags := data[0]
|
||||
destinationPort := binary.BigEndian.Uint16(data[1:3])
|
||||
idleDurationSeconds := binary.BigEndian.Uint16(data[3:5])
|
||||
|
||||
var requestID RequestID
|
||||
copy(requestID[:], data[5:5+v3RequestIDLength])
|
||||
|
||||
offset := 5 + v3RequestIDLength
|
||||
var destination netip.AddrPort
|
||||
|
||||
if flags&v3FlagIPv6 != 0 {
|
||||
if len(data) < offset+v3IPv6AddrLen {
|
||||
m.sendRegistrationResponse(requestID, v3ResponseErrorWithMsg, "registration too short for IPv6")
|
||||
return
|
||||
}
|
||||
var addr [16]byte
|
||||
copy(addr[:], data[offset:offset+v3IPv6AddrLen])
|
||||
destination = netip.AddrPortFrom(netip.AddrFrom16(addr), destinationPort)
|
||||
offset += v3IPv6AddrLen
|
||||
} else {
|
||||
if len(data) < offset+v3IPv4AddrLen {
|
||||
m.sendRegistrationResponse(requestID, v3ResponseErrorWithMsg, "registration too short for IPv4")
|
||||
return
|
||||
}
|
||||
var addr [4]byte
|
||||
copy(addr[:], data[offset:offset+v3IPv4AddrLen])
|
||||
destination = netip.AddrPortFrom(netip.AddrFrom4(addr), destinationPort)
|
||||
offset += v3IPv4AddrLen
|
||||
}
|
||||
|
||||
closeAfterIdle := time.Duration(idleDurationSeconds) * time.Second
|
||||
if closeAfterIdle == 0 {
|
||||
closeAfterIdle = 210 * time.Second
|
||||
}
|
||||
if !destination.Addr().IsValid() || destination.Addr().IsUnspecified() || destination.Port() == 0 {
|
||||
m.sendRegistrationResponse(requestID, v3ResponseDestinationUnreachable, "")
|
||||
return
|
||||
}
|
||||
|
||||
session, state, err := m.inbound.datagramV3Manager.Register(m.inbound, ctx, requestID, destination, closeAfterIdle, m.sender)
|
||||
if err == errTooManyActiveFlows {
|
||||
m.sendRegistrationResponse(requestID, v3ResponseTooManyActiveFlows, "")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
m.sendRegistrationResponse(requestID, v3ResponseUnableToBindSocket, "")
|
||||
return
|
||||
}
|
||||
|
||||
if state == v3RegistrationNew {
|
||||
m.logger.Info("registered V3 UDP session to ", destination)
|
||||
}
|
||||
m.sendRegistrationResponse(requestID, v3ResponseOK, "")
|
||||
|
||||
// Handle bundled first payload
|
||||
if flags&v3FlagBundle != 0 && len(data) > offset {
|
||||
session.writeToOrigin(data[offset:])
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DatagramV3Muxer) handlePayload(data []byte) {
|
||||
if len(data) < v3RequestIDLength || len(data) > v3RequestIDLength+maxV3UDPPayloadLen {
|
||||
return
|
||||
}
|
||||
|
||||
var requestID RequestID
|
||||
copy(requestID[:], data[:v3RequestIDLength])
|
||||
payload := data[v3RequestIDLength:]
|
||||
|
||||
session, exists := m.inbound.datagramV3Manager.Get(requestID)
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
session.writeToOrigin(payload)
|
||||
}
|
||||
|
||||
func (m *DatagramV3Muxer) sendRegistrationResponse(requestID RequestID, responseType byte, errorMessage string) {
|
||||
errorBytes := []byte(errorMessage)
|
||||
data := make([]byte, v3RegistrationRespLen+len(errorBytes))
|
||||
data[0] = byte(DatagramV3TypeRegistrationResponse)
|
||||
data[1] = responseType
|
||||
copy(data[2:2+v3RequestIDLength], requestID[:])
|
||||
binary.BigEndian.PutUint16(data[2+v3RequestIDLength:], uint16(len(errorBytes)))
|
||||
copy(data[v3RegistrationRespLen:], errorBytes)
|
||||
m.sender.SendDatagram(data)
|
||||
}
|
||||
|
||||
func (m *DatagramV3Muxer) sendPayload(requestID RequestID, payload []byte) {
|
||||
data := make([]byte, v3PayloadHeaderLen+len(payload))
|
||||
data[0] = byte(DatagramV3TypePayload)
|
||||
copy(data[1:1+v3RequestIDLength], requestID[:])
|
||||
copy(data[v3PayloadHeaderLen:], payload)
|
||||
m.sender.SendDatagram(data)
|
||||
}
|
||||
|
||||
// Close closes all V3 sessions.
|
||||
func (m *DatagramV3Muxer) Close() {}
|
||||
|
||||
// v3Session represents a V3 UDP session.
|
||||
type v3Session struct {
|
||||
id RequestID
|
||||
destination netip.AddrPort
|
||||
closeAfterIdle time.Duration
|
||||
origin N.PacketConn
|
||||
manager *DatagramV3SessionManager
|
||||
inbound *Inbound
|
||||
|
||||
writeChan chan []byte
|
||||
closeOnce sync.Once
|
||||
closeChan chan struct{}
|
||||
|
||||
activeAccess sync.RWMutex
|
||||
activeAt time.Time
|
||||
|
||||
senderAccess sync.RWMutex
|
||||
sender DatagramSender
|
||||
|
||||
contextAccess sync.RWMutex
|
||||
connCtx context.Context
|
||||
contextChan chan context.Context
|
||||
}
|
||||
|
||||
var errTooManyActiveFlows = errors.New("too many active flows")
|
||||
|
||||
func (m *DatagramV3SessionManager) Register(
|
||||
inbound *Inbound,
|
||||
ctx context.Context,
|
||||
requestID RequestID,
|
||||
destination netip.AddrPort,
|
||||
closeAfterIdle time.Duration,
|
||||
sender DatagramSender,
|
||||
) (*v3Session, v3RegistrationState, error) {
|
||||
m.sessionAccess.Lock()
|
||||
if existing, exists := m.sessions[requestID]; exists {
|
||||
if existing.sender == sender {
|
||||
existing.updateContext(ctx)
|
||||
existing.markActive()
|
||||
m.sessionAccess.Unlock()
|
||||
return existing, v3RegistrationExisting, nil
|
||||
}
|
||||
existing.migrate(sender, ctx)
|
||||
existing.markActive()
|
||||
m.sessionAccess.Unlock()
|
||||
return existing, v3RegistrationMigrated, nil
|
||||
}
|
||||
|
||||
limit := inbound.maxActiveFlows()
|
||||
if !inbound.flowLimiter.Acquire(limit) {
|
||||
m.sessionAccess.Unlock()
|
||||
return nil, 0, errTooManyActiveFlows
|
||||
}
|
||||
origin, err := inbound.dialWarpPacketConnection(ctx, destination)
|
||||
if err != nil {
|
||||
inbound.flowLimiter.Release(limit)
|
||||
m.sessionAccess.Unlock()
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
session := &v3Session{
|
||||
id: requestID,
|
||||
destination: destination,
|
||||
closeAfterIdle: closeAfterIdle,
|
||||
origin: origin,
|
||||
manager: m,
|
||||
inbound: inbound,
|
||||
writeChan: make(chan []byte, 512),
|
||||
closeChan: make(chan struct{}),
|
||||
activeAt: time.Now(),
|
||||
sender: sender,
|
||||
connCtx: ctx,
|
||||
contextChan: make(chan context.Context, 1),
|
||||
}
|
||||
m.sessions[requestID] = session
|
||||
m.sessionAccess.Unlock()
|
||||
|
||||
sessionCtx := ctx
|
||||
if sessionCtx == nil {
|
||||
sessionCtx = context.Background()
|
||||
}
|
||||
session.connCtx = sessionCtx
|
||||
go session.serve(sessionCtx, limit)
|
||||
return session, v3RegistrationNew, nil
|
||||
}
|
||||
|
||||
func (m *DatagramV3SessionManager) Get(requestID RequestID) (*v3Session, bool) {
|
||||
m.sessionAccess.RLock()
|
||||
defer m.sessionAccess.RUnlock()
|
||||
session, exists := m.sessions[requestID]
|
||||
return session, exists
|
||||
}
|
||||
|
||||
func (m *DatagramV3SessionManager) remove(session *v3Session) {
|
||||
m.sessionAccess.Lock()
|
||||
if current, exists := m.sessions[session.id]; exists && current == session {
|
||||
delete(m.sessions, session.id)
|
||||
}
|
||||
m.sessionAccess.Unlock()
|
||||
}
|
||||
|
||||
func (s *v3Session) serve(ctx context.Context, limit uint64) {
|
||||
defer s.inbound.flowLimiter.Release(limit)
|
||||
defer s.manager.remove(s)
|
||||
|
||||
go s.readLoop()
|
||||
go s.writeLoop()
|
||||
|
||||
connCtx := ctx
|
||||
|
||||
tickInterval := s.closeAfterIdle / 2
|
||||
if tickInterval <= 0 || tickInterval > 10*time.Second {
|
||||
tickInterval = time.Second
|
||||
}
|
||||
ticker := time.NewTicker(tickInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-connCtx.Done():
|
||||
if latestCtx := s.currentContext(); latestCtx != nil && latestCtx != connCtx {
|
||||
connCtx = latestCtx
|
||||
continue
|
||||
}
|
||||
s.close()
|
||||
case newCtx := <-s.contextChan:
|
||||
if newCtx != nil {
|
||||
connCtx = newCtx
|
||||
}
|
||||
case <-ticker.C:
|
||||
if time.Since(s.lastActive()) >= s.closeAfterIdle {
|
||||
s.close()
|
||||
}
|
||||
case <-s.closeChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *v3Session) readLoop() {
|
||||
for {
|
||||
buffer := buf.NewPacket()
|
||||
_, err := s.origin.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
s.close()
|
||||
return
|
||||
}
|
||||
if buffer.Len() > maxV3UDPPayloadLen {
|
||||
s.inbound.logger.Debug("drop oversized V3 UDP payload: ", buffer.Len())
|
||||
buffer.Release()
|
||||
continue
|
||||
}
|
||||
s.markActive()
|
||||
if err := s.senderDatagram(append([]byte(nil), buffer.Bytes()...)); err != nil {
|
||||
buffer.Release()
|
||||
s.close()
|
||||
return
|
||||
}
|
||||
buffer.Release()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *v3Session) writeLoop() {
|
||||
for {
|
||||
select {
|
||||
case payload := <-s.writeChan:
|
||||
err := s.origin.WritePacket(buf.As(payload), M.SocksaddrFromNetIP(s.destination))
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
s.inbound.logger.Debug("drop V3 UDP payload due to write deadline exceeded")
|
||||
continue
|
||||
}
|
||||
s.close()
|
||||
return
|
||||
}
|
||||
s.markActive()
|
||||
case <-s.closeChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *v3Session) writeToOrigin(payload []byte) {
|
||||
data := make([]byte, len(payload))
|
||||
copy(data, payload)
|
||||
select {
|
||||
case s.writeChan <- data:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *v3Session) senderDatagram(payload []byte) error {
|
||||
data := make([]byte, v3PayloadHeaderLen+len(payload))
|
||||
data[0] = byte(DatagramV3TypePayload)
|
||||
copy(data[1:1+v3RequestIDLength], s.id[:])
|
||||
copy(data[v3PayloadHeaderLen:], payload)
|
||||
|
||||
s.senderAccess.RLock()
|
||||
sender := s.sender
|
||||
s.senderAccess.RUnlock()
|
||||
return sender.SendDatagram(data)
|
||||
}
|
||||
|
||||
func (s *v3Session) setSender(sender DatagramSender) {
|
||||
s.senderAccess.Lock()
|
||||
s.sender = sender
|
||||
s.senderAccess.Unlock()
|
||||
}
|
||||
|
||||
func (s *v3Session) updateContext(ctx context.Context) {
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
s.contextAccess.Lock()
|
||||
s.connCtx = ctx
|
||||
s.contextAccess.Unlock()
|
||||
select {
|
||||
case s.contextChan <- ctx:
|
||||
default:
|
||||
select {
|
||||
case <-s.contextChan:
|
||||
default:
|
||||
}
|
||||
s.contextChan <- ctx
|
||||
}
|
||||
}
|
||||
|
||||
func (s *v3Session) migrate(sender DatagramSender, ctx context.Context) {
|
||||
s.setSender(sender)
|
||||
s.updateContext(ctx)
|
||||
}
|
||||
|
||||
func (s *v3Session) currentContext() context.Context {
|
||||
s.contextAccess.RLock()
|
||||
defer s.contextAccess.RUnlock()
|
||||
return s.connCtx
|
||||
}
|
||||
|
||||
func (s *v3Session) markActive() {
|
||||
s.activeAccess.Lock()
|
||||
s.activeAt = time.Now()
|
||||
s.activeAccess.Unlock()
|
||||
}
|
||||
|
||||
func (s *v3Session) lastActive() time.Time {
|
||||
s.activeAccess.RLock()
|
||||
defer s.activeAccess.RUnlock()
|
||||
return s.activeAt
|
||||
}
|
||||
|
||||
func (s *v3Session) close() {
|
||||
s.closeOnce.Do(func() {
|
||||
if s.origin != nil {
|
||||
_ = s.origin.Close()
|
||||
}
|
||||
close(s.closeChan)
|
||||
})
|
||||
}
|
||||
@@ -1,232 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
func TestDatagramV3RegistrationDestinationUnreachable(t *testing.T) {
|
||||
sender := &captureDatagramSender{}
|
||||
inboundInstance := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
flowLimiter: &FlowLimiter{},
|
||||
datagramV3Manager: NewDatagramV3SessionManager(),
|
||||
}
|
||||
muxer := NewDatagramV3Muxer(inboundInstance, sender, nil)
|
||||
|
||||
requestID := RequestID{}
|
||||
requestID[15] = 1
|
||||
payload := make([]byte, 1+2+2+16+4)
|
||||
payload[0] = 0
|
||||
binary.BigEndian.PutUint16(payload[1:3], 0)
|
||||
binary.BigEndian.PutUint16(payload[3:5], 30)
|
||||
copy(payload[5:21], requestID[:])
|
||||
copy(payload[21:25], []byte{0, 0, 0, 0})
|
||||
|
||||
muxer.handleRegistration(context.Background(), payload)
|
||||
if len(sender.sent) != 1 {
|
||||
t.Fatalf("expected one registration response, got %d", len(sender.sent))
|
||||
}
|
||||
if sender.sent[0][0] != byte(DatagramV3TypeRegistrationResponse) || sender.sent[0][1] != v3ResponseDestinationUnreachable {
|
||||
t.Fatalf("unexpected datagram response: %v", sender.sent[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatagramV3RegistrationErrorWithMessage(t *testing.T) {
|
||||
sender := &captureDatagramSender{}
|
||||
inboundInstance := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
flowLimiter: &FlowLimiter{},
|
||||
datagramV3Manager: NewDatagramV3SessionManager(),
|
||||
}
|
||||
muxer := NewDatagramV3Muxer(inboundInstance, sender, nil)
|
||||
|
||||
requestID := RequestID{}
|
||||
requestID[15] = 2
|
||||
payload := make([]byte, 1+2+2+16+1)
|
||||
payload[0] = 1
|
||||
binary.BigEndian.PutUint16(payload[1:3], 53)
|
||||
binary.BigEndian.PutUint16(payload[3:5], 30)
|
||||
copy(payload[5:21], requestID[:])
|
||||
payload[21] = 0xaa
|
||||
|
||||
muxer.handleRegistration(context.Background(), payload)
|
||||
if len(sender.sent) != 1 {
|
||||
t.Fatalf("expected one registration response, got %d", len(sender.sent))
|
||||
}
|
||||
if sender.sent[0][0] != byte(DatagramV3TypeRegistrationResponse) || sender.sent[0][1] != v3ResponseErrorWithMsg {
|
||||
t.Fatalf("unexpected datagram response: %v", sender.sent[0])
|
||||
}
|
||||
}
|
||||
|
||||
type scriptedPacketConn struct {
|
||||
reads [][]byte
|
||||
index int
|
||||
}
|
||||
|
||||
func (c *scriptedPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||
if c.index >= len(c.reads) {
|
||||
return M.Socksaddr{}, io.EOF
|
||||
}
|
||||
_, err := buffer.Write(c.reads[c.index])
|
||||
c.index++
|
||||
return M.Socksaddr{}, err
|
||||
}
|
||||
|
||||
func (c *scriptedPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
|
||||
buffer.Release()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *scriptedPacketConn) Close() error { return nil }
|
||||
func (c *scriptedPacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} }
|
||||
func (c *scriptedPacketConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *scriptedPacketConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *scriptedPacketConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
type sizeLimitedSender struct {
|
||||
sent [][]byte
|
||||
max int
|
||||
}
|
||||
|
||||
func (s *sizeLimitedSender) SendDatagram(data []byte) error {
|
||||
if len(data) > s.max {
|
||||
return errors.New("datagram too large")
|
||||
}
|
||||
s.sent = append(s.sent, append([]byte(nil), data...))
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDatagramV3ReadLoopDropsOversizedOriginPackets(t *testing.T) {
|
||||
logger := log.NewNOPFactory().NewLogger("test")
|
||||
sender := &sizeLimitedSender{max: v3PayloadHeaderLen + maxV3UDPPayloadLen}
|
||||
session := &v3Session{
|
||||
id: RequestID{},
|
||||
destination: netip.MustParseAddrPort("127.0.0.1:53"),
|
||||
origin: &scriptedPacketConn{reads: [][]byte{
|
||||
make([]byte, maxV3UDPPayloadLen+1),
|
||||
[]byte("ok"),
|
||||
}},
|
||||
inbound: &Inbound{
|
||||
logger: logger,
|
||||
},
|
||||
writeChan: make(chan []byte, 1),
|
||||
closeChan: make(chan struct{}),
|
||||
contextChan: make(chan context.Context, 1),
|
||||
sender: sender,
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
session.readLoop()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected read loop to finish")
|
||||
}
|
||||
|
||||
if len(sender.sent) != 1 {
|
||||
t.Fatalf("expected one datagram after dropping oversized payload, got %d", len(sender.sent))
|
||||
}
|
||||
if len(sender.sent[0]) != v3PayloadHeaderLen+2 {
|
||||
t.Fatalf("unexpected forwarded datagram length: %d", len(sender.sent[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatagramV3HandlePayloadDropsOversizedPayload(t *testing.T) {
|
||||
requestID := RequestID{}
|
||||
requestID[15] = 9
|
||||
session := &v3Session{
|
||||
id: requestID,
|
||||
writeChan: make(chan []byte, 1),
|
||||
}
|
||||
manager := NewDatagramV3SessionManager()
|
||||
manager.sessions[requestID] = session
|
||||
muxer := &DatagramV3Muxer{
|
||||
inbound: &Inbound{
|
||||
datagramV3Manager: manager,
|
||||
},
|
||||
}
|
||||
|
||||
payload := make([]byte, v3RequestIDLength+maxV3UDPPayloadLen+1)
|
||||
copy(payload[:v3RequestIDLength], requestID[:])
|
||||
muxer.handlePayload(payload)
|
||||
|
||||
select {
|
||||
case <-session.writeChan:
|
||||
t.Fatal("expected oversized payload to be dropped")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
type deadlinePacketConn struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *deadlinePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||
buffer.Release()
|
||||
return M.Socksaddr{}, io.EOF
|
||||
}
|
||||
|
||||
func (c *deadlinePacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
|
||||
buffer.Release()
|
||||
return c.err
|
||||
}
|
||||
|
||||
func (c *deadlinePacketConn) Close() error { return nil }
|
||||
func (c *deadlinePacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} }
|
||||
func (c *deadlinePacketConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *deadlinePacketConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *deadlinePacketConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
func TestDatagramV3WriteLoopDropsDeadlineExceeded(t *testing.T) {
|
||||
session := &v3Session{
|
||||
destination: netip.MustParseAddrPort("127.0.0.1:53"),
|
||||
origin: &deadlinePacketConn{err: os.ErrDeadlineExceeded},
|
||||
inbound: &Inbound{
|
||||
logger: log.NewNOPFactory().NewLogger("test"),
|
||||
},
|
||||
writeChan: make(chan []byte, 1),
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
session.writeLoop()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
session.writeToOrigin([]byte("payload"))
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
select {
|
||||
case <-session.closeChan:
|
||||
t.Fatal("expected session to remain open after deadline exceeded")
|
||||
default:
|
||||
}
|
||||
|
||||
session.close()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected write loop to exit after manual close")
|
||||
}
|
||||
}
|
||||
@@ -1,206 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
stdTLS "crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
boxTLS "github.com/sagernet/sing-box/common/tls"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
)
|
||||
|
||||
func TestNewDirectOriginTransportUnix(t *testing.T) {
|
||||
socketPath := fmt.Sprintf("/tmp/cf-origin-%d.sock", time.Now().UnixNano())
|
||||
_ = os.Remove(socketPath)
|
||||
t.Cleanup(func() { _ = os.Remove(socketPath) })
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
go serveTestHTTPOverListener(listener, func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
_, _ = writer.Write([]byte("unix-ok"))
|
||||
})
|
||||
|
||||
inboundInstance := &Inbound{}
|
||||
transport, cleanup, err := inboundInstance.newDirectOriginTransport(ResolvedService{
|
||||
Kind: ResolvedServiceUnix,
|
||||
UnixPath: socketPath,
|
||||
BaseURL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "localhost",
|
||||
},
|
||||
}, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
resp, err := client.Get("http://localhost/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(body) != "unix-ok" {
|
||||
t.Fatalf("unexpected response body: %q", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDirectOriginTransportUnixTLS(t *testing.T) {
|
||||
socketPath := fmt.Sprintf("/tmp/cf-origin-tls-%d.sock", time.Now().UnixNano())
|
||||
_ = os.Remove(socketPath)
|
||||
t.Cleanup(func() { _ = os.Remove(socketPath) })
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
certificate, err := boxTLS.GenerateKeyPair(nil, nil, time.Now, "localhost")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tlsListener := stdTLS.NewListener(listener, &stdTLS.Config{
|
||||
Certificates: []stdTLS.Certificate{*certificate},
|
||||
})
|
||||
defer tlsListener.Close()
|
||||
|
||||
go serveTestHTTPOverListener(tlsListener, func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
_, _ = writer.Write([]byte("unix-tls-ok"))
|
||||
})
|
||||
|
||||
inboundInstance := &Inbound{}
|
||||
transport, cleanup, err := inboundInstance.newDirectOriginTransport(ResolvedService{
|
||||
Kind: ResolvedServiceUnixTLS,
|
||||
OriginRequest: OriginRequestConfig{
|
||||
NoTLSVerify: true,
|
||||
},
|
||||
UnixPath: socketPath,
|
||||
BaseURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "localhost",
|
||||
},
|
||||
}, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
resp, err := client.Get("https://localhost/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(body) != "unix-tls-ok" {
|
||||
t.Fatalf("unexpected response body: %q", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func serveTestHTTPOverListener(listener net.Listener, handler func(http.ResponseWriter, *http.Request)) {
|
||||
server := &http.Server{Handler: http.HandlerFunc(handler)}
|
||||
_ = server.Serve(listener)
|
||||
}
|
||||
|
||||
func TestDirectOriginTransportCacheReusesMatchingTransports(t *testing.T) {
|
||||
inboundInstance := &Inbound{
|
||||
directTransports: make(map[string]*http.Transport),
|
||||
}
|
||||
service := ResolvedService{
|
||||
Kind: ResolvedServiceUnix,
|
||||
UnixPath: "/tmp/test.sock",
|
||||
BaseURL: &url.URL{Scheme: "http", Host: "localhost"},
|
||||
}
|
||||
|
||||
transport1, _, err := inboundInstance.newDirectOriginTransport(service, "example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
transport2, _, err := inboundInstance.newDirectOriginTransport(service, "example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if transport1 != transport2 {
|
||||
t.Fatal("expected matching direct-origin transports to be reused")
|
||||
}
|
||||
|
||||
transport3, _, err := inboundInstance.newDirectOriginTransport(service, "other.example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if transport3 == transport1 {
|
||||
t.Fatal("expected different cache keys to produce different transports")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyConfigClearsDirectOriginTransportCache(t *testing.T) {
|
||||
configManager, err := NewConfigManager()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
inboundInstance := &Inbound{
|
||||
logger: log.NewNOPFactory().NewLogger("test"),
|
||||
configManager: configManager,
|
||||
directTransports: make(map[string]*http.Transport),
|
||||
}
|
||||
service := ResolvedService{
|
||||
Kind: ResolvedServiceUnix,
|
||||
UnixPath: "/tmp/test.sock",
|
||||
BaseURL: &url.URL{Scheme: "http", Host: "localhost"},
|
||||
}
|
||||
|
||||
transport1, _, err := inboundInstance.newDirectOriginTransport(service, "example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
result := inboundInstance.ApplyConfig(1, []byte(`{"ingress":[{"service":"http_status:503"}]}`))
|
||||
if result.Err != nil {
|
||||
t.Fatal(result.Err)
|
||||
}
|
||||
|
||||
transport2, _, err := inboundInstance.newDirectOriginTransport(service, "example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if transport1 == transport2 {
|
||||
t.Fatal("expected ApplyConfig to clear direct-origin transport cache")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDirectOriginTransportUsesCloudflaredDefaults(t *testing.T) {
|
||||
inboundInstance := &Inbound{}
|
||||
transport, cleanup, err := inboundInstance.newDirectOriginTransport(ResolvedService{
|
||||
Kind: ResolvedServiceUnix,
|
||||
UnixPath: "/tmp/test.sock",
|
||||
BaseURL: &url.URL{Scheme: "http", Host: "localhost"},
|
||||
}, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
if transport.ExpectContinueTimeout != time.Second {
|
||||
t.Fatalf("expected ExpectContinueTimeout=1s, got %s", transport.ExpectContinueTimeout)
|
||||
}
|
||||
if transport.DisableCompression {
|
||||
t.Fatal("expected compression to remain enabled by default")
|
||||
}
|
||||
}
|
||||
@@ -1,731 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
const (
|
||||
metadataHTTPMethod = "HttpMethod"
|
||||
metadataHTTPHost = "HttpHost"
|
||||
metadataHTTPHeader = "HttpHeader"
|
||||
metadataHTTPStatus = "HttpStatus"
|
||||
)
|
||||
|
||||
var (
|
||||
loadOriginCABasePool = cloudflareRootCertPool
|
||||
readOriginCAFile = os.ReadFile
|
||||
proxyFromEnvironment = http.ProxyFromEnvironment
|
||||
)
|
||||
|
||||
// ConnectResponseWriter abstracts the response writing for both QUIC and HTTP/2.
|
||||
type ConnectResponseWriter interface {
|
||||
// WriteResponse sends the connect response (ack or error) with optional metadata.
|
||||
WriteResponse(responseError error, metadata []Metadata) error
|
||||
}
|
||||
|
||||
type connectResponseTrailerWriter interface {
|
||||
AddTrailer(name, value string)
|
||||
}
|
||||
|
||||
// quicResponseWriter writes ConnectResponse in QUIC data stream format (signature + capnp).
|
||||
type quicResponseWriter struct {
|
||||
stream io.Writer
|
||||
}
|
||||
|
||||
func (w *quicResponseWriter) WriteResponse(responseError error, metadata []Metadata) error {
|
||||
return WriteConnectResponse(w.stream, responseError, metadata...)
|
||||
}
|
||||
|
||||
// HandleDataStream dispatches an incoming edge data stream (QUIC path).
|
||||
func (i *Inbound) HandleDataStream(ctx context.Context, stream io.ReadWriteCloser, request *ConnectRequest, connIndex uint8) {
|
||||
ctx = log.ContextWithNewID(ctx)
|
||||
respWriter := &quicResponseWriter{stream: stream}
|
||||
i.dispatchRequest(ctx, stream, respWriter, request)
|
||||
}
|
||||
|
||||
// HandleRPCStream handles an incoming edge RPC stream (session management, configuration).
|
||||
func (i *Inbound) HandleRPCStream(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8) {
|
||||
i.logger.DebugContext(ctx, "received RPC stream on connection ", connIndex)
|
||||
// V2 RPC streams are handled here - the edge calls RegisterUdpSession/UnregisterUdpSession
|
||||
// We need the sender (DatagramSender) to find the muxer - but HandleRPCStream doesn't have it.
|
||||
// The V2 muxer is looked up via GetOrCreateV2Muxer in HandleDatagram when first datagram arrives.
|
||||
// For RPC, we need a different approach - see handleRPCStreamWithSender below.
|
||||
}
|
||||
|
||||
// HandleRPCStreamWithSender handles an RPC stream with access to the DatagramSender for V2 muxer lookup.
|
||||
func (i *Inbound) HandleRPCStreamWithSender(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8, sender DatagramSender) {
|
||||
switch datagramVersionForSender(sender) {
|
||||
case "v3":
|
||||
ServeV3RPCStream(ctx, stream, i, i.logger)
|
||||
default:
|
||||
muxer := i.getOrCreateV2Muxer(sender)
|
||||
ServeRPCStream(ctx, stream, i, muxer, i.logger)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleDatagram handles an incoming QUIC datagram.
|
||||
func (i *Inbound) HandleDatagram(ctx context.Context, datagram []byte, sender DatagramSender) {
|
||||
switch datagramVersionForSender(sender) {
|
||||
case "v3":
|
||||
muxer := i.getOrCreateV3Muxer(sender)
|
||||
muxer.HandleDatagram(ctx, datagram)
|
||||
default:
|
||||
muxer := i.getOrCreateV2Muxer(sender)
|
||||
muxer.HandleDatagram(ctx, datagram)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Inbound) getOrCreateV2Muxer(sender DatagramSender) *DatagramV2Muxer {
|
||||
i.datagramMuxerAccess.Lock()
|
||||
defer i.datagramMuxerAccess.Unlock()
|
||||
muxer, exists := i.datagramV2Muxers[sender]
|
||||
if !exists {
|
||||
muxer = NewDatagramV2Muxer(i, sender, i.logger)
|
||||
i.datagramV2Muxers[sender] = muxer
|
||||
}
|
||||
return muxer
|
||||
}
|
||||
|
||||
func (i *Inbound) getOrCreateV3Muxer(sender DatagramSender) *DatagramV3Muxer {
|
||||
i.datagramMuxerAccess.Lock()
|
||||
defer i.datagramMuxerAccess.Unlock()
|
||||
muxer, exists := i.datagramV3Muxers[sender]
|
||||
if !exists {
|
||||
muxer = NewDatagramV3Muxer(i, sender, i.logger)
|
||||
i.datagramV3Muxers[sender] = muxer
|
||||
}
|
||||
return muxer
|
||||
}
|
||||
|
||||
// RemoveDatagramMuxer cleans up muxers when a connection closes.
|
||||
func (i *Inbound) RemoveDatagramMuxer(sender DatagramSender) {
|
||||
i.datagramMuxerAccess.Lock()
|
||||
if muxer, exists := i.datagramV2Muxers[sender]; exists {
|
||||
muxer.Close()
|
||||
delete(i.datagramV2Muxers, sender)
|
||||
}
|
||||
if muxer, exists := i.datagramV3Muxers[sender]; exists {
|
||||
muxer.Close()
|
||||
delete(i.datagramV3Muxers, sender)
|
||||
}
|
||||
i.datagramMuxerAccess.Unlock()
|
||||
}
|
||||
|
||||
func (i *Inbound) dispatchRequest(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest) {
|
||||
metadata := adapter.InboundContext{
|
||||
Inbound: i.Tag(),
|
||||
InboundType: i.Type(),
|
||||
}
|
||||
|
||||
switch request.Type {
|
||||
case ConnectionTypeTCP:
|
||||
metadata.Destination = M.ParseSocksaddr(request.Dest)
|
||||
i.handleTCPStream(ctx, stream, respWriter, metadata)
|
||||
case ConnectionTypeHTTP, ConnectionTypeWebsocket:
|
||||
service, originURL, err := i.resolveHTTPService(request.Dest)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "resolve origin service: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
request.Dest = originURL
|
||||
i.handleHTTPService(ctx, stream, respWriter, request, metadata, service)
|
||||
default:
|
||||
i.logger.ErrorContext(ctx, "unknown connection type: ", request.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Inbound) resolveHTTPService(requestURL string) (ResolvedService, string, error) {
|
||||
parsedURL, err := url.Parse(requestURL)
|
||||
if err != nil {
|
||||
return ResolvedService{}, "", E.Cause(err, "parse request URL")
|
||||
}
|
||||
service, loaded := i.configManager.Resolve(parsedURL.Hostname(), parsedURL.Path)
|
||||
if !loaded {
|
||||
return ResolvedService{}, "", E.New("no ingress rule matched request host/path")
|
||||
}
|
||||
originURL, err := service.BuildRequestURL(requestURL)
|
||||
if err != nil {
|
||||
return ResolvedService{}, "", E.Cause(err, "build origin request URL")
|
||||
}
|
||||
return service, originURL, nil
|
||||
}
|
||||
|
||||
func parseHTTPDestination(dest string) M.Socksaddr {
|
||||
parsed, err := url.Parse(dest)
|
||||
if err != nil {
|
||||
return M.ParseSocksaddr(dest)
|
||||
}
|
||||
host := parsed.Hostname()
|
||||
port := parsed.Port()
|
||||
if port == "" {
|
||||
switch parsed.Scheme {
|
||||
case "https", "wss":
|
||||
port = "443"
|
||||
default:
|
||||
port = "80"
|
||||
}
|
||||
}
|
||||
return M.ParseSocksaddr(net.JoinHostPort(host, port))
|
||||
}
|
||||
|
||||
func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, metadata adapter.InboundContext) {
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound TCP connection to ", metadata.Destination)
|
||||
limit := i.maxActiveFlows()
|
||||
if !i.flowLimiter.Acquire(limit) {
|
||||
err := E.New("too many active flows")
|
||||
i.logger.ErrorContext(ctx, err)
|
||||
respWriter.WriteResponse(err, flowConnectRateLimitedMetadata())
|
||||
return
|
||||
}
|
||||
defer i.flowLimiter.Release(limit)
|
||||
|
||||
warpRouting := i.configManager.Snapshot().WarpRouting
|
||||
targetConn, cleanup, err := i.dialRouterTCPWithMetadata(ctx, metadata, routedPipeTCPOptions{
|
||||
timeout: warpRouting.ConnectTimeout,
|
||||
onHandshake: func(conn net.Conn) {
|
||||
_ = applyTCPKeepAlive(conn, warpRouting.TCPKeepAlive)
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "dial tcp origin: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
// Cloudflare expects an optimistic ACK here so the routed TCP path can sniff
|
||||
// the real input stream before the outbound connection is fully established.
|
||||
err = respWriter.WriteResponse(nil, nil)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "write connect response: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = bufio.CopyConn(ctx, newStreamConn(stream), targetConn)
|
||||
if err != nil && !E.IsClosedOrCanceled(err) {
|
||||
i.logger.DebugContext(ctx, "copy TCP stream: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
validationRequest, err := buildMetadataOnlyHTTPRequest(ctx, request)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "build request for access validation: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
validationRequest = applyOriginRequest(validationRequest, service.OriginRequest)
|
||||
if service.OriginRequest.Access.Required {
|
||||
validator, err := i.accessCache.Get(service.OriginRequest.Access)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "create access validator: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
if err := validator.Validate(validationRequest.Context(), validationRequest); err != nil {
|
||||
respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusForbidden, http.Header{}))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
switch service.Kind {
|
||||
case ResolvedServiceStatus:
|
||||
err = respWriter.WriteResponse(nil, encodeResponseHeaders(service.StatusCode, http.Header{}))
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "write status service response: ", err)
|
||||
}
|
||||
return
|
||||
case ResolvedServiceHTTP:
|
||||
metadata.Destination = service.Destination
|
||||
if request.Type == ConnectionTypeHTTP {
|
||||
i.handleHTTPStream(ctx, stream, respWriter, request, metadata, service)
|
||||
} else {
|
||||
i.handleWebSocketStream(ctx, stream, respWriter, request, metadata, service)
|
||||
}
|
||||
case ResolvedServiceStream:
|
||||
if request.Type != ConnectionTypeWebsocket {
|
||||
err := E.New("stream service requires websocket request type")
|
||||
i.logger.ErrorContext(ctx, err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
i.handleStreamService(ctx, stream, respWriter, request, metadata, service)
|
||||
case ResolvedServiceUnix, ResolvedServiceUnixTLS:
|
||||
if request.Type == ConnectionTypeHTTP {
|
||||
i.handleDirectHTTPStream(ctx, stream, respWriter, request, metadata, service)
|
||||
} else {
|
||||
i.handleDirectWebSocketStream(ctx, stream, respWriter, request, metadata, service)
|
||||
}
|
||||
case ResolvedServiceBastion:
|
||||
if request.Type != ConnectionTypeWebsocket {
|
||||
err := E.New("bastion service requires websocket request type")
|
||||
i.logger.ErrorContext(ctx, err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
i.handleBastionStream(ctx, stream, respWriter, request, metadata, service)
|
||||
case ResolvedServiceSocksProxy:
|
||||
if request.Type != ConnectionTypeWebsocket {
|
||||
err := E.New("socks-proxy service requires websocket request type")
|
||||
i.logger.ErrorContext(ctx, err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
i.handleSocksProxyStream(ctx, stream, respWriter, request, metadata, service)
|
||||
default:
|
||||
err := E.New("unsupported service kind for HTTP/WebSocket request")
|
||||
i.logger.ErrorContext(ctx, err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound HTTP connection to ", metadata.Destination)
|
||||
|
||||
transport, cleanup, err := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest, request.MetadataMap()[metadataHTTPHost])
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "build origin transport: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
defer cleanup()
|
||||
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
|
||||
}
|
||||
|
||||
func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination)
|
||||
|
||||
transport, cleanup, err := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest, request.MetadataMap()[metadataHTTPHost])
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "build origin transport: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
defer cleanup()
|
||||
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
|
||||
}
|
||||
|
||||
func (i *Inbound) handleDirectHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound HTTP connection to ", request.Dest)
|
||||
|
||||
transport, cleanup, err := i.newDirectOriginTransport(service, request.MetadataMap()[metadataHTTPHost])
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "build direct origin transport: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
defer cleanup()
|
||||
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
|
||||
}
|
||||
|
||||
func (i *Inbound) handleDirectWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", request.Dest)
|
||||
|
||||
transport, cleanup, err := i.newDirectOriginTransport(service, request.MetadataMap()[metadataHTTPHost])
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "build direct origin transport: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
defer cleanup()
|
||||
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
|
||||
}
|
||||
|
||||
func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, service ResolvedService, transport *http.Transport) {
|
||||
httpRequest, err := buildHTTPRequestFromMetadata(ctx, request, stream)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "build HTTP request: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
|
||||
httpRequest = normalizeOriginRequest(request.Type, httpRequest, service.OriginRequest)
|
||||
requestCtx := httpRequest.Context()
|
||||
if service.OriginRequest.ConnectTimeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
requestCtx, cancel = context.WithTimeout(requestCtx, service.OriginRequest.ConnectTimeout)
|
||||
defer cancel()
|
||||
httpRequest = httpRequest.WithContext(requestCtx)
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: transport,
|
||||
CheckRedirect: func(request *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
response, err := httpClient.Do(httpRequest)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "origin request: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
responseMetadata := encodeResponseHeaders(response.StatusCode, response.Header)
|
||||
err = respWriter.WriteResponse(nil, responseMetadata)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "write origin response headers: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if request.Type == ConnectionTypeWebsocket && response.StatusCode == http.StatusSwitchingProtocols {
|
||||
rwc, ok := response.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
i.logger.ErrorContext(ctx, "websocket origin response body is not duplex")
|
||||
return
|
||||
}
|
||||
bidirectionalCopy(stream, rwc)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = io.Copy(stream, response.Body)
|
||||
if err != nil && !E.IsClosedOrCanceled(err) {
|
||||
i.logger.DebugContext(ctx, "copy HTTP response body: ", err)
|
||||
}
|
||||
if trailerWriter, ok := respWriter.(connectResponseTrailerWriter); ok {
|
||||
for name, values := range response.Trailer {
|
||||
for _, value := range values {
|
||||
trailerWriter.AddTrailer(name, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig, requestHost string) (*http.Transport, func(), error) {
|
||||
tlsConfig, err := newOriginTLSConfig(originRequest, effectiveOriginHost(originRequest, requestHost))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
input, cleanup, _ := i.dialRouterTCPWithMetadata(ctx, metadata, routedPipeTCPOptions{})
|
||||
|
||||
transport := &http.Transport{
|
||||
ExpectContinueTimeout: time.Second,
|
||||
ForceAttemptHTTP2: originRequest.HTTP2Origin,
|
||||
TLSHandshakeTimeout: originRequest.TLSTimeout,
|
||||
IdleConnTimeout: originRequest.KeepAliveTimeout,
|
||||
MaxIdleConns: originRequest.KeepAliveConnections,
|
||||
MaxIdleConnsPerHost: originRequest.KeepAliveConnections,
|
||||
Proxy: proxyFromEnvironment,
|
||||
TLSClientConfig: tlsConfig,
|
||||
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return input, nil
|
||||
},
|
||||
}
|
||||
return transport, cleanup, nil
|
||||
}
|
||||
|
||||
func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost string) (*http.Transport, func(), error) {
|
||||
cacheKey, err := directOriginTransportKey(service, requestHost)
|
||||
if err != nil {
|
||||
return nil, nil, E.Cause(err, "marshal direct origin transport key")
|
||||
}
|
||||
|
||||
i.directTransportAccess.Lock()
|
||||
if i.directTransports == nil {
|
||||
i.directTransports = make(map[string]*http.Transport)
|
||||
}
|
||||
if transport, exists := i.directTransports[cacheKey]; exists {
|
||||
i.directTransportAccess.Unlock()
|
||||
return transport, func() {}, nil
|
||||
}
|
||||
i.directTransportAccess.Unlock()
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: service.OriginRequest.ConnectTimeout,
|
||||
KeepAlive: service.OriginRequest.TCPKeepAlive,
|
||||
}
|
||||
if service.OriginRequest.NoHappyEyeballs {
|
||||
dialer.FallbackDelay = -1
|
||||
}
|
||||
tlsConfig, err := newOriginTLSConfig(service.OriginRequest, effectiveOriginHost(service.OriginRequest, requestHost))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
transport := &http.Transport{
|
||||
ExpectContinueTimeout: time.Second,
|
||||
ForceAttemptHTTP2: service.OriginRequest.HTTP2Origin,
|
||||
TLSHandshakeTimeout: service.OriginRequest.TLSTimeout,
|
||||
IdleConnTimeout: service.OriginRequest.KeepAliveTimeout,
|
||||
MaxIdleConns: service.OriginRequest.KeepAliveConnections,
|
||||
MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections,
|
||||
Proxy: proxyFromEnvironment,
|
||||
TLSClientConfig: tlsConfig,
|
||||
}
|
||||
switch service.Kind {
|
||||
case ResolvedServiceUnix, ResolvedServiceUnixTLS:
|
||||
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, "unix", service.UnixPath)
|
||||
}
|
||||
default:
|
||||
return nil, nil, E.New("unsupported direct origin service")
|
||||
}
|
||||
|
||||
i.directTransportAccess.Lock()
|
||||
if i.directTransports == nil {
|
||||
i.directTransports = make(map[string]*http.Transport)
|
||||
}
|
||||
if cached, exists := i.directTransports[cacheKey]; exists {
|
||||
i.directTransportAccess.Unlock()
|
||||
transport.CloseIdleConnections()
|
||||
return cached, func() {}, nil
|
||||
}
|
||||
i.directTransports[cacheKey] = transport
|
||||
i.directTransportAccess.Unlock()
|
||||
return transport, func() {}, nil
|
||||
}
|
||||
|
||||
type directOriginTransportCacheKey struct {
|
||||
Kind ResolvedServiceKind `json:"kind"`
|
||||
UnixPath string `json:"unix_path,omitempty"`
|
||||
RequestHost string `json:"request_host,omitempty"`
|
||||
Origin OriginRequestConfig `json:"origin"`
|
||||
}
|
||||
|
||||
func directOriginTransportKey(service ResolvedService, requestHost string) (string, error) {
|
||||
key := directOriginTransportCacheKey{
|
||||
Kind: service.Kind,
|
||||
UnixPath: service.UnixPath,
|
||||
RequestHost: effectiveOriginHost(service.OriginRequest, requestHost),
|
||||
Origin: service.OriginRequest,
|
||||
}
|
||||
data, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func effectiveOriginHost(originRequest OriginRequestConfig, requestHost string) string {
|
||||
if originRequest.HTTPHostHeader != "" {
|
||||
return originRequest.HTTPHostHeader
|
||||
}
|
||||
return requestHost
|
||||
}
|
||||
|
||||
func newOriginTLSConfig(originRequest OriginRequestConfig, requestHost string) (*tls.Config, error) {
|
||||
rootCAs, err := loadOriginCABasePool()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "load origin root CAs")
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: originRequest.NoTLSVerify, //nolint:gosec
|
||||
ServerName: originTLSServerName(originRequest, requestHost),
|
||||
RootCAs: rootCAs,
|
||||
}
|
||||
if originRequest.CAPool == "" {
|
||||
return tlsConfig, nil
|
||||
}
|
||||
pemData, err := readOriginCAFile(originRequest.CAPool)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read origin ca pool")
|
||||
}
|
||||
if !tlsConfig.RootCAs.AppendCertsFromPEM(pemData) {
|
||||
return nil, E.New("parse origin ca pool")
|
||||
}
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
func originTLSServerName(originRequest OriginRequestConfig, requestHost string) string {
|
||||
if originRequest.OriginServerName != "" {
|
||||
return originRequest.OriginServerName
|
||||
}
|
||||
if !originRequest.MatchSNIToHost {
|
||||
return ""
|
||||
}
|
||||
if host, _, err := net.SplitHostPort(requestHost); err == nil {
|
||||
return host
|
||||
}
|
||||
return requestHost
|
||||
}
|
||||
|
||||
func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig) *http.Request {
|
||||
request = request.Clone(request.Context())
|
||||
if originRequest.HTTPHostHeader != "" {
|
||||
request.Header.Set("X-Forwarded-Host", request.Host)
|
||||
request.Host = originRequest.HTTPHostHeader
|
||||
}
|
||||
return request
|
||||
}
|
||||
|
||||
func normalizeOriginRequest(connectType ConnectionType, request *http.Request, originRequest OriginRequestConfig) *http.Request {
|
||||
request = applyOriginRequest(request, originRequest)
|
||||
|
||||
switch connectType {
|
||||
case ConnectionTypeWebsocket:
|
||||
request.Header.Set("Connection", "Upgrade")
|
||||
request.Header.Set("Upgrade", "websocket")
|
||||
request.Header.Set("Sec-Websocket-Version", "13")
|
||||
request.ContentLength = 0
|
||||
request.Body = nil
|
||||
default:
|
||||
if originRequest.DisableChunkedEncoding {
|
||||
request.TransferEncoding = []string{"gzip", "deflate"}
|
||||
if contentLength, err := strconv.ParseInt(request.Header.Get("Content-Length"), 10, 64); err == nil {
|
||||
request.ContentLength = contentLength
|
||||
}
|
||||
}
|
||||
request.Header.Set("Connection", "keep-alive")
|
||||
}
|
||||
|
||||
if _, exists := request.Header["User-Agent"]; !exists {
|
||||
request.Header.Set("User-Agent", "")
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
func buildMetadataOnlyHTTPRequest(ctx context.Context, connectRequest *ConnectRequest) (*http.Request, error) {
|
||||
return buildHTTPRequestFromMetadata(ctx, &ConnectRequest{
|
||||
Dest: connectRequest.Dest,
|
||||
Type: connectRequest.Type,
|
||||
Metadata: append([]Metadata(nil), connectRequest.Metadata...),
|
||||
}, http.NoBody)
|
||||
}
|
||||
|
||||
func bidirectionalCopy(left, right io.ReadWriteCloser) {
|
||||
var closeOnce sync.Once
|
||||
closeBoth := func() {
|
||||
closeOnce.Do(func() {
|
||||
common.Close(left, right)
|
||||
})
|
||||
}
|
||||
|
||||
done := make(chan struct{}, 2)
|
||||
go func() {
|
||||
io.Copy(left, right)
|
||||
closeBoth()
|
||||
done <- struct{}{}
|
||||
}()
|
||||
go func() {
|
||||
io.Copy(right, left)
|
||||
closeBoth()
|
||||
done <- struct{}{}
|
||||
}()
|
||||
<-done
|
||||
<-done
|
||||
}
|
||||
|
||||
func buildHTTPRequestFromMetadata(ctx context.Context, connectRequest *ConnectRequest, body io.Reader) (*http.Request, error) {
|
||||
metadataMap := connectRequest.MetadataMap()
|
||||
method := metadataMap[metadataHTTPMethod]
|
||||
host := metadataMap[metadataHTTPHost]
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, method, connectRequest.Dest, body)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create HTTP request")
|
||||
}
|
||||
request.Host = host
|
||||
|
||||
for _, entry := range connectRequest.Metadata {
|
||||
if !strings.Contains(entry.Key, metadataHTTPHeader) {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(entry.Key, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
request.Header.Add(parts[1], entry.Val)
|
||||
}
|
||||
|
||||
contentLengthStr := request.Header.Get("Content-Length")
|
||||
if contentLengthStr != "" {
|
||||
request.ContentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "parse content-length")
|
||||
}
|
||||
}
|
||||
|
||||
if connectRequest.Type != ConnectionTypeWebsocket && !isTransferEncodingChunked(request) && request.ContentLength == 0 {
|
||||
request.Body = http.NoBody
|
||||
}
|
||||
|
||||
request.Header.Del("Cf-Cloudflared-Proxy-Connection-Upgrade")
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func isTransferEncodingChunked(request *http.Request) bool {
|
||||
for _, encoding := range request.TransferEncoding {
|
||||
if strings.Contains(strings.ToLower(encoding), "chunked") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return strings.Contains(strings.ToLower(request.Header.Get("Transfer-Encoding")), "chunked")
|
||||
}
|
||||
|
||||
func encodeResponseHeaders(statusCode int, header http.Header) []Metadata {
|
||||
metadata := make([]Metadata, 0, len(header)+1)
|
||||
metadata = append(metadata, Metadata{
|
||||
Key: metadataHTTPStatus,
|
||||
Val: strconv.Itoa(statusCode),
|
||||
})
|
||||
for name, values := range header {
|
||||
for _, value := range values {
|
||||
metadata = append(metadata, Metadata{
|
||||
Key: metadataHTTPHeader + ":" + name,
|
||||
Val: value,
|
||||
})
|
||||
}
|
||||
}
|
||||
return metadata
|
||||
}
|
||||
|
||||
// streamConn wraps an io.ReadWriteCloser as a net.Conn.
|
||||
type streamConn struct {
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func newStreamConn(stream io.ReadWriteCloser) *streamConn {
|
||||
return &streamConn{ReadWriteCloser: stream}
|
||||
}
|
||||
|
||||
func (c *streamConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *streamConn) RemoteAddr() net.Addr { return nil }
|
||||
func (c *streamConn) SetDeadline(_ time.Time) error { return nil }
|
||||
func (c *streamConn) SetReadDeadline(_ time.Time) error { return nil }
|
||||
func (c *streamConn) SetWriteDeadline(_ time.Time) error { return nil }
|
||||
|
||||
type datagramVersionedSender interface {
|
||||
DatagramVersion() string
|
||||
}
|
||||
|
||||
func datagramVersionForSender(sender DatagramSender) string {
|
||||
versioned, ok := sender.(datagramVersionedSender)
|
||||
if !ok {
|
||||
return defaultDatagramVersion
|
||||
}
|
||||
version := versioned.DatagramVersion()
|
||||
if version == "" {
|
||||
return defaultDatagramVersion
|
||||
}
|
||||
return version
|
||||
}
|
||||
@@ -1,137 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseHTTPDestination(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dest string
|
||||
expected string
|
||||
}{
|
||||
{"http with port", "http://127.0.0.1:8083/path", "127.0.0.1:8083"},
|
||||
{"https default port", "https://example.com", "example.com:443"},
|
||||
{"http default port", "http://example.com", "example.com:80"},
|
||||
{"wss default port", "wss://example.com/ws", "example.com:443"},
|
||||
{"explicit port", "https://example.com:9443/api", "example.com:9443"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseHTTPDestination(tt.dest)
|
||||
if result.String() != tt.expected {
|
||||
t.Errorf("parseHTTPDestination(%q) = %q, want %q", tt.dest, result.String(), tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeHeaders(t *testing.T) {
|
||||
header := http.Header{}
|
||||
header.Set("Content-Type", "text/html")
|
||||
header.Set("X-Foo", "bar")
|
||||
|
||||
serialized := SerializeHeaders(header)
|
||||
if serialized == "" {
|
||||
t.Fatal("expected non-empty serialized headers")
|
||||
}
|
||||
|
||||
decoded := make(map[string]string)
|
||||
for _, pair := range splitNonEmpty(serialized, ";") {
|
||||
parts := splitNonEmpty(pair, ":")
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("malformed pair: %q", pair)
|
||||
}
|
||||
name, err := headerEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
t.Fatal("decode name: ", err)
|
||||
}
|
||||
value, err := headerEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
t.Fatal("decode value: ", err)
|
||||
}
|
||||
decoded[string(name)] = string(value)
|
||||
}
|
||||
|
||||
if decoded["Content-Type"] != "text/html" {
|
||||
t.Error("expected Content-Type=text/html, got ", decoded["Content-Type"])
|
||||
}
|
||||
if decoded["X-Foo"] != "bar" {
|
||||
t.Error("expected X-Foo=bar, got ", decoded["X-Foo"])
|
||||
}
|
||||
}
|
||||
|
||||
func splitNonEmpty(s string, sep string) []string {
|
||||
var result []string
|
||||
for _, part := range splitString(s, sep) {
|
||||
if part != "" {
|
||||
result = append(result, part)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func splitString(s string, sep string) []string {
|
||||
if len(sep) == 0 {
|
||||
return []string{s}
|
||||
}
|
||||
var result []string
|
||||
start := 0
|
||||
for i := 0; i <= len(s)-len(sep); i++ {
|
||||
if s[i:i+len(sep)] == sep {
|
||||
result = append(result, s[start:i])
|
||||
start = i + len(sep)
|
||||
i += len(sep) - 1
|
||||
}
|
||||
}
|
||||
result = append(result, s[start:])
|
||||
return result
|
||||
}
|
||||
|
||||
func TestIsControlResponseHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expected bool
|
||||
}{
|
||||
{":status", true},
|
||||
{"cf-int-foo", true},
|
||||
{"cf-cloudflared-response-meta", true},
|
||||
{"cf-proxy-src", true},
|
||||
{"content-type", false},
|
||||
{"x-custom", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isControlResponseHeader(tt.name)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isControlResponseHeader(%q) = %v, want %v", tt.name, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsWebsocketClientHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expected bool
|
||||
}{
|
||||
{"sec-websocket-accept", true},
|
||||
{"connection", true},
|
||||
{"upgrade", true},
|
||||
{"content-type", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isWebsocketClientHeader(tt.name)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isWebsocketClientHeader(%q) = %v, want %v", tt.name, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
const (
|
||||
edgeSRVService = "v2-origintunneld"
|
||||
edgeSRVProto = "tcp"
|
||||
edgeSRVName = "argotunnel.com"
|
||||
|
||||
dotServerName = "cloudflare-dns.com"
|
||||
dotServerAddr = "1.1.1.1:853"
|
||||
dotTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
func getRegionalServiceName(region string) string {
|
||||
if region == "" {
|
||||
return edgeSRVService
|
||||
}
|
||||
return region + "-" + edgeSRVService
|
||||
}
|
||||
|
||||
// EdgeAddr represents a Cloudflare edge server address.
|
||||
type EdgeAddr struct {
|
||||
TCP *net.TCPAddr
|
||||
UDP *net.UDPAddr
|
||||
IPVersion int // 4 or 6
|
||||
}
|
||||
|
||||
// DiscoverEdge performs SRV-based edge discovery and returns addresses
|
||||
// partitioned into regions (typically 2).
|
||||
func DiscoverEdge(ctx context.Context, region string, controlDialer N.Dialer) ([][]*EdgeAddr, error) {
|
||||
regions, err := lookupEdgeSRV(region)
|
||||
if err != nil {
|
||||
regions, err = lookupEdgeSRVWithDoT(ctx, region, controlDialer)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "edge discovery")
|
||||
}
|
||||
}
|
||||
if len(regions) == 0 {
|
||||
return nil, E.New("edge discovery: no edge addresses found")
|
||||
}
|
||||
return regions, nil
|
||||
}
|
||||
|
||||
func lookupEdgeSRV(region string) ([][]*EdgeAddr, error) {
|
||||
_, addrs, err := net.LookupSRV(getRegionalServiceName(region), edgeSRVProto, edgeSRVName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resolveSRVRecords(addrs)
|
||||
}
|
||||
|
||||
func lookupEdgeSRVWithDoT(ctx context.Context, region string, controlDialer N.Dialer) ([][]*EdgeAddr, error) {
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
conn, err := controlDialer.DialContext(ctx, "tcp", M.ParseSocksaddr(dotServerAddr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tls.Client(conn, &tls.Config{ServerName: dotServerName}), nil
|
||||
},
|
||||
}
|
||||
lookupCtx, cancel := context.WithTimeout(ctx, dotTimeout)
|
||||
defer cancel()
|
||||
_, addrs, err := resolver.LookupSRV(lookupCtx, getRegionalServiceName(region), edgeSRVProto, edgeSRVName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resolveSRVRecords(addrs)
|
||||
}
|
||||
|
||||
func resolveSRVRecords(records []*net.SRV) ([][]*EdgeAddr, error) {
|
||||
var regions [][]*EdgeAddr
|
||||
for _, record := range records {
|
||||
ips, err := net.LookupIP(record.Target)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "resolve SRV target: ", record.Target)
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
continue
|
||||
}
|
||||
edgeAddrs := make([]*EdgeAddr, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
ipVersion := 6
|
||||
if ip.To4() != nil {
|
||||
ipVersion = 4
|
||||
}
|
||||
edgeAddrs = append(edgeAddrs, &EdgeAddr{
|
||||
TCP: &net.TCPAddr{IP: ip, Port: int(record.Port)},
|
||||
UDP: &net.UDPAddr{IP: ip, Port: int(record.Port)},
|
||||
IPVersion: ipVersion,
|
||||
})
|
||||
}
|
||||
regions = append(regions, edgeAddrs)
|
||||
}
|
||||
return regions, nil
|
||||
}
|
||||
|
||||
// FilterByIPVersion filters edge addresses to only include the specified IP version.
|
||||
// version 0 means no filtering (auto).
|
||||
func FilterByIPVersion(regions [][]*EdgeAddr, version int) [][]*EdgeAddr {
|
||||
if version == 0 {
|
||||
return regions
|
||||
}
|
||||
var filtered [][]*EdgeAddr
|
||||
for _, region := range regions {
|
||||
var addrs []*EdgeAddr
|
||||
for _, addr := range region {
|
||||
if addr.IPVersion == version {
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
}
|
||||
if len(addrs) > 0 {
|
||||
filtered = append(filtered, addrs)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
@@ -1,148 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func TestDiscoverEdge(t *testing.T) {
|
||||
regions, err := DiscoverEdge(context.Background(), "", N.SystemDialer)
|
||||
if err != nil {
|
||||
t.Fatal("DiscoverEdge: ", err)
|
||||
}
|
||||
if len(regions) == 0 {
|
||||
t.Fatal("expected at least 1 region")
|
||||
}
|
||||
for i, region := range regions {
|
||||
if len(region) == 0 {
|
||||
t.Errorf("region %d is empty", i)
|
||||
continue
|
||||
}
|
||||
for j, addr := range region {
|
||||
if addr.TCP == nil {
|
||||
t.Errorf("region %d addr %d: TCP is nil", i, j)
|
||||
}
|
||||
if addr.UDP == nil {
|
||||
t.Errorf("region %d addr %d: UDP is nil", i, j)
|
||||
}
|
||||
if addr.IPVersion != 4 && addr.IPVersion != 6 {
|
||||
t.Errorf("region %d addr %d: invalid IPVersion %d", i, j, addr.IPVersion)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterByIPVersion(t *testing.T) {
|
||||
v4Addr := &EdgeAddr{
|
||||
TCP: &net.TCPAddr{IP: net.IPv4(1, 1, 1, 1), Port: 7844},
|
||||
UDP: &net.UDPAddr{IP: net.IPv4(1, 1, 1, 1), Port: 7844},
|
||||
IPVersion: 4,
|
||||
}
|
||||
v6Addr := &EdgeAddr{
|
||||
TCP: &net.TCPAddr{IP: net.ParseIP("2606:4700::1"), Port: 7844},
|
||||
UDP: &net.UDPAddr{IP: net.ParseIP("2606:4700::1"), Port: 7844},
|
||||
IPVersion: 6,
|
||||
}
|
||||
mixed := [][]*EdgeAddr{{v4Addr, v6Addr}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
version int
|
||||
expected int
|
||||
}{
|
||||
{"auto", 0, 2},
|
||||
{"v4 only", 4, 1},
|
||||
{"v6 only", 6, 1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := FilterByIPVersion(mixed, tt.version)
|
||||
total := 0
|
||||
for _, region := range result {
|
||||
total += len(region)
|
||||
}
|
||||
if total != tt.expected {
|
||||
t.Errorf("expected %d addrs, got %d", tt.expected, total)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("no match", func(t *testing.T) {
|
||||
v4Only := [][]*EdgeAddr{{v4Addr}}
|
||||
result := FilterByIPVersion(v4Only, 6)
|
||||
if len(result) != 0 {
|
||||
t.Error("expected empty result for no match")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty input", func(t *testing.T) {
|
||||
result := FilterByIPVersion(nil, 4)
|
||||
if len(result) != 0 {
|
||||
t.Error("expected empty result for nil input")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetRegionalServiceName(t *testing.T) {
|
||||
if got := getRegionalServiceName(""); got != edgeSRVService {
|
||||
t.Fatalf("expected global service %s, got %s", edgeSRVService, got)
|
||||
}
|
||||
if got := getRegionalServiceName("us"); got != "us-"+edgeSRVService {
|
||||
t.Fatalf("expected regional service us-%s, got %s", edgeSRVService, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitialEdgeAddrIndex(t *testing.T) {
|
||||
if got := initialEdgeAddrIndex(0, 4); got != 0 {
|
||||
t.Fatalf("expected conn 0 to get index 0, got %d", got)
|
||||
}
|
||||
if got := initialEdgeAddrIndex(3, 4); got != 3 {
|
||||
t.Fatalf("expected conn 3 to get index 3, got %d", got)
|
||||
}
|
||||
if got := initialEdgeAddrIndex(5, 4); got != 1 {
|
||||
t.Fatalf("expected conn 5 to wrap to index 1, got %d", got)
|
||||
}
|
||||
if got := initialEdgeAddrIndex(2, 1); got != 0 {
|
||||
t.Fatalf("expected single-address pool to always return 0, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotateEdgeAddrIndex(t *testing.T) {
|
||||
if got := rotateEdgeAddrIndex(0, 4); got != 1 {
|
||||
t.Fatalf("expected index 0 to rotate to 1, got %d", got)
|
||||
}
|
||||
if got := rotateEdgeAddrIndex(3, 4); got != 0 {
|
||||
t.Fatalf("expected last index to wrap to 0, got %d", got)
|
||||
}
|
||||
if got := rotateEdgeAddrIndex(0, 1); got != 0 {
|
||||
t.Fatalf("expected single-address pool to stay at 0, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveHAConnections(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
requested int
|
||||
available int
|
||||
expected int
|
||||
}{
|
||||
{name: "requested below available", requested: 2, available: 4, expected: 2},
|
||||
{name: "requested equals available", requested: 4, available: 4, expected: 4},
|
||||
{name: "requested above available", requested: 5, available: 3, expected: 3},
|
||||
{name: "no available edges", requested: 4, available: 0, expected: 0},
|
||||
}
|
||||
|
||||
for _, testCase := range tests {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
if actual := effectiveHAConnections(testCase.requested, testCase.available); actual != testCase.expected {
|
||||
t.Fatalf("effectiveHAConnections(%d, %d) = %d, want %d", testCase.requested, testCase.available, actual, testCase.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
)
|
||||
|
||||
func newEdgeTLSConfig(rootCAs *x509.CertPool, serverName string, nextProtos []string) *tls.Config {
|
||||
return &tls.Config{
|
||||
RootCAs: rootCAs,
|
||||
ServerName: serverName,
|
||||
NextProtos: nextProtos,
|
||||
CurvePreferences: []tls.CurveID{tls.CurveP256},
|
||||
}
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewEdgeTLSConfigUsesP256(t *testing.T) {
|
||||
rootCAs := x509.NewCertPool()
|
||||
config := newEdgeTLSConfig(rootCAs, h2EdgeSNI, nil)
|
||||
|
||||
if config.RootCAs != rootCAs {
|
||||
t.Fatal("expected root CA pool to be preserved")
|
||||
}
|
||||
if config.ServerName != h2EdgeSNI {
|
||||
t.Fatalf("expected server name %q, got %q", h2EdgeSNI, config.ServerName)
|
||||
}
|
||||
if len(config.CurvePreferences) != 1 || config.CurvePreferences[0] != tls.CurveP256 {
|
||||
t.Fatalf("unexpected curve preferences: %#v", config.CurvePreferences)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEdgeTLSConfigPreservesNextProtos(t *testing.T) {
|
||||
config := newEdgeTLSConfig(x509.NewCertPool(), quicEdgeSNI, []string{quicEdgeALPN})
|
||||
if len(config.NextProtos) != 1 || config.NextProtos[0] != quicEdgeALPN {
|
||||
t.Fatalf("unexpected next protos: %#v", config.NextProtos)
|
||||
}
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"hash/fnv"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
featureSelectorHostname = "cfd-features.argotunnel.com"
|
||||
featureLookupTimeout = 10 * time.Second
|
||||
defaultDatagramVersion = "v2"
|
||||
defaultFeatureRefreshInterval = time.Hour
|
||||
)
|
||||
|
||||
type cloudflaredFeaturesRecord struct {
|
||||
DatagramV3Percentage uint32 `json:"dv3_2"`
|
||||
}
|
||||
|
||||
var lookupCloudflaredFeatures = func(ctx context.Context) ([]byte, error) {
|
||||
lookupCtx, cancel := context.WithTimeout(ctx, featureLookupTimeout)
|
||||
defer cancel()
|
||||
|
||||
records, err := net.DefaultResolver.LookupTXT(lookupCtx, featureSelectorHostname)
|
||||
if err != nil || len(records) == 0 {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(records[0]), nil
|
||||
}
|
||||
|
||||
type featureSelector struct {
|
||||
configured string
|
||||
accountTag string
|
||||
lookup func(context.Context) ([]byte, error)
|
||||
refreshInterval time.Duration
|
||||
currentDatagramVersion string
|
||||
|
||||
access sync.RWMutex
|
||||
}
|
||||
|
||||
func newFeatureSelector(ctx context.Context, accountTag string, configured string) *featureSelector {
|
||||
selector := &featureSelector{
|
||||
configured: configured,
|
||||
accountTag: accountTag,
|
||||
lookup: lookupCloudflaredFeatures,
|
||||
refreshInterval: defaultFeatureRefreshInterval,
|
||||
currentDatagramVersion: defaultDatagramVersion,
|
||||
}
|
||||
if configured != "" {
|
||||
selector.currentDatagramVersion = configured
|
||||
return selector
|
||||
}
|
||||
_ = selector.refresh(ctx)
|
||||
if selector.refreshInterval > 0 {
|
||||
go selector.refreshLoop(ctx)
|
||||
}
|
||||
return selector
|
||||
}
|
||||
|
||||
func (s *featureSelector) Snapshot() (string, []string) {
|
||||
if s == nil {
|
||||
return defaultDatagramVersion, DefaultFeatures(defaultDatagramVersion)
|
||||
}
|
||||
s.access.RLock()
|
||||
defer s.access.RUnlock()
|
||||
return s.currentDatagramVersion, DefaultFeatures(s.currentDatagramVersion)
|
||||
}
|
||||
|
||||
func (s *featureSelector) refresh(ctx context.Context) error {
|
||||
if s == nil || s.configured != "" {
|
||||
return nil
|
||||
}
|
||||
record, err := s.lookup(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
version, err := resolveRemoteDatagramVersion(s.accountTag, record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.access.Lock()
|
||||
s.currentDatagramVersion = version
|
||||
s.access.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *featureSelector) refreshLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(s.refreshInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
_ = s.refresh(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func resolveRemoteDatagramVersion(accountTag string, record []byte) (string, error) {
|
||||
var features cloudflaredFeaturesRecord
|
||||
if err := json.Unmarshal(record, &features); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if accountEnabled(accountTag, features.DatagramV3Percentage) {
|
||||
return "v3", nil
|
||||
}
|
||||
return defaultDatagramVersion, nil
|
||||
}
|
||||
|
||||
func accountEnabled(accountTag string, percentage uint32) bool {
|
||||
if percentage == 0 {
|
||||
return false
|
||||
}
|
||||
hasher := fnv.New32a()
|
||||
_, _ = hasher.Write([]byte(accountTag))
|
||||
return percentage > hasher.Sum32()%100
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFeatureSelectorConfiguredWins(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
selector := newFeatureSelector(ctx, "account", "v3")
|
||||
version, features := selector.Snapshot()
|
||||
if version != "v3" {
|
||||
t.Fatalf("expected configured version to win, got %s", version)
|
||||
}
|
||||
if !slices.Contains(features, "support_datagram_v3_2") {
|
||||
t.Fatalf("expected v3 feature list, got %#v", features)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFeatureSelectorInitialRemoteSelection(t *testing.T) {
|
||||
selector := &featureSelector{
|
||||
accountTag: "account",
|
||||
lookup: func(context.Context) ([]byte, error) { return []byte(`{"dv3_2":100}`), nil },
|
||||
currentDatagramVersion: defaultDatagramVersion,
|
||||
}
|
||||
|
||||
if err := selector.refresh(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
version, _ := selector.Snapshot()
|
||||
if version != "v3" {
|
||||
t.Fatalf("expected auto-selected v3, got %s", version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFeatureSelectorRefreshUpdatesSnapshot(t *testing.T) {
|
||||
record := []byte(`{"dv3_2":0}`)
|
||||
selector := &featureSelector{
|
||||
accountTag: "account",
|
||||
currentDatagramVersion: defaultDatagramVersion,
|
||||
lookup: func(context.Context) ([]byte, error) {
|
||||
return record, nil
|
||||
},
|
||||
}
|
||||
|
||||
if err := selector.refresh(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
version, _ := selector.Snapshot()
|
||||
if version != defaultDatagramVersion {
|
||||
t.Fatalf("expected initial v2, got %s", version)
|
||||
}
|
||||
|
||||
record = []byte(`{"dv3_2":100}`)
|
||||
if err := selector.refresh(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
version, _ = selector.Snapshot()
|
||||
if version != "v3" {
|
||||
t.Fatalf("expected refreshed v3, got %s", version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFeatureSelectorRefreshFailureKeepsPreviousValue(t *testing.T) {
|
||||
selector := &featureSelector{
|
||||
accountTag: "account",
|
||||
currentDatagramVersion: "v3",
|
||||
lookup: func(context.Context) ([]byte, error) {
|
||||
return nil, errors.New("lookup failed")
|
||||
},
|
||||
}
|
||||
|
||||
if err := selector.refresh(context.Background()); err == nil {
|
||||
t.Fatal("expected refresh failure")
|
||||
}
|
||||
|
||||
version, _ := selector.Snapshot()
|
||||
if version != "v3" {
|
||||
t.Fatalf("expected previous version to be retained, got %s", version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInboundUsesFreshFeatureSnapshotOnRetry(t *testing.T) {
|
||||
inbound := &Inbound{
|
||||
featureSelector: &featureSelector{
|
||||
accountTag: "account",
|
||||
currentDatagramVersion: defaultDatagramVersion,
|
||||
},
|
||||
}
|
||||
|
||||
version, features := inbound.currentConnectionFeatures()
|
||||
if version != defaultDatagramVersion {
|
||||
t.Fatalf("expected initial v2, got %s", version)
|
||||
}
|
||||
if slices.Contains(features, "support_datagram_v3_2") {
|
||||
t.Fatalf("unexpected v3 feature list: %#v", features)
|
||||
}
|
||||
|
||||
inbound.featureSelector.access.Lock()
|
||||
inbound.featureSelector.currentDatagramVersion = "v3"
|
||||
inbound.featureSelector.access.Unlock()
|
||||
|
||||
version, features = inbound.currentConnectionFeatures()
|
||||
if version != "v3" {
|
||||
t.Fatalf("expected refreshed v3, got %s", version)
|
||||
}
|
||||
if !slices.Contains(features, "support_datagram_v3_2") {
|
||||
t.Fatalf("expected v3 feature list, got %#v", features)
|
||||
}
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import "sync"
|
||||
|
||||
type FlowLimiter struct {
|
||||
access sync.Mutex
|
||||
active uint64
|
||||
}
|
||||
|
||||
func (l *FlowLimiter) Acquire(limit uint64) bool {
|
||||
if limit == 0 {
|
||||
return true
|
||||
}
|
||||
l.access.Lock()
|
||||
defer l.access.Unlock()
|
||||
if l.active >= limit {
|
||||
return false
|
||||
}
|
||||
l.active++
|
||||
return true
|
||||
}
|
||||
|
||||
func (l *FlowLimiter) Release(limit uint64) {
|
||||
if limit == 0 {
|
||||
return
|
||||
}
|
||||
l.access.Lock()
|
||||
defer l.access.Unlock()
|
||||
if l.active > 0 {
|
||||
l.active--
|
||||
}
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type captureConnectMetadataWriter struct {
|
||||
err error
|
||||
metadata []Metadata
|
||||
}
|
||||
|
||||
func (w *captureConnectMetadataWriter) WriteResponse(responseError error, metadata []Metadata) error {
|
||||
w.err = responseError
|
||||
w.metadata = append([]Metadata(nil), metadata...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newLimitedInbound(t *testing.T, limit uint64) *Inbound {
|
||||
t.Helper()
|
||||
logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
configManager, err := NewConfigManager()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
config := configManager.Snapshot()
|
||||
config.WarpRouting.MaxActiveFlows = limit
|
||||
configManager.activeConfig = config
|
||||
return &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
router: &testRouter{},
|
||||
logger: logFactory.NewLogger("test"),
|
||||
configManager: configManager,
|
||||
flowLimiter: &FlowLimiter{},
|
||||
datagramV3Manager: NewDatagramV3SessionManager(),
|
||||
connectionStates: make([]connectionState, 1),
|
||||
successfulProtocols: make(map[string]struct{}),
|
||||
directTransports: make(map[string]*http.Transport),
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleTCPStreamRespectsMaxActiveFlows(t *testing.T) {
|
||||
inboundInstance := newLimitedInbound(t, 1)
|
||||
if !inboundInstance.flowLimiter.Acquire(1) {
|
||||
t.Fatal("failed to pre-acquire limiter")
|
||||
}
|
||||
|
||||
stream, peer := net.Pipe()
|
||||
defer stream.Close()
|
||||
defer peer.Close()
|
||||
respWriter := &fakeConnectResponseWriter{}
|
||||
inboundInstance.handleTCPStream(context.Background(), stream, respWriter, adapter.InboundContext{})
|
||||
if respWriter.err == nil {
|
||||
t.Fatal("expected too many active flows error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleTCPStreamRateLimitMetadata(t *testing.T) {
|
||||
inboundInstance := newLimitedInbound(t, 1)
|
||||
if !inboundInstance.flowLimiter.Acquire(1) {
|
||||
t.Fatal("failed to pre-acquire limiter")
|
||||
}
|
||||
|
||||
stream, peer := net.Pipe()
|
||||
defer stream.Close()
|
||||
defer peer.Close()
|
||||
|
||||
respWriter := &captureConnectMetadataWriter{}
|
||||
inboundInstance.handleTCPStream(context.Background(), stream, respWriter, adapter.InboundContext{})
|
||||
if respWriter.err == nil {
|
||||
t.Fatal("expected too many active flows error")
|
||||
}
|
||||
if !hasFlowConnectRateLimited(respWriter.metadata) {
|
||||
t.Fatal("expected flow rate limit metadata")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP2ResponseWriterFlowRateLimitedMeta(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
writer := &http2ResponseWriter{
|
||||
writer: recorder,
|
||||
flusher: recorder,
|
||||
}
|
||||
|
||||
err := writer.WriteResponse(context.DeadlineExceeded, flowConnectRateLimitedMetadata())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if recorder.Code != http.StatusBadGateway {
|
||||
t.Fatalf("expected %d, got %d", http.StatusBadGateway, recorder.Code)
|
||||
}
|
||||
if meta := recorder.Header().Get(h2HeaderResponseMeta); meta != h2ResponseMetaCloudflaredLimited {
|
||||
t.Fatalf("unexpected response meta: %q", meta)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatagramV2RegisterSessionRespectsMaxActiveFlows(t *testing.T) {
|
||||
inboundInstance := newLimitedInbound(t, 1)
|
||||
if !inboundInstance.flowLimiter.Acquire(1) {
|
||||
t.Fatal("failed to pre-acquire limiter")
|
||||
}
|
||||
muxer := NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger)
|
||||
err := muxer.RegisterSession(context.Background(), uuidTest(1), net.IPv4(1, 1, 1, 1), 53, 0)
|
||||
if err == nil {
|
||||
t.Fatal("expected too many active flows error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatagramV3RegistrationTooManyActiveFlows(t *testing.T) {
|
||||
inboundInstance := newLimitedInbound(t, 1)
|
||||
if !inboundInstance.flowLimiter.Acquire(1) {
|
||||
t.Fatal("failed to pre-acquire limiter")
|
||||
}
|
||||
sender := &captureDatagramSender{}
|
||||
muxer := NewDatagramV3Muxer(inboundInstance, sender, inboundInstance.logger)
|
||||
|
||||
requestID := RequestID{}
|
||||
requestID[15] = 1
|
||||
payload := make([]byte, 1+1+2+2+16+4)
|
||||
payload[0] = 0
|
||||
binary.BigEndian.PutUint16(payload[1:3], 53)
|
||||
binary.BigEndian.PutUint16(payload[3:5], 30)
|
||||
copy(payload[5:21], requestID[:])
|
||||
copy(payload[21:25], []byte{1, 1, 1, 1})
|
||||
|
||||
muxer.handleRegistration(context.Background(), payload)
|
||||
if len(sender.sent) != 1 {
|
||||
t.Fatalf("expected one registration response, got %d", len(sender.sent))
|
||||
}
|
||||
if sender.sent[0][0] != byte(DatagramV3TypeRegistrationResponse) || sender.sent[0][1] != v3ResponseTooManyActiveFlows {
|
||||
t.Fatalf("unexpected v3 response: %v", sender.sent[0])
|
||||
}
|
||||
}
|
||||
|
||||
func uuidTest(last byte) uuid.UUID {
|
||||
var value uuid.UUID
|
||||
value[15] = last
|
||||
return value
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
h2HeaderUpgrade = "Cf-Cloudflared-Proxy-Connection-Upgrade"
|
||||
h2HeaderTCPSrc = "Cf-Cloudflared-Proxy-Src"
|
||||
h2HeaderResponseMeta = "Cf-Cloudflared-Response-Meta"
|
||||
h2HeaderResponseUser = "Cf-Cloudflared-Response-Headers"
|
||||
h2UpgradeControlStream = "control-stream"
|
||||
h2UpgradeWebsocket = "websocket"
|
||||
h2UpgradeConfiguration = "update-configuration"
|
||||
h2ResponseMetaOrigin = `{"src":"origin"}`
|
||||
)
|
||||
|
||||
var headerEncoding = base64.RawStdEncoding
|
||||
|
||||
// SerializeHeaders encodes HTTP/1 headers into base64 pairs: base64(name):base64(value);...
|
||||
func SerializeHeaders(header http.Header) string {
|
||||
var builder strings.Builder
|
||||
for name, values := range header {
|
||||
for _, value := range values {
|
||||
if builder.Len() > 0 {
|
||||
builder.WriteByte(';')
|
||||
}
|
||||
builder.WriteString(headerEncoding.EncodeToString([]byte(name)))
|
||||
builder.WriteByte(':')
|
||||
builder.WriteString(headerEncoding.EncodeToString([]byte(value)))
|
||||
}
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// isControlResponseHeader returns true for headers that are internal control headers.
|
||||
func isControlResponseHeader(name string) bool {
|
||||
lower := strings.ToLower(name)
|
||||
return strings.HasPrefix(lower, ":") ||
|
||||
strings.HasPrefix(lower, "cf-int-") ||
|
||||
strings.HasPrefix(lower, "cf-cloudflared-") ||
|
||||
strings.HasPrefix(lower, "cf-proxy-")
|
||||
}
|
||||
|
||||
// isWebsocketClientHeader returns true for headers needed by the client for WebSocket upgrade.
|
||||
func isWebsocketClientHeader(name string) bool {
|
||||
lower := strings.ToLower(name)
|
||||
return lower == "sec-websocket-accept" ||
|
||||
lower == "connection" ||
|
||||
lower == "upgrade"
|
||||
}
|
||||
@@ -1,247 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func requireEnvVars(t *testing.T) (token string, testURL string) {
|
||||
t.Helper()
|
||||
token = os.Getenv("CF_TUNNEL_TOKEN")
|
||||
testURL = os.Getenv("CF_TEST_URL")
|
||||
if token == "" || testURL == "" {
|
||||
t.Skip("CF_TUNNEL_TOKEN and CF_TEST_URL must be set")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func startOriginServer(t *testing.T) {
|
||||
t.Helper()
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"ok":true}`))
|
||||
})
|
||||
mux.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
io.Copy(w, r.Body)
|
||||
})
|
||||
mux.HandleFunc("/status/", func(w http.ResponseWriter, r *http.Request) {
|
||||
codeStr := strings.TrimPrefix(r.URL.Path, "/status/")
|
||||
code, err := strconv.Atoi(codeStr)
|
||||
if err != nil {
|
||||
code = 200
|
||||
}
|
||||
w.Header().Set("X-Custom", "test-value")
|
||||
w.WriteHeader(code)
|
||||
fmt.Fprintf(w, "status: %d", code)
|
||||
})
|
||||
mux.HandleFunc("/headers", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(r.Header)
|
||||
})
|
||||
|
||||
server := &http.Server{
|
||||
Addr: "127.0.0.1:8083",
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", server.Addr)
|
||||
if err != nil {
|
||||
t.Fatal("start origin server: ", err)
|
||||
}
|
||||
|
||||
go server.Serve(listener)
|
||||
t.Cleanup(func() {
|
||||
server.Close()
|
||||
})
|
||||
}
|
||||
|
||||
type testRouter struct {
|
||||
preMatch func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error)
|
||||
}
|
||||
|
||||
func (r *testRouter) Start(stage adapter.StartStage) error { return nil }
|
||||
|
||||
func (r *testRouter) Close() error { return nil }
|
||||
|
||||
func (r *testRouter) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
|
||||
destination := metadata.Destination.String()
|
||||
upstream, err := net.Dial("tcp", destination)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
io.Copy(upstream, conn)
|
||||
upstream.Close()
|
||||
}()
|
||||
io.Copy(conn, upstream)
|
||||
conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *testRouter) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *testRouter) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
destination := metadata.Destination.String()
|
||||
upstream, err := net.Dial("tcp", destination)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
onClose(err)
|
||||
return
|
||||
}
|
||||
var once sync.Once
|
||||
closeFn := func() {
|
||||
once.Do(func() {
|
||||
conn.Close()
|
||||
upstream.Close()
|
||||
})
|
||||
}
|
||||
go func() {
|
||||
io.Copy(upstream, conn)
|
||||
closeFn()
|
||||
}()
|
||||
io.Copy(conn, upstream)
|
||||
closeFn()
|
||||
onClose(nil)
|
||||
}
|
||||
|
||||
func (r *testRouter) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
onClose(nil)
|
||||
}
|
||||
|
||||
func (r *testRouter) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) {
|
||||
conn, err := net.Dial("udp", metadata.Destination.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bufio.NewUnbindPacketConn(conn), nil
|
||||
}
|
||||
|
||||
func (r *testRouter) PreMatch(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
|
||||
if r.preMatch != nil {
|
||||
return r.preMatch(metadata, routeContext, timeout, supportBypass)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *testRouter) RuleSet(tag string) (adapter.RuleSet, bool) { return nil, false }
|
||||
|
||||
func (r *testRouter) Rules() []adapter.Rule { return nil }
|
||||
|
||||
func (r *testRouter) NeedFindProcess() bool { return false }
|
||||
|
||||
func (r *testRouter) NeedFindNeighbor() bool { return false }
|
||||
|
||||
func (r *testRouter) NeighborResolver() adapter.NeighborResolver { return nil }
|
||||
|
||||
func (r *testRouter) AppendTracker(tracker adapter.ConnectionTracker) {}
|
||||
|
||||
func (r *testRouter) ResetNetwork() {}
|
||||
|
||||
func newTestInbound(t *testing.T, token string, protocol string, haConnections int) *Inbound {
|
||||
t.Helper()
|
||||
credentials, err := parseToken(token)
|
||||
if err != nil {
|
||||
t.Fatal("parse token: ", err)
|
||||
}
|
||||
|
||||
logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}})
|
||||
if err != nil {
|
||||
t.Fatal("create logger: ", err)
|
||||
}
|
||||
|
||||
configManager, err := NewConfigManager()
|
||||
if err != nil {
|
||||
t.Fatal("create config manager: ", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
inboundInstance := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
router: &testRouter{},
|
||||
logger: logFactory.NewLogger("test"),
|
||||
credentials: credentials,
|
||||
connectorID: uuid.New(),
|
||||
haConnections: haConnections,
|
||||
protocol: protocol,
|
||||
edgeIPVersion: 0,
|
||||
datagramVersion: "",
|
||||
featureSelector: newFeatureSelector(ctx, credentials.AccountTag, ""),
|
||||
gracePeriod: 5 * time.Second,
|
||||
configManager: configManager,
|
||||
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
|
||||
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
|
||||
datagramV3Manager: NewDatagramV3SessionManager(),
|
||||
connectedIndices: make(map[uint8]struct{}),
|
||||
connectedNotify: make(chan uint8, haConnections),
|
||||
controlDialer: N.SystemDialer,
|
||||
tunnelDialer: N.SystemDialer,
|
||||
accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer},
|
||||
connectionStates: make([]connectionState, haConnections),
|
||||
successfulProtocols: make(map[string]struct{}),
|
||||
directTransports: make(map[string]*http.Transport),
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
inboundInstance.Close()
|
||||
})
|
||||
return inboundInstance
|
||||
}
|
||||
|
||||
func waitForTunnel(t *testing.T, testURL string, timeout time.Duration) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
var lastErr error
|
||||
var lastStatus int
|
||||
var lastBody string
|
||||
for time.Now().Before(deadline) {
|
||||
resp, err := client.Get(testURL + "/ping")
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
lastStatus = resp.StatusCode
|
||||
lastBody = string(body)
|
||||
if resp.StatusCode == http.StatusOK && lastBody == `{"ok":true}` {
|
||||
return
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("tunnel not ready after %s (lastErr=%v, lastStatus=%d, lastBody=%q)", timeout, lastErr, lastStatus, lastBody)
|
||||
}
|
||||
@@ -1,577 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
const (
|
||||
icmpFlowTimeout = 30 * time.Second
|
||||
icmpTraceIdentityLength = 16 + 8 + 1
|
||||
defaultICMPPacketTTL = 255
|
||||
icmpErrorHeaderLen = 8
|
||||
ipv4TTLExceededQuoteLen = 548
|
||||
ipv6TTLExceededQuoteLen = 1232
|
||||
maxICMPPayloadLen = 1280
|
||||
|
||||
icmpv4TypeEchoRequest = 8
|
||||
icmpv4TypeEchoReply = 0
|
||||
icmpv4TypeTimeExceeded = 11
|
||||
icmpv6TypeEchoRequest = 128
|
||||
icmpv6TypeEchoReply = 129
|
||||
icmpv6TypeTimeExceeded = 3
|
||||
)
|
||||
|
||||
type ICMPTraceContext struct {
|
||||
Traced bool
|
||||
Identity []byte
|
||||
}
|
||||
|
||||
type ICMPFlowKey struct {
|
||||
IPVersion uint8
|
||||
SourceIP netip.Addr
|
||||
Destination netip.Addr
|
||||
}
|
||||
|
||||
type ICMPRequestKey struct {
|
||||
Flow ICMPFlowKey
|
||||
Identifier uint16
|
||||
Sequence uint16
|
||||
}
|
||||
|
||||
type ICMPPacketInfo struct {
|
||||
IPVersion uint8
|
||||
Protocol uint8
|
||||
SourceIP netip.Addr
|
||||
Destination netip.Addr
|
||||
ICMPType uint8
|
||||
ICMPCode uint8
|
||||
Identifier uint16
|
||||
Sequence uint16
|
||||
IPv4HeaderLen int
|
||||
IPv4TTL uint8
|
||||
IPv6HopLimit uint8
|
||||
RawPacket []byte
|
||||
}
|
||||
|
||||
func (i ICMPPacketInfo) FlowKey() ICMPFlowKey {
|
||||
return ICMPFlowKey{
|
||||
IPVersion: i.IPVersion,
|
||||
SourceIP: i.SourceIP,
|
||||
Destination: i.Destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (i ICMPPacketInfo) RequestKey() ICMPRequestKey {
|
||||
return ICMPRequestKey{
|
||||
Flow: i.FlowKey(),
|
||||
Identifier: i.Identifier,
|
||||
Sequence: i.Sequence,
|
||||
}
|
||||
}
|
||||
|
||||
func (i ICMPPacketInfo) ReplyRequestKey() ICMPRequestKey {
|
||||
return ICMPRequestKey{
|
||||
Flow: ICMPFlowKey{
|
||||
IPVersion: i.IPVersion,
|
||||
SourceIP: i.Destination,
|
||||
Destination: i.SourceIP,
|
||||
},
|
||||
Identifier: i.Identifier,
|
||||
Sequence: i.Sequence,
|
||||
}
|
||||
}
|
||||
|
||||
func (i ICMPPacketInfo) IsEchoRequest() bool {
|
||||
switch i.IPVersion {
|
||||
case 4:
|
||||
return i.ICMPType == icmpv4TypeEchoRequest && i.ICMPCode == 0
|
||||
case 6:
|
||||
return i.ICMPType == icmpv6TypeEchoRequest && i.ICMPCode == 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (i ICMPPacketInfo) IsEchoReply() bool {
|
||||
switch i.IPVersion {
|
||||
case 4:
|
||||
return i.ICMPType == icmpv4TypeEchoReply && i.ICMPCode == 0
|
||||
case 6:
|
||||
return i.ICMPType == icmpv6TypeEchoReply && i.ICMPCode == 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (i ICMPPacketInfo) TTL() uint8 {
|
||||
if i.IPVersion == 4 {
|
||||
return i.IPv4TTL
|
||||
}
|
||||
return i.IPv6HopLimit
|
||||
}
|
||||
|
||||
func (i ICMPPacketInfo) TTLExpired() bool {
|
||||
return i.TTL() <= 1
|
||||
}
|
||||
|
||||
func (i *ICMPPacketInfo) DecrementTTL() error {
|
||||
switch i.IPVersion {
|
||||
case 4:
|
||||
if i.IPv4TTL == 0 || i.IPv4HeaderLen < 20 || len(i.RawPacket) < i.IPv4HeaderLen {
|
||||
return E.New("invalid IPv4 packet TTL state")
|
||||
}
|
||||
i.IPv4TTL--
|
||||
i.RawPacket[8] = i.IPv4TTL
|
||||
binary.BigEndian.PutUint16(i.RawPacket[10:12], 0)
|
||||
binary.BigEndian.PutUint16(i.RawPacket[10:12], checksum(i.RawPacket[:i.IPv4HeaderLen], 0))
|
||||
case 6:
|
||||
if i.IPv6HopLimit == 0 || len(i.RawPacket) < 40 {
|
||||
return E.New("invalid IPv6 packet hop limit state")
|
||||
}
|
||||
i.IPv6HopLimit--
|
||||
i.RawPacket[7] = i.IPv6HopLimit
|
||||
default:
|
||||
return E.New("unsupported IP version: ", i.IPVersion)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type icmpWireVersion uint8
|
||||
|
||||
const (
|
||||
icmpWireV2 icmpWireVersion = iota + 1
|
||||
icmpWireV3
|
||||
)
|
||||
|
||||
type icmpFlowState struct {
|
||||
writer *ICMPReplyWriter
|
||||
lastActive time.Time
|
||||
}
|
||||
|
||||
type traceEntry struct {
|
||||
context ICMPTraceContext
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
type ICMPReplyWriter struct {
|
||||
sender DatagramSender
|
||||
wireVersion icmpWireVersion
|
||||
|
||||
access sync.Mutex
|
||||
traces map[ICMPRequestKey]traceEntry
|
||||
}
|
||||
|
||||
func NewICMPReplyWriter(sender DatagramSender, wireVersion icmpWireVersion) *ICMPReplyWriter {
|
||||
return &ICMPReplyWriter{
|
||||
sender: sender,
|
||||
wireVersion: wireVersion,
|
||||
traces: make(map[ICMPRequestKey]traceEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *ICMPReplyWriter) RegisterRequestTrace(packetInfo ICMPPacketInfo, traceContext ICMPTraceContext) {
|
||||
if !traceContext.Traced {
|
||||
return
|
||||
}
|
||||
w.access.Lock()
|
||||
w.traces[packetInfo.RequestKey()] = traceEntry{
|
||||
context: traceContext,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
w.access.Unlock()
|
||||
}
|
||||
|
||||
func (w *ICMPReplyWriter) WritePacket(packet []byte) error {
|
||||
packetInfo, err := ParseICMPPacket(packet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !packetInfo.IsEchoReply() {
|
||||
return nil
|
||||
}
|
||||
|
||||
requestKey := packetInfo.ReplyRequestKey()
|
||||
w.access.Lock()
|
||||
entry, loaded := w.traces[requestKey]
|
||||
if loaded {
|
||||
delete(w.traces, requestKey)
|
||||
}
|
||||
w.access.Unlock()
|
||||
traceContext := entry.context
|
||||
|
||||
datagram, err := encodeICMPDatagram(packetInfo.RawPacket, w.wireVersion, traceContext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return w.sender.SendDatagram(datagram)
|
||||
}
|
||||
|
||||
func (w *ICMPReplyWriter) cleanupExpired(now time.Time) {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
for key, entry := range w.traces {
|
||||
if now.After(entry.createdAt.Add(icmpFlowTimeout)) {
|
||||
delete(w.traces, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ICMPBridge struct {
|
||||
inbound *Inbound
|
||||
sender DatagramSender
|
||||
wireVersion icmpWireVersion
|
||||
routeMapping *tun.DirectRouteMapping
|
||||
|
||||
flowAccess sync.Mutex
|
||||
flows map[ICMPFlowKey]*icmpFlowState
|
||||
}
|
||||
|
||||
func NewICMPBridge(inbound *Inbound, sender DatagramSender, wireVersion icmpWireVersion) *ICMPBridge {
|
||||
bridge := &ICMPBridge{
|
||||
inbound: inbound,
|
||||
sender: sender,
|
||||
wireVersion: wireVersion,
|
||||
routeMapping: tun.NewDirectRouteMapping(icmpFlowTimeout),
|
||||
flows: make(map[ICMPFlowKey]*icmpFlowState),
|
||||
}
|
||||
if inbound != nil && inbound.ctx != nil {
|
||||
go bridge.cleanupLoop(inbound.ctx)
|
||||
}
|
||||
return bridge
|
||||
}
|
||||
|
||||
func (b *ICMPBridge) HandleV2(ctx context.Context, datagramType DatagramV2Type, payload []byte) error {
|
||||
traceContext := ICMPTraceContext{}
|
||||
switch datagramType {
|
||||
case DatagramV2TypeIP:
|
||||
case DatagramV2TypeIPWithTrace:
|
||||
if len(payload) < icmpTraceIdentityLength {
|
||||
return E.New("icmp trace payload is too short")
|
||||
}
|
||||
traceContext.Traced = true
|
||||
traceContext.Identity = append([]byte(nil), payload[len(payload)-icmpTraceIdentityLength:]...)
|
||||
payload = payload[:len(payload)-icmpTraceIdentityLength]
|
||||
default:
|
||||
return E.New("unsupported v2 icmp datagram type: ", datagramType)
|
||||
}
|
||||
return b.handlePacket(ctx, payload, traceContext)
|
||||
}
|
||||
|
||||
func (b *ICMPBridge) HandleV3(ctx context.Context, payload []byte) error {
|
||||
return b.handlePacket(ctx, payload, ICMPTraceContext{})
|
||||
}
|
||||
|
||||
func (b *ICMPBridge) handlePacket(ctx context.Context, payload []byte, traceContext ICMPTraceContext) error {
|
||||
packetInfo, err := ParseICMPPacket(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !packetInfo.IsEchoRequest() {
|
||||
return nil
|
||||
}
|
||||
if packetInfo.TTLExpired() {
|
||||
ttlExceededPacket, err := buildICMPTTLExceededPacket(packetInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
datagram, err := encodeICMPDatagram(ttlExceededPacket, b.wireVersion, traceContext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return b.sender.SendDatagram(datagram)
|
||||
}
|
||||
|
||||
if err := packetInfo.DecrementTTL(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state := b.getFlowState(packetInfo.FlowKey())
|
||||
state.lastActive = time.Now()
|
||||
if traceContext.Traced {
|
||||
state.writer.RegisterRequestTrace(packetInfo, traceContext)
|
||||
}
|
||||
|
||||
action, err := b.routeMapping.Lookup(tun.DirectRouteSession{
|
||||
Source: packetInfo.SourceIP,
|
||||
Destination: packetInfo.Destination,
|
||||
}, func(timeout time.Duration) (tun.DirectRouteDestination, error) {
|
||||
metadata := adapter.InboundContext{
|
||||
Inbound: b.inbound.Tag(),
|
||||
InboundType: b.inbound.Type(),
|
||||
IPVersion: packetInfo.IPVersion,
|
||||
Network: N.NetworkICMP,
|
||||
Source: M.SocksaddrFrom(packetInfo.SourceIP, 0),
|
||||
Destination: M.SocksaddrFrom(packetInfo.Destination, 0),
|
||||
OriginDestination: M.SocksaddrFrom(packetInfo.Destination, 0),
|
||||
}
|
||||
return b.inbound.router.PreMatch(metadata, state.writer, timeout, false)
|
||||
})
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return action.WritePacket(buf.As(packetInfo.RawPacket).ToOwned())
|
||||
}
|
||||
|
||||
func (b *ICMPBridge) getFlowState(key ICMPFlowKey) *icmpFlowState {
|
||||
b.flowAccess.Lock()
|
||||
defer b.flowAccess.Unlock()
|
||||
state, loaded := b.flows[key]
|
||||
if loaded {
|
||||
return state
|
||||
}
|
||||
state = &icmpFlowState{
|
||||
writer: NewICMPReplyWriter(b.sender, b.wireVersion),
|
||||
}
|
||||
b.flows[key] = state
|
||||
return state
|
||||
}
|
||||
|
||||
func (b *ICMPBridge) cleanupLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(icmpFlowTimeout)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case now := <-ticker.C:
|
||||
b.cleanupExpired(now)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *ICMPBridge) cleanupExpired(now time.Time) {
|
||||
b.flowAccess.Lock()
|
||||
defer b.flowAccess.Unlock()
|
||||
for key, state := range b.flows {
|
||||
state.writer.cleanupExpired(now)
|
||||
if now.After(state.lastActive.Add(icmpFlowTimeout)) {
|
||||
delete(b.flows, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ParseICMPPacket(packet []byte) (ICMPPacketInfo, error) {
|
||||
if len(packet) < 1 {
|
||||
return ICMPPacketInfo{}, E.New("empty IP packet")
|
||||
}
|
||||
version := packet[0] >> 4
|
||||
switch version {
|
||||
case 4:
|
||||
return parseIPv4ICMPPacket(packet)
|
||||
case 6:
|
||||
return parseIPv6ICMPPacket(packet)
|
||||
default:
|
||||
return ICMPPacketInfo{}, E.New("unsupported IP version: ", version)
|
||||
}
|
||||
}
|
||||
|
||||
func parseIPv4ICMPPacket(packet []byte) (ICMPPacketInfo, error) {
|
||||
if len(packet) < 20 {
|
||||
return ICMPPacketInfo{}, E.New("IPv4 packet too short")
|
||||
}
|
||||
headerLen := int(packet[0]&0x0F) * 4
|
||||
if headerLen < 20 || len(packet) < headerLen+8 {
|
||||
return ICMPPacketInfo{}, E.New("invalid IPv4 header length")
|
||||
}
|
||||
if packet[9] != 1 {
|
||||
return ICMPPacketInfo{}, E.New("IPv4 packet is not ICMP")
|
||||
}
|
||||
sourceIP, ok := netip.AddrFromSlice(packet[12:16])
|
||||
if !ok {
|
||||
return ICMPPacketInfo{}, E.New("invalid IPv4 source address")
|
||||
}
|
||||
destinationIP, ok := netip.AddrFromSlice(packet[16:20])
|
||||
if !ok {
|
||||
return ICMPPacketInfo{}, E.New("invalid IPv4 destination address")
|
||||
}
|
||||
return ICMPPacketInfo{
|
||||
IPVersion: 4,
|
||||
Protocol: 1,
|
||||
SourceIP: sourceIP,
|
||||
Destination: destinationIP,
|
||||
ICMPType: packet[headerLen],
|
||||
ICMPCode: packet[headerLen+1],
|
||||
Identifier: binary.BigEndian.Uint16(packet[headerLen+4 : headerLen+6]),
|
||||
Sequence: binary.BigEndian.Uint16(packet[headerLen+6 : headerLen+8]),
|
||||
IPv4HeaderLen: headerLen,
|
||||
IPv4TTL: packet[8],
|
||||
RawPacket: append([]byte(nil), packet...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseIPv6ICMPPacket(packet []byte) (ICMPPacketInfo, error) {
|
||||
if len(packet) < 48 {
|
||||
return ICMPPacketInfo{}, E.New("IPv6 packet too short")
|
||||
}
|
||||
if packet[6] != 58 {
|
||||
return ICMPPacketInfo{}, E.New("IPv6 packet is not ICMP")
|
||||
}
|
||||
sourceIP, ok := netip.AddrFromSlice(packet[8:24])
|
||||
if !ok {
|
||||
return ICMPPacketInfo{}, E.New("invalid IPv6 source address")
|
||||
}
|
||||
destinationIP, ok := netip.AddrFromSlice(packet[24:40])
|
||||
if !ok {
|
||||
return ICMPPacketInfo{}, E.New("invalid IPv6 destination address")
|
||||
}
|
||||
return ICMPPacketInfo{
|
||||
IPVersion: 6,
|
||||
Protocol: 58,
|
||||
SourceIP: sourceIP,
|
||||
Destination: destinationIP,
|
||||
ICMPType: packet[40],
|
||||
ICMPCode: packet[41],
|
||||
Identifier: binary.BigEndian.Uint16(packet[44:46]),
|
||||
Sequence: binary.BigEndian.Uint16(packet[46:48]),
|
||||
IPv6HopLimit: packet[7],
|
||||
RawPacket: append([]byte(nil), packet...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func maxEncodedICMPPacketLen(wireVersion icmpWireVersion, traceContext ICMPTraceContext) int {
|
||||
limit := maxV3UDPPayloadLen
|
||||
switch wireVersion {
|
||||
case icmpWireV2:
|
||||
limit -= typeIDLength
|
||||
if traceContext.Traced {
|
||||
limit -= len(traceContext.Identity)
|
||||
}
|
||||
case icmpWireV3:
|
||||
limit -= 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
if limit < 0 {
|
||||
return 0
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
func buildICMPTTLExceededPacket(packetInfo ICMPPacketInfo) ([]byte, error) {
|
||||
switch packetInfo.IPVersion {
|
||||
case 4:
|
||||
return buildIPv4ICMPTTLExceededPacket(packetInfo)
|
||||
case 6:
|
||||
return buildIPv6ICMPTTLExceededPacket(packetInfo)
|
||||
default:
|
||||
return nil, E.New("unsupported IP version: ", packetInfo.IPVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func buildIPv4ICMPTTLExceededPacket(packetInfo ICMPPacketInfo) ([]byte, error) {
|
||||
const headerLen = 20
|
||||
if !packetInfo.SourceIP.Is4() || !packetInfo.Destination.Is4() {
|
||||
return nil, E.New("TTL exceeded packet requires IPv4 addresses")
|
||||
}
|
||||
|
||||
quotedLength := min(len(packetInfo.RawPacket), ipv4TTLExceededQuoteLen)
|
||||
packet := make([]byte, headerLen+icmpErrorHeaderLen+quotedLength)
|
||||
packet[0] = 0x45
|
||||
binary.BigEndian.PutUint16(packet[2:4], uint16(len(packet)))
|
||||
packet[8] = defaultICMPPacketTTL
|
||||
packet[9] = 1
|
||||
copy(packet[12:16], packetInfo.Destination.AsSlice())
|
||||
copy(packet[16:20], packetInfo.SourceIP.AsSlice())
|
||||
packet[20] = icmpv4TypeTimeExceeded
|
||||
packet[21] = 0
|
||||
copy(packet[headerLen+icmpErrorHeaderLen:], packetInfo.RawPacket[:quotedLength])
|
||||
binary.BigEndian.PutUint16(packet[22:24], checksum(packet[20:], 0))
|
||||
binary.BigEndian.PutUint16(packet[10:12], checksum(packet[:headerLen], 0))
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func buildIPv6ICMPTTLExceededPacket(packetInfo ICMPPacketInfo) ([]byte, error) {
|
||||
const headerLen = 40
|
||||
if !packetInfo.SourceIP.Is6() || !packetInfo.Destination.Is6() {
|
||||
return nil, E.New("TTL exceeded packet requires IPv6 addresses")
|
||||
}
|
||||
|
||||
quotedLength := min(len(packetInfo.RawPacket), ipv6TTLExceededQuoteLen)
|
||||
packet := make([]byte, headerLen+icmpErrorHeaderLen+quotedLength)
|
||||
packet[0] = 0x60
|
||||
binary.BigEndian.PutUint16(packet[4:6], uint16(icmpErrorHeaderLen+quotedLength))
|
||||
packet[6] = 58
|
||||
packet[7] = defaultICMPPacketTTL
|
||||
copy(packet[8:24], packetInfo.Destination.AsSlice())
|
||||
copy(packet[24:40], packetInfo.SourceIP.AsSlice())
|
||||
packet[40] = icmpv6TypeTimeExceeded
|
||||
packet[41] = 0
|
||||
copy(packet[headerLen+icmpErrorHeaderLen:], packetInfo.RawPacket[:quotedLength])
|
||||
binary.BigEndian.PutUint16(packet[42:44], checksum(packet[40:], ipv6PseudoHeaderChecksum(packetInfo.Destination, packetInfo.SourceIP, uint32(icmpErrorHeaderLen+quotedLength), 58)))
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func encodeICMPDatagram(packet []byte, wireVersion icmpWireVersion, traceContext ICMPTraceContext) ([]byte, error) {
|
||||
switch wireVersion {
|
||||
case icmpWireV2:
|
||||
return encodeV2ICMPDatagram(packet, traceContext)
|
||||
case icmpWireV3:
|
||||
return encodeV3ICMPDatagram(packet)
|
||||
default:
|
||||
return nil, E.New("unsupported icmp wire version: ", wireVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func ipv6PseudoHeaderChecksum(source, destination netip.Addr, payloadLength uint32, nextHeader uint8) uint32 {
|
||||
var sum uint32
|
||||
sum = checksumSum(source.AsSlice(), sum)
|
||||
sum = checksumSum(destination.AsSlice(), sum)
|
||||
var lengthBytes [4]byte
|
||||
binary.BigEndian.PutUint32(lengthBytes[:], payloadLength)
|
||||
sum = checksumSum(lengthBytes[:], sum)
|
||||
sum = checksumSum([]byte{0, 0, 0, nextHeader}, sum)
|
||||
return sum
|
||||
}
|
||||
|
||||
func checksumSum(data []byte, sum uint32) uint32 {
|
||||
for len(data) >= 2 {
|
||||
sum += uint32(binary.BigEndian.Uint16(data[:2]))
|
||||
data = data[2:]
|
||||
}
|
||||
if len(data) == 1 {
|
||||
sum += uint32(data[0]) << 8
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
func checksum(data []byte, initial uint32) uint16 {
|
||||
sum := checksumSum(data, initial)
|
||||
for sum > 0xffff {
|
||||
sum = (sum >> 16) + (sum & 0xffff)
|
||||
}
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
func encodeV2ICMPDatagram(packet []byte, _ ICMPTraceContext) ([]byte, error) {
|
||||
data := make([]byte, 0, len(packet)+1)
|
||||
data = append(data, packet...)
|
||||
data = append(data, byte(DatagramV2TypeIP))
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func encodeV3ICMPDatagram(packet []byte) ([]byte, error) {
|
||||
if len(packet) == 0 {
|
||||
return nil, E.New("icmp payload is missing")
|
||||
}
|
||||
if len(packet) > maxICMPPayloadLen {
|
||||
return nil, E.New("icmp payload is too large")
|
||||
}
|
||||
data := make([]byte, 0, len(packet)+1)
|
||||
data = append(data, byte(DatagramV3TypeICMP))
|
||||
data = append(data, packet...)
|
||||
return data, nil
|
||||
}
|
||||
@@ -1,478 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type captureDatagramSender struct {
|
||||
sent [][]byte
|
||||
}
|
||||
|
||||
func (s *captureDatagramSender) SendDatagram(data []byte) error {
|
||||
s.sent = append(s.sent, append([]byte(nil), data...))
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeDirectRouteDestination struct {
|
||||
routeContext tun.DirectRouteContext
|
||||
packets [][]byte
|
||||
reply func(packet []byte) []byte
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (d *fakeDirectRouteDestination) WritePacket(packet *buf.Buffer) error {
|
||||
data := append([]byte(nil), packet.Bytes()...)
|
||||
packet.Release()
|
||||
d.packets = append(d.packets, data)
|
||||
if d.reply != nil {
|
||||
reply := d.reply(data)
|
||||
if reply != nil {
|
||||
return d.routeContext.WritePacket(reply)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *fakeDirectRouteDestination) Close() error {
|
||||
d.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *fakeDirectRouteDestination) IsClosed() bool {
|
||||
return d.closed
|
||||
}
|
||||
|
||||
func TestICMPBridgeHandleV2RoutesEchoRequest(t *testing.T) {
|
||||
var (
|
||||
preMatchCalls int
|
||||
captured adapter.InboundContext
|
||||
destination *fakeDirectRouteDestination
|
||||
)
|
||||
router := &testRouter{
|
||||
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
|
||||
preMatchCalls++
|
||||
captured = metadata
|
||||
destination = &fakeDirectRouteDestination{routeContext: routeContext}
|
||||
return destination, nil
|
||||
},
|
||||
}
|
||||
inboundInstance := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
router: router,
|
||||
}
|
||||
sender := &captureDatagramSender{}
|
||||
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2)
|
||||
|
||||
source := netip.MustParseAddr("198.18.0.2")
|
||||
target := netip.MustParseAddr("1.1.1.1")
|
||||
packet1 := buildIPv4ICMPPacket(source, target, 8, 0, 1, 1)
|
||||
packet2 := buildIPv4ICMPPacket(source, target, 8, 0, 1, 2)
|
||||
|
||||
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if preMatchCalls != 1 {
|
||||
t.Fatalf("expected one direct-route lookup, got %d", preMatchCalls)
|
||||
}
|
||||
if captured.Network != N.NetworkICMP {
|
||||
t.Fatalf("expected NetworkICMP, got %s", captured.Network)
|
||||
}
|
||||
if captured.Source.Addr != source || captured.Destination.Addr != target {
|
||||
t.Fatalf("unexpected metadata source/destination: %#v", captured)
|
||||
}
|
||||
if len(destination.packets) != 2 {
|
||||
t.Fatalf("expected two packets written, got %d", len(destination.packets))
|
||||
}
|
||||
if len(sender.sent) != 0 {
|
||||
t.Fatalf("expected no reply datagrams, got %d", len(sender.sent))
|
||||
}
|
||||
}
|
||||
|
||||
func TestICMPBridgeHandleV2TracedReply(t *testing.T) {
|
||||
traceIdentity := bytes.Repeat([]byte{0x7a}, icmpTraceIdentityLength)
|
||||
sender := &captureDatagramSender{}
|
||||
router := &testRouter{
|
||||
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
|
||||
return &fakeDirectRouteDestination{
|
||||
routeContext: routeContext,
|
||||
reply: buildEchoReply,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
inboundInstance := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
router: router,
|
||||
}
|
||||
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2)
|
||||
|
||||
request := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), 8, 0, 9, 7)
|
||||
request = append(request, traceIdentity...)
|
||||
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIPWithTrace, request); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(sender.sent) != 1 {
|
||||
t.Fatalf("expected one reply datagram, got %d", len(sender.sent))
|
||||
}
|
||||
reply := sender.sent[0]
|
||||
if reply[len(reply)-1] != byte(DatagramV2TypeIP) {
|
||||
t.Fatalf("expected plain v2 IP reply, got type %d", reply[len(reply)-1])
|
||||
}
|
||||
if len(reply) != len(buildEchoReply(buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), 8, 0, 9, 7)))+1 {
|
||||
t.Fatalf("unexpected traced reply size: %d", len(reply))
|
||||
}
|
||||
}
|
||||
|
||||
func TestICMPBridgeHandleV3Reply(t *testing.T) {
|
||||
sender := &captureDatagramSender{}
|
||||
router := &testRouter{
|
||||
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
|
||||
return &fakeDirectRouteDestination{
|
||||
routeContext: routeContext,
|
||||
reply: buildEchoReply,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
inboundInstance := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
router: router,
|
||||
}
|
||||
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV3)
|
||||
|
||||
request := buildIPv6ICMPPacket(netip.MustParseAddr("2001:db8::2"), netip.MustParseAddr("2606:4700:4700::1111"), 128, 0, 3, 5)
|
||||
if err := bridge.HandleV3(context.Background(), request); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(sender.sent) != 1 {
|
||||
t.Fatalf("expected one reply datagram, got %d", len(sender.sent))
|
||||
}
|
||||
reply := sender.sent[0]
|
||||
if reply[0] != byte(DatagramV3TypeICMP) {
|
||||
t.Fatalf("expected v3 ICMP datagram, got %d", reply[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestICMPBridgeDecrementsIPv4TTLBeforeRouting(t *testing.T) {
|
||||
var destination *fakeDirectRouteDestination
|
||||
router := &testRouter{
|
||||
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
|
||||
destination = &fakeDirectRouteDestination{routeContext: routeContext}
|
||||
return destination, nil
|
||||
},
|
||||
}
|
||||
inboundInstance := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
router: router,
|
||||
}
|
||||
bridge := NewICMPBridge(inboundInstance, &captureDatagramSender{}, icmpWireV2)
|
||||
|
||||
packet := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), icmpv4TypeEchoRequest, 0, 1, 1)
|
||||
packet[8] = 5
|
||||
|
||||
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(destination.packets) != 1 {
|
||||
t.Fatalf("expected one routed packet, got %d", len(destination.packets))
|
||||
}
|
||||
if got := destination.packets[0][8]; got != 4 {
|
||||
t.Fatalf("expected decremented IPv4 TTL, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestICMPBridgeDecrementsIPv6HopLimitBeforeRouting(t *testing.T) {
|
||||
var destination *fakeDirectRouteDestination
|
||||
router := &testRouter{
|
||||
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
|
||||
destination = &fakeDirectRouteDestination{routeContext: routeContext}
|
||||
return destination, nil
|
||||
},
|
||||
}
|
||||
inboundInstance := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
router: router,
|
||||
}
|
||||
bridge := NewICMPBridge(inboundInstance, &captureDatagramSender{}, icmpWireV3)
|
||||
|
||||
packet := buildIPv6ICMPPacket(netip.MustParseAddr("2001:db8::2"), netip.MustParseAddr("2606:4700:4700::1111"), icmpv6TypeEchoRequest, 0, 1, 1)
|
||||
packet[7] = 3
|
||||
|
||||
if err := bridge.HandleV3(context.Background(), packet); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(destination.packets) != 1 {
|
||||
t.Fatalf("expected one routed packet, got %d", len(destination.packets))
|
||||
}
|
||||
if got := destination.packets[0][7]; got != 2 {
|
||||
t.Fatalf("expected decremented IPv6 hop limit, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestICMPBridgeHandleV2TTLExceededTracedReply(t *testing.T) {
|
||||
var preMatchCalls int
|
||||
traceIdentity := bytes.Repeat([]byte{0x6b}, icmpTraceIdentityLength)
|
||||
sender := &captureDatagramSender{}
|
||||
router := &testRouter{
|
||||
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
|
||||
preMatchCalls++
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
inboundInstance := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
router: router,
|
||||
}
|
||||
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2)
|
||||
|
||||
source := netip.MustParseAddr("198.18.0.2")
|
||||
target := netip.MustParseAddr("1.1.1.1")
|
||||
packet := buildIPv4ICMPPacket(source, target, icmpv4TypeEchoRequest, 0, 1, 1)
|
||||
packet[8] = 1
|
||||
packet = append(packet, traceIdentity...)
|
||||
|
||||
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIPWithTrace, packet); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if preMatchCalls != 0 {
|
||||
t.Fatalf("expected TTL exceeded to bypass routing, got %d route lookups", preMatchCalls)
|
||||
}
|
||||
if len(sender.sent) != 1 {
|
||||
t.Fatalf("expected one TTL exceeded reply, got %d", len(sender.sent))
|
||||
}
|
||||
reply := sender.sent[0]
|
||||
if reply[len(reply)-1] != byte(DatagramV2TypeIP) {
|
||||
t.Fatalf("expected plain v2 reply, got type %d", reply[len(reply)-1])
|
||||
}
|
||||
rawReply := reply[:len(reply)-1]
|
||||
packetInfo, err := ParseICMPPacket(rawReply)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if packetInfo.ICMPType != icmpv4TypeTimeExceeded || packetInfo.ICMPCode != 0 {
|
||||
t.Fatalf("expected IPv4 time exceeded reply, got type=%d code=%d", packetInfo.ICMPType, packetInfo.ICMPCode)
|
||||
}
|
||||
if packetInfo.SourceIP != target || packetInfo.Destination != source {
|
||||
t.Fatalf("unexpected TTL exceeded routing: src=%s dst=%s", packetInfo.SourceIP, packetInfo.Destination)
|
||||
}
|
||||
if packetInfo.TTL() != 255 {
|
||||
t.Fatalf("expected TTL exceeded packet TTL 255, got %d", packetInfo.TTL())
|
||||
}
|
||||
}
|
||||
|
||||
func TestICMPBridgeHandleV3TTLExceededReply(t *testing.T) {
|
||||
var preMatchCalls int
|
||||
sender := &captureDatagramSender{}
|
||||
router := &testRouter{
|
||||
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
|
||||
preMatchCalls++
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
inboundInstance := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
router: router,
|
||||
}
|
||||
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV3)
|
||||
|
||||
source := netip.MustParseAddr("2001:db8::2")
|
||||
target := netip.MustParseAddr("2606:4700:4700::1111")
|
||||
packet := buildIPv6ICMPPacket(source, target, icmpv6TypeEchoRequest, 0, 1, 1)
|
||||
packet[7] = 1
|
||||
|
||||
if err := bridge.HandleV3(context.Background(), packet); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if preMatchCalls != 0 {
|
||||
t.Fatalf("expected TTL exceeded to bypass routing, got %d route lookups", preMatchCalls)
|
||||
}
|
||||
if len(sender.sent) != 1 {
|
||||
t.Fatalf("expected one TTL exceeded reply, got %d", len(sender.sent))
|
||||
}
|
||||
if sender.sent[0][0] != byte(DatagramV3TypeICMP) {
|
||||
t.Fatalf("expected v3 ICMP reply, got %d", sender.sent[0][0])
|
||||
}
|
||||
packetInfo, err := ParseICMPPacket(sender.sent[0][1:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if packetInfo.ICMPType != icmpv6TypeTimeExceeded || packetInfo.ICMPCode != 0 {
|
||||
t.Fatalf("expected IPv6 time exceeded reply, got type=%d code=%d", packetInfo.ICMPType, packetInfo.ICMPCode)
|
||||
}
|
||||
if packetInfo.SourceIP != target || packetInfo.Destination != source {
|
||||
t.Fatalf("unexpected TTL exceeded routing: src=%s dst=%s", packetInfo.SourceIP, packetInfo.Destination)
|
||||
}
|
||||
if packetInfo.TTL() != 255 {
|
||||
t.Fatalf("expected TTL exceeded packet TTL 255, got %d", packetInfo.TTL())
|
||||
}
|
||||
}
|
||||
|
||||
func TestICMPBridgeDropsNonEcho(t *testing.T) {
|
||||
var preMatchCalls int
|
||||
router := &testRouter{
|
||||
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
|
||||
preMatchCalls++
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
inboundInstance := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
router: router,
|
||||
}
|
||||
sender := &captureDatagramSender{}
|
||||
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2)
|
||||
|
||||
packet := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), 3, 0, 1, 1)
|
||||
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if preMatchCalls != 0 {
|
||||
t.Fatalf("expected no route lookup, got %d", preMatchCalls)
|
||||
}
|
||||
if len(sender.sent) != 0 {
|
||||
t.Fatalf("expected no sender datagrams, got %d", len(sender.sent))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildICMPTTLExceededPacketUsesRFCQuoteLengths(t *testing.T) {
|
||||
ipv4Packet := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), icmpv4TypeEchoRequest, 0, 1, 1)
|
||||
ipv4Packet = append(ipv4Packet, bytes.Repeat([]byte{0xaa}, 4096)...)
|
||||
ipv4Info, err := ParseICMPPacket(ipv4Packet)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ipv4Reply, err := buildICMPTTLExceededPacket(ipv4Info)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(ipv4Reply) != 20+icmpErrorHeaderLen+ipv4TTLExceededQuoteLen {
|
||||
t.Fatalf("unexpected IPv4 TTL exceeded size: %d", len(ipv4Reply))
|
||||
}
|
||||
|
||||
ipv6Packet := buildIPv6ICMPPacket(netip.MustParseAddr("2001:db8::2"), netip.MustParseAddr("2606:4700:4700::1111"), icmpv6TypeEchoRequest, 0, 1, 1)
|
||||
ipv6Packet = append(ipv6Packet, bytes.Repeat([]byte{0xbb}, 4096)...)
|
||||
ipv6Info, err := ParseICMPPacket(ipv6Packet)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ipv6Reply, err := buildICMPTTLExceededPacket(ipv6Info)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(ipv6Reply) != 40+icmpErrorHeaderLen+ipv6TTLExceededQuoteLen {
|
||||
t.Fatalf("unexpected IPv6 TTL exceeded size: %d", len(ipv6Reply))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeV3ICMPDatagramRejectsEmptyPayload(t *testing.T) {
|
||||
if _, err := encodeV3ICMPDatagram(nil); err == nil {
|
||||
t.Fatal("expected empty payload to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeV3ICMPDatagramRejectsOversizedPayload(t *testing.T) {
|
||||
if _, err := encodeV3ICMPDatagram(make([]byte, maxICMPPayloadLen+1)); err == nil {
|
||||
t.Fatal("expected oversized payload to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICMPBridgeCleanupExpired(t *testing.T) {
|
||||
bridge := NewICMPBridge(&Inbound{}, &captureDatagramSender{}, icmpWireV2)
|
||||
now := time.Now()
|
||||
|
||||
expiredKey := ICMPFlowKey{
|
||||
IPVersion: 4,
|
||||
SourceIP: netip.MustParseAddr("198.18.0.2"),
|
||||
Destination: netip.MustParseAddr("1.1.1.1"),
|
||||
}
|
||||
expiredState := bridge.getFlowState(expiredKey)
|
||||
expiredState.lastActive = now.Add(-icmpFlowTimeout - time.Second)
|
||||
expiredState.writer.traces[ICMPRequestKey{Flow: expiredKey, Identifier: 1, Sequence: 1}] = traceEntry{
|
||||
context: ICMPTraceContext{Traced: true, Identity: []byte{1}},
|
||||
createdAt: now.Add(-icmpFlowTimeout - time.Second),
|
||||
}
|
||||
|
||||
activeKey := ICMPFlowKey{
|
||||
IPVersion: 6,
|
||||
SourceIP: netip.MustParseAddr("2001:db8::2"),
|
||||
Destination: netip.MustParseAddr("2606:4700:4700::1111"),
|
||||
}
|
||||
activeState := bridge.getFlowState(activeKey)
|
||||
activeState.lastActive = now
|
||||
activeState.writer.traces[ICMPRequestKey{Flow: activeKey, Identifier: 2, Sequence: 2}] = traceEntry{
|
||||
context: ICMPTraceContext{Traced: true, Identity: []byte{2}},
|
||||
createdAt: now,
|
||||
}
|
||||
|
||||
bridge.cleanupExpired(now)
|
||||
|
||||
if _, exists := bridge.flows[expiredKey]; exists {
|
||||
t.Fatal("expected expired flow to be removed")
|
||||
}
|
||||
if _, exists := bridge.flows[activeKey]; !exists {
|
||||
t.Fatal("expected active flow to remain")
|
||||
}
|
||||
if len(activeState.writer.traces) != 1 {
|
||||
t.Fatalf("expected active trace to remain, got %d", len(activeState.writer.traces))
|
||||
}
|
||||
}
|
||||
|
||||
func buildEchoReply(packet []byte) []byte {
|
||||
info, err := ParseICMPPacket(packet)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
switch info.IPVersion {
|
||||
case 4:
|
||||
return buildIPv4ICMPPacket(info.Destination, info.SourceIP, 0, 0, info.Identifier, info.Sequence)
|
||||
case 6:
|
||||
return buildIPv6ICMPPacket(info.Destination, info.SourceIP, 129, 0, info.Identifier, info.Sequence)
|
||||
default:
|
||||
panic("unsupported version")
|
||||
}
|
||||
}
|
||||
|
||||
func buildIPv4ICMPPacket(source, destination netip.Addr, icmpType, icmpCode uint8, identifier, sequence uint16) []byte {
|
||||
packet := make([]byte, 28)
|
||||
packet[0] = 0x45
|
||||
binary.BigEndian.PutUint16(packet[2:4], uint16(len(packet)))
|
||||
packet[8] = 64
|
||||
packet[9] = 1
|
||||
copy(packet[12:16], source.AsSlice())
|
||||
copy(packet[16:20], destination.AsSlice())
|
||||
packet[20] = icmpType
|
||||
packet[21] = icmpCode
|
||||
binary.BigEndian.PutUint16(packet[24:26], identifier)
|
||||
binary.BigEndian.PutUint16(packet[26:28], sequence)
|
||||
return packet
|
||||
}
|
||||
|
||||
func buildIPv6ICMPPacket(source, destination netip.Addr, icmpType, icmpCode uint8, identifier, sequence uint16) []byte {
|
||||
packet := make([]byte, 48)
|
||||
packet[0] = 0x60
|
||||
binary.BigEndian.PutUint16(packet[4:6], 8)
|
||||
packet[6] = 58
|
||||
packet[7] = 64
|
||||
copy(packet[8:24], source.AsSlice())
|
||||
copy(packet[24:40], destination.AsSlice())
|
||||
packet[40] = icmpType
|
||||
packet[41] = icmpCode
|
||||
binary.BigEndian.PutUint16(packet[44:46], identifier)
|
||||
binary.BigEndian.PutUint16(packet[46:48], sequence)
|
||||
return packet
|
||||
}
|
||||
@@ -4,149 +4,32 @@ package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
cloudflared "github.com/sagernet/sing-cloudflared"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
boxDialer "github.com/sagernet/sing-box/common/dialer"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sagernet/sing/common/pipe"
|
||||
tun "github.com/sagernet/sing-tun"
|
||||
)
|
||||
|
||||
func RegisterInbound(registry *inbound.Registry) {
|
||||
inbound.Register[option.CloudflaredInboundOptions](registry, C.TypeCloudflared, NewInbound)
|
||||
}
|
||||
|
||||
var ErrNonRemoteManagedTunnelUnsupported = errors.New("cloudflared only supports remote-managed tunnels")
|
||||
|
||||
var (
|
||||
newQUICConnection = NewQUICConnection
|
||||
newHTTP2Connection = NewHTTP2Connection
|
||||
serveQUICConnection = func(connection *QUICConnection, ctx context.Context, handler StreamHandler) error {
|
||||
return connection.Serve(ctx, handler)
|
||||
}
|
||||
serveHTTP2Connection = func(connection *HTTP2Connection, ctx context.Context) error {
|
||||
return connection.Serve(ctx)
|
||||
}
|
||||
)
|
||||
|
||||
type Inbound struct {
|
||||
inbound.Adapter
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
router adapter.Router
|
||||
logger log.ContextLogger
|
||||
credentials Credentials
|
||||
connectorID uuid.UUID
|
||||
haConnections int
|
||||
protocol string
|
||||
region string
|
||||
edgeIPVersion int
|
||||
datagramVersion string
|
||||
featureSelector *featureSelector
|
||||
gracePeriod time.Duration
|
||||
configManager *ConfigManager
|
||||
flowLimiter *FlowLimiter
|
||||
accessCache *accessValidatorCache
|
||||
controlDialer N.Dialer
|
||||
tunnelDialer N.Dialer
|
||||
|
||||
connectionAccess sync.Mutex
|
||||
connections []io.Closer
|
||||
done sync.WaitGroup
|
||||
|
||||
datagramMuxerAccess sync.Mutex
|
||||
datagramV2Muxers map[DatagramSender]*DatagramV2Muxer
|
||||
datagramV3Muxers map[DatagramSender]*DatagramV3Muxer
|
||||
datagramV3Manager *DatagramV3SessionManager
|
||||
|
||||
connectedAccess sync.Mutex
|
||||
connectedIndices map[uint8]struct{}
|
||||
connectedNotify chan uint8
|
||||
|
||||
stateAccess sync.Mutex
|
||||
connectionStates []connectionState
|
||||
successfulProtocols map[string]struct{}
|
||||
firstSuccessfulProtocol string
|
||||
|
||||
directTransportAccess sync.Mutex
|
||||
directTransports map[string]*http.Transport
|
||||
}
|
||||
|
||||
type connectionState struct {
|
||||
protocol string
|
||||
retries uint8
|
||||
}
|
||||
|
||||
func resolveGracePeriod(value *badoption.Duration) time.Duration {
|
||||
if value == nil {
|
||||
return 30 * time.Second
|
||||
}
|
||||
return time.Duration(*value)
|
||||
}
|
||||
|
||||
func connectionRetryDecision(err error) (retry bool, cancelAll bool) {
|
||||
switch {
|
||||
case err == nil:
|
||||
return false, false
|
||||
case errors.Is(err, ErrNonRemoteManagedTunnelUnsupported):
|
||||
return false, true
|
||||
case isPermanentRegistrationError(err):
|
||||
return false, false
|
||||
default:
|
||||
return true, false
|
||||
}
|
||||
}
|
||||
|
||||
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflaredInboundOptions) (adapter.Inbound, error) {
|
||||
if options.Token == "" {
|
||||
return nil, E.New("missing token")
|
||||
}
|
||||
credentials, err := parseToken(options.Token)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "parse token")
|
||||
}
|
||||
|
||||
haConnections := options.HAConnections
|
||||
if haConnections <= 0 {
|
||||
haConnections = 4
|
||||
}
|
||||
|
||||
protocol, err := normalizeProtocol(options.Protocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
edgeIPVersion := options.EdgeIPVersion
|
||||
if edgeIPVersion != 0 && edgeIPVersion != 4 && edgeIPVersion != 6 {
|
||||
return nil, E.New("unsupported edge_ip_version: ", edgeIPVersion, ", expected 0, 4 or 6")
|
||||
}
|
||||
|
||||
datagramVersion := options.DatagramVersion
|
||||
if datagramVersion != "" && datagramVersion != "v2" && datagramVersion != "v3" {
|
||||
return nil, E.New("unsupported datagram_version: ", datagramVersion, ", expected v2 or v3")
|
||||
}
|
||||
|
||||
gracePeriod := resolveGracePeriod(options.GracePeriod)
|
||||
|
||||
configManager, err := NewConfigManager()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "build cloudflared runtime config")
|
||||
}
|
||||
controlDialer, err := boxDialer.NewWithOptions(boxDialer.Options{
|
||||
Context: ctx,
|
||||
Options: options.ControlDialer,
|
||||
@@ -163,457 +46,131 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||
return nil, E.Cause(err, "build cloudflared tunnel dialer")
|
||||
}
|
||||
|
||||
region := options.Region
|
||||
if region != "" && credentials.Endpoint != "" {
|
||||
return nil, E.New("region cannot be specified when credentials already include an endpoint")
|
||||
service, err := cloudflared.NewService(cloudflared.ServiceOptions{
|
||||
Logger: logger,
|
||||
ConnectionDialer: &routerDialer{router: router, tag: tag},
|
||||
ControlDialer: controlDialer,
|
||||
TunnelDialer: tunnelDialer,
|
||||
ICMPHandler: &icmpRouterHandler{router: router, tag: tag},
|
||||
ConnContext: func(ctx context.Context) context.Context {
|
||||
return adapter.WithContext(ctx, &adapter.InboundContext{
|
||||
Inbound: tag,
|
||||
InboundType: C.TypeCloudflared,
|
||||
})
|
||||
},
|
||||
Token: options.Token,
|
||||
HAConnections: options.HAConnections,
|
||||
Protocol: options.Protocol,
|
||||
PostQuantum: options.PostQuantum,
|
||||
EdgeIPVersion: options.EdgeIPVersion,
|
||||
DatagramVersion: options.DatagramVersion,
|
||||
GracePeriod: resolveGracePeriod(options.GracePeriod),
|
||||
Region: options.Region,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if region == "" {
|
||||
region = credentials.Endpoint
|
||||
}
|
||||
|
||||
inboundCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
return &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, tag),
|
||||
ctx: inboundCtx,
|
||||
cancel: cancel,
|
||||
router: router,
|
||||
logger: logger,
|
||||
credentials: credentials,
|
||||
connectorID: uuid.New(),
|
||||
haConnections: haConnections,
|
||||
protocol: protocol,
|
||||
region: region,
|
||||
edgeIPVersion: edgeIPVersion,
|
||||
datagramVersion: datagramVersion,
|
||||
featureSelector: newFeatureSelector(inboundCtx, credentials.AccountTag, datagramVersion),
|
||||
gracePeriod: gracePeriod,
|
||||
configManager: configManager,
|
||||
flowLimiter: &FlowLimiter{},
|
||||
accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: controlDialer},
|
||||
controlDialer: controlDialer,
|
||||
tunnelDialer: tunnelDialer,
|
||||
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
|
||||
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
|
||||
datagramV3Manager: NewDatagramV3SessionManager(),
|
||||
connectedIndices: make(map[uint8]struct{}),
|
||||
connectedNotify: make(chan uint8, haConnections),
|
||||
connectionStates: make([]connectionState, haConnections),
|
||||
successfulProtocols: make(map[string]struct{}),
|
||||
directTransports: make(map[string]*http.Transport),
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, tag),
|
||||
service: service,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Inbound struct {
|
||||
inbound.Adapter
|
||||
service *cloudflared.Service
|
||||
}
|
||||
|
||||
func (i *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
|
||||
i.logger.Info("starting Cloudflare Tunnel with ", i.haConnections, " HA connections")
|
||||
|
||||
regions, err := DiscoverEdge(i.ctx, i.region, i.controlDialer)
|
||||
if err != nil {
|
||||
return E.Cause(err, "discover edge")
|
||||
}
|
||||
regions = FilterByIPVersion(regions, i.edgeIPVersion)
|
||||
edgeAddrs := flattenRegions(regions)
|
||||
if len(edgeAddrs) == 0 {
|
||||
return E.New("no edge addresses available")
|
||||
}
|
||||
if cappedHAConnections := effectiveHAConnections(i.haConnections, len(edgeAddrs)); cappedHAConnections != i.haConnections {
|
||||
i.logger.Info("requested ", i.haConnections, " HA connections but only ", cappedHAConnections, " edge addresses are available")
|
||||
i.haConnections = cappedHAConnections
|
||||
}
|
||||
|
||||
for connIndex := 0; connIndex < i.haConnections; connIndex++ {
|
||||
i.initializeConnectionState(uint8(connIndex))
|
||||
i.done.Add(1)
|
||||
go i.superviseConnection(uint8(connIndex), edgeAddrs)
|
||||
select {
|
||||
case readyConnIndex := <-i.connectedNotify:
|
||||
if readyConnIndex != uint8(connIndex) {
|
||||
i.logger.Debug("received unexpected ready notification for connection ", readyConnIndex)
|
||||
}
|
||||
case <-time.After(firstConnectionReadyTimeout):
|
||||
case <-i.ctx.Done():
|
||||
if connIndex == 0 {
|
||||
return i.ctx.Err()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Inbound) notifyConnected(connIndex uint8, protocol string) {
|
||||
i.stateAccess.Lock()
|
||||
if i.successfulProtocols == nil {
|
||||
i.successfulProtocols = make(map[string]struct{})
|
||||
}
|
||||
i.ensureConnectionStateLocked(connIndex)
|
||||
state := i.connectionStates[connIndex]
|
||||
state.retries = 0
|
||||
state.protocol = protocol
|
||||
i.connectionStates[connIndex] = state
|
||||
if protocol != "" {
|
||||
i.successfulProtocols[protocol] = struct{}{}
|
||||
if i.firstSuccessfulProtocol == "" {
|
||||
i.firstSuccessfulProtocol = protocol
|
||||
}
|
||||
}
|
||||
i.stateAccess.Unlock()
|
||||
|
||||
if i.connectedNotify == nil {
|
||||
return
|
||||
}
|
||||
i.connectedAccess.Lock()
|
||||
if _, loaded := i.connectedIndices[connIndex]; loaded {
|
||||
i.connectedAccess.Unlock()
|
||||
return
|
||||
}
|
||||
i.connectedIndices[connIndex] = struct{}{}
|
||||
i.connectedAccess.Unlock()
|
||||
i.connectedNotify <- connIndex
|
||||
}
|
||||
|
||||
func (i *Inbound) ApplyConfig(version int32, config []byte) ConfigUpdateResult {
|
||||
result := i.configManager.Apply(version, config)
|
||||
if result.Err != nil {
|
||||
i.logger.Error("update ingress configuration: ", result.Err)
|
||||
return result
|
||||
}
|
||||
i.resetDirectOriginTransports()
|
||||
i.logger.Info("updated ingress configuration (version ", result.LastAppliedVersion, ")")
|
||||
return result
|
||||
}
|
||||
|
||||
func (i *Inbound) maxActiveFlows() uint64 {
|
||||
return i.configManager.Snapshot().WarpRouting.MaxActiveFlows
|
||||
return i.service.Start()
|
||||
}
|
||||
|
||||
func (i *Inbound) Close() error {
|
||||
i.cancel()
|
||||
i.done.Wait()
|
||||
i.connectionAccess.Lock()
|
||||
for _, connection := range i.connections {
|
||||
connection.Close()
|
||||
}
|
||||
i.connections = nil
|
||||
i.connectionAccess.Unlock()
|
||||
i.resetDirectOriginTransports()
|
||||
return nil
|
||||
return i.service.Close()
|
||||
}
|
||||
|
||||
const (
|
||||
backoffBaseTime = time.Second
|
||||
backoffMaxTime = 2 * time.Minute
|
||||
firstConnectionReadyTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr) {
|
||||
defer i.done.Done()
|
||||
|
||||
edgeIndex := initialEdgeAddrIndex(connIndex, len(edgeAddrs))
|
||||
for {
|
||||
select {
|
||||
case <-i.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
edgeAddr := edgeAddrs[edgeIndex]
|
||||
err := i.safeServeConnection(connIndex, edgeAddr)
|
||||
if err == nil || i.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
retry, cancelAll := connectionRetryDecision(err)
|
||||
if cancelAll {
|
||||
i.logger.Error("connection ", connIndex, " failed permanently: ", err)
|
||||
i.cancel()
|
||||
return
|
||||
}
|
||||
if !retry {
|
||||
i.logger.Error("connection ", connIndex, " failed permanently: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
retries := i.incrementConnectionRetries(connIndex)
|
||||
edgeIndex = rotateEdgeAddrIndex(edgeIndex, len(edgeAddrs))
|
||||
backoff := backoffDuration(int(retries))
|
||||
var retryableErr *RetryableError
|
||||
if errors.As(err, &retryableErr) && retryableErr.Delay > 0 {
|
||||
backoff = retryableErr.Delay
|
||||
}
|
||||
i.logger.Error("connection ", connIndex, " failed: ", err, ", retrying in ", backoff)
|
||||
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-i.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr) error {
|
||||
state := i.connectionState(connIndex)
|
||||
protocol := state.protocol
|
||||
numPreviousAttempts := state.retries
|
||||
datagramVersion, features := i.currentConnectionFeatures()
|
||||
|
||||
switch protocol {
|
||||
case "quic":
|
||||
err := i.serveQUIC(connIndex, edgeAddr, datagramVersion, features, numPreviousAttempts)
|
||||
if err == nil || i.ctx.Err() != nil {
|
||||
return err
|
||||
}
|
||||
if errors.Is(err, ErrNonRemoteManagedTunnelUnsupported) {
|
||||
return err
|
||||
}
|
||||
if !i.protocolIsAuto() {
|
||||
return err
|
||||
}
|
||||
if i.hasSuccessfulProtocol("quic") {
|
||||
return err
|
||||
}
|
||||
i.setConnectionProtocol(connIndex, "http2")
|
||||
i.logger.Warn("QUIC connection failed, falling back to HTTP/2: ", err)
|
||||
return i.serveHTTP2(connIndex, edgeAddr, features, numPreviousAttempts)
|
||||
case "http2":
|
||||
return i.serveHTTP2(connIndex, edgeAddr, features, numPreviousAttempts)
|
||||
default:
|
||||
return E.New("unsupported protocol: ", protocol)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Inbound) safeServeConnection(connIndex uint8, edgeAddr *EdgeAddr) (err error) {
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
err = E.New("panic in serve connection: ", recovered, "\n", string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
return i.serveConnection(connIndex, edgeAddr)
|
||||
}
|
||||
|
||||
func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, datagramVersion string, features []string, numPreviousAttempts uint8) error {
|
||||
i.logger.Info("connecting to edge via QUIC (connection ", connIndex, ")")
|
||||
|
||||
connection, err := newQUICConnection(
|
||||
i.ctx, edgeAddr, connIndex,
|
||||
i.credentials, i.connectorID, datagramVersion,
|
||||
features, numPreviousAttempts, i.gracePeriod, i.tunnelDialer, func() {
|
||||
i.notifyConnected(connIndex, "quic")
|
||||
}, i.logger,
|
||||
)
|
||||
if err != nil {
|
||||
return E.Cause(err, "create QUIC connection")
|
||||
}
|
||||
|
||||
i.trackConnection(connection)
|
||||
defer func() {
|
||||
i.untrackConnection(connection)
|
||||
i.RemoveDatagramMuxer(connection)
|
||||
}()
|
||||
|
||||
return serveQUICConnection(connection, i.ctx, i)
|
||||
}
|
||||
|
||||
func (i *Inbound) currentConnectionFeatures() (string, []string) {
|
||||
if i.featureSelector != nil {
|
||||
return i.featureSelector.Snapshot()
|
||||
}
|
||||
version := i.datagramVersion
|
||||
if version == "" {
|
||||
version = defaultDatagramVersion
|
||||
}
|
||||
return version, DefaultFeatures(version)
|
||||
}
|
||||
|
||||
func (i *Inbound) serveHTTP2(connIndex uint8, edgeAddr *EdgeAddr, features []string, numPreviousAttempts uint8) error {
|
||||
i.logger.Info("connecting to edge via HTTP/2 (connection ", connIndex, ")")
|
||||
|
||||
connection, err := newHTTP2Connection(
|
||||
i.ctx, edgeAddr, connIndex,
|
||||
i.credentials, i.connectorID,
|
||||
features, numPreviousAttempts, i.gracePeriod, i, i.logger,
|
||||
)
|
||||
if err != nil {
|
||||
return E.Cause(err, "create HTTP/2 connection")
|
||||
}
|
||||
|
||||
i.trackConnection(connection)
|
||||
defer i.untrackConnection(connection)
|
||||
|
||||
return serveHTTP2Connection(connection, i.ctx)
|
||||
}
|
||||
|
||||
func (i *Inbound) initializeConnectionState(connIndex uint8) {
|
||||
i.stateAccess.Lock()
|
||||
defer i.stateAccess.Unlock()
|
||||
i.ensureConnectionStateLocked(connIndex)
|
||||
if i.connectionStates[connIndex].protocol == "" {
|
||||
i.connectionStates[connIndex].protocol = i.initialProtocolLocked()
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Inbound) connectionState(connIndex uint8) connectionState {
|
||||
i.stateAccess.Lock()
|
||||
defer i.stateAccess.Unlock()
|
||||
i.ensureConnectionStateLocked(connIndex)
|
||||
state := i.connectionStates[connIndex]
|
||||
if state.protocol == "" {
|
||||
state.protocol = i.initialProtocolLocked()
|
||||
i.connectionStates[connIndex] = state
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (i *Inbound) incrementConnectionRetries(connIndex uint8) uint8 {
|
||||
i.stateAccess.Lock()
|
||||
defer i.stateAccess.Unlock()
|
||||
i.ensureConnectionStateLocked(connIndex)
|
||||
state := i.connectionStates[connIndex]
|
||||
state.retries++
|
||||
i.connectionStates[connIndex] = state
|
||||
return state.retries
|
||||
}
|
||||
|
||||
func (i *Inbound) setConnectionProtocol(connIndex uint8, protocol string) {
|
||||
i.stateAccess.Lock()
|
||||
defer i.stateAccess.Unlock()
|
||||
i.ensureConnectionStateLocked(connIndex)
|
||||
state := i.connectionStates[connIndex]
|
||||
state.protocol = protocol
|
||||
i.connectionStates[connIndex] = state
|
||||
}
|
||||
|
||||
func (i *Inbound) hasSuccessfulProtocol(protocol string) bool {
|
||||
i.stateAccess.Lock()
|
||||
defer i.stateAccess.Unlock()
|
||||
if i.successfulProtocols == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := i.successfulProtocols[protocol]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (i *Inbound) protocolIsAuto() bool {
|
||||
return i.protocol == ""
|
||||
}
|
||||
|
||||
func (i *Inbound) ensureConnectionStateLocked(connIndex uint8) {
|
||||
requiredLen := int(connIndex) + 1
|
||||
if len(i.connectionStates) >= requiredLen {
|
||||
return
|
||||
}
|
||||
grown := make([]connectionState, requiredLen)
|
||||
copy(grown, i.connectionStates)
|
||||
i.connectionStates = grown
|
||||
}
|
||||
|
||||
func (i *Inbound) initialProtocolLocked() string {
|
||||
if i.protocol != "" {
|
||||
return i.protocol
|
||||
}
|
||||
if i.firstSuccessfulProtocol != "" {
|
||||
return i.firstSuccessfulProtocol
|
||||
}
|
||||
return "quic"
|
||||
}
|
||||
|
||||
func (i *Inbound) resetDirectOriginTransports() {
|
||||
i.directTransportAccess.Lock()
|
||||
transports := i.directTransports
|
||||
i.directTransports = make(map[string]*http.Transport)
|
||||
i.directTransportAccess.Unlock()
|
||||
|
||||
for _, transport := range transports {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Inbound) trackConnection(connection io.Closer) {
|
||||
i.connectionAccess.Lock()
|
||||
defer i.connectionAccess.Unlock()
|
||||
i.connections = append(i.connections, connection)
|
||||
}
|
||||
|
||||
func (i *Inbound) untrackConnection(connection io.Closer) {
|
||||
i.connectionAccess.Lock()
|
||||
defer i.connectionAccess.Unlock()
|
||||
for index, tracked := range i.connections {
|
||||
if tracked == connection {
|
||||
i.connections = append(i.connections[:index], i.connections[index+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func backoffDuration(retries int) time.Duration {
|
||||
backoff := backoffBaseTime * (1 << min(retries, 7))
|
||||
if backoff > backoffMaxTime {
|
||||
backoff = backoffMaxTime
|
||||
}
|
||||
// Add jitter: random duration in [backoff/2, backoff)
|
||||
jitter := time.Duration(rand.Int63n(int64(backoff / 2)))
|
||||
return backoff/2 + jitter
|
||||
}
|
||||
|
||||
func initialEdgeAddrIndex(connIndex uint8, size int) int {
|
||||
if size <= 1 {
|
||||
func resolveGracePeriod(value *badoption.Duration) time.Duration {
|
||||
if value == nil {
|
||||
return 0
|
||||
}
|
||||
return int(connIndex) % size
|
||||
return time.Duration(*value)
|
||||
}
|
||||
|
||||
func rotateEdgeAddrIndex(current int, size int) int {
|
||||
if size <= 1 {
|
||||
return 0
|
||||
}
|
||||
return (current + 1) % size
|
||||
// routerDialer bridges N.Dialer to the sing-box router for origin connections.
|
||||
type routerDialer struct {
|
||||
router adapter.Router
|
||||
tag string
|
||||
}
|
||||
|
||||
func flattenRegions(regions [][]*EdgeAddr) []*EdgeAddr {
|
||||
var result []*EdgeAddr
|
||||
for _, region := range regions {
|
||||
result = append(result, region...)
|
||||
func (d *routerDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
input, output := pipe.Pipe()
|
||||
done := make(chan struct{})
|
||||
metadata := adapter.InboundContext{
|
||||
Inbound: d.tag,
|
||||
InboundType: C.TypeCloudflared,
|
||||
Network: N.NetworkTCP,
|
||||
Destination: destination,
|
||||
}
|
||||
return result
|
||||
var closeOnce sync.Once
|
||||
closePipe := func() {
|
||||
closeOnce.Do(func() {
|
||||
common.Close(input, output)
|
||||
})
|
||||
}
|
||||
go d.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) {
|
||||
closePipe()
|
||||
close(done)
|
||||
}))
|
||||
return input, nil
|
||||
}
|
||||
|
||||
func effectiveHAConnections(requested, available int) int {
|
||||
if available <= 0 {
|
||||
return 0
|
||||
func (d *routerDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
originDialer, ok := d.router.(routedOriginPacketDialer)
|
||||
if !ok {
|
||||
return nil, E.New("router does not support cloudflare routed packet dialing")
|
||||
}
|
||||
if requested > available {
|
||||
return available
|
||||
}
|
||||
return requested
|
||||
}
|
||||
|
||||
func parseToken(token string) (Credentials, error) {
|
||||
data, err := base64.StdEncoding.DecodeString(token)
|
||||
packetConn, err := originDialer.DialRoutePacketConnection(ctx, adapter.InboundContext{
|
||||
Inbound: d.tag,
|
||||
InboundType: C.TypeCloudflared,
|
||||
Network: N.NetworkUDP,
|
||||
Destination: destination,
|
||||
UDPConnect: true,
|
||||
})
|
||||
if err != nil {
|
||||
return Credentials{}, E.Cause(err, "decode token")
|
||||
return nil, err
|
||||
}
|
||||
var tunnelToken TunnelToken
|
||||
err = json.Unmarshal(data, &tunnelToken)
|
||||
if err != nil {
|
||||
return Credentials{}, E.Cause(err, "unmarshal token")
|
||||
}
|
||||
return tunnelToken.ToCredentials(), nil
|
||||
return bufio.NewNetPacketConn(packetConn), nil
|
||||
}
|
||||
|
||||
// "auto" does not choose a transport here. We normalize it to an empty
|
||||
// sentinel so serveConnection can apply the token-style behavior later.
|
||||
// In the token-provided, remotely-managed tunnel path supported here, that
|
||||
// matches cloudflared's NewProtocolSelector(..., tunnelTokenProvided=true)
|
||||
// branch rather than the non-token remote-percentage selector.
|
||||
func normalizeProtocol(protocol string) (string, error) {
|
||||
if protocol == "auto" {
|
||||
return "", nil
|
||||
}
|
||||
if protocol != "" && protocol != "quic" && protocol != "http2" {
|
||||
return "", E.New("unsupported protocol: ", protocol, ", expected auto, quic or http2")
|
||||
}
|
||||
return protocol, nil
|
||||
type routedOriginPacketDialer interface {
|
||||
DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error)
|
||||
}
|
||||
|
||||
// icmpRouterHandler bridges cloudflared.ICMPHandler to router.PreMatch.
|
||||
type icmpRouterHandler struct {
|
||||
router adapter.Router
|
||||
tag string
|
||||
}
|
||||
|
||||
func (h *icmpRouterHandler) RouteICMPConnection(ctx context.Context, session tun.DirectRouteSession, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
|
||||
var ipVersion uint8
|
||||
if session.Source.Is4() {
|
||||
ipVersion = 4
|
||||
} else {
|
||||
ipVersion = 6
|
||||
}
|
||||
metadata := adapter.InboundContext{
|
||||
Inbound: h.tag,
|
||||
InboundType: C.TypeCloudflared,
|
||||
IPVersion: ipVersion,
|
||||
Network: N.NetworkICMP,
|
||||
Source: M.SocksaddrFrom(session.Source, 0),
|
||||
Destination: M.SocksaddrFrom(session.Destination, 0),
|
||||
OriginDestination: M.SocksaddrFrom(session.Destination, 0),
|
||||
}
|
||||
return h.router.PreMatch(metadata, routeContext, timeout, false)
|
||||
}
|
||||
|
||||
@@ -1,228 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func restoreConnectionHooks(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
originalNewQUICConnection := newQUICConnection
|
||||
originalNewHTTP2Connection := newHTTP2Connection
|
||||
originalServeQUICConnection := serveQUICConnection
|
||||
originalServeHTTP2Connection := serveHTTP2Connection
|
||||
t.Cleanup(func() {
|
||||
newQUICConnection = originalNewQUICConnection
|
||||
newHTTP2Connection = originalNewHTTP2Connection
|
||||
serveQUICConnection = originalServeQUICConnection
|
||||
serveHTTP2Connection = originalServeHTTP2Connection
|
||||
})
|
||||
}
|
||||
|
||||
func TestServeConnectionAutoFallbackSticky(t *testing.T) {
|
||||
restoreConnectionHooks(t)
|
||||
|
||||
inboundInstance := newLimitedInbound(t, 0)
|
||||
inboundInstance.protocol = ""
|
||||
inboundInstance.initializeConnectionState(0)
|
||||
|
||||
var quicCalls, http2Calls int
|
||||
newQUICConnection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, string, []string, uint8, time.Duration, N.Dialer, func(), log.ContextLogger) (*QUICConnection, error) {
|
||||
quicCalls++
|
||||
return &QUICConnection{}, nil
|
||||
}
|
||||
serveQUICConnection = func(*QUICConnection, context.Context, StreamHandler) error {
|
||||
return errors.New("quic failed")
|
||||
}
|
||||
newHTTP2Connection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, []string, uint8, time.Duration, *Inbound, log.ContextLogger) (*HTTP2Connection, error) {
|
||||
http2Calls++
|
||||
return &HTTP2Connection{}, nil
|
||||
}
|
||||
serveHTTP2Connection = func(*HTTP2Connection, context.Context) error {
|
||||
return errors.New("http2 failed")
|
||||
}
|
||||
|
||||
if err := inboundInstance.serveConnection(0, &EdgeAddr{}); err == nil || err.Error() != "http2 failed" {
|
||||
t.Fatalf("expected HTTP/2 fallback error, got %v", err)
|
||||
}
|
||||
if state := inboundInstance.connectionState(0); state.protocol != "http2" {
|
||||
t.Fatalf("expected sticky HTTP/2 fallback, got %#v", state)
|
||||
}
|
||||
|
||||
if err := inboundInstance.serveConnection(0, &EdgeAddr{}); err == nil || err.Error() != "http2 failed" {
|
||||
t.Fatalf("expected second HTTP/2 error, got %v", err)
|
||||
}
|
||||
if quicCalls != 1 {
|
||||
t.Fatalf("expected QUIC to be attempted once, got %d", quicCalls)
|
||||
}
|
||||
if http2Calls != 2 {
|
||||
t.Fatalf("expected HTTP/2 to be attempted twice, got %d", http2Calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecondConnectionInitialProtocolUsesFirstSuccess(t *testing.T) {
|
||||
inboundInstance := newLimitedInbound(t, 0)
|
||||
inboundInstance.protocol = ""
|
||||
|
||||
inboundInstance.notifyConnected(0, "http2")
|
||||
inboundInstance.initializeConnectionState(1)
|
||||
|
||||
if state := inboundInstance.connectionState(1); state.protocol != "http2" {
|
||||
t.Fatalf("expected second connection to inherit HTTP/2, got %#v", state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeConnectionSkipsFallbackWhenQUICAlreadySucceeded(t *testing.T) {
|
||||
restoreConnectionHooks(t)
|
||||
|
||||
inboundInstance := newLimitedInbound(t, 0)
|
||||
inboundInstance.protocol = ""
|
||||
inboundInstance.notifyConnected(0, "quic")
|
||||
inboundInstance.initializeConnectionState(1)
|
||||
|
||||
var http2Calls int
|
||||
quicErr := errors.New("quic failed")
|
||||
newQUICConnection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, string, []string, uint8, time.Duration, N.Dialer, func(), log.ContextLogger) (*QUICConnection, error) {
|
||||
return &QUICConnection{}, nil
|
||||
}
|
||||
serveQUICConnection = func(*QUICConnection, context.Context, StreamHandler) error {
|
||||
return quicErr
|
||||
}
|
||||
newHTTP2Connection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, []string, uint8, time.Duration, *Inbound, log.ContextLogger) (*HTTP2Connection, error) {
|
||||
http2Calls++
|
||||
return &HTTP2Connection{}, nil
|
||||
}
|
||||
|
||||
err := inboundInstance.serveConnection(1, &EdgeAddr{})
|
||||
if !errors.Is(err, quicErr) {
|
||||
t.Fatalf("expected QUIC error without fallback, got %v", err)
|
||||
}
|
||||
if http2Calls != 0 {
|
||||
t.Fatalf("expected no HTTP/2 fallback, got %d calls", http2Calls)
|
||||
}
|
||||
if state := inboundInstance.connectionState(1); state.protocol != "quic" {
|
||||
t.Fatalf("expected connection to remain on QUIC, got %#v", state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotifyConnectedResetsRetries(t *testing.T) {
|
||||
inboundInstance := newLimitedInbound(t, 0)
|
||||
inboundInstance.protocol = ""
|
||||
inboundInstance.initializeConnectionState(0)
|
||||
inboundInstance.incrementConnectionRetries(0)
|
||||
inboundInstance.incrementConnectionRetries(0)
|
||||
|
||||
inboundInstance.notifyConnected(0, "http2")
|
||||
|
||||
state := inboundInstance.connectionState(0)
|
||||
if state.retries != 0 {
|
||||
t.Fatalf("expected retries reset after success, got %d", state.retries)
|
||||
}
|
||||
if state.protocol != "http2" {
|
||||
t.Fatalf("expected protocol to be pinned to success, got %q", state.protocol)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeServeConnectionRecoversPanic(t *testing.T) {
|
||||
restoreConnectionHooks(t)
|
||||
|
||||
inboundInstance := newLimitedInbound(t, 0)
|
||||
inboundInstance.protocol = "quic"
|
||||
inboundInstance.initializeConnectionState(0)
|
||||
|
||||
newQUICConnection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, string, []string, uint8, time.Duration, N.Dialer, func(), log.ContextLogger) (*QUICConnection, error) {
|
||||
return &QUICConnection{}, nil
|
||||
}
|
||||
serveQUICConnection = func(*QUICConnection, context.Context, StreamHandler) error {
|
||||
panic("boom")
|
||||
}
|
||||
|
||||
err := inboundInstance.safeServeConnection(0, &EdgeAddr{})
|
||||
if err == nil || !strings.Contains(err.Error(), "panic in serve connection") {
|
||||
t.Fatalf("expected recovered panic error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuperviseConnectionStopsOnPermanentRegistrationError(t *testing.T) {
|
||||
restoreConnectionHooks(t)
|
||||
|
||||
inboundInstance := newLimitedInbound(t, 0)
|
||||
inboundInstance.protocol = "quic"
|
||||
inboundInstance.initializeConnectionState(0)
|
||||
|
||||
permanentErr := &permanentRegistrationError{Err: errors.New("permanent register error")}
|
||||
newQUICConnection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, string, []string, uint8, time.Duration, N.Dialer, func(), log.ContextLogger) (*QUICConnection, error) {
|
||||
return &QUICConnection{}, nil
|
||||
}
|
||||
serveQUICConnection = func(*QUICConnection, context.Context, StreamHandler) error {
|
||||
return permanentErr
|
||||
}
|
||||
|
||||
inboundInstance.done.Add(1)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
inboundInstance.superviseConnection(0, []*EdgeAddr{{}})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected connection supervision to stop")
|
||||
}
|
||||
|
||||
if retries := inboundInstance.connectionState(0).retries; retries != 0 {
|
||||
t.Fatalf("expected no retries for permanent registration error, got %d", retries)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-inboundInstance.ctx.Done():
|
||||
t.Fatal("expected permanent registration error to stop only this connection")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuperviseConnectionCancelsInboundOnNonRemoteManagedError(t *testing.T) {
|
||||
restoreConnectionHooks(t)
|
||||
|
||||
inboundInstance := newLimitedInbound(t, 0)
|
||||
inboundInstance.protocol = "quic"
|
||||
inboundInstance.initializeConnectionState(0)
|
||||
|
||||
newQUICConnection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, string, []string, uint8, time.Duration, N.Dialer, func(), log.ContextLogger) (*QUICConnection, error) {
|
||||
return &QUICConnection{}, nil
|
||||
}
|
||||
serveQUICConnection = func(*QUICConnection, context.Context, StreamHandler) error {
|
||||
return ErrNonRemoteManagedTunnelUnsupported
|
||||
}
|
||||
|
||||
inboundInstance.done.Add(1)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
inboundInstance.superviseConnection(0, []*EdgeAddr{{}})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected connection supervision to stop")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-inboundInstance.ctx.Done():
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected inbound cancellation on non-remote-managed tunnel error")
|
||||
}
|
||||
}
|
||||
@@ -1,271 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
)
|
||||
|
||||
func newTestIngressInbound(t *testing.T) *Inbound {
|
||||
t.Helper()
|
||||
configManager, err := NewConfigManager()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &Inbound{
|
||||
logger: log.NewNOPFactory().NewLogger("test"),
|
||||
configManager: configManager,
|
||||
}
|
||||
}
|
||||
|
||||
func mustResolvedService(t *testing.T, rawService string) ResolvedService {
|
||||
t.Helper()
|
||||
service, err := parseResolvedService(rawService, defaultOriginRequestConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return service
|
||||
}
|
||||
|
||||
func TestApplyConfig(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound(t)
|
||||
|
||||
config1 := []byte(`{"ingress":[{"hostname":"a.com","service":"http://localhost:80"},{"hostname":"b.com","service":"http://localhost:81"},{"service":"http_status:404"}]}`)
|
||||
result := inboundInstance.ApplyConfig(1, config1)
|
||||
if result.Err != nil {
|
||||
t.Fatal(result.Err)
|
||||
}
|
||||
if result.LastAppliedVersion != 1 {
|
||||
t.Fatalf("expected version 1, got %d", result.LastAppliedVersion)
|
||||
}
|
||||
|
||||
service, loaded := inboundInstance.configManager.Resolve("a.com", "/")
|
||||
if !loaded || service.Service != "http://localhost:80" {
|
||||
t.Fatalf("expected a.com to resolve to localhost:80, got %#v, loaded=%v", service, loaded)
|
||||
}
|
||||
|
||||
result = inboundInstance.ApplyConfig(1, []byte(`{"ingress":[{"service":"http_status:503"}]}`))
|
||||
if result.Err != nil {
|
||||
t.Fatal(result.Err)
|
||||
}
|
||||
if result.LastAppliedVersion != 1 {
|
||||
t.Fatalf("same version should keep current version, got %d", result.LastAppliedVersion)
|
||||
}
|
||||
|
||||
service, loaded = inboundInstance.configManager.Resolve("b.com", "/")
|
||||
if !loaded || service.Service != "http://localhost:81" {
|
||||
t.Fatalf("expected old rules to remain, got %#v, loaded=%v", service, loaded)
|
||||
}
|
||||
|
||||
result = inboundInstance.ApplyConfig(2, []byte(`{"ingress":[{"service":"http_status:503"}]}`))
|
||||
if result.Err != nil {
|
||||
t.Fatal(result.Err)
|
||||
}
|
||||
if result.LastAppliedVersion != 2 {
|
||||
t.Fatalf("expected version 2, got %d", result.LastAppliedVersion)
|
||||
}
|
||||
|
||||
service, loaded = inboundInstance.configManager.Resolve("anything.com", "/")
|
||||
if !loaded || service.StatusCode != 503 {
|
||||
t.Fatalf("expected catch-all status 503, got %#v, loaded=%v", service, loaded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyConfigInvalidJSON(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound(t)
|
||||
result := inboundInstance.ApplyConfig(1, []byte("not json"))
|
||||
if result.Err == nil {
|
||||
t.Fatal("expected parse error")
|
||||
}
|
||||
if result.LastAppliedVersion != -1 {
|
||||
t.Fatalf("expected version to stay -1, got %d", result.LastAppliedVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfigIsCatchAll503(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound(t)
|
||||
|
||||
service, loaded := inboundInstance.configManager.Resolve("any.example.com", "/")
|
||||
if !loaded {
|
||||
t.Fatal("expected default config to resolve catch-all rule")
|
||||
}
|
||||
if service.StatusCode != 503 {
|
||||
t.Fatalf("expected catch-all 503, got %#v", service)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveExactAndWildcard(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound(t)
|
||||
inboundInstance.configManager.activeConfig = RuntimeConfig{
|
||||
Ingress: []compiledIngressRule{
|
||||
{Hostname: "test.example.com", Service: mustResolvedService(t, "http://localhost:8080")},
|
||||
{Hostname: "*.example.com", Service: mustResolvedService(t, "http://localhost:9090")},
|
||||
{Service: mustResolvedService(t, "http_status:404")},
|
||||
},
|
||||
}
|
||||
|
||||
service, loaded := inboundInstance.configManager.Resolve("test.example.com", "/")
|
||||
if !loaded || service.Service != "http://localhost:8080" {
|
||||
t.Fatalf("expected exact match, got %#v, loaded=%v", service, loaded)
|
||||
}
|
||||
|
||||
service, loaded = inboundInstance.configManager.Resolve("sub.example.com", "/")
|
||||
if !loaded || service.Service != "http://localhost:9090" {
|
||||
t.Fatalf("expected wildcard match, got %#v, loaded=%v", service, loaded)
|
||||
}
|
||||
|
||||
service, loaded = inboundInstance.configManager.Resolve("unknown.test", "/")
|
||||
if !loaded || service.StatusCode != 404 {
|
||||
t.Fatalf("expected catch-all 404, got %#v, loaded=%v", service, loaded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveHTTPService(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound(t)
|
||||
inboundInstance.configManager.activeConfig = RuntimeConfig{
|
||||
Ingress: []compiledIngressRule{
|
||||
{Hostname: "foo.com", Service: mustResolvedService(t, "http://127.0.0.1:8083")},
|
||||
{Service: mustResolvedService(t, "http_status:404")},
|
||||
},
|
||||
}
|
||||
|
||||
service, requestURL, err := inboundInstance.resolveHTTPService("https://foo.com/path?q=1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if service.Destination.String() != "127.0.0.1:8083" {
|
||||
t.Fatalf("expected destination 127.0.0.1:8083, got %s", service.Destination)
|
||||
}
|
||||
if requestURL != "http://127.0.0.1:8083/path?q=1" {
|
||||
t.Fatalf("expected rewritten URL, got %s", requestURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveHTTPServiceStatus(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound(t)
|
||||
inboundInstance.configManager.activeConfig = RuntimeConfig{
|
||||
Ingress: []compiledIngressRule{
|
||||
{Service: mustResolvedService(t, "http_status:404")},
|
||||
},
|
||||
}
|
||||
|
||||
service, requestURL, err := inboundInstance.resolveHTTPService("https://any.com/path")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if service.StatusCode != 404 {
|
||||
t.Fatalf("expected status 404, got %#v", service)
|
||||
}
|
||||
if requestURL != "https://any.com/path" {
|
||||
t.Fatalf("status service should keep request URL, got %s", requestURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResolvedServiceCanonicalizesWebSocketOrigin(t *testing.T) {
|
||||
testCases := []struct {
|
||||
rawService string
|
||||
wantScheme string
|
||||
}{
|
||||
{rawService: "ws://127.0.0.1:8080", wantScheme: "http"},
|
||||
{rawService: "wss://127.0.0.1:8443", wantScheme: "https"},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.rawService, func(t *testing.T) {
|
||||
service, err := parseResolvedService(testCase.rawService, defaultOriginRequestConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if service.BaseURL == nil {
|
||||
t.Fatal("expected base URL")
|
||||
}
|
||||
if service.BaseURL.Scheme != testCase.wantScheme {
|
||||
t.Fatalf("expected scheme %q, got %q", testCase.wantScheme, service.BaseURL.Scheme)
|
||||
}
|
||||
if service.Service != testCase.rawService {
|
||||
t.Fatalf("expected raw service to stay %q, got %q", testCase.rawService, service.Service)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResolvedServiceGenericStreamSchemeWithoutPort(t *testing.T) {
|
||||
service, err := parseResolvedService("ftp://127.0.0.1", defaultOriginRequestConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if service.Kind != ResolvedServiceStream {
|
||||
t.Fatalf("expected stream service, got %v", service.Kind)
|
||||
}
|
||||
if service.Destination.AddrString() != "127.0.0.1" {
|
||||
t.Fatalf("expected destination host 127.0.0.1, got %s", service.Destination.AddrString())
|
||||
}
|
||||
if service.Destination.Port != 0 {
|
||||
t.Fatalf("expected destination port 0, got %d", service.Destination.Port)
|
||||
}
|
||||
if service.StreamHasPort {
|
||||
t.Fatal("expected generic stream service without port to report missing port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResolvedServiceGenericStreamSchemeWithPort(t *testing.T) {
|
||||
service, err := parseResolvedService("ftp://127.0.0.1:21", defaultOriginRequestConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if service.Kind != ResolvedServiceStream {
|
||||
t.Fatalf("expected stream service, got %v", service.Kind)
|
||||
}
|
||||
if service.Destination.String() != "127.0.0.1:21" {
|
||||
t.Fatalf("expected destination 127.0.0.1:21, got %s", service.Destination)
|
||||
}
|
||||
if !service.StreamHasPort {
|
||||
t.Fatal("expected generic stream service with explicit port to be dialable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResolvedServiceSSHDefaultPort(t *testing.T) {
|
||||
service, err := parseResolvedService("ssh://127.0.0.1", defaultOriginRequestConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if service.Destination.String() != "127.0.0.1:22" {
|
||||
t.Fatalf("expected destination 127.0.0.1:22, got %s", service.Destination)
|
||||
}
|
||||
if !service.StreamHasPort {
|
||||
t.Fatal("expected ssh stream service to apply default port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResolvedServiceTCPDefaultPort(t *testing.T) {
|
||||
service, err := parseResolvedService("tcp://127.0.0.1", defaultOriginRequestConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if service.Destination.String() != "127.0.0.1:7864" {
|
||||
t.Fatalf("expected destination 127.0.0.1:7864, got %s", service.Destination)
|
||||
}
|
||||
if !service.StreamHasPort {
|
||||
t.Fatal("expected tcp stream service to apply default port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveHTTPServiceWebSocketOrigin(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound(t)
|
||||
inboundInstance.configManager.activeConfig = RuntimeConfig{
|
||||
Ingress: []compiledIngressRule{
|
||||
{Hostname: "foo.com", Service: mustResolvedService(t, "ws://127.0.0.1:8083")},
|
||||
{Service: mustResolvedService(t, "http_status:404")},
|
||||
},
|
||||
}
|
||||
|
||||
_, requestURL, err := inboundInstance.resolveHTTPService("https://foo.com/path?q=1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if requestURL != "http://127.0.0.1:8083/path?q=1" {
|
||||
t.Fatalf("expected websocket origin to be canonicalized, got %s", requestURL)
|
||||
}
|
||||
}
|
||||
@@ -1,181 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
)
|
||||
|
||||
func TestQUICIntegration(t *testing.T) {
|
||||
token, testURL := requireEnvVars(t)
|
||||
startOriginServer(t)
|
||||
|
||||
inboundInstance := newTestInbound(t, token, "quic", 1)
|
||||
err := inboundInstance.Start(adapter.StartStateStart)
|
||||
if err != nil {
|
||||
t.Fatal("Start: ", err)
|
||||
}
|
||||
|
||||
waitForTunnel(t, testURL, 30*time.Second)
|
||||
|
||||
resp, err := http.Get(testURL + "/ping")
|
||||
if err != nil {
|
||||
t.Fatal("GET /ping: ", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatal("expected 200, got ", resp.StatusCode)
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal("read body: ", err)
|
||||
}
|
||||
if string(body) != `{"ok":true}` {
|
||||
t.Error("unexpected body: ", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP2Integration(t *testing.T) {
|
||||
token, testURL := requireEnvVars(t)
|
||||
startOriginServer(t)
|
||||
|
||||
inboundInstance := newTestInbound(t, token, "http2", 1)
|
||||
err := inboundInstance.Start(adapter.StartStateStart)
|
||||
if err != nil {
|
||||
t.Fatal("Start: ", err)
|
||||
}
|
||||
|
||||
waitForTunnel(t, testURL, 30*time.Second)
|
||||
|
||||
resp, err := http.Get(testURL + "/ping")
|
||||
if err != nil {
|
||||
t.Fatal("GET /ping: ", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatal("expected 200, got ", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleHAConnections(t *testing.T) {
|
||||
token, testURL := requireEnvVars(t)
|
||||
startOriginServer(t)
|
||||
|
||||
inboundInstance := newTestInbound(t, token, "quic", 2)
|
||||
err := inboundInstance.Start(adapter.StartStateStart)
|
||||
if err != nil {
|
||||
t.Fatal("Start: ", err)
|
||||
}
|
||||
|
||||
waitForTunnel(t, testURL, 30*time.Second)
|
||||
|
||||
// Allow time for second connection to register
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
inboundInstance.connectionAccess.Lock()
|
||||
connCount := len(inboundInstance.connections)
|
||||
inboundInstance.connectionAccess.Unlock()
|
||||
if connCount < 2 {
|
||||
t.Errorf("expected at least 2 connections, got %d", connCount)
|
||||
}
|
||||
|
||||
resp, err := http.Get(testURL + "/ping")
|
||||
if err != nil {
|
||||
t.Fatal("GET /ping: ", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatal("expected 200, got ", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPResponseCorrectness(t *testing.T) {
|
||||
token, testURL := requireEnvVars(t)
|
||||
startOriginServer(t)
|
||||
|
||||
inboundInstance := newTestInbound(t, token, "quic", 1)
|
||||
err := inboundInstance.Start(adapter.StartStateStart)
|
||||
if err != nil {
|
||||
t.Fatal("Start: ", err)
|
||||
}
|
||||
|
||||
waitForTunnel(t, testURL, 30*time.Second)
|
||||
|
||||
t.Run("StatusCode", func(t *testing.T) {
|
||||
resp, err := http.Get(testURL + "/status/201")
|
||||
if err != nil {
|
||||
t.Fatal("GET /status/201: ", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != 201 {
|
||||
t.Error("expected 201, got ", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CustomHeader", func(t *testing.T) {
|
||||
resp, err := http.Get(testURL + "/status/200")
|
||||
if err != nil {
|
||||
t.Fatal("GET /status/200: ", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
customHeader := resp.Header.Get("X-Custom")
|
||||
if customHeader != "test-value" {
|
||||
t.Error("expected X-Custom=test-value, got ", customHeader)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PostEcho", func(t *testing.T) {
|
||||
resp, err := http.Post(testURL+"/echo", "text/plain", strings.NewReader("payload"))
|
||||
if err != nil {
|
||||
t.Fatal("POST /echo: ", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatal("expected 200, got ", resp.StatusCode)
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal("read body: ", err)
|
||||
}
|
||||
if string(body) != "payload" {
|
||||
t.Error("unexpected body: ", string(body))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGracefulClose(t *testing.T) {
|
||||
token, testURL := requireEnvVars(t)
|
||||
startOriginServer(t)
|
||||
|
||||
inboundInstance := newTestInbound(t, token, "quic", 1)
|
||||
err := inboundInstance.Start(adapter.StartStateStart)
|
||||
if err != nil {
|
||||
t.Fatal("Start: ", err)
|
||||
}
|
||||
|
||||
waitForTunnel(t, testURL, 30*time.Second)
|
||||
|
||||
err = inboundInstance.Close()
|
||||
if err != nil {
|
||||
t.Fatal("Close: ", err)
|
||||
}
|
||||
|
||||
if inboundInstance.ctx.Err() == nil {
|
||||
t.Error("expected context to be cancelled after Close")
|
||||
}
|
||||
|
||||
inboundInstance.connectionAccess.Lock()
|
||||
remaining := inboundInstance.connections
|
||||
inboundInstance.connectionAccess.Unlock()
|
||||
if remaining != nil {
|
||||
t.Error("expected connections to be nil after Close, got ", len(remaining))
|
||||
}
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type compiledIPRule struct {
|
||||
prefix netip.Prefix
|
||||
ports []int
|
||||
allow bool
|
||||
}
|
||||
|
||||
type ipRulePolicy struct {
|
||||
rules []compiledIPRule
|
||||
}
|
||||
|
||||
func newIPRulePolicy(rawRules []IPRule) (*ipRulePolicy, error) {
|
||||
policy := &ipRulePolicy{
|
||||
rules: make([]compiledIPRule, 0, len(rawRules)),
|
||||
}
|
||||
for _, rawRule := range rawRules {
|
||||
if rawRule.Prefix == "" {
|
||||
return nil, E.New("ip_rule prefix cannot be blank")
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(rawRule.Prefix)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "parse ip_rule prefix")
|
||||
}
|
||||
ports := append([]int(nil), rawRule.Ports...)
|
||||
sort.Ints(ports)
|
||||
for _, port := range ports {
|
||||
if port < 1 || port > 65535 {
|
||||
return nil, E.New("invalid ip_rule port: ", port)
|
||||
}
|
||||
}
|
||||
policy.rules = append(policy.rules, compiledIPRule{
|
||||
prefix: prefix,
|
||||
ports: ports,
|
||||
allow: rawRule.Allow,
|
||||
})
|
||||
}
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func (p *ipRulePolicy) Allow(ctx context.Context, destination M.Socksaddr) (bool, error) {
|
||||
if p == nil {
|
||||
return false, nil
|
||||
}
|
||||
ipAddr, err := resolvePolicyDestination(ctx, destination)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
port := int(destination.Port)
|
||||
for _, rule := range p.rules {
|
||||
if !rule.prefix.Contains(ipAddr) {
|
||||
continue
|
||||
}
|
||||
if len(rule.ports) == 0 {
|
||||
return rule.allow, nil
|
||||
}
|
||||
portIndex := sort.SearchInts(rule.ports, port)
|
||||
if portIndex < len(rule.ports) && rule.ports[portIndex] == port {
|
||||
return rule.allow, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func resolvePolicyDestination(ctx context.Context, destination M.Socksaddr) (netip.Addr, error) {
|
||||
if destination.IsIP() {
|
||||
return destination.Unwrap().Addr, nil
|
||||
}
|
||||
if !destination.IsFqdn() {
|
||||
return netip.Addr{}, E.New("destination is neither IP nor FQDN")
|
||||
}
|
||||
ipAddrs, err := net.DefaultResolver.LookupIPAddr(ctx, destination.Fqdn)
|
||||
if err != nil {
|
||||
return netip.Addr{}, E.Cause(err, "resolve destination")
|
||||
}
|
||||
if len(ipAddrs) == 0 {
|
||||
return netip.Addr{}, E.New("resolved destination is empty")
|
||||
}
|
||||
resolvedAddr, ok := netip.AddrFromSlice(ipAddrs[0].IP)
|
||||
if !ok {
|
||||
return netip.Addr{}, E.New("resolved destination is invalid")
|
||||
}
|
||||
return resolvedAddr.Unmap(), nil
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
const originUDPWriteTimeout = 200 * time.Millisecond
|
||||
|
||||
type udpWriteDeadlinePacketConn struct {
|
||||
N.PacketConn
|
||||
}
|
||||
|
||||
func (c *udpWriteDeadlinePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
_ = c.PacketConn.SetWriteDeadline(time.Now().Add(originUDPWriteTimeout))
|
||||
defer func() {
|
||||
_ = c.PacketConn.SetWriteDeadline(time.Time{})
|
||||
}()
|
||||
return c.PacketConn.WritePacket(buffer, destination)
|
||||
}
|
||||
|
||||
type routedOriginPacketDialer interface {
|
||||
DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error)
|
||||
}
|
||||
|
||||
func (i *Inbound) dialWarpPacketConnection(ctx context.Context, destination netip.AddrPort) (N.PacketConn, error) {
|
||||
originDialer, ok := i.router.(routedOriginPacketDialer)
|
||||
if !ok {
|
||||
return nil, E.New("router does not support cloudflare routed packet dialing")
|
||||
}
|
||||
|
||||
warpRouting := i.configManager.Snapshot().WarpRouting
|
||||
if warpRouting.ConnectTimeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, warpRouting.ConnectTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
packetConn, err := originDialer.DialRoutePacketConnection(ctx, adapter.InboundContext{
|
||||
Inbound: i.Tag(),
|
||||
InboundType: i.Type(),
|
||||
Network: N.NetworkUDP,
|
||||
Destination: M.SocksaddrFromNetIP(destination),
|
||||
UDPConnect: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &udpWriteDeadlinePacketConn{PacketConn: packetConn}, nil
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type captureDeadlinePacketConn struct {
|
||||
err error
|
||||
deadlines []time.Time
|
||||
}
|
||||
|
||||
func (c *captureDeadlinePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||
buffer.Release()
|
||||
return M.Socksaddr{}, errors.New("unused")
|
||||
}
|
||||
|
||||
func (c *captureDeadlinePacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
|
||||
buffer.Release()
|
||||
return c.err
|
||||
}
|
||||
|
||||
func (c *captureDeadlinePacketConn) Close() error { return nil }
|
||||
func (c *captureDeadlinePacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} }
|
||||
func (c *captureDeadlinePacketConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *captureDeadlinePacketConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *captureDeadlinePacketConn) SetWriteDeadline(t time.Time) error {
|
||||
c.deadlines = append(c.deadlines, t)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDeadlinePacketConnWrapsWriteDeadline(t *testing.T) {
|
||||
packetConn := &captureDeadlinePacketConn{}
|
||||
wrapped := &udpWriteDeadlinePacketConn{PacketConn: packetConn}
|
||||
|
||||
if err := wrapped.WritePacket(buf.As([]byte("payload")), M.Socksaddr{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(packetConn.deadlines) != 2 {
|
||||
t.Fatalf("expected two deadline updates, got %d", len(packetConn.deadlines))
|
||||
}
|
||||
if packetConn.deadlines[0].IsZero() {
|
||||
t.Fatal("expected first deadline to set a timeout")
|
||||
}
|
||||
if !packetConn.deadlines[1].IsZero() {
|
||||
t.Fatal("expected second deadline to clear the timeout")
|
||||
}
|
||||
}
|
||||
@@ -1,337 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type noopRouteConnectionRouter struct {
|
||||
testRouter
|
||||
}
|
||||
|
||||
func (r *noopRouteConnectionRouter) RouteConnectionEx(_ context.Context, conn net.Conn, _ adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
_ = conn.Close()
|
||||
onClose(nil)
|
||||
}
|
||||
|
||||
func TestOriginTLSServerName(t *testing.T) {
|
||||
t.Run("origin server name overrides host", func(t *testing.T) {
|
||||
serverName := originTLSServerName(OriginRequestConfig{
|
||||
OriginServerName: "origin.example.com",
|
||||
MatchSNIToHost: true,
|
||||
}, "request.example.com")
|
||||
if serverName != "origin.example.com" {
|
||||
t.Fatalf("expected origin.example.com, got %s", serverName)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("match sni to host strips port", func(t *testing.T) {
|
||||
serverName := originTLSServerName(OriginRequestConfig{
|
||||
MatchSNIToHost: true,
|
||||
}, "request.example.com:443")
|
||||
if serverName != "request.example.com" {
|
||||
t.Fatalf("expected request.example.com, got %s", serverName)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("match sni to host uses http host header", func(t *testing.T) {
|
||||
serverName := originTLSServerName(OriginRequestConfig{
|
||||
MatchSNIToHost: true,
|
||||
}, effectiveOriginHost(OriginRequestConfig{
|
||||
HTTPHostHeader: "origin.example.com",
|
||||
MatchSNIToHost: true,
|
||||
}, "request.example.com"))
|
||||
if serverName != "origin.example.com" {
|
||||
t.Fatalf("expected origin.example.com, got %s", serverName)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("match sni to host strips port from http host header", func(t *testing.T) {
|
||||
serverName := originTLSServerName(OriginRequestConfig{
|
||||
MatchSNIToHost: true,
|
||||
}, effectiveOriginHost(OriginRequestConfig{
|
||||
HTTPHostHeader: "origin.example.com:8443",
|
||||
MatchSNIToHost: true,
|
||||
}, "request.example.com"))
|
||||
if serverName != "origin.example.com" {
|
||||
t.Fatalf("expected origin.example.com, got %s", serverName)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disabled match keeps empty server name", func(t *testing.T) {
|
||||
serverName := originTLSServerName(OriginRequestConfig{}, "request.example.com")
|
||||
if serverName != "" {
|
||||
t.Fatalf("expected empty server name, got %s", serverName)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewOriginTLSConfigErrorsOnMissingCAPool(t *testing.T) {
|
||||
originalBaseLoader := loadOriginCABasePool
|
||||
loadOriginCABasePool = func() (*x509.CertPool, error) {
|
||||
return x509.NewCertPool(), nil
|
||||
}
|
||||
defer func() {
|
||||
loadOriginCABasePool = originalBaseLoader
|
||||
}()
|
||||
|
||||
_, err := newOriginTLSConfig(OriginRequestConfig{
|
||||
CAPool: "/path/does/not/exist.pem",
|
||||
}, "request.example.com")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing ca pool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewOriginTLSConfigAppendsCustomCAInsteadOfReplacingBasePool(t *testing.T) {
|
||||
basePEM, baseCert := createTestCertificatePEM(t, "base")
|
||||
customPEM, customCert := createTestCertificatePEM(t, "custom")
|
||||
|
||||
basePool := x509.NewCertPool()
|
||||
if !basePool.AppendCertsFromPEM(basePEM) {
|
||||
t.Fatal("expected base cert to append")
|
||||
}
|
||||
|
||||
originalBaseLoader := loadOriginCABasePool
|
||||
loadOriginCABasePool = func() (*x509.CertPool, error) {
|
||||
return basePool, nil
|
||||
}
|
||||
defer func() {
|
||||
loadOriginCABasePool = originalBaseLoader
|
||||
}()
|
||||
|
||||
caFile := writeTempPEM(t, customPEM)
|
||||
tlsConfig, err := newOriginTLSConfig(OriginRequestConfig{
|
||||
CAPool: caFile,
|
||||
}, "request.example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if tlsConfig.RootCAs == nil {
|
||||
t.Fatal("expected root CA pool")
|
||||
}
|
||||
subjects := tlsConfig.RootCAs.Subjects()
|
||||
if len(subjects) != 2 {
|
||||
t.Fatalf("expected 2 subjects, got %d", len(subjects))
|
||||
}
|
||||
if !containsSubject(subjects, baseCert.RawSubject) {
|
||||
t.Fatal("expected base subject to remain in pool")
|
||||
}
|
||||
if !containsSubject(subjects, customCert.RawSubject) {
|
||||
t.Fatal("expected custom subject to be appended to pool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOriginTransportUsesProxyFromEnvironmentOnly(t *testing.T) {
|
||||
originalProxyFromEnvironment := proxyFromEnvironment
|
||||
proxyFromEnvironment = func(request *http.Request) (*url.URL, error) {
|
||||
return url.Parse("http://proxy.example.com:8080")
|
||||
}
|
||||
defer func() {
|
||||
proxyFromEnvironment = originalProxyFromEnvironment
|
||||
}()
|
||||
|
||||
inbound := &Inbound{}
|
||||
transport, cleanup, err := inbound.newDirectOriginTransport(ResolvedService{
|
||||
Kind: ResolvedServiceUnix,
|
||||
UnixPath: "/tmp/test.sock",
|
||||
OriginRequest: OriginRequestConfig{
|
||||
ProxyAddress: "127.0.0.1",
|
||||
ProxyPort: 8081,
|
||||
ProxyType: "http",
|
||||
},
|
||||
}, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
proxyURL, err := transport.Proxy(&http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if proxyURL == nil || proxyURL.String() != "http://proxy.example.com:8080" {
|
||||
t.Fatalf("expected environment proxy URL, got %#v", proxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDirectOriginTransportNoHappyEyeballs(t *testing.T) {
|
||||
inbound := &Inbound{}
|
||||
transport, cleanup, err := inbound.newDirectOriginTransport(ResolvedService{
|
||||
Kind: ResolvedServiceUnix,
|
||||
UnixPath: "/tmp/test.sock",
|
||||
OriginRequest: OriginRequestConfig{
|
||||
NoHappyEyeballs: true,
|
||||
},
|
||||
}, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cleanup()
|
||||
if transport.Proxy == nil {
|
||||
t.Fatal("expected proxy function to be configured from environment")
|
||||
}
|
||||
if transport.DialContext == nil {
|
||||
t.Fatal("expected custom direct dial context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRouterOriginTransportPropagatesTLSConfigError(t *testing.T) {
|
||||
originalBaseLoader := loadOriginCABasePool
|
||||
loadOriginCABasePool = func() (*x509.CertPool, error) {
|
||||
return x509.NewCertPool(), nil
|
||||
}
|
||||
defer func() {
|
||||
loadOriginCABasePool = originalBaseLoader
|
||||
}()
|
||||
|
||||
inbound := &Inbound{}
|
||||
_, _, err := inbound.newRouterOriginTransport(context.Background(), adapter.InboundContext{}, OriginRequestConfig{
|
||||
CAPool: "/path/does/not/exist.pem",
|
||||
}, "")
|
||||
if err == nil {
|
||||
t.Fatal("expected transport build error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRouterOriginTransportUsesCloudflaredDefaults(t *testing.T) {
|
||||
inbound := &Inbound{
|
||||
router: &noopRouteConnectionRouter{},
|
||||
}
|
||||
transport, cleanup, err := inbound.newRouterOriginTransport(context.Background(), adapter.InboundContext{}, OriginRequestConfig{}, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
if transport.ExpectContinueTimeout != time.Second {
|
||||
t.Fatalf("expected ExpectContinueTimeout=1s, got %s", transport.ExpectContinueTimeout)
|
||||
}
|
||||
if transport.DisableCompression {
|
||||
t.Fatal("expected compression to remain enabled by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOriginRequestSetsKeepAliveAndEmptyUserAgent(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "https://example.com/path", http.NoBody)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
request = normalizeOriginRequest(ConnectionTypeHTTP, request, OriginRequestConfig{})
|
||||
if connection := request.Header.Get("Connection"); connection != "keep-alive" {
|
||||
t.Fatalf("expected keep-alive connection header, got %q", connection)
|
||||
}
|
||||
if values, exists := request.Header["User-Agent"]; !exists || len(values) != 1 || values[0] != "" {
|
||||
t.Fatalf("expected empty User-Agent header, got %#v", request.Header["User-Agent"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOriginRequestDisableChunkedEncoding(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodPost, "https://example.com/path", strings.NewReader("payload"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
request.TransferEncoding = []string{"chunked"}
|
||||
request.Header.Set("Content-Length", "7")
|
||||
|
||||
request = normalizeOriginRequest(ConnectionTypeHTTP, request, OriginRequestConfig{
|
||||
DisableChunkedEncoding: true,
|
||||
})
|
||||
if len(request.TransferEncoding) != 2 || request.TransferEncoding[0] != "gzip" || request.TransferEncoding[1] != "deflate" {
|
||||
t.Fatalf("unexpected transfer encoding: %#v", request.TransferEncoding)
|
||||
}
|
||||
if request.ContentLength != 7 {
|
||||
t.Fatalf("expected content length 7, got %d", request.ContentLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOriginRequestWebsocket(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "https://example.com/path", io.NopCloser(strings.NewReader("payload")))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
request = normalizeOriginRequest(ConnectionTypeWebsocket, request, OriginRequestConfig{})
|
||||
if connection := request.Header.Get("Connection"); connection != "Upgrade" {
|
||||
t.Fatalf("expected websocket connection header, got %q", connection)
|
||||
}
|
||||
if upgrade := request.Header.Get("Upgrade"); upgrade != "websocket" {
|
||||
t.Fatalf("expected websocket upgrade header, got %q", upgrade)
|
||||
}
|
||||
if version := request.Header.Get("Sec-Websocket-Version"); version != "13" {
|
||||
t.Fatalf("expected websocket version 13, got %q", version)
|
||||
}
|
||||
if request.ContentLength != 0 {
|
||||
t.Fatalf("expected websocket content length 0, got %d", request.ContentLength)
|
||||
}
|
||||
if request.Body != nil {
|
||||
t.Fatal("expected websocket body to be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func createTestCertificatePEM(t *testing.T, commonName string) ([]byte, *x509.Certificate) {
|
||||
t.Helper()
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(time.Now().UnixNano()),
|
||||
Subject: pkix.Name{
|
||||
CommonName: commonName,
|
||||
},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
certificate, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}), certificate
|
||||
}
|
||||
|
||||
func writeTempPEM(t *testing.T, pemData []byte) string {
|
||||
t.Helper()
|
||||
path := t.TempDir() + "/ca.pem"
|
||||
if err := os.WriteFile(path, pemData, 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func containsSubject(subjects [][]byte, want []byte) bool {
|
||||
for _, subject := range subjects {
|
||||
if bytes.Equal(subject, want) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildHTTPRequestFromMetadataUsesNoBodyWhenLengthZeroWithoutChunked(t *testing.T) {
|
||||
request, err := buildHTTPRequestFromMetadata(context.Background(), &ConnectRequest{
|
||||
Dest: "http://example.com",
|
||||
Type: ConnectionTypeHTTP,
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPMethod, Val: http.MethodGet},
|
||||
{Key: metadataHTTPHost, Val: "cf.host"},
|
||||
},
|
||||
}, io.NopCloser(bytes.NewBuffer(nil)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if request.Body != http.NoBody {
|
||||
t.Fatalf("expected http.NoBody, got %#v", request.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildHTTPRequestFromMetadataPreservesBodyWhenTransferEncodingChunked(t *testing.T) {
|
||||
request, err := buildHTTPRequestFromMetadata(context.Background(), &ConnectRequest{
|
||||
Dest: "http://example.com",
|
||||
Type: ConnectionTypeHTTP,
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPMethod, Val: http.MethodPost},
|
||||
{Key: metadataHTTPHost, Val: "cf.host"},
|
||||
{Key: metadataHTTPHeader + ":Transfer-Encoding", Val: "chunked"},
|
||||
},
|
||||
}, io.NopCloser(bytes.NewBufferString("payload")))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if request.Body == http.NoBody {
|
||||
t.Fatal("expected request body to be preserved")
|
||||
}
|
||||
body, err := io.ReadAll(request.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(body) != "payload" {
|
||||
t.Fatalf("unexpected body %q", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildHTTPRequestFromMetadataPreservesBodyWhenTransferEncodingContainsChunked(t *testing.T) {
|
||||
request, err := buildHTTPRequestFromMetadata(context.Background(), &ConnectRequest{
|
||||
Dest: "http://example.com",
|
||||
Type: ConnectionTypeHTTP,
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPMethod, Val: http.MethodPost},
|
||||
{Key: metadataHTTPHost, Val: "cf.host"},
|
||||
{Key: metadataHTTPHeader + ":Transfer-Encoding", Val: "gzip,chunked"},
|
||||
},
|
||||
}, io.NopCloser(bytes.NewBufferString("payload")))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if request.Body == http.NoBody {
|
||||
t.Fatal("expected request body to be preserved")
|
||||
}
|
||||
body, err := io.ReadAll(request.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(body) != "payload" {
|
||||
t.Fatalf("unexpected body %q", body)
|
||||
}
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
)
|
||||
|
||||
type trailerCaptureResponseWriter struct {
|
||||
status int
|
||||
trailers http.Header
|
||||
}
|
||||
|
||||
func (w *trailerCaptureResponseWriter) WriteResponse(responseError error, metadata []Metadata) error {
|
||||
for _, entry := range metadata {
|
||||
if entry.Key == metadataHTTPStatus {
|
||||
w.status = http.StatusOK
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *trailerCaptureResponseWriter) AddTrailer(name, value string) {
|
||||
if w.trailers == nil {
|
||||
w.trailers = make(http.Header)
|
||||
}
|
||||
w.trailers.Add(name, value)
|
||||
}
|
||||
|
||||
type captureReadWriteCloser struct {
|
||||
body []byte
|
||||
}
|
||||
|
||||
func (c *captureReadWriteCloser) Read(_ []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (c *captureReadWriteCloser) Write(p []byte) (int, error) {
|
||||
c.body = append(c.body, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *captureReadWriteCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRoundTripHTTPCopiesTrailers(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add("Trailer", "X-Test-Trailer")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
w.Header().Set("X-Test-Trailer", "trailer-value")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
transport, ok := server.Client().Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected transport type %T", server.Client().Transport)
|
||||
}
|
||||
|
||||
inboundInstance := &Inbound{
|
||||
logger: log.NewNOPFactory().NewLogger("test"),
|
||||
}
|
||||
stream := &captureReadWriteCloser{}
|
||||
respWriter := &trailerCaptureResponseWriter{}
|
||||
request := &ConnectRequest{
|
||||
Dest: server.URL,
|
||||
Type: ConnectionTypeHTTP,
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPMethod, Val: http.MethodGet},
|
||||
{Key: metadataHTTPHost, Val: "example.com"},
|
||||
},
|
||||
}
|
||||
|
||||
inboundInstance.roundTripHTTP(context.Background(), stream, respWriter, request, ResolvedService{
|
||||
OriginRequest: defaultOriginRequestConfig(),
|
||||
}, transport)
|
||||
|
||||
if got := respWriter.trailers.Get("X-Test-Trailer"); got != "trailer-value" {
|
||||
t.Fatalf("expected copied trailer, got %q", got)
|
||||
}
|
||||
if string(stream.body) != "ok" {
|
||||
t.Fatalf("unexpected response body %q", stream.body)
|
||||
}
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
_ "embed"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
//go:embed cloudflare_ca.pem
|
||||
var cloudflareRootCAPEM []byte
|
||||
|
||||
func cloudflareRootCertPool() (*x509.CertPool, error) {
|
||||
pool, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
pool = x509.NewCertPool()
|
||||
}
|
||||
if !pool.AppendCertsFromPEM(cloudflareRootCAPEM) {
|
||||
return nil, E.New("failed to parse embedded Cloudflare root CAs")
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing/common"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/pipe"
|
||||
)
|
||||
|
||||
type routedPipeTCPOptions struct {
|
||||
timeout time.Duration
|
||||
onHandshake func(net.Conn)
|
||||
}
|
||||
|
||||
type routedPipeTCPConn struct {
|
||||
net.Conn
|
||||
handshakeOnce sync.Once
|
||||
onHandshake func(net.Conn)
|
||||
}
|
||||
|
||||
func (c *routedPipeTCPConn) ConnHandshakeSuccess(conn net.Conn) error {
|
||||
if c.onHandshake != nil {
|
||||
c.handshakeOnce.Do(func() {
|
||||
c.onHandshake(conn)
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Inbound) dialRouterTCPWithMetadata(ctx context.Context, metadata adapter.InboundContext, options routedPipeTCPOptions) (net.Conn, func(), error) {
|
||||
input, output := pipe.Pipe()
|
||||
routerConn := &routedPipeTCPConn{
|
||||
Conn: output,
|
||||
onHandshake: options.onHandshake,
|
||||
}
|
||||
done := make(chan struct{})
|
||||
|
||||
routeCtx := ctx
|
||||
var cancel context.CancelFunc
|
||||
if options.timeout > 0 {
|
||||
routeCtx, cancel = context.WithTimeout(ctx, options.timeout)
|
||||
}
|
||||
|
||||
var closeOnce sync.Once
|
||||
closePipe := func() {
|
||||
closeOnce.Do(func() {
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
common.Close(input, routerConn)
|
||||
})
|
||||
}
|
||||
go i.router.RouteConnectionEx(routeCtx, routerConn, metadata, N.OnceClose(func(it error) {
|
||||
closePipe()
|
||||
close(done)
|
||||
}))
|
||||
|
||||
return input, func() {
|
||||
closePipe()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func applyTCPKeepAlive(conn net.Conn, keepAlive time.Duration) error {
|
||||
if keepAlive <= 0 {
|
||||
return nil
|
||||
}
|
||||
type keepAliveConn interface {
|
||||
SetKeepAlive(bool) error
|
||||
SetKeepAlivePeriod(time.Duration) error
|
||||
}
|
||||
tcpConn, ok := conn.(keepAliveConn)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if err := tcpConn.SetKeepAlive(true); err != nil {
|
||||
return err
|
||||
}
|
||||
return tcpConn.SetKeepAlivePeriod(keepAlive)
|
||||
}
|
||||
@@ -1,165 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func TestHandleTCPStreamUsesRouteConnectionEx(t *testing.T) {
|
||||
listener := startEchoListener(t)
|
||||
defer listener.Close()
|
||||
|
||||
router := &countingRouter{}
|
||||
inboundInstance := newSpecialServiceInboundWithRouter(t, router)
|
||||
|
||||
serverSide, clientSide := net.Pipe()
|
||||
defer clientSide.Close()
|
||||
|
||||
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
|
||||
responseDone := respWriter.done
|
||||
finished := make(chan struct{})
|
||||
go func() {
|
||||
inboundInstance.handleTCPStream(context.Background(), serverSide, respWriter, adapter.InboundContext{
|
||||
Destination: M.ParseSocksaddr(listener.Addr().String()),
|
||||
})
|
||||
close(finished)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-responseDone:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for connect response")
|
||||
}
|
||||
if respWriter.err != nil {
|
||||
t.Fatal("unexpected response error: ", respWriter.err)
|
||||
}
|
||||
|
||||
if err := clientSide.SetDeadline(time.Now().Add(time.Second)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
payload := []byte("ping")
|
||||
if _, err := clientSide.Write(payload); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
response := make([]byte, len(payload))
|
||||
if _, err := io.ReadFull(clientSide, response); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(response) != string(payload) {
|
||||
t.Fatalf("unexpected echo payload: %q", string(response))
|
||||
}
|
||||
if router.count.Load() != 1 {
|
||||
t.Fatalf("expected RouteConnectionEx to be used once, got %d", router.count.Load())
|
||||
}
|
||||
|
||||
_ = clientSide.Close()
|
||||
select {
|
||||
case <-finished:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for TCP stream handler to exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleTCPStreamWritesOptimisticAck(t *testing.T) {
|
||||
router := &blockingRouteRouter{
|
||||
started: make(chan struct{}),
|
||||
release: make(chan struct{}),
|
||||
}
|
||||
inboundInstance := newSpecialServiceInboundWithRouter(t, router)
|
||||
|
||||
serverSide, clientSide := net.Pipe()
|
||||
defer clientSide.Close()
|
||||
|
||||
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
|
||||
responseDone := respWriter.done
|
||||
finished := make(chan struct{})
|
||||
go func() {
|
||||
inboundInstance.handleTCPStream(context.Background(), serverSide, respWriter, adapter.InboundContext{
|
||||
Destination: M.ParseSocksaddr("127.0.0.1:443"),
|
||||
})
|
||||
close(finished)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-router.started:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for router goroutine to start")
|
||||
}
|
||||
select {
|
||||
case <-responseDone:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for optimistic connect response")
|
||||
}
|
||||
if respWriter.err != nil {
|
||||
t.Fatal("unexpected response error: ", respWriter.err)
|
||||
}
|
||||
|
||||
close(router.release)
|
||||
_ = clientSide.Close()
|
||||
select {
|
||||
case <-finished:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for TCP stream handler to exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutedPipeTCPConnHandshakeAppliesKeepAlive(t *testing.T) {
|
||||
left, right := net.Pipe()
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
|
||||
remoteConn := &keepAliveTestConn{Conn: right}
|
||||
routerConn := &routedPipeTCPConn{
|
||||
Conn: left,
|
||||
onHandshake: func(conn net.Conn) {
|
||||
_ = applyTCPKeepAlive(conn, 15*time.Second)
|
||||
},
|
||||
}
|
||||
if err := routerConn.ConnHandshakeSuccess(remoteConn); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !remoteConn.enabled {
|
||||
t.Fatal("expected keepalive to be enabled")
|
||||
}
|
||||
if remoteConn.period != 15*time.Second {
|
||||
t.Fatalf("unexpected keepalive period: %s", remoteConn.period)
|
||||
}
|
||||
}
|
||||
|
||||
type blockingRouteRouter struct {
|
||||
testRouter
|
||||
started chan struct{}
|
||||
release chan struct{}
|
||||
}
|
||||
|
||||
func (r *blockingRouteRouter) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
close(r.started)
|
||||
<-r.release
|
||||
_ = conn.Close()
|
||||
onClose(nil)
|
||||
}
|
||||
|
||||
type keepAliveTestConn struct {
|
||||
net.Conn
|
||||
enabled bool
|
||||
period time.Duration
|
||||
}
|
||||
|
||||
func (c *keepAliveTestConn) SetKeepAlive(enabled bool) error {
|
||||
c.enabled = enabled
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *keepAliveTestConn) SetKeepAlivePeriod(period time.Duration) error {
|
||||
c.period = period
|
||||
return nil
|
||||
}
|
||||
@@ -1,251 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type blockingRPCStream struct {
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func newBlockingRPCStream() *blockingRPCStream {
|
||||
return &blockingRPCStream{closed: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (s *blockingRPCStream) Read(_ []byte) (int, error) {
|
||||
<-s.closed
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (s *blockingRPCStream) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (s *blockingRPCStream) Close() error {
|
||||
select {
|
||||
case <-s.closed:
|
||||
default:
|
||||
close(s.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type blockingPacketDialRouter struct {
|
||||
testRouter
|
||||
entered chan struct{}
|
||||
release chan struct{}
|
||||
}
|
||||
|
||||
func (r *blockingPacketDialRouter) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) {
|
||||
select {
|
||||
case <-r.entered:
|
||||
default:
|
||||
close(r.entered)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-r.release:
|
||||
return newBlockingPacketConn(), nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func newRPCInbound(t *testing.T, router adapter.Router) *Inbound {
|
||||
t.Helper()
|
||||
|
||||
inboundInstance := newLimitedInbound(t, 0)
|
||||
inboundInstance.router = router
|
||||
return inboundInstance
|
||||
}
|
||||
|
||||
func newRPCClientPair(t *testing.T, ctx context.Context) (tunnelrpc.CloudflaredServer, io.Closer, io.Closer, net.Conn, net.Conn) {
|
||||
t.Helper()
|
||||
|
||||
serverSide, clientSide := net.Pipe()
|
||||
transport := safeTransport(clientSide)
|
||||
clientConn := newRPCClientConn(transport, ctx)
|
||||
client := tunnelrpc.CloudflaredServer{Client: clientConn.Bootstrap(ctx)}
|
||||
return client, clientConn, transport, serverSide, clientSide
|
||||
}
|
||||
|
||||
func TestServeRPCStreamRespectsContextDeadline(t *testing.T) {
|
||||
inboundInstance := newLimitedInbound(t, 0)
|
||||
stream := newBlockingRPCStream()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
ServeRPCStream(ctx, stream, inboundInstance, NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger), inboundInstance.logger)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected ServeRPCStream to exit after context deadline")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeV3RPCStreamRespectsContextDeadline(t *testing.T) {
|
||||
inboundInstance := newLimitedInbound(t, 0)
|
||||
stream := newBlockingRPCStream()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
ServeV3RPCStream(ctx, stream, inboundInstance, inboundInstance.logger)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected ServeV3RPCStream to exit after context deadline")
|
||||
}
|
||||
}
|
||||
|
||||
func TestV2RPCAckAllowsConcurrentDispatch(t *testing.T) {
|
||||
router := &blockingPacketDialRouter{
|
||||
entered: make(chan struct{}),
|
||||
release: make(chan struct{}),
|
||||
}
|
||||
inboundInstance := newRPCInbound(t, router)
|
||||
muxer := NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
client, clientConn, transport, serverSide, clientSide := newRPCClientPair(t, ctx)
|
||||
defer clientConn.Close()
|
||||
defer transport.Close()
|
||||
defer clientSide.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
ServeRPCStream(ctx, serverSide, inboundInstance, muxer, inboundInstance.logger)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
registerPromise := client.RegisterUdpSession(ctx, func(p tunnelrpc.SessionManager_registerUdpSession_Params) error {
|
||||
sessionID := uuid.New()
|
||||
if err := p.SetSessionId(sessionID[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := p.SetDstIp([]byte{127, 0, 0, 1}); err != nil {
|
||||
return err
|
||||
}
|
||||
p.SetDstPort(53)
|
||||
p.SetCloseAfterIdleHint(int64(time.Second))
|
||||
return p.SetTraceContext("")
|
||||
})
|
||||
|
||||
select {
|
||||
case <-router.entered:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected register RPC to enter the blocking dial")
|
||||
}
|
||||
|
||||
updateCtx, updateCancel := context.WithTimeout(ctx, 500*time.Millisecond)
|
||||
defer updateCancel()
|
||||
updatePromise := client.UpdateConfiguration(updateCtx, func(p tunnelrpc.ConfigurationManager_updateConfiguration_Params) error {
|
||||
p.SetVersion(1)
|
||||
return p.SetConfig([]byte(`{"ingress":[{"service":"http_status:503"}]}`))
|
||||
})
|
||||
if _, err := updatePromise.Result().Struct(); err != nil {
|
||||
t.Fatalf("expected concurrent update RPC to succeed, got %v", err)
|
||||
}
|
||||
|
||||
close(router.release)
|
||||
if _, err := registerPromise.Result().Struct(); err != nil {
|
||||
t.Fatalf("expected register RPC to complete, got %v", err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected ServeRPCStream to exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestV3RPCAckAllowsConcurrentDispatch(t *testing.T) {
|
||||
inboundInstance := newLimitedInbound(t, 0)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
client, clientConn, transport, serverSide, clientSide := newRPCClientPair(t, ctx)
|
||||
defer clientConn.Close()
|
||||
defer transport.Close()
|
||||
defer clientSide.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
ServeV3RPCStream(ctx, serverSide, inboundInstance, inboundInstance.logger)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
inboundInstance.configManager.access.Lock()
|
||||
updatePromise := client.UpdateConfiguration(ctx, func(p tunnelrpc.ConfigurationManager_updateConfiguration_Params) error {
|
||||
p.SetVersion(1)
|
||||
return p.SetConfig([]byte(`{"ingress":[{"service":"http_status:503"}]}`))
|
||||
})
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
registerCtx, registerCancel := context.WithTimeout(ctx, 500*time.Millisecond)
|
||||
defer registerCancel()
|
||||
registerPromise := client.RegisterUdpSession(registerCtx, func(p tunnelrpc.SessionManager_registerUdpSession_Params) error {
|
||||
sessionID := uuid.New()
|
||||
if err := p.SetSessionId(sessionID[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := p.SetDstIp([]byte{127, 0, 0, 1}); err != nil {
|
||||
return err
|
||||
}
|
||||
p.SetDstPort(53)
|
||||
p.SetCloseAfterIdleHint(int64(time.Second))
|
||||
return p.SetTraceContext("")
|
||||
})
|
||||
|
||||
registerResult, err := registerPromise.Result().Struct()
|
||||
if err != nil {
|
||||
t.Fatalf("expected concurrent v3 register RPC to succeed, got %v", err)
|
||||
}
|
||||
resultErr, err := registerResult.Err()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resultErr != errUnsupportedDatagramV3UDPRegistration.Error() {
|
||||
t.Fatalf("unexpected registration error %q", resultErr)
|
||||
}
|
||||
|
||||
inboundInstance.configManager.access.Unlock()
|
||||
if _, err := updatePromise.Result().Struct(); err != nil {
|
||||
t.Fatalf("expected update RPC to complete, got %v", err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected ServeV3RPCStream to exit")
|
||||
}
|
||||
}
|
||||
@@ -1,715 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultHTTPConnectTimeout = 30 * time.Second
|
||||
defaultTLSTimeout = 10 * time.Second
|
||||
defaultTCPKeepAlive = 30 * time.Second
|
||||
defaultKeepAliveTimeout = 90 * time.Second
|
||||
defaultKeepAliveConnections = 100
|
||||
defaultProxyAddress = "127.0.0.1"
|
||||
defaultWarpRoutingConnectTime = 5 * time.Second
|
||||
defaultWarpRoutingTCPKeepAlive = 30 * time.Second
|
||||
)
|
||||
|
||||
type ResolvedServiceKind int
|
||||
|
||||
const (
|
||||
ResolvedServiceHTTP ResolvedServiceKind = iota
|
||||
ResolvedServiceStream
|
||||
ResolvedServiceStatus
|
||||
ResolvedServiceUnix
|
||||
ResolvedServiceUnixTLS
|
||||
ResolvedServiceBastion
|
||||
ResolvedServiceSocksProxy
|
||||
)
|
||||
|
||||
type ResolvedService struct {
|
||||
Kind ResolvedServiceKind
|
||||
Service string
|
||||
Destination M.Socksaddr
|
||||
StreamHasPort bool
|
||||
BaseURL *url.URL
|
||||
UnixPath string
|
||||
StatusCode int
|
||||
SocksPolicy *ipRulePolicy
|
||||
OriginRequest OriginRequestConfig
|
||||
}
|
||||
|
||||
func (s ResolvedService) RouterControlled() bool {
|
||||
return s.Kind == ResolvedServiceHTTP || s.Kind == ResolvedServiceStream
|
||||
}
|
||||
|
||||
func (s ResolvedService) BuildRequestURL(requestURL string) (string, error) {
|
||||
switch s.Kind {
|
||||
case ResolvedServiceHTTP, ResolvedServiceUnix, ResolvedServiceUnixTLS:
|
||||
requestParsed, err := url.Parse(requestURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
originURL := *s.BaseURL
|
||||
originURL.Path = requestParsed.Path
|
||||
originURL.RawPath = requestParsed.RawPath
|
||||
originURL.RawQuery = requestParsed.RawQuery
|
||||
originURL.Fragment = requestParsed.Fragment
|
||||
return originURL.String(), nil
|
||||
default:
|
||||
return requestURL, nil
|
||||
}
|
||||
}
|
||||
|
||||
func canonicalizeHTTPOriginURL(parsedURL *url.URL) *url.URL {
|
||||
if parsedURL == nil {
|
||||
return nil
|
||||
}
|
||||
canonicalURL := *parsedURL
|
||||
switch canonicalURL.Scheme {
|
||||
case "ws":
|
||||
canonicalURL.Scheme = "http"
|
||||
case "wss":
|
||||
canonicalURL.Scheme = "https"
|
||||
}
|
||||
return &canonicalURL
|
||||
}
|
||||
|
||||
func isHTTPServiceScheme(scheme string) bool {
|
||||
switch scheme {
|
||||
case "http", "https", "ws", "wss":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type compiledIngressRule struct {
|
||||
Hostname string
|
||||
PunycodeHostname string
|
||||
Path *regexp.Regexp
|
||||
Service ResolvedService
|
||||
}
|
||||
|
||||
type RuntimeConfig struct {
|
||||
Ingress []compiledIngressRule
|
||||
OriginRequest OriginRequestConfig
|
||||
WarpRouting WarpRoutingConfig
|
||||
}
|
||||
|
||||
type OriginRequestConfig struct {
|
||||
ConnectTimeout time.Duration
|
||||
TLSTimeout time.Duration
|
||||
TCPKeepAlive time.Duration
|
||||
NoHappyEyeballs bool
|
||||
KeepAliveTimeout time.Duration
|
||||
KeepAliveConnections int
|
||||
HTTPHostHeader string
|
||||
OriginServerName string
|
||||
MatchSNIToHost bool
|
||||
CAPool string
|
||||
NoTLSVerify bool
|
||||
DisableChunkedEncoding bool
|
||||
BastionMode bool
|
||||
ProxyAddress string
|
||||
ProxyPort uint
|
||||
ProxyType string
|
||||
IPRules []IPRule
|
||||
HTTP2Origin bool
|
||||
Access AccessConfig
|
||||
}
|
||||
|
||||
type AccessConfig struct {
|
||||
Required bool
|
||||
TeamName string
|
||||
AudTag []string
|
||||
Environment string
|
||||
}
|
||||
|
||||
type IPRule struct {
|
||||
Prefix string
|
||||
Ports []int
|
||||
Allow bool
|
||||
}
|
||||
|
||||
type WarpRoutingConfig struct {
|
||||
ConnectTimeout time.Duration
|
||||
MaxActiveFlows uint64
|
||||
TCPKeepAlive time.Duration
|
||||
}
|
||||
|
||||
type ConfigUpdateResult struct {
|
||||
LastAppliedVersion int32
|
||||
Err error
|
||||
}
|
||||
|
||||
type ConfigManager struct {
|
||||
access sync.RWMutex
|
||||
currentVersion int32
|
||||
activeConfig RuntimeConfig
|
||||
}
|
||||
|
||||
func NewConfigManager() (*ConfigManager, error) {
|
||||
config, err := defaultRuntimeConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ConfigManager{
|
||||
currentVersion: -1,
|
||||
activeConfig: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *ConfigManager) Snapshot() RuntimeConfig {
|
||||
m.access.RLock()
|
||||
defer m.access.RUnlock()
|
||||
return m.activeConfig
|
||||
}
|
||||
|
||||
func (m *ConfigManager) CurrentVersion() int32 {
|
||||
m.access.RLock()
|
||||
defer m.access.RUnlock()
|
||||
return m.currentVersion
|
||||
}
|
||||
|
||||
func (m *ConfigManager) Apply(version int32, raw []byte) ConfigUpdateResult {
|
||||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
|
||||
if version <= m.currentVersion {
|
||||
return ConfigUpdateResult{LastAppliedVersion: m.currentVersion}
|
||||
}
|
||||
|
||||
config, err := buildRemoteRuntimeConfig(raw)
|
||||
if err != nil {
|
||||
return ConfigUpdateResult{
|
||||
LastAppliedVersion: m.currentVersion,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
m.activeConfig = config
|
||||
m.currentVersion = version
|
||||
return ConfigUpdateResult{LastAppliedVersion: m.currentVersion}
|
||||
}
|
||||
|
||||
func (m *ConfigManager) Resolve(hostname, path string) (ResolvedService, bool) {
|
||||
m.access.RLock()
|
||||
defer m.access.RUnlock()
|
||||
return m.activeConfig.Resolve(hostname, path)
|
||||
}
|
||||
|
||||
func (c RuntimeConfig) Resolve(hostname, path string) (ResolvedService, bool) {
|
||||
host := stripPort(hostname)
|
||||
for _, rule := range c.Ingress {
|
||||
if !matchIngressRule(rule, host, path) {
|
||||
continue
|
||||
}
|
||||
return rule.Service, true
|
||||
}
|
||||
return ResolvedService{}, false
|
||||
}
|
||||
|
||||
func matchIngressRule(rule compiledIngressRule, hostname, path string) bool {
|
||||
hostMatch := rule.Hostname == "" || rule.Hostname == "*" || matchIngressHost(rule.Hostname, hostname)
|
||||
if !hostMatch && rule.PunycodeHostname != "" {
|
||||
hostMatch = matchIngressHost(rule.PunycodeHostname, hostname)
|
||||
}
|
||||
if !hostMatch {
|
||||
return false
|
||||
}
|
||||
return rule.Path == nil || rule.Path.MatchString(path)
|
||||
}
|
||||
|
||||
func matchIngressHost(pattern, hostname string) bool {
|
||||
if pattern == hostname {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(pattern, "*.") {
|
||||
return strings.HasSuffix(hostname, strings.TrimPrefix(pattern, "*"))
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func defaultRuntimeConfig() (RuntimeConfig, error) {
|
||||
defaultOriginRequest := defaultOriginRequestConfig()
|
||||
compiledRules, err := compileIngressRules(defaultOriginRequest, nil)
|
||||
if err != nil {
|
||||
return RuntimeConfig{}, err
|
||||
}
|
||||
return RuntimeConfig{
|
||||
Ingress: compiledRules,
|
||||
OriginRequest: defaultOriginRequest,
|
||||
WarpRouting: WarpRoutingConfig{
|
||||
ConnectTimeout: defaultWarpRoutingConnectTime,
|
||||
TCPKeepAlive: defaultWarpRoutingTCPKeepAlive,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildRemoteRuntimeConfig(raw []byte) (RuntimeConfig, error) {
|
||||
var remote remoteConfigJSON
|
||||
if err := json.Unmarshal(raw, &remote); err != nil {
|
||||
return RuntimeConfig{}, E.Cause(err, "decode remote config")
|
||||
}
|
||||
defaultOriginRequest := originRequestFromRemote(remote.OriginRequest)
|
||||
warpRouting := warpRoutingFromRemote(remote.WarpRouting)
|
||||
var ingressRules []localIngressRule
|
||||
for _, rule := range remote.Ingress {
|
||||
ingressRules = append(ingressRules, localIngressRule{
|
||||
Hostname: rule.Hostname,
|
||||
Path: rule.Path,
|
||||
Service: rule.Service,
|
||||
OriginRequest: mergeRemoteOriginRequest(defaultOriginRequest, rule.OriginRequest),
|
||||
})
|
||||
}
|
||||
compiledRules, err := compileIngressRules(defaultOriginRequest, ingressRules)
|
||||
if err != nil {
|
||||
return RuntimeConfig{}, err
|
||||
}
|
||||
return RuntimeConfig{
|
||||
Ingress: compiledRules,
|
||||
OriginRequest: defaultOriginRequest,
|
||||
WarpRouting: warpRouting,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type localIngressRule struct {
|
||||
Hostname string
|
||||
Path string
|
||||
Service string
|
||||
OriginRequest OriginRequestConfig
|
||||
}
|
||||
|
||||
type remoteConfigJSON struct {
|
||||
OriginRequest remoteOriginRequestJSON `json:"originRequest"`
|
||||
Ingress []remoteIngressRuleJSON `json:"ingress"`
|
||||
WarpRouting remoteWarpRoutingJSON `json:"warp-routing"`
|
||||
}
|
||||
|
||||
type remoteIngressRuleJSON struct {
|
||||
Hostname string `json:"hostname,omitempty"`
|
||||
Path string `json:"path,omitempty"`
|
||||
Service string `json:"service"`
|
||||
OriginRequest remoteOriginRequestJSON `json:"originRequest,omitempty"`
|
||||
}
|
||||
|
||||
type remoteOriginRequestJSON struct {
|
||||
ConnectTimeout int64 `json:"connectTimeout,omitempty"`
|
||||
TLSTimeout int64 `json:"tlsTimeout,omitempty"`
|
||||
TCPKeepAlive int64 `json:"tcpKeepAlive,omitempty"`
|
||||
NoHappyEyeballs *bool `json:"noHappyEyeballs,omitempty"`
|
||||
KeepAliveTimeout int64 `json:"keepAliveTimeout,omitempty"`
|
||||
KeepAliveConnections *int `json:"keepAliveConnections,omitempty"`
|
||||
HTTPHostHeader string `json:"httpHostHeader,omitempty"`
|
||||
OriginServerName string `json:"originServerName,omitempty"`
|
||||
MatchSNIToHost *bool `json:"matchSNIToHost,omitempty"`
|
||||
CAPool string `json:"caPool,omitempty"`
|
||||
NoTLSVerify *bool `json:"noTLSVerify,omitempty"`
|
||||
DisableChunkedEncoding *bool `json:"disableChunkedEncoding,omitempty"`
|
||||
BastionMode *bool `json:"bastionMode,omitempty"`
|
||||
ProxyAddress string `json:"proxyAddress,omitempty"`
|
||||
ProxyPort *uint `json:"proxyPort,omitempty"`
|
||||
ProxyType string `json:"proxyType,omitempty"`
|
||||
IPRules []remoteIPRuleJSON `json:"ipRules,omitempty"`
|
||||
HTTP2Origin *bool `json:"http2Origin,omitempty"`
|
||||
Access *remoteAccessJSON `json:"access,omitempty"`
|
||||
}
|
||||
|
||||
type remoteAccessJSON struct {
|
||||
Required bool `json:"required,omitempty"`
|
||||
TeamName string `json:"teamName,omitempty"`
|
||||
AudTag []string `json:"audTag,omitempty"`
|
||||
Environment string `json:"environment,omitempty"`
|
||||
}
|
||||
|
||||
type remoteIPRuleJSON struct {
|
||||
Prefix string `json:"prefix,omitempty"`
|
||||
Ports []int `json:"ports,omitempty"`
|
||||
Allow bool `json:"allow,omitempty"`
|
||||
}
|
||||
|
||||
type remoteWarpRoutingJSON struct {
|
||||
ConnectTimeout int64 `json:"connectTimeout,omitempty"`
|
||||
MaxActiveFlows uint64 `json:"maxActiveFlows,omitempty"`
|
||||
TCPKeepAlive int64 `json:"tcpKeepAlive,omitempty"`
|
||||
}
|
||||
|
||||
func compileIngressRules(defaultOriginRequest OriginRequestConfig, rawRules []localIngressRule) ([]compiledIngressRule, error) {
|
||||
if len(rawRules) == 0 {
|
||||
rawRules = []localIngressRule{{
|
||||
Service: "http_status:503",
|
||||
OriginRequest: defaultOriginRequest,
|
||||
}}
|
||||
}
|
||||
if !isCatchAllRule(rawRules[len(rawRules)-1].Hostname, rawRules[len(rawRules)-1].Path) {
|
||||
return nil, E.New("the last ingress rule must be a catch-all rule")
|
||||
}
|
||||
|
||||
compiled := make([]compiledIngressRule, 0, len(rawRules))
|
||||
for index, rule := range rawRules {
|
||||
if err := validateHostname(rule.Hostname, index == len(rawRules)-1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateAccessConfiguration(rule.OriginRequest.Access); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
service, err := parseResolvedService(rule.Service, rule.OriginRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var pathPattern *regexp.Regexp
|
||||
if rule.Path != "" {
|
||||
pathPattern, err = regexp.Compile(rule.Path)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "compile ingress path regex")
|
||||
}
|
||||
}
|
||||
punycode := ""
|
||||
if rule.Hostname != "" && rule.Hostname != "*" {
|
||||
punycodeValue, err := idna.Lookup.ToASCII(rule.Hostname)
|
||||
if err == nil && punycodeValue != rule.Hostname {
|
||||
punycode = punycodeValue
|
||||
}
|
||||
}
|
||||
compiled = append(compiled, compiledIngressRule{
|
||||
Hostname: rule.Hostname,
|
||||
PunycodeHostname: punycode,
|
||||
Path: pathPattern,
|
||||
Service: service,
|
||||
})
|
||||
}
|
||||
return compiled, nil
|
||||
}
|
||||
|
||||
func parseResolvedService(rawService string, originRequest OriginRequestConfig) (ResolvedService, error) {
|
||||
switch {
|
||||
case rawService == "":
|
||||
if originRequest.BastionMode {
|
||||
return ResolvedService{
|
||||
Kind: ResolvedServiceBastion,
|
||||
Service: "bastion",
|
||||
OriginRequest: originRequest,
|
||||
}, nil
|
||||
}
|
||||
return ResolvedService{}, E.New("missing ingress service")
|
||||
case strings.HasPrefix(rawService, "http_status:"):
|
||||
statusCode, err := strconv.Atoi(strings.TrimPrefix(rawService, "http_status:"))
|
||||
if err != nil {
|
||||
return ResolvedService{}, E.Cause(err, "parse http_status service")
|
||||
}
|
||||
if statusCode < 100 || statusCode > 999 {
|
||||
return ResolvedService{}, E.New("invalid http_status code: ", statusCode)
|
||||
}
|
||||
return ResolvedService{
|
||||
Kind: ResolvedServiceStatus,
|
||||
Service: rawService,
|
||||
StatusCode: statusCode,
|
||||
OriginRequest: originRequest,
|
||||
}, nil
|
||||
case rawService == "hello_world" || rawService == "hello-world":
|
||||
return ResolvedService{}, E.New("unsupported ingress service: hello_world")
|
||||
case rawService == "bastion":
|
||||
return ResolvedService{
|
||||
Kind: ResolvedServiceBastion,
|
||||
Service: rawService,
|
||||
OriginRequest: originRequest,
|
||||
}, nil
|
||||
case rawService == "socks-proxy":
|
||||
policy, err := newIPRulePolicy(originRequest.IPRules)
|
||||
if err != nil {
|
||||
return ResolvedService{}, E.Cause(err, "compile socks-proxy ip rules")
|
||||
}
|
||||
return ResolvedService{
|
||||
Kind: ResolvedServiceSocksProxy,
|
||||
Service: rawService,
|
||||
SocksPolicy: policy,
|
||||
OriginRequest: originRequest,
|
||||
}, nil
|
||||
case strings.HasPrefix(rawService, "unix:"):
|
||||
return ResolvedService{
|
||||
Kind: ResolvedServiceUnix,
|
||||
Service: rawService,
|
||||
UnixPath: strings.TrimPrefix(rawService, "unix:"),
|
||||
BaseURL: &url.URL{Scheme: "http", Host: "localhost"},
|
||||
OriginRequest: originRequest,
|
||||
}, nil
|
||||
case strings.HasPrefix(rawService, "unix+tls:"):
|
||||
return ResolvedService{
|
||||
Kind: ResolvedServiceUnixTLS,
|
||||
Service: rawService,
|
||||
UnixPath: strings.TrimPrefix(rawService, "unix+tls:"),
|
||||
BaseURL: &url.URL{Scheme: "https", Host: "localhost"},
|
||||
OriginRequest: originRequest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(rawService)
|
||||
if err != nil {
|
||||
return ResolvedService{}, E.Cause(err, "parse ingress service URL")
|
||||
}
|
||||
if parsedURL.Scheme == "" || parsedURL.Hostname() == "" {
|
||||
return ResolvedService{}, E.New("ingress service must include scheme and hostname: ", rawService)
|
||||
}
|
||||
if parsedURL.Path != "" {
|
||||
return ResolvedService{}, E.New("ingress service cannot include a path: ", rawService)
|
||||
}
|
||||
|
||||
if isHTTPServiceScheme(parsedURL.Scheme) {
|
||||
return ResolvedService{
|
||||
Kind: ResolvedServiceHTTP,
|
||||
Service: rawService,
|
||||
Destination: parseHTTPServiceDestination(parsedURL),
|
||||
BaseURL: canonicalizeHTTPOriginURL(parsedURL),
|
||||
OriginRequest: originRequest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
destination, hasPort := parseStreamServiceDestination(parsedURL)
|
||||
return ResolvedService{
|
||||
Kind: ResolvedServiceStream,
|
||||
Service: rawService,
|
||||
Destination: destination,
|
||||
StreamHasPort: hasPort,
|
||||
BaseURL: parsedURL,
|
||||
OriginRequest: originRequest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseHTTPServiceDestination(parsedURL *url.URL) M.Socksaddr {
|
||||
host := parsedURL.Hostname()
|
||||
port := parsedURL.Port()
|
||||
if port == "" {
|
||||
switch parsedURL.Scheme {
|
||||
case "https", "wss":
|
||||
port = "443"
|
||||
default:
|
||||
port = "80"
|
||||
}
|
||||
}
|
||||
return M.ParseSocksaddr(net.JoinHostPort(host, port))
|
||||
}
|
||||
|
||||
func parseStreamServiceDestination(parsedURL *url.URL) (M.Socksaddr, bool) {
|
||||
host := parsedURL.Hostname()
|
||||
port := parsedURL.Port()
|
||||
if port == "" {
|
||||
switch parsedURL.Scheme {
|
||||
case "ssh":
|
||||
port = "22"
|
||||
case "rdp":
|
||||
port = "3389"
|
||||
case "smb":
|
||||
port = "445"
|
||||
case "tcp":
|
||||
port = "7864"
|
||||
default:
|
||||
return M.ParseSocksaddrHostPort(host, 0), false
|
||||
}
|
||||
}
|
||||
return M.ParseSocksaddr(net.JoinHostPort(host, port)), true
|
||||
}
|
||||
|
||||
func validateHostname(hostname string, isLast bool) error {
|
||||
if hostname == "" || hostname == "*" {
|
||||
if !isLast {
|
||||
return E.New("only the last ingress rule may be a catch-all rule")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if strings.Count(hostname, "*") > 1 || (strings.Contains(hostname, "*") && !strings.HasPrefix(hostname, "*.")) {
|
||||
return E.New("hostname wildcard must be in the form *.example.com")
|
||||
}
|
||||
if stripPort(hostname) != hostname {
|
||||
return E.New("ingress hostname cannot contain a port")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isCatchAllRule(hostname, path string) bool {
|
||||
return (hostname == "" || hostname == "*") && path == ""
|
||||
}
|
||||
|
||||
func stripPort(hostname string) string {
|
||||
if host, _, err := net.SplitHostPort(hostname); err == nil {
|
||||
return host
|
||||
}
|
||||
return hostname
|
||||
}
|
||||
|
||||
func defaultOriginRequestConfig() OriginRequestConfig {
|
||||
return OriginRequestConfig{
|
||||
ConnectTimeout: defaultHTTPConnectTimeout,
|
||||
TLSTimeout: defaultTLSTimeout,
|
||||
TCPKeepAlive: defaultTCPKeepAlive,
|
||||
KeepAliveTimeout: defaultKeepAliveTimeout,
|
||||
KeepAliveConnections: defaultKeepAliveConnections,
|
||||
ProxyAddress: defaultProxyAddress,
|
||||
}
|
||||
}
|
||||
|
||||
func originRequestFromRemote(input remoteOriginRequestJSON) OriginRequestConfig {
|
||||
config := defaultOriginRequestConfig()
|
||||
if input.ConnectTimeout != 0 {
|
||||
config.ConnectTimeout = time.Duration(input.ConnectTimeout) * time.Second
|
||||
}
|
||||
if input.TLSTimeout != 0 {
|
||||
config.TLSTimeout = time.Duration(input.TLSTimeout) * time.Second
|
||||
}
|
||||
if input.TCPKeepAlive != 0 {
|
||||
config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) * time.Second
|
||||
}
|
||||
if input.KeepAliveTimeout != 0 {
|
||||
config.KeepAliveTimeout = time.Duration(input.KeepAliveTimeout) * time.Second
|
||||
}
|
||||
if input.KeepAliveConnections != nil {
|
||||
config.KeepAliveConnections = *input.KeepAliveConnections
|
||||
}
|
||||
if input.NoHappyEyeballs != nil {
|
||||
config.NoHappyEyeballs = *input.NoHappyEyeballs
|
||||
}
|
||||
config.HTTPHostHeader = input.HTTPHostHeader
|
||||
config.OriginServerName = input.OriginServerName
|
||||
if input.MatchSNIToHost != nil {
|
||||
config.MatchSNIToHost = *input.MatchSNIToHost
|
||||
}
|
||||
config.CAPool = input.CAPool
|
||||
if input.NoTLSVerify != nil {
|
||||
config.NoTLSVerify = *input.NoTLSVerify
|
||||
}
|
||||
if input.DisableChunkedEncoding != nil {
|
||||
config.DisableChunkedEncoding = *input.DisableChunkedEncoding
|
||||
}
|
||||
if input.BastionMode != nil {
|
||||
config.BastionMode = *input.BastionMode
|
||||
}
|
||||
if input.ProxyAddress != "" {
|
||||
config.ProxyAddress = input.ProxyAddress
|
||||
}
|
||||
if input.ProxyPort != nil {
|
||||
config.ProxyPort = *input.ProxyPort
|
||||
}
|
||||
config.ProxyType = input.ProxyType
|
||||
if input.HTTP2Origin != nil {
|
||||
config.HTTP2Origin = *input.HTTP2Origin
|
||||
}
|
||||
if input.Access != nil {
|
||||
config.Access = AccessConfig{
|
||||
Required: input.Access.Required,
|
||||
TeamName: input.Access.TeamName,
|
||||
AudTag: append([]string(nil), input.Access.AudTag...),
|
||||
Environment: input.Access.Environment,
|
||||
}
|
||||
}
|
||||
for _, rule := range input.IPRules {
|
||||
config.IPRules = append(config.IPRules, IPRule{
|
||||
Prefix: rule.Prefix,
|
||||
Ports: append([]int(nil), rule.Ports...),
|
||||
Allow: rule.Allow,
|
||||
})
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
func mergeRemoteOriginRequest(base OriginRequestConfig, override remoteOriginRequestJSON) OriginRequestConfig {
|
||||
result := base
|
||||
if override.ConnectTimeout != 0 {
|
||||
result.ConnectTimeout = time.Duration(override.ConnectTimeout) * time.Second
|
||||
}
|
||||
if override.TLSTimeout != 0 {
|
||||
result.TLSTimeout = time.Duration(override.TLSTimeout) * time.Second
|
||||
}
|
||||
if override.TCPKeepAlive != 0 {
|
||||
result.TCPKeepAlive = time.Duration(override.TCPKeepAlive) * time.Second
|
||||
}
|
||||
if override.NoHappyEyeballs != nil {
|
||||
result.NoHappyEyeballs = *override.NoHappyEyeballs
|
||||
}
|
||||
if override.KeepAliveTimeout != 0 {
|
||||
result.KeepAliveTimeout = time.Duration(override.KeepAliveTimeout) * time.Second
|
||||
}
|
||||
if override.KeepAliveConnections != nil {
|
||||
result.KeepAliveConnections = *override.KeepAliveConnections
|
||||
}
|
||||
if override.HTTPHostHeader != "" {
|
||||
result.HTTPHostHeader = override.HTTPHostHeader
|
||||
}
|
||||
if override.OriginServerName != "" {
|
||||
result.OriginServerName = override.OriginServerName
|
||||
}
|
||||
if override.MatchSNIToHost != nil {
|
||||
result.MatchSNIToHost = *override.MatchSNIToHost
|
||||
}
|
||||
if override.CAPool != "" {
|
||||
result.CAPool = override.CAPool
|
||||
}
|
||||
if override.NoTLSVerify != nil {
|
||||
result.NoTLSVerify = *override.NoTLSVerify
|
||||
}
|
||||
if override.DisableChunkedEncoding != nil {
|
||||
result.DisableChunkedEncoding = *override.DisableChunkedEncoding
|
||||
}
|
||||
if override.BastionMode != nil {
|
||||
result.BastionMode = *override.BastionMode
|
||||
}
|
||||
if override.ProxyAddress != "" {
|
||||
result.ProxyAddress = override.ProxyAddress
|
||||
}
|
||||
if override.ProxyPort != nil {
|
||||
result.ProxyPort = *override.ProxyPort
|
||||
}
|
||||
if override.ProxyType != "" {
|
||||
result.ProxyType = override.ProxyType
|
||||
}
|
||||
if len(override.IPRules) > 0 {
|
||||
result.IPRules = nil
|
||||
for _, rule := range override.IPRules {
|
||||
result.IPRules = append(result.IPRules, IPRule{
|
||||
Prefix: rule.Prefix,
|
||||
Ports: append([]int(nil), rule.Ports...),
|
||||
Allow: rule.Allow,
|
||||
})
|
||||
}
|
||||
}
|
||||
if override.HTTP2Origin != nil {
|
||||
result.HTTP2Origin = *override.HTTP2Origin
|
||||
}
|
||||
if override.Access != nil {
|
||||
result.Access = AccessConfig{
|
||||
Required: override.Access.Required,
|
||||
TeamName: override.Access.TeamName,
|
||||
AudTag: append([]string(nil), override.Access.AudTag...),
|
||||
Environment: override.Access.Environment,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func warpRoutingFromRemote(input remoteWarpRoutingJSON) WarpRoutingConfig {
|
||||
config := WarpRoutingConfig{
|
||||
ConnectTimeout: defaultWarpRoutingConnectTime,
|
||||
TCPKeepAlive: defaultWarpRoutingTCPKeepAlive,
|
||||
MaxActiveFlows: input.MaxActiveFlows,
|
||||
}
|
||||
if input.ConnectTimeout != 0 {
|
||||
config.ConnectTimeout = time.Duration(input.ConnectTimeout) * time.Second
|
||||
}
|
||||
if input.TCPKeepAlive != 0 {
|
||||
config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) * time.Second
|
||||
}
|
||||
return config
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
)
|
||||
|
||||
const (
|
||||
safeTransportMaxRetries = 3
|
||||
safeTransportRetryInterval = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
type safeReadWriteCloser struct {
|
||||
io.ReadWriteCloser
|
||||
retries int
|
||||
}
|
||||
|
||||
func (s *safeReadWriteCloser) Read(p []byte) (int, error) {
|
||||
n, err := s.ReadWriteCloser.Read(p)
|
||||
if n == 0 && err != nil && isTemporaryError(err) {
|
||||
if s.retries >= safeTransportMaxRetries {
|
||||
return 0, E.Cause(err, "read capnproto transport after multiple temporary errors")
|
||||
}
|
||||
s.retries++
|
||||
time.Sleep(safeTransportRetryInterval)
|
||||
return n, err
|
||||
}
|
||||
if err == nil {
|
||||
s.retries = 0
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func isTemporaryError(err error) bool {
|
||||
type temporary interface{ Temporary() bool }
|
||||
t, ok := err.(temporary)
|
||||
return ok && t.Temporary()
|
||||
}
|
||||
|
||||
func safeTransport(stream io.ReadWriteCloser) rpc.Transport {
|
||||
return rpc.StreamTransport(&safeReadWriteCloser{ReadWriteCloser: stream})
|
||||
}
|
||||
|
||||
type noopCapnpLogger struct{}
|
||||
|
||||
func (noopCapnpLogger) Infof(ctx context.Context, format string, args ...interface{}) {}
|
||||
func (noopCapnpLogger) Errorf(ctx context.Context, format string, args ...interface{}) {}
|
||||
|
||||
func newRPCClientConn(transport rpc.Transport, ctx context.Context) *rpc.Conn {
|
||||
return rpc.NewConn(transport, rpc.ConnLog(noopCapnpLogger{}))
|
||||
}
|
||||
|
||||
func newRPCServerConn(transport rpc.Transport, client capnp.Client) *rpc.Conn {
|
||||
return rpc.NewConn(transport, rpc.MainInterface(client), rpc.ConnLog(noopCapnpLogger{}))
|
||||
}
|
||||
@@ -1,355 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/transport/v2raywebsocket"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/ws"
|
||||
)
|
||||
|
||||
var wsAcceptGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||
|
||||
const (
|
||||
socksReplySuccess = 0
|
||||
socksReplyRuleFailure = 2
|
||||
socksReplyNetworkUnreachable = 3
|
||||
socksReplyHostUnreachable = 4
|
||||
socksReplyConnectionRefused = 5
|
||||
socksReplyCommandNotSupported = 7
|
||||
)
|
||||
|
||||
func (i *Inbound) handleBastionStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
destination, err := resolveBastionDestination(request)
|
||||
if err != nil {
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
i.handleRouterBackedStream(ctx, stream, respWriter, request, M.ParseSocksaddr(destination), service.OriginRequest.ProxyType)
|
||||
}
|
||||
|
||||
func (i *Inbound) handleStreamService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
if !service.StreamHasPort {
|
||||
respWriter.WriteResponse(E.New("address ", streamServiceHostname(service), ": missing port in address"), nil)
|
||||
return
|
||||
}
|
||||
i.handleRouterBackedStream(ctx, stream, respWriter, request, service.Destination, service.OriginRequest.ProxyType)
|
||||
}
|
||||
|
||||
func (i *Inbound) handleRouterBackedStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, destination M.Socksaddr, proxyType string) {
|
||||
targetConn, cleanup, err := i.dialRouterTCP(ctx, destination)
|
||||
if err != nil {
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
err = respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusSwitchingProtocols, websocketResponseHeaders(request)))
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "write bastion websocket response: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
wsConn := v2raywebsocket.NewConn(newStreamConn(stream), nil, ws.StateServerSide)
|
||||
defer wsConn.Close()
|
||||
if isSocksProxyType(proxyType) {
|
||||
if err := serveFixedSocksStream(ctx, wsConn, targetConn); err != nil && !E.IsClosedOrCanceled(err) {
|
||||
i.logger.DebugContext(ctx, "socks-over-websocket stream closed: ", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
_ = bufio.CopyConn(ctx, wsConn, targetConn)
|
||||
}
|
||||
|
||||
func (i *Inbound) handleSocksProxyStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
err := respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusSwitchingProtocols, websocketResponseHeaders(request)))
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "write socks-proxy websocket response: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
wsConn := v2raywebsocket.NewConn(newStreamConn(stream), nil, ws.StateServerSide)
|
||||
defer wsConn.Close()
|
||||
if err := i.serveSocksProxy(ctx, wsConn, service.SocksPolicy); err != nil && !E.IsClosedOrCanceled(err) {
|
||||
i.logger.DebugContext(ctx, "socks-proxy stream closed: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func resolveBastionDestination(request *ConnectRequest) (string, error) {
|
||||
headerValue := requestHeaderValue(request, "Cf-Access-Jump-Destination")
|
||||
if headerValue == "" {
|
||||
return "", E.New("missing Cf-Access-Jump-Destination header")
|
||||
}
|
||||
if parsed, err := url.Parse(headerValue); err == nil && parsed.Host != "" {
|
||||
headerValue = parsed.Host
|
||||
}
|
||||
return strings.SplitN(headerValue, "/", 2)[0], nil
|
||||
}
|
||||
|
||||
func websocketResponseHeaders(request *ConnectRequest) http.Header {
|
||||
header := http.Header{}
|
||||
header.Set("Connection", "Upgrade")
|
||||
header.Set("Upgrade", "websocket")
|
||||
secKey := requestHeaderValue(request, "Sec-WebSocket-Key")
|
||||
if secKey != "" {
|
||||
sum := sha1.Sum(append([]byte(secKey), wsAcceptGUID...))
|
||||
header.Set("Sec-WebSocket-Accept", base64.StdEncoding.EncodeToString(sum[:]))
|
||||
}
|
||||
return header
|
||||
}
|
||||
|
||||
func isSocksProxyType(proxyType string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(proxyType))
|
||||
return lower == "socks" || lower == "socks5"
|
||||
}
|
||||
|
||||
func serveFixedSocksStream(ctx context.Context, conn net.Conn, targetConn net.Conn) error {
|
||||
version := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, version); err != nil {
|
||||
return err
|
||||
}
|
||||
if version[0] != 5 {
|
||||
return E.New("unsupported SOCKS version: ", version[0])
|
||||
}
|
||||
|
||||
methodCount := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, methodCount); err != nil {
|
||||
return err
|
||||
}
|
||||
methods := make([]byte, int(methodCount[0]))
|
||||
if _, err := io.ReadFull(conn, methods); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var supportsNoAuth bool
|
||||
for _, method := range methods {
|
||||
if method == 0 {
|
||||
supportsNoAuth = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !supportsNoAuth {
|
||||
_, err := conn.Write([]byte{5, 255})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return E.New("unknown authentication type")
|
||||
}
|
||||
if _, err := conn.Write([]byte{5, 0}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
requestHeader := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, requestHeader); err != nil {
|
||||
return err
|
||||
}
|
||||
if requestHeader[0] != 5 {
|
||||
return E.New("unsupported SOCKS request version: ", requestHeader[0])
|
||||
}
|
||||
if requestHeader[1] != 1 {
|
||||
_ = writeSocksReply(conn, socksReplyCommandNotSupported)
|
||||
return E.New("unsupported SOCKS command: ", requestHeader[1])
|
||||
}
|
||||
if _, err := readSocksDestination(conn, requestHeader[3]); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeSocksReply(conn, socksReplySuccess); err != nil {
|
||||
return err
|
||||
}
|
||||
return bufio.CopyConn(ctx, conn, targetConn)
|
||||
}
|
||||
|
||||
func requestHeaderValue(request *ConnectRequest, headerName string) string {
|
||||
for _, entry := range request.Metadata {
|
||||
if !strings.HasPrefix(entry.Key, metadataHTTPHeader+":") {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimPrefix(entry.Key, metadataHTTPHeader+":")
|
||||
if strings.EqualFold(name, headerName) {
|
||||
return entry.Val
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func streamServiceHostname(service ResolvedService) string {
|
||||
if service.BaseURL != nil && service.BaseURL.Hostname() != "" {
|
||||
return service.BaseURL.Hostname()
|
||||
}
|
||||
parsedURL, err := url.Parse(service.Service)
|
||||
if err == nil && parsedURL.Hostname() != "" {
|
||||
return parsedURL.Hostname()
|
||||
}
|
||||
return service.Destination.AddrString()
|
||||
}
|
||||
|
||||
func (i *Inbound) dialRouterTCP(ctx context.Context, destination M.Socksaddr) (net.Conn, func(), error) {
|
||||
metadata := adapter.InboundContext{
|
||||
Inbound: i.Tag(),
|
||||
InboundType: i.Type(),
|
||||
Network: N.NetworkTCP,
|
||||
Destination: destination,
|
||||
}
|
||||
return i.dialRouterTCPWithMetadata(ctx, metadata, routedPipeTCPOptions{})
|
||||
}
|
||||
|
||||
func (i *Inbound) serveSocksProxy(ctx context.Context, conn net.Conn, policy *ipRulePolicy) error {
|
||||
version := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, version); err != nil {
|
||||
return err
|
||||
}
|
||||
if version[0] != 5 {
|
||||
return E.New("unsupported SOCKS version: ", version[0])
|
||||
}
|
||||
|
||||
methodCount := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, methodCount); err != nil {
|
||||
return err
|
||||
}
|
||||
methods := make([]byte, int(methodCount[0]))
|
||||
if _, err := io.ReadFull(conn, methods); err != nil {
|
||||
return err
|
||||
}
|
||||
var supportsNoAuth bool
|
||||
for _, method := range methods {
|
||||
if method == 0 {
|
||||
supportsNoAuth = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !supportsNoAuth {
|
||||
if _, err := conn.Write([]byte{5, 255}); err != nil {
|
||||
return err
|
||||
}
|
||||
return E.New("unknown authentication type")
|
||||
}
|
||||
if _, err := conn.Write([]byte{5, 0}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
requestHeader := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, requestHeader); err != nil {
|
||||
return err
|
||||
}
|
||||
if requestHeader[0] != 5 {
|
||||
return E.New("unsupported SOCKS request version: ", requestHeader[0])
|
||||
}
|
||||
if requestHeader[1] != 1 {
|
||||
_ = writeSocksReply(conn, socksReplyCommandNotSupported)
|
||||
return E.New("unsupported SOCKS command: ", requestHeader[1])
|
||||
}
|
||||
|
||||
destination, err := readSocksDestination(conn, requestHeader[3])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
allowed, err := policy.Allow(ctx, destination)
|
||||
if err != nil {
|
||||
_ = writeSocksReply(conn, socksReplyRuleFailure)
|
||||
return err
|
||||
}
|
||||
if !allowed {
|
||||
_ = writeSocksReply(conn, socksReplyRuleFailure)
|
||||
return E.New("connect to ", destination, " denied by ip_rules")
|
||||
}
|
||||
targetConn, cleanup, err := i.dialRouterTCP(ctx, destination)
|
||||
if err != nil {
|
||||
_ = writeSocksReply(conn, socksReplyForDialError(err))
|
||||
return err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
if err := writeSocksReply(conn, socksReplySuccess); err != nil {
|
||||
return err
|
||||
}
|
||||
return bufio.CopyConn(ctx, conn, targetConn)
|
||||
}
|
||||
|
||||
func writeSocksReply(conn net.Conn, reply byte) error {
|
||||
_, err := conn.Write([]byte{5, reply, 0, 1, 0, 0, 0, 0, 0, 0})
|
||||
return err
|
||||
}
|
||||
|
||||
func socksReplyForDialError(err error) byte {
|
||||
lower := strings.ToLower(err.Error())
|
||||
switch {
|
||||
case strings.Contains(lower, "refused"):
|
||||
return socksReplyConnectionRefused
|
||||
case strings.Contains(lower, "network is unreachable"):
|
||||
return socksReplyNetworkUnreachable
|
||||
default:
|
||||
return socksReplyHostUnreachable
|
||||
}
|
||||
}
|
||||
|
||||
func readSocksDestination(conn net.Conn, addressType byte) (M.Socksaddr, error) {
|
||||
switch addressType {
|
||||
case 1:
|
||||
addr := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, addr); err != nil {
|
||||
return M.Socksaddr{}, err
|
||||
}
|
||||
port, err := readSocksPort(conn)
|
||||
if err != nil {
|
||||
return M.Socksaddr{}, err
|
||||
}
|
||||
ipAddr, ok := netip.AddrFromSlice(addr)
|
||||
if !ok {
|
||||
return M.Socksaddr{}, E.New("invalid IPv4 SOCKS destination")
|
||||
}
|
||||
return M.SocksaddrFrom(ipAddr, port), nil
|
||||
case 3:
|
||||
length := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, length); err != nil {
|
||||
return M.Socksaddr{}, err
|
||||
}
|
||||
host := make([]byte, int(length[0]))
|
||||
if _, err := io.ReadFull(conn, host); err != nil {
|
||||
return M.Socksaddr{}, err
|
||||
}
|
||||
port, err := readSocksPort(conn)
|
||||
if err != nil {
|
||||
return M.Socksaddr{}, err
|
||||
}
|
||||
return M.ParseSocksaddr(net.JoinHostPort(string(host), strconv.Itoa(int(port)))), nil
|
||||
case 4:
|
||||
addr := make([]byte, 16)
|
||||
if _, err := io.ReadFull(conn, addr); err != nil {
|
||||
return M.Socksaddr{}, err
|
||||
}
|
||||
port, err := readSocksPort(conn)
|
||||
if err != nil {
|
||||
return M.Socksaddr{}, err
|
||||
}
|
||||
ipAddr, ok := netip.AddrFromSlice(addr)
|
||||
if !ok {
|
||||
return M.Socksaddr{}, E.New("invalid IPv6 SOCKS destination")
|
||||
}
|
||||
return M.SocksaddrFrom(ipAddr, port), nil
|
||||
default:
|
||||
return M.Socksaddr{}, E.New("unsupported SOCKS address type: ", addressType)
|
||||
}
|
||||
}
|
||||
|
||||
func readSocksPort(conn net.Conn) (uint16, error) {
|
||||
port := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, port); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint16(port[0])<<8 | uint16(port[1]), nil
|
||||
}
|
||||
@@ -1,680 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/ws"
|
||||
"github.com/sagernet/ws/wsutil"
|
||||
)
|
||||
|
||||
type fakeConnectResponseWriter struct {
|
||||
status int
|
||||
headers http.Header
|
||||
err error
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func (w *fakeConnectResponseWriter) WriteResponse(responseError error, metadata []Metadata) error {
|
||||
w.err = responseError
|
||||
w.headers = make(http.Header)
|
||||
for _, entry := range metadata {
|
||||
switch {
|
||||
case entry.Key == metadataHTTPStatus:
|
||||
status, _ := strconv.Atoi(entry.Val)
|
||||
w.status = status
|
||||
case len(entry.Key) > len(metadataHTTPHeader)+1 && entry.Key[:len(metadataHTTPHeader)+1] == metadataHTTPHeader+":":
|
||||
w.headers.Add(entry.Key[len(metadataHTTPHeader)+1:], entry.Val)
|
||||
}
|
||||
}
|
||||
if w.done != nil {
|
||||
close(w.done)
|
||||
w.done = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newSpecialServiceInbound(t *testing.T) *Inbound {
|
||||
return newSpecialServiceInboundWithRouter(t, &testRouter{})
|
||||
}
|
||||
|
||||
func newSpecialServiceInboundWithRouter(t *testing.T, router adapter.Router) *Inbound {
|
||||
t.Helper()
|
||||
logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
configManager, err := NewConfigManager()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
router: router,
|
||||
logger: logFactory.NewLogger("test"),
|
||||
configManager: configManager,
|
||||
flowLimiter: &FlowLimiter{},
|
||||
}
|
||||
}
|
||||
|
||||
type countingRouter struct {
|
||||
testRouter
|
||||
count atomic.Int32
|
||||
}
|
||||
|
||||
func (r *countingRouter) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
r.count.Add(1)
|
||||
r.testRouter.RouteConnectionEx(ctx, conn, metadata, onClose)
|
||||
}
|
||||
|
||||
func startEchoListener(t *testing.T) net.Listener {
|
||||
t.Helper()
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
_, _ = io.Copy(conn, conn)
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
return listener
|
||||
}
|
||||
|
||||
func newSocksProxyService(t *testing.T, rules []IPRule) ResolvedService {
|
||||
t.Helper()
|
||||
service, err := parseResolvedService("socks-proxy", OriginRequestConfig{IPRules: rules})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return service
|
||||
}
|
||||
|
||||
func newSocksProxyConnectRequest() *ConnectRequest {
|
||||
return &ConnectRequest{
|
||||
Type: ConnectionTypeWebsocket,
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func startSocksProxyStream(t *testing.T, inboundInstance *Inbound, service ResolvedService) (net.Conn, <-chan struct{}) {
|
||||
t.Helper()
|
||||
serverSide, clientSide := net.Pipe()
|
||||
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
inboundInstance.handleSocksProxyStream(context.Background(), serverSide, respWriter, newSocksProxyConnectRequest(), adapter.InboundContext{}, service)
|
||||
}()
|
||||
select {
|
||||
case <-respWriter.done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for socks-proxy connect response")
|
||||
}
|
||||
if respWriter.err != nil {
|
||||
t.Fatal(respWriter.err)
|
||||
}
|
||||
if respWriter.status != http.StatusSwitchingProtocols {
|
||||
t.Fatalf("expected 101 response, got %d", respWriter.status)
|
||||
}
|
||||
return clientSide, done
|
||||
}
|
||||
|
||||
func writeSocksAuth(t *testing.T, conn net.Conn) {
|
||||
t.Helper()
|
||||
if err := wsutil.WriteClientMessage(conn, ws.OpBinary, []byte{5, 1, 0}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _, err := wsutil.ReadServerData(conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(data) != string([]byte{5, 0}) {
|
||||
t.Fatalf("unexpected auth response: %v", data)
|
||||
}
|
||||
}
|
||||
|
||||
func writeSocksConnectIPv4(t *testing.T, conn net.Conn, address string) []byte {
|
||||
t.Helper()
|
||||
host, portText, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
port, err := strconv.Atoi(portText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
requestBytes := []byte{5, 1, 0, 1}
|
||||
requestBytes = append(requestBytes, net.ParseIP(host).To4()...)
|
||||
requestBytes = append(requestBytes, byte(port>>8), byte(port))
|
||||
if err := wsutil.WriteClientMessage(conn, ws.OpBinary, requestBytes); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _, err := wsutil.ReadServerData(conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func TestServeSocksProxyRejectsMissingNoAuth(t *testing.T) {
|
||||
inboundInstance := newSpecialServiceInbound(t)
|
||||
serverSide, clientSide := net.Pipe()
|
||||
defer clientSide.Close()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- inboundInstance.serveSocksProxy(context.Background(), serverSide, nil)
|
||||
}()
|
||||
|
||||
if _, err := clientSide.Write([]byte{5, 1, 2}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
response := make([]byte, 2)
|
||||
if _, err := io.ReadFull(clientSide, response); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(response) != string([]byte{5, 255}) {
|
||||
t.Fatalf("unexpected auth rejection response: %v", response)
|
||||
}
|
||||
if err := <-errCh; err == nil {
|
||||
t.Fatal("expected socks auth rejection error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocksReplyForDialError(t *testing.T) {
|
||||
if reply := socksReplyForDialError(io.EOF); reply != socksReplyHostUnreachable {
|
||||
t.Fatalf("expected host unreachable for generic error, got %d", reply)
|
||||
}
|
||||
if reply := socksReplyForDialError(errors.New("connection refused")); reply != 5 {
|
||||
t.Fatalf("expected connection refused reply, got %d", reply)
|
||||
}
|
||||
if reply := socksReplyForDialError(errors.New("network is unreachable")); reply != 3 {
|
||||
t.Fatalf("expected network unreachable reply, got %d", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleBastionStream(t *testing.T) {
|
||||
listener := startEchoListener(t)
|
||||
defer listener.Close()
|
||||
|
||||
serverSide, clientSide := net.Pipe()
|
||||
defer clientSide.Close()
|
||||
|
||||
inboundInstance := newSpecialServiceInbound(t)
|
||||
request := &ConnectRequest{
|
||||
Type: ConnectionTypeWebsocket,
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
{Key: metadataHTTPHeader + ":Cf-Access-Jump-Destination", Val: listener.Addr().String()},
|
||||
},
|
||||
}
|
||||
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
inboundInstance.handleBastionStream(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{})
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-respWriter.done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for bastion connect response")
|
||||
}
|
||||
if respWriter.err != nil {
|
||||
t.Fatal(respWriter.err)
|
||||
}
|
||||
if respWriter.status != http.StatusSwitchingProtocols {
|
||||
t.Fatalf("expected 101 response, got %d", respWriter.status)
|
||||
}
|
||||
if respWriter.headers.Get("Sec-WebSocket-Accept") == "" {
|
||||
t.Fatal("expected websocket accept header")
|
||||
}
|
||||
|
||||
if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, opCode, err := wsutil.ReadServerData(clientSide)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if opCode != ws.OpBinary {
|
||||
t.Fatalf("expected binary frame, got %v", opCode)
|
||||
}
|
||||
if string(data) != "hello" {
|
||||
t.Fatalf("expected echoed payload, got %q", string(data))
|
||||
}
|
||||
_ = clientSide.Close()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("bastion stream did not exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSocksProxyStream(t *testing.T) {
|
||||
listener := startEchoListener(t)
|
||||
defer listener.Close()
|
||||
|
||||
_, portText, _ := net.SplitHostPort(listener.Addr().String())
|
||||
port, _ := strconv.Atoi(portText)
|
||||
service := newSocksProxyService(t, []IPRule{{
|
||||
Prefix: "127.0.0.0/8",
|
||||
Ports: []int{port},
|
||||
Allow: true,
|
||||
}})
|
||||
|
||||
clientSide, done := startSocksProxyStream(t, newSpecialServiceInbound(t), service)
|
||||
defer clientSide.Close()
|
||||
|
||||
writeSocksAuth(t, clientSide)
|
||||
data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String())
|
||||
if len(data) != 10 || data[1] != 0 {
|
||||
t.Fatalf("unexpected connect response: %v", data)
|
||||
}
|
||||
|
||||
if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _, err := wsutil.ReadServerData(clientSide)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(data) != "hello" {
|
||||
t.Fatalf("expected echoed payload, got %q", string(data))
|
||||
}
|
||||
_ = clientSide.Close()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("socks-proxy stream did not exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSocksProxyStreamDenyRule(t *testing.T) {
|
||||
listener := startEchoListener(t)
|
||||
defer listener.Close()
|
||||
|
||||
_, portText, _ := net.SplitHostPort(listener.Addr().String())
|
||||
port, _ := strconv.Atoi(portText)
|
||||
service := newSocksProxyService(t, []IPRule{{
|
||||
Prefix: "127.0.0.0/8",
|
||||
Ports: []int{port},
|
||||
Allow: false,
|
||||
}})
|
||||
router := &countingRouter{}
|
||||
clientSide, done := startSocksProxyStream(t, newSpecialServiceInboundWithRouter(t, router), service)
|
||||
defer clientSide.Close()
|
||||
|
||||
writeSocksAuth(t, clientSide)
|
||||
data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String())
|
||||
if len(data) != 10 || data[1] != socksReplyRuleFailure {
|
||||
t.Fatalf("unexpected deny response: %v", data)
|
||||
}
|
||||
if router.count.Load() != 0 {
|
||||
t.Fatalf("expected no router dial, got %d", router.count.Load())
|
||||
}
|
||||
_ = clientSide.Close()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("socks-proxy stream did not exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSocksProxyStreamPortMismatchDefaultDeny(t *testing.T) {
|
||||
listener := startEchoListener(t)
|
||||
defer listener.Close()
|
||||
|
||||
_, portText, _ := net.SplitHostPort(listener.Addr().String())
|
||||
port, _ := strconv.Atoi(portText)
|
||||
service := newSocksProxyService(t, []IPRule{{
|
||||
Prefix: "127.0.0.0/8",
|
||||
Ports: []int{port + 1},
|
||||
Allow: true,
|
||||
}})
|
||||
router := &countingRouter{}
|
||||
clientSide, done := startSocksProxyStream(t, newSpecialServiceInboundWithRouter(t, router), service)
|
||||
defer clientSide.Close()
|
||||
|
||||
writeSocksAuth(t, clientSide)
|
||||
data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String())
|
||||
if len(data) != 10 || data[1] != socksReplyRuleFailure {
|
||||
t.Fatalf("unexpected port mismatch response: %v", data)
|
||||
}
|
||||
if router.count.Load() != 0 {
|
||||
t.Fatalf("expected no router dial, got %d", router.count.Load())
|
||||
}
|
||||
_ = clientSide.Close()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("socks-proxy stream did not exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSocksProxyStreamEmptyRulesDefaultDeny(t *testing.T) {
|
||||
listener := startEchoListener(t)
|
||||
defer listener.Close()
|
||||
|
||||
router := &countingRouter{}
|
||||
clientSide, done := startSocksProxyStream(t, newSpecialServiceInboundWithRouter(t, router), newSocksProxyService(t, nil))
|
||||
defer clientSide.Close()
|
||||
|
||||
writeSocksAuth(t, clientSide)
|
||||
data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String())
|
||||
if len(data) != 10 || data[1] != socksReplyRuleFailure {
|
||||
t.Fatalf("unexpected empty-rule response: %v", data)
|
||||
}
|
||||
if router.count.Load() != 0 {
|
||||
t.Fatalf("expected no router dial, got %d", router.count.Load())
|
||||
}
|
||||
_ = clientSide.Close()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("socks-proxy stream did not exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSocksProxyStreamRuleOrderFirstMatchWins(t *testing.T) {
|
||||
listener := startEchoListener(t)
|
||||
defer listener.Close()
|
||||
|
||||
_, portText, _ := net.SplitHostPort(listener.Addr().String())
|
||||
port, _ := strconv.Atoi(portText)
|
||||
allowFirst := newSocksProxyService(t, []IPRule{
|
||||
{Prefix: "127.0.0.0/8", Ports: []int{port}, Allow: true},
|
||||
{Prefix: "127.0.0.1/32", Ports: []int{port}, Allow: false},
|
||||
})
|
||||
denyFirst := newSocksProxyService(t, []IPRule{
|
||||
{Prefix: "127.0.0.1/32", Ports: []int{port}, Allow: false},
|
||||
{Prefix: "127.0.0.0/8", Ports: []int{port}, Allow: true},
|
||||
})
|
||||
|
||||
t.Run("allow-first", func(t *testing.T) {
|
||||
clientSide, done := startSocksProxyStream(t, newSpecialServiceInbound(t), allowFirst)
|
||||
defer clientSide.Close()
|
||||
|
||||
writeSocksAuth(t, clientSide)
|
||||
data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String())
|
||||
if len(data) != 10 || data[1] != socksReplySuccess {
|
||||
t.Fatalf("unexpected allow-first response: %v", data)
|
||||
}
|
||||
_ = clientSide.Close()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("socks-proxy stream did not exit")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("deny-first", func(t *testing.T) {
|
||||
router := &countingRouter{}
|
||||
clientSide, done := startSocksProxyStream(t, newSpecialServiceInboundWithRouter(t, router), denyFirst)
|
||||
defer clientSide.Close()
|
||||
|
||||
writeSocksAuth(t, clientSide)
|
||||
data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String())
|
||||
if len(data) != 10 || data[1] != socksReplyRuleFailure {
|
||||
t.Fatalf("unexpected deny-first response: %v", data)
|
||||
}
|
||||
if router.count.Load() != 0 {
|
||||
t.Fatalf("expected no router dial, got %d", router.count.Load())
|
||||
}
|
||||
_ = clientSide.Close()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("socks-proxy stream did not exit")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleStreamService(t *testing.T) {
|
||||
listener := startEchoListener(t)
|
||||
defer listener.Close()
|
||||
|
||||
serverSide, clientSide := net.Pipe()
|
||||
defer clientSide.Close()
|
||||
|
||||
inboundInstance := newSpecialServiceInbound(t)
|
||||
request := &ConnectRequest{
|
||||
Type: ConnectionTypeWebsocket,
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
},
|
||||
}
|
||||
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{
|
||||
Kind: ResolvedServiceStream,
|
||||
Destination: M.ParseSocksaddr(listener.Addr().String()),
|
||||
StreamHasPort: true,
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-respWriter.done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for stream service connect response")
|
||||
}
|
||||
if respWriter.err != nil {
|
||||
t.Fatal(respWriter.err)
|
||||
}
|
||||
if respWriter.status != http.StatusSwitchingProtocols {
|
||||
t.Fatalf("expected 101 response, got %d", respWriter.status)
|
||||
}
|
||||
|
||||
if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, opCode, err := wsutil.ReadServerData(clientSide)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if opCode != ws.OpBinary {
|
||||
t.Fatalf("expected binary frame, got %v", opCode)
|
||||
}
|
||||
if string(data) != "hello" {
|
||||
t.Fatalf("expected echoed payload, got %q", string(data))
|
||||
}
|
||||
_ = clientSide.Close()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("stream service did not exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamServiceProxyTypeSocks(t *testing.T) {
|
||||
listener := startEchoListener(t)
|
||||
defer listener.Close()
|
||||
|
||||
serverSide, clientSide := net.Pipe()
|
||||
defer clientSide.Close()
|
||||
|
||||
inboundInstance := newSpecialServiceInbound(t)
|
||||
request := &ConnectRequest{
|
||||
Type: ConnectionTypeWebsocket,
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
},
|
||||
}
|
||||
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{
|
||||
Kind: ResolvedServiceStream,
|
||||
Destination: M.ParseSocksaddr(listener.Addr().String()),
|
||||
StreamHasPort: true,
|
||||
OriginRequest: OriginRequestConfig{
|
||||
ProxyType: "socks",
|
||||
},
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-respWriter.done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for stream service connect response")
|
||||
}
|
||||
if respWriter.err != nil {
|
||||
t.Fatal(respWriter.err)
|
||||
}
|
||||
if respWriter.status != http.StatusSwitchingProtocols {
|
||||
t.Fatalf("expected 101 response, got %d", respWriter.status)
|
||||
}
|
||||
|
||||
writeSocksAuth(t, clientSide)
|
||||
data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String())
|
||||
if len(data) != 10 || data[1] != socksReplySuccess {
|
||||
t.Fatalf("unexpected socks connect response: %v", data)
|
||||
}
|
||||
|
||||
if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _, err := wsutil.ReadServerData(clientSide)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(data) != "hello" {
|
||||
t.Fatalf("expected echoed payload, got %q", string(data))
|
||||
}
|
||||
_ = clientSide.Close()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("socks stream service did not exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamServiceGenericSchemeWithPort(t *testing.T) {
|
||||
listener := startEchoListener(t)
|
||||
defer listener.Close()
|
||||
|
||||
serverSide, clientSide := net.Pipe()
|
||||
defer clientSide.Close()
|
||||
|
||||
inboundInstance := newSpecialServiceInbound(t)
|
||||
request := &ConnectRequest{
|
||||
Type: ConnectionTypeWebsocket,
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
},
|
||||
}
|
||||
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{
|
||||
Kind: ResolvedServiceStream,
|
||||
Service: "ftp://" + listener.Addr().String(),
|
||||
Destination: M.ParseSocksaddr(listener.Addr().String()),
|
||||
StreamHasPort: true,
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-respWriter.done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for stream service connect response")
|
||||
}
|
||||
if respWriter.err != nil {
|
||||
t.Fatal(respWriter.err)
|
||||
}
|
||||
if respWriter.status != http.StatusSwitchingProtocols {
|
||||
t.Fatalf("expected 101 response, got %d", respWriter.status)
|
||||
}
|
||||
|
||||
if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _, err := wsutil.ReadServerData(clientSide)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(data) != "hello" {
|
||||
t.Fatalf("expected echoed payload, got %q", string(data))
|
||||
}
|
||||
_ = clientSide.Close()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("generic stream service did not exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamServiceGenericSchemeWithoutPort(t *testing.T) {
|
||||
serverSide, clientSide := net.Pipe()
|
||||
defer clientSide.Close()
|
||||
defer serverSide.Close()
|
||||
|
||||
router := &countingRouter{}
|
||||
inboundInstance := newSpecialServiceInboundWithRouter(t, router)
|
||||
request := &ConnectRequest{
|
||||
Type: ConnectionTypeWebsocket,
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
},
|
||||
}
|
||||
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
|
||||
|
||||
inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{
|
||||
Kind: ResolvedServiceStream,
|
||||
Service: "ftp://127.0.0.1",
|
||||
Destination: M.ParseSocksaddrHostPort("127.0.0.1", 0),
|
||||
StreamHasPort: false,
|
||||
BaseURL: &url.URL{
|
||||
Scheme: "ftp",
|
||||
Host: "127.0.0.1",
|
||||
},
|
||||
})
|
||||
|
||||
if respWriter.err == nil {
|
||||
t.Fatal("expected missing port error")
|
||||
}
|
||||
if respWriter.err.Error() != "address 127.0.0.1: missing port in address" {
|
||||
t.Fatalf("unexpected error: %v", respWriter.err)
|
||||
}
|
||||
if respWriter.status == http.StatusSwitchingProtocols {
|
||||
t.Fatalf("expected non-upgrade response on error, got %d", respWriter.status)
|
||||
}
|
||||
if router.count.Load() != 0 {
|
||||
t.Fatalf("expected router not to be used, got %d", router.count.Load())
|
||||
}
|
||||
}
|
||||
@@ -1,235 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"github.com/google/uuid"
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/pogs"
|
||||
)
|
||||
|
||||
// Protocol signatures distinguish stream types.
|
||||
var (
|
||||
dataStreamSignature = [6]byte{0x0A, 0x36, 0xCD, 0x12, 0xA1, 0x3E}
|
||||
rpcStreamSignature = [6]byte{0x52, 0xBB, 0x82, 0x5C, 0xDB, 0x65}
|
||||
)
|
||||
|
||||
const protocolVersion = "01"
|
||||
|
||||
// StreamType identifies the kind of QUIC stream.
|
||||
type StreamType int
|
||||
|
||||
const (
|
||||
StreamTypeData StreamType = iota
|
||||
StreamTypeRPC
|
||||
)
|
||||
|
||||
const metadataFlowConnectRateLimited = "FlowConnectRateLimited"
|
||||
|
||||
// ConnectionType indicates the proxied connection type within a data stream.
|
||||
type ConnectionType uint16
|
||||
|
||||
const (
|
||||
ConnectionTypeHTTP ConnectionType = iota
|
||||
ConnectionTypeWebsocket
|
||||
ConnectionTypeTCP
|
||||
)
|
||||
|
||||
func (c ConnectionType) String() string {
|
||||
switch c {
|
||||
case ConnectionTypeHTTP:
|
||||
return "http"
|
||||
case ConnectionTypeWebsocket:
|
||||
return "websocket"
|
||||
case ConnectionTypeTCP:
|
||||
return "tcp"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Metadata is a key-value pair in stream metadata.
|
||||
type Metadata struct {
|
||||
Key string `capnp:"key"`
|
||||
Val string `capnp:"val"`
|
||||
}
|
||||
|
||||
func flowConnectRateLimitedMetadata() []Metadata {
|
||||
return []Metadata{{
|
||||
Key: metadataFlowConnectRateLimited,
|
||||
Val: "true",
|
||||
}}
|
||||
}
|
||||
|
||||
func hasFlowConnectRateLimited(metadata []Metadata) bool {
|
||||
for _, entry := range metadata {
|
||||
if entry.Key == metadataFlowConnectRateLimited && entry.Val == "true" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ConnectRequest is sent by the edge at the start of a data stream.
|
||||
type ConnectRequest struct {
|
||||
Dest string `capnp:"dest"`
|
||||
Type ConnectionType `capnp:"type"`
|
||||
Metadata []Metadata `capnp:"metadata"`
|
||||
}
|
||||
|
||||
func (r *ConnectRequest) MetadataMap() map[string]string {
|
||||
result := make(map[string]string, len(r.Metadata))
|
||||
for _, m := range r.Metadata {
|
||||
result[m.Key] = m.Val
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *ConnectRequest) fromCapnp(msg *capnp.Message) error {
|
||||
root, err := tunnelrpc.ReadRootConnectRequest(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return pogs.Extract(r, tunnelrpc.ConnectRequest_TypeID, root.Struct)
|
||||
}
|
||||
|
||||
// ConnectResponse is sent back to the edge after processing a ConnectRequest.
|
||||
type ConnectResponse struct {
|
||||
Error string `capnp:"error"`
|
||||
Metadata []Metadata `capnp:"metadata"`
|
||||
}
|
||||
|
||||
func (r *ConnectResponse) toCapnp() (*capnp.Message, error) {
|
||||
msg, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
root, err := tunnelrpc.NewRootConnectResponse(seg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = pogs.Insert(tunnelrpc.ConnectResponse_TypeID, root.Struct, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// ReadStreamSignature reads the 6-byte stream type signature.
|
||||
func ReadStreamSignature(r io.Reader) (StreamType, error) {
|
||||
var signature [6]byte
|
||||
_, err := io.ReadFull(r, signature[:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch signature {
|
||||
case dataStreamSignature:
|
||||
return StreamTypeData, nil
|
||||
case rpcStreamSignature:
|
||||
return StreamTypeRPC, nil
|
||||
default:
|
||||
return 0, E.New("unknown stream signature")
|
||||
}
|
||||
}
|
||||
|
||||
// ReadConnectRequest reads the version and ConnectRequest from a data stream.
|
||||
func ReadConnectRequest(r io.Reader) (*ConnectRequest, error) {
|
||||
version := make([]byte, 2)
|
||||
_, err := io.ReadFull(r, version)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read version")
|
||||
}
|
||||
|
||||
msg, err := capnp.NewDecoder(r).Decode()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode connect request")
|
||||
}
|
||||
|
||||
request := &ConnectRequest{}
|
||||
err = request.fromCapnp(msg)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "extract connect request")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
// WriteConnectResponse writes a ConnectResponse with the data stream preamble.
|
||||
func WriteConnectResponse(w io.Writer, responseError error, metadata ...Metadata) error {
|
||||
response := &ConnectResponse{
|
||||
Metadata: metadata,
|
||||
}
|
||||
if responseError != nil {
|
||||
response.Error = responseError.Error()
|
||||
}
|
||||
|
||||
msg, err := response.toCapnp()
|
||||
if err != nil {
|
||||
return E.Cause(err, "encode connect response")
|
||||
}
|
||||
|
||||
// Write data stream preamble
|
||||
_, err = w.Write(dataStreamSignature[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write([]byte(protocolVersion))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return capnp.NewEncoder(w).Encode(msg)
|
||||
}
|
||||
|
||||
func WriteRPCStreamSignature(w io.Writer) error {
|
||||
_, err := w.Write(rpcStreamSignature[:])
|
||||
return err
|
||||
}
|
||||
|
||||
// Registration data structures for the control stream.
|
||||
|
||||
type RegistrationTunnelAuth struct {
|
||||
AccountTag string `capnp:"accountTag"`
|
||||
TunnelSecret []byte `capnp:"tunnelSecret"`
|
||||
}
|
||||
|
||||
type RegistrationClientInfo struct {
|
||||
ClientID []byte `capnp:"clientId"`
|
||||
Features []string `capnp:"features"`
|
||||
Version string `capnp:"version"`
|
||||
Arch string `capnp:"arch"`
|
||||
}
|
||||
|
||||
type RegistrationConnectionOptions struct {
|
||||
Client RegistrationClientInfo `capnp:"client"`
|
||||
OriginLocalIP net.IP `capnp:"originLocalIp"`
|
||||
ReplaceExisting bool `capnp:"replaceExisting"`
|
||||
CompressionQuality uint8 `capnp:"compressionQuality"`
|
||||
NumPreviousAttempts uint8 `capnp:"numPreviousAttempts"`
|
||||
}
|
||||
|
||||
// RegistrationResult is the parsed result of a RegisterConnection RPC.
|
||||
type RegistrationResult struct {
|
||||
ConnectionID uuid.UUID
|
||||
Location string
|
||||
TunnelIsRemotelyManaged bool
|
||||
}
|
||||
|
||||
// RetryableError signals the edge wants us to retry after a delay.
|
||||
type RetryableError struct {
|
||||
Err error
|
||||
Delay time.Duration
|
||||
}
|
||||
|
||||
func (e *RetryableError) Error() string {
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
func (e *RetryableError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReadStreamSignatureData(t *testing.T) {
|
||||
buf := bytes.NewBuffer(dataStreamSignature[:])
|
||||
streamType, err := ReadStreamSignature(buf)
|
||||
if err != nil {
|
||||
t.Fatal("ReadStreamSignature: ", err)
|
||||
}
|
||||
if streamType != StreamTypeData {
|
||||
t.Error("expected StreamTypeData, got ", streamType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadStreamSignatureRPC(t *testing.T) {
|
||||
buf := bytes.NewBuffer(rpcStreamSignature[:])
|
||||
streamType, err := ReadStreamSignature(buf)
|
||||
if err != nil {
|
||||
t.Fatal("ReadStreamSignature: ", err)
|
||||
}
|
||||
if streamType != StreamTypeRPC {
|
||||
t.Error("expected StreamTypeRPC, got ", streamType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadStreamSignatureUnknown(t *testing.T) {
|
||||
buf := bytes.NewBuffer([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
|
||||
_, err := ReadStreamSignature(buf)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown signature")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadStreamSignatureTooShort(t *testing.T) {
|
||||
buf := bytes.NewBuffer([]byte{0x0A, 0x36, 0xCD})
|
||||
_, err := ReadStreamSignature(buf)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for short input")
|
||||
}
|
||||
if !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
t.Error("expected ErrUnexpectedEOF, got ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteConnectResponseSuccess(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
metadata := Metadata{Key: "testKey", Val: "testVal"}
|
||||
err := WriteConnectResponse(&buf, nil, metadata)
|
||||
if err != nil {
|
||||
t.Fatal("WriteConnectResponse: ", err)
|
||||
}
|
||||
|
||||
data := buf.Bytes()
|
||||
if len(data) < 8 {
|
||||
t.Fatal("response too short: ", len(data))
|
||||
}
|
||||
|
||||
var signature [6]byte
|
||||
copy(signature[:], data[:6])
|
||||
if signature != dataStreamSignature {
|
||||
t.Error("expected data stream signature")
|
||||
}
|
||||
|
||||
version := string(data[6:8])
|
||||
if version != "01" {
|
||||
t.Error("expected version 01, got ", version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteConnectResponseError(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
err := WriteConnectResponse(&buf, errors.New("test failure"))
|
||||
if err != nil {
|
||||
t.Fatal("WriteConnectResponse: ", err)
|
||||
}
|
||||
|
||||
data := buf.Bytes()
|
||||
if len(data) < 8 {
|
||||
t.Fatal("response too short")
|
||||
}
|
||||
|
||||
var signature [6]byte
|
||||
copy(signature[:], data[:6])
|
||||
if signature != dataStreamSignature {
|
||||
t.Error("expected data stream signature")
|
||||
}
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
# Generate go.capnp.out with:
|
||||
# capnp compile -o- go.capnp > go.capnp.out
|
||||
# Must run inside this directory to preserve paths.
|
||||
|
||||
@0xd12a1c51fedd6c88;
|
||||
|
||||
annotation package(file) :Text;
|
||||
# The Go package name for the generated file.
|
||||
|
||||
annotation import(file) :Text;
|
||||
# The Go import path that the generated file is accessible from.
|
||||
# Used to generate import statements and check if two types are in the
|
||||
# same package.
|
||||
|
||||
annotation doc(struct, field, enum) :Text;
|
||||
# Adds a doc comment to the generated code.
|
||||
|
||||
annotation tag(enumerant) :Text;
|
||||
# Changes the string representation of the enum in the generated code.
|
||||
|
||||
annotation notag(enumerant) :Void;
|
||||
# Removes the string representation of the enum in the generated code.
|
||||
|
||||
annotation customtype(field) :Text;
|
||||
# OBSOLETE, not used by code generator.
|
||||
|
||||
annotation name(struct, field, union, enum, enumerant, interface, method, param, annotation, const, group) :Text;
|
||||
# Used to rename the element in the generated code.
|
||||
|
||||
$package("capnp");
|
||||
$import("zombiezen.com/go/capnproto2");
|
||||
@@ -1,28 +0,0 @@
|
||||
using Go = import "go.capnp";
|
||||
@0xb29021ef7421cc32;
|
||||
|
||||
$Go.package("tunnelrpc");
|
||||
$Go.import("github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc");
|
||||
|
||||
|
||||
struct ConnectRequest @0xc47116a1045e4061 {
|
||||
dest @0 :Text;
|
||||
type @1 :ConnectionType;
|
||||
metadata @2 :List(Metadata);
|
||||
}
|
||||
|
||||
enum ConnectionType @0xc52e1bac26d379c8 {
|
||||
http @0;
|
||||
websocket @1;
|
||||
tcp @2;
|
||||
}
|
||||
|
||||
struct Metadata @0xe1446b97bfd1cd37 {
|
||||
key @0 :Text;
|
||||
val @1 :Text;
|
||||
}
|
||||
|
||||
struct ConnectResponse @0xb1032ec91cef8727 {
|
||||
error @0 :Text;
|
||||
metadata @1 :List(Metadata);
|
||||
}
|
||||
@@ -1,394 +0,0 @@
|
||||
// Code generated by capnpc-go. DO NOT EDIT.
|
||||
|
||||
package tunnelrpc
|
||||
|
||||
import (
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
text "zombiezen.com/go/capnproto2/encoding/text"
|
||||
schemas "zombiezen.com/go/capnproto2/schemas"
|
||||
)
|
||||
|
||||
type ConnectRequest struct{ capnp.Struct }
|
||||
|
||||
// ConnectRequest_TypeID is the unique identifier for the type ConnectRequest.
|
||||
const ConnectRequest_TypeID = 0xc47116a1045e4061
|
||||
|
||||
func NewConnectRequest(s *capnp.Segment) (ConnectRequest, error) {
|
||||
st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 2})
|
||||
return ConnectRequest{st}, err
|
||||
}
|
||||
|
||||
func NewRootConnectRequest(s *capnp.Segment) (ConnectRequest, error) {
|
||||
st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 2})
|
||||
return ConnectRequest{st}, err
|
||||
}
|
||||
|
||||
func ReadRootConnectRequest(msg *capnp.Message) (ConnectRequest, error) {
|
||||
root, err := msg.RootPtr()
|
||||
return ConnectRequest{root.Struct()}, err
|
||||
}
|
||||
|
||||
func (s ConnectRequest) String() string {
|
||||
str, _ := text.Marshal(0xc47116a1045e4061, s.Struct)
|
||||
return str
|
||||
}
|
||||
|
||||
func (s ConnectRequest) Dest() (string, error) {
|
||||
p, err := s.Struct.Ptr(0)
|
||||
return p.Text(), err
|
||||
}
|
||||
|
||||
func (s ConnectRequest) HasDest() bool {
|
||||
p, err := s.Struct.Ptr(0)
|
||||
return p.IsValid() || err != nil
|
||||
}
|
||||
|
||||
func (s ConnectRequest) DestBytes() ([]byte, error) {
|
||||
p, err := s.Struct.Ptr(0)
|
||||
return p.TextBytes(), err
|
||||
}
|
||||
|
||||
func (s ConnectRequest) SetDest(v string) error {
|
||||
return s.Struct.SetText(0, v)
|
||||
}
|
||||
|
||||
func (s ConnectRequest) Type() ConnectionType {
|
||||
return ConnectionType(s.Struct.Uint16(0))
|
||||
}
|
||||
|
||||
func (s ConnectRequest) SetType(v ConnectionType) {
|
||||
s.Struct.SetUint16(0, uint16(v))
|
||||
}
|
||||
|
||||
func (s ConnectRequest) Metadata() (Metadata_List, error) {
|
||||
p, err := s.Struct.Ptr(1)
|
||||
return Metadata_List{List: p.List()}, err
|
||||
}
|
||||
|
||||
func (s ConnectRequest) HasMetadata() bool {
|
||||
p, err := s.Struct.Ptr(1)
|
||||
return p.IsValid() || err != nil
|
||||
}
|
||||
|
||||
func (s ConnectRequest) SetMetadata(v Metadata_List) error {
|
||||
return s.Struct.SetPtr(1, v.List.ToPtr())
|
||||
}
|
||||
|
||||
// NewMetadata sets the metadata field to a newly
|
||||
// allocated Metadata_List, preferring placement in s's segment.
|
||||
func (s ConnectRequest) NewMetadata(n int32) (Metadata_List, error) {
|
||||
l, err := NewMetadata_List(s.Struct.Segment(), n)
|
||||
if err != nil {
|
||||
return Metadata_List{}, err
|
||||
}
|
||||
err = s.Struct.SetPtr(1, l.List.ToPtr())
|
||||
return l, err
|
||||
}
|
||||
|
||||
// ConnectRequest_List is a list of ConnectRequest.
|
||||
type ConnectRequest_List struct{ capnp.List }
|
||||
|
||||
// NewConnectRequest creates a new list of ConnectRequest.
|
||||
func NewConnectRequest_List(s *capnp.Segment, sz int32) (ConnectRequest_List, error) {
|
||||
l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 8, PointerCount: 2}, sz)
|
||||
return ConnectRequest_List{l}, err
|
||||
}
|
||||
|
||||
func (s ConnectRequest_List) At(i int) ConnectRequest { return ConnectRequest{s.List.Struct(i)} }
|
||||
|
||||
func (s ConnectRequest_List) Set(i int, v ConnectRequest) error { return s.List.SetStruct(i, v.Struct) }
|
||||
|
||||
func (s ConnectRequest_List) String() string {
|
||||
str, _ := text.MarshalList(0xc47116a1045e4061, s.List)
|
||||
return str
|
||||
}
|
||||
|
||||
// ConnectRequest_Promise is a wrapper for a ConnectRequest promised by a client call.
|
||||
type ConnectRequest_Promise struct{ *capnp.Pipeline }
|
||||
|
||||
func (p ConnectRequest_Promise) Struct() (ConnectRequest, error) {
|
||||
s, err := p.Pipeline.Struct()
|
||||
return ConnectRequest{s}, err
|
||||
}
|
||||
|
||||
type ConnectionType uint16
|
||||
|
||||
// ConnectionType_TypeID is the unique identifier for the type ConnectionType.
|
||||
const ConnectionType_TypeID = 0xc52e1bac26d379c8
|
||||
|
||||
// Values of ConnectionType.
|
||||
const (
|
||||
ConnectionType_http ConnectionType = 0
|
||||
ConnectionType_websocket ConnectionType = 1
|
||||
ConnectionType_tcp ConnectionType = 2
|
||||
)
|
||||
|
||||
// String returns the enum's constant name.
|
||||
func (c ConnectionType) String() string {
|
||||
switch c {
|
||||
case ConnectionType_http:
|
||||
return "http"
|
||||
case ConnectionType_websocket:
|
||||
return "websocket"
|
||||
case ConnectionType_tcp:
|
||||
return "tcp"
|
||||
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectionTypeFromString returns the enum value with a name,
|
||||
// or the zero value if there's no such value.
|
||||
func ConnectionTypeFromString(c string) ConnectionType {
|
||||
switch c {
|
||||
case "http":
|
||||
return ConnectionType_http
|
||||
case "websocket":
|
||||
return ConnectionType_websocket
|
||||
case "tcp":
|
||||
return ConnectionType_tcp
|
||||
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
type ConnectionType_List struct{ capnp.List }
|
||||
|
||||
func NewConnectionType_List(s *capnp.Segment, sz int32) (ConnectionType_List, error) {
|
||||
l, err := capnp.NewUInt16List(s, sz)
|
||||
return ConnectionType_List{l.List}, err
|
||||
}
|
||||
|
||||
func (l ConnectionType_List) At(i int) ConnectionType {
|
||||
ul := capnp.UInt16List{List: l.List}
|
||||
return ConnectionType(ul.At(i))
|
||||
}
|
||||
|
||||
func (l ConnectionType_List) Set(i int, v ConnectionType) {
|
||||
ul := capnp.UInt16List{List: l.List}
|
||||
ul.Set(i, uint16(v))
|
||||
}
|
||||
|
||||
type Metadata struct{ capnp.Struct }
|
||||
|
||||
// Metadata_TypeID is the unique identifier for the type Metadata.
|
||||
const Metadata_TypeID = 0xe1446b97bfd1cd37
|
||||
|
||||
func NewMetadata(s *capnp.Segment) (Metadata, error) {
|
||||
st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2})
|
||||
return Metadata{st}, err
|
||||
}
|
||||
|
||||
func NewRootMetadata(s *capnp.Segment) (Metadata, error) {
|
||||
st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2})
|
||||
return Metadata{st}, err
|
||||
}
|
||||
|
||||
func ReadRootMetadata(msg *capnp.Message) (Metadata, error) {
|
||||
root, err := msg.RootPtr()
|
||||
return Metadata{root.Struct()}, err
|
||||
}
|
||||
|
||||
func (s Metadata) String() string {
|
||||
str, _ := text.Marshal(0xe1446b97bfd1cd37, s.Struct)
|
||||
return str
|
||||
}
|
||||
|
||||
func (s Metadata) Key() (string, error) {
|
||||
p, err := s.Struct.Ptr(0)
|
||||
return p.Text(), err
|
||||
}
|
||||
|
||||
func (s Metadata) HasKey() bool {
|
||||
p, err := s.Struct.Ptr(0)
|
||||
return p.IsValid() || err != nil
|
||||
}
|
||||
|
||||
func (s Metadata) KeyBytes() ([]byte, error) {
|
||||
p, err := s.Struct.Ptr(0)
|
||||
return p.TextBytes(), err
|
||||
}
|
||||
|
||||
func (s Metadata) SetKey(v string) error {
|
||||
return s.Struct.SetText(0, v)
|
||||
}
|
||||
|
||||
func (s Metadata) Val() (string, error) {
|
||||
p, err := s.Struct.Ptr(1)
|
||||
return p.Text(), err
|
||||
}
|
||||
|
||||
func (s Metadata) HasVal() bool {
|
||||
p, err := s.Struct.Ptr(1)
|
||||
return p.IsValid() || err != nil
|
||||
}
|
||||
|
||||
func (s Metadata) ValBytes() ([]byte, error) {
|
||||
p, err := s.Struct.Ptr(1)
|
||||
return p.TextBytes(), err
|
||||
}
|
||||
|
||||
func (s Metadata) SetVal(v string) error {
|
||||
return s.Struct.SetText(1, v)
|
||||
}
|
||||
|
||||
// Metadata_List is a list of Metadata.
|
||||
type Metadata_List struct{ capnp.List }
|
||||
|
||||
// NewMetadata creates a new list of Metadata.
|
||||
func NewMetadata_List(s *capnp.Segment, sz int32) (Metadata_List, error) {
|
||||
l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}, sz)
|
||||
return Metadata_List{l}, err
|
||||
}
|
||||
|
||||
func (s Metadata_List) At(i int) Metadata { return Metadata{s.List.Struct(i)} }
|
||||
|
||||
func (s Metadata_List) Set(i int, v Metadata) error { return s.List.SetStruct(i, v.Struct) }
|
||||
|
||||
func (s Metadata_List) String() string {
|
||||
str, _ := text.MarshalList(0xe1446b97bfd1cd37, s.List)
|
||||
return str
|
||||
}
|
||||
|
||||
// Metadata_Promise is a wrapper for a Metadata promised by a client call.
|
||||
type Metadata_Promise struct{ *capnp.Pipeline }
|
||||
|
||||
func (p Metadata_Promise) Struct() (Metadata, error) {
|
||||
s, err := p.Pipeline.Struct()
|
||||
return Metadata{s}, err
|
||||
}
|
||||
|
||||
type ConnectResponse struct{ capnp.Struct }
|
||||
|
||||
// ConnectResponse_TypeID is the unique identifier for the type ConnectResponse.
|
||||
const ConnectResponse_TypeID = 0xb1032ec91cef8727
|
||||
|
||||
func NewConnectResponse(s *capnp.Segment) (ConnectResponse, error) {
|
||||
st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2})
|
||||
return ConnectResponse{st}, err
|
||||
}
|
||||
|
||||
func NewRootConnectResponse(s *capnp.Segment) (ConnectResponse, error) {
|
||||
st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2})
|
||||
return ConnectResponse{st}, err
|
||||
}
|
||||
|
||||
func ReadRootConnectResponse(msg *capnp.Message) (ConnectResponse, error) {
|
||||
root, err := msg.RootPtr()
|
||||
return ConnectResponse{root.Struct()}, err
|
||||
}
|
||||
|
||||
func (s ConnectResponse) String() string {
|
||||
str, _ := text.Marshal(0xb1032ec91cef8727, s.Struct)
|
||||
return str
|
||||
}
|
||||
|
||||
func (s ConnectResponse) Error() (string, error) {
|
||||
p, err := s.Struct.Ptr(0)
|
||||
return p.Text(), err
|
||||
}
|
||||
|
||||
func (s ConnectResponse) HasError() bool {
|
||||
p, err := s.Struct.Ptr(0)
|
||||
return p.IsValid() || err != nil
|
||||
}
|
||||
|
||||
func (s ConnectResponse) ErrorBytes() ([]byte, error) {
|
||||
p, err := s.Struct.Ptr(0)
|
||||
return p.TextBytes(), err
|
||||
}
|
||||
|
||||
func (s ConnectResponse) SetError(v string) error {
|
||||
return s.Struct.SetText(0, v)
|
||||
}
|
||||
|
||||
func (s ConnectResponse) Metadata() (Metadata_List, error) {
|
||||
p, err := s.Struct.Ptr(1)
|
||||
return Metadata_List{List: p.List()}, err
|
||||
}
|
||||
|
||||
func (s ConnectResponse) HasMetadata() bool {
|
||||
p, err := s.Struct.Ptr(1)
|
||||
return p.IsValid() || err != nil
|
||||
}
|
||||
|
||||
func (s ConnectResponse) SetMetadata(v Metadata_List) error {
|
||||
return s.Struct.SetPtr(1, v.List.ToPtr())
|
||||
}
|
||||
|
||||
// NewMetadata sets the metadata field to a newly
|
||||
// allocated Metadata_List, preferring placement in s's segment.
|
||||
func (s ConnectResponse) NewMetadata(n int32) (Metadata_List, error) {
|
||||
l, err := NewMetadata_List(s.Struct.Segment(), n)
|
||||
if err != nil {
|
||||
return Metadata_List{}, err
|
||||
}
|
||||
err = s.Struct.SetPtr(1, l.List.ToPtr())
|
||||
return l, err
|
||||
}
|
||||
|
||||
// ConnectResponse_List is a list of ConnectResponse.
|
||||
type ConnectResponse_List struct{ capnp.List }
|
||||
|
||||
// NewConnectResponse creates a new list of ConnectResponse.
|
||||
func NewConnectResponse_List(s *capnp.Segment, sz int32) (ConnectResponse_List, error) {
|
||||
l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}, sz)
|
||||
return ConnectResponse_List{l}, err
|
||||
}
|
||||
|
||||
func (s ConnectResponse_List) At(i int) ConnectResponse { return ConnectResponse{s.List.Struct(i)} }
|
||||
|
||||
func (s ConnectResponse_List) Set(i int, v ConnectResponse) error {
|
||||
return s.List.SetStruct(i, v.Struct)
|
||||
}
|
||||
|
||||
func (s ConnectResponse_List) String() string {
|
||||
str, _ := text.MarshalList(0xb1032ec91cef8727, s.List)
|
||||
return str
|
||||
}
|
||||
|
||||
// ConnectResponse_Promise is a wrapper for a ConnectResponse promised by a client call.
|
||||
type ConnectResponse_Promise struct{ *capnp.Pipeline }
|
||||
|
||||
func (p ConnectResponse_Promise) Struct() (ConnectResponse, error) {
|
||||
s, err := p.Pipeline.Struct()
|
||||
return ConnectResponse{s}, err
|
||||
}
|
||||
|
||||
const schema_b29021ef7421cc32 = "x\xda\xac\x91Ak\x13A\x1c\xc5\xdf\x9bI\\\x85\xe8" +
|
||||
"fH\x15DCi\x11\xb5\xc5\x06\x9b\x08\x82\xa7\x80\x15" +
|
||||
"TZ\xcc\x14\xcf\x96\xedv\xb05\xed\xee$;\xb5\xe4" +
|
||||
"\x13x\xf5&\x1e=\x0a\x8a\xe8\x17\xf0\xa2\xa0\xa0\x88\x88" +
|
||||
"\x1f\xc0\x83\x07O\xfd\x04\xb22\x0b\xdb@\xc9\xc1Co" +
|
||||
"\x7f\xde<\xde\xfc\xfe\xffW\xff\xd6\x15\x8b\xd5\x87\x04t" +
|
||||
"\xbdz,\xbf\xf4d\xff\xfc\xe7\x96|\x0b\xd5d\xde\xfe" +
|
||||
"2\xe3\xf6g\x9e\xbeCU\x04\xc0\xe2\xf3GT\xaf\x03" +
|
||||
"@\xbd\xdc\x03\xf3\xa8\xfb\xa0\xf2\xe2\xcc\xe0\x03t\x93\x87" +
|
||||
"\xad\x9d\xb3\\gc\x81\x01\xd0\x98\xe3\x1b0\xff4\xfa" +
|
||||
"q\xf1\xd5\xb9\xd6G\xa8\xa6\x18\x9b\xc1\xceO\xef\xfcS" +
|
||||
"8\x7f\xf3\x1e\x98_\xff\xfa\xfd\xfd\xb3\xfe\xd2\xaf\x09\x04" +
|
||||
"\x9d\xbfl\xb3q\xd2\x8f\x8d\x13\xc2C\x0cv\xb7\xe2\xb5" +
|
||||
"\x1d\xe3*\xd1F\xe4\xa25;L]\x1a\xa7\xdb\xad8" +
|
||||
"\xb2\x89\xbdq3M\x12\x13\xbbU\x93\xd90M2\xd3" +
|
||||
"#\xf5qY\x01*\x04\xd4\\\x1b\xd0\x17$\xf5UA" +
|
||||
"EN\xd1\x8b\x0bw\x01}ER\xdf\x16\x9c6\xc3a" +
|
||||
":d\x0d\x8250\xdf1\xae\xf8\x05\x00O\x81=I" +
|
||||
"\xd6\xc7\xb4\xa0\x17\xff\x17h\xb0\x1b\x98\xccy\x9e\xda\x01" +
|
||||
"\xcf\xady@w%\xf5\xb2`\x89s\xc7kK\x92\xba" +
|
||||
"'\xa8\x04\xa7(\x00\xb5\xe2\x19\x97%\xf5\xa6`\xb8a" +
|
||||
"2W\"\x86nd\x0d\xc3\xf1\xb1A\x86GJ\xbe\x95" +
|
||||
"&\xf7\x83\x91-.Y+`\x9a\xf3>@\x9d^\x05" +
|
||||
"(\x94\x9a\x05\xc2M\xe7l\xbeg\xd6\xb34\xee\x1b\xd0" +
|
||||
"\x05.\xb6\x07\xf1rb\xfc\x8aq\xd3\xc5\xc3\xa1\x8af" +
|
||||
"'U\xe4\xc5\xcb\x92\xfa\x9a`\xd07\xa3r\xfb\xe0q" +
|
||||
"\xb4]\xce\xff\x02\x00\x00\xff\xff\x14\xd5\xb6\xda"
|
||||
|
||||
func init() {
|
||||
schemas.Register(schema_b29021ef7421cc32,
|
||||
0xb1032ec91cef8727,
|
||||
0xc47116a1045e4061,
|
||||
0xc52e1bac26d379c8,
|
||||
0xe1446b97bfd1cd37)
|
||||
}
|
||||
@@ -1,195 +0,0 @@
|
||||
using Go = import "go.capnp";
|
||||
@0xdb8274f9144abc7e;
|
||||
$Go.package("tunnelrpc");
|
||||
$Go.import("github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc");
|
||||
|
||||
# === DEPRECATED Legacy Tunnel Authentication and Registration methods/servers ===
|
||||
#
|
||||
# These structs and interfaces are no longer used but it is important to keep
|
||||
# them around to make sure backwards compatibility within the rpc protocol is
|
||||
# maintained.
|
||||
|
||||
struct Authentication @0xc082ef6e0d42ed1d {
|
||||
# DEPRECATED: Legacy tunnel authentication mechanism
|
||||
key @0 :Text;
|
||||
email @1 :Text;
|
||||
originCAKey @2 :Text;
|
||||
}
|
||||
|
||||
struct TunnelRegistration @0xf41a0f001ad49e46 {
|
||||
# DEPRECATED: Legacy tunnel authentication mechanism
|
||||
err @0 :Text;
|
||||
# the url to access the tunnel
|
||||
url @1 :Text;
|
||||
# Used to inform the client of actions taken.
|
||||
logLines @2 :List(Text);
|
||||
# In case of error, whether the client should attempt to reconnect.
|
||||
permanentFailure @3 :Bool;
|
||||
# Displayed to user
|
||||
tunnelID @4 :Text;
|
||||
# How long should this connection wait to retry in seconds, if the error wasn't permanent
|
||||
retryAfterSeconds @5 :UInt16;
|
||||
# A unique ID used to reconnect this tunnel.
|
||||
eventDigest @6 :Data;
|
||||
# A unique ID used to prove this tunnel was previously connected to a given metal.
|
||||
connDigest @7 :Data;
|
||||
}
|
||||
|
||||
struct RegistrationOptions @0xc793e50592935b4a {
|
||||
# DEPRECATED: Legacy tunnel authentication mechanism
|
||||
|
||||
# The tunnel client's unique identifier, used to verify a reconnection.
|
||||
clientId @0 :Text;
|
||||
# Information about the running binary.
|
||||
version @1 :Text;
|
||||
os @2 :Text;
|
||||
# What to do with existing tunnels for the given hostname.
|
||||
existingTunnelPolicy @3 :ExistingTunnelPolicy;
|
||||
# If using the balancing policy, identifies the LB pool to use.
|
||||
poolName @4 :Text;
|
||||
# Client-defined tags to associate with the tunnel
|
||||
tags @5 :List(Tag);
|
||||
# A unique identifier for a high-availability connection made by a single client.
|
||||
connectionId @6 :UInt8;
|
||||
# origin LAN IP
|
||||
originLocalIp @7 :Text;
|
||||
# whether Argo Tunnel client has been autoupdated
|
||||
isAutoupdated @8 :Bool;
|
||||
# whether Argo Tunnel client is run from a terminal
|
||||
runFromTerminal @9 :Bool;
|
||||
# cross stream compression setting, 0 - off, 3 - high
|
||||
compressionQuality @10 :UInt64;
|
||||
uuid @11 :Text;
|
||||
# number of previous attempts to send RegisterTunnel/ReconnectTunnel
|
||||
numPreviousAttempts @12 :UInt8;
|
||||
# Set of features this cloudflared knows it supports
|
||||
features @13 :List(Text);
|
||||
}
|
||||
|
||||
enum ExistingTunnelPolicy @0x84cb9536a2cf6d3c {
|
||||
# DEPRECATED: Legacy tunnel registration mechanism
|
||||
|
||||
ignore @0;
|
||||
disconnect @1;
|
||||
balance @2;
|
||||
}
|
||||
|
||||
struct ServerInfo @0xf2c68e2547ec3866 {
|
||||
# DEPRECATED: Legacy tunnel registration mechanism
|
||||
|
||||
locationName @0 :Text;
|
||||
}
|
||||
|
||||
struct AuthenticateResponse @0x82c325a07ad22a65 {
|
||||
# DEPRECATED: Legacy tunnel registration mechanism
|
||||
|
||||
permanentErr @0 :Text;
|
||||
retryableErr @1 :Text;
|
||||
jwt @2 :Data;
|
||||
hoursUntilRefresh @3 :UInt8;
|
||||
}
|
||||
|
||||
interface TunnelServer @0xea58385c65416035 extends (RegistrationServer) {
|
||||
# DEPRECATED: Legacy tunnel authentication server
|
||||
|
||||
registerTunnel @0 (originCert :Data, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration);
|
||||
getServerInfo @1 () -> (result :ServerInfo);
|
||||
unregisterTunnel @2 (gracePeriodNanoSec :Int64) -> ();
|
||||
# obsoleteDeclarativeTunnelConnect RPC deprecated in TUN-3019
|
||||
obsoleteDeclarativeTunnelConnect @3 () -> ();
|
||||
authenticate @4 (originCert :Data, hostname :Text, options :RegistrationOptions) -> (result :AuthenticateResponse);
|
||||
reconnectTunnel @5 (jwt :Data, eventDigest :Data, connDigest :Data, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration);
|
||||
}
|
||||
|
||||
struct Tag @0xcbd96442ae3bb01a {
|
||||
# DEPRECATED: Legacy tunnel additional HTTP header mechanism
|
||||
|
||||
name @0 :Text;
|
||||
value @1 :Text;
|
||||
}
|
||||
|
||||
# === End DEPRECATED Objects ===
|
||||
|
||||
struct ClientInfo @0x83ced0145b2f114b {
|
||||
# The tunnel client's unique identifier, used to verify a reconnection.
|
||||
clientId @0 :Data;
|
||||
# Set of features this cloudflared knows it supports
|
||||
features @1 :List(Text);
|
||||
# Information about the running binary.
|
||||
version @2 :Text;
|
||||
# Client OS and CPU info
|
||||
arch @3 :Text;
|
||||
}
|
||||
|
||||
struct ConnectionOptions @0xb4bf9861fe035d04 {
|
||||
# client details
|
||||
client @0 :ClientInfo;
|
||||
# origin LAN IP
|
||||
originLocalIp @1 :Data;
|
||||
# What to do if connection already exists
|
||||
replaceExisting @2 :Bool;
|
||||
# cross stream compression setting, 0 - off, 3 - high
|
||||
compressionQuality @3 :UInt8;
|
||||
# number of previous attempts to send RegisterConnection
|
||||
numPreviousAttempts @4 :UInt8;
|
||||
}
|
||||
|
||||
struct ConnectionResponse @0xdbaa9d03d52b62dc {
|
||||
result :union {
|
||||
error @0 :ConnectionError;
|
||||
connectionDetails @1 :ConnectionDetails;
|
||||
}
|
||||
}
|
||||
|
||||
struct ConnectionError @0xf5f383d2785edb86 {
|
||||
cause @0 :Text;
|
||||
# How long should this connection wait to retry in ns
|
||||
retryAfter @1 :Int64;
|
||||
shouldRetry @2 :Bool;
|
||||
}
|
||||
|
||||
struct ConnectionDetails @0xb5f39f082b9ac18a {
|
||||
# identifier of this connection
|
||||
uuid @0 :Data;
|
||||
# airport code of the colo where this connection landed
|
||||
locationName @1 :Text;
|
||||
# tells if the tunnel is remotely managed
|
||||
tunnelIsRemotelyManaged @2: Bool;
|
||||
}
|
||||
|
||||
struct TunnelAuth @0x9496331ab9cd463f {
|
||||
accountTag @0 :Text;
|
||||
tunnelSecret @1 :Data;
|
||||
}
|
||||
|
||||
interface RegistrationServer @0xf71695ec7fe85497 {
|
||||
registerConnection @0 (auth :TunnelAuth, tunnelId :Data, connIndex :UInt8, options :ConnectionOptions) -> (result :ConnectionResponse);
|
||||
unregisterConnection @1 () -> ();
|
||||
updateLocalConfiguration @2 (config :Data) -> ();
|
||||
}
|
||||
|
||||
struct RegisterUdpSessionResponse @0xab6d5210c1f26687 {
|
||||
err @0 :Text;
|
||||
spans @1 :Data;
|
||||
}
|
||||
|
||||
interface SessionManager @0x839445a59fb01686 {
|
||||
# Let the edge decide closeAfterIdle to make sure cloudflared doesn't close session before the edge closes its side
|
||||
registerUdpSession @0 (sessionId :Data, dstIp :Data, dstPort :UInt16, closeAfterIdleHint :Int64, traceContext :Text = "") -> (result :RegisterUdpSessionResponse);
|
||||
unregisterUdpSession @1 (sessionId :Data, message :Text) -> ();
|
||||
}
|
||||
|
||||
struct UpdateConfigurationResponse @0xdb58ff694ba05cf9 {
|
||||
# Latest configuration that was applied successfully. The err field might be populated at the same time to indicate
|
||||
# that cloudflared is using an older configuration because the latest cannot be applied
|
||||
latestAppliedVersion @0 :Int32;
|
||||
# Any error encountered when trying to apply the last configuration
|
||||
err @1 :Text;
|
||||
}
|
||||
|
||||
# ConfigurationManager defines RPC to manage cloudflared configuration remotely
|
||||
interface ConfigurationManager @0xb48edfbdaa25db04 {
|
||||
updateConfiguration @0 (version :Int32, config :Data) -> (result: UpdateConfigurationResponse);
|
||||
}
|
||||
|
||||
interface CloudflaredServer @0xf548cef9dea2a4a1 extends(SessionManager, ConfigurationManager) {}
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user