Route cloudflare tunnel ICMP through sing-box router

This commit is contained in:
世界
2026-03-24 11:47:06 +08:00
parent b3cad021b8
commit 71c7a585ef
6 changed files with 641 additions and 7 deletions

View File

@@ -42,6 +42,7 @@ type DatagramV2Muxer struct {
inbound *Inbound
logger log.ContextLogger
sender DatagramSender
icmp *ICMPBridge
sessionAccess sync.RWMutex
sessions map[uuid.UUID]*udpSession
@@ -53,6 +54,7 @@ func NewDatagramV2Muxer(inbound *Inbound, sender DatagramSender, logger log.Cont
inbound: inbound,
logger: logger,
sender: sender,
icmp: NewICMPBridge(inbound, sender, icmpWireV2),
sessions: make(map[uuid.UUID]*udpSession),
}
}
@@ -70,10 +72,13 @@ func (m *DatagramV2Muxer) HandleDatagram(ctx context.Context, data []byte) {
case DatagramV2TypeUDP:
m.handleUDPDatagram(ctx, payload)
case DatagramV2TypeIP:
// TODO: ICMP handling
m.logger.Debug("received V2 IP datagram (ICMP not yet implemented)")
if err := m.icmp.HandleV2(ctx, datagramType, payload); err != nil {
m.logger.Debug("drop V2 ICMP datagram: ", err)
}
case DatagramV2TypeIPWithTrace:
m.logger.Debug("received V2 IP+trace datagram")
if err := m.icmp.HandleV2(ctx, datagramType, payload); err != nil {
m.logger.Debug("drop V2 traced ICMP datagram: ", err)
}
case DatagramV2TypeTracingSpan:
// Tracing spans, ignore
}

View File

@@ -61,6 +61,7 @@ type DatagramV3Muxer struct {
inbound *Inbound
logger log.ContextLogger
sender DatagramSender
icmp *ICMPBridge
sessionAccess sync.RWMutex
sessions map[RequestID]*v3Session
@@ -72,6 +73,7 @@ func NewDatagramV3Muxer(inbound *Inbound, sender DatagramSender, logger log.Cont
inbound: inbound,
logger: logger,
sender: sender,
icmp: NewICMPBridge(inbound, sender, icmpWireV3),
sessions: make(map[RequestID]*v3Session),
}
}
@@ -91,8 +93,9 @@ func (m *DatagramV3Muxer) HandleDatagram(ctx context.Context, data []byte) {
case DatagramV3TypePayload:
m.handlePayload(payload)
case DatagramV3TypeICMP:
// TODO: ICMP handling
m.logger.Debug("received V3 ICMP datagram (not yet implemented)")
if err := m.icmp.HandleV3(ctx, payload); err != nil {
m.logger.Debug("drop V3 ICMP datagram: ", err)
}
case DatagramV3TypeRegistrationResponse:
// Unexpected - we never send registrations
m.logger.Debug("received unexpected V3 registration response")

View File

@@ -21,6 +21,7 @@ import (
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-tun"
N "github.com/sagernet/sing/common/network"
"github.com/google/uuid"
@@ -80,7 +81,13 @@ func startOriginServer(t *testing.T) {
})
}
type testRouter struct{}
type testRouter struct {
preMatch func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error)
}
func (r *testRouter) Start(stage adapter.StartStage) error { return nil }
func (r *testRouter) Close() error { return nil }
func (r *testRouter) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
destination := metadata.Destination.String()
@@ -130,6 +137,27 @@ func (r *testRouter) RoutePacketConnectionEx(ctx context.Context, conn N.PacketC
onClose(nil)
}
func (r *testRouter) PreMatch(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
if r.preMatch != nil {
return r.preMatch(metadata, routeContext, timeout, supportBypass)
}
return nil, nil
}
func (r *testRouter) RuleSet(tag string) (adapter.RuleSet, bool) { return nil, false }
func (r *testRouter) Rules() []adapter.Rule { return nil }
func (r *testRouter) NeedFindProcess() bool { return false }
func (r *testRouter) NeedFindNeighbor() bool { return false }
func (r *testRouter) NeighborResolver() adapter.NeighborResolver { return nil }
func (r *testRouter) AppendTracker(tracker adapter.ConnectionTracker) {}
func (r *testRouter) ResetNetwork() {}
func newTestInbound(t *testing.T, token string, protocol string, haConnections int) *Inbound {
t.Helper()
credentials, err := parseToken(token)

356
protocol/cloudflare/icmp.go Normal file
View File

@@ -0,0 +1,356 @@
//go:build with_cloudflare_tunnel
package cloudflare
import (
"context"
"encoding/binary"
"net/netip"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
const (
icmpFlowTimeout = 30 * time.Second
icmpTraceIdentityLength = 16 + 8 + 1
)
type ICMPTraceContext struct {
Traced bool
Identity []byte
}
type ICMPFlowKey struct {
IPVersion uint8
SourceIP netip.Addr
Destination netip.Addr
}
type ICMPRequestKey struct {
Flow ICMPFlowKey
Identifier uint16
Sequence uint16
}
type ICMPPacketInfo struct {
IPVersion uint8
Protocol uint8
SourceIP netip.Addr
Destination netip.Addr
ICMPType uint8
ICMPCode uint8
Identifier uint16
Sequence uint16
RawPacket []byte
}
func (i ICMPPacketInfo) FlowKey() ICMPFlowKey {
return ICMPFlowKey{
IPVersion: i.IPVersion,
SourceIP: i.SourceIP,
Destination: i.Destination,
}
}
func (i ICMPPacketInfo) RequestKey() ICMPRequestKey {
return ICMPRequestKey{
Flow: i.FlowKey(),
Identifier: i.Identifier,
Sequence: i.Sequence,
}
}
func (i ICMPPacketInfo) ReplyRequestKey() ICMPRequestKey {
return ICMPRequestKey{
Flow: ICMPFlowKey{
IPVersion: i.IPVersion,
SourceIP: i.Destination,
Destination: i.SourceIP,
},
Identifier: i.Identifier,
Sequence: i.Sequence,
}
}
func (i ICMPPacketInfo) IsEchoRequest() bool {
switch i.IPVersion {
case 4:
return i.ICMPType == 8 && i.ICMPCode == 0
case 6:
return i.ICMPType == 128 && i.ICMPCode == 0
default:
return false
}
}
func (i ICMPPacketInfo) IsEchoReply() bool {
switch i.IPVersion {
case 4:
return i.ICMPType == 0 && i.ICMPCode == 0
case 6:
return i.ICMPType == 129 && i.ICMPCode == 0
default:
return false
}
}
type icmpWireVersion uint8
const (
icmpWireV2 icmpWireVersion = iota + 1
icmpWireV3
)
type icmpFlowState struct {
writer *ICMPReplyWriter
}
type ICMPReplyWriter struct {
sender DatagramSender
wireVersion icmpWireVersion
access sync.Mutex
traces map[ICMPRequestKey]ICMPTraceContext
}
func NewICMPReplyWriter(sender DatagramSender, wireVersion icmpWireVersion) *ICMPReplyWriter {
return &ICMPReplyWriter{
sender: sender,
wireVersion: wireVersion,
traces: make(map[ICMPRequestKey]ICMPTraceContext),
}
}
func (w *ICMPReplyWriter) RegisterRequestTrace(packetInfo ICMPPacketInfo, traceContext ICMPTraceContext) {
if !traceContext.Traced {
return
}
w.access.Lock()
w.traces[packetInfo.RequestKey()] = traceContext
w.access.Unlock()
}
func (w *ICMPReplyWriter) WritePacket(packet []byte) error {
packetInfo, err := ParseICMPPacket(packet)
if err != nil {
return err
}
if !packetInfo.IsEchoReply() {
return nil
}
requestKey := packetInfo.ReplyRequestKey()
w.access.Lock()
traceContext, loaded := w.traces[requestKey]
if loaded {
delete(w.traces, requestKey)
}
w.access.Unlock()
var datagram []byte
switch w.wireVersion {
case icmpWireV2:
datagram, err = encodeV2ICMPDatagram(packetInfo.RawPacket, traceContext)
case icmpWireV3:
datagram = encodeV3ICMPDatagram(packetInfo.RawPacket)
default:
err = E.New("unsupported icmp wire version: ", w.wireVersion)
}
if err != nil {
return err
}
return w.sender.SendDatagram(datagram)
}
type ICMPBridge struct {
inbound *Inbound
sender DatagramSender
wireVersion icmpWireVersion
routeMapping *tun.DirectRouteMapping
flowAccess sync.Mutex
flows map[ICMPFlowKey]*icmpFlowState
}
func NewICMPBridge(inbound *Inbound, sender DatagramSender, wireVersion icmpWireVersion) *ICMPBridge {
return &ICMPBridge{
inbound: inbound,
sender: sender,
wireVersion: wireVersion,
routeMapping: tun.NewDirectRouteMapping(icmpFlowTimeout),
flows: make(map[ICMPFlowKey]*icmpFlowState),
}
}
func (b *ICMPBridge) HandleV2(ctx context.Context, datagramType DatagramV2Type, payload []byte) error {
traceContext := ICMPTraceContext{}
switch datagramType {
case DatagramV2TypeIP:
case DatagramV2TypeIPWithTrace:
if len(payload) < icmpTraceIdentityLength {
return E.New("icmp trace payload is too short")
}
traceContext.Traced = true
traceContext.Identity = append([]byte(nil), payload[len(payload)-icmpTraceIdentityLength:]...)
payload = payload[:len(payload)-icmpTraceIdentityLength]
default:
return E.New("unsupported v2 icmp datagram type: ", datagramType)
}
return b.handlePacket(ctx, payload, traceContext)
}
func (b *ICMPBridge) HandleV3(ctx context.Context, payload []byte) error {
return b.handlePacket(ctx, payload, ICMPTraceContext{})
}
func (b *ICMPBridge) handlePacket(ctx context.Context, payload []byte, traceContext ICMPTraceContext) error {
packetInfo, err := ParseICMPPacket(payload)
if err != nil {
return err
}
if !packetInfo.IsEchoRequest() {
return nil
}
state := b.getFlowState(packetInfo.FlowKey())
if traceContext.Traced {
state.writer.RegisterRequestTrace(packetInfo, traceContext)
}
action, err := b.routeMapping.Lookup(tun.DirectRouteSession{
Source: packetInfo.SourceIP,
Destination: packetInfo.Destination,
}, func(timeout time.Duration) (tun.DirectRouteDestination, error) {
metadata := adapter.InboundContext{
Inbound: b.inbound.Tag(),
InboundType: b.inbound.Type(),
IPVersion: packetInfo.IPVersion,
Network: N.NetworkICMP,
Source: M.SocksaddrFrom(packetInfo.SourceIP, 0),
Destination: M.SocksaddrFrom(packetInfo.Destination, 0),
OriginDestination: M.SocksaddrFrom(packetInfo.Destination, 0),
}
return b.inbound.router.PreMatch(metadata, state.writer, timeout, false)
})
if err != nil {
return nil
}
return action.WritePacket(buf.As(packetInfo.RawPacket).ToOwned())
}
func (b *ICMPBridge) getFlowState(key ICMPFlowKey) *icmpFlowState {
b.flowAccess.Lock()
defer b.flowAccess.Unlock()
state, loaded := b.flows[key]
if loaded {
return state
}
state = &icmpFlowState{
writer: NewICMPReplyWriter(b.sender, b.wireVersion),
}
b.flows[key] = state
return state
}
func ParseICMPPacket(packet []byte) (ICMPPacketInfo, error) {
if len(packet) < 1 {
return ICMPPacketInfo{}, E.New("empty IP packet")
}
version := packet[0] >> 4
switch version {
case 4:
return parseIPv4ICMPPacket(packet)
case 6:
return parseIPv6ICMPPacket(packet)
default:
return ICMPPacketInfo{}, E.New("unsupported IP version: ", version)
}
}
func parseIPv4ICMPPacket(packet []byte) (ICMPPacketInfo, error) {
if len(packet) < 20 {
return ICMPPacketInfo{}, E.New("IPv4 packet too short")
}
headerLen := int(packet[0]&0x0F) * 4
if headerLen < 20 || len(packet) < headerLen+8 {
return ICMPPacketInfo{}, E.New("invalid IPv4 header length")
}
if packet[9] != 1 {
return ICMPPacketInfo{}, E.New("IPv4 packet is not ICMP")
}
sourceIP, ok := netip.AddrFromSlice(packet[12:16])
if !ok {
return ICMPPacketInfo{}, E.New("invalid IPv4 source address")
}
destinationIP, ok := netip.AddrFromSlice(packet[16:20])
if !ok {
return ICMPPacketInfo{}, E.New("invalid IPv4 destination address")
}
return ICMPPacketInfo{
IPVersion: 4,
Protocol: 1,
SourceIP: sourceIP,
Destination: destinationIP,
ICMPType: packet[headerLen],
ICMPCode: packet[headerLen+1],
Identifier: binary.BigEndian.Uint16(packet[headerLen+4 : headerLen+6]),
Sequence: binary.BigEndian.Uint16(packet[headerLen+6 : headerLen+8]),
RawPacket: append([]byte(nil), packet...),
}, nil
}
func parseIPv6ICMPPacket(packet []byte) (ICMPPacketInfo, error) {
if len(packet) < 48 {
return ICMPPacketInfo{}, E.New("IPv6 packet too short")
}
if packet[6] != 58 {
return ICMPPacketInfo{}, E.New("IPv6 packet is not ICMP")
}
sourceIP, ok := netip.AddrFromSlice(packet[8:24])
if !ok {
return ICMPPacketInfo{}, E.New("invalid IPv6 source address")
}
destinationIP, ok := netip.AddrFromSlice(packet[24:40])
if !ok {
return ICMPPacketInfo{}, E.New("invalid IPv6 destination address")
}
return ICMPPacketInfo{
IPVersion: 6,
Protocol: 58,
SourceIP: sourceIP,
Destination: destinationIP,
ICMPType: packet[40],
ICMPCode: packet[41],
Identifier: binary.BigEndian.Uint16(packet[44:46]),
Sequence: binary.BigEndian.Uint16(packet[46:48]),
RawPacket: append([]byte(nil), packet...),
}, nil
}
func encodeV2ICMPDatagram(packet []byte, traceContext ICMPTraceContext) ([]byte, error) {
if traceContext.Traced {
data := make([]byte, 0, len(packet)+len(traceContext.Identity)+1)
data = append(data, packet...)
data = append(data, traceContext.Identity...)
data = append(data, byte(DatagramV2TypeIPWithTrace))
return data, nil
}
data := make([]byte, 0, len(packet)+1)
data = append(data, packet...)
data = append(data, byte(DatagramV2TypeIP))
return data, nil
}
func encodeV3ICMPDatagram(packet []byte) []byte {
data := make([]byte, 0, len(packet)+1)
data = append(data, byte(DatagramV3TypeICMP))
data = append(data, packet...)
return data
}

View File

@@ -0,0 +1,242 @@
//go:build with_cloudflare_tunnel
package cloudflare
import (
"bytes"
"context"
"encoding/binary"
"net/netip"
"testing"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
)
type captureDatagramSender struct {
sent [][]byte
}
func (s *captureDatagramSender) SendDatagram(data []byte) error {
s.sent = append(s.sent, append([]byte(nil), data...))
return nil
}
type fakeDirectRouteDestination struct {
routeContext tun.DirectRouteContext
packets [][]byte
reply func(packet []byte) []byte
closed bool
}
func (d *fakeDirectRouteDestination) WritePacket(packet *buf.Buffer) error {
data := append([]byte(nil), packet.Bytes()...)
packet.Release()
d.packets = append(d.packets, data)
if d.reply != nil {
reply := d.reply(data)
if reply != nil {
return d.routeContext.WritePacket(reply)
}
}
return nil
}
func (d *fakeDirectRouteDestination) Close() error {
d.closed = true
return nil
}
func (d *fakeDirectRouteDestination) IsClosed() bool {
return d.closed
}
func TestICMPBridgeHandleV2RoutesEchoRequest(t *testing.T) {
var (
preMatchCalls int
captured adapter.InboundContext
destination *fakeDirectRouteDestination
)
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
preMatchCalls++
captured = metadata
destination = &fakeDirectRouteDestination{routeContext: routeContext}
return destination, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"),
router: router,
}
sender := &captureDatagramSender{}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2)
source := netip.MustParseAddr("198.18.0.2")
target := netip.MustParseAddr("1.1.1.1")
packet1 := buildIPv4ICMPPacket(source, target, 8, 0, 1, 1)
packet2 := buildIPv4ICMPPacket(source, target, 8, 0, 1, 2)
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet1); err != nil {
t.Fatal(err)
}
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet2); err != nil {
t.Fatal(err)
}
if preMatchCalls != 1 {
t.Fatalf("expected one direct-route lookup, got %d", preMatchCalls)
}
if captured.Network != N.NetworkICMP {
t.Fatalf("expected NetworkICMP, got %s", captured.Network)
}
if captured.Source.Addr != source || captured.Destination.Addr != target {
t.Fatalf("unexpected metadata source/destination: %#v", captured)
}
if len(destination.packets) != 2 {
t.Fatalf("expected two packets written, got %d", len(destination.packets))
}
if len(sender.sent) != 0 {
t.Fatalf("expected no reply datagrams, got %d", len(sender.sent))
}
}
func TestICMPBridgeHandleV2TracedReply(t *testing.T) {
traceIdentity := bytes.Repeat([]byte{0x7a}, icmpTraceIdentityLength)
sender := &captureDatagramSender{}
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
return &fakeDirectRouteDestination{
routeContext: routeContext,
reply: buildEchoReply,
}, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2)
request := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), 8, 0, 9, 7)
request = append(request, traceIdentity...)
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIPWithTrace, request); err != nil {
t.Fatal(err)
}
if len(sender.sent) != 1 {
t.Fatalf("expected one reply datagram, got %d", len(sender.sent))
}
reply := sender.sent[0]
if reply[len(reply)-1] != byte(DatagramV2TypeIPWithTrace) {
t.Fatalf("expected traced v2 reply, got type %d", reply[len(reply)-1])
}
gotIdentity := reply[len(reply)-1-icmpTraceIdentityLength : len(reply)-1]
if !bytes.Equal(gotIdentity, traceIdentity) {
t.Fatalf("unexpected trace identity: %x", gotIdentity)
}
}
func TestICMPBridgeHandleV3Reply(t *testing.T) {
sender := &captureDatagramSender{}
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
return &fakeDirectRouteDestination{
routeContext: routeContext,
reply: buildEchoReply,
}, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV3)
request := buildIPv6ICMPPacket(netip.MustParseAddr("2001:db8::2"), netip.MustParseAddr("2606:4700:4700::1111"), 128, 0, 3, 5)
if err := bridge.HandleV3(context.Background(), request); err != nil {
t.Fatal(err)
}
if len(sender.sent) != 1 {
t.Fatalf("expected one reply datagram, got %d", len(sender.sent))
}
reply := sender.sent[0]
if reply[0] != byte(DatagramV3TypeICMP) {
t.Fatalf("expected v3 ICMP datagram, got %d", reply[0])
}
}
func TestICMPBridgeDropsNonEcho(t *testing.T) {
var preMatchCalls int
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
preMatchCalls++
return nil, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"),
router: router,
}
sender := &captureDatagramSender{}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2)
packet := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), 3, 0, 1, 1)
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet); err != nil {
t.Fatal(err)
}
if preMatchCalls != 0 {
t.Fatalf("expected no route lookup, got %d", preMatchCalls)
}
if len(sender.sent) != 0 {
t.Fatalf("expected no sender datagrams, got %d", len(sender.sent))
}
}
func buildEchoReply(packet []byte) []byte {
info, err := ParseICMPPacket(packet)
if err != nil {
panic(err)
}
switch info.IPVersion {
case 4:
return buildIPv4ICMPPacket(info.Destination, info.SourceIP, 0, 0, info.Identifier, info.Sequence)
case 6:
return buildIPv6ICMPPacket(info.Destination, info.SourceIP, 129, 0, info.Identifier, info.Sequence)
default:
panic("unsupported version")
}
}
func buildIPv4ICMPPacket(source, destination netip.Addr, icmpType, icmpCode uint8, identifier, sequence uint16) []byte {
packet := make([]byte, 28)
packet[0] = 0x45
binary.BigEndian.PutUint16(packet[2:4], uint16(len(packet)))
packet[8] = 64
packet[9] = 1
copy(packet[12:16], source.AsSlice())
copy(packet[16:20], destination.AsSlice())
packet[20] = icmpType
packet[21] = icmpCode
binary.BigEndian.PutUint16(packet[24:26], identifier)
binary.BigEndian.PutUint16(packet[26:28], sequence)
return packet
}
func buildIPv6ICMPPacket(source, destination netip.Addr, icmpType, icmpCode uint8, identifier, sequence uint16) []byte {
packet := make([]byte, 48)
packet[0] = 0x60
binary.BigEndian.PutUint16(packet[4:6], 8)
packet[6] = 58
packet[7] = 64
copy(packet[8:24], source.AsSlice())
copy(packet[24:40], destination.AsSlice())
packet[40] = icmpType
packet[41] = icmpCode
binary.BigEndian.PutUint16(packet[44:46], identifier)
binary.BigEndian.PutUint16(packet[46:48], sequence)
return packet
}

View File

@@ -33,7 +33,7 @@ type Inbound struct {
inbound.Adapter
ctx context.Context
cancel context.CancelFunc
router adapter.ConnectionRouterEx
router adapter.Router
logger log.ContextLogger
credentials Credentials
connectorID uuid.UUID