mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
Apply origin request SNI selection
This commit is contained in:
@@ -231,7 +231,7 @@ func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteClose
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound HTTP connection to ", metadata.Destination)
|
||||
|
||||
transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest)
|
||||
transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest, request.MetadataMap()[metadataHTTPHost])
|
||||
defer cleanup()
|
||||
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
|
||||
}
|
||||
@@ -240,7 +240,7 @@ func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWrite
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination)
|
||||
|
||||
transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest)
|
||||
transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest, request.MetadataMap()[metadataHTTPHost])
|
||||
defer cleanup()
|
||||
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
|
||||
}
|
||||
@@ -249,7 +249,7 @@ func (i *Inbound) handleDirectHTTPStream(ctx context.Context, stream io.ReadWrit
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound HTTP connection to ", request.Dest)
|
||||
|
||||
transport, cleanup, err := i.newDirectOriginTransport(service)
|
||||
transport, cleanup, err := i.newDirectOriginTransport(service, request.MetadataMap()[metadataHTTPHost])
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "build direct origin transport: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
@@ -263,7 +263,7 @@ func (i *Inbound) handleDirectWebSocketStream(ctx context.Context, stream io.Rea
|
||||
metadata.Network = N.NetworkTCP
|
||||
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", request.Dest)
|
||||
|
||||
transport, cleanup, err := i.newDirectOriginTransport(service)
|
||||
transport, cleanup, err := i.newDirectOriginTransport(service, request.MetadataMap()[metadataHTTPHost])
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "build direct origin transport: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
@@ -329,7 +329,7 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig) (*http.Transport, func()) {
|
||||
func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig, requestHost string) (*http.Transport, func()) {
|
||||
input, output := pipe.Pipe()
|
||||
done := make(chan struct{})
|
||||
go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) {
|
||||
@@ -344,7 +344,7 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter
|
||||
IdleConnTimeout: originRequest.KeepAliveTimeout,
|
||||
MaxIdleConns: originRequest.KeepAliveConnections,
|
||||
MaxIdleConnsPerHost: originRequest.KeepAliveConnections,
|
||||
TLSClientConfig: buildOriginTLSConfig(originRequest),
|
||||
TLSClientConfig: buildOriginTLSConfig(originRequest, requestHost),
|
||||
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return input, nil
|
||||
},
|
||||
@@ -358,7 +358,7 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Transport, func(), error) {
|
||||
func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost string) (*http.Transport, func(), error) {
|
||||
transport := &http.Transport{
|
||||
DisableCompression: true,
|
||||
ForceAttemptHTTP2: service.OriginRequest.HTTP2Origin,
|
||||
@@ -366,7 +366,7 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Trans
|
||||
IdleConnTimeout: service.OriginRequest.KeepAliveTimeout,
|
||||
MaxIdleConns: service.OriginRequest.KeepAliveConnections,
|
||||
MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections,
|
||||
TLSClientConfig: buildOriginTLSConfig(service.OriginRequest),
|
||||
TLSClientConfig: buildOriginTLSConfig(service.OriginRequest, requestHost),
|
||||
}
|
||||
switch service.Kind {
|
||||
case ResolvedServiceUnix, ResolvedServiceUnixTLS:
|
||||
@@ -386,10 +386,10 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Trans
|
||||
return transport, func() {}, nil
|
||||
}
|
||||
|
||||
func buildOriginTLSConfig(originRequest OriginRequestConfig) *tls.Config {
|
||||
func buildOriginTLSConfig(originRequest OriginRequestConfig, requestHost string) *tls.Config {
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: originRequest.NoTLSVerify, //nolint:gosec
|
||||
ServerName: originRequest.OriginServerName,
|
||||
ServerName: originTLSServerName(originRequest, requestHost),
|
||||
}
|
||||
if originRequest.CAPool == "" {
|
||||
return tlsConfig
|
||||
@@ -405,6 +405,19 @@ func buildOriginTLSConfig(originRequest OriginRequestConfig) *tls.Config {
|
||||
return tlsConfig
|
||||
}
|
||||
|
||||
func originTLSServerName(originRequest OriginRequestConfig, requestHost string) string {
|
||||
if originRequest.OriginServerName != "" {
|
||||
return originRequest.OriginServerName
|
||||
}
|
||||
if !originRequest.MatchSNIToHost {
|
||||
return ""
|
||||
}
|
||||
if host, _, err := net.SplitHostPort(requestHost); err == nil {
|
||||
return host
|
||||
}
|
||||
return requestHost
|
||||
}
|
||||
|
||||
func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig) *http.Request {
|
||||
request = request.Clone(request.Context())
|
||||
if originRequest.HTTPHostHeader != "" {
|
||||
|
||||
33
protocol/cloudflare/origin_request_test.go
Normal file
33
protocol/cloudflare/origin_request_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
//go:build with_cloudflare_tunnel
|
||||
|
||||
package cloudflare
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestOriginTLSServerName(t *testing.T) {
|
||||
t.Run("origin server name overrides host", func(t *testing.T) {
|
||||
serverName := originTLSServerName(OriginRequestConfig{
|
||||
OriginServerName: "origin.example.com",
|
||||
MatchSNIToHost: true,
|
||||
}, "request.example.com")
|
||||
if serverName != "origin.example.com" {
|
||||
t.Fatalf("expected origin.example.com, got %s", serverName)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("match sni to host strips port", func(t *testing.T) {
|
||||
serverName := originTLSServerName(OriginRequestConfig{
|
||||
MatchSNIToHost: true,
|
||||
}, "request.example.com:443")
|
||||
if serverName != "request.example.com" {
|
||||
t.Fatalf("expected request.example.com, got %s", serverName)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disabled match keeps empty server name", func(t *testing.T) {
|
||||
serverName := originTLSServerName(OriginRequestConfig{}, "request.example.com")
|
||||
if serverName != "" {
|
||||
t.Fatalf("expected empty server name, got %s", serverName)
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user