From 316c2559b1a0aec9df70d241062d4ef19e798484 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 25 Mar 2026 18:57:38 +0800 Subject: [PATCH] Fix cloudflared compatibility gaps --- protocol/cloudflare/connection_drain_test.go | 232 ++++++++++++++++ protocol/cloudflare/connection_http2.go | 111 ++++++-- protocol/cloudflare/connection_quic.go | 74 +++-- protocol/cloudflare/control.go | 12 + .../cloudflare/datagram_lifecycle_test.go | 96 +++++++ protocol/cloudflare/datagram_v3.go | 59 +++- protocol/cloudflare/datagram_v3_test.go | 85 ++++++ protocol/cloudflare/icmp.go | 255 +++++++++++++++--- protocol/cloudflare/icmp_test.go | 152 +++++++++++ protocol/cloudflare/ingress_test.go | 46 ++++ protocol/cloudflare/runtime_config.go | 16 +- 11 files changed, 1052 insertions(+), 86 deletions(-) create mode 100644 protocol/cloudflare/connection_drain_test.go diff --git a/protocol/cloudflare/connection_drain_test.go b/protocol/cloudflare/connection_drain_test.go new file mode 100644 index 000000000..0d975a154 --- /dev/null +++ b/protocol/cloudflare/connection_drain_test.go @@ -0,0 +1,232 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "errors" + "io" + "net" + "testing" + "time" + + "github.com/google/uuid" + "github.com/sagernet/quic-go" +) + +type stubNetConn struct { + closed chan struct{} +} + +func newStubNetConn() *stubNetConn { + return &stubNetConn{closed: make(chan struct{})} +} + +func (c *stubNetConn) Read(_ []byte) (int, error) { <-c.closed; return 0, io.EOF } +func (c *stubNetConn) Write(b []byte) (int, error) { return len(b), nil } +func (c *stubNetConn) Close() error { closeOnce(c.closed); return nil } +func (c *stubNetConn) LocalAddr() net.Addr { return &net.TCPAddr{} } +func (c *stubNetConn) RemoteAddr() net.Addr { return &net.TCPAddr{} } +func (c *stubNetConn) SetDeadline(time.Time) error { return nil } +func (c *stubNetConn) SetReadDeadline(time.Time) error { return nil } +func (c *stubNetConn) SetWriteDeadline(time.Time) error { return nil } + +type stubQUICConn struct { + closed chan string +} + +func newStubQUICConn() *stubQUICConn { + return &stubQUICConn{closed: make(chan string, 1)} +} + +func (c *stubQUICConn) OpenStream() (*quic.Stream, error) { return nil, errors.New("unused") } +func (c *stubQUICConn) AcceptStream(context.Context) (*quic.Stream, error) { + return nil, errors.New("unused") +} +func (c *stubQUICConn) ReceiveDatagram(context.Context) ([]byte, error) { + return nil, errors.New("unused") +} +func (c *stubQUICConn) SendDatagram([]byte) error { return nil } +func (c *stubQUICConn) LocalAddr() net.Addr { return &net.UDPAddr{} } +func (c *stubQUICConn) CloseWithError(_ quic.ApplicationErrorCode, reason string) error { + select { + case c.closed <- reason: + default: + } + return nil +} + +type mockRegistrationClient struct { + unregisterCalled chan struct{} + closed chan struct{} +} + +func newMockRegistrationClient() *mockRegistrationClient { + return &mockRegistrationClient{ + unregisterCalled: make(chan struct{}, 1), + closed: make(chan struct{}, 1), + } +} + +func (c *mockRegistrationClient) RegisterConnection(context.Context, TunnelAuth, uuid.UUID, uint8, *RegistrationConnectionOptions) (*RegistrationResult, error) { + return &RegistrationResult{}, nil +} + +func (c *mockRegistrationClient) Unregister(context.Context) error { + select { + case c.unregisterCalled <- struct{}{}: + default: + } + return nil +} + +func (c *mockRegistrationClient) Close() error { + select { + case c.closed <- struct{}{}: + default: + } + return nil +} + +func closeOnce(ch chan struct{}) { + select { + case <-ch: + default: + close(ch) + } +} + +func TestHTTP2GracefulShutdownWaitsForActiveRequests(t *testing.T) { + conn := newStubNetConn() + registrationClient := newMockRegistrationClient() + connection := &HTTP2Connection{ + conn: conn, + gracePeriod: 200 * time.Millisecond, + registrationClient: registrationClient, + registrationResult: &RegistrationResult{}, + serveCancel: func() {}, + } + connection.activeRequests.Add(1) + + done := make(chan struct{}) + go func() { + connection.gracefulShutdown() + close(done) + }() + + select { + case <-registrationClient.unregisterCalled: + case <-time.After(time.Second): + t.Fatal("expected unregister call") + } + + select { + case <-conn.closed: + t.Fatal("connection closed before active requests completed") + case <-time.After(50 * time.Millisecond): + } + + connection.activeRequests.Done() + + select { + case <-conn.closed: + case <-time.After(time.Second): + t.Fatal("expected connection close after active requests finished") + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected graceful shutdown to finish") + } +} + +func TestHTTP2GracefulShutdownTimesOut(t *testing.T) { + conn := newStubNetConn() + registrationClient := newMockRegistrationClient() + connection := &HTTP2Connection{ + conn: conn, + gracePeriod: 50 * time.Millisecond, + registrationClient: registrationClient, + registrationResult: &RegistrationResult{}, + serveCancel: func() {}, + } + connection.activeRequests.Add(1) + + done := make(chan struct{}) + go func() { + connection.gracefulShutdown() + close(done) + }() + + select { + case <-conn.closed: + case <-time.After(500 * time.Millisecond): + t.Fatal("expected connection close after grace timeout") + } + + connection.activeRequests.Done() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected graceful shutdown to finish after request completion") + } +} + +func TestQUICGracefulShutdownWaitsForDrainWindow(t *testing.T) { + conn := newStubQUICConn() + registrationClient := newMockRegistrationClient() + serveCancelCalled := make(chan struct{}, 1) + connection := &QUICConnection{ + conn: conn, + gracePeriod: 80 * time.Millisecond, + registrationClient: registrationClient, + registrationResult: &RegistrationResult{}, + serveCancel: func() { + select { + case serveCancelCalled <- struct{}{}: + default: + } + }, + } + + done := make(chan struct{}) + go func() { + connection.gracefulShutdown() + close(done) + }() + + select { + case <-registrationClient.unregisterCalled: + case <-time.After(time.Second): + t.Fatal("expected unregister call") + } + + select { + case <-conn.closed: + t.Fatal("connection closed before grace window elapsed") + case <-time.After(20 * time.Millisecond): + } + + select { + case reason := <-conn.closed: + if reason != "graceful shutdown" { + t.Fatalf("unexpected close reason: %q", reason) + } + case <-time.After(time.Second): + t.Fatal("expected graceful close") + } + + select { + case <-serveCancelCalled: + case <-time.After(time.Second): + t.Fatal("expected serve cancel to be called") + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected graceful shutdown to finish") + } +} diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index daa5cfd90..33192806f 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -44,12 +44,15 @@ type HTTP2Connection struct { inbound *Inbound numPreviousAttempts uint8 - registrationClient *RegistrationClient + registrationClient registrationRPCClient registrationResult *RegistrationResult controlStreamErr error - activeRequests sync.WaitGroup - closeOnce sync.Once + activeRequests sync.WaitGroup + serveCancel context.CancelFunc + registrationClose sync.Once + shutdownOnce sync.Once + closeOnce sync.Once } // NewHTTP2Connection dials the edge and establishes an HTTP/2 connection with role reversal. @@ -106,22 +109,28 @@ func NewHTTP2Connection( // Serve runs the HTTP/2 server. Blocks until the context is cancelled or the connection ends. func (c *HTTP2Connection) Serve(ctx context.Context) error { + serveCtx, serveCancel := context.WithCancel(context.WithoutCancel(ctx)) + c.serveCancel = serveCancel + + shutdownDone := make(chan struct{}) go func() { <-ctx.Done() - c.close() + c.gracefulShutdown() + close(shutdownDone) }() c.server.ServeConn(c.conn, &http2.ServeConnOpts{ - Context: ctx, + Context: serveCtx, Handler: c, }) + if ctx.Err() != nil { + <-shutdownDone + return ctx.Err() + } if c.controlStreamErr != nil { return c.controlStreamErr } - if ctx.Err() != nil { - return ctx.Err() - } if c.registrationResult == nil { return E.New("edge connection closed before registration") } @@ -129,12 +138,15 @@ func (c *HTTP2Connection) Serve(ctx context.Context) error { } func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Header.Get(h2HeaderUpgrade) == h2UpgradeControlStream { + c.handleControlStream(r.Context(), r, w) + return + } + c.activeRequests.Add(1) defer c.activeRequests.Done() switch { - case r.Header.Get(h2HeaderUpgrade) == h2UpgradeControlStream: - c.handleControlStream(r.Context(), r, w) case r.Header.Get(h2HeaderUpgrade) == h2UpgradeWebsocket: c.handleH2DataStream(r.Context(), r, w, ConnectionTypeWebsocket) case r.Header.Get(h2HeaderTCPSrc) != "": @@ -169,17 +181,13 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque if err != nil { c.controlStreamErr = err c.logger.Error("register connection: ", err) - if c.registrationClient != nil { - c.registrationClient.Close() - } - go c.close() + go c.forceClose() return } if err := validateRegistrationResult(result); err != nil { c.controlStreamErr = err c.logger.Error("register connection: ", err) - c.registrationClient.Close() - go c.close() + go c.forceClose() return } c.registrationResult = result @@ -189,13 +197,6 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque " (connection ", result.ConnectionID, ")") <-ctx.Done() - unregisterCtx, cancel := context.WithTimeout(context.Background(), c.gracePeriod) - defer cancel() - err = c.registrationClient.Unregister(unregisterCtx) - if err != nil { - c.logger.Debug("failed to unregister: ", err) - } - c.registrationClient.Close() } func (c *HTTP2Connection) handleH2DataStream(ctx context.Context, r *http.Request, w http.ResponseWriter, connectionType ConnectionType) { @@ -280,16 +281,74 @@ func (c *HTTP2Connection) handleConfigurationUpdate(r *http.Request, w http.Resp w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":null}`)) } -func (c *HTTP2Connection) close() { +func (c *HTTP2Connection) gracefulShutdown() { + c.shutdownOnce.Do(func() { + if c.registrationClient == nil || c.registrationResult == nil { + c.closeNow() + return + } + + unregisterCtx, cancel := context.WithTimeout(context.Background(), c.gracePeriod) + err := c.registrationClient.Unregister(unregisterCtx) + cancel() + if err != nil { + c.logger.Debug("failed to unregister: ", err) + } + c.closeRegistrationClient() + c.waitForActiveRequests(c.gracePeriod) + c.closeNow() + }) +} + +func (c *HTTP2Connection) forceClose() { + c.shutdownOnce.Do(func() { + c.closeNow() + }) +} + +func (c *HTTP2Connection) waitForActiveRequests(timeout time.Duration) { + if timeout <= 0 { + c.activeRequests.Wait() + return + } + + done := make(chan struct{}) + go func() { + c.activeRequests.Wait() + close(done) + }() + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case <-done: + case <-timer.C: + } +} + +func (c *HTTP2Connection) closeRegistrationClient() { + c.registrationClose.Do(func() { + if c.registrationClient != nil { + _ = c.registrationClient.Close() + } + }) +} + +func (c *HTTP2Connection) closeNow() { c.closeOnce.Do(func() { - c.conn.Close() + _ = c.conn.Close() + if c.serveCancel != nil { + c.serveCancel() + } + c.closeRegistrationClient() c.activeRequests.Wait() }) } // Close closes the HTTP/2 connection. func (c *HTTP2Connection) Close() error { - c.close() + c.forceClose() return nil } diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go index f654ee4cb..bb935b4e3 100644 --- a/protocol/cloudflare/connection_quic.go +++ b/protocol/cloudflare/connection_quic.go @@ -48,11 +48,14 @@ type QUICConnection struct { features []string numPreviousAttempts uint8 gracePeriod time.Duration - registrationClient *RegistrationClient + registrationClient registrationRPCClient registrationResult *RegistrationResult onConnected func() - closeOnce sync.Once + serveCancel context.CancelFunc + registrationClose sync.Once + shutdownOnce sync.Once + closeOnce sync.Once } type quicConnection interface { @@ -180,22 +183,29 @@ func (q *QUICConnection) Serve(ctx context.Context, handler StreamHandler) error q.logger.Info("connected to ", q.registrationResult.Location, " (connection ", q.registrationResult.ConnectionID, ")") + serveCtx, serveCancel := context.WithCancel(context.WithoutCancel(ctx)) + q.serveCancel = serveCancel + errChan := make(chan error, 2) go func() { - errChan <- q.acceptStreams(ctx, handler) + errChan <- q.acceptStreams(serveCtx, handler) }() go func() { - errChan <- q.handleDatagrams(ctx, handler) + errChan <- q.handleDatagrams(serveCtx, handler) }() select { case <-ctx.Done(): q.gracefulShutdown() + <-errChan return ctx.Err() case err = <-errChan: - q.gracefulShutdown() + q.forceClose() + if ctx.Err() != nil { + return ctx.Err() + } return err } } @@ -285,23 +295,55 @@ func (q *QUICConnection) OpenRPCStream(ctx context.Context) (io.ReadWriteCloser, } func (q *QUICConnection) gracefulShutdown() { - q.closeOnce.Do(func() { - if q.registrationClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), q.gracePeriod) - defer cancel() - err := q.registrationClient.Unregister(ctx) - if err != nil { - q.logger.Debug("failed to unregister: ", err) - } - q.registrationClient.Close() + q.shutdownOnce.Do(func() { + if q.registrationClient == nil || q.registrationResult == nil { + q.closeNow("connection closed") + return } - q.conn.CloseWithError(0, "graceful shutdown") + + ctx, cancel := context.WithTimeout(context.Background(), q.gracePeriod) + err := q.registrationClient.Unregister(ctx) + cancel() + if err != nil { + q.logger.Debug("failed to unregister: ", err) + } + q.closeRegistrationClient() + if q.gracePeriod > 0 { + timer := time.NewTimer(q.gracePeriod) + <-timer.C + timer.Stop() + } + q.closeNow("graceful shutdown") + }) +} + +func (q *QUICConnection) forceClose() { + q.shutdownOnce.Do(func() { + q.closeNow("connection closed") + }) +} + +func (q *QUICConnection) closeRegistrationClient() { + q.registrationClose.Do(func() { + if q.registrationClient != nil { + _ = q.registrationClient.Close() + } + }) +} + +func (q *QUICConnection) closeNow(reason string) { + q.closeOnce.Do(func() { + if q.serveCancel != nil { + q.serveCancel() + } + q.closeRegistrationClient() + _ = q.conn.CloseWithError(0, reason) }) } // Close closes the QUIC connection immediately. func (q *QUICConnection) Close() error { - q.gracefulShutdown() + q.forceClose() return nil } diff --git a/protocol/cloudflare/control.go b/protocol/cloudflare/control.go index dd8b99da6..e6a0b070f 100644 --- a/protocol/cloudflare/control.go +++ b/protocol/cloudflare/control.go @@ -31,6 +31,18 @@ type RegistrationClient struct { transport rpc.Transport } +type registrationRPCClient interface { + RegisterConnection( + ctx context.Context, + auth TunnelAuth, + tunnelID uuid.UUID, + connIndex uint8, + options *RegistrationConnectionOptions, + ) (*RegistrationResult, error) + Unregister(ctx context.Context) error + Close() error +} + // NewRegistrationClient creates a Cap'n Proto RPC client over the given stream. // The stream should be the first QUIC stream (control stream). func NewRegistrationClient(ctx context.Context, stream io.ReadWriteCloser) *RegistrationClient { diff --git a/protocol/cloudflare/datagram_lifecycle_test.go b/protocol/cloudflare/datagram_lifecycle_test.go index 11a98b8bc..b08e3a7e5 100644 --- a/protocol/cloudflare/datagram_lifecycle_test.go +++ b/protocol/cloudflare/datagram_lifecycle_test.go @@ -5,11 +5,16 @@ package cloudflare import ( "context" "encoding/binary" + "io" "net" "testing" "time" "github.com/google/uuid" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) type v2UnregisterCall struct { @@ -25,6 +30,43 @@ type captureV2SessionRPCClient struct { unregisterCh chan v2UnregisterCall } +type blockingPacketConn struct { + closed chan struct{} +} + +func newBlockingPacketConn() *blockingPacketConn { + return &blockingPacketConn{closed: make(chan struct{})} +} + +func (c *blockingPacketConn) ReadPacket(_ *buf.Buffer) (M.Socksaddr, error) { + <-c.closed + return M.Socksaddr{}, io.EOF +} + +func (c *blockingPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error { + buffer.Release() + return nil +} + +func (c *blockingPacketConn) Close() error { + closeOnce(c.closed) + return nil +} + +func (c *blockingPacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} } +func (c *blockingPacketConn) SetDeadline(time.Time) error { return nil } +func (c *blockingPacketConn) SetReadDeadline(time.Time) error { return nil } +func (c *blockingPacketConn) SetWriteDeadline(time.Time) error { return nil } + +type packetDialingRouter struct { + testRouter + packetConn N.PacketConn +} + +func (r *packetDialingRouter) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) { + return r.packetConn, nil +} + func (c *captureV2SessionRPCClient) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string) error { c.unregisterCh <- v2UnregisterCall{sessionID: sessionID, message: message} return nil @@ -105,3 +147,57 @@ func TestDatagramV3RegistrationMigratesSender(t *testing.T) { session.close() } + +func TestDatagramV3MigrationUpdatesSessionContext(t *testing.T) { + packetConn := newBlockingPacketConn() + inboundInstance := newLimitedInbound(t, 0) + inboundInstance.router = &packetDialingRouter{packetConn: packetConn} + sender1 := &captureDatagramSender{} + sender2 := &captureDatagramSender{} + muxer1 := NewDatagramV3Muxer(inboundInstance, sender1, inboundInstance.logger) + muxer2 := NewDatagramV3Muxer(inboundInstance, sender2, inboundInstance.logger) + + requestID := RequestID{} + requestID[15] = 10 + payload := make([]byte, 1+2+2+16+4) + payload[0] = 0 + binary.BigEndian.PutUint16(payload[1:3], 53) + binary.BigEndian.PutUint16(payload[3:5], 30) + copy(payload[5:21], requestID[:]) + copy(payload[21:25], []byte{127, 0, 0, 1}) + + ctx1, cancel1 := context.WithCancel(context.Background()) + muxer1.handleRegistration(ctx1, payload) + + ctx2, cancel2 := context.WithCancel(context.Background()) + muxer2.handleRegistration(ctx2, payload) + + cancel1() + time.Sleep(50 * time.Millisecond) + + session, exists := inboundInstance.datagramV3Manager.Get(requestID) + if !exists { + t.Fatal("expected session to survive old connection context cancellation") + } + + session.senderAccess.RLock() + currentSender := session.sender + session.senderAccess.RUnlock() + if currentSender != sender2 { + t.Fatal("expected migrated sender to stay active") + } + + cancel2() + + deadline := time.After(time.Second) + for { + if _, exists := inboundInstance.datagramV3Manager.Get(requestID); !exists { + return + } + select { + case <-deadline: + t.Fatal("expected session to be removed after new context cancellation") + case <-time.After(10 * time.Millisecond): + } + } +} diff --git a/protocol/cloudflare/datagram_v3.go b/protocol/cloudflare/datagram_v3.go index 436fc5a33..42719c5a1 100644 --- a/protocol/cloudflare/datagram_v3.go +++ b/protocol/cloudflare/datagram_v3.go @@ -37,6 +37,7 @@ const ( v3RegistrationBaseLen = 1 + v3RegistrationFlagLen + v3RegistrationPortLen + v3RegistrationIdleLen + v3RequestIDLength // 22 v3PayloadHeaderLen = 1 + v3RequestIDLength // 17 v3RegistrationRespLen = 1 + 1 + v3RequestIDLength + 2 // 20 + maxV3UDPPayloadLen = 1280 // V3 registration flags v3FlagIPv6 byte = 0x01 @@ -238,6 +239,10 @@ type v3Session struct { senderAccess sync.RWMutex sender DatagramSender + + contextAccess sync.RWMutex + connCtx context.Context + contextChan chan context.Context } var errTooManyActiveFlows = errors.New("too many active flows") @@ -253,11 +258,12 @@ func (m *DatagramV3SessionManager) Register( m.sessionAccess.Lock() if existing, exists := m.sessions[requestID]; exists { if existing.sender == sender { + existing.updateContext(ctx) existing.markActive() m.sessionAccess.Unlock() return existing, v3RegistrationExisting, nil } - existing.setSender(sender) + existing.migrate(sender, ctx) existing.markActive() m.sessionAccess.Unlock() return existing, v3RegistrationMigrated, nil @@ -286,14 +292,17 @@ func (m *DatagramV3SessionManager) Register( closeChan: make(chan struct{}), activeAt: time.Now(), sender: sender, + connCtx: ctx, + contextChan: make(chan context.Context, 1), } m.sessions[requestID] = session m.sessionAccess.Unlock() - sessionCtx := inbound.ctx + sessionCtx := ctx if sessionCtx == nil { sessionCtx = context.Background() } + session.connCtx = sessionCtx go session.serve(sessionCtx, limit) return session, v3RegistrationNew, nil } @@ -320,6 +329,8 @@ func (s *v3Session) serve(ctx context.Context, limit uint64) { go s.readLoop() go s.writeLoop() + connCtx := ctx + tickInterval := s.closeAfterIdle / 2 if tickInterval <= 0 || tickInterval > 10*time.Second { tickInterval = time.Second @@ -329,8 +340,16 @@ func (s *v3Session) serve(ctx context.Context, limit uint64) { for { select { - case <-ctx.Done(): + case <-connCtx.Done(): + if latestCtx := s.currentContext(); latestCtx != nil && latestCtx != connCtx { + connCtx = latestCtx + continue + } s.close() + case newCtx := <-s.contextChan: + if newCtx != nil { + connCtx = newCtx + } case <-ticker.C: if time.Since(s.lastActive()) >= s.closeAfterIdle { s.close() @@ -350,6 +369,11 @@ func (s *v3Session) readLoop() { s.close() return } + if buffer.Len() > maxV3UDPPayloadLen { + s.inbound.logger.Debug("drop oversized V3 UDP payload: ", buffer.Len()) + buffer.Release() + continue + } s.markActive() if err := s.senderDatagram(append([]byte(nil), buffer.Bytes()...)); err != nil { buffer.Release() @@ -403,6 +427,35 @@ func (s *v3Session) setSender(sender DatagramSender) { s.senderAccess.Unlock() } +func (s *v3Session) updateContext(ctx context.Context) { + if ctx == nil { + return + } + s.contextAccess.Lock() + s.connCtx = ctx + s.contextAccess.Unlock() + select { + case s.contextChan <- ctx: + default: + select { + case <-s.contextChan: + default: + } + s.contextChan <- ctx + } +} + +func (s *v3Session) migrate(sender DatagramSender, ctx context.Context) { + s.setSender(sender) + s.updateContext(ctx) +} + +func (s *v3Session) currentContext() context.Context { + s.contextAccess.RLock() + defer s.contextAccess.RUnlock() + return s.connCtx +} + func (s *v3Session) markActive() { s.activeAccess.Lock() s.activeAt = time.Now() diff --git a/protocol/cloudflare/datagram_v3_test.go b/protocol/cloudflare/datagram_v3_test.go index 87f9148dd..22f7cc585 100644 --- a/protocol/cloudflare/datagram_v3_test.go +++ b/protocol/cloudflare/datagram_v3_test.go @@ -5,10 +5,18 @@ package cloudflare import ( "context" "encoding/binary" + "errors" + "io" + "net" + "net/netip" "testing" + "time" "github.com/sagernet/sing-box/adapter/inbound" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" ) func TestDatagramV3RegistrationDestinationUnreachable(t *testing.T) { @@ -64,3 +72,80 @@ func TestDatagramV3RegistrationErrorWithMessage(t *testing.T) { t.Fatalf("unexpected datagram response: %v", sender.sent[0]) } } + +type scriptedPacketConn struct { + reads [][]byte + index int +} + +func (c *scriptedPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + if c.index >= len(c.reads) { + return M.Socksaddr{}, io.EOF + } + _, err := buffer.Write(c.reads[c.index]) + c.index++ + return M.Socksaddr{}, err +} + +func (c *scriptedPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error { + buffer.Release() + return nil +} + +func (c *scriptedPacketConn) Close() error { return nil } +func (c *scriptedPacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} } +func (c *scriptedPacketConn) SetDeadline(time.Time) error { return nil } +func (c *scriptedPacketConn) SetReadDeadline(time.Time) error { return nil } +func (c *scriptedPacketConn) SetWriteDeadline(time.Time) error { return nil } + +type sizeLimitedSender struct { + sent [][]byte + max int +} + +func (s *sizeLimitedSender) SendDatagram(data []byte) error { + if len(data) > s.max { + return errors.New("datagram too large") + } + s.sent = append(s.sent, append([]byte(nil), data...)) + return nil +} + +func TestDatagramV3ReadLoopDropsOversizedOriginPackets(t *testing.T) { + logger := log.NewNOPFactory().NewLogger("test") + sender := &sizeLimitedSender{max: v3PayloadHeaderLen + maxV3UDPPayloadLen} + session := &v3Session{ + id: RequestID{}, + destination: netip.MustParseAddrPort("127.0.0.1:53"), + origin: &scriptedPacketConn{reads: [][]byte{ + make([]byte, maxV3UDPPayloadLen+1), + []byte("ok"), + }}, + inbound: &Inbound{ + logger: logger, + }, + writeChan: make(chan []byte, 1), + closeChan: make(chan struct{}), + contextChan: make(chan context.Context, 1), + sender: sender, + } + + done := make(chan struct{}) + go func() { + session.readLoop() + close(done) + }() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected read loop to finish") + } + + if len(sender.sent) != 1 { + t.Fatalf("expected one datagram after dropping oversized payload, got %d", len(sender.sent)) + } + if len(sender.sent[0]) != v3PayloadHeaderLen+2 { + t.Fatalf("unexpected forwarded datagram length: %d", len(sender.sent[0])) + } +} diff --git a/protocol/cloudflare/icmp.go b/protocol/cloudflare/icmp.go index 1070a2d83..088fd4159 100644 --- a/protocol/cloudflare/icmp.go +++ b/protocol/cloudflare/icmp.go @@ -20,6 +20,15 @@ import ( const ( icmpFlowTimeout = 30 * time.Second icmpTraceIdentityLength = 16 + 8 + 1 + defaultICMPPacketTTL = 64 + icmpErrorHeaderLen = 8 + + icmpv4TypeEchoRequest = 8 + icmpv4TypeEchoReply = 0 + icmpv4TypeTimeExceeded = 11 + icmpv6TypeEchoRequest = 128 + icmpv6TypeEchoReply = 129 + icmpv6TypeTimeExceeded = 3 ) type ICMPTraceContext struct { @@ -40,15 +49,18 @@ type ICMPRequestKey struct { } type ICMPPacketInfo struct { - IPVersion uint8 - Protocol uint8 - SourceIP netip.Addr - Destination netip.Addr - ICMPType uint8 - ICMPCode uint8 - Identifier uint16 - Sequence uint16 - RawPacket []byte + IPVersion uint8 + Protocol uint8 + SourceIP netip.Addr + Destination netip.Addr + ICMPType uint8 + ICMPCode uint8 + Identifier uint16 + Sequence uint16 + IPv4HeaderLen int + IPv4TTL uint8 + IPv6HopLimit uint8 + RawPacket []byte } func (i ICMPPacketInfo) FlowKey() ICMPFlowKey { @@ -82,9 +94,9 @@ func (i ICMPPacketInfo) ReplyRequestKey() ICMPRequestKey { func (i ICMPPacketInfo) IsEchoRequest() bool { switch i.IPVersion { case 4: - return i.ICMPType == 8 && i.ICMPCode == 0 + return i.ICMPType == icmpv4TypeEchoRequest && i.ICMPCode == 0 case 6: - return i.ICMPType == 128 && i.ICMPCode == 0 + return i.ICMPType == icmpv6TypeEchoRequest && i.ICMPCode == 0 default: return false } @@ -93,14 +105,47 @@ func (i ICMPPacketInfo) IsEchoRequest() bool { func (i ICMPPacketInfo) IsEchoReply() bool { switch i.IPVersion { case 4: - return i.ICMPType == 0 && i.ICMPCode == 0 + return i.ICMPType == icmpv4TypeEchoReply && i.ICMPCode == 0 case 6: - return i.ICMPType == 129 && i.ICMPCode == 0 + return i.ICMPType == icmpv6TypeEchoReply && i.ICMPCode == 0 default: return false } } +func (i ICMPPacketInfo) TTL() uint8 { + if i.IPVersion == 4 { + return i.IPv4TTL + } + return i.IPv6HopLimit +} + +func (i ICMPPacketInfo) TTLExpired() bool { + return i.TTL() <= 1 +} + +func (i *ICMPPacketInfo) DecrementTTL() error { + switch i.IPVersion { + case 4: + if i.IPv4TTL == 0 || i.IPv4HeaderLen < 20 || len(i.RawPacket) < i.IPv4HeaderLen { + return E.New("invalid IPv4 packet TTL state") + } + i.IPv4TTL-- + i.RawPacket[8] = i.IPv4TTL + binary.BigEndian.PutUint16(i.RawPacket[10:12], 0) + binary.BigEndian.PutUint16(i.RawPacket[10:12], checksum(i.RawPacket[:i.IPv4HeaderLen], 0)) + case 6: + if i.IPv6HopLimit == 0 || len(i.RawPacket) < 40 { + return E.New("invalid IPv6 packet hop limit state") + } + i.IPv6HopLimit-- + i.RawPacket[7] = i.IPv6HopLimit + default: + return E.New("unsupported IP version: ", i.IPVersion) + } + return nil +} + type icmpWireVersion uint8 const ( @@ -154,15 +199,7 @@ func (w *ICMPReplyWriter) WritePacket(packet []byte) error { } w.access.Unlock() - var datagram []byte - switch w.wireVersion { - case icmpWireV2: - datagram, err = encodeV2ICMPDatagram(packetInfo.RawPacket, traceContext) - case icmpWireV3: - datagram = encodeV3ICMPDatagram(packetInfo.RawPacket) - default: - err = E.New("unsupported icmp wire version: ", w.wireVersion) - } + datagram, err := encodeICMPDatagram(packetInfo.RawPacket, w.wireVersion, traceContext) if err != nil { return err } @@ -218,6 +255,21 @@ func (b *ICMPBridge) handlePacket(ctx context.Context, payload []byte, traceCont if !packetInfo.IsEchoRequest() { return nil } + if packetInfo.TTLExpired() { + ttlExceededPacket, err := buildICMPTTLExceededPacket(packetInfo, maxEncodedICMPPacketLen(b.wireVersion, traceContext)) + if err != nil { + return err + } + datagram, err := encodeICMPDatagram(ttlExceededPacket, b.wireVersion, traceContext) + if err != nil { + return err + } + return b.sender.SendDatagram(datagram) + } + + if err := packetInfo.DecrementTTL(); err != nil { + return err + } state := b.getFlowState(packetInfo.FlowKey()) if traceContext.Traced { @@ -294,15 +346,17 @@ func parseIPv4ICMPPacket(packet []byte) (ICMPPacketInfo, error) { return ICMPPacketInfo{}, E.New("invalid IPv4 destination address") } return ICMPPacketInfo{ - IPVersion: 4, - Protocol: 1, - SourceIP: sourceIP, - Destination: destinationIP, - ICMPType: packet[headerLen], - ICMPCode: packet[headerLen+1], - Identifier: binary.BigEndian.Uint16(packet[headerLen+4 : headerLen+6]), - Sequence: binary.BigEndian.Uint16(packet[headerLen+6 : headerLen+8]), - RawPacket: append([]byte(nil), packet...), + IPVersion: 4, + Protocol: 1, + SourceIP: sourceIP, + Destination: destinationIP, + ICMPType: packet[headerLen], + ICMPCode: packet[headerLen+1], + Identifier: binary.BigEndian.Uint16(packet[headerLen+4 : headerLen+6]), + Sequence: binary.BigEndian.Uint16(packet[headerLen+6 : headerLen+8]), + IPv4HeaderLen: headerLen, + IPv4TTL: packet[8], + RawPacket: append([]byte(nil), packet...), }, nil } @@ -322,18 +376,139 @@ func parseIPv6ICMPPacket(packet []byte) (ICMPPacketInfo, error) { return ICMPPacketInfo{}, E.New("invalid IPv6 destination address") } return ICMPPacketInfo{ - IPVersion: 6, - Protocol: 58, - SourceIP: sourceIP, - Destination: destinationIP, - ICMPType: packet[40], - ICMPCode: packet[41], - Identifier: binary.BigEndian.Uint16(packet[44:46]), - Sequence: binary.BigEndian.Uint16(packet[46:48]), - RawPacket: append([]byte(nil), packet...), + IPVersion: 6, + Protocol: 58, + SourceIP: sourceIP, + Destination: destinationIP, + ICMPType: packet[40], + ICMPCode: packet[41], + Identifier: binary.BigEndian.Uint16(packet[44:46]), + Sequence: binary.BigEndian.Uint16(packet[46:48]), + IPv6HopLimit: packet[7], + RawPacket: append([]byte(nil), packet...), }, nil } +func maxEncodedICMPPacketLen(wireVersion icmpWireVersion, traceContext ICMPTraceContext) int { + limit := maxV3UDPPayloadLen + switch wireVersion { + case icmpWireV2: + limit -= typeIDLength + if traceContext.Traced { + limit -= len(traceContext.Identity) + } + case icmpWireV3: + limit -= 1 + default: + return 0 + } + if limit < 0 { + return 0 + } + return limit +} + +func buildICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) { + switch packetInfo.IPVersion { + case 4: + return buildIPv4ICMPTTLExceededPacket(packetInfo, maxPacketLen) + case 6: + return buildIPv6ICMPTTLExceededPacket(packetInfo, maxPacketLen) + default: + return nil, E.New("unsupported IP version: ", packetInfo.IPVersion) + } +} + +func buildIPv4ICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) { + const headerLen = 20 + if !packetInfo.SourceIP.Is4() || !packetInfo.Destination.Is4() { + return nil, E.New("TTL exceeded packet requires IPv4 addresses") + } + if maxPacketLen <= headerLen+icmpErrorHeaderLen { + return nil, E.New("TTL exceeded packet size limit is too small") + } + + quotedLength := min(len(packetInfo.RawPacket), maxPacketLen-headerLen-icmpErrorHeaderLen) + packet := make([]byte, headerLen+icmpErrorHeaderLen+quotedLength) + packet[0] = 0x45 + binary.BigEndian.PutUint16(packet[2:4], uint16(len(packet))) + packet[8] = defaultICMPPacketTTL + packet[9] = 1 + copy(packet[12:16], packetInfo.Destination.AsSlice()) + copy(packet[16:20], packetInfo.SourceIP.AsSlice()) + packet[20] = icmpv4TypeTimeExceeded + packet[21] = 0 + copy(packet[headerLen+icmpErrorHeaderLen:], packetInfo.RawPacket[:quotedLength]) + binary.BigEndian.PutUint16(packet[22:24], checksum(packet[20:], 0)) + binary.BigEndian.PutUint16(packet[10:12], checksum(packet[:headerLen], 0)) + return packet, nil +} + +func buildIPv6ICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) { + const headerLen = 40 + if !packetInfo.SourceIP.Is6() || !packetInfo.Destination.Is6() { + return nil, E.New("TTL exceeded packet requires IPv6 addresses") + } + if maxPacketLen <= headerLen+icmpErrorHeaderLen { + return nil, E.New("TTL exceeded packet size limit is too small") + } + + quotedLength := min(len(packetInfo.RawPacket), maxPacketLen-headerLen-icmpErrorHeaderLen) + packet := make([]byte, headerLen+icmpErrorHeaderLen+quotedLength) + packet[0] = 0x60 + binary.BigEndian.PutUint16(packet[4:6], uint16(icmpErrorHeaderLen+quotedLength)) + packet[6] = 58 + packet[7] = defaultICMPPacketTTL + copy(packet[8:24], packetInfo.Destination.AsSlice()) + copy(packet[24:40], packetInfo.SourceIP.AsSlice()) + packet[40] = icmpv6TypeTimeExceeded + packet[41] = 0 + copy(packet[headerLen+icmpErrorHeaderLen:], packetInfo.RawPacket[:quotedLength]) + binary.BigEndian.PutUint16(packet[42:44], checksum(packet[40:], ipv6PseudoHeaderChecksum(packetInfo.Destination, packetInfo.SourceIP, uint32(icmpErrorHeaderLen+quotedLength), 58))) + return packet, nil +} + +func encodeICMPDatagram(packet []byte, wireVersion icmpWireVersion, traceContext ICMPTraceContext) ([]byte, error) { + switch wireVersion { + case icmpWireV2: + return encodeV2ICMPDatagram(packet, traceContext) + case icmpWireV3: + return encodeV3ICMPDatagram(packet), nil + default: + return nil, E.New("unsupported icmp wire version: ", wireVersion) + } +} + +func ipv6PseudoHeaderChecksum(source, destination netip.Addr, payloadLength uint32, nextHeader uint8) uint32 { + var sum uint32 + sum = checksumSum(source.AsSlice(), sum) + sum = checksumSum(destination.AsSlice(), sum) + var lengthBytes [4]byte + binary.BigEndian.PutUint32(lengthBytes[:], payloadLength) + sum = checksumSum(lengthBytes[:], sum) + sum = checksumSum([]byte{0, 0, 0, nextHeader}, sum) + return sum +} + +func checksumSum(data []byte, sum uint32) uint32 { + for len(data) >= 2 { + sum += uint32(binary.BigEndian.Uint16(data[:2])) + data = data[2:] + } + if len(data) == 1 { + sum += uint32(data[0]) << 8 + } + return sum +} + +func checksum(data []byte, initial uint32) uint16 { + sum := checksumSum(data, initial) + for sum > 0xffff { + sum = (sum >> 16) + (sum & 0xffff) + } + return ^uint16(sum) +} + func encodeV2ICMPDatagram(packet []byte, traceContext ICMPTraceContext) ([]byte, error) { if traceContext.Traced { data := make([]byte, 0, len(packet)+len(traceContext.Identity)+1) diff --git a/protocol/cloudflare/icmp_test.go b/protocol/cloudflare/icmp_test.go index 9557fa16f..aeecf5751 100644 --- a/protocol/cloudflare/icmp_test.go +++ b/protocol/cloudflare/icmp_test.go @@ -169,6 +169,158 @@ func TestICMPBridgeHandleV3Reply(t *testing.T) { } } +func TestICMPBridgeDecrementsIPv4TTLBeforeRouting(t *testing.T) { + var destination *fakeDirectRouteDestination + router := &testRouter{ + preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) { + destination = &fakeDirectRouteDestination{routeContext: routeContext} + return destination, nil + }, + } + inboundInstance := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), + router: router, + } + bridge := NewICMPBridge(inboundInstance, &captureDatagramSender{}, icmpWireV2) + + packet := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), icmpv4TypeEchoRequest, 0, 1, 1) + packet[8] = 5 + + if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet); err != nil { + t.Fatal(err) + } + if len(destination.packets) != 1 { + t.Fatalf("expected one routed packet, got %d", len(destination.packets)) + } + if got := destination.packets[0][8]; got != 4 { + t.Fatalf("expected decremented IPv4 TTL, got %d", got) + } +} + +func TestICMPBridgeDecrementsIPv6HopLimitBeforeRouting(t *testing.T) { + var destination *fakeDirectRouteDestination + router := &testRouter{ + preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) { + destination = &fakeDirectRouteDestination{routeContext: routeContext} + return destination, nil + }, + } + inboundInstance := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), + router: router, + } + bridge := NewICMPBridge(inboundInstance, &captureDatagramSender{}, icmpWireV3) + + packet := buildIPv6ICMPPacket(netip.MustParseAddr("2001:db8::2"), netip.MustParseAddr("2606:4700:4700::1111"), icmpv6TypeEchoRequest, 0, 1, 1) + packet[7] = 3 + + if err := bridge.HandleV3(context.Background(), packet); err != nil { + t.Fatal(err) + } + if len(destination.packets) != 1 { + t.Fatalf("expected one routed packet, got %d", len(destination.packets)) + } + if got := destination.packets[0][7]; got != 2 { + t.Fatalf("expected decremented IPv6 hop limit, got %d", got) + } +} + +func TestICMPBridgeHandleV2TTLExceededTracedReply(t *testing.T) { + var preMatchCalls int + traceIdentity := bytes.Repeat([]byte{0x6b}, icmpTraceIdentityLength) + sender := &captureDatagramSender{} + router := &testRouter{ + preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) { + preMatchCalls++ + return nil, nil + }, + } + inboundInstance := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), + router: router, + } + bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2) + + source := netip.MustParseAddr("198.18.0.2") + target := netip.MustParseAddr("1.1.1.1") + packet := buildIPv4ICMPPacket(source, target, icmpv4TypeEchoRequest, 0, 1, 1) + packet[8] = 1 + packet = append(packet, traceIdentity...) + + if err := bridge.HandleV2(context.Background(), DatagramV2TypeIPWithTrace, packet); err != nil { + t.Fatal(err) + } + if preMatchCalls != 0 { + t.Fatalf("expected TTL exceeded to bypass routing, got %d route lookups", preMatchCalls) + } + if len(sender.sent) != 1 { + t.Fatalf("expected one TTL exceeded reply, got %d", len(sender.sent)) + } + reply := sender.sent[0] + if reply[len(reply)-1] != byte(DatagramV2TypeIPWithTrace) { + t.Fatalf("expected traced v2 reply, got type %d", reply[len(reply)-1]) + } + gotIdentity := reply[len(reply)-1-icmpTraceIdentityLength : len(reply)-1] + if !bytes.Equal(gotIdentity, traceIdentity) { + t.Fatalf("unexpected trace identity: %x", gotIdentity) + } + rawReply := reply[:len(reply)-1-icmpTraceIdentityLength] + packetInfo, err := ParseICMPPacket(rawReply) + if err != nil { + t.Fatal(err) + } + if packetInfo.ICMPType != icmpv4TypeTimeExceeded || packetInfo.ICMPCode != 0 { + t.Fatalf("expected IPv4 time exceeded reply, got type=%d code=%d", packetInfo.ICMPType, packetInfo.ICMPCode) + } + if packetInfo.SourceIP != target || packetInfo.Destination != source { + t.Fatalf("unexpected TTL exceeded routing: src=%s dst=%s", packetInfo.SourceIP, packetInfo.Destination) + } +} + +func TestICMPBridgeHandleV3TTLExceededReply(t *testing.T) { + var preMatchCalls int + sender := &captureDatagramSender{} + router := &testRouter{ + preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) { + preMatchCalls++ + return nil, nil + }, + } + inboundInstance := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), + router: router, + } + bridge := NewICMPBridge(inboundInstance, sender, icmpWireV3) + + source := netip.MustParseAddr("2001:db8::2") + target := netip.MustParseAddr("2606:4700:4700::1111") + packet := buildIPv6ICMPPacket(source, target, icmpv6TypeEchoRequest, 0, 1, 1) + packet[7] = 1 + + if err := bridge.HandleV3(context.Background(), packet); err != nil { + t.Fatal(err) + } + if preMatchCalls != 0 { + t.Fatalf("expected TTL exceeded to bypass routing, got %d route lookups", preMatchCalls) + } + if len(sender.sent) != 1 { + t.Fatalf("expected one TTL exceeded reply, got %d", len(sender.sent)) + } + if sender.sent[0][0] != byte(DatagramV3TypeICMP) { + t.Fatalf("expected v3 ICMP reply, got %d", sender.sent[0][0]) + } + packetInfo, err := ParseICMPPacket(sender.sent[0][1:]) + if err != nil { + t.Fatal(err) + } + if packetInfo.ICMPType != icmpv6TypeTimeExceeded || packetInfo.ICMPCode != 0 { + t.Fatalf("expected IPv6 time exceeded reply, got type=%d code=%d", packetInfo.ICMPType, packetInfo.ICMPCode) + } + if packetInfo.SourceIP != target || packetInfo.Destination != source { + t.Fatalf("unexpected TTL exceeded routing: src=%s dst=%s", packetInfo.SourceIP, packetInfo.Destination) + } +} + func TestICMPBridgeDropsNonEcho(t *testing.T) { var preMatchCalls int router := &testRouter{ diff --git a/protocol/cloudflare/ingress_test.go b/protocol/cloudflare/ingress_test.go index 5ff2db3db..61111ab77 100644 --- a/protocol/cloudflare/ingress_test.go +++ b/protocol/cloudflare/ingress_test.go @@ -162,3 +162,49 @@ func TestResolveHTTPServiceStatus(t *testing.T) { t.Fatalf("status service should keep request URL, got %s", requestURL) } } + +func TestParseResolvedServiceCanonicalizesWebSocketOrigin(t *testing.T) { + testCases := []struct { + rawService string + wantScheme string + }{ + {rawService: "ws://127.0.0.1:8080", wantScheme: "http"}, + {rawService: "wss://127.0.0.1:8443", wantScheme: "https"}, + } + + for _, testCase := range testCases { + t.Run(testCase.rawService, func(t *testing.T) { + service, err := parseResolvedService(testCase.rawService, defaultOriginRequestConfig()) + if err != nil { + t.Fatal(err) + } + if service.BaseURL == nil { + t.Fatal("expected base URL") + } + if service.BaseURL.Scheme != testCase.wantScheme { + t.Fatalf("expected scheme %q, got %q", testCase.wantScheme, service.BaseURL.Scheme) + } + if service.Service != testCase.rawService { + t.Fatalf("expected raw service to stay %q, got %q", testCase.rawService, service.Service) + } + }) + } +} + +func TestResolveHTTPServiceWebSocketOrigin(t *testing.T) { + inboundInstance := newTestIngressInbound(t) + inboundInstance.configManager.activeConfig = RuntimeConfig{ + Ingress: []compiledIngressRule{ + {Hostname: "foo.com", Service: mustResolvedService(t, "ws://127.0.0.1:8083")}, + {Service: mustResolvedService(t, "http_status:404")}, + }, + } + + _, requestURL, err := inboundInstance.resolveHTTPService("https://foo.com/path?q=1") + if err != nil { + t.Fatal(err) + } + if requestURL != "http://127.0.0.1:8083/path?q=1" { + t.Fatalf("expected websocket origin to be canonicalized, got %s", requestURL) + } +} diff --git a/protocol/cloudflare/runtime_config.go b/protocol/cloudflare/runtime_config.go index ef8c50495..5b0c8d9ed 100644 --- a/protocol/cloudflare/runtime_config.go +++ b/protocol/cloudflare/runtime_config.go @@ -74,6 +74,20 @@ func (s ResolvedService) BuildRequestURL(requestURL string) (string, error) { } } +func canonicalizeHTTPOriginURL(parsedURL *url.URL) *url.URL { + if parsedURL == nil { + return nil + } + canonicalURL := *parsedURL + switch canonicalURL.Scheme { + case "ws": + canonicalURL.Scheme = "http" + case "wss": + canonicalURL.Scheme = "https" + } + return &canonicalURL +} + type compiledIngressRule struct { Hostname string PunycodeHostname string @@ -451,7 +465,7 @@ func parseResolvedService(rawService string, originRequest OriginRequestConfig) Kind: ResolvedServiceHTTP, Service: rawService, Destination: parseServiceDestination(parsedURL), - BaseURL: parsedURL, + BaseURL: canonicalizeHTTPOriginURL(parsedURL), OriginRequest: originRequest, }, nil case "tcp", "ssh", "rdp", "smb":