diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 746e83bb3..5cedfe2a2 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -23,7 +23,6 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/pipe" ) const ( @@ -196,14 +195,22 @@ func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser } defer i.flowLimiter.Release(limit) - targetConn, err := i.dialWarpTCP(ctx, metadata.Destination) + warpRouting := i.configManager.Snapshot().WarpRouting + targetConn, cleanup, err := i.dialRouterTCPWithMetadata(ctx, metadata, routedPipeTCPOptions{ + timeout: warpRouting.ConnectTimeout, + onHandshake: func(conn net.Conn) { + _ = applyTCPKeepAlive(conn, warpRouting.TCPKeepAlive) + }, + }) if err != nil { i.logger.ErrorContext(ctx, "dial tcp origin: ", err) respWriter.WriteResponse(err, nil) return } - defer targetConn.Close() + defer cleanup() + // Cloudflare expects an optimistic ACK here so the routed TCP path can sniff + // the real input stream before the outbound connection is fully established. err = respWriter.WriteResponse(nil, nil) if err != nil { i.logger.ErrorContext(ctx, "write connect response: ", err) @@ -391,12 +398,7 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, } func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig, requestHost string) (*http.Transport, func()) { - input, output := pipe.Pipe() - done := make(chan struct{}) - go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) { - common.Close(input, output) - close(done) - })) + input, cleanup, _ := i.dialRouterTCPWithMetadata(ctx, metadata, routedPipeTCPOptions{}) transport := &http.Transport{ DisableCompression: true, @@ -411,13 +413,7 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter }, } applyHTTPTransportProxy(transport, originRequest) - return transport, func() { - common.Close(input, output) - select { - case <-done: - case <-time.After(time.Second): - } - } + return transport, cleanup } func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost string) (*http.Transport, func(), error) { diff --git a/protocol/cloudflare/helpers_test.go b/protocol/cloudflare/helpers_test.go index 81f829daf..253eed5cf 100644 --- a/protocol/cloudflare/helpers_test.go +++ b/protocol/cloudflare/helpers_test.go @@ -138,10 +138,6 @@ func (r *testRouter) RoutePacketConnectionEx(ctx context.Context, conn N.PacketC onClose(nil) } -func (r *testRouter) DialRouteConnection(ctx context.Context, metadata adapter.InboundContext) (net.Conn, error) { - return net.Dial("tcp", metadata.Destination.String()) -} - func (r *testRouter) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) { conn, err := net.Dial("udp", metadata.Destination.String()) if err != nil { diff --git a/protocol/cloudflare/origin_dial.go b/protocol/cloudflare/origin_dial.go index 5c6f80b7a..c937aa35b 100644 --- a/protocol/cloudflare/origin_dial.go +++ b/protocol/cloudflare/origin_dial.go @@ -4,9 +4,7 @@ package cloudflare import ( "context" - "net" "net/netip" - "time" "github.com/sagernet/sing-box/adapter" E "github.com/sagernet/sing/common/exceptions" @@ -14,39 +12,12 @@ import ( N "github.com/sagernet/sing/common/network" ) -type routedOriginDialer interface { - DialRouteConnection(ctx context.Context, metadata adapter.InboundContext) (net.Conn, error) +type routedOriginPacketDialer interface { DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) } -func (i *Inbound) dialWarpTCP(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { - originDialer, ok := i.router.(routedOriginDialer) - if !ok { - return nil, E.New("router does not support cloudflare routed dialing") - } - - warpRouting := i.configManager.Snapshot().WarpRouting - if warpRouting.ConnectTimeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, warpRouting.ConnectTimeout) - defer cancel() - } - - conn, err := originDialer.DialRouteConnection(ctx, adapter.InboundContext{ - Inbound: i.Tag(), - InboundType: i.Type(), - Network: N.NetworkTCP, - Destination: destination, - }) - if err != nil { - return nil, err - } - _ = applyTCPKeepAlive(conn, warpRouting.TCPKeepAlive) - return conn, nil -} - func (i *Inbound) dialWarpPacketConnection(ctx context.Context, destination netip.AddrPort) (N.PacketConn, error) { - originDialer, ok := i.router.(routedOriginDialer) + originDialer, ok := i.router.(routedOriginPacketDialer) if !ok { return nil, E.New("router does not support cloudflare routed packet dialing") } @@ -66,21 +37,3 @@ func (i *Inbound) dialWarpPacketConnection(ctx context.Context, destination neti UDPConnect: true, }) } - -func applyTCPKeepAlive(conn net.Conn, keepAlive time.Duration) error { - if keepAlive <= 0 { - return nil - } - type keepAliveConn interface { - SetKeepAlive(bool) error - SetKeepAlivePeriod(time.Duration) error - } - tcpConn, ok := conn.(keepAliveConn) - if !ok { - return nil - } - if err := tcpConn.SetKeepAlive(true); err != nil { - return err - } - return tcpConn.SetKeepAlivePeriod(keepAlive) -} diff --git a/protocol/cloudflare/router_pipe.go b/protocol/cloudflare/router_pipe.go new file mode 100644 index 000000000..9431fa228 --- /dev/null +++ b/protocol/cloudflare/router_pipe.go @@ -0,0 +1,90 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "net" + "sync" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/pipe" +) + +type routedPipeTCPOptions struct { + timeout time.Duration + onHandshake func(net.Conn) +} + +type routedPipeTCPConn struct { + net.Conn + handshakeOnce sync.Once + onHandshake func(net.Conn) +} + +func (c *routedPipeTCPConn) ConnHandshakeSuccess(conn net.Conn) error { + if c.onHandshake != nil { + c.handshakeOnce.Do(func() { + c.onHandshake(conn) + }) + } + return nil +} + +func (i *Inbound) dialRouterTCPWithMetadata(ctx context.Context, metadata adapter.InboundContext, options routedPipeTCPOptions) (net.Conn, func(), error) { + input, output := pipe.Pipe() + routerConn := &routedPipeTCPConn{ + Conn: output, + onHandshake: options.onHandshake, + } + done := make(chan struct{}) + + routeCtx := ctx + var cancel context.CancelFunc + if options.timeout > 0 { + routeCtx, cancel = context.WithTimeout(ctx, options.timeout) + } + + var closeOnce sync.Once + closePipe := func() { + closeOnce.Do(func() { + if cancel != nil { + cancel() + } + common.Close(input, routerConn) + }) + } + go i.router.RouteConnectionEx(routeCtx, routerConn, metadata, N.OnceClose(func(it error) { + closePipe() + close(done) + })) + + return input, func() { + closePipe() + select { + case <-done: + case <-time.After(time.Second): + } + }, nil +} + +func applyTCPKeepAlive(conn net.Conn, keepAlive time.Duration) error { + if keepAlive <= 0 { + return nil + } + type keepAliveConn interface { + SetKeepAlive(bool) error + SetKeepAlivePeriod(time.Duration) error + } + tcpConn, ok := conn.(keepAliveConn) + if !ok { + return nil + } + if err := tcpConn.SetKeepAlive(true); err != nil { + return err + } + return tcpConn.SetKeepAlivePeriod(keepAlive) +} diff --git a/protocol/cloudflare/router_pipe_test.go b/protocol/cloudflare/router_pipe_test.go new file mode 100644 index 000000000..779ade53d --- /dev/null +++ b/protocol/cloudflare/router_pipe_test.go @@ -0,0 +1,165 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "io" + "net" + "testing" + "time" + + "github.com/sagernet/sing-box/adapter" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func TestHandleTCPStreamUsesRouteConnectionEx(t *testing.T) { + listener := startEchoListener(t) + defer listener.Close() + + router := &countingRouter{} + inboundInstance := newSpecialServiceInboundWithRouter(t, router) + + serverSide, clientSide := net.Pipe() + defer clientSide.Close() + + respWriter := &fakeConnectResponseWriter{done: make(chan struct{})} + responseDone := respWriter.done + finished := make(chan struct{}) + go func() { + inboundInstance.handleTCPStream(context.Background(), serverSide, respWriter, adapter.InboundContext{ + Destination: M.ParseSocksaddr(listener.Addr().String()), + }) + close(finished) + }() + + select { + case <-responseDone: + case <-time.After(time.Second): + t.Fatal("timed out waiting for connect response") + } + if respWriter.err != nil { + t.Fatal("unexpected response error: ", respWriter.err) + } + + if err := clientSide.SetDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatal(err) + } + payload := []byte("ping") + if _, err := clientSide.Write(payload); err != nil { + t.Fatal(err) + } + response := make([]byte, len(payload)) + if _, err := io.ReadFull(clientSide, response); err != nil { + t.Fatal(err) + } + if string(response) != string(payload) { + t.Fatalf("unexpected echo payload: %q", string(response)) + } + if router.count.Load() != 1 { + t.Fatalf("expected RouteConnectionEx to be used once, got %d", router.count.Load()) + } + + _ = clientSide.Close() + select { + case <-finished: + case <-time.After(time.Second): + t.Fatal("timed out waiting for TCP stream handler to exit") + } +} + +func TestHandleTCPStreamWritesOptimisticAck(t *testing.T) { + router := &blockingRouteRouter{ + started: make(chan struct{}), + release: make(chan struct{}), + } + inboundInstance := newSpecialServiceInboundWithRouter(t, router) + + serverSide, clientSide := net.Pipe() + defer clientSide.Close() + + respWriter := &fakeConnectResponseWriter{done: make(chan struct{})} + responseDone := respWriter.done + finished := make(chan struct{}) + go func() { + inboundInstance.handleTCPStream(context.Background(), serverSide, respWriter, adapter.InboundContext{ + Destination: M.ParseSocksaddr("127.0.0.1:443"), + }) + close(finished) + }() + + select { + case <-router.started: + case <-time.After(time.Second): + t.Fatal("timed out waiting for router goroutine to start") + } + select { + case <-responseDone: + case <-time.After(time.Second): + t.Fatal("timed out waiting for optimistic connect response") + } + if respWriter.err != nil { + t.Fatal("unexpected response error: ", respWriter.err) + } + + close(router.release) + _ = clientSide.Close() + select { + case <-finished: + case <-time.After(time.Second): + t.Fatal("timed out waiting for TCP stream handler to exit") + } +} + +func TestRoutedPipeTCPConnHandshakeAppliesKeepAlive(t *testing.T) { + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + remoteConn := &keepAliveTestConn{Conn: right} + routerConn := &routedPipeTCPConn{ + Conn: left, + onHandshake: func(conn net.Conn) { + _ = applyTCPKeepAlive(conn, 15*time.Second) + }, + } + if err := routerConn.ConnHandshakeSuccess(remoteConn); err != nil { + t.Fatal(err) + } + if !remoteConn.enabled { + t.Fatal("expected keepalive to be enabled") + } + if remoteConn.period != 15*time.Second { + t.Fatalf("unexpected keepalive period: %s", remoteConn.period) + } +} + +type blockingRouteRouter struct { + testRouter + started chan struct{} + release chan struct{} +} + +func (r *blockingRouteRouter) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + close(r.started) + <-r.release + _ = conn.Close() + onClose(nil) +} + +type keepAliveTestConn struct { + net.Conn + enabled bool + period time.Duration +} + +func (c *keepAliveTestConn) SetKeepAlive(enabled bool) error { + c.enabled = enabled + return nil +} + +func (c *keepAliveTestConn) SetKeepAlivePeriod(period time.Duration) error { + c.period = period + return nil +} diff --git a/protocol/cloudflare/special_service.go b/protocol/cloudflare/special_service.go index e60edfa0d..c5b5e0f4d 100644 --- a/protocol/cloudflare/special_service.go +++ b/protocol/cloudflare/special_service.go @@ -13,16 +13,13 @@ import ( "net/url" "strconv" "strings" - "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/transport/v2raywebsocket" - "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/pipe" "github.com/sagernet/ws" ) @@ -118,25 +115,13 @@ func requestHeaderValue(request *ConnectRequest, headerName string) string { } func (i *Inbound) dialRouterTCP(ctx context.Context, destination M.Socksaddr) (net.Conn, func(), error) { - input, output := pipe.Pipe() - done := make(chan struct{}) metadata := adapter.InboundContext{ Inbound: i.Tag(), InboundType: i.Type(), Network: N.NetworkTCP, Destination: destination, } - go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) { - common.Close(input, output) - close(done) - })) - return input, func() { - common.Close(input, output) - select { - case <-done: - case <-time.After(time.Second): - } - }, nil + return i.dialRouterTCPWithMetadata(ctx, metadata, routedPipeTCPOptions{}) } func (i *Inbound) serveSocksProxy(ctx context.Context, conn net.Conn, policy *ipRulePolicy) error { diff --git a/protocol/cloudflare/special_service_test.go b/protocol/cloudflare/special_service_test.go index 35d8245fc..2df966afa 100644 --- a/protocol/cloudflare/special_service_test.go +++ b/protocol/cloudflare/special_service_test.go @@ -68,6 +68,7 @@ func newSpecialServiceInboundWithRouter(t *testing.T, router adapter.Router) *In router: router, logger: logFactory.NewLogger("test"), configManager: configManager, + flowLimiter: &FlowLimiter{}, } } diff --git a/route/dial.go b/route/dial.go index 013a6350a..48187debe 100644 --- a/route/dial.go +++ b/route/dial.go @@ -6,7 +6,6 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" - tf "github.com/sagernet/sing-box/common/tlsfragment" C "github.com/sagernet/sing-box/constant" R "github.com/sagernet/sing-box/route/rule" "github.com/sagernet/sing/common" @@ -16,45 +15,6 @@ import ( N "github.com/sagernet/sing/common/network" ) -// DialRouteConnection dials a routed TCP connection for metadata without requiring an upstream accepted socket. -func (r *Router) DialRouteConnection(ctx context.Context, metadata adapter.InboundContext) (net.Conn, error) { - metadata.Network = N.NetworkTCP - ctx = adapter.WithContext(ctx, &metadata) - - selectedRule, selectedOutbound, err := r.selectRoutedOutbound(ctx, &metadata, N.NetworkTCP) - if err != nil { - return nil, err - } - - var conn net.Conn - if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() { - conn, err = dialer.DialSerialNetwork( - ctx, - selectedOutbound, - N.NetworkTCP, - metadata.Destination, - metadata.DestinationAddresses, - metadata.NetworkStrategy, - metadata.NetworkType, - metadata.FallbackNetworkType, - metadata.FallbackDelay, - ) - } else { - conn, err = selectedOutbound.DialContext(ctx, N.NetworkTCP, metadata.Destination) - } - if err != nil { - return nil, err - } - - if metadata.TLSFragment || metadata.TLSRecordFragment { - conn = tf.NewConn(conn, ctx, metadata.TLSFragment, metadata.TLSRecordFragment, metadata.TLSFragmentFallbackDelay) - } - for _, tracker := range r.trackers { - conn = tracker.RoutedConnection(ctx, conn, metadata, selectedRule, selectedOutbound) - } - return conn, nil -} - // DialRoutePacketConnection dials a routed connected UDP packet connection for metadata. func (r *Router) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) { metadata.Network = N.NetworkUDP