mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
Fix cloudflared parity regressions
This commit is contained in:
@@ -56,17 +56,21 @@ func (v *oidcAccessValidator) Validate(ctx context.Context, request *http.Reques
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(v.audTags) == 0 {
|
||||
if accessTokenAudienceAllowed(token.Audience, v.audTags) {
|
||||
return nil
|
||||
}
|
||||
for _, jwtAudTag := range token.Audience {
|
||||
for _, acceptedAudTag := range v.audTags {
|
||||
if acceptedAudTag == jwtAudTag {
|
||||
return nil
|
||||
return E.New("access token audience does not match configured aud_tag")
|
||||
}
|
||||
|
||||
func accessTokenAudienceAllowed(tokenAudience []string, configuredAudTags []string) bool {
|
||||
for _, tokenAudTag := range tokenAudience {
|
||||
for _, configuredAudTag := range configuredAudTags {
|
||||
if configuredAudTag == tokenAudTag {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return E.New("access token audience does not match configured aud_tag")
|
||||
return false
|
||||
}
|
||||
|
||||
func accessIssuerURL(teamName string, environment string) string {
|
||||
|
||||
@@ -50,6 +50,43 @@ func TestValidateAccessConfiguration(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessTokenAudienceAllowed(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
tokenAudience []string
|
||||
configuredTags []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "matching audience",
|
||||
tokenAudience: []string{"aud-1", "aud-2"},
|
||||
configuredTags: []string{"aud-2"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "empty configured tags rejected",
|
||||
tokenAudience: []string{"aud-1"},
|
||||
configuredTags: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "non matching audience rejected",
|
||||
tokenAudience: []string{"aud-1"},
|
||||
configuredTags: []string{"aud-2"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
allowed := accessTokenAudienceAllowed(testCase.tokenAudience, testCase.configuredTags)
|
||||
if allowed != testCase.expected {
|
||||
t.Fatalf("accessTokenAudienceAllowed(%v, %v) = %v, want %v", testCase.tokenAudience, testCase.configuredTags, allowed, testCase.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundTripHTTPAccessDenied(t *testing.T) {
|
||||
originalFactory := newAccessValidator
|
||||
defer func() {
|
||||
|
||||
@@ -411,6 +411,13 @@ type http2ResponseWriter struct {
|
||||
headersSent bool
|
||||
}
|
||||
|
||||
func (w *http2ResponseWriter) AddTrailer(name, value string) {
|
||||
if !w.headersSent {
|
||||
return
|
||||
}
|
||||
w.writer.Header().Add(http2.TrailerPrefix+name, value)
|
||||
}
|
||||
|
||||
func (w *http2ResponseWriter) WriteResponse(responseError error, metadata []Metadata) error {
|
||||
if w.headersSent {
|
||||
return nil
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
@@ -262,7 +263,7 @@ func (q *QUICConnection) handleStream(ctx context.Context, stream *quic.Stream,
|
||||
q.logger.Debug("failed to read connect request: ", err)
|
||||
return
|
||||
}
|
||||
handler.HandleDataStream(ctx, rwc, request, q.connIndex)
|
||||
handler.HandleDataStream(ctx, &nopCloserReadWriter{ReadWriteCloser: rwc}, request, q.connIndex)
|
||||
|
||||
case StreamTypeRPC:
|
||||
handler.HandleRPCStreamWithSender(ctx, rwc, q.connIndex, q)
|
||||
@@ -388,3 +389,33 @@ func (s *streamReadWriteCloser) Close() error {
|
||||
s.stream.CancelRead(0)
|
||||
return s.stream.Close()
|
||||
}
|
||||
|
||||
// nopCloserReadWriter lets handlers stop consuming the read side without closing
|
||||
// the underlying stream write side. This matches cloudflared's QUIC HTTP behavior,
|
||||
// where the request body can be closed before the response is fully written.
|
||||
type nopCloserReadWriter struct {
|
||||
io.ReadWriteCloser
|
||||
|
||||
sawEOF bool
|
||||
closed uint32
|
||||
}
|
||||
|
||||
func (n *nopCloserReadWriter) Read(p []byte) (int, error) {
|
||||
if n.sawEOF {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if atomic.LoadUint32(&n.closed) > 0 {
|
||||
return 0, fmt.Errorf("closed by handler")
|
||||
}
|
||||
|
||||
readLen, err := n.ReadWriteCloser.Read(p)
|
||||
if err == io.EOF {
|
||||
n.sawEOF = true
|
||||
}
|
||||
return readLen, err
|
||||
}
|
||||
|
||||
func (n *nopCloserReadWriter) Close() error {
|
||||
atomic.StoreUint32(&n.closed, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,7 +2,11 @@
|
||||
|
||||
package cloudflare
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestQUICInitialPacketSize(t *testing.T) {
|
||||
testCases := []struct {
|
||||
@@ -23,3 +27,53 @@ func TestQUICInitialPacketSize(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockReadWriteCloser struct {
|
||||
reader strings.Reader
|
||||
writes []byte
|
||||
}
|
||||
|
||||
func (m *mockReadWriteCloser) Read(p []byte) (int, error) {
|
||||
return m.reader.Read(p)
|
||||
}
|
||||
|
||||
func (m *mockReadWriteCloser) Write(p []byte) (int, error) {
|
||||
m.writes = append(m.writes, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (m *mockReadWriteCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNOPCloserReadWriterCloseOnlyStopsReads(t *testing.T) {
|
||||
inner := &mockReadWriteCloser{reader: *strings.NewReader("payload")}
|
||||
wrapper := &nopCloserReadWriter{ReadWriteCloser: inner}
|
||||
|
||||
if err := wrapper.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := wrapper.Read(make([]byte, 1)); err == nil {
|
||||
t.Fatal("expected read to fail after close")
|
||||
}
|
||||
|
||||
if _, err := wrapper.Write([]byte("response")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(inner.writes) != "response" {
|
||||
t.Fatalf("unexpected writes %q", inner.writes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNOPCloserReadWriterTracksEOF(t *testing.T) {
|
||||
inner := &mockReadWriteCloser{reader: *strings.NewReader("")}
|
||||
wrapper := &nopCloserReadWriter{ReadWriteCloser: inner}
|
||||
|
||||
if _, err := wrapper.Read(make([]byte, 1)); err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %v", err)
|
||||
}
|
||||
if _, err := wrapper.Read(make([]byte, 1)); err != io.EOF {
|
||||
t.Fatalf("expected cached EOF, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
133
protocol/cloudflare/datagram_rpc_test.go
Normal file
133
protocol/cloudflare/datagram_rpc_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
|
||||
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
)
|
||||
|
||||
func newRegisterUDPSessionCall(t *testing.T, traceContext string) (tunnelrpc.SessionManager_registerUdpSession, func() (tunnelrpc.RegisterUdpSessionResponse, error)) {
|
||||
t.Helper()
|
||||
|
||||
_, paramsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
params, err := tunnelrpc.NewSessionManager_registerUdpSession_Params(paramsSeg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sessionID := uuid.New()
|
||||
if err := params.SetSessionId(sessionID[:]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := params.SetDstIp([]byte{127, 0, 0, 1}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
params.SetDstPort(53)
|
||||
params.SetCloseAfterIdleHint(int64(30))
|
||||
if err := params.SetTraceContext(traceContext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, resultsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
results, err := tunnelrpc.NewSessionManager_registerUdpSession_Results(resultsSeg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
call := tunnelrpc.SessionManager_registerUdpSession{
|
||||
Ctx: context.Background(),
|
||||
Params: params,
|
||||
Results: results,
|
||||
}
|
||||
return call, results.Result
|
||||
}
|
||||
|
||||
func newUnregisterUDPSessionCall(t *testing.T) tunnelrpc.SessionManager_unregisterUdpSession {
|
||||
t.Helper()
|
||||
|
||||
_, paramsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
params, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Params(paramsSeg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sessionID := uuid.New()
|
||||
if err := params.SetSessionId(sessionID[:]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := params.SetMessage("close"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, resultsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
results, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Results(resultsSeg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return tunnelrpc.SessionManager_unregisterUdpSession{
|
||||
Ctx: context.Background(),
|
||||
Params: params,
|
||||
Results: results,
|
||||
}
|
||||
}
|
||||
|
||||
func TestV3RPCRegisterUDPSessionReturnsUnsupportedResult(t *testing.T) {
|
||||
server := &cloudflaredV3Server{
|
||||
inbound: &Inbound{Adapter: inbound.NewAdapter(C.TypeCloudflared, "test")},
|
||||
}
|
||||
call, readResult := newRegisterUDPSessionCall(t, "trace-context")
|
||||
if err := server.RegisterUdpSession(call); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result, err := readResult()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resultErr, err := result.Err()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resultErr != errUnsupportedDatagramV3UDPRegistration.Error() {
|
||||
t.Fatalf("unexpected registration error %q", resultErr)
|
||||
}
|
||||
spans, err := result.Spans()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(spans) != 0 {
|
||||
t.Fatalf("expected empty spans, got %x", spans)
|
||||
}
|
||||
}
|
||||
|
||||
func TestV3RPCUnregisterUDPSessionReturnsUnsupportedError(t *testing.T) {
|
||||
server := &cloudflaredV3Server{
|
||||
inbound: &Inbound{Adapter: inbound.NewAdapter(C.TypeCloudflared, "test")},
|
||||
}
|
||||
err := server.UnregisterUdpSession(newUnregisterUDPSessionCall(t))
|
||||
if err == nil {
|
||||
t.Fatal("expected unsupported unregister error")
|
||||
}
|
||||
if err.Error() != errUnsupportedDatagramV3UDPUnregistration.Error() {
|
||||
t.Fatalf("unexpected unregister error %v", err)
|
||||
}
|
||||
}
|
||||
73
protocol/cloudflare/datagram_rpc_v3.go
Normal file
73
protocol/cloudflare/datagram_rpc_v3.go
Normal file
@@ -0,0 +1,73 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
)
|
||||
|
||||
var (
|
||||
errUnsupportedDatagramV3UDPRegistration = errors.New("datagram v3 does not support RegisterUdpSession RPC")
|
||||
errUnsupportedDatagramV3UDPUnregistration = errors.New("datagram v3 does not support UnregisterUdpSession RPC")
|
||||
)
|
||||
|
||||
type cloudflaredV3Server struct {
|
||||
inbound *Inbound
|
||||
logger log.ContextLogger
|
||||
}
|
||||
|
||||
func (s *cloudflaredV3Server) RegisterUdpSession(call tunnelrpc.SessionManager_registerUdpSession) error {
|
||||
result, err := call.Results.NewResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := result.SetErr(errUnsupportedDatagramV3UDPRegistration.Error()); err != nil {
|
||||
return err
|
||||
}
|
||||
return result.SetSpans([]byte{})
|
||||
}
|
||||
|
||||
func (s *cloudflaredV3Server) UnregisterUdpSession(call tunnelrpc.SessionManager_unregisterUdpSession) error {
|
||||
return errUnsupportedDatagramV3UDPUnregistration
|
||||
}
|
||||
|
||||
func (s *cloudflaredV3Server) UpdateConfiguration(call tunnelrpc.ConfigurationManager_updateConfiguration) error {
|
||||
version := call.Params.Version()
|
||||
configData, _ := call.Params.Config()
|
||||
updateResult := s.inbound.ApplyConfig(version, configData)
|
||||
result, err := call.Results.NewResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result.SetLatestAppliedVersion(updateResult.LastAppliedVersion)
|
||||
if updateResult.Err != nil {
|
||||
result.SetErr(updateResult.Err.Error())
|
||||
} else {
|
||||
result.SetErr("")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServeV3RPCStream serves configuration updates on v3 and rejects legacy UDP RPCs.
|
||||
func ServeV3RPCStream(ctx context.Context, stream io.ReadWriteCloser, inbound *Inbound, logger log.ContextLogger) {
|
||||
srv := &cloudflaredV3Server{
|
||||
inbound: inbound,
|
||||
logger: logger,
|
||||
}
|
||||
client := tunnelrpc.CloudflaredServer_ServerToClient(srv)
|
||||
transport := rpc.StreamTransport(stream)
|
||||
rpcConn := rpc.NewConn(transport, rpc.MainInterface(client.Client))
|
||||
<-rpcConn.Done()
|
||||
E.Errors(
|
||||
rpcConn.Close(),
|
||||
transport.Close(),
|
||||
)
|
||||
}
|
||||
@@ -474,6 +474,9 @@ func (s *cloudflaredServer) RegisterUdpSession(call tunnelrpc.SessionManager_reg
|
||||
|
||||
destinationPort := call.Params.DstPort()
|
||||
closeAfterIdle := time.Duration(call.Params.CloseAfterIdleHint())
|
||||
if _, traceErr := call.Params.TraceContext(); traceErr != nil {
|
||||
return traceErr
|
||||
}
|
||||
|
||||
err = s.muxer.RegisterSession(s.ctx, sessionID, net.IP(destinationIP), destinationPort, closeAfterIdle)
|
||||
|
||||
@@ -481,6 +484,9 @@ func (s *cloudflaredServer) RegisterUdpSession(call tunnelrpc.SessionManager_reg
|
||||
if allocErr != nil {
|
||||
return allocErr
|
||||
}
|
||||
if spansErr := result.SetSpans([]byte{}); spansErr != nil {
|
||||
return spansErr
|
||||
}
|
||||
if err != nil {
|
||||
result.SetErr(err.Error())
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ const (
|
||||
var (
|
||||
loadOriginCABasePool = cloudflareRootCertPool
|
||||
readOriginCAFile = os.ReadFile
|
||||
proxyFromEnvironment = http.ProxyFromEnvironment
|
||||
)
|
||||
|
||||
// ConnectResponseWriter abstracts the response writing for both QUIC and HTTP/2.
|
||||
@@ -42,6 +43,10 @@ type ConnectResponseWriter interface {
|
||||
WriteResponse(responseError error, metadata []Metadata) error
|
||||
}
|
||||
|
||||
type connectResponseTrailerWriter interface {
|
||||
AddTrailer(name, value string)
|
||||
}
|
||||
|
||||
// quicResponseWriter writes ConnectResponse in QUIC data stream format (signature + capnp).
|
||||
type quicResponseWriter struct {
|
||||
stream io.Writer
|
||||
@@ -69,8 +74,13 @@ func (i *Inbound) HandleRPCStream(ctx context.Context, stream io.ReadWriteCloser
|
||||
|
||||
// HandleRPCStreamWithSender handles an RPC stream with access to the DatagramSender for V2 muxer lookup.
|
||||
func (i *Inbound) HandleRPCStreamWithSender(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8, sender DatagramSender) {
|
||||
muxer := i.getOrCreateV2Muxer(sender)
|
||||
ServeRPCStream(ctx, stream, i, muxer, i.logger)
|
||||
switch datagramVersionForSender(sender) {
|
||||
case "v3":
|
||||
ServeV3RPCStream(ctx, stream, i, i.logger)
|
||||
default:
|
||||
muxer := i.getOrCreateV2Muxer(sender)
|
||||
ServeRPCStream(ctx, stream, i, muxer, i.logger)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleDatagram handles an incoming QUIC datagram.
|
||||
@@ -401,6 +411,13 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser,
|
||||
if err != nil && !E.IsClosedOrCanceled(err) {
|
||||
i.logger.DebugContext(ctx, "copy HTTP response body: ", err)
|
||||
}
|
||||
if trailerWriter, ok := respWriter.(connectResponseTrailerWriter); ok {
|
||||
for name, values := range response.Trailer {
|
||||
for _, value := range values {
|
||||
trailerWriter.AddTrailer(name, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig, requestHost string) (*http.Transport, func(), error) {
|
||||
@@ -417,7 +434,7 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter
|
||||
IdleConnTimeout: originRequest.KeepAliveTimeout,
|
||||
MaxIdleConns: originRequest.KeepAliveConnections,
|
||||
MaxIdleConnsPerHost: originRequest.KeepAliveConnections,
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Proxy: proxyFromEnvironment,
|
||||
TLSClientConfig: tlsConfig,
|
||||
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return input, nil
|
||||
@@ -445,7 +462,7 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost
|
||||
IdleConnTimeout: service.OriginRequest.KeepAliveTimeout,
|
||||
MaxIdleConns: service.OriginRequest.KeepAliveConnections,
|
||||
MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections,
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Proxy: proxyFromEnvironment,
|
||||
TLSClientConfig: tlsConfig,
|
||||
}
|
||||
switch service.Kind {
|
||||
|
||||
@@ -131,7 +131,13 @@ func TestNewOriginTLSConfigAppendsCustomCAInsteadOfReplacingBasePool(t *testing.
|
||||
}
|
||||
|
||||
func TestOriginTransportUsesProxyFromEnvironmentOnly(t *testing.T) {
|
||||
t.Setenv("HTTP_PROXY", "http://proxy.example.com:8080")
|
||||
originalProxyFromEnvironment := proxyFromEnvironment
|
||||
proxyFromEnvironment = func(request *http.Request) (*url.URL, error) {
|
||||
return url.Parse("http://proxy.example.com:8080")
|
||||
}
|
||||
defer func() {
|
||||
proxyFromEnvironment = originalProxyFromEnvironment
|
||||
}()
|
||||
|
||||
inbound := &Inbound{}
|
||||
transport, cleanup, err := inbound.newDirectOriginTransport(ResolvedService{
|
||||
|
||||
91
protocol/cloudflare/response_trailer_test.go
Normal file
91
protocol/cloudflare/response_trailer_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
)
|
||||
|
||||
type trailerCaptureResponseWriter struct {
|
||||
status int
|
||||
trailers http.Header
|
||||
}
|
||||
|
||||
func (w *trailerCaptureResponseWriter) WriteResponse(responseError error, metadata []Metadata) error {
|
||||
for _, entry := range metadata {
|
||||
if entry.Key == metadataHTTPStatus {
|
||||
w.status = http.StatusOK
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *trailerCaptureResponseWriter) AddTrailer(name, value string) {
|
||||
if w.trailers == nil {
|
||||
w.trailers = make(http.Header)
|
||||
}
|
||||
w.trailers.Add(name, value)
|
||||
}
|
||||
|
||||
type captureReadWriteCloser struct {
|
||||
body []byte
|
||||
}
|
||||
|
||||
func (c *captureReadWriteCloser) Read(_ []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (c *captureReadWriteCloser) Write(p []byte) (int, error) {
|
||||
c.body = append(c.body, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *captureReadWriteCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRoundTripHTTPCopiesTrailers(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add("Trailer", "X-Test-Trailer")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
w.Header().Set("X-Test-Trailer", "trailer-value")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
transport, ok := server.Client().Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected transport type %T", server.Client().Transport)
|
||||
}
|
||||
|
||||
inboundInstance := &Inbound{
|
||||
logger: log.NewNOPFactory().NewLogger("test"),
|
||||
}
|
||||
stream := &captureReadWriteCloser{}
|
||||
respWriter := &trailerCaptureResponseWriter{}
|
||||
request := &ConnectRequest{
|
||||
Dest: server.URL,
|
||||
Type: ConnectionTypeHTTP,
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPMethod, Val: http.MethodGet},
|
||||
{Key: metadataHTTPHost, Val: "example.com"},
|
||||
},
|
||||
}
|
||||
|
||||
inboundInstance.roundTripHTTP(context.Background(), stream, respWriter, request, ResolvedService{
|
||||
OriginRequest: defaultOriginRequestConfig(),
|
||||
}, transport)
|
||||
|
||||
if got := respWriter.trailers.Get("X-Test-Trailer"); got != "trailer-value" {
|
||||
t.Fatalf("expected copied trailer, got %q", got)
|
||||
}
|
||||
if string(stream.body) != "ok" {
|
||||
t.Fatalf("unexpected response body %q", stream.body)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user