mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-14 12:48:28 +10:00
Route cloudflare TCP through pipe
This commit is contained in:
@@ -23,7 +23,6 @@ import (
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/pipe"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -196,14 +195,22 @@ func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser
|
||||
}
|
||||
defer i.flowLimiter.Release(limit)
|
||||
|
||||
targetConn, err := i.dialWarpTCP(ctx, metadata.Destination)
|
||||
warpRouting := i.configManager.Snapshot().WarpRouting
|
||||
targetConn, cleanup, err := i.dialRouterTCPWithMetadata(ctx, metadata, routedPipeTCPOptions{
|
||||
timeout: warpRouting.ConnectTimeout,
|
||||
onHandshake: func(conn net.Conn) {
|
||||
_ = applyTCPKeepAlive(conn, warpRouting.TCPKeepAlive)
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "dial tcp origin: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
defer cleanup()
|
||||
|
||||
// Cloudflare expects an optimistic ACK here so the routed TCP path can sniff
|
||||
// the real input stream before the outbound connection is fully established.
|
||||
err = respWriter.WriteResponse(nil, nil)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "write connect response: ", err)
|
||||
@@ -391,12 +398,7 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser,
|
||||
}
|
||||
|
||||
func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig, requestHost string) (*http.Transport, func()) {
|
||||
input, output := pipe.Pipe()
|
||||
done := make(chan struct{})
|
||||
go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) {
|
||||
common.Close(input, output)
|
||||
close(done)
|
||||
}))
|
||||
input, cleanup, _ := i.dialRouterTCPWithMetadata(ctx, metadata, routedPipeTCPOptions{})
|
||||
|
||||
transport := &http.Transport{
|
||||
DisableCompression: true,
|
||||
@@ -411,13 +413,7 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter
|
||||
},
|
||||
}
|
||||
applyHTTPTransportProxy(transport, originRequest)
|
||||
return transport, func() {
|
||||
common.Close(input, output)
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}
|
||||
return transport, cleanup
|
||||
}
|
||||
|
||||
func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost string) (*http.Transport, func(), error) {
|
||||
|
||||
@@ -138,10 +138,6 @@ func (r *testRouter) RoutePacketConnectionEx(ctx context.Context, conn N.PacketC
|
||||
onClose(nil)
|
||||
}
|
||||
|
||||
func (r *testRouter) DialRouteConnection(ctx context.Context, metadata adapter.InboundContext) (net.Conn, error) {
|
||||
return net.Dial("tcp", metadata.Destination.String())
|
||||
}
|
||||
|
||||
func (r *testRouter) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) {
|
||||
conn, err := net.Dial("udp", metadata.Destination.String())
|
||||
if err != nil {
|
||||
|
||||
@@ -4,9 +4,7 @@ package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
@@ -14,39 +12,12 @@ import (
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type routedOriginDialer interface {
|
||||
DialRouteConnection(ctx context.Context, metadata adapter.InboundContext) (net.Conn, error)
|
||||
type routedOriginPacketDialer interface {
|
||||
DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error)
|
||||
}
|
||||
|
||||
func (i *Inbound) dialWarpTCP(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
|
||||
originDialer, ok := i.router.(routedOriginDialer)
|
||||
if !ok {
|
||||
return nil, E.New("router does not support cloudflare routed dialing")
|
||||
}
|
||||
|
||||
warpRouting := i.configManager.Snapshot().WarpRouting
|
||||
if warpRouting.ConnectTimeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, warpRouting.ConnectTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
conn, err := originDialer.DialRouteConnection(ctx, adapter.InboundContext{
|
||||
Inbound: i.Tag(),
|
||||
InboundType: i.Type(),
|
||||
Network: N.NetworkTCP,
|
||||
Destination: destination,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = applyTCPKeepAlive(conn, warpRouting.TCPKeepAlive)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (i *Inbound) dialWarpPacketConnection(ctx context.Context, destination netip.AddrPort) (N.PacketConn, error) {
|
||||
originDialer, ok := i.router.(routedOriginDialer)
|
||||
originDialer, ok := i.router.(routedOriginPacketDialer)
|
||||
if !ok {
|
||||
return nil, E.New("router does not support cloudflare routed packet dialing")
|
||||
}
|
||||
@@ -66,21 +37,3 @@ func (i *Inbound) dialWarpPacketConnection(ctx context.Context, destination neti
|
||||
UDPConnect: true,
|
||||
})
|
||||
}
|
||||
|
||||
func applyTCPKeepAlive(conn net.Conn, keepAlive time.Duration) error {
|
||||
if keepAlive <= 0 {
|
||||
return nil
|
||||
}
|
||||
type keepAliveConn interface {
|
||||
SetKeepAlive(bool) error
|
||||
SetKeepAlivePeriod(time.Duration) error
|
||||
}
|
||||
tcpConn, ok := conn.(keepAliveConn)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if err := tcpConn.SetKeepAlive(true); err != nil {
|
||||
return err
|
||||
}
|
||||
return tcpConn.SetKeepAlivePeriod(keepAlive)
|
||||
}
|
||||
|
||||
90
protocol/cloudflare/router_pipe.go
Normal file
90
protocol/cloudflare/router_pipe.go
Normal file
@@ -0,0 +1,90 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing/common"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/pipe"
|
||||
)
|
||||
|
||||
type routedPipeTCPOptions struct {
|
||||
timeout time.Duration
|
||||
onHandshake func(net.Conn)
|
||||
}
|
||||
|
||||
type routedPipeTCPConn struct {
|
||||
net.Conn
|
||||
handshakeOnce sync.Once
|
||||
onHandshake func(net.Conn)
|
||||
}
|
||||
|
||||
func (c *routedPipeTCPConn) ConnHandshakeSuccess(conn net.Conn) error {
|
||||
if c.onHandshake != nil {
|
||||
c.handshakeOnce.Do(func() {
|
||||
c.onHandshake(conn)
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Inbound) dialRouterTCPWithMetadata(ctx context.Context, metadata adapter.InboundContext, options routedPipeTCPOptions) (net.Conn, func(), error) {
|
||||
input, output := pipe.Pipe()
|
||||
routerConn := &routedPipeTCPConn{
|
||||
Conn: output,
|
||||
onHandshake: options.onHandshake,
|
||||
}
|
||||
done := make(chan struct{})
|
||||
|
||||
routeCtx := ctx
|
||||
var cancel context.CancelFunc
|
||||
if options.timeout > 0 {
|
||||
routeCtx, cancel = context.WithTimeout(ctx, options.timeout)
|
||||
}
|
||||
|
||||
var closeOnce sync.Once
|
||||
closePipe := func() {
|
||||
closeOnce.Do(func() {
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
common.Close(input, routerConn)
|
||||
})
|
||||
}
|
||||
go i.router.RouteConnectionEx(routeCtx, routerConn, metadata, N.OnceClose(func(it error) {
|
||||
closePipe()
|
||||
close(done)
|
||||
}))
|
||||
|
||||
return input, func() {
|
||||
closePipe()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func applyTCPKeepAlive(conn net.Conn, keepAlive time.Duration) error {
|
||||
if keepAlive <= 0 {
|
||||
return nil
|
||||
}
|
||||
type keepAliveConn interface {
|
||||
SetKeepAlive(bool) error
|
||||
SetKeepAlivePeriod(time.Duration) error
|
||||
}
|
||||
tcpConn, ok := conn.(keepAliveConn)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if err := tcpConn.SetKeepAlive(true); err != nil {
|
||||
return err
|
||||
}
|
||||
return tcpConn.SetKeepAlivePeriod(keepAlive)
|
||||
}
|
||||
165
protocol/cloudflare/router_pipe_test.go
Normal file
165
protocol/cloudflare/router_pipe_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
//go:build with_cloudflared
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func TestHandleTCPStreamUsesRouteConnectionEx(t *testing.T) {
|
||||
listener := startEchoListener(t)
|
||||
defer listener.Close()
|
||||
|
||||
router := &countingRouter{}
|
||||
inboundInstance := newSpecialServiceInboundWithRouter(t, router)
|
||||
|
||||
serverSide, clientSide := net.Pipe()
|
||||
defer clientSide.Close()
|
||||
|
||||
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
|
||||
responseDone := respWriter.done
|
||||
finished := make(chan struct{})
|
||||
go func() {
|
||||
inboundInstance.handleTCPStream(context.Background(), serverSide, respWriter, adapter.InboundContext{
|
||||
Destination: M.ParseSocksaddr(listener.Addr().String()),
|
||||
})
|
||||
close(finished)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-responseDone:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for connect response")
|
||||
}
|
||||
if respWriter.err != nil {
|
||||
t.Fatal("unexpected response error: ", respWriter.err)
|
||||
}
|
||||
|
||||
if err := clientSide.SetDeadline(time.Now().Add(time.Second)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
payload := []byte("ping")
|
||||
if _, err := clientSide.Write(payload); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
response := make([]byte, len(payload))
|
||||
if _, err := io.ReadFull(clientSide, response); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(response) != string(payload) {
|
||||
t.Fatalf("unexpected echo payload: %q", string(response))
|
||||
}
|
||||
if router.count.Load() != 1 {
|
||||
t.Fatalf("expected RouteConnectionEx to be used once, got %d", router.count.Load())
|
||||
}
|
||||
|
||||
_ = clientSide.Close()
|
||||
select {
|
||||
case <-finished:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for TCP stream handler to exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleTCPStreamWritesOptimisticAck(t *testing.T) {
|
||||
router := &blockingRouteRouter{
|
||||
started: make(chan struct{}),
|
||||
release: make(chan struct{}),
|
||||
}
|
||||
inboundInstance := newSpecialServiceInboundWithRouter(t, router)
|
||||
|
||||
serverSide, clientSide := net.Pipe()
|
||||
defer clientSide.Close()
|
||||
|
||||
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
|
||||
responseDone := respWriter.done
|
||||
finished := make(chan struct{})
|
||||
go func() {
|
||||
inboundInstance.handleTCPStream(context.Background(), serverSide, respWriter, adapter.InboundContext{
|
||||
Destination: M.ParseSocksaddr("127.0.0.1:443"),
|
||||
})
|
||||
close(finished)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-router.started:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for router goroutine to start")
|
||||
}
|
||||
select {
|
||||
case <-responseDone:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for optimistic connect response")
|
||||
}
|
||||
if respWriter.err != nil {
|
||||
t.Fatal("unexpected response error: ", respWriter.err)
|
||||
}
|
||||
|
||||
close(router.release)
|
||||
_ = clientSide.Close()
|
||||
select {
|
||||
case <-finished:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for TCP stream handler to exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutedPipeTCPConnHandshakeAppliesKeepAlive(t *testing.T) {
|
||||
left, right := net.Pipe()
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
|
||||
remoteConn := &keepAliveTestConn{Conn: right}
|
||||
routerConn := &routedPipeTCPConn{
|
||||
Conn: left,
|
||||
onHandshake: func(conn net.Conn) {
|
||||
_ = applyTCPKeepAlive(conn, 15*time.Second)
|
||||
},
|
||||
}
|
||||
if err := routerConn.ConnHandshakeSuccess(remoteConn); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !remoteConn.enabled {
|
||||
t.Fatal("expected keepalive to be enabled")
|
||||
}
|
||||
if remoteConn.period != 15*time.Second {
|
||||
t.Fatalf("unexpected keepalive period: %s", remoteConn.period)
|
||||
}
|
||||
}
|
||||
|
||||
type blockingRouteRouter struct {
|
||||
testRouter
|
||||
started chan struct{}
|
||||
release chan struct{}
|
||||
}
|
||||
|
||||
func (r *blockingRouteRouter) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
close(r.started)
|
||||
<-r.release
|
||||
_ = conn.Close()
|
||||
onClose(nil)
|
||||
}
|
||||
|
||||
type keepAliveTestConn struct {
|
||||
net.Conn
|
||||
enabled bool
|
||||
period time.Duration
|
||||
}
|
||||
|
||||
func (c *keepAliveTestConn) SetKeepAlive(enabled bool) error {
|
||||
c.enabled = enabled
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *keepAliveTestConn) SetKeepAlivePeriod(period time.Duration) error {
|
||||
c.period = period
|
||||
return nil
|
||||
}
|
||||
@@ -13,16 +13,13 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/transport/v2raywebsocket"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/pipe"
|
||||
"github.com/sagernet/ws"
|
||||
)
|
||||
|
||||
@@ -118,25 +115,13 @@ func requestHeaderValue(request *ConnectRequest, headerName string) string {
|
||||
}
|
||||
|
||||
func (i *Inbound) dialRouterTCP(ctx context.Context, destination M.Socksaddr) (net.Conn, func(), error) {
|
||||
input, output := pipe.Pipe()
|
||||
done := make(chan struct{})
|
||||
metadata := adapter.InboundContext{
|
||||
Inbound: i.Tag(),
|
||||
InboundType: i.Type(),
|
||||
Network: N.NetworkTCP,
|
||||
Destination: destination,
|
||||
}
|
||||
go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) {
|
||||
common.Close(input, output)
|
||||
close(done)
|
||||
}))
|
||||
return input, func() {
|
||||
common.Close(input, output)
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}, nil
|
||||
return i.dialRouterTCPWithMetadata(ctx, metadata, routedPipeTCPOptions{})
|
||||
}
|
||||
|
||||
func (i *Inbound) serveSocksProxy(ctx context.Context, conn net.Conn, policy *ipRulePolicy) error {
|
||||
|
||||
@@ -68,6 +68,7 @@ func newSpecialServiceInboundWithRouter(t *testing.T, router adapter.Router) *In
|
||||
router: router,
|
||||
logger: logFactory.NewLogger("test"),
|
||||
configManager: configManager,
|
||||
flowLimiter: &FlowLimiter{},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user