Route cloudflare control plane through configurable dialer

This commit is contained in:
世界
2026-03-24 13:52:55 +08:00
parent d017cbe008
commit 2321e941e0
9 changed files with 75 additions and 44 deletions

View File

@@ -5,19 +5,29 @@ package cloudflare
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"sync"
"github.com/coreos/go-oidc/v3/oidc"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
const accessJWTAssertionHeader = "Cf-Access-Jwt-Assertion"
var newAccessValidator = func(access AccessConfig) (accessValidator, error) {
var newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
issuerURL := accessIssuerURL(access.TeamName, access.Environment)
keySet := oidc.NewRemoteKeySet(context.Background(), issuerURL+"/cdn-cgi/access/certs")
client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return dialer.DialContext(ctx, network, M.ParseSocksaddr(address))
},
},
}
keySet := oidc.NewRemoteKeySet(oidc.ClientContext(context.Background(), client), issuerURL+"/cdn-cgi/access/certs")
verifier := oidc.NewVerifier(issuerURL, keySet, &oidc.Config{
SkipClientIDCheck: true,
})
@@ -82,6 +92,7 @@ func accessValidatorKey(access AccessConfig) string {
type accessValidatorCache struct {
access sync.RWMutex
values map[string]accessValidator
dialer N.Dialer
}
func (c *accessValidatorCache) Get(accessConfig AccessConfig) (accessValidator, error) {
@@ -93,7 +104,7 @@ func (c *accessValidatorCache) Get(accessConfig AccessConfig) (accessValidator,
return validator, nil
}
validator, err := newAccessValidator(accessConfig)
validator, err := newAccessValidator(accessConfig, c.dialer)
if err != nil {
return nil, err
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type fakeAccessValidator struct {
@@ -31,10 +32,11 @@ func newAccessTestInbound(t *testing.T) *Inbound {
t.Fatal(err)
}
return &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"),
logger: logFactory.NewLogger("test"),
accessCache: &accessValidatorCache{values: make(map[string]accessValidator)},
router: &testRouter{},
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"),
logger: logFactory.NewLogger("test"),
accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer},
router: &testRouter{},
controlDialer: N.SystemDialer,
}
}
@@ -53,7 +55,7 @@ func TestRoundTripHTTPAccessDenied(t *testing.T) {
defer func() {
newAccessValidator = originalFactory
}()
newAccessValidator = func(access AccessConfig) (accessValidator, error) {
newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
return &fakeAccessValidator{err: E.New("forbidden")}, nil
}
@@ -87,7 +89,7 @@ func TestHandleHTTPServiceStatusAccessDenied(t *testing.T) {
defer func() {
newAccessValidator = originalFactory
}()
newAccessValidator = func(access AccessConfig) (accessValidator, error) {
newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
return &fakeAccessValidator{err: E.New("forbidden")}, nil
}
@@ -121,7 +123,7 @@ func TestHandleHTTPServiceStreamAccessDenied(t *testing.T) {
defer func() {
newAccessValidator = originalFactory
}()
newAccessValidator = func(access AccessConfig) (accessValidator, error) {
newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
return &fakeAccessValidator{err: E.New("forbidden")}, nil
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
M "github.com/sagernet/sing/common/metadata"
"github.com/google/uuid"
"golang.org/x/net/http2"
@@ -71,8 +72,7 @@ func NewHTTP2Connection(
ServerName: h2EdgeSNI,
}
dialer := &net.Dialer{}
tcpConn, err := dialer.DialContext(ctx, "tcp", edgeAddr.TCP.String())
tcpConn, err := inbound.controlDialer.DialContext(ctx, "tcp", M.SocksaddrFrom(edgeAddr.TCP.AddrPort().Addr(), edgeAddr.TCP.AddrPort().Port()))
if err != nil {
return nil, E.Cause(err, "dial edge TCP")
}
@@ -113,10 +113,13 @@ func (c *HTTP2Connection) Serve(ctx context.Context) error {
Handler: c,
})
if c.registrationResult != nil {
return nil
if ctx.Err() != nil {
return ctx.Err()
}
return E.New("edge connection closed before registration")
if c.registrationResult == nil {
return E.New("edge connection closed before registration")
}
return E.New("edge connection closed")
}
func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -167,6 +170,12 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque
" (connection ", result.ConnectionID, ")")
<-ctx.Done()
unregisterCtx, cancel := context.WithTimeout(context.Background(), c.gracePeriod)
defer cancel()
err = c.registrationClient.Unregister(unregisterCtx)
if err != nil {
c.logger.Debug("failed to unregister: ", err)
}
c.registrationClient.Close()
}

View File

@@ -8,13 +8,14 @@ import (
"fmt"
"io"
"net"
"runtime"
"sync"
"time"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/google/uuid"
)
@@ -88,6 +89,7 @@ func NewQUICConnection(
features []string,
numPreviousAttempts uint8,
gracePeriod time.Duration,
controlDialer N.Dialer,
logger log.ContextLogger,
) (*QUICConnection, error) {
rootCAs, err := cloudflareRootCertPool()
@@ -111,7 +113,7 @@ func NewQUICConnection(
InitialPacketSize: quicInitialPacketSize(edgeAddr.IPVersion),
}
udpConn, err := createUDPConnForConnIndex(connIndex, edgeAddr)
udpConn, err := createUDPConnForConnIndex(ctx, connIndex, edgeAddr, controlDialer)
if err != nil {
return nil, E.Cause(err, "listen UDP for QUIC edge")
}
@@ -135,30 +137,19 @@ func NewQUICConnection(
}, nil
}
func createUDPConnForConnIndex(connIndex uint8, edgeAddr *EdgeAddr) (*net.UDPConn, error) {
func createUDPConnForConnIndex(ctx context.Context, connIndex uint8, edgeAddr *EdgeAddr, controlDialer N.Dialer) (*net.UDPConn, error) {
quicPortAccess.Lock()
defer quicPortAccess.Unlock()
network := "udp"
if runtime.GOOS == "darwin" {
if edgeAddr.IPVersion == 4 {
network = "udp4"
} else {
network = "udp6"
}
}
if port, loaded := quicPortByConnIndex[connIndex]; loaded {
udpConn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err == nil {
return udpConn, nil
}
}
udpConn, err := net.ListenUDP(network, &net.UDPAddr{Port: 0})
packetConn, err := controlDialer.ListenPacket(ctx, M.SocksaddrFrom(edgeAddr.UDP.AddrPort().Addr(), edgeAddr.UDP.AddrPort().Port()))
if err != nil {
return nil, err
}
udpConn, ok := packetConn.(*net.UDPConn)
if !ok {
packetConn.Close()
return nil, fmt.Errorf("unexpected packet conn type %T", packetConn)
}
udpAddr, ok := udpConn.LocalAddr().(*net.UDPAddr)
if !ok {
udpConn.Close()

View File

@@ -9,6 +9,8 @@ import (
"time"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
const (
@@ -37,10 +39,10 @@ type EdgeAddr struct {
// DiscoverEdge performs SRV-based edge discovery and returns addresses
// partitioned into regions (typically 2).
func DiscoverEdge(ctx context.Context, region string) ([][]*EdgeAddr, error) {
func DiscoverEdge(ctx context.Context, region string, controlDialer N.Dialer) ([][]*EdgeAddr, error) {
regions, err := lookupEdgeSRV(region)
if err != nil {
regions, err = lookupEdgeSRVWithDoT(ctx, region)
regions, err = lookupEdgeSRVWithDoT(ctx, region, controlDialer)
if err != nil {
return nil, E.Cause(err, "edge discovery")
}
@@ -59,12 +61,11 @@ func lookupEdgeSRV(region string) ([][]*EdgeAddr, error) {
return resolveSRVRecords(addrs)
}
func lookupEdgeSRVWithDoT(ctx context.Context, region string) ([][]*EdgeAddr, error) {
func lookupEdgeSRVWithDoT(ctx context.Context, region string, controlDialer N.Dialer) ([][]*EdgeAddr, error) {
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
var dialer net.Dialer
conn, err := dialer.DialContext(ctx, "tcp", dotServerAddr)
conn, err := controlDialer.DialContext(ctx, "tcp", M.ParseSocksaddr(dotServerAddr))
if err != nil {
return nil, err
}

View File

@@ -6,10 +6,12 @@ import (
"context"
"net"
"testing"
N "github.com/sagernet/sing/common/network"
)
func TestDiscoverEdge(t *testing.T) {
regions, err := DiscoverEdge(context.Background(), "")
regions, err := DiscoverEdge(context.Background(), "", N.SystemDialer)
if err != nil {
t.Fatal("DiscoverEdge: ", err)
}

View File

@@ -192,6 +192,8 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i
configManager: configManager,
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
controlDialer: N.SystemDialer,
accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer},
}
t.Cleanup(func() {

View File

@@ -16,11 +16,13 @@ import (
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
boxDialer "github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
N "github.com/sagernet/sing/common/network"
"github.com/google/uuid"
)
@@ -46,6 +48,7 @@ type Inbound struct {
configManager *ConfigManager
flowLimiter *FlowLimiter
accessCache *accessValidatorCache
controlDialer N.Dialer
connectionAccess sync.Mutex
connections []io.Closer
@@ -95,6 +98,14 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
if err != nil {
return nil, E.Cause(err, "build cloudflare tunnel runtime config")
}
controlDialer, err := boxDialer.NewWithOptions(boxDialer.Options{
Context: ctx,
Options: options.ControlDialer,
RemoteIsDomain: true,
})
if err != nil {
return nil, E.Cause(err, "build cloudflare tunnel control dialer")
}
region := options.Region
if region != "" && credentials.Endpoint != "" {
@@ -122,7 +133,8 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
gracePeriod: gracePeriod,
configManager: configManager,
flowLimiter: &FlowLimiter{},
accessCache: &accessValidatorCache{values: make(map[string]accessValidator)},
accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: controlDialer},
controlDialer: controlDialer,
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
}, nil
@@ -135,7 +147,7 @@ func (i *Inbound) Start(stage adapter.StartStage) error {
i.logger.Info("starting Cloudflare Tunnel with ", i.haConnections, " HA connections")
regions, err := DiscoverEdge(i.ctx, i.region)
regions, err := DiscoverEdge(i.ctx, i.region, i.controlDialer)
if err != nil {
return E.Cause(err, "discover edge")
}
@@ -287,7 +299,7 @@ func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, features []stri
connection, err := NewQUICConnection(
i.ctx, edgeAddr, connIndex,
i.credentials, i.connectorID,
features, numPreviousAttempts, i.gracePeriod, i.logger,
features, numPreviousAttempts, i.gracePeriod, i.controlDialer, i.logger,
)
if err != nil {
return E.Cause(err, "create QUIC connection")