mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-14 04:38:28 +10:00
cloudflare: require remote-managed tunnels
This commit is contained in:
@@ -3,64 +3,12 @@ package option
|
||||
import "github.com/sagernet/sing/common/json/badoption"
|
||||
|
||||
type CloudflareTunnelInboundOptions struct {
|
||||
Token string `json:"token,omitempty"`
|
||||
CredentialPath string `json:"credential_path,omitempty"`
|
||||
HAConnections int `json:"ha_connections,omitempty"`
|
||||
Protocol string `json:"protocol,omitempty"`
|
||||
ControlDialer DialerOptions `json:"control_dialer,omitempty"`
|
||||
EdgeIPVersion int `json:"edge_ip_version,omitempty"`
|
||||
DatagramVersion string `json:"datagram_version,omitempty"`
|
||||
GracePeriod badoption.Duration `json:"grace_period,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
Ingress []CloudflareTunnelIngressRule `json:"ingress,omitempty"`
|
||||
OriginRequest CloudflareTunnelOriginRequestOptions `json:"origin_request,omitempty"`
|
||||
WarpRouting CloudflareTunnelWarpRoutingOptions `json:"warp_routing,omitempty"`
|
||||
}
|
||||
|
||||
type CloudflareTunnelIngressRule struct {
|
||||
Hostname string `json:"hostname,omitempty"`
|
||||
Path string `json:"path,omitempty"`
|
||||
Service string `json:"service,omitempty"`
|
||||
OriginRequest CloudflareTunnelOriginRequestOptions `json:"origin_request,omitempty"`
|
||||
}
|
||||
|
||||
type CloudflareTunnelOriginRequestOptions struct {
|
||||
ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"`
|
||||
TLSTimeout badoption.Duration `json:"tls_timeout,omitempty"`
|
||||
TCPKeepAlive badoption.Duration `json:"tcp_keep_alive,omitempty"`
|
||||
NoHappyEyeballs bool `json:"no_happy_eyeballs,omitempty"`
|
||||
KeepAliveTimeout badoption.Duration `json:"keep_alive_timeout,omitempty"`
|
||||
KeepAliveConnections int `json:"keep_alive_connections,omitempty"`
|
||||
HTTPHostHeader string `json:"http_host_header,omitempty"`
|
||||
OriginServerName string `json:"origin_server_name,omitempty"`
|
||||
MatchSNIToHost bool `json:"match_sni_to_host,omitempty"`
|
||||
CAPool string `json:"ca_pool,omitempty"`
|
||||
NoTLSVerify bool `json:"no_tls_verify,omitempty"`
|
||||
DisableChunkedEncoding bool `json:"disable_chunked_encoding,omitempty"`
|
||||
BastionMode bool `json:"bastion_mode,omitempty"`
|
||||
ProxyAddress string `json:"proxy_address,omitempty"`
|
||||
ProxyPort uint `json:"proxy_port,omitempty"`
|
||||
ProxyType string `json:"proxy_type,omitempty"`
|
||||
IPRules []CloudflareTunnelIPRule `json:"ip_rules,omitempty"`
|
||||
HTTP2Origin bool `json:"http2_origin,omitempty"`
|
||||
Access CloudflareTunnelAccessRule `json:"access,omitempty"`
|
||||
}
|
||||
|
||||
type CloudflareTunnelAccessRule struct {
|
||||
Required bool `json:"required,omitempty"`
|
||||
TeamName string `json:"team_name,omitempty"`
|
||||
AudTag []string `json:"aud_tag,omitempty"`
|
||||
Environment string `json:"environment,omitempty"`
|
||||
}
|
||||
|
||||
type CloudflareTunnelIPRule struct {
|
||||
Prefix string `json:"prefix,omitempty"`
|
||||
Ports []int `json:"ports,omitempty"`
|
||||
Allow bool `json:"allow,omitempty"`
|
||||
}
|
||||
|
||||
type CloudflareTunnelWarpRoutingOptions struct {
|
||||
ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"`
|
||||
MaxActiveFlows uint64 `json:"max_active_flows,omitempty"`
|
||||
TCPKeepAlive badoption.Duration `json:"tcp_keep_alive,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
HAConnections int `json:"ha_connections,omitempty"`
|
||||
Protocol string `json:"protocol,omitempty"`
|
||||
ControlDialer DialerOptions `json:"control_dialer,omitempty"`
|
||||
EdgeIPVersion int `json:"edge_ip_version,omitempty"`
|
||||
DatagramVersion string `json:"datagram_version,omitempty"`
|
||||
GracePeriod badoption.Duration `json:"grace_period,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
}
|
||||
|
||||
28
protocol/cloudflare/config_decode_test.go
Normal file
28
protocol/cloudflare/config_decode_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
//go:build with_cloudflare_tunnel
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
func TestNewInboundRequiresToken(t *testing.T) {
|
||||
_, err := NewInbound(context.Background(), nil, log.NewNOPFactory().NewLogger("test"), "test", option.CloudflareTunnelInboundOptions{})
|
||||
if err == nil {
|
||||
t.Fatal("expected missing token error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRegistrationResultRejectsNonRemoteManaged(t *testing.T) {
|
||||
err := validateRegistrationResult(&RegistrationResult{TunnelIsRemotelyManaged: false})
|
||||
if err == nil {
|
||||
t.Fatal("expected unsupported tunnel error")
|
||||
}
|
||||
if err != ErrNonRemoteManagedTunnelUnsupported {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -44,6 +44,7 @@ type HTTP2Connection struct {
|
||||
numPreviousAttempts uint8
|
||||
registrationClient *RegistrationClient
|
||||
registrationResult *RegistrationResult
|
||||
controlStreamErr error
|
||||
|
||||
activeRequests sync.WaitGroup
|
||||
closeOnce sync.Once
|
||||
@@ -113,6 +114,9 @@ func (c *HTTP2Connection) Serve(ctx context.Context) error {
|
||||
Handler: c,
|
||||
})
|
||||
|
||||
if c.controlStreamErr != nil {
|
||||
return c.controlStreamErr
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
@@ -161,10 +165,23 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque
|
||||
ctx, c.credentials.Auth(), c.credentials.TunnelID, c.connIndex, options,
|
||||
)
|
||||
if err != nil {
|
||||
c.controlStreamErr = err
|
||||
c.logger.Error("register connection: ", err)
|
||||
if c.registrationClient != nil {
|
||||
c.registrationClient.Close()
|
||||
}
|
||||
go c.close()
|
||||
return
|
||||
}
|
||||
if err := validateRegistrationResult(result); err != nil {
|
||||
c.controlStreamErr = err
|
||||
c.logger.Error("register connection: ", err)
|
||||
c.registrationClient.Close()
|
||||
go c.close()
|
||||
return
|
||||
}
|
||||
c.registrationResult = result
|
||||
c.inbound.notifyConnected(c.connIndex)
|
||||
|
||||
c.logger.Info("connected to ", result.Location,
|
||||
" (connection ", result.ConnectionID, ")")
|
||||
|
||||
@@ -50,6 +50,7 @@ type QUICConnection struct {
|
||||
gracePeriod time.Duration
|
||||
registrationClient *RegistrationClient
|
||||
registrationResult *RegistrationResult
|
||||
onConnected func()
|
||||
|
||||
closeOnce sync.Once
|
||||
}
|
||||
@@ -90,6 +91,7 @@ func NewQUICConnection(
|
||||
numPreviousAttempts uint8,
|
||||
gracePeriod time.Duration,
|
||||
controlDialer N.Dialer,
|
||||
onConnected func(),
|
||||
logger log.ContextLogger,
|
||||
) (*QUICConnection, error) {
|
||||
rootCAs, err := cloudflareRootCertPool()
|
||||
@@ -134,6 +136,7 @@ func NewQUICConnection(
|
||||
features: features,
|
||||
numPreviousAttempts: numPreviousAttempts,
|
||||
gracePeriod: gracePeriod,
|
||||
onConnected: onConnected,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -170,6 +173,7 @@ func (q *QUICConnection) Serve(ctx context.Context, handler StreamHandler) error
|
||||
err = q.register(ctx, controlStream)
|
||||
if err != nil {
|
||||
controlStream.Close()
|
||||
q.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -208,7 +212,13 @@ func (q *QUICConnection) register(ctx context.Context, stream *quic.Stream) erro
|
||||
if err != nil {
|
||||
return E.Cause(err, "register connection")
|
||||
}
|
||||
if err := validateRegistrationResult(result); err != nil {
|
||||
return err
|
||||
}
|
||||
q.registrationResult = result
|
||||
if q.onConnected != nil {
|
||||
q.onConnected()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -150,6 +150,13 @@ func (c *RegistrationClient) Close() error {
|
||||
)
|
||||
}
|
||||
|
||||
func validateRegistrationResult(result *RegistrationResult) error {
|
||||
if result == nil || result.TunnelIsRemotelyManaged {
|
||||
return nil
|
||||
}
|
||||
return ErrNonRemoteManagedTunnelUnsupported
|
||||
}
|
||||
|
||||
// BuildConnectionOptions creates the ConnectionOptions to send during registration.
|
||||
func BuildConnectionOptions(connectorID uuid.UUID, features []string, numPreviousAttempts uint8, originLocalIP net.IP) *RegistrationConnectionOptions {
|
||||
return &RegistrationConnectionOptions{
|
||||
|
||||
@@ -4,8 +4,6 @@ package cloudflare
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -43,52 +41,3 @@ func TestParseTokenInvalidJSON(t *testing.T) {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCredentialFile(t *testing.T) {
|
||||
tunnelID := uuid.New()
|
||||
content := `{"AccountTag":"acct","TunnelSecret":"c2VjcmV0","TunnelID":"` + tunnelID.String() + `"}`
|
||||
path := filepath.Join(t.TempDir(), "creds.json")
|
||||
err := os.WriteFile(path, []byte(content), 0o644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
credentials, err := parseCredentialFile(path)
|
||||
if err != nil {
|
||||
t.Fatal("parseCredentialFile: ", err)
|
||||
}
|
||||
if credentials.AccountTag != "acct" {
|
||||
t.Error("expected AccountTag acct, got ", credentials.AccountTag)
|
||||
}
|
||||
if credentials.TunnelID != tunnelID {
|
||||
t.Error("expected TunnelID ", tunnelID, ", got ", credentials.TunnelID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCredentialFileMissingTunnelID(t *testing.T) {
|
||||
content := `{"AccountTag":"acct","TunnelSecret":"c2VjcmV0","TunnelID":"00000000-0000-0000-0000-000000000000"}`
|
||||
path := filepath.Join(t.TempDir(), "creds.json")
|
||||
err := os.WriteFile(path, []byte(content), 0o644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = parseCredentialFile(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing tunnel ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCredentialsBothSpecified(t *testing.T) {
|
||||
_, err := parseCredentials("sometoken", "/some/path")
|
||||
if err == nil {
|
||||
t.Fatal("expected error when both specified")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCredentialsNoneSpecified(t *testing.T) {
|
||||
_, err := parseCredentials("", "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error when none specified")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ func newLimitedInbound(t *testing.T, limit uint64) *Inbound {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{})
|
||||
configManager, err := NewConfigManager()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -170,7 +170,7 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i
|
||||
t.Fatal("create logger: ", err)
|
||||
}
|
||||
|
||||
configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{})
|
||||
configManager, err := NewConfigManager()
|
||||
if err != nil {
|
||||
t.Fatal("create config manager: ", err)
|
||||
}
|
||||
|
||||
@@ -6,12 +6,12 @@ import (
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -33,6 +33,8 @@ func RegisterInbound(registry *inbound.Registry) {
|
||||
inbound.Register[option.CloudflareTunnelInboundOptions](registry, C.TypeCloudflareTunnel, NewInbound)
|
||||
}
|
||||
|
||||
var ErrNonRemoteManagedTunnelUnsupported = errors.New("cloudflare tunnel only supports remote-managed tunnels")
|
||||
|
||||
type Inbound struct {
|
||||
inbound.Adapter
|
||||
ctx context.Context
|
||||
@@ -63,12 +65,19 @@ type Inbound struct {
|
||||
helloWorldAccess sync.Mutex
|
||||
helloWorldServer *http.Server
|
||||
helloWorldURL *url.URL
|
||||
|
||||
connectedAccess sync.Mutex
|
||||
connectedIndices map[uint8]struct{}
|
||||
connectedNotify chan uint8
|
||||
}
|
||||
|
||||
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflareTunnelInboundOptions) (adapter.Inbound, error) {
|
||||
credentials, err := parseCredentials(options.Token, options.CredentialPath)
|
||||
if options.Token == "" {
|
||||
return nil, E.New("missing token")
|
||||
}
|
||||
credentials, err := parseToken(options.Token)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "parse credentials")
|
||||
return nil, E.Cause(err, "parse token")
|
||||
}
|
||||
|
||||
haConnections := options.HAConnections
|
||||
@@ -96,7 +105,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||
gracePeriod = 30 * time.Second
|
||||
}
|
||||
|
||||
configManager, err := NewConfigManager(options)
|
||||
configManager, err := NewConfigManager()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "build cloudflare tunnel runtime config")
|
||||
}
|
||||
@@ -139,6 +148,8 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||
controlDialer: controlDialer,
|
||||
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
|
||||
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
|
||||
connectedIndices: make(map[uint8]struct{}),
|
||||
connectedNotify: make(chan uint8, haConnections),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -164,24 +175,36 @@ func (i *Inbound) Start(stage adapter.StartStage) error {
|
||||
for connIndex := 0; connIndex < i.haConnections; connIndex++ {
|
||||
i.done.Add(1)
|
||||
go i.superviseConnection(uint8(connIndex), edgeAddrs, features)
|
||||
if connIndex == 0 {
|
||||
// Wait a bit for the first connection before starting others
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
case <-i.ctx.Done():
|
||||
select {
|
||||
case readyConnIndex := <-i.connectedNotify:
|
||||
if readyConnIndex != uint8(connIndex) {
|
||||
i.logger.Debug("received unexpected ready notification for connection ", readyConnIndex)
|
||||
}
|
||||
case <-time.After(firstConnectionReadyTimeout):
|
||||
case <-i.ctx.Done():
|
||||
if connIndex == 0 {
|
||||
return i.ctx.Err()
|
||||
}
|
||||
} else {
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
case <-i.ctx.Done():
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Inbound) notifyConnected(connIndex uint8) {
|
||||
if i.connectedNotify == nil {
|
||||
return
|
||||
}
|
||||
i.connectedAccess.Lock()
|
||||
if _, loaded := i.connectedIndices[connIndex]; loaded {
|
||||
i.connectedAccess.Unlock()
|
||||
return
|
||||
}
|
||||
i.connectedIndices[connIndex] = struct{}{}
|
||||
i.connectedAccess.Unlock()
|
||||
i.connectedNotify <- connIndex
|
||||
}
|
||||
|
||||
func (i *Inbound) ApplyConfig(version int32, config []byte) ConfigUpdateResult {
|
||||
result := i.configManager.Apply(version, config)
|
||||
if result.Err != nil {
|
||||
@@ -249,8 +272,9 @@ func (i *Inbound) ensureHelloWorldURL() (*url.URL, error) {
|
||||
}
|
||||
|
||||
const (
|
||||
backoffBaseTime = time.Second
|
||||
backoffMaxTime = 2 * time.Minute
|
||||
backoffBaseTime = time.Second
|
||||
backoffMaxTime = 2 * time.Minute
|
||||
firstConnectionReadyTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, features []string) {
|
||||
@@ -269,6 +293,11 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, fe
|
||||
if err == nil || i.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if errors.Is(err, ErrNonRemoteManagedTunnelUnsupported) {
|
||||
i.logger.Error("connection ", connIndex, " failed permanently: ", err)
|
||||
i.cancel()
|
||||
return
|
||||
}
|
||||
|
||||
retries++
|
||||
backoff := backoffDuration(retries)
|
||||
@@ -294,6 +323,9 @@ func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr, features
|
||||
if err == nil || i.ctx.Err() != nil {
|
||||
return err
|
||||
}
|
||||
if errors.Is(err, ErrNonRemoteManagedTunnelUnsupported) {
|
||||
return err
|
||||
}
|
||||
i.logger.Warn("QUIC connection failed, falling back to HTTP/2: ", err)
|
||||
return i.serveHTTP2(connIndex, edgeAddr, features, numPreviousAttempts)
|
||||
case "http2":
|
||||
@@ -309,7 +341,9 @@ func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, features []stri
|
||||
connection, err := NewQUICConnection(
|
||||
i.ctx, edgeAddr, connIndex,
|
||||
i.credentials, i.connectorID,
|
||||
features, numPreviousAttempts, i.gracePeriod, i.controlDialer, i.logger,
|
||||
features, numPreviousAttempts, i.gracePeriod, i.controlDialer, func() {
|
||||
i.notifyConnected(connIndex)
|
||||
}, i.logger,
|
||||
)
|
||||
if err != nil {
|
||||
return E.Cause(err, "create QUIC connection")
|
||||
@@ -377,19 +411,6 @@ func flattenRegions(regions [][]*EdgeAddr) []*EdgeAddr {
|
||||
return result
|
||||
}
|
||||
|
||||
func parseCredentials(token string, credentialPath string) (Credentials, error) {
|
||||
if token == "" && credentialPath == "" {
|
||||
return Credentials{}, E.New("either token or credential_path must be specified")
|
||||
}
|
||||
if token != "" && credentialPath != "" {
|
||||
return Credentials{}, E.New("token and credential_path are mutually exclusive")
|
||||
}
|
||||
if token != "" {
|
||||
return parseToken(token)
|
||||
}
|
||||
return parseCredentialFile(credentialPath)
|
||||
}
|
||||
|
||||
func parseToken(token string) (Credentials, error) {
|
||||
data, err := base64.StdEncoding.DecodeString(token)
|
||||
if err != nil {
|
||||
@@ -402,19 +423,3 @@ func parseToken(token string) (Credentials, error) {
|
||||
}
|
||||
return tunnelToken.ToCredentials(), nil
|
||||
}
|
||||
|
||||
func parseCredentialFile(path string) (Credentials, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return Credentials{}, E.Cause(err, "read credential file")
|
||||
}
|
||||
var credentials Credentials
|
||||
err = json.Unmarshal(data, &credentials)
|
||||
if err != nil {
|
||||
return Credentials{}, E.Cause(err, "unmarshal credential file")
|
||||
}
|
||||
if credentials.TunnelID == (uuid.UUID{}) {
|
||||
return Credentials{}, E.New("credential file missing tunnel ID")
|
||||
}
|
||||
return credentials, nil
|
||||
}
|
||||
|
||||
@@ -6,12 +6,11 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
func newTestIngressInbound(t *testing.T) *Inbound {
|
||||
t.Helper()
|
||||
configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{})
|
||||
configManager, err := NewConfigManager()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -85,6 +84,18 @@ func TestApplyConfigInvalidJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfigIsCatchAll503(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound(t)
|
||||
|
||||
service, loaded := inboundInstance.configManager.Resolve("any.example.com", "/")
|
||||
if !loaded {
|
||||
t.Fatal("expected default config to resolve catch-all rule")
|
||||
}
|
||||
if service.StatusCode != 503 {
|
||||
t.Fatalf("expected catch-all 503, got %#v", service)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveExactAndWildcard(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound(t)
|
||||
inboundInstance.configManager.activeConfig = RuntimeConfig{
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
@@ -155,8 +154,8 @@ type ConfigManager struct {
|
||||
activeConfig RuntimeConfig
|
||||
}
|
||||
|
||||
func NewConfigManager(options option.CloudflareTunnelInboundOptions) (*ConfigManager, error) {
|
||||
config, err := buildLocalRuntimeConfig(options)
|
||||
func NewConfigManager() (*ConfigManager, error) {
|
||||
config, err := defaultRuntimeConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -237,26 +236,19 @@ func matchIngressHost(pattern, hostname string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func buildLocalRuntimeConfig(options option.CloudflareTunnelInboundOptions) (RuntimeConfig, error) {
|
||||
defaultOriginRequest := originRequestFromOption(options.OriginRequest)
|
||||
warpRouting := warpRoutingFromOption(options.WarpRouting)
|
||||
var ingressRules []localIngressRule
|
||||
for _, rule := range options.Ingress {
|
||||
ingressRules = append(ingressRules, localIngressRule{
|
||||
Hostname: rule.Hostname,
|
||||
Path: rule.Path,
|
||||
Service: rule.Service,
|
||||
OriginRequest: mergeOptionOriginRequest(defaultOriginRequest, rule.OriginRequest),
|
||||
})
|
||||
}
|
||||
compiledRules, err := compileIngressRules(defaultOriginRequest, ingressRules)
|
||||
func defaultRuntimeConfig() (RuntimeConfig, error) {
|
||||
defaultOriginRequest := defaultOriginRequestConfig()
|
||||
compiledRules, err := compileIngressRules(defaultOriginRequest, nil)
|
||||
if err != nil {
|
||||
return RuntimeConfig{}, err
|
||||
}
|
||||
return RuntimeConfig{
|
||||
Ingress: compiledRules,
|
||||
OriginRequest: defaultOriginRequest,
|
||||
WarpRouting: warpRouting,
|
||||
WarpRouting: WarpRoutingConfig{
|
||||
ConnectTimeout: defaultWarpRoutingConnectTime,
|
||||
TCPKeepAlive: defaultWarpRoutingTCPKeepAlive,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -554,117 +546,6 @@ func defaultOriginRequestConfig() OriginRequestConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func originRequestFromOption(input option.CloudflareTunnelOriginRequestOptions) OriginRequestConfig {
|
||||
config := defaultOriginRequestConfig()
|
||||
if input.ConnectTimeout != 0 {
|
||||
config.ConnectTimeout = time.Duration(input.ConnectTimeout)
|
||||
}
|
||||
if input.TLSTimeout != 0 {
|
||||
config.TLSTimeout = time.Duration(input.TLSTimeout)
|
||||
}
|
||||
if input.TCPKeepAlive != 0 {
|
||||
config.TCPKeepAlive = time.Duration(input.TCPKeepAlive)
|
||||
}
|
||||
if input.KeepAliveTimeout != 0 {
|
||||
config.KeepAliveTimeout = time.Duration(input.KeepAliveTimeout)
|
||||
}
|
||||
if input.KeepAliveConnections != 0 {
|
||||
config.KeepAliveConnections = input.KeepAliveConnections
|
||||
}
|
||||
config.NoHappyEyeballs = input.NoHappyEyeballs
|
||||
config.HTTPHostHeader = input.HTTPHostHeader
|
||||
config.OriginServerName = input.OriginServerName
|
||||
config.MatchSNIToHost = input.MatchSNIToHost
|
||||
config.CAPool = input.CAPool
|
||||
config.NoTLSVerify = input.NoTLSVerify
|
||||
config.DisableChunkedEncoding = input.DisableChunkedEncoding
|
||||
config.BastionMode = input.BastionMode
|
||||
if input.ProxyAddress != "" {
|
||||
config.ProxyAddress = input.ProxyAddress
|
||||
}
|
||||
if input.ProxyPort != 0 {
|
||||
config.ProxyPort = input.ProxyPort
|
||||
}
|
||||
config.ProxyType = input.ProxyType
|
||||
config.HTTP2Origin = input.HTTP2Origin
|
||||
config.Access = AccessConfig{
|
||||
Required: input.Access.Required,
|
||||
TeamName: input.Access.TeamName,
|
||||
AudTag: append([]string(nil), input.Access.AudTag...),
|
||||
Environment: input.Access.Environment,
|
||||
}
|
||||
for _, rule := range input.IPRules {
|
||||
config.IPRules = append(config.IPRules, IPRule{
|
||||
Prefix: rule.Prefix,
|
||||
Ports: append([]int(nil), rule.Ports...),
|
||||
Allow: rule.Allow,
|
||||
})
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
func mergeOptionOriginRequest(base OriginRequestConfig, override option.CloudflareTunnelOriginRequestOptions) OriginRequestConfig {
|
||||
result := base
|
||||
if override.ConnectTimeout != 0 {
|
||||
result.ConnectTimeout = time.Duration(override.ConnectTimeout)
|
||||
}
|
||||
if override.TLSTimeout != 0 {
|
||||
result.TLSTimeout = time.Duration(override.TLSTimeout)
|
||||
}
|
||||
if override.TCPKeepAlive != 0 {
|
||||
result.TCPKeepAlive = time.Duration(override.TCPKeepAlive)
|
||||
}
|
||||
if override.KeepAliveTimeout != 0 {
|
||||
result.KeepAliveTimeout = time.Duration(override.KeepAliveTimeout)
|
||||
}
|
||||
if override.KeepAliveConnections != 0 {
|
||||
result.KeepAliveConnections = override.KeepAliveConnections
|
||||
}
|
||||
result.NoHappyEyeballs = override.NoHappyEyeballs
|
||||
if override.HTTPHostHeader != "" {
|
||||
result.HTTPHostHeader = override.HTTPHostHeader
|
||||
}
|
||||
if override.OriginServerName != "" {
|
||||
result.OriginServerName = override.OriginServerName
|
||||
}
|
||||
result.MatchSNIToHost = override.MatchSNIToHost
|
||||
if override.CAPool != "" {
|
||||
result.CAPool = override.CAPool
|
||||
}
|
||||
result.NoTLSVerify = override.NoTLSVerify
|
||||
result.DisableChunkedEncoding = override.DisableChunkedEncoding
|
||||
result.BastionMode = override.BastionMode
|
||||
if override.ProxyAddress != "" {
|
||||
result.ProxyAddress = override.ProxyAddress
|
||||
}
|
||||
if override.ProxyPort != 0 {
|
||||
result.ProxyPort = override.ProxyPort
|
||||
}
|
||||
if override.ProxyType != "" {
|
||||
result.ProxyType = override.ProxyType
|
||||
}
|
||||
if len(override.IPRules) > 0 {
|
||||
result.IPRules = nil
|
||||
for _, rule := range override.IPRules {
|
||||
result.IPRules = append(result.IPRules, IPRule{
|
||||
Prefix: rule.Prefix,
|
||||
Ports: append([]int(nil), rule.Ports...),
|
||||
Allow: rule.Allow,
|
||||
})
|
||||
}
|
||||
}
|
||||
result.HTTP2Origin = override.HTTP2Origin
|
||||
if override.Access.Required || override.Access.TeamName != "" || len(override.Access.AudTag) > 0 || override.Access.Environment != "" {
|
||||
result.Access = AccessConfig{
|
||||
Required: override.Access.Required,
|
||||
TeamName: override.Access.TeamName,
|
||||
AudTag: append([]string(nil), override.Access.AudTag...),
|
||||
Environment: override.Access.Environment,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func originRequestFromRemote(input remoteOriginRequestJSON) OriginRequestConfig {
|
||||
config := defaultOriginRequestConfig()
|
||||
if input.ConnectTimeout != 0 {
|
||||
@@ -802,21 +683,6 @@ func mergeRemoteOriginRequest(base OriginRequestConfig, override remoteOriginReq
|
||||
return result
|
||||
}
|
||||
|
||||
func warpRoutingFromOption(input option.CloudflareTunnelWarpRoutingOptions) WarpRoutingConfig {
|
||||
config := WarpRoutingConfig{
|
||||
ConnectTimeout: defaultWarpRoutingConnectTime,
|
||||
TCPKeepAlive: defaultWarpRoutingTCPKeepAlive,
|
||||
MaxActiveFlows: input.MaxActiveFlows,
|
||||
}
|
||||
if input.ConnectTimeout != 0 {
|
||||
config.ConnectTimeout = time.Duration(input.ConnectTimeout)
|
||||
}
|
||||
if input.TCPKeepAlive != 0 {
|
||||
config.TCPKeepAlive = time.Duration(input.TCPKeepAlive)
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
func warpRoutingFromRemote(input remoteWarpRoutingJSON) WarpRoutingConfig {
|
||||
config := WarpRoutingConfig{
|
||||
ConnectTimeout: defaultWarpRoutingConnectTime,
|
||||
|
||||
@@ -59,7 +59,7 @@ func newSpecialServiceInboundWithRouter(t *testing.T, router adapter.Router) *In
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{})
|
||||
configManager, err := NewConfigManager()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -102,11 +102,9 @@ func startEchoListener(t *testing.T) net.Listener {
|
||||
return listener
|
||||
}
|
||||
|
||||
func newSocksProxyService(t *testing.T, rules []option.CloudflareTunnelIPRule) ResolvedService {
|
||||
func newSocksProxyService(t *testing.T, rules []IPRule) ResolvedService {
|
||||
t.Helper()
|
||||
service, err := parseResolvedService("socks-proxy", originRequestFromOption(option.CloudflareTunnelOriginRequestOptions{
|
||||
IPRules: rules,
|
||||
}))
|
||||
service, err := parseResolvedService("socks-proxy", OriginRequestConfig{IPRules: rules})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -247,7 +245,7 @@ func TestHandleSocksProxyStream(t *testing.T) {
|
||||
|
||||
_, portText, _ := net.SplitHostPort(listener.Addr().String())
|
||||
port, _ := strconv.Atoi(portText)
|
||||
service := newSocksProxyService(t, []option.CloudflareTunnelIPRule{{
|
||||
service := newSocksProxyService(t, []IPRule{{
|
||||
Prefix: "127.0.0.0/8",
|
||||
Ports: []int{port},
|
||||
Allow: true,
|
||||
@@ -286,7 +284,7 @@ func TestHandleSocksProxyStreamDenyRule(t *testing.T) {
|
||||
|
||||
_, portText, _ := net.SplitHostPort(listener.Addr().String())
|
||||
port, _ := strconv.Atoi(portText)
|
||||
service := newSocksProxyService(t, []option.CloudflareTunnelIPRule{{
|
||||
service := newSocksProxyService(t, []IPRule{{
|
||||
Prefix: "127.0.0.0/8",
|
||||
Ports: []int{port},
|
||||
Allow: false,
|
||||
@@ -317,7 +315,7 @@ func TestHandleSocksProxyStreamPortMismatchDefaultDeny(t *testing.T) {
|
||||
|
||||
_, portText, _ := net.SplitHostPort(listener.Addr().String())
|
||||
port, _ := strconv.Atoi(portText)
|
||||
service := newSocksProxyService(t, []option.CloudflareTunnelIPRule{{
|
||||
service := newSocksProxyService(t, []IPRule{{
|
||||
Prefix: "127.0.0.0/8",
|
||||
Ports: []int{port + 1},
|
||||
Allow: true,
|
||||
@@ -372,11 +370,11 @@ func TestHandleSocksProxyStreamRuleOrderFirstMatchWins(t *testing.T) {
|
||||
|
||||
_, portText, _ := net.SplitHostPort(listener.Addr().String())
|
||||
port, _ := strconv.Atoi(portText)
|
||||
allowFirst := newSocksProxyService(t, []option.CloudflareTunnelIPRule{
|
||||
allowFirst := newSocksProxyService(t, []IPRule{
|
||||
{Prefix: "127.0.0.0/8", Ports: []int{port}, Allow: true},
|
||||
{Prefix: "127.0.0.1/32", Ports: []int{port}, Allow: false},
|
||||
})
|
||||
denyFirst := newSocksProxyService(t, []option.CloudflareTunnelIPRule{
|
||||
denyFirst := newSocksProxyService(t, []IPRule{
|
||||
{Prefix: "127.0.0.1/32", Ports: []int{port}, Allow: false},
|
||||
{Prefix: "127.0.0.0/8", Ports: []int{port}, Allow: true},
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user