mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-14 04:38:28 +10:00
Implement router-backed cloudflare tunnel ingress config
This commit is contained in:
@@ -3,11 +3,63 @@ package option
|
||||
import "github.com/sagernet/sing/common/json/badoption"
|
||||
|
||||
type CloudflareTunnelInboundOptions struct {
|
||||
Token string `json:"token,omitempty"`
|
||||
CredentialPath string `json:"credential_path,omitempty"`
|
||||
HAConnections int `json:"ha_connections,omitempty"`
|
||||
Protocol string `json:"protocol,omitempty"`
|
||||
EdgeIPVersion int `json:"edge_ip_version,omitempty"`
|
||||
DatagramVersion string `json:"datagram_version,omitempty"`
|
||||
GracePeriod badoption.Duration `json:"grace_period,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
CredentialPath string `json:"credential_path,omitempty"`
|
||||
HAConnections int `json:"ha_connections,omitempty"`
|
||||
Protocol string `json:"protocol,omitempty"`
|
||||
EdgeIPVersion int `json:"edge_ip_version,omitempty"`
|
||||
DatagramVersion string `json:"datagram_version,omitempty"`
|
||||
GracePeriod badoption.Duration `json:"grace_period,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
Ingress []CloudflareTunnelIngressRule `json:"ingress,omitempty"`
|
||||
OriginRequest CloudflareTunnelOriginRequestOptions `json:"origin_request,omitempty"`
|
||||
WarpRouting CloudflareTunnelWarpRoutingOptions `json:"warp_routing,omitempty"`
|
||||
}
|
||||
|
||||
type CloudflareTunnelIngressRule struct {
|
||||
Hostname string `json:"hostname,omitempty"`
|
||||
Path string `json:"path,omitempty"`
|
||||
Service string `json:"service,omitempty"`
|
||||
OriginRequest CloudflareTunnelOriginRequestOptions `json:"origin_request,omitempty"`
|
||||
}
|
||||
|
||||
type CloudflareTunnelOriginRequestOptions struct {
|
||||
ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"`
|
||||
TLSTimeout badoption.Duration `json:"tls_timeout,omitempty"`
|
||||
TCPKeepAlive badoption.Duration `json:"tcp_keep_alive,omitempty"`
|
||||
NoHappyEyeballs bool `json:"no_happy_eyeballs,omitempty"`
|
||||
KeepAliveTimeout badoption.Duration `json:"keep_alive_timeout,omitempty"`
|
||||
KeepAliveConnections int `json:"keep_alive_connections,omitempty"`
|
||||
HTTPHostHeader string `json:"http_host_header,omitempty"`
|
||||
OriginServerName string `json:"origin_server_name,omitempty"`
|
||||
MatchSNIToHost bool `json:"match_sni_to_host,omitempty"`
|
||||
CAPool string `json:"ca_pool,omitempty"`
|
||||
NoTLSVerify bool `json:"no_tls_verify,omitempty"`
|
||||
DisableChunkedEncoding bool `json:"disable_chunked_encoding,omitempty"`
|
||||
BastionMode bool `json:"bastion_mode,omitempty"`
|
||||
ProxyAddress string `json:"proxy_address,omitempty"`
|
||||
ProxyPort uint `json:"proxy_port,omitempty"`
|
||||
ProxyType string `json:"proxy_type,omitempty"`
|
||||
IPRules []CloudflareTunnelIPRule `json:"ip_rules,omitempty"`
|
||||
HTTP2Origin bool `json:"http2_origin,omitempty"`
|
||||
Access CloudflareTunnelAccessRule `json:"access,omitempty"`
|
||||
}
|
||||
|
||||
type CloudflareTunnelAccessRule struct {
|
||||
Required bool `json:"required,omitempty"`
|
||||
TeamName string `json:"team_name,omitempty"`
|
||||
AudTag []string `json:"aud_tag,omitempty"`
|
||||
Environment string `json:"environment,omitempty"`
|
||||
}
|
||||
|
||||
type CloudflareTunnelIPRule struct {
|
||||
Prefix string `json:"prefix,omitempty"`
|
||||
Ports []int `json:"ports,omitempty"`
|
||||
Allow bool `json:"allow,omitempty"`
|
||||
}
|
||||
|
||||
type CloudflareTunnelWarpRoutingOptions struct {
|
||||
ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"`
|
||||
MaxActiveFlows uint64 `json:"max_active_flows,omitempty"`
|
||||
TCPKeepAlive badoption.Duration `json:"tcp_keep_alive,omitempty"`
|
||||
}
|
||||
|
||||
@@ -90,10 +90,10 @@ func NewHTTP2Connection(
|
||||
server: &http2.Server{
|
||||
MaxConcurrentStreams: math.MaxUint32,
|
||||
},
|
||||
logger: logger,
|
||||
edgeAddr: edgeAddr,
|
||||
connIndex: connIndex,
|
||||
credentials: credentials,
|
||||
logger: logger,
|
||||
edgeAddr: edgeAddr,
|
||||
connIndex: connIndex,
|
||||
credentials: credentials,
|
||||
connectorID: connectorID,
|
||||
features: features,
|
||||
numPreviousAttempts: numPreviousAttempts,
|
||||
@@ -244,9 +244,13 @@ func (c *HTTP2Connection) handleConfigurationUpdate(r *http.Request, w http.Resp
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
c.inbound.UpdateIngress(body.Version, body.Config)
|
||||
result := c.inbound.ApplyConfig(body.Version, body.Config)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(body.Version), 10) + `,"err":null}`))
|
||||
if result.Err != nil {
|
||||
w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":` + strconv.Quote(result.Err.Error()) + `}`))
|
||||
return
|
||||
}
|
||||
w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":null}`))
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) close() {
|
||||
|
||||
@@ -171,7 +171,6 @@ func DefaultFeatures(datagramVersion string) []string {
|
||||
"support_datagram_v2",
|
||||
"support_quic_eof",
|
||||
"allow_remote_config",
|
||||
"management_logs",
|
||||
}
|
||||
if datagramVersion == "v3" {
|
||||
features = append(features, "support_datagram_v3_2")
|
||||
|
||||
@@ -318,12 +318,17 @@ func (s *cloudflaredServer) UnregisterUdpSession(call tunnelrpc.SessionManager_u
|
||||
func (s *cloudflaredServer) UpdateConfiguration(call tunnelrpc.ConfigurationManager_updateConfiguration) error {
|
||||
version := call.Params.Version()
|
||||
configData, _ := call.Params.Config()
|
||||
s.inbound.UpdateIngress(version, configData)
|
||||
updateResult := s.inbound.ApplyConfig(version, configData)
|
||||
result, err := call.Results.NewResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result.SetErr("")
|
||||
result.SetLatestAppliedVersion(updateResult.LastAppliedVersion)
|
||||
if updateResult.Err != nil {
|
||||
result.SetErr(updateResult.Err.Error())
|
||||
} else {
|
||||
result.SetErr("")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,12 +4,16 @@ package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
@@ -124,19 +128,42 @@ func (i *Inbound) dispatchRequest(ctx context.Context, stream io.ReadWriteCloser
|
||||
metadata.Destination = M.ParseSocksaddr(request.Dest)
|
||||
i.handleTCPStream(ctx, stream, respWriter, metadata)
|
||||
case ConnectionTypeHTTP, ConnectionTypeWebsocket:
|
||||
originURL := i.ResolveOriginURL(request.Dest)
|
||||
request.Dest = originURL
|
||||
metadata.Destination = parseHTTPDestination(originURL)
|
||||
if request.Type == ConnectionTypeHTTP {
|
||||
i.handleHTTPStream(ctx, stream, respWriter, request, metadata)
|
||||
} else {
|
||||
i.handleWebSocketStream(ctx, stream, respWriter, request, metadata)
|
||||
service, originURL, err := i.resolveHTTPService(request.Dest)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "resolve origin service: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
request.Dest = originURL
|
||||
i.handleHTTPService(ctx, stream, respWriter, request, metadata, service)
|
||||
default:
|
||||
i.logger.ErrorContext(ctx, "unknown connection type: ", request.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Inbound) resolveHTTPService(requestURL string) (ResolvedService, string, error) {
|
||||
parsedURL, err := url.Parse(requestURL)
|
||||
if err != nil {
|
||||
return ResolvedService{}, "", E.Cause(err, "parse request URL")
|
||||
}
|
||||
service, loaded := i.configManager.Resolve(parsedURL.Hostname(), parsedURL.Path)
|
||||
if !loaded {
|
||||
return ResolvedService{}, "", E.New("no ingress rule matched request host/path")
|
||||
}
|
||||
if service.Kind == ResolvedServiceHelloWorld {
|
||||
helloURL, err := i.ensureHelloWorldURL()
|
||||
if err != nil {
|
||||
return ResolvedService{}, "", err
|
||||
}
|
||||
service.BaseURL = helloURL
|
||||
}
|
||||
originURL, err := service.BuildRequestURL(requestURL)
|
||||
if err != nil {
|
||||
return ResolvedService{}, "", E.Cause(err, "build origin request URL")
|
||||
}
|
||||
return service, originURL, nil
|
||||
}
|
||||
|
||||
func parseHTTPDestination(dest string) M.Socksaddr {
|
||||
parsed, err := url.Parse(dest)
|
||||
if err != nil {
|
||||
@@ -172,10 +199,81 @@ func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser
|
||||
<-done
|
||||
}
|
||||
|
||||
func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) {
|
||||
func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
switch service.Kind {
|
||||
case ResolvedServiceStatus:
|
||||
err := respWriter.WriteResponse(nil, encodeResponseHeaders(service.StatusCode, http.Header{}))
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "write status service response: ", err)
|
||||
}
|
||||
return
|
||||
case ResolvedServiceHTTP:
|
||||
metadata.Destination = service.Destination
|
||||
if request.Type == ConnectionTypeHTTP {
|
||||
i.handleHTTPStream(ctx, stream, respWriter, request, metadata, service)
|
||||
} else {
|
||||
i.handleWebSocketStream(ctx, stream, respWriter, request, metadata, service)
|
||||
}
|
||||
case ResolvedServiceUnix, ResolvedServiceUnixTLS, ResolvedServiceHelloWorld:
|
||||
if request.Type == ConnectionTypeHTTP {
|
||||
i.handleDirectHTTPStream(ctx, stream, respWriter, request, metadata, service)
|
||||
} else {
|
||||
i.handleDirectWebSocketStream(ctx, stream, respWriter, request, metadata, service)
|
||||
}
|
||||
default:
|
||||
err := E.New("unsupported service kind for HTTP/WebSocket request")
|
||||
i.logger.ErrorContext(ctx, err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound HTTP connection to ", metadata.Destination)
|
||||
|
||||
transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest)
|
||||
defer cleanup()
|
||||
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
|
||||
}
|
||||
|
||||
func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination)
|
||||
|
||||
transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest)
|
||||
defer cleanup()
|
||||
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
|
||||
}
|
||||
|
||||
func (i *Inbound) handleDirectHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound HTTP connection to ", request.Dest)
|
||||
|
||||
transport, cleanup, err := i.newDirectOriginTransport(service)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "build direct origin transport: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
defer cleanup()
|
||||
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
|
||||
}
|
||||
|
||||
func (i *Inbound) handleDirectWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", request.Dest)
|
||||
|
||||
transport, cleanup, err := i.newDirectOriginTransport(service)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "build direct origin transport: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
defer cleanup()
|
||||
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
|
||||
}
|
||||
|
||||
func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, service ResolvedService, transport *http.Transport) {
|
||||
httpRequest, err := buildHTTPRequestFromMetadata(ctx, request, stream)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "build HTTP request: ", err)
|
||||
@@ -183,23 +281,17 @@ func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteClose
|
||||
return
|
||||
}
|
||||
|
||||
input, output := pipe.Pipe()
|
||||
var innerError error
|
||||
|
||||
done := make(chan struct{})
|
||||
go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) {
|
||||
innerError = it
|
||||
common.Close(input, output)
|
||||
close(done)
|
||||
}))
|
||||
httpRequest = applyOriginRequest(httpRequest, service.OriginRequest)
|
||||
requestCtx := httpRequest.Context()
|
||||
if service.OriginRequest.ConnectTimeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
requestCtx, cancel = context.WithTimeout(requestCtx, service.OriginRequest.ConnectTimeout)
|
||||
defer cancel()
|
||||
httpRequest = httpRequest.WithContext(requestCtx)
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DisableCompression: true,
|
||||
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return input, nil
|
||||
},
|
||||
},
|
||||
Transport: transport,
|
||||
CheckRedirect: func(request *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
@@ -208,87 +300,146 @@ func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteClose
|
||||
|
||||
response, err := httpClient.Do(httpRequest)
|
||||
if err != nil {
|
||||
<-done
|
||||
i.logger.ErrorContext(ctx, "HTTP request: ", E.Errors(innerError, err))
|
||||
i.logger.ErrorContext(ctx, "origin request: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
responseMetadata := encodeResponseHeaders(response.StatusCode, response.Header)
|
||||
err = respWriter.WriteResponse(nil, responseMetadata)
|
||||
if err != nil {
|
||||
response.Body.Close()
|
||||
i.logger.ErrorContext(ctx, "write HTTP response headers: ", err)
|
||||
<-done
|
||||
i.logger.ErrorContext(ctx, "write origin response headers: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if request.Type == ConnectionTypeWebsocket && response.StatusCode == http.StatusSwitchingProtocols {
|
||||
rwc, ok := response.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
i.logger.ErrorContext(ctx, "websocket origin response body is not duplex")
|
||||
return
|
||||
}
|
||||
bidirectionalCopy(stream, rwc)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = io.Copy(stream, response.Body)
|
||||
response.Body.Close()
|
||||
common.Close(input, output)
|
||||
if err != nil && !E.IsClosedOrCanceled(err) {
|
||||
i.logger.DebugContext(ctx, "copy HTTP response body: ", err)
|
||||
}
|
||||
<-done
|
||||
}
|
||||
|
||||
func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) {
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination)
|
||||
|
||||
httpRequest, err := buildHTTPRequestFromMetadata(ctx, request, stream)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "build WebSocket request: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
|
||||
func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig) (*http.Transport, func()) {
|
||||
input, output := pipe.Pipe()
|
||||
var innerError error
|
||||
|
||||
done := make(chan struct{})
|
||||
go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) {
|
||||
innerError = it
|
||||
common.Close(input, output)
|
||||
close(done)
|
||||
}))
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DisableCompression: true,
|
||||
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return input, nil
|
||||
},
|
||||
},
|
||||
CheckRedirect: func(request *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
transport := &http.Transport{
|
||||
DisableCompression: true,
|
||||
ForceAttemptHTTP2: originRequest.HTTP2Origin,
|
||||
TLSHandshakeTimeout: originRequest.TLSTimeout,
|
||||
IdleConnTimeout: originRequest.KeepAliveTimeout,
|
||||
MaxIdleConns: originRequest.KeepAliveConnections,
|
||||
MaxIdleConnsPerHost: originRequest.KeepAliveConnections,
|
||||
TLSClientConfig: buildOriginTLSConfig(originRequest),
|
||||
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return input, nil
|
||||
},
|
||||
}
|
||||
defer httpClient.CloseIdleConnections()
|
||||
return transport, func() {
|
||||
common.Close(input, output)
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response, err := httpClient.Do(httpRequest)
|
||||
func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Transport, func(), error) {
|
||||
transport := &http.Transport{
|
||||
DisableCompression: true,
|
||||
ForceAttemptHTTP2: service.OriginRequest.HTTP2Origin,
|
||||
TLSHandshakeTimeout: service.OriginRequest.TLSTimeout,
|
||||
IdleConnTimeout: service.OriginRequest.KeepAliveTimeout,
|
||||
MaxIdleConns: service.OriginRequest.KeepAliveConnections,
|
||||
MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections,
|
||||
TLSClientConfig: buildOriginTLSConfig(service.OriginRequest),
|
||||
}
|
||||
switch service.Kind {
|
||||
case ResolvedServiceUnix, ResolvedServiceUnixTLS:
|
||||
dialer := &net.Dialer{}
|
||||
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, "unix", service.UnixPath)
|
||||
}
|
||||
case ResolvedServiceHelloWorld:
|
||||
dialer := &net.Dialer{}
|
||||
target := service.BaseURL.Host
|
||||
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, "tcp", target)
|
||||
}
|
||||
default:
|
||||
return nil, nil, E.New("unsupported direct origin service")
|
||||
}
|
||||
return transport, func() {}, nil
|
||||
}
|
||||
|
||||
func buildOriginTLSConfig(originRequest OriginRequestConfig) *tls.Config {
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: originRequest.NoTLSVerify, //nolint:gosec
|
||||
ServerName: originRequest.OriginServerName,
|
||||
}
|
||||
if originRequest.CAPool == "" {
|
||||
return tlsConfig
|
||||
}
|
||||
pemData, err := os.ReadFile(originRequest.CAPool)
|
||||
if err != nil {
|
||||
<-done
|
||||
i.logger.ErrorContext(ctx, "WebSocket request: ", E.Errors(innerError, err))
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
return tlsConfig
|
||||
}
|
||||
pool := x509.NewCertPool()
|
||||
if pool.AppendCertsFromPEM(pemData) {
|
||||
tlsConfig.RootCAs = pool
|
||||
}
|
||||
return tlsConfig
|
||||
}
|
||||
|
||||
func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig) *http.Request {
|
||||
request = request.Clone(request.Context())
|
||||
if originRequest.HTTPHostHeader != "" {
|
||||
request.Header.Set("X-Forwarded-Host", request.Host)
|
||||
request.Host = originRequest.HTTPHostHeader
|
||||
}
|
||||
if originRequest.DisableChunkedEncoding && request.Header.Get("Content-Length") != "" {
|
||||
if contentLength, err := strconv.ParseInt(request.Header.Get("Content-Length"), 10, 64); err == nil {
|
||||
request.ContentLength = contentLength
|
||||
request.TransferEncoding = nil
|
||||
}
|
||||
}
|
||||
return request
|
||||
}
|
||||
|
||||
func bidirectionalCopy(left, right io.ReadWriteCloser) {
|
||||
var closeOnce sync.Once
|
||||
closeBoth := func() {
|
||||
closeOnce.Do(func() {
|
||||
common.Close(left, right)
|
||||
})
|
||||
}
|
||||
|
||||
responseMetadata := encodeResponseHeaders(response.StatusCode, response.Header)
|
||||
err = respWriter.WriteResponse(nil, responseMetadata)
|
||||
if err != nil {
|
||||
response.Body.Close()
|
||||
i.logger.ErrorContext(ctx, "write WebSocket response headers: ", err)
|
||||
<-done
|
||||
return
|
||||
}
|
||||
|
||||
_, err = io.Copy(stream, response.Body)
|
||||
response.Body.Close()
|
||||
common.Close(input, output)
|
||||
if err != nil && !E.IsClosedOrCanceled(err) {
|
||||
i.logger.DebugContext(ctx, "copy WebSocket response body: ", err)
|
||||
}
|
||||
done := make(chan struct{}, 2)
|
||||
go func() {
|
||||
io.Copy(left, right)
|
||||
closeBoth()
|
||||
done <- struct{}{}
|
||||
}()
|
||||
go func() {
|
||||
io.Copy(right, left)
|
||||
closeBoth()
|
||||
done <- struct{}{}
|
||||
}()
|
||||
<-done
|
||||
<-done
|
||||
}
|
||||
|
||||
|
||||
@@ -142,6 +142,11 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i
|
||||
t.Fatal("create logger: ", err)
|
||||
}
|
||||
|
||||
configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{})
|
||||
if err != nil {
|
||||
t.Fatal("create config manager: ", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
inboundInstance := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"),
|
||||
@@ -156,6 +161,7 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i
|
||||
edgeIPVersion: 0,
|
||||
datagramVersion: "",
|
||||
gracePeriod: 5 * time.Second,
|
||||
configManager: configManager,
|
||||
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
|
||||
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
|
||||
}
|
||||
|
||||
@@ -7,9 +7,10 @@ import (
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -30,17 +31,19 @@ func RegisterInbound(registry *inbound.Registry) {
|
||||
|
||||
type Inbound struct {
|
||||
inbound.Adapter
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
router adapter.ConnectionRouterEx
|
||||
logger log.ContextLogger
|
||||
credentials Credentials
|
||||
connectorID uuid.UUID
|
||||
haConnections int
|
||||
protocol string
|
||||
edgeIPVersion int
|
||||
datagramVersion string
|
||||
gracePeriod time.Duration
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
router adapter.ConnectionRouterEx
|
||||
logger log.ContextLogger
|
||||
credentials Credentials
|
||||
connectorID uuid.UUID
|
||||
haConnections int
|
||||
protocol string
|
||||
region string
|
||||
edgeIPVersion int
|
||||
datagramVersion string
|
||||
gracePeriod time.Duration
|
||||
configManager *ConfigManager
|
||||
|
||||
connectionAccess sync.Mutex
|
||||
connections []io.Closer
|
||||
@@ -50,101 +53,9 @@ type Inbound struct {
|
||||
datagramV2Muxers map[DatagramSender]*DatagramV2Muxer
|
||||
datagramV3Muxers map[DatagramSender]*DatagramV3Muxer
|
||||
|
||||
ingressAccess sync.RWMutex
|
||||
ingressVersion int32
|
||||
ingressRules []IngressRule
|
||||
}
|
||||
|
||||
// IngressRule maps a hostname pattern to an origin service URL.
|
||||
type IngressRule struct {
|
||||
Hostname string
|
||||
Service string
|
||||
}
|
||||
|
||||
type ingressConfig struct {
|
||||
Ingress []ingressConfigRule `json:"ingress"`
|
||||
}
|
||||
|
||||
type ingressConfigRule struct {
|
||||
Hostname string `json:"hostname,omitempty"`
|
||||
Service string `json:"service"`
|
||||
}
|
||||
|
||||
// UpdateIngress applies a new ingress configuration from the edge.
|
||||
func (i *Inbound) UpdateIngress(version int32, config []byte) {
|
||||
i.ingressAccess.Lock()
|
||||
defer i.ingressAccess.Unlock()
|
||||
|
||||
if version <= i.ingressVersion {
|
||||
return
|
||||
}
|
||||
|
||||
var parsed ingressConfig
|
||||
err := json.Unmarshal(config, &parsed)
|
||||
if err != nil {
|
||||
i.logger.Error("parse ingress config: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
rules := make([]IngressRule, 0, len(parsed.Ingress))
|
||||
for _, rule := range parsed.Ingress {
|
||||
rules = append(rules, IngressRule{
|
||||
Hostname: rule.Hostname,
|
||||
Service: rule.Service,
|
||||
})
|
||||
}
|
||||
i.ingressRules = rules
|
||||
i.ingressVersion = version
|
||||
i.logger.Info("updated ingress configuration (version ", version, ", ", len(rules), " rules)")
|
||||
}
|
||||
|
||||
// ResolveOrigin finds the origin service URL for a given hostname.
|
||||
// Returns the service URL if matched, or empty string if no match.
|
||||
func (i *Inbound) ResolveOrigin(hostname string) string {
|
||||
i.ingressAccess.RLock()
|
||||
defer i.ingressAccess.RUnlock()
|
||||
|
||||
for _, rule := range i.ingressRules {
|
||||
if rule.Hostname == "" {
|
||||
return rule.Service
|
||||
}
|
||||
if matchIngress(rule.Hostname, hostname) {
|
||||
return rule.Service
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func matchIngress(pattern, hostname string) bool {
|
||||
if pattern == hostname {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(pattern, "*.") {
|
||||
suffix := pattern[1:]
|
||||
return strings.HasSuffix(hostname, suffix)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ResolveOriginURL rewrites a request URL to point to the origin service.
|
||||
// For example, https://testbox.badnet.work/path → http://127.0.0.1:8083/path
|
||||
func (i *Inbound) ResolveOriginURL(requestURL string) string {
|
||||
parsed, err := url.Parse(requestURL)
|
||||
if err != nil {
|
||||
return requestURL
|
||||
}
|
||||
hostname := parsed.Hostname()
|
||||
origin := i.ResolveOrigin(hostname)
|
||||
if origin == "" || strings.HasPrefix(origin, "http_status:") {
|
||||
return requestURL
|
||||
}
|
||||
originURL, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return requestURL
|
||||
}
|
||||
parsed.Scheme = originURL.Scheme
|
||||
parsed.Host = originURL.Host
|
||||
return parsed.String()
|
||||
helloWorldAccess sync.Mutex
|
||||
helloWorldServer *http.Server
|
||||
helloWorldURL *url.URL
|
||||
}
|
||||
|
||||
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflareTunnelInboundOptions) (adapter.Inbound, error) {
|
||||
@@ -178,23 +89,30 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||
gracePeriod = 30 * time.Second
|
||||
}
|
||||
|
||||
configManager, err := NewConfigManager(options)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "build cloudflare tunnel runtime config")
|
||||
}
|
||||
|
||||
inboundCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
return &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, tag),
|
||||
ctx: inboundCtx,
|
||||
cancel: cancel,
|
||||
router: router,
|
||||
logger: logger,
|
||||
credentials: credentials,
|
||||
connectorID: uuid.New(),
|
||||
haConnections: haConnections,
|
||||
protocol: protocol,
|
||||
edgeIPVersion: edgeIPVersion,
|
||||
datagramVersion: datagramVersion,
|
||||
gracePeriod: gracePeriod,
|
||||
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
|
||||
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, tag),
|
||||
ctx: inboundCtx,
|
||||
cancel: cancel,
|
||||
router: router,
|
||||
logger: logger,
|
||||
credentials: credentials,
|
||||
connectorID: uuid.New(),
|
||||
haConnections: haConnections,
|
||||
protocol: protocol,
|
||||
region: options.Region,
|
||||
edgeIPVersion: edgeIPVersion,
|
||||
datagramVersion: datagramVersion,
|
||||
gracePeriod: gracePeriod,
|
||||
configManager: configManager,
|
||||
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
|
||||
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -238,6 +156,16 @@ func (i *Inbound) Start(stage adapter.StartStage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Inbound) ApplyConfig(version int32, config []byte) ConfigUpdateResult {
|
||||
result := i.configManager.Apply(version, config)
|
||||
if result.Err != nil {
|
||||
i.logger.Error("update ingress configuration: ", result.Err)
|
||||
return result
|
||||
}
|
||||
i.logger.Info("updated ingress configuration (version ", result.LastAppliedVersion, ")")
|
||||
return result
|
||||
}
|
||||
|
||||
func (i *Inbound) Close() error {
|
||||
i.cancel()
|
||||
i.done.Wait()
|
||||
@@ -247,9 +175,41 @@ func (i *Inbound) Close() error {
|
||||
}
|
||||
i.connections = nil
|
||||
i.connectionAccess.Unlock()
|
||||
if i.helloWorldServer != nil {
|
||||
i.helloWorldServer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Inbound) ensureHelloWorldURL() (*url.URL, error) {
|
||||
i.helloWorldAccess.Lock()
|
||||
defer i.helloWorldAccess.Unlock()
|
||||
if i.helloWorldURL != nil {
|
||||
return i.helloWorldURL, nil
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
_, _ = writer.Write([]byte("Hello World"))
|
||||
})
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "listen hello world server")
|
||||
}
|
||||
server := &http.Server{Handler: mux}
|
||||
go server.Serve(listener)
|
||||
|
||||
i.helloWorldServer = server
|
||||
i.helloWorldURL = &url.URL{
|
||||
Scheme: "http",
|
||||
Host: listener.Addr().String(),
|
||||
}
|
||||
return i.helloWorldURL, nil
|
||||
}
|
||||
|
||||
const (
|
||||
backoffBaseTime = time.Second
|
||||
backoffMaxTime = 2 * time.Minute
|
||||
|
||||
@@ -6,143 +6,148 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
func newTestIngressInbound() *Inbound {
|
||||
return &Inbound{logger: log.NewNOPFactory().NewLogger("test")}
|
||||
func newTestIngressInbound(t *testing.T) *Inbound {
|
||||
t.Helper()
|
||||
configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &Inbound{
|
||||
logger: log.NewNOPFactory().NewLogger("test"),
|
||||
configManager: configManager,
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateIngress(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound()
|
||||
func mustResolvedService(t *testing.T, rawService string) ResolvedService {
|
||||
t.Helper()
|
||||
service, err := parseResolvedService(rawService, defaultOriginRequestConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return service
|
||||
}
|
||||
|
||||
func TestApplyConfig(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound(t)
|
||||
|
||||
config1 := []byte(`{"ingress":[{"hostname":"a.com","service":"http://localhost:80"},{"hostname":"b.com","service":"http://localhost:81"},{"service":"http_status:404"}]}`)
|
||||
inboundInstance.UpdateIngress(1, config1)
|
||||
|
||||
inboundInstance.ingressAccess.RLock()
|
||||
count := len(inboundInstance.ingressRules)
|
||||
inboundInstance.ingressAccess.RUnlock()
|
||||
if count != 3 {
|
||||
t.Fatalf("expected 3 rules, got %d", count)
|
||||
result := inboundInstance.ApplyConfig(1, config1)
|
||||
if result.Err != nil {
|
||||
t.Fatal(result.Err)
|
||||
}
|
||||
if result.LastAppliedVersion != 1 {
|
||||
t.Fatalf("expected version 1, got %d", result.LastAppliedVersion)
|
||||
}
|
||||
|
||||
inboundInstance.UpdateIngress(1, []byte(`{"ingress":[{"service":"http_status:503"}]}`))
|
||||
inboundInstance.ingressAccess.RLock()
|
||||
count = len(inboundInstance.ingressRules)
|
||||
inboundInstance.ingressAccess.RUnlock()
|
||||
if count != 3 {
|
||||
t.Error("version 1 re-apply should not change rules, got ", count)
|
||||
service, loaded := inboundInstance.configManager.Resolve("a.com", "/")
|
||||
if !loaded || service.Service != "http://localhost:80" {
|
||||
t.Fatalf("expected a.com to resolve to localhost:80, got %#v, loaded=%v", service, loaded)
|
||||
}
|
||||
|
||||
inboundInstance.UpdateIngress(2, []byte(`{"ingress":[{"service":"http_status:503"}]}`))
|
||||
inboundInstance.ingressAccess.RLock()
|
||||
count = len(inboundInstance.ingressRules)
|
||||
inboundInstance.ingressAccess.RUnlock()
|
||||
if count != 1 {
|
||||
t.Error("version 2 should update to 1 rule, got ", count)
|
||||
result = inboundInstance.ApplyConfig(1, []byte(`{"ingress":[{"service":"http_status:503"}]}`))
|
||||
if result.Err != nil {
|
||||
t.Fatal(result.Err)
|
||||
}
|
||||
if result.LastAppliedVersion != 1 {
|
||||
t.Fatalf("same version should keep current version, got %d", result.LastAppliedVersion)
|
||||
}
|
||||
|
||||
service, loaded = inboundInstance.configManager.Resolve("b.com", "/")
|
||||
if !loaded || service.Service != "http://localhost:81" {
|
||||
t.Fatalf("expected old rules to remain, got %#v, loaded=%v", service, loaded)
|
||||
}
|
||||
|
||||
result = inboundInstance.ApplyConfig(2, []byte(`{"ingress":[{"service":"http_status:503"}]}`))
|
||||
if result.Err != nil {
|
||||
t.Fatal(result.Err)
|
||||
}
|
||||
if result.LastAppliedVersion != 2 {
|
||||
t.Fatalf("expected version 2, got %d", result.LastAppliedVersion)
|
||||
}
|
||||
|
||||
service, loaded = inboundInstance.configManager.Resolve("anything.com", "/")
|
||||
if !loaded || service.StatusCode != 503 {
|
||||
t.Fatalf("expected catch-all status 503, got %#v, loaded=%v", service, loaded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateIngressInvalidJSON(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound()
|
||||
inboundInstance.UpdateIngress(1, []byte("not json"))
|
||||
|
||||
inboundInstance.ingressAccess.RLock()
|
||||
count := len(inboundInstance.ingressRules)
|
||||
inboundInstance.ingressAccess.RUnlock()
|
||||
if count != 0 {
|
||||
t.Error("invalid JSON should leave rules empty, got ", count)
|
||||
func TestApplyConfigInvalidJSON(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound(t)
|
||||
result := inboundInstance.ApplyConfig(1, []byte("not json"))
|
||||
if result.Err == nil {
|
||||
t.Fatal("expected parse error")
|
||||
}
|
||||
if result.LastAppliedVersion != -1 {
|
||||
t.Fatalf("expected version to stay -1, got %d", result.LastAppliedVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOriginExact(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound()
|
||||
inboundInstance.ingressRules = []IngressRule{
|
||||
{Hostname: "test.example.com", Service: "http://localhost:8080"},
|
||||
{Hostname: "", Service: "http_status:404"},
|
||||
func TestResolveExactAndWildcard(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound(t)
|
||||
inboundInstance.configManager.activeConfig = RuntimeConfig{
|
||||
Ingress: []compiledIngressRule{
|
||||
{Hostname: "test.example.com", Service: mustResolvedService(t, "http://localhost:8080")},
|
||||
{Hostname: "*.example.com", Service: mustResolvedService(t, "http://localhost:9090")},
|
||||
{Service: mustResolvedService(t, "http_status:404")},
|
||||
},
|
||||
}
|
||||
|
||||
result := inboundInstance.ResolveOrigin("test.example.com")
|
||||
if result != "http://localhost:8080" {
|
||||
t.Error("expected http://localhost:8080, got ", result)
|
||||
service, loaded := inboundInstance.configManager.Resolve("test.example.com", "/")
|
||||
if !loaded || service.Service != "http://localhost:8080" {
|
||||
t.Fatalf("expected exact match, got %#v, loaded=%v", service, loaded)
|
||||
}
|
||||
|
||||
service, loaded = inboundInstance.configManager.Resolve("sub.example.com", "/")
|
||||
if !loaded || service.Service != "http://localhost:9090" {
|
||||
t.Fatalf("expected wildcard match, got %#v, loaded=%v", service, loaded)
|
||||
}
|
||||
|
||||
service, loaded = inboundInstance.configManager.Resolve("unknown.test", "/")
|
||||
if !loaded || service.StatusCode != 404 {
|
||||
t.Fatalf("expected catch-all 404, got %#v, loaded=%v", service, loaded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOriginWildcard(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound()
|
||||
inboundInstance.ingressRules = []IngressRule{
|
||||
{Hostname: "*.example.com", Service: "http://localhost:9090"},
|
||||
func TestResolveHTTPService(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound(t)
|
||||
inboundInstance.configManager.activeConfig = RuntimeConfig{
|
||||
Ingress: []compiledIngressRule{
|
||||
{Hostname: "foo.com", Service: mustResolvedService(t, "http://127.0.0.1:8083")},
|
||||
{Service: mustResolvedService(t, "http_status:404")},
|
||||
},
|
||||
}
|
||||
|
||||
result := inboundInstance.ResolveOrigin("sub.example.com")
|
||||
if result != "http://localhost:9090" {
|
||||
t.Error("wildcard should match sub.example.com, got ", result)
|
||||
service, requestURL, err := inboundInstance.resolveHTTPService("https://foo.com/path?q=1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result = inboundInstance.ResolveOrigin("example.com")
|
||||
if result != "" {
|
||||
t.Error("wildcard should not match bare example.com, got ", result)
|
||||
if service.Destination.String() != "127.0.0.1:8083" {
|
||||
t.Fatalf("expected destination 127.0.0.1:8083, got %s", service.Destination)
|
||||
}
|
||||
if requestURL != "http://127.0.0.1:8083/path?q=1" {
|
||||
t.Fatalf("expected rewritten URL, got %s", requestURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOriginCatchAll(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound()
|
||||
inboundInstance.ingressRules = []IngressRule{
|
||||
{Hostname: "specific.com", Service: "http://localhost:1"},
|
||||
{Hostname: "", Service: "http://localhost:2"},
|
||||
func TestResolveHTTPServiceStatus(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound(t)
|
||||
inboundInstance.configManager.activeConfig = RuntimeConfig{
|
||||
Ingress: []compiledIngressRule{
|
||||
{Service: mustResolvedService(t, "http_status:404")},
|
||||
},
|
||||
}
|
||||
|
||||
result := inboundInstance.ResolveOrigin("anything.com")
|
||||
if result != "http://localhost:2" {
|
||||
t.Error("catch-all should match, got ", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOriginNoMatch(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound()
|
||||
inboundInstance.ingressRules = []IngressRule{
|
||||
{Hostname: "specific.com", Service: "http://localhost:1"},
|
||||
}
|
||||
|
||||
result := inboundInstance.ResolveOrigin("other.com")
|
||||
if result != "" {
|
||||
t.Error("expected empty for no match, got ", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOriginURLRewrite(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound()
|
||||
inboundInstance.ingressRules = []IngressRule{
|
||||
{Hostname: "foo.com", Service: "http://127.0.0.1:8083"},
|
||||
}
|
||||
|
||||
result := inboundInstance.ResolveOriginURL("https://foo.com/path?q=1")
|
||||
if result != "http://127.0.0.1:8083/path?q=1" {
|
||||
t.Error("expected http://127.0.0.1:8083/path?q=1, got ", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOriginURLNoMatch(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound()
|
||||
inboundInstance.ingressRules = []IngressRule{
|
||||
{Hostname: "other.com", Service: "http://localhost:1"},
|
||||
}
|
||||
|
||||
original := "https://unknown.com/page"
|
||||
result := inboundInstance.ResolveOriginURL(original)
|
||||
if result != original {
|
||||
t.Error("no match should return original, got ", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOriginURLHTTPStatus(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound()
|
||||
inboundInstance.ingressRules = []IngressRule{
|
||||
{Hostname: "", Service: "http_status:404"},
|
||||
}
|
||||
|
||||
original := "https://any.com/page"
|
||||
result := inboundInstance.ResolveOriginURL(original)
|
||||
if result != original {
|
||||
t.Error("http_status service should return original, got ", result)
|
||||
service, requestURL, err := inboundInstance.resolveHTTPService("https://any.com/path")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if service.StatusCode != 404 {
|
||||
t.Fatalf("expected status 404, got %#v", service)
|
||||
}
|
||||
if requestURL != "https://any.com/path" {
|
||||
t.Fatalf("status service should keep request URL, got %s", requestURL)
|
||||
}
|
||||
}
|
||||
|
||||
803
protocol/cloudflare/runtime_config.go
Normal file
803
protocol/cloudflare/runtime_config.go
Normal file
@@ -0,0 +1,803 @@
|
||||
//go:build with_cloudflare_tunnel
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultHTTPConnectTimeout = 30 * time.Second
|
||||
defaultTLSTimeout = 10 * time.Second
|
||||
defaultTCPKeepAlive = 30 * time.Second
|
||||
defaultKeepAliveTimeout = 90 * time.Second
|
||||
defaultKeepAliveConnections = 100
|
||||
defaultProxyAddress = "127.0.0.1"
|
||||
defaultWarpRoutingConnectTime = 5 * time.Second
|
||||
defaultWarpRoutingTCPKeepAlive = 30 * time.Second
|
||||
)
|
||||
|
||||
type ResolvedServiceKind int
|
||||
|
||||
const (
|
||||
ResolvedServiceHTTP ResolvedServiceKind = iota
|
||||
ResolvedServiceStream
|
||||
ResolvedServiceStatus
|
||||
ResolvedServiceHelloWorld
|
||||
ResolvedServiceUnix
|
||||
ResolvedServiceUnixTLS
|
||||
)
|
||||
|
||||
type ResolvedService struct {
|
||||
Kind ResolvedServiceKind
|
||||
Service string
|
||||
Destination M.Socksaddr
|
||||
BaseURL *url.URL
|
||||
UnixPath string
|
||||
StatusCode int
|
||||
OriginRequest OriginRequestConfig
|
||||
}
|
||||
|
||||
func (s ResolvedService) RouterControlled() bool {
|
||||
return s.Kind == ResolvedServiceHTTP || s.Kind == ResolvedServiceStream
|
||||
}
|
||||
|
||||
func (s ResolvedService) BuildRequestURL(requestURL string) (string, error) {
|
||||
switch s.Kind {
|
||||
case ResolvedServiceHTTP, ResolvedServiceUnix, ResolvedServiceUnixTLS:
|
||||
requestParsed, err := url.Parse(requestURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
originURL := *s.BaseURL
|
||||
originURL.Path = requestParsed.Path
|
||||
originURL.RawPath = requestParsed.RawPath
|
||||
originURL.RawQuery = requestParsed.RawQuery
|
||||
originURL.Fragment = requestParsed.Fragment
|
||||
return originURL.String(), nil
|
||||
case ResolvedServiceHelloWorld:
|
||||
if s.BaseURL == nil {
|
||||
return "", E.New("hello world service is unavailable")
|
||||
}
|
||||
requestParsed, err := url.Parse(requestURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
originURL := *s.BaseURL
|
||||
originURL.Path = requestParsed.Path
|
||||
originURL.RawPath = requestParsed.RawPath
|
||||
originURL.RawQuery = requestParsed.RawQuery
|
||||
originURL.Fragment = requestParsed.Fragment
|
||||
return originURL.String(), nil
|
||||
default:
|
||||
return requestURL, nil
|
||||
}
|
||||
}
|
||||
|
||||
type compiledIngressRule struct {
|
||||
Hostname string
|
||||
PunycodeHostname string
|
||||
Path *regexp.Regexp
|
||||
Service ResolvedService
|
||||
}
|
||||
|
||||
type RuntimeConfig struct {
|
||||
Ingress []compiledIngressRule
|
||||
OriginRequest OriginRequestConfig
|
||||
WarpRouting WarpRoutingConfig
|
||||
}
|
||||
|
||||
type OriginRequestConfig struct {
|
||||
ConnectTimeout time.Duration
|
||||
TLSTimeout time.Duration
|
||||
TCPKeepAlive time.Duration
|
||||
NoHappyEyeballs bool
|
||||
KeepAliveTimeout time.Duration
|
||||
KeepAliveConnections int
|
||||
HTTPHostHeader string
|
||||
OriginServerName string
|
||||
MatchSNIToHost bool
|
||||
CAPool string
|
||||
NoTLSVerify bool
|
||||
DisableChunkedEncoding bool
|
||||
BastionMode bool
|
||||
ProxyAddress string
|
||||
ProxyPort uint
|
||||
ProxyType string
|
||||
IPRules []IPRule
|
||||
HTTP2Origin bool
|
||||
Access AccessConfig
|
||||
}
|
||||
|
||||
type AccessConfig struct {
|
||||
Required bool
|
||||
TeamName string
|
||||
AudTag []string
|
||||
Environment string
|
||||
}
|
||||
|
||||
type IPRule struct {
|
||||
Prefix string
|
||||
Ports []int
|
||||
Allow bool
|
||||
}
|
||||
|
||||
type WarpRoutingConfig struct {
|
||||
ConnectTimeout time.Duration
|
||||
MaxActiveFlows uint64
|
||||
TCPKeepAlive time.Duration
|
||||
}
|
||||
|
||||
type ConfigUpdateResult struct {
|
||||
LastAppliedVersion int32
|
||||
Err error
|
||||
}
|
||||
|
||||
type ConfigManager struct {
|
||||
access sync.RWMutex
|
||||
currentVersion int32
|
||||
activeConfig RuntimeConfig
|
||||
}
|
||||
|
||||
func NewConfigManager(options option.CloudflareTunnelInboundOptions) (*ConfigManager, error) {
|
||||
config, err := buildLocalRuntimeConfig(options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ConfigManager{
|
||||
currentVersion: -1,
|
||||
activeConfig: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *ConfigManager) Snapshot() RuntimeConfig {
|
||||
m.access.RLock()
|
||||
defer m.access.RUnlock()
|
||||
return m.activeConfig
|
||||
}
|
||||
|
||||
func (m *ConfigManager) CurrentVersion() int32 {
|
||||
m.access.RLock()
|
||||
defer m.access.RUnlock()
|
||||
return m.currentVersion
|
||||
}
|
||||
|
||||
func (m *ConfigManager) Apply(version int32, raw []byte) ConfigUpdateResult {
|
||||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
|
||||
if version <= m.currentVersion {
|
||||
return ConfigUpdateResult{LastAppliedVersion: m.currentVersion}
|
||||
}
|
||||
|
||||
config, err := buildRemoteRuntimeConfig(raw)
|
||||
if err != nil {
|
||||
return ConfigUpdateResult{
|
||||
LastAppliedVersion: m.currentVersion,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
m.activeConfig = config
|
||||
m.currentVersion = version
|
||||
return ConfigUpdateResult{LastAppliedVersion: m.currentVersion}
|
||||
}
|
||||
|
||||
func (m *ConfigManager) Resolve(hostname, path string) (ResolvedService, bool) {
|
||||
m.access.RLock()
|
||||
defer m.access.RUnlock()
|
||||
return m.activeConfig.Resolve(hostname, path)
|
||||
}
|
||||
|
||||
func (c RuntimeConfig) Resolve(hostname, path string) (ResolvedService, bool) {
|
||||
host := stripPort(hostname)
|
||||
for _, rule := range c.Ingress {
|
||||
if !matchIngressRule(rule, host, path) {
|
||||
continue
|
||||
}
|
||||
return rule.Service, true
|
||||
}
|
||||
return ResolvedService{}, false
|
||||
}
|
||||
|
||||
func matchIngressRule(rule compiledIngressRule, hostname, path string) bool {
|
||||
hostMatch := rule.Hostname == "" || rule.Hostname == "*" || matchIngressHost(rule.Hostname, hostname)
|
||||
if !hostMatch && rule.PunycodeHostname != "" {
|
||||
hostMatch = matchIngressHost(rule.PunycodeHostname, hostname)
|
||||
}
|
||||
if !hostMatch {
|
||||
return false
|
||||
}
|
||||
return rule.Path == nil || rule.Path.MatchString(path)
|
||||
}
|
||||
|
||||
func matchIngressHost(pattern, hostname string) bool {
|
||||
if pattern == hostname {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(pattern, "*.") {
|
||||
return strings.HasSuffix(hostname, strings.TrimPrefix(pattern, "*"))
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildLocalRuntimeConfig(options option.CloudflareTunnelInboundOptions) (RuntimeConfig, error) {
|
||||
defaultOriginRequest := originRequestFromOption(options.OriginRequest)
|
||||
warpRouting := warpRoutingFromOption(options.WarpRouting)
|
||||
var ingressRules []localIngressRule
|
||||
for _, rule := range options.Ingress {
|
||||
ingressRules = append(ingressRules, localIngressRule{
|
||||
Hostname: rule.Hostname,
|
||||
Path: rule.Path,
|
||||
Service: rule.Service,
|
||||
OriginRequest: mergeOptionOriginRequest(defaultOriginRequest, rule.OriginRequest),
|
||||
})
|
||||
}
|
||||
compiledRules, err := compileIngressRules(defaultOriginRequest, ingressRules)
|
||||
if err != nil {
|
||||
return RuntimeConfig{}, err
|
||||
}
|
||||
return RuntimeConfig{
|
||||
Ingress: compiledRules,
|
||||
OriginRequest: defaultOriginRequest,
|
||||
WarpRouting: warpRouting,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildRemoteRuntimeConfig(raw []byte) (RuntimeConfig, error) {
|
||||
var remote remoteConfigJSON
|
||||
if err := json.Unmarshal(raw, &remote); err != nil {
|
||||
return RuntimeConfig{}, E.Cause(err, "decode remote config")
|
||||
}
|
||||
defaultOriginRequest := originRequestFromRemote(remote.OriginRequest)
|
||||
warpRouting := warpRoutingFromRemote(remote.WarpRouting)
|
||||
var ingressRules []localIngressRule
|
||||
for _, rule := range remote.Ingress {
|
||||
ingressRules = append(ingressRules, localIngressRule{
|
||||
Hostname: rule.Hostname,
|
||||
Path: rule.Path,
|
||||
Service: rule.Service,
|
||||
OriginRequest: mergeRemoteOriginRequest(defaultOriginRequest, rule.OriginRequest),
|
||||
})
|
||||
}
|
||||
compiledRules, err := compileIngressRules(defaultOriginRequest, ingressRules)
|
||||
if err != nil {
|
||||
return RuntimeConfig{}, err
|
||||
}
|
||||
return RuntimeConfig{
|
||||
Ingress: compiledRules,
|
||||
OriginRequest: defaultOriginRequest,
|
||||
WarpRouting: warpRouting,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type localIngressRule struct {
|
||||
Hostname string
|
||||
Path string
|
||||
Service string
|
||||
OriginRequest OriginRequestConfig
|
||||
}
|
||||
|
||||
type remoteConfigJSON struct {
|
||||
OriginRequest remoteOriginRequestJSON `json:"originRequest"`
|
||||
Ingress []remoteIngressRuleJSON `json:"ingress"`
|
||||
WarpRouting remoteWarpRoutingJSON `json:"warp-routing"`
|
||||
}
|
||||
|
||||
type remoteIngressRuleJSON struct {
|
||||
Hostname string `json:"hostname,omitempty"`
|
||||
Path string `json:"path,omitempty"`
|
||||
Service string `json:"service"`
|
||||
OriginRequest remoteOriginRequestJSON `json:"originRequest,omitempty"`
|
||||
}
|
||||
|
||||
type remoteOriginRequestJSON struct {
|
||||
ConnectTimeout int64 `json:"connectTimeout,omitempty"`
|
||||
TLSTimeout int64 `json:"tlsTimeout,omitempty"`
|
||||
TCPKeepAlive int64 `json:"tcpKeepAlive,omitempty"`
|
||||
NoHappyEyeballs *bool `json:"noHappyEyeballs,omitempty"`
|
||||
KeepAliveTimeout int64 `json:"keepAliveTimeout,omitempty"`
|
||||
KeepAliveConnections *int `json:"keepAliveConnections,omitempty"`
|
||||
HTTPHostHeader string `json:"httpHostHeader,omitempty"`
|
||||
OriginServerName string `json:"originServerName,omitempty"`
|
||||
MatchSNIToHost *bool `json:"matchSNIToHost,omitempty"`
|
||||
CAPool string `json:"caPool,omitempty"`
|
||||
NoTLSVerify *bool `json:"noTLSVerify,omitempty"`
|
||||
DisableChunkedEncoding *bool `json:"disableChunkedEncoding,omitempty"`
|
||||
BastionMode *bool `json:"bastionMode,omitempty"`
|
||||
ProxyAddress string `json:"proxyAddress,omitempty"`
|
||||
ProxyPort *uint `json:"proxyPort,omitempty"`
|
||||
ProxyType string `json:"proxyType,omitempty"`
|
||||
IPRules []remoteIPRuleJSON `json:"ipRules,omitempty"`
|
||||
HTTP2Origin *bool `json:"http2Origin,omitempty"`
|
||||
Access *remoteAccessJSON `json:"access,omitempty"`
|
||||
}
|
||||
|
||||
type remoteAccessJSON struct {
|
||||
Required bool `json:"required,omitempty"`
|
||||
TeamName string `json:"teamName,omitempty"`
|
||||
AudTag []string `json:"audTag,omitempty"`
|
||||
Environment string `json:"environment,omitempty"`
|
||||
}
|
||||
|
||||
type remoteIPRuleJSON struct {
|
||||
Prefix string `json:"prefix,omitempty"`
|
||||
Ports []int `json:"ports,omitempty"`
|
||||
Allow bool `json:"allow,omitempty"`
|
||||
}
|
||||
|
||||
type remoteWarpRoutingJSON struct {
|
||||
ConnectTimeout int64 `json:"connectTimeout,omitempty"`
|
||||
MaxActiveFlows uint64 `json:"maxActiveFlows,omitempty"`
|
||||
TCPKeepAlive int64 `json:"tcpKeepAlive,omitempty"`
|
||||
}
|
||||
|
||||
func compileIngressRules(defaultOriginRequest OriginRequestConfig, rawRules []localIngressRule) ([]compiledIngressRule, error) {
|
||||
if len(rawRules) == 0 {
|
||||
rawRules = []localIngressRule{{
|
||||
Service: "http_status:503",
|
||||
OriginRequest: defaultOriginRequest,
|
||||
}}
|
||||
}
|
||||
if !isCatchAllRule(rawRules[len(rawRules)-1].Hostname, rawRules[len(rawRules)-1].Path) {
|
||||
return nil, E.New("the last ingress rule must be a catch-all rule")
|
||||
}
|
||||
|
||||
compiled := make([]compiledIngressRule, 0, len(rawRules))
|
||||
for index, rule := range rawRules {
|
||||
if err := validateHostname(rule.Hostname, index == len(rawRules)-1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
service, err := parseResolvedService(rule.Service, rule.OriginRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var pathPattern *regexp.Regexp
|
||||
if rule.Path != "" {
|
||||
pathPattern, err = regexp.Compile(rule.Path)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "compile ingress path regex")
|
||||
}
|
||||
}
|
||||
punycode := ""
|
||||
if rule.Hostname != "" && rule.Hostname != "*" {
|
||||
punycodeValue, err := idna.Lookup.ToASCII(rule.Hostname)
|
||||
if err == nil && punycodeValue != rule.Hostname {
|
||||
punycode = punycodeValue
|
||||
}
|
||||
}
|
||||
compiled = append(compiled, compiledIngressRule{
|
||||
Hostname: rule.Hostname,
|
||||
PunycodeHostname: punycode,
|
||||
Path: pathPattern,
|
||||
Service: service,
|
||||
})
|
||||
}
|
||||
return compiled, nil
|
||||
}
|
||||
|
||||
func parseResolvedService(rawService string, originRequest OriginRequestConfig) (ResolvedService, error) {
|
||||
switch {
|
||||
case rawService == "":
|
||||
return ResolvedService{}, E.New("missing ingress service")
|
||||
case strings.HasPrefix(rawService, "http_status:"):
|
||||
statusCode, err := strconv.Atoi(strings.TrimPrefix(rawService, "http_status:"))
|
||||
if err != nil {
|
||||
return ResolvedService{}, E.Cause(err, "parse http_status service")
|
||||
}
|
||||
if statusCode < 100 || statusCode > 999 {
|
||||
return ResolvedService{}, E.New("invalid http_status code: ", statusCode)
|
||||
}
|
||||
return ResolvedService{
|
||||
Kind: ResolvedServiceStatus,
|
||||
Service: rawService,
|
||||
StatusCode: statusCode,
|
||||
OriginRequest: originRequest,
|
||||
}, nil
|
||||
case rawService == "hello_world" || rawService == "hello-world":
|
||||
return ResolvedService{
|
||||
Kind: ResolvedServiceHelloWorld,
|
||||
Service: rawService,
|
||||
OriginRequest: originRequest,
|
||||
}, nil
|
||||
case strings.HasPrefix(rawService, "unix:"):
|
||||
return ResolvedService{
|
||||
Kind: ResolvedServiceUnix,
|
||||
Service: rawService,
|
||||
UnixPath: strings.TrimPrefix(rawService, "unix:"),
|
||||
BaseURL: &url.URL{Scheme: "http", Host: "localhost"},
|
||||
OriginRequest: originRequest,
|
||||
}, nil
|
||||
case strings.HasPrefix(rawService, "unix+tls:"):
|
||||
return ResolvedService{
|
||||
Kind: ResolvedServiceUnixTLS,
|
||||
Service: rawService,
|
||||
UnixPath: strings.TrimPrefix(rawService, "unix+tls:"),
|
||||
BaseURL: &url.URL{Scheme: "https", Host: "localhost"},
|
||||
OriginRequest: originRequest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(rawService)
|
||||
if err != nil {
|
||||
return ResolvedService{}, E.Cause(err, "parse ingress service URL")
|
||||
}
|
||||
if parsedURL.Scheme == "" || parsedURL.Hostname() == "" {
|
||||
return ResolvedService{}, E.New("ingress service must include scheme and hostname: ", rawService)
|
||||
}
|
||||
if parsedURL.Path != "" {
|
||||
return ResolvedService{}, E.New("ingress service cannot include a path: ", rawService)
|
||||
}
|
||||
|
||||
switch parsedURL.Scheme {
|
||||
case "http", "https", "ws", "wss":
|
||||
return ResolvedService{
|
||||
Kind: ResolvedServiceHTTP,
|
||||
Service: rawService,
|
||||
Destination: parseServiceDestination(parsedURL),
|
||||
BaseURL: parsedURL,
|
||||
OriginRequest: originRequest,
|
||||
}, nil
|
||||
case "tcp", "ssh", "rdp", "smb":
|
||||
return ResolvedService{
|
||||
Kind: ResolvedServiceStream,
|
||||
Service: rawService,
|
||||
Destination: parseServiceDestination(parsedURL),
|
||||
BaseURL: parsedURL,
|
||||
OriginRequest: originRequest,
|
||||
}, nil
|
||||
default:
|
||||
return ResolvedService{}, E.New("unsupported ingress service scheme: ", parsedURL.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func parseServiceDestination(parsedURL *url.URL) M.Socksaddr {
|
||||
host := parsedURL.Hostname()
|
||||
port := parsedURL.Port()
|
||||
if port == "" {
|
||||
switch parsedURL.Scheme {
|
||||
case "https", "wss":
|
||||
port = "443"
|
||||
case "ssh":
|
||||
port = "22"
|
||||
case "rdp":
|
||||
port = "3389"
|
||||
case "smb":
|
||||
port = "445"
|
||||
case "tcp":
|
||||
port = "7864"
|
||||
default:
|
||||
port = "80"
|
||||
}
|
||||
}
|
||||
return M.ParseSocksaddr(net.JoinHostPort(host, port))
|
||||
}
|
||||
|
||||
func validateHostname(hostname string, isLast bool) error {
|
||||
if hostname == "" || hostname == "*" {
|
||||
if !isLast {
|
||||
return E.New("only the last ingress rule may be a catch-all rule")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if strings.Count(hostname, "*") > 1 || (strings.Contains(hostname, "*") && !strings.HasPrefix(hostname, "*.")) {
|
||||
return E.New("hostname wildcard must be in the form *.example.com")
|
||||
}
|
||||
if stripPort(hostname) != hostname {
|
||||
return E.New("ingress hostname cannot contain a port")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isCatchAllRule(hostname, path string) bool {
|
||||
return (hostname == "" || hostname == "*") && path == ""
|
||||
}
|
||||
|
||||
func stripPort(hostname string) string {
|
||||
if host, _, err := net.SplitHostPort(hostname); err == nil {
|
||||
return host
|
||||
}
|
||||
return hostname
|
||||
}
|
||||
|
||||
func defaultOriginRequestConfig() OriginRequestConfig {
|
||||
return OriginRequestConfig{
|
||||
ConnectTimeout: defaultHTTPConnectTimeout,
|
||||
TLSTimeout: defaultTLSTimeout,
|
||||
TCPKeepAlive: defaultTCPKeepAlive,
|
||||
KeepAliveTimeout: defaultKeepAliveTimeout,
|
||||
KeepAliveConnections: defaultKeepAliveConnections,
|
||||
ProxyAddress: defaultProxyAddress,
|
||||
}
|
||||
}
|
||||
|
||||
func originRequestFromOption(input option.CloudflareTunnelOriginRequestOptions) OriginRequestConfig {
|
||||
config := defaultOriginRequestConfig()
|
||||
if input.ConnectTimeout != 0 {
|
||||
config.ConnectTimeout = time.Duration(input.ConnectTimeout)
|
||||
}
|
||||
if input.TLSTimeout != 0 {
|
||||
config.TLSTimeout = time.Duration(input.TLSTimeout)
|
||||
}
|
||||
if input.TCPKeepAlive != 0 {
|
||||
config.TCPKeepAlive = time.Duration(input.TCPKeepAlive)
|
||||
}
|
||||
if input.KeepAliveTimeout != 0 {
|
||||
config.KeepAliveTimeout = time.Duration(input.KeepAliveTimeout)
|
||||
}
|
||||
if input.KeepAliveConnections != 0 {
|
||||
config.KeepAliveConnections = input.KeepAliveConnections
|
||||
}
|
||||
config.NoHappyEyeballs = input.NoHappyEyeballs
|
||||
config.HTTPHostHeader = input.HTTPHostHeader
|
||||
config.OriginServerName = input.OriginServerName
|
||||
config.MatchSNIToHost = input.MatchSNIToHost
|
||||
config.CAPool = input.CAPool
|
||||
config.NoTLSVerify = input.NoTLSVerify
|
||||
config.DisableChunkedEncoding = input.DisableChunkedEncoding
|
||||
config.BastionMode = input.BastionMode
|
||||
if input.ProxyAddress != "" {
|
||||
config.ProxyAddress = input.ProxyAddress
|
||||
}
|
||||
if input.ProxyPort != 0 {
|
||||
config.ProxyPort = input.ProxyPort
|
||||
}
|
||||
config.ProxyType = input.ProxyType
|
||||
config.HTTP2Origin = input.HTTP2Origin
|
||||
config.Access = AccessConfig{
|
||||
Required: input.Access.Required,
|
||||
TeamName: input.Access.TeamName,
|
||||
AudTag: append([]string(nil), input.Access.AudTag...),
|
||||
Environment: input.Access.Environment,
|
||||
}
|
||||
for _, rule := range input.IPRules {
|
||||
config.IPRules = append(config.IPRules, IPRule{
|
||||
Prefix: rule.Prefix,
|
||||
Ports: append([]int(nil), rule.Ports...),
|
||||
Allow: rule.Allow,
|
||||
})
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
func mergeOptionOriginRequest(base OriginRequestConfig, override option.CloudflareTunnelOriginRequestOptions) OriginRequestConfig {
|
||||
result := base
|
||||
if override.ConnectTimeout != 0 {
|
||||
result.ConnectTimeout = time.Duration(override.ConnectTimeout)
|
||||
}
|
||||
if override.TLSTimeout != 0 {
|
||||
result.TLSTimeout = time.Duration(override.TLSTimeout)
|
||||
}
|
||||
if override.TCPKeepAlive != 0 {
|
||||
result.TCPKeepAlive = time.Duration(override.TCPKeepAlive)
|
||||
}
|
||||
if override.KeepAliveTimeout != 0 {
|
||||
result.KeepAliveTimeout = time.Duration(override.KeepAliveTimeout)
|
||||
}
|
||||
if override.KeepAliveConnections != 0 {
|
||||
result.KeepAliveConnections = override.KeepAliveConnections
|
||||
}
|
||||
result.NoHappyEyeballs = override.NoHappyEyeballs
|
||||
if override.HTTPHostHeader != "" {
|
||||
result.HTTPHostHeader = override.HTTPHostHeader
|
||||
}
|
||||
if override.OriginServerName != "" {
|
||||
result.OriginServerName = override.OriginServerName
|
||||
}
|
||||
result.MatchSNIToHost = override.MatchSNIToHost
|
||||
if override.CAPool != "" {
|
||||
result.CAPool = override.CAPool
|
||||
}
|
||||
result.NoTLSVerify = override.NoTLSVerify
|
||||
result.DisableChunkedEncoding = override.DisableChunkedEncoding
|
||||
result.BastionMode = override.BastionMode
|
||||
if override.ProxyAddress != "" {
|
||||
result.ProxyAddress = override.ProxyAddress
|
||||
}
|
||||
if override.ProxyPort != 0 {
|
||||
result.ProxyPort = override.ProxyPort
|
||||
}
|
||||
if override.ProxyType != "" {
|
||||
result.ProxyType = override.ProxyType
|
||||
}
|
||||
if len(override.IPRules) > 0 {
|
||||
result.IPRules = nil
|
||||
for _, rule := range override.IPRules {
|
||||
result.IPRules = append(result.IPRules, IPRule{
|
||||
Prefix: rule.Prefix,
|
||||
Ports: append([]int(nil), rule.Ports...),
|
||||
Allow: rule.Allow,
|
||||
})
|
||||
}
|
||||
}
|
||||
result.HTTP2Origin = override.HTTP2Origin
|
||||
if override.Access.Required || override.Access.TeamName != "" || len(override.Access.AudTag) > 0 || override.Access.Environment != "" {
|
||||
result.Access = AccessConfig{
|
||||
Required: override.Access.Required,
|
||||
TeamName: override.Access.TeamName,
|
||||
AudTag: append([]string(nil), override.Access.AudTag...),
|
||||
Environment: override.Access.Environment,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func originRequestFromRemote(input remoteOriginRequestJSON) OriginRequestConfig {
|
||||
config := defaultOriginRequestConfig()
|
||||
if input.ConnectTimeout != 0 {
|
||||
config.ConnectTimeout = time.Duration(input.ConnectTimeout) * time.Second
|
||||
}
|
||||
if input.TLSTimeout != 0 {
|
||||
config.TLSTimeout = time.Duration(input.TLSTimeout) * time.Second
|
||||
}
|
||||
if input.TCPKeepAlive != 0 {
|
||||
config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) * time.Second
|
||||
}
|
||||
if input.KeepAliveTimeout != 0 {
|
||||
config.KeepAliveTimeout = time.Duration(input.KeepAliveTimeout) * time.Second
|
||||
}
|
||||
if input.KeepAliveConnections != nil {
|
||||
config.KeepAliveConnections = *input.KeepAliveConnections
|
||||
}
|
||||
if input.NoHappyEyeballs != nil {
|
||||
config.NoHappyEyeballs = *input.NoHappyEyeballs
|
||||
}
|
||||
config.HTTPHostHeader = input.HTTPHostHeader
|
||||
config.OriginServerName = input.OriginServerName
|
||||
if input.MatchSNIToHost != nil {
|
||||
config.MatchSNIToHost = *input.MatchSNIToHost
|
||||
}
|
||||
config.CAPool = input.CAPool
|
||||
if input.NoTLSVerify != nil {
|
||||
config.NoTLSVerify = *input.NoTLSVerify
|
||||
}
|
||||
if input.DisableChunkedEncoding != nil {
|
||||
config.DisableChunkedEncoding = *input.DisableChunkedEncoding
|
||||
}
|
||||
if input.BastionMode != nil {
|
||||
config.BastionMode = *input.BastionMode
|
||||
}
|
||||
if input.ProxyAddress != "" {
|
||||
config.ProxyAddress = input.ProxyAddress
|
||||
}
|
||||
if input.ProxyPort != nil {
|
||||
config.ProxyPort = *input.ProxyPort
|
||||
}
|
||||
config.ProxyType = input.ProxyType
|
||||
if input.HTTP2Origin != nil {
|
||||
config.HTTP2Origin = *input.HTTP2Origin
|
||||
}
|
||||
if input.Access != nil {
|
||||
config.Access = AccessConfig{
|
||||
Required: input.Access.Required,
|
||||
TeamName: input.Access.TeamName,
|
||||
AudTag: append([]string(nil), input.Access.AudTag...),
|
||||
Environment: input.Access.Environment,
|
||||
}
|
||||
}
|
||||
for _, rule := range input.IPRules {
|
||||
config.IPRules = append(config.IPRules, IPRule{
|
||||
Prefix: rule.Prefix,
|
||||
Ports: append([]int(nil), rule.Ports...),
|
||||
Allow: rule.Allow,
|
||||
})
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
func mergeRemoteOriginRequest(base OriginRequestConfig, override remoteOriginRequestJSON) OriginRequestConfig {
|
||||
result := base
|
||||
if override.ConnectTimeout != 0 {
|
||||
result.ConnectTimeout = time.Duration(override.ConnectTimeout) * time.Second
|
||||
}
|
||||
if override.TLSTimeout != 0 {
|
||||
result.TLSTimeout = time.Duration(override.TLSTimeout) * time.Second
|
||||
}
|
||||
if override.TCPKeepAlive != 0 {
|
||||
result.TCPKeepAlive = time.Duration(override.TCPKeepAlive) * time.Second
|
||||
}
|
||||
if override.NoHappyEyeballs != nil {
|
||||
result.NoHappyEyeballs = *override.NoHappyEyeballs
|
||||
}
|
||||
if override.KeepAliveTimeout != 0 {
|
||||
result.KeepAliveTimeout = time.Duration(override.KeepAliveTimeout) * time.Second
|
||||
}
|
||||
if override.KeepAliveConnections != nil {
|
||||
result.KeepAliveConnections = *override.KeepAliveConnections
|
||||
}
|
||||
if override.HTTPHostHeader != "" {
|
||||
result.HTTPHostHeader = override.HTTPHostHeader
|
||||
}
|
||||
if override.OriginServerName != "" {
|
||||
result.OriginServerName = override.OriginServerName
|
||||
}
|
||||
if override.MatchSNIToHost != nil {
|
||||
result.MatchSNIToHost = *override.MatchSNIToHost
|
||||
}
|
||||
if override.CAPool != "" {
|
||||
result.CAPool = override.CAPool
|
||||
}
|
||||
if override.NoTLSVerify != nil {
|
||||
result.NoTLSVerify = *override.NoTLSVerify
|
||||
}
|
||||
if override.DisableChunkedEncoding != nil {
|
||||
result.DisableChunkedEncoding = *override.DisableChunkedEncoding
|
||||
}
|
||||
if override.BastionMode != nil {
|
||||
result.BastionMode = *override.BastionMode
|
||||
}
|
||||
if override.ProxyAddress != "" {
|
||||
result.ProxyAddress = override.ProxyAddress
|
||||
}
|
||||
if override.ProxyPort != nil {
|
||||
result.ProxyPort = *override.ProxyPort
|
||||
}
|
||||
if override.ProxyType != "" {
|
||||
result.ProxyType = override.ProxyType
|
||||
}
|
||||
if len(override.IPRules) > 0 {
|
||||
result.IPRules = nil
|
||||
for _, rule := range override.IPRules {
|
||||
result.IPRules = append(result.IPRules, IPRule{
|
||||
Prefix: rule.Prefix,
|
||||
Ports: append([]int(nil), rule.Ports...),
|
||||
Allow: rule.Allow,
|
||||
})
|
||||
}
|
||||
}
|
||||
if override.HTTP2Origin != nil {
|
||||
result.HTTP2Origin = *override.HTTP2Origin
|
||||
}
|
||||
if override.Access != nil {
|
||||
result.Access = AccessConfig{
|
||||
Required: override.Access.Required,
|
||||
TeamName: override.Access.TeamName,
|
||||
AudTag: append([]string(nil), override.Access.AudTag...),
|
||||
Environment: override.Access.Environment,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func warpRoutingFromOption(input option.CloudflareTunnelWarpRoutingOptions) WarpRoutingConfig {
|
||||
config := WarpRoutingConfig{
|
||||
ConnectTimeout: defaultWarpRoutingConnectTime,
|
||||
TCPKeepAlive: defaultWarpRoutingTCPKeepAlive,
|
||||
MaxActiveFlows: input.MaxActiveFlows,
|
||||
}
|
||||
if input.ConnectTimeout != 0 {
|
||||
config.ConnectTimeout = time.Duration(input.ConnectTimeout)
|
||||
}
|
||||
if input.TCPKeepAlive != 0 {
|
||||
config.TCPKeepAlive = time.Duration(input.TCPKeepAlive)
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
func warpRoutingFromRemote(input remoteWarpRoutingJSON) WarpRoutingConfig {
|
||||
config := WarpRoutingConfig{
|
||||
ConnectTimeout: defaultWarpRoutingConnectTime,
|
||||
TCPKeepAlive: defaultWarpRoutingTCPKeepAlive,
|
||||
MaxActiveFlows: input.MaxActiveFlows,
|
||||
}
|
||||
if input.ConnectTimeout != 0 {
|
||||
config.ConnectTimeout = time.Duration(input.ConnectTimeout) * time.Second
|
||||
}
|
||||
if input.TCPKeepAlive != 0 {
|
||||
config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) * time.Second
|
||||
}
|
||||
return config
|
||||
}
|
||||
Reference in New Issue
Block a user