mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-14 12:48:28 +10:00
Route cloudflare control plane through configurable dialer
This commit is contained in:
@@ -5,19 +5,29 @@ package cloudflare
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
const accessJWTAssertionHeader = "Cf-Access-Jwt-Assertion"
|
||||
|
||||
var newAccessValidator = func(access AccessConfig) (accessValidator, error) {
|
||||
var newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
|
||||
issuerURL := accessIssuerURL(access.TeamName, access.Environment)
|
||||
keySet := oidc.NewRemoteKeySet(context.Background(), issuerURL+"/cdn-cgi/access/certs")
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, network, M.ParseSocksaddr(address))
|
||||
},
|
||||
},
|
||||
}
|
||||
keySet := oidc.NewRemoteKeySet(oidc.ClientContext(context.Background(), client), issuerURL+"/cdn-cgi/access/certs")
|
||||
verifier := oidc.NewVerifier(issuerURL, keySet, &oidc.Config{
|
||||
SkipClientIDCheck: true,
|
||||
})
|
||||
@@ -82,6 +92,7 @@ func accessValidatorKey(access AccessConfig) string {
|
||||
type accessValidatorCache struct {
|
||||
access sync.RWMutex
|
||||
values map[string]accessValidator
|
||||
dialer N.Dialer
|
||||
}
|
||||
|
||||
func (c *accessValidatorCache) Get(accessConfig AccessConfig) (accessValidator, error) {
|
||||
@@ -93,7 +104,7 @@ func (c *accessValidatorCache) Get(accessConfig AccessConfig) (accessValidator,
|
||||
return validator, nil
|
||||
}
|
||||
|
||||
validator, err := newAccessValidator(accessConfig)
|
||||
validator, err := newAccessValidator(accessConfig, c.dialer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type fakeAccessValidator struct {
|
||||
@@ -31,10 +32,11 @@ func newAccessTestInbound(t *testing.T) *Inbound {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"),
|
||||
logger: logFactory.NewLogger("test"),
|
||||
accessCache: &accessValidatorCache{values: make(map[string]accessValidator)},
|
||||
router: &testRouter{},
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"),
|
||||
logger: logFactory.NewLogger("test"),
|
||||
accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer},
|
||||
router: &testRouter{},
|
||||
controlDialer: N.SystemDialer,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +55,7 @@ func TestRoundTripHTTPAccessDenied(t *testing.T) {
|
||||
defer func() {
|
||||
newAccessValidator = originalFactory
|
||||
}()
|
||||
newAccessValidator = func(access AccessConfig) (accessValidator, error) {
|
||||
newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
|
||||
return &fakeAccessValidator{err: E.New("forbidden")}, nil
|
||||
}
|
||||
|
||||
@@ -87,7 +89,7 @@ func TestHandleHTTPServiceStatusAccessDenied(t *testing.T) {
|
||||
defer func() {
|
||||
newAccessValidator = originalFactory
|
||||
}()
|
||||
newAccessValidator = func(access AccessConfig) (accessValidator, error) {
|
||||
newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
|
||||
return &fakeAccessValidator{err: E.New("forbidden")}, nil
|
||||
}
|
||||
|
||||
@@ -121,7 +123,7 @@ func TestHandleHTTPServiceStreamAccessDenied(t *testing.T) {
|
||||
defer func() {
|
||||
newAccessValidator = originalFactory
|
||||
}()
|
||||
newAccessValidator = func(access AccessConfig) (accessValidator, error) {
|
||||
newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
|
||||
return &fakeAccessValidator{err: E.New("forbidden")}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/sagernet/sing-box/log"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/net/http2"
|
||||
@@ -71,8 +72,7 @@ func NewHTTP2Connection(
|
||||
ServerName: h2EdgeSNI,
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
tcpConn, err := dialer.DialContext(ctx, "tcp", edgeAddr.TCP.String())
|
||||
tcpConn, err := inbound.controlDialer.DialContext(ctx, "tcp", M.SocksaddrFrom(edgeAddr.TCP.AddrPort().Addr(), edgeAddr.TCP.AddrPort().Port()))
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "dial edge TCP")
|
||||
}
|
||||
@@ -113,10 +113,13 @@ func (c *HTTP2Connection) Serve(ctx context.Context) error {
|
||||
Handler: c,
|
||||
})
|
||||
|
||||
if c.registrationResult != nil {
|
||||
return nil
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return E.New("edge connection closed before registration")
|
||||
if c.registrationResult == nil {
|
||||
return E.New("edge connection closed before registration")
|
||||
}
|
||||
return E.New("edge connection closed")
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -167,6 +170,12 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque
|
||||
" (connection ", result.ConnectionID, ")")
|
||||
|
||||
<-ctx.Done()
|
||||
unregisterCtx, cancel := context.WithTimeout(context.Background(), c.gracePeriod)
|
||||
defer cancel()
|
||||
err = c.registrationClient.Unregister(unregisterCtx)
|
||||
if err != nil {
|
||||
c.logger.Debug("failed to unregister: ", err)
|
||||
}
|
||||
c.registrationClient.Close()
|
||||
}
|
||||
|
||||
|
||||
@@ -8,13 +8,14 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -88,6 +89,7 @@ func NewQUICConnection(
|
||||
features []string,
|
||||
numPreviousAttempts uint8,
|
||||
gracePeriod time.Duration,
|
||||
controlDialer N.Dialer,
|
||||
logger log.ContextLogger,
|
||||
) (*QUICConnection, error) {
|
||||
rootCAs, err := cloudflareRootCertPool()
|
||||
@@ -111,7 +113,7 @@ func NewQUICConnection(
|
||||
InitialPacketSize: quicInitialPacketSize(edgeAddr.IPVersion),
|
||||
}
|
||||
|
||||
udpConn, err := createUDPConnForConnIndex(connIndex, edgeAddr)
|
||||
udpConn, err := createUDPConnForConnIndex(ctx, connIndex, edgeAddr, controlDialer)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "listen UDP for QUIC edge")
|
||||
}
|
||||
@@ -135,30 +137,19 @@ func NewQUICConnection(
|
||||
}, nil
|
||||
}
|
||||
|
||||
func createUDPConnForConnIndex(connIndex uint8, edgeAddr *EdgeAddr) (*net.UDPConn, error) {
|
||||
func createUDPConnForConnIndex(ctx context.Context, connIndex uint8, edgeAddr *EdgeAddr, controlDialer N.Dialer) (*net.UDPConn, error) {
|
||||
quicPortAccess.Lock()
|
||||
defer quicPortAccess.Unlock()
|
||||
|
||||
network := "udp"
|
||||
if runtime.GOOS == "darwin" {
|
||||
if edgeAddr.IPVersion == 4 {
|
||||
network = "udp4"
|
||||
} else {
|
||||
network = "udp6"
|
||||
}
|
||||
}
|
||||
|
||||
if port, loaded := quicPortByConnIndex[connIndex]; loaded {
|
||||
udpConn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
|
||||
if err == nil {
|
||||
return udpConn, nil
|
||||
}
|
||||
}
|
||||
|
||||
udpConn, err := net.ListenUDP(network, &net.UDPAddr{Port: 0})
|
||||
packetConn, err := controlDialer.ListenPacket(ctx, M.SocksaddrFrom(edgeAddr.UDP.AddrPort().Addr(), edgeAddr.UDP.AddrPort().Port()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpConn, ok := packetConn.(*net.UDPConn)
|
||||
if !ok {
|
||||
packetConn.Close()
|
||||
return nil, fmt.Errorf("unexpected packet conn type %T", packetConn)
|
||||
}
|
||||
udpAddr, ok := udpConn.LocalAddr().(*net.UDPAddr)
|
||||
if !ok {
|
||||
udpConn.Close()
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -37,10 +39,10 @@ type EdgeAddr struct {
|
||||
|
||||
// DiscoverEdge performs SRV-based edge discovery and returns addresses
|
||||
// partitioned into regions (typically 2).
|
||||
func DiscoverEdge(ctx context.Context, region string) ([][]*EdgeAddr, error) {
|
||||
func DiscoverEdge(ctx context.Context, region string, controlDialer N.Dialer) ([][]*EdgeAddr, error) {
|
||||
regions, err := lookupEdgeSRV(region)
|
||||
if err != nil {
|
||||
regions, err = lookupEdgeSRVWithDoT(ctx, region)
|
||||
regions, err = lookupEdgeSRVWithDoT(ctx, region, controlDialer)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "edge discovery")
|
||||
}
|
||||
@@ -59,12 +61,11 @@ func lookupEdgeSRV(region string) ([][]*EdgeAddr, error) {
|
||||
return resolveSRVRecords(addrs)
|
||||
}
|
||||
|
||||
func lookupEdgeSRVWithDoT(ctx context.Context, region string) ([][]*EdgeAddr, error) {
|
||||
func lookupEdgeSRVWithDoT(ctx context.Context, region string, controlDialer N.Dialer) ([][]*EdgeAddr, error) {
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
var dialer net.Dialer
|
||||
conn, err := dialer.DialContext(ctx, "tcp", dotServerAddr)
|
||||
conn, err := controlDialer.DialContext(ctx, "tcp", M.ParseSocksaddr(dotServerAddr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -6,10 +6,12 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func TestDiscoverEdge(t *testing.T) {
|
||||
regions, err := DiscoverEdge(context.Background(), "")
|
||||
regions, err := DiscoverEdge(context.Background(), "", N.SystemDialer)
|
||||
if err != nil {
|
||||
t.Fatal("DiscoverEdge: ", err)
|
||||
}
|
||||
|
||||
@@ -192,6 +192,8 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i
|
||||
configManager: configManager,
|
||||
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
|
||||
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
|
||||
controlDialer: N.SystemDialer,
|
||||
accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer},
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
|
||||
@@ -16,11 +16,13 @@ import (
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
boxDialer "github.com/sagernet/sing-box/common/dialer"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -46,6 +48,7 @@ type Inbound struct {
|
||||
configManager *ConfigManager
|
||||
flowLimiter *FlowLimiter
|
||||
accessCache *accessValidatorCache
|
||||
controlDialer N.Dialer
|
||||
|
||||
connectionAccess sync.Mutex
|
||||
connections []io.Closer
|
||||
@@ -95,6 +98,14 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "build cloudflare tunnel runtime config")
|
||||
}
|
||||
controlDialer, err := boxDialer.NewWithOptions(boxDialer.Options{
|
||||
Context: ctx,
|
||||
Options: options.ControlDialer,
|
||||
RemoteIsDomain: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "build cloudflare tunnel control dialer")
|
||||
}
|
||||
|
||||
region := options.Region
|
||||
if region != "" && credentials.Endpoint != "" {
|
||||
@@ -122,7 +133,8 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||
gracePeriod: gracePeriod,
|
||||
configManager: configManager,
|
||||
flowLimiter: &FlowLimiter{},
|
||||
accessCache: &accessValidatorCache{values: make(map[string]accessValidator)},
|
||||
accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: controlDialer},
|
||||
controlDialer: controlDialer,
|
||||
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
|
||||
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
|
||||
}, nil
|
||||
@@ -135,7 +147,7 @@ func (i *Inbound) Start(stage adapter.StartStage) error {
|
||||
|
||||
i.logger.Info("starting Cloudflare Tunnel with ", i.haConnections, " HA connections")
|
||||
|
||||
regions, err := DiscoverEdge(i.ctx, i.region)
|
||||
regions, err := DiscoverEdge(i.ctx, i.region, i.controlDialer)
|
||||
if err != nil {
|
||||
return E.Cause(err, "discover edge")
|
||||
}
|
||||
@@ -287,7 +299,7 @@ func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, features []stri
|
||||
connection, err := NewQUICConnection(
|
||||
i.ctx, edgeAddr, connIndex,
|
||||
i.credentials, i.connectorID,
|
||||
features, numPreviousAttempts, i.gracePeriod, i.logger,
|
||||
features, numPreviousAttempts, i.gracePeriod, i.controlDialer, i.logger,
|
||||
)
|
||||
if err != nil {
|
||||
return E.Cause(err, "create QUIC connection")
|
||||
|
||||
Reference in New Issue
Block a user