diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 9263e16c3..761e4a978 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -231,7 +231,7 @@ func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteClose metadata.Network = N.NetworkTCP i.logger.InfoContext(ctx, "inbound HTTP connection to ", metadata.Destination) - transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest) + transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest, request.MetadataMap()[metadataHTTPHost]) defer cleanup() i.roundTripHTTP(ctx, stream, respWriter, request, service, transport) } @@ -240,7 +240,7 @@ func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWrite metadata.Network = N.NetworkTCP i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination) - transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest) + transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest, request.MetadataMap()[metadataHTTPHost]) defer cleanup() i.roundTripHTTP(ctx, stream, respWriter, request, service, transport) } @@ -249,7 +249,7 @@ func (i *Inbound) handleDirectHTTPStream(ctx context.Context, stream io.ReadWrit metadata.Network = N.NetworkTCP i.logger.InfoContext(ctx, "inbound HTTP connection to ", request.Dest) - transport, cleanup, err := i.newDirectOriginTransport(service) + transport, cleanup, err := i.newDirectOriginTransport(service, request.MetadataMap()[metadataHTTPHost]) if err != nil { i.logger.ErrorContext(ctx, "build direct origin transport: ", err) respWriter.WriteResponse(err, nil) @@ -263,7 +263,7 @@ func (i *Inbound) handleDirectWebSocketStream(ctx context.Context, stream io.Rea metadata.Network = N.NetworkTCP i.logger.InfoContext(ctx, "inbound WebSocket connection to ", request.Dest) - transport, cleanup, err := i.newDirectOriginTransport(service) + transport, cleanup, err := i.newDirectOriginTransport(service, request.MetadataMap()[metadataHTTPHost]) if err != nil { i.logger.ErrorContext(ctx, "build direct origin transport: ", err) respWriter.WriteResponse(err, nil) @@ -329,7 +329,7 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, } } -func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig) (*http.Transport, func()) { +func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig, requestHost string) (*http.Transport, func()) { input, output := pipe.Pipe() done := make(chan struct{}) go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) { @@ -344,7 +344,7 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter IdleConnTimeout: originRequest.KeepAliveTimeout, MaxIdleConns: originRequest.KeepAliveConnections, MaxIdleConnsPerHost: originRequest.KeepAliveConnections, - TLSClientConfig: buildOriginTLSConfig(originRequest), + TLSClientConfig: buildOriginTLSConfig(originRequest, requestHost), DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { return input, nil }, @@ -358,7 +358,7 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter } } -func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Transport, func(), error) { +func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost string) (*http.Transport, func(), error) { transport := &http.Transport{ DisableCompression: true, ForceAttemptHTTP2: service.OriginRequest.HTTP2Origin, @@ -366,7 +366,7 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Trans IdleConnTimeout: service.OriginRequest.KeepAliveTimeout, MaxIdleConns: service.OriginRequest.KeepAliveConnections, MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections, - TLSClientConfig: buildOriginTLSConfig(service.OriginRequest), + TLSClientConfig: buildOriginTLSConfig(service.OriginRequest, requestHost), } switch service.Kind { case ResolvedServiceUnix, ResolvedServiceUnixTLS: @@ -386,10 +386,10 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Trans return transport, func() {}, nil } -func buildOriginTLSConfig(originRequest OriginRequestConfig) *tls.Config { +func buildOriginTLSConfig(originRequest OriginRequestConfig, requestHost string) *tls.Config { tlsConfig := &tls.Config{ InsecureSkipVerify: originRequest.NoTLSVerify, //nolint:gosec - ServerName: originRequest.OriginServerName, + ServerName: originTLSServerName(originRequest, requestHost), } if originRequest.CAPool == "" { return tlsConfig @@ -405,6 +405,19 @@ func buildOriginTLSConfig(originRequest OriginRequestConfig) *tls.Config { return tlsConfig } +func originTLSServerName(originRequest OriginRequestConfig, requestHost string) string { + if originRequest.OriginServerName != "" { + return originRequest.OriginServerName + } + if !originRequest.MatchSNIToHost { + return "" + } + if host, _, err := net.SplitHostPort(requestHost); err == nil { + return host + } + return requestHost +} + func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig) *http.Request { request = request.Clone(request.Context()) if originRequest.HTTPHostHeader != "" { diff --git a/protocol/cloudflare/origin_request_test.go b/protocol/cloudflare/origin_request_test.go new file mode 100644 index 000000000..b56a0a52f --- /dev/null +++ b/protocol/cloudflare/origin_request_test.go @@ -0,0 +1,33 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import "testing" + +func TestOriginTLSServerName(t *testing.T) { + t.Run("origin server name overrides host", func(t *testing.T) { + serverName := originTLSServerName(OriginRequestConfig{ + OriginServerName: "origin.example.com", + MatchSNIToHost: true, + }, "request.example.com") + if serverName != "origin.example.com" { + t.Fatalf("expected origin.example.com, got %s", serverName) + } + }) + + t.Run("match sni to host strips port", func(t *testing.T) { + serverName := originTLSServerName(OriginRequestConfig{ + MatchSNIToHost: true, + }, "request.example.com:443") + if serverName != "request.example.com" { + t.Fatalf("expected request.example.com, got %s", serverName) + } + }) + + t.Run("disabled match keeps empty server name", func(t *testing.T) { + serverName := originTLSServerName(OriginRequestConfig{}, "request.example.com") + if serverName != "" { + t.Fatalf("expected empty server name, got %s", serverName) + } + }) +}