Files
sing-box/protocol/cloudflare/dispatch.go

525 lines
17 KiB
Go

//go:build with_cloudflare_tunnel
package cloudflare
import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe"
)
const (
metadataHTTPMethod = "HttpMethod"
metadataHTTPHost = "HttpHost"
metadataHTTPHeader = "HttpHeader"
metadataHTTPStatus = "HttpStatus"
)
// ConnectResponseWriter abstracts the response writing for both QUIC and HTTP/2.
type ConnectResponseWriter interface {
// WriteResponse sends the connect response (ack or error) with optional metadata.
WriteResponse(responseError error, metadata []Metadata) error
}
// quicResponseWriter writes ConnectResponse in QUIC data stream format (signature + capnp).
type quicResponseWriter struct {
stream io.Writer
}
func (w *quicResponseWriter) WriteResponse(responseError error, metadata []Metadata) error {
return WriteConnectResponse(w.stream, responseError, metadata...)
}
// HandleDataStream dispatches an incoming edge data stream (QUIC path).
func (i *Inbound) HandleDataStream(ctx context.Context, stream io.ReadWriteCloser, request *ConnectRequest, connIndex uint8) {
ctx = log.ContextWithNewID(ctx)
respWriter := &quicResponseWriter{stream: stream}
i.dispatchRequest(ctx, stream, respWriter, request)
}
// HandleRPCStream handles an incoming edge RPC stream (session management, configuration).
func (i *Inbound) HandleRPCStream(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8) {
i.logger.DebugContext(ctx, "received RPC stream on connection ", connIndex)
// V2 RPC streams are handled here - the edge calls RegisterUdpSession/UnregisterUdpSession
// We need the sender (DatagramSender) to find the muxer - but HandleRPCStream doesn't have it.
// The V2 muxer is looked up via GetOrCreateV2Muxer in HandleDatagram when first datagram arrives.
// For RPC, we need a different approach - see handleRPCStreamWithSender below.
}
// HandleRPCStreamWithSender handles an RPC stream with access to the DatagramSender for V2 muxer lookup.
func (i *Inbound) HandleRPCStreamWithSender(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8, sender DatagramSender) {
muxer := i.getOrCreateV2Muxer(sender)
ServeRPCStream(ctx, stream, i, muxer, i.logger)
}
// HandleDatagram handles an incoming QUIC datagram.
func (i *Inbound) HandleDatagram(ctx context.Context, datagram []byte, sender DatagramSender) {
switch i.datagramVersion {
case "v3":
muxer := i.getOrCreateV3Muxer(sender)
muxer.HandleDatagram(ctx, datagram)
default:
muxer := i.getOrCreateV2Muxer(sender)
muxer.HandleDatagram(ctx, datagram)
}
}
func (i *Inbound) getOrCreateV2Muxer(sender DatagramSender) *DatagramV2Muxer {
i.datagramMuxerAccess.Lock()
defer i.datagramMuxerAccess.Unlock()
muxer, exists := i.datagramV2Muxers[sender]
if !exists {
muxer = NewDatagramV2Muxer(i, sender, i.logger)
i.datagramV2Muxers[sender] = muxer
}
return muxer
}
func (i *Inbound) getOrCreateV3Muxer(sender DatagramSender) *DatagramV3Muxer {
i.datagramMuxerAccess.Lock()
defer i.datagramMuxerAccess.Unlock()
muxer, exists := i.datagramV3Muxers[sender]
if !exists {
muxer = NewDatagramV3Muxer(i, sender, i.logger)
i.datagramV3Muxers[sender] = muxer
}
return muxer
}
// RemoveDatagramMuxer cleans up muxers when a connection closes.
func (i *Inbound) RemoveDatagramMuxer(sender DatagramSender) {
i.datagramMuxerAccess.Lock()
if muxer, exists := i.datagramV2Muxers[sender]; exists {
muxer.Close()
delete(i.datagramV2Muxers, sender)
}
if muxer, exists := i.datagramV3Muxers[sender]; exists {
muxer.Close()
delete(i.datagramV3Muxers, sender)
}
i.datagramMuxerAccess.Unlock()
}
func (i *Inbound) dispatchRequest(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest) {
metadata := adapter.InboundContext{
Inbound: i.Tag(),
InboundType: i.Type(),
}
switch request.Type {
case ConnectionTypeTCP:
metadata.Destination = M.ParseSocksaddr(request.Dest)
i.handleTCPStream(ctx, stream, respWriter, metadata)
case ConnectionTypeHTTP, ConnectionTypeWebsocket:
service, originURL, err := i.resolveHTTPService(request.Dest)
if err != nil {
i.logger.ErrorContext(ctx, "resolve origin service: ", err)
respWriter.WriteResponse(err, nil)
return
}
request.Dest = originURL
i.handleHTTPService(ctx, stream, respWriter, request, metadata, service)
default:
i.logger.ErrorContext(ctx, "unknown connection type: ", request.Type)
}
}
func (i *Inbound) resolveHTTPService(requestURL string) (ResolvedService, string, error) {
parsedURL, err := url.Parse(requestURL)
if err != nil {
return ResolvedService{}, "", E.Cause(err, "parse request URL")
}
service, loaded := i.configManager.Resolve(parsedURL.Hostname(), parsedURL.Path)
if !loaded {
return ResolvedService{}, "", E.New("no ingress rule matched request host/path")
}
if service.Kind == ResolvedServiceHelloWorld {
helloURL, err := i.ensureHelloWorldURL()
if err != nil {
return ResolvedService{}, "", err
}
service.BaseURL = helloURL
}
originURL, err := service.BuildRequestURL(requestURL)
if err != nil {
return ResolvedService{}, "", E.Cause(err, "build origin request URL")
}
return service, originURL, nil
}
func parseHTTPDestination(dest string) M.Socksaddr {
parsed, err := url.Parse(dest)
if err != nil {
return M.ParseSocksaddr(dest)
}
host := parsed.Hostname()
port := parsed.Port()
if port == "" {
switch parsed.Scheme {
case "https", "wss":
port = "443"
default:
port = "80"
}
}
return M.ParseSocksaddr(net.JoinHostPort(host, port))
}
func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, metadata adapter.InboundContext) {
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound TCP connection to ", metadata.Destination)
err := respWriter.WriteResponse(nil, nil)
if err != nil {
i.logger.ErrorContext(ctx, "write connect response: ", err)
return
}
done := make(chan struct{})
i.router.RouteConnectionEx(ctx, newStreamConn(stream), metadata, N.OnceClose(func(it error) {
close(done)
}))
<-done
}
func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
switch service.Kind {
case ResolvedServiceStatus:
err := respWriter.WriteResponse(nil, encodeResponseHeaders(service.StatusCode, http.Header{}))
if err != nil {
i.logger.ErrorContext(ctx, "write status service response: ", err)
}
return
case ResolvedServiceHTTP:
metadata.Destination = service.Destination
if request.Type == ConnectionTypeHTTP {
i.handleHTTPStream(ctx, stream, respWriter, request, metadata, service)
} else {
i.handleWebSocketStream(ctx, stream, respWriter, request, metadata, service)
}
case ResolvedServiceUnix, ResolvedServiceUnixTLS, ResolvedServiceHelloWorld:
if request.Type == ConnectionTypeHTTP {
i.handleDirectHTTPStream(ctx, stream, respWriter, request, metadata, service)
} else {
i.handleDirectWebSocketStream(ctx, stream, respWriter, request, metadata, service)
}
default:
err := E.New("unsupported service kind for HTTP/WebSocket request")
i.logger.ErrorContext(ctx, err)
respWriter.WriteResponse(err, nil)
}
}
func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound HTTP connection to ", metadata.Destination)
transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest)
defer cleanup()
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
}
func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination)
transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest)
defer cleanup()
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
}
func (i *Inbound) handleDirectHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound HTTP connection to ", request.Dest)
transport, cleanup, err := i.newDirectOriginTransport(service)
if err != nil {
i.logger.ErrorContext(ctx, "build direct origin transport: ", err)
respWriter.WriteResponse(err, nil)
return
}
defer cleanup()
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
}
func (i *Inbound) handleDirectWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", request.Dest)
transport, cleanup, err := i.newDirectOriginTransport(service)
if err != nil {
i.logger.ErrorContext(ctx, "build direct origin transport: ", err)
respWriter.WriteResponse(err, nil)
return
}
defer cleanup()
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
}
func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, service ResolvedService, transport *http.Transport) {
httpRequest, err := buildHTTPRequestFromMetadata(ctx, request, stream)
if err != nil {
i.logger.ErrorContext(ctx, "build HTTP request: ", err)
respWriter.WriteResponse(err, nil)
return
}
httpRequest = applyOriginRequest(httpRequest, service.OriginRequest)
requestCtx := httpRequest.Context()
if service.OriginRequest.ConnectTimeout > 0 {
var cancel context.CancelFunc
requestCtx, cancel = context.WithTimeout(requestCtx, service.OriginRequest.ConnectTimeout)
defer cancel()
httpRequest = httpRequest.WithContext(requestCtx)
}
httpClient := &http.Client{
Transport: transport,
CheckRedirect: func(request *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
defer httpClient.CloseIdleConnections()
response, err := httpClient.Do(httpRequest)
if err != nil {
i.logger.ErrorContext(ctx, "origin request: ", err)
respWriter.WriteResponse(err, nil)
return
}
defer response.Body.Close()
responseMetadata := encodeResponseHeaders(response.StatusCode, response.Header)
err = respWriter.WriteResponse(nil, responseMetadata)
if err != nil {
i.logger.ErrorContext(ctx, "write origin response headers: ", err)
return
}
if request.Type == ConnectionTypeWebsocket && response.StatusCode == http.StatusSwitchingProtocols {
rwc, ok := response.Body.(io.ReadWriteCloser)
if !ok {
i.logger.ErrorContext(ctx, "websocket origin response body is not duplex")
return
}
bidirectionalCopy(stream, rwc)
return
}
_, err = io.Copy(stream, response.Body)
if err != nil && !E.IsClosedOrCanceled(err) {
i.logger.DebugContext(ctx, "copy HTTP response body: ", err)
}
}
func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig) (*http.Transport, func()) {
input, output := pipe.Pipe()
done := make(chan struct{})
go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) {
common.Close(input, output)
close(done)
}))
transport := &http.Transport{
DisableCompression: true,
ForceAttemptHTTP2: originRequest.HTTP2Origin,
TLSHandshakeTimeout: originRequest.TLSTimeout,
IdleConnTimeout: originRequest.KeepAliveTimeout,
MaxIdleConns: originRequest.KeepAliveConnections,
MaxIdleConnsPerHost: originRequest.KeepAliveConnections,
TLSClientConfig: buildOriginTLSConfig(originRequest),
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return input, nil
},
}
return transport, func() {
common.Close(input, output)
select {
case <-done:
case <-time.After(time.Second):
}
}
}
func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Transport, func(), error) {
transport := &http.Transport{
DisableCompression: true,
ForceAttemptHTTP2: service.OriginRequest.HTTP2Origin,
TLSHandshakeTimeout: service.OriginRequest.TLSTimeout,
IdleConnTimeout: service.OriginRequest.KeepAliveTimeout,
MaxIdleConns: service.OriginRequest.KeepAliveConnections,
MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections,
TLSClientConfig: buildOriginTLSConfig(service.OriginRequest),
}
switch service.Kind {
case ResolvedServiceUnix, ResolvedServiceUnixTLS:
dialer := &net.Dialer{}
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, "unix", service.UnixPath)
}
case ResolvedServiceHelloWorld:
dialer := &net.Dialer{}
target := service.BaseURL.Host
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", target)
}
default:
return nil, nil, E.New("unsupported direct origin service")
}
return transport, func() {}, nil
}
func buildOriginTLSConfig(originRequest OriginRequestConfig) *tls.Config {
tlsConfig := &tls.Config{
InsecureSkipVerify: originRequest.NoTLSVerify, //nolint:gosec
ServerName: originRequest.OriginServerName,
}
if originRequest.CAPool == "" {
return tlsConfig
}
pemData, err := os.ReadFile(originRequest.CAPool)
if err != nil {
return tlsConfig
}
pool := x509.NewCertPool()
if pool.AppendCertsFromPEM(pemData) {
tlsConfig.RootCAs = pool
}
return tlsConfig
}
func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig) *http.Request {
request = request.Clone(request.Context())
if originRequest.HTTPHostHeader != "" {
request.Header.Set("X-Forwarded-Host", request.Host)
request.Host = originRequest.HTTPHostHeader
}
if originRequest.DisableChunkedEncoding && request.Header.Get("Content-Length") != "" {
if contentLength, err := strconv.ParseInt(request.Header.Get("Content-Length"), 10, 64); err == nil {
request.ContentLength = contentLength
request.TransferEncoding = nil
}
}
return request
}
func bidirectionalCopy(left, right io.ReadWriteCloser) {
var closeOnce sync.Once
closeBoth := func() {
closeOnce.Do(func() {
common.Close(left, right)
})
}
done := make(chan struct{}, 2)
go func() {
io.Copy(left, right)
closeBoth()
done <- struct{}{}
}()
go func() {
io.Copy(right, left)
closeBoth()
done <- struct{}{}
}()
<-done
<-done
}
func buildHTTPRequestFromMetadata(ctx context.Context, connectRequest *ConnectRequest, body io.Reader) (*http.Request, error) {
metadataMap := connectRequest.MetadataMap()
method := metadataMap[metadataHTTPMethod]
host := metadataMap[metadataHTTPHost]
request, err := http.NewRequestWithContext(ctx, method, connectRequest.Dest, body)
if err != nil {
return nil, E.Cause(err, "create HTTP request")
}
request.Host = host
for _, entry := range connectRequest.Metadata {
if !strings.Contains(entry.Key, metadataHTTPHeader) {
continue
}
parts := strings.SplitN(entry.Key, ":", 2)
if len(parts) != 2 {
continue
}
request.Header.Add(parts[1], entry.Val)
}
contentLengthStr := request.Header.Get("Content-Length")
if contentLengthStr != "" {
request.ContentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
if err != nil {
return nil, E.Cause(err, "parse content-length")
}
}
if connectRequest.Type != ConnectionTypeWebsocket && !isTransferEncodingChunked(request) && request.ContentLength == 0 {
request.Body = http.NoBody
}
request.Header.Del("Cf-Cloudflared-Proxy-Connection-Upgrade")
return request, nil
}
func isTransferEncodingChunked(request *http.Request) bool {
for _, encoding := range request.TransferEncoding {
if strings.EqualFold(encoding, "chunked") {
return true
}
}
return false
}
func encodeResponseHeaders(statusCode int, header http.Header) []Metadata {
metadata := make([]Metadata, 0, len(header)+1)
metadata = append(metadata, Metadata{
Key: metadataHTTPStatus,
Val: strconv.Itoa(statusCode),
})
for name, values := range header {
for _, value := range values {
metadata = append(metadata, Metadata{
Key: metadataHTTPHeader + ":" + name,
Val: value,
})
}
}
return metadata
}
// streamConn wraps an io.ReadWriteCloser as a net.Conn.
type streamConn struct {
io.ReadWriteCloser
}
func newStreamConn(stream io.ReadWriteCloser) *streamConn {
return &streamConn{ReadWriteCloser: stream}
}
func (c *streamConn) LocalAddr() net.Addr { return nil }
func (c *streamConn) RemoteAddr() net.Addr { return nil }
func (c *streamConn) SetDeadline(_ time.Time) error { return nil }
func (c *streamConn) SetReadDeadline(_ time.Time) error { return nil }
func (c *streamConn) SetWriteDeadline(_ time.Time) error { return nil }