diff --git a/option/cloudflared.go b/option/cloudflared.go index e597ebb77..7daaafb9f 100644 --- a/option/cloudflared.go +++ b/option/cloudflared.go @@ -7,6 +7,7 @@ type CloudflaredInboundOptions struct { HAConnections int `json:"ha_connections,omitempty"` Protocol string `json:"protocol,omitempty"` ControlDialer DialerOptions `json:"control_dialer,omitempty"` + TunnelDialer DialerOptions `json:"tunnel_dialer,omitempty"` EdgeIPVersion int `json:"edge_ip_version,omitempty"` DatagramVersion string `json:"datagram_version,omitempty"` GracePeriod badoption.Duration `json:"grace_period,omitempty"` diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index 56ac895e5..0b21cab06 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -78,7 +78,7 @@ func NewHTTP2Connection( ServerName: h2EdgeSNI, } - tcpConn, err := inbound.controlDialer.DialContext(ctx, "tcp", M.SocksaddrFrom(edgeAddr.TCP.AddrPort().Addr(), edgeAddr.TCP.AddrPort().Port())) + tcpConn, err := inbound.tunnelDialer.DialContext(ctx, "tcp", M.SocksaddrFrom(edgeAddr.TCP.AddrPort().Addr(), edgeAddr.TCP.AddrPort().Port())) if err != nil { return nil, E.Cause(err, "dial edge TCP") } diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go index e83bf8298..32f872a56 100644 --- a/protocol/cloudflare/connection_quic.go +++ b/protocol/cloudflare/connection_quic.go @@ -80,10 +80,6 @@ func (c *closeableQUICConn) CloseWithError(code quic.ApplicationErrorCode, reaso return err } -var ( - quicPortByConnIndex = make(map[uint8]int) - quicPortAccess sync.Mutex -) // NewQUICConnection dials the edge and establishes a QUIC connection. func NewQUICConnection( @@ -96,7 +92,7 @@ func NewQUICConnection( features []string, numPreviousAttempts uint8, gracePeriod time.Duration, - controlDialer N.Dialer, + tunnelDialer N.Dialer, onConnected func(), logger log.ContextLogger, ) (*QUICConnection, error) { @@ -121,7 +117,7 @@ func NewQUICConnection( InitialPacketSize: quicInitialPacketSize(edgeAddr.IPVersion), } - udpConn, err := createUDPConnForConnIndex(ctx, connIndex, edgeAddr, controlDialer) + udpConn, err := createUDPConnForConnIndex(ctx, edgeAddr, tunnelDialer) if err != nil { return nil, E.Cause(err, "listen UDP for QUIC edge") } @@ -147,11 +143,15 @@ func NewQUICConnection( }, nil } -func createUDPConnForConnIndex(ctx context.Context, connIndex uint8, edgeAddr *EdgeAddr, controlDialer N.Dialer) (*net.UDPConn, error) { - quicPortAccess.Lock() - defer quicPortAccess.Unlock() - - packetConn, err := controlDialer.ListenPacket(ctx, M.SocksaddrFrom(edgeAddr.UDP.AddrPort().Addr(), edgeAddr.UDP.AddrPort().Port())) +// createUDPConnForConnIndex creates a UDP socket for QUIC via the tunnel dialer. +// Unlike cloudflared, we do not attempt to reuse previously-bound ports across +// reconnects — the dialer interface does not support specifying local ports, +// and fixed port binding is not important for our use case. +// We also do not apply Darwin-specific udp4/udp6 network selection to work around +// quic-go#3793 (DF bit on macOS dual-stack); the dialer controls network selection +// and this is a non-critical platform-specific limitation. +func createUDPConnForConnIndex(ctx context.Context, edgeAddr *EdgeAddr, tunnelDialer N.Dialer) (*net.UDPConn, error) { + packetConn, err := tunnelDialer.ListenPacket(ctx, M.SocksaddrFrom(edgeAddr.UDP.AddrPort().Addr(), edgeAddr.UDP.AddrPort().Port())) if err != nil { return nil, err } @@ -160,12 +160,6 @@ func createUDPConnForConnIndex(ctx context.Context, connIndex uint8, edgeAddr *E packetConn.Close() return nil, fmt.Errorf("unexpected packet conn type %T", packetConn) } - udpAddr, ok := udpConn.LocalAddr().(*net.UDPAddr) - if !ok { - udpConn.Close() - return nil, fmt.Errorf("unexpected local UDP address type %T", udpConn.LocalAddr()) - } - quicPortByConnIndex[connIndex] = udpAddr.Port return udpConn, nil } @@ -368,9 +362,11 @@ type DatagramSender interface { SendDatagram(data []byte) error } -// streamReadWriteCloser adapts a *quic.Stream to io.ReadWriteCloser. +// streamReadWriteCloser adapts a *quic.Stream to io.ReadWriteCloser +// with mutex-protected writes and safe close semantics. type streamReadWriteCloser struct { - stream *quic.Stream + stream *quic.Stream + writeAccess sync.Mutex } func newStreamReadWriteCloser(stream *quic.Stream) *streamReadWriteCloser { @@ -382,10 +378,15 @@ func (s *streamReadWriteCloser) Read(p []byte) (int, error) { } func (s *streamReadWriteCloser) Write(p []byte) (int, error) { + s.writeAccess.Lock() + defer s.writeAccess.Unlock() return s.stream.Write(p) } func (s *streamReadWriteCloser) Close() error { + _ = s.stream.SetWriteDeadline(time.Now()) + s.writeAccess.Lock() + defer s.writeAccess.Unlock() s.stream.CancelRead(0) return s.stream.Close() } diff --git a/protocol/cloudflare/control.go b/protocol/cloudflare/control.go index e6a0b070f..194a72259 100644 --- a/protocol/cloudflare/control.go +++ b/protocol/cloudflare/control.go @@ -46,8 +46,8 @@ type registrationRPCClient interface { // NewRegistrationClient creates a Cap'n Proto RPC client over the given stream. // The stream should be the first QUIC stream (control stream). func NewRegistrationClient(ctx context.Context, stream io.ReadWriteCloser) *RegistrationClient { - transport := rpc.StreamTransport(stream) - conn := rpc.NewConn(transport) + transport := safeTransport(stream) + conn := newRPCClientConn(transport, ctx) return &RegistrationClient{ client: tunnelrpc.TunnelServer{Client: conn.Bootstrap(ctx)}, rpcConn: conn, diff --git a/protocol/cloudflare/datagram_rpc_v3.go b/protocol/cloudflare/datagram_rpc_v3.go index 38af323ff..6c40db882 100644 --- a/protocol/cloudflare/datagram_rpc_v3.go +++ b/protocol/cloudflare/datagram_rpc_v3.go @@ -10,8 +10,6 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc" E "github.com/sagernet/sing/common/exceptions" - - "zombiezen.com/go/capnproto2/rpc" ) var ( @@ -63,8 +61,8 @@ func ServeV3RPCStream(ctx context.Context, stream io.ReadWriteCloser, inbound *I logger: logger, } client := tunnelrpc.CloudflaredServer_ServerToClient(srv) - transport := rpc.StreamTransport(stream) - rpcConn := rpc.NewConn(transport, rpc.MainInterface(client.Client)) + transport := safeTransport(stream) + rpcConn := newRPCServerConn(transport, client.Client) <-rpcConn.Done() E.Errors( rpcConn.Close(), diff --git a/protocol/cloudflare/datagram_v2.go b/protocol/cloudflare/datagram_v2.go index 8fa3ffa62..d7454cab1 100644 --- a/protocol/cloudflare/datagram_v2.go +++ b/protocol/cloudflare/datagram_v2.go @@ -76,8 +76,8 @@ var newV2SessionRPCClient = func(ctx context.Context, sender DatagramSender) (v2 if err != nil { return nil, err } - transport := rpc.StreamTransport(stream) - conn := rpc.NewConn(transport) + transport := safeTransport(stream) + conn := newRPCClientConn(transport, ctx) return &capnpV2SessionRPCClient{ client: tunnelrpc.SessionManager{Client: conn.Bootstrap(ctx)}, rpcConn: conn, @@ -533,8 +533,8 @@ func ServeRPCStream(ctx context.Context, stream io.ReadWriteCloser, inbound *Inb logger: logger, } client := tunnelrpc.CloudflaredServer_ServerToClient(srv) - transport := rpc.StreamTransport(stream) - rpcConn := rpc.NewConn(transport, rpc.MainInterface(client.Client)) + transport := safeTransport(stream) + rpcConn := newRPCServerConn(transport, client.Client) <-rpcConn.Done() E.Errors( rpcConn.Close(), diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index f405d3636..674abef16 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -49,6 +49,7 @@ type Inbound struct { flowLimiter *FlowLimiter accessCache *accessValidatorCache controlDialer N.Dialer + tunnelDialer N.Dialer connectionAccess sync.Mutex connections []io.Closer @@ -110,6 +111,13 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo if err != nil { return nil, E.Cause(err, "build cloudflared control dialer") } + tunnelDialer, err := boxDialer.NewWithOptions(boxDialer.Options{ + Context: ctx, + Options: options.TunnelDialer, + }) + if err != nil { + return nil, E.Cause(err, "build cloudflared tunnel dialer") + } region := options.Region if region != "" && credentials.Endpoint != "" { @@ -140,6 +148,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo flowLimiter: &FlowLimiter{}, accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: controlDialer}, controlDialer: controlDialer, + tunnelDialer: tunnelDialer, datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), datagramV3Manager: NewDatagramV3SessionManager(), @@ -310,7 +319,7 @@ func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, datagramVersion connection, err := NewQUICConnection( i.ctx, edgeAddr, connIndex, i.credentials, i.connectorID, datagramVersion, - features, numPreviousAttempts, i.gracePeriod, i.controlDialer, func() { + features, numPreviousAttempts, i.gracePeriod, i.tunnelDialer, func() { i.notifyConnected(connIndex) }, i.logger, ) diff --git a/protocol/cloudflare/safe_transport.go b/protocol/cloudflare/safe_transport.go new file mode 100644 index 000000000..99b7880d7 --- /dev/null +++ b/protocol/cloudflare/safe_transport.go @@ -0,0 +1,63 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "io" + "time" + + E "github.com/sagernet/sing/common/exceptions" + + capnp "zombiezen.com/go/capnproto2" + "zombiezen.com/go/capnproto2/rpc" +) + +const ( + safeTransportMaxRetries = 3 + safeTransportRetryInterval = 500 * time.Millisecond +) + +type safeReadWriteCloser struct { + io.ReadWriteCloser + retries int +} + +func (s *safeReadWriteCloser) Read(p []byte) (int, error) { + n, err := s.ReadWriteCloser.Read(p) + if n == 0 && err != nil && isTemporaryError(err) { + if s.retries >= safeTransportMaxRetries { + return 0, E.Cause(err, "read capnproto transport after multiple temporary errors") + } + s.retries++ + time.Sleep(safeTransportRetryInterval) + return n, err + } + if err == nil { + s.retries = 0 + } + return n, err +} + +func isTemporaryError(err error) bool { + type temporary interface{ Temporary() bool } + t, ok := err.(temporary) + return ok && t.Temporary() +} + +func safeTransport(stream io.ReadWriteCloser) rpc.Transport { + return rpc.StreamTransport(&safeReadWriteCloser{ReadWriteCloser: stream}) +} + +type noopCapnpLogger struct{} + +func (noopCapnpLogger) Infof(ctx context.Context, format string, args ...interface{}) {} +func (noopCapnpLogger) Errorf(ctx context.Context, format string, args ...interface{}) {} + +func newRPCClientConn(transport rpc.Transport, ctx context.Context) *rpc.Conn { + return rpc.NewConn(transport, rpc.ConnLog(noopCapnpLogger{})) +} + +func newRPCServerConn(transport rpc.Transport, client capnp.Client) *rpc.Conn { + return rpc.NewConn(transport, rpc.MainInterface(client), rpc.ConnLog(noopCapnpLogger{})) +}