mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-14 04:38:28 +10:00
- Set QUIC InitialPacketSize per IP family (IPv4: 1252, IPv6: 1232) - Set MaxIncomingStreams/MaxIncomingUniStreams to 1<<60 - Populate OriginLocalIP from local socket address in both QUIC and HTTP/2 - Pass NumPreviousAttempts from retry counter to registration - Include version number in client version string - Use OS_GOARCH format for Arch field
418 lines
11 KiB
Go
418 lines
11 KiB
Go
//go:build with_cloudflare_tunnel
|
|
|
|
package cloudflare
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"io"
|
|
"math/rand"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sagernet/sing-box/adapter"
|
|
"github.com/sagernet/sing-box/adapter/inbound"
|
|
C "github.com/sagernet/sing-box/constant"
|
|
"github.com/sagernet/sing-box/log"
|
|
"github.com/sagernet/sing-box/option"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
"github.com/sagernet/sing/common/json"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
func RegisterInbound(registry *inbound.Registry) {
|
|
inbound.Register[option.CloudflareTunnelInboundOptions](registry, C.TypeCloudflareTunnel, NewInbound)
|
|
}
|
|
|
|
type Inbound struct {
|
|
inbound.Adapter
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
router adapter.ConnectionRouterEx
|
|
logger log.ContextLogger
|
|
credentials Credentials
|
|
connectorID uuid.UUID
|
|
haConnections int
|
|
protocol string
|
|
edgeIPVersion int
|
|
datagramVersion string
|
|
gracePeriod time.Duration
|
|
|
|
connectionAccess sync.Mutex
|
|
connections []io.Closer
|
|
done sync.WaitGroup
|
|
|
|
datagramMuxerAccess sync.Mutex
|
|
datagramV2Muxers map[DatagramSender]*DatagramV2Muxer
|
|
datagramV3Muxers map[DatagramSender]*DatagramV3Muxer
|
|
|
|
ingressAccess sync.RWMutex
|
|
ingressVersion int32
|
|
ingressRules []IngressRule
|
|
}
|
|
|
|
// IngressRule maps a hostname pattern to an origin service URL.
|
|
type IngressRule struct {
|
|
Hostname string
|
|
Service string
|
|
}
|
|
|
|
type ingressConfig struct {
|
|
Ingress []ingressConfigRule `json:"ingress"`
|
|
}
|
|
|
|
type ingressConfigRule struct {
|
|
Hostname string `json:"hostname,omitempty"`
|
|
Service string `json:"service"`
|
|
}
|
|
|
|
// UpdateIngress applies a new ingress configuration from the edge.
|
|
func (i *Inbound) UpdateIngress(version int32, config []byte) {
|
|
i.ingressAccess.Lock()
|
|
defer i.ingressAccess.Unlock()
|
|
|
|
if version <= i.ingressVersion {
|
|
return
|
|
}
|
|
|
|
var parsed ingressConfig
|
|
err := json.Unmarshal(config, &parsed)
|
|
if err != nil {
|
|
i.logger.Error("parse ingress config: ", err)
|
|
return
|
|
}
|
|
|
|
rules := make([]IngressRule, 0, len(parsed.Ingress))
|
|
for _, rule := range parsed.Ingress {
|
|
rules = append(rules, IngressRule{
|
|
Hostname: rule.Hostname,
|
|
Service: rule.Service,
|
|
})
|
|
}
|
|
i.ingressRules = rules
|
|
i.ingressVersion = version
|
|
i.logger.Info("updated ingress configuration (version ", version, ", ", len(rules), " rules)")
|
|
}
|
|
|
|
// ResolveOrigin finds the origin service URL for a given hostname.
|
|
// Returns the service URL if matched, or empty string if no match.
|
|
func (i *Inbound) ResolveOrigin(hostname string) string {
|
|
i.ingressAccess.RLock()
|
|
defer i.ingressAccess.RUnlock()
|
|
|
|
for _, rule := range i.ingressRules {
|
|
if rule.Hostname == "" {
|
|
return rule.Service
|
|
}
|
|
if matchIngress(rule.Hostname, hostname) {
|
|
return rule.Service
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func matchIngress(pattern, hostname string) bool {
|
|
if pattern == hostname {
|
|
return true
|
|
}
|
|
if strings.HasPrefix(pattern, "*.") {
|
|
suffix := pattern[1:]
|
|
return strings.HasSuffix(hostname, suffix)
|
|
}
|
|
return false
|
|
}
|
|
|
|
// ResolveOriginURL rewrites a request URL to point to the origin service.
|
|
// For example, https://testbox.badnet.work/path → http://127.0.0.1:8083/path
|
|
func (i *Inbound) ResolveOriginURL(requestURL string) string {
|
|
parsed, err := url.Parse(requestURL)
|
|
if err != nil {
|
|
return requestURL
|
|
}
|
|
hostname := parsed.Hostname()
|
|
origin := i.ResolveOrigin(hostname)
|
|
if origin == "" || strings.HasPrefix(origin, "http_status:") {
|
|
return requestURL
|
|
}
|
|
originURL, err := url.Parse(origin)
|
|
if err != nil {
|
|
return requestURL
|
|
}
|
|
parsed.Scheme = originURL.Scheme
|
|
parsed.Host = originURL.Host
|
|
return parsed.String()
|
|
}
|
|
|
|
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflareTunnelInboundOptions) (adapter.Inbound, error) {
|
|
credentials, err := parseCredentials(options.Token, options.CredentialPath)
|
|
if err != nil {
|
|
return nil, E.Cause(err, "parse credentials")
|
|
}
|
|
|
|
haConnections := options.HAConnections
|
|
if haConnections <= 0 {
|
|
haConnections = 4
|
|
}
|
|
|
|
protocol := options.Protocol
|
|
if protocol != "" && protocol != "quic" && protocol != "http2" {
|
|
return nil, E.New("unsupported protocol: ", protocol, ", expected quic or http2")
|
|
}
|
|
|
|
edgeIPVersion := options.EdgeIPVersion
|
|
if edgeIPVersion != 0 && edgeIPVersion != 4 && edgeIPVersion != 6 {
|
|
return nil, E.New("unsupported edge_ip_version: ", edgeIPVersion, ", expected 0, 4 or 6")
|
|
}
|
|
|
|
datagramVersion := options.DatagramVersion
|
|
if datagramVersion != "" && datagramVersion != "v2" && datagramVersion != "v3" {
|
|
return nil, E.New("unsupported datagram_version: ", datagramVersion, ", expected v2 or v3")
|
|
}
|
|
|
|
gracePeriod := time.Duration(options.GracePeriod)
|
|
if gracePeriod == 0 {
|
|
gracePeriod = 30 * time.Second
|
|
}
|
|
|
|
inboundCtx, cancel := context.WithCancel(ctx)
|
|
|
|
return &Inbound{
|
|
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, tag),
|
|
ctx: inboundCtx,
|
|
cancel: cancel,
|
|
router: router,
|
|
logger: logger,
|
|
credentials: credentials,
|
|
connectorID: uuid.New(),
|
|
haConnections: haConnections,
|
|
protocol: protocol,
|
|
edgeIPVersion: edgeIPVersion,
|
|
datagramVersion: datagramVersion,
|
|
gracePeriod: gracePeriod,
|
|
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
|
|
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
|
|
}, nil
|
|
}
|
|
|
|
func (i *Inbound) Start(stage adapter.StartStage) error {
|
|
if stage != adapter.StartStateStart {
|
|
return nil
|
|
}
|
|
|
|
i.logger.Info("starting Cloudflare Tunnel with ", i.haConnections, " HA connections")
|
|
|
|
regions, err := DiscoverEdge(i.ctx)
|
|
if err != nil {
|
|
return E.Cause(err, "discover edge")
|
|
}
|
|
regions = FilterByIPVersion(regions, i.edgeIPVersion)
|
|
edgeAddrs := flattenRegions(regions)
|
|
if len(edgeAddrs) == 0 {
|
|
return E.New("no edge addresses available")
|
|
}
|
|
|
|
features := DefaultFeatures(i.datagramVersion)
|
|
|
|
for connIndex := 0; connIndex < i.haConnections; connIndex++ {
|
|
i.done.Add(1)
|
|
go i.superviseConnection(uint8(connIndex), edgeAddrs, features)
|
|
if connIndex == 0 {
|
|
// Wait a bit for the first connection before starting others
|
|
select {
|
|
case <-time.After(time.Second):
|
|
case <-i.ctx.Done():
|
|
return i.ctx.Err()
|
|
}
|
|
} else {
|
|
select {
|
|
case <-time.After(time.Second):
|
|
case <-i.ctx.Done():
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (i *Inbound) Close() error {
|
|
i.cancel()
|
|
i.done.Wait()
|
|
i.connectionAccess.Lock()
|
|
for _, connection := range i.connections {
|
|
connection.Close()
|
|
}
|
|
i.connections = nil
|
|
i.connectionAccess.Unlock()
|
|
return nil
|
|
}
|
|
|
|
const (
|
|
backoffBaseTime = time.Second
|
|
backoffMaxTime = 2 * time.Minute
|
|
)
|
|
|
|
func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, features []string) {
|
|
defer i.done.Done()
|
|
|
|
retries := 0
|
|
for {
|
|
select {
|
|
case <-i.ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
edgeAddr := edgeAddrs[rand.Intn(len(edgeAddrs))]
|
|
err := i.serveConnection(connIndex, edgeAddr, features, uint8(retries))
|
|
if err == nil || i.ctx.Err() != nil {
|
|
return
|
|
}
|
|
|
|
retries++
|
|
backoff := backoffDuration(retries)
|
|
i.logger.Error("connection ", connIndex, " failed: ", err, ", retrying in ", backoff)
|
|
|
|
select {
|
|
case <-time.After(backoff):
|
|
case <-i.ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr, features []string, numPreviousAttempts uint8) error {
|
|
protocol := i.protocol
|
|
if protocol == "" {
|
|
protocol = "quic"
|
|
}
|
|
|
|
switch protocol {
|
|
case "quic":
|
|
return i.serveQUIC(connIndex, edgeAddr, features, numPreviousAttempts)
|
|
case "http2":
|
|
return i.serveHTTP2(connIndex, edgeAddr, features, numPreviousAttempts)
|
|
default:
|
|
return E.New("unsupported protocol: ", protocol)
|
|
}
|
|
}
|
|
|
|
func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, features []string, numPreviousAttempts uint8) error {
|
|
i.logger.Info("connecting to edge via QUIC (connection ", connIndex, ")")
|
|
|
|
connection, err := NewQUICConnection(
|
|
i.ctx, edgeAddr, connIndex,
|
|
i.credentials, i.connectorID,
|
|
features, numPreviousAttempts, i.gracePeriod, i.logger,
|
|
)
|
|
if err != nil {
|
|
return E.Cause(err, "create QUIC connection")
|
|
}
|
|
|
|
i.trackConnection(connection)
|
|
defer func() {
|
|
i.untrackConnection(connection)
|
|
i.RemoveDatagramMuxer(connection)
|
|
}()
|
|
|
|
return connection.Serve(i.ctx, i)
|
|
}
|
|
|
|
func (i *Inbound) serveHTTP2(connIndex uint8, edgeAddr *EdgeAddr, features []string, numPreviousAttempts uint8) error {
|
|
i.logger.Info("connecting to edge via HTTP/2 (connection ", connIndex, ")")
|
|
|
|
connection, err := NewHTTP2Connection(
|
|
i.ctx, edgeAddr, connIndex,
|
|
i.credentials, i.connectorID,
|
|
features, numPreviousAttempts, i.gracePeriod, i, i.logger,
|
|
)
|
|
if err != nil {
|
|
return E.Cause(err, "create HTTP/2 connection")
|
|
}
|
|
|
|
i.trackConnection(connection)
|
|
defer i.untrackConnection(connection)
|
|
|
|
return connection.Serve(i.ctx)
|
|
}
|
|
|
|
func (i *Inbound) trackConnection(connection io.Closer) {
|
|
i.connectionAccess.Lock()
|
|
defer i.connectionAccess.Unlock()
|
|
i.connections = append(i.connections, connection)
|
|
}
|
|
|
|
func (i *Inbound) untrackConnection(connection io.Closer) {
|
|
i.connectionAccess.Lock()
|
|
defer i.connectionAccess.Unlock()
|
|
for index, tracked := range i.connections {
|
|
if tracked == connection {
|
|
i.connections = append(i.connections[:index], i.connections[index+1:]...)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func backoffDuration(retries int) time.Duration {
|
|
backoff := backoffBaseTime * (1 << min(retries, 7))
|
|
if backoff > backoffMaxTime {
|
|
backoff = backoffMaxTime
|
|
}
|
|
// Add jitter: random duration in [backoff/2, backoff)
|
|
jitter := time.Duration(rand.Int63n(int64(backoff / 2)))
|
|
return backoff/2 + jitter
|
|
}
|
|
|
|
func flattenRegions(regions [][]*EdgeAddr) []*EdgeAddr {
|
|
var result []*EdgeAddr
|
|
for _, region := range regions {
|
|
result = append(result, region...)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func parseCredentials(token string, credentialPath string) (Credentials, error) {
|
|
if token == "" && credentialPath == "" {
|
|
return Credentials{}, E.New("either token or credential_path must be specified")
|
|
}
|
|
if token != "" && credentialPath != "" {
|
|
return Credentials{}, E.New("token and credential_path are mutually exclusive")
|
|
}
|
|
if token != "" {
|
|
return parseToken(token)
|
|
}
|
|
return parseCredentialFile(credentialPath)
|
|
}
|
|
|
|
func parseToken(token string) (Credentials, error) {
|
|
data, err := base64.StdEncoding.DecodeString(token)
|
|
if err != nil {
|
|
return Credentials{}, E.Cause(err, "decode token")
|
|
}
|
|
var tunnelToken TunnelToken
|
|
err = json.Unmarshal(data, &tunnelToken)
|
|
if err != nil {
|
|
return Credentials{}, E.Cause(err, "unmarshal token")
|
|
}
|
|
return tunnelToken.ToCredentials(), nil
|
|
}
|
|
|
|
func parseCredentialFile(path string) (Credentials, error) {
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return Credentials{}, E.Cause(err, "read credential file")
|
|
}
|
|
var credentials Credentials
|
|
err = json.Unmarshal(data, &credentials)
|
|
if err != nil {
|
|
return Credentials{}, E.Cause(err, "unmarshal credential file")
|
|
}
|
|
if credentials.TunnelID == (uuid.UUID{}) {
|
|
return Credentials{}, E.New("credential file missing tunnel ID")
|
|
}
|
|
return credentials, nil
|
|
}
|