Fix cloudflared compatibility gaps

This commit is contained in:
世界
2026-03-25 18:57:38 +08:00
parent 4497f61323
commit 316c2559b1
11 changed files with 1052 additions and 86 deletions

View File

@@ -0,0 +1,232 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"errors"
"io"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.com/sagernet/quic-go"
)
type stubNetConn struct {
closed chan struct{}
}
func newStubNetConn() *stubNetConn {
return &stubNetConn{closed: make(chan struct{})}
}
func (c *stubNetConn) Read(_ []byte) (int, error) { <-c.closed; return 0, io.EOF }
func (c *stubNetConn) Write(b []byte) (int, error) { return len(b), nil }
func (c *stubNetConn) Close() error { closeOnce(c.closed); return nil }
func (c *stubNetConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
func (c *stubNetConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
func (c *stubNetConn) SetDeadline(time.Time) error { return nil }
func (c *stubNetConn) SetReadDeadline(time.Time) error { return nil }
func (c *stubNetConn) SetWriteDeadline(time.Time) error { return nil }
type stubQUICConn struct {
closed chan string
}
func newStubQUICConn() *stubQUICConn {
return &stubQUICConn{closed: make(chan string, 1)}
}
func (c *stubQUICConn) OpenStream() (*quic.Stream, error) { return nil, errors.New("unused") }
func (c *stubQUICConn) AcceptStream(context.Context) (*quic.Stream, error) {
return nil, errors.New("unused")
}
func (c *stubQUICConn) ReceiveDatagram(context.Context) ([]byte, error) {
return nil, errors.New("unused")
}
func (c *stubQUICConn) SendDatagram([]byte) error { return nil }
func (c *stubQUICConn) LocalAddr() net.Addr { return &net.UDPAddr{} }
func (c *stubQUICConn) CloseWithError(_ quic.ApplicationErrorCode, reason string) error {
select {
case c.closed <- reason:
default:
}
return nil
}
type mockRegistrationClient struct {
unregisterCalled chan struct{}
closed chan struct{}
}
func newMockRegistrationClient() *mockRegistrationClient {
return &mockRegistrationClient{
unregisterCalled: make(chan struct{}, 1),
closed: make(chan struct{}, 1),
}
}
func (c *mockRegistrationClient) RegisterConnection(context.Context, TunnelAuth, uuid.UUID, uint8, *RegistrationConnectionOptions) (*RegistrationResult, error) {
return &RegistrationResult{}, nil
}
func (c *mockRegistrationClient) Unregister(context.Context) error {
select {
case c.unregisterCalled <- struct{}{}:
default:
}
return nil
}
func (c *mockRegistrationClient) Close() error {
select {
case c.closed <- struct{}{}:
default:
}
return nil
}
func closeOnce(ch chan struct{}) {
select {
case <-ch:
default:
close(ch)
}
}
func TestHTTP2GracefulShutdownWaitsForActiveRequests(t *testing.T) {
conn := newStubNetConn()
registrationClient := newMockRegistrationClient()
connection := &HTTP2Connection{
conn: conn,
gracePeriod: 200 * time.Millisecond,
registrationClient: registrationClient,
registrationResult: &RegistrationResult{},
serveCancel: func() {},
}
connection.activeRequests.Add(1)
done := make(chan struct{})
go func() {
connection.gracefulShutdown()
close(done)
}()
select {
case <-registrationClient.unregisterCalled:
case <-time.After(time.Second):
t.Fatal("expected unregister call")
}
select {
case <-conn.closed:
t.Fatal("connection closed before active requests completed")
case <-time.After(50 * time.Millisecond):
}
connection.activeRequests.Done()
select {
case <-conn.closed:
case <-time.After(time.Second):
t.Fatal("expected connection close after active requests finished")
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected graceful shutdown to finish")
}
}
func TestHTTP2GracefulShutdownTimesOut(t *testing.T) {
conn := newStubNetConn()
registrationClient := newMockRegistrationClient()
connection := &HTTP2Connection{
conn: conn,
gracePeriod: 50 * time.Millisecond,
registrationClient: registrationClient,
registrationResult: &RegistrationResult{},
serveCancel: func() {},
}
connection.activeRequests.Add(1)
done := make(chan struct{})
go func() {
connection.gracefulShutdown()
close(done)
}()
select {
case <-conn.closed:
case <-time.After(500 * time.Millisecond):
t.Fatal("expected connection close after grace timeout")
}
connection.activeRequests.Done()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected graceful shutdown to finish after request completion")
}
}
func TestQUICGracefulShutdownWaitsForDrainWindow(t *testing.T) {
conn := newStubQUICConn()
registrationClient := newMockRegistrationClient()
serveCancelCalled := make(chan struct{}, 1)
connection := &QUICConnection{
conn: conn,
gracePeriod: 80 * time.Millisecond,
registrationClient: registrationClient,
registrationResult: &RegistrationResult{},
serveCancel: func() {
select {
case serveCancelCalled <- struct{}{}:
default:
}
},
}
done := make(chan struct{})
go func() {
connection.gracefulShutdown()
close(done)
}()
select {
case <-registrationClient.unregisterCalled:
case <-time.After(time.Second):
t.Fatal("expected unregister call")
}
select {
case <-conn.closed:
t.Fatal("connection closed before grace window elapsed")
case <-time.After(20 * time.Millisecond):
}
select {
case reason := <-conn.closed:
if reason != "graceful shutdown" {
t.Fatalf("unexpected close reason: %q", reason)
}
case <-time.After(time.Second):
t.Fatal("expected graceful close")
}
select {
case <-serveCancelCalled:
case <-time.After(time.Second):
t.Fatal("expected serve cancel to be called")
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected graceful shutdown to finish")
}
}

View File

@@ -44,12 +44,15 @@ type HTTP2Connection struct {
inbound *Inbound
numPreviousAttempts uint8
registrationClient *RegistrationClient
registrationClient registrationRPCClient
registrationResult *RegistrationResult
controlStreamErr error
activeRequests sync.WaitGroup
closeOnce sync.Once
activeRequests sync.WaitGroup
serveCancel context.CancelFunc
registrationClose sync.Once
shutdownOnce sync.Once
closeOnce sync.Once
}
// NewHTTP2Connection dials the edge and establishes an HTTP/2 connection with role reversal.
@@ -106,22 +109,28 @@ func NewHTTP2Connection(
// Serve runs the HTTP/2 server. Blocks until the context is cancelled or the connection ends.
func (c *HTTP2Connection) Serve(ctx context.Context) error {
serveCtx, serveCancel := context.WithCancel(context.WithoutCancel(ctx))
c.serveCancel = serveCancel
shutdownDone := make(chan struct{})
go func() {
<-ctx.Done()
c.close()
c.gracefulShutdown()
close(shutdownDone)
}()
c.server.ServeConn(c.conn, &http2.ServeConnOpts{
Context: ctx,
Context: serveCtx,
Handler: c,
})
if ctx.Err() != nil {
<-shutdownDone
return ctx.Err()
}
if c.controlStreamErr != nil {
return c.controlStreamErr
}
if ctx.Err() != nil {
return ctx.Err()
}
if c.registrationResult == nil {
return E.New("edge connection closed before registration")
}
@@ -129,12 +138,15 @@ func (c *HTTP2Connection) Serve(ctx context.Context) error {
}
func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Header.Get(h2HeaderUpgrade) == h2UpgradeControlStream {
c.handleControlStream(r.Context(), r, w)
return
}
c.activeRequests.Add(1)
defer c.activeRequests.Done()
switch {
case r.Header.Get(h2HeaderUpgrade) == h2UpgradeControlStream:
c.handleControlStream(r.Context(), r, w)
case r.Header.Get(h2HeaderUpgrade) == h2UpgradeWebsocket:
c.handleH2DataStream(r.Context(), r, w, ConnectionTypeWebsocket)
case r.Header.Get(h2HeaderTCPSrc) != "":
@@ -169,17 +181,13 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque
if err != nil {
c.controlStreamErr = err
c.logger.Error("register connection: ", err)
if c.registrationClient != nil {
c.registrationClient.Close()
}
go c.close()
go c.forceClose()
return
}
if err := validateRegistrationResult(result); err != nil {
c.controlStreamErr = err
c.logger.Error("register connection: ", err)
c.registrationClient.Close()
go c.close()
go c.forceClose()
return
}
c.registrationResult = result
@@ -189,13 +197,6 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque
" (connection ", result.ConnectionID, ")")
<-ctx.Done()
unregisterCtx, cancel := context.WithTimeout(context.Background(), c.gracePeriod)
defer cancel()
err = c.registrationClient.Unregister(unregisterCtx)
if err != nil {
c.logger.Debug("failed to unregister: ", err)
}
c.registrationClient.Close()
}
func (c *HTTP2Connection) handleH2DataStream(ctx context.Context, r *http.Request, w http.ResponseWriter, connectionType ConnectionType) {
@@ -280,16 +281,74 @@ func (c *HTTP2Connection) handleConfigurationUpdate(r *http.Request, w http.Resp
w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":null}`))
}
func (c *HTTP2Connection) close() {
func (c *HTTP2Connection) gracefulShutdown() {
c.shutdownOnce.Do(func() {
if c.registrationClient == nil || c.registrationResult == nil {
c.closeNow()
return
}
unregisterCtx, cancel := context.WithTimeout(context.Background(), c.gracePeriod)
err := c.registrationClient.Unregister(unregisterCtx)
cancel()
if err != nil {
c.logger.Debug("failed to unregister: ", err)
}
c.closeRegistrationClient()
c.waitForActiveRequests(c.gracePeriod)
c.closeNow()
})
}
func (c *HTTP2Connection) forceClose() {
c.shutdownOnce.Do(func() {
c.closeNow()
})
}
func (c *HTTP2Connection) waitForActiveRequests(timeout time.Duration) {
if timeout <= 0 {
c.activeRequests.Wait()
return
}
done := make(chan struct{})
go func() {
c.activeRequests.Wait()
close(done)
}()
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case <-done:
case <-timer.C:
}
}
func (c *HTTP2Connection) closeRegistrationClient() {
c.registrationClose.Do(func() {
if c.registrationClient != nil {
_ = c.registrationClient.Close()
}
})
}
func (c *HTTP2Connection) closeNow() {
c.closeOnce.Do(func() {
c.conn.Close()
_ = c.conn.Close()
if c.serveCancel != nil {
c.serveCancel()
}
c.closeRegistrationClient()
c.activeRequests.Wait()
})
}
// Close closes the HTTP/2 connection.
func (c *HTTP2Connection) Close() error {
c.close()
c.forceClose()
return nil
}

View File

@@ -48,11 +48,14 @@ type QUICConnection struct {
features []string
numPreviousAttempts uint8
gracePeriod time.Duration
registrationClient *RegistrationClient
registrationClient registrationRPCClient
registrationResult *RegistrationResult
onConnected func()
closeOnce sync.Once
serveCancel context.CancelFunc
registrationClose sync.Once
shutdownOnce sync.Once
closeOnce sync.Once
}
type quicConnection interface {
@@ -180,22 +183,29 @@ func (q *QUICConnection) Serve(ctx context.Context, handler StreamHandler) error
q.logger.Info("connected to ", q.registrationResult.Location,
" (connection ", q.registrationResult.ConnectionID, ")")
serveCtx, serveCancel := context.WithCancel(context.WithoutCancel(ctx))
q.serveCancel = serveCancel
errChan := make(chan error, 2)
go func() {
errChan <- q.acceptStreams(ctx, handler)
errChan <- q.acceptStreams(serveCtx, handler)
}()
go func() {
errChan <- q.handleDatagrams(ctx, handler)
errChan <- q.handleDatagrams(serveCtx, handler)
}()
select {
case <-ctx.Done():
q.gracefulShutdown()
<-errChan
return ctx.Err()
case err = <-errChan:
q.gracefulShutdown()
q.forceClose()
if ctx.Err() != nil {
return ctx.Err()
}
return err
}
}
@@ -285,23 +295,55 @@ func (q *QUICConnection) OpenRPCStream(ctx context.Context) (io.ReadWriteCloser,
}
func (q *QUICConnection) gracefulShutdown() {
q.closeOnce.Do(func() {
if q.registrationClient != nil {
ctx, cancel := context.WithTimeout(context.Background(), q.gracePeriod)
defer cancel()
err := q.registrationClient.Unregister(ctx)
if err != nil {
q.logger.Debug("failed to unregister: ", err)
}
q.registrationClient.Close()
q.shutdownOnce.Do(func() {
if q.registrationClient == nil || q.registrationResult == nil {
q.closeNow("connection closed")
return
}
q.conn.CloseWithError(0, "graceful shutdown")
ctx, cancel := context.WithTimeout(context.Background(), q.gracePeriod)
err := q.registrationClient.Unregister(ctx)
cancel()
if err != nil {
q.logger.Debug("failed to unregister: ", err)
}
q.closeRegistrationClient()
if q.gracePeriod > 0 {
timer := time.NewTimer(q.gracePeriod)
<-timer.C
timer.Stop()
}
q.closeNow("graceful shutdown")
})
}
func (q *QUICConnection) forceClose() {
q.shutdownOnce.Do(func() {
q.closeNow("connection closed")
})
}
func (q *QUICConnection) closeRegistrationClient() {
q.registrationClose.Do(func() {
if q.registrationClient != nil {
_ = q.registrationClient.Close()
}
})
}
func (q *QUICConnection) closeNow(reason string) {
q.closeOnce.Do(func() {
if q.serveCancel != nil {
q.serveCancel()
}
q.closeRegistrationClient()
_ = q.conn.CloseWithError(0, reason)
})
}
// Close closes the QUIC connection immediately.
func (q *QUICConnection) Close() error {
q.gracefulShutdown()
q.forceClose()
return nil
}

View File

@@ -31,6 +31,18 @@ type RegistrationClient struct {
transport rpc.Transport
}
type registrationRPCClient interface {
RegisterConnection(
ctx context.Context,
auth TunnelAuth,
tunnelID uuid.UUID,
connIndex uint8,
options *RegistrationConnectionOptions,
) (*RegistrationResult, error)
Unregister(ctx context.Context) error
Close() error
}
// NewRegistrationClient creates a Cap'n Proto RPC client over the given stream.
// The stream should be the first QUIC stream (control stream).
func NewRegistrationClient(ctx context.Context, stream io.ReadWriteCloser) *RegistrationClient {

View File

@@ -5,11 +5,16 @@ package cloudflare
import (
"context"
"encoding/binary"
"io"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type v2UnregisterCall struct {
@@ -25,6 +30,43 @@ type captureV2SessionRPCClient struct {
unregisterCh chan v2UnregisterCall
}
type blockingPacketConn struct {
closed chan struct{}
}
func newBlockingPacketConn() *blockingPacketConn {
return &blockingPacketConn{closed: make(chan struct{})}
}
func (c *blockingPacketConn) ReadPacket(_ *buf.Buffer) (M.Socksaddr, error) {
<-c.closed
return M.Socksaddr{}, io.EOF
}
func (c *blockingPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
buffer.Release()
return nil
}
func (c *blockingPacketConn) Close() error {
closeOnce(c.closed)
return nil
}
func (c *blockingPacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} }
func (c *blockingPacketConn) SetDeadline(time.Time) error { return nil }
func (c *blockingPacketConn) SetReadDeadline(time.Time) error { return nil }
func (c *blockingPacketConn) SetWriteDeadline(time.Time) error { return nil }
type packetDialingRouter struct {
testRouter
packetConn N.PacketConn
}
func (r *packetDialingRouter) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) {
return r.packetConn, nil
}
func (c *captureV2SessionRPCClient) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string) error {
c.unregisterCh <- v2UnregisterCall{sessionID: sessionID, message: message}
return nil
@@ -105,3 +147,57 @@ func TestDatagramV3RegistrationMigratesSender(t *testing.T) {
session.close()
}
func TestDatagramV3MigrationUpdatesSessionContext(t *testing.T) {
packetConn := newBlockingPacketConn()
inboundInstance := newLimitedInbound(t, 0)
inboundInstance.router = &packetDialingRouter{packetConn: packetConn}
sender1 := &captureDatagramSender{}
sender2 := &captureDatagramSender{}
muxer1 := NewDatagramV3Muxer(inboundInstance, sender1, inboundInstance.logger)
muxer2 := NewDatagramV3Muxer(inboundInstance, sender2, inboundInstance.logger)
requestID := RequestID{}
requestID[15] = 10
payload := make([]byte, 1+2+2+16+4)
payload[0] = 0
binary.BigEndian.PutUint16(payload[1:3], 53)
binary.BigEndian.PutUint16(payload[3:5], 30)
copy(payload[5:21], requestID[:])
copy(payload[21:25], []byte{127, 0, 0, 1})
ctx1, cancel1 := context.WithCancel(context.Background())
muxer1.handleRegistration(ctx1, payload)
ctx2, cancel2 := context.WithCancel(context.Background())
muxer2.handleRegistration(ctx2, payload)
cancel1()
time.Sleep(50 * time.Millisecond)
session, exists := inboundInstance.datagramV3Manager.Get(requestID)
if !exists {
t.Fatal("expected session to survive old connection context cancellation")
}
session.senderAccess.RLock()
currentSender := session.sender
session.senderAccess.RUnlock()
if currentSender != sender2 {
t.Fatal("expected migrated sender to stay active")
}
cancel2()
deadline := time.After(time.Second)
for {
if _, exists := inboundInstance.datagramV3Manager.Get(requestID); !exists {
return
}
select {
case <-deadline:
t.Fatal("expected session to be removed after new context cancellation")
case <-time.After(10 * time.Millisecond):
}
}
}

View File

@@ -37,6 +37,7 @@ const (
v3RegistrationBaseLen = 1 + v3RegistrationFlagLen + v3RegistrationPortLen + v3RegistrationIdleLen + v3RequestIDLength // 22
v3PayloadHeaderLen = 1 + v3RequestIDLength // 17
v3RegistrationRespLen = 1 + 1 + v3RequestIDLength + 2 // 20
maxV3UDPPayloadLen = 1280
// V3 registration flags
v3FlagIPv6 byte = 0x01
@@ -238,6 +239,10 @@ type v3Session struct {
senderAccess sync.RWMutex
sender DatagramSender
contextAccess sync.RWMutex
connCtx context.Context
contextChan chan context.Context
}
var errTooManyActiveFlows = errors.New("too many active flows")
@@ -253,11 +258,12 @@ func (m *DatagramV3SessionManager) Register(
m.sessionAccess.Lock()
if existing, exists := m.sessions[requestID]; exists {
if existing.sender == sender {
existing.updateContext(ctx)
existing.markActive()
m.sessionAccess.Unlock()
return existing, v3RegistrationExisting, nil
}
existing.setSender(sender)
existing.migrate(sender, ctx)
existing.markActive()
m.sessionAccess.Unlock()
return existing, v3RegistrationMigrated, nil
@@ -286,14 +292,17 @@ func (m *DatagramV3SessionManager) Register(
closeChan: make(chan struct{}),
activeAt: time.Now(),
sender: sender,
connCtx: ctx,
contextChan: make(chan context.Context, 1),
}
m.sessions[requestID] = session
m.sessionAccess.Unlock()
sessionCtx := inbound.ctx
sessionCtx := ctx
if sessionCtx == nil {
sessionCtx = context.Background()
}
session.connCtx = sessionCtx
go session.serve(sessionCtx, limit)
return session, v3RegistrationNew, nil
}
@@ -320,6 +329,8 @@ func (s *v3Session) serve(ctx context.Context, limit uint64) {
go s.readLoop()
go s.writeLoop()
connCtx := ctx
tickInterval := s.closeAfterIdle / 2
if tickInterval <= 0 || tickInterval > 10*time.Second {
tickInterval = time.Second
@@ -329,8 +340,16 @@ func (s *v3Session) serve(ctx context.Context, limit uint64) {
for {
select {
case <-ctx.Done():
case <-connCtx.Done():
if latestCtx := s.currentContext(); latestCtx != nil && latestCtx != connCtx {
connCtx = latestCtx
continue
}
s.close()
case newCtx := <-s.contextChan:
if newCtx != nil {
connCtx = newCtx
}
case <-ticker.C:
if time.Since(s.lastActive()) >= s.closeAfterIdle {
s.close()
@@ -350,6 +369,11 @@ func (s *v3Session) readLoop() {
s.close()
return
}
if buffer.Len() > maxV3UDPPayloadLen {
s.inbound.logger.Debug("drop oversized V3 UDP payload: ", buffer.Len())
buffer.Release()
continue
}
s.markActive()
if err := s.senderDatagram(append([]byte(nil), buffer.Bytes()...)); err != nil {
buffer.Release()
@@ -403,6 +427,35 @@ func (s *v3Session) setSender(sender DatagramSender) {
s.senderAccess.Unlock()
}
func (s *v3Session) updateContext(ctx context.Context) {
if ctx == nil {
return
}
s.contextAccess.Lock()
s.connCtx = ctx
s.contextAccess.Unlock()
select {
case s.contextChan <- ctx:
default:
select {
case <-s.contextChan:
default:
}
s.contextChan <- ctx
}
}
func (s *v3Session) migrate(sender DatagramSender, ctx context.Context) {
s.setSender(sender)
s.updateContext(ctx)
}
func (s *v3Session) currentContext() context.Context {
s.contextAccess.RLock()
defer s.contextAccess.RUnlock()
return s.connCtx
}
func (s *v3Session) markActive() {
s.activeAccess.Lock()
s.activeAt = time.Now()

View File

@@ -5,10 +5,18 @@ package cloudflare
import (
"context"
"encoding/binary"
"errors"
"io"
"net"
"net/netip"
"testing"
"time"
"github.com/sagernet/sing-box/adapter/inbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
)
func TestDatagramV3RegistrationDestinationUnreachable(t *testing.T) {
@@ -64,3 +72,80 @@ func TestDatagramV3RegistrationErrorWithMessage(t *testing.T) {
t.Fatalf("unexpected datagram response: %v", sender.sent[0])
}
}
type scriptedPacketConn struct {
reads [][]byte
index int
}
func (c *scriptedPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
if c.index >= len(c.reads) {
return M.Socksaddr{}, io.EOF
}
_, err := buffer.Write(c.reads[c.index])
c.index++
return M.Socksaddr{}, err
}
func (c *scriptedPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
buffer.Release()
return nil
}
func (c *scriptedPacketConn) Close() error { return nil }
func (c *scriptedPacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} }
func (c *scriptedPacketConn) SetDeadline(time.Time) error { return nil }
func (c *scriptedPacketConn) SetReadDeadline(time.Time) error { return nil }
func (c *scriptedPacketConn) SetWriteDeadline(time.Time) error { return nil }
type sizeLimitedSender struct {
sent [][]byte
max int
}
func (s *sizeLimitedSender) SendDatagram(data []byte) error {
if len(data) > s.max {
return errors.New("datagram too large")
}
s.sent = append(s.sent, append([]byte(nil), data...))
return nil
}
func TestDatagramV3ReadLoopDropsOversizedOriginPackets(t *testing.T) {
logger := log.NewNOPFactory().NewLogger("test")
sender := &sizeLimitedSender{max: v3PayloadHeaderLen + maxV3UDPPayloadLen}
session := &v3Session{
id: RequestID{},
destination: netip.MustParseAddrPort("127.0.0.1:53"),
origin: &scriptedPacketConn{reads: [][]byte{
make([]byte, maxV3UDPPayloadLen+1),
[]byte("ok"),
}},
inbound: &Inbound{
logger: logger,
},
writeChan: make(chan []byte, 1),
closeChan: make(chan struct{}),
contextChan: make(chan context.Context, 1),
sender: sender,
}
done := make(chan struct{})
go func() {
session.readLoop()
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected read loop to finish")
}
if len(sender.sent) != 1 {
t.Fatalf("expected one datagram after dropping oversized payload, got %d", len(sender.sent))
}
if len(sender.sent[0]) != v3PayloadHeaderLen+2 {
t.Fatalf("unexpected forwarded datagram length: %d", len(sender.sent[0]))
}
}

View File

@@ -20,6 +20,15 @@ import (
const (
icmpFlowTimeout = 30 * time.Second
icmpTraceIdentityLength = 16 + 8 + 1
defaultICMPPacketTTL = 64
icmpErrorHeaderLen = 8
icmpv4TypeEchoRequest = 8
icmpv4TypeEchoReply = 0
icmpv4TypeTimeExceeded = 11
icmpv6TypeEchoRequest = 128
icmpv6TypeEchoReply = 129
icmpv6TypeTimeExceeded = 3
)
type ICMPTraceContext struct {
@@ -40,15 +49,18 @@ type ICMPRequestKey struct {
}
type ICMPPacketInfo struct {
IPVersion uint8
Protocol uint8
SourceIP netip.Addr
Destination netip.Addr
ICMPType uint8
ICMPCode uint8
Identifier uint16
Sequence uint16
RawPacket []byte
IPVersion uint8
Protocol uint8
SourceIP netip.Addr
Destination netip.Addr
ICMPType uint8
ICMPCode uint8
Identifier uint16
Sequence uint16
IPv4HeaderLen int
IPv4TTL uint8
IPv6HopLimit uint8
RawPacket []byte
}
func (i ICMPPacketInfo) FlowKey() ICMPFlowKey {
@@ -82,9 +94,9 @@ func (i ICMPPacketInfo) ReplyRequestKey() ICMPRequestKey {
func (i ICMPPacketInfo) IsEchoRequest() bool {
switch i.IPVersion {
case 4:
return i.ICMPType == 8 && i.ICMPCode == 0
return i.ICMPType == icmpv4TypeEchoRequest && i.ICMPCode == 0
case 6:
return i.ICMPType == 128 && i.ICMPCode == 0
return i.ICMPType == icmpv6TypeEchoRequest && i.ICMPCode == 0
default:
return false
}
@@ -93,14 +105,47 @@ func (i ICMPPacketInfo) IsEchoRequest() bool {
func (i ICMPPacketInfo) IsEchoReply() bool {
switch i.IPVersion {
case 4:
return i.ICMPType == 0 && i.ICMPCode == 0
return i.ICMPType == icmpv4TypeEchoReply && i.ICMPCode == 0
case 6:
return i.ICMPType == 129 && i.ICMPCode == 0
return i.ICMPType == icmpv6TypeEchoReply && i.ICMPCode == 0
default:
return false
}
}
func (i ICMPPacketInfo) TTL() uint8 {
if i.IPVersion == 4 {
return i.IPv4TTL
}
return i.IPv6HopLimit
}
func (i ICMPPacketInfo) TTLExpired() bool {
return i.TTL() <= 1
}
func (i *ICMPPacketInfo) DecrementTTL() error {
switch i.IPVersion {
case 4:
if i.IPv4TTL == 0 || i.IPv4HeaderLen < 20 || len(i.RawPacket) < i.IPv4HeaderLen {
return E.New("invalid IPv4 packet TTL state")
}
i.IPv4TTL--
i.RawPacket[8] = i.IPv4TTL
binary.BigEndian.PutUint16(i.RawPacket[10:12], 0)
binary.BigEndian.PutUint16(i.RawPacket[10:12], checksum(i.RawPacket[:i.IPv4HeaderLen], 0))
case 6:
if i.IPv6HopLimit == 0 || len(i.RawPacket) < 40 {
return E.New("invalid IPv6 packet hop limit state")
}
i.IPv6HopLimit--
i.RawPacket[7] = i.IPv6HopLimit
default:
return E.New("unsupported IP version: ", i.IPVersion)
}
return nil
}
type icmpWireVersion uint8
const (
@@ -154,15 +199,7 @@ func (w *ICMPReplyWriter) WritePacket(packet []byte) error {
}
w.access.Unlock()
var datagram []byte
switch w.wireVersion {
case icmpWireV2:
datagram, err = encodeV2ICMPDatagram(packetInfo.RawPacket, traceContext)
case icmpWireV3:
datagram = encodeV3ICMPDatagram(packetInfo.RawPacket)
default:
err = E.New("unsupported icmp wire version: ", w.wireVersion)
}
datagram, err := encodeICMPDatagram(packetInfo.RawPacket, w.wireVersion, traceContext)
if err != nil {
return err
}
@@ -218,6 +255,21 @@ func (b *ICMPBridge) handlePacket(ctx context.Context, payload []byte, traceCont
if !packetInfo.IsEchoRequest() {
return nil
}
if packetInfo.TTLExpired() {
ttlExceededPacket, err := buildICMPTTLExceededPacket(packetInfo, maxEncodedICMPPacketLen(b.wireVersion, traceContext))
if err != nil {
return err
}
datagram, err := encodeICMPDatagram(ttlExceededPacket, b.wireVersion, traceContext)
if err != nil {
return err
}
return b.sender.SendDatagram(datagram)
}
if err := packetInfo.DecrementTTL(); err != nil {
return err
}
state := b.getFlowState(packetInfo.FlowKey())
if traceContext.Traced {
@@ -294,15 +346,17 @@ func parseIPv4ICMPPacket(packet []byte) (ICMPPacketInfo, error) {
return ICMPPacketInfo{}, E.New("invalid IPv4 destination address")
}
return ICMPPacketInfo{
IPVersion: 4,
Protocol: 1,
SourceIP: sourceIP,
Destination: destinationIP,
ICMPType: packet[headerLen],
ICMPCode: packet[headerLen+1],
Identifier: binary.BigEndian.Uint16(packet[headerLen+4 : headerLen+6]),
Sequence: binary.BigEndian.Uint16(packet[headerLen+6 : headerLen+8]),
RawPacket: append([]byte(nil), packet...),
IPVersion: 4,
Protocol: 1,
SourceIP: sourceIP,
Destination: destinationIP,
ICMPType: packet[headerLen],
ICMPCode: packet[headerLen+1],
Identifier: binary.BigEndian.Uint16(packet[headerLen+4 : headerLen+6]),
Sequence: binary.BigEndian.Uint16(packet[headerLen+6 : headerLen+8]),
IPv4HeaderLen: headerLen,
IPv4TTL: packet[8],
RawPacket: append([]byte(nil), packet...),
}, nil
}
@@ -322,18 +376,139 @@ func parseIPv6ICMPPacket(packet []byte) (ICMPPacketInfo, error) {
return ICMPPacketInfo{}, E.New("invalid IPv6 destination address")
}
return ICMPPacketInfo{
IPVersion: 6,
Protocol: 58,
SourceIP: sourceIP,
Destination: destinationIP,
ICMPType: packet[40],
ICMPCode: packet[41],
Identifier: binary.BigEndian.Uint16(packet[44:46]),
Sequence: binary.BigEndian.Uint16(packet[46:48]),
RawPacket: append([]byte(nil), packet...),
IPVersion: 6,
Protocol: 58,
SourceIP: sourceIP,
Destination: destinationIP,
ICMPType: packet[40],
ICMPCode: packet[41],
Identifier: binary.BigEndian.Uint16(packet[44:46]),
Sequence: binary.BigEndian.Uint16(packet[46:48]),
IPv6HopLimit: packet[7],
RawPacket: append([]byte(nil), packet...),
}, nil
}
func maxEncodedICMPPacketLen(wireVersion icmpWireVersion, traceContext ICMPTraceContext) int {
limit := maxV3UDPPayloadLen
switch wireVersion {
case icmpWireV2:
limit -= typeIDLength
if traceContext.Traced {
limit -= len(traceContext.Identity)
}
case icmpWireV3:
limit -= 1
default:
return 0
}
if limit < 0 {
return 0
}
return limit
}
func buildICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) {
switch packetInfo.IPVersion {
case 4:
return buildIPv4ICMPTTLExceededPacket(packetInfo, maxPacketLen)
case 6:
return buildIPv6ICMPTTLExceededPacket(packetInfo, maxPacketLen)
default:
return nil, E.New("unsupported IP version: ", packetInfo.IPVersion)
}
}
func buildIPv4ICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) {
const headerLen = 20
if !packetInfo.SourceIP.Is4() || !packetInfo.Destination.Is4() {
return nil, E.New("TTL exceeded packet requires IPv4 addresses")
}
if maxPacketLen <= headerLen+icmpErrorHeaderLen {
return nil, E.New("TTL exceeded packet size limit is too small")
}
quotedLength := min(len(packetInfo.RawPacket), maxPacketLen-headerLen-icmpErrorHeaderLen)
packet := make([]byte, headerLen+icmpErrorHeaderLen+quotedLength)
packet[0] = 0x45
binary.BigEndian.PutUint16(packet[2:4], uint16(len(packet)))
packet[8] = defaultICMPPacketTTL
packet[9] = 1
copy(packet[12:16], packetInfo.Destination.AsSlice())
copy(packet[16:20], packetInfo.SourceIP.AsSlice())
packet[20] = icmpv4TypeTimeExceeded
packet[21] = 0
copy(packet[headerLen+icmpErrorHeaderLen:], packetInfo.RawPacket[:quotedLength])
binary.BigEndian.PutUint16(packet[22:24], checksum(packet[20:], 0))
binary.BigEndian.PutUint16(packet[10:12], checksum(packet[:headerLen], 0))
return packet, nil
}
func buildIPv6ICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) {
const headerLen = 40
if !packetInfo.SourceIP.Is6() || !packetInfo.Destination.Is6() {
return nil, E.New("TTL exceeded packet requires IPv6 addresses")
}
if maxPacketLen <= headerLen+icmpErrorHeaderLen {
return nil, E.New("TTL exceeded packet size limit is too small")
}
quotedLength := min(len(packetInfo.RawPacket), maxPacketLen-headerLen-icmpErrorHeaderLen)
packet := make([]byte, headerLen+icmpErrorHeaderLen+quotedLength)
packet[0] = 0x60
binary.BigEndian.PutUint16(packet[4:6], uint16(icmpErrorHeaderLen+quotedLength))
packet[6] = 58
packet[7] = defaultICMPPacketTTL
copy(packet[8:24], packetInfo.Destination.AsSlice())
copy(packet[24:40], packetInfo.SourceIP.AsSlice())
packet[40] = icmpv6TypeTimeExceeded
packet[41] = 0
copy(packet[headerLen+icmpErrorHeaderLen:], packetInfo.RawPacket[:quotedLength])
binary.BigEndian.PutUint16(packet[42:44], checksum(packet[40:], ipv6PseudoHeaderChecksum(packetInfo.Destination, packetInfo.SourceIP, uint32(icmpErrorHeaderLen+quotedLength), 58)))
return packet, nil
}
func encodeICMPDatagram(packet []byte, wireVersion icmpWireVersion, traceContext ICMPTraceContext) ([]byte, error) {
switch wireVersion {
case icmpWireV2:
return encodeV2ICMPDatagram(packet, traceContext)
case icmpWireV3:
return encodeV3ICMPDatagram(packet), nil
default:
return nil, E.New("unsupported icmp wire version: ", wireVersion)
}
}
func ipv6PseudoHeaderChecksum(source, destination netip.Addr, payloadLength uint32, nextHeader uint8) uint32 {
var sum uint32
sum = checksumSum(source.AsSlice(), sum)
sum = checksumSum(destination.AsSlice(), sum)
var lengthBytes [4]byte
binary.BigEndian.PutUint32(lengthBytes[:], payloadLength)
sum = checksumSum(lengthBytes[:], sum)
sum = checksumSum([]byte{0, 0, 0, nextHeader}, sum)
return sum
}
func checksumSum(data []byte, sum uint32) uint32 {
for len(data) >= 2 {
sum += uint32(binary.BigEndian.Uint16(data[:2]))
data = data[2:]
}
if len(data) == 1 {
sum += uint32(data[0]) << 8
}
return sum
}
func checksum(data []byte, initial uint32) uint16 {
sum := checksumSum(data, initial)
for sum > 0xffff {
sum = (sum >> 16) + (sum & 0xffff)
}
return ^uint16(sum)
}
func encodeV2ICMPDatagram(packet []byte, traceContext ICMPTraceContext) ([]byte, error) {
if traceContext.Traced {
data := make([]byte, 0, len(packet)+len(traceContext.Identity)+1)

View File

@@ -169,6 +169,158 @@ func TestICMPBridgeHandleV3Reply(t *testing.T) {
}
}
func TestICMPBridgeDecrementsIPv4TTLBeforeRouting(t *testing.T) {
var destination *fakeDirectRouteDestination
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
destination = &fakeDirectRouteDestination{routeContext: routeContext}
return destination, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, &captureDatagramSender{}, icmpWireV2)
packet := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), icmpv4TypeEchoRequest, 0, 1, 1)
packet[8] = 5
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet); err != nil {
t.Fatal(err)
}
if len(destination.packets) != 1 {
t.Fatalf("expected one routed packet, got %d", len(destination.packets))
}
if got := destination.packets[0][8]; got != 4 {
t.Fatalf("expected decremented IPv4 TTL, got %d", got)
}
}
func TestICMPBridgeDecrementsIPv6HopLimitBeforeRouting(t *testing.T) {
var destination *fakeDirectRouteDestination
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
destination = &fakeDirectRouteDestination{routeContext: routeContext}
return destination, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, &captureDatagramSender{}, icmpWireV3)
packet := buildIPv6ICMPPacket(netip.MustParseAddr("2001:db8::2"), netip.MustParseAddr("2606:4700:4700::1111"), icmpv6TypeEchoRequest, 0, 1, 1)
packet[7] = 3
if err := bridge.HandleV3(context.Background(), packet); err != nil {
t.Fatal(err)
}
if len(destination.packets) != 1 {
t.Fatalf("expected one routed packet, got %d", len(destination.packets))
}
if got := destination.packets[0][7]; got != 2 {
t.Fatalf("expected decremented IPv6 hop limit, got %d", got)
}
}
func TestICMPBridgeHandleV2TTLExceededTracedReply(t *testing.T) {
var preMatchCalls int
traceIdentity := bytes.Repeat([]byte{0x6b}, icmpTraceIdentityLength)
sender := &captureDatagramSender{}
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
preMatchCalls++
return nil, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2)
source := netip.MustParseAddr("198.18.0.2")
target := netip.MustParseAddr("1.1.1.1")
packet := buildIPv4ICMPPacket(source, target, icmpv4TypeEchoRequest, 0, 1, 1)
packet[8] = 1
packet = append(packet, traceIdentity...)
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIPWithTrace, packet); err != nil {
t.Fatal(err)
}
if preMatchCalls != 0 {
t.Fatalf("expected TTL exceeded to bypass routing, got %d route lookups", preMatchCalls)
}
if len(sender.sent) != 1 {
t.Fatalf("expected one TTL exceeded reply, got %d", len(sender.sent))
}
reply := sender.sent[0]
if reply[len(reply)-1] != byte(DatagramV2TypeIPWithTrace) {
t.Fatalf("expected traced v2 reply, got type %d", reply[len(reply)-1])
}
gotIdentity := reply[len(reply)-1-icmpTraceIdentityLength : len(reply)-1]
if !bytes.Equal(gotIdentity, traceIdentity) {
t.Fatalf("unexpected trace identity: %x", gotIdentity)
}
rawReply := reply[:len(reply)-1-icmpTraceIdentityLength]
packetInfo, err := ParseICMPPacket(rawReply)
if err != nil {
t.Fatal(err)
}
if packetInfo.ICMPType != icmpv4TypeTimeExceeded || packetInfo.ICMPCode != 0 {
t.Fatalf("expected IPv4 time exceeded reply, got type=%d code=%d", packetInfo.ICMPType, packetInfo.ICMPCode)
}
if packetInfo.SourceIP != target || packetInfo.Destination != source {
t.Fatalf("unexpected TTL exceeded routing: src=%s dst=%s", packetInfo.SourceIP, packetInfo.Destination)
}
}
func TestICMPBridgeHandleV3TTLExceededReply(t *testing.T) {
var preMatchCalls int
sender := &captureDatagramSender{}
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
preMatchCalls++
return nil, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV3)
source := netip.MustParseAddr("2001:db8::2")
target := netip.MustParseAddr("2606:4700:4700::1111")
packet := buildIPv6ICMPPacket(source, target, icmpv6TypeEchoRequest, 0, 1, 1)
packet[7] = 1
if err := bridge.HandleV3(context.Background(), packet); err != nil {
t.Fatal(err)
}
if preMatchCalls != 0 {
t.Fatalf("expected TTL exceeded to bypass routing, got %d route lookups", preMatchCalls)
}
if len(sender.sent) != 1 {
t.Fatalf("expected one TTL exceeded reply, got %d", len(sender.sent))
}
if sender.sent[0][0] != byte(DatagramV3TypeICMP) {
t.Fatalf("expected v3 ICMP reply, got %d", sender.sent[0][0])
}
packetInfo, err := ParseICMPPacket(sender.sent[0][1:])
if err != nil {
t.Fatal(err)
}
if packetInfo.ICMPType != icmpv6TypeTimeExceeded || packetInfo.ICMPCode != 0 {
t.Fatalf("expected IPv6 time exceeded reply, got type=%d code=%d", packetInfo.ICMPType, packetInfo.ICMPCode)
}
if packetInfo.SourceIP != target || packetInfo.Destination != source {
t.Fatalf("unexpected TTL exceeded routing: src=%s dst=%s", packetInfo.SourceIP, packetInfo.Destination)
}
}
func TestICMPBridgeDropsNonEcho(t *testing.T) {
var preMatchCalls int
router := &testRouter{

View File

@@ -162,3 +162,49 @@ func TestResolveHTTPServiceStatus(t *testing.T) {
t.Fatalf("status service should keep request URL, got %s", requestURL)
}
}
func TestParseResolvedServiceCanonicalizesWebSocketOrigin(t *testing.T) {
testCases := []struct {
rawService string
wantScheme string
}{
{rawService: "ws://127.0.0.1:8080", wantScheme: "http"},
{rawService: "wss://127.0.0.1:8443", wantScheme: "https"},
}
for _, testCase := range testCases {
t.Run(testCase.rawService, func(t *testing.T) {
service, err := parseResolvedService(testCase.rawService, defaultOriginRequestConfig())
if err != nil {
t.Fatal(err)
}
if service.BaseURL == nil {
t.Fatal("expected base URL")
}
if service.BaseURL.Scheme != testCase.wantScheme {
t.Fatalf("expected scheme %q, got %q", testCase.wantScheme, service.BaseURL.Scheme)
}
if service.Service != testCase.rawService {
t.Fatalf("expected raw service to stay %q, got %q", testCase.rawService, service.Service)
}
})
}
}
func TestResolveHTTPServiceWebSocketOrigin(t *testing.T) {
inboundInstance := newTestIngressInbound(t)
inboundInstance.configManager.activeConfig = RuntimeConfig{
Ingress: []compiledIngressRule{
{Hostname: "foo.com", Service: mustResolvedService(t, "ws://127.0.0.1:8083")},
{Service: mustResolvedService(t, "http_status:404")},
},
}
_, requestURL, err := inboundInstance.resolveHTTPService("https://foo.com/path?q=1")
if err != nil {
t.Fatal(err)
}
if requestURL != "http://127.0.0.1:8083/path?q=1" {
t.Fatalf("expected websocket origin to be canonicalized, got %s", requestURL)
}
}

View File

@@ -74,6 +74,20 @@ func (s ResolvedService) BuildRequestURL(requestURL string) (string, error) {
}
}
func canonicalizeHTTPOriginURL(parsedURL *url.URL) *url.URL {
if parsedURL == nil {
return nil
}
canonicalURL := *parsedURL
switch canonicalURL.Scheme {
case "ws":
canonicalURL.Scheme = "http"
case "wss":
canonicalURL.Scheme = "https"
}
return &canonicalURL
}
type compiledIngressRule struct {
Hostname string
PunycodeHostname string
@@ -451,7 +465,7 @@ func parseResolvedService(rawService string, originRequest OriginRequestConfig)
Kind: ResolvedServiceHTTP,
Service: rawService,
Destination: parseServiceDestination(parsedURL),
BaseURL: parsedURL,
BaseURL: canonicalizeHTTPOriginURL(parsedURL),
OriginRequest: originRequest,
}, nil
case "tcp", "ssh", "rdp", "smb":