From c07abeeab3c7ca6054035801b976d3d19af11c82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 26 Mar 2026 11:30:22 +0800 Subject: [PATCH] Fix cloudflared parity regressions --- protocol/cloudflare/access.go | 16 ++- protocol/cloudflare/access_test.go | 37 ++++++ protocol/cloudflare/connection_http2.go | 7 + protocol/cloudflare/connection_quic.go | 33 ++++- protocol/cloudflare/connection_quic_test.go | 56 +++++++- protocol/cloudflare/datagram_rpc_test.go | 133 +++++++++++++++++++ protocol/cloudflare/datagram_rpc_v3.go | 73 ++++++++++ protocol/cloudflare/datagram_v2.go | 6 + protocol/cloudflare/dispatch.go | 25 +++- protocol/cloudflare/origin_request_test.go | 8 +- protocol/cloudflare/response_trailer_test.go | 91 +++++++++++++ 11 files changed, 472 insertions(+), 13 deletions(-) create mode 100644 protocol/cloudflare/datagram_rpc_test.go create mode 100644 protocol/cloudflare/datagram_rpc_v3.go create mode 100644 protocol/cloudflare/response_trailer_test.go diff --git a/protocol/cloudflare/access.go b/protocol/cloudflare/access.go index fc40e7233..f51168e21 100644 --- a/protocol/cloudflare/access.go +++ b/protocol/cloudflare/access.go @@ -56,17 +56,21 @@ func (v *oidcAccessValidator) Validate(ctx context.Context, request *http.Reques if err != nil { return err } - if len(v.audTags) == 0 { + if accessTokenAudienceAllowed(token.Audience, v.audTags) { return nil } - for _, jwtAudTag := range token.Audience { - for _, acceptedAudTag := range v.audTags { - if acceptedAudTag == jwtAudTag { - return nil + return E.New("access token audience does not match configured aud_tag") +} + +func accessTokenAudienceAllowed(tokenAudience []string, configuredAudTags []string) bool { + for _, tokenAudTag := range tokenAudience { + for _, configuredAudTag := range configuredAudTags { + if configuredAudTag == tokenAudTag { + return true } } } - return E.New("access token audience does not match configured aud_tag") + return false } func accessIssuerURL(teamName string, environment string) string { diff --git a/protocol/cloudflare/access_test.go b/protocol/cloudflare/access_test.go index 3cceb155e..5fb6fa178 100644 --- a/protocol/cloudflare/access_test.go +++ b/protocol/cloudflare/access_test.go @@ -50,6 +50,43 @@ func TestValidateAccessConfiguration(t *testing.T) { } } +func TestAccessTokenAudienceAllowed(t *testing.T) { + testCases := []struct { + name string + tokenAudience []string + configuredTags []string + expected bool + }{ + { + name: "matching audience", + tokenAudience: []string{"aud-1", "aud-2"}, + configuredTags: []string{"aud-2"}, + expected: true, + }, + { + name: "empty configured tags rejected", + tokenAudience: []string{"aud-1"}, + configuredTags: nil, + expected: false, + }, + { + name: "non matching audience rejected", + tokenAudience: []string{"aud-1"}, + configuredTags: []string{"aud-2"}, + expected: false, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + allowed := accessTokenAudienceAllowed(testCase.tokenAudience, testCase.configuredTags) + if allowed != testCase.expected { + t.Fatalf("accessTokenAudienceAllowed(%v, %v) = %v, want %v", testCase.tokenAudience, testCase.configuredTags, allowed, testCase.expected) + } + }) + } +} + func TestRoundTripHTTPAccessDenied(t *testing.T) { originalFactory := newAccessValidator defer func() { diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index 33192806f..56ac895e5 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -411,6 +411,13 @@ type http2ResponseWriter struct { headersSent bool } +func (w *http2ResponseWriter) AddTrailer(name, value string) { + if !w.headersSent { + return + } + w.writer.Header().Add(http2.TrailerPrefix+name, value) +} + func (w *http2ResponseWriter) WriteResponse(responseError error, metadata []Metadata) error { if w.headersSent { return nil diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go index 4c5256d09..e83bf8298 100644 --- a/protocol/cloudflare/connection_quic.go +++ b/protocol/cloudflare/connection_quic.go @@ -9,6 +9,7 @@ import ( "io" "net" "sync" + "sync/atomic" "time" "github.com/sagernet/quic-go" @@ -262,7 +263,7 @@ func (q *QUICConnection) handleStream(ctx context.Context, stream *quic.Stream, q.logger.Debug("failed to read connect request: ", err) return } - handler.HandleDataStream(ctx, rwc, request, q.connIndex) + handler.HandleDataStream(ctx, &nopCloserReadWriter{ReadWriteCloser: rwc}, request, q.connIndex) case StreamTypeRPC: handler.HandleRPCStreamWithSender(ctx, rwc, q.connIndex, q) @@ -388,3 +389,33 @@ func (s *streamReadWriteCloser) Close() error { s.stream.CancelRead(0) return s.stream.Close() } + +// nopCloserReadWriter lets handlers stop consuming the read side without closing +// the underlying stream write side. This matches cloudflared's QUIC HTTP behavior, +// where the request body can be closed before the response is fully written. +type nopCloserReadWriter struct { + io.ReadWriteCloser + + sawEOF bool + closed uint32 +} + +func (n *nopCloserReadWriter) Read(p []byte) (int, error) { + if n.sawEOF { + return 0, io.EOF + } + if atomic.LoadUint32(&n.closed) > 0 { + return 0, fmt.Errorf("closed by handler") + } + + readLen, err := n.ReadWriteCloser.Read(p) + if err == io.EOF { + n.sawEOF = true + } + return readLen, err +} + +func (n *nopCloserReadWriter) Close() error { + atomic.StoreUint32(&n.closed, 1) + return nil +} diff --git a/protocol/cloudflare/connection_quic_test.go b/protocol/cloudflare/connection_quic_test.go index ac7f58aba..78479dad8 100644 --- a/protocol/cloudflare/connection_quic_test.go +++ b/protocol/cloudflare/connection_quic_test.go @@ -2,7 +2,11 @@ package cloudflare -import "testing" +import ( + "io" + "strings" + "testing" +) func TestQUICInitialPacketSize(t *testing.T) { testCases := []struct { @@ -23,3 +27,53 @@ func TestQUICInitialPacketSize(t *testing.T) { }) } } + +type mockReadWriteCloser struct { + reader strings.Reader + writes []byte +} + +func (m *mockReadWriteCloser) Read(p []byte) (int, error) { + return m.reader.Read(p) +} + +func (m *mockReadWriteCloser) Write(p []byte) (int, error) { + m.writes = append(m.writes, p...) + return len(p), nil +} + +func (m *mockReadWriteCloser) Close() error { + return nil +} + +func TestNOPCloserReadWriterCloseOnlyStopsReads(t *testing.T) { + inner := &mockReadWriteCloser{reader: *strings.NewReader("payload")} + wrapper := &nopCloserReadWriter{ReadWriteCloser: inner} + + if err := wrapper.Close(); err != nil { + t.Fatal(err) + } + + if _, err := wrapper.Read(make([]byte, 1)); err == nil { + t.Fatal("expected read to fail after close") + } + + if _, err := wrapper.Write([]byte("response")); err != nil { + t.Fatal(err) + } + if string(inner.writes) != "response" { + t.Fatalf("unexpected writes %q", inner.writes) + } +} + +func TestNOPCloserReadWriterTracksEOF(t *testing.T) { + inner := &mockReadWriteCloser{reader: *strings.NewReader("")} + wrapper := &nopCloserReadWriter{ReadWriteCloser: inner} + + if _, err := wrapper.Read(make([]byte, 1)); err != io.EOF { + t.Fatalf("expected EOF, got %v", err) + } + if _, err := wrapper.Read(make([]byte, 1)); err != io.EOF { + t.Fatalf("expected cached EOF, got %v", err) + } +} diff --git a/protocol/cloudflare/datagram_rpc_test.go b/protocol/cloudflare/datagram_rpc_test.go new file mode 100644 index 000000000..08974a9cb --- /dev/null +++ b/protocol/cloudflare/datagram_rpc_test.go @@ -0,0 +1,133 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "testing" + + "github.com/google/uuid" + "github.com/sagernet/sing-box/adapter/inbound" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc" + + capnp "zombiezen.com/go/capnproto2" +) + +func newRegisterUDPSessionCall(t *testing.T, traceContext string) (tunnelrpc.SessionManager_registerUdpSession, func() (tunnelrpc.RegisterUdpSessionResponse, error)) { + t.Helper() + + _, paramsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil)) + if err != nil { + t.Fatal(err) + } + params, err := tunnelrpc.NewSessionManager_registerUdpSession_Params(paramsSeg) + if err != nil { + t.Fatal(err) + } + sessionID := uuid.New() + if err := params.SetSessionId(sessionID[:]); err != nil { + t.Fatal(err) + } + if err := params.SetDstIp([]byte{127, 0, 0, 1}); err != nil { + t.Fatal(err) + } + params.SetDstPort(53) + params.SetCloseAfterIdleHint(int64(30)) + if err := params.SetTraceContext(traceContext); err != nil { + t.Fatal(err) + } + + _, resultsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil)) + if err != nil { + t.Fatal(err) + } + results, err := tunnelrpc.NewSessionManager_registerUdpSession_Results(resultsSeg) + if err != nil { + t.Fatal(err) + } + + call := tunnelrpc.SessionManager_registerUdpSession{ + Ctx: context.Background(), + Params: params, + Results: results, + } + return call, results.Result +} + +func newUnregisterUDPSessionCall(t *testing.T) tunnelrpc.SessionManager_unregisterUdpSession { + t.Helper() + + _, paramsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil)) + if err != nil { + t.Fatal(err) + } + params, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Params(paramsSeg) + if err != nil { + t.Fatal(err) + } + sessionID := uuid.New() + if err := params.SetSessionId(sessionID[:]); err != nil { + t.Fatal(err) + } + if err := params.SetMessage("close"); err != nil { + t.Fatal(err) + } + + _, resultsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil)) + if err != nil { + t.Fatal(err) + } + results, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Results(resultsSeg) + if err != nil { + t.Fatal(err) + } + + return tunnelrpc.SessionManager_unregisterUdpSession{ + Ctx: context.Background(), + Params: params, + Results: results, + } +} + +func TestV3RPCRegisterUDPSessionReturnsUnsupportedResult(t *testing.T) { + server := &cloudflaredV3Server{ + inbound: &Inbound{Adapter: inbound.NewAdapter(C.TypeCloudflared, "test")}, + } + call, readResult := newRegisterUDPSessionCall(t, "trace-context") + if err := server.RegisterUdpSession(call); err != nil { + t.Fatal(err) + } + + result, err := readResult() + if err != nil { + t.Fatal(err) + } + resultErr, err := result.Err() + if err != nil { + t.Fatal(err) + } + if resultErr != errUnsupportedDatagramV3UDPRegistration.Error() { + t.Fatalf("unexpected registration error %q", resultErr) + } + spans, err := result.Spans() + if err != nil { + t.Fatal(err) + } + if len(spans) != 0 { + t.Fatalf("expected empty spans, got %x", spans) + } +} + +func TestV3RPCUnregisterUDPSessionReturnsUnsupportedError(t *testing.T) { + server := &cloudflaredV3Server{ + inbound: &Inbound{Adapter: inbound.NewAdapter(C.TypeCloudflared, "test")}, + } + err := server.UnregisterUdpSession(newUnregisterUDPSessionCall(t)) + if err == nil { + t.Fatal("expected unsupported unregister error") + } + if err.Error() != errUnsupportedDatagramV3UDPUnregistration.Error() { + t.Fatalf("unexpected unregister error %v", err) + } +} diff --git a/protocol/cloudflare/datagram_rpc_v3.go b/protocol/cloudflare/datagram_rpc_v3.go new file mode 100644 index 000000000..38af323ff --- /dev/null +++ b/protocol/cloudflare/datagram_rpc_v3.go @@ -0,0 +1,73 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "errors" + "io" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc" + E "github.com/sagernet/sing/common/exceptions" + + "zombiezen.com/go/capnproto2/rpc" +) + +var ( + errUnsupportedDatagramV3UDPRegistration = errors.New("datagram v3 does not support RegisterUdpSession RPC") + errUnsupportedDatagramV3UDPUnregistration = errors.New("datagram v3 does not support UnregisterUdpSession RPC") +) + +type cloudflaredV3Server struct { + inbound *Inbound + logger log.ContextLogger +} + +func (s *cloudflaredV3Server) RegisterUdpSession(call tunnelrpc.SessionManager_registerUdpSession) error { + result, err := call.Results.NewResult() + if err != nil { + return err + } + if err := result.SetErr(errUnsupportedDatagramV3UDPRegistration.Error()); err != nil { + return err + } + return result.SetSpans([]byte{}) +} + +func (s *cloudflaredV3Server) UnregisterUdpSession(call tunnelrpc.SessionManager_unregisterUdpSession) error { + return errUnsupportedDatagramV3UDPUnregistration +} + +func (s *cloudflaredV3Server) UpdateConfiguration(call tunnelrpc.ConfigurationManager_updateConfiguration) error { + version := call.Params.Version() + configData, _ := call.Params.Config() + updateResult := s.inbound.ApplyConfig(version, configData) + result, err := call.Results.NewResult() + if err != nil { + return err + } + result.SetLatestAppliedVersion(updateResult.LastAppliedVersion) + if updateResult.Err != nil { + result.SetErr(updateResult.Err.Error()) + } else { + result.SetErr("") + } + return nil +} + +// ServeV3RPCStream serves configuration updates on v3 and rejects legacy UDP RPCs. +func ServeV3RPCStream(ctx context.Context, stream io.ReadWriteCloser, inbound *Inbound, logger log.ContextLogger) { + srv := &cloudflaredV3Server{ + inbound: inbound, + logger: logger, + } + client := tunnelrpc.CloudflaredServer_ServerToClient(srv) + transport := rpc.StreamTransport(stream) + rpcConn := rpc.NewConn(transport, rpc.MainInterface(client.Client)) + <-rpcConn.Done() + E.Errors( + rpcConn.Close(), + transport.Close(), + ) +} diff --git a/protocol/cloudflare/datagram_v2.go b/protocol/cloudflare/datagram_v2.go index 9ba52f973..8fa3ffa62 100644 --- a/protocol/cloudflare/datagram_v2.go +++ b/protocol/cloudflare/datagram_v2.go @@ -474,6 +474,9 @@ func (s *cloudflaredServer) RegisterUdpSession(call tunnelrpc.SessionManager_reg destinationPort := call.Params.DstPort() closeAfterIdle := time.Duration(call.Params.CloseAfterIdleHint()) + if _, traceErr := call.Params.TraceContext(); traceErr != nil { + return traceErr + } err = s.muxer.RegisterSession(s.ctx, sessionID, net.IP(destinationIP), destinationPort, closeAfterIdle) @@ -481,6 +484,9 @@ func (s *cloudflaredServer) RegisterUdpSession(call tunnelrpc.SessionManager_reg if allocErr != nil { return allocErr } + if spansErr := result.SetSpans([]byte{}); spansErr != nil { + return spansErr + } if err != nil { result.SetErr(err.Error()) } diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index f3f92ddca..77afee784 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -34,6 +34,7 @@ const ( var ( loadOriginCABasePool = cloudflareRootCertPool readOriginCAFile = os.ReadFile + proxyFromEnvironment = http.ProxyFromEnvironment ) // ConnectResponseWriter abstracts the response writing for both QUIC and HTTP/2. @@ -42,6 +43,10 @@ type ConnectResponseWriter interface { WriteResponse(responseError error, metadata []Metadata) error } +type connectResponseTrailerWriter interface { + AddTrailer(name, value string) +} + // quicResponseWriter writes ConnectResponse in QUIC data stream format (signature + capnp). type quicResponseWriter struct { stream io.Writer @@ -69,8 +74,13 @@ func (i *Inbound) HandleRPCStream(ctx context.Context, stream io.ReadWriteCloser // HandleRPCStreamWithSender handles an RPC stream with access to the DatagramSender for V2 muxer lookup. func (i *Inbound) HandleRPCStreamWithSender(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8, sender DatagramSender) { - muxer := i.getOrCreateV2Muxer(sender) - ServeRPCStream(ctx, stream, i, muxer, i.logger) + switch datagramVersionForSender(sender) { + case "v3": + ServeV3RPCStream(ctx, stream, i, i.logger) + default: + muxer := i.getOrCreateV2Muxer(sender) + ServeRPCStream(ctx, stream, i, muxer, i.logger) + } } // HandleDatagram handles an incoming QUIC datagram. @@ -401,6 +411,13 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, if err != nil && !E.IsClosedOrCanceled(err) { i.logger.DebugContext(ctx, "copy HTTP response body: ", err) } + if trailerWriter, ok := respWriter.(connectResponseTrailerWriter); ok { + for name, values := range response.Trailer { + for _, value := range values { + trailerWriter.AddTrailer(name, value) + } + } + } } func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig, requestHost string) (*http.Transport, func(), error) { @@ -417,7 +434,7 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter IdleConnTimeout: originRequest.KeepAliveTimeout, MaxIdleConns: originRequest.KeepAliveConnections, MaxIdleConnsPerHost: originRequest.KeepAliveConnections, - Proxy: http.ProxyFromEnvironment, + Proxy: proxyFromEnvironment, TLSClientConfig: tlsConfig, DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { return input, nil @@ -445,7 +462,7 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost IdleConnTimeout: service.OriginRequest.KeepAliveTimeout, MaxIdleConns: service.OriginRequest.KeepAliveConnections, MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections, - Proxy: http.ProxyFromEnvironment, + Proxy: proxyFromEnvironment, TLSClientConfig: tlsConfig, } switch service.Kind { diff --git a/protocol/cloudflare/origin_request_test.go b/protocol/cloudflare/origin_request_test.go index ec94ef123..a63c42236 100644 --- a/protocol/cloudflare/origin_request_test.go +++ b/protocol/cloudflare/origin_request_test.go @@ -131,7 +131,13 @@ func TestNewOriginTLSConfigAppendsCustomCAInsteadOfReplacingBasePool(t *testing. } func TestOriginTransportUsesProxyFromEnvironmentOnly(t *testing.T) { - t.Setenv("HTTP_PROXY", "http://proxy.example.com:8080") + originalProxyFromEnvironment := proxyFromEnvironment + proxyFromEnvironment = func(request *http.Request) (*url.URL, error) { + return url.Parse("http://proxy.example.com:8080") + } + defer func() { + proxyFromEnvironment = originalProxyFromEnvironment + }() inbound := &Inbound{} transport, cleanup, err := inbound.newDirectOriginTransport(ResolvedService{ diff --git a/protocol/cloudflare/response_trailer_test.go b/protocol/cloudflare/response_trailer_test.go new file mode 100644 index 000000000..5b833972a --- /dev/null +++ b/protocol/cloudflare/response_trailer_test.go @@ -0,0 +1,91 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/sagernet/sing-box/log" +) + +type trailerCaptureResponseWriter struct { + status int + trailers http.Header +} + +func (w *trailerCaptureResponseWriter) WriteResponse(responseError error, metadata []Metadata) error { + for _, entry := range metadata { + if entry.Key == metadataHTTPStatus { + w.status = http.StatusOK + } + } + return nil +} + +func (w *trailerCaptureResponseWriter) AddTrailer(name, value string) { + if w.trailers == nil { + w.trailers = make(http.Header) + } + w.trailers.Add(name, value) +} + +type captureReadWriteCloser struct { + body []byte +} + +func (c *captureReadWriteCloser) Read(_ []byte) (int, error) { + return 0, io.EOF +} + +func (c *captureReadWriteCloser) Write(p []byte) (int, error) { + c.body = append(c.body, p...) + return len(p), nil +} + +func (c *captureReadWriteCloser) Close() error { + return nil +} + +func TestRoundTripHTTPCopiesTrailers(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Trailer", "X-Test-Trailer") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + w.Header().Set("X-Test-Trailer", "trailer-value") + })) + defer server.Close() + + transport, ok := server.Client().Transport.(*http.Transport) + if !ok { + t.Fatalf("unexpected transport type %T", server.Client().Transport) + } + + inboundInstance := &Inbound{ + logger: log.NewNOPFactory().NewLogger("test"), + } + stream := &captureReadWriteCloser{} + respWriter := &trailerCaptureResponseWriter{} + request := &ConnectRequest{ + Dest: server.URL, + Type: ConnectionTypeHTTP, + Metadata: []Metadata{ + {Key: metadataHTTPMethod, Val: http.MethodGet}, + {Key: metadataHTTPHost, Val: "example.com"}, + }, + } + + inboundInstance.roundTripHTTP(context.Background(), stream, respWriter, request, ResolvedService{ + OriginRequest: defaultOriginRequestConfig(), + }, transport) + + if got := respWriter.trailers.Get("X-Test-Trailer"); got != "trailer-value" { + t.Fatalf("expected copied trailer, got %q", got) + } + if string(stream.body) != "ok" { + t.Fatalf("unexpected response body %q", stream.body) + } +}