Files
sing-box/protocol/cloudflare/inbound.go
2026-03-31 15:32:56 +08:00

426 lines
12 KiB
Go

//go:build with_cloudflared
package cloudflare
import (
"context"
stdTLS "crypto/tls"
"encoding/base64"
"errors"
"io"
"math/rand"
"net"
"net/http"
"net/url"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
boxDialer "github.com/sagernet/sing-box/common/dialer"
boxTLS "github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
N "github.com/sagernet/sing/common/network"
"github.com/google/uuid"
)
func RegisterInbound(registry *inbound.Registry) {
inbound.Register[option.CloudflaredInboundOptions](registry, C.TypeCloudflared, NewInbound)
}
var ErrNonRemoteManagedTunnelUnsupported = errors.New("cloudflared only supports remote-managed tunnels")
type Inbound struct {
inbound.Adapter
ctx context.Context
cancel context.CancelFunc
router adapter.Router
logger log.ContextLogger
credentials Credentials
connectorID uuid.UUID
haConnections int
protocol string
region string
edgeIPVersion int
datagramVersion string
gracePeriod time.Duration
configManager *ConfigManager
flowLimiter *FlowLimiter
accessCache *accessValidatorCache
controlDialer N.Dialer
connectionAccess sync.Mutex
connections []io.Closer
done sync.WaitGroup
datagramMuxerAccess sync.Mutex
datagramV2Muxers map[DatagramSender]*DatagramV2Muxer
datagramV3Muxers map[DatagramSender]*DatagramV3Muxer
helloWorldAccess sync.Mutex
helloWorldServer *http.Server
helloWorldURL *url.URL
connectedAccess sync.Mutex
connectedIndices map[uint8]struct{}
connectedNotify chan uint8
}
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflaredInboundOptions) (adapter.Inbound, error) {
if options.Token == "" {
return nil, E.New("missing token")
}
credentials, err := parseToken(options.Token)
if err != nil {
return nil, E.Cause(err, "parse token")
}
haConnections := options.HAConnections
if haConnections <= 0 {
haConnections = 4
}
protocol := options.Protocol
if protocol != "" && protocol != "quic" && protocol != "http2" {
return nil, E.New("unsupported protocol: ", protocol, ", expected quic or http2")
}
edgeIPVersion := options.EdgeIPVersion
if edgeIPVersion != 0 && edgeIPVersion != 4 && edgeIPVersion != 6 {
return nil, E.New("unsupported edge_ip_version: ", edgeIPVersion, ", expected 0, 4 or 6")
}
datagramVersion := options.DatagramVersion
if datagramVersion != "" && datagramVersion != "v2" && datagramVersion != "v3" {
return nil, E.New("unsupported datagram_version: ", datagramVersion, ", expected v2 or v3")
}
gracePeriod := time.Duration(options.GracePeriod)
if gracePeriod == 0 {
gracePeriod = 30 * time.Second
}
configManager, err := NewConfigManager()
if err != nil {
return nil, E.Cause(err, "build cloudflared runtime config")
}
controlDialer, err := boxDialer.NewWithOptions(boxDialer.Options{
Context: ctx,
Options: options.ControlDialer,
RemoteIsDomain: true,
})
if err != nil {
return nil, E.Cause(err, "build cloudflared control dialer")
}
region := options.Region
if region != "" && credentials.Endpoint != "" {
return nil, E.New("region cannot be specified when credentials already include an endpoint")
}
if region == "" {
region = credentials.Endpoint
}
inboundCtx, cancel := context.WithCancel(ctx)
return &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, tag),
ctx: inboundCtx,
cancel: cancel,
router: router,
logger: logger,
credentials: credentials,
connectorID: uuid.New(),
haConnections: haConnections,
protocol: protocol,
region: region,
edgeIPVersion: edgeIPVersion,
datagramVersion: datagramVersion,
gracePeriod: gracePeriod,
configManager: configManager,
flowLimiter: &FlowLimiter{},
accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: controlDialer},
controlDialer: controlDialer,
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
connectedIndices: make(map[uint8]struct{}),
connectedNotify: make(chan uint8, haConnections),
}, nil
}
func (i *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
i.logger.Info("starting Cloudflare Tunnel with ", i.haConnections, " HA connections")
regions, err := DiscoverEdge(i.ctx, i.region, i.controlDialer)
if err != nil {
return E.Cause(err, "discover edge")
}
regions = FilterByIPVersion(regions, i.edgeIPVersion)
edgeAddrs := flattenRegions(regions)
if len(edgeAddrs) == 0 {
return E.New("no edge addresses available")
}
features := DefaultFeatures(i.datagramVersion)
for connIndex := 0; connIndex < i.haConnections; connIndex++ {
i.done.Add(1)
go i.superviseConnection(uint8(connIndex), edgeAddrs, features)
select {
case readyConnIndex := <-i.connectedNotify:
if readyConnIndex != uint8(connIndex) {
i.logger.Debug("received unexpected ready notification for connection ", readyConnIndex)
}
case <-time.After(firstConnectionReadyTimeout):
case <-i.ctx.Done():
if connIndex == 0 {
return i.ctx.Err()
}
return nil
}
}
return nil
}
func (i *Inbound) notifyConnected(connIndex uint8) {
if i.connectedNotify == nil {
return
}
i.connectedAccess.Lock()
if _, loaded := i.connectedIndices[connIndex]; loaded {
i.connectedAccess.Unlock()
return
}
i.connectedIndices[connIndex] = struct{}{}
i.connectedAccess.Unlock()
i.connectedNotify <- connIndex
}
func (i *Inbound) ApplyConfig(version int32, config []byte) ConfigUpdateResult {
result := i.configManager.Apply(version, config)
if result.Err != nil {
i.logger.Error("update ingress configuration: ", result.Err)
return result
}
i.logger.Info("updated ingress configuration (version ", result.LastAppliedVersion, ")")
return result
}
func (i *Inbound) maxActiveFlows() uint64 {
return i.configManager.Snapshot().WarpRouting.MaxActiveFlows
}
func (i *Inbound) Close() error {
i.cancel()
i.done.Wait()
i.connectionAccess.Lock()
for _, connection := range i.connections {
connection.Close()
}
i.connections = nil
i.connectionAccess.Unlock()
if i.helloWorldServer != nil {
i.helloWorldServer.Close()
}
return nil
}
func (i *Inbound) ensureHelloWorldURL() (*url.URL, error) {
i.helloWorldAccess.Lock()
defer i.helloWorldAccess.Unlock()
if i.helloWorldURL != nil {
return i.helloWorldURL, nil
}
mux := http.NewServeMux()
mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, _ = writer.Write([]byte("Hello World"))
})
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, E.Cause(err, "listen hello world server")
}
certificate, err := boxTLS.GenerateKeyPair(nil, nil, time.Now, "localhost")
if err != nil {
_ = listener.Close()
return nil, E.Cause(err, "generate hello world certificate")
}
tlsListener := stdTLS.NewListener(listener, &stdTLS.Config{
Certificates: []stdTLS.Certificate{*certificate},
})
server := &http.Server{Handler: mux}
go server.Serve(tlsListener)
i.helloWorldServer = server
i.helloWorldURL = &url.URL{
Scheme: "https",
Host: listener.Addr().String(),
}
return i.helloWorldURL, nil
}
const (
backoffBaseTime = time.Second
backoffMaxTime = 2 * time.Minute
firstConnectionReadyTimeout = 15 * time.Second
)
func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, features []string) {
defer i.done.Done()
retries := 0
for {
select {
case <-i.ctx.Done():
return
default:
}
edgeAddr := edgeAddrs[rand.Intn(len(edgeAddrs))]
err := i.serveConnection(connIndex, edgeAddr, features, uint8(retries))
if err == nil || i.ctx.Err() != nil {
return
}
if errors.Is(err, ErrNonRemoteManagedTunnelUnsupported) {
i.logger.Error("connection ", connIndex, " failed permanently: ", err)
i.cancel()
return
}
retries++
backoff := backoffDuration(retries)
i.logger.Error("connection ", connIndex, " failed: ", err, ", retrying in ", backoff)
select {
case <-time.After(backoff):
case <-i.ctx.Done():
return
}
}
}
func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr, features []string, numPreviousAttempts uint8) error {
protocol := i.protocol
if protocol == "" {
protocol = "quic"
}
switch protocol {
case "quic":
err := i.serveQUIC(connIndex, edgeAddr, features, numPreviousAttempts)
if err == nil || i.ctx.Err() != nil {
return err
}
if errors.Is(err, ErrNonRemoteManagedTunnelUnsupported) {
return err
}
i.logger.Warn("QUIC connection failed, falling back to HTTP/2: ", err)
return i.serveHTTP2(connIndex, edgeAddr, features, numPreviousAttempts)
case "http2":
return i.serveHTTP2(connIndex, edgeAddr, features, numPreviousAttempts)
default:
return E.New("unsupported protocol: ", protocol)
}
}
func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, features []string, numPreviousAttempts uint8) error {
i.logger.Info("connecting to edge via QUIC (connection ", connIndex, ")")
connection, err := NewQUICConnection(
i.ctx, edgeAddr, connIndex,
i.credentials, i.connectorID,
features, numPreviousAttempts, i.gracePeriod, i.controlDialer, func() {
i.notifyConnected(connIndex)
}, i.logger,
)
if err != nil {
return E.Cause(err, "create QUIC connection")
}
i.trackConnection(connection)
defer func() {
i.untrackConnection(connection)
i.RemoveDatagramMuxer(connection)
}()
return connection.Serve(i.ctx, i)
}
func (i *Inbound) serveHTTP2(connIndex uint8, edgeAddr *EdgeAddr, features []string, numPreviousAttempts uint8) error {
i.logger.Info("connecting to edge via HTTP/2 (connection ", connIndex, ")")
connection, err := NewHTTP2Connection(
i.ctx, edgeAddr, connIndex,
i.credentials, i.connectorID,
features, numPreviousAttempts, i.gracePeriod, i, i.logger,
)
if err != nil {
return E.Cause(err, "create HTTP/2 connection")
}
i.trackConnection(connection)
defer i.untrackConnection(connection)
return connection.Serve(i.ctx)
}
func (i *Inbound) trackConnection(connection io.Closer) {
i.connectionAccess.Lock()
defer i.connectionAccess.Unlock()
i.connections = append(i.connections, connection)
}
func (i *Inbound) untrackConnection(connection io.Closer) {
i.connectionAccess.Lock()
defer i.connectionAccess.Unlock()
for index, tracked := range i.connections {
if tracked == connection {
i.connections = append(i.connections[:index], i.connections[index+1:]...)
break
}
}
}
func backoffDuration(retries int) time.Duration {
backoff := backoffBaseTime * (1 << min(retries, 7))
if backoff > backoffMaxTime {
backoff = backoffMaxTime
}
// Add jitter: random duration in [backoff/2, backoff)
jitter := time.Duration(rand.Int63n(int64(backoff / 2)))
return backoff/2 + jitter
}
func flattenRegions(regions [][]*EdgeAddr) []*EdgeAddr {
var result []*EdgeAddr
for _, region := range regions {
result = append(result, region...)
}
return result
}
func parseToken(token string) (Credentials, error) {
data, err := base64.StdEncoding.DecodeString(token)
if err != nil {
return Credentials{}, E.Cause(err, "decode token")
}
var tunnelToken TunnelToken
err = json.Unmarshal(data, &tunnelToken)
if err != nil {
return Credentials{}, E.Cause(err, "unmarshal token")
}
return tunnelToken.ToCredentials(), nil
}