From 2321e941e0d0824a3df0c3e8cd9df097cd5dfd01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 13:52:55 +0800 Subject: [PATCH] Route cloudflare control plane through configurable dialer --- option/cloudflare_tunnel.go | 1 + protocol/cloudflare/access.go | 17 +++++++++--- protocol/cloudflare/access_test.go | 16 ++++++----- protocol/cloudflare/connection_http2.go | 19 +++++++++---- protocol/cloudflare/connection_quic.go | 31 ++++++++-------------- protocol/cloudflare/edge_discovery.go | 11 ++++---- protocol/cloudflare/edge_discovery_test.go | 4 ++- protocol/cloudflare/helpers_test.go | 2 ++ protocol/cloudflare/inbound.go | 18 ++++++++++--- 9 files changed, 75 insertions(+), 44 deletions(-) diff --git a/option/cloudflare_tunnel.go b/option/cloudflare_tunnel.go index c0fdbfa87..bf388a044 100644 --- a/option/cloudflare_tunnel.go +++ b/option/cloudflare_tunnel.go @@ -7,6 +7,7 @@ type CloudflareTunnelInboundOptions struct { CredentialPath string `json:"credential_path,omitempty"` HAConnections int `json:"ha_connections,omitempty"` Protocol string `json:"protocol,omitempty"` + ControlDialer DialerOptions `json:"control_dialer,omitempty"` EdgeIPVersion int `json:"edge_ip_version,omitempty"` DatagramVersion string `json:"datagram_version,omitempty"` GracePeriod badoption.Duration `json:"grace_period,omitempty"` diff --git a/protocol/cloudflare/access.go b/protocol/cloudflare/access.go index 9407d0312..75c1e8ada 100644 --- a/protocol/cloudflare/access.go +++ b/protocol/cloudflare/access.go @@ -5,19 +5,29 @@ package cloudflare import ( "context" "fmt" + "net" "net/http" "strings" "sync" "github.com/coreos/go-oidc/v3/oidc" E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) const accessJWTAssertionHeader = "Cf-Access-Jwt-Assertion" -var newAccessValidator = func(access AccessConfig) (accessValidator, error) { +var newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) { issuerURL := accessIssuerURL(access.TeamName, access.Environment) - keySet := oidc.NewRemoteKeySet(context.Background(), issuerURL+"/cdn-cgi/access/certs") + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return dialer.DialContext(ctx, network, M.ParseSocksaddr(address)) + }, + }, + } + keySet := oidc.NewRemoteKeySet(oidc.ClientContext(context.Background(), client), issuerURL+"/cdn-cgi/access/certs") verifier := oidc.NewVerifier(issuerURL, keySet, &oidc.Config{ SkipClientIDCheck: true, }) @@ -82,6 +92,7 @@ func accessValidatorKey(access AccessConfig) string { type accessValidatorCache struct { access sync.RWMutex values map[string]accessValidator + dialer N.Dialer } func (c *accessValidatorCache) Get(accessConfig AccessConfig) (accessValidator, error) { @@ -93,7 +104,7 @@ func (c *accessValidatorCache) Get(accessConfig AccessConfig) (accessValidator, return validator, nil } - validator, err := newAccessValidator(accessConfig) + validator, err := newAccessValidator(accessConfig, c.dialer) if err != nil { return nil, err } diff --git a/protocol/cloudflare/access_test.go b/protocol/cloudflare/access_test.go index 357cd9f43..8c7d2b9e1 100644 --- a/protocol/cloudflare/access_test.go +++ b/protocol/cloudflare/access_test.go @@ -14,6 +14,7 @@ import ( "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) type fakeAccessValidator struct { @@ -31,10 +32,11 @@ func newAccessTestInbound(t *testing.T) *Inbound { t.Fatal(err) } return &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), - logger: logFactory.NewLogger("test"), - accessCache: &accessValidatorCache{values: make(map[string]accessValidator)}, - router: &testRouter{}, + Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + logger: logFactory.NewLogger("test"), + accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer}, + router: &testRouter{}, + controlDialer: N.SystemDialer, } } @@ -53,7 +55,7 @@ func TestRoundTripHTTPAccessDenied(t *testing.T) { defer func() { newAccessValidator = originalFactory }() - newAccessValidator = func(access AccessConfig) (accessValidator, error) { + newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) { return &fakeAccessValidator{err: E.New("forbidden")}, nil } @@ -87,7 +89,7 @@ func TestHandleHTTPServiceStatusAccessDenied(t *testing.T) { defer func() { newAccessValidator = originalFactory }() - newAccessValidator = func(access AccessConfig) (accessValidator, error) { + newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) { return &fakeAccessValidator{err: E.New("forbidden")}, nil } @@ -121,7 +123,7 @@ func TestHandleHTTPServiceStreamAccessDenied(t *testing.T) { defer func() { newAccessValidator = originalFactory }() - newAccessValidator = func(access AccessConfig) (accessValidator, error) { + newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) { return &fakeAccessValidator{err: E.New("forbidden")}, nil } diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index 2de68d464..9fed72cf8 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -17,6 +17,7 @@ import ( "github.com/sagernet/sing-box/log" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" + M "github.com/sagernet/sing/common/metadata" "github.com/google/uuid" "golang.org/x/net/http2" @@ -71,8 +72,7 @@ func NewHTTP2Connection( ServerName: h2EdgeSNI, } - dialer := &net.Dialer{} - tcpConn, err := dialer.DialContext(ctx, "tcp", edgeAddr.TCP.String()) + tcpConn, err := inbound.controlDialer.DialContext(ctx, "tcp", M.SocksaddrFrom(edgeAddr.TCP.AddrPort().Addr(), edgeAddr.TCP.AddrPort().Port())) if err != nil { return nil, E.Cause(err, "dial edge TCP") } @@ -113,10 +113,13 @@ func (c *HTTP2Connection) Serve(ctx context.Context) error { Handler: c, }) - if c.registrationResult != nil { - return nil + if ctx.Err() != nil { + return ctx.Err() } - return E.New("edge connection closed before registration") + if c.registrationResult == nil { + return E.New("edge connection closed before registration") + } + return E.New("edge connection closed") } func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -167,6 +170,12 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque " (connection ", result.ConnectionID, ")") <-ctx.Done() + unregisterCtx, cancel := context.WithTimeout(context.Background(), c.gracePeriod) + defer cancel() + err = c.registrationClient.Unregister(unregisterCtx) + if err != nil { + c.logger.Debug("failed to unregister: ", err) + } c.registrationClient.Close() } diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go index f06ad714f..fd2e56a3b 100644 --- a/protocol/cloudflare/connection_quic.go +++ b/protocol/cloudflare/connection_quic.go @@ -8,13 +8,14 @@ import ( "fmt" "io" "net" - "runtime" "sync" "time" "github.com/sagernet/quic-go" "github.com/sagernet/sing-box/log" E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/google/uuid" ) @@ -88,6 +89,7 @@ func NewQUICConnection( features []string, numPreviousAttempts uint8, gracePeriod time.Duration, + controlDialer N.Dialer, logger log.ContextLogger, ) (*QUICConnection, error) { rootCAs, err := cloudflareRootCertPool() @@ -111,7 +113,7 @@ func NewQUICConnection( InitialPacketSize: quicInitialPacketSize(edgeAddr.IPVersion), } - udpConn, err := createUDPConnForConnIndex(connIndex, edgeAddr) + udpConn, err := createUDPConnForConnIndex(ctx, connIndex, edgeAddr, controlDialer) if err != nil { return nil, E.Cause(err, "listen UDP for QUIC edge") } @@ -135,30 +137,19 @@ func NewQUICConnection( }, nil } -func createUDPConnForConnIndex(connIndex uint8, edgeAddr *EdgeAddr) (*net.UDPConn, error) { +func createUDPConnForConnIndex(ctx context.Context, connIndex uint8, edgeAddr *EdgeAddr, controlDialer N.Dialer) (*net.UDPConn, error) { quicPortAccess.Lock() defer quicPortAccess.Unlock() - network := "udp" - if runtime.GOOS == "darwin" { - if edgeAddr.IPVersion == 4 { - network = "udp4" - } else { - network = "udp6" - } - } - - if port, loaded := quicPortByConnIndex[connIndex]; loaded { - udpConn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) - if err == nil { - return udpConn, nil - } - } - - udpConn, err := net.ListenUDP(network, &net.UDPAddr{Port: 0}) + packetConn, err := controlDialer.ListenPacket(ctx, M.SocksaddrFrom(edgeAddr.UDP.AddrPort().Addr(), edgeAddr.UDP.AddrPort().Port())) if err != nil { return nil, err } + udpConn, ok := packetConn.(*net.UDPConn) + if !ok { + packetConn.Close() + return nil, fmt.Errorf("unexpected packet conn type %T", packetConn) + } udpAddr, ok := udpConn.LocalAddr().(*net.UDPAddr) if !ok { udpConn.Close() diff --git a/protocol/cloudflare/edge_discovery.go b/protocol/cloudflare/edge_discovery.go index 0c08bcbf8..922063ce4 100644 --- a/protocol/cloudflare/edge_discovery.go +++ b/protocol/cloudflare/edge_discovery.go @@ -9,6 +9,8 @@ import ( "time" E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) const ( @@ -37,10 +39,10 @@ type EdgeAddr struct { // DiscoverEdge performs SRV-based edge discovery and returns addresses // partitioned into regions (typically 2). -func DiscoverEdge(ctx context.Context, region string) ([][]*EdgeAddr, error) { +func DiscoverEdge(ctx context.Context, region string, controlDialer N.Dialer) ([][]*EdgeAddr, error) { regions, err := lookupEdgeSRV(region) if err != nil { - regions, err = lookupEdgeSRVWithDoT(ctx, region) + regions, err = lookupEdgeSRVWithDoT(ctx, region, controlDialer) if err != nil { return nil, E.Cause(err, "edge discovery") } @@ -59,12 +61,11 @@ func lookupEdgeSRV(region string) ([][]*EdgeAddr, error) { return resolveSRVRecords(addrs) } -func lookupEdgeSRVWithDoT(ctx context.Context, region string) ([][]*EdgeAddr, error) { +func lookupEdgeSRVWithDoT(ctx context.Context, region string, controlDialer N.Dialer) ([][]*EdgeAddr, error) { resolver := &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { - var dialer net.Dialer - conn, err := dialer.DialContext(ctx, "tcp", dotServerAddr) + conn, err := controlDialer.DialContext(ctx, "tcp", M.ParseSocksaddr(dotServerAddr)) if err != nil { return nil, err } diff --git a/protocol/cloudflare/edge_discovery_test.go b/protocol/cloudflare/edge_discovery_test.go index c282009d0..930fd46be 100644 --- a/protocol/cloudflare/edge_discovery_test.go +++ b/protocol/cloudflare/edge_discovery_test.go @@ -6,10 +6,12 @@ import ( "context" "net" "testing" + + N "github.com/sagernet/sing/common/network" ) func TestDiscoverEdge(t *testing.T) { - regions, err := DiscoverEdge(context.Background(), "") + regions, err := DiscoverEdge(context.Background(), "", N.SystemDialer) if err != nil { t.Fatal("DiscoverEdge: ", err) } diff --git a/protocol/cloudflare/helpers_test.go b/protocol/cloudflare/helpers_test.go index f06a5fa2a..d873a40b1 100644 --- a/protocol/cloudflare/helpers_test.go +++ b/protocol/cloudflare/helpers_test.go @@ -192,6 +192,8 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i configManager: configManager, datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), + controlDialer: N.SystemDialer, + accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer}, } t.Cleanup(func() { diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index ae48ba40e..5e30b89d1 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -16,11 +16,13 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/inbound" + boxDialer "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" + N "github.com/sagernet/sing/common/network" "github.com/google/uuid" ) @@ -46,6 +48,7 @@ type Inbound struct { configManager *ConfigManager flowLimiter *FlowLimiter accessCache *accessValidatorCache + controlDialer N.Dialer connectionAccess sync.Mutex connections []io.Closer @@ -95,6 +98,14 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo if err != nil { return nil, E.Cause(err, "build cloudflare tunnel runtime config") } + controlDialer, err := boxDialer.NewWithOptions(boxDialer.Options{ + Context: ctx, + Options: options.ControlDialer, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "build cloudflare tunnel control dialer") + } region := options.Region if region != "" && credentials.Endpoint != "" { @@ -122,7 +133,8 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo gracePeriod: gracePeriod, configManager: configManager, flowLimiter: &FlowLimiter{}, - accessCache: &accessValidatorCache{values: make(map[string]accessValidator)}, + accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: controlDialer}, + controlDialer: controlDialer, datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), }, nil @@ -135,7 +147,7 @@ func (i *Inbound) Start(stage adapter.StartStage) error { i.logger.Info("starting Cloudflare Tunnel with ", i.haConnections, " HA connections") - regions, err := DiscoverEdge(i.ctx, i.region) + regions, err := DiscoverEdge(i.ctx, i.region, i.controlDialer) if err != nil { return E.Cause(err, "discover edge") } @@ -287,7 +299,7 @@ func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, features []stri connection, err := NewQUICConnection( i.ctx, edgeAddr, connIndex, i.credentials, i.connectorID, - features, numPreviousAttempts, i.gracePeriod, i.logger, + features, numPreviousAttempts, i.gracePeriod, i.controlDialer, i.logger, ) if err != nil { return E.Cause(err, "create QUIC connection")