diff --git a/option/cloudflare_tunnel.go b/option/cloudflare_tunnel.go index a1a2c4442..c0fdbfa87 100644 --- a/option/cloudflare_tunnel.go +++ b/option/cloudflare_tunnel.go @@ -3,11 +3,63 @@ package option import "github.com/sagernet/sing/common/json/badoption" type CloudflareTunnelInboundOptions struct { - Token string `json:"token,omitempty"` - CredentialPath string `json:"credential_path,omitempty"` - HAConnections int `json:"ha_connections,omitempty"` - Protocol string `json:"protocol,omitempty"` - EdgeIPVersion int `json:"edge_ip_version,omitempty"` - DatagramVersion string `json:"datagram_version,omitempty"` - GracePeriod badoption.Duration `json:"grace_period,omitempty"` + Token string `json:"token,omitempty"` + CredentialPath string `json:"credential_path,omitempty"` + HAConnections int `json:"ha_connections,omitempty"` + Protocol string `json:"protocol,omitempty"` + EdgeIPVersion int `json:"edge_ip_version,omitempty"` + DatagramVersion string `json:"datagram_version,omitempty"` + GracePeriod badoption.Duration `json:"grace_period,omitempty"` + Region string `json:"region,omitempty"` + Ingress []CloudflareTunnelIngressRule `json:"ingress,omitempty"` + OriginRequest CloudflareTunnelOriginRequestOptions `json:"origin_request,omitempty"` + WarpRouting CloudflareTunnelWarpRoutingOptions `json:"warp_routing,omitempty"` +} + +type CloudflareTunnelIngressRule struct { + Hostname string `json:"hostname,omitempty"` + Path string `json:"path,omitempty"` + Service string `json:"service,omitempty"` + OriginRequest CloudflareTunnelOriginRequestOptions `json:"origin_request,omitempty"` +} + +type CloudflareTunnelOriginRequestOptions struct { + ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"` + TLSTimeout badoption.Duration `json:"tls_timeout,omitempty"` + TCPKeepAlive badoption.Duration `json:"tcp_keep_alive,omitempty"` + NoHappyEyeballs bool `json:"no_happy_eyeballs,omitempty"` + KeepAliveTimeout badoption.Duration `json:"keep_alive_timeout,omitempty"` + KeepAliveConnections int `json:"keep_alive_connections,omitempty"` + HTTPHostHeader string `json:"http_host_header,omitempty"` + OriginServerName string `json:"origin_server_name,omitempty"` + MatchSNIToHost bool `json:"match_sni_to_host,omitempty"` + CAPool string `json:"ca_pool,omitempty"` + NoTLSVerify bool `json:"no_tls_verify,omitempty"` + DisableChunkedEncoding bool `json:"disable_chunked_encoding,omitempty"` + BastionMode bool `json:"bastion_mode,omitempty"` + ProxyAddress string `json:"proxy_address,omitempty"` + ProxyPort uint `json:"proxy_port,omitempty"` + ProxyType string `json:"proxy_type,omitempty"` + IPRules []CloudflareTunnelIPRule `json:"ip_rules,omitempty"` + HTTP2Origin bool `json:"http2_origin,omitempty"` + Access CloudflareTunnelAccessRule `json:"access,omitempty"` +} + +type CloudflareTunnelAccessRule struct { + Required bool `json:"required,omitempty"` + TeamName string `json:"team_name,omitempty"` + AudTag []string `json:"aud_tag,omitempty"` + Environment string `json:"environment,omitempty"` +} + +type CloudflareTunnelIPRule struct { + Prefix string `json:"prefix,omitempty"` + Ports []int `json:"ports,omitempty"` + Allow bool `json:"allow,omitempty"` +} + +type CloudflareTunnelWarpRoutingOptions struct { + ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"` + MaxActiveFlows uint64 `json:"max_active_flows,omitempty"` + TCPKeepAlive badoption.Duration `json:"tcp_keep_alive,omitempty"` } diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index afe0699bb..42d9bffea 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -90,10 +90,10 @@ func NewHTTP2Connection( server: &http2.Server{ MaxConcurrentStreams: math.MaxUint32, }, - logger: logger, - edgeAddr: edgeAddr, - connIndex: connIndex, - credentials: credentials, + logger: logger, + edgeAddr: edgeAddr, + connIndex: connIndex, + credentials: credentials, connectorID: connectorID, features: features, numPreviousAttempts: numPreviousAttempts, @@ -244,9 +244,13 @@ func (c *HTTP2Connection) handleConfigurationUpdate(r *http.Request, w http.Resp w.WriteHeader(http.StatusBadRequest) return } - c.inbound.UpdateIngress(body.Version, body.Config) + result := c.inbound.ApplyConfig(body.Version, body.Config) w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(body.Version), 10) + `,"err":null}`)) + if result.Err != nil { + w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":` + strconv.Quote(result.Err.Error()) + `}`)) + return + } + w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":null}`)) } func (c *HTTP2Connection) close() { diff --git a/protocol/cloudflare/control.go b/protocol/cloudflare/control.go index a4130f096..6e9881114 100644 --- a/protocol/cloudflare/control.go +++ b/protocol/cloudflare/control.go @@ -171,7 +171,6 @@ func DefaultFeatures(datagramVersion string) []string { "support_datagram_v2", "support_quic_eof", "allow_remote_config", - "management_logs", } if datagramVersion == "v3" { features = append(features, "support_datagram_v3_2") diff --git a/protocol/cloudflare/datagram_v2.go b/protocol/cloudflare/datagram_v2.go index 8159b04cc..2071d86e9 100644 --- a/protocol/cloudflare/datagram_v2.go +++ b/protocol/cloudflare/datagram_v2.go @@ -318,12 +318,17 @@ func (s *cloudflaredServer) UnregisterUdpSession(call tunnelrpc.SessionManager_u func (s *cloudflaredServer) UpdateConfiguration(call tunnelrpc.ConfigurationManager_updateConfiguration) error { version := call.Params.Version() configData, _ := call.Params.Config() - s.inbound.UpdateIngress(version, configData) + updateResult := s.inbound.ApplyConfig(version, configData) result, err := call.Results.NewResult() if err != nil { return err } - result.SetErr("") + result.SetLatestAppliedVersion(updateResult.LastAppliedVersion) + if updateResult.Err != nil { + result.SetErr(updateResult.Err.Error()) + } else { + result.SetErr("") + } return nil } diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 7f949a7a4..9263e16c3 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -4,12 +4,16 @@ package cloudflare import ( "context" + "crypto/tls" + "crypto/x509" "io" "net" "net/http" "net/url" + "os" "strconv" "strings" + "sync" "time" "github.com/sagernet/sing-box/adapter" @@ -124,19 +128,42 @@ func (i *Inbound) dispatchRequest(ctx context.Context, stream io.ReadWriteCloser metadata.Destination = M.ParseSocksaddr(request.Dest) i.handleTCPStream(ctx, stream, respWriter, metadata) case ConnectionTypeHTTP, ConnectionTypeWebsocket: - originURL := i.ResolveOriginURL(request.Dest) - request.Dest = originURL - metadata.Destination = parseHTTPDestination(originURL) - if request.Type == ConnectionTypeHTTP { - i.handleHTTPStream(ctx, stream, respWriter, request, metadata) - } else { - i.handleWebSocketStream(ctx, stream, respWriter, request, metadata) + service, originURL, err := i.resolveHTTPService(request.Dest) + if err != nil { + i.logger.ErrorContext(ctx, "resolve origin service: ", err) + respWriter.WriteResponse(err, nil) + return } + request.Dest = originURL + i.handleHTTPService(ctx, stream, respWriter, request, metadata, service) default: i.logger.ErrorContext(ctx, "unknown connection type: ", request.Type) } } +func (i *Inbound) resolveHTTPService(requestURL string) (ResolvedService, string, error) { + parsedURL, err := url.Parse(requestURL) + if err != nil { + return ResolvedService{}, "", E.Cause(err, "parse request URL") + } + service, loaded := i.configManager.Resolve(parsedURL.Hostname(), parsedURL.Path) + if !loaded { + return ResolvedService{}, "", E.New("no ingress rule matched request host/path") + } + if service.Kind == ResolvedServiceHelloWorld { + helloURL, err := i.ensureHelloWorldURL() + if err != nil { + return ResolvedService{}, "", err + } + service.BaseURL = helloURL + } + originURL, err := service.BuildRequestURL(requestURL) + if err != nil { + return ResolvedService{}, "", E.Cause(err, "build origin request URL") + } + return service, originURL, nil +} + func parseHTTPDestination(dest string) M.Socksaddr { parsed, err := url.Parse(dest) if err != nil { @@ -172,10 +199,81 @@ func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser <-done } -func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) { +func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { + switch service.Kind { + case ResolvedServiceStatus: + err := respWriter.WriteResponse(nil, encodeResponseHeaders(service.StatusCode, http.Header{})) + if err != nil { + i.logger.ErrorContext(ctx, "write status service response: ", err) + } + return + case ResolvedServiceHTTP: + metadata.Destination = service.Destination + if request.Type == ConnectionTypeHTTP { + i.handleHTTPStream(ctx, stream, respWriter, request, metadata, service) + } else { + i.handleWebSocketStream(ctx, stream, respWriter, request, metadata, service) + } + case ResolvedServiceUnix, ResolvedServiceUnixTLS, ResolvedServiceHelloWorld: + if request.Type == ConnectionTypeHTTP { + i.handleDirectHTTPStream(ctx, stream, respWriter, request, metadata, service) + } else { + i.handleDirectWebSocketStream(ctx, stream, respWriter, request, metadata, service) + } + default: + err := E.New("unsupported service kind for HTTP/WebSocket request") + i.logger.ErrorContext(ctx, err) + respWriter.WriteResponse(err, nil) + } +} + +func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { metadata.Network = N.NetworkTCP i.logger.InfoContext(ctx, "inbound HTTP connection to ", metadata.Destination) + transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest) + defer cleanup() + i.roundTripHTTP(ctx, stream, respWriter, request, service, transport) +} + +func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { + metadata.Network = N.NetworkTCP + i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination) + + transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest) + defer cleanup() + i.roundTripHTTP(ctx, stream, respWriter, request, service, transport) +} + +func (i *Inbound) handleDirectHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { + metadata.Network = N.NetworkTCP + i.logger.InfoContext(ctx, "inbound HTTP connection to ", request.Dest) + + transport, cleanup, err := i.newDirectOriginTransport(service) + if err != nil { + i.logger.ErrorContext(ctx, "build direct origin transport: ", err) + respWriter.WriteResponse(err, nil) + return + } + defer cleanup() + i.roundTripHTTP(ctx, stream, respWriter, request, service, transport) +} + +func (i *Inbound) handleDirectWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { + metadata.Network = N.NetworkTCP + i.logger.InfoContext(ctx, "inbound WebSocket connection to ", request.Dest) + + transport, cleanup, err := i.newDirectOriginTransport(service) + if err != nil { + i.logger.ErrorContext(ctx, "build direct origin transport: ", err) + respWriter.WriteResponse(err, nil) + return + } + defer cleanup() + i.roundTripHTTP(ctx, stream, respWriter, request, service, transport) +} + +func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, service ResolvedService, transport *http.Transport) { httpRequest, err := buildHTTPRequestFromMetadata(ctx, request, stream) if err != nil { i.logger.ErrorContext(ctx, "build HTTP request: ", err) @@ -183,23 +281,17 @@ func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteClose return } - input, output := pipe.Pipe() - var innerError error - - done := make(chan struct{}) - go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) { - innerError = it - common.Close(input, output) - close(done) - })) + httpRequest = applyOriginRequest(httpRequest, service.OriginRequest) + requestCtx := httpRequest.Context() + if service.OriginRequest.ConnectTimeout > 0 { + var cancel context.CancelFunc + requestCtx, cancel = context.WithTimeout(requestCtx, service.OriginRequest.ConnectTimeout) + defer cancel() + httpRequest = httpRequest.WithContext(requestCtx) + } httpClient := &http.Client{ - Transport: &http.Transport{ - DisableCompression: true, - DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { - return input, nil - }, - }, + Transport: transport, CheckRedirect: func(request *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, @@ -208,87 +300,146 @@ func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteClose response, err := httpClient.Do(httpRequest) if err != nil { - <-done - i.logger.ErrorContext(ctx, "HTTP request: ", E.Errors(innerError, err)) + i.logger.ErrorContext(ctx, "origin request: ", err) respWriter.WriteResponse(err, nil) return } + defer response.Body.Close() responseMetadata := encodeResponseHeaders(response.StatusCode, response.Header) err = respWriter.WriteResponse(nil, responseMetadata) if err != nil { - response.Body.Close() - i.logger.ErrorContext(ctx, "write HTTP response headers: ", err) - <-done + i.logger.ErrorContext(ctx, "write origin response headers: ", err) + return + } + + if request.Type == ConnectionTypeWebsocket && response.StatusCode == http.StatusSwitchingProtocols { + rwc, ok := response.Body.(io.ReadWriteCloser) + if !ok { + i.logger.ErrorContext(ctx, "websocket origin response body is not duplex") + return + } + bidirectionalCopy(stream, rwc) return } _, err = io.Copy(stream, response.Body) - response.Body.Close() - common.Close(input, output) if err != nil && !E.IsClosedOrCanceled(err) { i.logger.DebugContext(ctx, "copy HTTP response body: ", err) } - <-done } -func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) { - metadata.Network = N.NetworkTCP - i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination) - - httpRequest, err := buildHTTPRequestFromMetadata(ctx, request, stream) - if err != nil { - i.logger.ErrorContext(ctx, "build WebSocket request: ", err) - respWriter.WriteResponse(err, nil) - return - } - +func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig) (*http.Transport, func()) { input, output := pipe.Pipe() - var innerError error - done := make(chan struct{}) go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) { - innerError = it common.Close(input, output) close(done) })) - httpClient := &http.Client{ - Transport: &http.Transport{ - DisableCompression: true, - DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { - return input, nil - }, - }, - CheckRedirect: func(request *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse + transport := &http.Transport{ + DisableCompression: true, + ForceAttemptHTTP2: originRequest.HTTP2Origin, + TLSHandshakeTimeout: originRequest.TLSTimeout, + IdleConnTimeout: originRequest.KeepAliveTimeout, + MaxIdleConns: originRequest.KeepAliveConnections, + MaxIdleConnsPerHost: originRequest.KeepAliveConnections, + TLSClientConfig: buildOriginTLSConfig(originRequest), + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return input, nil }, } - defer httpClient.CloseIdleConnections() + return transport, func() { + common.Close(input, output) + select { + case <-done: + case <-time.After(time.Second): + } + } +} - response, err := httpClient.Do(httpRequest) +func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Transport, func(), error) { + transport := &http.Transport{ + DisableCompression: true, + ForceAttemptHTTP2: service.OriginRequest.HTTP2Origin, + TLSHandshakeTimeout: service.OriginRequest.TLSTimeout, + IdleConnTimeout: service.OriginRequest.KeepAliveTimeout, + MaxIdleConns: service.OriginRequest.KeepAliveConnections, + MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections, + TLSClientConfig: buildOriginTLSConfig(service.OriginRequest), + } + switch service.Kind { + case ResolvedServiceUnix, ResolvedServiceUnixTLS: + dialer := &net.Dialer{} + transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + return dialer.DialContext(ctx, "unix", service.UnixPath) + } + case ResolvedServiceHelloWorld: + dialer := &net.Dialer{} + target := service.BaseURL.Host + transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + return dialer.DialContext(ctx, "tcp", target) + } + default: + return nil, nil, E.New("unsupported direct origin service") + } + return transport, func() {}, nil +} + +func buildOriginTLSConfig(originRequest OriginRequestConfig) *tls.Config { + tlsConfig := &tls.Config{ + InsecureSkipVerify: originRequest.NoTLSVerify, //nolint:gosec + ServerName: originRequest.OriginServerName, + } + if originRequest.CAPool == "" { + return tlsConfig + } + pemData, err := os.ReadFile(originRequest.CAPool) if err != nil { - <-done - i.logger.ErrorContext(ctx, "WebSocket request: ", E.Errors(innerError, err)) - respWriter.WriteResponse(err, nil) - return + return tlsConfig + } + pool := x509.NewCertPool() + if pool.AppendCertsFromPEM(pemData) { + tlsConfig.RootCAs = pool + } + return tlsConfig +} + +func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig) *http.Request { + request = request.Clone(request.Context()) + if originRequest.HTTPHostHeader != "" { + request.Header.Set("X-Forwarded-Host", request.Host) + request.Host = originRequest.HTTPHostHeader + } + if originRequest.DisableChunkedEncoding && request.Header.Get("Content-Length") != "" { + if contentLength, err := strconv.ParseInt(request.Header.Get("Content-Length"), 10, 64); err == nil { + request.ContentLength = contentLength + request.TransferEncoding = nil + } + } + return request +} + +func bidirectionalCopy(left, right io.ReadWriteCloser) { + var closeOnce sync.Once + closeBoth := func() { + closeOnce.Do(func() { + common.Close(left, right) + }) } - responseMetadata := encodeResponseHeaders(response.StatusCode, response.Header) - err = respWriter.WriteResponse(nil, responseMetadata) - if err != nil { - response.Body.Close() - i.logger.ErrorContext(ctx, "write WebSocket response headers: ", err) - <-done - return - } - - _, err = io.Copy(stream, response.Body) - response.Body.Close() - common.Close(input, output) - if err != nil && !E.IsClosedOrCanceled(err) { - i.logger.DebugContext(ctx, "copy WebSocket response body: ", err) - } + done := make(chan struct{}, 2) + go func() { + io.Copy(left, right) + closeBoth() + done <- struct{}{} + }() + go func() { + io.Copy(right, left) + closeBoth() + done <- struct{}{} + }() + <-done <-done } diff --git a/protocol/cloudflare/helpers_test.go b/protocol/cloudflare/helpers_test.go index 82676434d..8920b7705 100644 --- a/protocol/cloudflare/helpers_test.go +++ b/protocol/cloudflare/helpers_test.go @@ -142,6 +142,11 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i t.Fatal("create logger: ", err) } + configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{}) + if err != nil { + t.Fatal("create config manager: ", err) + } + ctx, cancel := context.WithCancel(context.Background()) inboundInstance := &Inbound{ Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), @@ -156,6 +161,7 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i edgeIPVersion: 0, datagramVersion: "", gracePeriod: 5 * time.Second, + configManager: configManager, datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), } diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index dd0bd4398..0fcb2ac65 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -7,9 +7,10 @@ import ( "encoding/base64" "io" "math/rand" + "net" + "net/http" "net/url" "os" - "strings" "sync" "time" @@ -30,17 +31,19 @@ func RegisterInbound(registry *inbound.Registry) { type Inbound struct { inbound.Adapter - ctx context.Context - cancel context.CancelFunc - router adapter.ConnectionRouterEx - logger log.ContextLogger - credentials Credentials - connectorID uuid.UUID - haConnections int - protocol string - edgeIPVersion int - datagramVersion string - gracePeriod time.Duration + ctx context.Context + cancel context.CancelFunc + router adapter.ConnectionRouterEx + logger log.ContextLogger + credentials Credentials + connectorID uuid.UUID + haConnections int + protocol string + region string + edgeIPVersion int + datagramVersion string + gracePeriod time.Duration + configManager *ConfigManager connectionAccess sync.Mutex connections []io.Closer @@ -50,101 +53,9 @@ type Inbound struct { datagramV2Muxers map[DatagramSender]*DatagramV2Muxer datagramV3Muxers map[DatagramSender]*DatagramV3Muxer - ingressAccess sync.RWMutex - ingressVersion int32 - ingressRules []IngressRule -} - -// IngressRule maps a hostname pattern to an origin service URL. -type IngressRule struct { - Hostname string - Service string -} - -type ingressConfig struct { - Ingress []ingressConfigRule `json:"ingress"` -} - -type ingressConfigRule struct { - Hostname string `json:"hostname,omitempty"` - Service string `json:"service"` -} - -// UpdateIngress applies a new ingress configuration from the edge. -func (i *Inbound) UpdateIngress(version int32, config []byte) { - i.ingressAccess.Lock() - defer i.ingressAccess.Unlock() - - if version <= i.ingressVersion { - return - } - - var parsed ingressConfig - err := json.Unmarshal(config, &parsed) - if err != nil { - i.logger.Error("parse ingress config: ", err) - return - } - - rules := make([]IngressRule, 0, len(parsed.Ingress)) - for _, rule := range parsed.Ingress { - rules = append(rules, IngressRule{ - Hostname: rule.Hostname, - Service: rule.Service, - }) - } - i.ingressRules = rules - i.ingressVersion = version - i.logger.Info("updated ingress configuration (version ", version, ", ", len(rules), " rules)") -} - -// ResolveOrigin finds the origin service URL for a given hostname. -// Returns the service URL if matched, or empty string if no match. -func (i *Inbound) ResolveOrigin(hostname string) string { - i.ingressAccess.RLock() - defer i.ingressAccess.RUnlock() - - for _, rule := range i.ingressRules { - if rule.Hostname == "" { - return rule.Service - } - if matchIngress(rule.Hostname, hostname) { - return rule.Service - } - } - return "" -} - -func matchIngress(pattern, hostname string) bool { - if pattern == hostname { - return true - } - if strings.HasPrefix(pattern, "*.") { - suffix := pattern[1:] - return strings.HasSuffix(hostname, suffix) - } - return false -} - -// ResolveOriginURL rewrites a request URL to point to the origin service. -// For example, https://testbox.badnet.work/path → http://127.0.0.1:8083/path -func (i *Inbound) ResolveOriginURL(requestURL string) string { - parsed, err := url.Parse(requestURL) - if err != nil { - return requestURL - } - hostname := parsed.Hostname() - origin := i.ResolveOrigin(hostname) - if origin == "" || strings.HasPrefix(origin, "http_status:") { - return requestURL - } - originURL, err := url.Parse(origin) - if err != nil { - return requestURL - } - parsed.Scheme = originURL.Scheme - parsed.Host = originURL.Host - return parsed.String() + helloWorldAccess sync.Mutex + helloWorldServer *http.Server + helloWorldURL *url.URL } func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflareTunnelInboundOptions) (adapter.Inbound, error) { @@ -178,23 +89,30 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo gracePeriod = 30 * time.Second } + configManager, err := NewConfigManager(options) + if err != nil { + return nil, E.Cause(err, "build cloudflare tunnel runtime config") + } + inboundCtx, cancel := context.WithCancel(ctx) return &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, tag), - ctx: inboundCtx, - cancel: cancel, - router: router, - logger: logger, - credentials: credentials, - connectorID: uuid.New(), - haConnections: haConnections, - protocol: protocol, - edgeIPVersion: edgeIPVersion, - datagramVersion: datagramVersion, - gracePeriod: gracePeriod, - datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), - datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), + Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, tag), + ctx: inboundCtx, + cancel: cancel, + router: router, + logger: logger, + credentials: credentials, + connectorID: uuid.New(), + haConnections: haConnections, + protocol: protocol, + region: options.Region, + edgeIPVersion: edgeIPVersion, + datagramVersion: datagramVersion, + gracePeriod: gracePeriod, + configManager: configManager, + datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), + datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), }, nil } @@ -238,6 +156,16 @@ func (i *Inbound) Start(stage adapter.StartStage) error { return nil } +func (i *Inbound) ApplyConfig(version int32, config []byte) ConfigUpdateResult { + result := i.configManager.Apply(version, config) + if result.Err != nil { + i.logger.Error("update ingress configuration: ", result.Err) + return result + } + i.logger.Info("updated ingress configuration (version ", result.LastAppliedVersion, ")") + return result +} + func (i *Inbound) Close() error { i.cancel() i.done.Wait() @@ -247,9 +175,41 @@ func (i *Inbound) Close() error { } i.connections = nil i.connectionAccess.Unlock() + if i.helloWorldServer != nil { + i.helloWorldServer.Close() + } return nil } +func (i *Inbound) ensureHelloWorldURL() (*url.URL, error) { + i.helloWorldAccess.Lock() + defer i.helloWorldAccess.Unlock() + if i.helloWorldURL != nil { + return i.helloWorldURL, nil + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, _ = writer.Write([]byte("Hello World")) + }) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, E.Cause(err, "listen hello world server") + } + server := &http.Server{Handler: mux} + go server.Serve(listener) + + i.helloWorldServer = server + i.helloWorldURL = &url.URL{ + Scheme: "http", + Host: listener.Addr().String(), + } + return i.helloWorldURL, nil +} + const ( backoffBaseTime = time.Second backoffMaxTime = 2 * time.Minute diff --git a/protocol/cloudflare/ingress_test.go b/protocol/cloudflare/ingress_test.go index 190eb5b15..03a91f0f0 100644 --- a/protocol/cloudflare/ingress_test.go +++ b/protocol/cloudflare/ingress_test.go @@ -6,143 +6,148 @@ import ( "testing" "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" ) -func newTestIngressInbound() *Inbound { - return &Inbound{logger: log.NewNOPFactory().NewLogger("test")} +func newTestIngressInbound(t *testing.T) *Inbound { + t.Helper() + configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{}) + if err != nil { + t.Fatal(err) + } + return &Inbound{ + logger: log.NewNOPFactory().NewLogger("test"), + configManager: configManager, + } } -func TestUpdateIngress(t *testing.T) { - inboundInstance := newTestIngressInbound() +func mustResolvedService(t *testing.T, rawService string) ResolvedService { + t.Helper() + service, err := parseResolvedService(rawService, defaultOriginRequestConfig()) + if err != nil { + t.Fatal(err) + } + return service +} + +func TestApplyConfig(t *testing.T) { + inboundInstance := newTestIngressInbound(t) config1 := []byte(`{"ingress":[{"hostname":"a.com","service":"http://localhost:80"},{"hostname":"b.com","service":"http://localhost:81"},{"service":"http_status:404"}]}`) - inboundInstance.UpdateIngress(1, config1) - - inboundInstance.ingressAccess.RLock() - count := len(inboundInstance.ingressRules) - inboundInstance.ingressAccess.RUnlock() - if count != 3 { - t.Fatalf("expected 3 rules, got %d", count) + result := inboundInstance.ApplyConfig(1, config1) + if result.Err != nil { + t.Fatal(result.Err) + } + if result.LastAppliedVersion != 1 { + t.Fatalf("expected version 1, got %d", result.LastAppliedVersion) } - inboundInstance.UpdateIngress(1, []byte(`{"ingress":[{"service":"http_status:503"}]}`)) - inboundInstance.ingressAccess.RLock() - count = len(inboundInstance.ingressRules) - inboundInstance.ingressAccess.RUnlock() - if count != 3 { - t.Error("version 1 re-apply should not change rules, got ", count) + service, loaded := inboundInstance.configManager.Resolve("a.com", "/") + if !loaded || service.Service != "http://localhost:80" { + t.Fatalf("expected a.com to resolve to localhost:80, got %#v, loaded=%v", service, loaded) } - inboundInstance.UpdateIngress(2, []byte(`{"ingress":[{"service":"http_status:503"}]}`)) - inboundInstance.ingressAccess.RLock() - count = len(inboundInstance.ingressRules) - inboundInstance.ingressAccess.RUnlock() - if count != 1 { - t.Error("version 2 should update to 1 rule, got ", count) + result = inboundInstance.ApplyConfig(1, []byte(`{"ingress":[{"service":"http_status:503"}]}`)) + if result.Err != nil { + t.Fatal(result.Err) + } + if result.LastAppliedVersion != 1 { + t.Fatalf("same version should keep current version, got %d", result.LastAppliedVersion) + } + + service, loaded = inboundInstance.configManager.Resolve("b.com", "/") + if !loaded || service.Service != "http://localhost:81" { + t.Fatalf("expected old rules to remain, got %#v, loaded=%v", service, loaded) + } + + result = inboundInstance.ApplyConfig(2, []byte(`{"ingress":[{"service":"http_status:503"}]}`)) + if result.Err != nil { + t.Fatal(result.Err) + } + if result.LastAppliedVersion != 2 { + t.Fatalf("expected version 2, got %d", result.LastAppliedVersion) + } + + service, loaded = inboundInstance.configManager.Resolve("anything.com", "/") + if !loaded || service.StatusCode != 503 { + t.Fatalf("expected catch-all status 503, got %#v, loaded=%v", service, loaded) } } -func TestUpdateIngressInvalidJSON(t *testing.T) { - inboundInstance := newTestIngressInbound() - inboundInstance.UpdateIngress(1, []byte("not json")) - - inboundInstance.ingressAccess.RLock() - count := len(inboundInstance.ingressRules) - inboundInstance.ingressAccess.RUnlock() - if count != 0 { - t.Error("invalid JSON should leave rules empty, got ", count) +func TestApplyConfigInvalidJSON(t *testing.T) { + inboundInstance := newTestIngressInbound(t) + result := inboundInstance.ApplyConfig(1, []byte("not json")) + if result.Err == nil { + t.Fatal("expected parse error") + } + if result.LastAppliedVersion != -1 { + t.Fatalf("expected version to stay -1, got %d", result.LastAppliedVersion) } } -func TestResolveOriginExact(t *testing.T) { - inboundInstance := newTestIngressInbound() - inboundInstance.ingressRules = []IngressRule{ - {Hostname: "test.example.com", Service: "http://localhost:8080"}, - {Hostname: "", Service: "http_status:404"}, +func TestResolveExactAndWildcard(t *testing.T) { + inboundInstance := newTestIngressInbound(t) + inboundInstance.configManager.activeConfig = RuntimeConfig{ + Ingress: []compiledIngressRule{ + {Hostname: "test.example.com", Service: mustResolvedService(t, "http://localhost:8080")}, + {Hostname: "*.example.com", Service: mustResolvedService(t, "http://localhost:9090")}, + {Service: mustResolvedService(t, "http_status:404")}, + }, } - result := inboundInstance.ResolveOrigin("test.example.com") - if result != "http://localhost:8080" { - t.Error("expected http://localhost:8080, got ", result) + service, loaded := inboundInstance.configManager.Resolve("test.example.com", "/") + if !loaded || service.Service != "http://localhost:8080" { + t.Fatalf("expected exact match, got %#v, loaded=%v", service, loaded) + } + + service, loaded = inboundInstance.configManager.Resolve("sub.example.com", "/") + if !loaded || service.Service != "http://localhost:9090" { + t.Fatalf("expected wildcard match, got %#v, loaded=%v", service, loaded) + } + + service, loaded = inboundInstance.configManager.Resolve("unknown.test", "/") + if !loaded || service.StatusCode != 404 { + t.Fatalf("expected catch-all 404, got %#v, loaded=%v", service, loaded) } } -func TestResolveOriginWildcard(t *testing.T) { - inboundInstance := newTestIngressInbound() - inboundInstance.ingressRules = []IngressRule{ - {Hostname: "*.example.com", Service: "http://localhost:9090"}, +func TestResolveHTTPService(t *testing.T) { + inboundInstance := newTestIngressInbound(t) + inboundInstance.configManager.activeConfig = RuntimeConfig{ + Ingress: []compiledIngressRule{ + {Hostname: "foo.com", Service: mustResolvedService(t, "http://127.0.0.1:8083")}, + {Service: mustResolvedService(t, "http_status:404")}, + }, } - result := inboundInstance.ResolveOrigin("sub.example.com") - if result != "http://localhost:9090" { - t.Error("wildcard should match sub.example.com, got ", result) + service, requestURL, err := inboundInstance.resolveHTTPService("https://foo.com/path?q=1") + if err != nil { + t.Fatal(err) } - - result = inboundInstance.ResolveOrigin("example.com") - if result != "" { - t.Error("wildcard should not match bare example.com, got ", result) + if service.Destination.String() != "127.0.0.1:8083" { + t.Fatalf("expected destination 127.0.0.1:8083, got %s", service.Destination) + } + if requestURL != "http://127.0.0.1:8083/path?q=1" { + t.Fatalf("expected rewritten URL, got %s", requestURL) } } -func TestResolveOriginCatchAll(t *testing.T) { - inboundInstance := newTestIngressInbound() - inboundInstance.ingressRules = []IngressRule{ - {Hostname: "specific.com", Service: "http://localhost:1"}, - {Hostname: "", Service: "http://localhost:2"}, +func TestResolveHTTPServiceStatus(t *testing.T) { + inboundInstance := newTestIngressInbound(t) + inboundInstance.configManager.activeConfig = RuntimeConfig{ + Ingress: []compiledIngressRule{ + {Service: mustResolvedService(t, "http_status:404")}, + }, } - result := inboundInstance.ResolveOrigin("anything.com") - if result != "http://localhost:2" { - t.Error("catch-all should match, got ", result) - } -} - -func TestResolveOriginNoMatch(t *testing.T) { - inboundInstance := newTestIngressInbound() - inboundInstance.ingressRules = []IngressRule{ - {Hostname: "specific.com", Service: "http://localhost:1"}, - } - - result := inboundInstance.ResolveOrigin("other.com") - if result != "" { - t.Error("expected empty for no match, got ", result) - } -} - -func TestResolveOriginURLRewrite(t *testing.T) { - inboundInstance := newTestIngressInbound() - inboundInstance.ingressRules = []IngressRule{ - {Hostname: "foo.com", Service: "http://127.0.0.1:8083"}, - } - - result := inboundInstance.ResolveOriginURL("https://foo.com/path?q=1") - if result != "http://127.0.0.1:8083/path?q=1" { - t.Error("expected http://127.0.0.1:8083/path?q=1, got ", result) - } -} - -func TestResolveOriginURLNoMatch(t *testing.T) { - inboundInstance := newTestIngressInbound() - inboundInstance.ingressRules = []IngressRule{ - {Hostname: "other.com", Service: "http://localhost:1"}, - } - - original := "https://unknown.com/page" - result := inboundInstance.ResolveOriginURL(original) - if result != original { - t.Error("no match should return original, got ", result) - } -} - -func TestResolveOriginURLHTTPStatus(t *testing.T) { - inboundInstance := newTestIngressInbound() - inboundInstance.ingressRules = []IngressRule{ - {Hostname: "", Service: "http_status:404"}, - } - - original := "https://any.com/page" - result := inboundInstance.ResolveOriginURL(original) - if result != original { - t.Error("http_status service should return original, got ", result) + service, requestURL, err := inboundInstance.resolveHTTPService("https://any.com/path") + if err != nil { + t.Fatal(err) + } + if service.StatusCode != 404 { + t.Fatalf("expected status 404, got %#v", service) + } + if requestURL != "https://any.com/path" { + t.Fatalf("status service should keep request URL, got %s", requestURL) } } diff --git a/protocol/cloudflare/runtime_config.go b/protocol/cloudflare/runtime_config.go new file mode 100644 index 000000000..99fe73c81 --- /dev/null +++ b/protocol/cloudflare/runtime_config.go @@ -0,0 +1,803 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "encoding/json" + "net" + "net/url" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + + "golang.org/x/net/idna" +) + +const ( + defaultHTTPConnectTimeout = 30 * time.Second + defaultTLSTimeout = 10 * time.Second + defaultTCPKeepAlive = 30 * time.Second + defaultKeepAliveTimeout = 90 * time.Second + defaultKeepAliveConnections = 100 + defaultProxyAddress = "127.0.0.1" + defaultWarpRoutingConnectTime = 5 * time.Second + defaultWarpRoutingTCPKeepAlive = 30 * time.Second +) + +type ResolvedServiceKind int + +const ( + ResolvedServiceHTTP ResolvedServiceKind = iota + ResolvedServiceStream + ResolvedServiceStatus + ResolvedServiceHelloWorld + ResolvedServiceUnix + ResolvedServiceUnixTLS +) + +type ResolvedService struct { + Kind ResolvedServiceKind + Service string + Destination M.Socksaddr + BaseURL *url.URL + UnixPath string + StatusCode int + OriginRequest OriginRequestConfig +} + +func (s ResolvedService) RouterControlled() bool { + return s.Kind == ResolvedServiceHTTP || s.Kind == ResolvedServiceStream +} + +func (s ResolvedService) BuildRequestURL(requestURL string) (string, error) { + switch s.Kind { + case ResolvedServiceHTTP, ResolvedServiceUnix, ResolvedServiceUnixTLS: + requestParsed, err := url.Parse(requestURL) + if err != nil { + return "", err + } + originURL := *s.BaseURL + originURL.Path = requestParsed.Path + originURL.RawPath = requestParsed.RawPath + originURL.RawQuery = requestParsed.RawQuery + originURL.Fragment = requestParsed.Fragment + return originURL.String(), nil + case ResolvedServiceHelloWorld: + if s.BaseURL == nil { + return "", E.New("hello world service is unavailable") + } + requestParsed, err := url.Parse(requestURL) + if err != nil { + return "", err + } + originURL := *s.BaseURL + originURL.Path = requestParsed.Path + originURL.RawPath = requestParsed.RawPath + originURL.RawQuery = requestParsed.RawQuery + originURL.Fragment = requestParsed.Fragment + return originURL.String(), nil + default: + return requestURL, nil + } +} + +type compiledIngressRule struct { + Hostname string + PunycodeHostname string + Path *regexp.Regexp + Service ResolvedService +} + +type RuntimeConfig struct { + Ingress []compiledIngressRule + OriginRequest OriginRequestConfig + WarpRouting WarpRoutingConfig +} + +type OriginRequestConfig struct { + ConnectTimeout time.Duration + TLSTimeout time.Duration + TCPKeepAlive time.Duration + NoHappyEyeballs bool + KeepAliveTimeout time.Duration + KeepAliveConnections int + HTTPHostHeader string + OriginServerName string + MatchSNIToHost bool + CAPool string + NoTLSVerify bool + DisableChunkedEncoding bool + BastionMode bool + ProxyAddress string + ProxyPort uint + ProxyType string + IPRules []IPRule + HTTP2Origin bool + Access AccessConfig +} + +type AccessConfig struct { + Required bool + TeamName string + AudTag []string + Environment string +} + +type IPRule struct { + Prefix string + Ports []int + Allow bool +} + +type WarpRoutingConfig struct { + ConnectTimeout time.Duration + MaxActiveFlows uint64 + TCPKeepAlive time.Duration +} + +type ConfigUpdateResult struct { + LastAppliedVersion int32 + Err error +} + +type ConfigManager struct { + access sync.RWMutex + currentVersion int32 + activeConfig RuntimeConfig +} + +func NewConfigManager(options option.CloudflareTunnelInboundOptions) (*ConfigManager, error) { + config, err := buildLocalRuntimeConfig(options) + if err != nil { + return nil, err + } + return &ConfigManager{ + currentVersion: -1, + activeConfig: config, + }, nil +} + +func (m *ConfigManager) Snapshot() RuntimeConfig { + m.access.RLock() + defer m.access.RUnlock() + return m.activeConfig +} + +func (m *ConfigManager) CurrentVersion() int32 { + m.access.RLock() + defer m.access.RUnlock() + return m.currentVersion +} + +func (m *ConfigManager) Apply(version int32, raw []byte) ConfigUpdateResult { + m.access.Lock() + defer m.access.Unlock() + + if version <= m.currentVersion { + return ConfigUpdateResult{LastAppliedVersion: m.currentVersion} + } + + config, err := buildRemoteRuntimeConfig(raw) + if err != nil { + return ConfigUpdateResult{ + LastAppliedVersion: m.currentVersion, + Err: err, + } + } + + m.activeConfig = config + m.currentVersion = version + return ConfigUpdateResult{LastAppliedVersion: m.currentVersion} +} + +func (m *ConfigManager) Resolve(hostname, path string) (ResolvedService, bool) { + m.access.RLock() + defer m.access.RUnlock() + return m.activeConfig.Resolve(hostname, path) +} + +func (c RuntimeConfig) Resolve(hostname, path string) (ResolvedService, bool) { + host := stripPort(hostname) + for _, rule := range c.Ingress { + if !matchIngressRule(rule, host, path) { + continue + } + return rule.Service, true + } + return ResolvedService{}, false +} + +func matchIngressRule(rule compiledIngressRule, hostname, path string) bool { + hostMatch := rule.Hostname == "" || rule.Hostname == "*" || matchIngressHost(rule.Hostname, hostname) + if !hostMatch && rule.PunycodeHostname != "" { + hostMatch = matchIngressHost(rule.PunycodeHostname, hostname) + } + if !hostMatch { + return false + } + return rule.Path == nil || rule.Path.MatchString(path) +} + +func matchIngressHost(pattern, hostname string) bool { + if pattern == hostname { + return true + } + if strings.HasPrefix(pattern, "*.") { + return strings.HasSuffix(hostname, strings.TrimPrefix(pattern, "*")) + } + return false +} + +func buildLocalRuntimeConfig(options option.CloudflareTunnelInboundOptions) (RuntimeConfig, error) { + defaultOriginRequest := originRequestFromOption(options.OriginRequest) + warpRouting := warpRoutingFromOption(options.WarpRouting) + var ingressRules []localIngressRule + for _, rule := range options.Ingress { + ingressRules = append(ingressRules, localIngressRule{ + Hostname: rule.Hostname, + Path: rule.Path, + Service: rule.Service, + OriginRequest: mergeOptionOriginRequest(defaultOriginRequest, rule.OriginRequest), + }) + } + compiledRules, err := compileIngressRules(defaultOriginRequest, ingressRules) + if err != nil { + return RuntimeConfig{}, err + } + return RuntimeConfig{ + Ingress: compiledRules, + OriginRequest: defaultOriginRequest, + WarpRouting: warpRouting, + }, nil +} + +func buildRemoteRuntimeConfig(raw []byte) (RuntimeConfig, error) { + var remote remoteConfigJSON + if err := json.Unmarshal(raw, &remote); err != nil { + return RuntimeConfig{}, E.Cause(err, "decode remote config") + } + defaultOriginRequest := originRequestFromRemote(remote.OriginRequest) + warpRouting := warpRoutingFromRemote(remote.WarpRouting) + var ingressRules []localIngressRule + for _, rule := range remote.Ingress { + ingressRules = append(ingressRules, localIngressRule{ + Hostname: rule.Hostname, + Path: rule.Path, + Service: rule.Service, + OriginRequest: mergeRemoteOriginRequest(defaultOriginRequest, rule.OriginRequest), + }) + } + compiledRules, err := compileIngressRules(defaultOriginRequest, ingressRules) + if err != nil { + return RuntimeConfig{}, err + } + return RuntimeConfig{ + Ingress: compiledRules, + OriginRequest: defaultOriginRequest, + WarpRouting: warpRouting, + }, nil +} + +type localIngressRule struct { + Hostname string + Path string + Service string + OriginRequest OriginRequestConfig +} + +type remoteConfigJSON struct { + OriginRequest remoteOriginRequestJSON `json:"originRequest"` + Ingress []remoteIngressRuleJSON `json:"ingress"` + WarpRouting remoteWarpRoutingJSON `json:"warp-routing"` +} + +type remoteIngressRuleJSON struct { + Hostname string `json:"hostname,omitempty"` + Path string `json:"path,omitempty"` + Service string `json:"service"` + OriginRequest remoteOriginRequestJSON `json:"originRequest,omitempty"` +} + +type remoteOriginRequestJSON struct { + ConnectTimeout int64 `json:"connectTimeout,omitempty"` + TLSTimeout int64 `json:"tlsTimeout,omitempty"` + TCPKeepAlive int64 `json:"tcpKeepAlive,omitempty"` + NoHappyEyeballs *bool `json:"noHappyEyeballs,omitempty"` + KeepAliveTimeout int64 `json:"keepAliveTimeout,omitempty"` + KeepAliveConnections *int `json:"keepAliveConnections,omitempty"` + HTTPHostHeader string `json:"httpHostHeader,omitempty"` + OriginServerName string `json:"originServerName,omitempty"` + MatchSNIToHost *bool `json:"matchSNIToHost,omitempty"` + CAPool string `json:"caPool,omitempty"` + NoTLSVerify *bool `json:"noTLSVerify,omitempty"` + DisableChunkedEncoding *bool `json:"disableChunkedEncoding,omitempty"` + BastionMode *bool `json:"bastionMode,omitempty"` + ProxyAddress string `json:"proxyAddress,omitempty"` + ProxyPort *uint `json:"proxyPort,omitempty"` + ProxyType string `json:"proxyType,omitempty"` + IPRules []remoteIPRuleJSON `json:"ipRules,omitempty"` + HTTP2Origin *bool `json:"http2Origin,omitempty"` + Access *remoteAccessJSON `json:"access,omitempty"` +} + +type remoteAccessJSON struct { + Required bool `json:"required,omitempty"` + TeamName string `json:"teamName,omitempty"` + AudTag []string `json:"audTag,omitempty"` + Environment string `json:"environment,omitempty"` +} + +type remoteIPRuleJSON struct { + Prefix string `json:"prefix,omitempty"` + Ports []int `json:"ports,omitempty"` + Allow bool `json:"allow,omitempty"` +} + +type remoteWarpRoutingJSON struct { + ConnectTimeout int64 `json:"connectTimeout,omitempty"` + MaxActiveFlows uint64 `json:"maxActiveFlows,omitempty"` + TCPKeepAlive int64 `json:"tcpKeepAlive,omitempty"` +} + +func compileIngressRules(defaultOriginRequest OriginRequestConfig, rawRules []localIngressRule) ([]compiledIngressRule, error) { + if len(rawRules) == 0 { + rawRules = []localIngressRule{{ + Service: "http_status:503", + OriginRequest: defaultOriginRequest, + }} + } + if !isCatchAllRule(rawRules[len(rawRules)-1].Hostname, rawRules[len(rawRules)-1].Path) { + return nil, E.New("the last ingress rule must be a catch-all rule") + } + + compiled := make([]compiledIngressRule, 0, len(rawRules)) + for index, rule := range rawRules { + if err := validateHostname(rule.Hostname, index == len(rawRules)-1); err != nil { + return nil, err + } + service, err := parseResolvedService(rule.Service, rule.OriginRequest) + if err != nil { + return nil, err + } + var pathPattern *regexp.Regexp + if rule.Path != "" { + pathPattern, err = regexp.Compile(rule.Path) + if err != nil { + return nil, E.Cause(err, "compile ingress path regex") + } + } + punycode := "" + if rule.Hostname != "" && rule.Hostname != "*" { + punycodeValue, err := idna.Lookup.ToASCII(rule.Hostname) + if err == nil && punycodeValue != rule.Hostname { + punycode = punycodeValue + } + } + compiled = append(compiled, compiledIngressRule{ + Hostname: rule.Hostname, + PunycodeHostname: punycode, + Path: pathPattern, + Service: service, + }) + } + return compiled, nil +} + +func parseResolvedService(rawService string, originRequest OriginRequestConfig) (ResolvedService, error) { + switch { + case rawService == "": + return ResolvedService{}, E.New("missing ingress service") + case strings.HasPrefix(rawService, "http_status:"): + statusCode, err := strconv.Atoi(strings.TrimPrefix(rawService, "http_status:")) + if err != nil { + return ResolvedService{}, E.Cause(err, "parse http_status service") + } + if statusCode < 100 || statusCode > 999 { + return ResolvedService{}, E.New("invalid http_status code: ", statusCode) + } + return ResolvedService{ + Kind: ResolvedServiceStatus, + Service: rawService, + StatusCode: statusCode, + OriginRequest: originRequest, + }, nil + case rawService == "hello_world" || rawService == "hello-world": + return ResolvedService{ + Kind: ResolvedServiceHelloWorld, + Service: rawService, + OriginRequest: originRequest, + }, nil + case strings.HasPrefix(rawService, "unix:"): + return ResolvedService{ + Kind: ResolvedServiceUnix, + Service: rawService, + UnixPath: strings.TrimPrefix(rawService, "unix:"), + BaseURL: &url.URL{Scheme: "http", Host: "localhost"}, + OriginRequest: originRequest, + }, nil + case strings.HasPrefix(rawService, "unix+tls:"): + return ResolvedService{ + Kind: ResolvedServiceUnixTLS, + Service: rawService, + UnixPath: strings.TrimPrefix(rawService, "unix+tls:"), + BaseURL: &url.URL{Scheme: "https", Host: "localhost"}, + OriginRequest: originRequest, + }, nil + } + + parsedURL, err := url.Parse(rawService) + if err != nil { + return ResolvedService{}, E.Cause(err, "parse ingress service URL") + } + if parsedURL.Scheme == "" || parsedURL.Hostname() == "" { + return ResolvedService{}, E.New("ingress service must include scheme and hostname: ", rawService) + } + if parsedURL.Path != "" { + return ResolvedService{}, E.New("ingress service cannot include a path: ", rawService) + } + + switch parsedURL.Scheme { + case "http", "https", "ws", "wss": + return ResolvedService{ + Kind: ResolvedServiceHTTP, + Service: rawService, + Destination: parseServiceDestination(parsedURL), + BaseURL: parsedURL, + OriginRequest: originRequest, + }, nil + case "tcp", "ssh", "rdp", "smb": + return ResolvedService{ + Kind: ResolvedServiceStream, + Service: rawService, + Destination: parseServiceDestination(parsedURL), + BaseURL: parsedURL, + OriginRequest: originRequest, + }, nil + default: + return ResolvedService{}, E.New("unsupported ingress service scheme: ", parsedURL.Scheme) + } +} + +func parseServiceDestination(parsedURL *url.URL) M.Socksaddr { + host := parsedURL.Hostname() + port := parsedURL.Port() + if port == "" { + switch parsedURL.Scheme { + case "https", "wss": + port = "443" + case "ssh": + port = "22" + case "rdp": + port = "3389" + case "smb": + port = "445" + case "tcp": + port = "7864" + default: + port = "80" + } + } + return M.ParseSocksaddr(net.JoinHostPort(host, port)) +} + +func validateHostname(hostname string, isLast bool) error { + if hostname == "" || hostname == "*" { + if !isLast { + return E.New("only the last ingress rule may be a catch-all rule") + } + return nil + } + if strings.Count(hostname, "*") > 1 || (strings.Contains(hostname, "*") && !strings.HasPrefix(hostname, "*.")) { + return E.New("hostname wildcard must be in the form *.example.com") + } + if stripPort(hostname) != hostname { + return E.New("ingress hostname cannot contain a port") + } + return nil +} + +func isCatchAllRule(hostname, path string) bool { + return (hostname == "" || hostname == "*") && path == "" +} + +func stripPort(hostname string) string { + if host, _, err := net.SplitHostPort(hostname); err == nil { + return host + } + return hostname +} + +func defaultOriginRequestConfig() OriginRequestConfig { + return OriginRequestConfig{ + ConnectTimeout: defaultHTTPConnectTimeout, + TLSTimeout: defaultTLSTimeout, + TCPKeepAlive: defaultTCPKeepAlive, + KeepAliveTimeout: defaultKeepAliveTimeout, + KeepAliveConnections: defaultKeepAliveConnections, + ProxyAddress: defaultProxyAddress, + } +} + +func originRequestFromOption(input option.CloudflareTunnelOriginRequestOptions) OriginRequestConfig { + config := defaultOriginRequestConfig() + if input.ConnectTimeout != 0 { + config.ConnectTimeout = time.Duration(input.ConnectTimeout) + } + if input.TLSTimeout != 0 { + config.TLSTimeout = time.Duration(input.TLSTimeout) + } + if input.TCPKeepAlive != 0 { + config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) + } + if input.KeepAliveTimeout != 0 { + config.KeepAliveTimeout = time.Duration(input.KeepAliveTimeout) + } + if input.KeepAliveConnections != 0 { + config.KeepAliveConnections = input.KeepAliveConnections + } + config.NoHappyEyeballs = input.NoHappyEyeballs + config.HTTPHostHeader = input.HTTPHostHeader + config.OriginServerName = input.OriginServerName + config.MatchSNIToHost = input.MatchSNIToHost + config.CAPool = input.CAPool + config.NoTLSVerify = input.NoTLSVerify + config.DisableChunkedEncoding = input.DisableChunkedEncoding + config.BastionMode = input.BastionMode + if input.ProxyAddress != "" { + config.ProxyAddress = input.ProxyAddress + } + if input.ProxyPort != 0 { + config.ProxyPort = input.ProxyPort + } + config.ProxyType = input.ProxyType + config.HTTP2Origin = input.HTTP2Origin + config.Access = AccessConfig{ + Required: input.Access.Required, + TeamName: input.Access.TeamName, + AudTag: append([]string(nil), input.Access.AudTag...), + Environment: input.Access.Environment, + } + for _, rule := range input.IPRules { + config.IPRules = append(config.IPRules, IPRule{ + Prefix: rule.Prefix, + Ports: append([]int(nil), rule.Ports...), + Allow: rule.Allow, + }) + } + return config +} + +func mergeOptionOriginRequest(base OriginRequestConfig, override option.CloudflareTunnelOriginRequestOptions) OriginRequestConfig { + result := base + if override.ConnectTimeout != 0 { + result.ConnectTimeout = time.Duration(override.ConnectTimeout) + } + if override.TLSTimeout != 0 { + result.TLSTimeout = time.Duration(override.TLSTimeout) + } + if override.TCPKeepAlive != 0 { + result.TCPKeepAlive = time.Duration(override.TCPKeepAlive) + } + if override.KeepAliveTimeout != 0 { + result.KeepAliveTimeout = time.Duration(override.KeepAliveTimeout) + } + if override.KeepAliveConnections != 0 { + result.KeepAliveConnections = override.KeepAliveConnections + } + result.NoHappyEyeballs = override.NoHappyEyeballs + if override.HTTPHostHeader != "" { + result.HTTPHostHeader = override.HTTPHostHeader + } + if override.OriginServerName != "" { + result.OriginServerName = override.OriginServerName + } + result.MatchSNIToHost = override.MatchSNIToHost + if override.CAPool != "" { + result.CAPool = override.CAPool + } + result.NoTLSVerify = override.NoTLSVerify + result.DisableChunkedEncoding = override.DisableChunkedEncoding + result.BastionMode = override.BastionMode + if override.ProxyAddress != "" { + result.ProxyAddress = override.ProxyAddress + } + if override.ProxyPort != 0 { + result.ProxyPort = override.ProxyPort + } + if override.ProxyType != "" { + result.ProxyType = override.ProxyType + } + if len(override.IPRules) > 0 { + result.IPRules = nil + for _, rule := range override.IPRules { + result.IPRules = append(result.IPRules, IPRule{ + Prefix: rule.Prefix, + Ports: append([]int(nil), rule.Ports...), + Allow: rule.Allow, + }) + } + } + result.HTTP2Origin = override.HTTP2Origin + if override.Access.Required || override.Access.TeamName != "" || len(override.Access.AudTag) > 0 || override.Access.Environment != "" { + result.Access = AccessConfig{ + Required: override.Access.Required, + TeamName: override.Access.TeamName, + AudTag: append([]string(nil), override.Access.AudTag...), + Environment: override.Access.Environment, + } + } + return result +} + +func originRequestFromRemote(input remoteOriginRequestJSON) OriginRequestConfig { + config := defaultOriginRequestConfig() + if input.ConnectTimeout != 0 { + config.ConnectTimeout = time.Duration(input.ConnectTimeout) * time.Second + } + if input.TLSTimeout != 0 { + config.TLSTimeout = time.Duration(input.TLSTimeout) * time.Second + } + if input.TCPKeepAlive != 0 { + config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) * time.Second + } + if input.KeepAliveTimeout != 0 { + config.KeepAliveTimeout = time.Duration(input.KeepAliveTimeout) * time.Second + } + if input.KeepAliveConnections != nil { + config.KeepAliveConnections = *input.KeepAliveConnections + } + if input.NoHappyEyeballs != nil { + config.NoHappyEyeballs = *input.NoHappyEyeballs + } + config.HTTPHostHeader = input.HTTPHostHeader + config.OriginServerName = input.OriginServerName + if input.MatchSNIToHost != nil { + config.MatchSNIToHost = *input.MatchSNIToHost + } + config.CAPool = input.CAPool + if input.NoTLSVerify != nil { + config.NoTLSVerify = *input.NoTLSVerify + } + if input.DisableChunkedEncoding != nil { + config.DisableChunkedEncoding = *input.DisableChunkedEncoding + } + if input.BastionMode != nil { + config.BastionMode = *input.BastionMode + } + if input.ProxyAddress != "" { + config.ProxyAddress = input.ProxyAddress + } + if input.ProxyPort != nil { + config.ProxyPort = *input.ProxyPort + } + config.ProxyType = input.ProxyType + if input.HTTP2Origin != nil { + config.HTTP2Origin = *input.HTTP2Origin + } + if input.Access != nil { + config.Access = AccessConfig{ + Required: input.Access.Required, + TeamName: input.Access.TeamName, + AudTag: append([]string(nil), input.Access.AudTag...), + Environment: input.Access.Environment, + } + } + for _, rule := range input.IPRules { + config.IPRules = append(config.IPRules, IPRule{ + Prefix: rule.Prefix, + Ports: append([]int(nil), rule.Ports...), + Allow: rule.Allow, + }) + } + return config +} + +func mergeRemoteOriginRequest(base OriginRequestConfig, override remoteOriginRequestJSON) OriginRequestConfig { + result := base + if override.ConnectTimeout != 0 { + result.ConnectTimeout = time.Duration(override.ConnectTimeout) * time.Second + } + if override.TLSTimeout != 0 { + result.TLSTimeout = time.Duration(override.TLSTimeout) * time.Second + } + if override.TCPKeepAlive != 0 { + result.TCPKeepAlive = time.Duration(override.TCPKeepAlive) * time.Second + } + if override.NoHappyEyeballs != nil { + result.NoHappyEyeballs = *override.NoHappyEyeballs + } + if override.KeepAliveTimeout != 0 { + result.KeepAliveTimeout = time.Duration(override.KeepAliveTimeout) * time.Second + } + if override.KeepAliveConnections != nil { + result.KeepAliveConnections = *override.KeepAliveConnections + } + if override.HTTPHostHeader != "" { + result.HTTPHostHeader = override.HTTPHostHeader + } + if override.OriginServerName != "" { + result.OriginServerName = override.OriginServerName + } + if override.MatchSNIToHost != nil { + result.MatchSNIToHost = *override.MatchSNIToHost + } + if override.CAPool != "" { + result.CAPool = override.CAPool + } + if override.NoTLSVerify != nil { + result.NoTLSVerify = *override.NoTLSVerify + } + if override.DisableChunkedEncoding != nil { + result.DisableChunkedEncoding = *override.DisableChunkedEncoding + } + if override.BastionMode != nil { + result.BastionMode = *override.BastionMode + } + if override.ProxyAddress != "" { + result.ProxyAddress = override.ProxyAddress + } + if override.ProxyPort != nil { + result.ProxyPort = *override.ProxyPort + } + if override.ProxyType != "" { + result.ProxyType = override.ProxyType + } + if len(override.IPRules) > 0 { + result.IPRules = nil + for _, rule := range override.IPRules { + result.IPRules = append(result.IPRules, IPRule{ + Prefix: rule.Prefix, + Ports: append([]int(nil), rule.Ports...), + Allow: rule.Allow, + }) + } + } + if override.HTTP2Origin != nil { + result.HTTP2Origin = *override.HTTP2Origin + } + if override.Access != nil { + result.Access = AccessConfig{ + Required: override.Access.Required, + TeamName: override.Access.TeamName, + AudTag: append([]string(nil), override.Access.AudTag...), + Environment: override.Access.Environment, + } + } + return result +} + +func warpRoutingFromOption(input option.CloudflareTunnelWarpRoutingOptions) WarpRoutingConfig { + config := WarpRoutingConfig{ + ConnectTimeout: defaultWarpRoutingConnectTime, + TCPKeepAlive: defaultWarpRoutingTCPKeepAlive, + MaxActiveFlows: input.MaxActiveFlows, + } + if input.ConnectTimeout != 0 { + config.ConnectTimeout = time.Duration(input.ConnectTimeout) + } + if input.TCPKeepAlive != 0 { + config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) + } + return config +} + +func warpRoutingFromRemote(input remoteWarpRoutingJSON) WarpRoutingConfig { + config := WarpRoutingConfig{ + ConnectTimeout: defaultWarpRoutingConnectTime, + TCPKeepAlive: defaultWarpRoutingTCPKeepAlive, + MaxActiveFlows: input.MaxActiveFlows, + } + if input.ConnectTimeout != 0 { + config.ConnectTimeout = time.Duration(input.ConnectTimeout) * time.Second + } + if input.TCPKeepAlive != 0 { + config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) * time.Second + } + return config +}