mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-14 04:38:28 +10:00
Support regional cloudflare edge selection
This commit is contained in:
@@ -21,6 +21,13 @@ const (
|
||||
dotTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
func getRegionalServiceName(region string) string {
|
||||
if region == "" {
|
||||
return edgeSRVService
|
||||
}
|
||||
return region + "-" + edgeSRVService
|
||||
}
|
||||
|
||||
// EdgeAddr represents a Cloudflare edge server address.
|
||||
type EdgeAddr struct {
|
||||
TCP *net.TCPAddr
|
||||
@@ -30,10 +37,10 @@ type EdgeAddr struct {
|
||||
|
||||
// DiscoverEdge performs SRV-based edge discovery and returns addresses
|
||||
// partitioned into regions (typically 2).
|
||||
func DiscoverEdge(ctx context.Context) ([][]*EdgeAddr, error) {
|
||||
regions, err := lookupEdgeSRV()
|
||||
func DiscoverEdge(ctx context.Context, region string) ([][]*EdgeAddr, error) {
|
||||
regions, err := lookupEdgeSRV(region)
|
||||
if err != nil {
|
||||
regions, err = lookupEdgeSRVWithDoT(ctx)
|
||||
regions, err = lookupEdgeSRVWithDoT(ctx, region)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "edge discovery")
|
||||
}
|
||||
@@ -44,15 +51,15 @@ func DiscoverEdge(ctx context.Context) ([][]*EdgeAddr, error) {
|
||||
return regions, nil
|
||||
}
|
||||
|
||||
func lookupEdgeSRV() ([][]*EdgeAddr, error) {
|
||||
_, addrs, err := net.LookupSRV(edgeSRVService, edgeSRVProto, edgeSRVName)
|
||||
func lookupEdgeSRV(region string) ([][]*EdgeAddr, error) {
|
||||
_, addrs, err := net.LookupSRV(getRegionalServiceName(region), edgeSRVProto, edgeSRVName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resolveSRVRecords(addrs)
|
||||
}
|
||||
|
||||
func lookupEdgeSRVWithDoT(ctx context.Context) ([][]*EdgeAddr, error) {
|
||||
func lookupEdgeSRVWithDoT(ctx context.Context, region string) ([][]*EdgeAddr, error) {
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
@@ -66,7 +73,7 @@ func lookupEdgeSRVWithDoT(ctx context.Context) ([][]*EdgeAddr, error) {
|
||||
}
|
||||
lookupCtx, cancel := context.WithTimeout(ctx, dotTimeout)
|
||||
defer cancel()
|
||||
_, addrs, err := resolver.LookupSRV(lookupCtx, edgeSRVService, edgeSRVProto, edgeSRVName)
|
||||
_, addrs, err := resolver.LookupSRV(lookupCtx, getRegionalServiceName(region), edgeSRVProto, edgeSRVName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func TestDiscoverEdge(t *testing.T) {
|
||||
regions, err := DiscoverEdge(context.Background())
|
||||
regions, err := DiscoverEdge(context.Background(), "")
|
||||
if err != nil {
|
||||
t.Fatal("DiscoverEdge: ", err)
|
||||
}
|
||||
@@ -86,3 +86,12 @@ func TestFilterByIPVersion(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetRegionalServiceName(t *testing.T) {
|
||||
if got := getRegionalServiceName(""); got != edgeSRVService {
|
||||
t.Fatalf("expected global service %s, got %s", edgeSRVService, got)
|
||||
}
|
||||
if got := getRegionalServiceName("us"); got != "us-"+edgeSRVService {
|
||||
t.Fatalf("expected regional service us-%s, got %s", edgeSRVService, got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,6 +94,14 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||
return nil, E.Cause(err, "build cloudflare tunnel runtime config")
|
||||
}
|
||||
|
||||
region := options.Region
|
||||
if region != "" && credentials.Endpoint != "" {
|
||||
return nil, E.New("region cannot be specified when credentials already include an endpoint")
|
||||
}
|
||||
if region == "" {
|
||||
region = credentials.Endpoint
|
||||
}
|
||||
|
||||
inboundCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
return &Inbound{
|
||||
@@ -106,7 +114,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||
connectorID: uuid.New(),
|
||||
haConnections: haConnections,
|
||||
protocol: protocol,
|
||||
region: options.Region,
|
||||
region: region,
|
||||
edgeIPVersion: edgeIPVersion,
|
||||
datagramVersion: datagramVersion,
|
||||
gracePeriod: gracePeriod,
|
||||
@@ -123,7 +131,7 @@ func (i *Inbound) Start(stage adapter.StartStage) error {
|
||||
|
||||
i.logger.Info("starting Cloudflare Tunnel with ", i.haConnections, " HA connections")
|
||||
|
||||
regions, err := DiscoverEdge(i.ctx)
|
||||
regions, err := DiscoverEdge(i.ctx, i.region)
|
||||
if err != nil {
|
||||
return E.Cause(err, "discover edge")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user