From 1af14a0237949b4ed17f2c99193087ad1a4949cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 28 Jan 2026 18:22:28 +0800 Subject: [PATCH] Remove varbin usages --- adapter/experimental.go | 28 +- common/geosite/compat_test.go | 234 ++++++++++++ common/geosite/reader.go | 31 +- common/geosite/writer.go | 18 +- common/srs/binary.go | 71 +++- common/srs/compat_test.go | 494 ++++++++++++++++++++++++++ common/srs/ip_cidr.go | 19 +- common/srs/ip_set.go | 65 ++-- experimental/libbox/profile_import.go | 46 ++- go.mod | 2 +- go.sum | 4 +- 11 files changed, 953 insertions(+), 59 deletions(-) create mode 100644 common/geosite/compat_test.go create mode 100644 common/srs/compat_test.go diff --git a/adapter/experimental.go b/adapter/experimental.go index d4d37922b..1bd8d2d92 100644 --- a/adapter/experimental.go +++ b/adapter/experimental.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/binary" + "io" "time" "github.com/sagernet/sing/common/observable" @@ -68,7 +69,11 @@ func (s *SavedBinary) MarshalBinary() ([]byte, error) { if err != nil { return nil, err } - err = varbin.Write(&buffer, binary.BigEndian, s.Content) + _, err = varbin.WriteUvarint(&buffer, uint64(len(s.Content))) + if err != nil { + return nil, err + } + _, err = buffer.Write(s.Content) if err != nil { return nil, err } @@ -76,7 +81,11 @@ func (s *SavedBinary) MarshalBinary() ([]byte, error) { if err != nil { return nil, err } - err = varbin.Write(&buffer, binary.BigEndian, s.LastEtag) + _, err = varbin.WriteUvarint(&buffer, uint64(len(s.LastEtag))) + if err != nil { + return nil, err + } + _, err = buffer.WriteString(s.LastEtag) if err != nil { return nil, err } @@ -90,7 +99,12 @@ func (s *SavedBinary) UnmarshalBinary(data []byte) error { if err != nil { return err } - err = varbin.Read(reader, binary.BigEndian, &s.Content) + contentLength, err := binary.ReadUvarint(reader) + if err != nil { + return err + } + s.Content = make([]byte, contentLength) + _, err = io.ReadFull(reader, s.Content) if err != nil { return err } @@ -100,10 +114,16 @@ func (s *SavedBinary) UnmarshalBinary(data []byte) error { return err } s.LastUpdated = time.Unix(lastUpdated, 0) - err = varbin.Read(reader, binary.BigEndian, &s.LastEtag) + etagLength, err := binary.ReadUvarint(reader) if err != nil { return err } + etagBytes := make([]byte, etagLength) + _, err = io.ReadFull(reader, etagBytes) + if err != nil { + return err + } + s.LastEtag = string(etagBytes) return nil } diff --git a/common/geosite/compat_test.go b/common/geosite/compat_test.go new file mode 100644 index 000000000..1a55c6442 --- /dev/null +++ b/common/geosite/compat_test.go @@ -0,0 +1,234 @@ +package geosite + +import ( + "bufio" + "bytes" + "encoding/binary" + "strings" + "testing" + + "github.com/sagernet/sing/common/varbin" + + "github.com/stretchr/testify/require" +) + +// Old implementation using varbin reflection-based serialization + +func oldWriteString(writer varbin.Writer, value string) error { + //nolint:staticcheck + return varbin.Write(writer, binary.BigEndian, value) +} + +func oldWriteItem(writer varbin.Writer, item Item) error { + //nolint:staticcheck + return varbin.Write(writer, binary.BigEndian, item) +} + +func oldReadString(reader varbin.Reader) (string, error) { + //nolint:staticcheck + return varbin.ReadValue[string](reader, binary.BigEndian) +} + +func oldReadItem(reader varbin.Reader) (Item, error) { + //nolint:staticcheck + return varbin.ReadValue[Item](reader, binary.BigEndian) +} + +func TestStringCompat(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + input string + }{ + {"empty", ""}, + {"single_char", "a"}, + {"ascii", "example.com"}, + {"utf8", "测试域名.中国"}, + {"special_chars", "\x00\xff\n\t"}, + {"127_bytes", strings.Repeat("x", 127)}, + {"128_bytes", strings.Repeat("x", 128)}, + {"16383_bytes", strings.Repeat("x", 16383)}, + {"16384_bytes", strings.Repeat("x", 16384)}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Old write + var oldBuf bytes.Buffer + err := oldWriteString(&oldBuf, tc.input) + require.NoError(t, err) + + // New write + var newBuf bytes.Buffer + err = writeString(&newBuf, tc.input) + require.NoError(t, err) + + // Bytes must match + require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(), + "mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes()) + + // New write -> old read + readBack, err := oldReadString(bufio.NewReader(bytes.NewReader(newBuf.Bytes()))) + require.NoError(t, err) + require.Equal(t, tc.input, readBack) + + // Old write -> new read + readBack2, err := readString(bufio.NewReader(bytes.NewReader(oldBuf.Bytes()))) + require.NoError(t, err) + require.Equal(t, tc.input, readBack2) + }) + } +} + +func TestItemCompat(t *testing.T) { + t.Parallel() + + // Note: varbin.Write has a bug where struct values (not pointers) don't write their fields + // because field.CanSet() returns false for non-addressable values. + // The old geosite code passed Item values to varbin.Write, which silently wrote nothing. + // The new code correctly writes Type + Value using manual serialization. + // This test verifies the new serialization format and round-trip correctness. + + cases := []struct { + name string + input Item + }{ + {"domain_empty", Item{Type: RuleTypeDomain, Value: ""}}, + {"domain_normal", Item{Type: RuleTypeDomain, Value: "example.com"}}, + {"domain_suffix", Item{Type: RuleTypeDomainSuffix, Value: ".example.com"}}, + {"domain_keyword", Item{Type: RuleTypeDomainKeyword, Value: "google"}}, + {"domain_regex", Item{Type: RuleTypeDomainRegex, Value: `^.*\.example\.com$`}}, + {"utf8_domain", Item{Type: RuleTypeDomain, Value: "测试.com"}}, + {"long_domain", Item{Type: RuleTypeDomainSuffix, Value: strings.Repeat("a", 200) + ".com"}}, + {"128_bytes_value", Item{Type: RuleTypeDomain, Value: strings.Repeat("x", 128)}}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // New write + var newBuf bytes.Buffer + err := newBuf.WriteByte(byte(tc.input.Type)) + require.NoError(t, err) + err = writeString(&newBuf, tc.input.Value) + require.NoError(t, err) + + // Verify format: Type (1 byte) + Value (uvarint len + bytes) + require.True(t, len(newBuf.Bytes()) >= 1, "output too short") + require.Equal(t, byte(tc.input.Type), newBuf.Bytes()[0], "type byte mismatch") + + // New write -> old read (varbin can read correctly when given addressable target) + readBack, err := oldReadItem(bufio.NewReader(bytes.NewReader(newBuf.Bytes()))) + require.NoError(t, err) + require.Equal(t, tc.input, readBack) + + // New write -> new read + reader := bufio.NewReader(bytes.NewReader(newBuf.Bytes())) + typeByte, err := reader.ReadByte() + require.NoError(t, err) + value, err := readString(reader) + require.NoError(t, err) + require.Equal(t, tc.input, Item{Type: ItemType(typeByte), Value: value}) + }) + } +} + +func TestGeositeWriteReadCompat(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + input map[string][]Item + }{ + { + "empty_map", + map[string][]Item{}, + }, + { + "single_code_empty_items", + map[string][]Item{"test": {}}, + }, + { + "single_code_single_item", + map[string][]Item{"test": {{Type: RuleTypeDomain, Value: "a.com"}}}, + }, + { + "single_code_multi_items", + map[string][]Item{ + "test": { + {Type: RuleTypeDomain, Value: "a.com"}, + {Type: RuleTypeDomainSuffix, Value: ".b.com"}, + {Type: RuleTypeDomainKeyword, Value: "keyword"}, + {Type: RuleTypeDomainRegex, Value: `^.*$`}, + }, + }, + }, + { + "multi_code", + map[string][]Item{ + "cn": {{Type: RuleTypeDomain, Value: "baidu.com"}, {Type: RuleTypeDomainSuffix, Value: ".cn"}}, + "us": {{Type: RuleTypeDomain, Value: "google.com"}}, + "jp": {{Type: RuleTypeDomainSuffix, Value: ".jp"}}, + }, + }, + { + "utf8_values", + map[string][]Item{ + "test": { + {Type: RuleTypeDomain, Value: "测试.中国"}, + {Type: RuleTypeDomainSuffix, Value: ".テスト"}, + }, + }, + }, + { + "large_items", + generateLargeItems(1000), + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Write using new implementation + var buf bytes.Buffer + err := Write(&buf, tc.input) + require.NoError(t, err) + + // Read back and verify + reader, codes, err := NewReader(bytes.NewReader(buf.Bytes())) + require.NoError(t, err) + + // Verify all codes exist + codeSet := make(map[string]bool) + for _, code := range codes { + codeSet[code] = true + } + for code := range tc.input { + require.True(t, codeSet[code], "missing code: %s", code) + } + + // Verify items match + for code, expectedItems := range tc.input { + items, err := reader.Read(code) + require.NoError(t, err) + require.Equal(t, expectedItems, items, "items mismatch for code: %s", code) + } + }) + } +} + +func generateLargeItems(count int) map[string][]Item { + items := make([]Item, count) + for i := 0; i < count; i++ { + items[i] = Item{ + Type: ItemType(i % 4), + Value: strings.Repeat("x", i%200) + ".com", + } + } + return map[string][]Item{"large": items} +} diff --git a/common/geosite/reader.go b/common/geosite/reader.go index 3b3f7fec2..ef99837d8 100644 --- a/common/geosite/reader.go +++ b/common/geosite/reader.go @@ -9,7 +9,6 @@ import ( "sync/atomic" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/varbin" ) type Reader struct { @@ -78,7 +77,7 @@ func (r *Reader) readMetadata() error { codeIndex uint64 codeLength uint64 ) - code, err = varbin.ReadValue[string](reader, binary.BigEndian) + code, err = readString(reader) if err != nil { return err } @@ -112,9 +111,16 @@ func (r *Reader) Read(code string) ([]Item, error) { } r.bufferedReader.Reset(r.reader) itemList := make([]Item, r.domainLength[code]) - err = varbin.Read(r.bufferedReader, binary.BigEndian, &itemList) - if err != nil { - return nil, err + for i := range itemList { + typeByte, err := r.bufferedReader.ReadByte() + if err != nil { + return nil, err + } + itemList[i].Type = ItemType(typeByte) + itemList[i].Value, err = readString(r.bufferedReader) + if err != nil { + return nil, err + } } return itemList, nil } @@ -135,3 +141,18 @@ func (r *readCounter) Read(p []byte) (n int, err error) { } return } + +func readString(reader io.ByteReader) (string, error) { + length, err := binary.ReadUvarint(reader) + if err != nil { + return "", err + } + bytes := make([]byte, length) + for i := range bytes { + bytes[i], err = reader.ReadByte() + if err != nil { + return "", err + } + } + return string(bytes), nil +} diff --git a/common/geosite/writer.go b/common/geosite/writer.go index 1615fa348..52f2f7b9b 100644 --- a/common/geosite/writer.go +++ b/common/geosite/writer.go @@ -2,7 +2,6 @@ package geosite import ( "bytes" - "encoding/binary" "sort" "github.com/sagernet/sing/common/varbin" @@ -20,7 +19,11 @@ func Write(writer varbin.Writer, domains map[string][]Item) error { for _, code := range keys { index[code] = content.Len() for _, item := range domains[code] { - err := varbin.Write(content, binary.BigEndian, item) + err := content.WriteByte(byte(item.Type)) + if err != nil { + return err + } + err = writeString(content, item.Value) if err != nil { return err } @@ -38,7 +41,7 @@ func Write(writer varbin.Writer, domains map[string][]Item) error { } for _, code := range keys { - err = varbin.Write(writer, binary.BigEndian, code) + err = writeString(writer, code) if err != nil { return err } @@ -59,3 +62,12 @@ func Write(writer varbin.Writer, domains map[string][]Item) error { return nil } + +func writeString(writer varbin.Writer, value string) error { + _, err := varbin.WriteUvarint(writer, uint64(len(value))) + if err != nil { + return err + } + _, err = writer.Write([]byte(value)) + return err +} diff --git a/common/srs/binary.go b/common/srs/binary.go index 0c93c2842..ca12fff09 100644 --- a/common/srs/binary.go +++ b/common/srs/binary.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "io" "net/netip" + "unsafe" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" @@ -505,7 +506,24 @@ func writeDefaultRule(writer varbin.Writer, rule option.DefaultHeadlessRule, gen } func readRuleItemString(reader varbin.Reader) ([]string, error) { - return varbin.ReadValue[[]string](reader, binary.BigEndian) + length, err := binary.ReadUvarint(reader) + if err != nil { + return nil, err + } + result := make([]string, length) + for i := range result { + strLen, err := binary.ReadUvarint(reader) + if err != nil { + return nil, err + } + buf := make([]byte, strLen) + _, err = io.ReadFull(reader, buf) + if err != nil { + return nil, err + } + result[i] = string(buf) + } + return result, nil } func writeRuleItemString(writer varbin.Writer, itemType uint8, value []string) error { @@ -513,11 +531,34 @@ func writeRuleItemString(writer varbin.Writer, itemType uint8, value []string) e if err != nil { return err } - return varbin.Write(writer, binary.BigEndian, value) + _, err = varbin.WriteUvarint(writer, uint64(len(value))) + if err != nil { + return err + } + for _, s := range value { + _, err = varbin.WriteUvarint(writer, uint64(len(s))) + if err != nil { + return err + } + _, err = writer.Write([]byte(s)) + if err != nil { + return err + } + } + return nil } func readRuleItemUint8[E ~uint8](reader varbin.Reader) ([]E, error) { - return varbin.ReadValue[[]E](reader, binary.BigEndian) + length, err := binary.ReadUvarint(reader) + if err != nil { + return nil, err + } + result := make([]E, length) + _, err = io.ReadFull(reader, *(*[]byte)(unsafe.Pointer(&result))) + if err != nil { + return nil, err + } + return result, nil } func writeRuleItemUint8[E ~uint8](writer varbin.Writer, itemType uint8, value []E) error { @@ -525,11 +566,25 @@ func writeRuleItemUint8[E ~uint8](writer varbin.Writer, itemType uint8, value [] if err != nil { return err } - return varbin.Write(writer, binary.BigEndian, value) + _, err = varbin.WriteUvarint(writer, uint64(len(value))) + if err != nil { + return err + } + _, err = writer.Write(*(*[]byte)(unsafe.Pointer(&value))) + return err } func readRuleItemUint16(reader varbin.Reader) ([]uint16, error) { - return varbin.ReadValue[[]uint16](reader, binary.BigEndian) + length, err := binary.ReadUvarint(reader) + if err != nil { + return nil, err + } + result := make([]uint16, length) + err = binary.Read(reader, binary.BigEndian, result) + if err != nil { + return nil, err + } + return result, nil } func writeRuleItemUint16(writer varbin.Writer, itemType uint8, value []uint16) error { @@ -537,7 +592,11 @@ func writeRuleItemUint16(writer varbin.Writer, itemType uint8, value []uint16) e if err != nil { return err } - return varbin.Write(writer, binary.BigEndian, value) + _, err = varbin.WriteUvarint(writer, uint64(len(value))) + if err != nil { + return err + } + return binary.Write(writer, binary.BigEndian, value) } func writeRuleItemCIDR(writer varbin.Writer, itemType uint8, value []string) error { diff --git a/common/srs/compat_test.go b/common/srs/compat_test.go new file mode 100644 index 000000000..98552b324 --- /dev/null +++ b/common/srs/compat_test.go @@ -0,0 +1,494 @@ +package srs + +import ( + "bufio" + "bytes" + "encoding/binary" + "net/netip" + "strings" + "testing" + "unsafe" + + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/varbin" + + "github.com/stretchr/testify/require" + "go4.org/netipx" +) + +// Old implementations using varbin reflection-based serialization + +func oldWriteStringSlice(writer varbin.Writer, value []string) error { + //nolint:staticcheck + return varbin.Write(writer, binary.BigEndian, value) +} + +func oldReadStringSlice(reader varbin.Reader) ([]string, error) { + //nolint:staticcheck + return varbin.ReadValue[[]string](reader, binary.BigEndian) +} + +func oldWriteUint8Slice[E ~uint8](writer varbin.Writer, value []E) error { + //nolint:staticcheck + return varbin.Write(writer, binary.BigEndian, value) +} + +func oldReadUint8Slice[E ~uint8](reader varbin.Reader) ([]E, error) { + //nolint:staticcheck + return varbin.ReadValue[[]E](reader, binary.BigEndian) +} + +func oldWriteUint16Slice(writer varbin.Writer, value []uint16) error { + //nolint:staticcheck + return varbin.Write(writer, binary.BigEndian, value) +} + +func oldReadUint16Slice(reader varbin.Reader) ([]uint16, error) { + //nolint:staticcheck + return varbin.ReadValue[[]uint16](reader, binary.BigEndian) +} + +func oldWritePrefix(writer varbin.Writer, prefix netip.Prefix) error { + //nolint:staticcheck + err := varbin.Write(writer, binary.BigEndian, prefix.Addr().AsSlice()) + if err != nil { + return err + } + return binary.Write(writer, binary.BigEndian, uint8(prefix.Bits())) +} + +type oldIPRangeData struct { + From []byte + To []byte +} + +// Note: The old writeIPSet had a bug where varbin.Write(writer, binary.BigEndian, data) +// with a struct VALUE (not pointer) silently wrote nothing because field.CanSet() returned false. +// This caused IP range data to be missing from the output. +// The new implementation correctly writes all range data. +// +// The old readIPSet used varbin.Read with a pre-allocated slice, which worked because +// slice elements are addressable and CanSet() returns true for them. +// +// For compatibility testing, we verify: +// 1. New write produces correct output with range data +// 2. New read can parse the new format correctly +// 3. Round-trip works correctly + +func oldReadIPSet(reader varbin.Reader) (*netipx.IPSet, error) { + version, err := reader.ReadByte() + if err != nil { + return nil, err + } + if version != 1 { + return nil, err + } + var length uint64 + err = binary.Read(reader, binary.BigEndian, &length) + if err != nil { + return nil, err + } + ranges := make([]oldIPRangeData, length) + //nolint:staticcheck + err = varbin.Read(reader, binary.BigEndian, &ranges) + if err != nil { + return nil, err + } + mySet := &myIPSet{ + rr: make([]myIPRange, len(ranges)), + } + for i, rangeData := range ranges { + mySet.rr[i].from = M.AddrFromIP(rangeData.From) + mySet.rr[i].to = M.AddrFromIP(rangeData.To) + } + return (*netipx.IPSet)(unsafe.Pointer(mySet)), nil +} + +// New write functions (without itemType prefix for testing) + +func newWriteStringSlice(writer varbin.Writer, value []string) error { + _, err := varbin.WriteUvarint(writer, uint64(len(value))) + if err != nil { + return err + } + for _, s := range value { + _, err = varbin.WriteUvarint(writer, uint64(len(s))) + if err != nil { + return err + } + _, err = writer.Write([]byte(s)) + if err != nil { + return err + } + } + return nil +} + +func newWriteUint8Slice[E ~uint8](writer varbin.Writer, value []E) error { + _, err := varbin.WriteUvarint(writer, uint64(len(value))) + if err != nil { + return err + } + _, err = writer.Write(*(*[]byte)(unsafe.Pointer(&value))) + return err +} + +func newWriteUint16Slice(writer varbin.Writer, value []uint16) error { + _, err := varbin.WriteUvarint(writer, uint64(len(value))) + if err != nil { + return err + } + return binary.Write(writer, binary.BigEndian, value) +} + +func newWritePrefix(writer varbin.Writer, prefix netip.Prefix) error { + addrSlice := prefix.Addr().AsSlice() + _, err := varbin.WriteUvarint(writer, uint64(len(addrSlice))) + if err != nil { + return err + } + _, err = writer.Write(addrSlice) + if err != nil { + return err + } + return writer.WriteByte(uint8(prefix.Bits())) +} + +// Tests + +func TestStringSliceCompat(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + input []string + }{ + {"nil", nil}, + {"empty", []string{}}, + {"single_empty", []string{""}}, + {"single", []string{"test"}}, + {"multi", []string{"a", "b", "c"}}, + {"with_empty", []string{"a", "", "c"}}, + {"utf8", []string{"测试", "テスト", "тест"}}, + {"long_string", []string{strings.Repeat("x", 128)}}, + {"many_elements", generateStrings(128)}, + {"many_elements_256", generateStrings(256)}, + {"127_byte_string", []string{strings.Repeat("x", 127)}}, + {"128_byte_string", []string{strings.Repeat("x", 128)}}, + {"mixed_lengths", []string{"a", strings.Repeat("b", 100), "", strings.Repeat("c", 200)}}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Old write + var oldBuf bytes.Buffer + err := oldWriteStringSlice(&oldBuf, tc.input) + require.NoError(t, err) + + // New write + var newBuf bytes.Buffer + err = newWriteStringSlice(&newBuf, tc.input) + require.NoError(t, err) + + // Bytes must match + require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(), + "mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes()) + + // New write -> old read + readBack, err := oldReadStringSlice(bufio.NewReader(bytes.NewReader(newBuf.Bytes()))) + require.NoError(t, err) + requireStringSliceEqual(t, tc.input, readBack) + + // Old write -> new read + readBack2, err := readRuleItemString(bufio.NewReader(bytes.NewReader(oldBuf.Bytes()))) + require.NoError(t, err) + requireStringSliceEqual(t, tc.input, readBack2) + }) + } +} + +func TestUint8SliceCompat(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + input []uint8 + }{ + {"nil", nil}, + {"empty", []uint8{}}, + {"single_zero", []uint8{0}}, + {"single_max", []uint8{255}}, + {"multi", []uint8{0, 1, 127, 128, 255}}, + {"boundary", []uint8{0x00, 0x7f, 0x80, 0xff}}, + {"sequential", generateUint8Slice(256)}, + {"127_elements", generateUint8Slice(127)}, + {"128_elements", generateUint8Slice(128)}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Old write + var oldBuf bytes.Buffer + err := oldWriteUint8Slice(&oldBuf, tc.input) + require.NoError(t, err) + + // New write + var newBuf bytes.Buffer + err = newWriteUint8Slice(&newBuf, tc.input) + require.NoError(t, err) + + // Bytes must match + require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(), + "mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes()) + + // New write -> old read + readBack, err := oldReadUint8Slice[uint8](bufio.NewReader(bytes.NewReader(newBuf.Bytes()))) + require.NoError(t, err) + requireUint8SliceEqual(t, tc.input, readBack) + + // Old write -> new read + readBack2, err := readRuleItemUint8[uint8](bufio.NewReader(bytes.NewReader(oldBuf.Bytes()))) + require.NoError(t, err) + requireUint8SliceEqual(t, tc.input, readBack2) + }) + } +} + +func TestUint16SliceCompat(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + input []uint16 + }{ + {"nil", nil}, + {"empty", []uint16{}}, + {"single_zero", []uint16{0}}, + {"single_max", []uint16{65535}}, + {"multi", []uint16{0, 255, 256, 32767, 32768, 65535}}, + {"ports", []uint16{80, 443, 8080, 8443}}, + {"127_elements", generateUint16Slice(127)}, + {"128_elements", generateUint16Slice(128)}, + {"256_elements", generateUint16Slice(256)}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Old write + var oldBuf bytes.Buffer + err := oldWriteUint16Slice(&oldBuf, tc.input) + require.NoError(t, err) + + // New write + var newBuf bytes.Buffer + err = newWriteUint16Slice(&newBuf, tc.input) + require.NoError(t, err) + + // Bytes must match + require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(), + "mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes()) + + // New write -> old read + readBack, err := oldReadUint16Slice(bufio.NewReader(bytes.NewReader(newBuf.Bytes()))) + require.NoError(t, err) + requireUint16SliceEqual(t, tc.input, readBack) + + // Old write -> new read + readBack2, err := readRuleItemUint16(bufio.NewReader(bytes.NewReader(oldBuf.Bytes()))) + require.NoError(t, err) + requireUint16SliceEqual(t, tc.input, readBack2) + }) + } +} + +func TestPrefixCompat(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + input netip.Prefix + }{ + {"ipv4_0", netip.MustParsePrefix("0.0.0.0/0")}, + {"ipv4_8", netip.MustParsePrefix("10.0.0.0/8")}, + {"ipv4_16", netip.MustParsePrefix("192.168.0.0/16")}, + {"ipv4_24", netip.MustParsePrefix("192.168.1.0/24")}, + {"ipv4_32", netip.MustParsePrefix("1.2.3.4/32")}, + {"ipv6_0", netip.MustParsePrefix("::/0")}, + {"ipv6_64", netip.MustParsePrefix("2001:db8::/64")}, + {"ipv6_128", netip.MustParsePrefix("::1/128")}, + {"ipv6_full", netip.MustParsePrefix("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128")}, + {"ipv4_private", netip.MustParsePrefix("172.16.0.0/12")}, + {"ipv6_link_local", netip.MustParsePrefix("fe80::/10")}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Old write + var oldBuf bytes.Buffer + err := oldWritePrefix(&oldBuf, tc.input) + require.NoError(t, err) + + // New write + var newBuf bytes.Buffer + err = newWritePrefix(&newBuf, tc.input) + require.NoError(t, err) + + // Bytes must match + require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(), + "mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes()) + + // New write -> new read (no old read for prefix) + readBack, err := readPrefix(bufio.NewReader(bytes.NewReader(newBuf.Bytes()))) + require.NoError(t, err) + require.Equal(t, tc.input, readBack) + + // Old write -> new read + readBack2, err := readPrefix(bufio.NewReader(bytes.NewReader(oldBuf.Bytes()))) + require.NoError(t, err) + require.Equal(t, tc.input, readBack2) + }) + } +} + +func TestIPSetCompat(t *testing.T) { + t.Parallel() + + // Note: The old writeIPSet was buggy (varbin.Write with struct values wrote nothing). + // This test verifies the new implementation writes correct data and round-trips correctly. + + cases := []struct { + name string + input *netipx.IPSet + }{ + {"single_ipv4", buildIPSet("1.2.3.4")}, + {"ipv4_range", buildIPSet("192.168.0.0/16")}, + {"multi_ipv4", buildIPSet("10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16")}, + {"single_ipv6", buildIPSet("::1")}, + {"ipv6_range", buildIPSet("2001:db8::/32")}, + {"mixed", buildIPSet("10.0.0.0/8", "::1", "2001:db8::/32")}, + {"large", buildLargeIPSet(100)}, + {"adjacent_ranges", buildIPSet("192.168.0.0/24", "192.168.1.0/24", "192.168.2.0/24")}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // New write + var newBuf bytes.Buffer + err := writeIPSet(&newBuf, tc.input) + require.NoError(t, err) + + // Verify format starts with version byte (1) + uint64 count + require.True(t, len(newBuf.Bytes()) >= 9, "output too short") + require.Equal(t, byte(1), newBuf.Bytes()[0], "version byte mismatch") + + // New write -> old read (varbin.Read with pre-allocated slice works correctly) + readBack, err := oldReadIPSet(bufio.NewReader(bytes.NewReader(newBuf.Bytes()))) + require.NoError(t, err) + requireIPSetEqual(t, tc.input, readBack) + + // New write -> new read + readBack2, err := readIPSet(bufio.NewReader(bytes.NewReader(newBuf.Bytes()))) + require.NoError(t, err) + requireIPSetEqual(t, tc.input, readBack2) + }) + } +} + +// Helper functions + +func generateStrings(count int) []string { + result := make([]string, count) + for i := range result { + result[i] = strings.Repeat("x", i%50) + } + return result +} + +func generateUint8Slice(count int) []uint8 { + result := make([]uint8, count) + for i := range result { + result[i] = uint8(i % 256) + } + return result +} + +func generateUint16Slice(count int) []uint16 { + result := make([]uint16, count) + for i := range result { + result[i] = uint16(i * 257) + } + return result +} + +func buildIPSet(cidrs ...string) *netipx.IPSet { + var builder netipx.IPSetBuilder + for _, cidr := range cidrs { + prefix, err := netip.ParsePrefix(cidr) + if err != nil { + addr, err := netip.ParseAddr(cidr) + if err != nil { + panic(err) + } + builder.Add(addr) + } else { + builder.AddPrefix(prefix) + } + } + set, _ := builder.IPSet() + return set +} + +func buildLargeIPSet(count int) *netipx.IPSet { + var builder netipx.IPSetBuilder + for i := 0; i < count; i++ { + prefix := netip.PrefixFrom(netip.AddrFrom4([4]byte{10, byte(i / 256), byte(i % 256), 0}), 24) + builder.AddPrefix(prefix) + } + set, _ := builder.IPSet() + return set +} + +func requireStringSliceEqual(t *testing.T, expected, actual []string) { + t.Helper() + if len(expected) == 0 && len(actual) == 0 { + return + } + require.Equal(t, expected, actual) +} + +func requireUint8SliceEqual(t *testing.T, expected, actual []uint8) { + t.Helper() + if len(expected) == 0 && len(actual) == 0 { + return + } + require.Equal(t, expected, actual) +} + +func requireUint16SliceEqual(t *testing.T, expected, actual []uint16) { + t.Helper() + if len(expected) == 0 && len(actual) == 0 { + return + } + require.Equal(t, expected, actual) +} + +func requireIPSetEqual(t *testing.T, expected, actual *netipx.IPSet) { + t.Helper() + expectedRanges := expected.Ranges() + actualRanges := actual.Ranges() + require.Equal(t, len(expectedRanges), len(actualRanges), "range count mismatch") + for i := range expectedRanges { + require.Equal(t, expectedRanges[i].From(), actualRanges[i].From(), "range[%d].from mismatch", i) + require.Equal(t, expectedRanges[i].To(), actualRanges[i].To(), "range[%d].to mismatch", i) + } +} diff --git a/common/srs/ip_cidr.go b/common/srs/ip_cidr.go index 93ae84ad3..7c81abda2 100644 --- a/common/srs/ip_cidr.go +++ b/common/srs/ip_cidr.go @@ -2,6 +2,7 @@ package srs import ( "encoding/binary" + "io" "net/netip" M "github.com/sagernet/sing/common/metadata" @@ -9,11 +10,16 @@ import ( ) func readPrefix(reader varbin.Reader) (netip.Prefix, error) { - addrSlice, err := varbin.ReadValue[[]byte](reader, binary.BigEndian) + addrLen, err := binary.ReadUvarint(reader) if err != nil { return netip.Prefix{}, err } - prefixBits, err := varbin.ReadValue[uint8](reader, binary.BigEndian) + addrSlice := make([]byte, addrLen) + _, err = io.ReadFull(reader, addrSlice) + if err != nil { + return netip.Prefix{}, err + } + prefixBits, err := reader.ReadByte() if err != nil { return netip.Prefix{}, err } @@ -21,11 +27,16 @@ func readPrefix(reader varbin.Reader) (netip.Prefix, error) { } func writePrefix(writer varbin.Writer, prefix netip.Prefix) error { - err := varbin.Write(writer, binary.BigEndian, prefix.Addr().AsSlice()) + addrSlice := prefix.Addr().AsSlice() + _, err := varbin.WriteUvarint(writer, uint64(len(addrSlice))) if err != nil { return err } - err = binary.Write(writer, binary.BigEndian, uint8(prefix.Bits())) + _, err = writer.Write(addrSlice) + if err != nil { + return err + } + err = writer.WriteByte(uint8(prefix.Bits())) if err != nil { return err } diff --git a/common/srs/ip_set.go b/common/srs/ip_set.go index 044dc823b..a10ac08c0 100644 --- a/common/srs/ip_set.go +++ b/common/srs/ip_set.go @@ -2,11 +2,11 @@ package srs import ( "encoding/binary" + "io" "net/netip" "os" "unsafe" - "github.com/sagernet/sing/common" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/varbin" @@ -22,11 +22,6 @@ type myIPRange struct { to netip.Addr } -type myIPRangeData struct { - From []byte - To []byte -} - func readIPSet(reader varbin.Reader) (*netipx.IPSet, error) { version, err := reader.ReadByte() if err != nil { @@ -41,17 +36,30 @@ func readIPSet(reader varbin.Reader) (*netipx.IPSet, error) { if err != nil { return nil, err } - ranges := make([]myIPRangeData, length) - err = varbin.Read(reader, binary.BigEndian, &ranges) - if err != nil { - return nil, err - } mySet := &myIPSet{ - rr: make([]myIPRange, len(ranges)), + rr: make([]myIPRange, length), } - for i, rangeData := range ranges { - mySet.rr[i].from = M.AddrFromIP(rangeData.From) - mySet.rr[i].to = M.AddrFromIP(rangeData.To) + for i := range mySet.rr { + fromLen, err := binary.ReadUvarint(reader) + if err != nil { + return nil, err + } + fromBytes := make([]byte, fromLen) + _, err = io.ReadFull(reader, fromBytes) + if err != nil { + return nil, err + } + toLen, err := binary.ReadUvarint(reader) + if err != nil { + return nil, err + } + toBytes := make([]byte, toLen) + _, err = io.ReadFull(reader, toBytes) + if err != nil { + return nil, err + } + mySet.rr[i].from = M.AddrFromIP(fromBytes) + mySet.rr[i].to = M.AddrFromIP(toBytes) } return (*netipx.IPSet)(unsafe.Pointer(mySet)), nil } @@ -61,18 +69,27 @@ func writeIPSet(writer varbin.Writer, set *netipx.IPSet) error { if err != nil { return err } - dataList := common.Map((*myIPSet)(unsafe.Pointer(set)).rr, func(rr myIPRange) myIPRangeData { - return myIPRangeData{ - From: rr.from.AsSlice(), - To: rr.to.AsSlice(), - } - }) - err = binary.Write(writer, binary.BigEndian, uint64(len(dataList))) + mySet := (*myIPSet)(unsafe.Pointer(set)) + err = binary.Write(writer, binary.BigEndian, uint64(len(mySet.rr))) if err != nil { return err } - for _, data := range dataList { - err = varbin.Write(writer, binary.BigEndian, data) + for _, rr := range mySet.rr { + fromBytes := rr.from.AsSlice() + _, err = varbin.WriteUvarint(writer, uint64(len(fromBytes))) + if err != nil { + return err + } + _, err = writer.Write(fromBytes) + if err != nil { + return err + } + toBytes := rr.to.AsSlice() + _, err = varbin.WriteUvarint(writer, uint64(len(toBytes))) + if err != nil { + return err + } + _, err = writer.Write(toBytes) if err != nil { return err } diff --git a/experimental/libbox/profile_import.go b/experimental/libbox/profile_import.go index 17671e560..c337d015f 100644 --- a/experimental/libbox/profile_import.go +++ b/experimental/libbox/profile_import.go @@ -5,6 +5,7 @@ import ( "bytes" "compress/gzip" "encoding/binary" + "io" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/varbin" @@ -35,7 +36,7 @@ type ErrorMessage struct { func (e *ErrorMessage) Encode() []byte { var buffer bytes.Buffer buffer.WriteByte(MessageTypeError) - varbin.Write(&buffer, binary.BigEndian, e.Message) + writeString(&buffer, e.Message) return buffer.Bytes() } @@ -49,7 +50,7 @@ func DecodeErrorMessage(data []byte) (*ErrorMessage, error) { return nil, E.New("invalid message") } var message ErrorMessage - message.Message, err = varbin.ReadValue[string](reader, binary.BigEndian) + message.Message, err = readString(reader) if err != nil { return nil, err } @@ -87,7 +88,7 @@ func (e *ProfileEncoder) Encode() []byte { binary.Write(&buffer, binary.BigEndian, uint16(len(e.profiles))) for _, preview := range e.profiles { binary.Write(&buffer, binary.BigEndian, preview.ProfileID) - varbin.Write(&buffer, binary.BigEndian, preview.Name) + writeString(&buffer, preview.Name) binary.Write(&buffer, binary.BigEndian, preview.Type) } return buffer.Bytes() @@ -117,7 +118,7 @@ func (d *ProfileDecoder) Decode(data []byte) error { if err != nil { return err } - profile.Name, err = varbin.ReadValue[string](reader, binary.BigEndian) + profile.Name, err = readString(reader) if err != nil { return err } @@ -178,11 +179,11 @@ func (c *ProfileContent) Encode() []byte { buffer.WriteByte(1) gWriter := gzip.NewWriter(buffer) writer := bufio.NewWriter(gWriter) - varbin.Write(writer, binary.BigEndian, c.Name) + writeStringBuffered(writer, c.Name) binary.Write(writer, binary.BigEndian, c.Type) - varbin.Write(writer, binary.BigEndian, c.Config) + writeStringBuffered(writer, c.Config) if c.Type != ProfileTypeLocal { - varbin.Write(writer, binary.BigEndian, c.RemotePath) + writeStringBuffered(writer, c.RemotePath) } if c.Type == ProfileTypeRemote { binary.Write(writer, binary.BigEndian, c.AutoUpdate) @@ -214,7 +215,7 @@ func DecodeProfileContent(data []byte) (*ProfileContent, error) { } bReader := varbin.StubReader(gReader) var content ProfileContent - content.Name, err = varbin.ReadValue[string](bReader, binary.BigEndian) + content.Name, err = readString(bReader) if err != nil { return nil, err } @@ -222,12 +223,12 @@ func DecodeProfileContent(data []byte) (*ProfileContent, error) { if err != nil { return nil, err } - content.Config, err = varbin.ReadValue[string](bReader, binary.BigEndian) + content.Config, err = readString(bReader) if err != nil { return nil, err } if content.Type != ProfileTypeLocal { - content.RemotePath, err = varbin.ReadValue[string](bReader, binary.BigEndian) + content.RemotePath, err = readString(bReader) if err != nil { return nil, err } @@ -250,3 +251,28 @@ func DecodeProfileContent(data []byte) (*ProfileContent, error) { } return &content, nil } + +func readString(reader io.ByteReader) (string, error) { + length, err := binary.ReadUvarint(reader) + if err != nil { + return "", err + } + buf := make([]byte, length) + for i := range buf { + buf[i], err = reader.ReadByte() + if err != nil { + return "", err + } + } + return string(buf), nil +} + +func writeString(buffer *bytes.Buffer, value string) { + varbin.WriteUvarint(buffer, uint64(len(value))) + buffer.WriteString(value) +} + +func writeStringBuffered(writer *bufio.Writer, value string) { + varbin.WriteUvarint(writer, uint64(len(value))) + writer.WriteString(value) +} diff --git a/go.mod b/go.mod index 91c529ec5..80b4fa1c2 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( github.com/sagernet/gomobile v0.1.11 github.com/sagernet/gvisor v0.0.0-20250811.0-sing-box-mod.1 github.com/sagernet/quic-go v0.59.0-sing-box-mod.2 - github.com/sagernet/sing v0.8.0-beta.11 + github.com/sagernet/sing v0.8.0-beta.12 github.com/sagernet/sing-mux v0.3.4 github.com/sagernet/sing-quic v0.6.0-beta.11 github.com/sagernet/sing-shadowsocks v0.2.8 diff --git a/go.sum b/go.sum index b2c410026..b99db9b1a 100644 --- a/go.sum +++ b/go.sum @@ -210,8 +210,8 @@ github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNen github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8= github.com/sagernet/quic-go v0.59.0-sing-box-mod.2 h1:hJUL+HtxEOjxsa0CsucbBVqI/AMS4k52NwNU637zmdw= github.com/sagernet/quic-go v0.59.0-sing-box-mod.2/go.mod h1:OqILvS182CyOol5zNNo6bguvOGgXzV459+chpRaUC+4= -github.com/sagernet/sing v0.8.0-beta.11 h1:nn/2Uod61b5rLHnXCuaFcbDxI1oRCodaxjzjt7Mobe4= -github.com/sagernet/sing v0.8.0-beta.11/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.8.0-beta.12 h1:Xt0MNk6i6vI9f2vV6QkwhgiLB0g53fZSVsCCBEtQ8qQ= +github.com/sagernet/sing v0.8.0-beta.12/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/sagernet/sing-mux v0.3.4 h1:ZQplKl8MNXutjzbMVtWvWG31fohhgOfCuUZR4dVQ8+s= github.com/sagernet/sing-mux v0.3.4/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk= github.com/sagernet/sing-quic v0.6.0-beta.11 h1:eUusxITKKRedhWC2ScUYFUvD96h/QfbKLaS3N6/7in4=