Fix cloudflared parity regressions

This commit is contained in:
世界
2026-03-26 11:30:22 +08:00
parent 3eb626581f
commit c07abeeab3
11 changed files with 472 additions and 13 deletions

View File

@@ -56,17 +56,21 @@ func (v *oidcAccessValidator) Validate(ctx context.Context, request *http.Reques
if err != nil {
return err
}
if len(v.audTags) == 0 {
if accessTokenAudienceAllowed(token.Audience, v.audTags) {
return nil
}
for _, jwtAudTag := range token.Audience {
for _, acceptedAudTag := range v.audTags {
if acceptedAudTag == jwtAudTag {
return nil
return E.New("access token audience does not match configured aud_tag")
}
func accessTokenAudienceAllowed(tokenAudience []string, configuredAudTags []string) bool {
for _, tokenAudTag := range tokenAudience {
for _, configuredAudTag := range configuredAudTags {
if configuredAudTag == tokenAudTag {
return true
}
}
}
return E.New("access token audience does not match configured aud_tag")
return false
}
func accessIssuerURL(teamName string, environment string) string {

View File

@@ -50,6 +50,43 @@ func TestValidateAccessConfiguration(t *testing.T) {
}
}
func TestAccessTokenAudienceAllowed(t *testing.T) {
testCases := []struct {
name string
tokenAudience []string
configuredTags []string
expected bool
}{
{
name: "matching audience",
tokenAudience: []string{"aud-1", "aud-2"},
configuredTags: []string{"aud-2"},
expected: true,
},
{
name: "empty configured tags rejected",
tokenAudience: []string{"aud-1"},
configuredTags: nil,
expected: false,
},
{
name: "non matching audience rejected",
tokenAudience: []string{"aud-1"},
configuredTags: []string{"aud-2"},
expected: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
allowed := accessTokenAudienceAllowed(testCase.tokenAudience, testCase.configuredTags)
if allowed != testCase.expected {
t.Fatalf("accessTokenAudienceAllowed(%v, %v) = %v, want %v", testCase.tokenAudience, testCase.configuredTags, allowed, testCase.expected)
}
})
}
}
func TestRoundTripHTTPAccessDenied(t *testing.T) {
originalFactory := newAccessValidator
defer func() {

View File

@@ -411,6 +411,13 @@ type http2ResponseWriter struct {
headersSent bool
}
func (w *http2ResponseWriter) AddTrailer(name, value string) {
if !w.headersSent {
return
}
w.writer.Header().Add(http2.TrailerPrefix+name, value)
}
func (w *http2ResponseWriter) WriteResponse(responseError error, metadata []Metadata) error {
if w.headersSent {
return nil

View File

@@ -9,6 +9,7 @@ import (
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/sagernet/quic-go"
@@ -262,7 +263,7 @@ func (q *QUICConnection) handleStream(ctx context.Context, stream *quic.Stream,
q.logger.Debug("failed to read connect request: ", err)
return
}
handler.HandleDataStream(ctx, rwc, request, q.connIndex)
handler.HandleDataStream(ctx, &nopCloserReadWriter{ReadWriteCloser: rwc}, request, q.connIndex)
case StreamTypeRPC:
handler.HandleRPCStreamWithSender(ctx, rwc, q.connIndex, q)
@@ -388,3 +389,33 @@ func (s *streamReadWriteCloser) Close() error {
s.stream.CancelRead(0)
return s.stream.Close()
}
// nopCloserReadWriter lets handlers stop consuming the read side without closing
// the underlying stream write side. This matches cloudflared's QUIC HTTP behavior,
// where the request body can be closed before the response is fully written.
type nopCloserReadWriter struct {
io.ReadWriteCloser
sawEOF bool
closed uint32
}
func (n *nopCloserReadWriter) Read(p []byte) (int, error) {
if n.sawEOF {
return 0, io.EOF
}
if atomic.LoadUint32(&n.closed) > 0 {
return 0, fmt.Errorf("closed by handler")
}
readLen, err := n.ReadWriteCloser.Read(p)
if err == io.EOF {
n.sawEOF = true
}
return readLen, err
}
func (n *nopCloserReadWriter) Close() error {
atomic.StoreUint32(&n.closed, 1)
return nil
}

View File

@@ -2,7 +2,11 @@
package cloudflare
import "testing"
import (
"io"
"strings"
"testing"
)
func TestQUICInitialPacketSize(t *testing.T) {
testCases := []struct {
@@ -23,3 +27,53 @@ func TestQUICInitialPacketSize(t *testing.T) {
})
}
}
type mockReadWriteCloser struct {
reader strings.Reader
writes []byte
}
func (m *mockReadWriteCloser) Read(p []byte) (int, error) {
return m.reader.Read(p)
}
func (m *mockReadWriteCloser) Write(p []byte) (int, error) {
m.writes = append(m.writes, p...)
return len(p), nil
}
func (m *mockReadWriteCloser) Close() error {
return nil
}
func TestNOPCloserReadWriterCloseOnlyStopsReads(t *testing.T) {
inner := &mockReadWriteCloser{reader: *strings.NewReader("payload")}
wrapper := &nopCloserReadWriter{ReadWriteCloser: inner}
if err := wrapper.Close(); err != nil {
t.Fatal(err)
}
if _, err := wrapper.Read(make([]byte, 1)); err == nil {
t.Fatal("expected read to fail after close")
}
if _, err := wrapper.Write([]byte("response")); err != nil {
t.Fatal(err)
}
if string(inner.writes) != "response" {
t.Fatalf("unexpected writes %q", inner.writes)
}
}
func TestNOPCloserReadWriterTracksEOF(t *testing.T) {
inner := &mockReadWriteCloser{reader: *strings.NewReader("")}
wrapper := &nopCloserReadWriter{ReadWriteCloser: inner}
if _, err := wrapper.Read(make([]byte, 1)); err != io.EOF {
t.Fatalf("expected EOF, got %v", err)
}
if _, err := wrapper.Read(make([]byte, 1)); err != io.EOF {
t.Fatalf("expected cached EOF, got %v", err)
}
}

View File

@@ -0,0 +1,133 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"testing"
"github.com/google/uuid"
"github.com/sagernet/sing-box/adapter/inbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
capnp "zombiezen.com/go/capnproto2"
)
func newRegisterUDPSessionCall(t *testing.T, traceContext string) (tunnelrpc.SessionManager_registerUdpSession, func() (tunnelrpc.RegisterUdpSessionResponse, error)) {
t.Helper()
_, paramsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
params, err := tunnelrpc.NewSessionManager_registerUdpSession_Params(paramsSeg)
if err != nil {
t.Fatal(err)
}
sessionID := uuid.New()
if err := params.SetSessionId(sessionID[:]); err != nil {
t.Fatal(err)
}
if err := params.SetDstIp([]byte{127, 0, 0, 1}); err != nil {
t.Fatal(err)
}
params.SetDstPort(53)
params.SetCloseAfterIdleHint(int64(30))
if err := params.SetTraceContext(traceContext); err != nil {
t.Fatal(err)
}
_, resultsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
results, err := tunnelrpc.NewSessionManager_registerUdpSession_Results(resultsSeg)
if err != nil {
t.Fatal(err)
}
call := tunnelrpc.SessionManager_registerUdpSession{
Ctx: context.Background(),
Params: params,
Results: results,
}
return call, results.Result
}
func newUnregisterUDPSessionCall(t *testing.T) tunnelrpc.SessionManager_unregisterUdpSession {
t.Helper()
_, paramsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
params, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Params(paramsSeg)
if err != nil {
t.Fatal(err)
}
sessionID := uuid.New()
if err := params.SetSessionId(sessionID[:]); err != nil {
t.Fatal(err)
}
if err := params.SetMessage("close"); err != nil {
t.Fatal(err)
}
_, resultsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
results, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Results(resultsSeg)
if err != nil {
t.Fatal(err)
}
return tunnelrpc.SessionManager_unregisterUdpSession{
Ctx: context.Background(),
Params: params,
Results: results,
}
}
func TestV3RPCRegisterUDPSessionReturnsUnsupportedResult(t *testing.T) {
server := &cloudflaredV3Server{
inbound: &Inbound{Adapter: inbound.NewAdapter(C.TypeCloudflared, "test")},
}
call, readResult := newRegisterUDPSessionCall(t, "trace-context")
if err := server.RegisterUdpSession(call); err != nil {
t.Fatal(err)
}
result, err := readResult()
if err != nil {
t.Fatal(err)
}
resultErr, err := result.Err()
if err != nil {
t.Fatal(err)
}
if resultErr != errUnsupportedDatagramV3UDPRegistration.Error() {
t.Fatalf("unexpected registration error %q", resultErr)
}
spans, err := result.Spans()
if err != nil {
t.Fatal(err)
}
if len(spans) != 0 {
t.Fatalf("expected empty spans, got %x", spans)
}
}
func TestV3RPCUnregisterUDPSessionReturnsUnsupportedError(t *testing.T) {
server := &cloudflaredV3Server{
inbound: &Inbound{Adapter: inbound.NewAdapter(C.TypeCloudflared, "test")},
}
err := server.UnregisterUdpSession(newUnregisterUDPSessionCall(t))
if err == nil {
t.Fatal("expected unsupported unregister error")
}
if err.Error() != errUnsupportedDatagramV3UDPUnregistration.Error() {
t.Fatalf("unexpected unregister error %v", err)
}
}

View File

@@ -0,0 +1,73 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"errors"
"io"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
E "github.com/sagernet/sing/common/exceptions"
"zombiezen.com/go/capnproto2/rpc"
)
var (
errUnsupportedDatagramV3UDPRegistration = errors.New("datagram v3 does not support RegisterUdpSession RPC")
errUnsupportedDatagramV3UDPUnregistration = errors.New("datagram v3 does not support UnregisterUdpSession RPC")
)
type cloudflaredV3Server struct {
inbound *Inbound
logger log.ContextLogger
}
func (s *cloudflaredV3Server) RegisterUdpSession(call tunnelrpc.SessionManager_registerUdpSession) error {
result, err := call.Results.NewResult()
if err != nil {
return err
}
if err := result.SetErr(errUnsupportedDatagramV3UDPRegistration.Error()); err != nil {
return err
}
return result.SetSpans([]byte{})
}
func (s *cloudflaredV3Server) UnregisterUdpSession(call tunnelrpc.SessionManager_unregisterUdpSession) error {
return errUnsupportedDatagramV3UDPUnregistration
}
func (s *cloudflaredV3Server) UpdateConfiguration(call tunnelrpc.ConfigurationManager_updateConfiguration) error {
version := call.Params.Version()
configData, _ := call.Params.Config()
updateResult := s.inbound.ApplyConfig(version, configData)
result, err := call.Results.NewResult()
if err != nil {
return err
}
result.SetLatestAppliedVersion(updateResult.LastAppliedVersion)
if updateResult.Err != nil {
result.SetErr(updateResult.Err.Error())
} else {
result.SetErr("")
}
return nil
}
// ServeV3RPCStream serves configuration updates on v3 and rejects legacy UDP RPCs.
func ServeV3RPCStream(ctx context.Context, stream io.ReadWriteCloser, inbound *Inbound, logger log.ContextLogger) {
srv := &cloudflaredV3Server{
inbound: inbound,
logger: logger,
}
client := tunnelrpc.CloudflaredServer_ServerToClient(srv)
transport := rpc.StreamTransport(stream)
rpcConn := rpc.NewConn(transport, rpc.MainInterface(client.Client))
<-rpcConn.Done()
E.Errors(
rpcConn.Close(),
transport.Close(),
)
}

View File

@@ -474,6 +474,9 @@ func (s *cloudflaredServer) RegisterUdpSession(call tunnelrpc.SessionManager_reg
destinationPort := call.Params.DstPort()
closeAfterIdle := time.Duration(call.Params.CloseAfterIdleHint())
if _, traceErr := call.Params.TraceContext(); traceErr != nil {
return traceErr
}
err = s.muxer.RegisterSession(s.ctx, sessionID, net.IP(destinationIP), destinationPort, closeAfterIdle)
@@ -481,6 +484,9 @@ func (s *cloudflaredServer) RegisterUdpSession(call tunnelrpc.SessionManager_reg
if allocErr != nil {
return allocErr
}
if spansErr := result.SetSpans([]byte{}); spansErr != nil {
return spansErr
}
if err != nil {
result.SetErr(err.Error())
}

View File

@@ -34,6 +34,7 @@ const (
var (
loadOriginCABasePool = cloudflareRootCertPool
readOriginCAFile = os.ReadFile
proxyFromEnvironment = http.ProxyFromEnvironment
)
// ConnectResponseWriter abstracts the response writing for both QUIC and HTTP/2.
@@ -42,6 +43,10 @@ type ConnectResponseWriter interface {
WriteResponse(responseError error, metadata []Metadata) error
}
type connectResponseTrailerWriter interface {
AddTrailer(name, value string)
}
// quicResponseWriter writes ConnectResponse in QUIC data stream format (signature + capnp).
type quicResponseWriter struct {
stream io.Writer
@@ -69,8 +74,13 @@ func (i *Inbound) HandleRPCStream(ctx context.Context, stream io.ReadWriteCloser
// HandleRPCStreamWithSender handles an RPC stream with access to the DatagramSender for V2 muxer lookup.
func (i *Inbound) HandleRPCStreamWithSender(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8, sender DatagramSender) {
muxer := i.getOrCreateV2Muxer(sender)
ServeRPCStream(ctx, stream, i, muxer, i.logger)
switch datagramVersionForSender(sender) {
case "v3":
ServeV3RPCStream(ctx, stream, i, i.logger)
default:
muxer := i.getOrCreateV2Muxer(sender)
ServeRPCStream(ctx, stream, i, muxer, i.logger)
}
}
// HandleDatagram handles an incoming QUIC datagram.
@@ -401,6 +411,13 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser,
if err != nil && !E.IsClosedOrCanceled(err) {
i.logger.DebugContext(ctx, "copy HTTP response body: ", err)
}
if trailerWriter, ok := respWriter.(connectResponseTrailerWriter); ok {
for name, values := range response.Trailer {
for _, value := range values {
trailerWriter.AddTrailer(name, value)
}
}
}
}
func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig, requestHost string) (*http.Transport, func(), error) {
@@ -417,7 +434,7 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter
IdleConnTimeout: originRequest.KeepAliveTimeout,
MaxIdleConns: originRequest.KeepAliveConnections,
MaxIdleConnsPerHost: originRequest.KeepAliveConnections,
Proxy: http.ProxyFromEnvironment,
Proxy: proxyFromEnvironment,
TLSClientConfig: tlsConfig,
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return input, nil
@@ -445,7 +462,7 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost
IdleConnTimeout: service.OriginRequest.KeepAliveTimeout,
MaxIdleConns: service.OriginRequest.KeepAliveConnections,
MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections,
Proxy: http.ProxyFromEnvironment,
Proxy: proxyFromEnvironment,
TLSClientConfig: tlsConfig,
}
switch service.Kind {

View File

@@ -131,7 +131,13 @@ func TestNewOriginTLSConfigAppendsCustomCAInsteadOfReplacingBasePool(t *testing.
}
func TestOriginTransportUsesProxyFromEnvironmentOnly(t *testing.T) {
t.Setenv("HTTP_PROXY", "http://proxy.example.com:8080")
originalProxyFromEnvironment := proxyFromEnvironment
proxyFromEnvironment = func(request *http.Request) (*url.URL, error) {
return url.Parse("http://proxy.example.com:8080")
}
defer func() {
proxyFromEnvironment = originalProxyFromEnvironment
}()
inbound := &Inbound{}
transport, cleanup, err := inbound.newDirectOriginTransport(ResolvedService{

View File

@@ -0,0 +1,91 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/sagernet/sing-box/log"
)
type trailerCaptureResponseWriter struct {
status int
trailers http.Header
}
func (w *trailerCaptureResponseWriter) WriteResponse(responseError error, metadata []Metadata) error {
for _, entry := range metadata {
if entry.Key == metadataHTTPStatus {
w.status = http.StatusOK
}
}
return nil
}
func (w *trailerCaptureResponseWriter) AddTrailer(name, value string) {
if w.trailers == nil {
w.trailers = make(http.Header)
}
w.trailers.Add(name, value)
}
type captureReadWriteCloser struct {
body []byte
}
func (c *captureReadWriteCloser) Read(_ []byte) (int, error) {
return 0, io.EOF
}
func (c *captureReadWriteCloser) Write(p []byte) (int, error) {
c.body = append(c.body, p...)
return len(p), nil
}
func (c *captureReadWriteCloser) Close() error {
return nil
}
func TestRoundTripHTTPCopiesTrailers(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Trailer", "X-Test-Trailer")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
w.Header().Set("X-Test-Trailer", "trailer-value")
}))
defer server.Close()
transport, ok := server.Client().Transport.(*http.Transport)
if !ok {
t.Fatalf("unexpected transport type %T", server.Client().Transport)
}
inboundInstance := &Inbound{
logger: log.NewNOPFactory().NewLogger("test"),
}
stream := &captureReadWriteCloser{}
respWriter := &trailerCaptureResponseWriter{}
request := &ConnectRequest{
Dest: server.URL,
Type: ConnectionTypeHTTP,
Metadata: []Metadata{
{Key: metadataHTTPMethod, Val: http.MethodGet},
{Key: metadataHTTPHost, Val: "example.com"},
},
}
inboundInstance.roundTripHTTP(context.Background(), stream, respWriter, request, ResolvedService{
OriginRequest: defaultOriginRequestConfig(),
}, transport)
if got := respWriter.trailers.Get("X-Test-Trailer"); got != "trailer-value" {
t.Fatalf("expected copied trailer, got %q", got)
}
if string(stream.body) != "ok" {
t.Fatalf("unexpected response body %q", stream.body)
}
}