From 124379fc1d3c57fb6138e7f98cebddc488130bfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 11:18:43 +0800 Subject: [PATCH] Support regional cloudflare edge selection --- protocol/cloudflare/edge_discovery.go | 21 ++++++++++++++------- protocol/cloudflare/edge_discovery_test.go | 11 ++++++++++- protocol/cloudflare/inbound.go | 12 ++++++++++-- 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/protocol/cloudflare/edge_discovery.go b/protocol/cloudflare/edge_discovery.go index 6ca9403d0..0c08bcbf8 100644 --- a/protocol/cloudflare/edge_discovery.go +++ b/protocol/cloudflare/edge_discovery.go @@ -21,6 +21,13 @@ const ( dotTimeout = 15 * time.Second ) +func getRegionalServiceName(region string) string { + if region == "" { + return edgeSRVService + } + return region + "-" + edgeSRVService +} + // EdgeAddr represents a Cloudflare edge server address. type EdgeAddr struct { TCP *net.TCPAddr @@ -30,10 +37,10 @@ type EdgeAddr struct { // DiscoverEdge performs SRV-based edge discovery and returns addresses // partitioned into regions (typically 2). -func DiscoverEdge(ctx context.Context) ([][]*EdgeAddr, error) { - regions, err := lookupEdgeSRV() +func DiscoverEdge(ctx context.Context, region string) ([][]*EdgeAddr, error) { + regions, err := lookupEdgeSRV(region) if err != nil { - regions, err = lookupEdgeSRVWithDoT(ctx) + regions, err = lookupEdgeSRVWithDoT(ctx, region) if err != nil { return nil, E.Cause(err, "edge discovery") } @@ -44,15 +51,15 @@ func DiscoverEdge(ctx context.Context) ([][]*EdgeAddr, error) { return regions, nil } -func lookupEdgeSRV() ([][]*EdgeAddr, error) { - _, addrs, err := net.LookupSRV(edgeSRVService, edgeSRVProto, edgeSRVName) +func lookupEdgeSRV(region string) ([][]*EdgeAddr, error) { + _, addrs, err := net.LookupSRV(getRegionalServiceName(region), edgeSRVProto, edgeSRVName) if err != nil { return nil, err } return resolveSRVRecords(addrs) } -func lookupEdgeSRVWithDoT(ctx context.Context) ([][]*EdgeAddr, error) { +func lookupEdgeSRVWithDoT(ctx context.Context, region string) ([][]*EdgeAddr, error) { resolver := &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { @@ -66,7 +73,7 @@ func lookupEdgeSRVWithDoT(ctx context.Context) ([][]*EdgeAddr, error) { } lookupCtx, cancel := context.WithTimeout(ctx, dotTimeout) defer cancel() - _, addrs, err := resolver.LookupSRV(lookupCtx, edgeSRVService, edgeSRVProto, edgeSRVName) + _, addrs, err := resolver.LookupSRV(lookupCtx, getRegionalServiceName(region), edgeSRVProto, edgeSRVName) if err != nil { return nil, err } diff --git a/protocol/cloudflare/edge_discovery_test.go b/protocol/cloudflare/edge_discovery_test.go index 6d602cfa6..c282009d0 100644 --- a/protocol/cloudflare/edge_discovery_test.go +++ b/protocol/cloudflare/edge_discovery_test.go @@ -9,7 +9,7 @@ import ( ) func TestDiscoverEdge(t *testing.T) { - regions, err := DiscoverEdge(context.Background()) + regions, err := DiscoverEdge(context.Background(), "") if err != nil { t.Fatal("DiscoverEdge: ", err) } @@ -86,3 +86,12 @@ func TestFilterByIPVersion(t *testing.T) { } }) } + +func TestGetRegionalServiceName(t *testing.T) { + if got := getRegionalServiceName(""); got != edgeSRVService { + t.Fatalf("expected global service %s, got %s", edgeSRVService, got) + } + if got := getRegionalServiceName("us"); got != "us-"+edgeSRVService { + t.Fatalf("expected regional service us-%s, got %s", edgeSRVService, got) + } +} diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index 0fcb2ac65..3b349d3e3 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -94,6 +94,14 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo return nil, E.Cause(err, "build cloudflare tunnel runtime config") } + region := options.Region + if region != "" && credentials.Endpoint != "" { + return nil, E.New("region cannot be specified when credentials already include an endpoint") + } + if region == "" { + region = credentials.Endpoint + } + inboundCtx, cancel := context.WithCancel(ctx) return &Inbound{ @@ -106,7 +114,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo connectorID: uuid.New(), haConnections: haConnections, protocol: protocol, - region: options.Region, + region: region, edgeIPVersion: edgeIPVersion, datagramVersion: datagramVersion, gracePeriod: gracePeriod, @@ -123,7 +131,7 @@ func (i *Inbound) Start(stage adapter.StartStage) error { i.logger.Info("starting Cloudflare Tunnel with ", i.haConnections, " HA connections") - regions, err := DiscoverEdge(i.ctx) + regions, err := DiscoverEdge(i.ctx, i.region) if err != nil { return E.Cause(err, "discover edge") }