Add domain sniffer
This commit is contained in:
58
common/sniff/dns.go
Normal file
58
common/sniff/dns.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package sniff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
)
|
||||
|
||||
func StreamDomainNameQuery(readCtx context.Context, reader io.Reader) (*adapter.InboundContext, error) {
|
||||
var length uint16
|
||||
err := binary.Read(reader, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if length > 512 {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
_buffer := buf.StackNewSize(int(length))
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
|
||||
readCtx, cancel := context.WithTimeout(readCtx, time.Millisecond*100)
|
||||
err = task.Run(readCtx, func() error {
|
||||
return common.Error(buffer.ReadFullFrom(reader, buffer.FreeLen()))
|
||||
})
|
||||
cancel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DomainNameQuery(readCtx, buffer.Bytes())
|
||||
}
|
||||
|
||||
func DomainNameQuery(ctx context.Context, packet []byte) (*adapter.InboundContext, error) {
|
||||
var parser dnsmessage.Parser
|
||||
_, err := parser.Start(packet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
question, err := parser.Question()
|
||||
if err != nil {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
domain := question.Name.String()
|
||||
if question.Class == dnsmessage.ClassINET && (question.Type == dnsmessage.TypeA || question.Type == dnsmessage.TypeAAAA) && IsDomainName(domain) {
|
||||
return &adapter.InboundContext{Protocol: C.ProtocolDNS, Domain: domain}, nil
|
||||
}
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
6
common/sniff/domain.go
Normal file
6
common/sniff/domain.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package sniff
|
||||
|
||||
import _ "unsafe" // for linkname
|
||||
|
||||
//go:linkname IsDomainName net.isDomainName
|
||||
func IsDomainName(domain string) bool
|
||||
19
common/sniff/http.go
Normal file
19
common/sniff/http.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package sniff
|
||||
|
||||
import (
|
||||
std_bufio "bufio"
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing/protocol/http"
|
||||
)
|
||||
|
||||
func HTTPHost(ctx context.Context, reader io.Reader) (*adapter.InboundContext, error) {
|
||||
request, err := http.ReadRequest(std_bufio.NewReader(reader))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &adapter.InboundContext{Protocol: C.ProtocolHTTP, Domain: request.Host}, nil
|
||||
}
|
||||
148
common/sniff/internal/qtls/qtls.go
Normal file
148
common/sniff/internal/qtls/qtls.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
const (
|
||||
VersionDraft29 = 0xff00001d
|
||||
Version1 = 0x1
|
||||
Version2 = 0x709a50c4
|
||||
)
|
||||
|
||||
var (
|
||||
SaltOld = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99}
|
||||
SaltV1 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a}
|
||||
SaltV2 = []byte{0xa7, 0x07, 0xc2, 0x03, 0xa5, 0x9b, 0x47, 0x18, 0x4a, 0x1d, 0x62, 0xca, 0x57, 0x04, 0x06, 0xea, 0x7a, 0xe3, 0xe5, 0xd3}
|
||||
)
|
||||
|
||||
const (
|
||||
HKDFLabelKeyV1 = "quic key"
|
||||
HKDFLabelKeyV2 = "quicv2 key"
|
||||
HKDFLabelIVV1 = "quic iv"
|
||||
HKDFLabelIVV2 = "quicv2 iv"
|
||||
HKDFLabelHeaderProtectionV1 = "quic hp"
|
||||
HKDFLabelHeaderProtectionV2 = "quicv2 hp"
|
||||
)
|
||||
|
||||
func AEADAESGCMTLS13(key, nonceMask []byte) cipher.AEAD {
|
||||
if len(nonceMask) != 12 {
|
||||
panic("tls: internal error: wrong nonce length")
|
||||
}
|
||||
aes, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(aes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ret := &xorNonceAEAD{aead: aead}
|
||||
copy(ret.nonceMask[:], nonceMask)
|
||||
return ret
|
||||
}
|
||||
|
||||
type xorNonceAEAD struct {
|
||||
nonceMask [12]byte
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number
|
||||
func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() }
|
||||
func (f *xorNonceAEAD) explicitNonceLen() int { return 0 }
|
||||
|
||||
func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData)
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData)
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func HKDFExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {
|
||||
b := make([]byte, 3, 3+6+len(label)+1+len(context))
|
||||
binary.BigEndian.PutUint16(b, uint16(length))
|
||||
b[2] = uint8(6 + len(label))
|
||||
b = append(b, []byte("tls13 ")...)
|
||||
b = append(b, []byte(label)...)
|
||||
b = b[:3+6+len(label)+1]
|
||||
b[3+6+len(label)] = uint8(len(context))
|
||||
b = append(b, context...)
|
||||
out := make([]byte, length)
|
||||
n, err := hkdf.Expand(hash.New, secret, b).Read(out)
|
||||
if err != nil || n != length {
|
||||
panic("quic: HKDF-Expand-Label invocation failed unexpectedly")
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func ReadUvarint(r io.ByteReader) (uint64, error) {
|
||||
firstByte, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// the first two bits of the first byte encode the length
|
||||
len := 1 << ((firstByte & 0xc0) >> 6)
|
||||
b1 := firstByte & (0xff - 0xc0)
|
||||
if len == 1 {
|
||||
return uint64(b1), nil
|
||||
}
|
||||
b2, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len == 2 {
|
||||
return uint64(b2) + uint64(b1)<<8, nil
|
||||
}
|
||||
b3, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b4, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len == 4 {
|
||||
return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil
|
||||
}
|
||||
b5, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b6, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b7, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b8, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil
|
||||
}
|
||||
187
common/sniff/quic.go
Normal file
187
common/sniff/quic.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package sniff
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/aes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/sniff/internal/qtls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
func QUICClientHello(ctx context.Context, packet []byte) (*adapter.InboundContext, error) {
|
||||
reader := bytes.NewReader(packet)
|
||||
|
||||
typeByte, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if typeByte&0x80 == 0 || typeByte&0x40 == 0 {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
var versionNumber uint32
|
||||
err = binary.Read(reader, binary.BigEndian, &versionNumber)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if versionNumber != qtls.VersionDraft29 && versionNumber != qtls.Version1 && versionNumber != qtls.Version2 {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
if (typeByte&0x30)>>4 == 0x0 {
|
||||
} else if (typeByte&0x30)>>4 != 0x01 {
|
||||
// 0-rtt
|
||||
} else {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
destConnIDLen, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
destConnID := make([]byte, destConnIDLen)
|
||||
_, err = io.ReadFull(reader, destConnID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
srcConnIDLen, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = io.CopyN(io.Discard, reader, int64(srcConnIDLen))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokenLen, err := qtls.ReadUvarint(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = io.CopyN(io.Discard, reader, int64(tokenLen))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
packetLen, err := qtls.ReadUvarint(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hdrLen := int(reader.Size()) - reader.Len()
|
||||
if hdrLen != len(packet)-int(packetLen) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
_, err = io.CopyN(io.Discard, reader, 4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pnBytes := make([]byte, aes.BlockSize)
|
||||
_, err = io.ReadFull(reader, pnBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var salt []byte
|
||||
switch versionNumber {
|
||||
case qtls.Version1:
|
||||
salt = qtls.SaltV1
|
||||
case qtls.Version2:
|
||||
salt = qtls.SaltV2
|
||||
default:
|
||||
salt = qtls.SaltOld
|
||||
}
|
||||
var hkdfHeaderProtectionLabel string
|
||||
switch versionNumber {
|
||||
case qtls.Version2:
|
||||
hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV2
|
||||
default:
|
||||
hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV1
|
||||
}
|
||||
initialSecret := hkdf.Extract(crypto.SHA256.New, destConnID, salt)
|
||||
secret := qtls.HKDFExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
|
||||
hpKey := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, hkdfHeaderProtectionLabel, 16)
|
||||
block, err := aes.NewCipher(hpKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mask := make([]byte, aes.BlockSize)
|
||||
block.Encrypt(mask, pnBytes)
|
||||
newPacket := make([]byte, len(packet))
|
||||
copy(newPacket, packet)
|
||||
newPacket[0] ^= mask[0] & 0xf
|
||||
for i := range newPacket[hdrLen : hdrLen+4] {
|
||||
newPacket[hdrLen+i] ^= mask[i+1]
|
||||
}
|
||||
packetNumberLength := newPacket[0]&0x3 + 1
|
||||
if packetNumberLength != 1 {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
packetNumber := newPacket[hdrLen]
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if packetNumber != 0 {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
extHdrLen := hdrLen + int(packetNumberLength)
|
||||
copy(newPacket[extHdrLen:hdrLen+4], packet[extHdrLen:])
|
||||
data := newPacket[extHdrLen : int(packetLen)+hdrLen]
|
||||
|
||||
var keyLabel string
|
||||
var ivLabel string
|
||||
switch versionNumber {
|
||||
case qtls.Version2:
|
||||
keyLabel = qtls.HKDFLabelKeyV2
|
||||
ivLabel = qtls.HKDFLabelIVV2
|
||||
default:
|
||||
keyLabel = qtls.HKDFLabelKeyV1
|
||||
ivLabel = qtls.HKDFLabelIVV1
|
||||
}
|
||||
|
||||
key := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16)
|
||||
iv := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12)
|
||||
cipher := qtls.AEADAESGCMTLS13(key, iv)
|
||||
nonce := make([]byte, int32(cipher.NonceSize()))
|
||||
binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber))
|
||||
decrypted, err := cipher.Open(newPacket[extHdrLen:extHdrLen], nonce, data, newPacket[:extHdrLen])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decryptedReader := bytes.NewReader(decrypted)
|
||||
frameType, err := decryptedReader.ReadByte()
|
||||
if frameType != 0x6 {
|
||||
// not crypto frame
|
||||
return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, nil
|
||||
}
|
||||
_, err = qtls.ReadUvarint(decryptedReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = qtls.ReadUvarint(decryptedReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsHdr := make([]byte, 5)
|
||||
tlsHdr[0] = 0x16
|
||||
binary.BigEndian.PutUint16(tlsHdr[1:], uint16(0x0303))
|
||||
binary.BigEndian.PutUint16(tlsHdr[3:], uint16(decryptedReader.Len()))
|
||||
metadata, err := TLSClientHello(ctx, io.MultiReader(bytes.NewReader(tlsHdr), decryptedReader))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
metadata.Protocol = C.ProtocolQUIC
|
||||
return metadata, nil
|
||||
}
|
||||
23
common/sniff/quic_test.go
Normal file
23
common/sniff/quic_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package sniff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSniffQUICv1(t *testing.T) {
|
||||
pkt, err := hex.DecodeString("cc0000000108d2dc7bad02241f5003796e71004215a71bfcb05159416c724be418537389acdd9a4047306283dcb4d7a9cad5cc06322042d204da67a8dbaa328ab476bb428b48fd001501863afd203f8d4ef085629d664f1a734a65969a47e4a63d4e01a21f18c1d90db0c027180906dc135f9ae421bb8617314c8d54c175fef3d3383d310d0916ebcbd6eed9329befbbb109d8fd4af1d2cf9d6adce8e6c1260a7f8256e273e326da0aa7cc148d76e7a08489dc9d52ade89c027cbc3491ada46417c2c04e2ca768e9a7dd6aa00c594e48b678927325da796817693499bb727050cb3baf3d3291a397c3a8d868e8ec7b8f7295e347455c9dadbe2252ae917ac793d958c7fb8a3d2cdb34e3891eb4286f18617556ff7216dd60256aa5b1d11ff4753459fc5f9dedf11d483a26a0835dc6cd50e1c1f54f86e8f1e502821183cd874f6447a74e818bf3445c7795acf4559d1c1fac474911d2ead5c8d23e4aa4f67afb66efe305a30a0b5d825679b31ddc186cbea936535795c7e8c378c87b8c5adc065154d15bae8f85ac8fec2da40c3aa623b682a065440831555011d7647cde44446a0fb4cf5892f2c088ae1920643094be72e3c499fe8d265caf939e8ab607a5b9317917d2a32a812e8a0e6a2f84721bbb5984ffd242838f705d13f4cfb249bc6a5c80d58ac2595edf56648ec3fe21d787573c253a79805252d6d81e26d367d4ff29ef66b5fe8992086af7bada8cad10b82a7c0dc406c5b6d0c5ec3c583e767f759ce08cad6c3c8f91e5a8")
|
||||
require.NoError(t, err)
|
||||
metadata, err := QUICClientHello(context.Background(), pkt)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, metadata.Domain, "cloudflare-quic.com")
|
||||
}
|
||||
|
||||
func FuzzSniffQUIC(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, data []byte) {
|
||||
QUICClientHello(context.Background(), data)
|
||||
})
|
||||
}
|
||||
36
common/sniff/sniff.go
Normal file
36
common/sniff/sniff.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package sniff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
)
|
||||
|
||||
type (
|
||||
StreamSniffer = func(ctx context.Context, reader io.Reader) (*adapter.InboundContext, error)
|
||||
PacketSniffer = func(ctx context.Context, packet []byte) (*adapter.InboundContext, error)
|
||||
)
|
||||
|
||||
func PeekStream(ctx context.Context, reader io.Reader, sniffers ...StreamSniffer) (*adapter.InboundContext, error) {
|
||||
for _, sniffer := range sniffers {
|
||||
sniffMetadata, err := sniffer(ctx, reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sniffMetadata, nil
|
||||
}
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
func PeekPacket(ctx context.Context, packet []byte, sniffers ...PacketSniffer) (*adapter.InboundContext, error) {
|
||||
for _, sniffer := range sniffers {
|
||||
sniffMetadata, err := sniffer(ctx, packet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sniffMetadata, nil
|
||||
}
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
28
common/sniff/tls.go
Normal file
28
common/sniff/tls.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package sniff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
)
|
||||
|
||||
func TLSClientHello(ctx context.Context, reader io.Reader) (*adapter.InboundContext, error) {
|
||||
var clientHello *tls.ClientHelloInfo
|
||||
err := tls.Server(bufio.NewReadOnlyConn(reader), &tls.Config{
|
||||
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
clientHello = argHello
|
||||
return nil, nil
|
||||
},
|
||||
}).HandshakeContext(ctx)
|
||||
if clientHello != nil {
|
||||
return &adapter.InboundContext{Protocol: C.ProtocolTLS, Domain: clientHello.ServerName}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func Packet() {
|
||||
}
|
||||
Reference in New Issue
Block a user