Implement router-backed cloudflare tunnel ingress config

This commit is contained in:
世界
2026-03-24 11:17:39 +08:00
parent 87a2f4c336
commit 01a8405069
9 changed files with 1306 additions and 321 deletions

View File

@@ -3,11 +3,63 @@ package option
import "github.com/sagernet/sing/common/json/badoption"
type CloudflareTunnelInboundOptions struct {
Token string `json:"token,omitempty"`
CredentialPath string `json:"credential_path,omitempty"`
HAConnections int `json:"ha_connections,omitempty"`
Protocol string `json:"protocol,omitempty"`
EdgeIPVersion int `json:"edge_ip_version,omitempty"`
DatagramVersion string `json:"datagram_version,omitempty"`
GracePeriod badoption.Duration `json:"grace_period,omitempty"`
Token string `json:"token,omitempty"`
CredentialPath string `json:"credential_path,omitempty"`
HAConnections int `json:"ha_connections,omitempty"`
Protocol string `json:"protocol,omitempty"`
EdgeIPVersion int `json:"edge_ip_version,omitempty"`
DatagramVersion string `json:"datagram_version,omitempty"`
GracePeriod badoption.Duration `json:"grace_period,omitempty"`
Region string `json:"region,omitempty"`
Ingress []CloudflareTunnelIngressRule `json:"ingress,omitempty"`
OriginRequest CloudflareTunnelOriginRequestOptions `json:"origin_request,omitempty"`
WarpRouting CloudflareTunnelWarpRoutingOptions `json:"warp_routing,omitempty"`
}
type CloudflareTunnelIngressRule struct {
Hostname string `json:"hostname,omitempty"`
Path string `json:"path,omitempty"`
Service string `json:"service,omitempty"`
OriginRequest CloudflareTunnelOriginRequestOptions `json:"origin_request,omitempty"`
}
type CloudflareTunnelOriginRequestOptions struct {
ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"`
TLSTimeout badoption.Duration `json:"tls_timeout,omitempty"`
TCPKeepAlive badoption.Duration `json:"tcp_keep_alive,omitempty"`
NoHappyEyeballs bool `json:"no_happy_eyeballs,omitempty"`
KeepAliveTimeout badoption.Duration `json:"keep_alive_timeout,omitempty"`
KeepAliveConnections int `json:"keep_alive_connections,omitempty"`
HTTPHostHeader string `json:"http_host_header,omitempty"`
OriginServerName string `json:"origin_server_name,omitempty"`
MatchSNIToHost bool `json:"match_sni_to_host,omitempty"`
CAPool string `json:"ca_pool,omitempty"`
NoTLSVerify bool `json:"no_tls_verify,omitempty"`
DisableChunkedEncoding bool `json:"disable_chunked_encoding,omitempty"`
BastionMode bool `json:"bastion_mode,omitempty"`
ProxyAddress string `json:"proxy_address,omitempty"`
ProxyPort uint `json:"proxy_port,omitempty"`
ProxyType string `json:"proxy_type,omitempty"`
IPRules []CloudflareTunnelIPRule `json:"ip_rules,omitempty"`
HTTP2Origin bool `json:"http2_origin,omitempty"`
Access CloudflareTunnelAccessRule `json:"access,omitempty"`
}
type CloudflareTunnelAccessRule struct {
Required bool `json:"required,omitempty"`
TeamName string `json:"team_name,omitempty"`
AudTag []string `json:"aud_tag,omitempty"`
Environment string `json:"environment,omitempty"`
}
type CloudflareTunnelIPRule struct {
Prefix string `json:"prefix,omitempty"`
Ports []int `json:"ports,omitempty"`
Allow bool `json:"allow,omitempty"`
}
type CloudflareTunnelWarpRoutingOptions struct {
ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"`
MaxActiveFlows uint64 `json:"max_active_flows,omitempty"`
TCPKeepAlive badoption.Duration `json:"tcp_keep_alive,omitempty"`
}

View File

@@ -90,10 +90,10 @@ func NewHTTP2Connection(
server: &http2.Server{
MaxConcurrentStreams: math.MaxUint32,
},
logger: logger,
edgeAddr: edgeAddr,
connIndex: connIndex,
credentials: credentials,
logger: logger,
edgeAddr: edgeAddr,
connIndex: connIndex,
credentials: credentials,
connectorID: connectorID,
features: features,
numPreviousAttempts: numPreviousAttempts,
@@ -244,9 +244,13 @@ func (c *HTTP2Connection) handleConfigurationUpdate(r *http.Request, w http.Resp
w.WriteHeader(http.StatusBadRequest)
return
}
c.inbound.UpdateIngress(body.Version, body.Config)
result := c.inbound.ApplyConfig(body.Version, body.Config)
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(body.Version), 10) + `,"err":null}`))
if result.Err != nil {
w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":` + strconv.Quote(result.Err.Error()) + `}`))
return
}
w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":null}`))
}
func (c *HTTP2Connection) close() {

View File

@@ -171,7 +171,6 @@ func DefaultFeatures(datagramVersion string) []string {
"support_datagram_v2",
"support_quic_eof",
"allow_remote_config",
"management_logs",
}
if datagramVersion == "v3" {
features = append(features, "support_datagram_v3_2")

View File

@@ -318,12 +318,17 @@ func (s *cloudflaredServer) UnregisterUdpSession(call tunnelrpc.SessionManager_u
func (s *cloudflaredServer) UpdateConfiguration(call tunnelrpc.ConfigurationManager_updateConfiguration) error {
version := call.Params.Version()
configData, _ := call.Params.Config()
s.inbound.UpdateIngress(version, configData)
updateResult := s.inbound.ApplyConfig(version, configData)
result, err := call.Results.NewResult()
if err != nil {
return err
}
result.SetErr("")
result.SetLatestAppliedVersion(updateResult.LastAppliedVersion)
if updateResult.Err != nil {
result.SetErr(updateResult.Err.Error())
} else {
result.SetErr("")
}
return nil
}

View File

@@ -4,12 +4,16 @@ package cloudflare
import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
@@ -124,19 +128,42 @@ func (i *Inbound) dispatchRequest(ctx context.Context, stream io.ReadWriteCloser
metadata.Destination = M.ParseSocksaddr(request.Dest)
i.handleTCPStream(ctx, stream, respWriter, metadata)
case ConnectionTypeHTTP, ConnectionTypeWebsocket:
originURL := i.ResolveOriginURL(request.Dest)
request.Dest = originURL
metadata.Destination = parseHTTPDestination(originURL)
if request.Type == ConnectionTypeHTTP {
i.handleHTTPStream(ctx, stream, respWriter, request, metadata)
} else {
i.handleWebSocketStream(ctx, stream, respWriter, request, metadata)
service, originURL, err := i.resolveHTTPService(request.Dest)
if err != nil {
i.logger.ErrorContext(ctx, "resolve origin service: ", err)
respWriter.WriteResponse(err, nil)
return
}
request.Dest = originURL
i.handleHTTPService(ctx, stream, respWriter, request, metadata, service)
default:
i.logger.ErrorContext(ctx, "unknown connection type: ", request.Type)
}
}
func (i *Inbound) resolveHTTPService(requestURL string) (ResolvedService, string, error) {
parsedURL, err := url.Parse(requestURL)
if err != nil {
return ResolvedService{}, "", E.Cause(err, "parse request URL")
}
service, loaded := i.configManager.Resolve(parsedURL.Hostname(), parsedURL.Path)
if !loaded {
return ResolvedService{}, "", E.New("no ingress rule matched request host/path")
}
if service.Kind == ResolvedServiceHelloWorld {
helloURL, err := i.ensureHelloWorldURL()
if err != nil {
return ResolvedService{}, "", err
}
service.BaseURL = helloURL
}
originURL, err := service.BuildRequestURL(requestURL)
if err != nil {
return ResolvedService{}, "", E.Cause(err, "build origin request URL")
}
return service, originURL, nil
}
func parseHTTPDestination(dest string) M.Socksaddr {
parsed, err := url.Parse(dest)
if err != nil {
@@ -172,10 +199,81 @@ func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser
<-done
}
func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) {
func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
switch service.Kind {
case ResolvedServiceStatus:
err := respWriter.WriteResponse(nil, encodeResponseHeaders(service.StatusCode, http.Header{}))
if err != nil {
i.logger.ErrorContext(ctx, "write status service response: ", err)
}
return
case ResolvedServiceHTTP:
metadata.Destination = service.Destination
if request.Type == ConnectionTypeHTTP {
i.handleHTTPStream(ctx, stream, respWriter, request, metadata, service)
} else {
i.handleWebSocketStream(ctx, stream, respWriter, request, metadata, service)
}
case ResolvedServiceUnix, ResolvedServiceUnixTLS, ResolvedServiceHelloWorld:
if request.Type == ConnectionTypeHTTP {
i.handleDirectHTTPStream(ctx, stream, respWriter, request, metadata, service)
} else {
i.handleDirectWebSocketStream(ctx, stream, respWriter, request, metadata, service)
}
default:
err := E.New("unsupported service kind for HTTP/WebSocket request")
i.logger.ErrorContext(ctx, err)
respWriter.WriteResponse(err, nil)
}
}
func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound HTTP connection to ", metadata.Destination)
transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest)
defer cleanup()
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
}
func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination)
transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest)
defer cleanup()
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
}
func (i *Inbound) handleDirectHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound HTTP connection to ", request.Dest)
transport, cleanup, err := i.newDirectOriginTransport(service)
if err != nil {
i.logger.ErrorContext(ctx, "build direct origin transport: ", err)
respWriter.WriteResponse(err, nil)
return
}
defer cleanup()
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
}
func (i *Inbound) handleDirectWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", request.Dest)
transport, cleanup, err := i.newDirectOriginTransport(service)
if err != nil {
i.logger.ErrorContext(ctx, "build direct origin transport: ", err)
respWriter.WriteResponse(err, nil)
return
}
defer cleanup()
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
}
func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, service ResolvedService, transport *http.Transport) {
httpRequest, err := buildHTTPRequestFromMetadata(ctx, request, stream)
if err != nil {
i.logger.ErrorContext(ctx, "build HTTP request: ", err)
@@ -183,23 +281,17 @@ func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteClose
return
}
input, output := pipe.Pipe()
var innerError error
done := make(chan struct{})
go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) {
innerError = it
common.Close(input, output)
close(done)
}))
httpRequest = applyOriginRequest(httpRequest, service.OriginRequest)
requestCtx := httpRequest.Context()
if service.OriginRequest.ConnectTimeout > 0 {
var cancel context.CancelFunc
requestCtx, cancel = context.WithTimeout(requestCtx, service.OriginRequest.ConnectTimeout)
defer cancel()
httpRequest = httpRequest.WithContext(requestCtx)
}
httpClient := &http.Client{
Transport: &http.Transport{
DisableCompression: true,
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return input, nil
},
},
Transport: transport,
CheckRedirect: func(request *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
@@ -208,87 +300,146 @@ func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteClose
response, err := httpClient.Do(httpRequest)
if err != nil {
<-done
i.logger.ErrorContext(ctx, "HTTP request: ", E.Errors(innerError, err))
i.logger.ErrorContext(ctx, "origin request: ", err)
respWriter.WriteResponse(err, nil)
return
}
defer response.Body.Close()
responseMetadata := encodeResponseHeaders(response.StatusCode, response.Header)
err = respWriter.WriteResponse(nil, responseMetadata)
if err != nil {
response.Body.Close()
i.logger.ErrorContext(ctx, "write HTTP response headers: ", err)
<-done
i.logger.ErrorContext(ctx, "write origin response headers: ", err)
return
}
if request.Type == ConnectionTypeWebsocket && response.StatusCode == http.StatusSwitchingProtocols {
rwc, ok := response.Body.(io.ReadWriteCloser)
if !ok {
i.logger.ErrorContext(ctx, "websocket origin response body is not duplex")
return
}
bidirectionalCopy(stream, rwc)
return
}
_, err = io.Copy(stream, response.Body)
response.Body.Close()
common.Close(input, output)
if err != nil && !E.IsClosedOrCanceled(err) {
i.logger.DebugContext(ctx, "copy HTTP response body: ", err)
}
<-done
}
func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) {
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination)
httpRequest, err := buildHTTPRequestFromMetadata(ctx, request, stream)
if err != nil {
i.logger.ErrorContext(ctx, "build WebSocket request: ", err)
respWriter.WriteResponse(err, nil)
return
}
func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig) (*http.Transport, func()) {
input, output := pipe.Pipe()
var innerError error
done := make(chan struct{})
go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) {
innerError = it
common.Close(input, output)
close(done)
}))
httpClient := &http.Client{
Transport: &http.Transport{
DisableCompression: true,
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return input, nil
},
},
CheckRedirect: func(request *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
transport := &http.Transport{
DisableCompression: true,
ForceAttemptHTTP2: originRequest.HTTP2Origin,
TLSHandshakeTimeout: originRequest.TLSTimeout,
IdleConnTimeout: originRequest.KeepAliveTimeout,
MaxIdleConns: originRequest.KeepAliveConnections,
MaxIdleConnsPerHost: originRequest.KeepAliveConnections,
TLSClientConfig: buildOriginTLSConfig(originRequest),
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return input, nil
},
}
defer httpClient.CloseIdleConnections()
return transport, func() {
common.Close(input, output)
select {
case <-done:
case <-time.After(time.Second):
}
}
}
response, err := httpClient.Do(httpRequest)
func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Transport, func(), error) {
transport := &http.Transport{
DisableCompression: true,
ForceAttemptHTTP2: service.OriginRequest.HTTP2Origin,
TLSHandshakeTimeout: service.OriginRequest.TLSTimeout,
IdleConnTimeout: service.OriginRequest.KeepAliveTimeout,
MaxIdleConns: service.OriginRequest.KeepAliveConnections,
MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections,
TLSClientConfig: buildOriginTLSConfig(service.OriginRequest),
}
switch service.Kind {
case ResolvedServiceUnix, ResolvedServiceUnixTLS:
dialer := &net.Dialer{}
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, "unix", service.UnixPath)
}
case ResolvedServiceHelloWorld:
dialer := &net.Dialer{}
target := service.BaseURL.Host
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", target)
}
default:
return nil, nil, E.New("unsupported direct origin service")
}
return transport, func() {}, nil
}
func buildOriginTLSConfig(originRequest OriginRequestConfig) *tls.Config {
tlsConfig := &tls.Config{
InsecureSkipVerify: originRequest.NoTLSVerify, //nolint:gosec
ServerName: originRequest.OriginServerName,
}
if originRequest.CAPool == "" {
return tlsConfig
}
pemData, err := os.ReadFile(originRequest.CAPool)
if err != nil {
<-done
i.logger.ErrorContext(ctx, "WebSocket request: ", E.Errors(innerError, err))
respWriter.WriteResponse(err, nil)
return
return tlsConfig
}
pool := x509.NewCertPool()
if pool.AppendCertsFromPEM(pemData) {
tlsConfig.RootCAs = pool
}
return tlsConfig
}
func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig) *http.Request {
request = request.Clone(request.Context())
if originRequest.HTTPHostHeader != "" {
request.Header.Set("X-Forwarded-Host", request.Host)
request.Host = originRequest.HTTPHostHeader
}
if originRequest.DisableChunkedEncoding && request.Header.Get("Content-Length") != "" {
if contentLength, err := strconv.ParseInt(request.Header.Get("Content-Length"), 10, 64); err == nil {
request.ContentLength = contentLength
request.TransferEncoding = nil
}
}
return request
}
func bidirectionalCopy(left, right io.ReadWriteCloser) {
var closeOnce sync.Once
closeBoth := func() {
closeOnce.Do(func() {
common.Close(left, right)
})
}
responseMetadata := encodeResponseHeaders(response.StatusCode, response.Header)
err = respWriter.WriteResponse(nil, responseMetadata)
if err != nil {
response.Body.Close()
i.logger.ErrorContext(ctx, "write WebSocket response headers: ", err)
<-done
return
}
_, err = io.Copy(stream, response.Body)
response.Body.Close()
common.Close(input, output)
if err != nil && !E.IsClosedOrCanceled(err) {
i.logger.DebugContext(ctx, "copy WebSocket response body: ", err)
}
done := make(chan struct{}, 2)
go func() {
io.Copy(left, right)
closeBoth()
done <- struct{}{}
}()
go func() {
io.Copy(right, left)
closeBoth()
done <- struct{}{}
}()
<-done
<-done
}

View File

@@ -142,6 +142,11 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i
t.Fatal("create logger: ", err)
}
configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{})
if err != nil {
t.Fatal("create config manager: ", err)
}
ctx, cancel := context.WithCancel(context.Background())
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"),
@@ -156,6 +161,7 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i
edgeIPVersion: 0,
datagramVersion: "",
gracePeriod: 5 * time.Second,
configManager: configManager,
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
}

View File

@@ -7,9 +7,10 @@ import (
"encoding/base64"
"io"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
@@ -30,17 +31,19 @@ func RegisterInbound(registry *inbound.Registry) {
type Inbound struct {
inbound.Adapter
ctx context.Context
cancel context.CancelFunc
router adapter.ConnectionRouterEx
logger log.ContextLogger
credentials Credentials
connectorID uuid.UUID
haConnections int
protocol string
edgeIPVersion int
datagramVersion string
gracePeriod time.Duration
ctx context.Context
cancel context.CancelFunc
router adapter.ConnectionRouterEx
logger log.ContextLogger
credentials Credentials
connectorID uuid.UUID
haConnections int
protocol string
region string
edgeIPVersion int
datagramVersion string
gracePeriod time.Duration
configManager *ConfigManager
connectionAccess sync.Mutex
connections []io.Closer
@@ -50,101 +53,9 @@ type Inbound struct {
datagramV2Muxers map[DatagramSender]*DatagramV2Muxer
datagramV3Muxers map[DatagramSender]*DatagramV3Muxer
ingressAccess sync.RWMutex
ingressVersion int32
ingressRules []IngressRule
}
// IngressRule maps a hostname pattern to an origin service URL.
type IngressRule struct {
Hostname string
Service string
}
type ingressConfig struct {
Ingress []ingressConfigRule `json:"ingress"`
}
type ingressConfigRule struct {
Hostname string `json:"hostname,omitempty"`
Service string `json:"service"`
}
// UpdateIngress applies a new ingress configuration from the edge.
func (i *Inbound) UpdateIngress(version int32, config []byte) {
i.ingressAccess.Lock()
defer i.ingressAccess.Unlock()
if version <= i.ingressVersion {
return
}
var parsed ingressConfig
err := json.Unmarshal(config, &parsed)
if err != nil {
i.logger.Error("parse ingress config: ", err)
return
}
rules := make([]IngressRule, 0, len(parsed.Ingress))
for _, rule := range parsed.Ingress {
rules = append(rules, IngressRule{
Hostname: rule.Hostname,
Service: rule.Service,
})
}
i.ingressRules = rules
i.ingressVersion = version
i.logger.Info("updated ingress configuration (version ", version, ", ", len(rules), " rules)")
}
// ResolveOrigin finds the origin service URL for a given hostname.
// Returns the service URL if matched, or empty string if no match.
func (i *Inbound) ResolveOrigin(hostname string) string {
i.ingressAccess.RLock()
defer i.ingressAccess.RUnlock()
for _, rule := range i.ingressRules {
if rule.Hostname == "" {
return rule.Service
}
if matchIngress(rule.Hostname, hostname) {
return rule.Service
}
}
return ""
}
func matchIngress(pattern, hostname string) bool {
if pattern == hostname {
return true
}
if strings.HasPrefix(pattern, "*.") {
suffix := pattern[1:]
return strings.HasSuffix(hostname, suffix)
}
return false
}
// ResolveOriginURL rewrites a request URL to point to the origin service.
// For example, https://testbox.badnet.work/path → http://127.0.0.1:8083/path
func (i *Inbound) ResolveOriginURL(requestURL string) string {
parsed, err := url.Parse(requestURL)
if err != nil {
return requestURL
}
hostname := parsed.Hostname()
origin := i.ResolveOrigin(hostname)
if origin == "" || strings.HasPrefix(origin, "http_status:") {
return requestURL
}
originURL, err := url.Parse(origin)
if err != nil {
return requestURL
}
parsed.Scheme = originURL.Scheme
parsed.Host = originURL.Host
return parsed.String()
helloWorldAccess sync.Mutex
helloWorldServer *http.Server
helloWorldURL *url.URL
}
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflareTunnelInboundOptions) (adapter.Inbound, error) {
@@ -178,23 +89,30 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
gracePeriod = 30 * time.Second
}
configManager, err := NewConfigManager(options)
if err != nil {
return nil, E.Cause(err, "build cloudflare tunnel runtime config")
}
inboundCtx, cancel := context.WithCancel(ctx)
return &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, tag),
ctx: inboundCtx,
cancel: cancel,
router: router,
logger: logger,
credentials: credentials,
connectorID: uuid.New(),
haConnections: haConnections,
protocol: protocol,
edgeIPVersion: edgeIPVersion,
datagramVersion: datagramVersion,
gracePeriod: gracePeriod,
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, tag),
ctx: inboundCtx,
cancel: cancel,
router: router,
logger: logger,
credentials: credentials,
connectorID: uuid.New(),
haConnections: haConnections,
protocol: protocol,
region: options.Region,
edgeIPVersion: edgeIPVersion,
datagramVersion: datagramVersion,
gracePeriod: gracePeriod,
configManager: configManager,
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
}, nil
}
@@ -238,6 +156,16 @@ func (i *Inbound) Start(stage adapter.StartStage) error {
return nil
}
func (i *Inbound) ApplyConfig(version int32, config []byte) ConfigUpdateResult {
result := i.configManager.Apply(version, config)
if result.Err != nil {
i.logger.Error("update ingress configuration: ", result.Err)
return result
}
i.logger.Info("updated ingress configuration (version ", result.LastAppliedVersion, ")")
return result
}
func (i *Inbound) Close() error {
i.cancel()
i.done.Wait()
@@ -247,9 +175,41 @@ func (i *Inbound) Close() error {
}
i.connections = nil
i.connectionAccess.Unlock()
if i.helloWorldServer != nil {
i.helloWorldServer.Close()
}
return nil
}
func (i *Inbound) ensureHelloWorldURL() (*url.URL, error) {
i.helloWorldAccess.Lock()
defer i.helloWorldAccess.Unlock()
if i.helloWorldURL != nil {
return i.helloWorldURL, nil
}
mux := http.NewServeMux()
mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, _ = writer.Write([]byte("Hello World"))
})
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, E.Cause(err, "listen hello world server")
}
server := &http.Server{Handler: mux}
go server.Serve(listener)
i.helloWorldServer = server
i.helloWorldURL = &url.URL{
Scheme: "http",
Host: listener.Addr().String(),
}
return i.helloWorldURL, nil
}
const (
backoffBaseTime = time.Second
backoffMaxTime = 2 * time.Minute

View File

@@ -6,143 +6,148 @@ import (
"testing"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)
func newTestIngressInbound() *Inbound {
return &Inbound{logger: log.NewNOPFactory().NewLogger("test")}
func newTestIngressInbound(t *testing.T) *Inbound {
t.Helper()
configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{})
if err != nil {
t.Fatal(err)
}
return &Inbound{
logger: log.NewNOPFactory().NewLogger("test"),
configManager: configManager,
}
}
func TestUpdateIngress(t *testing.T) {
inboundInstance := newTestIngressInbound()
func mustResolvedService(t *testing.T, rawService string) ResolvedService {
t.Helper()
service, err := parseResolvedService(rawService, defaultOriginRequestConfig())
if err != nil {
t.Fatal(err)
}
return service
}
func TestApplyConfig(t *testing.T) {
inboundInstance := newTestIngressInbound(t)
config1 := []byte(`{"ingress":[{"hostname":"a.com","service":"http://localhost:80"},{"hostname":"b.com","service":"http://localhost:81"},{"service":"http_status:404"}]}`)
inboundInstance.UpdateIngress(1, config1)
inboundInstance.ingressAccess.RLock()
count := len(inboundInstance.ingressRules)
inboundInstance.ingressAccess.RUnlock()
if count != 3 {
t.Fatalf("expected 3 rules, got %d", count)
result := inboundInstance.ApplyConfig(1, config1)
if result.Err != nil {
t.Fatal(result.Err)
}
if result.LastAppliedVersion != 1 {
t.Fatalf("expected version 1, got %d", result.LastAppliedVersion)
}
inboundInstance.UpdateIngress(1, []byte(`{"ingress":[{"service":"http_status:503"}]}`))
inboundInstance.ingressAccess.RLock()
count = len(inboundInstance.ingressRules)
inboundInstance.ingressAccess.RUnlock()
if count != 3 {
t.Error("version 1 re-apply should not change rules, got ", count)
service, loaded := inboundInstance.configManager.Resolve("a.com", "/")
if !loaded || service.Service != "http://localhost:80" {
t.Fatalf("expected a.com to resolve to localhost:80, got %#v, loaded=%v", service, loaded)
}
inboundInstance.UpdateIngress(2, []byte(`{"ingress":[{"service":"http_status:503"}]}`))
inboundInstance.ingressAccess.RLock()
count = len(inboundInstance.ingressRules)
inboundInstance.ingressAccess.RUnlock()
if count != 1 {
t.Error("version 2 should update to 1 rule, got ", count)
result = inboundInstance.ApplyConfig(1, []byte(`{"ingress":[{"service":"http_status:503"}]}`))
if result.Err != nil {
t.Fatal(result.Err)
}
if result.LastAppliedVersion != 1 {
t.Fatalf("same version should keep current version, got %d", result.LastAppliedVersion)
}
service, loaded = inboundInstance.configManager.Resolve("b.com", "/")
if !loaded || service.Service != "http://localhost:81" {
t.Fatalf("expected old rules to remain, got %#v, loaded=%v", service, loaded)
}
result = inboundInstance.ApplyConfig(2, []byte(`{"ingress":[{"service":"http_status:503"}]}`))
if result.Err != nil {
t.Fatal(result.Err)
}
if result.LastAppliedVersion != 2 {
t.Fatalf("expected version 2, got %d", result.LastAppliedVersion)
}
service, loaded = inboundInstance.configManager.Resolve("anything.com", "/")
if !loaded || service.StatusCode != 503 {
t.Fatalf("expected catch-all status 503, got %#v, loaded=%v", service, loaded)
}
}
func TestUpdateIngressInvalidJSON(t *testing.T) {
inboundInstance := newTestIngressInbound()
inboundInstance.UpdateIngress(1, []byte("not json"))
inboundInstance.ingressAccess.RLock()
count := len(inboundInstance.ingressRules)
inboundInstance.ingressAccess.RUnlock()
if count != 0 {
t.Error("invalid JSON should leave rules empty, got ", count)
func TestApplyConfigInvalidJSON(t *testing.T) {
inboundInstance := newTestIngressInbound(t)
result := inboundInstance.ApplyConfig(1, []byte("not json"))
if result.Err == nil {
t.Fatal("expected parse error")
}
if result.LastAppliedVersion != -1 {
t.Fatalf("expected version to stay -1, got %d", result.LastAppliedVersion)
}
}
func TestResolveOriginExact(t *testing.T) {
inboundInstance := newTestIngressInbound()
inboundInstance.ingressRules = []IngressRule{
{Hostname: "test.example.com", Service: "http://localhost:8080"},
{Hostname: "", Service: "http_status:404"},
func TestResolveExactAndWildcard(t *testing.T) {
inboundInstance := newTestIngressInbound(t)
inboundInstance.configManager.activeConfig = RuntimeConfig{
Ingress: []compiledIngressRule{
{Hostname: "test.example.com", Service: mustResolvedService(t, "http://localhost:8080")},
{Hostname: "*.example.com", Service: mustResolvedService(t, "http://localhost:9090")},
{Service: mustResolvedService(t, "http_status:404")},
},
}
result := inboundInstance.ResolveOrigin("test.example.com")
if result != "http://localhost:8080" {
t.Error("expected http://localhost:8080, got ", result)
service, loaded := inboundInstance.configManager.Resolve("test.example.com", "/")
if !loaded || service.Service != "http://localhost:8080" {
t.Fatalf("expected exact match, got %#v, loaded=%v", service, loaded)
}
service, loaded = inboundInstance.configManager.Resolve("sub.example.com", "/")
if !loaded || service.Service != "http://localhost:9090" {
t.Fatalf("expected wildcard match, got %#v, loaded=%v", service, loaded)
}
service, loaded = inboundInstance.configManager.Resolve("unknown.test", "/")
if !loaded || service.StatusCode != 404 {
t.Fatalf("expected catch-all 404, got %#v, loaded=%v", service, loaded)
}
}
func TestResolveOriginWildcard(t *testing.T) {
inboundInstance := newTestIngressInbound()
inboundInstance.ingressRules = []IngressRule{
{Hostname: "*.example.com", Service: "http://localhost:9090"},
func TestResolveHTTPService(t *testing.T) {
inboundInstance := newTestIngressInbound(t)
inboundInstance.configManager.activeConfig = RuntimeConfig{
Ingress: []compiledIngressRule{
{Hostname: "foo.com", Service: mustResolvedService(t, "http://127.0.0.1:8083")},
{Service: mustResolvedService(t, "http_status:404")},
},
}
result := inboundInstance.ResolveOrigin("sub.example.com")
if result != "http://localhost:9090" {
t.Error("wildcard should match sub.example.com, got ", result)
service, requestURL, err := inboundInstance.resolveHTTPService("https://foo.com/path?q=1")
if err != nil {
t.Fatal(err)
}
result = inboundInstance.ResolveOrigin("example.com")
if result != "" {
t.Error("wildcard should not match bare example.com, got ", result)
if service.Destination.String() != "127.0.0.1:8083" {
t.Fatalf("expected destination 127.0.0.1:8083, got %s", service.Destination)
}
if requestURL != "http://127.0.0.1:8083/path?q=1" {
t.Fatalf("expected rewritten URL, got %s", requestURL)
}
}
func TestResolveOriginCatchAll(t *testing.T) {
inboundInstance := newTestIngressInbound()
inboundInstance.ingressRules = []IngressRule{
{Hostname: "specific.com", Service: "http://localhost:1"},
{Hostname: "", Service: "http://localhost:2"},
func TestResolveHTTPServiceStatus(t *testing.T) {
inboundInstance := newTestIngressInbound(t)
inboundInstance.configManager.activeConfig = RuntimeConfig{
Ingress: []compiledIngressRule{
{Service: mustResolvedService(t, "http_status:404")},
},
}
result := inboundInstance.ResolveOrigin("anything.com")
if result != "http://localhost:2" {
t.Error("catch-all should match, got ", result)
}
}
func TestResolveOriginNoMatch(t *testing.T) {
inboundInstance := newTestIngressInbound()
inboundInstance.ingressRules = []IngressRule{
{Hostname: "specific.com", Service: "http://localhost:1"},
}
result := inboundInstance.ResolveOrigin("other.com")
if result != "" {
t.Error("expected empty for no match, got ", result)
}
}
func TestResolveOriginURLRewrite(t *testing.T) {
inboundInstance := newTestIngressInbound()
inboundInstance.ingressRules = []IngressRule{
{Hostname: "foo.com", Service: "http://127.0.0.1:8083"},
}
result := inboundInstance.ResolveOriginURL("https://foo.com/path?q=1")
if result != "http://127.0.0.1:8083/path?q=1" {
t.Error("expected http://127.0.0.1:8083/path?q=1, got ", result)
}
}
func TestResolveOriginURLNoMatch(t *testing.T) {
inboundInstance := newTestIngressInbound()
inboundInstance.ingressRules = []IngressRule{
{Hostname: "other.com", Service: "http://localhost:1"},
}
original := "https://unknown.com/page"
result := inboundInstance.ResolveOriginURL(original)
if result != original {
t.Error("no match should return original, got ", result)
}
}
func TestResolveOriginURLHTTPStatus(t *testing.T) {
inboundInstance := newTestIngressInbound()
inboundInstance.ingressRules = []IngressRule{
{Hostname: "", Service: "http_status:404"},
}
original := "https://any.com/page"
result := inboundInstance.ResolveOriginURL(original)
if result != original {
t.Error("http_status service should return original, got ", result)
service, requestURL, err := inboundInstance.resolveHTTPService("https://any.com/path")
if err != nil {
t.Fatal(err)
}
if service.StatusCode != 404 {
t.Fatalf("expected status 404, got %#v", service)
}
if requestURL != "https://any.com/path" {
t.Fatalf("status service should keep request URL, got %s", requestURL)
}
}

View File

@@ -0,0 +1,803 @@
//go:build with_cloudflare_tunnel
package cloudflare
import (
"encoding/json"
"net"
"net/url"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/net/idna"
)
const (
defaultHTTPConnectTimeout = 30 * time.Second
defaultTLSTimeout = 10 * time.Second
defaultTCPKeepAlive = 30 * time.Second
defaultKeepAliveTimeout = 90 * time.Second
defaultKeepAliveConnections = 100
defaultProxyAddress = "127.0.0.1"
defaultWarpRoutingConnectTime = 5 * time.Second
defaultWarpRoutingTCPKeepAlive = 30 * time.Second
)
type ResolvedServiceKind int
const (
ResolvedServiceHTTP ResolvedServiceKind = iota
ResolvedServiceStream
ResolvedServiceStatus
ResolvedServiceHelloWorld
ResolvedServiceUnix
ResolvedServiceUnixTLS
)
type ResolvedService struct {
Kind ResolvedServiceKind
Service string
Destination M.Socksaddr
BaseURL *url.URL
UnixPath string
StatusCode int
OriginRequest OriginRequestConfig
}
func (s ResolvedService) RouterControlled() bool {
return s.Kind == ResolvedServiceHTTP || s.Kind == ResolvedServiceStream
}
func (s ResolvedService) BuildRequestURL(requestURL string) (string, error) {
switch s.Kind {
case ResolvedServiceHTTP, ResolvedServiceUnix, ResolvedServiceUnixTLS:
requestParsed, err := url.Parse(requestURL)
if err != nil {
return "", err
}
originURL := *s.BaseURL
originURL.Path = requestParsed.Path
originURL.RawPath = requestParsed.RawPath
originURL.RawQuery = requestParsed.RawQuery
originURL.Fragment = requestParsed.Fragment
return originURL.String(), nil
case ResolvedServiceHelloWorld:
if s.BaseURL == nil {
return "", E.New("hello world service is unavailable")
}
requestParsed, err := url.Parse(requestURL)
if err != nil {
return "", err
}
originURL := *s.BaseURL
originURL.Path = requestParsed.Path
originURL.RawPath = requestParsed.RawPath
originURL.RawQuery = requestParsed.RawQuery
originURL.Fragment = requestParsed.Fragment
return originURL.String(), nil
default:
return requestURL, nil
}
}
type compiledIngressRule struct {
Hostname string
PunycodeHostname string
Path *regexp.Regexp
Service ResolvedService
}
type RuntimeConfig struct {
Ingress []compiledIngressRule
OriginRequest OriginRequestConfig
WarpRouting WarpRoutingConfig
}
type OriginRequestConfig struct {
ConnectTimeout time.Duration
TLSTimeout time.Duration
TCPKeepAlive time.Duration
NoHappyEyeballs bool
KeepAliveTimeout time.Duration
KeepAliveConnections int
HTTPHostHeader string
OriginServerName string
MatchSNIToHost bool
CAPool string
NoTLSVerify bool
DisableChunkedEncoding bool
BastionMode bool
ProxyAddress string
ProxyPort uint
ProxyType string
IPRules []IPRule
HTTP2Origin bool
Access AccessConfig
}
type AccessConfig struct {
Required bool
TeamName string
AudTag []string
Environment string
}
type IPRule struct {
Prefix string
Ports []int
Allow bool
}
type WarpRoutingConfig struct {
ConnectTimeout time.Duration
MaxActiveFlows uint64
TCPKeepAlive time.Duration
}
type ConfigUpdateResult struct {
LastAppliedVersion int32
Err error
}
type ConfigManager struct {
access sync.RWMutex
currentVersion int32
activeConfig RuntimeConfig
}
func NewConfigManager(options option.CloudflareTunnelInboundOptions) (*ConfigManager, error) {
config, err := buildLocalRuntimeConfig(options)
if err != nil {
return nil, err
}
return &ConfigManager{
currentVersion: -1,
activeConfig: config,
}, nil
}
func (m *ConfigManager) Snapshot() RuntimeConfig {
m.access.RLock()
defer m.access.RUnlock()
return m.activeConfig
}
func (m *ConfigManager) CurrentVersion() int32 {
m.access.RLock()
defer m.access.RUnlock()
return m.currentVersion
}
func (m *ConfigManager) Apply(version int32, raw []byte) ConfigUpdateResult {
m.access.Lock()
defer m.access.Unlock()
if version <= m.currentVersion {
return ConfigUpdateResult{LastAppliedVersion: m.currentVersion}
}
config, err := buildRemoteRuntimeConfig(raw)
if err != nil {
return ConfigUpdateResult{
LastAppliedVersion: m.currentVersion,
Err: err,
}
}
m.activeConfig = config
m.currentVersion = version
return ConfigUpdateResult{LastAppliedVersion: m.currentVersion}
}
func (m *ConfigManager) Resolve(hostname, path string) (ResolvedService, bool) {
m.access.RLock()
defer m.access.RUnlock()
return m.activeConfig.Resolve(hostname, path)
}
func (c RuntimeConfig) Resolve(hostname, path string) (ResolvedService, bool) {
host := stripPort(hostname)
for _, rule := range c.Ingress {
if !matchIngressRule(rule, host, path) {
continue
}
return rule.Service, true
}
return ResolvedService{}, false
}
func matchIngressRule(rule compiledIngressRule, hostname, path string) bool {
hostMatch := rule.Hostname == "" || rule.Hostname == "*" || matchIngressHost(rule.Hostname, hostname)
if !hostMatch && rule.PunycodeHostname != "" {
hostMatch = matchIngressHost(rule.PunycodeHostname, hostname)
}
if !hostMatch {
return false
}
return rule.Path == nil || rule.Path.MatchString(path)
}
func matchIngressHost(pattern, hostname string) bool {
if pattern == hostname {
return true
}
if strings.HasPrefix(pattern, "*.") {
return strings.HasSuffix(hostname, strings.TrimPrefix(pattern, "*"))
}
return false
}
func buildLocalRuntimeConfig(options option.CloudflareTunnelInboundOptions) (RuntimeConfig, error) {
defaultOriginRequest := originRequestFromOption(options.OriginRequest)
warpRouting := warpRoutingFromOption(options.WarpRouting)
var ingressRules []localIngressRule
for _, rule := range options.Ingress {
ingressRules = append(ingressRules, localIngressRule{
Hostname: rule.Hostname,
Path: rule.Path,
Service: rule.Service,
OriginRequest: mergeOptionOriginRequest(defaultOriginRequest, rule.OriginRequest),
})
}
compiledRules, err := compileIngressRules(defaultOriginRequest, ingressRules)
if err != nil {
return RuntimeConfig{}, err
}
return RuntimeConfig{
Ingress: compiledRules,
OriginRequest: defaultOriginRequest,
WarpRouting: warpRouting,
}, nil
}
func buildRemoteRuntimeConfig(raw []byte) (RuntimeConfig, error) {
var remote remoteConfigJSON
if err := json.Unmarshal(raw, &remote); err != nil {
return RuntimeConfig{}, E.Cause(err, "decode remote config")
}
defaultOriginRequest := originRequestFromRemote(remote.OriginRequest)
warpRouting := warpRoutingFromRemote(remote.WarpRouting)
var ingressRules []localIngressRule
for _, rule := range remote.Ingress {
ingressRules = append(ingressRules, localIngressRule{
Hostname: rule.Hostname,
Path: rule.Path,
Service: rule.Service,
OriginRequest: mergeRemoteOriginRequest(defaultOriginRequest, rule.OriginRequest),
})
}
compiledRules, err := compileIngressRules(defaultOriginRequest, ingressRules)
if err != nil {
return RuntimeConfig{}, err
}
return RuntimeConfig{
Ingress: compiledRules,
OriginRequest: defaultOriginRequest,
WarpRouting: warpRouting,
}, nil
}
type localIngressRule struct {
Hostname string
Path string
Service string
OriginRequest OriginRequestConfig
}
type remoteConfigJSON struct {
OriginRequest remoteOriginRequestJSON `json:"originRequest"`
Ingress []remoteIngressRuleJSON `json:"ingress"`
WarpRouting remoteWarpRoutingJSON `json:"warp-routing"`
}
type remoteIngressRuleJSON struct {
Hostname string `json:"hostname,omitempty"`
Path string `json:"path,omitempty"`
Service string `json:"service"`
OriginRequest remoteOriginRequestJSON `json:"originRequest,omitempty"`
}
type remoteOriginRequestJSON struct {
ConnectTimeout int64 `json:"connectTimeout,omitempty"`
TLSTimeout int64 `json:"tlsTimeout,omitempty"`
TCPKeepAlive int64 `json:"tcpKeepAlive,omitempty"`
NoHappyEyeballs *bool `json:"noHappyEyeballs,omitempty"`
KeepAliveTimeout int64 `json:"keepAliveTimeout,omitempty"`
KeepAliveConnections *int `json:"keepAliveConnections,omitempty"`
HTTPHostHeader string `json:"httpHostHeader,omitempty"`
OriginServerName string `json:"originServerName,omitempty"`
MatchSNIToHost *bool `json:"matchSNIToHost,omitempty"`
CAPool string `json:"caPool,omitempty"`
NoTLSVerify *bool `json:"noTLSVerify,omitempty"`
DisableChunkedEncoding *bool `json:"disableChunkedEncoding,omitempty"`
BastionMode *bool `json:"bastionMode,omitempty"`
ProxyAddress string `json:"proxyAddress,omitempty"`
ProxyPort *uint `json:"proxyPort,omitempty"`
ProxyType string `json:"proxyType,omitempty"`
IPRules []remoteIPRuleJSON `json:"ipRules,omitempty"`
HTTP2Origin *bool `json:"http2Origin,omitempty"`
Access *remoteAccessJSON `json:"access,omitempty"`
}
type remoteAccessJSON struct {
Required bool `json:"required,omitempty"`
TeamName string `json:"teamName,omitempty"`
AudTag []string `json:"audTag,omitempty"`
Environment string `json:"environment,omitempty"`
}
type remoteIPRuleJSON struct {
Prefix string `json:"prefix,omitempty"`
Ports []int `json:"ports,omitempty"`
Allow bool `json:"allow,omitempty"`
}
type remoteWarpRoutingJSON struct {
ConnectTimeout int64 `json:"connectTimeout,omitempty"`
MaxActiveFlows uint64 `json:"maxActiveFlows,omitempty"`
TCPKeepAlive int64 `json:"tcpKeepAlive,omitempty"`
}
func compileIngressRules(defaultOriginRequest OriginRequestConfig, rawRules []localIngressRule) ([]compiledIngressRule, error) {
if len(rawRules) == 0 {
rawRules = []localIngressRule{{
Service: "http_status:503",
OriginRequest: defaultOriginRequest,
}}
}
if !isCatchAllRule(rawRules[len(rawRules)-1].Hostname, rawRules[len(rawRules)-1].Path) {
return nil, E.New("the last ingress rule must be a catch-all rule")
}
compiled := make([]compiledIngressRule, 0, len(rawRules))
for index, rule := range rawRules {
if err := validateHostname(rule.Hostname, index == len(rawRules)-1); err != nil {
return nil, err
}
service, err := parseResolvedService(rule.Service, rule.OriginRequest)
if err != nil {
return nil, err
}
var pathPattern *regexp.Regexp
if rule.Path != "" {
pathPattern, err = regexp.Compile(rule.Path)
if err != nil {
return nil, E.Cause(err, "compile ingress path regex")
}
}
punycode := ""
if rule.Hostname != "" && rule.Hostname != "*" {
punycodeValue, err := idna.Lookup.ToASCII(rule.Hostname)
if err == nil && punycodeValue != rule.Hostname {
punycode = punycodeValue
}
}
compiled = append(compiled, compiledIngressRule{
Hostname: rule.Hostname,
PunycodeHostname: punycode,
Path: pathPattern,
Service: service,
})
}
return compiled, nil
}
func parseResolvedService(rawService string, originRequest OriginRequestConfig) (ResolvedService, error) {
switch {
case rawService == "":
return ResolvedService{}, E.New("missing ingress service")
case strings.HasPrefix(rawService, "http_status:"):
statusCode, err := strconv.Atoi(strings.TrimPrefix(rawService, "http_status:"))
if err != nil {
return ResolvedService{}, E.Cause(err, "parse http_status service")
}
if statusCode < 100 || statusCode > 999 {
return ResolvedService{}, E.New("invalid http_status code: ", statusCode)
}
return ResolvedService{
Kind: ResolvedServiceStatus,
Service: rawService,
StatusCode: statusCode,
OriginRequest: originRequest,
}, nil
case rawService == "hello_world" || rawService == "hello-world":
return ResolvedService{
Kind: ResolvedServiceHelloWorld,
Service: rawService,
OriginRequest: originRequest,
}, nil
case strings.HasPrefix(rawService, "unix:"):
return ResolvedService{
Kind: ResolvedServiceUnix,
Service: rawService,
UnixPath: strings.TrimPrefix(rawService, "unix:"),
BaseURL: &url.URL{Scheme: "http", Host: "localhost"},
OriginRequest: originRequest,
}, nil
case strings.HasPrefix(rawService, "unix+tls:"):
return ResolvedService{
Kind: ResolvedServiceUnixTLS,
Service: rawService,
UnixPath: strings.TrimPrefix(rawService, "unix+tls:"),
BaseURL: &url.URL{Scheme: "https", Host: "localhost"},
OriginRequest: originRequest,
}, nil
}
parsedURL, err := url.Parse(rawService)
if err != nil {
return ResolvedService{}, E.Cause(err, "parse ingress service URL")
}
if parsedURL.Scheme == "" || parsedURL.Hostname() == "" {
return ResolvedService{}, E.New("ingress service must include scheme and hostname: ", rawService)
}
if parsedURL.Path != "" {
return ResolvedService{}, E.New("ingress service cannot include a path: ", rawService)
}
switch parsedURL.Scheme {
case "http", "https", "ws", "wss":
return ResolvedService{
Kind: ResolvedServiceHTTP,
Service: rawService,
Destination: parseServiceDestination(parsedURL),
BaseURL: parsedURL,
OriginRequest: originRequest,
}, nil
case "tcp", "ssh", "rdp", "smb":
return ResolvedService{
Kind: ResolvedServiceStream,
Service: rawService,
Destination: parseServiceDestination(parsedURL),
BaseURL: parsedURL,
OriginRequest: originRequest,
}, nil
default:
return ResolvedService{}, E.New("unsupported ingress service scheme: ", parsedURL.Scheme)
}
}
func parseServiceDestination(parsedURL *url.URL) M.Socksaddr {
host := parsedURL.Hostname()
port := parsedURL.Port()
if port == "" {
switch parsedURL.Scheme {
case "https", "wss":
port = "443"
case "ssh":
port = "22"
case "rdp":
port = "3389"
case "smb":
port = "445"
case "tcp":
port = "7864"
default:
port = "80"
}
}
return M.ParseSocksaddr(net.JoinHostPort(host, port))
}
func validateHostname(hostname string, isLast bool) error {
if hostname == "" || hostname == "*" {
if !isLast {
return E.New("only the last ingress rule may be a catch-all rule")
}
return nil
}
if strings.Count(hostname, "*") > 1 || (strings.Contains(hostname, "*") && !strings.HasPrefix(hostname, "*.")) {
return E.New("hostname wildcard must be in the form *.example.com")
}
if stripPort(hostname) != hostname {
return E.New("ingress hostname cannot contain a port")
}
return nil
}
func isCatchAllRule(hostname, path string) bool {
return (hostname == "" || hostname == "*") && path == ""
}
func stripPort(hostname string) string {
if host, _, err := net.SplitHostPort(hostname); err == nil {
return host
}
return hostname
}
func defaultOriginRequestConfig() OriginRequestConfig {
return OriginRequestConfig{
ConnectTimeout: defaultHTTPConnectTimeout,
TLSTimeout: defaultTLSTimeout,
TCPKeepAlive: defaultTCPKeepAlive,
KeepAliveTimeout: defaultKeepAliveTimeout,
KeepAliveConnections: defaultKeepAliveConnections,
ProxyAddress: defaultProxyAddress,
}
}
func originRequestFromOption(input option.CloudflareTunnelOriginRequestOptions) OriginRequestConfig {
config := defaultOriginRequestConfig()
if input.ConnectTimeout != 0 {
config.ConnectTimeout = time.Duration(input.ConnectTimeout)
}
if input.TLSTimeout != 0 {
config.TLSTimeout = time.Duration(input.TLSTimeout)
}
if input.TCPKeepAlive != 0 {
config.TCPKeepAlive = time.Duration(input.TCPKeepAlive)
}
if input.KeepAliveTimeout != 0 {
config.KeepAliveTimeout = time.Duration(input.KeepAliveTimeout)
}
if input.KeepAliveConnections != 0 {
config.KeepAliveConnections = input.KeepAliveConnections
}
config.NoHappyEyeballs = input.NoHappyEyeballs
config.HTTPHostHeader = input.HTTPHostHeader
config.OriginServerName = input.OriginServerName
config.MatchSNIToHost = input.MatchSNIToHost
config.CAPool = input.CAPool
config.NoTLSVerify = input.NoTLSVerify
config.DisableChunkedEncoding = input.DisableChunkedEncoding
config.BastionMode = input.BastionMode
if input.ProxyAddress != "" {
config.ProxyAddress = input.ProxyAddress
}
if input.ProxyPort != 0 {
config.ProxyPort = input.ProxyPort
}
config.ProxyType = input.ProxyType
config.HTTP2Origin = input.HTTP2Origin
config.Access = AccessConfig{
Required: input.Access.Required,
TeamName: input.Access.TeamName,
AudTag: append([]string(nil), input.Access.AudTag...),
Environment: input.Access.Environment,
}
for _, rule := range input.IPRules {
config.IPRules = append(config.IPRules, IPRule{
Prefix: rule.Prefix,
Ports: append([]int(nil), rule.Ports...),
Allow: rule.Allow,
})
}
return config
}
func mergeOptionOriginRequest(base OriginRequestConfig, override option.CloudflareTunnelOriginRequestOptions) OriginRequestConfig {
result := base
if override.ConnectTimeout != 0 {
result.ConnectTimeout = time.Duration(override.ConnectTimeout)
}
if override.TLSTimeout != 0 {
result.TLSTimeout = time.Duration(override.TLSTimeout)
}
if override.TCPKeepAlive != 0 {
result.TCPKeepAlive = time.Duration(override.TCPKeepAlive)
}
if override.KeepAliveTimeout != 0 {
result.KeepAliveTimeout = time.Duration(override.KeepAliveTimeout)
}
if override.KeepAliveConnections != 0 {
result.KeepAliveConnections = override.KeepAliveConnections
}
result.NoHappyEyeballs = override.NoHappyEyeballs
if override.HTTPHostHeader != "" {
result.HTTPHostHeader = override.HTTPHostHeader
}
if override.OriginServerName != "" {
result.OriginServerName = override.OriginServerName
}
result.MatchSNIToHost = override.MatchSNIToHost
if override.CAPool != "" {
result.CAPool = override.CAPool
}
result.NoTLSVerify = override.NoTLSVerify
result.DisableChunkedEncoding = override.DisableChunkedEncoding
result.BastionMode = override.BastionMode
if override.ProxyAddress != "" {
result.ProxyAddress = override.ProxyAddress
}
if override.ProxyPort != 0 {
result.ProxyPort = override.ProxyPort
}
if override.ProxyType != "" {
result.ProxyType = override.ProxyType
}
if len(override.IPRules) > 0 {
result.IPRules = nil
for _, rule := range override.IPRules {
result.IPRules = append(result.IPRules, IPRule{
Prefix: rule.Prefix,
Ports: append([]int(nil), rule.Ports...),
Allow: rule.Allow,
})
}
}
result.HTTP2Origin = override.HTTP2Origin
if override.Access.Required || override.Access.TeamName != "" || len(override.Access.AudTag) > 0 || override.Access.Environment != "" {
result.Access = AccessConfig{
Required: override.Access.Required,
TeamName: override.Access.TeamName,
AudTag: append([]string(nil), override.Access.AudTag...),
Environment: override.Access.Environment,
}
}
return result
}
func originRequestFromRemote(input remoteOriginRequestJSON) OriginRequestConfig {
config := defaultOriginRequestConfig()
if input.ConnectTimeout != 0 {
config.ConnectTimeout = time.Duration(input.ConnectTimeout) * time.Second
}
if input.TLSTimeout != 0 {
config.TLSTimeout = time.Duration(input.TLSTimeout) * time.Second
}
if input.TCPKeepAlive != 0 {
config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) * time.Second
}
if input.KeepAliveTimeout != 0 {
config.KeepAliveTimeout = time.Duration(input.KeepAliveTimeout) * time.Second
}
if input.KeepAliveConnections != nil {
config.KeepAliveConnections = *input.KeepAliveConnections
}
if input.NoHappyEyeballs != nil {
config.NoHappyEyeballs = *input.NoHappyEyeballs
}
config.HTTPHostHeader = input.HTTPHostHeader
config.OriginServerName = input.OriginServerName
if input.MatchSNIToHost != nil {
config.MatchSNIToHost = *input.MatchSNIToHost
}
config.CAPool = input.CAPool
if input.NoTLSVerify != nil {
config.NoTLSVerify = *input.NoTLSVerify
}
if input.DisableChunkedEncoding != nil {
config.DisableChunkedEncoding = *input.DisableChunkedEncoding
}
if input.BastionMode != nil {
config.BastionMode = *input.BastionMode
}
if input.ProxyAddress != "" {
config.ProxyAddress = input.ProxyAddress
}
if input.ProxyPort != nil {
config.ProxyPort = *input.ProxyPort
}
config.ProxyType = input.ProxyType
if input.HTTP2Origin != nil {
config.HTTP2Origin = *input.HTTP2Origin
}
if input.Access != nil {
config.Access = AccessConfig{
Required: input.Access.Required,
TeamName: input.Access.TeamName,
AudTag: append([]string(nil), input.Access.AudTag...),
Environment: input.Access.Environment,
}
}
for _, rule := range input.IPRules {
config.IPRules = append(config.IPRules, IPRule{
Prefix: rule.Prefix,
Ports: append([]int(nil), rule.Ports...),
Allow: rule.Allow,
})
}
return config
}
func mergeRemoteOriginRequest(base OriginRequestConfig, override remoteOriginRequestJSON) OriginRequestConfig {
result := base
if override.ConnectTimeout != 0 {
result.ConnectTimeout = time.Duration(override.ConnectTimeout) * time.Second
}
if override.TLSTimeout != 0 {
result.TLSTimeout = time.Duration(override.TLSTimeout) * time.Second
}
if override.TCPKeepAlive != 0 {
result.TCPKeepAlive = time.Duration(override.TCPKeepAlive) * time.Second
}
if override.NoHappyEyeballs != nil {
result.NoHappyEyeballs = *override.NoHappyEyeballs
}
if override.KeepAliveTimeout != 0 {
result.KeepAliveTimeout = time.Duration(override.KeepAliveTimeout) * time.Second
}
if override.KeepAliveConnections != nil {
result.KeepAliveConnections = *override.KeepAliveConnections
}
if override.HTTPHostHeader != "" {
result.HTTPHostHeader = override.HTTPHostHeader
}
if override.OriginServerName != "" {
result.OriginServerName = override.OriginServerName
}
if override.MatchSNIToHost != nil {
result.MatchSNIToHost = *override.MatchSNIToHost
}
if override.CAPool != "" {
result.CAPool = override.CAPool
}
if override.NoTLSVerify != nil {
result.NoTLSVerify = *override.NoTLSVerify
}
if override.DisableChunkedEncoding != nil {
result.DisableChunkedEncoding = *override.DisableChunkedEncoding
}
if override.BastionMode != nil {
result.BastionMode = *override.BastionMode
}
if override.ProxyAddress != "" {
result.ProxyAddress = override.ProxyAddress
}
if override.ProxyPort != nil {
result.ProxyPort = *override.ProxyPort
}
if override.ProxyType != "" {
result.ProxyType = override.ProxyType
}
if len(override.IPRules) > 0 {
result.IPRules = nil
for _, rule := range override.IPRules {
result.IPRules = append(result.IPRules, IPRule{
Prefix: rule.Prefix,
Ports: append([]int(nil), rule.Ports...),
Allow: rule.Allow,
})
}
}
if override.HTTP2Origin != nil {
result.HTTP2Origin = *override.HTTP2Origin
}
if override.Access != nil {
result.Access = AccessConfig{
Required: override.Access.Required,
TeamName: override.Access.TeamName,
AudTag: append([]string(nil), override.Access.AudTag...),
Environment: override.Access.Environment,
}
}
return result
}
func warpRoutingFromOption(input option.CloudflareTunnelWarpRoutingOptions) WarpRoutingConfig {
config := WarpRoutingConfig{
ConnectTimeout: defaultWarpRoutingConnectTime,
TCPKeepAlive: defaultWarpRoutingTCPKeepAlive,
MaxActiveFlows: input.MaxActiveFlows,
}
if input.ConnectTimeout != 0 {
config.ConnectTimeout = time.Duration(input.ConnectTimeout)
}
if input.TCPKeepAlive != 0 {
config.TCPKeepAlive = time.Duration(input.TCPKeepAlive)
}
return config
}
func warpRoutingFromRemote(input remoteWarpRoutingJSON) WarpRoutingConfig {
config := WarpRoutingConfig{
ConnectTimeout: defaultWarpRoutingConnectTime,
TCPKeepAlive: defaultWarpRoutingTCPKeepAlive,
MaxActiveFlows: input.MaxActiveFlows,
}
if input.ConnectTimeout != 0 {
config.ConnectTimeout = time.Duration(input.ConnectTimeout) * time.Second
}
if input.TCPKeepAlive != 0 {
config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) * time.Second
}
return config
}