//go:build with_cloudflared package cloudflare import ( "context" "encoding/binary" "net/netip" "sync" "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) const ( icmpFlowTimeout = 30 * time.Second icmpTraceIdentityLength = 16 + 8 + 1 defaultICMPPacketTTL = 64 icmpErrorHeaderLen = 8 icmpv4TypeEchoRequest = 8 icmpv4TypeEchoReply = 0 icmpv4TypeTimeExceeded = 11 icmpv6TypeEchoRequest = 128 icmpv6TypeEchoReply = 129 icmpv6TypeTimeExceeded = 3 ) type ICMPTraceContext struct { Traced bool Identity []byte } type ICMPFlowKey struct { IPVersion uint8 SourceIP netip.Addr Destination netip.Addr } type ICMPRequestKey struct { Flow ICMPFlowKey Identifier uint16 Sequence uint16 } type ICMPPacketInfo struct { IPVersion uint8 Protocol uint8 SourceIP netip.Addr Destination netip.Addr ICMPType uint8 ICMPCode uint8 Identifier uint16 Sequence uint16 IPv4HeaderLen int IPv4TTL uint8 IPv6HopLimit uint8 RawPacket []byte } func (i ICMPPacketInfo) FlowKey() ICMPFlowKey { return ICMPFlowKey{ IPVersion: i.IPVersion, SourceIP: i.SourceIP, Destination: i.Destination, } } func (i ICMPPacketInfo) RequestKey() ICMPRequestKey { return ICMPRequestKey{ Flow: i.FlowKey(), Identifier: i.Identifier, Sequence: i.Sequence, } } func (i ICMPPacketInfo) ReplyRequestKey() ICMPRequestKey { return ICMPRequestKey{ Flow: ICMPFlowKey{ IPVersion: i.IPVersion, SourceIP: i.Destination, Destination: i.SourceIP, }, Identifier: i.Identifier, Sequence: i.Sequence, } } func (i ICMPPacketInfo) IsEchoRequest() bool { switch i.IPVersion { case 4: return i.ICMPType == icmpv4TypeEchoRequest && i.ICMPCode == 0 case 6: return i.ICMPType == icmpv6TypeEchoRequest && i.ICMPCode == 0 default: return false } } func (i ICMPPacketInfo) IsEchoReply() bool { switch i.IPVersion { case 4: return i.ICMPType == icmpv4TypeEchoReply && i.ICMPCode == 0 case 6: return i.ICMPType == icmpv6TypeEchoReply && i.ICMPCode == 0 default: return false } } func (i ICMPPacketInfo) TTL() uint8 { if i.IPVersion == 4 { return i.IPv4TTL } return i.IPv6HopLimit } func (i ICMPPacketInfo) TTLExpired() bool { return i.TTL() <= 1 } func (i *ICMPPacketInfo) DecrementTTL() error { switch i.IPVersion { case 4: if i.IPv4TTL == 0 || i.IPv4HeaderLen < 20 || len(i.RawPacket) < i.IPv4HeaderLen { return E.New("invalid IPv4 packet TTL state") } i.IPv4TTL-- i.RawPacket[8] = i.IPv4TTL binary.BigEndian.PutUint16(i.RawPacket[10:12], 0) binary.BigEndian.PutUint16(i.RawPacket[10:12], checksum(i.RawPacket[:i.IPv4HeaderLen], 0)) case 6: if i.IPv6HopLimit == 0 || len(i.RawPacket) < 40 { return E.New("invalid IPv6 packet hop limit state") } i.IPv6HopLimit-- i.RawPacket[7] = i.IPv6HopLimit default: return E.New("unsupported IP version: ", i.IPVersion) } return nil } type icmpWireVersion uint8 const ( icmpWireV2 icmpWireVersion = iota + 1 icmpWireV3 ) type icmpFlowState struct { writer *ICMPReplyWriter } type ICMPReplyWriter struct { sender DatagramSender wireVersion icmpWireVersion access sync.Mutex traces map[ICMPRequestKey]ICMPTraceContext } func NewICMPReplyWriter(sender DatagramSender, wireVersion icmpWireVersion) *ICMPReplyWriter { return &ICMPReplyWriter{ sender: sender, wireVersion: wireVersion, traces: make(map[ICMPRequestKey]ICMPTraceContext), } } func (w *ICMPReplyWriter) RegisterRequestTrace(packetInfo ICMPPacketInfo, traceContext ICMPTraceContext) { if !traceContext.Traced { return } w.access.Lock() w.traces[packetInfo.RequestKey()] = traceContext w.access.Unlock() } func (w *ICMPReplyWriter) WritePacket(packet []byte) error { packetInfo, err := ParseICMPPacket(packet) if err != nil { return err } if !packetInfo.IsEchoReply() { return nil } requestKey := packetInfo.ReplyRequestKey() w.access.Lock() traceContext, loaded := w.traces[requestKey] if loaded { delete(w.traces, requestKey) } w.access.Unlock() datagram, err := encodeICMPDatagram(packetInfo.RawPacket, w.wireVersion, traceContext) if err != nil { return err } return w.sender.SendDatagram(datagram) } type ICMPBridge struct { inbound *Inbound sender DatagramSender wireVersion icmpWireVersion routeMapping *tun.DirectRouteMapping flowAccess sync.Mutex flows map[ICMPFlowKey]*icmpFlowState } func NewICMPBridge(inbound *Inbound, sender DatagramSender, wireVersion icmpWireVersion) *ICMPBridge { return &ICMPBridge{ inbound: inbound, sender: sender, wireVersion: wireVersion, routeMapping: tun.NewDirectRouteMapping(icmpFlowTimeout), flows: make(map[ICMPFlowKey]*icmpFlowState), } } func (b *ICMPBridge) HandleV2(ctx context.Context, datagramType DatagramV2Type, payload []byte) error { traceContext := ICMPTraceContext{} switch datagramType { case DatagramV2TypeIP: case DatagramV2TypeIPWithTrace: if len(payload) < icmpTraceIdentityLength { return E.New("icmp trace payload is too short") } traceContext.Traced = true traceContext.Identity = append([]byte(nil), payload[len(payload)-icmpTraceIdentityLength:]...) payload = payload[:len(payload)-icmpTraceIdentityLength] default: return E.New("unsupported v2 icmp datagram type: ", datagramType) } return b.handlePacket(ctx, payload, traceContext) } func (b *ICMPBridge) HandleV3(ctx context.Context, payload []byte) error { return b.handlePacket(ctx, payload, ICMPTraceContext{}) } func (b *ICMPBridge) handlePacket(ctx context.Context, payload []byte, traceContext ICMPTraceContext) error { packetInfo, err := ParseICMPPacket(payload) if err != nil { return err } if !packetInfo.IsEchoRequest() { return nil } if packetInfo.TTLExpired() { ttlExceededPacket, err := buildICMPTTLExceededPacket(packetInfo, maxEncodedICMPPacketLen(b.wireVersion, traceContext)) if err != nil { return err } datagram, err := encodeICMPDatagram(ttlExceededPacket, b.wireVersion, traceContext) if err != nil { return err } return b.sender.SendDatagram(datagram) } if err := packetInfo.DecrementTTL(); err != nil { return err } state := b.getFlowState(packetInfo.FlowKey()) if traceContext.Traced { state.writer.RegisterRequestTrace(packetInfo, traceContext) } action, err := b.routeMapping.Lookup(tun.DirectRouteSession{ Source: packetInfo.SourceIP, Destination: packetInfo.Destination, }, func(timeout time.Duration) (tun.DirectRouteDestination, error) { metadata := adapter.InboundContext{ Inbound: b.inbound.Tag(), InboundType: b.inbound.Type(), IPVersion: packetInfo.IPVersion, Network: N.NetworkICMP, Source: M.SocksaddrFrom(packetInfo.SourceIP, 0), Destination: M.SocksaddrFrom(packetInfo.Destination, 0), OriginDestination: M.SocksaddrFrom(packetInfo.Destination, 0), } return b.inbound.router.PreMatch(metadata, state.writer, timeout, false) }) if err != nil { return nil } return action.WritePacket(buf.As(packetInfo.RawPacket).ToOwned()) } func (b *ICMPBridge) getFlowState(key ICMPFlowKey) *icmpFlowState { b.flowAccess.Lock() defer b.flowAccess.Unlock() state, loaded := b.flows[key] if loaded { return state } state = &icmpFlowState{ writer: NewICMPReplyWriter(b.sender, b.wireVersion), } b.flows[key] = state return state } func ParseICMPPacket(packet []byte) (ICMPPacketInfo, error) { if len(packet) < 1 { return ICMPPacketInfo{}, E.New("empty IP packet") } version := packet[0] >> 4 switch version { case 4: return parseIPv4ICMPPacket(packet) case 6: return parseIPv6ICMPPacket(packet) default: return ICMPPacketInfo{}, E.New("unsupported IP version: ", version) } } func parseIPv4ICMPPacket(packet []byte) (ICMPPacketInfo, error) { if len(packet) < 20 { return ICMPPacketInfo{}, E.New("IPv4 packet too short") } headerLen := int(packet[0]&0x0F) * 4 if headerLen < 20 || len(packet) < headerLen+8 { return ICMPPacketInfo{}, E.New("invalid IPv4 header length") } if packet[9] != 1 { return ICMPPacketInfo{}, E.New("IPv4 packet is not ICMP") } sourceIP, ok := netip.AddrFromSlice(packet[12:16]) if !ok { return ICMPPacketInfo{}, E.New("invalid IPv4 source address") } destinationIP, ok := netip.AddrFromSlice(packet[16:20]) if !ok { return ICMPPacketInfo{}, E.New("invalid IPv4 destination address") } return ICMPPacketInfo{ IPVersion: 4, Protocol: 1, SourceIP: sourceIP, Destination: destinationIP, ICMPType: packet[headerLen], ICMPCode: packet[headerLen+1], Identifier: binary.BigEndian.Uint16(packet[headerLen+4 : headerLen+6]), Sequence: binary.BigEndian.Uint16(packet[headerLen+6 : headerLen+8]), IPv4HeaderLen: headerLen, IPv4TTL: packet[8], RawPacket: append([]byte(nil), packet...), }, nil } func parseIPv6ICMPPacket(packet []byte) (ICMPPacketInfo, error) { if len(packet) < 48 { return ICMPPacketInfo{}, E.New("IPv6 packet too short") } if packet[6] != 58 { return ICMPPacketInfo{}, E.New("IPv6 packet is not ICMP") } sourceIP, ok := netip.AddrFromSlice(packet[8:24]) if !ok { return ICMPPacketInfo{}, E.New("invalid IPv6 source address") } destinationIP, ok := netip.AddrFromSlice(packet[24:40]) if !ok { return ICMPPacketInfo{}, E.New("invalid IPv6 destination address") } return ICMPPacketInfo{ IPVersion: 6, Protocol: 58, SourceIP: sourceIP, Destination: destinationIP, ICMPType: packet[40], ICMPCode: packet[41], Identifier: binary.BigEndian.Uint16(packet[44:46]), Sequence: binary.BigEndian.Uint16(packet[46:48]), IPv6HopLimit: packet[7], RawPacket: append([]byte(nil), packet...), }, nil } func maxEncodedICMPPacketLen(wireVersion icmpWireVersion, traceContext ICMPTraceContext) int { limit := maxV3UDPPayloadLen switch wireVersion { case icmpWireV2: limit -= typeIDLength if traceContext.Traced { limit -= len(traceContext.Identity) } case icmpWireV3: limit -= 1 default: return 0 } if limit < 0 { return 0 } return limit } func buildICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) { switch packetInfo.IPVersion { case 4: return buildIPv4ICMPTTLExceededPacket(packetInfo, maxPacketLen) case 6: return buildIPv6ICMPTTLExceededPacket(packetInfo, maxPacketLen) default: return nil, E.New("unsupported IP version: ", packetInfo.IPVersion) } } func buildIPv4ICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) { const headerLen = 20 if !packetInfo.SourceIP.Is4() || !packetInfo.Destination.Is4() { return nil, E.New("TTL exceeded packet requires IPv4 addresses") } if maxPacketLen <= headerLen+icmpErrorHeaderLen { return nil, E.New("TTL exceeded packet size limit is too small") } quotedLength := min(len(packetInfo.RawPacket), maxPacketLen-headerLen-icmpErrorHeaderLen) packet := make([]byte, headerLen+icmpErrorHeaderLen+quotedLength) packet[0] = 0x45 binary.BigEndian.PutUint16(packet[2:4], uint16(len(packet))) packet[8] = defaultICMPPacketTTL packet[9] = 1 copy(packet[12:16], packetInfo.Destination.AsSlice()) copy(packet[16:20], packetInfo.SourceIP.AsSlice()) packet[20] = icmpv4TypeTimeExceeded packet[21] = 0 copy(packet[headerLen+icmpErrorHeaderLen:], packetInfo.RawPacket[:quotedLength]) binary.BigEndian.PutUint16(packet[22:24], checksum(packet[20:], 0)) binary.BigEndian.PutUint16(packet[10:12], checksum(packet[:headerLen], 0)) return packet, nil } func buildIPv6ICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) { const headerLen = 40 if !packetInfo.SourceIP.Is6() || !packetInfo.Destination.Is6() { return nil, E.New("TTL exceeded packet requires IPv6 addresses") } if maxPacketLen <= headerLen+icmpErrorHeaderLen { return nil, E.New("TTL exceeded packet size limit is too small") } quotedLength := min(len(packetInfo.RawPacket), maxPacketLen-headerLen-icmpErrorHeaderLen) packet := make([]byte, headerLen+icmpErrorHeaderLen+quotedLength) packet[0] = 0x60 binary.BigEndian.PutUint16(packet[4:6], uint16(icmpErrorHeaderLen+quotedLength)) packet[6] = 58 packet[7] = defaultICMPPacketTTL copy(packet[8:24], packetInfo.Destination.AsSlice()) copy(packet[24:40], packetInfo.SourceIP.AsSlice()) packet[40] = icmpv6TypeTimeExceeded packet[41] = 0 copy(packet[headerLen+icmpErrorHeaderLen:], packetInfo.RawPacket[:quotedLength]) binary.BigEndian.PutUint16(packet[42:44], checksum(packet[40:], ipv6PseudoHeaderChecksum(packetInfo.Destination, packetInfo.SourceIP, uint32(icmpErrorHeaderLen+quotedLength), 58))) return packet, nil } func encodeICMPDatagram(packet []byte, wireVersion icmpWireVersion, traceContext ICMPTraceContext) ([]byte, error) { switch wireVersion { case icmpWireV2: return encodeV2ICMPDatagram(packet, traceContext) case icmpWireV3: return encodeV3ICMPDatagram(packet), nil default: return nil, E.New("unsupported icmp wire version: ", wireVersion) } } func ipv6PseudoHeaderChecksum(source, destination netip.Addr, payloadLength uint32, nextHeader uint8) uint32 { var sum uint32 sum = checksumSum(source.AsSlice(), sum) sum = checksumSum(destination.AsSlice(), sum) var lengthBytes [4]byte binary.BigEndian.PutUint32(lengthBytes[:], payloadLength) sum = checksumSum(lengthBytes[:], sum) sum = checksumSum([]byte{0, 0, 0, nextHeader}, sum) return sum } func checksumSum(data []byte, sum uint32) uint32 { for len(data) >= 2 { sum += uint32(binary.BigEndian.Uint16(data[:2])) data = data[2:] } if len(data) == 1 { sum += uint32(data[0]) << 8 } return sum } func checksum(data []byte, initial uint32) uint16 { sum := checksumSum(data, initial) for sum > 0xffff { sum = (sum >> 16) + (sum & 0xffff) } return ^uint16(sum) } func encodeV2ICMPDatagram(packet []byte, traceContext ICMPTraceContext) ([]byte, error) { if traceContext.Traced { data := make([]byte, 0, len(packet)+len(traceContext.Identity)+1) data = append(data, packet...) data = append(data, traceContext.Identity...) data = append(data, byte(DatagramV2TypeIPWithTrace)) return data, nil } data := make([]byte, 0, len(packet)+1) data = append(data, packet...) data = append(data, byte(DatagramV2TypeIP)) return data, nil } func encodeV3ICMPDatagram(packet []byte) []byte { data := make([]byte, 0, len(packet)+1) data = append(data, byte(DatagramV3TypeICMP)) data = append(data, packet...) return data }