mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
Fix cloudflared compatibility gaps
This commit is contained in:
232
protocol/cloudflare/connection_drain_test.go
Normal file
232
protocol/cloudflare/connection_drain_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sagernet/quic-go"
|
||||
)
|
||||
|
||||
type stubNetConn struct {
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func newStubNetConn() *stubNetConn {
|
||||
return &stubNetConn{closed: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (c *stubNetConn) Read(_ []byte) (int, error) { <-c.closed; return 0, io.EOF }
|
||||
func (c *stubNetConn) Write(b []byte) (int, error) { return len(b), nil }
|
||||
func (c *stubNetConn) Close() error { closeOnce(c.closed); return nil }
|
||||
func (c *stubNetConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
|
||||
func (c *stubNetConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
|
||||
func (c *stubNetConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *stubNetConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *stubNetConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
type stubQUICConn struct {
|
||||
closed chan string
|
||||
}
|
||||
|
||||
func newStubQUICConn() *stubQUICConn {
|
||||
return &stubQUICConn{closed: make(chan string, 1)}
|
||||
}
|
||||
|
||||
func (c *stubQUICConn) OpenStream() (*quic.Stream, error) { return nil, errors.New("unused") }
|
||||
func (c *stubQUICConn) AcceptStream(context.Context) (*quic.Stream, error) {
|
||||
return nil, errors.New("unused")
|
||||
}
|
||||
func (c *stubQUICConn) ReceiveDatagram(context.Context) ([]byte, error) {
|
||||
return nil, errors.New("unused")
|
||||
}
|
||||
func (c *stubQUICConn) SendDatagram([]byte) error { return nil }
|
||||
func (c *stubQUICConn) LocalAddr() net.Addr { return &net.UDPAddr{} }
|
||||
func (c *stubQUICConn) CloseWithError(_ quic.ApplicationErrorCode, reason string) error {
|
||||
select {
|
||||
case c.closed <- reason:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockRegistrationClient struct {
|
||||
unregisterCalled chan struct{}
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func newMockRegistrationClient() *mockRegistrationClient {
|
||||
return &mockRegistrationClient{
|
||||
unregisterCalled: make(chan struct{}, 1),
|
||||
closed: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mockRegistrationClient) RegisterConnection(context.Context, TunnelAuth, uuid.UUID, uint8, *RegistrationConnectionOptions) (*RegistrationResult, error) {
|
||||
return &RegistrationResult{}, nil
|
||||
}
|
||||
|
||||
func (c *mockRegistrationClient) Unregister(context.Context) error {
|
||||
select {
|
||||
case c.unregisterCalled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mockRegistrationClient) Close() error {
|
||||
select {
|
||||
case c.closed <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func closeOnce(ch chan struct{}) {
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP2GracefulShutdownWaitsForActiveRequests(t *testing.T) {
|
||||
conn := newStubNetConn()
|
||||
registrationClient := newMockRegistrationClient()
|
||||
connection := &HTTP2Connection{
|
||||
conn: conn,
|
||||
gracePeriod: 200 * time.Millisecond,
|
||||
registrationClient: registrationClient,
|
||||
registrationResult: &RegistrationResult{},
|
||||
serveCancel: func() {},
|
||||
}
|
||||
connection.activeRequests.Add(1)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
connection.gracefulShutdown()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-registrationClient.unregisterCalled:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected unregister call")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-conn.closed:
|
||||
t.Fatal("connection closed before active requests completed")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
connection.activeRequests.Done()
|
||||
|
||||
select {
|
||||
case <-conn.closed:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected connection close after active requests finished")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected graceful shutdown to finish")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP2GracefulShutdownTimesOut(t *testing.T) {
|
||||
conn := newStubNetConn()
|
||||
registrationClient := newMockRegistrationClient()
|
||||
connection := &HTTP2Connection{
|
||||
conn: conn,
|
||||
gracePeriod: 50 * time.Millisecond,
|
||||
registrationClient: registrationClient,
|
||||
registrationResult: &RegistrationResult{},
|
||||
serveCancel: func() {},
|
||||
}
|
||||
connection.activeRequests.Add(1)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
connection.gracefulShutdown()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-conn.closed:
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("expected connection close after grace timeout")
|
||||
}
|
||||
|
||||
connection.activeRequests.Done()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected graceful shutdown to finish after request completion")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQUICGracefulShutdownWaitsForDrainWindow(t *testing.T) {
|
||||
conn := newStubQUICConn()
|
||||
registrationClient := newMockRegistrationClient()
|
||||
serveCancelCalled := make(chan struct{}, 1)
|
||||
connection := &QUICConnection{
|
||||
conn: conn,
|
||||
gracePeriod: 80 * time.Millisecond,
|
||||
registrationClient: registrationClient,
|
||||
registrationResult: &RegistrationResult{},
|
||||
serveCancel: func() {
|
||||
select {
|
||||
case serveCancelCalled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
connection.gracefulShutdown()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-registrationClient.unregisterCalled:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected unregister call")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-conn.closed:
|
||||
t.Fatal("connection closed before grace window elapsed")
|
||||
case <-time.After(20 * time.Millisecond):
|
||||
}
|
||||
|
||||
select {
|
||||
case reason := <-conn.closed:
|
||||
if reason != "graceful shutdown" {
|
||||
t.Fatalf("unexpected close reason: %q", reason)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected graceful close")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-serveCancelCalled:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected serve cancel to be called")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected graceful shutdown to finish")
|
||||
}
|
||||
}
|
||||
@@ -44,12 +44,15 @@ type HTTP2Connection struct {
|
||||
inbound *Inbound
|
||||
|
||||
numPreviousAttempts uint8
|
||||
registrationClient *RegistrationClient
|
||||
registrationClient registrationRPCClient
|
||||
registrationResult *RegistrationResult
|
||||
controlStreamErr error
|
||||
|
||||
activeRequests sync.WaitGroup
|
||||
closeOnce sync.Once
|
||||
activeRequests sync.WaitGroup
|
||||
serveCancel context.CancelFunc
|
||||
registrationClose sync.Once
|
||||
shutdownOnce sync.Once
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// NewHTTP2Connection dials the edge and establishes an HTTP/2 connection with role reversal.
|
||||
@@ -106,22 +109,28 @@ func NewHTTP2Connection(
|
||||
|
||||
// Serve runs the HTTP/2 server. Blocks until the context is cancelled or the connection ends.
|
||||
func (c *HTTP2Connection) Serve(ctx context.Context) error {
|
||||
serveCtx, serveCancel := context.WithCancel(context.WithoutCancel(ctx))
|
||||
c.serveCancel = serveCancel
|
||||
|
||||
shutdownDone := make(chan struct{})
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
c.close()
|
||||
c.gracefulShutdown()
|
||||
close(shutdownDone)
|
||||
}()
|
||||
|
||||
c.server.ServeConn(c.conn, &http2.ServeConnOpts{
|
||||
Context: ctx,
|
||||
Context: serveCtx,
|
||||
Handler: c,
|
||||
})
|
||||
|
||||
if ctx.Err() != nil {
|
||||
<-shutdownDone
|
||||
return ctx.Err()
|
||||
}
|
||||
if c.controlStreamErr != nil {
|
||||
return c.controlStreamErr
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if c.registrationResult == nil {
|
||||
return E.New("edge connection closed before registration")
|
||||
}
|
||||
@@ -129,12 +138,15 @@ func (c *HTTP2Connection) Serve(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get(h2HeaderUpgrade) == h2UpgradeControlStream {
|
||||
c.handleControlStream(r.Context(), r, w)
|
||||
return
|
||||
}
|
||||
|
||||
c.activeRequests.Add(1)
|
||||
defer c.activeRequests.Done()
|
||||
|
||||
switch {
|
||||
case r.Header.Get(h2HeaderUpgrade) == h2UpgradeControlStream:
|
||||
c.handleControlStream(r.Context(), r, w)
|
||||
case r.Header.Get(h2HeaderUpgrade) == h2UpgradeWebsocket:
|
||||
c.handleH2DataStream(r.Context(), r, w, ConnectionTypeWebsocket)
|
||||
case r.Header.Get(h2HeaderTCPSrc) != "":
|
||||
@@ -169,17 +181,13 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque
|
||||
if err != nil {
|
||||
c.controlStreamErr = err
|
||||
c.logger.Error("register connection: ", err)
|
||||
if c.registrationClient != nil {
|
||||
c.registrationClient.Close()
|
||||
}
|
||||
go c.close()
|
||||
go c.forceClose()
|
||||
return
|
||||
}
|
||||
if err := validateRegistrationResult(result); err != nil {
|
||||
c.controlStreamErr = err
|
||||
c.logger.Error("register connection: ", err)
|
||||
c.registrationClient.Close()
|
||||
go c.close()
|
||||
go c.forceClose()
|
||||
return
|
||||
}
|
||||
c.registrationResult = result
|
||||
@@ -189,13 +197,6 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque
|
||||
" (connection ", result.ConnectionID, ")")
|
||||
|
||||
<-ctx.Done()
|
||||
unregisterCtx, cancel := context.WithTimeout(context.Background(), c.gracePeriod)
|
||||
defer cancel()
|
||||
err = c.registrationClient.Unregister(unregisterCtx)
|
||||
if err != nil {
|
||||
c.logger.Debug("failed to unregister: ", err)
|
||||
}
|
||||
c.registrationClient.Close()
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) handleH2DataStream(ctx context.Context, r *http.Request, w http.ResponseWriter, connectionType ConnectionType) {
|
||||
@@ -280,16 +281,74 @@ func (c *HTTP2Connection) handleConfigurationUpdate(r *http.Request, w http.Resp
|
||||
w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":null}`))
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) close() {
|
||||
func (c *HTTP2Connection) gracefulShutdown() {
|
||||
c.shutdownOnce.Do(func() {
|
||||
if c.registrationClient == nil || c.registrationResult == nil {
|
||||
c.closeNow()
|
||||
return
|
||||
}
|
||||
|
||||
unregisterCtx, cancel := context.WithTimeout(context.Background(), c.gracePeriod)
|
||||
err := c.registrationClient.Unregister(unregisterCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
c.logger.Debug("failed to unregister: ", err)
|
||||
}
|
||||
c.closeRegistrationClient()
|
||||
c.waitForActiveRequests(c.gracePeriod)
|
||||
c.closeNow()
|
||||
})
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) forceClose() {
|
||||
c.shutdownOnce.Do(func() {
|
||||
c.closeNow()
|
||||
})
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) waitForActiveRequests(timeout time.Duration) {
|
||||
if timeout <= 0 {
|
||||
c.activeRequests.Wait()
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.activeRequests.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) closeRegistrationClient() {
|
||||
c.registrationClose.Do(func() {
|
||||
if c.registrationClient != nil {
|
||||
_ = c.registrationClient.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) closeNow() {
|
||||
c.closeOnce.Do(func() {
|
||||
c.conn.Close()
|
||||
_ = c.conn.Close()
|
||||
if c.serveCancel != nil {
|
||||
c.serveCancel()
|
||||
}
|
||||
c.closeRegistrationClient()
|
||||
c.activeRequests.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
// Close closes the HTTP/2 connection.
|
||||
func (c *HTTP2Connection) Close() error {
|
||||
c.close()
|
||||
c.forceClose()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -48,11 +48,14 @@ type QUICConnection struct {
|
||||
features []string
|
||||
numPreviousAttempts uint8
|
||||
gracePeriod time.Duration
|
||||
registrationClient *RegistrationClient
|
||||
registrationClient registrationRPCClient
|
||||
registrationResult *RegistrationResult
|
||||
onConnected func()
|
||||
|
||||
closeOnce sync.Once
|
||||
serveCancel context.CancelFunc
|
||||
registrationClose sync.Once
|
||||
shutdownOnce sync.Once
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
type quicConnection interface {
|
||||
@@ -180,22 +183,29 @@ func (q *QUICConnection) Serve(ctx context.Context, handler StreamHandler) error
|
||||
q.logger.Info("connected to ", q.registrationResult.Location,
|
||||
" (connection ", q.registrationResult.ConnectionID, ")")
|
||||
|
||||
serveCtx, serveCancel := context.WithCancel(context.WithoutCancel(ctx))
|
||||
q.serveCancel = serveCancel
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
errChan <- q.acceptStreams(ctx, handler)
|
||||
errChan <- q.acceptStreams(serveCtx, handler)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
errChan <- q.handleDatagrams(ctx, handler)
|
||||
errChan <- q.handleDatagrams(serveCtx, handler)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
q.gracefulShutdown()
|
||||
<-errChan
|
||||
return ctx.Err()
|
||||
case err = <-errChan:
|
||||
q.gracefulShutdown()
|
||||
q.forceClose()
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -285,23 +295,55 @@ func (q *QUICConnection) OpenRPCStream(ctx context.Context) (io.ReadWriteCloser,
|
||||
}
|
||||
|
||||
func (q *QUICConnection) gracefulShutdown() {
|
||||
q.closeOnce.Do(func() {
|
||||
if q.registrationClient != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), q.gracePeriod)
|
||||
defer cancel()
|
||||
err := q.registrationClient.Unregister(ctx)
|
||||
if err != nil {
|
||||
q.logger.Debug("failed to unregister: ", err)
|
||||
}
|
||||
q.registrationClient.Close()
|
||||
q.shutdownOnce.Do(func() {
|
||||
if q.registrationClient == nil || q.registrationResult == nil {
|
||||
q.closeNow("connection closed")
|
||||
return
|
||||
}
|
||||
q.conn.CloseWithError(0, "graceful shutdown")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), q.gracePeriod)
|
||||
err := q.registrationClient.Unregister(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
q.logger.Debug("failed to unregister: ", err)
|
||||
}
|
||||
q.closeRegistrationClient()
|
||||
if q.gracePeriod > 0 {
|
||||
timer := time.NewTimer(q.gracePeriod)
|
||||
<-timer.C
|
||||
timer.Stop()
|
||||
}
|
||||
q.closeNow("graceful shutdown")
|
||||
})
|
||||
}
|
||||
|
||||
func (q *QUICConnection) forceClose() {
|
||||
q.shutdownOnce.Do(func() {
|
||||
q.closeNow("connection closed")
|
||||
})
|
||||
}
|
||||
|
||||
func (q *QUICConnection) closeRegistrationClient() {
|
||||
q.registrationClose.Do(func() {
|
||||
if q.registrationClient != nil {
|
||||
_ = q.registrationClient.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (q *QUICConnection) closeNow(reason string) {
|
||||
q.closeOnce.Do(func() {
|
||||
if q.serveCancel != nil {
|
||||
q.serveCancel()
|
||||
}
|
||||
q.closeRegistrationClient()
|
||||
_ = q.conn.CloseWithError(0, reason)
|
||||
})
|
||||
}
|
||||
|
||||
// Close closes the QUIC connection immediately.
|
||||
func (q *QUICConnection) Close() error {
|
||||
q.gracefulShutdown()
|
||||
q.forceClose()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,18 @@ type RegistrationClient struct {
|
||||
transport rpc.Transport
|
||||
}
|
||||
|
||||
type registrationRPCClient interface {
|
||||
RegisterConnection(
|
||||
ctx context.Context,
|
||||
auth TunnelAuth,
|
||||
tunnelID uuid.UUID,
|
||||
connIndex uint8,
|
||||
options *RegistrationConnectionOptions,
|
||||
) (*RegistrationResult, error)
|
||||
Unregister(ctx context.Context) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// NewRegistrationClient creates a Cap'n Proto RPC client over the given stream.
|
||||
// The stream should be the first QUIC stream (control stream).
|
||||
func NewRegistrationClient(ctx context.Context, stream io.ReadWriteCloser) *RegistrationClient {
|
||||
|
||||
@@ -5,11 +5,16 @@ package cloudflare
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type v2UnregisterCall struct {
|
||||
@@ -25,6 +30,43 @@ type captureV2SessionRPCClient struct {
|
||||
unregisterCh chan v2UnregisterCall
|
||||
}
|
||||
|
||||
type blockingPacketConn struct {
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func newBlockingPacketConn() *blockingPacketConn {
|
||||
return &blockingPacketConn{closed: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (c *blockingPacketConn) ReadPacket(_ *buf.Buffer) (M.Socksaddr, error) {
|
||||
<-c.closed
|
||||
return M.Socksaddr{}, io.EOF
|
||||
}
|
||||
|
||||
func (c *blockingPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
|
||||
buffer.Release()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *blockingPacketConn) Close() error {
|
||||
closeOnce(c.closed)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *blockingPacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} }
|
||||
func (c *blockingPacketConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *blockingPacketConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *blockingPacketConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
type packetDialingRouter struct {
|
||||
testRouter
|
||||
packetConn N.PacketConn
|
||||
}
|
||||
|
||||
func (r *packetDialingRouter) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) {
|
||||
return r.packetConn, nil
|
||||
}
|
||||
|
||||
func (c *captureV2SessionRPCClient) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string) error {
|
||||
c.unregisterCh <- v2UnregisterCall{sessionID: sessionID, message: message}
|
||||
return nil
|
||||
@@ -105,3 +147,57 @@ func TestDatagramV3RegistrationMigratesSender(t *testing.T) {
|
||||
|
||||
session.close()
|
||||
}
|
||||
|
||||
func TestDatagramV3MigrationUpdatesSessionContext(t *testing.T) {
|
||||
packetConn := newBlockingPacketConn()
|
||||
inboundInstance := newLimitedInbound(t, 0)
|
||||
inboundInstance.router = &packetDialingRouter{packetConn: packetConn}
|
||||
sender1 := &captureDatagramSender{}
|
||||
sender2 := &captureDatagramSender{}
|
||||
muxer1 := NewDatagramV3Muxer(inboundInstance, sender1, inboundInstance.logger)
|
||||
muxer2 := NewDatagramV3Muxer(inboundInstance, sender2, inboundInstance.logger)
|
||||
|
||||
requestID := RequestID{}
|
||||
requestID[15] = 10
|
||||
payload := make([]byte, 1+2+2+16+4)
|
||||
payload[0] = 0
|
||||
binary.BigEndian.PutUint16(payload[1:3], 53)
|
||||
binary.BigEndian.PutUint16(payload[3:5], 30)
|
||||
copy(payload[5:21], requestID[:])
|
||||
copy(payload[21:25], []byte{127, 0, 0, 1})
|
||||
|
||||
ctx1, cancel1 := context.WithCancel(context.Background())
|
||||
muxer1.handleRegistration(ctx1, payload)
|
||||
|
||||
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||
muxer2.handleRegistration(ctx2, payload)
|
||||
|
||||
cancel1()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
session, exists := inboundInstance.datagramV3Manager.Get(requestID)
|
||||
if !exists {
|
||||
t.Fatal("expected session to survive old connection context cancellation")
|
||||
}
|
||||
|
||||
session.senderAccess.RLock()
|
||||
currentSender := session.sender
|
||||
session.senderAccess.RUnlock()
|
||||
if currentSender != sender2 {
|
||||
t.Fatal("expected migrated sender to stay active")
|
||||
}
|
||||
|
||||
cancel2()
|
||||
|
||||
deadline := time.After(time.Second)
|
||||
for {
|
||||
if _, exists := inboundInstance.datagramV3Manager.Get(requestID); !exists {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-deadline:
|
||||
t.Fatal("expected session to be removed after new context cancellation")
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ const (
|
||||
v3RegistrationBaseLen = 1 + v3RegistrationFlagLen + v3RegistrationPortLen + v3RegistrationIdleLen + v3RequestIDLength // 22
|
||||
v3PayloadHeaderLen = 1 + v3RequestIDLength // 17
|
||||
v3RegistrationRespLen = 1 + 1 + v3RequestIDLength + 2 // 20
|
||||
maxV3UDPPayloadLen = 1280
|
||||
|
||||
// V3 registration flags
|
||||
v3FlagIPv6 byte = 0x01
|
||||
@@ -238,6 +239,10 @@ type v3Session struct {
|
||||
|
||||
senderAccess sync.RWMutex
|
||||
sender DatagramSender
|
||||
|
||||
contextAccess sync.RWMutex
|
||||
connCtx context.Context
|
||||
contextChan chan context.Context
|
||||
}
|
||||
|
||||
var errTooManyActiveFlows = errors.New("too many active flows")
|
||||
@@ -253,11 +258,12 @@ func (m *DatagramV3SessionManager) Register(
|
||||
m.sessionAccess.Lock()
|
||||
if existing, exists := m.sessions[requestID]; exists {
|
||||
if existing.sender == sender {
|
||||
existing.updateContext(ctx)
|
||||
existing.markActive()
|
||||
m.sessionAccess.Unlock()
|
||||
return existing, v3RegistrationExisting, nil
|
||||
}
|
||||
existing.setSender(sender)
|
||||
existing.migrate(sender, ctx)
|
||||
existing.markActive()
|
||||
m.sessionAccess.Unlock()
|
||||
return existing, v3RegistrationMigrated, nil
|
||||
@@ -286,14 +292,17 @@ func (m *DatagramV3SessionManager) Register(
|
||||
closeChan: make(chan struct{}),
|
||||
activeAt: time.Now(),
|
||||
sender: sender,
|
||||
connCtx: ctx,
|
||||
contextChan: make(chan context.Context, 1),
|
||||
}
|
||||
m.sessions[requestID] = session
|
||||
m.sessionAccess.Unlock()
|
||||
|
||||
sessionCtx := inbound.ctx
|
||||
sessionCtx := ctx
|
||||
if sessionCtx == nil {
|
||||
sessionCtx = context.Background()
|
||||
}
|
||||
session.connCtx = sessionCtx
|
||||
go session.serve(sessionCtx, limit)
|
||||
return session, v3RegistrationNew, nil
|
||||
}
|
||||
@@ -320,6 +329,8 @@ func (s *v3Session) serve(ctx context.Context, limit uint64) {
|
||||
go s.readLoop()
|
||||
go s.writeLoop()
|
||||
|
||||
connCtx := ctx
|
||||
|
||||
tickInterval := s.closeAfterIdle / 2
|
||||
if tickInterval <= 0 || tickInterval > 10*time.Second {
|
||||
tickInterval = time.Second
|
||||
@@ -329,8 +340,16 @@ func (s *v3Session) serve(ctx context.Context, limit uint64) {
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-connCtx.Done():
|
||||
if latestCtx := s.currentContext(); latestCtx != nil && latestCtx != connCtx {
|
||||
connCtx = latestCtx
|
||||
continue
|
||||
}
|
||||
s.close()
|
||||
case newCtx := <-s.contextChan:
|
||||
if newCtx != nil {
|
||||
connCtx = newCtx
|
||||
}
|
||||
case <-ticker.C:
|
||||
if time.Since(s.lastActive()) >= s.closeAfterIdle {
|
||||
s.close()
|
||||
@@ -350,6 +369,11 @@ func (s *v3Session) readLoop() {
|
||||
s.close()
|
||||
return
|
||||
}
|
||||
if buffer.Len() > maxV3UDPPayloadLen {
|
||||
s.inbound.logger.Debug("drop oversized V3 UDP payload: ", buffer.Len())
|
||||
buffer.Release()
|
||||
continue
|
||||
}
|
||||
s.markActive()
|
||||
if err := s.senderDatagram(append([]byte(nil), buffer.Bytes()...)); err != nil {
|
||||
buffer.Release()
|
||||
@@ -403,6 +427,35 @@ func (s *v3Session) setSender(sender DatagramSender) {
|
||||
s.senderAccess.Unlock()
|
||||
}
|
||||
|
||||
func (s *v3Session) updateContext(ctx context.Context) {
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
s.contextAccess.Lock()
|
||||
s.connCtx = ctx
|
||||
s.contextAccess.Unlock()
|
||||
select {
|
||||
case s.contextChan <- ctx:
|
||||
default:
|
||||
select {
|
||||
case <-s.contextChan:
|
||||
default:
|
||||
}
|
||||
s.contextChan <- ctx
|
||||
}
|
||||
}
|
||||
|
||||
func (s *v3Session) migrate(sender DatagramSender, ctx context.Context) {
|
||||
s.setSender(sender)
|
||||
s.updateContext(ctx)
|
||||
}
|
||||
|
||||
func (s *v3Session) currentContext() context.Context {
|
||||
s.contextAccess.RLock()
|
||||
defer s.contextAccess.RUnlock()
|
||||
return s.connCtx
|
||||
}
|
||||
|
||||
func (s *v3Session) markActive() {
|
||||
s.activeAccess.Lock()
|
||||
s.activeAt = time.Now()
|
||||
|
||||
@@ -5,10 +5,18 @@ package cloudflare
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
func TestDatagramV3RegistrationDestinationUnreachable(t *testing.T) {
|
||||
@@ -64,3 +72,80 @@ func TestDatagramV3RegistrationErrorWithMessage(t *testing.T) {
|
||||
t.Fatalf("unexpected datagram response: %v", sender.sent[0])
|
||||
}
|
||||
}
|
||||
|
||||
type scriptedPacketConn struct {
|
||||
reads [][]byte
|
||||
index int
|
||||
}
|
||||
|
||||
func (c *scriptedPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||
if c.index >= len(c.reads) {
|
||||
return M.Socksaddr{}, io.EOF
|
||||
}
|
||||
_, err := buffer.Write(c.reads[c.index])
|
||||
c.index++
|
||||
return M.Socksaddr{}, err
|
||||
}
|
||||
|
||||
func (c *scriptedPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
|
||||
buffer.Release()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *scriptedPacketConn) Close() error { return nil }
|
||||
func (c *scriptedPacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} }
|
||||
func (c *scriptedPacketConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *scriptedPacketConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *scriptedPacketConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
type sizeLimitedSender struct {
|
||||
sent [][]byte
|
||||
max int
|
||||
}
|
||||
|
||||
func (s *sizeLimitedSender) SendDatagram(data []byte) error {
|
||||
if len(data) > s.max {
|
||||
return errors.New("datagram too large")
|
||||
}
|
||||
s.sent = append(s.sent, append([]byte(nil), data...))
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDatagramV3ReadLoopDropsOversizedOriginPackets(t *testing.T) {
|
||||
logger := log.NewNOPFactory().NewLogger("test")
|
||||
sender := &sizeLimitedSender{max: v3PayloadHeaderLen + maxV3UDPPayloadLen}
|
||||
session := &v3Session{
|
||||
id: RequestID{},
|
||||
destination: netip.MustParseAddrPort("127.0.0.1:53"),
|
||||
origin: &scriptedPacketConn{reads: [][]byte{
|
||||
make([]byte, maxV3UDPPayloadLen+1),
|
||||
[]byte("ok"),
|
||||
}},
|
||||
inbound: &Inbound{
|
||||
logger: logger,
|
||||
},
|
||||
writeChan: make(chan []byte, 1),
|
||||
closeChan: make(chan struct{}),
|
||||
contextChan: make(chan context.Context, 1),
|
||||
sender: sender,
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
session.readLoop()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected read loop to finish")
|
||||
}
|
||||
|
||||
if len(sender.sent) != 1 {
|
||||
t.Fatalf("expected one datagram after dropping oversized payload, got %d", len(sender.sent))
|
||||
}
|
||||
if len(sender.sent[0]) != v3PayloadHeaderLen+2 {
|
||||
t.Fatalf("unexpected forwarded datagram length: %d", len(sender.sent[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,15 @@ import (
|
||||
const (
|
||||
icmpFlowTimeout = 30 * time.Second
|
||||
icmpTraceIdentityLength = 16 + 8 + 1
|
||||
defaultICMPPacketTTL = 64
|
||||
icmpErrorHeaderLen = 8
|
||||
|
||||
icmpv4TypeEchoRequest = 8
|
||||
icmpv4TypeEchoReply = 0
|
||||
icmpv4TypeTimeExceeded = 11
|
||||
icmpv6TypeEchoRequest = 128
|
||||
icmpv6TypeEchoReply = 129
|
||||
icmpv6TypeTimeExceeded = 3
|
||||
)
|
||||
|
||||
type ICMPTraceContext struct {
|
||||
@@ -40,15 +49,18 @@ type ICMPRequestKey struct {
|
||||
}
|
||||
|
||||
type ICMPPacketInfo struct {
|
||||
IPVersion uint8
|
||||
Protocol uint8
|
||||
SourceIP netip.Addr
|
||||
Destination netip.Addr
|
||||
ICMPType uint8
|
||||
ICMPCode uint8
|
||||
Identifier uint16
|
||||
Sequence uint16
|
||||
RawPacket []byte
|
||||
IPVersion uint8
|
||||
Protocol uint8
|
||||
SourceIP netip.Addr
|
||||
Destination netip.Addr
|
||||
ICMPType uint8
|
||||
ICMPCode uint8
|
||||
Identifier uint16
|
||||
Sequence uint16
|
||||
IPv4HeaderLen int
|
||||
IPv4TTL uint8
|
||||
IPv6HopLimit uint8
|
||||
RawPacket []byte
|
||||
}
|
||||
|
||||
func (i ICMPPacketInfo) FlowKey() ICMPFlowKey {
|
||||
@@ -82,9 +94,9 @@ func (i ICMPPacketInfo) ReplyRequestKey() ICMPRequestKey {
|
||||
func (i ICMPPacketInfo) IsEchoRequest() bool {
|
||||
switch i.IPVersion {
|
||||
case 4:
|
||||
return i.ICMPType == 8 && i.ICMPCode == 0
|
||||
return i.ICMPType == icmpv4TypeEchoRequest && i.ICMPCode == 0
|
||||
case 6:
|
||||
return i.ICMPType == 128 && i.ICMPCode == 0
|
||||
return i.ICMPType == icmpv6TypeEchoRequest && i.ICMPCode == 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
@@ -93,14 +105,47 @@ func (i ICMPPacketInfo) IsEchoRequest() bool {
|
||||
func (i ICMPPacketInfo) IsEchoReply() bool {
|
||||
switch i.IPVersion {
|
||||
case 4:
|
||||
return i.ICMPType == 0 && i.ICMPCode == 0
|
||||
return i.ICMPType == icmpv4TypeEchoReply && i.ICMPCode == 0
|
||||
case 6:
|
||||
return i.ICMPType == 129 && i.ICMPCode == 0
|
||||
return i.ICMPType == icmpv6TypeEchoReply && i.ICMPCode == 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (i ICMPPacketInfo) TTL() uint8 {
|
||||
if i.IPVersion == 4 {
|
||||
return i.IPv4TTL
|
||||
}
|
||||
return i.IPv6HopLimit
|
||||
}
|
||||
|
||||
func (i ICMPPacketInfo) TTLExpired() bool {
|
||||
return i.TTL() <= 1
|
||||
}
|
||||
|
||||
func (i *ICMPPacketInfo) DecrementTTL() error {
|
||||
switch i.IPVersion {
|
||||
case 4:
|
||||
if i.IPv4TTL == 0 || i.IPv4HeaderLen < 20 || len(i.RawPacket) < i.IPv4HeaderLen {
|
||||
return E.New("invalid IPv4 packet TTL state")
|
||||
}
|
||||
i.IPv4TTL--
|
||||
i.RawPacket[8] = i.IPv4TTL
|
||||
binary.BigEndian.PutUint16(i.RawPacket[10:12], 0)
|
||||
binary.BigEndian.PutUint16(i.RawPacket[10:12], checksum(i.RawPacket[:i.IPv4HeaderLen], 0))
|
||||
case 6:
|
||||
if i.IPv6HopLimit == 0 || len(i.RawPacket) < 40 {
|
||||
return E.New("invalid IPv6 packet hop limit state")
|
||||
}
|
||||
i.IPv6HopLimit--
|
||||
i.RawPacket[7] = i.IPv6HopLimit
|
||||
default:
|
||||
return E.New("unsupported IP version: ", i.IPVersion)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type icmpWireVersion uint8
|
||||
|
||||
const (
|
||||
@@ -154,15 +199,7 @@ func (w *ICMPReplyWriter) WritePacket(packet []byte) error {
|
||||
}
|
||||
w.access.Unlock()
|
||||
|
||||
var datagram []byte
|
||||
switch w.wireVersion {
|
||||
case icmpWireV2:
|
||||
datagram, err = encodeV2ICMPDatagram(packetInfo.RawPacket, traceContext)
|
||||
case icmpWireV3:
|
||||
datagram = encodeV3ICMPDatagram(packetInfo.RawPacket)
|
||||
default:
|
||||
err = E.New("unsupported icmp wire version: ", w.wireVersion)
|
||||
}
|
||||
datagram, err := encodeICMPDatagram(packetInfo.RawPacket, w.wireVersion, traceContext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -218,6 +255,21 @@ func (b *ICMPBridge) handlePacket(ctx context.Context, payload []byte, traceCont
|
||||
if !packetInfo.IsEchoRequest() {
|
||||
return nil
|
||||
}
|
||||
if packetInfo.TTLExpired() {
|
||||
ttlExceededPacket, err := buildICMPTTLExceededPacket(packetInfo, maxEncodedICMPPacketLen(b.wireVersion, traceContext))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
datagram, err := encodeICMPDatagram(ttlExceededPacket, b.wireVersion, traceContext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return b.sender.SendDatagram(datagram)
|
||||
}
|
||||
|
||||
if err := packetInfo.DecrementTTL(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state := b.getFlowState(packetInfo.FlowKey())
|
||||
if traceContext.Traced {
|
||||
@@ -294,15 +346,17 @@ func parseIPv4ICMPPacket(packet []byte) (ICMPPacketInfo, error) {
|
||||
return ICMPPacketInfo{}, E.New("invalid IPv4 destination address")
|
||||
}
|
||||
return ICMPPacketInfo{
|
||||
IPVersion: 4,
|
||||
Protocol: 1,
|
||||
SourceIP: sourceIP,
|
||||
Destination: destinationIP,
|
||||
ICMPType: packet[headerLen],
|
||||
ICMPCode: packet[headerLen+1],
|
||||
Identifier: binary.BigEndian.Uint16(packet[headerLen+4 : headerLen+6]),
|
||||
Sequence: binary.BigEndian.Uint16(packet[headerLen+6 : headerLen+8]),
|
||||
RawPacket: append([]byte(nil), packet...),
|
||||
IPVersion: 4,
|
||||
Protocol: 1,
|
||||
SourceIP: sourceIP,
|
||||
Destination: destinationIP,
|
||||
ICMPType: packet[headerLen],
|
||||
ICMPCode: packet[headerLen+1],
|
||||
Identifier: binary.BigEndian.Uint16(packet[headerLen+4 : headerLen+6]),
|
||||
Sequence: binary.BigEndian.Uint16(packet[headerLen+6 : headerLen+8]),
|
||||
IPv4HeaderLen: headerLen,
|
||||
IPv4TTL: packet[8],
|
||||
RawPacket: append([]byte(nil), packet...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -322,18 +376,139 @@ func parseIPv6ICMPPacket(packet []byte) (ICMPPacketInfo, error) {
|
||||
return ICMPPacketInfo{}, E.New("invalid IPv6 destination address")
|
||||
}
|
||||
return ICMPPacketInfo{
|
||||
IPVersion: 6,
|
||||
Protocol: 58,
|
||||
SourceIP: sourceIP,
|
||||
Destination: destinationIP,
|
||||
ICMPType: packet[40],
|
||||
ICMPCode: packet[41],
|
||||
Identifier: binary.BigEndian.Uint16(packet[44:46]),
|
||||
Sequence: binary.BigEndian.Uint16(packet[46:48]),
|
||||
RawPacket: append([]byte(nil), packet...),
|
||||
IPVersion: 6,
|
||||
Protocol: 58,
|
||||
SourceIP: sourceIP,
|
||||
Destination: destinationIP,
|
||||
ICMPType: packet[40],
|
||||
ICMPCode: packet[41],
|
||||
Identifier: binary.BigEndian.Uint16(packet[44:46]),
|
||||
Sequence: binary.BigEndian.Uint16(packet[46:48]),
|
||||
IPv6HopLimit: packet[7],
|
||||
RawPacket: append([]byte(nil), packet...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func maxEncodedICMPPacketLen(wireVersion icmpWireVersion, traceContext ICMPTraceContext) int {
|
||||
limit := maxV3UDPPayloadLen
|
||||
switch wireVersion {
|
||||
case icmpWireV2:
|
||||
limit -= typeIDLength
|
||||
if traceContext.Traced {
|
||||
limit -= len(traceContext.Identity)
|
||||
}
|
||||
case icmpWireV3:
|
||||
limit -= 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
if limit < 0 {
|
||||
return 0
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
func buildICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) {
|
||||
switch packetInfo.IPVersion {
|
||||
case 4:
|
||||
return buildIPv4ICMPTTLExceededPacket(packetInfo, maxPacketLen)
|
||||
case 6:
|
||||
return buildIPv6ICMPTTLExceededPacket(packetInfo, maxPacketLen)
|
||||
default:
|
||||
return nil, E.New("unsupported IP version: ", packetInfo.IPVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func buildIPv4ICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) {
|
||||
const headerLen = 20
|
||||
if !packetInfo.SourceIP.Is4() || !packetInfo.Destination.Is4() {
|
||||
return nil, E.New("TTL exceeded packet requires IPv4 addresses")
|
||||
}
|
||||
if maxPacketLen <= headerLen+icmpErrorHeaderLen {
|
||||
return nil, E.New("TTL exceeded packet size limit is too small")
|
||||
}
|
||||
|
||||
quotedLength := min(len(packetInfo.RawPacket), maxPacketLen-headerLen-icmpErrorHeaderLen)
|
||||
packet := make([]byte, headerLen+icmpErrorHeaderLen+quotedLength)
|
||||
packet[0] = 0x45
|
||||
binary.BigEndian.PutUint16(packet[2:4], uint16(len(packet)))
|
||||
packet[8] = defaultICMPPacketTTL
|
||||
packet[9] = 1
|
||||
copy(packet[12:16], packetInfo.Destination.AsSlice())
|
||||
copy(packet[16:20], packetInfo.SourceIP.AsSlice())
|
||||
packet[20] = icmpv4TypeTimeExceeded
|
||||
packet[21] = 0
|
||||
copy(packet[headerLen+icmpErrorHeaderLen:], packetInfo.RawPacket[:quotedLength])
|
||||
binary.BigEndian.PutUint16(packet[22:24], checksum(packet[20:], 0))
|
||||
binary.BigEndian.PutUint16(packet[10:12], checksum(packet[:headerLen], 0))
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func buildIPv6ICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) {
|
||||
const headerLen = 40
|
||||
if !packetInfo.SourceIP.Is6() || !packetInfo.Destination.Is6() {
|
||||
return nil, E.New("TTL exceeded packet requires IPv6 addresses")
|
||||
}
|
||||
if maxPacketLen <= headerLen+icmpErrorHeaderLen {
|
||||
return nil, E.New("TTL exceeded packet size limit is too small")
|
||||
}
|
||||
|
||||
quotedLength := min(len(packetInfo.RawPacket), maxPacketLen-headerLen-icmpErrorHeaderLen)
|
||||
packet := make([]byte, headerLen+icmpErrorHeaderLen+quotedLength)
|
||||
packet[0] = 0x60
|
||||
binary.BigEndian.PutUint16(packet[4:6], uint16(icmpErrorHeaderLen+quotedLength))
|
||||
packet[6] = 58
|
||||
packet[7] = defaultICMPPacketTTL
|
||||
copy(packet[8:24], packetInfo.Destination.AsSlice())
|
||||
copy(packet[24:40], packetInfo.SourceIP.AsSlice())
|
||||
packet[40] = icmpv6TypeTimeExceeded
|
||||
packet[41] = 0
|
||||
copy(packet[headerLen+icmpErrorHeaderLen:], packetInfo.RawPacket[:quotedLength])
|
||||
binary.BigEndian.PutUint16(packet[42:44], checksum(packet[40:], ipv6PseudoHeaderChecksum(packetInfo.Destination, packetInfo.SourceIP, uint32(icmpErrorHeaderLen+quotedLength), 58)))
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func encodeICMPDatagram(packet []byte, wireVersion icmpWireVersion, traceContext ICMPTraceContext) ([]byte, error) {
|
||||
switch wireVersion {
|
||||
case icmpWireV2:
|
||||
return encodeV2ICMPDatagram(packet, traceContext)
|
||||
case icmpWireV3:
|
||||
return encodeV3ICMPDatagram(packet), nil
|
||||
default:
|
||||
return nil, E.New("unsupported icmp wire version: ", wireVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func ipv6PseudoHeaderChecksum(source, destination netip.Addr, payloadLength uint32, nextHeader uint8) uint32 {
|
||||
var sum uint32
|
||||
sum = checksumSum(source.AsSlice(), sum)
|
||||
sum = checksumSum(destination.AsSlice(), sum)
|
||||
var lengthBytes [4]byte
|
||||
binary.BigEndian.PutUint32(lengthBytes[:], payloadLength)
|
||||
sum = checksumSum(lengthBytes[:], sum)
|
||||
sum = checksumSum([]byte{0, 0, 0, nextHeader}, sum)
|
||||
return sum
|
||||
}
|
||||
|
||||
func checksumSum(data []byte, sum uint32) uint32 {
|
||||
for len(data) >= 2 {
|
||||
sum += uint32(binary.BigEndian.Uint16(data[:2]))
|
||||
data = data[2:]
|
||||
}
|
||||
if len(data) == 1 {
|
||||
sum += uint32(data[0]) << 8
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
func checksum(data []byte, initial uint32) uint16 {
|
||||
sum := checksumSum(data, initial)
|
||||
for sum > 0xffff {
|
||||
sum = (sum >> 16) + (sum & 0xffff)
|
||||
}
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
func encodeV2ICMPDatagram(packet []byte, traceContext ICMPTraceContext) ([]byte, error) {
|
||||
if traceContext.Traced {
|
||||
data := make([]byte, 0, len(packet)+len(traceContext.Identity)+1)
|
||||
|
||||
@@ -169,6 +169,158 @@ func TestICMPBridgeHandleV3Reply(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestICMPBridgeDecrementsIPv4TTLBeforeRouting(t *testing.T) {
|
||||
var destination *fakeDirectRouteDestination
|
||||
router := &testRouter{
|
||||
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
|
||||
destination = &fakeDirectRouteDestination{routeContext: routeContext}
|
||||
return destination, nil
|
||||
},
|
||||
}
|
||||
inboundInstance := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
router: router,
|
||||
}
|
||||
bridge := NewICMPBridge(inboundInstance, &captureDatagramSender{}, icmpWireV2)
|
||||
|
||||
packet := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), icmpv4TypeEchoRequest, 0, 1, 1)
|
||||
packet[8] = 5
|
||||
|
||||
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(destination.packets) != 1 {
|
||||
t.Fatalf("expected one routed packet, got %d", len(destination.packets))
|
||||
}
|
||||
if got := destination.packets[0][8]; got != 4 {
|
||||
t.Fatalf("expected decremented IPv4 TTL, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestICMPBridgeDecrementsIPv6HopLimitBeforeRouting(t *testing.T) {
|
||||
var destination *fakeDirectRouteDestination
|
||||
router := &testRouter{
|
||||
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
|
||||
destination = &fakeDirectRouteDestination{routeContext: routeContext}
|
||||
return destination, nil
|
||||
},
|
||||
}
|
||||
inboundInstance := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
router: router,
|
||||
}
|
||||
bridge := NewICMPBridge(inboundInstance, &captureDatagramSender{}, icmpWireV3)
|
||||
|
||||
packet := buildIPv6ICMPPacket(netip.MustParseAddr("2001:db8::2"), netip.MustParseAddr("2606:4700:4700::1111"), icmpv6TypeEchoRequest, 0, 1, 1)
|
||||
packet[7] = 3
|
||||
|
||||
if err := bridge.HandleV3(context.Background(), packet); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(destination.packets) != 1 {
|
||||
t.Fatalf("expected one routed packet, got %d", len(destination.packets))
|
||||
}
|
||||
if got := destination.packets[0][7]; got != 2 {
|
||||
t.Fatalf("expected decremented IPv6 hop limit, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestICMPBridgeHandleV2TTLExceededTracedReply(t *testing.T) {
|
||||
var preMatchCalls int
|
||||
traceIdentity := bytes.Repeat([]byte{0x6b}, icmpTraceIdentityLength)
|
||||
sender := &captureDatagramSender{}
|
||||
router := &testRouter{
|
||||
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
|
||||
preMatchCalls++
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
inboundInstance := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
router: router,
|
||||
}
|
||||
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2)
|
||||
|
||||
source := netip.MustParseAddr("198.18.0.2")
|
||||
target := netip.MustParseAddr("1.1.1.1")
|
||||
packet := buildIPv4ICMPPacket(source, target, icmpv4TypeEchoRequest, 0, 1, 1)
|
||||
packet[8] = 1
|
||||
packet = append(packet, traceIdentity...)
|
||||
|
||||
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIPWithTrace, packet); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if preMatchCalls != 0 {
|
||||
t.Fatalf("expected TTL exceeded to bypass routing, got %d route lookups", preMatchCalls)
|
||||
}
|
||||
if len(sender.sent) != 1 {
|
||||
t.Fatalf("expected one TTL exceeded reply, got %d", len(sender.sent))
|
||||
}
|
||||
reply := sender.sent[0]
|
||||
if reply[len(reply)-1] != byte(DatagramV2TypeIPWithTrace) {
|
||||
t.Fatalf("expected traced v2 reply, got type %d", reply[len(reply)-1])
|
||||
}
|
||||
gotIdentity := reply[len(reply)-1-icmpTraceIdentityLength : len(reply)-1]
|
||||
if !bytes.Equal(gotIdentity, traceIdentity) {
|
||||
t.Fatalf("unexpected trace identity: %x", gotIdentity)
|
||||
}
|
||||
rawReply := reply[:len(reply)-1-icmpTraceIdentityLength]
|
||||
packetInfo, err := ParseICMPPacket(rawReply)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if packetInfo.ICMPType != icmpv4TypeTimeExceeded || packetInfo.ICMPCode != 0 {
|
||||
t.Fatalf("expected IPv4 time exceeded reply, got type=%d code=%d", packetInfo.ICMPType, packetInfo.ICMPCode)
|
||||
}
|
||||
if packetInfo.SourceIP != target || packetInfo.Destination != source {
|
||||
t.Fatalf("unexpected TTL exceeded routing: src=%s dst=%s", packetInfo.SourceIP, packetInfo.Destination)
|
||||
}
|
||||
}
|
||||
|
||||
func TestICMPBridgeHandleV3TTLExceededReply(t *testing.T) {
|
||||
var preMatchCalls int
|
||||
sender := &captureDatagramSender{}
|
||||
router := &testRouter{
|
||||
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
|
||||
preMatchCalls++
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
inboundInstance := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
|
||||
router: router,
|
||||
}
|
||||
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV3)
|
||||
|
||||
source := netip.MustParseAddr("2001:db8::2")
|
||||
target := netip.MustParseAddr("2606:4700:4700::1111")
|
||||
packet := buildIPv6ICMPPacket(source, target, icmpv6TypeEchoRequest, 0, 1, 1)
|
||||
packet[7] = 1
|
||||
|
||||
if err := bridge.HandleV3(context.Background(), packet); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if preMatchCalls != 0 {
|
||||
t.Fatalf("expected TTL exceeded to bypass routing, got %d route lookups", preMatchCalls)
|
||||
}
|
||||
if len(sender.sent) != 1 {
|
||||
t.Fatalf("expected one TTL exceeded reply, got %d", len(sender.sent))
|
||||
}
|
||||
if sender.sent[0][0] != byte(DatagramV3TypeICMP) {
|
||||
t.Fatalf("expected v3 ICMP reply, got %d", sender.sent[0][0])
|
||||
}
|
||||
packetInfo, err := ParseICMPPacket(sender.sent[0][1:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if packetInfo.ICMPType != icmpv6TypeTimeExceeded || packetInfo.ICMPCode != 0 {
|
||||
t.Fatalf("expected IPv6 time exceeded reply, got type=%d code=%d", packetInfo.ICMPType, packetInfo.ICMPCode)
|
||||
}
|
||||
if packetInfo.SourceIP != target || packetInfo.Destination != source {
|
||||
t.Fatalf("unexpected TTL exceeded routing: src=%s dst=%s", packetInfo.SourceIP, packetInfo.Destination)
|
||||
}
|
||||
}
|
||||
|
||||
func TestICMPBridgeDropsNonEcho(t *testing.T) {
|
||||
var preMatchCalls int
|
||||
router := &testRouter{
|
||||
|
||||
@@ -162,3 +162,49 @@ func TestResolveHTTPServiceStatus(t *testing.T) {
|
||||
t.Fatalf("status service should keep request URL, got %s", requestURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResolvedServiceCanonicalizesWebSocketOrigin(t *testing.T) {
|
||||
testCases := []struct {
|
||||
rawService string
|
||||
wantScheme string
|
||||
}{
|
||||
{rawService: "ws://127.0.0.1:8080", wantScheme: "http"},
|
||||
{rawService: "wss://127.0.0.1:8443", wantScheme: "https"},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.rawService, func(t *testing.T) {
|
||||
service, err := parseResolvedService(testCase.rawService, defaultOriginRequestConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if service.BaseURL == nil {
|
||||
t.Fatal("expected base URL")
|
||||
}
|
||||
if service.BaseURL.Scheme != testCase.wantScheme {
|
||||
t.Fatalf("expected scheme %q, got %q", testCase.wantScheme, service.BaseURL.Scheme)
|
||||
}
|
||||
if service.Service != testCase.rawService {
|
||||
t.Fatalf("expected raw service to stay %q, got %q", testCase.rawService, service.Service)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveHTTPServiceWebSocketOrigin(t *testing.T) {
|
||||
inboundInstance := newTestIngressInbound(t)
|
||||
inboundInstance.configManager.activeConfig = RuntimeConfig{
|
||||
Ingress: []compiledIngressRule{
|
||||
{Hostname: "foo.com", Service: mustResolvedService(t, "ws://127.0.0.1:8083")},
|
||||
{Service: mustResolvedService(t, "http_status:404")},
|
||||
},
|
||||
}
|
||||
|
||||
_, requestURL, err := inboundInstance.resolveHTTPService("https://foo.com/path?q=1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if requestURL != "http://127.0.0.1:8083/path?q=1" {
|
||||
t.Fatalf("expected websocket origin to be canonicalized, got %s", requestURL)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,6 +74,20 @@ func (s ResolvedService) BuildRequestURL(requestURL string) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func canonicalizeHTTPOriginURL(parsedURL *url.URL) *url.URL {
|
||||
if parsedURL == nil {
|
||||
return nil
|
||||
}
|
||||
canonicalURL := *parsedURL
|
||||
switch canonicalURL.Scheme {
|
||||
case "ws":
|
||||
canonicalURL.Scheme = "http"
|
||||
case "wss":
|
||||
canonicalURL.Scheme = "https"
|
||||
}
|
||||
return &canonicalURL
|
||||
}
|
||||
|
||||
type compiledIngressRule struct {
|
||||
Hostname string
|
||||
PunycodeHostname string
|
||||
@@ -451,7 +465,7 @@ func parseResolvedService(rawService string, originRequest OriginRequestConfig)
|
||||
Kind: ResolvedServiceHTTP,
|
||||
Service: rawService,
|
||||
Destination: parseServiceDestination(parsedURL),
|
||||
BaseURL: parsedURL,
|
||||
BaseURL: canonicalizeHTTPOriginURL(parsedURL),
|
||||
OriginRequest: originRequest,
|
||||
}, nil
|
||||
case "tcp", "ssh", "rdp", "smb":
|
||||
|
||||
Reference in New Issue
Block a user