Apply origin request SNI selection

This commit is contained in:
世界
2026-03-24 11:20:26 +08:00
parent 124379fc1d
commit b3cad021b8
2 changed files with 56 additions and 10 deletions

View File

@@ -231,7 +231,7 @@ func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteClose
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound HTTP connection to ", metadata.Destination)
transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest)
transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest, request.MetadataMap()[metadataHTTPHost])
defer cleanup()
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
}
@@ -240,7 +240,7 @@ func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWrite
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination)
transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest)
transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest, request.MetadataMap()[metadataHTTPHost])
defer cleanup()
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
}
@@ -249,7 +249,7 @@ func (i *Inbound) handleDirectHTTPStream(ctx context.Context, stream io.ReadWrit
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound HTTP connection to ", request.Dest)
transport, cleanup, err := i.newDirectOriginTransport(service)
transport, cleanup, err := i.newDirectOriginTransport(service, request.MetadataMap()[metadataHTTPHost])
if err != nil {
i.logger.ErrorContext(ctx, "build direct origin transport: ", err)
respWriter.WriteResponse(err, nil)
@@ -263,7 +263,7 @@ func (i *Inbound) handleDirectWebSocketStream(ctx context.Context, stream io.Rea
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", request.Dest)
transport, cleanup, err := i.newDirectOriginTransport(service)
transport, cleanup, err := i.newDirectOriginTransport(service, request.MetadataMap()[metadataHTTPHost])
if err != nil {
i.logger.ErrorContext(ctx, "build direct origin transport: ", err)
respWriter.WriteResponse(err, nil)
@@ -329,7 +329,7 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser,
}
}
func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig) (*http.Transport, func()) {
func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig, requestHost string) (*http.Transport, func()) {
input, output := pipe.Pipe()
done := make(chan struct{})
go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) {
@@ -344,7 +344,7 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter
IdleConnTimeout: originRequest.KeepAliveTimeout,
MaxIdleConns: originRequest.KeepAliveConnections,
MaxIdleConnsPerHost: originRequest.KeepAliveConnections,
TLSClientConfig: buildOriginTLSConfig(originRequest),
TLSClientConfig: buildOriginTLSConfig(originRequest, requestHost),
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return input, nil
},
@@ -358,7 +358,7 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter
}
}
func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Transport, func(), error) {
func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost string) (*http.Transport, func(), error) {
transport := &http.Transport{
DisableCompression: true,
ForceAttemptHTTP2: service.OriginRequest.HTTP2Origin,
@@ -366,7 +366,7 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Trans
IdleConnTimeout: service.OriginRequest.KeepAliveTimeout,
MaxIdleConns: service.OriginRequest.KeepAliveConnections,
MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections,
TLSClientConfig: buildOriginTLSConfig(service.OriginRequest),
TLSClientConfig: buildOriginTLSConfig(service.OriginRequest, requestHost),
}
switch service.Kind {
case ResolvedServiceUnix, ResolvedServiceUnixTLS:
@@ -386,10 +386,10 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Trans
return transport, func() {}, nil
}
func buildOriginTLSConfig(originRequest OriginRequestConfig) *tls.Config {
func buildOriginTLSConfig(originRequest OriginRequestConfig, requestHost string) *tls.Config {
tlsConfig := &tls.Config{
InsecureSkipVerify: originRequest.NoTLSVerify, //nolint:gosec
ServerName: originRequest.OriginServerName,
ServerName: originTLSServerName(originRequest, requestHost),
}
if originRequest.CAPool == "" {
return tlsConfig
@@ -405,6 +405,19 @@ func buildOriginTLSConfig(originRequest OriginRequestConfig) *tls.Config {
return tlsConfig
}
func originTLSServerName(originRequest OriginRequestConfig, requestHost string) string {
if originRequest.OriginServerName != "" {
return originRequest.OriginServerName
}
if !originRequest.MatchSNIToHost {
return ""
}
if host, _, err := net.SplitHostPort(requestHost); err == nil {
return host
}
return requestHost
}
func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig) *http.Request {
request = request.Clone(request.Context())
if originRequest.HTTPHostHeader != "" {

View File

@@ -0,0 +1,33 @@
//go:build with_cloudflare_tunnel
package cloudflare
import "testing"
func TestOriginTLSServerName(t *testing.T) {
t.Run("origin server name overrides host", func(t *testing.T) {
serverName := originTLSServerName(OriginRequestConfig{
OriginServerName: "origin.example.com",
MatchSNIToHost: true,
}, "request.example.com")
if serverName != "origin.example.com" {
t.Fatalf("expected origin.example.com, got %s", serverName)
}
})
t.Run("match sni to host strips port", func(t *testing.T) {
serverName := originTLSServerName(OriginRequestConfig{
MatchSNIToHost: true,
}, "request.example.com:443")
if serverName != "request.example.com" {
t.Fatalf("expected request.example.com, got %s", serverName)
}
})
t.Run("disabled match keeps empty server name", func(t *testing.T) {
serverName := originTLSServerName(OriginRequestConfig{}, "request.example.com")
if serverName != "" {
t.Fatalf("expected empty server name, got %s", serverName)
}
})
}