Implement router-backed cloudflare tunnel ingress config
This commit is contained in:
@@ -4,12 +4,16 @@ package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
@@ -124,19 +128,42 @@ func (i *Inbound) dispatchRequest(ctx context.Context, stream io.ReadWriteCloser
|
||||
metadata.Destination = M.ParseSocksaddr(request.Dest)
|
||||
i.handleTCPStream(ctx, stream, respWriter, metadata)
|
||||
case ConnectionTypeHTTP, ConnectionTypeWebsocket:
|
||||
originURL := i.ResolveOriginURL(request.Dest)
|
||||
request.Dest = originURL
|
||||
metadata.Destination = parseHTTPDestination(originURL)
|
||||
if request.Type == ConnectionTypeHTTP {
|
||||
i.handleHTTPStream(ctx, stream, respWriter, request, metadata)
|
||||
} else {
|
||||
i.handleWebSocketStream(ctx, stream, respWriter, request, metadata)
|
||||
service, originURL, err := i.resolveHTTPService(request.Dest)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "resolve origin service: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
request.Dest = originURL
|
||||
i.handleHTTPService(ctx, stream, respWriter, request, metadata, service)
|
||||
default:
|
||||
i.logger.ErrorContext(ctx, "unknown connection type: ", request.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Inbound) resolveHTTPService(requestURL string) (ResolvedService, string, error) {
|
||||
parsedURL, err := url.Parse(requestURL)
|
||||
if err != nil {
|
||||
return ResolvedService{}, "", E.Cause(err, "parse request URL")
|
||||
}
|
||||
service, loaded := i.configManager.Resolve(parsedURL.Hostname(), parsedURL.Path)
|
||||
if !loaded {
|
||||
return ResolvedService{}, "", E.New("no ingress rule matched request host/path")
|
||||
}
|
||||
if service.Kind == ResolvedServiceHelloWorld {
|
||||
helloURL, err := i.ensureHelloWorldURL()
|
||||
if err != nil {
|
||||
return ResolvedService{}, "", err
|
||||
}
|
||||
service.BaseURL = helloURL
|
||||
}
|
||||
originURL, err := service.BuildRequestURL(requestURL)
|
||||
if err != nil {
|
||||
return ResolvedService{}, "", E.Cause(err, "build origin request URL")
|
||||
}
|
||||
return service, originURL, nil
|
||||
}
|
||||
|
||||
func parseHTTPDestination(dest string) M.Socksaddr {
|
||||
parsed, err := url.Parse(dest)
|
||||
if err != nil {
|
||||
@@ -172,10 +199,81 @@ func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser
|
||||
<-done
|
||||
}
|
||||
|
||||
func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) {
|
||||
func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
switch service.Kind {
|
||||
case ResolvedServiceStatus:
|
||||
err := respWriter.WriteResponse(nil, encodeResponseHeaders(service.StatusCode, http.Header{}))
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "write status service response: ", err)
|
||||
}
|
||||
return
|
||||
case ResolvedServiceHTTP:
|
||||
metadata.Destination = service.Destination
|
||||
if request.Type == ConnectionTypeHTTP {
|
||||
i.handleHTTPStream(ctx, stream, respWriter, request, metadata, service)
|
||||
} else {
|
||||
i.handleWebSocketStream(ctx, stream, respWriter, request, metadata, service)
|
||||
}
|
||||
case ResolvedServiceUnix, ResolvedServiceUnixTLS, ResolvedServiceHelloWorld:
|
||||
if request.Type == ConnectionTypeHTTP {
|
||||
i.handleDirectHTTPStream(ctx, stream, respWriter, request, metadata, service)
|
||||
} else {
|
||||
i.handleDirectWebSocketStream(ctx, stream, respWriter, request, metadata, service)
|
||||
}
|
||||
default:
|
||||
err := E.New("unsupported service kind for HTTP/WebSocket request")
|
||||
i.logger.ErrorContext(ctx, err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound HTTP connection to ", metadata.Destination)
|
||||
|
||||
transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest)
|
||||
defer cleanup()
|
||||
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
|
||||
}
|
||||
|
||||
func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination)
|
||||
|
||||
transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest)
|
||||
defer cleanup()
|
||||
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
|
||||
}
|
||||
|
||||
func (i *Inbound) handleDirectHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound HTTP connection to ", request.Dest)
|
||||
|
||||
transport, cleanup, err := i.newDirectOriginTransport(service)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "build direct origin transport: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
defer cleanup()
|
||||
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
|
||||
}
|
||||
|
||||
func (i *Inbound) handleDirectWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", request.Dest)
|
||||
|
||||
transport, cleanup, err := i.newDirectOriginTransport(service)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "build direct origin transport: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
defer cleanup()
|
||||
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
|
||||
}
|
||||
|
||||
func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, service ResolvedService, transport *http.Transport) {
|
||||
httpRequest, err := buildHTTPRequestFromMetadata(ctx, request, stream)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "build HTTP request: ", err)
|
||||
@@ -183,23 +281,17 @@ func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteClose
|
||||
return
|
||||
}
|
||||
|
||||
input, output := pipe.Pipe()
|
||||
var innerError error
|
||||
|
||||
done := make(chan struct{})
|
||||
go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) {
|
||||
innerError = it
|
||||
common.Close(input, output)
|
||||
close(done)
|
||||
}))
|
||||
httpRequest = applyOriginRequest(httpRequest, service.OriginRequest)
|
||||
requestCtx := httpRequest.Context()
|
||||
if service.OriginRequest.ConnectTimeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
requestCtx, cancel = context.WithTimeout(requestCtx, service.OriginRequest.ConnectTimeout)
|
||||
defer cancel()
|
||||
httpRequest = httpRequest.WithContext(requestCtx)
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DisableCompression: true,
|
||||
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return input, nil
|
||||
},
|
||||
},
|
||||
Transport: transport,
|
||||
CheckRedirect: func(request *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
@@ -208,87 +300,146 @@ func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteClose
|
||||
|
||||
response, err := httpClient.Do(httpRequest)
|
||||
if err != nil {
|
||||
<-done
|
||||
i.logger.ErrorContext(ctx, "HTTP request: ", E.Errors(innerError, err))
|
||||
i.logger.ErrorContext(ctx, "origin request: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
responseMetadata := encodeResponseHeaders(response.StatusCode, response.Header)
|
||||
err = respWriter.WriteResponse(nil, responseMetadata)
|
||||
if err != nil {
|
||||
response.Body.Close()
|
||||
i.logger.ErrorContext(ctx, "write HTTP response headers: ", err)
|
||||
<-done
|
||||
i.logger.ErrorContext(ctx, "write origin response headers: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if request.Type == ConnectionTypeWebsocket && response.StatusCode == http.StatusSwitchingProtocols {
|
||||
rwc, ok := response.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
i.logger.ErrorContext(ctx, "websocket origin response body is not duplex")
|
||||
return
|
||||
}
|
||||
bidirectionalCopy(stream, rwc)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = io.Copy(stream, response.Body)
|
||||
response.Body.Close()
|
||||
common.Close(input, output)
|
||||
if err != nil && !E.IsClosedOrCanceled(err) {
|
||||
i.logger.DebugContext(ctx, "copy HTTP response body: ", err)
|
||||
}
|
||||
<-done
|
||||
}
|
||||
|
||||
func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) {
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination)
|
||||
|
||||
httpRequest, err := buildHTTPRequestFromMetadata(ctx, request, stream)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "build WebSocket request: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
|
||||
func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig) (*http.Transport, func()) {
|
||||
input, output := pipe.Pipe()
|
||||
var innerError error
|
||||
|
||||
done := make(chan struct{})
|
||||
go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) {
|
||||
innerError = it
|
||||
common.Close(input, output)
|
||||
close(done)
|
||||
}))
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DisableCompression: true,
|
||||
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return input, nil
|
||||
},
|
||||
},
|
||||
CheckRedirect: func(request *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
transport := &http.Transport{
|
||||
DisableCompression: true,
|
||||
ForceAttemptHTTP2: originRequest.HTTP2Origin,
|
||||
TLSHandshakeTimeout: originRequest.TLSTimeout,
|
||||
IdleConnTimeout: originRequest.KeepAliveTimeout,
|
||||
MaxIdleConns: originRequest.KeepAliveConnections,
|
||||
MaxIdleConnsPerHost: originRequest.KeepAliveConnections,
|
||||
TLSClientConfig: buildOriginTLSConfig(originRequest),
|
||||
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return input, nil
|
||||
},
|
||||
}
|
||||
defer httpClient.CloseIdleConnections()
|
||||
return transport, func() {
|
||||
common.Close(input, output)
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response, err := httpClient.Do(httpRequest)
|
||||
func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Transport, func(), error) {
|
||||
transport := &http.Transport{
|
||||
DisableCompression: true,
|
||||
ForceAttemptHTTP2: service.OriginRequest.HTTP2Origin,
|
||||
TLSHandshakeTimeout: service.OriginRequest.TLSTimeout,
|
||||
IdleConnTimeout: service.OriginRequest.KeepAliveTimeout,
|
||||
MaxIdleConns: service.OriginRequest.KeepAliveConnections,
|
||||
MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections,
|
||||
TLSClientConfig: buildOriginTLSConfig(service.OriginRequest),
|
||||
}
|
||||
switch service.Kind {
|
||||
case ResolvedServiceUnix, ResolvedServiceUnixTLS:
|
||||
dialer := &net.Dialer{}
|
||||
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, "unix", service.UnixPath)
|
||||
}
|
||||
case ResolvedServiceHelloWorld:
|
||||
dialer := &net.Dialer{}
|
||||
target := service.BaseURL.Host
|
||||
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, "tcp", target)
|
||||
}
|
||||
default:
|
||||
return nil, nil, E.New("unsupported direct origin service")
|
||||
}
|
||||
return transport, func() {}, nil
|
||||
}
|
||||
|
||||
func buildOriginTLSConfig(originRequest OriginRequestConfig) *tls.Config {
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: originRequest.NoTLSVerify, //nolint:gosec
|
||||
ServerName: originRequest.OriginServerName,
|
||||
}
|
||||
if originRequest.CAPool == "" {
|
||||
return tlsConfig
|
||||
}
|
||||
pemData, err := os.ReadFile(originRequest.CAPool)
|
||||
if err != nil {
|
||||
<-done
|
||||
i.logger.ErrorContext(ctx, "WebSocket request: ", E.Errors(innerError, err))
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
return tlsConfig
|
||||
}
|
||||
pool := x509.NewCertPool()
|
||||
if pool.AppendCertsFromPEM(pemData) {
|
||||
tlsConfig.RootCAs = pool
|
||||
}
|
||||
return tlsConfig
|
||||
}
|
||||
|
||||
func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig) *http.Request {
|
||||
request = request.Clone(request.Context())
|
||||
if originRequest.HTTPHostHeader != "" {
|
||||
request.Header.Set("X-Forwarded-Host", request.Host)
|
||||
request.Host = originRequest.HTTPHostHeader
|
||||
}
|
||||
if originRequest.DisableChunkedEncoding && request.Header.Get("Content-Length") != "" {
|
||||
if contentLength, err := strconv.ParseInt(request.Header.Get("Content-Length"), 10, 64); err == nil {
|
||||
request.ContentLength = contentLength
|
||||
request.TransferEncoding = nil
|
||||
}
|
||||
}
|
||||
return request
|
||||
}
|
||||
|
||||
func bidirectionalCopy(left, right io.ReadWriteCloser) {
|
||||
var closeOnce sync.Once
|
||||
closeBoth := func() {
|
||||
closeOnce.Do(func() {
|
||||
common.Close(left, right)
|
||||
})
|
||||
}
|
||||
|
||||
responseMetadata := encodeResponseHeaders(response.StatusCode, response.Header)
|
||||
err = respWriter.WriteResponse(nil, responseMetadata)
|
||||
if err != nil {
|
||||
response.Body.Close()
|
||||
i.logger.ErrorContext(ctx, "write WebSocket response headers: ", err)
|
||||
<-done
|
||||
return
|
||||
}
|
||||
|
||||
_, err = io.Copy(stream, response.Body)
|
||||
response.Body.Close()
|
||||
common.Close(input, output)
|
||||
if err != nil && !E.IsClosedOrCanceled(err) {
|
||||
i.logger.DebugContext(ctx, "copy WebSocket response body: ", err)
|
||||
}
|
||||
done := make(chan struct{}, 2)
|
||||
go func() {
|
||||
io.Copy(left, right)
|
||||
closeBoth()
|
||||
done <- struct{}{}
|
||||
}()
|
||||
go func() {
|
||||
io.Copy(right, left)
|
||||
closeBoth()
|
||||
done <- struct{}{}
|
||||
}()
|
||||
<-done
|
||||
<-done
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user