Files
sing-box/protocol/cloudflare/icmp_test.go
2026-03-31 15:32:56 +08:00

395 lines
13 KiB
Go

//go:build with_cloudflared
package cloudflare
import (
"bytes"
"context"
"encoding/binary"
"net/netip"
"testing"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
)
type captureDatagramSender struct {
sent [][]byte
}
func (s *captureDatagramSender) SendDatagram(data []byte) error {
s.sent = append(s.sent, append([]byte(nil), data...))
return nil
}
type fakeDirectRouteDestination struct {
routeContext tun.DirectRouteContext
packets [][]byte
reply func(packet []byte) []byte
closed bool
}
func (d *fakeDirectRouteDestination) WritePacket(packet *buf.Buffer) error {
data := append([]byte(nil), packet.Bytes()...)
packet.Release()
d.packets = append(d.packets, data)
if d.reply != nil {
reply := d.reply(data)
if reply != nil {
return d.routeContext.WritePacket(reply)
}
}
return nil
}
func (d *fakeDirectRouteDestination) Close() error {
d.closed = true
return nil
}
func (d *fakeDirectRouteDestination) IsClosed() bool {
return d.closed
}
func TestICMPBridgeHandleV2RoutesEchoRequest(t *testing.T) {
var (
preMatchCalls int
captured adapter.InboundContext
destination *fakeDirectRouteDestination
)
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
preMatchCalls++
captured = metadata
destination = &fakeDirectRouteDestination{routeContext: routeContext}
return destination, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
sender := &captureDatagramSender{}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2)
source := netip.MustParseAddr("198.18.0.2")
target := netip.MustParseAddr("1.1.1.1")
packet1 := buildIPv4ICMPPacket(source, target, 8, 0, 1, 1)
packet2 := buildIPv4ICMPPacket(source, target, 8, 0, 1, 2)
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet1); err != nil {
t.Fatal(err)
}
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet2); err != nil {
t.Fatal(err)
}
if preMatchCalls != 1 {
t.Fatalf("expected one direct-route lookup, got %d", preMatchCalls)
}
if captured.Network != N.NetworkICMP {
t.Fatalf("expected NetworkICMP, got %s", captured.Network)
}
if captured.Source.Addr != source || captured.Destination.Addr != target {
t.Fatalf("unexpected metadata source/destination: %#v", captured)
}
if len(destination.packets) != 2 {
t.Fatalf("expected two packets written, got %d", len(destination.packets))
}
if len(sender.sent) != 0 {
t.Fatalf("expected no reply datagrams, got %d", len(sender.sent))
}
}
func TestICMPBridgeHandleV2TracedReply(t *testing.T) {
traceIdentity := bytes.Repeat([]byte{0x7a}, icmpTraceIdentityLength)
sender := &captureDatagramSender{}
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
return &fakeDirectRouteDestination{
routeContext: routeContext,
reply: buildEchoReply,
}, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2)
request := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), 8, 0, 9, 7)
request = append(request, traceIdentity...)
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIPWithTrace, request); err != nil {
t.Fatal(err)
}
if len(sender.sent) != 1 {
t.Fatalf("expected one reply datagram, got %d", len(sender.sent))
}
reply := sender.sent[0]
if reply[len(reply)-1] != byte(DatagramV2TypeIPWithTrace) {
t.Fatalf("expected traced v2 reply, got type %d", reply[len(reply)-1])
}
gotIdentity := reply[len(reply)-1-icmpTraceIdentityLength : len(reply)-1]
if !bytes.Equal(gotIdentity, traceIdentity) {
t.Fatalf("unexpected trace identity: %x", gotIdentity)
}
}
func TestICMPBridgeHandleV3Reply(t *testing.T) {
sender := &captureDatagramSender{}
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
return &fakeDirectRouteDestination{
routeContext: routeContext,
reply: buildEchoReply,
}, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV3)
request := buildIPv6ICMPPacket(netip.MustParseAddr("2001:db8::2"), netip.MustParseAddr("2606:4700:4700::1111"), 128, 0, 3, 5)
if err := bridge.HandleV3(context.Background(), request); err != nil {
t.Fatal(err)
}
if len(sender.sent) != 1 {
t.Fatalf("expected one reply datagram, got %d", len(sender.sent))
}
reply := sender.sent[0]
if reply[0] != byte(DatagramV3TypeICMP) {
t.Fatalf("expected v3 ICMP datagram, got %d", reply[0])
}
}
func TestICMPBridgeDecrementsIPv4TTLBeforeRouting(t *testing.T) {
var destination *fakeDirectRouteDestination
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
destination = &fakeDirectRouteDestination{routeContext: routeContext}
return destination, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, &captureDatagramSender{}, icmpWireV2)
packet := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), icmpv4TypeEchoRequest, 0, 1, 1)
packet[8] = 5
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet); err != nil {
t.Fatal(err)
}
if len(destination.packets) != 1 {
t.Fatalf("expected one routed packet, got %d", len(destination.packets))
}
if got := destination.packets[0][8]; got != 4 {
t.Fatalf("expected decremented IPv4 TTL, got %d", got)
}
}
func TestICMPBridgeDecrementsIPv6HopLimitBeforeRouting(t *testing.T) {
var destination *fakeDirectRouteDestination
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
destination = &fakeDirectRouteDestination{routeContext: routeContext}
return destination, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, &captureDatagramSender{}, icmpWireV3)
packet := buildIPv6ICMPPacket(netip.MustParseAddr("2001:db8::2"), netip.MustParseAddr("2606:4700:4700::1111"), icmpv6TypeEchoRequest, 0, 1, 1)
packet[7] = 3
if err := bridge.HandleV3(context.Background(), packet); err != nil {
t.Fatal(err)
}
if len(destination.packets) != 1 {
t.Fatalf("expected one routed packet, got %d", len(destination.packets))
}
if got := destination.packets[0][7]; got != 2 {
t.Fatalf("expected decremented IPv6 hop limit, got %d", got)
}
}
func TestICMPBridgeHandleV2TTLExceededTracedReply(t *testing.T) {
var preMatchCalls int
traceIdentity := bytes.Repeat([]byte{0x6b}, icmpTraceIdentityLength)
sender := &captureDatagramSender{}
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
preMatchCalls++
return nil, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2)
source := netip.MustParseAddr("198.18.0.2")
target := netip.MustParseAddr("1.1.1.1")
packet := buildIPv4ICMPPacket(source, target, icmpv4TypeEchoRequest, 0, 1, 1)
packet[8] = 1
packet = append(packet, traceIdentity...)
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIPWithTrace, packet); err != nil {
t.Fatal(err)
}
if preMatchCalls != 0 {
t.Fatalf("expected TTL exceeded to bypass routing, got %d route lookups", preMatchCalls)
}
if len(sender.sent) != 1 {
t.Fatalf("expected one TTL exceeded reply, got %d", len(sender.sent))
}
reply := sender.sent[0]
if reply[len(reply)-1] != byte(DatagramV2TypeIPWithTrace) {
t.Fatalf("expected traced v2 reply, got type %d", reply[len(reply)-1])
}
gotIdentity := reply[len(reply)-1-icmpTraceIdentityLength : len(reply)-1]
if !bytes.Equal(gotIdentity, traceIdentity) {
t.Fatalf("unexpected trace identity: %x", gotIdentity)
}
rawReply := reply[:len(reply)-1-icmpTraceIdentityLength]
packetInfo, err := ParseICMPPacket(rawReply)
if err != nil {
t.Fatal(err)
}
if packetInfo.ICMPType != icmpv4TypeTimeExceeded || packetInfo.ICMPCode != 0 {
t.Fatalf("expected IPv4 time exceeded reply, got type=%d code=%d", packetInfo.ICMPType, packetInfo.ICMPCode)
}
if packetInfo.SourceIP != target || packetInfo.Destination != source {
t.Fatalf("unexpected TTL exceeded routing: src=%s dst=%s", packetInfo.SourceIP, packetInfo.Destination)
}
}
func TestICMPBridgeHandleV3TTLExceededReply(t *testing.T) {
var preMatchCalls int
sender := &captureDatagramSender{}
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
preMatchCalls++
return nil, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV3)
source := netip.MustParseAddr("2001:db8::2")
target := netip.MustParseAddr("2606:4700:4700::1111")
packet := buildIPv6ICMPPacket(source, target, icmpv6TypeEchoRequest, 0, 1, 1)
packet[7] = 1
if err := bridge.HandleV3(context.Background(), packet); err != nil {
t.Fatal(err)
}
if preMatchCalls != 0 {
t.Fatalf("expected TTL exceeded to bypass routing, got %d route lookups", preMatchCalls)
}
if len(sender.sent) != 1 {
t.Fatalf("expected one TTL exceeded reply, got %d", len(sender.sent))
}
if sender.sent[0][0] != byte(DatagramV3TypeICMP) {
t.Fatalf("expected v3 ICMP reply, got %d", sender.sent[0][0])
}
packetInfo, err := ParseICMPPacket(sender.sent[0][1:])
if err != nil {
t.Fatal(err)
}
if packetInfo.ICMPType != icmpv6TypeTimeExceeded || packetInfo.ICMPCode != 0 {
t.Fatalf("expected IPv6 time exceeded reply, got type=%d code=%d", packetInfo.ICMPType, packetInfo.ICMPCode)
}
if packetInfo.SourceIP != target || packetInfo.Destination != source {
t.Fatalf("unexpected TTL exceeded routing: src=%s dst=%s", packetInfo.SourceIP, packetInfo.Destination)
}
}
func TestICMPBridgeDropsNonEcho(t *testing.T) {
var preMatchCalls int
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
preMatchCalls++
return nil, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
sender := &captureDatagramSender{}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2)
packet := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), 3, 0, 1, 1)
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet); err != nil {
t.Fatal(err)
}
if preMatchCalls != 0 {
t.Fatalf("expected no route lookup, got %d", preMatchCalls)
}
if len(sender.sent) != 0 {
t.Fatalf("expected no sender datagrams, got %d", len(sender.sent))
}
}
func buildEchoReply(packet []byte) []byte {
info, err := ParseICMPPacket(packet)
if err != nil {
panic(err)
}
switch info.IPVersion {
case 4:
return buildIPv4ICMPPacket(info.Destination, info.SourceIP, 0, 0, info.Identifier, info.Sequence)
case 6:
return buildIPv6ICMPPacket(info.Destination, info.SourceIP, 129, 0, info.Identifier, info.Sequence)
default:
panic("unsupported version")
}
}
func buildIPv4ICMPPacket(source, destination netip.Addr, icmpType, icmpCode uint8, identifier, sequence uint16) []byte {
packet := make([]byte, 28)
packet[0] = 0x45
binary.BigEndian.PutUint16(packet[2:4], uint16(len(packet)))
packet[8] = 64
packet[9] = 1
copy(packet[12:16], source.AsSlice())
copy(packet[16:20], destination.AsSlice())
packet[20] = icmpType
packet[21] = icmpCode
binary.BigEndian.PutUint16(packet[24:26], identifier)
binary.BigEndian.PutUint16(packet[26:28], sequence)
return packet
}
func buildIPv6ICMPPacket(source, destination netip.Addr, icmpType, icmpCode uint8, identifier, sequence uint16) []byte {
packet := make([]byte, 48)
packet[0] = 0x60
binary.BigEndian.PutUint16(packet[4:6], 8)
packet[6] = 58
packet[7] = 64
copy(packet[8:24], source.AsSlice())
copy(packet[24:40], destination.AsSlice())
packet[40] = icmpType
packet[41] = icmpCode
binary.BigEndian.PutUint16(packet[44:46], identifier)
binary.BigEndian.PutUint16(packet[46:48], sequence)
return packet
}