Files
sing-box/protocol/cloudflare/connection_http2.go
2026-03-31 15:32:57 +08:00

523 lines
14 KiB
Go

//go:build with_cloudflared
package cloudflare
import (
"context"
"crypto/tls"
"io"
"math"
"net"
"net/http"
"runtime/debug"
"strconv"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
M "github.com/sagernet/sing/common/metadata"
"github.com/google/uuid"
"golang.org/x/net/http2"
)
const (
h2EdgeSNI = "h2.cftunnel.com"
h2ResponseMetaCloudflared = `{"src":"cloudflared"}`
h2ResponseMetaCloudflaredLimited = `{"src":"cloudflared","flow_rate_limited":true}`
contentTypeHeader = "content-type"
contentLengthHeader = "content-length"
transferEncodingHeader = "transfer-encoding"
chunkTransferEncoding = "chunked"
sseContentType = "text/event-stream"
grpcContentType = "application/grpc"
ndjsonContentType = "application/x-ndjson"
)
var flushableContentTypes = []string{sseContentType, grpcContentType, ndjsonContentType}
// HTTP2Connection manages a single HTTP/2 connection to the Cloudflare edge.
// Uses role reversal: we dial the edge as a TLS client but serve HTTP/2 as server.
type HTTP2Connection struct {
conn net.Conn
server *http2.Server
logger log.ContextLogger
edgeAddr *EdgeAddr
connIndex uint8
credentials Credentials
connectorID uuid.UUID
features []string
gracePeriod time.Duration
inbound *Inbound
numPreviousAttempts uint8
registrationClient registrationRPCClient
registrationResult *RegistrationResult
controlStreamErr error
activeRequests sync.WaitGroup
serveCancel context.CancelFunc
registrationClose sync.Once
shutdownOnce sync.Once
closeOnce sync.Once
}
// NewHTTP2Connection dials the edge and establishes an HTTP/2 connection with role reversal.
func NewHTTP2Connection(
ctx context.Context,
edgeAddr *EdgeAddr,
connIndex uint8,
credentials Credentials,
connectorID uuid.UUID,
features []string,
numPreviousAttempts uint8,
gracePeriod time.Duration,
inbound *Inbound,
logger log.ContextLogger,
) (*HTTP2Connection, error) {
rootCAs, err := cloudflareRootCertPool()
if err != nil {
return nil, E.Cause(err, "load Cloudflare root CAs")
}
tlsConfig := newEdgeTLSConfig(rootCAs, h2EdgeSNI, nil)
tcpConn, err := inbound.tunnelDialer.DialContext(ctx, "tcp", M.SocksaddrFrom(edgeAddr.TCP.AddrPort().Addr(), edgeAddr.TCP.AddrPort().Port()))
if err != nil {
return nil, E.Cause(err, "dial edge TCP")
}
tlsConn := tls.Client(tcpConn, tlsConfig)
err = tlsConn.HandshakeContext(ctx)
if err != nil {
tcpConn.Close()
return nil, E.Cause(err, "TLS handshake")
}
return &HTTP2Connection{
conn: tlsConn,
server: &http2.Server{
MaxConcurrentStreams: math.MaxUint32,
},
logger: logger,
edgeAddr: edgeAddr,
connIndex: connIndex,
credentials: credentials,
connectorID: connectorID,
features: features,
numPreviousAttempts: numPreviousAttempts,
gracePeriod: gracePeriod,
inbound: inbound,
}, nil
}
// Serve runs the HTTP/2 server. Blocks until the context is cancelled or the connection ends.
func (c *HTTP2Connection) Serve(ctx context.Context) error {
serveCtx, serveCancel := context.WithCancel(context.WithoutCancel(ctx))
c.serveCancel = serveCancel
shutdownDone := make(chan struct{})
go func() {
<-ctx.Done()
c.gracefulShutdown()
close(shutdownDone)
}()
c.server.ServeConn(c.conn, &http2.ServeConnOpts{
Context: serveCtx,
Handler: c,
})
if ctx.Err() != nil {
<-shutdownDone
return ctx.Err()
}
if c.controlStreamErr != nil {
return c.controlStreamErr
}
if c.registrationResult == nil {
return E.New("edge connection closed before registration")
}
return E.New("edge connection closed")
}
func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Header.Get(h2HeaderUpgrade) == h2UpgradeControlStream {
c.handleControlStream(r.Context(), r, w)
return
}
c.activeRequests.Add(1)
defer c.activeRequests.Done()
switch {
case r.Header.Get(h2HeaderUpgrade) == h2UpgradeWebsocket:
c.handleH2DataStream(r.Context(), r, w, ConnectionTypeWebsocket)
case r.Header.Get(h2HeaderTCPSrc) != "":
c.handleH2DataStream(r.Context(), r, w, ConnectionTypeTCP)
case r.Header.Get(h2HeaderUpgrade) == h2UpgradeConfiguration:
c.handleConfigurationUpdate(r, w)
default:
c.handleH2DataStream(r.Context(), r, w, ConnectionTypeHTTP)
}
}
func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Request, w http.ResponseWriter) {
flusher, ok := w.(http.Flusher)
if !ok {
c.logger.Error("response writer does not support flushing")
return
}
w.WriteHeader(http.StatusOK)
flusher.Flush()
stream := newHTTP2Stream(r.Body, &http2FlushWriter{w: w, flusher: flusher})
c.registrationClient = NewRegistrationClient(ctx, stream)
host, _, _ := net.SplitHostPort(c.conn.LocalAddr().String())
originLocalIP := net.ParseIP(host)
options := BuildConnectionOptions(c.connectorID, c.features, c.numPreviousAttempts, originLocalIP)
result, err := c.registrationClient.RegisterConnection(
ctx, c.credentials.Auth(), c.credentials.TunnelID, c.connIndex, options,
)
if err != nil {
c.controlStreamErr = err
c.logger.Error("register connection: ", err)
go c.forceClose()
return
}
if err := validateRegistrationResult(result); err != nil {
c.controlStreamErr = err
c.logger.Error("register connection: ", err)
go c.forceClose()
return
}
c.registrationResult = result
c.inbound.notifyConnected(c.connIndex, "http2")
c.logger.Info("connected to ", result.Location,
" (connection ", result.ConnectionID, ")")
<-ctx.Done()
}
func (c *HTTP2Connection) handleH2DataStream(ctx context.Context, r *http.Request, w http.ResponseWriter, connectionType ConnectionType) {
r.Header.Del(h2HeaderUpgrade)
r.Header.Del(h2HeaderTCPSrc)
flusher, ok := w.(http.Flusher)
if !ok {
c.logger.Error("response writer does not support flushing")
return
}
var destination string
if connectionType == ConnectionTypeTCP {
destination = r.Host
if destination == "" && r.URL != nil {
destination = r.URL.Host
}
} else {
if r.URL.Scheme == "" {
r.URL.Scheme = "http"
}
if r.URL.Host == "" {
r.URL.Host = r.Host
}
destination = r.URL.String()
}
request := &ConnectRequest{
Dest: destination,
Type: connectionType,
}
request.Metadata = append(request.Metadata, Metadata{
Key: metadataHTTPMethod,
Val: r.Method,
})
request.Metadata = append(request.Metadata, Metadata{
Key: metadataHTTPHost,
Val: r.Host,
})
for name, values := range r.Header {
for _, value := range values {
request.Metadata = append(request.Metadata, Metadata{
Key: metadataHTTPHeader + ":" + name,
Val: value,
})
}
}
flushState := &http2FlushState{shouldFlush: connectionType != ConnectionTypeHTTP}
stream := &http2DataStream{
reader: r.Body,
writer: w,
flusher: flusher,
state: flushState,
logger: c.logger,
}
respWriter := &http2ResponseWriter{
writer: w,
flusher: flusher,
flushState: flushState,
}
c.inbound.dispatchRequest(ctx, stream, respWriter, request)
}
type h2ConfigurationUpdateBody struct {
Version int32 `json:"version"`
Config json.RawMessage `json:"config"`
}
func (c *HTTP2Connection) handleConfigurationUpdate(r *http.Request, w http.ResponseWriter) {
var body h2ConfigurationUpdateBody
err := json.NewDecoder(r.Body).Decode(&body)
if err != nil {
c.logger.Error("decode configuration update: ", err)
w.Header().Set(h2HeaderResponseMeta, h2ResponseMetaCloudflared)
w.WriteHeader(http.StatusBadGateway)
return
}
result := c.inbound.ApplyConfig(body.Version, body.Config)
w.WriteHeader(http.StatusOK)
if result.Err != nil {
w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":` + strconv.Quote(result.Err.Error()) + `}`))
return
}
w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":null}`))
}
func (c *HTTP2Connection) gracefulShutdown() {
c.shutdownOnce.Do(func() {
if c.registrationClient == nil || c.registrationResult == nil {
c.closeNow()
return
}
unregisterCtx, cancel := context.WithTimeout(context.Background(), c.gracePeriod)
err := c.registrationClient.Unregister(unregisterCtx)
cancel()
if err != nil {
c.logger.Debug("failed to unregister: ", err)
}
c.closeRegistrationClient()
c.waitForActiveRequests(c.gracePeriod)
c.closeNow()
})
}
func (c *HTTP2Connection) forceClose() {
c.shutdownOnce.Do(func() {
c.closeNow()
})
}
func (c *HTTP2Connection) waitForActiveRequests(timeout time.Duration) {
if timeout <= 0 {
c.activeRequests.Wait()
return
}
done := make(chan struct{})
go func() {
c.activeRequests.Wait()
close(done)
}()
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case <-done:
case <-timer.C:
}
}
func (c *HTTP2Connection) closeRegistrationClient() {
c.registrationClose.Do(func() {
if c.registrationClient != nil {
_ = c.registrationClient.Close()
}
})
}
func (c *HTTP2Connection) closeNow() {
c.closeOnce.Do(func() {
_ = c.conn.Close()
if c.serveCancel != nil {
c.serveCancel()
}
c.closeRegistrationClient()
c.activeRequests.Wait()
})
}
// Close closes the HTTP/2 connection.
func (c *HTTP2Connection) Close() error {
c.forceClose()
return nil
}
// http2Stream wraps an HTTP/2 request body (reader) and a flush-writer (writer) as an io.ReadWriteCloser.
// Used for the control stream.
type http2Stream struct {
reader io.ReadCloser
writer io.Writer
}
func newHTTP2Stream(reader io.ReadCloser, writer io.Writer) *http2Stream {
return &http2Stream{reader: reader, writer: writer}
}
func (s *http2Stream) Read(p []byte) (int, error) { return s.reader.Read(p) }
func (s *http2Stream) Write(p []byte) (int, error) { return s.writer.Write(p) }
func (s *http2Stream) Close() error { return s.reader.Close() }
// http2FlushWriter wraps an http.ResponseWriter and flushes after every write.
type http2FlushWriter struct {
w http.ResponseWriter
flusher http.Flusher
}
func (w *http2FlushWriter) Write(p []byte) (int, error) {
n, err := w.w.Write(p)
if err == nil {
w.flusher.Flush()
}
return n, err
}
// http2DataStream wraps an HTTP/2 request/response pair as io.ReadWriteCloser for data streams.
type http2DataStream struct {
reader io.ReadCloser
writer http.ResponseWriter
flusher http.Flusher
state *http2FlushState
logger log.ContextLogger
}
func (s *http2DataStream) Read(p []byte) (int, error) {
return s.reader.Read(p)
}
func (s *http2DataStream) Write(p []byte) (n int, err error) {
defer func() {
if recovered := recover(); recovered != nil {
if s.logger != nil {
s.logger.Debug("recovered from HTTP/2 data stream panic: ", recovered, "\n", string(debug.Stack()))
}
n = 0
err = io.ErrClosedPipe
}
}()
n, err = s.writer.Write(p)
if err == nil && s.state != nil && s.state.shouldFlush {
s.flusher.Flush()
}
return n, err
}
func (s *http2DataStream) Close() error {
return s.reader.Close()
}
// http2ResponseWriter translates ConnectResponse metadata to HTTP/2 response headers.
type http2ResponseWriter struct {
writer http.ResponseWriter
flusher http.Flusher
headersSent bool
flushState *http2FlushState
}
func (w *http2ResponseWriter) AddTrailer(name, value string) {
if !w.headersSent {
return
}
w.writer.Header().Add(http2.TrailerPrefix+name, value)
}
func (w *http2ResponseWriter) WriteResponse(responseError error, metadata []Metadata) error {
if w.headersSent {
return nil
}
w.headersSent = true
if responseError != nil {
if hasFlowConnectRateLimited(metadata) {
w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaCloudflaredLimited)
} else {
w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaCloudflared)
}
w.writer.WriteHeader(http.StatusBadGateway)
w.flusher.Flush()
return nil
}
statusCode := http.StatusOK
userHeaders := make(http.Header)
for _, entry := range metadata {
if entry.Key == metadataHTTPStatus {
code, err := strconv.Atoi(entry.Val)
if err == nil {
statusCode = code
}
continue
}
if strings.HasPrefix(entry.Key, metadataHTTPHeader+":") {
headerName := strings.TrimPrefix(entry.Key, metadataHTTPHeader+":")
lower := strings.ToLower(headerName)
if lower == "content-length" {
w.writer.Header().Set(headerName, entry.Val)
}
if !isControlResponseHeader(lower) || isWebsocketClientHeader(lower) {
userHeaders.Add(headerName, entry.Val)
}
}
}
w.writer.Header().Set(h2HeaderResponseUser, SerializeHeaders(userHeaders))
w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaOrigin)
if w.flushState != nil && shouldFlushHTTPHeaders(userHeaders) {
w.flushState.shouldFlush = true
}
if statusCode == http.StatusSwitchingProtocols {
statusCode = http.StatusOK
}
w.writer.WriteHeader(statusCode)
if w.flushState != nil && w.flushState.shouldFlush {
w.flusher.Flush()
}
return nil
}
type http2FlushState struct {
shouldFlush bool
}
func shouldFlushHTTPHeaders(headers http.Header) bool {
if headers.Get(contentLengthHeader) == "" {
return true
}
if transferEncoding := strings.ToLower(headers.Get(transferEncodingHeader)); transferEncoding != "" && strings.Contains(transferEncoding, chunkTransferEncoding) {
return true
}
contentType := strings.ToLower(headers.Get(contentTypeHeader))
for _, flushable := range flushableContentTypes {
if strings.HasPrefix(contentType, flushable) {
return true
}
}
return false
}