From af2afc529b70514bec0a74a8e821e8cd23d303fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 16:26:50 +0800 Subject: [PATCH] cloudflare: require remote-managed tunnels --- option/cloudflare_tunnel.go | 68 ++------- protocol/cloudflare/config_decode_test.go | 28 ++++ protocol/cloudflare/connection_http2.go | 17 +++ protocol/cloudflare/connection_quic.go | 10 ++ protocol/cloudflare/control.go | 7 + protocol/cloudflare/credentials_test.go | 51 ------- protocol/cloudflare/flow_limiter_test.go | 2 +- protocol/cloudflare/helpers_test.go | 2 +- protocol/cloudflare/inbound.go | 99 +++++++------ protocol/cloudflare/ingress_test.go | 15 +- protocol/cloudflare/runtime_config.go | 152 ++------------------ protocol/cloudflare/special_service_test.go | 18 ++- 12 files changed, 154 insertions(+), 315 deletions(-) create mode 100644 protocol/cloudflare/config_decode_test.go diff --git a/option/cloudflare_tunnel.go b/option/cloudflare_tunnel.go index bf388a044..74b511eef 100644 --- a/option/cloudflare_tunnel.go +++ b/option/cloudflare_tunnel.go @@ -3,64 +3,12 @@ package option import "github.com/sagernet/sing/common/json/badoption" type CloudflareTunnelInboundOptions struct { - Token string `json:"token,omitempty"` - CredentialPath string `json:"credential_path,omitempty"` - HAConnections int `json:"ha_connections,omitempty"` - Protocol string `json:"protocol,omitempty"` - ControlDialer DialerOptions `json:"control_dialer,omitempty"` - EdgeIPVersion int `json:"edge_ip_version,omitempty"` - DatagramVersion string `json:"datagram_version,omitempty"` - GracePeriod badoption.Duration `json:"grace_period,omitempty"` - Region string `json:"region,omitempty"` - Ingress []CloudflareTunnelIngressRule `json:"ingress,omitempty"` - OriginRequest CloudflareTunnelOriginRequestOptions `json:"origin_request,omitempty"` - WarpRouting CloudflareTunnelWarpRoutingOptions `json:"warp_routing,omitempty"` -} - -type CloudflareTunnelIngressRule struct { - Hostname string `json:"hostname,omitempty"` - Path string `json:"path,omitempty"` - Service string `json:"service,omitempty"` - OriginRequest CloudflareTunnelOriginRequestOptions `json:"origin_request,omitempty"` -} - -type CloudflareTunnelOriginRequestOptions struct { - ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"` - TLSTimeout badoption.Duration `json:"tls_timeout,omitempty"` - TCPKeepAlive badoption.Duration `json:"tcp_keep_alive,omitempty"` - NoHappyEyeballs bool `json:"no_happy_eyeballs,omitempty"` - KeepAliveTimeout badoption.Duration `json:"keep_alive_timeout,omitempty"` - KeepAliveConnections int `json:"keep_alive_connections,omitempty"` - HTTPHostHeader string `json:"http_host_header,omitempty"` - OriginServerName string `json:"origin_server_name,omitempty"` - MatchSNIToHost bool `json:"match_sni_to_host,omitempty"` - CAPool string `json:"ca_pool,omitempty"` - NoTLSVerify bool `json:"no_tls_verify,omitempty"` - DisableChunkedEncoding bool `json:"disable_chunked_encoding,omitempty"` - BastionMode bool `json:"bastion_mode,omitempty"` - ProxyAddress string `json:"proxy_address,omitempty"` - ProxyPort uint `json:"proxy_port,omitempty"` - ProxyType string `json:"proxy_type,omitempty"` - IPRules []CloudflareTunnelIPRule `json:"ip_rules,omitempty"` - HTTP2Origin bool `json:"http2_origin,omitempty"` - Access CloudflareTunnelAccessRule `json:"access,omitempty"` -} - -type CloudflareTunnelAccessRule struct { - Required bool `json:"required,omitempty"` - TeamName string `json:"team_name,omitempty"` - AudTag []string `json:"aud_tag,omitempty"` - Environment string `json:"environment,omitempty"` -} - -type CloudflareTunnelIPRule struct { - Prefix string `json:"prefix,omitempty"` - Ports []int `json:"ports,omitempty"` - Allow bool `json:"allow,omitempty"` -} - -type CloudflareTunnelWarpRoutingOptions struct { - ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"` - MaxActiveFlows uint64 `json:"max_active_flows,omitempty"` - TCPKeepAlive badoption.Duration `json:"tcp_keep_alive,omitempty"` + Token string `json:"token,omitempty"` + HAConnections int `json:"ha_connections,omitempty"` + Protocol string `json:"protocol,omitempty"` + ControlDialer DialerOptions `json:"control_dialer,omitempty"` + EdgeIPVersion int `json:"edge_ip_version,omitempty"` + DatagramVersion string `json:"datagram_version,omitempty"` + GracePeriod badoption.Duration `json:"grace_period,omitempty"` + Region string `json:"region,omitempty"` } diff --git a/protocol/cloudflare/config_decode_test.go b/protocol/cloudflare/config_decode_test.go new file mode 100644 index 000000000..588e0355e --- /dev/null +++ b/protocol/cloudflare/config_decode_test.go @@ -0,0 +1,28 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "testing" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" +) + +func TestNewInboundRequiresToken(t *testing.T) { + _, err := NewInbound(context.Background(), nil, log.NewNOPFactory().NewLogger("test"), "test", option.CloudflareTunnelInboundOptions{}) + if err == nil { + t.Fatal("expected missing token error") + } +} + +func TestValidateRegistrationResultRejectsNonRemoteManaged(t *testing.T) { + err := validateRegistrationResult(&RegistrationResult{TunnelIsRemotelyManaged: false}) + if err == nil { + t.Fatal("expected unsupported tunnel error") + } + if err != ErrNonRemoteManagedTunnelUnsupported { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index 9fed72cf8..24ddadd6c 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -44,6 +44,7 @@ type HTTP2Connection struct { numPreviousAttempts uint8 registrationClient *RegistrationClient registrationResult *RegistrationResult + controlStreamErr error activeRequests sync.WaitGroup closeOnce sync.Once @@ -113,6 +114,9 @@ func (c *HTTP2Connection) Serve(ctx context.Context) error { Handler: c, }) + if c.controlStreamErr != nil { + return c.controlStreamErr + } if ctx.Err() != nil { return ctx.Err() } @@ -161,10 +165,23 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque ctx, c.credentials.Auth(), c.credentials.TunnelID, c.connIndex, options, ) if err != nil { + c.controlStreamErr = err c.logger.Error("register connection: ", err) + if c.registrationClient != nil { + c.registrationClient.Close() + } + go c.close() + return + } + if err := validateRegistrationResult(result); err != nil { + c.controlStreamErr = err + c.logger.Error("register connection: ", err) + c.registrationClient.Close() + go c.close() return } c.registrationResult = result + c.inbound.notifyConnected(c.connIndex) c.logger.Info("connected to ", result.Location, " (connection ", result.ConnectionID, ")") diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go index fd2e56a3b..2a02a06d0 100644 --- a/protocol/cloudflare/connection_quic.go +++ b/protocol/cloudflare/connection_quic.go @@ -50,6 +50,7 @@ type QUICConnection struct { gracePeriod time.Duration registrationClient *RegistrationClient registrationResult *RegistrationResult + onConnected func() closeOnce sync.Once } @@ -90,6 +91,7 @@ func NewQUICConnection( numPreviousAttempts uint8, gracePeriod time.Duration, controlDialer N.Dialer, + onConnected func(), logger log.ContextLogger, ) (*QUICConnection, error) { rootCAs, err := cloudflareRootCertPool() @@ -134,6 +136,7 @@ func NewQUICConnection( features: features, numPreviousAttempts: numPreviousAttempts, gracePeriod: gracePeriod, + onConnected: onConnected, }, nil } @@ -170,6 +173,7 @@ func (q *QUICConnection) Serve(ctx context.Context, handler StreamHandler) error err = q.register(ctx, controlStream) if err != nil { controlStream.Close() + q.Close() return err } @@ -208,7 +212,13 @@ func (q *QUICConnection) register(ctx context.Context, stream *quic.Stream) erro if err != nil { return E.Cause(err, "register connection") } + if err := validateRegistrationResult(result); err != nil { + return err + } q.registrationResult = result + if q.onConnected != nil { + q.onConnected() + } return nil } diff --git a/protocol/cloudflare/control.go b/protocol/cloudflare/control.go index 6e9881114..f72d627f0 100644 --- a/protocol/cloudflare/control.go +++ b/protocol/cloudflare/control.go @@ -150,6 +150,13 @@ func (c *RegistrationClient) Close() error { ) } +func validateRegistrationResult(result *RegistrationResult) error { + if result == nil || result.TunnelIsRemotelyManaged { + return nil + } + return ErrNonRemoteManagedTunnelUnsupported +} + // BuildConnectionOptions creates the ConnectionOptions to send during registration. func BuildConnectionOptions(connectorID uuid.UUID, features []string, numPreviousAttempts uint8, originLocalIP net.IP) *RegistrationConnectionOptions { return &RegistrationConnectionOptions{ diff --git a/protocol/cloudflare/credentials_test.go b/protocol/cloudflare/credentials_test.go index 31759aa34..506d8601a 100644 --- a/protocol/cloudflare/credentials_test.go +++ b/protocol/cloudflare/credentials_test.go @@ -4,8 +4,6 @@ package cloudflare import ( "encoding/base64" - "os" - "path/filepath" "testing" "github.com/google/uuid" @@ -43,52 +41,3 @@ func TestParseTokenInvalidJSON(t *testing.T) { t.Fatal("expected error for invalid JSON") } } - -func TestParseCredentialFile(t *testing.T) { - tunnelID := uuid.New() - content := `{"AccountTag":"acct","TunnelSecret":"c2VjcmV0","TunnelID":"` + tunnelID.String() + `"}` - path := filepath.Join(t.TempDir(), "creds.json") - err := os.WriteFile(path, []byte(content), 0o644) - if err != nil { - t.Fatal(err) - } - - credentials, err := parseCredentialFile(path) - if err != nil { - t.Fatal("parseCredentialFile: ", err) - } - if credentials.AccountTag != "acct" { - t.Error("expected AccountTag acct, got ", credentials.AccountTag) - } - if credentials.TunnelID != tunnelID { - t.Error("expected TunnelID ", tunnelID, ", got ", credentials.TunnelID) - } -} - -func TestParseCredentialFileMissingTunnelID(t *testing.T) { - content := `{"AccountTag":"acct","TunnelSecret":"c2VjcmV0","TunnelID":"00000000-0000-0000-0000-000000000000"}` - path := filepath.Join(t.TempDir(), "creds.json") - err := os.WriteFile(path, []byte(content), 0o644) - if err != nil { - t.Fatal(err) - } - - _, err = parseCredentialFile(path) - if err == nil { - t.Fatal("expected error for missing tunnel ID") - } -} - -func TestParseCredentialsBothSpecified(t *testing.T) { - _, err := parseCredentials("sometoken", "/some/path") - if err == nil { - t.Fatal("expected error when both specified") - } -} - -func TestParseCredentialsNoneSpecified(t *testing.T) { - _, err := parseCredentials("", "") - if err == nil { - t.Fatal("expected error when none specified") - } -} diff --git a/protocol/cloudflare/flow_limiter_test.go b/protocol/cloudflare/flow_limiter_test.go index ad27c534b..b8e69aeeb 100644 --- a/protocol/cloudflare/flow_limiter_test.go +++ b/protocol/cloudflare/flow_limiter_test.go @@ -23,7 +23,7 @@ func newLimitedInbound(t *testing.T, limit uint64) *Inbound { if err != nil { t.Fatal(err) } - configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{}) + configManager, err := NewConfigManager() if err != nil { t.Fatal(err) } diff --git a/protocol/cloudflare/helpers_test.go b/protocol/cloudflare/helpers_test.go index d873a40b1..8dbec9c7a 100644 --- a/protocol/cloudflare/helpers_test.go +++ b/protocol/cloudflare/helpers_test.go @@ -170,7 +170,7 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i t.Fatal("create logger: ", err) } - configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{}) + configManager, err := NewConfigManager() if err != nil { t.Fatal("create config manager: ", err) } diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index fd6a0f0b5..038a6a909 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -6,12 +6,12 @@ import ( "context" stdTLS "crypto/tls" "encoding/base64" + "errors" "io" "math/rand" "net" "net/http" "net/url" - "os" "sync" "time" @@ -33,6 +33,8 @@ func RegisterInbound(registry *inbound.Registry) { inbound.Register[option.CloudflareTunnelInboundOptions](registry, C.TypeCloudflareTunnel, NewInbound) } +var ErrNonRemoteManagedTunnelUnsupported = errors.New("cloudflare tunnel only supports remote-managed tunnels") + type Inbound struct { inbound.Adapter ctx context.Context @@ -63,12 +65,19 @@ type Inbound struct { helloWorldAccess sync.Mutex helloWorldServer *http.Server helloWorldURL *url.URL + + connectedAccess sync.Mutex + connectedIndices map[uint8]struct{} + connectedNotify chan uint8 } func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflareTunnelInboundOptions) (adapter.Inbound, error) { - credentials, err := parseCredentials(options.Token, options.CredentialPath) + if options.Token == "" { + return nil, E.New("missing token") + } + credentials, err := parseToken(options.Token) if err != nil { - return nil, E.Cause(err, "parse credentials") + return nil, E.Cause(err, "parse token") } haConnections := options.HAConnections @@ -96,7 +105,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo gracePeriod = 30 * time.Second } - configManager, err := NewConfigManager(options) + configManager, err := NewConfigManager() if err != nil { return nil, E.Cause(err, "build cloudflare tunnel runtime config") } @@ -139,6 +148,8 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo controlDialer: controlDialer, datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), + connectedIndices: make(map[uint8]struct{}), + connectedNotify: make(chan uint8, haConnections), }, nil } @@ -164,24 +175,36 @@ func (i *Inbound) Start(stage adapter.StartStage) error { for connIndex := 0; connIndex < i.haConnections; connIndex++ { i.done.Add(1) go i.superviseConnection(uint8(connIndex), edgeAddrs, features) - if connIndex == 0 { - // Wait a bit for the first connection before starting others - select { - case <-time.After(time.Second): - case <-i.ctx.Done(): + select { + case readyConnIndex := <-i.connectedNotify: + if readyConnIndex != uint8(connIndex) { + i.logger.Debug("received unexpected ready notification for connection ", readyConnIndex) + } + case <-time.After(firstConnectionReadyTimeout): + case <-i.ctx.Done(): + if connIndex == 0 { return i.ctx.Err() } - } else { - select { - case <-time.After(time.Second): - case <-i.ctx.Done(): - return nil - } + return nil } } return nil } +func (i *Inbound) notifyConnected(connIndex uint8) { + if i.connectedNotify == nil { + return + } + i.connectedAccess.Lock() + if _, loaded := i.connectedIndices[connIndex]; loaded { + i.connectedAccess.Unlock() + return + } + i.connectedIndices[connIndex] = struct{}{} + i.connectedAccess.Unlock() + i.connectedNotify <- connIndex +} + func (i *Inbound) ApplyConfig(version int32, config []byte) ConfigUpdateResult { result := i.configManager.Apply(version, config) if result.Err != nil { @@ -249,8 +272,9 @@ func (i *Inbound) ensureHelloWorldURL() (*url.URL, error) { } const ( - backoffBaseTime = time.Second - backoffMaxTime = 2 * time.Minute + backoffBaseTime = time.Second + backoffMaxTime = 2 * time.Minute + firstConnectionReadyTimeout = 15 * time.Second ) func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, features []string) { @@ -269,6 +293,11 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, fe if err == nil || i.ctx.Err() != nil { return } + if errors.Is(err, ErrNonRemoteManagedTunnelUnsupported) { + i.logger.Error("connection ", connIndex, " failed permanently: ", err) + i.cancel() + return + } retries++ backoff := backoffDuration(retries) @@ -294,6 +323,9 @@ func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr, features if err == nil || i.ctx.Err() != nil { return err } + if errors.Is(err, ErrNonRemoteManagedTunnelUnsupported) { + return err + } i.logger.Warn("QUIC connection failed, falling back to HTTP/2: ", err) return i.serveHTTP2(connIndex, edgeAddr, features, numPreviousAttempts) case "http2": @@ -309,7 +341,9 @@ func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, features []stri connection, err := NewQUICConnection( i.ctx, edgeAddr, connIndex, i.credentials, i.connectorID, - features, numPreviousAttempts, i.gracePeriod, i.controlDialer, i.logger, + features, numPreviousAttempts, i.gracePeriod, i.controlDialer, func() { + i.notifyConnected(connIndex) + }, i.logger, ) if err != nil { return E.Cause(err, "create QUIC connection") @@ -377,19 +411,6 @@ func flattenRegions(regions [][]*EdgeAddr) []*EdgeAddr { return result } -func parseCredentials(token string, credentialPath string) (Credentials, error) { - if token == "" && credentialPath == "" { - return Credentials{}, E.New("either token or credential_path must be specified") - } - if token != "" && credentialPath != "" { - return Credentials{}, E.New("token and credential_path are mutually exclusive") - } - if token != "" { - return parseToken(token) - } - return parseCredentialFile(credentialPath) -} - func parseToken(token string) (Credentials, error) { data, err := base64.StdEncoding.DecodeString(token) if err != nil { @@ -402,19 +423,3 @@ func parseToken(token string) (Credentials, error) { } return tunnelToken.ToCredentials(), nil } - -func parseCredentialFile(path string) (Credentials, error) { - data, err := os.ReadFile(path) - if err != nil { - return Credentials{}, E.Cause(err, "read credential file") - } - var credentials Credentials - err = json.Unmarshal(data, &credentials) - if err != nil { - return Credentials{}, E.Cause(err, "unmarshal credential file") - } - if credentials.TunnelID == (uuid.UUID{}) { - return Credentials{}, E.New("credential file missing tunnel ID") - } - return credentials, nil -} diff --git a/protocol/cloudflare/ingress_test.go b/protocol/cloudflare/ingress_test.go index e4432cf32..f73db11a1 100644 --- a/protocol/cloudflare/ingress_test.go +++ b/protocol/cloudflare/ingress_test.go @@ -6,12 +6,11 @@ import ( "testing" "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing-box/option" ) func newTestIngressInbound(t *testing.T) *Inbound { t.Helper() - configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{}) + configManager, err := NewConfigManager() if err != nil { t.Fatal(err) } @@ -85,6 +84,18 @@ func TestApplyConfigInvalidJSON(t *testing.T) { } } +func TestDefaultConfigIsCatchAll503(t *testing.T) { + inboundInstance := newTestIngressInbound(t) + + service, loaded := inboundInstance.configManager.Resolve("any.example.com", "/") + if !loaded { + t.Fatal("expected default config to resolve catch-all rule") + } + if service.StatusCode != 503 { + t.Fatalf("expected catch-all 503, got %#v", service) + } +} + func TestResolveExactAndWildcard(t *testing.T) { inboundInstance := newTestIngressInbound(t) inboundInstance.configManager.activeConfig = RuntimeConfig{ diff --git a/protocol/cloudflare/runtime_config.go b/protocol/cloudflare/runtime_config.go index 52583fa3e..d46e1062f 100644 --- a/protocol/cloudflare/runtime_config.go +++ b/protocol/cloudflare/runtime_config.go @@ -12,7 +12,6 @@ import ( "sync" "time" - "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" @@ -155,8 +154,8 @@ type ConfigManager struct { activeConfig RuntimeConfig } -func NewConfigManager(options option.CloudflareTunnelInboundOptions) (*ConfigManager, error) { - config, err := buildLocalRuntimeConfig(options) +func NewConfigManager() (*ConfigManager, error) { + config, err := defaultRuntimeConfig() if err != nil { return nil, err } @@ -237,26 +236,19 @@ func matchIngressHost(pattern, hostname string) bool { return false } -func buildLocalRuntimeConfig(options option.CloudflareTunnelInboundOptions) (RuntimeConfig, error) { - defaultOriginRequest := originRequestFromOption(options.OriginRequest) - warpRouting := warpRoutingFromOption(options.WarpRouting) - var ingressRules []localIngressRule - for _, rule := range options.Ingress { - ingressRules = append(ingressRules, localIngressRule{ - Hostname: rule.Hostname, - Path: rule.Path, - Service: rule.Service, - OriginRequest: mergeOptionOriginRequest(defaultOriginRequest, rule.OriginRequest), - }) - } - compiledRules, err := compileIngressRules(defaultOriginRequest, ingressRules) +func defaultRuntimeConfig() (RuntimeConfig, error) { + defaultOriginRequest := defaultOriginRequestConfig() + compiledRules, err := compileIngressRules(defaultOriginRequest, nil) if err != nil { return RuntimeConfig{}, err } return RuntimeConfig{ Ingress: compiledRules, OriginRequest: defaultOriginRequest, - WarpRouting: warpRouting, + WarpRouting: WarpRoutingConfig{ + ConnectTimeout: defaultWarpRoutingConnectTime, + TCPKeepAlive: defaultWarpRoutingTCPKeepAlive, + }, }, nil } @@ -554,117 +546,6 @@ func defaultOriginRequestConfig() OriginRequestConfig { } } -func originRequestFromOption(input option.CloudflareTunnelOriginRequestOptions) OriginRequestConfig { - config := defaultOriginRequestConfig() - if input.ConnectTimeout != 0 { - config.ConnectTimeout = time.Duration(input.ConnectTimeout) - } - if input.TLSTimeout != 0 { - config.TLSTimeout = time.Duration(input.TLSTimeout) - } - if input.TCPKeepAlive != 0 { - config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) - } - if input.KeepAliveTimeout != 0 { - config.KeepAliveTimeout = time.Duration(input.KeepAliveTimeout) - } - if input.KeepAliveConnections != 0 { - config.KeepAliveConnections = input.KeepAliveConnections - } - config.NoHappyEyeballs = input.NoHappyEyeballs - config.HTTPHostHeader = input.HTTPHostHeader - config.OriginServerName = input.OriginServerName - config.MatchSNIToHost = input.MatchSNIToHost - config.CAPool = input.CAPool - config.NoTLSVerify = input.NoTLSVerify - config.DisableChunkedEncoding = input.DisableChunkedEncoding - config.BastionMode = input.BastionMode - if input.ProxyAddress != "" { - config.ProxyAddress = input.ProxyAddress - } - if input.ProxyPort != 0 { - config.ProxyPort = input.ProxyPort - } - config.ProxyType = input.ProxyType - config.HTTP2Origin = input.HTTP2Origin - config.Access = AccessConfig{ - Required: input.Access.Required, - TeamName: input.Access.TeamName, - AudTag: append([]string(nil), input.Access.AudTag...), - Environment: input.Access.Environment, - } - for _, rule := range input.IPRules { - config.IPRules = append(config.IPRules, IPRule{ - Prefix: rule.Prefix, - Ports: append([]int(nil), rule.Ports...), - Allow: rule.Allow, - }) - } - return config -} - -func mergeOptionOriginRequest(base OriginRequestConfig, override option.CloudflareTunnelOriginRequestOptions) OriginRequestConfig { - result := base - if override.ConnectTimeout != 0 { - result.ConnectTimeout = time.Duration(override.ConnectTimeout) - } - if override.TLSTimeout != 0 { - result.TLSTimeout = time.Duration(override.TLSTimeout) - } - if override.TCPKeepAlive != 0 { - result.TCPKeepAlive = time.Duration(override.TCPKeepAlive) - } - if override.KeepAliveTimeout != 0 { - result.KeepAliveTimeout = time.Duration(override.KeepAliveTimeout) - } - if override.KeepAliveConnections != 0 { - result.KeepAliveConnections = override.KeepAliveConnections - } - result.NoHappyEyeballs = override.NoHappyEyeballs - if override.HTTPHostHeader != "" { - result.HTTPHostHeader = override.HTTPHostHeader - } - if override.OriginServerName != "" { - result.OriginServerName = override.OriginServerName - } - result.MatchSNIToHost = override.MatchSNIToHost - if override.CAPool != "" { - result.CAPool = override.CAPool - } - result.NoTLSVerify = override.NoTLSVerify - result.DisableChunkedEncoding = override.DisableChunkedEncoding - result.BastionMode = override.BastionMode - if override.ProxyAddress != "" { - result.ProxyAddress = override.ProxyAddress - } - if override.ProxyPort != 0 { - result.ProxyPort = override.ProxyPort - } - if override.ProxyType != "" { - result.ProxyType = override.ProxyType - } - if len(override.IPRules) > 0 { - result.IPRules = nil - for _, rule := range override.IPRules { - result.IPRules = append(result.IPRules, IPRule{ - Prefix: rule.Prefix, - Ports: append([]int(nil), rule.Ports...), - Allow: rule.Allow, - }) - } - } - result.HTTP2Origin = override.HTTP2Origin - if override.Access.Required || override.Access.TeamName != "" || len(override.Access.AudTag) > 0 || override.Access.Environment != "" { - result.Access = AccessConfig{ - Required: override.Access.Required, - TeamName: override.Access.TeamName, - AudTag: append([]string(nil), override.Access.AudTag...), - Environment: override.Access.Environment, - } - } - return result -} - func originRequestFromRemote(input remoteOriginRequestJSON) OriginRequestConfig { config := defaultOriginRequestConfig() if input.ConnectTimeout != 0 { @@ -802,21 +683,6 @@ func mergeRemoteOriginRequest(base OriginRequestConfig, override remoteOriginReq return result } -func warpRoutingFromOption(input option.CloudflareTunnelWarpRoutingOptions) WarpRoutingConfig { - config := WarpRoutingConfig{ - ConnectTimeout: defaultWarpRoutingConnectTime, - TCPKeepAlive: defaultWarpRoutingTCPKeepAlive, - MaxActiveFlows: input.MaxActiveFlows, - } - if input.ConnectTimeout != 0 { - config.ConnectTimeout = time.Duration(input.ConnectTimeout) - } - if input.TCPKeepAlive != 0 { - config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) - } - return config -} - func warpRoutingFromRemote(input remoteWarpRoutingJSON) WarpRoutingConfig { config := WarpRoutingConfig{ ConnectTimeout: defaultWarpRoutingConnectTime, diff --git a/protocol/cloudflare/special_service_test.go b/protocol/cloudflare/special_service_test.go index a2f445553..9b29c0a0e 100644 --- a/protocol/cloudflare/special_service_test.go +++ b/protocol/cloudflare/special_service_test.go @@ -59,7 +59,7 @@ func newSpecialServiceInboundWithRouter(t *testing.T, router adapter.Router) *In if err != nil { t.Fatal(err) } - configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{}) + configManager, err := NewConfigManager() if err != nil { t.Fatal(err) } @@ -102,11 +102,9 @@ func startEchoListener(t *testing.T) net.Listener { return listener } -func newSocksProxyService(t *testing.T, rules []option.CloudflareTunnelIPRule) ResolvedService { +func newSocksProxyService(t *testing.T, rules []IPRule) ResolvedService { t.Helper() - service, err := parseResolvedService("socks-proxy", originRequestFromOption(option.CloudflareTunnelOriginRequestOptions{ - IPRules: rules, - })) + service, err := parseResolvedService("socks-proxy", OriginRequestConfig{IPRules: rules}) if err != nil { t.Fatal(err) } @@ -247,7 +245,7 @@ func TestHandleSocksProxyStream(t *testing.T) { _, portText, _ := net.SplitHostPort(listener.Addr().String()) port, _ := strconv.Atoi(portText) - service := newSocksProxyService(t, []option.CloudflareTunnelIPRule{{ + service := newSocksProxyService(t, []IPRule{{ Prefix: "127.0.0.0/8", Ports: []int{port}, Allow: true, @@ -286,7 +284,7 @@ func TestHandleSocksProxyStreamDenyRule(t *testing.T) { _, portText, _ := net.SplitHostPort(listener.Addr().String()) port, _ := strconv.Atoi(portText) - service := newSocksProxyService(t, []option.CloudflareTunnelIPRule{{ + service := newSocksProxyService(t, []IPRule{{ Prefix: "127.0.0.0/8", Ports: []int{port}, Allow: false, @@ -317,7 +315,7 @@ func TestHandleSocksProxyStreamPortMismatchDefaultDeny(t *testing.T) { _, portText, _ := net.SplitHostPort(listener.Addr().String()) port, _ := strconv.Atoi(portText) - service := newSocksProxyService(t, []option.CloudflareTunnelIPRule{{ + service := newSocksProxyService(t, []IPRule{{ Prefix: "127.0.0.0/8", Ports: []int{port + 1}, Allow: true, @@ -372,11 +370,11 @@ func TestHandleSocksProxyStreamRuleOrderFirstMatchWins(t *testing.T) { _, portText, _ := net.SplitHostPort(listener.Addr().String()) port, _ := strconv.Atoi(portText) - allowFirst := newSocksProxyService(t, []option.CloudflareTunnelIPRule{ + allowFirst := newSocksProxyService(t, []IPRule{ {Prefix: "127.0.0.0/8", Ports: []int{port}, Allow: true}, {Prefix: "127.0.0.1/32", Ports: []int{port}, Allow: false}, }) - denyFirst := newSocksProxyService(t, []option.CloudflareTunnelIPRule{ + denyFirst := newSocksProxyService(t, []IPRule{ {Prefix: "127.0.0.1/32", Ports: []int{port}, Allow: false}, {Prefix: "127.0.0.0/8", Ports: []int{port}, Allow: true}, })