Compare commits
46 Commits
v1.14.0-al
...
testing
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bf4216bd3 | ||
|
|
a4fde8012f | ||
|
|
17dd36ce31 | ||
|
|
52df438e02 | ||
|
|
2ceecf9b3c | ||
|
|
89ce67c9f4 | ||
|
|
73d9acd274 | ||
|
|
f5f7269962 | ||
|
|
a6fd0604cc | ||
|
|
ca63351ec2 | ||
|
|
d89783ba3d | ||
|
|
845e94d7d3 | ||
|
|
bf3454ee5e | ||
|
|
83c67939fa | ||
|
|
00adb23810 | ||
|
|
34fdc29f18 | ||
|
|
1b5fd56232 | ||
|
|
49704b49cc | ||
|
|
539302511a | ||
|
|
e80b662da8 | ||
|
|
c80f4add73 | ||
|
|
32a34ef434 | ||
|
|
f72c4c1f88 | ||
|
|
90c342e941 | ||
|
|
d70bfb9bb6 | ||
|
|
47f09c96a9 | ||
|
|
3769ad4296 | ||
|
|
fcb43d7cf9 | ||
|
|
16e81c4c68 | ||
|
|
6912e272ea | ||
|
|
c208f4aea4 | ||
|
|
e4450ec230 | ||
|
|
d4dacfc480 | ||
|
|
a45639370f | ||
|
|
986de07d0c | ||
|
|
284c4d8494 | ||
|
|
129396f490 | ||
|
|
df806c96fb | ||
|
|
72a039a8f3 | ||
|
|
9b155ba467 | ||
|
|
7ed5ef6da4 | ||
|
|
bb3ad9c694 | ||
|
|
d5adb54bc6 | ||
|
|
1cfcea769f | ||
|
|
f43fc797d4 | ||
|
|
8e3176b789 |
2
.github/CRONET_GO_VERSION
vendored
2
.github/CRONET_GO_VERSION
vendored
@@ -1 +1 @@
|
|||||||
ea7cd33752aed62603775af3df946c1b83f4b0b3
|
e4926ba205fae5351e3d3eeafff7e7029654424a
|
||||||
|
|||||||
2
.github/setup_go_for_macos1013.sh
vendored
2
.github/setup_go_for_macos1013.sh
vendored
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
VERSION="1.25.8"
|
VERSION="1.25.9"
|
||||||
PATCH_COMMITS=(
|
PATCH_COMMITS=(
|
||||||
"afe69d3cec1c6dcf0f1797b20546795730850070"
|
"afe69d3cec1c6dcf0f1797b20546795730850070"
|
||||||
"1ed289b0cf87dc5aae9c6fe1aa5f200a83412938"
|
"1ed289b0cf87dc5aae9c6fe1aa5f200a83412938"
|
||||||
|
|||||||
2
.github/setup_go_for_windows7.sh
vendored
2
.github/setup_go_for_windows7.sh
vendored
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
VERSION="1.25.8"
|
VERSION="1.25.9"
|
||||||
PATCH_COMMITS=(
|
PATCH_COMMITS=(
|
||||||
"466f6c7a29bc098b0d4c987b803c779222894a11"
|
"466f6c7a29bc098b0d4c987b803c779222894a11"
|
||||||
"1bdabae205052afe1dadb2ad6f1ba612cdbc532a"
|
"1bdabae205052afe1dadb2ad6f1ba612cdbc532a"
|
||||||
|
|||||||
10
.github/workflows/build.yml
vendored
10
.github/workflows/build.yml
vendored
@@ -47,7 +47,7 @@ jobs:
|
|||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: ~1.25.8
|
go-version: ~1.25.9
|
||||||
- name: Check input version
|
- name: Check input version
|
||||||
if: github.event_name == 'workflow_dispatch'
|
if: github.event_name == 'workflow_dispatch'
|
||||||
run: |-
|
run: |-
|
||||||
@@ -124,7 +124,7 @@ jobs:
|
|||||||
if: ${{ ! matrix.legacy_win7 }}
|
if: ${{ ! matrix.legacy_win7 }}
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: ~1.25.8
|
go-version: ~1.25.9
|
||||||
- name: Cache Go for Windows 7
|
- name: Cache Go for Windows 7
|
||||||
if: matrix.legacy_win7
|
if: matrix.legacy_win7
|
||||||
id: cache-go-for-windows7
|
id: cache-go-for-windows7
|
||||||
@@ -641,7 +641,7 @@ jobs:
|
|||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: ~1.25.8
|
go-version: ~1.25.9
|
||||||
- name: Setup Android NDK
|
- name: Setup Android NDK
|
||||||
id: setup-ndk
|
id: setup-ndk
|
||||||
uses: nttld/setup-ndk@v1
|
uses: nttld/setup-ndk@v1
|
||||||
@@ -731,7 +731,7 @@ jobs:
|
|||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: ~1.25.8
|
go-version: ~1.25.9
|
||||||
- name: Setup Android NDK
|
- name: Setup Android NDK
|
||||||
id: setup-ndk
|
id: setup-ndk
|
||||||
uses: nttld/setup-ndk@v1
|
uses: nttld/setup-ndk@v1
|
||||||
@@ -830,7 +830,7 @@ jobs:
|
|||||||
if: matrix.if
|
if: matrix.if
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: ~1.25.8
|
go-version: ~1.25.9
|
||||||
- name: Set tag
|
- name: Set tag
|
||||||
if: matrix.if
|
if: matrix.if
|
||||||
run: |-
|
run: |-
|
||||||
|
|||||||
2
.github/workflows/docker.yml
vendored
2
.github/workflows/docker.yml
vendored
@@ -55,7 +55,7 @@ jobs:
|
|||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: ~1.25.8
|
go-version: ~1.25.9
|
||||||
- name: Clone cronet-go
|
- name: Clone cronet-go
|
||||||
if: matrix.naive
|
if: matrix.naive
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
4
.github/workflows/linux.yml
vendored
4
.github/workflows/linux.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
|||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: ~1.25.8
|
go-version: ~1.25.9
|
||||||
- name: Check input version
|
- name: Check input version
|
||||||
if: github.event_name == 'workflow_dispatch'
|
if: github.event_name == 'workflow_dispatch'
|
||||||
run: |-
|
run: |-
|
||||||
@@ -72,7 +72,7 @@ jobs:
|
|||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: ~1.25.8
|
go-version: ~1.25.9
|
||||||
- name: Clone cronet-go
|
- name: Clone cronet-go
|
||||||
if: matrix.naive
|
if: matrix.naive
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
55
.github/workflows/test.yml
vendored
Normal file
55
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
name: Test
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- stable
|
||||||
|
- testing
|
||||||
|
- unstable
|
||||||
|
paths-ignore:
|
||||||
|
- '**.md'
|
||||||
|
- '.github/**'
|
||||||
|
- '!.github/workflows/test.yml'
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- stable
|
||||||
|
- testing
|
||||||
|
- unstable
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name }}-${{ inputs.build }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
name: Test
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os:
|
||||||
|
- ubuntu-latest
|
||||||
|
- windows-latest
|
||||||
|
- macos-latest
|
||||||
|
go:
|
||||||
|
- ~1.24
|
||||||
|
- ~1.25
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: ${{ matrix.go }}
|
||||||
|
- name: Set build tags and ldflags
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
echo "BUILD_TAGS=$(cat release/DEFAULT_BUILD_TAGS_OTHERS)" >> "$GITHUB_ENV"
|
||||||
|
echo "LDFLAGS_SHARED=$(cat release/LDFLAGS)" >> "$GITHUB_ENV"
|
||||||
|
- name: Test (unix)
|
||||||
|
if: matrix.os != 'windows-latest'
|
||||||
|
run: go test -v -exec sudo -tags "$BUILD_TAGS" -ldflags "$LDFLAGS_SHARED" ./...
|
||||||
|
- name: Test (windows)
|
||||||
|
if: matrix.os == 'windows-latest'
|
||||||
|
shell: bash
|
||||||
|
run: go test -v -tags "$BUILD_TAGS" -ldflags "$LDFLAGS_SHARED" ./...
|
||||||
@@ -19,7 +19,6 @@ linters:
|
|||||||
enable:
|
enable:
|
||||||
- govet
|
- govet
|
||||||
- ineffassign
|
- ineffassign
|
||||||
- paralleltest
|
|
||||||
- staticcheck
|
- staticcheck
|
||||||
settings:
|
settings:
|
||||||
staticcheck:
|
staticcheck:
|
||||||
|
|||||||
2
Makefile
2
Makefile
@@ -52,7 +52,7 @@ lint:
|
|||||||
GOOS=android golangci-lint run ./...
|
GOOS=android golangci-lint run ./...
|
||||||
GOOS=windows golangci-lint run ./...
|
GOOS=windows golangci-lint run ./...
|
||||||
GOOS=darwin golangci-lint run ./...
|
GOOS=darwin golangci-lint run ./...
|
||||||
GOOS=freebsd golangci-lint run ./...
|
# GOOS=freebsd golangci-lint run ./...
|
||||||
|
|
||||||
lint_install:
|
lint_install:
|
||||||
go install -v github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
|
go install -v github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package adapter
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
@@ -35,6 +36,7 @@ type DNSQueryOptions struct {
|
|||||||
Strategy C.DomainStrategy
|
Strategy C.DomainStrategy
|
||||||
LookupStrategy C.DomainStrategy
|
LookupStrategy C.DomainStrategy
|
||||||
DisableCache bool
|
DisableCache bool
|
||||||
|
DisableOptimisticCache bool
|
||||||
RewriteTTL *uint32
|
RewriteTTL *uint32
|
||||||
ClientSubnet netip.Prefix
|
ClientSubnet netip.Prefix
|
||||||
}
|
}
|
||||||
@@ -52,6 +54,7 @@ func DNSQueryOptionsFrom(ctx context.Context, options *option.DomainResolveOptio
|
|||||||
Transport: transport,
|
Transport: transport,
|
||||||
Strategy: C.DomainStrategy(options.Strategy),
|
Strategy: C.DomainStrategy(options.Strategy),
|
||||||
DisableCache: options.DisableCache,
|
DisableCache: options.DisableCache,
|
||||||
|
DisableOptimisticCache: options.DisableOptimisticCache,
|
||||||
RewriteTTL: options.RewriteTTL,
|
RewriteTTL: options.RewriteTTL,
|
||||||
ClientSubnet: options.ClientSubnet.Build(netip.Prefix{}),
|
ClientSubnet: options.ClientSubnet.Build(netip.Prefix{}),
|
||||||
}, nil
|
}, nil
|
||||||
@@ -63,6 +66,13 @@ type RDRCStore interface {
|
|||||||
SaveRDRCAsync(transportName string, qName string, qType uint16, logger logger.Logger)
|
SaveRDRCAsync(transportName string, qName string, qType uint16, logger logger.Logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DNSCacheStore interface {
|
||||||
|
LoadDNSCache(transportName string, qName string, qType uint16) (rawMessage []byte, expireAt time.Time, loaded bool)
|
||||||
|
SaveDNSCache(transportName string, qName string, qType uint16, rawMessage []byte, expireAt time.Time) error
|
||||||
|
SaveDNSCacheAsync(transportName string, qName string, qType uint16, rawMessage []byte, expireAt time.Time, logger logger.Logger)
|
||||||
|
ClearDNSCache() error
|
||||||
|
}
|
||||||
|
|
||||||
type DNSTransport interface {
|
type DNSTransport interface {
|
||||||
Lifecycle
|
Lifecycle
|
||||||
Type() string
|
Type() string
|
||||||
|
|||||||
@@ -47,6 +47,12 @@ type CacheFile interface {
|
|||||||
StoreRDRC() bool
|
StoreRDRC() bool
|
||||||
RDRCStore
|
RDRCStore
|
||||||
|
|
||||||
|
StoreDNS() bool
|
||||||
|
DNSCacheStore
|
||||||
|
|
||||||
|
SetDisableExpire(disableExpire bool)
|
||||||
|
SetOptimisticTimeout(timeout time.Duration)
|
||||||
|
|
||||||
LoadMode() string
|
LoadMode() string
|
||||||
StoreMode(mode string) error
|
StoreMode(mode string) error
|
||||||
LoadSelected(group string) string
|
LoadSelected(group string) string
|
||||||
|
|||||||
43
adapter/http.go
Normal file
43
adapter/http.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HTTPTransport interface {
|
||||||
|
http.RoundTripper
|
||||||
|
CloseIdleConnections()
|
||||||
|
Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
type HTTPClientManager interface {
|
||||||
|
ResolveTransport(ctx context.Context, logger logger.ContextLogger, options option.HTTPClientOptions) (HTTPTransport, error)
|
||||||
|
DefaultTransport() HTTPTransport
|
||||||
|
ResetNetwork()
|
||||||
|
}
|
||||||
|
|
||||||
|
type HTTPStartContext struct {
|
||||||
|
access sync.Mutex
|
||||||
|
transports []HTTPTransport
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHTTPStartContext() *HTTPStartContext {
|
||||||
|
return &HTTPStartContext{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HTTPStartContext) Register(transport HTTPTransport) {
|
||||||
|
c.access.Lock()
|
||||||
|
defer c.access.Unlock()
|
||||||
|
c.transports = append(c.transports, transport)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HTTPStartContext) Close() {
|
||||||
|
for _, transport := range c.transports {
|
||||||
|
transport.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,17 +2,11 @@ package adapter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
C "github.com/sagernet/sing-box/constant"
|
|
||||||
"github.com/sagernet/sing-tun"
|
"github.com/sagernet/sing-tun"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
"github.com/sagernet/sing/common/ntp"
|
|
||||||
"github.com/sagernet/sing/common/x/list"
|
"github.com/sagernet/sing/common/x/list"
|
||||||
|
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
@@ -76,47 +70,10 @@ type RuleSetMetadata struct {
|
|||||||
ContainsWIFIRule bool
|
ContainsWIFIRule bool
|
||||||
ContainsIPCIDRRule bool
|
ContainsIPCIDRRule bool
|
||||||
ContainsDNSQueryTypeRule bool
|
ContainsDNSQueryTypeRule bool
|
||||||
}
|
// ContainsNonIPCIDRRule signals that the rule-set carries at least one sub-rule
|
||||||
type HTTPStartContext struct {
|
// with a predicate other than destination ip_cidr / ip_set, so it can contribute
|
||||||
ctx context.Context
|
// to DNS pre-response matching. A rule-set where this is false and
|
||||||
access sync.Mutex
|
// ContainsIPCIDRRule is true is "pure-IP" and matches nothing before a DNS
|
||||||
httpClientCache map[string]*http.Client
|
// response is available.
|
||||||
}
|
ContainsNonIPCIDRRule bool
|
||||||
|
|
||||||
func NewHTTPStartContext(ctx context.Context) *HTTPStartContext {
|
|
||||||
return &HTTPStartContext{
|
|
||||||
ctx: ctx,
|
|
||||||
httpClientCache: make(map[string]*http.Client),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HTTPStartContext) HTTPClient(detour string, dialer N.Dialer) *http.Client {
|
|
||||||
c.access.Lock()
|
|
||||||
defer c.access.Unlock()
|
|
||||||
if httpClient, loaded := c.httpClientCache[detour]; loaded {
|
|
||||||
return httpClient
|
|
||||||
}
|
|
||||||
httpClient := &http.Client{
|
|
||||||
Transport: &http.Transport{
|
|
||||||
ForceAttemptHTTP2: true,
|
|
||||||
TLSHandshakeTimeout: C.TCPTimeout,
|
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
|
||||||
},
|
|
||||||
TLSClientConfig: &tls.Config{
|
|
||||||
Time: ntp.TimeFuncFromContext(c.ctx),
|
|
||||||
RootCAs: RootPoolFromContext(c.ctx),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
c.httpClientCache[detour] = httpClient
|
|
||||||
return httpClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HTTPStartContext) Close() {
|
|
||||||
c.access.Lock()
|
|
||||||
defer c.access.Unlock()
|
|
||||||
for _, client := range c.httpClientCache {
|
|
||||||
client.CloseIdleConnections()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
42
box.go
42
box.go
@@ -16,12 +16,14 @@ import (
|
|||||||
boxService "github.com/sagernet/sing-box/adapter/service"
|
boxService "github.com/sagernet/sing-box/adapter/service"
|
||||||
"github.com/sagernet/sing-box/common/certificate"
|
"github.com/sagernet/sing-box/common/certificate"
|
||||||
"github.com/sagernet/sing-box/common/dialer"
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
|
"github.com/sagernet/sing-box/common/httpclient"
|
||||||
"github.com/sagernet/sing-box/common/taskmonitor"
|
"github.com/sagernet/sing-box/common/taskmonitor"
|
||||||
"github.com/sagernet/sing-box/common/tls"
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/dns"
|
"github.com/sagernet/sing-box/dns"
|
||||||
"github.com/sagernet/sing-box/experimental"
|
"github.com/sagernet/sing-box/experimental"
|
||||||
"github.com/sagernet/sing-box/experimental/cachefile"
|
"github.com/sagernet/sing-box/experimental/cachefile"
|
||||||
|
"github.com/sagernet/sing-box/experimental/deprecated"
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
"github.com/sagernet/sing-box/protocol/direct"
|
"github.com/sagernet/sing-box/protocol/direct"
|
||||||
@@ -50,6 +52,7 @@ type Box struct {
|
|||||||
dnsRouter *dns.Router
|
dnsRouter *dns.Router
|
||||||
connection *route.ConnectionManager
|
connection *route.ConnectionManager
|
||||||
router *route.Router
|
router *route.Router
|
||||||
|
httpClientService adapter.LifecycleService
|
||||||
internalService []adapter.LifecycleService
|
internalService []adapter.LifecycleService
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
}
|
}
|
||||||
@@ -169,6 +172,7 @@ func New(options Options) (*Box, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var internalServices []adapter.LifecycleService
|
var internalServices []adapter.LifecycleService
|
||||||
|
routeOptions := common.PtrValueOrDefault(options.Route)
|
||||||
certificateOptions := common.PtrValueOrDefault(options.Certificate)
|
certificateOptions := common.PtrValueOrDefault(options.Certificate)
|
||||||
if C.IsAndroid || certificateOptions.Store != "" && certificateOptions.Store != C.CertificateStoreSystem ||
|
if C.IsAndroid || certificateOptions.Store != "" && certificateOptions.Store != C.CertificateStoreSystem ||
|
||||||
len(certificateOptions.Certificate) > 0 ||
|
len(certificateOptions.Certificate) > 0 ||
|
||||||
@@ -181,8 +185,6 @@ func New(options Options) (*Box, error) {
|
|||||||
service.MustRegister[adapter.CertificateStore](ctx, certificateStore)
|
service.MustRegister[adapter.CertificateStore](ctx, certificateStore)
|
||||||
internalServices = append(internalServices, certificateStore)
|
internalServices = append(internalServices, certificateStore)
|
||||||
}
|
}
|
||||||
|
|
||||||
routeOptions := common.PtrValueOrDefault(options.Route)
|
|
||||||
dnsOptions := common.PtrValueOrDefault(options.DNS)
|
dnsOptions := common.PtrValueOrDefault(options.DNS)
|
||||||
endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry)
|
endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry)
|
||||||
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager)
|
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager)
|
||||||
@@ -196,7 +198,10 @@ func New(options Options) (*Box, error) {
|
|||||||
service.MustRegister[adapter.DNSTransportManager](ctx, dnsTransportManager)
|
service.MustRegister[adapter.DNSTransportManager](ctx, dnsTransportManager)
|
||||||
service.MustRegister[adapter.ServiceManager](ctx, serviceManager)
|
service.MustRegister[adapter.ServiceManager](ctx, serviceManager)
|
||||||
service.MustRegister[adapter.CertificateProviderManager](ctx, certificateProviderManager)
|
service.MustRegister[adapter.CertificateProviderManager](ctx, certificateProviderManager)
|
||||||
dnsRouter := dns.NewRouter(ctx, logFactory, dnsOptions)
|
dnsRouter, err := dns.NewRouter(ctx, logFactory, dnsOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "initialize DNS router")
|
||||||
|
}
|
||||||
service.MustRegister[adapter.DNSRouter](ctx, dnsRouter)
|
service.MustRegister[adapter.DNSRouter](ctx, dnsRouter)
|
||||||
service.MustRegister[adapter.DNSRuleSetUpdateValidator](ctx, dnsRouter)
|
service.MustRegister[adapter.DNSRuleSetUpdateValidator](ctx, dnsRouter)
|
||||||
networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions, dnsOptions)
|
networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions, dnsOptions)
|
||||||
@@ -206,6 +211,10 @@ func New(options Options) (*Box, error) {
|
|||||||
service.MustRegister[adapter.NetworkManager](ctx, networkManager)
|
service.MustRegister[adapter.NetworkManager](ctx, networkManager)
|
||||||
connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection"))
|
connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection"))
|
||||||
service.MustRegister[adapter.ConnectionManager](ctx, connectionManager)
|
service.MustRegister[adapter.ConnectionManager](ctx, connectionManager)
|
||||||
|
// Must register after ConnectionManager: the Apple HTTP engine's proxy bridge reads it from the context when Manager.Start resolves the default client.
|
||||||
|
httpClientManager := httpclient.NewManager(ctx, logFactory.NewLogger("httpclient"), options.HTTPClients, routeOptions.DefaultHTTPClient)
|
||||||
|
service.MustRegister[adapter.HTTPClientManager](ctx, httpClientManager)
|
||||||
|
httpClientService := adapter.LifecycleService(httpClientManager)
|
||||||
router := route.NewRouter(ctx, logFactory, routeOptions, dnsOptions)
|
router := route.NewRouter(ctx, logFactory, routeOptions, dnsOptions)
|
||||||
service.MustRegister[adapter.Router](ctx, router)
|
service.MustRegister[adapter.Router](ctx, router)
|
||||||
err = router.Initialize(routeOptions.Rules, routeOptions.RuleSet)
|
err = router.Initialize(routeOptions.Rules, routeOptions.RuleSet)
|
||||||
@@ -365,6 +374,12 @@ func New(options Options) (*Box, error) {
|
|||||||
&option.LocalDNSServerOptions{},
|
&option.LocalDNSServerOptions{},
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
httpClientManager.Initialize(func() (*httpclient.ManagedTransport, error) {
|
||||||
|
deprecated.Report(ctx, deprecated.OptionImplicitDefaultHTTPClient)
|
||||||
|
var httpClientOptions option.HTTPClientOptions
|
||||||
|
httpClientOptions.DefaultOutbound = true
|
||||||
|
return httpclient.NewTransport(ctx, logFactory.NewLogger("httpclient"), "", httpClientOptions)
|
||||||
|
})
|
||||||
if platformInterface != nil {
|
if platformInterface != nil {
|
||||||
err = platformInterface.Initialize(networkManager)
|
err = platformInterface.Initialize(networkManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -372,7 +387,7 @@ func New(options Options) (*Box, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if needCacheFile {
|
if needCacheFile {
|
||||||
cacheFile := cachefile.New(ctx, common.PtrValueOrDefault(experimentalOptions.CacheFile))
|
cacheFile := cachefile.New(ctx, logFactory.NewLogger("cache-file"), common.PtrValueOrDefault(experimentalOptions.CacheFile))
|
||||||
service.MustRegister[adapter.CacheFile](ctx, cacheFile)
|
service.MustRegister[adapter.CacheFile](ctx, cacheFile)
|
||||||
internalServices = append(internalServices, cacheFile)
|
internalServices = append(internalServices, cacheFile)
|
||||||
}
|
}
|
||||||
@@ -425,6 +440,7 @@ func New(options Options) (*Box, error) {
|
|||||||
dnsRouter: dnsRouter,
|
dnsRouter: dnsRouter,
|
||||||
connection: connectionManager,
|
connection: connectionManager,
|
||||||
router: router,
|
router: router,
|
||||||
|
httpClientService: httpClientService,
|
||||||
createdAt: createdAt,
|
createdAt: createdAt,
|
||||||
logFactory: logFactory,
|
logFactory: logFactory,
|
||||||
logger: logFactory.Logger(),
|
logger: logFactory.Logger(),
|
||||||
@@ -487,7 +503,15 @@ func (s *Box) preStart() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = adapter.Start(s.logger, adapter.StartStateStart, s.outbound, s.dnsTransport, s.network, s.connection, s.router, s.dnsRouter)
|
err = adapter.Start(s.logger, adapter.StartStateStart, s.outbound, s.dnsTransport, s.network, s.connection)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = adapter.StartNamed(s.logger, adapter.StartStateStart, []adapter.LifecycleService{s.httpClientService})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = adapter.Start(s.logger, adapter.StartStateStart, s.router, s.dnsRouter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -564,6 +588,14 @@ func (s *Box) Close() error {
|
|||||||
})
|
})
|
||||||
s.logger.Trace("close ", closeItem.name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
|
s.logger.Trace("close ", closeItem.name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
|
||||||
}
|
}
|
||||||
|
if s.httpClientService != nil {
|
||||||
|
s.logger.Trace("close ", s.httpClientService.Name())
|
||||||
|
startTime := time.Now()
|
||||||
|
err = E.Append(err, s.httpClientService.Close(), func(err error) error {
|
||||||
|
return E.Cause(err, "close ", s.httpClientService.Name())
|
||||||
|
})
|
||||||
|
s.logger.Trace("close ", s.httpClientService.Name(), " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
|
||||||
|
}
|
||||||
for _, lifecycleService := range s.internalService {
|
for _, lifecycleService := range s.internalService {
|
||||||
s.logger.Trace("close ", lifecycleService.Name())
|
s.logger.Trace("close ", lifecycleService.Name())
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|||||||
Submodule clients/android updated: fea0f3a7ba...ab09918615
Submodule clients/apple updated: ffbf405b52...ad7434d676
@@ -204,6 +204,9 @@ func buildApple() {
|
|||||||
"-target", bindTarget,
|
"-target", bindTarget,
|
||||||
"-libname=box",
|
"-libname=box",
|
||||||
"-tags-not-macos=with_low_memory",
|
"-tags-not-macos=with_low_memory",
|
||||||
|
"-iosversion=15.0",
|
||||||
|
"-macosversion=13.0",
|
||||||
|
"-tvosversion=17.0",
|
||||||
}
|
}
|
||||||
//if !withTailscale {
|
//if !withTailscale {
|
||||||
// args = append(args, "-tags-macos="+strings.Join(memcTags, ","))
|
// args = append(args, "-tags-macos="+strings.Join(memcTags, ","))
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
@@ -35,21 +36,9 @@ func updateMozillaIncludedRootCAs() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
geoIndex := slices.Index(header, "Geographic Focus")
|
geoIndex := slices.Index(header, "Geographic Focus")
|
||||||
nameIndex := slices.Index(header, "Common Name or Certificate Name")
|
|
||||||
certIndex := slices.Index(header, "PEM Info")
|
certIndex := slices.Index(header, "PEM Info")
|
||||||
|
|
||||||
generated := strings.Builder{}
|
pemBundle := strings.Builder{}
|
||||||
generated.WriteString(`// Code generated by 'make update_certificates'. DO NOT EDIT.
|
|
||||||
|
|
||||||
package certificate
|
|
||||||
|
|
||||||
import "crypto/x509"
|
|
||||||
|
|
||||||
var mozillaIncluded *x509.CertPool
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
mozillaIncluded = x509.NewCertPool()
|
|
||||||
`)
|
|
||||||
for {
|
for {
|
||||||
record, err := reader.Read()
|
record, err := reader.Read()
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
@@ -60,18 +49,12 @@ func init() {
|
|||||||
if record[geoIndex] == "China" {
|
if record[geoIndex] == "China" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
generated.WriteString("\n // ")
|
|
||||||
generated.WriteString(record[nameIndex])
|
|
||||||
generated.WriteString("\n")
|
|
||||||
generated.WriteString(" mozillaIncluded.AppendCertsFromPEM([]byte(`")
|
|
||||||
cert := record[certIndex]
|
cert := record[certIndex]
|
||||||
// Remove single quotes
|
|
||||||
cert = cert[1 : len(cert)-1]
|
cert = cert[1 : len(cert)-1]
|
||||||
generated.WriteString(cert)
|
pemBundle.WriteString(cert)
|
||||||
generated.WriteString("`))\n")
|
pemBundle.WriteString("\n")
|
||||||
}
|
}
|
||||||
generated.WriteString("}\n")
|
return writeGeneratedCertificateBundle("mozilla", "mozillaIncluded", pemBundle.String())
|
||||||
return os.WriteFile("common/certificate/mozilla.go", []byte(generated.String()), 0o644)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchChinaFingerprints() (map[string]bool, error) {
|
func fetchChinaFingerprints() (map[string]bool, error) {
|
||||||
@@ -119,23 +102,11 @@ func updateChromeIncludedRootCAs() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
subjectIndex := slices.Index(header, "Subject")
|
|
||||||
statusIndex := slices.Index(header, "Google Chrome Status")
|
statusIndex := slices.Index(header, "Google Chrome Status")
|
||||||
certIndex := slices.Index(header, "X.509 Certificate (PEM)")
|
certIndex := slices.Index(header, "X.509 Certificate (PEM)")
|
||||||
fingerprintIndex := slices.Index(header, "SHA-256 Fingerprint")
|
fingerprintIndex := slices.Index(header, "SHA-256 Fingerprint")
|
||||||
|
|
||||||
generated := strings.Builder{}
|
pemBundle := strings.Builder{}
|
||||||
generated.WriteString(`// Code generated by 'make update_certificates'. DO NOT EDIT.
|
|
||||||
|
|
||||||
package certificate
|
|
||||||
|
|
||||||
import "crypto/x509"
|
|
||||||
|
|
||||||
var chromeIncluded *x509.CertPool
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
chromeIncluded = x509.NewCertPool()
|
|
||||||
`)
|
|
||||||
for {
|
for {
|
||||||
record, err := reader.Read()
|
record, err := reader.Read()
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
@@ -149,18 +120,39 @@ func init() {
|
|||||||
if chinaFingerprints[record[fingerprintIndex]] {
|
if chinaFingerprints[record[fingerprintIndex]] {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
generated.WriteString("\n // ")
|
|
||||||
generated.WriteString(record[subjectIndex])
|
|
||||||
generated.WriteString("\n")
|
|
||||||
generated.WriteString(" chromeIncluded.AppendCertsFromPEM([]byte(`")
|
|
||||||
cert := record[certIndex]
|
cert := record[certIndex]
|
||||||
// Remove single quotes if present
|
|
||||||
if len(cert) > 0 && cert[0] == '\'' {
|
if len(cert) > 0 && cert[0] == '\'' {
|
||||||
cert = cert[1 : len(cert)-1]
|
cert = cert[1 : len(cert)-1]
|
||||||
}
|
}
|
||||||
generated.WriteString(cert)
|
pemBundle.WriteString(cert)
|
||||||
generated.WriteString("`))\n")
|
pemBundle.WriteString("\n")
|
||||||
}
|
}
|
||||||
generated.WriteString("}\n")
|
return writeGeneratedCertificateBundle("chrome", "chromeIncluded", pemBundle.String())
|
||||||
return os.WriteFile("common/certificate/chrome.go", []byte(generated.String()), 0o644)
|
}
|
||||||
|
|
||||||
|
func writeGeneratedCertificateBundle(name string, variableName string, pemBundle string) error {
|
||||||
|
goSource := `// Code generated by 'make update_certificates'. DO NOT EDIT.
|
||||||
|
|
||||||
|
package certificate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
_ "embed"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed ` + name + `.pem
|
||||||
|
var ` + variableName + `PEM string
|
||||||
|
|
||||||
|
var ` + variableName + ` *x509.CertPool
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
` + variableName + ` = x509.NewCertPool()
|
||||||
|
` + variableName + `.AppendCertsFromPEM([]byte(` + variableName + `PEM))
|
||||||
|
}
|
||||||
|
`
|
||||||
|
err := os.WriteFile(filepath.Join("common/certificate", name+".pem"), []byte(pemBundle), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.WriteFile(filepath.Join("common/certificate", name+".go"), []byte(goSource), 0o644)
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
2650
common/certificate/chrome.pem
Normal file
2650
common/certificate/chrome.pem
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
4256
common/certificate/mozilla.pem
Normal file
4256
common/certificate/mozilla.pem
Normal file
File diff suppressed because it is too large
Load Diff
@@ -22,8 +22,10 @@ var _ adapter.CertificateStore = (*Store)(nil)
|
|||||||
|
|
||||||
type Store struct {
|
type Store struct {
|
||||||
access sync.RWMutex
|
access sync.RWMutex
|
||||||
|
store string
|
||||||
systemPool *x509.CertPool
|
systemPool *x509.CertPool
|
||||||
currentPool *x509.CertPool
|
currentPool *x509.CertPool
|
||||||
|
currentPEM []string
|
||||||
certificate string
|
certificate string
|
||||||
certificatePaths []string
|
certificatePaths []string
|
||||||
certificateDirectoryPaths []string
|
certificateDirectoryPaths []string
|
||||||
@@ -61,6 +63,7 @@ func NewStore(ctx context.Context, logger logger.Logger, options option.Certific
|
|||||||
return nil, E.New("unknown certificate store: ", options.Store)
|
return nil, E.New("unknown certificate store: ", options.Store)
|
||||||
}
|
}
|
||||||
store := &Store{
|
store := &Store{
|
||||||
|
store: options.Store,
|
||||||
systemPool: systemPool,
|
systemPool: systemPool,
|
||||||
certificate: strings.Join(options.Certificate, "\n"),
|
certificate: strings.Join(options.Certificate, "\n"),
|
||||||
certificatePaths: options.CertificatePath,
|
certificatePaths: options.CertificatePath,
|
||||||
@@ -123,19 +126,37 @@ func (s *Store) Pool() *x509.CertPool {
|
|||||||
return s.currentPool
|
return s.currentPool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) StoreKind() string {
|
||||||
|
return s.store
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CurrentPEM() []string {
|
||||||
|
s.access.RLock()
|
||||||
|
defer s.access.RUnlock()
|
||||||
|
return append([]string(nil), s.currentPEM...)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) update() error {
|
func (s *Store) update() error {
|
||||||
s.access.Lock()
|
s.access.Lock()
|
||||||
defer s.access.Unlock()
|
defer s.access.Unlock()
|
||||||
var currentPool *x509.CertPool
|
var currentPool *x509.CertPool
|
||||||
|
var currentPEM []string
|
||||||
if s.systemPool == nil {
|
if s.systemPool == nil {
|
||||||
currentPool = x509.NewCertPool()
|
currentPool = x509.NewCertPool()
|
||||||
} else {
|
} else {
|
||||||
currentPool = s.systemPool.Clone()
|
currentPool = s.systemPool.Clone()
|
||||||
}
|
}
|
||||||
|
switch s.store {
|
||||||
|
case C.CertificateStoreMozilla:
|
||||||
|
currentPEM = append(currentPEM, mozillaIncludedPEM)
|
||||||
|
case C.CertificateStoreChrome:
|
||||||
|
currentPEM = append(currentPEM, chromeIncludedPEM)
|
||||||
|
}
|
||||||
if s.certificate != "" {
|
if s.certificate != "" {
|
||||||
if !currentPool.AppendCertsFromPEM([]byte(s.certificate)) {
|
if !currentPool.AppendCertsFromPEM([]byte(s.certificate)) {
|
||||||
return E.New("invalid certificate PEM strings")
|
return E.New("invalid certificate PEM strings")
|
||||||
}
|
}
|
||||||
|
currentPEM = append(currentPEM, s.certificate)
|
||||||
}
|
}
|
||||||
for _, path := range s.certificatePaths {
|
for _, path := range s.certificatePaths {
|
||||||
pemContent, err := os.ReadFile(path)
|
pemContent, err := os.ReadFile(path)
|
||||||
@@ -145,6 +166,7 @@ func (s *Store) update() error {
|
|||||||
if !currentPool.AppendCertsFromPEM(pemContent) {
|
if !currentPool.AppendCertsFromPEM(pemContent) {
|
||||||
return E.New("invalid certificate PEM file: ", path)
|
return E.New("invalid certificate PEM file: ", path)
|
||||||
}
|
}
|
||||||
|
currentPEM = append(currentPEM, string(pemContent))
|
||||||
}
|
}
|
||||||
var firstErr error
|
var firstErr error
|
||||||
for _, directoryPath := range s.certificateDirectoryPaths {
|
for _, directoryPath := range s.certificateDirectoryPaths {
|
||||||
@@ -157,8 +179,8 @@ func (s *Store) update() error {
|
|||||||
}
|
}
|
||||||
for _, directoryEntry := range directoryEntries {
|
for _, directoryEntry := range directoryEntries {
|
||||||
pemContent, err := os.ReadFile(filepath.Join(directoryPath, directoryEntry.Name()))
|
pemContent, err := os.ReadFile(filepath.Join(directoryPath, directoryEntry.Name()))
|
||||||
if err == nil {
|
if err == nil && currentPool.AppendCertsFromPEM(pemContent) {
|
||||||
currentPool.AppendCertsFromPEM(pemContent)
|
currentPEM = append(currentPEM, string(pemContent))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -166,6 +188,7 @@ func (s *Store) update() error {
|
|||||||
return firstErr
|
return firstErr
|
||||||
}
|
}
|
||||||
s.currentPool = currentPool
|
s.currentPool = currentPool
|
||||||
|
s.currentPEM = currentPEM
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,17 +19,25 @@ type DirectDialer interface {
|
|||||||
type DetourDialer struct {
|
type DetourDialer struct {
|
||||||
outboundManager adapter.OutboundManager
|
outboundManager adapter.OutboundManager
|
||||||
detour string
|
detour string
|
||||||
legacyDNSDialer bool
|
defaultOutbound bool
|
||||||
|
disableEmptyDirectCheck bool
|
||||||
dialer N.Dialer
|
dialer N.Dialer
|
||||||
initOnce sync.Once
|
initOnce sync.Once
|
||||||
initErr error
|
initErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDetour(outboundManager adapter.OutboundManager, detour string, legacyDNSDialer bool) N.Dialer {
|
func NewDetour(outboundManager adapter.OutboundManager, detour string, disableEmptyDirectCheck bool) N.Dialer {
|
||||||
return &DetourDialer{
|
return &DetourDialer{
|
||||||
outboundManager: outboundManager,
|
outboundManager: outboundManager,
|
||||||
detour: detour,
|
detour: detour,
|
||||||
legacyDNSDialer: legacyDNSDialer,
|
disableEmptyDirectCheck: disableEmptyDirectCheck,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDefaultOutboundDetour(outboundManager adapter.OutboundManager) N.Dialer {
|
||||||
|
return &DetourDialer{
|
||||||
|
outboundManager: outboundManager,
|
||||||
|
defaultOutbound: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -47,12 +55,18 @@ func (d *DetourDialer) Dialer() (N.Dialer, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DetourDialer) init() {
|
func (d *DetourDialer) init() {
|
||||||
dialer, loaded := d.outboundManager.Outbound(d.detour)
|
var dialer adapter.Outbound
|
||||||
|
if d.detour != "" {
|
||||||
|
var loaded bool
|
||||||
|
dialer, loaded = d.outboundManager.Outbound(d.detour)
|
||||||
if !loaded {
|
if !loaded {
|
||||||
d.initErr = E.New("outbound detour not found: ", d.detour)
|
d.initErr = E.New("outbound detour not found: ", d.detour)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !d.legacyDNSDialer {
|
} else {
|
||||||
|
dialer = d.outboundManager.Default()
|
||||||
|
}
|
||||||
|
if !d.defaultOutbound && !d.disableEmptyDirectCheck {
|
||||||
if directDialer, isDirect := dialer.(DirectDialer); isDirect {
|
if directDialer, isDirect := dialer.(DirectDialer); isDirect {
|
||||||
if directDialer.IsEmpty() {
|
if directDialer.IsEmpty() {
|
||||||
d.initErr = E.New("detour to an empty direct outbound makes no sense")
|
d.initErr = E.New("detour to an empty direct outbound makes no sense")
|
||||||
|
|||||||
@@ -23,8 +23,9 @@ type Options struct {
|
|||||||
DirectResolver bool
|
DirectResolver bool
|
||||||
ResolverOnDetour bool
|
ResolverOnDetour bool
|
||||||
NewDialer bool
|
NewDialer bool
|
||||||
LegacyDNSDialer bool
|
DisableEmptyDirectCheck bool
|
||||||
DirectOutbound bool
|
DirectOutbound bool
|
||||||
|
DefaultOutbound bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: merge with NewWithOptions
|
// TODO: merge with NewWithOptions
|
||||||
@@ -42,19 +43,26 @@ func NewWithOptions(options Options) (N.Dialer, error) {
|
|||||||
dialer N.Dialer
|
dialer N.Dialer
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
hasDetour := dialOptions.Detour != "" || options.DefaultOutbound
|
||||||
if dialOptions.Detour != "" {
|
if dialOptions.Detour != "" {
|
||||||
outboundManager := service.FromContext[adapter.OutboundManager](options.Context)
|
outboundManager := service.FromContext[adapter.OutboundManager](options.Context)
|
||||||
if outboundManager == nil {
|
if outboundManager == nil {
|
||||||
return nil, E.New("missing outbound manager")
|
return nil, E.New("missing outbound manager")
|
||||||
}
|
}
|
||||||
dialer = NewDetour(outboundManager, dialOptions.Detour, options.LegacyDNSDialer)
|
dialer = NewDetour(outboundManager, dialOptions.Detour, options.DisableEmptyDirectCheck)
|
||||||
|
} else if options.DefaultOutbound {
|
||||||
|
outboundManager := service.FromContext[adapter.OutboundManager](options.Context)
|
||||||
|
if outboundManager == nil {
|
||||||
|
return nil, E.New("missing outbound manager")
|
||||||
|
}
|
||||||
|
dialer = NewDefaultOutboundDetour(outboundManager)
|
||||||
} else {
|
} else {
|
||||||
dialer, err = NewDefault(options.Context, dialOptions)
|
dialer, err = NewDefault(options.Context, dialOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if options.RemoteIsDomain && (dialOptions.Detour == "" || options.ResolverOnDetour || dialOptions.DomainResolver != nil && dialOptions.DomainResolver.Server != "") {
|
if options.RemoteIsDomain && (!hasDetour || options.ResolverOnDetour || dialOptions.DomainResolver != nil && dialOptions.DomainResolver.Server != "") {
|
||||||
networkManager := service.FromContext[adapter.NetworkManager](options.Context)
|
networkManager := service.FromContext[adapter.NetworkManager](options.Context)
|
||||||
dnsTransport := service.FromContext[adapter.DNSTransportManager](options.Context)
|
dnsTransport := service.FromContext[adapter.DNSTransportManager](options.Context)
|
||||||
var defaultOptions adapter.NetworkOptions
|
var defaultOptions adapter.NetworkOptions
|
||||||
@@ -90,6 +98,7 @@ func NewWithOptions(options Options) (N.Dialer, error) {
|
|||||||
Transport: transport,
|
Transport: transport,
|
||||||
Strategy: strategy,
|
Strategy: strategy,
|
||||||
DisableCache: dialOptions.DomainResolver.DisableCache,
|
DisableCache: dialOptions.DomainResolver.DisableCache,
|
||||||
|
DisableOptimisticCache: dialOptions.DomainResolver.DisableOptimisticCache,
|
||||||
RewriteTTL: dialOptions.DomainResolver.RewriteTTL,
|
RewriteTTL: dialOptions.DomainResolver.RewriteTTL,
|
||||||
ClientSubnet: dialOptions.DomainResolver.ClientSubnet.Build(netip.Prefix{}),
|
ClientSubnet: dialOptions.DomainResolver.ClientSubnet.Build(netip.Prefix{}),
|
||||||
}
|
}
|
||||||
|
|||||||
423
common/httpclient/apple_transport_darwin.go
Normal file
423
common/httpclient/apple_transport_darwin.go
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
//go:build darwin && cgo
|
||||||
|
|
||||||
|
package httpclient
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CFLAGS: -x objective-c -fobjc-arc
|
||||||
|
#cgo LDFLAGS: -framework Foundation -framework Security
|
||||||
|
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include "apple_transport_darwin.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/common/proxybridge"
|
||||||
|
boxTLS "github.com/sagernet/sing-box/common/tls"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/sing/common/ntp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const applePinnedHashSize = sha256.Size
|
||||||
|
|
||||||
|
func verifyApplePinnedPublicKeySHA256(flatHashes []byte, leafCertificate []byte) error {
|
||||||
|
if len(flatHashes)%applePinnedHashSize != 0 {
|
||||||
|
return E.New("invalid pinned public key list")
|
||||||
|
}
|
||||||
|
knownHashes := make([][]byte, 0, len(flatHashes)/applePinnedHashSize)
|
||||||
|
for offset := 0; offset < len(flatHashes); offset += applePinnedHashSize {
|
||||||
|
knownHashes = append(knownHashes, append([]byte(nil), flatHashes[offset:offset+applePinnedHashSize]...))
|
||||||
|
}
|
||||||
|
return boxTLS.VerifyPublicKeySHA256(knownHashes, [][]byte{leafCertificate})
|
||||||
|
}
|
||||||
|
|
||||||
|
//export box_apple_http_verify_public_key_sha256
|
||||||
|
func box_apple_http_verify_public_key_sha256(knownHashValues *C.uint8_t, knownHashValuesLen C.size_t, leafCert *C.uint8_t, leafCertLen C.size_t) *C.char {
|
||||||
|
flatHashes := C.GoBytes(unsafe.Pointer(knownHashValues), C.int(knownHashValuesLen))
|
||||||
|
leafCertificate := C.GoBytes(unsafe.Pointer(leafCert), C.int(leafCertLen))
|
||||||
|
err := verifyApplePinnedPublicKeySHA256(flatHashes, leafCertificate)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return C.CString(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
type appleSessionConfig struct {
|
||||||
|
serverName string
|
||||||
|
minVersion uint16
|
||||||
|
maxVersion uint16
|
||||||
|
insecure bool
|
||||||
|
anchorPEM string
|
||||||
|
anchorOnly bool
|
||||||
|
pinnedPublicKeySHA256s []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type appleTransportShared struct {
|
||||||
|
logger logger.ContextLogger
|
||||||
|
bridge *proxybridge.Bridge
|
||||||
|
config appleSessionConfig
|
||||||
|
timeFunc func() time.Time
|
||||||
|
refs atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
type appleTransport struct {
|
||||||
|
shared *appleTransportShared
|
||||||
|
access sync.Mutex
|
||||||
|
session *C.box_apple_http_session_t
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleTransport(ctx context.Context, logger logger.ContextLogger, rawDialer N.Dialer, options option.HTTPClientOptions) (innerTransport, error) {
|
||||||
|
sessionConfig, err := newAppleSessionConfig(ctx, options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
bridge, err := proxybridge.New(ctx, logger, "apple http proxy", rawDialer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
shared := &appleTransportShared{
|
||||||
|
logger: logger,
|
||||||
|
bridge: bridge,
|
||||||
|
config: sessionConfig,
|
||||||
|
timeFunc: ntp.TimeFuncFromContext(ctx),
|
||||||
|
}
|
||||||
|
shared.refs.Store(1)
|
||||||
|
session, err := shared.newSession()
|
||||||
|
if err != nil {
|
||||||
|
bridge.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &appleTransport{
|
||||||
|
shared: shared,
|
||||||
|
session: session,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleSessionConfig(ctx context.Context, options option.HTTPClientOptions) (appleSessionConfig, error) {
|
||||||
|
version := options.Version
|
||||||
|
if version == 0 {
|
||||||
|
version = 2
|
||||||
|
}
|
||||||
|
switch version {
|
||||||
|
case 2:
|
||||||
|
case 1:
|
||||||
|
return appleSessionConfig{}, E.New("HTTP/1.1 is unsupported in Apple HTTP engine")
|
||||||
|
case 3:
|
||||||
|
return appleSessionConfig{}, E.New("HTTP/3 is unsupported in Apple HTTP engine")
|
||||||
|
default:
|
||||||
|
return appleSessionConfig{}, E.New("unknown HTTP version: ", version)
|
||||||
|
}
|
||||||
|
if options.DisableVersionFallback {
|
||||||
|
return appleSessionConfig{}, E.New("disable_version_fallback is unsupported in Apple HTTP engine")
|
||||||
|
}
|
||||||
|
if options.HTTP2Options != (option.HTTP2Options{}) {
|
||||||
|
return appleSessionConfig{}, E.New("HTTP/2 options are unsupported in Apple HTTP engine")
|
||||||
|
}
|
||||||
|
if options.HTTP3Options != (option.QUICOptions{}) {
|
||||||
|
return appleSessionConfig{}, E.New("QUIC options are unsupported in Apple HTTP engine")
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsOptions := common.PtrValueOrDefault(options.TLS)
|
||||||
|
if tlsOptions.Engine != "" {
|
||||||
|
return appleSessionConfig{}, E.New("tls.engine is unsupported in Apple HTTP engine")
|
||||||
|
}
|
||||||
|
if len(tlsOptions.ALPN) > 0 {
|
||||||
|
return appleSessionConfig{}, E.New("tls.alpn is unsupported in Apple HTTP engine")
|
||||||
|
}
|
||||||
|
validated, err := boxTLS.ValidateAppleTLSOptions(ctx, tlsOptions, "Apple HTTP engine")
|
||||||
|
if err != nil {
|
||||||
|
return appleSessionConfig{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
config := appleSessionConfig{
|
||||||
|
serverName: tlsOptions.ServerName,
|
||||||
|
minVersion: validated.MinVersion,
|
||||||
|
maxVersion: validated.MaxVersion,
|
||||||
|
insecure: tlsOptions.Insecure || len(tlsOptions.CertificatePublicKeySHA256) > 0,
|
||||||
|
anchorPEM: validated.AnchorPEM,
|
||||||
|
anchorOnly: validated.AnchorOnly,
|
||||||
|
}
|
||||||
|
if len(tlsOptions.CertificatePublicKeySHA256) > 0 {
|
||||||
|
config.pinnedPublicKeySHA256s = make([]byte, 0, len(tlsOptions.CertificatePublicKeySHA256)*applePinnedHashSize)
|
||||||
|
for _, hashValue := range tlsOptions.CertificatePublicKeySHA256 {
|
||||||
|
if len(hashValue) != applePinnedHashSize {
|
||||||
|
return appleSessionConfig{}, E.New("invalid certificate_public_key_sha256 length: ", len(hashValue))
|
||||||
|
}
|
||||||
|
config.pinnedPublicKeySHA256s = append(config.pinnedPublicKeySHA256s, hashValue...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appleTransportShared) retain() {
|
||||||
|
s.refs.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appleTransportShared) release() error {
|
||||||
|
if s.refs.Add(-1) == 0 {
|
||||||
|
return s.bridge.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appleTransportShared) newSession() (*C.box_apple_http_session_t, error) {
|
||||||
|
cProxyHost := C.CString("127.0.0.1")
|
||||||
|
defer C.free(unsafe.Pointer(cProxyHost))
|
||||||
|
cProxyUsername := C.CString(s.bridge.Username())
|
||||||
|
defer C.free(unsafe.Pointer(cProxyUsername))
|
||||||
|
cProxyPassword := C.CString(s.bridge.Password())
|
||||||
|
defer C.free(unsafe.Pointer(cProxyPassword))
|
||||||
|
var cAnchorPEM *C.char
|
||||||
|
if s.config.anchorPEM != "" {
|
||||||
|
cAnchorPEM = C.CString(s.config.anchorPEM)
|
||||||
|
defer C.free(unsafe.Pointer(cAnchorPEM))
|
||||||
|
}
|
||||||
|
var pinnedPointer *C.uint8_t
|
||||||
|
if len(s.config.pinnedPublicKeySHA256s) > 0 {
|
||||||
|
pinnedPointer = (*C.uint8_t)(C.CBytes(s.config.pinnedPublicKeySHA256s))
|
||||||
|
defer C.free(unsafe.Pointer(pinnedPointer))
|
||||||
|
}
|
||||||
|
cConfig := C.box_apple_http_session_config_t{
|
||||||
|
proxy_host: cProxyHost,
|
||||||
|
proxy_port: C.int(s.bridge.Port()),
|
||||||
|
proxy_username: cProxyUsername,
|
||||||
|
proxy_password: cProxyPassword,
|
||||||
|
min_tls_version: C.uint16_t(s.config.minVersion),
|
||||||
|
max_tls_version: C.uint16_t(s.config.maxVersion),
|
||||||
|
insecure: C.bool(s.config.insecure),
|
||||||
|
anchor_pem: cAnchorPEM,
|
||||||
|
anchor_pem_len: C.size_t(len(s.config.anchorPEM)),
|
||||||
|
anchor_only: C.bool(s.config.anchorOnly),
|
||||||
|
pinned_public_key_sha256: pinnedPointer,
|
||||||
|
pinned_public_key_sha256_len: C.size_t(len(s.config.pinnedPublicKeySHA256s)),
|
||||||
|
}
|
||||||
|
var cErr *C.char
|
||||||
|
session := C.box_apple_http_session_create(&cConfig, &cErr)
|
||||||
|
if session != nil {
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
return nil, appleCStringError(cErr, "create Apple HTTP session")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *appleTransport) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
if requestRequiresHTTP1(request) {
|
||||||
|
return nil, E.New("HTTP upgrade requests are unsupported in Apple HTTP engine")
|
||||||
|
}
|
||||||
|
if request.URL == nil {
|
||||||
|
return nil, E.New("missing request URL")
|
||||||
|
}
|
||||||
|
switch request.URL.Scheme {
|
||||||
|
case "http", "https":
|
||||||
|
default:
|
||||||
|
return nil, E.New("unsupported URL scheme: ", request.URL.Scheme)
|
||||||
|
}
|
||||||
|
if request.URL.Scheme == "https" && t.shared.config.serverName != "" && !strings.EqualFold(t.shared.config.serverName, request.URL.Hostname()) {
|
||||||
|
return nil, E.New("tls.server_name is unsupported in Apple HTTP engine unless it matches request host")
|
||||||
|
}
|
||||||
|
var body []byte
|
||||||
|
if request.Body != nil && request.Body != http.NoBody {
|
||||||
|
defer request.Body.Close()
|
||||||
|
var err error
|
||||||
|
body, err = io.ReadAll(request.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
headerKeys, headerValues := flattenRequestHeaders(request)
|
||||||
|
cMethod := C.CString(request.Method)
|
||||||
|
defer C.free(unsafe.Pointer(cMethod))
|
||||||
|
cURL := C.CString(request.URL.String())
|
||||||
|
defer C.free(unsafe.Pointer(cURL))
|
||||||
|
cHeaderKeys := make([]*C.char, len(headerKeys))
|
||||||
|
cHeaderValues := make([]*C.char, len(headerValues))
|
||||||
|
defer func() {
|
||||||
|
for _, ptr := range cHeaderKeys {
|
||||||
|
C.free(unsafe.Pointer(ptr))
|
||||||
|
}
|
||||||
|
for _, ptr := range cHeaderValues {
|
||||||
|
C.free(unsafe.Pointer(ptr))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
for index, value := range headerKeys {
|
||||||
|
cHeaderKeys[index] = C.CString(value)
|
||||||
|
}
|
||||||
|
for index, value := range headerValues {
|
||||||
|
cHeaderValues[index] = C.CString(value)
|
||||||
|
}
|
||||||
|
var headerKeysPointer **C.char
|
||||||
|
var headerValuesPointer **C.char
|
||||||
|
if len(cHeaderKeys) > 0 {
|
||||||
|
pointerArraySize := C.size_t(len(cHeaderKeys)) * C.size_t(unsafe.Sizeof((*C.char)(nil)))
|
||||||
|
headerKeysPointer = (**C.char)(C.malloc(pointerArraySize))
|
||||||
|
defer C.free(unsafe.Pointer(headerKeysPointer))
|
||||||
|
headerValuesPointer = (**C.char)(C.malloc(pointerArraySize))
|
||||||
|
defer C.free(unsafe.Pointer(headerValuesPointer))
|
||||||
|
copy(unsafe.Slice(headerKeysPointer, len(cHeaderKeys)), cHeaderKeys)
|
||||||
|
copy(unsafe.Slice(headerValuesPointer, len(cHeaderValues)), cHeaderValues)
|
||||||
|
}
|
||||||
|
var bodyPointer *C.uint8_t
|
||||||
|
if len(body) > 0 {
|
||||||
|
bodyPointer = (*C.uint8_t)(C.CBytes(body))
|
||||||
|
defer C.free(unsafe.Pointer(bodyPointer))
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
hasVerifyTime bool
|
||||||
|
verifyTimeUnixMilli int64
|
||||||
|
)
|
||||||
|
if t.shared.timeFunc != nil {
|
||||||
|
hasVerifyTime = true
|
||||||
|
verifyTimeUnixMilli = t.shared.timeFunc().UnixMilli()
|
||||||
|
}
|
||||||
|
cRequest := C.box_apple_http_request_t{
|
||||||
|
method: cMethod,
|
||||||
|
url: cURL,
|
||||||
|
header_keys: (**C.char)(headerKeysPointer),
|
||||||
|
header_values: (**C.char)(headerValuesPointer),
|
||||||
|
header_count: C.size_t(len(cHeaderKeys)),
|
||||||
|
body: bodyPointer,
|
||||||
|
body_len: C.size_t(len(body)),
|
||||||
|
has_verify_time: C.bool(hasVerifyTime),
|
||||||
|
verify_time_unix_millis: C.int64_t(verifyTimeUnixMilli),
|
||||||
|
}
|
||||||
|
var cErr *C.char
|
||||||
|
var task *C.box_apple_http_task_t
|
||||||
|
t.access.Lock()
|
||||||
|
if t.session == nil {
|
||||||
|
t.access.Unlock()
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
// Keep the session attached until NSURLSession has created the task.
|
||||||
|
task = C.box_apple_http_session_send_async(t.session, &cRequest, &cErr)
|
||||||
|
t.access.Unlock()
|
||||||
|
if task == nil {
|
||||||
|
return nil, appleCStringError(cErr, "create Apple HTTP request")
|
||||||
|
}
|
||||||
|
cancelDone := make(chan struct{})
|
||||||
|
cancelExit := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(cancelExit)
|
||||||
|
select {
|
||||||
|
case <-request.Context().Done():
|
||||||
|
C.box_apple_http_task_cancel(task)
|
||||||
|
case <-cancelDone:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
cResponse := C.box_apple_http_task_wait(task, &cErr)
|
||||||
|
close(cancelDone)
|
||||||
|
<-cancelExit
|
||||||
|
C.box_apple_http_task_close(task)
|
||||||
|
if cResponse == nil {
|
||||||
|
err := appleCStringError(cErr, "Apple HTTP request failed")
|
||||||
|
if request.Context().Err() != nil {
|
||||||
|
return nil, request.Context().Err()
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer C.box_apple_http_response_free(cResponse)
|
||||||
|
return parseAppleHTTPResponse(request, cResponse), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *appleTransport) CloseIdleConnections() {
|
||||||
|
t.access.Lock()
|
||||||
|
if t.closed {
|
||||||
|
t.access.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.access.Unlock()
|
||||||
|
newSession, err := t.shared.newSession()
|
||||||
|
if err != nil {
|
||||||
|
t.shared.logger.Error(E.Cause(err, "reset Apple HTTP session"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.access.Lock()
|
||||||
|
if t.closed {
|
||||||
|
t.access.Unlock()
|
||||||
|
C.box_apple_http_session_close(newSession)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
oldSession := t.session
|
||||||
|
t.session = newSession
|
||||||
|
t.access.Unlock()
|
||||||
|
C.box_apple_http_session_retire(oldSession)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *appleTransport) Close() error {
|
||||||
|
t.access.Lock()
|
||||||
|
if t.closed {
|
||||||
|
t.access.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
t.closed = true
|
||||||
|
session := t.session
|
||||||
|
t.session = nil
|
||||||
|
t.access.Unlock()
|
||||||
|
C.box_apple_http_session_close(session)
|
||||||
|
return t.shared.release()
|
||||||
|
}
|
||||||
|
|
||||||
|
func flattenRequestHeaders(request *http.Request) ([]string, []string) {
|
||||||
|
var (
|
||||||
|
keys []string
|
||||||
|
values []string
|
||||||
|
)
|
||||||
|
for key, headerValues := range request.Header {
|
||||||
|
for _, value := range headerValues {
|
||||||
|
keys = append(keys, key)
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if request.Host != "" {
|
||||||
|
keys = append(keys, "Host")
|
||||||
|
values = append(values, request.Host)
|
||||||
|
}
|
||||||
|
return keys, values
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAppleHTTPResponse(request *http.Request, response *C.box_apple_http_response_t) *http.Response {
|
||||||
|
headers := make(http.Header)
|
||||||
|
headerKeys := unsafe.Slice(response.header_keys, int(response.header_count))
|
||||||
|
headerValues := unsafe.Slice(response.header_values, int(response.header_count))
|
||||||
|
for index := range headerKeys {
|
||||||
|
headers.Add(C.GoString(headerKeys[index]), C.GoString(headerValues[index]))
|
||||||
|
}
|
||||||
|
body := bytes.NewReader(C.GoBytes(unsafe.Pointer(response.body), C.int(response.body_len)))
|
||||||
|
// NSURLSession's completion-handler API does not expose the negotiated protocol;
|
||||||
|
// callers that read Response.Proto will see HTTP/1.1 even when the wire was HTTP/2.
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: int(response.status_code),
|
||||||
|
Status: fmt.Sprintf("%d %s", int(response.status_code), http.StatusText(int(response.status_code))),
|
||||||
|
Proto: "HTTP/1.1",
|
||||||
|
ProtoMajor: 1,
|
||||||
|
ProtoMinor: 1,
|
||||||
|
Header: headers,
|
||||||
|
Body: io.NopCloser(body),
|
||||||
|
ContentLength: int64(body.Len()),
|
||||||
|
Request: request,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func appleCStringError(cErr *C.char, message string) error {
|
||||||
|
if cErr == nil {
|
||||||
|
return E.New(message)
|
||||||
|
}
|
||||||
|
defer C.free(unsafe.Pointer(cErr))
|
||||||
|
return E.New(message, ": ", C.GoString(cErr))
|
||||||
|
}
|
||||||
71
common/httpclient/apple_transport_darwin.h
Normal file
71
common/httpclient/apple_transport_darwin.h
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
#include <stdbool.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
typedef struct box_apple_http_session box_apple_http_session_t;
|
||||||
|
typedef struct box_apple_http_task box_apple_http_task_t;
|
||||||
|
|
||||||
|
typedef struct box_apple_http_session_config {
|
||||||
|
const char *proxy_host;
|
||||||
|
int proxy_port;
|
||||||
|
const char *proxy_username;
|
||||||
|
const char *proxy_password;
|
||||||
|
uint16_t min_tls_version;
|
||||||
|
uint16_t max_tls_version;
|
||||||
|
bool insecure;
|
||||||
|
const char *anchor_pem;
|
||||||
|
size_t anchor_pem_len;
|
||||||
|
bool anchor_only;
|
||||||
|
const uint8_t *pinned_public_key_sha256;
|
||||||
|
size_t pinned_public_key_sha256_len;
|
||||||
|
} box_apple_http_session_config_t;
|
||||||
|
|
||||||
|
typedef struct box_apple_http_request {
|
||||||
|
const char *method;
|
||||||
|
const char *url;
|
||||||
|
const char **header_keys;
|
||||||
|
const char **header_values;
|
||||||
|
size_t header_count;
|
||||||
|
const uint8_t *body;
|
||||||
|
size_t body_len;
|
||||||
|
bool has_verify_time;
|
||||||
|
int64_t verify_time_unix_millis;
|
||||||
|
} box_apple_http_request_t;
|
||||||
|
|
||||||
|
typedef struct box_apple_http_response {
|
||||||
|
int status_code;
|
||||||
|
char **header_keys;
|
||||||
|
char **header_values;
|
||||||
|
size_t header_count;
|
||||||
|
uint8_t *body;
|
||||||
|
size_t body_len;
|
||||||
|
char *error;
|
||||||
|
} box_apple_http_response_t;
|
||||||
|
|
||||||
|
box_apple_http_session_t *box_apple_http_session_create(
|
||||||
|
const box_apple_http_session_config_t *config,
|
||||||
|
char **error_out
|
||||||
|
);
|
||||||
|
void box_apple_http_session_retire(box_apple_http_session_t *session);
|
||||||
|
void box_apple_http_session_close(box_apple_http_session_t *session);
|
||||||
|
|
||||||
|
box_apple_http_task_t *box_apple_http_session_send_async(
|
||||||
|
box_apple_http_session_t *session,
|
||||||
|
const box_apple_http_request_t *request,
|
||||||
|
char **error_out
|
||||||
|
);
|
||||||
|
box_apple_http_response_t *box_apple_http_task_wait(
|
||||||
|
box_apple_http_task_t *task,
|
||||||
|
char **error_out
|
||||||
|
);
|
||||||
|
void box_apple_http_task_cancel(box_apple_http_task_t *task);
|
||||||
|
void box_apple_http_task_close(box_apple_http_task_t *task);
|
||||||
|
|
||||||
|
void box_apple_http_response_free(box_apple_http_response_t *response);
|
||||||
|
|
||||||
|
char *box_apple_http_verify_public_key_sha256(
|
||||||
|
uint8_t *known_hash_values,
|
||||||
|
size_t known_hash_values_len,
|
||||||
|
uint8_t *leaf_cert,
|
||||||
|
size_t leaf_cert_len
|
||||||
|
);
|
||||||
398
common/httpclient/apple_transport_darwin.m
Normal file
398
common/httpclient/apple_transport_darwin.m
Normal file
@@ -0,0 +1,398 @@
|
|||||||
|
#import "apple_transport_darwin.h"
|
||||||
|
|
||||||
|
#import <CoreFoundation/CFStream.h>
|
||||||
|
#import <Foundation/Foundation.h>
|
||||||
|
#import <Security/Security.h>
|
||||||
|
#import <dispatch/dispatch.h>
|
||||||
|
#import <stdlib.h>
|
||||||
|
#import <string.h>
|
||||||
|
|
||||||
|
typedef struct box_apple_http_session {
|
||||||
|
void *handle;
|
||||||
|
} box_apple_http_session_t;
|
||||||
|
|
||||||
|
typedef struct box_apple_http_task {
|
||||||
|
void *task;
|
||||||
|
void *done_semaphore;
|
||||||
|
box_apple_http_response_t *response;
|
||||||
|
char *error;
|
||||||
|
} box_apple_http_task_t;
|
||||||
|
|
||||||
|
static NSString *const box_apple_http_verify_time_key = @"sing-box.verify-time";
|
||||||
|
|
||||||
|
static void box_set_error_string(char **error_out, NSString *message) {
|
||||||
|
if (error_out == NULL || *error_out != NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const char *utf8 = [message UTF8String];
|
||||||
|
*error_out = strdup(utf8 != NULL ? utf8 : "unknown error");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void box_set_error_from_nserror(char **error_out, NSError *error) {
|
||||||
|
if (error == nil) {
|
||||||
|
box_set_error_string(error_out, @"unknown error");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
box_set_error_string(error_out, error.localizedDescription ?: error.description);
|
||||||
|
}
|
||||||
|
|
||||||
|
static NSArray *box_parse_certificates_from_pem(const char *pem, size_t pem_len) {
|
||||||
|
if (pem == NULL || pem_len == 0) {
|
||||||
|
return @[];
|
||||||
|
}
|
||||||
|
NSString *content = [[NSString alloc] initWithBytes:pem length:pem_len encoding:NSUTF8StringEncoding];
|
||||||
|
if (content == nil) {
|
||||||
|
return @[];
|
||||||
|
}
|
||||||
|
NSString *beginMarker = @"-----BEGIN CERTIFICATE-----";
|
||||||
|
NSString *endMarker = @"-----END CERTIFICATE-----";
|
||||||
|
NSMutableArray *certificates = [NSMutableArray array];
|
||||||
|
NSUInteger searchFrom = 0;
|
||||||
|
while (searchFrom < content.length) {
|
||||||
|
NSRange beginRange = [content rangeOfString:beginMarker options:0 range:NSMakeRange(searchFrom, content.length - searchFrom)];
|
||||||
|
if (beginRange.location == NSNotFound) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
NSUInteger bodyStart = beginRange.location + beginRange.length;
|
||||||
|
NSRange endRange = [content rangeOfString:endMarker options:0 range:NSMakeRange(bodyStart, content.length - bodyStart)];
|
||||||
|
if (endRange.location == NSNotFound) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
NSString *base64Section = [content substringWithRange:NSMakeRange(bodyStart, endRange.location - bodyStart)];
|
||||||
|
NSArray<NSString *> *components = [base64Section componentsSeparatedByCharactersInSet:[NSCharacterSet whitespaceAndNewlineCharacterSet]];
|
||||||
|
NSString *base64Content = [components componentsJoinedByString:@""];
|
||||||
|
NSData *der = [[NSData alloc] initWithBase64EncodedString:base64Content options:0];
|
||||||
|
if (der != nil) {
|
||||||
|
SecCertificateRef certificate = SecCertificateCreateWithData(NULL, (__bridge CFDataRef)der);
|
||||||
|
if (certificate != NULL) {
|
||||||
|
[certificates addObject:(__bridge id)certificate];
|
||||||
|
CFRelease(certificate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
searchFrom = endRange.location + endRange.length;
|
||||||
|
}
|
||||||
|
return certificates;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool box_evaluate_trust(SecTrustRef trustRef, NSArray *anchors, bool anchor_only, NSDate *verifyDate) {
|
||||||
|
if (trustRef == NULL) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (verifyDate != nil && SecTrustSetVerifyDate(trustRef, (__bridge CFDateRef)verifyDate) != errSecSuccess) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (anchors.count > 0 || anchor_only) {
|
||||||
|
CFMutableArrayRef anchorArray = CFArrayCreateMutable(NULL, 0, &kCFTypeArrayCallBacks);
|
||||||
|
for (id certificate in anchors) {
|
||||||
|
CFArrayAppendValue(anchorArray, (__bridge const void *)certificate);
|
||||||
|
}
|
||||||
|
SecTrustSetAnchorCertificates(trustRef, anchorArray);
|
||||||
|
SecTrustSetAnchorCertificatesOnly(trustRef, anchor_only);
|
||||||
|
CFRelease(anchorArray);
|
||||||
|
}
|
||||||
|
CFErrorRef error = NULL;
|
||||||
|
bool result = SecTrustEvaluateWithError(trustRef, &error);
|
||||||
|
if (error != NULL) {
|
||||||
|
CFRelease(error);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static NSDate *box_apple_http_verify_date_for_request(NSURLRequest *request) {
|
||||||
|
if (request == nil) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
id value = [NSURLProtocol propertyForKey:box_apple_http_verify_time_key inRequest:request];
|
||||||
|
if (![value isKindOfClass:[NSNumber class]]) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
return [NSDate dateWithTimeIntervalSince1970:[(NSNumber *)value longLongValue] / 1000.0];
|
||||||
|
}
|
||||||
|
|
||||||
|
static box_apple_http_response_t *box_create_response(NSHTTPURLResponse *httpResponse, NSData *data) {
|
||||||
|
box_apple_http_response_t *response = calloc(1, sizeof(box_apple_http_response_t));
|
||||||
|
response->status_code = (int)httpResponse.statusCode;
|
||||||
|
NSDictionary *headers = httpResponse.allHeaderFields;
|
||||||
|
response->header_count = headers.count;
|
||||||
|
if (response->header_count > 0) {
|
||||||
|
response->header_keys = calloc(response->header_count, sizeof(char *));
|
||||||
|
response->header_values = calloc(response->header_count, sizeof(char *));
|
||||||
|
NSUInteger index = 0;
|
||||||
|
for (id key in headers) {
|
||||||
|
NSString *keyString = [[key description] copy];
|
||||||
|
NSString *valueString = [[headers[key] description] copy];
|
||||||
|
response->header_keys[index] = strdup(keyString.UTF8String ?: "");
|
||||||
|
response->header_values[index] = strdup(valueString.UTF8String ?: "");
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (data.length > 0) {
|
||||||
|
response->body_len = data.length;
|
||||||
|
response->body = malloc(data.length);
|
||||||
|
memcpy(response->body, data.bytes, data.length);
|
||||||
|
}
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
|
||||||
|
@interface BoxAppleHTTPSessionDelegate : NSObject <NSURLSessionTaskDelegate, NSURLSessionDataDelegate>
|
||||||
|
@property(nonatomic, assign) BOOL insecure;
|
||||||
|
@property(nonatomic, assign) BOOL anchorOnly;
|
||||||
|
@property(nonatomic, strong) NSArray *anchors;
|
||||||
|
@property(nonatomic, strong) NSData *pinnedPublicKeyHashes;
|
||||||
|
@end
|
||||||
|
|
||||||
|
@implementation BoxAppleHTTPSessionDelegate
|
||||||
|
|
||||||
|
- (void)URLSession:(NSURLSession *)session
|
||||||
|
task:(NSURLSessionTask *)task
|
||||||
|
willPerformHTTPRedirection:(NSHTTPURLResponse *)response
|
||||||
|
newRequest:(NSURLRequest *)request
|
||||||
|
completionHandler:(void (^)(NSURLRequest * _Nullable))completionHandler {
|
||||||
|
completionHandler(nil);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)URLSession:(NSURLSession *)session
|
||||||
|
task:(NSURLSessionTask *)task
|
||||||
|
didReceiveChallenge:(NSURLAuthenticationChallenge *)challenge
|
||||||
|
completionHandler:(void (^)(NSURLSessionAuthChallengeDisposition disposition, NSURLCredential * _Nullable credential))completionHandler {
|
||||||
|
if (![challenge.protectionSpace.authenticationMethod isEqualToString:NSURLAuthenticationMethodServerTrust]) {
|
||||||
|
completionHandler(NSURLSessionAuthChallengePerformDefaultHandling, nil);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
SecTrustRef trustRef = challenge.protectionSpace.serverTrust;
|
||||||
|
if (trustRef == NULL) {
|
||||||
|
completionHandler(NSURLSessionAuthChallengeCancelAuthenticationChallenge, nil);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
NSDate *verifyDate = box_apple_http_verify_date_for_request(task.currentRequest ?: task.originalRequest);
|
||||||
|
BOOL needsCustomHandling = self.insecure || self.anchorOnly || self.anchors.count > 0 || self.pinnedPublicKeyHashes.length > 0 || verifyDate != nil;
|
||||||
|
if (!needsCustomHandling) {
|
||||||
|
completionHandler(NSURLSessionAuthChallengePerformDefaultHandling, nil);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
BOOL ok = YES;
|
||||||
|
if (!self.insecure) {
|
||||||
|
ok = box_evaluate_trust(trustRef, self.anchors, self.anchorOnly, verifyDate);
|
||||||
|
}
|
||||||
|
if (ok && self.pinnedPublicKeyHashes.length > 0) {
|
||||||
|
CFArrayRef certificateChain = SecTrustCopyCertificateChain(trustRef);
|
||||||
|
SecCertificateRef leafCertificate = NULL;
|
||||||
|
if (certificateChain != NULL && CFArrayGetCount(certificateChain) > 0) {
|
||||||
|
leafCertificate = (SecCertificateRef)CFArrayGetValueAtIndex(certificateChain, 0);
|
||||||
|
}
|
||||||
|
if (leafCertificate == NULL) {
|
||||||
|
ok = NO;
|
||||||
|
} else {
|
||||||
|
NSData *leafData = CFBridgingRelease(SecCertificateCopyData(leafCertificate));
|
||||||
|
char *pinError = box_apple_http_verify_public_key_sha256(
|
||||||
|
(uint8_t *)self.pinnedPublicKeyHashes.bytes,
|
||||||
|
self.pinnedPublicKeyHashes.length,
|
||||||
|
(uint8_t *)leafData.bytes,
|
||||||
|
leafData.length
|
||||||
|
);
|
||||||
|
if (pinError != NULL) {
|
||||||
|
free(pinError);
|
||||||
|
ok = NO;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (certificateChain != NULL) {
|
||||||
|
CFRelease(certificateChain);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!ok) {
|
||||||
|
completionHandler(NSURLSessionAuthChallengeCancelAuthenticationChallenge, nil);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
completionHandler(NSURLSessionAuthChallengeUseCredential, [NSURLCredential credentialForTrust:trustRef]);
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
@interface BoxAppleHTTPSessionHandle : NSObject
|
||||||
|
@property(nonatomic, strong) NSURLSession *session;
|
||||||
|
@property(nonatomic, strong) BoxAppleHTTPSessionDelegate *delegate;
|
||||||
|
@end
|
||||||
|
|
||||||
|
@implementation BoxAppleHTTPSessionHandle
|
||||||
|
@end
|
||||||
|
|
||||||
|
box_apple_http_session_t *box_apple_http_session_create(
|
||||||
|
const box_apple_http_session_config_t *config,
|
||||||
|
char **error_out
|
||||||
|
) {
|
||||||
|
@autoreleasepool {
|
||||||
|
NSURLSessionConfiguration *sessionConfig = [NSURLSessionConfiguration ephemeralSessionConfiguration];
|
||||||
|
sessionConfig.URLCache = nil;
|
||||||
|
sessionConfig.HTTPCookieStorage = nil;
|
||||||
|
sessionConfig.URLCredentialStorage = nil;
|
||||||
|
sessionConfig.HTTPShouldSetCookies = NO;
|
||||||
|
if (config != NULL && config->proxy_host != NULL && config->proxy_port > 0) {
|
||||||
|
NSMutableDictionary *proxyDictionary = [NSMutableDictionary dictionary];
|
||||||
|
proxyDictionary[(__bridge NSString *)kCFStreamPropertySOCKSProxyHost] = [NSString stringWithUTF8String:config->proxy_host];
|
||||||
|
proxyDictionary[(__bridge NSString *)kCFStreamPropertySOCKSProxyPort] = @(config->proxy_port);
|
||||||
|
proxyDictionary[(__bridge NSString *)kCFStreamPropertySOCKSVersion] = (__bridge NSString *)kCFStreamSocketSOCKSVersion5;
|
||||||
|
if (config->proxy_username != NULL) {
|
||||||
|
proxyDictionary[(__bridge NSString *)kCFStreamPropertySOCKSUser] = [NSString stringWithUTF8String:config->proxy_username];
|
||||||
|
}
|
||||||
|
if (config->proxy_password != NULL) {
|
||||||
|
proxyDictionary[(__bridge NSString *)kCFStreamPropertySOCKSPassword] = [NSString stringWithUTF8String:config->proxy_password];
|
||||||
|
}
|
||||||
|
sessionConfig.connectionProxyDictionary = proxyDictionary;
|
||||||
|
}
|
||||||
|
if (config != NULL && config->min_tls_version != 0) {
|
||||||
|
sessionConfig.TLSMinimumSupportedProtocolVersion = (tls_protocol_version_t)config->min_tls_version;
|
||||||
|
}
|
||||||
|
if (config != NULL && config->max_tls_version != 0) {
|
||||||
|
sessionConfig.TLSMaximumSupportedProtocolVersion = (tls_protocol_version_t)config->max_tls_version;
|
||||||
|
}
|
||||||
|
BoxAppleHTTPSessionDelegate *delegate = [[BoxAppleHTTPSessionDelegate alloc] init];
|
||||||
|
if (config != NULL) {
|
||||||
|
delegate.insecure = config->insecure;
|
||||||
|
delegate.anchorOnly = config->anchor_only;
|
||||||
|
delegate.anchors = box_parse_certificates_from_pem(config->anchor_pem, config->anchor_pem_len);
|
||||||
|
if (config->pinned_public_key_sha256 != NULL && config->pinned_public_key_sha256_len > 0) {
|
||||||
|
delegate.pinnedPublicKeyHashes = [NSData dataWithBytes:config->pinned_public_key_sha256 length:config->pinned_public_key_sha256_len];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
NSURLSession *session = [NSURLSession sessionWithConfiguration:sessionConfig delegate:delegate delegateQueue:nil];
|
||||||
|
if (session == nil) {
|
||||||
|
box_set_error_string(error_out, @"create URLSession");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
BoxAppleHTTPSessionHandle *handle = [[BoxAppleHTTPSessionHandle alloc] init];
|
||||||
|
handle.session = session;
|
||||||
|
handle.delegate = delegate;
|
||||||
|
box_apple_http_session_t *sessionHandle = calloc(1, sizeof(box_apple_http_session_t));
|
||||||
|
sessionHandle->handle = (__bridge_retained void *)handle;
|
||||||
|
return sessionHandle;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void box_apple_http_session_retire(box_apple_http_session_t *session) {
|
||||||
|
if (session == NULL || session->handle == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
BoxAppleHTTPSessionHandle *handle = (__bridge_transfer BoxAppleHTTPSessionHandle *)session->handle;
|
||||||
|
[handle.session finishTasksAndInvalidate];
|
||||||
|
free(session);
|
||||||
|
}
|
||||||
|
|
||||||
|
void box_apple_http_session_close(box_apple_http_session_t *session) {
|
||||||
|
if (session == NULL || session->handle == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
BoxAppleHTTPSessionHandle *handle = (__bridge_transfer BoxAppleHTTPSessionHandle *)session->handle;
|
||||||
|
[handle.session invalidateAndCancel];
|
||||||
|
free(session);
|
||||||
|
}
|
||||||
|
|
||||||
|
box_apple_http_task_t *box_apple_http_session_send_async(
|
||||||
|
box_apple_http_session_t *session,
|
||||||
|
const box_apple_http_request_t *request,
|
||||||
|
char **error_out
|
||||||
|
) {
|
||||||
|
@autoreleasepool {
|
||||||
|
if (session == NULL || session->handle == NULL || request == NULL || request->method == NULL || request->url == NULL) {
|
||||||
|
box_set_error_string(error_out, @"invalid apple HTTP request");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
BoxAppleHTTPSessionHandle *handle = (__bridge BoxAppleHTTPSessionHandle *)session->handle;
|
||||||
|
NSURL *requestURL = [NSURL URLWithString:[NSString stringWithUTF8String:request->url]];
|
||||||
|
if (requestURL == nil) {
|
||||||
|
box_set_error_string(error_out, @"invalid request URL");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
NSMutableURLRequest *urlRequest = [NSMutableURLRequest requestWithURL:requestURL];
|
||||||
|
urlRequest.HTTPMethod = [NSString stringWithUTF8String:request->method];
|
||||||
|
for (size_t index = 0; index < request->header_count; index++) {
|
||||||
|
const char *key = request->header_keys[index];
|
||||||
|
const char *value = request->header_values[index];
|
||||||
|
if (key == NULL || value == NULL) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
[urlRequest addValue:[NSString stringWithUTF8String:value] forHTTPHeaderField:[NSString stringWithUTF8String:key]];
|
||||||
|
}
|
||||||
|
if (request->body != NULL && request->body_len > 0) {
|
||||||
|
urlRequest.HTTPBody = [NSData dataWithBytes:request->body length:request->body_len];
|
||||||
|
}
|
||||||
|
if (request->has_verify_time) {
|
||||||
|
[NSURLProtocol setProperty:@(request->verify_time_unix_millis) forKey:box_apple_http_verify_time_key inRequest:urlRequest];
|
||||||
|
}
|
||||||
|
box_apple_http_task_t *task = calloc(1, sizeof(box_apple_http_task_t));
|
||||||
|
dispatch_semaphore_t doneSemaphore = dispatch_semaphore_create(0);
|
||||||
|
task->done_semaphore = (__bridge_retained void *)doneSemaphore;
|
||||||
|
NSURLSessionDataTask *dataTask = [handle.session dataTaskWithRequest:urlRequest completionHandler:^(NSData *data, NSURLResponse *response, NSError *error) {
|
||||||
|
if (error != nil) {
|
||||||
|
box_set_error_from_nserror(&task->error, error);
|
||||||
|
} else if (![response isKindOfClass:[NSHTTPURLResponse class]]) {
|
||||||
|
box_set_error_string(&task->error, @"unexpected HTTP response type");
|
||||||
|
} else {
|
||||||
|
task->response = box_create_response((NSHTTPURLResponse *)response, data ?: [NSData data]);
|
||||||
|
}
|
||||||
|
dispatch_semaphore_signal((__bridge dispatch_semaphore_t)task->done_semaphore);
|
||||||
|
}];
|
||||||
|
if (dataTask == nil) {
|
||||||
|
box_set_error_string(error_out, @"create data task");
|
||||||
|
box_apple_http_task_close(task);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
task->task = (__bridge_retained void *)dataTask;
|
||||||
|
[dataTask resume];
|
||||||
|
return task;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
box_apple_http_response_t *box_apple_http_task_wait(
|
||||||
|
box_apple_http_task_t *task,
|
||||||
|
char **error_out
|
||||||
|
) {
|
||||||
|
if (task == NULL || task->done_semaphore == NULL) {
|
||||||
|
box_set_error_string(error_out, @"invalid apple HTTP task");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
dispatch_semaphore_wait((__bridge dispatch_semaphore_t)task->done_semaphore, DISPATCH_TIME_FOREVER);
|
||||||
|
if (task->error != NULL) {
|
||||||
|
box_set_error_string(error_out, [NSString stringWithUTF8String:task->error]);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
return task->response;
|
||||||
|
}
|
||||||
|
|
||||||
|
void box_apple_http_task_cancel(box_apple_http_task_t *task) {
|
||||||
|
if (task == NULL || task->task == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
NSURLSessionTask *nsTask = (__bridge NSURLSessionTask *)task->task;
|
||||||
|
[nsTask cancel];
|
||||||
|
}
|
||||||
|
|
||||||
|
void box_apple_http_task_close(box_apple_http_task_t *task) {
|
||||||
|
if (task == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (task->task != NULL) {
|
||||||
|
__unused NSURLSessionTask *nsTask = (__bridge_transfer NSURLSessionTask *)task->task;
|
||||||
|
task->task = NULL;
|
||||||
|
}
|
||||||
|
if (task->done_semaphore != NULL) {
|
||||||
|
__unused dispatch_semaphore_t doneSemaphore = (__bridge_transfer dispatch_semaphore_t)task->done_semaphore;
|
||||||
|
task->done_semaphore = NULL;
|
||||||
|
}
|
||||||
|
free(task->error);
|
||||||
|
free(task);
|
||||||
|
}
|
||||||
|
|
||||||
|
void box_apple_http_response_free(box_apple_http_response_t *response) {
|
||||||
|
if (response == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (size_t index = 0; index < response->header_count; index++) {
|
||||||
|
free(response->header_keys[index]);
|
||||||
|
free(response->header_values[index]);
|
||||||
|
}
|
||||||
|
free(response->header_keys);
|
||||||
|
free(response->header_values);
|
||||||
|
free(response->body);
|
||||||
|
free(response->error);
|
||||||
|
free(response);
|
||||||
|
}
|
||||||
855
common/httpclient/apple_transport_darwin_test.go
Normal file
855
common/httpclient/apple_transport_darwin_test.go
Normal file
@@ -0,0 +1,855 @@
|
|||||||
|
//go:build darwin && cgo
|
||||||
|
|
||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
stdtls "crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
boxTLS "github.com/sagernet/sing-box/common/tls"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing-box/route"
|
||||||
|
"github.com/sagernet/sing/common/json/badoption"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/sing/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
const appleHTTPTestTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
const appleHTTPRecoveryLoops = 5
|
||||||
|
|
||||||
|
type appleHTTPTestDialer struct {
|
||||||
|
dialer net.Dialer
|
||||||
|
listener net.ListenConfig
|
||||||
|
hostMap map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
type appleHTTPObservedRequest struct {
|
||||||
|
method string
|
||||||
|
body string
|
||||||
|
host string
|
||||||
|
values []string
|
||||||
|
protoMajor int
|
||||||
|
}
|
||||||
|
|
||||||
|
type appleHTTPTestServer struct {
|
||||||
|
server *httptest.Server
|
||||||
|
baseURL string
|
||||||
|
dialHost string
|
||||||
|
certificate stdtls.Certificate
|
||||||
|
certificatePEM string
|
||||||
|
publicKeyHash []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewAppleSessionConfig(t *testing.T) {
|
||||||
|
serverCertificate, serverCertificatePEM := newAppleHTTPTestCertificate(t, "localhost")
|
||||||
|
serverHash := certificatePublicKeySHA256(t, serverCertificate.Certificate[0])
|
||||||
|
otherHash := bytes.Repeat([]byte{0x7f}, applePinnedHashSize)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
options option.HTTPClientOptions
|
||||||
|
check func(t *testing.T, config appleSessionConfig)
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success with certificate anchors",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
DialerOptions: option.DialerOptions{
|
||||||
|
ConnectTimeout: badoption.Duration(2 * time.Second),
|
||||||
|
},
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
ServerName: "localhost",
|
||||||
|
MinVersion: "1.2",
|
||||||
|
MaxVersion: "1.3",
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: func(t *testing.T, config appleSessionConfig) {
|
||||||
|
t.Helper()
|
||||||
|
if config.serverName != "localhost" {
|
||||||
|
t.Fatalf("unexpected server name: %q", config.serverName)
|
||||||
|
}
|
||||||
|
if config.minVersion != stdtls.VersionTLS12 {
|
||||||
|
t.Fatalf("unexpected min version: %x", config.minVersion)
|
||||||
|
}
|
||||||
|
if config.maxVersion != stdtls.VersionTLS13 {
|
||||||
|
t.Fatalf("unexpected max version: %x", config.maxVersion)
|
||||||
|
}
|
||||||
|
if config.insecure {
|
||||||
|
t.Fatal("unexpected insecure flag")
|
||||||
|
}
|
||||||
|
if !config.anchorOnly {
|
||||||
|
t.Fatal("expected anchor_only")
|
||||||
|
}
|
||||||
|
if !strings.Contains(config.anchorPEM, "BEGIN CERTIFICATE") {
|
||||||
|
t.Fatalf("unexpected anchor pem: %q", config.anchorPEM)
|
||||||
|
}
|
||||||
|
if len(config.pinnedPublicKeySHA256s) != 0 {
|
||||||
|
t.Fatalf("unexpected pinned hashes: %d", len(config.pinnedPublicKeySHA256s))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success with flattened pins",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
Insecure: true,
|
||||||
|
CertificatePublicKeySHA256: badoption.Listable[[]byte]{serverHash, otherHash},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: func(t *testing.T, config appleSessionConfig) {
|
||||||
|
t.Helper()
|
||||||
|
if !config.insecure {
|
||||||
|
t.Fatal("expected insecure flag")
|
||||||
|
}
|
||||||
|
if len(config.pinnedPublicKeySHA256s) != 2*applePinnedHashSize {
|
||||||
|
t.Fatalf("unexpected flattened pin length: %d", len(config.pinnedPublicKeySHA256s))
|
||||||
|
}
|
||||||
|
if !bytes.Equal(config.pinnedPublicKeySHA256s[:applePinnedHashSize], serverHash) {
|
||||||
|
t.Fatal("unexpected first pin")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(config.pinnedPublicKeySHA256s[applePinnedHashSize:], otherHash) {
|
||||||
|
t.Fatal("unexpected second pin")
|
||||||
|
}
|
||||||
|
if config.anchorPEM != "" {
|
||||||
|
t.Fatalf("unexpected anchor pem: %q", config.anchorPEM)
|
||||||
|
}
|
||||||
|
if config.anchorOnly {
|
||||||
|
t.Fatal("unexpected anchor_only")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "http11 unsupported",
|
||||||
|
options: option.HTTPClientOptions{Version: 1},
|
||||||
|
wantErr: "HTTP/1.1 is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "http3 unsupported",
|
||||||
|
options: option.HTTPClientOptions{Version: 3},
|
||||||
|
wantErr: "HTTP/3 is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown version",
|
||||||
|
options: option.HTTPClientOptions{Version: 9},
|
||||||
|
wantErr: "unknown HTTP version: 9",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disable version fallback unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
DisableVersionFallback: true,
|
||||||
|
},
|
||||||
|
wantErr: "disable_version_fallback is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "http2 options unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
HTTP2Options: option.HTTP2Options{
|
||||||
|
IdleTimeout: badoption.Duration(time.Second),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "HTTP/2 options are unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "quic options unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
HTTP3Options: option.QUICOptions{
|
||||||
|
InitialPacketSize: 1200,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "QUIC options are unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tls engine unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{Engine: "go"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "tls.engine is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disable sni unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{DisableSNI: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "disable_sni is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "alpn unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
ALPN: badoption.Listable[string]{"h2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "tls.alpn is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cipher suites unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
CipherSuites: badoption.Listable[string]{"TLS_AES_128_GCM_SHA256"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "cipher_suites is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "curve preferences unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
CurvePreferences: badoption.Listable[option.CurvePreference]{option.CurvePreference(option.X25519)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "curve_preferences is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "client certificate unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
ClientCertificate: badoption.Listable[string]{"client-certificate"},
|
||||||
|
ClientKey: badoption.Listable[string]{"client-key"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "client certificate is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tls fragment unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{Fragment: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "tls fragment is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ktls unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{KernelTx: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "ktls is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ech unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
ECH: &option.OutboundECHOptions{Enabled: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "ech is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "utls unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
UTLS: &option.OutboundUTLSOptions{Enabled: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "utls is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reality unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
Reality: &option.OutboundRealityOptions{Enabled: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "reality is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pin and certificate conflict",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
CertificatePublicKeySHA256: badoption.Listable[[]byte]{serverHash},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "certificate_public_key_sha256 is conflict with certificate or certificate_path",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid min version",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{MinVersion: "bogus"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "parse min_version",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid max version",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{MaxVersion: "bogus"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "parse max_version",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid pin length",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
CertificatePublicKeySHA256: badoption.Listable[[]byte]{{0x01, 0x02}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "invalid certificate_public_key_sha256 length: 2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
config, err := newAppleSessionConfig(context.Background(), testCase.options)
|
||||||
|
if testCase.wantErr != "" {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), testCase.wantErr) {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if testCase.check != nil {
|
||||||
|
testCase.check(t, config)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleTransportVerifyPublicKeySHA256(t *testing.T) {
|
||||||
|
serverCertificate, _ := newAppleHTTPTestCertificate(t, "localhost")
|
||||||
|
goodHash := certificatePublicKeySHA256(t, serverCertificate.Certificate[0])
|
||||||
|
badHash := append([]byte(nil), goodHash...)
|
||||||
|
badHash[0] ^= 0xff
|
||||||
|
|
||||||
|
err := verifyApplePinnedPublicKeySHA256(goodHash, serverCertificate.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected correct pin to succeed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = verifyApplePinnedPublicKeySHA256(badHash, serverCertificate.Certificate[0])
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected incorrect pin to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "unrecognized remote public key") {
|
||||||
|
t.Fatalf("unexpected pin mismatch error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = verifyApplePinnedPublicKeySHA256(goodHash[:applePinnedHashSize-1], serverCertificate.Certificate[0])
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected malformed pin list to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "invalid pinned public key list") {
|
||||||
|
t.Fatalf("unexpected malformed pin error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleTransportRoundTripHTTPS(t *testing.T) {
|
||||||
|
requests := make(chan appleHTTPObservedRequest, 1)
|
||||||
|
server := startAppleHTTPTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
requests <- appleHTTPObservedRequest{
|
||||||
|
method: r.Method,
|
||||||
|
body: string(body),
|
||||||
|
host: r.Host,
|
||||||
|
values: append([]string(nil), r.Header.Values("X-Test")...),
|
||||||
|
protoMajor: r.ProtoMajor,
|
||||||
|
}
|
||||||
|
w.Header().Set("X-Reply", "apple")
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
_, _ = w.Write([]byte("response body"))
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := newAppleHTTPTestTransport(t, server, option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: appleHTTPServerTLSOptions(server),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
request, err := http.NewRequest(http.MethodPost, server.URL("/roundtrip"), bytes.NewReader([]byte("request body")))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
request.Header.Add("X-Test", "one")
|
||||||
|
request.Header.Add("X-Test", "two")
|
||||||
|
request.Host = "custom.example"
|
||||||
|
|
||||||
|
response, err := transport.RoundTrip(request)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
responseBody := readResponseBody(t, response)
|
||||||
|
if response.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("unexpected status code: %d", response.StatusCode)
|
||||||
|
}
|
||||||
|
if response.Status != "201 Created" {
|
||||||
|
t.Fatalf("unexpected status: %q", response.Status)
|
||||||
|
}
|
||||||
|
if response.Header.Get("X-Reply") != "apple" {
|
||||||
|
t.Fatalf("unexpected response header: %q", response.Header.Get("X-Reply"))
|
||||||
|
}
|
||||||
|
if responseBody != "response body" {
|
||||||
|
t.Fatalf("unexpected response body: %q", responseBody)
|
||||||
|
}
|
||||||
|
if response.ContentLength != int64(len(responseBody)) {
|
||||||
|
t.Fatalf("unexpected content length: %d", response.ContentLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
observed := waitObservedRequest(t, requests)
|
||||||
|
if observed.method != http.MethodPost {
|
||||||
|
t.Fatalf("unexpected method: %q", observed.method)
|
||||||
|
}
|
||||||
|
if observed.body != "request body" {
|
||||||
|
t.Fatalf("unexpected request body: %q", observed.body)
|
||||||
|
}
|
||||||
|
if observed.host != "custom.example" {
|
||||||
|
t.Fatalf("unexpected host: %q", observed.host)
|
||||||
|
}
|
||||||
|
if observed.protoMajor != 2 {
|
||||||
|
t.Fatalf("expected HTTP/2 request, got HTTP/%d", observed.protoMajor)
|
||||||
|
}
|
||||||
|
var normalizedValues []string
|
||||||
|
for _, value := range observed.values {
|
||||||
|
for _, part := range strings.Split(value, ",") {
|
||||||
|
normalizedValues = append(normalizedValues, strings.TrimSpace(part))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slices.Sort(normalizedValues)
|
||||||
|
if !slices.Equal(normalizedValues, []string{"one", "two"}) {
|
||||||
|
t.Fatalf("unexpected header values: %#v", observed.values)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleTransportPinnedPublicKey(t *testing.T) {
|
||||||
|
server := startAppleHTTPTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("pinned"))
|
||||||
|
})
|
||||||
|
|
||||||
|
goodTransport := newAppleHTTPTestTransport(t, server, option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
ServerName: "localhost",
|
||||||
|
Insecure: true,
|
||||||
|
CertificatePublicKeySHA256: badoption.Listable[[]byte]{server.publicKeyHash},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
response, err := goodTransport.RoundTrip(newAppleHTTPRequest(t, http.MethodGet, server.URL("/good"), nil))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected pinned request to succeed: %v", err)
|
||||||
|
}
|
||||||
|
response.Body.Close()
|
||||||
|
|
||||||
|
badHash := append([]byte(nil), server.publicKeyHash...)
|
||||||
|
badHash[0] ^= 0xff
|
||||||
|
badTransport := newAppleHTTPTestTransport(t, server, option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
ServerName: "localhost",
|
||||||
|
Insecure: true,
|
||||||
|
CertificatePublicKeySHA256: badoption.Listable[[]byte]{badHash},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
response, err = badTransport.RoundTrip(newAppleHTTPRequest(t, http.MethodGet, server.URL("/bad"), nil))
|
||||||
|
if err == nil {
|
||||||
|
response.Body.Close()
|
||||||
|
t.Fatal("expected incorrect pinned public key to fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleTransportGuardrails(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
options option.HTTPClientOptions
|
||||||
|
buildRequest func(t *testing.T) *http.Request
|
||||||
|
wantErrSubstr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "websocket upgrade rejected",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
},
|
||||||
|
buildRequest: func(t *testing.T) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
request := newAppleHTTPRequest(t, http.MethodGet, "https://localhost/socket", nil)
|
||||||
|
request.Header.Set("Connection", "Upgrade")
|
||||||
|
request.Header.Set("Upgrade", "websocket")
|
||||||
|
return request
|
||||||
|
},
|
||||||
|
wantErrSubstr: "HTTP upgrade requests are unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing url rejected",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
},
|
||||||
|
buildRequest: func(t *testing.T) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
return &http.Request{Method: http.MethodGet}
|
||||||
|
},
|
||||||
|
wantErrSubstr: "missing request URL",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported scheme rejected",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
},
|
||||||
|
buildRequest: func(t *testing.T) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
return newAppleHTTPRequest(t, http.MethodGet, "ftp://localhost/file", nil)
|
||||||
|
},
|
||||||
|
wantErrSubstr: "unsupported URL scheme: ftp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "server name mismatch rejected",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
ServerName: "example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
buildRequest: func(t *testing.T) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
return newAppleHTTPRequest(t, http.MethodGet, "https://localhost/path", nil)
|
||||||
|
},
|
||||||
|
wantErrSubstr: "tls.server_name is unsupported in Apple HTTP engine unless it matches request host",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
transport := newAppleHTTPTestTransport(t, nil, testCase.options)
|
||||||
|
response, err := transport.RoundTrip(testCase.buildRequest(t))
|
||||||
|
if err == nil {
|
||||||
|
response.Body.Close()
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), testCase.wantErrSubstr) {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleTransportCancellationRecovery(t *testing.T) {
|
||||||
|
server := startAppleHTTPTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/block":
|
||||||
|
select {
|
||||||
|
case <-r.Context().Done():
|
||||||
|
return
|
||||||
|
case <-time.After(appleHTTPTestTimeout):
|
||||||
|
http.Error(w, "request was not canceled", http.StatusGatewayTimeout)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := newAppleHTTPTestTransport(t, server, option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: appleHTTPServerTLSOptions(server),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
for index := 0; index < appleHTTPRecoveryLoops; index++ {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
request := newAppleHTTPRequestWithContext(t, ctx, http.MethodGet, server.URL("/block"), nil)
|
||||||
|
response, err := transport.RoundTrip(request)
|
||||||
|
cancel()
|
||||||
|
if err == nil {
|
||||||
|
response.Body.Close()
|
||||||
|
t.Fatalf("iteration %d: expected cancellation error", index)
|
||||||
|
}
|
||||||
|
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("iteration %d: unexpected cancellation error: %v", index, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err = transport.RoundTrip(newAppleHTTPRequest(t, http.MethodGet, server.URL("/ok"), nil))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("iteration %d: follow-up request failed: %v", index, err)
|
||||||
|
}
|
||||||
|
if body := readResponseBody(t, response); body != "ok" {
|
||||||
|
response.Body.Close()
|
||||||
|
t.Fatalf("iteration %d: unexpected follow-up body: %q", index, body)
|
||||||
|
}
|
||||||
|
response.Body.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleTransportLifecycle(t *testing.T) {
|
||||||
|
server := startAppleHTTPTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := newAppleHTTPTestTransport(t, server, option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: appleHTTPServerTLSOptions(server),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
assertAppleHTTPSucceeds(t, transport, server.URL("/original"))
|
||||||
|
|
||||||
|
transport.CloseIdleConnections()
|
||||||
|
assertAppleHTTPSucceeds(t, transport, server.URL("/reset"))
|
||||||
|
|
||||||
|
innerTransport := transport.(*appleTransport)
|
||||||
|
if err := innerTransport.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := innerTransport.RoundTrip(newAppleHTTPRequest(t, http.MethodGet, server.URL("/closed"), nil))
|
||||||
|
if err == nil {
|
||||||
|
response.Body.Close()
|
||||||
|
t.Fatal("expected closed transport to fail")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, net.ErrClosed) {
|
||||||
|
t.Fatalf("unexpected closed transport error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func startAppleHTTPTestServer(t *testing.T, handler http.HandlerFunc) *appleHTTPTestServer {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
serverCertificate, serverCertificatePEM := newAppleHTTPTestCertificate(t, "localhost")
|
||||||
|
server := httptest.NewUnstartedServer(handler)
|
||||||
|
server.EnableHTTP2 = true
|
||||||
|
server.TLS = &stdtls.Config{
|
||||||
|
Certificates: []stdtls.Certificate{serverCertificate},
|
||||||
|
MinVersion: stdtls.VersionTLS12,
|
||||||
|
}
|
||||||
|
server.StartTLS()
|
||||||
|
t.Cleanup(server.Close)
|
||||||
|
|
||||||
|
parsedURL, err := url.Parse(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
baseURL := *parsedURL
|
||||||
|
baseURL.Host = net.JoinHostPort("localhost", parsedURL.Port())
|
||||||
|
|
||||||
|
return &appleHTTPTestServer{
|
||||||
|
server: server,
|
||||||
|
baseURL: baseURL.String(),
|
||||||
|
dialHost: parsedURL.Hostname(),
|
||||||
|
certificate: serverCertificate,
|
||||||
|
certificatePEM: serverCertificatePEM,
|
||||||
|
publicKeyHash: certificatePublicKeySHA256(t, serverCertificate.Certificate[0]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appleHTTPTestServer) URL(path string) string {
|
||||||
|
if path == "" {
|
||||||
|
return s.baseURL
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(path, "/") {
|
||||||
|
return s.baseURL + path
|
||||||
|
}
|
||||||
|
return s.baseURL + "/" + path
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleHTTPTestTransport(t *testing.T, server *appleHTTPTestServer, options option.HTTPClientOptions) innerTransport {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx := service.ContextWith[adapter.ConnectionManager](
|
||||||
|
context.Background(),
|
||||||
|
route.NewConnectionManager(log.NewNOPFactory().NewLogger("connection")),
|
||||||
|
)
|
||||||
|
dialer := &appleHTTPTestDialer{
|
||||||
|
hostMap: make(map[string]string),
|
||||||
|
}
|
||||||
|
if server != nil {
|
||||||
|
dialer.hostMap["localhost"] = server.dialHost
|
||||||
|
}
|
||||||
|
|
||||||
|
transport, err := newAppleTransport(ctx, log.NewNOPFactory().NewLogger("httpclient"), dialer, options)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = transport.Close()
|
||||||
|
})
|
||||||
|
return transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *appleHTTPTestDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||||
|
host := destination.AddrString()
|
||||||
|
if destination.IsDomain() {
|
||||||
|
host = destination.Fqdn
|
||||||
|
if mappedHost, loaded := d.hostMap[host]; loaded {
|
||||||
|
host = mappedHost
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return d.dialer.DialContext(ctx, network, net.JoinHostPort(host, strconv.Itoa(int(destination.Port))))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *appleHTTPTestDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||||
|
host := destination.AddrString()
|
||||||
|
if destination.IsDomain() {
|
||||||
|
host = destination.Fqdn
|
||||||
|
if mappedHost, loaded := d.hostMap[host]; loaded {
|
||||||
|
host = mappedHost
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if host == "" {
|
||||||
|
host = "127.0.0.1"
|
||||||
|
}
|
||||||
|
return d.listener.ListenPacket(ctx, N.NetworkUDP, net.JoinHostPort(host, strconv.Itoa(int(destination.Port))))
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleHTTPTestCertificate(t *testing.T, serverName string) (stdtls.Certificate, string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
privateKeyPEM, certificatePEM, err := boxTLS.GenerateCertificate(nil, nil, time.Now, serverName, time.Now().Add(time.Hour))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
certificate, err := stdtls.X509KeyPair(certificatePEM, privateKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return certificate, string(certificatePEM)
|
||||||
|
}
|
||||||
|
|
||||||
|
func certificatePublicKeySHA256(t *testing.T, certificateDER []byte) []byte {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
certificate, err := x509.ParseCertificate(certificateDER)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
publicKeyDER, err := x509.MarshalPKIXPublicKey(certificate.PublicKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
hashValue := sha256.Sum256(publicKeyDER)
|
||||||
|
return append([]byte(nil), hashValue[:]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func appleHTTPServerTLSOptions(server *appleHTTPTestServer) *option.OutboundTLSOptions {
|
||||||
|
return &option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
ServerName: "localhost",
|
||||||
|
Certificate: badoption.Listable[string]{server.certificatePEM},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleHTTPRequest(t *testing.T, method string, rawURL string, body []byte) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
return newAppleHTTPRequestWithContext(t, context.Background(), method, rawURL, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleHTTPRequestWithContext(t *testing.T, ctx context.Context, method string, rawURL string, body []byte) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
request, err := http.NewRequestWithContext(ctx, method, rawURL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return request
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitObservedRequest(t *testing.T, requests <-chan appleHTTPObservedRequest) appleHTTPObservedRequest {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case request := <-requests:
|
||||||
|
return request
|
||||||
|
case <-time.After(appleHTTPTestTimeout):
|
||||||
|
t.Fatal("timed out waiting for observed request")
|
||||||
|
return appleHTTPObservedRequest{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readResponseBody(t *testing.T, response *http.Response) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return string(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertAppleHTTPSucceeds(t *testing.T, transport http.RoundTripper, rawURL string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
response, err := transport.RoundTrip(newAppleHTTPRequest(t, http.MethodGet, rawURL, nil))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
if body := readResponseBody(t, response); body != "ok" {
|
||||||
|
t.Fatalf("unexpected response body: %q", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
16
common/httpclient/apple_transport_stub.go
Normal file
16
common/httpclient/apple_transport_stub.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
//go:build !darwin || !cgo
|
||||||
|
|
||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAppleTransport(ctx context.Context, logger logger.ContextLogger, rawDialer N.Dialer, options option.HTTPClientOptions) (innerTransport, error) {
|
||||||
|
return nil, E.New("Apple HTTP engine is not available on non-Apple platforms")
|
||||||
|
}
|
||||||
131
common/httpclient/client.go
Normal file
131
common/httpclient/client.go
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewTransport(ctx context.Context, logger logger.ContextLogger, tag string, options option.HTTPClientOptions) (*ManagedTransport, error) {
|
||||||
|
rawDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||||
|
Context: ctx,
|
||||||
|
Options: options.DialerOptions,
|
||||||
|
RemoteIsDomain: true,
|
||||||
|
DirectResolver: options.DirectResolver,
|
||||||
|
ResolverOnDetour: options.ResolveOnDetour,
|
||||||
|
NewDialer: options.ResolveOnDetour,
|
||||||
|
DisableEmptyDirectCheck: options.DisableEmptyDirectCheck,
|
||||||
|
DefaultOutbound: options.DefaultOutbound,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
headers := options.Headers.Build()
|
||||||
|
host := headers.Get("Host")
|
||||||
|
headers.Del("Host")
|
||||||
|
|
||||||
|
var cheapRebuild bool
|
||||||
|
switch options.Engine {
|
||||||
|
case C.TLSEngineApple:
|
||||||
|
inner, transportErr := newAppleTransport(ctx, logger, rawDialer, options)
|
||||||
|
if transportErr != nil {
|
||||||
|
return nil, transportErr
|
||||||
|
}
|
||||||
|
managedTransport := &ManagedTransport{
|
||||||
|
dialer: rawDialer,
|
||||||
|
headers: headers,
|
||||||
|
host: host,
|
||||||
|
tag: tag,
|
||||||
|
factory: func() (innerTransport, error) {
|
||||||
|
return newAppleTransport(ctx, logger, rawDialer, options)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
managedTransport.epoch.Store(&transportEpoch{transport: inner})
|
||||||
|
return managedTransport, nil
|
||||||
|
case C.TLSEngineDefault, "go":
|
||||||
|
cheapRebuild = true
|
||||||
|
default:
|
||||||
|
return nil, E.New("unknown HTTP engine: ", options.Engine)
|
||||||
|
}
|
||||||
|
tlsOptions := common.PtrValueOrDefault(options.TLS)
|
||||||
|
tlsOptions.Enabled = true
|
||||||
|
baseTLSConfig, err := tls.NewClientWithOptions(tls.ClientOptions{
|
||||||
|
Context: ctx,
|
||||||
|
Logger: logger,
|
||||||
|
Options: tlsOptions,
|
||||||
|
AllowEmptyServerName: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
inner, err := newTransport(rawDialer, baseTLSConfig, options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
managedTransport := &ManagedTransport{
|
||||||
|
cheapRebuild: cheapRebuild,
|
||||||
|
dialer: rawDialer,
|
||||||
|
headers: headers,
|
||||||
|
host: host,
|
||||||
|
tag: tag,
|
||||||
|
factory: func() (innerTransport, error) {
|
||||||
|
return newTransport(rawDialer, baseTLSConfig, options)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
managedTransport.epoch.Store(&transportEpoch{transport: inner})
|
||||||
|
return managedTransport, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTransport(rawDialer N.Dialer, baseTLSConfig tls.Config, options option.HTTPClientOptions) (innerTransport, error) {
|
||||||
|
version := options.Version
|
||||||
|
if version == 0 {
|
||||||
|
version = 2
|
||||||
|
}
|
||||||
|
fallbackDelay := time.Duration(options.DialerOptions.FallbackDelay)
|
||||||
|
if fallbackDelay == 0 {
|
||||||
|
fallbackDelay = 300 * time.Millisecond
|
||||||
|
}
|
||||||
|
var transport innerTransport
|
||||||
|
var err error
|
||||||
|
switch version {
|
||||||
|
case 1:
|
||||||
|
transport = newHTTP1Transport(rawDialer, baseTLSConfig)
|
||||||
|
case 2:
|
||||||
|
if options.DisableVersionFallback {
|
||||||
|
transport, err = newHTTP2Transport(rawDialer, baseTLSConfig, options.HTTP2Options)
|
||||||
|
} else {
|
||||||
|
transport, err = newHTTP2FallbackTransport(rawDialer, baseTLSConfig, options.HTTP2Options)
|
||||||
|
}
|
||||||
|
case 3:
|
||||||
|
if baseTLSConfig != nil {
|
||||||
|
_, err = baseTLSConfig.STDConfig()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if options.DisableVersionFallback {
|
||||||
|
transport, err = newHTTP3Transport(rawDialer, baseTLSConfig, options.HTTP3Options)
|
||||||
|
} else {
|
||||||
|
var h2Fallback innerTransport
|
||||||
|
h2Fallback, err = newHTTP2FallbackTransport(rawDialer, baseTLSConfig, options.HTTP2Options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
transport, err = newHTTP3FallbackTransport(rawDialer, baseTLSConfig, h2Fallback, options.HTTP3Options, fallbackDelay)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, E.New("unknown HTTP version: ", version)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return transport, nil
|
||||||
|
}
|
||||||
14
common/httpclient/context.go
Normal file
14
common/httpclient/context.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package httpclient
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type transportKey struct{}
|
||||||
|
|
||||||
|
func contextWithTransportTag(ctx context.Context, transportTag string) context.Context {
|
||||||
|
return context.WithValue(ctx, transportKey{}, transportTag)
|
||||||
|
}
|
||||||
|
|
||||||
|
func transportTagFromContext(ctx context.Context) (string, bool) {
|
||||||
|
value, loaded := ctx.Value(transportKey{}).(string)
|
||||||
|
return value, loaded
|
||||||
|
}
|
||||||
86
common/httpclient/helpers.go
Normal file
86
common/httpclient/helpers.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
stdTLS "crypto/tls"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func dialTLS(ctx context.Context, rawDialer N.Dialer, baseTLSConfig tls.Config, destination M.Socksaddr, nextProtos []string, expectProto string) (net.Conn, error) {
|
||||||
|
if baseTLSConfig == nil {
|
||||||
|
return nil, E.New("TLS transport unavailable")
|
||||||
|
}
|
||||||
|
tlsConfig := baseTLSConfig.Clone()
|
||||||
|
if tlsConfig.ServerName() == "" && destination.IsValid() {
|
||||||
|
tlsConfig.SetServerName(destination.AddrString())
|
||||||
|
}
|
||||||
|
tlsConfig.SetNextProtos(nextProtos)
|
||||||
|
conn, err := rawDialer.DialContext(ctx, N.NetworkTCP, destination)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tlsConn, err := tls.ClientHandshake(ctx, conn, tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if expectProto != "" && tlsConn.ConnectionState().NegotiatedProtocol != expectProto {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, errHTTP2Fallback
|
||||||
|
}
|
||||||
|
return tlsConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyHeaders(request *http.Request, headers http.Header, host string) {
|
||||||
|
for header, values := range headers {
|
||||||
|
request.Header[header] = append([]string(nil), values...)
|
||||||
|
}
|
||||||
|
if host != "" {
|
||||||
|
request.Host = host
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestRequiresHTTP1(request *http.Request) bool {
|
||||||
|
return strings.Contains(strings.ToLower(request.Header.Get("Connection")), "upgrade") &&
|
||||||
|
strings.EqualFold(request.Header.Get("Upgrade"), "websocket")
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestReplayable(request *http.Request) bool {
|
||||||
|
return request.Body == nil || request.Body == http.NoBody || request.GetBody != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneRequestForRetry(request *http.Request) *http.Request {
|
||||||
|
cloned := request.Clone(request.Context())
|
||||||
|
if request.Body != nil && request.Body != http.NoBody && request.GetBody != nil {
|
||||||
|
cloned.Body = mustGetBody(request)
|
||||||
|
}
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustGetBody(request *http.Request) io.ReadCloser {
|
||||||
|
body, err := request.GetBody()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildSTDTLSConfig(baseTLSConfig tls.Config, destination M.Socksaddr, nextProtos []string) (*stdTLS.Config, error) {
|
||||||
|
if baseTLSConfig == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
tlsConfig := baseTLSConfig.Clone()
|
||||||
|
if tlsConfig.ServerName() == "" && destination.IsValid() {
|
||||||
|
tlsConfig.SetServerName(destination.AddrString())
|
||||||
|
}
|
||||||
|
tlsConfig.SetNextProtos(nextProtos)
|
||||||
|
return tlsConfig.STDConfig()
|
||||||
|
}
|
||||||
42
common/httpclient/http1_transport.go
Normal file
42
common/httpclient/http1_transport.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
type http1Transport struct {
|
||||||
|
transport *http.Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP1Transport(rawDialer N.Dialer, baseTLSConfig tls.Config) *http1Transport {
|
||||||
|
transport := &http.Transport{
|
||||||
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return rawDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if baseTLSConfig != nil {
|
||||||
|
transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return dialTLS(ctx, rawDialer, baseTLSConfig, M.ParseSocksaddr(addr), []string{"http/1.1"}, "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &http1Transport{transport: transport}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http1Transport) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
return t.transport.RoundTrip(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http1Transport) CloseIdleConnections() {
|
||||||
|
t.transport.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http1Transport) Close() error {
|
||||||
|
t.CloseIdleConnections()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
42
common/httpclient/http2_config.go
Normal file
42
common/httpclient/http2_config.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
stdTLS "crypto/tls"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CloneHTTP2Transport(transport *http2.Transport) *http2.Transport {
|
||||||
|
return &http2.Transport{
|
||||||
|
ReadIdleTimeout: transport.ReadIdleTimeout,
|
||||||
|
PingTimeout: transport.PingTimeout,
|
||||||
|
DialTLSContext: transport.DialTLSContext,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConfigureHTTP2Transport(options option.HTTP2Options) (*http2.Transport, error) {
|
||||||
|
stdTransport := &http.Transport{
|
||||||
|
TLSClientConfig: &stdTLS.Config{},
|
||||||
|
HTTP2: &http.HTTP2Config{
|
||||||
|
MaxReceiveBufferPerStream: int(options.StreamReceiveWindow.Value()),
|
||||||
|
MaxReceiveBufferPerConnection: int(options.ConnectionReceiveWindow.Value()),
|
||||||
|
MaxConcurrentStreams: options.MaxConcurrentStreams,
|
||||||
|
SendPingTimeout: time.Duration(options.KeepAlivePeriod),
|
||||||
|
PingTimeout: time.Duration(options.IdleTimeout),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h2Transport, err := http2.ConfigureTransports(stdTransport)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "configure HTTP/2 transport")
|
||||||
|
}
|
||||||
|
// ConfigureTransports binds ConnPool to the throwaway http.Transport; sever it so DialTLSContext is used directly.
|
||||||
|
h2Transport.ConnPool = nil
|
||||||
|
h2Transport.ReadIdleTimeout = time.Duration(options.KeepAlivePeriod)
|
||||||
|
h2Transport.PingTimeout = time.Duration(options.IdleTimeout)
|
||||||
|
return h2Transport, nil
|
||||||
|
}
|
||||||
84
common/httpclient/http2_fallback_transport.go
Normal file
84
common/httpclient/http2_fallback_transport.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
stdTLS "crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errHTTP2Fallback = E.New("fallback to HTTP/1.1")
|
||||||
|
|
||||||
|
type http2FallbackTransport struct {
|
||||||
|
h2Transport *http2.Transport
|
||||||
|
h1Transport *http1Transport
|
||||||
|
h2Fallback *atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP2FallbackTransport(rawDialer N.Dialer, baseTLSConfig tls.Config, options option.HTTP2Options) (*http2FallbackTransport, error) {
|
||||||
|
h1 := newHTTP1Transport(rawDialer, baseTLSConfig)
|
||||||
|
var fallback atomic.Bool
|
||||||
|
h2Transport, err := ConfigureHTTP2Transport(options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
h2Transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *stdTLS.Config) (net.Conn, error) {
|
||||||
|
conn, dialErr := dialTLS(ctx, rawDialer, baseTLSConfig, M.ParseSocksaddr(addr), []string{http2.NextProtoTLS, "http/1.1"}, http2.NextProtoTLS)
|
||||||
|
if dialErr != nil {
|
||||||
|
if errors.Is(dialErr, errHTTP2Fallback) {
|
||||||
|
fallback.Store(true)
|
||||||
|
}
|
||||||
|
return nil, dialErr
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
return &http2FallbackTransport{
|
||||||
|
h2Transport: h2Transport,
|
||||||
|
h1Transport: h1,
|
||||||
|
h2Fallback: &fallback,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http2FallbackTransport) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
return t.roundTrip(request, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http2FallbackTransport) roundTrip(request *http.Request, allowHTTP1Fallback bool) (*http.Response, error) {
|
||||||
|
if request.URL.Scheme != "https" || requestRequiresHTTP1(request) {
|
||||||
|
return t.h1Transport.RoundTrip(request)
|
||||||
|
}
|
||||||
|
if t.h2Fallback.Load() {
|
||||||
|
if !allowHTTP1Fallback {
|
||||||
|
return nil, errHTTP2Fallback
|
||||||
|
}
|
||||||
|
return t.h1Transport.RoundTrip(request)
|
||||||
|
}
|
||||||
|
response, err := t.h2Transport.RoundTrip(request)
|
||||||
|
if err == nil {
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, errHTTP2Fallback) || !allowHTTP1Fallback {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return t.h1Transport.RoundTrip(cloneRequestForRetry(request))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http2FallbackTransport) CloseIdleConnections() {
|
||||||
|
t.h1Transport.CloseIdleConnections()
|
||||||
|
t.h2Transport.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http2FallbackTransport) Close() error {
|
||||||
|
t.CloseIdleConnections()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
52
common/httpclient/http2_transport.go
Normal file
52
common/httpclient/http2_transport.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
stdTLS "crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type http2Transport struct {
|
||||||
|
h2Transport *http2.Transport
|
||||||
|
h1Transport *http1Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP2Transport(rawDialer N.Dialer, baseTLSConfig tls.Config, options option.HTTP2Options) (*http2Transport, error) {
|
||||||
|
h1 := newHTTP1Transport(rawDialer, baseTLSConfig)
|
||||||
|
h2Transport, err := ConfigureHTTP2Transport(options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
h2Transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *stdTLS.Config) (net.Conn, error) {
|
||||||
|
return dialTLS(ctx, rawDialer, baseTLSConfig, M.ParseSocksaddr(addr), []string{http2.NextProtoTLS}, http2.NextProtoTLS)
|
||||||
|
}
|
||||||
|
return &http2Transport{
|
||||||
|
h2Transport: h2Transport,
|
||||||
|
h1Transport: h1,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http2Transport) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
if request.URL.Scheme != "https" || requestRequiresHTTP1(request) {
|
||||||
|
return t.h1Transport.RoundTrip(request)
|
||||||
|
}
|
||||||
|
return t.h2Transport.RoundTrip(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http2Transport) CloseIdleConnections() {
|
||||||
|
t.h1Transport.CloseIdleConnections()
|
||||||
|
t.h2Transport.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http2Transport) Close() error {
|
||||||
|
t.CloseIdleConnections()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
297
common/httpclient/http3_transport.go
Normal file
297
common/httpclient/http3_transport.go
Normal file
@@ -0,0 +1,297 @@
|
|||||||
|
//go:build with_quic
|
||||||
|
|
||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
stdTLS "crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go"
|
||||||
|
"github.com/sagernet/quic-go/http3"
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing/common/bufio"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
type http3Transport struct {
|
||||||
|
h3Transport *http3.Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
type http3FallbackTransport struct {
|
||||||
|
h3Transport *http3.Transport
|
||||||
|
h2Fallback innerTransport
|
||||||
|
fallbackDelay time.Duration
|
||||||
|
brokenAccess sync.Mutex
|
||||||
|
brokenUntil time.Time
|
||||||
|
brokenBackoff time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP3RoundTripper(
|
||||||
|
rawDialer N.Dialer,
|
||||||
|
baseTLSConfig tls.Config,
|
||||||
|
options option.QUICOptions,
|
||||||
|
) *http3.Transport {
|
||||||
|
var handshakeTimeout time.Duration
|
||||||
|
if baseTLSConfig != nil {
|
||||||
|
handshakeTimeout = baseTLSConfig.HandshakeTimeout()
|
||||||
|
}
|
||||||
|
quicConfig := &quic.Config{
|
||||||
|
InitialStreamReceiveWindow: options.StreamReceiveWindow.Value(),
|
||||||
|
MaxStreamReceiveWindow: options.StreamReceiveWindow.Value(),
|
||||||
|
InitialConnectionReceiveWindow: options.ConnectionReceiveWindow.Value(),
|
||||||
|
MaxConnectionReceiveWindow: options.ConnectionReceiveWindow.Value(),
|
||||||
|
KeepAlivePeriod: time.Duration(options.KeepAlivePeriod),
|
||||||
|
MaxIdleTimeout: time.Duration(options.IdleTimeout),
|
||||||
|
DisablePathMTUDiscovery: options.DisablePathMTUDiscovery,
|
||||||
|
}
|
||||||
|
if options.InitialPacketSize > 0 {
|
||||||
|
quicConfig.InitialPacketSize = uint16(options.InitialPacketSize)
|
||||||
|
}
|
||||||
|
if options.MaxConcurrentStreams > 0 {
|
||||||
|
quicConfig.MaxIncomingStreams = int64(options.MaxConcurrentStreams)
|
||||||
|
}
|
||||||
|
if handshakeTimeout > 0 {
|
||||||
|
quicConfig.HandshakeIdleTimeout = handshakeTimeout
|
||||||
|
}
|
||||||
|
h3Transport := &http3.Transport{
|
||||||
|
TLSClientConfig: &stdTLS.Config{},
|
||||||
|
QUICConfig: quicConfig,
|
||||||
|
Dial: func(ctx context.Context, addr string, tlsConfig *stdTLS.Config, quicConfig *quic.Config) (*quic.Conn, error) {
|
||||||
|
if handshakeTimeout > 0 && quicConfig.HandshakeIdleTimeout == 0 {
|
||||||
|
quicConfig = quicConfig.Clone()
|
||||||
|
quicConfig.HandshakeIdleTimeout = handshakeTimeout
|
||||||
|
}
|
||||||
|
if baseTLSConfig != nil {
|
||||||
|
var err error
|
||||||
|
tlsConfig, err = buildSTDTLSConfig(baseTLSConfig, M.ParseSocksaddr(addr), []string{http3.NextProtoH3})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tlsConfig = tlsConfig.Clone()
|
||||||
|
tlsConfig.NextProtos = []string{http3.NextProtoH3}
|
||||||
|
}
|
||||||
|
conn, err := rawDialer.DialContext(ctx, N.NetworkUDP, M.ParseSocksaddr(addr))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
quicConn, err := quic.DialEarly(ctx, bufio.NewUnbindPacketConn(conn), conn.RemoteAddr(), tlsConfig, quicConfig)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return quicConn, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return h3Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP3Transport(
|
||||||
|
rawDialer N.Dialer,
|
||||||
|
baseTLSConfig tls.Config,
|
||||||
|
options option.QUICOptions,
|
||||||
|
) (innerTransport, error) {
|
||||||
|
return &http3Transport{
|
||||||
|
h3Transport: newHTTP3RoundTripper(rawDialer, baseTLSConfig, options),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP3FallbackTransport(
|
||||||
|
rawDialer N.Dialer,
|
||||||
|
baseTLSConfig tls.Config,
|
||||||
|
h2Fallback innerTransport,
|
||||||
|
options option.QUICOptions,
|
||||||
|
fallbackDelay time.Duration,
|
||||||
|
) (innerTransport, error) {
|
||||||
|
return &http3FallbackTransport{
|
||||||
|
h3Transport: newHTTP3RoundTripper(rawDialer, baseTLSConfig, options),
|
||||||
|
h2Fallback: h2Fallback,
|
||||||
|
fallbackDelay: fallbackDelay,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3Transport) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
return t.h3Transport.RoundTrip(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3Transport) CloseIdleConnections() {
|
||||||
|
t.h3Transport.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3Transport) Close() error {
|
||||||
|
t.CloseIdleConnections()
|
||||||
|
return t.h3Transport.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
if request.URL.Scheme != "https" || requestRequiresHTTP1(request) {
|
||||||
|
return t.h2Fallback.RoundTrip(request)
|
||||||
|
}
|
||||||
|
return t.roundTripHTTP3(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) roundTripHTTP3(request *http.Request) (*http.Response, error) {
|
||||||
|
if t.h3Broken() {
|
||||||
|
return t.h2FallbackRoundTrip(request)
|
||||||
|
}
|
||||||
|
response, err := t.h3Transport.RoundTripOpt(request, http3.RoundTripOpt{OnlyCachedConn: true})
|
||||||
|
if err == nil {
|
||||||
|
t.clearH3Broken()
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, http3.ErrNoCachedConn) {
|
||||||
|
t.markH3Broken()
|
||||||
|
return t.h2FallbackRoundTrip(cloneRequestForRetry(request))
|
||||||
|
}
|
||||||
|
if !requestReplayable(request) {
|
||||||
|
response, err = t.h3Transport.RoundTrip(request)
|
||||||
|
if err == nil {
|
||||||
|
t.clearH3Broken()
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
t.markH3Broken()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return t.roundTripHTTP3Race(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) roundTripHTTP3Race(request *http.Request) (*http.Response, error) {
|
||||||
|
ctx, cancel := context.WithCancel(request.Context())
|
||||||
|
defer cancel()
|
||||||
|
type result struct {
|
||||||
|
response *http.Response
|
||||||
|
err error
|
||||||
|
h3 bool
|
||||||
|
}
|
||||||
|
results := make(chan result, 2)
|
||||||
|
startRoundTrip := func(request *http.Request, useH3 bool) {
|
||||||
|
request = request.WithContext(ctx)
|
||||||
|
var (
|
||||||
|
response *http.Response
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if useH3 {
|
||||||
|
response, err = t.h3Transport.RoundTrip(request)
|
||||||
|
} else {
|
||||||
|
response, err = t.h2FallbackRoundTrip(request)
|
||||||
|
}
|
||||||
|
results <- result{response: response, err: err, h3: useH3}
|
||||||
|
}
|
||||||
|
goroutines := 1
|
||||||
|
received := 0
|
||||||
|
drainRemaining := func() {
|
||||||
|
cancel()
|
||||||
|
for range goroutines - received {
|
||||||
|
go func() {
|
||||||
|
loser := <-results
|
||||||
|
if loser.response != nil && loser.response.Body != nil {
|
||||||
|
loser.response.Body.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
go startRoundTrip(cloneRequestForRetry(request), true)
|
||||||
|
timer := time.NewTimer(t.fallbackDelay)
|
||||||
|
defer timer.Stop()
|
||||||
|
var (
|
||||||
|
h3Err error
|
||||||
|
fallbackErr error
|
||||||
|
)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
if goroutines == 1 {
|
||||||
|
goroutines++
|
||||||
|
go startRoundTrip(cloneRequestForRetry(request), false)
|
||||||
|
}
|
||||||
|
case raceResult := <-results:
|
||||||
|
received++
|
||||||
|
if raceResult.err == nil {
|
||||||
|
if raceResult.h3 {
|
||||||
|
t.clearH3Broken()
|
||||||
|
}
|
||||||
|
drainRemaining()
|
||||||
|
return raceResult.response, nil
|
||||||
|
}
|
||||||
|
if raceResult.h3 {
|
||||||
|
t.markH3Broken()
|
||||||
|
h3Err = raceResult.err
|
||||||
|
if goroutines == 1 {
|
||||||
|
goroutines++
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
go startRoundTrip(cloneRequestForRetry(request), false)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fallbackErr = raceResult.err
|
||||||
|
}
|
||||||
|
if received < goroutines {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
drainRemaining()
|
||||||
|
switch {
|
||||||
|
case h3Err != nil && fallbackErr != nil:
|
||||||
|
return nil, E.Errors(h3Err, fallbackErr)
|
||||||
|
case fallbackErr != nil:
|
||||||
|
return nil, fallbackErr
|
||||||
|
default:
|
||||||
|
return nil, h3Err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) h2FallbackRoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
if fallback, isFallback := t.h2Fallback.(*http2FallbackTransport); isFallback {
|
||||||
|
return fallback.roundTrip(request, true)
|
||||||
|
}
|
||||||
|
return t.h2Fallback.RoundTrip(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) CloseIdleConnections() {
|
||||||
|
t.h3Transport.CloseIdleConnections()
|
||||||
|
t.h2Fallback.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) Close() error {
|
||||||
|
t.CloseIdleConnections()
|
||||||
|
return t.h3Transport.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) h3Broken() bool {
|
||||||
|
t.brokenAccess.Lock()
|
||||||
|
defer t.brokenAccess.Unlock()
|
||||||
|
return !t.brokenUntil.IsZero() && time.Now().Before(t.brokenUntil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) clearH3Broken() {
|
||||||
|
t.brokenAccess.Lock()
|
||||||
|
t.brokenUntil = time.Time{}
|
||||||
|
t.brokenBackoff = 0
|
||||||
|
t.brokenAccess.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) markH3Broken() {
|
||||||
|
t.brokenAccess.Lock()
|
||||||
|
defer t.brokenAccess.Unlock()
|
||||||
|
if t.brokenBackoff == 0 {
|
||||||
|
t.brokenBackoff = 5 * time.Minute
|
||||||
|
} else {
|
||||||
|
t.brokenBackoff *= 2
|
||||||
|
if t.brokenBackoff > 48*time.Hour {
|
||||||
|
t.brokenBackoff = 48 * time.Hour
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.brokenUntil = time.Now().Add(t.brokenBackoff)
|
||||||
|
}
|
||||||
30
common/httpclient/http3_transport_stub.go
Normal file
30
common/httpclient/http3_transport_stub.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
//go:build !with_quic
|
||||||
|
|
||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newHTTP3FallbackTransport(
|
||||||
|
rawDialer N.Dialer,
|
||||||
|
baseTLSConfig tls.Config,
|
||||||
|
h2Fallback innerTransport,
|
||||||
|
options option.QUICOptions,
|
||||||
|
fallbackDelay time.Duration,
|
||||||
|
) (innerTransport, error) {
|
||||||
|
return nil, E.New("HTTP/3 requires building with the with_quic tag")
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP3Transport(
|
||||||
|
rawDialer N.Dialer,
|
||||||
|
baseTLSConfig tls.Config,
|
||||||
|
options option.QUICOptions,
|
||||||
|
) (innerTransport, error) {
|
||||||
|
return nil, E.New("HTTP/3 requires building with the with_quic tag")
|
||||||
|
}
|
||||||
209
common/httpclient/managed_transport.go
Normal file
209
common/httpclient/managed_transport.go
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
type innerTransport interface {
|
||||||
|
http.RoundTripper
|
||||||
|
CloseIdleConnections()
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ adapter.HTTPTransport = (*ManagedTransport)(nil)
|
||||||
|
|
||||||
|
type ManagedTransport struct {
|
||||||
|
epoch atomic.Pointer[transportEpoch]
|
||||||
|
rebuildAccess sync.Mutex
|
||||||
|
factory func() (innerTransport, error)
|
||||||
|
cheapRebuild bool
|
||||||
|
|
||||||
|
dialer N.Dialer
|
||||||
|
headers http.Header
|
||||||
|
host string
|
||||||
|
tag string
|
||||||
|
}
|
||||||
|
|
||||||
|
type transportEpoch struct {
|
||||||
|
transport innerTransport
|
||||||
|
active atomic.Int64
|
||||||
|
marked atomic.Bool
|
||||||
|
closeOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
type managedResponseBody struct {
|
||||||
|
body io.ReadCloser
|
||||||
|
release func()
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *transportEpoch) tryClose() {
|
||||||
|
e.closeOnce.Do(func() {
|
||||||
|
e.transport.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *managedResponseBody) Read(p []byte) (int, error) {
|
||||||
|
return b.body.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *managedResponseBody) Close() error {
|
||||||
|
err := b.body.Close()
|
||||||
|
b.once.Do(b.release)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ManagedTransport) getEpoch() (*transportEpoch, error) {
|
||||||
|
epoch := t.epoch.Load()
|
||||||
|
if epoch != nil {
|
||||||
|
return epoch, nil
|
||||||
|
}
|
||||||
|
t.rebuildAccess.Lock()
|
||||||
|
defer t.rebuildAccess.Unlock()
|
||||||
|
epoch = t.epoch.Load()
|
||||||
|
if epoch != nil {
|
||||||
|
return epoch, nil
|
||||||
|
}
|
||||||
|
inner, err := t.factory()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
epoch = &transportEpoch{transport: inner}
|
||||||
|
t.epoch.Store(epoch)
|
||||||
|
return epoch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ManagedTransport) acquireEpoch() (*transportEpoch, error) {
|
||||||
|
for {
|
||||||
|
epoch, err := t.getEpoch()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
epoch.active.Add(1)
|
||||||
|
if epoch == t.epoch.Load() {
|
||||||
|
return epoch, nil
|
||||||
|
}
|
||||||
|
t.releaseEpoch(epoch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ManagedTransport) releaseEpoch(epoch *transportEpoch) {
|
||||||
|
if epoch.active.Add(-1) == 0 && epoch.marked.Load() {
|
||||||
|
epoch.tryClose()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ManagedTransport) retireEpoch(epoch *transportEpoch) {
|
||||||
|
if epoch == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
epoch.marked.Store(true)
|
||||||
|
if epoch.active.Load() == 0 {
|
||||||
|
epoch.tryClose()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ManagedTransport) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
epoch, err := t.acquireEpoch()
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "rebuild http transport")
|
||||||
|
}
|
||||||
|
if t.tag != "" {
|
||||||
|
if transportTag, loaded := transportTagFromContext(request.Context()); loaded && transportTag == t.tag {
|
||||||
|
t.releaseEpoch(epoch)
|
||||||
|
return nil, E.New("HTTP request loopback in transport[", t.tag, "]")
|
||||||
|
}
|
||||||
|
request = request.Clone(contextWithTransportTag(request.Context(), t.tag))
|
||||||
|
} else if len(t.headers) > 0 || t.host != "" {
|
||||||
|
request = request.Clone(request.Context())
|
||||||
|
}
|
||||||
|
applyHeaders(request, t.headers, t.host)
|
||||||
|
response, roundTripErr := epoch.transport.RoundTrip(request)
|
||||||
|
if roundTripErr != nil || response == nil || response.Body == nil {
|
||||||
|
t.releaseEpoch(epoch)
|
||||||
|
return response, roundTripErr
|
||||||
|
}
|
||||||
|
response.Body = &managedResponseBody{
|
||||||
|
body: response.Body,
|
||||||
|
release: func() { t.releaseEpoch(epoch) },
|
||||||
|
}
|
||||||
|
return response, roundTripErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ManagedTransport) CloseIdleConnections() {
|
||||||
|
oldEpoch := t.epoch.Swap(nil)
|
||||||
|
if oldEpoch == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
oldEpoch.transport.CloseIdleConnections()
|
||||||
|
t.retireEpoch(oldEpoch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ManagedTransport) Reset() {
|
||||||
|
oldEpoch := t.epoch.Swap(nil)
|
||||||
|
if t.cheapRebuild {
|
||||||
|
t.rebuildAccess.Lock()
|
||||||
|
if t.epoch.Load() == nil {
|
||||||
|
inner, err := t.factory()
|
||||||
|
if err == nil {
|
||||||
|
t.epoch.Store(&transportEpoch{transport: inner})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.rebuildAccess.Unlock()
|
||||||
|
}
|
||||||
|
t.retireEpoch(oldEpoch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ManagedTransport) close() error {
|
||||||
|
epoch := t.epoch.Swap(nil)
|
||||||
|
if epoch != nil {
|
||||||
|
return epoch.transport.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ adapter.HTTPTransport = (*sharedRef)(nil)
|
||||||
|
|
||||||
|
type sharedRef struct {
|
||||||
|
managed *ManagedTransport
|
||||||
|
shared *sharedState
|
||||||
|
idle atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type sharedState struct {
|
||||||
|
activeRefs atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSharedRef(managed *ManagedTransport, shared *sharedState) *sharedRef {
|
||||||
|
shared.activeRefs.Add(1)
|
||||||
|
return &sharedRef{
|
||||||
|
managed: managed,
|
||||||
|
shared: shared,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *sharedRef) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
if r.idle.CompareAndSwap(true, false) {
|
||||||
|
r.shared.activeRefs.Add(1)
|
||||||
|
}
|
||||||
|
return r.managed.RoundTrip(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *sharedRef) CloseIdleConnections() {
|
||||||
|
if r.idle.CompareAndSwap(false, true) {
|
||||||
|
if r.shared.activeRefs.Add(-1) == 0 {
|
||||||
|
r.managed.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *sharedRef) Reset() {
|
||||||
|
r.managed.Reset()
|
||||||
|
}
|
||||||
175
common/httpclient/manager.go
Normal file
175
common/httpclient/manager.go
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ adapter.HTTPClientManager = (*Manager)(nil)
|
||||||
|
_ adapter.LifecycleService = (*Manager)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
ctx context.Context
|
||||||
|
logger log.ContextLogger
|
||||||
|
access sync.Mutex
|
||||||
|
defines map[string]option.HTTPClient
|
||||||
|
sharedTransports map[string]*sharedManagedTransport
|
||||||
|
managedTransports []*ManagedTransport
|
||||||
|
defaultTag string
|
||||||
|
defaultTransport *sharedManagedTransport
|
||||||
|
defaultTransportFallback func() (*ManagedTransport, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type sharedManagedTransport struct {
|
||||||
|
managed *ManagedTransport
|
||||||
|
shared *sharedState
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(ctx context.Context, logger log.ContextLogger, clients []option.HTTPClient, defaultHTTPClient string) *Manager {
|
||||||
|
defines := make(map[string]option.HTTPClient, len(clients))
|
||||||
|
for _, client := range clients {
|
||||||
|
defines[client.Tag] = client
|
||||||
|
}
|
||||||
|
defaultTag := defaultHTTPClient
|
||||||
|
if defaultTag == "" && len(clients) > 0 {
|
||||||
|
defaultTag = clients[0].Tag
|
||||||
|
}
|
||||||
|
return &Manager{
|
||||||
|
ctx: ctx,
|
||||||
|
logger: logger,
|
||||||
|
defines: defines,
|
||||||
|
sharedTransports: make(map[string]*sharedManagedTransport),
|
||||||
|
defaultTag: defaultTag,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Initialize(defaultTransportFallback func() (*ManagedTransport, error)) {
|
||||||
|
m.defaultTransportFallback = defaultTransportFallback
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Name() string {
|
||||||
|
return "http-client"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if m.defaultTag != "" {
|
||||||
|
sharedTransport, err := m.resolveShared(m.defaultTag)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "resolve default http client")
|
||||||
|
}
|
||||||
|
m.defaultTransport = sharedTransport
|
||||||
|
} else if m.defaultTransportFallback != nil {
|
||||||
|
transport, err := m.defaultTransportFallback()
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "create default http client")
|
||||||
|
}
|
||||||
|
m.trackTransport(transport)
|
||||||
|
m.defaultTransport = &sharedManagedTransport{
|
||||||
|
managed: transport,
|
||||||
|
shared: &sharedState{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DefaultTransport() adapter.HTTPTransport {
|
||||||
|
if m.defaultTransport == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return newSharedRef(m.defaultTransport.managed, m.defaultTransport.shared)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) ResolveTransport(ctx context.Context, logger logger.ContextLogger, options option.HTTPClientOptions) (adapter.HTTPTransport, error) {
|
||||||
|
if options.Tag != "" {
|
||||||
|
if options.ResolveOnDetour {
|
||||||
|
define, loaded := m.defines[options.Tag]
|
||||||
|
if !loaded {
|
||||||
|
return nil, E.New("http_client not found: ", options.Tag)
|
||||||
|
}
|
||||||
|
resolvedOptions := define.Options()
|
||||||
|
resolvedOptions.ResolveOnDetour = true
|
||||||
|
transport, err := NewTransport(ctx, logger, options.Tag, resolvedOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m.trackTransport(transport)
|
||||||
|
return transport, nil
|
||||||
|
}
|
||||||
|
sharedTransport, err := m.resolveShared(options.Tag)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return newSharedRef(sharedTransport.managed, sharedTransport.shared), nil
|
||||||
|
}
|
||||||
|
transport, err := NewTransport(ctx, logger, "", options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m.trackTransport(transport)
|
||||||
|
return transport, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) trackTransport(transport *ManagedTransport) {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
m.managedTransports = append(m.managedTransports, transport)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) resolveShared(tag string) (*sharedManagedTransport, error) {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
if sharedTransport, loaded := m.sharedTransports[tag]; loaded {
|
||||||
|
return sharedTransport, nil
|
||||||
|
}
|
||||||
|
define, loaded := m.defines[tag]
|
||||||
|
if !loaded {
|
||||||
|
return nil, E.New("http_client not found: ", tag)
|
||||||
|
}
|
||||||
|
transport, err := NewTransport(m.ctx, m.logger, tag, define.Options())
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "create shared http_client[", tag, "]")
|
||||||
|
}
|
||||||
|
sharedTransport := &sharedManagedTransport{
|
||||||
|
managed: transport,
|
||||||
|
shared: &sharedState{},
|
||||||
|
}
|
||||||
|
m.sharedTransports[tag] = sharedTransport
|
||||||
|
m.managedTransports = append(m.managedTransports, transport)
|
||||||
|
return sharedTransport, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) ResetNetwork() {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
for _, transport := range m.managedTransports {
|
||||||
|
transport.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Close() error {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
if m.managedTransports == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
for _, transport := range m.managedTransports {
|
||||||
|
err = E.Append(err, transport.close(), func(err error) error {
|
||||||
|
return E.Cause(err, "close http client")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
m.managedTransports = nil
|
||||||
|
m.sharedTransports = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
115
common/proxybridge/bridge.go
Normal file
115
common/proxybridge/bridge.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package proxybridge
|
||||||
|
|
||||||
|
import (
|
||||||
|
std_bufio "bufio"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/auth"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/sing/protocol/socks"
|
||||||
|
"github.com/sagernet/sing/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Bridge struct {
|
||||||
|
ctx context.Context
|
||||||
|
logger logger.ContextLogger
|
||||||
|
tag string
|
||||||
|
dialer N.Dialer
|
||||||
|
connection adapter.ConnectionManager
|
||||||
|
tcpListener *net.TCPListener
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
authenticator *auth.Authenticator
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(ctx context.Context, logger logger.ContextLogger, tag string, dialer N.Dialer) (*Bridge, error) {
|
||||||
|
username := randomHex(16)
|
||||||
|
password := randomHex(16)
|
||||||
|
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
bridge := &Bridge{
|
||||||
|
ctx: ctx,
|
||||||
|
logger: logger,
|
||||||
|
tag: tag,
|
||||||
|
dialer: dialer,
|
||||||
|
connection: service.FromContext[adapter.ConnectionManager](ctx),
|
||||||
|
tcpListener: tcpListener,
|
||||||
|
username: username,
|
||||||
|
password: password,
|
||||||
|
authenticator: auth.NewAuthenticator([]auth.User{{Username: username, Password: password}}),
|
||||||
|
}
|
||||||
|
go bridge.acceptLoop()
|
||||||
|
return bridge, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomHex(size int) string {
|
||||||
|
raw := make([]byte, size)
|
||||||
|
rand.Read(raw)
|
||||||
|
return hex.EncodeToString(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bridge) Port() uint16 {
|
||||||
|
return M.SocksaddrFromNet(b.tcpListener.Addr()).Port
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bridge) Username() string {
|
||||||
|
return b.username
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bridge) Password() string {
|
||||||
|
return b.password
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bridge) Close() error {
|
||||||
|
return common.Close(b.tcpListener)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bridge) acceptLoop() {
|
||||||
|
for {
|
||||||
|
tcpConn, err := b.tcpListener.AcceptTCP()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx := log.ContextWithNewID(b.ctx)
|
||||||
|
go func() {
|
||||||
|
hErr := socks.HandleConnectionEx(ctx, tcpConn, std_bufio.NewReader(tcpConn), b.authenticator, b, nil, 0, M.SocksaddrFromNet(tcpConn.RemoteAddr()), nil)
|
||||||
|
if hErr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if E.IsClosedOrCanceled(hErr) {
|
||||||
|
b.logger.DebugContext(ctx, E.Cause(hErr, b.tag, " connection closed"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.logger.ErrorContext(ctx, E.Cause(hErr, b.tag))
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bridge) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
|
||||||
|
var metadata adapter.InboundContext
|
||||||
|
metadata.Source = source
|
||||||
|
metadata.Destination = destination
|
||||||
|
metadata.Network = N.NetworkTCP
|
||||||
|
b.logger.InfoContext(ctx, b.tag, " connection to ", metadata.Destination)
|
||||||
|
b.connection.NewConnection(ctx, b.dialer, conn, metadata, onClose)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bridge) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
|
||||||
|
var metadata adapter.InboundContext
|
||||||
|
metadata.Source = source
|
||||||
|
metadata.Destination = destination
|
||||||
|
metadata.Network = N.NetworkUDP
|
||||||
|
b.logger.InfoContext(ctx, b.tag, " packet connection to ", metadata.Destination)
|
||||||
|
b.connection.NewPacketConnection(ctx, b.dialer, conn, metadata, onClose)
|
||||||
|
}
|
||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common/bufio"
|
||||||
|
"github.com/sagernet/sing/common/bufio/deadline"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
@@ -431,6 +433,9 @@ func Run(options Options) (*Result, error) {
|
|||||||
defer func() {
|
defer func() {
|
||||||
_ = packetConn.Close()
|
_ = packetConn.Close()
|
||||||
}()
|
}()
|
||||||
|
if deadline.NeedAdditionalReadDeadline(packetConn) {
|
||||||
|
packetConn = deadline.NewPacketConn(bufio.NewPacketConn(packetConn))
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
|||||||
221
common/tls/apple_client.go
Normal file
221
common/tls/apple_client.go
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
//go:build darwin && cgo
|
||||||
|
|
||||||
|
package tls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
boxConstant "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
"github.com/sagernet/sing/common/ntp"
|
||||||
|
"github.com/sagernet/sing/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type appleCertificateStore interface {
|
||||||
|
StoreKind() string
|
||||||
|
CurrentPEM() []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type appleClientConfig struct {
|
||||||
|
serverName string
|
||||||
|
nextProtos []string
|
||||||
|
handshakeTimeout time.Duration
|
||||||
|
minVersion uint16
|
||||||
|
maxVersion uint16
|
||||||
|
insecure bool
|
||||||
|
anchorPEM string
|
||||||
|
anchorOnly bool
|
||||||
|
certificatePublicKeySHA256 [][]byte
|
||||||
|
timeFunc func() time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) ServerName() string {
|
||||||
|
return c.serverName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) SetServerName(serverName string) {
|
||||||
|
c.serverName = serverName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) NextProtos() []string {
|
||||||
|
return c.nextProtos
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) SetNextProtos(nextProto []string) {
|
||||||
|
c.nextProtos = append(c.nextProtos[:0], nextProto...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) HandshakeTimeout() time.Duration {
|
||||||
|
return c.handshakeTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) SetHandshakeTimeout(timeout time.Duration) {
|
||||||
|
c.handshakeTimeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) STDConfig() (*STDConfig, error) {
|
||||||
|
return nil, E.New("unsupported usage for Apple TLS engine")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) Client(conn net.Conn) (Conn, error) {
|
||||||
|
return nil, os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) Clone() Config {
|
||||||
|
return &appleClientConfig{
|
||||||
|
serverName: c.serverName,
|
||||||
|
nextProtos: append([]string(nil), c.nextProtos...),
|
||||||
|
handshakeTimeout: c.handshakeTimeout,
|
||||||
|
minVersion: c.minVersion,
|
||||||
|
maxVersion: c.maxVersion,
|
||||||
|
insecure: c.insecure,
|
||||||
|
anchorPEM: c.anchorPEM,
|
||||||
|
anchorOnly: c.anchorOnly,
|
||||||
|
certificatePublicKeySHA256: append([][]byte(nil), c.certificatePublicKeySHA256...),
|
||||||
|
timeFunc: c.timeFunc,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
|
||||||
|
validated, err := ValidateAppleTLSOptions(ctx, options, "Apple TLS engine")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var serverName string
|
||||||
|
if options.ServerName != "" {
|
||||||
|
serverName = options.ServerName
|
||||||
|
} else if serverAddress != "" {
|
||||||
|
serverName = serverAddress
|
||||||
|
}
|
||||||
|
if serverName == "" && !options.Insecure && !allowEmptyServerName {
|
||||||
|
return nil, errMissingServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
var handshakeTimeout time.Duration
|
||||||
|
if options.HandshakeTimeout > 0 {
|
||||||
|
handshakeTimeout = options.HandshakeTimeout.Build()
|
||||||
|
} else {
|
||||||
|
handshakeTimeout = boxConstant.TCPTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
return &appleClientConfig{
|
||||||
|
serverName: serverName,
|
||||||
|
nextProtos: append([]string(nil), options.ALPN...),
|
||||||
|
handshakeTimeout: handshakeTimeout,
|
||||||
|
minVersion: validated.MinVersion,
|
||||||
|
maxVersion: validated.MaxVersion,
|
||||||
|
insecure: options.Insecure || len(options.CertificatePublicKeySHA256) > 0,
|
||||||
|
anchorPEM: validated.AnchorPEM,
|
||||||
|
anchorOnly: validated.AnchorOnly,
|
||||||
|
certificatePublicKeySHA256: append([][]byte(nil), options.CertificatePublicKeySHA256...),
|
||||||
|
timeFunc: ntp.TimeFuncFromContext(ctx),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type AppleTLSValidated struct {
|
||||||
|
MinVersion uint16
|
||||||
|
MaxVersion uint16
|
||||||
|
AnchorPEM string
|
||||||
|
AnchorOnly bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateAppleTLSOptions(ctx context.Context, options option.OutboundTLSOptions, engineName string) (AppleTLSValidated, error) {
|
||||||
|
if options.Reality != nil && options.Reality.Enabled {
|
||||||
|
return AppleTLSValidated{}, E.New("reality is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if options.UTLS != nil && options.UTLS.Enabled {
|
||||||
|
return AppleTLSValidated{}, E.New("utls is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if options.ECH != nil && options.ECH.Enabled {
|
||||||
|
return AppleTLSValidated{}, E.New("ech is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if options.DisableSNI {
|
||||||
|
return AppleTLSValidated{}, E.New("disable_sni is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if len(options.CipherSuites) > 0 {
|
||||||
|
return AppleTLSValidated{}, E.New("cipher_suites is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if len(options.CurvePreferences) > 0 {
|
||||||
|
return AppleTLSValidated{}, E.New("curve_preferences is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if len(options.ClientCertificate) > 0 || options.ClientCertificatePath != "" || len(options.ClientKey) > 0 || options.ClientKeyPath != "" {
|
||||||
|
return AppleTLSValidated{}, E.New("client certificate is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if options.Fragment || options.RecordFragment {
|
||||||
|
return AppleTLSValidated{}, E.New("tls fragment is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if options.KernelTx || options.KernelRx {
|
||||||
|
return AppleTLSValidated{}, E.New("ktls is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if options.Spoof != "" || options.SpoofMethod != "" {
|
||||||
|
return AppleTLSValidated{}, E.New("spoof is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if len(options.CertificatePublicKeySHA256) > 0 && (len(options.Certificate) > 0 || options.CertificatePath != "") {
|
||||||
|
return AppleTLSValidated{}, E.New("certificate_public_key_sha256 is conflict with certificate or certificate_path")
|
||||||
|
}
|
||||||
|
var minVersion uint16
|
||||||
|
if options.MinVersion != "" {
|
||||||
|
var err error
|
||||||
|
minVersion, err = ParseTLSVersion(options.MinVersion)
|
||||||
|
if err != nil {
|
||||||
|
return AppleTLSValidated{}, E.Cause(err, "parse min_version")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var maxVersion uint16
|
||||||
|
if options.MaxVersion != "" {
|
||||||
|
var err error
|
||||||
|
maxVersion, err = ParseTLSVersion(options.MaxVersion)
|
||||||
|
if err != nil {
|
||||||
|
return AppleTLSValidated{}, E.Cause(err, "parse max_version")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
anchorPEM, anchorOnly, err := AppleAnchorPEM(ctx, options)
|
||||||
|
if err != nil {
|
||||||
|
return AppleTLSValidated{}, err
|
||||||
|
}
|
||||||
|
return AppleTLSValidated{
|
||||||
|
MinVersion: minVersion,
|
||||||
|
MaxVersion: maxVersion,
|
||||||
|
AnchorPEM: anchorPEM,
|
||||||
|
AnchorOnly: anchorOnly,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func AppleAnchorPEM(ctx context.Context, options option.OutboundTLSOptions) (string, bool, error) {
|
||||||
|
if len(options.Certificate) > 0 {
|
||||||
|
return strings.Join(options.Certificate, "\n"), true, nil
|
||||||
|
}
|
||||||
|
if options.CertificatePath != "" {
|
||||||
|
content, err := os.ReadFile(options.CertificatePath)
|
||||||
|
if err != nil {
|
||||||
|
return "", false, E.Cause(err, "read certificate")
|
||||||
|
}
|
||||||
|
return string(content), true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
certificateStore := service.FromContext[adapter.CertificateStore](ctx)
|
||||||
|
if certificateStore == nil {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
store, ok := certificateStore.(appleCertificateStore)
|
||||||
|
if !ok {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch store.StoreKind() {
|
||||||
|
case boxConstant.CertificateStoreSystem, "":
|
||||||
|
return strings.Join(store.CurrentPEM(), "\n"), false, nil
|
||||||
|
case boxConstant.CertificateStoreMozilla, boxConstant.CertificateStoreChrome, boxConstant.CertificateStoreNone:
|
||||||
|
return strings.Join(store.CurrentPEM(), "\n"), true, nil
|
||||||
|
default:
|
||||||
|
return "", false, E.New("unsupported certificate store for Apple TLS engine: ", store.StoreKind())
|
||||||
|
}
|
||||||
|
}
|
||||||
517
common/tls/apple_client_platform.go
Normal file
517
common/tls/apple_client_platform.go
Normal file
@@ -0,0 +1,517 @@
|
|||||||
|
//go:build darwin && cgo
|
||||||
|
|
||||||
|
package tls
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CFLAGS: -x objective-c -fobjc-arc
|
||||||
|
#cgo LDFLAGS: -framework Foundation -framework Network -framework Security
|
||||||
|
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include "apple_client_platform_darwin.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *appleClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (Conn, error) {
|
||||||
|
rawSyscallConn, ok := common.Cast[syscall.Conn](conn)
|
||||||
|
if !ok {
|
||||||
|
return nil, E.New("apple TLS: requires fd-backed TCP connection")
|
||||||
|
}
|
||||||
|
syscallConn, err := rawSyscallConn.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "access raw connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
var dupFD int
|
||||||
|
controlErr := syscallConn.Control(func(fd uintptr) {
|
||||||
|
dupFD, err = unix.Dup(int(fd))
|
||||||
|
})
|
||||||
|
if controlErr != nil {
|
||||||
|
return nil, E.Cause(controlErr, "access raw connection")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "duplicate raw connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
serverName := c.serverName
|
||||||
|
serverNamePtr := cStringOrNil(serverName)
|
||||||
|
defer cFree(serverNamePtr)
|
||||||
|
|
||||||
|
alpn := strings.Join(c.nextProtos, "\n")
|
||||||
|
alpnPtr := cStringOrNil(alpn)
|
||||||
|
defer cFree(alpnPtr)
|
||||||
|
|
||||||
|
anchorPEMPtr := cStringOrNil(c.anchorPEM)
|
||||||
|
defer cFree(anchorPEMPtr)
|
||||||
|
|
||||||
|
var (
|
||||||
|
hasVerifyTime bool
|
||||||
|
verifyTimeUnixMilli int64
|
||||||
|
)
|
||||||
|
if c.timeFunc != nil {
|
||||||
|
hasVerifyTime = true
|
||||||
|
verifyTimeUnixMilli = c.timeFunc().UnixMilli()
|
||||||
|
}
|
||||||
|
|
||||||
|
var errorPtr *C.char
|
||||||
|
client := C.box_apple_tls_client_create(
|
||||||
|
C.int(dupFD),
|
||||||
|
serverNamePtr,
|
||||||
|
alpnPtr,
|
||||||
|
C.size_t(len(alpn)),
|
||||||
|
C.uint16_t(c.minVersion),
|
||||||
|
C.uint16_t(c.maxVersion),
|
||||||
|
C.bool(c.insecure),
|
||||||
|
anchorPEMPtr,
|
||||||
|
C.size_t(len(c.anchorPEM)),
|
||||||
|
C.bool(c.anchorOnly),
|
||||||
|
C.bool(hasVerifyTime),
|
||||||
|
C.int64_t(verifyTimeUnixMilli),
|
||||||
|
&errorPtr,
|
||||||
|
)
|
||||||
|
if client == nil {
|
||||||
|
if errorPtr != nil {
|
||||||
|
defer C.free(unsafe.Pointer(errorPtr))
|
||||||
|
return nil, E.New(C.GoString(errorPtr))
|
||||||
|
}
|
||||||
|
return nil, E.New("apple TLS: create connection")
|
||||||
|
}
|
||||||
|
if err = waitAppleTLSClientReady(ctx, client); err != nil {
|
||||||
|
C.box_apple_tls_client_cancel(client)
|
||||||
|
C.box_apple_tls_client_free(client)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var state C.box_apple_tls_state_t
|
||||||
|
stateOK := C.box_apple_tls_client_copy_state(client, &state, &errorPtr)
|
||||||
|
if !bool(stateOK) {
|
||||||
|
C.box_apple_tls_client_cancel(client)
|
||||||
|
C.box_apple_tls_client_free(client)
|
||||||
|
if errorPtr != nil {
|
||||||
|
defer C.free(unsafe.Pointer(errorPtr))
|
||||||
|
return nil, E.New(C.GoString(errorPtr))
|
||||||
|
}
|
||||||
|
return nil, E.New("apple TLS: read metadata")
|
||||||
|
}
|
||||||
|
defer C.box_apple_tls_state_free(&state)
|
||||||
|
|
||||||
|
connectionState, rawCerts, err := parseAppleTLSState(&state)
|
||||||
|
if err != nil {
|
||||||
|
C.box_apple_tls_client_cancel(client)
|
||||||
|
C.box_apple_tls_client_free(client)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(c.certificatePublicKeySHA256) > 0 {
|
||||||
|
err = VerifyPublicKeySHA256(c.certificatePublicKeySHA256, rawCerts)
|
||||||
|
if err != nil {
|
||||||
|
C.box_apple_tls_client_cancel(client)
|
||||||
|
C.box_apple_tls_client_free(client)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &appleTLSConn{
|
||||||
|
rawConn: conn,
|
||||||
|
client: client,
|
||||||
|
state: connectionState,
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const appleTLSHandshakePollInterval = 100 * time.Millisecond
|
||||||
|
|
||||||
|
func waitAppleTLSClientReady(ctx context.Context, client *C.box_apple_tls_client_t) error {
|
||||||
|
for {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
C.box_apple_tls_client_cancel(client)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
waitTimeout := appleTLSHandshakePollInterval
|
||||||
|
if deadline, loaded := ctx.Deadline(); loaded {
|
||||||
|
remaining := time.Until(deadline)
|
||||||
|
if remaining <= 0 {
|
||||||
|
C.box_apple_tls_client_cancel(client)
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return context.DeadlineExceeded
|
||||||
|
}
|
||||||
|
if remaining < waitTimeout {
|
||||||
|
waitTimeout = remaining
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errorPtr *C.char
|
||||||
|
waitResult := C.box_apple_tls_client_wait_ready(client, C.int(timeoutFromDuration(waitTimeout)), &errorPtr)
|
||||||
|
switch waitResult {
|
||||||
|
case 1:
|
||||||
|
return nil
|
||||||
|
case -2:
|
||||||
|
continue
|
||||||
|
case 0:
|
||||||
|
if errorPtr != nil {
|
||||||
|
defer C.free(unsafe.Pointer(errorPtr))
|
||||||
|
return E.New(C.GoString(errorPtr))
|
||||||
|
}
|
||||||
|
return E.New("apple TLS: handshake failed")
|
||||||
|
default:
|
||||||
|
return E.New("apple TLS: invalid handshake state")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type appleTLSConn struct {
|
||||||
|
rawConn net.Conn
|
||||||
|
client *C.box_apple_tls_client_t
|
||||||
|
state tls.ConnectionState
|
||||||
|
|
||||||
|
readAccess sync.Mutex
|
||||||
|
writeAccess sync.Mutex
|
||||||
|
stateAccess sync.RWMutex
|
||||||
|
closeOnce sync.Once
|
||||||
|
ioAccess sync.Mutex
|
||||||
|
ioGroup sync.WaitGroup
|
||||||
|
closed chan struct{}
|
||||||
|
readEOF bool
|
||||||
|
deadlineAccess sync.Mutex
|
||||||
|
readDeadline time.Time
|
||||||
|
writeDeadline time.Time
|
||||||
|
readTimedOut bool
|
||||||
|
writeTimedOut bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) Read(p []byte) (int, error) {
|
||||||
|
c.readAccess.Lock()
|
||||||
|
defer c.readAccess.Unlock()
|
||||||
|
if c.readEOF {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
timeoutMs, err := c.prepareReadTimeout()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := c.acquireClient()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer c.releaseClient()
|
||||||
|
|
||||||
|
var eof C.bool
|
||||||
|
var errorPtr *C.char
|
||||||
|
n := C.box_apple_tls_client_read(client, unsafe.Pointer(&p[0]), C.size_t(len(p)), C.int(timeoutMs), &eof, &errorPtr)
|
||||||
|
switch {
|
||||||
|
case n == -2:
|
||||||
|
c.markReadTimedOut()
|
||||||
|
return 0, os.ErrDeadlineExceeded
|
||||||
|
case n >= 0:
|
||||||
|
if bool(eof) {
|
||||||
|
c.readEOF = true
|
||||||
|
if n == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return int(n), nil
|
||||||
|
default:
|
||||||
|
if errorPtr != nil {
|
||||||
|
defer C.free(unsafe.Pointer(errorPtr))
|
||||||
|
if c.isClosed() {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
return 0, E.New(C.GoString(errorPtr))
|
||||||
|
}
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) Write(p []byte) (int, error) {
|
||||||
|
c.writeAccess.Lock()
|
||||||
|
defer c.writeAccess.Unlock()
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
timeoutMs, err := c.prepareWriteTimeout()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := c.acquireClient()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer c.releaseClient()
|
||||||
|
|
||||||
|
var errorPtr *C.char
|
||||||
|
n := C.box_apple_tls_client_write(client, unsafe.Pointer(&p[0]), C.size_t(len(p)), C.int(timeoutMs), &errorPtr)
|
||||||
|
switch {
|
||||||
|
case n == -2:
|
||||||
|
c.markWriteTimedOut()
|
||||||
|
return 0, os.ErrDeadlineExceeded
|
||||||
|
case n >= 0:
|
||||||
|
return int(n), nil
|
||||||
|
}
|
||||||
|
if errorPtr != nil {
|
||||||
|
defer C.free(unsafe.Pointer(errorPtr))
|
||||||
|
if c.isClosed() {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
return 0, E.New(C.GoString(errorPtr))
|
||||||
|
}
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) Close() error {
|
||||||
|
var closeErr error
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
close(c.closed)
|
||||||
|
C.box_apple_tls_client_cancel(c.client)
|
||||||
|
closeErr = c.rawConn.Close()
|
||||||
|
c.ioAccess.Lock()
|
||||||
|
c.ioGroup.Wait()
|
||||||
|
C.box_apple_tls_client_free(c.client)
|
||||||
|
c.client = nil
|
||||||
|
c.ioAccess.Unlock()
|
||||||
|
})
|
||||||
|
return closeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) LocalAddr() net.Addr {
|
||||||
|
return c.rawConn.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) RemoteAddr() net.Addr {
|
||||||
|
return c.rawConn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline installs deadlines for subsequent Read and Write calls.
|
||||||
|
//
|
||||||
|
// Deadlines only apply to subsequent Read or Write calls; an in-flight call
|
||||||
|
// does not observe later updates to its deadline. Callers that need to cancel
|
||||||
|
// an in-flight I/O must Close the connection instead.
|
||||||
|
//
|
||||||
|
// Once an active Read or Write trips its deadline, the underlying
|
||||||
|
// nw_connection is cancelled and the conn is no longer usable — callers must
|
||||||
|
// Close after a deadline error.
|
||||||
|
func (c *appleTLSConn) SetDeadline(t time.Time) error {
|
||||||
|
c.deadlineAccess.Lock()
|
||||||
|
c.readDeadline = t
|
||||||
|
c.writeDeadline = t
|
||||||
|
c.readTimedOut = false
|
||||||
|
c.writeTimedOut = false
|
||||||
|
c.deadlineAccess.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) SetReadDeadline(t time.Time) error {
|
||||||
|
c.deadlineAccess.Lock()
|
||||||
|
c.readDeadline = t
|
||||||
|
c.readTimedOut = false
|
||||||
|
c.deadlineAccess.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
c.deadlineAccess.Lock()
|
||||||
|
c.writeDeadline = t
|
||||||
|
c.writeTimedOut = false
|
||||||
|
c.deadlineAccess.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) prepareReadTimeout() (int, error) {
|
||||||
|
c.deadlineAccess.Lock()
|
||||||
|
defer c.deadlineAccess.Unlock()
|
||||||
|
if c.readTimedOut {
|
||||||
|
return 0, os.ErrDeadlineExceeded
|
||||||
|
}
|
||||||
|
timeoutMs, expired := deadlineTimeoutMs(c.readDeadline)
|
||||||
|
if expired {
|
||||||
|
c.readTimedOut = true
|
||||||
|
return 0, os.ErrDeadlineExceeded
|
||||||
|
}
|
||||||
|
return timeoutMs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) prepareWriteTimeout() (int, error) {
|
||||||
|
c.deadlineAccess.Lock()
|
||||||
|
defer c.deadlineAccess.Unlock()
|
||||||
|
if c.writeTimedOut {
|
||||||
|
return 0, os.ErrDeadlineExceeded
|
||||||
|
}
|
||||||
|
timeoutMs, expired := deadlineTimeoutMs(c.writeDeadline)
|
||||||
|
if expired {
|
||||||
|
c.writeTimedOut = true
|
||||||
|
return 0, os.ErrDeadlineExceeded
|
||||||
|
}
|
||||||
|
return timeoutMs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) markReadTimedOut() {
|
||||||
|
c.deadlineAccess.Lock()
|
||||||
|
c.readTimedOut = true
|
||||||
|
c.deadlineAccess.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) markWriteTimedOut() {
|
||||||
|
c.deadlineAccess.Lock()
|
||||||
|
c.writeTimedOut = true
|
||||||
|
c.deadlineAccess.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func deadlineTimeoutMs(deadline time.Time) (int, bool) {
|
||||||
|
if deadline.IsZero() {
|
||||||
|
return -1, false
|
||||||
|
}
|
||||||
|
remaining := time.Until(deadline)
|
||||||
|
if remaining <= 0 {
|
||||||
|
return 0, true
|
||||||
|
}
|
||||||
|
return timeoutFromDuration(remaining), false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) isClosed() bool {
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) acquireClient() (*C.box_apple_tls_client_t, error) {
|
||||||
|
c.ioAccess.Lock()
|
||||||
|
defer c.ioAccess.Unlock()
|
||||||
|
if c.isClosed() {
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
client := c.client
|
||||||
|
if client == nil {
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
c.ioGroup.Add(1)
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) releaseClient() {
|
||||||
|
c.ioGroup.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) NetConn() net.Conn {
|
||||||
|
return c.rawConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) HandshakeContext(ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) ConnectionState() ConnectionState {
|
||||||
|
c.stateAccess.RLock()
|
||||||
|
defer c.stateAccess.RUnlock()
|
||||||
|
return c.state
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAppleTLSState(state *C.box_apple_tls_state_t) (tls.ConnectionState, [][]byte, error) {
|
||||||
|
rawCerts, peerCertificates, err := parseAppleCertChain(state.peer_cert_chain, state.peer_cert_chain_len)
|
||||||
|
if err != nil {
|
||||||
|
return tls.ConnectionState{}, nil, err
|
||||||
|
}
|
||||||
|
var negotiatedProtocol string
|
||||||
|
if state.alpn != nil {
|
||||||
|
negotiatedProtocol = C.GoString(state.alpn)
|
||||||
|
}
|
||||||
|
var serverName string
|
||||||
|
if state.server_name != nil {
|
||||||
|
serverName = C.GoString(state.server_name)
|
||||||
|
}
|
||||||
|
return tls.ConnectionState{
|
||||||
|
Version: uint16(state.version),
|
||||||
|
HandshakeComplete: true,
|
||||||
|
CipherSuite: uint16(state.cipher_suite),
|
||||||
|
NegotiatedProtocol: negotiatedProtocol,
|
||||||
|
ServerName: serverName,
|
||||||
|
PeerCertificates: peerCertificates,
|
||||||
|
}, rawCerts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAppleCertChain(chain *C.uint8_t, chainLen C.size_t) ([][]byte, []*x509.Certificate, error) {
|
||||||
|
if chain == nil || chainLen == 0 {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
chainBytes := C.GoBytes(unsafe.Pointer(chain), C.int(chainLen))
|
||||||
|
var (
|
||||||
|
rawCerts [][]byte
|
||||||
|
peerCertificates []*x509.Certificate
|
||||||
|
)
|
||||||
|
for len(chainBytes) >= 4 {
|
||||||
|
certificateLen := binary.BigEndian.Uint32(chainBytes[:4])
|
||||||
|
chainBytes = chainBytes[4:]
|
||||||
|
if len(chainBytes) < int(certificateLen) {
|
||||||
|
return nil, nil, E.New("apple TLS: invalid certificate chain")
|
||||||
|
}
|
||||||
|
certificateData := append([]byte(nil), chainBytes[:certificateLen]...)
|
||||||
|
certificate, err := x509.ParseCertificate(certificateData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, E.Cause(err, "parse peer certificate")
|
||||||
|
}
|
||||||
|
rawCerts = append(rawCerts, certificateData)
|
||||||
|
peerCertificates = append(peerCertificates, certificate)
|
||||||
|
chainBytes = chainBytes[certificateLen:]
|
||||||
|
}
|
||||||
|
if len(chainBytes) != 0 {
|
||||||
|
return nil, nil, E.New("apple TLS: invalid certificate chain")
|
||||||
|
}
|
||||||
|
return rawCerts, peerCertificates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func timeoutFromDuration(timeout time.Duration) int {
|
||||||
|
if timeout <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
timeoutMilliseconds := int64(timeout / time.Millisecond)
|
||||||
|
if timeout%time.Millisecond != 0 {
|
||||||
|
timeoutMilliseconds++
|
||||||
|
}
|
||||||
|
if timeoutMilliseconds > math.MaxInt32 {
|
||||||
|
return math.MaxInt32
|
||||||
|
}
|
||||||
|
return int(timeoutMilliseconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cStringOrNil(value string) *C.char {
|
||||||
|
if value == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return C.CString(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cFree(pointer *C.char) {
|
||||||
|
if pointer != nil {
|
||||||
|
C.free(unsafe.Pointer(pointer))
|
||||||
|
}
|
||||||
|
}
|
||||||
39
common/tls/apple_client_platform_darwin.h
Normal file
39
common/tls/apple_client_platform_darwin.h
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
#include <stdbool.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
|
typedef struct box_apple_tls_client box_apple_tls_client_t;
|
||||||
|
|
||||||
|
typedef struct box_apple_tls_state {
|
||||||
|
uint16_t version;
|
||||||
|
uint16_t cipher_suite;
|
||||||
|
char *alpn;
|
||||||
|
char *server_name;
|
||||||
|
uint8_t *peer_cert_chain;
|
||||||
|
size_t peer_cert_chain_len;
|
||||||
|
} box_apple_tls_state_t;
|
||||||
|
|
||||||
|
box_apple_tls_client_t *box_apple_tls_client_create(
|
||||||
|
int connected_socket,
|
||||||
|
const char *server_name,
|
||||||
|
const char *alpn,
|
||||||
|
size_t alpn_len,
|
||||||
|
uint16_t min_version,
|
||||||
|
uint16_t max_version,
|
||||||
|
bool insecure,
|
||||||
|
const char *anchor_pem,
|
||||||
|
size_t anchor_pem_len,
|
||||||
|
bool anchor_only,
|
||||||
|
bool has_verify_time,
|
||||||
|
int64_t verify_time_unix_millis,
|
||||||
|
char **error_out
|
||||||
|
);
|
||||||
|
|
||||||
|
int box_apple_tls_client_wait_ready(box_apple_tls_client_t *client, int timeout_msec, char **error_out);
|
||||||
|
void box_apple_tls_client_cancel(box_apple_tls_client_t *client);
|
||||||
|
void box_apple_tls_client_free(box_apple_tls_client_t *client);
|
||||||
|
ssize_t box_apple_tls_client_read(box_apple_tls_client_t *client, void *buffer, size_t buffer_len, int timeout_msec, bool *eof_out, char **error_out);
|
||||||
|
ssize_t box_apple_tls_client_write(box_apple_tls_client_t *client, const void *buffer, size_t buffer_len, int timeout_msec, char **error_out);
|
||||||
|
bool box_apple_tls_client_copy_state(box_apple_tls_client_t *client, box_apple_tls_state_t *state, char **error_out);
|
||||||
|
void box_apple_tls_state_free(box_apple_tls_state_t *state);
|
||||||
667
common/tls/apple_client_platform_darwin.m
Normal file
667
common/tls/apple_client_platform_darwin.m
Normal file
@@ -0,0 +1,667 @@
|
|||||||
|
#import "apple_client_platform_darwin.h"
|
||||||
|
|
||||||
|
#import <Foundation/Foundation.h>
|
||||||
|
#import <Network/Network.h>
|
||||||
|
#import <Security/Security.h>
|
||||||
|
#import <Security/SecProtocolMetadata.h>
|
||||||
|
#import <Security/SecProtocolOptions.h>
|
||||||
|
#import <Security/SecProtocolTypes.h>
|
||||||
|
#import <arpa/inet.h>
|
||||||
|
#import <dlfcn.h>
|
||||||
|
#import <dispatch/dispatch.h>
|
||||||
|
#import <stdatomic.h>
|
||||||
|
#import <stdlib.h>
|
||||||
|
#import <string.h>
|
||||||
|
#import <unistd.h>
|
||||||
|
|
||||||
|
typedef nw_connection_t _Nullable (*box_nw_connection_create_with_connected_socket_and_parameters_f)(int connected_socket, nw_parameters_t parameters);
|
||||||
|
typedef const char * _Nullable (*box_sec_protocol_metadata_string_accessor_f)(sec_protocol_metadata_t metadata);
|
||||||
|
|
||||||
|
typedef struct box_apple_tls_client {
|
||||||
|
void *connection;
|
||||||
|
void *queue;
|
||||||
|
void *ready_semaphore;
|
||||||
|
atomic_int ref_count;
|
||||||
|
atomic_bool ready;
|
||||||
|
atomic_bool ready_done;
|
||||||
|
char *ready_error;
|
||||||
|
box_apple_tls_state_t state;
|
||||||
|
} box_apple_tls_client_t;
|
||||||
|
|
||||||
|
static nw_connection_t box_apple_tls_connection(box_apple_tls_client_t *client) {
|
||||||
|
if (client == NULL || client->connection == NULL) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
return (__bridge nw_connection_t)client->connection;
|
||||||
|
}
|
||||||
|
|
||||||
|
static dispatch_queue_t box_apple_tls_client_queue(box_apple_tls_client_t *client) {
|
||||||
|
if (client == NULL || client->queue == NULL) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
return (__bridge dispatch_queue_t)client->queue;
|
||||||
|
}
|
||||||
|
|
||||||
|
static dispatch_semaphore_t box_apple_tls_ready_semaphore(box_apple_tls_client_t *client) {
|
||||||
|
if (client == NULL || client->ready_semaphore == NULL) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
return (__bridge dispatch_semaphore_t)client->ready_semaphore;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void box_apple_tls_state_reset(box_apple_tls_state_t *state) {
|
||||||
|
if (state == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
free(state->alpn);
|
||||||
|
free(state->server_name);
|
||||||
|
free(state->peer_cert_chain);
|
||||||
|
memset(state, 0, sizeof(box_apple_tls_state_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void box_apple_tls_client_destroy(box_apple_tls_client_t *client) {
|
||||||
|
free(client->ready_error);
|
||||||
|
box_apple_tls_state_reset(&client->state);
|
||||||
|
if (client->ready_semaphore != NULL) {
|
||||||
|
CFBridgingRelease(client->ready_semaphore);
|
||||||
|
}
|
||||||
|
if (client->connection != NULL) {
|
||||||
|
CFBridgingRelease(client->connection);
|
||||||
|
}
|
||||||
|
if (client->queue != NULL) {
|
||||||
|
CFBridgingRelease(client->queue);
|
||||||
|
}
|
||||||
|
free(client);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void box_apple_tls_client_release(box_apple_tls_client_t *client) {
|
||||||
|
if (client == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (atomic_fetch_sub(&client->ref_count, 1) == 1) {
|
||||||
|
box_apple_tls_client_destroy(client);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void box_set_error_string(char **error_out, NSString *message) {
|
||||||
|
if (error_out == NULL || *error_out != NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const char *utf8 = [message UTF8String];
|
||||||
|
*error_out = strdup(utf8 != NULL ? utf8 : "unknown error");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void box_set_error_message(char **error_out, const char *message) {
|
||||||
|
if (error_out == NULL || *error_out != NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
*error_out = strdup(message != NULL ? message : "unknown error");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void box_set_error_from_nw_error(char **error_out, nw_error_t error) {
|
||||||
|
if (error == NULL) {
|
||||||
|
box_set_error_message(error_out, "unknown network error");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
CFErrorRef cfError = nw_error_copy_cf_error(error);
|
||||||
|
if (cfError == NULL) {
|
||||||
|
box_set_error_message(error_out, "unknown network error");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
NSString *description = [(__bridge NSError *)cfError description];
|
||||||
|
box_set_error_string(error_out, description);
|
||||||
|
CFRelease(cfError);
|
||||||
|
}
|
||||||
|
|
||||||
|
static char *box_apple_tls_metadata_copy_negotiated_protocol(sec_protocol_metadata_t metadata) {
|
||||||
|
static box_sec_protocol_metadata_string_accessor_f copy_fn;
|
||||||
|
static box_sec_protocol_metadata_string_accessor_f get_fn;
|
||||||
|
static dispatch_once_t onceToken;
|
||||||
|
dispatch_once(&onceToken, ^{
|
||||||
|
copy_fn = (box_sec_protocol_metadata_string_accessor_f)dlsym(RTLD_DEFAULT, "sec_protocol_metadata_copy_negotiated_protocol");
|
||||||
|
get_fn = (box_sec_protocol_metadata_string_accessor_f)dlsym(RTLD_DEFAULT, "sec_protocol_metadata_get_negotiated_protocol");
|
||||||
|
});
|
||||||
|
if (copy_fn != NULL) {
|
||||||
|
return (char *)copy_fn(metadata);
|
||||||
|
}
|
||||||
|
if (get_fn != NULL) {
|
||||||
|
const char *protocol = get_fn(metadata);
|
||||||
|
if (protocol != NULL) {
|
||||||
|
return strdup(protocol);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
static char *box_apple_tls_metadata_copy_server_name(sec_protocol_metadata_t metadata) {
|
||||||
|
static box_sec_protocol_metadata_string_accessor_f copy_fn;
|
||||||
|
static box_sec_protocol_metadata_string_accessor_f get_fn;
|
||||||
|
static dispatch_once_t onceToken;
|
||||||
|
dispatch_once(&onceToken, ^{
|
||||||
|
copy_fn = (box_sec_protocol_metadata_string_accessor_f)dlsym(RTLD_DEFAULT, "sec_protocol_metadata_copy_server_name");
|
||||||
|
get_fn = (box_sec_protocol_metadata_string_accessor_f)dlsym(RTLD_DEFAULT, "sec_protocol_metadata_get_server_name");
|
||||||
|
});
|
||||||
|
if (copy_fn != NULL) {
|
||||||
|
return (char *)copy_fn(metadata);
|
||||||
|
}
|
||||||
|
if (get_fn != NULL) {
|
||||||
|
const char *server_name = get_fn(metadata);
|
||||||
|
if (server_name != NULL) {
|
||||||
|
return strdup(server_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
static NSArray<NSString *> *box_split_lines(const char *content, size_t content_len) {
|
||||||
|
if (content == NULL || content_len == 0) {
|
||||||
|
return @[];
|
||||||
|
}
|
||||||
|
NSString *string = [[NSString alloc] initWithBytes:content length:content_len encoding:NSUTF8StringEncoding];
|
||||||
|
if (string == nil) {
|
||||||
|
return @[];
|
||||||
|
}
|
||||||
|
NSMutableArray<NSString *> *lines = [NSMutableArray array];
|
||||||
|
[string enumerateLinesUsingBlock:^(NSString *line, BOOL *stop) {
|
||||||
|
if (line.length > 0) {
|
||||||
|
[lines addObject:line];
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
return lines;
|
||||||
|
}
|
||||||
|
|
||||||
|
static NSArray *box_parse_certificates_from_pem(const char *pem, size_t pem_len) {
|
||||||
|
if (pem == NULL || pem_len == 0) {
|
||||||
|
return @[];
|
||||||
|
}
|
||||||
|
NSString *content = [[NSString alloc] initWithBytes:pem length:pem_len encoding:NSUTF8StringEncoding];
|
||||||
|
if (content == nil) {
|
||||||
|
return @[];
|
||||||
|
}
|
||||||
|
NSString *beginMarker = @"-----BEGIN CERTIFICATE-----";
|
||||||
|
NSString *endMarker = @"-----END CERTIFICATE-----";
|
||||||
|
NSMutableArray *certificates = [NSMutableArray array];
|
||||||
|
NSUInteger searchFrom = 0;
|
||||||
|
while (searchFrom < content.length) {
|
||||||
|
NSRange beginRange = [content rangeOfString:beginMarker options:0 range:NSMakeRange(searchFrom, content.length - searchFrom)];
|
||||||
|
if (beginRange.location == NSNotFound) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
NSUInteger bodyStart = beginRange.location + beginRange.length;
|
||||||
|
NSRange endRange = [content rangeOfString:endMarker options:0 range:NSMakeRange(bodyStart, content.length - bodyStart)];
|
||||||
|
if (endRange.location == NSNotFound) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
NSString *base64Section = [content substringWithRange:NSMakeRange(bodyStart, endRange.location - bodyStart)];
|
||||||
|
NSArray<NSString *> *components = [base64Section componentsSeparatedByCharactersInSet:[NSCharacterSet whitespaceAndNewlineCharacterSet]];
|
||||||
|
NSString *base64Content = [components componentsJoinedByString:@""];
|
||||||
|
NSData *der = [[NSData alloc] initWithBase64EncodedString:base64Content options:0];
|
||||||
|
if (der != nil) {
|
||||||
|
SecCertificateRef certificate = SecCertificateCreateWithData(NULL, (__bridge CFDataRef)der);
|
||||||
|
if (certificate != NULL) {
|
||||||
|
[certificates addObject:(__bridge id)certificate];
|
||||||
|
CFRelease(certificate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
searchFrom = endRange.location + endRange.length;
|
||||||
|
}
|
||||||
|
return certificates;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool box_evaluate_trust(sec_trust_t trust, NSArray *anchors, bool anchor_only, NSDate *verify_date) {
|
||||||
|
bool result = false;
|
||||||
|
SecTrustRef trustRef = sec_trust_copy_ref(trust);
|
||||||
|
if (trustRef == NULL) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (verify_date != nil && SecTrustSetVerifyDate(trustRef, (__bridge CFDateRef)verify_date) != errSecSuccess) {
|
||||||
|
CFRelease(trustRef);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (anchors.count > 0 || anchor_only) {
|
||||||
|
CFMutableArrayRef anchorArray = CFArrayCreateMutable(NULL, 0, &kCFTypeArrayCallBacks);
|
||||||
|
for (id certificate in anchors) {
|
||||||
|
CFArrayAppendValue(anchorArray, (__bridge const void *)certificate);
|
||||||
|
}
|
||||||
|
SecTrustSetAnchorCertificates(trustRef, anchorArray);
|
||||||
|
SecTrustSetAnchorCertificatesOnly(trustRef, anchor_only);
|
||||||
|
CFRelease(anchorArray);
|
||||||
|
}
|
||||||
|
CFErrorRef error = NULL;
|
||||||
|
result = SecTrustEvaluateWithError(trustRef, &error);
|
||||||
|
if (error != NULL) {
|
||||||
|
CFRelease(error);
|
||||||
|
}
|
||||||
|
CFRelease(trustRef);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static nw_connection_t box_apple_tls_create_connection(int connected_socket, nw_parameters_t parameters) {
|
||||||
|
static box_nw_connection_create_with_connected_socket_and_parameters_f create_fn;
|
||||||
|
static dispatch_once_t onceToken;
|
||||||
|
dispatch_once(&onceToken, ^{
|
||||||
|
char name[] = "sretemarap_dna_tekcos_detcennoc_htiw_etaerc_noitcennoc_wn";
|
||||||
|
for (size_t i = 0, j = sizeof(name) - 2; i < j; i++, j--) {
|
||||||
|
char t = name[i];
|
||||||
|
name[i] = name[j];
|
||||||
|
name[j] = t;
|
||||||
|
}
|
||||||
|
create_fn = (box_nw_connection_create_with_connected_socket_and_parameters_f)dlsym(RTLD_DEFAULT, name);
|
||||||
|
});
|
||||||
|
if (create_fn == NULL) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
return create_fn(connected_socket, parameters);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool box_apple_tls_state_copy(const box_apple_tls_state_t *source, box_apple_tls_state_t *destination) {
|
||||||
|
memset(destination, 0, sizeof(box_apple_tls_state_t));
|
||||||
|
destination->version = source->version;
|
||||||
|
destination->cipher_suite = source->cipher_suite;
|
||||||
|
if (source->alpn != NULL) {
|
||||||
|
destination->alpn = strdup(source->alpn);
|
||||||
|
if (destination->alpn == NULL) {
|
||||||
|
goto oom;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (source->server_name != NULL) {
|
||||||
|
destination->server_name = strdup(source->server_name);
|
||||||
|
if (destination->server_name == NULL) {
|
||||||
|
goto oom;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (source->peer_cert_chain_len > 0) {
|
||||||
|
destination->peer_cert_chain = malloc(source->peer_cert_chain_len);
|
||||||
|
if (destination->peer_cert_chain == NULL) {
|
||||||
|
goto oom;
|
||||||
|
}
|
||||||
|
memcpy(destination->peer_cert_chain, source->peer_cert_chain, source->peer_cert_chain_len);
|
||||||
|
destination->peer_cert_chain_len = source->peer_cert_chain_len;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
|
||||||
|
oom:
|
||||||
|
box_apple_tls_state_reset(destination);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool box_apple_tls_state_load(nw_connection_t connection, box_apple_tls_state_t *state, char **error_out) {
|
||||||
|
box_apple_tls_state_reset(state);
|
||||||
|
if (connection == nil) {
|
||||||
|
box_set_error_message(error_out, "apple TLS: invalid client");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
nw_protocol_definition_t tls_definition = nw_protocol_copy_tls_definition();
|
||||||
|
nw_protocol_metadata_t metadata = nw_connection_copy_protocol_metadata(connection, tls_definition);
|
||||||
|
if (metadata == NULL || !nw_protocol_metadata_is_tls(metadata)) {
|
||||||
|
box_set_error_message(error_out, "apple TLS: metadata unavailable");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
sec_protocol_metadata_t sec_metadata = nw_tls_copy_sec_protocol_metadata(metadata);
|
||||||
|
if (sec_metadata == NULL) {
|
||||||
|
box_set_error_message(error_out, "apple TLS: metadata unavailable");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
state->version = (uint16_t)sec_protocol_metadata_get_negotiated_tls_protocol_version(sec_metadata);
|
||||||
|
state->cipher_suite = (uint16_t)sec_protocol_metadata_get_negotiated_tls_ciphersuite(sec_metadata);
|
||||||
|
state->alpn = box_apple_tls_metadata_copy_negotiated_protocol(sec_metadata);
|
||||||
|
state->server_name = box_apple_tls_metadata_copy_server_name(sec_metadata);
|
||||||
|
|
||||||
|
NSMutableData *chain_data = [NSMutableData data];
|
||||||
|
sec_protocol_metadata_access_peer_certificate_chain(sec_metadata, ^(sec_certificate_t certificate) {
|
||||||
|
SecCertificateRef certificate_ref = sec_certificate_copy_ref(certificate);
|
||||||
|
if (certificate_ref == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
CFDataRef certificate_data = SecCertificateCopyData(certificate_ref);
|
||||||
|
CFRelease(certificate_ref);
|
||||||
|
if (certificate_data == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
uint32_t certificate_len = (uint32_t)CFDataGetLength(certificate_data);
|
||||||
|
uint32_t network_len = htonl(certificate_len);
|
||||||
|
[chain_data appendBytes:&network_len length:sizeof(network_len)];
|
||||||
|
[chain_data appendBytes:CFDataGetBytePtr(certificate_data) length:certificate_len];
|
||||||
|
CFRelease(certificate_data);
|
||||||
|
});
|
||||||
|
if (chain_data.length > 0) {
|
||||||
|
state->peer_cert_chain = malloc(chain_data.length);
|
||||||
|
if (state->peer_cert_chain == NULL) {
|
||||||
|
box_set_error_message(error_out, "apple TLS: out of memory");
|
||||||
|
box_apple_tls_state_reset(state);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
memcpy(state->peer_cert_chain, chain_data.bytes, chain_data.length);
|
||||||
|
state->peer_cert_chain_len = chain_data.length;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
box_apple_tls_client_t *box_apple_tls_client_create(
|
||||||
|
int connected_socket,
|
||||||
|
const char *server_name,
|
||||||
|
const char *alpn,
|
||||||
|
size_t alpn_len,
|
||||||
|
uint16_t min_version,
|
||||||
|
uint16_t max_version,
|
||||||
|
bool insecure,
|
||||||
|
const char *anchor_pem,
|
||||||
|
size_t anchor_pem_len,
|
||||||
|
bool anchor_only,
|
||||||
|
bool has_verify_time,
|
||||||
|
int64_t verify_time_unix_millis,
|
||||||
|
char **error_out
|
||||||
|
) {
|
||||||
|
box_apple_tls_client_t *client = calloc(1, sizeof(box_apple_tls_client_t));
|
||||||
|
if (client == NULL) {
|
||||||
|
close(connected_socket);
|
||||||
|
box_set_error_message(error_out, "apple TLS: out of memory");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
client->queue = (__bridge_retained void *)dispatch_queue_create("sing-box.apple-private-tls", DISPATCH_QUEUE_SERIAL);
|
||||||
|
client->ready_semaphore = (__bridge_retained void *)dispatch_semaphore_create(0);
|
||||||
|
atomic_init(&client->ref_count, 1);
|
||||||
|
atomic_init(&client->ready, false);
|
||||||
|
atomic_init(&client->ready_done, false);
|
||||||
|
|
||||||
|
NSArray<NSString *> *alpnList = box_split_lines(alpn, alpn_len);
|
||||||
|
NSArray *anchors = box_parse_certificates_from_pem(anchor_pem, anchor_pem_len);
|
||||||
|
NSDate *verifyDate = nil;
|
||||||
|
if (has_verify_time) {
|
||||||
|
verifyDate = [NSDate dateWithTimeIntervalSince1970:(NSTimeInterval)verify_time_unix_millis / 1000.0];
|
||||||
|
}
|
||||||
|
nw_parameters_t parameters = nw_parameters_create_secure_tcp(^(nw_protocol_options_t tls_options) {
|
||||||
|
sec_protocol_options_t sec_options = nw_tls_copy_sec_protocol_options(tls_options);
|
||||||
|
if (min_version != 0) {
|
||||||
|
sec_protocol_options_set_min_tls_protocol_version(sec_options, (tls_protocol_version_t)min_version);
|
||||||
|
}
|
||||||
|
if (max_version != 0) {
|
||||||
|
sec_protocol_options_set_max_tls_protocol_version(sec_options, (tls_protocol_version_t)max_version);
|
||||||
|
}
|
||||||
|
if (server_name != NULL && server_name[0] != '\0') {
|
||||||
|
sec_protocol_options_set_tls_server_name(sec_options, server_name);
|
||||||
|
}
|
||||||
|
for (NSString *protocol in alpnList) {
|
||||||
|
sec_protocol_options_add_tls_application_protocol(sec_options, protocol.UTF8String);
|
||||||
|
}
|
||||||
|
sec_protocol_options_set_peer_authentication_required(sec_options, !insecure);
|
||||||
|
if (insecure) {
|
||||||
|
sec_protocol_options_set_verify_block(sec_options, ^(sec_protocol_metadata_t metadata, sec_trust_t trust, sec_protocol_verify_complete_t complete) {
|
||||||
|
complete(true);
|
||||||
|
}, box_apple_tls_client_queue(client));
|
||||||
|
} else if (verifyDate != nil || anchors.count > 0 || anchor_only) {
|
||||||
|
sec_protocol_options_set_verify_block(sec_options, ^(sec_protocol_metadata_t metadata, sec_trust_t trust, sec_protocol_verify_complete_t complete) {
|
||||||
|
complete(box_evaluate_trust(trust, anchors, anchor_only, verifyDate));
|
||||||
|
}, box_apple_tls_client_queue(client));
|
||||||
|
}
|
||||||
|
}, NW_PARAMETERS_DEFAULT_CONFIGURATION);
|
||||||
|
|
||||||
|
nw_connection_t connection = box_apple_tls_create_connection(connected_socket, parameters);
|
||||||
|
if (connection == NULL) {
|
||||||
|
close(connected_socket);
|
||||||
|
if (client->ready_semaphore != NULL) {
|
||||||
|
CFBridgingRelease(client->ready_semaphore);
|
||||||
|
}
|
||||||
|
if (client->queue != NULL) {
|
||||||
|
CFBridgingRelease(client->queue);
|
||||||
|
}
|
||||||
|
free(client);
|
||||||
|
box_set_error_message(error_out, "apple TLS: failed to create connection");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
client->connection = (__bridge_retained void *)connection;
|
||||||
|
atomic_fetch_add(&client->ref_count, 1);
|
||||||
|
|
||||||
|
nw_connection_set_state_changed_handler(connection, ^(nw_connection_state_t state, nw_error_t error) {
|
||||||
|
switch (state) {
|
||||||
|
case nw_connection_state_ready:
|
||||||
|
if (!atomic_load(&client->ready_done)) {
|
||||||
|
atomic_store(&client->ready, box_apple_tls_state_load(connection, &client->state, &client->ready_error));
|
||||||
|
atomic_store(&client->ready_done, true);
|
||||||
|
dispatch_semaphore_signal(box_apple_tls_ready_semaphore(client));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case nw_connection_state_failed:
|
||||||
|
if (!atomic_load(&client->ready_done)) {
|
||||||
|
box_set_error_from_nw_error(&client->ready_error, error);
|
||||||
|
atomic_store(&client->ready_done, true);
|
||||||
|
dispatch_semaphore_signal(box_apple_tls_ready_semaphore(client));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case nw_connection_state_cancelled:
|
||||||
|
if (!atomic_load(&client->ready_done)) {
|
||||||
|
box_set_error_from_nw_error(&client->ready_error, error);
|
||||||
|
atomic_store(&client->ready_done, true);
|
||||||
|
dispatch_semaphore_signal(box_apple_tls_ready_semaphore(client));
|
||||||
|
}
|
||||||
|
box_apple_tls_client_release(client);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
nw_connection_set_queue(connection, box_apple_tls_client_queue(client));
|
||||||
|
nw_connection_start(connection);
|
||||||
|
return client;
|
||||||
|
}
|
||||||
|
|
||||||
|
int box_apple_tls_client_wait_ready(box_apple_tls_client_t *client, int timeout_msec, char **error_out) {
|
||||||
|
dispatch_semaphore_t ready_semaphore = box_apple_tls_ready_semaphore(client);
|
||||||
|
if (ready_semaphore == nil) {
|
||||||
|
box_set_error_message(error_out, "apple TLS: invalid client");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
if (!atomic_load(&client->ready_done)) {
|
||||||
|
dispatch_time_t timeout = DISPATCH_TIME_FOREVER;
|
||||||
|
if (timeout_msec >= 0) {
|
||||||
|
timeout = dispatch_time(DISPATCH_TIME_NOW, (int64_t)timeout_msec * NSEC_PER_MSEC);
|
||||||
|
}
|
||||||
|
long wait_result = dispatch_semaphore_wait(ready_semaphore, timeout);
|
||||||
|
if (wait_result != 0) {
|
||||||
|
return -2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (atomic_load(&client->ready)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
if (client->ready_error != NULL) {
|
||||||
|
if (error_out != NULL) {
|
||||||
|
*error_out = client->ready_error;
|
||||||
|
client->ready_error = NULL;
|
||||||
|
} else {
|
||||||
|
free(client->ready_error);
|
||||||
|
client->ready_error = NULL;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
box_set_error_message(error_out, "apple TLS: handshake failed");
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void box_apple_tls_client_cancel(box_apple_tls_client_t *client) {
|
||||||
|
if (client == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
nw_connection_t connection = box_apple_tls_connection(client);
|
||||||
|
if (connection != nil) {
|
||||||
|
nw_connection_cancel(connection);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void box_apple_tls_client_free(box_apple_tls_client_t *client) {
|
||||||
|
if (client == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
nw_connection_t connection = box_apple_tls_connection(client);
|
||||||
|
if (connection != nil) {
|
||||||
|
nw_connection_cancel(connection);
|
||||||
|
}
|
||||||
|
box_apple_tls_client_release(client);
|
||||||
|
}
|
||||||
|
|
||||||
|
ssize_t box_apple_tls_client_read(box_apple_tls_client_t *client, void *buffer, size_t buffer_len, int timeout_msec, bool *eof_out, char **error_out) {
|
||||||
|
nw_connection_t connection = box_apple_tls_connection(client);
|
||||||
|
if (connection == nil) {
|
||||||
|
box_set_error_message(error_out, "apple TLS: invalid client");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch_semaphore_t read_semaphore = dispatch_semaphore_create(0);
|
||||||
|
__block NSData *content_data = nil;
|
||||||
|
__block bool read_eof = false;
|
||||||
|
__block char *local_error = NULL;
|
||||||
|
|
||||||
|
nw_connection_receive(connection, 1, (uint32_t)buffer_len, ^(dispatch_data_t content, nw_content_context_t context, bool is_complete, nw_error_t error) {
|
||||||
|
if (content != NULL) {
|
||||||
|
const void *mapped = NULL;
|
||||||
|
size_t mapped_len = 0;
|
||||||
|
dispatch_data_t mapped_data = dispatch_data_create_map(content, &mapped, &mapped_len);
|
||||||
|
if (mapped != NULL && mapped_len > 0) {
|
||||||
|
content_data = [NSData dataWithBytes:mapped length:mapped_len];
|
||||||
|
}
|
||||||
|
(void)mapped_data;
|
||||||
|
}
|
||||||
|
if (error != NULL && content_data.length == 0) {
|
||||||
|
box_set_error_from_nw_error(&local_error, error);
|
||||||
|
}
|
||||||
|
if (is_complete && (context == NULL || nw_content_context_get_is_final(context))) {
|
||||||
|
read_eof = true;
|
||||||
|
}
|
||||||
|
dispatch_semaphore_signal(read_semaphore);
|
||||||
|
});
|
||||||
|
|
||||||
|
dispatch_time_t wait_deadline = DISPATCH_TIME_FOREVER;
|
||||||
|
if (timeout_msec >= 0) {
|
||||||
|
wait_deadline = dispatch_time(DISPATCH_TIME_NOW, (int64_t)timeout_msec * NSEC_PER_MSEC);
|
||||||
|
}
|
||||||
|
long wait_result = dispatch_semaphore_wait(read_semaphore, wait_deadline);
|
||||||
|
if (wait_result != 0) {
|
||||||
|
nw_connection_cancel(connection);
|
||||||
|
dispatch_semaphore_wait(read_semaphore, DISPATCH_TIME_FOREVER);
|
||||||
|
if (local_error != NULL) {
|
||||||
|
free(local_error);
|
||||||
|
local_error = NULL;
|
||||||
|
}
|
||||||
|
return -2;
|
||||||
|
}
|
||||||
|
if (local_error != NULL) {
|
||||||
|
if (error_out != NULL) {
|
||||||
|
*error_out = local_error;
|
||||||
|
} else {
|
||||||
|
free(local_error);
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (eof_out != NULL) {
|
||||||
|
*eof_out = read_eof;
|
||||||
|
}
|
||||||
|
if (content_data == nil || content_data.length == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
memcpy(buffer, content_data.bytes, content_data.length);
|
||||||
|
return (ssize_t)content_data.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
ssize_t box_apple_tls_client_write(box_apple_tls_client_t *client, const void *buffer, size_t buffer_len, int timeout_msec, char **error_out) {
|
||||||
|
nw_connection_t connection = box_apple_tls_connection(client);
|
||||||
|
if (connection == nil) {
|
||||||
|
box_set_error_message(error_out, "apple TLS: invalid client");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (buffer_len == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void *content_copy = malloc(buffer_len);
|
||||||
|
dispatch_queue_t queue = box_apple_tls_client_queue(client);
|
||||||
|
if (content_copy == NULL) {
|
||||||
|
free(content_copy);
|
||||||
|
box_set_error_message(error_out, "apple TLS: out of memory");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (queue == nil) {
|
||||||
|
free(content_copy);
|
||||||
|
box_set_error_message(error_out, "apple TLS: invalid client");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
memcpy(content_copy, buffer, buffer_len);
|
||||||
|
dispatch_data_t content = dispatch_data_create(content_copy, buffer_len, queue, ^{
|
||||||
|
free(content_copy);
|
||||||
|
});
|
||||||
|
|
||||||
|
dispatch_semaphore_t write_semaphore = dispatch_semaphore_create(0);
|
||||||
|
__block char *local_error = NULL;
|
||||||
|
|
||||||
|
nw_connection_send(connection, content, NW_CONNECTION_DEFAULT_STREAM_CONTEXT, false, ^(nw_error_t error) {
|
||||||
|
if (error != NULL) {
|
||||||
|
box_set_error_from_nw_error(&local_error, error);
|
||||||
|
}
|
||||||
|
dispatch_semaphore_signal(write_semaphore);
|
||||||
|
});
|
||||||
|
|
||||||
|
dispatch_time_t wait_deadline = DISPATCH_TIME_FOREVER;
|
||||||
|
if (timeout_msec >= 0) {
|
||||||
|
wait_deadline = dispatch_time(DISPATCH_TIME_NOW, (int64_t)timeout_msec * NSEC_PER_MSEC);
|
||||||
|
}
|
||||||
|
long wait_result = dispatch_semaphore_wait(write_semaphore, wait_deadline);
|
||||||
|
if (wait_result != 0) {
|
||||||
|
nw_connection_cancel(connection);
|
||||||
|
dispatch_semaphore_wait(write_semaphore, DISPATCH_TIME_FOREVER);
|
||||||
|
if (local_error != NULL) {
|
||||||
|
free(local_error);
|
||||||
|
local_error = NULL;
|
||||||
|
}
|
||||||
|
return -2;
|
||||||
|
}
|
||||||
|
if (local_error != NULL) {
|
||||||
|
if (error_out != NULL) {
|
||||||
|
*error_out = local_error;
|
||||||
|
} else {
|
||||||
|
free(local_error);
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
return (ssize_t)buffer_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool box_apple_tls_client_copy_state(box_apple_tls_client_t *client, box_apple_tls_state_t *state, char **error_out) {
|
||||||
|
dispatch_queue_t queue = box_apple_tls_client_queue(client);
|
||||||
|
if (queue == nil || state == NULL) {
|
||||||
|
box_set_error_message(error_out, "apple TLS: invalid client");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
memset(state, 0, sizeof(box_apple_tls_state_t));
|
||||||
|
__block bool copied = false;
|
||||||
|
__block char *local_error = NULL;
|
||||||
|
dispatch_sync(queue, ^{
|
||||||
|
if (!atomic_load(&client->ready)) {
|
||||||
|
box_set_error_message(&local_error, "apple TLS: metadata unavailable");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!box_apple_tls_state_copy(&client->state, state)) {
|
||||||
|
box_set_error_message(&local_error, "apple TLS: out of memory");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
copied = true;
|
||||||
|
});
|
||||||
|
if (copied) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (local_error != NULL) {
|
||||||
|
if (error_out != NULL) {
|
||||||
|
*error_out = local_error;
|
||||||
|
} else {
|
||||||
|
free(local_error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
box_apple_tls_state_reset(state);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void box_apple_tls_state_free(box_apple_tls_state_t *state) {
|
||||||
|
box_apple_tls_state_reset(state);
|
||||||
|
}
|
||||||
453
common/tls/apple_client_platform_test.go
Normal file
453
common/tls/apple_client_platform_test.go
Normal file
@@ -0,0 +1,453 @@
|
|||||||
|
//go:build darwin && cgo
|
||||||
|
|
||||||
|
package tls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
stdtls "crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing/common/json/badoption"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
const appleTLSTestTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
const (
|
||||||
|
appleTLSSuccessHandshakeLoops = 20
|
||||||
|
appleTLSFailureRecoveryLoops = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
type appleTLSServerResult struct {
|
||||||
|
state stdtls.ConnectionState
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleClientHandshakeAppliesALPNAndVersion(t *testing.T) {
|
||||||
|
serverCertificate, serverCertificatePEM := newAppleTestCertificate(t, "localhost")
|
||||||
|
for index := 0; index < appleTLSSuccessHandshakeLoops; index++ {
|
||||||
|
serverResult, serverAddress := startAppleTLSTestServer(t, &stdtls.Config{
|
||||||
|
Certificates: []stdtls.Certificate{serverCertificate},
|
||||||
|
MinVersion: stdtls.VersionTLS12,
|
||||||
|
MaxVersion: stdtls.VersionTLS12,
|
||||||
|
NextProtos: []string{"h2"},
|
||||||
|
})
|
||||||
|
|
||||||
|
clientConn, err := newAppleTestClientConn(t, serverAddress, option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
Engine: "apple",
|
||||||
|
ServerName: "localhost",
|
||||||
|
MinVersion: "1.2",
|
||||||
|
MaxVersion: "1.2",
|
||||||
|
ALPN: badoption.Listable[string]{"h2"},
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("iteration %d: %v", index, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientState := clientConn.ConnectionState()
|
||||||
|
if clientState.Version != stdtls.VersionTLS12 {
|
||||||
|
_ = clientConn.Close()
|
||||||
|
t.Fatalf("iteration %d: unexpected negotiated version: %x", index, clientState.Version)
|
||||||
|
}
|
||||||
|
if clientState.NegotiatedProtocol != "h2" {
|
||||||
|
_ = clientConn.Close()
|
||||||
|
t.Fatalf("iteration %d: unexpected negotiated protocol: %q", index, clientState.NegotiatedProtocol)
|
||||||
|
}
|
||||||
|
_ = clientConn.Close()
|
||||||
|
|
||||||
|
result := <-serverResult
|
||||||
|
if result.err != nil {
|
||||||
|
t.Fatalf("iteration %d: %v", index, result.err)
|
||||||
|
}
|
||||||
|
if result.state.Version != stdtls.VersionTLS12 {
|
||||||
|
t.Fatalf("iteration %d: server negotiated unexpected version: %x", index, result.state.Version)
|
||||||
|
}
|
||||||
|
if result.state.NegotiatedProtocol != "h2" {
|
||||||
|
t.Fatalf("iteration %d: server negotiated unexpected protocol: %q", index, result.state.NegotiatedProtocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleClientHandshakeRejectsVersionMismatch(t *testing.T) {
|
||||||
|
serverCertificate, serverCertificatePEM := newAppleTestCertificate(t, "localhost")
|
||||||
|
serverResult, serverAddress := startAppleTLSTestServer(t, &stdtls.Config{
|
||||||
|
Certificates: []stdtls.Certificate{serverCertificate},
|
||||||
|
MinVersion: stdtls.VersionTLS13,
|
||||||
|
MaxVersion: stdtls.VersionTLS13,
|
||||||
|
})
|
||||||
|
|
||||||
|
clientConn, err := newAppleTestClientConn(t, serverAddress, option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
Engine: "apple",
|
||||||
|
ServerName: "localhost",
|
||||||
|
MaxVersion: "1.2",
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
clientConn.Close()
|
||||||
|
t.Fatal("expected version mismatch handshake to fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result := <-serverResult; result.err == nil {
|
||||||
|
t.Fatal("expected server handshake to fail on version mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleClientHandshakeRejectsServerNameMismatch(t *testing.T) {
|
||||||
|
serverCertificate, serverCertificatePEM := newAppleTestCertificate(t, "localhost")
|
||||||
|
serverResult, serverAddress := startAppleTLSTestServer(t, &stdtls.Config{
|
||||||
|
Certificates: []stdtls.Certificate{serverCertificate},
|
||||||
|
})
|
||||||
|
|
||||||
|
clientConn, err := newAppleTestClientConn(t, serverAddress, option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
Engine: "apple",
|
||||||
|
ServerName: "example.com",
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
clientConn.Close()
|
||||||
|
t.Fatal("expected server name mismatch handshake to fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result := <-serverResult; result.err == nil {
|
||||||
|
t.Fatal("expected server handshake to fail on server name mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleClientHandshakeRecoversAfterFailure(t *testing.T) {
|
||||||
|
serverCertificate, serverCertificatePEM := newAppleTestCertificate(t, "localhost")
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
serverConfig *stdtls.Config
|
||||||
|
clientOptions option.OutboundTLSOptions
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "version mismatch",
|
||||||
|
serverConfig: &stdtls.Config{
|
||||||
|
Certificates: []stdtls.Certificate{serverCertificate},
|
||||||
|
MinVersion: stdtls.VersionTLS13,
|
||||||
|
MaxVersion: stdtls.VersionTLS13,
|
||||||
|
},
|
||||||
|
clientOptions: option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
Engine: "apple",
|
||||||
|
ServerName: "localhost",
|
||||||
|
MaxVersion: "1.2",
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "server name mismatch",
|
||||||
|
serverConfig: &stdtls.Config{
|
||||||
|
Certificates: []stdtls.Certificate{serverCertificate},
|
||||||
|
},
|
||||||
|
clientOptions: option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
Engine: "apple",
|
||||||
|
ServerName: "example.com",
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
successClientOptions := option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
Engine: "apple",
|
||||||
|
ServerName: "localhost",
|
||||||
|
MinVersion: "1.2",
|
||||||
|
MaxVersion: "1.2",
|
||||||
|
ALPN: badoption.Listable[string]{"h2"},
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
for index := 0; index < appleTLSFailureRecoveryLoops; index++ {
|
||||||
|
failedResult, failedAddress := startAppleTLSTestServer(t, testCase.serverConfig)
|
||||||
|
failedConn, err := newAppleTestClientConn(t, failedAddress, testCase.clientOptions)
|
||||||
|
if err == nil {
|
||||||
|
_ = failedConn.Close()
|
||||||
|
t.Fatalf("iteration %d: expected handshake failure", index)
|
||||||
|
}
|
||||||
|
if result := <-failedResult; result.err == nil {
|
||||||
|
t.Fatalf("iteration %d: expected server handshake failure", index)
|
||||||
|
}
|
||||||
|
|
||||||
|
successResult, successAddress := startAppleTLSTestServer(t, &stdtls.Config{
|
||||||
|
Certificates: []stdtls.Certificate{serverCertificate},
|
||||||
|
MinVersion: stdtls.VersionTLS12,
|
||||||
|
MaxVersion: stdtls.VersionTLS12,
|
||||||
|
NextProtos: []string{"h2"},
|
||||||
|
})
|
||||||
|
successConn, err := newAppleTestClientConn(t, successAddress, successClientOptions)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("iteration %d: follow-up handshake failed: %v", index, err)
|
||||||
|
}
|
||||||
|
clientState := successConn.ConnectionState()
|
||||||
|
if clientState.NegotiatedProtocol != "h2" {
|
||||||
|
_ = successConn.Close()
|
||||||
|
t.Fatalf("iteration %d: unexpected negotiated protocol after failure: %q", index, clientState.NegotiatedProtocol)
|
||||||
|
}
|
||||||
|
_ = successConn.Close()
|
||||||
|
|
||||||
|
result := <-successResult
|
||||||
|
if result.err != nil {
|
||||||
|
t.Fatalf("iteration %d: follow-up server handshake failed: %v", index, result.err)
|
||||||
|
}
|
||||||
|
if result.state.NegotiatedProtocol != "h2" {
|
||||||
|
t.Fatalf("iteration %d: follow-up server negotiated unexpected protocol: %q", index, result.state.NegotiatedProtocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleClientReadDeadline(t *testing.T) {
|
||||||
|
serverCertificate, serverCertificatePEM := newAppleTestCertificate(t, "localhost")
|
||||||
|
serverDone, serverAddress := startAppleTLSSilentServer(t, &stdtls.Config{
|
||||||
|
Certificates: []stdtls.Certificate{serverCertificate},
|
||||||
|
MinVersion: stdtls.VersionTLS12,
|
||||||
|
MaxVersion: stdtls.VersionTLS12,
|
||||||
|
})
|
||||||
|
|
||||||
|
clientConn, err := newAppleTestClientConn(t, serverAddress, option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
Engine: "apple",
|
||||||
|
ServerName: "localhost",
|
||||||
|
MinVersion: "1.2",
|
||||||
|
MaxVersion: "1.2",
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer clientConn.Close()
|
||||||
|
defer close(serverDone)
|
||||||
|
|
||||||
|
err = clientConn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SetReadDeadline: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
readDone := make(chan error, 1)
|
||||||
|
buffer := make([]byte, 64)
|
||||||
|
go func() {
|
||||||
|
_, readErr := clientConn.Read(buffer)
|
||||||
|
readDone <- readErr
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case readErr := <-readDone:
|
||||||
|
if !errors.Is(readErr, os.ErrDeadlineExceeded) {
|
||||||
|
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", readErr)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Read did not return within 2s after deadline")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = clientConn.Read(buffer)
|
||||||
|
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
|
t.Fatalf("sticky deadline: expected os.ErrDeadlineExceeded, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleClientSetDeadlineClearsPreExpiredSticky(t *testing.T) {
|
||||||
|
serverCertificate, serverCertificatePEM := newAppleTestCertificate(t, "localhost")
|
||||||
|
serverDone, serverAddress := startAppleTLSSilentServer(t, &stdtls.Config{
|
||||||
|
Certificates: []stdtls.Certificate{serverCertificate},
|
||||||
|
MinVersion: stdtls.VersionTLS12,
|
||||||
|
MaxVersion: stdtls.VersionTLS12,
|
||||||
|
})
|
||||||
|
|
||||||
|
clientConn, err := newAppleTestClientConn(t, serverAddress, option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
Engine: "apple",
|
||||||
|
ServerName: "localhost",
|
||||||
|
MinVersion: "1.2",
|
||||||
|
MaxVersion: "1.2",
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer clientConn.Close()
|
||||||
|
defer close(serverDone)
|
||||||
|
|
||||||
|
err = clientConn.SetReadDeadline(time.Now().Add(-time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SetReadDeadline past: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-expired deadline trips sticky flag without cancelling nw_connection
|
||||||
|
// (prepareReadTimeout short-circuits before the C read is issued).
|
||||||
|
buffer := make([]byte, 64)
|
||||||
|
_, err = clientConn.Read(buffer)
|
||||||
|
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
|
t.Fatalf("pre-expired: expected os.ErrDeadlineExceeded, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = clientConn.SetReadDeadline(time.Time{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SetReadDeadline zero: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newDeadline := 300 * time.Millisecond
|
||||||
|
err = clientConn.SetReadDeadline(time.Now().Add(newDeadline))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SetReadDeadline future: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
readStart := time.Now()
|
||||||
|
_, err = clientConn.Read(buffer)
|
||||||
|
readElapsed := time.Since(readStart)
|
||||||
|
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
|
t.Fatalf("after clear: expected os.ErrDeadlineExceeded, got %v", err)
|
||||||
|
}
|
||||||
|
if readElapsed < newDeadline-50*time.Millisecond {
|
||||||
|
t.Fatalf("sticky flag was not cleared: Read returned after %v, expected ~%v", readElapsed, newDeadline)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func startAppleTLSSilentServer(t *testing.T, tlsConfig *stdtls.Config) (chan<- struct{}, string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
listener.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
if tcpListener, isTCP := listener.(*net.TCPListener); isTCP {
|
||||||
|
err = tcpListener.SetDeadline(time.Now().Add(appleTLSTestTimeout))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
conn, acceptErr := listener.Accept()
|
||||||
|
if acceptErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
handshakeErr := conn.SetDeadline(time.Now().Add(appleTLSTestTimeout))
|
||||||
|
if handshakeErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tlsConn := stdtls.Server(conn, tlsConfig)
|
||||||
|
defer tlsConn.Close()
|
||||||
|
handshakeErr = tlsConn.Handshake()
|
||||||
|
if handshakeErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
handshakeErr = conn.SetDeadline(time.Time{})
|
||||||
|
if handshakeErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
<-done
|
||||||
|
}()
|
||||||
|
return done, listener.Addr().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleTestCertificate(t *testing.T, serverName string) (stdtls.Certificate, string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
privateKeyPEM, certificatePEM, err := GenerateCertificate(nil, nil, time.Now, serverName, time.Now().Add(time.Hour))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
certificate, err := stdtls.X509KeyPair(certificatePEM, privateKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return certificate, string(certificatePEM)
|
||||||
|
}
|
||||||
|
|
||||||
|
func startAppleTLSTestServer(t *testing.T, tlsConfig *stdtls.Config) (<-chan appleTLSServerResult, string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
listener.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
if tcpListener, isTCP := listener.(*net.TCPListener); isTCP {
|
||||||
|
err = tcpListener.SetDeadline(time.Now().Add(appleTLSTestTimeout))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make(chan appleTLSServerResult, 1)
|
||||||
|
go func() {
|
||||||
|
defer close(result)
|
||||||
|
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
result <- appleTLSServerResult{err: err}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
err = conn.SetDeadline(time.Now().Add(appleTLSTestTimeout))
|
||||||
|
if err != nil {
|
||||||
|
result <- appleTLSServerResult{err: err}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn := stdtls.Server(conn, tlsConfig)
|
||||||
|
defer tlsConn.Close()
|
||||||
|
|
||||||
|
err = tlsConn.Handshake()
|
||||||
|
if err != nil {
|
||||||
|
result <- appleTLSServerResult{err: err}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result <- appleTLSServerResult{state: tlsConn.ConnectionState()}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return result, listener.Addr().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleTestClientConn(t *testing.T, serverAddress string, options option.OutboundTLSOptions) (Conn, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), appleTLSTestTimeout)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
clientConfig, err := NewClientWithOptions(ClientOptions{
|
||||||
|
Context: ctx,
|
||||||
|
Logger: logger.NOP(),
|
||||||
|
ServerAddress: "",
|
||||||
|
Options: options,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.DialTimeout("tcp", serverAddress, appleTLSTestTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn, err := ClientHandshake(ctx, conn, clientConfig)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return tlsConn, nil
|
||||||
|
}
|
||||||
15
common/tls/apple_client_stub.go
Normal file
15
common/tls/apple_client_stub.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build !darwin || !cgo
|
||||||
|
|
||||||
|
package tls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAppleClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
|
||||||
|
return nil, E.New("Apple TLS engine is not available on non-Apple platforms")
|
||||||
|
}
|
||||||
@@ -8,14 +8,49 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/common/badtls"
|
"github.com/sagernet/sing-box/common/badtls"
|
||||||
|
"github.com/sagernet/sing-box/common/tlsspoof"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
"github.com/sagernet/sing/common/logger"
|
"github.com/sagernet/sing/common/logger"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
aTLS "github.com/sagernet/sing/common/tls"
|
aTLS "github.com/sagernet/sing/common/tls"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var errMissingServerName = E.New("missing server_name or insecure=true")
|
||||||
|
|
||||||
|
func parseTLSSpoofOptions(serverName string, options option.OutboundTLSOptions) (string, tlsspoof.Method, error) {
|
||||||
|
if options.Spoof == "" {
|
||||||
|
if options.SpoofMethod != "" {
|
||||||
|
return "", 0, E.New("`spoof_method` requires `spoof`")
|
||||||
|
}
|
||||||
|
return "", 0, nil
|
||||||
|
}
|
||||||
|
if !tlsspoof.PlatformSupported {
|
||||||
|
return "", 0, E.New("`spoof` is not supported on this platform")
|
||||||
|
}
|
||||||
|
if options.DisableSNI || serverName == "" {
|
||||||
|
return "", 0, E.New("`spoof` requires TLS ClientHello with SNI")
|
||||||
|
}
|
||||||
|
method, err := tlsspoof.ParseMethod(options.SpoofMethod)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
return options.Spoof, method, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyTLSSpoof(conn net.Conn, spoof string, method tlsspoof.Method) (net.Conn, error) {
|
||||||
|
if spoof == "" {
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
spoofer, err := tlsspoof.NewSpoofer(conn, method)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return tlsspoof.NewConn(conn, spoofer, spoof), nil
|
||||||
|
}
|
||||||
|
|
||||||
func NewDialerFromOptions(ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) {
|
func NewDialerFromOptions(ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) {
|
||||||
if !options.Enabled {
|
if !options.Enabled {
|
||||||
return dialer, nil
|
return dialer, nil
|
||||||
@@ -46,6 +81,7 @@ type ClientOptions struct {
|
|||||||
Logger logger.ContextLogger
|
Logger logger.ContextLogger
|
||||||
ServerAddress string
|
ServerAddress string
|
||||||
Options option.OutboundTLSOptions
|
Options option.OutboundTLSOptions
|
||||||
|
AllowEmptyServerName bool
|
||||||
KTLSCompatible bool
|
KTLSCompatible bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,17 +97,22 @@ func NewClientWithOptions(options ClientOptions) (Config, error) {
|
|||||||
if options.Options.KernelRx {
|
if options.Options.KernelRx {
|
||||||
options.Logger.Warn("enabling kTLS RX will definitely reduce performance, please checkout https://sing-box.sagernet.org/configuration/shared/tls/#kernel_rx")
|
options.Logger.Warn("enabling kTLS RX will definitely reduce performance, please checkout https://sing-box.sagernet.org/configuration/shared/tls/#kernel_rx")
|
||||||
}
|
}
|
||||||
if options.Options.Reality != nil && options.Options.Reality.Enabled {
|
switch options.Options.Engine {
|
||||||
return NewRealityClient(options.Context, options.Logger, options.ServerAddress, options.Options)
|
case C.TLSEngineDefault, "go":
|
||||||
} else if options.Options.UTLS != nil && options.Options.UTLS.Enabled {
|
case C.TLSEngineApple:
|
||||||
return NewUTLSClient(options.Context, options.Logger, options.ServerAddress, options.Options)
|
return newAppleClient(options.Context, options.Logger, options.ServerAddress, options.Options, options.AllowEmptyServerName)
|
||||||
|
default:
|
||||||
|
return nil, E.New("unknown tls engine: ", options.Options.Engine)
|
||||||
}
|
}
|
||||||
return NewSTDClient(options.Context, options.Logger, options.ServerAddress, options.Options)
|
if options.Options.Reality != nil && options.Options.Reality.Enabled {
|
||||||
|
return newRealityClient(options.Context, options.Logger, options.ServerAddress, options.Options, options.AllowEmptyServerName)
|
||||||
|
} else if options.Options.UTLS != nil && options.Options.UTLS.Enabled {
|
||||||
|
return newUTLSClient(options.Context, options.Logger, options.ServerAddress, options.Options, options.AllowEmptyServerName)
|
||||||
|
}
|
||||||
|
return newSTDClient(options.Context, options.Logger, options.ServerAddress, options.Options, options.AllowEmptyServerName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, error) {
|
func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, C.TCPTimeout)
|
|
||||||
defer cancel()
|
|
||||||
tlsConn, err := aTLS.ClientHandshake(ctx, conn, config)
|
tlsConn, err := aTLS.ClientHandshake(ctx, conn, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -52,11 +52,18 @@ type RealityClientConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
func NewRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||||
|
return newRealityClient(ctx, logger, serverAddress, options, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
|
||||||
if options.UTLS == nil || !options.UTLS.Enabled {
|
if options.UTLS == nil || !options.UTLS.Enabled {
|
||||||
return nil, E.New("uTLS is required by reality client")
|
return nil, E.New("uTLS is required by reality client")
|
||||||
}
|
}
|
||||||
|
if options.Spoof != "" || options.SpoofMethod != "" {
|
||||||
|
return nil, E.New("spoof is unsupported in reality")
|
||||||
|
}
|
||||||
|
|
||||||
uClient, err := NewUTLSClient(ctx, logger, serverAddress, options)
|
uClient, err := newUTLSClient(ctx, logger, serverAddress, options, allowEmptyServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -108,6 +115,14 @@ func (e *RealityClientConfig) SetNextProtos(nextProto []string) {
|
|||||||
e.uClient.SetNextProtos(nextProto)
|
e.uClient.SetNextProtos(nextProto)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *RealityClientConfig) HandshakeTimeout() time.Duration {
|
||||||
|
return e.uClient.HandshakeTimeout()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RealityClientConfig) SetHandshakeTimeout(timeout time.Duration) {
|
||||||
|
e.uClient.SetHandshakeTimeout(timeout)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *RealityClientConfig) STDConfig() (*STDConfig, error) {
|
func (e *RealityClientConfig) STDConfig() (*STDConfig, error) {
|
||||||
return nil, E.New("unsupported usage for reality")
|
return nil, E.New("unsupported usage for reality")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ var _ ServerConfigCompat = (*RealityServerConfig)(nil)
|
|||||||
|
|
||||||
type RealityServerConfig struct {
|
type RealityServerConfig struct {
|
||||||
config *utls.RealityConfig
|
config *utls.RealityConfig
|
||||||
|
handshakeTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRealityServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) {
|
func NewRealityServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) {
|
||||||
@@ -130,7 +131,16 @@ func NewRealityServer(ctx context.Context, logger log.ContextLogger, options opt
|
|||||||
if options.ECH != nil && options.ECH.Enabled {
|
if options.ECH != nil && options.ECH.Enabled {
|
||||||
return nil, E.New("Reality is conflict with ECH")
|
return nil, E.New("Reality is conflict with ECH")
|
||||||
}
|
}
|
||||||
var config ServerConfig = &RealityServerConfig{&tlsConfig}
|
var handshakeTimeout time.Duration
|
||||||
|
if options.HandshakeTimeout > 0 {
|
||||||
|
handshakeTimeout = options.HandshakeTimeout.Build()
|
||||||
|
} else {
|
||||||
|
handshakeTimeout = C.TCPTimeout
|
||||||
|
}
|
||||||
|
var config ServerConfig = &RealityServerConfig{
|
||||||
|
config: &tlsConfig,
|
||||||
|
handshakeTimeout: handshakeTimeout,
|
||||||
|
}
|
||||||
if options.KernelTx || options.KernelRx {
|
if options.KernelTx || options.KernelRx {
|
||||||
if !C.IsLinux {
|
if !C.IsLinux {
|
||||||
return nil, E.New("kTLS is only supported on Linux")
|
return nil, E.New("kTLS is only supported on Linux")
|
||||||
@@ -161,6 +171,14 @@ func (c *RealityServerConfig) SetNextProtos(nextProto []string) {
|
|||||||
c.config.NextProtos = nextProto
|
c.config.NextProtos = nextProto
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *RealityServerConfig) HandshakeTimeout() time.Duration {
|
||||||
|
return c.handshakeTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RealityServerConfig) SetHandshakeTimeout(timeout time.Duration) {
|
||||||
|
c.handshakeTimeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
func (c *RealityServerConfig) STDConfig() (*tls.Config, error) {
|
func (c *RealityServerConfig) STDConfig() (*tls.Config, error) {
|
||||||
return nil, E.New("unsupported usage for reality")
|
return nil, E.New("unsupported usage for reality")
|
||||||
}
|
}
|
||||||
@@ -192,6 +210,7 @@ func (c *RealityServerConfig) ServerHandshake(ctx context.Context, conn net.Conn
|
|||||||
func (c *RealityServerConfig) Clone() Config {
|
func (c *RealityServerConfig) Clone() Config {
|
||||||
return &RealityServerConfig{
|
return &RealityServerConfig{
|
||||||
config: c.config.Clone(),
|
config: c.config.Clone(),
|
||||||
|
handshakeTimeout: c.handshakeTimeout,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -46,8 +46,11 @@ func NewServerWithOptions(options ServerOptions) (ServerConfig, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ServerHandshake(ctx context.Context, conn net.Conn, config ServerConfig) (Conn, error) {
|
func ServerHandshake(ctx context.Context, conn net.Conn, config ServerConfig) (Conn, error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, C.TCPTimeout)
|
if config.HandshakeTimeout() == 0 {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, C.TCPTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
}
|
||||||
tlsConn, err := aTLS.ServerHandshake(ctx, conn, config)
|
tlsConn, err := aTLS.ServerHandshake(ctx, conn, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
"github.com/sagernet/sing-box/common/tlsfragment"
|
"github.com/sagernet/sing-box/common/tlsfragment"
|
||||||
|
"github.com/sagernet/sing-box/common/tlsspoof"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
@@ -24,16 +25,32 @@ import (
|
|||||||
type STDClientConfig struct {
|
type STDClientConfig struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
config *tls.Config
|
config *tls.Config
|
||||||
|
serverName string
|
||||||
|
disableSNI bool
|
||||||
|
verifyServerName bool
|
||||||
|
handshakeTimeout time.Duration
|
||||||
fragment bool
|
fragment bool
|
||||||
fragmentFallbackDelay time.Duration
|
fragmentFallbackDelay time.Duration
|
||||||
recordFragment bool
|
recordFragment bool
|
||||||
|
spoof string
|
||||||
|
spoofMethod tlsspoof.Method
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *STDClientConfig) ServerName() string {
|
func (c *STDClientConfig) ServerName() string {
|
||||||
return c.config.ServerName
|
return c.serverName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *STDClientConfig) SetServerName(serverName string) {
|
func (c *STDClientConfig) SetServerName(serverName string) {
|
||||||
|
c.serverName = serverName
|
||||||
|
if c.disableSNI {
|
||||||
|
c.config.ServerName = ""
|
||||||
|
if c.verifyServerName {
|
||||||
|
c.config.VerifyConnection = verifyConnection(c.config.RootCAs, c.config.Time, serverName)
|
||||||
|
} else {
|
||||||
|
c.config.VerifyConnection = nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
c.config.ServerName = serverName
|
c.config.ServerName = serverName
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -45,6 +62,14 @@ func (c *STDClientConfig) SetNextProtos(nextProto []string) {
|
|||||||
c.config.NextProtos = nextProto
|
c.config.NextProtos = nextProto
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *STDClientConfig) HandshakeTimeout() time.Duration {
|
||||||
|
return c.handshakeTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *STDClientConfig) SetHandshakeTimeout(timeout time.Duration) {
|
||||||
|
c.handshakeTimeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
func (c *STDClientConfig) STDConfig() (*STDConfig, error) {
|
func (c *STDClientConfig) STDConfig() (*STDConfig, error) {
|
||||||
return c.config, nil
|
return c.config, nil
|
||||||
}
|
}
|
||||||
@@ -53,17 +78,29 @@ func (c *STDClientConfig) Client(conn net.Conn) (Conn, error) {
|
|||||||
if c.recordFragment {
|
if c.recordFragment {
|
||||||
conn = tf.NewConn(conn, c.ctx, c.fragment, c.recordFragment, c.fragmentFallbackDelay)
|
conn = tf.NewConn(conn, c.ctx, c.fragment, c.recordFragment, c.fragmentFallbackDelay)
|
||||||
}
|
}
|
||||||
|
conn, err := applyTLSSpoof(conn, c.spoof, c.spoofMethod)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return tls.Client(conn, c.config), nil
|
return tls.Client(conn, c.config), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *STDClientConfig) Clone() Config {
|
func (c *STDClientConfig) Clone() Config {
|
||||||
return &STDClientConfig{
|
cloned := &STDClientConfig{
|
||||||
ctx: c.ctx,
|
ctx: c.ctx,
|
||||||
config: c.config.Clone(),
|
config: c.config.Clone(),
|
||||||
|
serverName: c.serverName,
|
||||||
|
disableSNI: c.disableSNI,
|
||||||
|
verifyServerName: c.verifyServerName,
|
||||||
|
handshakeTimeout: c.handshakeTimeout,
|
||||||
fragment: c.fragment,
|
fragment: c.fragment,
|
||||||
fragmentFallbackDelay: c.fragmentFallbackDelay,
|
fragmentFallbackDelay: c.fragmentFallbackDelay,
|
||||||
recordFragment: c.recordFragment,
|
recordFragment: c.recordFragment,
|
||||||
|
spoof: c.spoof,
|
||||||
|
spoofMethod: c.spoofMethod,
|
||||||
}
|
}
|
||||||
|
cloned.SetServerName(cloned.serverName)
|
||||||
|
return cloned
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *STDClientConfig) ECHConfigList() []byte {
|
func (c *STDClientConfig) ECHConfigList() []byte {
|
||||||
@@ -75,41 +112,27 @@ func (c *STDClientConfig) SetECHConfigList(EncryptedClientHelloConfigList []byte
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||||
|
return newSTDClient(ctx, logger, serverAddress, options, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
|
||||||
var serverName string
|
var serverName string
|
||||||
if options.ServerName != "" {
|
if options.ServerName != "" {
|
||||||
serverName = options.ServerName
|
serverName = options.ServerName
|
||||||
} else if serverAddress != "" {
|
} else if serverAddress != "" {
|
||||||
serverName = serverAddress
|
serverName = serverAddress
|
||||||
}
|
}
|
||||||
if serverName == "" && !options.Insecure {
|
if serverName == "" && !options.Insecure && !allowEmptyServerName {
|
||||||
return nil, E.New("missing server_name or insecure=true")
|
return nil, errMissingServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
var tlsConfig tls.Config
|
var tlsConfig tls.Config
|
||||||
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
|
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
|
||||||
tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx)
|
tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx)
|
||||||
if !options.DisableSNI {
|
|
||||||
tlsConfig.ServerName = serverName
|
|
||||||
}
|
|
||||||
if options.Insecure {
|
if options.Insecure {
|
||||||
tlsConfig.InsecureSkipVerify = options.Insecure
|
tlsConfig.InsecureSkipVerify = options.Insecure
|
||||||
} else if options.DisableSNI {
|
} else if options.DisableSNI {
|
||||||
tlsConfig.InsecureSkipVerify = true
|
tlsConfig.InsecureSkipVerify = true
|
||||||
tlsConfig.VerifyConnection = func(state tls.ConnectionState) error {
|
|
||||||
verifyOptions := x509.VerifyOptions{
|
|
||||||
Roots: tlsConfig.RootCAs,
|
|
||||||
DNSName: serverName,
|
|
||||||
Intermediates: x509.NewCertPool(),
|
|
||||||
}
|
|
||||||
for _, cert := range state.PeerCertificates[1:] {
|
|
||||||
verifyOptions.Intermediates.AddCert(cert)
|
|
||||||
}
|
|
||||||
if tlsConfig.Time != nil {
|
|
||||||
verifyOptions.CurrentTime = tlsConfig.Time()
|
|
||||||
}
|
|
||||||
_, err := state.PeerCertificates[0].Verify(verifyOptions)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if len(options.CertificatePublicKeySHA256) > 0 {
|
if len(options.CertificatePublicKeySHA256) > 0 {
|
||||||
if len(options.Certificate) > 0 || options.CertificatePath != "" {
|
if len(options.Certificate) > 0 || options.CertificatePath != "" {
|
||||||
@@ -117,7 +140,7 @@ func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddres
|
|||||||
}
|
}
|
||||||
tlsConfig.InsecureSkipVerify = true
|
tlsConfig.InsecureSkipVerify = true
|
||||||
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||||
return verifyPublicKeySHA256(options.CertificatePublicKeySHA256, rawCerts, tlsConfig.Time)
|
return VerifyPublicKeySHA256(options.CertificatePublicKeySHA256, rawCerts)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(options.ALPN) > 0 {
|
if len(options.ALPN) > 0 {
|
||||||
@@ -198,7 +221,30 @@ func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddres
|
|||||||
} else if len(clientCertificate) > 0 || len(clientKey) > 0 {
|
} else if len(clientCertificate) > 0 || len(clientKey) > 0 {
|
||||||
return nil, E.New("client certificate and client key must be provided together")
|
return nil, E.New("client certificate and client key must be provided together")
|
||||||
}
|
}
|
||||||
var config Config = &STDClientConfig{ctx, &tlsConfig, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment}
|
var handshakeTimeout time.Duration
|
||||||
|
if options.HandshakeTimeout > 0 {
|
||||||
|
handshakeTimeout = options.HandshakeTimeout.Build()
|
||||||
|
} else {
|
||||||
|
handshakeTimeout = C.TCPTimeout
|
||||||
|
}
|
||||||
|
spoof, spoofMethod, err := parseTLSSpoofOptions(serverName, options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var config Config = &STDClientConfig{
|
||||||
|
ctx: ctx,
|
||||||
|
config: &tlsConfig,
|
||||||
|
serverName: serverName,
|
||||||
|
disableSNI: options.DisableSNI,
|
||||||
|
verifyServerName: options.DisableSNI && !options.Insecure,
|
||||||
|
handshakeTimeout: handshakeTimeout,
|
||||||
|
fragment: options.Fragment,
|
||||||
|
fragmentFallbackDelay: time.Duration(options.FragmentFallbackDelay),
|
||||||
|
recordFragment: options.RecordFragment,
|
||||||
|
spoof: spoof,
|
||||||
|
spoofMethod: spoofMethod,
|
||||||
|
}
|
||||||
|
config.SetServerName(serverName)
|
||||||
if options.ECH != nil && options.ECH.Enabled {
|
if options.ECH != nil && options.ECH.Enabled {
|
||||||
var err error
|
var err error
|
||||||
config, err = parseECHClientConfig(ctx, config.(ECHCapableConfig), options)
|
config, err = parseECHClientConfig(ctx, config.(ECHCapableConfig), options)
|
||||||
@@ -220,7 +266,28 @@ func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddres
|
|||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func verifyPublicKeySHA256(knownHashValues [][]byte, rawCerts [][]byte, timeFunc func() time.Time) error {
|
func verifyConnection(rootCAs *x509.CertPool, timeFunc func() time.Time, serverName string) func(state tls.ConnectionState) error {
|
||||||
|
return func(state tls.ConnectionState) error {
|
||||||
|
if serverName == "" {
|
||||||
|
return errMissingServerName
|
||||||
|
}
|
||||||
|
verifyOptions := x509.VerifyOptions{
|
||||||
|
Roots: rootCAs,
|
||||||
|
DNSName: serverName,
|
||||||
|
Intermediates: x509.NewCertPool(),
|
||||||
|
}
|
||||||
|
for _, cert := range state.PeerCertificates[1:] {
|
||||||
|
verifyOptions.Intermediates.AddCert(cert)
|
||||||
|
}
|
||||||
|
if timeFunc != nil {
|
||||||
|
verifyOptions.CurrentTime = timeFunc()
|
||||||
|
}
|
||||||
|
_, err := state.PeerCertificates[0].Verify(verifyOptions)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func VerifyPublicKeySHA256(knownHashValues [][]byte, rawCerts [][]byte) error {
|
||||||
leafCertificate, err := x509.ParseCertificate(rawCerts[0])
|
leafCertificate, err := x509.ParseCertificate(rawCerts[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "failed to parse leaf certificate")
|
return E.Cause(err, "failed to parse leaf certificate")
|
||||||
|
|||||||
@@ -92,6 +92,7 @@ func getACMENextProtos(provider adapter.CertificateProvider) []string {
|
|||||||
type STDServerConfig struct {
|
type STDServerConfig struct {
|
||||||
access sync.RWMutex
|
access sync.RWMutex
|
||||||
config *tls.Config
|
config *tls.Config
|
||||||
|
handshakeTimeout time.Duration
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
certificateProvider managedCertificateProvider
|
certificateProvider managedCertificateProvider
|
||||||
acmeService adapter.SimpleLifecycle
|
acmeService adapter.SimpleLifecycle
|
||||||
@@ -139,6 +140,18 @@ func (c *STDServerConfig) SetNextProtos(nextProto []string) {
|
|||||||
c.config = config
|
c.config = config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *STDServerConfig) HandshakeTimeout() time.Duration {
|
||||||
|
c.access.RLock()
|
||||||
|
defer c.access.RUnlock()
|
||||||
|
return c.handshakeTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *STDServerConfig) SetHandshakeTimeout(timeout time.Duration) {
|
||||||
|
c.access.Lock()
|
||||||
|
defer c.access.Unlock()
|
||||||
|
c.handshakeTimeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
func (c *STDServerConfig) hasACMEALPN() bool {
|
func (c *STDServerConfig) hasACMEALPN() bool {
|
||||||
if c.acmeService != nil {
|
if c.acmeService != nil {
|
||||||
return true
|
return true
|
||||||
@@ -166,6 +179,7 @@ func (c *STDServerConfig) Server(conn net.Conn) (Conn, error) {
|
|||||||
func (c *STDServerConfig) Clone() Config {
|
func (c *STDServerConfig) Clone() Config {
|
||||||
return &STDServerConfig{
|
return &STDServerConfig{
|
||||||
config: c.config.Clone(),
|
config: c.config.Clone(),
|
||||||
|
handshakeTimeout: c.handshakeTimeout,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -458,7 +472,7 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.
|
|||||||
tlsConfig.ClientAuth = tls.RequestClientCert
|
tlsConfig.ClientAuth = tls.RequestClientCert
|
||||||
}
|
}
|
||||||
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||||
return verifyPublicKeySHA256(options.ClientCertificatePublicKeySHA256, rawCerts, tlsConfig.Time)
|
return VerifyPublicKeySHA256(options.ClientCertificatePublicKeySHA256, rawCerts)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return nil, E.New("missing client_certificate, client_certificate_path or client_certificate_public_key_sha256 for client authentication")
|
return nil, E.New("missing client_certificate, client_certificate_path or client_certificate_public_key_sha256 for client authentication")
|
||||||
@@ -471,8 +485,15 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
var handshakeTimeout time.Duration
|
||||||
|
if options.HandshakeTimeout > 0 {
|
||||||
|
handshakeTimeout = options.HandshakeTimeout.Build()
|
||||||
|
} else {
|
||||||
|
handshakeTimeout = C.TCPTimeout
|
||||||
|
}
|
||||||
serverConfig := &STDServerConfig{
|
serverConfig := &STDServerConfig{
|
||||||
config: tlsConfig,
|
config: tlsConfig,
|
||||||
|
handshakeTimeout: handshakeTimeout,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
certificateProvider: certificateProvider,
|
certificateProvider: certificateProvider,
|
||||||
acmeService: acmeService,
|
acmeService: acmeService,
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
"github.com/sagernet/sing-box/common/tlsfragment"
|
"github.com/sagernet/sing-box/common/tlsfragment"
|
||||||
|
"github.com/sagernet/sing-box/common/tlsspoof"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
@@ -28,17 +29,33 @@ import (
|
|||||||
type UTLSClientConfig struct {
|
type UTLSClientConfig struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
config *utls.Config
|
config *utls.Config
|
||||||
|
serverName string
|
||||||
|
disableSNI bool
|
||||||
|
verifyServerName bool
|
||||||
|
handshakeTimeout time.Duration
|
||||||
id utls.ClientHelloID
|
id utls.ClientHelloID
|
||||||
fragment bool
|
fragment bool
|
||||||
fragmentFallbackDelay time.Duration
|
fragmentFallbackDelay time.Duration
|
||||||
recordFragment bool
|
recordFragment bool
|
||||||
|
spoof string
|
||||||
|
spoofMethod tlsspoof.Method
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UTLSClientConfig) ServerName() string {
|
func (c *UTLSClientConfig) ServerName() string {
|
||||||
return c.config.ServerName
|
return c.serverName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UTLSClientConfig) SetServerName(serverName string) {
|
func (c *UTLSClientConfig) SetServerName(serverName string) {
|
||||||
|
c.serverName = serverName
|
||||||
|
if c.disableSNI {
|
||||||
|
c.config.ServerName = ""
|
||||||
|
if c.verifyServerName {
|
||||||
|
c.config.InsecureServerNameToVerify = serverName
|
||||||
|
} else {
|
||||||
|
c.config.InsecureServerNameToVerify = ""
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
c.config.ServerName = serverName
|
c.config.ServerName = serverName
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,6 +70,14 @@ func (c *UTLSClientConfig) SetNextProtos(nextProto []string) {
|
|||||||
c.config.NextProtos = nextProto
|
c.config.NextProtos = nextProto
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *UTLSClientConfig) HandshakeTimeout() time.Duration {
|
||||||
|
return c.handshakeTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UTLSClientConfig) SetHandshakeTimeout(timeout time.Duration) {
|
||||||
|
c.handshakeTimeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
func (c *UTLSClientConfig) STDConfig() (*STDConfig, error) {
|
func (c *UTLSClientConfig) STDConfig() (*STDConfig, error) {
|
||||||
return nil, E.New("unsupported usage for uTLS")
|
return nil, E.New("unsupported usage for uTLS")
|
||||||
}
|
}
|
||||||
@@ -61,6 +86,10 @@ func (c *UTLSClientConfig) Client(conn net.Conn) (Conn, error) {
|
|||||||
if c.recordFragment {
|
if c.recordFragment {
|
||||||
conn = tf.NewConn(conn, c.ctx, c.fragment, c.recordFragment, c.fragmentFallbackDelay)
|
conn = tf.NewConn(conn, c.ctx, c.fragment, c.recordFragment, c.fragmentFallbackDelay)
|
||||||
}
|
}
|
||||||
|
conn, err := applyTLSSpoof(conn, c.spoof, c.spoofMethod)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return &utlsALPNWrapper{utlsConnWrapper{utls.UClient(conn, c.config.Clone(), c.id)}, c.config.NextProtos}, nil
|
return &utlsALPNWrapper{utlsConnWrapper{utls.UClient(conn, c.config.Clone(), c.id)}, c.config.NextProtos}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,9 +98,22 @@ func (c *UTLSClientConfig) SetSessionIDGenerator(generator func(clientHello []by
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *UTLSClientConfig) Clone() Config {
|
func (c *UTLSClientConfig) Clone() Config {
|
||||||
return &UTLSClientConfig{
|
cloned := &UTLSClientConfig{
|
||||||
c.ctx, c.config.Clone(), c.id, c.fragment, c.fragmentFallbackDelay, c.recordFragment,
|
ctx: c.ctx,
|
||||||
|
config: c.config.Clone(),
|
||||||
|
serverName: c.serverName,
|
||||||
|
disableSNI: c.disableSNI,
|
||||||
|
verifyServerName: c.verifyServerName,
|
||||||
|
handshakeTimeout: c.handshakeTimeout,
|
||||||
|
id: c.id,
|
||||||
|
fragment: c.fragment,
|
||||||
|
fragmentFallbackDelay: c.fragmentFallbackDelay,
|
||||||
|
recordFragment: c.recordFragment,
|
||||||
|
spoof: c.spoof,
|
||||||
|
spoofMethod: c.spoofMethod,
|
||||||
}
|
}
|
||||||
|
cloned.SetServerName(cloned.serverName)
|
||||||
|
return cloned
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UTLSClientConfig) ECHConfigList() []byte {
|
func (c *UTLSClientConfig) ECHConfigList() []byte {
|
||||||
@@ -143,29 +185,29 @@ func (c *utlsALPNWrapper) HandshakeContext(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||||
|
return newUTLSClient(ctx, logger, serverAddress, options, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
|
||||||
var serverName string
|
var serverName string
|
||||||
if options.ServerName != "" {
|
if options.ServerName != "" {
|
||||||
serverName = options.ServerName
|
serverName = options.ServerName
|
||||||
} else if serverAddress != "" {
|
} else if serverAddress != "" {
|
||||||
serverName = serverAddress
|
serverName = serverAddress
|
||||||
}
|
}
|
||||||
if serverName == "" && !options.Insecure {
|
if serverName == "" && !options.Insecure && !allowEmptyServerName {
|
||||||
return nil, E.New("missing server_name or insecure=true")
|
return nil, errMissingServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
var tlsConfig utls.Config
|
var tlsConfig utls.Config
|
||||||
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
|
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
|
||||||
tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx)
|
tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx)
|
||||||
if !options.DisableSNI {
|
|
||||||
tlsConfig.ServerName = serverName
|
|
||||||
}
|
|
||||||
if options.Insecure {
|
if options.Insecure {
|
||||||
tlsConfig.InsecureSkipVerify = options.Insecure
|
tlsConfig.InsecureSkipVerify = options.Insecure
|
||||||
} else if options.DisableSNI {
|
} else if options.DisableSNI {
|
||||||
if options.Reality != nil && options.Reality.Enabled {
|
if options.Reality != nil && options.Reality.Enabled {
|
||||||
return nil, E.New("disable_sni is unsupported in reality")
|
return nil, E.New("disable_sni is unsupported in reality")
|
||||||
}
|
}
|
||||||
tlsConfig.InsecureServerNameToVerify = serverName
|
|
||||||
}
|
}
|
||||||
if len(options.CertificatePublicKeySHA256) > 0 {
|
if len(options.CertificatePublicKeySHA256) > 0 {
|
||||||
if len(options.Certificate) > 0 || options.CertificatePath != "" {
|
if len(options.Certificate) > 0 || options.CertificatePath != "" {
|
||||||
@@ -173,7 +215,7 @@ func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddre
|
|||||||
}
|
}
|
||||||
tlsConfig.InsecureSkipVerify = true
|
tlsConfig.InsecureSkipVerify = true
|
||||||
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||||
return verifyPublicKeySHA256(options.CertificatePublicKeySHA256, rawCerts, tlsConfig.Time)
|
return VerifyPublicKeySHA256(options.CertificatePublicKeySHA256, rawCerts)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(options.ALPN) > 0 {
|
if len(options.ALPN) > 0 {
|
||||||
@@ -251,11 +293,35 @@ func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddre
|
|||||||
} else if len(clientCertificate) > 0 || len(clientKey) > 0 {
|
} else if len(clientCertificate) > 0 || len(clientKey) > 0 {
|
||||||
return nil, E.New("client certificate and client key must be provided together")
|
return nil, E.New("client certificate and client key must be provided together")
|
||||||
}
|
}
|
||||||
|
var handshakeTimeout time.Duration
|
||||||
|
if options.HandshakeTimeout > 0 {
|
||||||
|
handshakeTimeout = options.HandshakeTimeout.Build()
|
||||||
|
} else {
|
||||||
|
handshakeTimeout = C.TCPTimeout
|
||||||
|
}
|
||||||
|
spoof, spoofMethod, err := parseTLSSpoofOptions(serverName, options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
id, err := uTLSClientHelloID(options.UTLS.Fingerprint)
|
id, err := uTLSClientHelloID(options.UTLS.Fingerprint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var config Config = &UTLSClientConfig{ctx, &tlsConfig, id, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment}
|
var config Config = &UTLSClientConfig{
|
||||||
|
ctx: ctx,
|
||||||
|
config: &tlsConfig,
|
||||||
|
serverName: serverName,
|
||||||
|
disableSNI: options.DisableSNI,
|
||||||
|
verifyServerName: options.DisableSNI && !options.Insecure,
|
||||||
|
handshakeTimeout: handshakeTimeout,
|
||||||
|
id: id,
|
||||||
|
fragment: options.Fragment,
|
||||||
|
fragmentFallbackDelay: time.Duration(options.FragmentFallbackDelay),
|
||||||
|
recordFragment: options.RecordFragment,
|
||||||
|
spoof: spoof,
|
||||||
|
spoofMethod: spoofMethod,
|
||||||
|
}
|
||||||
|
config.SetServerName(serverName)
|
||||||
if options.ECH != nil && options.ECH.Enabled {
|
if options.ECH != nil && options.ECH.Enabled {
|
||||||
if options.Reality != nil && options.Reality.Enabled {
|
if options.Reality != nil && options.Reality.Enabled {
|
||||||
return nil, E.New("Reality is conflict with ECH")
|
return nil, E.New("Reality is conflict with ECH")
|
||||||
|
|||||||
@@ -12,10 +12,18 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||||
|
return newUTLSClient(ctx, logger, serverAddress, options, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
|
||||||
return nil, E.New(`uTLS is not included in this build, rebuild with -tags with_utls`)
|
return nil, E.New(`uTLS is not included in this build, rebuild with -tags with_utls`)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
func NewRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||||
|
return newRealityClient(ctx, logger, serverAddress, options, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
|
||||||
return nil, E.New(`uTLS, which is required by reality is not included in this build, rebuild with -tags with_utls`)
|
return nil, E.New(`uTLS, which is required by reality is not included in this build, rebuild with -tags with_utls`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ type MyServerName struct {
|
|||||||
Index int
|
Index int
|
||||||
Length int
|
Length int
|
||||||
ServerName string
|
ServerName string
|
||||||
|
ExtensionsListLengthIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
func IndexTLSServerName(payload []byte) *MyServerName {
|
func IndexTLSServerName(payload []byte) *MyServerName {
|
||||||
@@ -41,6 +42,7 @@ func IndexTLSServerName(payload []byte) *MyServerName {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
serverName.Index += recordLayerHeaderLen
|
serverName.Index += recordLayerHeaderLen
|
||||||
|
serverName.ExtensionsListLengthIndex += recordLayerHeaderLen
|
||||||
return serverName
|
return serverName
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,6 +84,7 @@ func indexTLSServerNameFromHandshake(handshake []byte) *MyServerName {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
serverName.Index += currentIndex
|
serverName.Index += currentIndex
|
||||||
|
serverName.ExtensionsListLengthIndex = currentIndex
|
||||||
return serverName
|
return serverName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
86
common/tlsspoof/client_hello.go
Normal file
86
common/tlsspoof/client_hello.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
|
||||||
|
tf "github.com/sagernet/sing-box/common/tlsfragment"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
recordLengthOffset = 3
|
||||||
|
handshakeLengthOffset = 6
|
||||||
|
)
|
||||||
|
|
||||||
|
// server_name extension layout (RFC 6066 §3). Offsets are relative to the
|
||||||
|
// SNI host name (index returned by the parser):
|
||||||
|
//
|
||||||
|
// ... uint16 extension_type = 0x0000 (host_name - 9)
|
||||||
|
// ... uint16 extension_data_length (host_name - 7)
|
||||||
|
// ... uint16 server_name_list_length (host_name - 5)
|
||||||
|
// ... uint8 name_type = host_name (host_name - 3)
|
||||||
|
// ... uint16 host_name_length (host_name - 2)
|
||||||
|
// sni host_name (host_name)
|
||||||
|
const (
|
||||||
|
extensionDataLengthOffsetFromSNI = -7
|
||||||
|
listLengthOffsetFromSNI = -5
|
||||||
|
hostNameLengthOffsetFromSNI = -2
|
||||||
|
)
|
||||||
|
|
||||||
|
func rewriteSNI(record []byte, fakeSNI string) ([]byte, error) {
|
||||||
|
if len(fakeSNI) > 0xFFFF {
|
||||||
|
return nil, E.New("fake SNI too long: ", len(fakeSNI), " bytes")
|
||||||
|
}
|
||||||
|
serverName := tf.IndexTLSServerName(record)
|
||||||
|
if serverName == nil {
|
||||||
|
return nil, E.New("not a ClientHello with SNI")
|
||||||
|
}
|
||||||
|
|
||||||
|
delta := len(fakeSNI) - serverName.Length
|
||||||
|
out := make([]byte, len(record)+delta)
|
||||||
|
copy(out, record[:serverName.Index])
|
||||||
|
copy(out[serverName.Index:], fakeSNI)
|
||||||
|
copy(out[serverName.Index+len(fakeSNI):], record[serverName.Index+serverName.Length:])
|
||||||
|
|
||||||
|
err := patchUint16(out, recordLengthOffset, delta)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "patch record length")
|
||||||
|
}
|
||||||
|
err = patchUint24(out, handshakeLengthOffset, delta)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "patch handshake length")
|
||||||
|
}
|
||||||
|
for _, off := range []int{
|
||||||
|
serverName.ExtensionsListLengthIndex,
|
||||||
|
serverName.Index + extensionDataLengthOffsetFromSNI,
|
||||||
|
serverName.Index + listLengthOffsetFromSNI,
|
||||||
|
serverName.Index + hostNameLengthOffsetFromSNI,
|
||||||
|
} {
|
||||||
|
err = patchUint16(out, off, delta)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "patch length at offset ", off)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func patchUint16(data []byte, offset, delta int) error {
|
||||||
|
patched := int(binary.BigEndian.Uint16(data[offset:])) + delta
|
||||||
|
if patched < 0 || patched > 0xFFFF {
|
||||||
|
return E.New("uint16 out of range: ", patched)
|
||||||
|
}
|
||||||
|
binary.BigEndian.PutUint16(data[offset:], uint16(patched))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func patchUint24(data []byte, offset, delta int) error {
|
||||||
|
original := int(data[offset])<<16 | int(data[offset+1])<<8 | int(data[offset+2])
|
||||||
|
patched := original + delta
|
||||||
|
if patched < 0 || patched > 0xFFFFFF {
|
||||||
|
return E.New("uint24 out of range: ", patched)
|
||||||
|
}
|
||||||
|
data[offset] = byte(patched >> 16)
|
||||||
|
data[offset+1] = byte(patched >> 8)
|
||||||
|
data[offset+2] = byte(patched)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
79
common/tlsspoof/client_hello_test.go
Normal file
79
common/tlsspoof/client_hello_test.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
tf "github.com/sagernet/sing-box/common/tlsfragment"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// realClientHello is a captured Chrome ClientHello for github.com,
|
||||||
|
// reused from common/tlsfragment/index_test.go.
|
||||||
|
const realClientHello = "16030105f8010005f403036e35de7389a679c54029cf452611f2211c70d9ac3897271de589ab6155f8e4ab20637d225f1ef969ad87ed78bfb9d171300bcb1703b6f314ccefb964f79b7d0961002a0a0a130213031301c02cc02bcca9c030c02fcca8c00ac009c014c013009d009c0035002fc008c012000a01000581baba00000000000f000d00000a6769746875622e636f6d00170000ff01000100000a000e000c3a3a11ec001d001700180019000b000201000010000e000c02683208687474702f312e31000500050100000000000d00160014040308040401050308050805050108060601020100120000003304ef04ed3a3a00010011ec04c0aeb2250c092a3463161cccb29d9183331a424964248579507ed23a180b0ceab2a5f5d9ce41547e497a89055471ea572867ba3a1fc3c9e45025274a20f60c6b60e62476b6afed0403af59ab83660ef4112ae20386a602010d0a5d454c0ed34c84ed4423e750213e6a2baab1bf9c4367a6007ab40a33d95220c2dcaa44f257024a5626b545db0510f4311b1a60714154909c6a61fdfca011fb2626d657aeb6070bf078508babe3b584555013e34acc56198ed4663742b3155a664a9901794c4586820a7dc162c01827291f3792e1237f801a8d1ef096013c181c4a58d2f6859ba75022d18cc4418bd4f351d5c18f83a58857d05af860c4b9ac018a5b63f17184e591532c6bc2cf2215d4a282c8a8a4f6f7aee110422c8bc9ebd3b1d609c568523aaae555db320e6c269473d87af38c256cbb9febc20aea6380c32a8916f7a373c8b1e37554e3260bf6621f6b804ee80b3c516b1d01985bf4c603b6daa9a5991de6a7a29f3a7122b8afb843a7660110fce62b43c615f5bcc2db688ba012649c0952b0a2c031e732d2b454c6b2968683cb8d244be2c9a7fa163222979eaf92722b92b862d81a3d94450c2b60c318421ebb4307c42d1f0473592a5c30e42039cc68cda9721e61aa63f49def17c15221680ed444896340133bbee67556f56b9f9d78a4df715f926a12add0cc9c862e46ea8b7316ae468282c18601b2771c9c9322f982228cf93effaacd3f80cbd12bce5fc36f56e2a3caf91e578a5fae00c9b23a8ed1a66764f4433c3628a70b8f0a6196adc60a4cb4226f07ba4c6b363fe9065563bfc1347452946386bab488686e837ab979c64f9047417fca635fe1bb4f074f256cc8af837c7b455e280426547755af90a61640169ef180aea3a77e662bb6dac1b6c3696027129b1a5edf495314e9c7f4b6110e16378ec893fa24642330a40aba1a85326101acb97c620fd8d71389e69eaed7bdb01bbe1fd428d66191150c7b2cd1ad4257391676a82ba8ce07fb2667c3b289f159003a7c7bc31d361b7b7f49a802961739d950dfcc0fa1c7abce5abdd2245101da391151490862028110465950b9e9c03d08a90998ab83267838d2e74a0593bc81f74cdf734519a05b351c0e5488c68dd810e6e9142ccc1e2f4a7f464297eb340e27acc6b9d64e12e38cce8492b3d939140b5a9e149a75597f10a23874c84323a07cdd657274378f887c85c4259b9c04cd33ba58ed630ef2a744f8e19dd34843dff331d2a6be7e2332c599289cd248a611c73d7481cd4a9bd43449a3836f14b2af18a1739e17999e4c67e85cc5bcecabb14185e5bcaff3c96098f03dc5aba819f29587758f49f940585354a2a780830528d68ccd166920dadcaa25cab5fc1907272a826aba3f08bc6b88757776812ecb6c7cec69a223ec0a13a7b62a2349a0f63ed7a27a3b15ba21d71fe6864ec6e089ae17cadd433fa3138f7ee24353c11365818f8fc34f43a05542d18efaac24bfccc1f748a0cc1a67ad379468b76fd34973dba785f5c91d618333cd810fe0700d1bbc8422029782628070a624c52c5309a4a64d625b11f8033ab28df34a1add297517fcc06b92b6817b3c5144438cf260867c57bde68c8c4b82e6a135ef676a52fbae5708002a404e6189a60e2836de565ad1b29e3819e5ed49f6810bcb28e1bd6de57306f94b79d9dae1cc4624d2a068499beef81cd5fe4b76dcbfff2a2008001d002001976128c6d5a934533f28b9914d2480aab2a8c1ab03d212529ce8b27640a716002d00020101002b000706caca03040303001b00030200015a5a000100"
|
||||||
|
|
||||||
|
func decodeClientHello(t *testing.T) []byte {
|
||||||
|
t.Helper()
|
||||||
|
payload, err := hex.DecodeString(realClientHello)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertConsistent(t *testing.T, payload []byte, expectedSNI string) {
|
||||||
|
t.Helper()
|
||||||
|
serverName := tf.IndexTLSServerName(payload)
|
||||||
|
require.NotNil(t, serverName, "parser should find SNI in rewritten payload")
|
||||||
|
require.Equal(t, expectedSNI, serverName.ServerName)
|
||||||
|
require.Equal(t, expectedSNI, string(payload[serverName.Index:serverName.Index+serverName.Length]))
|
||||||
|
// Record length must equal len(payload) - 5.
|
||||||
|
recordLen := binary.BigEndian.Uint16(payload[3:5])
|
||||||
|
require.Equal(t, len(payload)-5, int(recordLen), "record length must equal payload - 5")
|
||||||
|
// Handshake length must equal len(payload) - 5 - 4.
|
||||||
|
handshakeLen := int(payload[6])<<16 | int(payload[7])<<8 | int(payload[8])
|
||||||
|
require.Equal(t, len(payload)-5-4, handshakeLen, "handshake length must equal payload - 9")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteSNI_ShorterReplacement(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
payload := decodeClientHello(t)
|
||||||
|
out, err := rewriteSNI(payload, "a.io")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, out, len(payload)-6) // original "github.com" is 10 bytes, "a.io" is 4 bytes.
|
||||||
|
assertConsistent(t, out, "a.io")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteSNI_SameLengthReplacement(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
payload := decodeClientHello(t)
|
||||||
|
out, err := rewriteSNI(payload, "example.co")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, out, len(payload))
|
||||||
|
assertConsistent(t, out, "example.co")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteSNI_LongerReplacement(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
payload := decodeClientHello(t)
|
||||||
|
out, err := rewriteSNI(payload, "letsencrypt.org")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, out, len(payload)+5) // "letsencrypt.org" is 15, original 10, delta 5.
|
||||||
|
assertConsistent(t, out, "letsencrypt.org")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteSNI_NoSNIReturnsError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Truncated payload — not a valid ClientHello.
|
||||||
|
_, err := rewriteSNI([]byte{0x16, 0x03, 0x01, 0x00, 0x01, 0x01}, "x.com")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteSNI_DoesNotMutateInput(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
payload := decodeClientHello(t)
|
||||||
|
original := append([]byte(nil), payload...)
|
||||||
|
_, err := rewriteSNI(payload, "letsencrypt.org")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, original, payload, "input payload must not be mutated")
|
||||||
|
}
|
||||||
126
common/tlsspoof/conn_test.go
Normal file
126
common/tlsspoof/conn_test.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
tf "github.com/sagernet/sing-box/common/tlsfragment"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeSpoofer struct {
|
||||||
|
injected [][]byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeSpoofer) Inject(payload []byte) error {
|
||||||
|
if f.err != nil {
|
||||||
|
return f.err
|
||||||
|
}
|
||||||
|
f.injected = append(f.injected, append([]byte(nil), payload...))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeSpoofer) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readAll(t *testing.T, conn net.Conn) []byte {
|
||||||
|
t.Helper()
|
||||||
|
data, err := io.ReadAll(conn)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConn_Write_InjectsThenForwards(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
payload, err := hex.DecodeString(realClientHello)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
client, server := net.Pipe()
|
||||||
|
spoofer := &fakeSpoofer{}
|
||||||
|
wrapped := NewConn(client, spoofer, "letsencrypt.org")
|
||||||
|
|
||||||
|
serverRead := make(chan []byte, 1)
|
||||||
|
go func() {
|
||||||
|
serverRead <- readAll(t, server)
|
||||||
|
}()
|
||||||
|
|
||||||
|
n, err := wrapped.Write(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, len(payload), n)
|
||||||
|
require.NoError(t, wrapped.Close())
|
||||||
|
|
||||||
|
forwarded := <-serverRead
|
||||||
|
require.Equal(t, payload, forwarded, "underlying conn must receive the real ClientHello unchanged")
|
||||||
|
require.Len(t, spoofer.injected, 1)
|
||||||
|
|
||||||
|
injected := spoofer.injected[0]
|
||||||
|
serverName := tf.IndexTLSServerName(injected)
|
||||||
|
require.NotNil(t, serverName, "injected payload must parse as ClientHello")
|
||||||
|
require.Equal(t, "letsencrypt.org", serverName.ServerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConn_Write_SecondWriteDoesNotInject(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
payload, err := hex.DecodeString(realClientHello)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
client, server := net.Pipe()
|
||||||
|
spoofer := &fakeSpoofer{}
|
||||||
|
wrapped := NewConn(client, spoofer, "letsencrypt.org")
|
||||||
|
|
||||||
|
serverRead := make(chan []byte, 1)
|
||||||
|
go func() {
|
||||||
|
serverRead <- readAll(t, server)
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = wrapped.Write(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = wrapped.Write([]byte("second"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, wrapped.Close())
|
||||||
|
|
||||||
|
forwarded := <-serverRead
|
||||||
|
require.Equal(t, append(append([]byte(nil), payload...), []byte("second")...), forwarded)
|
||||||
|
require.Len(t, spoofer.injected, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConn_Write_NonClientHelloReturnsError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
client, server := net.Pipe()
|
||||||
|
defer client.Close()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
spoofer := &fakeSpoofer{}
|
||||||
|
wrapped := NewConn(client, spoofer, "letsencrypt.org")
|
||||||
|
|
||||||
|
_, err := wrapped.Write([]byte("not a ClientHello"))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Empty(t, spoofer.injected)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseMethod(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
cases := map[string]struct {
|
||||||
|
want Method
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
"": {MethodWrongSequence, true},
|
||||||
|
"wrong-sequence": {MethodWrongSequence, true},
|
||||||
|
"wrong-checksum": {MethodWrongChecksum, true},
|
||||||
|
"nonsense": {0, false},
|
||||||
|
}
|
||||||
|
for input, expected := range cases {
|
||||||
|
m, err := ParseMethod(input)
|
||||||
|
if !expected.ok {
|
||||||
|
require.Error(t, err, "input=%q", input)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
require.NoError(t, err, "input=%q", input)
|
||||||
|
require.Equal(t, expected.want, m, "input=%q", input)
|
||||||
|
}
|
||||||
|
}
|
||||||
29
common/tlsspoof/endpoints.go
Normal file
29
common/tlsspoof/endpoints.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The returned addresses are v4-unmapped and share the same family.
|
||||||
|
func tcpEndpoints(conn net.Conn) (*net.TCPConn, netip.AddrPort, netip.AddrPort, error) {
|
||||||
|
tcpConn, isTCP := common.Cast[*net.TCPConn](conn)
|
||||||
|
if !isTCP {
|
||||||
|
return nil, netip.AddrPort{}, netip.AddrPort{}, E.New("tls_spoof: underlying conn is not *net.TCPConn")
|
||||||
|
}
|
||||||
|
local := M.AddrPortFromNet(tcpConn.LocalAddr())
|
||||||
|
remote := M.AddrPortFromNet(tcpConn.RemoteAddr())
|
||||||
|
if !local.IsValid() || !remote.IsValid() {
|
||||||
|
return nil, netip.AddrPort{}, netip.AddrPort{}, E.New("tls_spoof: invalid conn address")
|
||||||
|
}
|
||||||
|
local = netip.AddrPortFrom(local.Addr().Unmap(), local.Port())
|
||||||
|
remote = netip.AddrPortFrom(remote.Addr().Unmap(), remote.Port())
|
||||||
|
if local.Addr().Is4() != remote.Addr().Is4() {
|
||||||
|
return nil, netip.AddrPort{}, netip.AddrPort{}, E.New("tls_spoof: local/remote address family mismatch")
|
||||||
|
}
|
||||||
|
return tcpConn, local, remote, nil
|
||||||
|
}
|
||||||
5
common/tlsspoof/integration_darwin_test.go
Normal file
5
common/tlsspoof/integration_darwin_test.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
//go:build darwin
|
||||||
|
|
||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
const loopbackInterface = "lo0"
|
||||||
5
common/tlsspoof/integration_linux_test.go
Normal file
5
common/tlsspoof/integration_linux_test.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
const loopbackInterface = "lo"
|
||||||
112
common/tlsspoof/integration_test.go
Normal file
112
common/tlsspoof/integration_test.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
//go:build linux || darwin
|
||||||
|
|
||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func requireRoot(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
if os.Geteuid() != 0 {
|
||||||
|
t.Fatal("integration test requires root")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func tcpdumpObserver(t *testing.T, iface string, port uint16, needle string, do func(), wait time.Duration) bool {
|
||||||
|
t.Helper()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), wait)
|
||||||
|
defer cancel()
|
||||||
|
cmd := exec.CommandContext(ctx, "tcpdump", "-i", iface, "-n", "-A", "-l",
|
||||||
|
"-s", "4096", fmt.Sprintf("tcp and port %d", port))
|
||||||
|
cmd.Cancel = func() error {
|
||||||
|
return cmd.Process.Signal(os.Interrupt)
|
||||||
|
}
|
||||||
|
stdout, err := cmd.StdoutPipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
stderr, err := cmd.StderrPipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, cmd.Start())
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = cmd.Process.Signal(os.Interrupt)
|
||||||
|
_ = cmd.Wait()
|
||||||
|
})
|
||||||
|
|
||||||
|
ready := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
scanner := bufio.NewScanner(stderr)
|
||||||
|
for scanner.Scan() {
|
||||||
|
if strings.Contains(scanner.Text(), "listening on") {
|
||||||
|
close(ready)
|
||||||
|
io.Copy(io.Discard, stderr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ready:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("tcpdump did not attach within 2s")
|
||||||
|
}
|
||||||
|
|
||||||
|
var found atomic.Bool
|
||||||
|
readerDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(readerDone)
|
||||||
|
scanner := bufio.NewScanner(stdout)
|
||||||
|
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||||
|
for scanner.Scan() {
|
||||||
|
if strings.Contains(scanner.Text(), needle) {
|
||||||
|
found.Store(true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
do()
|
||||||
|
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
_ = cmd.Process.Signal(os.Interrupt)
|
||||||
|
<-readerDone
|
||||||
|
return found.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialLocalEchoServer(t *testing.T) (client net.Conn, serverPort uint16) {
|
||||||
|
t.Helper()
|
||||||
|
listener, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accepted := make(chan net.Conn, 1)
|
||||||
|
go func() {
|
||||||
|
c, err := listener.Accept()
|
||||||
|
if err == nil {
|
||||||
|
accepted <- c
|
||||||
|
}
|
||||||
|
close(accepted)
|
||||||
|
}()
|
||||||
|
addr := listener.Addr().(*net.TCPAddr)
|
||||||
|
client, err = net.Dial("tcp4", addr.String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
server := <-accepted
|
||||||
|
require.NotNil(t, server)
|
||||||
|
|
||||||
|
go io.Copy(io.Discard, server)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
client.Close()
|
||||||
|
server.Close()
|
||||||
|
listener.Close()
|
||||||
|
})
|
||||||
|
return client, uint16(addr.Port)
|
||||||
|
}
|
||||||
100
common/tlsspoof/integration_unix_test.go
Normal file
100
common/tlsspoof/integration_unix_test.go
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
//go:build linux || darwin
|
||||||
|
|
||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIntegrationSpoofer_WrongChecksum(t *testing.T) {
|
||||||
|
requireRoot(t)
|
||||||
|
client, serverPort := dialLocalEchoServer(t)
|
||||||
|
spoofer, err := NewSpoofer(client, MethodWrongChecksum)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer spoofer.Close()
|
||||||
|
|
||||||
|
payload, err := hex.DecodeString(realClientHello)
|
||||||
|
require.NoError(t, err)
|
||||||
|
fake, err := rewriteSNI(payload, "letsencrypt.org")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
captured := tcpdumpObserver(t, loopbackInterface, serverPort, "letsencrypt.org", func() {
|
||||||
|
require.NoError(t, spoofer.Inject(fake))
|
||||||
|
}, 3*time.Second)
|
||||||
|
require.True(t, captured, "injected fake ClientHello must be observable on loopback")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegrationSpoofer_WrongSequence(t *testing.T) {
|
||||||
|
requireRoot(t)
|
||||||
|
client, serverPort := dialLocalEchoServer(t)
|
||||||
|
spoofer, err := NewSpoofer(client, MethodWrongSequence)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer spoofer.Close()
|
||||||
|
|
||||||
|
payload, err := hex.DecodeString(realClientHello)
|
||||||
|
require.NoError(t, err)
|
||||||
|
fake, err := rewriteSNI(payload, "letsencrypt.org")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
captured := tcpdumpObserver(t, loopbackInterface, serverPort, "letsencrypt.org", func() {
|
||||||
|
require.NoError(t, spoofer.Inject(fake))
|
||||||
|
}, 3*time.Second)
|
||||||
|
require.True(t, captured, "injected fake ClientHello must be observable on loopback")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loopback bypasses TCP checksum validation, so wrong-sequence is used instead.
|
||||||
|
func TestIntegrationConn_InjectsThenForwardsRealCH(t *testing.T) {
|
||||||
|
requireRoot(t)
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
serverReceived := make(chan []byte, 1)
|
||||||
|
go func() {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
got, _ := io.ReadAll(conn)
|
||||||
|
serverReceived <- got
|
||||||
|
}()
|
||||||
|
|
||||||
|
addr := listener.Addr().(*net.TCPAddr)
|
||||||
|
serverPort := uint16(addr.Port)
|
||||||
|
client, err := net.Dial("tcp4", addr.String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
client.Close()
|
||||||
|
listener.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
spoofer, err := NewSpoofer(client, MethodWrongSequence)
|
||||||
|
require.NoError(t, err)
|
||||||
|
wrapped := NewConn(client, spoofer, "letsencrypt.org")
|
||||||
|
|
||||||
|
payload, err := hex.DecodeString(realClientHello)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
captured := tcpdumpObserver(t, loopbackInterface, serverPort, "letsencrypt.org", func() {
|
||||||
|
n, err := wrapped.Write(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, len(payload), n)
|
||||||
|
}, 3*time.Second)
|
||||||
|
require.True(t, captured, "fake ClientHello with letsencrypt.org SNI must be on the wire")
|
||||||
|
|
||||||
|
_ = wrapped.Close()
|
||||||
|
select {
|
||||||
|
case got := <-serverReceived:
|
||||||
|
require.Equal(t, payload, got, "server must receive real ClientHello unchanged (wrong-sequence fake must be dropped)")
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("echo server did not receive real ClientHello")
|
||||||
|
}
|
||||||
|
}
|
||||||
139
common/tlsspoof/integration_windows_test.go
Normal file
139
common/tlsspoof/integration_windows_test.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
//go:build windows && (amd64 || 386)
|
||||||
|
|
||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newSpoofer(t *testing.T, conn net.Conn, method Method) Spoofer {
|
||||||
|
t.Helper()
|
||||||
|
spoofer, err := NewSpoofer(conn, method)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return spoofer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic lifecycle: opening a spoofer against a live TCP conn installs
|
||||||
|
// the driver, spawns run(), then shuts down cleanly without ever
|
||||||
|
// injecting. Exercises the close path that cancels an in-flight Recv.
|
||||||
|
func TestIntegrationSpooferOpenClose(t *testing.T) {
|
||||||
|
listener, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() { listener.Close() })
|
||||||
|
|
||||||
|
accepted := make(chan net.Conn, 1)
|
||||||
|
go func() {
|
||||||
|
c, _ := listener.Accept()
|
||||||
|
accepted <- c
|
||||||
|
}()
|
||||||
|
client, err := net.Dial("tcp4", listener.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() { client.Close() })
|
||||||
|
server := <-accepted
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if server != nil {
|
||||||
|
server.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
spoofer := newSpoofer(t, client, MethodWrongSequence)
|
||||||
|
require.NoError(t, spoofer.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
// End-to-end: Conn.Write injects a fake ClientHello with a rewritten
|
||||||
|
// SNI, then forwards the real ClientHello. With wrong-sequence, the
|
||||||
|
// fake lands before the connection's send-next sequence — the peer TCP
|
||||||
|
// stack treats it as already-received and only surfaces the real bytes
|
||||||
|
// to the echo server.
|
||||||
|
func TestIntegrationConnInjectsThenForwardsRealCH(t *testing.T) {
|
||||||
|
listener, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() { listener.Close() })
|
||||||
|
|
||||||
|
serverReceived := make(chan []byte, 1)
|
||||||
|
go func() {
|
||||||
|
conn, acceptErr := listener.Accept()
|
||||||
|
if acceptErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
got, _ := io.ReadAll(conn)
|
||||||
|
serverReceived <- got
|
||||||
|
}()
|
||||||
|
|
||||||
|
client, err := net.Dial("tcp4", listener.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() { client.Close() })
|
||||||
|
|
||||||
|
spoofer := newSpoofer(t, client, MethodWrongSequence)
|
||||||
|
wrapped := NewConn(client, spoofer, "letsencrypt.org")
|
||||||
|
|
||||||
|
payload, err := hex.DecodeString(realClientHello)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
n, err := wrapped.Write(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, len(payload), n)
|
||||||
|
_ = wrapped.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-serverReceived:
|
||||||
|
require.Equal(t, payload, got,
|
||||||
|
"server must receive real ClientHello unchanged (wrong-sequence fake must be dropped)")
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("echo server did not receive real ClientHello within 5s")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject before any kernel payload: stages the fake, then Write flushes
|
||||||
|
// the real CH. Same terminal expectation as the Conn variant but via the
|
||||||
|
// Spoofer primitive directly.
|
||||||
|
func TestIntegrationSpooferInjectThenWrite(t *testing.T) {
|
||||||
|
listener, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() { listener.Close() })
|
||||||
|
|
||||||
|
serverReceived := make(chan []byte, 1)
|
||||||
|
go func() {
|
||||||
|
conn, acceptErr := listener.Accept()
|
||||||
|
if acceptErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
got, _ := io.ReadAll(conn)
|
||||||
|
serverReceived <- got
|
||||||
|
}()
|
||||||
|
|
||||||
|
client, err := net.Dial("tcp4", listener.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() { client.Close() })
|
||||||
|
|
||||||
|
spoofer := newSpoofer(t, client, MethodWrongSequence)
|
||||||
|
t.Cleanup(func() { spoofer.Close() })
|
||||||
|
|
||||||
|
payload, err := hex.DecodeString(realClientHello)
|
||||||
|
require.NoError(t, err)
|
||||||
|
fake, err := rewriteSNI(payload, "letsencrypt.org")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, spoofer.Inject(fake))
|
||||||
|
|
||||||
|
n, err := client.Write(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, len(payload), n)
|
||||||
|
_ = client.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-serverReceived:
|
||||||
|
require.Equal(t, payload, got)
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("echo server did not receive real ClientHello within 5s")
|
||||||
|
}
|
||||||
|
}
|
||||||
100
common/tlsspoof/packet.go
Normal file
100
common/tlsspoof/packet.go
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-tun/gtcpip/checksum"
|
||||||
|
"github.com/sagernet/sing-tun/gtcpip/header"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultTTL uint8 = 64
|
||||||
|
defaultWindowSize uint16 = 0xFFFF
|
||||||
|
tcpHeaderLen = header.TCPMinimumSize
|
||||||
|
)
|
||||||
|
|
||||||
|
func buildTCPSegment(
|
||||||
|
src netip.AddrPort,
|
||||||
|
dst netip.AddrPort,
|
||||||
|
seqNum uint32,
|
||||||
|
ackNum uint32,
|
||||||
|
payload []byte,
|
||||||
|
corruptChecksum bool,
|
||||||
|
) []byte {
|
||||||
|
if src.Addr().Is4() != dst.Addr().Is4() {
|
||||||
|
panic("tlsspoof: mixed IPv4/IPv6 address family")
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
frame []byte
|
||||||
|
ipHeaderLen int
|
||||||
|
)
|
||||||
|
if src.Addr().Is4() {
|
||||||
|
ipHeaderLen = header.IPv4MinimumSize
|
||||||
|
frame = make([]byte, ipHeaderLen+tcpHeaderLen+len(payload))
|
||||||
|
ip := header.IPv4(frame[:ipHeaderLen])
|
||||||
|
ip.Encode(&header.IPv4Fields{
|
||||||
|
TotalLength: uint16(len(frame)),
|
||||||
|
ID: 0,
|
||||||
|
TTL: defaultTTL,
|
||||||
|
Protocol: uint8(header.TCPProtocolNumber),
|
||||||
|
SrcAddr: src.Addr(),
|
||||||
|
DstAddr: dst.Addr(),
|
||||||
|
})
|
||||||
|
ip.SetChecksum(^ip.CalculateChecksum())
|
||||||
|
} else {
|
||||||
|
ipHeaderLen = header.IPv6MinimumSize
|
||||||
|
frame = make([]byte, ipHeaderLen+tcpHeaderLen+len(payload))
|
||||||
|
ip := header.IPv6(frame[:ipHeaderLen])
|
||||||
|
ip.Encode(&header.IPv6Fields{
|
||||||
|
PayloadLength: uint16(tcpHeaderLen + len(payload)),
|
||||||
|
TransportProtocol: header.TCPProtocolNumber,
|
||||||
|
HopLimit: defaultTTL,
|
||||||
|
SrcAddr: src.Addr(),
|
||||||
|
DstAddr: dst.Addr(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
encodeTCP(frame, ipHeaderLen, src, dst, seqNum, ackNum, payload, corruptChecksum)
|
||||||
|
return frame
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeTCP(frame []byte, ipHeaderLen int, src, dst netip.AddrPort, seqNum, ackNum uint32, payload []byte, corruptChecksum bool) {
|
||||||
|
tcp := header.TCP(frame[ipHeaderLen:])
|
||||||
|
copy(frame[ipHeaderLen+tcpHeaderLen:], payload)
|
||||||
|
tcp.Encode(&header.TCPFields{
|
||||||
|
SrcPort: src.Port(),
|
||||||
|
DstPort: dst.Port(),
|
||||||
|
SeqNum: seqNum,
|
||||||
|
AckNum: ackNum,
|
||||||
|
DataOffset: tcpHeaderLen,
|
||||||
|
Flags: header.TCPFlagAck | header.TCPFlagPsh,
|
||||||
|
WindowSize: defaultWindowSize,
|
||||||
|
})
|
||||||
|
applyTCPChecksum(tcp, src.Addr(), dst.Addr(), payload, corruptChecksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildSpoofFrame(method Method, src, dst netip.AddrPort, sendNext, receiveNext uint32, payload []byte) ([]byte, error) {
|
||||||
|
var sequence uint32
|
||||||
|
corrupt := false
|
||||||
|
switch method {
|
||||||
|
case MethodWrongSequence:
|
||||||
|
sequence = sendNext - uint32(len(payload))
|
||||||
|
case MethodWrongChecksum:
|
||||||
|
sequence = sendNext
|
||||||
|
corrupt = true
|
||||||
|
default:
|
||||||
|
return nil, E.New("tls_spoof: unknown method ", method)
|
||||||
|
}
|
||||||
|
return buildTCPSegment(src, dst, sequence, receiveNext, payload, corrupt), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyTCPChecksum(tcp header.TCP, srcAddr, dstAddr netip.Addr, payload []byte, corrupt bool) {
|
||||||
|
tcpLen := tcpHeaderLen + len(payload)
|
||||||
|
pseudo := header.PseudoHeaderChecksum(header.TCPProtocolNumber, srcAddr.AsSlice(), dstAddr.AsSlice(), uint16(tcpLen))
|
||||||
|
payloadChecksum := checksum.Checksum(payload, 0)
|
||||||
|
tcpChecksum := ^tcp.CalculateChecksum(checksum.Combine(pseudo, payloadChecksum))
|
||||||
|
if corrupt {
|
||||||
|
tcpChecksum ^= 0xFFFF
|
||||||
|
}
|
||||||
|
tcp.SetChecksum(tcpChecksum)
|
||||||
|
}
|
||||||
77
common/tlsspoof/packet_test.go
Normal file
77
common/tlsspoof/packet_test.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-tun/gtcpip"
|
||||||
|
"github.com/sagernet/sing-tun/gtcpip/checksum"
|
||||||
|
"github.com/sagernet/sing-tun/gtcpip/header"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildTCPSegment_IPv4_ValidChecksum(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
src := netip.MustParseAddrPort("10.0.0.1:54321")
|
||||||
|
dst := netip.MustParseAddrPort("1.2.3.4:443")
|
||||||
|
payload := []byte("fake-client-hello")
|
||||||
|
frame := buildTCPSegment(src, dst, 100_000, 200_000, payload, false)
|
||||||
|
|
||||||
|
ip := header.IPv4(frame[:header.IPv4MinimumSize])
|
||||||
|
require.True(t, ip.IsChecksumValid())
|
||||||
|
|
||||||
|
tcp := header.TCP(frame[header.IPv4MinimumSize:])
|
||||||
|
payloadChecksum := checksum.Checksum(payload, 0)
|
||||||
|
require.True(t, tcp.IsChecksumValid(
|
||||||
|
tcpip.AddrFrom4(src.Addr().As4()),
|
||||||
|
tcpip.AddrFrom4(dst.Addr().As4()),
|
||||||
|
payloadChecksum,
|
||||||
|
uint16(len(payload)),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildTCPSegment_IPv4_CorruptChecksum(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
src := netip.MustParseAddrPort("10.0.0.1:54321")
|
||||||
|
dst := netip.MustParseAddrPort("1.2.3.4:443")
|
||||||
|
payload := []byte("fake-client-hello")
|
||||||
|
frame := buildTCPSegment(src, dst, 100_000, 200_000, payload, true)
|
||||||
|
|
||||||
|
tcp := header.TCP(frame[header.IPv4MinimumSize:])
|
||||||
|
payloadChecksum := checksum.Checksum(payload, 0)
|
||||||
|
require.False(t, tcp.IsChecksumValid(
|
||||||
|
tcpip.AddrFrom4(src.Addr().As4()),
|
||||||
|
tcpip.AddrFrom4(dst.Addr().As4()),
|
||||||
|
payloadChecksum,
|
||||||
|
uint16(len(payload)),
|
||||||
|
))
|
||||||
|
// IP checksum must still be valid so the router forwards the packet.
|
||||||
|
require.True(t, header.IPv4(frame[:header.IPv4MinimumSize]).IsChecksumValid())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildTCPSegment_IPv6_ValidChecksum(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
src := netip.MustParseAddrPort("[fe80::1]:54321")
|
||||||
|
dst := netip.MustParseAddrPort("[2606:4700::1]:443")
|
||||||
|
payload := []byte("fake-client-hello")
|
||||||
|
frame := buildTCPSegment(src, dst, 0xDEADBEEF, 0x12345678, payload, false)
|
||||||
|
|
||||||
|
tcp := header.TCP(frame[header.IPv6MinimumSize:])
|
||||||
|
payloadChecksum := checksum.Checksum(payload, 0)
|
||||||
|
require.True(t, tcp.IsChecksumValid(
|
||||||
|
tcpip.AddrFrom16(src.Addr().As16()),
|
||||||
|
tcpip.AddrFrom16(dst.Addr().As16()),
|
||||||
|
payloadChecksum,
|
||||||
|
uint16(len(payload)),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildTCPSegment_MixedFamilyPanics(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
src := netip.MustParseAddrPort("10.0.0.1:54321")
|
||||||
|
dst := netip.MustParseAddrPort("[2606:4700::1]:443")
|
||||||
|
require.Panics(t, func() {
|
||||||
|
buildTCPSegment(src, dst, 0, 0, nil, false)
|
||||||
|
})
|
||||||
|
}
|
||||||
161
common/tlsspoof/raw_darwin.go
Normal file
161
common/tlsspoof/raw_darwin.go
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
const PlatformSupported = true
|
||||||
|
|
||||||
|
// Offsets into xinpcb_n within each net.inet.tcp.pcblist_n record, identical
|
||||||
|
// to the values used by common/process/searcher_darwin_shared.go.
|
||||||
|
const (
|
||||||
|
darwinXinpgenSize = 24
|
||||||
|
darwinXsocketOffset = 104
|
||||||
|
darwinXinpcbForeignPort = 16
|
||||||
|
darwinXinpcbLocalPort = 18
|
||||||
|
darwinXinpcbVFlag = 44
|
||||||
|
darwinXinpcbForeignAddr = 48
|
||||||
|
darwinXinpcbLocalAddr = 64
|
||||||
|
darwinXinpcbIPv4Offset = 12
|
||||||
|
|
||||||
|
darwinTCPExtraSize = 208
|
||||||
|
|
||||||
|
darwinXtcpcbSndNxtOffset = 56
|
||||||
|
darwinXtcpcbRcvNxtOffset = 80
|
||||||
|
)
|
||||||
|
|
||||||
|
var darwinStructSize = sync.OnceValue(func() int {
|
||||||
|
value, _ := syscall.Sysctl("kern.osrelease")
|
||||||
|
major, _, _ := strings.Cut(value, ".")
|
||||||
|
n, _ := strconv.ParseInt(major, 10, 64)
|
||||||
|
if n >= 22 {
|
||||||
|
return 408
|
||||||
|
}
|
||||||
|
return 384
|
||||||
|
})
|
||||||
|
|
||||||
|
type darwinSpoofer struct {
|
||||||
|
method Method
|
||||||
|
src netip.AddrPort
|
||||||
|
dst netip.AddrPort
|
||||||
|
rawFD int
|
||||||
|
rawSockAddr unix.Sockaddr
|
||||||
|
sendNext uint32
|
||||||
|
receiveNext uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRawSpoofer(conn net.Conn, method Method) (Spoofer, error) {
|
||||||
|
_, src, dst, err := tcpEndpoints(conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fd, sockaddr, err := openDarwinRawSocket(dst)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sendNext, receiveNext, err := readDarwinTCPSequence(src, dst)
|
||||||
|
if err != nil {
|
||||||
|
unix.Close(fd)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &darwinSpoofer{
|
||||||
|
method: method,
|
||||||
|
src: src,
|
||||||
|
dst: dst,
|
||||||
|
rawFD: fd,
|
||||||
|
rawSockAddr: sockaddr,
|
||||||
|
sendNext: sendNext,
|
||||||
|
receiveNext: receiveNext,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readDarwinTCPSequence scans net.inet.tcp.pcblist_n for the PCB that matches
|
||||||
|
// src -> dst and returns (snd_nxt, rcv_nxt). These live in xtcpcb_n at the end
|
||||||
|
// of each record; see darwin-xnu bsd/netinet/in_pcblist.c:get_pcblist_n.
|
||||||
|
func readDarwinTCPSequence(src, dst netip.AddrPort) (uint32, uint32, error) {
|
||||||
|
buffer, err := unix.SysctlRaw("net.inet.tcp.pcblist_n")
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, E.Cause(err, "sysctl net.inet.tcp.pcblist_n")
|
||||||
|
}
|
||||||
|
structSize := darwinStructSize()
|
||||||
|
itemSize := structSize + darwinTCPExtraSize
|
||||||
|
for i := darwinXinpgenSize; i+itemSize <= len(buffer); i += itemSize {
|
||||||
|
inpcb := buffer[i : i+darwinXsocketOffset]
|
||||||
|
xtcpcb := buffer[i+structSize : i+itemSize]
|
||||||
|
localPort := binary.BigEndian.Uint16(inpcb[darwinXinpcbLocalPort : darwinXinpcbLocalPort+2])
|
||||||
|
remotePort := binary.BigEndian.Uint16(inpcb[darwinXinpcbForeignPort : darwinXinpcbForeignPort+2])
|
||||||
|
if localPort != src.Port() || remotePort != dst.Port() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
versionFlag := inpcb[darwinXinpcbVFlag]
|
||||||
|
var localAddr, remoteAddr netip.Addr
|
||||||
|
switch {
|
||||||
|
case versionFlag&0x1 != 0:
|
||||||
|
localAddr = netip.AddrFrom4([4]byte(inpcb[darwinXinpcbLocalAddr+darwinXinpcbIPv4Offset : darwinXinpcbLocalAddr+darwinXinpcbIPv4Offset+4]))
|
||||||
|
remoteAddr = netip.AddrFrom4([4]byte(inpcb[darwinXinpcbForeignAddr+darwinXinpcbIPv4Offset : darwinXinpcbForeignAddr+darwinXinpcbIPv4Offset+4]))
|
||||||
|
case versionFlag&0x2 != 0:
|
||||||
|
localAddr = netip.AddrFrom16([16]byte(inpcb[darwinXinpcbLocalAddr : darwinXinpcbLocalAddr+16]))
|
||||||
|
remoteAddr = netip.AddrFrom16([16]byte(inpcb[darwinXinpcbForeignAddr : darwinXinpcbForeignAddr+16]))
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if localAddr.Unmap() != src.Addr() || remoteAddr.Unmap() != dst.Addr() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sendNext := binary.NativeEndian.Uint32(xtcpcb[darwinXtcpcbSndNxtOffset : darwinXtcpcbSndNxtOffset+4])
|
||||||
|
receiveNext := binary.NativeEndian.Uint32(xtcpcb[darwinXtcpcbRcvNxtOffset : darwinXtcpcbRcvNxtOffset+4])
|
||||||
|
return sendNext, receiveNext, nil
|
||||||
|
}
|
||||||
|
return 0, 0, E.New("tls_spoof: connection ", src, "->", dst, " not found in pcblist_n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func openDarwinRawSocket(dst netip.AddrPort) (int, unix.Sockaddr, error) {
|
||||||
|
if !dst.Addr().Is4() {
|
||||||
|
// macOS does not expose IPV6_HDRINCL; raw AF_INET6 injection would
|
||||||
|
// require either BPF link-layer writes or kernel-side IPv6 header
|
||||||
|
// synthesis, neither of which is implemented here.
|
||||||
|
return -1, nil, E.New("tls_spoof: IPv6 not supported on darwin")
|
||||||
|
}
|
||||||
|
return openIPv4RawSocket(dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *darwinSpoofer) Inject(payload []byte) error {
|
||||||
|
frame, err := buildSpoofFrame(s.method, s.src, s.dst, s.sendNext, s.receiveNext, payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Darwin inherits the historical BSD quirk: with IP_HDRINCL the kernel
|
||||||
|
// expects ip_len and ip_off in host byte order, not network byte order.
|
||||||
|
// Apple's rip_output swaps them back before transmission. This does not
|
||||||
|
// apply to IPv6.
|
||||||
|
if s.src.Addr().Is4() {
|
||||||
|
totalLen := binary.BigEndian.Uint16(frame[2:4])
|
||||||
|
binary.NativeEndian.PutUint16(frame[2:4], totalLen)
|
||||||
|
fragOff := binary.BigEndian.Uint16(frame[6:8])
|
||||||
|
binary.NativeEndian.PutUint16(frame[6:8], fragOff)
|
||||||
|
}
|
||||||
|
err = unix.Sendto(s.rawFD, frame, 0, s.rawSockAddr)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "sendto raw socket")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *darwinSpoofer) Close() error {
|
||||||
|
if s.rawFD < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
err := unix.Close(s.rawFD)
|
||||||
|
s.rawFD = -1
|
||||||
|
return err
|
||||||
|
}
|
||||||
127
common/tlsspoof/raw_linux.go
Normal file
127
common/tlsspoof/raw_linux.go
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common/control"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
const PlatformSupported = true
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Values of enum { TCP_NO_QUEUE, TCP_RECV_QUEUE, TCP_SEND_QUEUE } from
|
||||||
|
// include/net/tcp.h; not exported by golang.org/x/sys/unix.
|
||||||
|
tcpRecvQueue = 1
|
||||||
|
tcpSendQueue = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type linuxSpoofer struct {
|
||||||
|
method Method
|
||||||
|
src netip.AddrPort
|
||||||
|
dst netip.AddrPort
|
||||||
|
rawFD int
|
||||||
|
rawSockAddr unix.Sockaddr
|
||||||
|
sendNext uint32
|
||||||
|
receiveNext uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRawSpoofer(conn net.Conn, method Method) (Spoofer, error) {
|
||||||
|
tcpConn, src, dst, err := tcpEndpoints(conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fd, sockaddr, err := openLinuxRawSocket(dst)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
spoofer := &linuxSpoofer{
|
||||||
|
method: method,
|
||||||
|
src: src,
|
||||||
|
dst: dst,
|
||||||
|
rawFD: fd,
|
||||||
|
rawSockAddr: sockaddr,
|
||||||
|
}
|
||||||
|
err = spoofer.loadSequenceNumbers(tcpConn)
|
||||||
|
if err != nil {
|
||||||
|
unix.Close(fd)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return spoofer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func openLinuxRawSocket(dst netip.AddrPort) (int, unix.Sockaddr, error) {
|
||||||
|
if dst.Addr().Is4() {
|
||||||
|
return openIPv4RawSocket(dst)
|
||||||
|
}
|
||||||
|
fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_TCP)
|
||||||
|
if err != nil {
|
||||||
|
return -1, nil, E.Cause(err, "open AF_INET6 SOCK_RAW")
|
||||||
|
}
|
||||||
|
err = unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_HDRINCL, 1)
|
||||||
|
if err != nil {
|
||||||
|
unix.Close(fd)
|
||||||
|
return -1, nil, E.Cause(err, "set IPV6_HDRINCL")
|
||||||
|
}
|
||||||
|
sockaddr := &unix.SockaddrInet6{Port: int(dst.Port())}
|
||||||
|
sockaddr.Addr = dst.Addr().As16()
|
||||||
|
return fd, sockaddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadSequenceNumbers puts the socket briefly into TCP_REPAIR mode to read
|
||||||
|
// snd_nxt and rcv_nxt from the kernel. TCP_REPAIR requires CAP_NET_ADMIN;
|
||||||
|
// callers must run as root or grant both CAP_NET_RAW and CAP_NET_ADMIN.
|
||||||
|
func (s *linuxSpoofer) loadSequenceNumbers(tcpConn *net.TCPConn) error {
|
||||||
|
return control.Conn(tcpConn, func(raw uintptr) error {
|
||||||
|
fd := int(raw)
|
||||||
|
err := unix.SetsockoptInt(fd, unix.IPPROTO_TCP, unix.TCP_REPAIR, unix.TCP_REPAIR_ON)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "enter TCP_REPAIR (need CAP_NET_ADMIN)")
|
||||||
|
}
|
||||||
|
defer unix.SetsockoptInt(fd, unix.IPPROTO_TCP, unix.TCP_REPAIR, unix.TCP_REPAIR_OFF)
|
||||||
|
|
||||||
|
err = unix.SetsockoptInt(fd, unix.IPPROTO_TCP, unix.TCP_REPAIR_QUEUE, tcpSendQueue)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "select TCP_SEND_QUEUE")
|
||||||
|
}
|
||||||
|
sendSequence, err := unix.GetsockoptInt(fd, unix.IPPROTO_TCP, unix.TCP_QUEUE_SEQ)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "read send queue sequence")
|
||||||
|
}
|
||||||
|
err = unix.SetsockoptInt(fd, unix.IPPROTO_TCP, unix.TCP_REPAIR_QUEUE, tcpRecvQueue)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "select TCP_RECV_QUEUE")
|
||||||
|
}
|
||||||
|
receiveSequence, err := unix.GetsockoptInt(fd, unix.IPPROTO_TCP, unix.TCP_QUEUE_SEQ)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "read recv queue sequence")
|
||||||
|
}
|
||||||
|
s.sendNext = uint32(sendSequence)
|
||||||
|
s.receiveNext = uint32(receiveSequence)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *linuxSpoofer) Inject(payload []byte) error {
|
||||||
|
frame, err := buildSpoofFrame(s.method, s.src, s.dst, s.sendNext, s.receiveNext, payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = unix.Sendto(s.rawFD, frame, 0, s.rawSockAddr)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "sendto raw socket")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *linuxSpoofer) Close() error {
|
||||||
|
if s.rawFD < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
err := unix.Close(s.rawFD)
|
||||||
|
s.rawFD = -1
|
||||||
|
return err
|
||||||
|
}
|
||||||
15
common/tlsspoof/raw_stub.go
Normal file
15
common/tlsspoof/raw_stub.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build !linux && !darwin && !(windows && (amd64 || 386))
|
||||||
|
|
||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
)
|
||||||
|
|
||||||
|
const PlatformSupported = false
|
||||||
|
|
||||||
|
func newRawSpoofer(conn net.Conn, method Method) (Spoofer, error) {
|
||||||
|
return nil, E.New("tls_spoof: unsupported platform")
|
||||||
|
}
|
||||||
26
common/tlsspoof/raw_unix.go
Normal file
26
common/tlsspoof/raw_unix.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
//go:build linux || darwin
|
||||||
|
|
||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func openIPv4RawSocket(dst netip.AddrPort) (int, unix.Sockaddr, error) {
|
||||||
|
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_TCP)
|
||||||
|
if err != nil {
|
||||||
|
return -1, nil, E.Cause(err, "open AF_INET SOCK_RAW")
|
||||||
|
}
|
||||||
|
err = unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_HDRINCL, 1)
|
||||||
|
if err != nil {
|
||||||
|
unix.Close(fd)
|
||||||
|
return -1, nil, E.Cause(err, "set IP_HDRINCL")
|
||||||
|
}
|
||||||
|
sockaddr := &unix.SockaddrInet4{Port: int(dst.Port())}
|
||||||
|
sockaddr.Addr = dst.Addr().As4()
|
||||||
|
return fd, sockaddr, nil
|
||||||
|
}
|
||||||
218
common/tlsspoof/raw_windows.go
Normal file
218
common/tlsspoof/raw_windows.go
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
//go:build windows && (amd64 || 386)
|
||||||
|
|
||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/common/windivert"
|
||||||
|
"github.com/sagernet/sing-tun/gtcpip/header"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
const PlatformSupported = true
|
||||||
|
|
||||||
|
// closeGracePeriod caps how long Close() waits for the divert goroutine to
|
||||||
|
// observe the kernel-emitted real ClientHello and perform the reorder
|
||||||
|
// (fake → real). In practice this completes in microseconds; the cap
|
||||||
|
// bounds the pathological case where the kernel buffers the packet.
|
||||||
|
const closeGracePeriod = 2 * time.Second
|
||||||
|
|
||||||
|
type windowsSpoofer struct {
|
||||||
|
method Method
|
||||||
|
src, dst netip.AddrPort
|
||||||
|
divertH *windivert.Handle
|
||||||
|
injectH *windivert.Handle
|
||||||
|
|
||||||
|
fakeReady chan []byte // buffered(1): staged by Inject
|
||||||
|
done chan struct{} // closed by run() on exit
|
||||||
|
closeOnce sync.Once
|
||||||
|
runErr atomic.Pointer[error]
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRawSpoofer(conn net.Conn, method Method) (Spoofer, error) {
|
||||||
|
_, src, dst, err := tcpEndpoints(conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
filter, err := windivert.OutboundTCP(src, dst)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
divertH, err := windivert.Open(filter, windivert.LayerNetwork, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "tls_spoof: open WinDivert")
|
||||||
|
}
|
||||||
|
injectH, err := windivert.Open(nil, windivert.LayerNetwork, 0, windivert.FlagSendOnly)
|
||||||
|
if err != nil {
|
||||||
|
divertH.Close()
|
||||||
|
return nil, E.Cause(err, "tls_spoof: open WinDivert")
|
||||||
|
}
|
||||||
|
s := &windowsSpoofer{
|
||||||
|
method: method,
|
||||||
|
src: src,
|
||||||
|
dst: dst,
|
||||||
|
divertH: divertH,
|
||||||
|
injectH: injectH,
|
||||||
|
fakeReady: make(chan []byte, 1),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
go s.run()
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *windowsSpoofer) Inject(payload []byte) error {
|
||||||
|
select {
|
||||||
|
case s.fakeReady <- payload:
|
||||||
|
return nil
|
||||||
|
case <-s.done:
|
||||||
|
if p := s.runErr.Load(); p != nil {
|
||||||
|
return *p
|
||||||
|
}
|
||||||
|
return E.New("tls_spoof: spoofer closed before Inject")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *windowsSpoofer) Close() error {
|
||||||
|
s.closeOnce.Do(func() {
|
||||||
|
// Give run() a grace window to finish handling the real packet.
|
||||||
|
select {
|
||||||
|
case <-s.done:
|
||||||
|
case <-time.After(closeGracePeriod):
|
||||||
|
// Force Recv() to return by closing the divert handle.
|
||||||
|
s.divertH.Close()
|
||||||
|
<-s.done
|
||||||
|
}
|
||||||
|
s.injectH.Close()
|
||||||
|
})
|
||||||
|
if p := s.runErr.Load(); p != nil {
|
||||||
|
return *p
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *windowsSpoofer) recordErr(err error) { s.runErr.Store(&err) }
|
||||||
|
|
||||||
|
func (s *windowsSpoofer) run() {
|
||||||
|
defer close(s.done)
|
||||||
|
defer s.divertH.Close()
|
||||||
|
|
||||||
|
buf := make([]byte, windivert.MTUMax)
|
||||||
|
for {
|
||||||
|
n, addr, err := s.divertH.Recv(buf)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, windows.ERROR_OPERATION_ABORTED) ||
|
||||||
|
errors.Is(err, windows.ERROR_NO_DATA) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.recordErr(E.Cause(err, "windivert recv"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pkt := buf[:n]
|
||||||
|
seq, ack, payloadLen, ok := parseTCPFields(pkt, addr.IPv6())
|
||||||
|
if !ok {
|
||||||
|
// Malformed / not TCP — shouldn't match our filter, but be safe.
|
||||||
|
_, _ = s.divertH.Send(pkt, &addr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if payloadLen == 0 {
|
||||||
|
// Handshake ACK, keepalive, FIN — pass through unchanged.
|
||||||
|
_, err := s.divertH.Send(pkt, &addr)
|
||||||
|
if err != nil {
|
||||||
|
s.recordErr(E.Cause(err, "windivert re-inject empty"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-empty outbound TCP payload = the real ClientHello.
|
||||||
|
var fake []byte
|
||||||
|
select {
|
||||||
|
case fake = <-s.fakeReady:
|
||||||
|
default:
|
||||||
|
// Inject() not yet called — pass through and keep observing.
|
||||||
|
_, err := s.divertH.Send(pkt, &addr)
|
||||||
|
if err != nil {
|
||||||
|
s.recordErr(E.Cause(err, "windivert re-inject early data"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
frame, err := buildSpoofFrame(s.method, s.src, s.dst, seq, ack, fake)
|
||||||
|
if err != nil {
|
||||||
|
s.recordErr(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fakeAddr := addr // inherit Outbound, IfIdx
|
||||||
|
// buildSpoofFrame emits ready-to-wire bytes. The driver recomputes
|
||||||
|
// checksums on Send when TCPChecksum/IPChecksum are 0 — which would
|
||||||
|
// overwrite the intentionally corrupt checksum in WrongChecksum mode.
|
||||||
|
// Force both to 1 to keep our bytes intact.
|
||||||
|
fakeAddr.SetIPChecksum(true)
|
||||||
|
fakeAddr.SetTCPChecksum(true)
|
||||||
|
_, err = s.injectH.Send(frame, &fakeAddr)
|
||||||
|
if err != nil {
|
||||||
|
s.recordErr(E.Cause(err, "windivert inject fake"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = s.divertH.Send(pkt, &addr)
|
||||||
|
if err != nil {
|
||||||
|
s.recordErr(E.Cause(err, "windivert re-inject real"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return // single-shot reorder complete
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTCPFields(pkt []byte, isV6 bool) (seq, ack uint32, payloadLen int, ok bool) {
|
||||||
|
if isV6 {
|
||||||
|
if len(pkt) < header.IPv6MinimumSize+header.TCPMinimumSize {
|
||||||
|
return 0, 0, 0, false
|
||||||
|
}
|
||||||
|
ip := header.IPv6(pkt)
|
||||||
|
if ip.TransportProtocol() != header.TCPProtocolNumber {
|
||||||
|
return 0, 0, 0, false
|
||||||
|
}
|
||||||
|
tcp := header.TCP(pkt[header.IPv6MinimumSize:])
|
||||||
|
tcpHdr := int(tcp.DataOffset())
|
||||||
|
if tcpHdr < header.TCPMinimumSize || header.IPv6MinimumSize+tcpHdr > len(pkt) {
|
||||||
|
return 0, 0, 0, false
|
||||||
|
}
|
||||||
|
return tcp.SequenceNumber(), tcp.AckNumber(),
|
||||||
|
len(pkt) - header.IPv6MinimumSize - tcpHdr, true
|
||||||
|
}
|
||||||
|
if len(pkt) < header.IPv4MinimumSize+header.TCPMinimumSize {
|
||||||
|
return 0, 0, 0, false
|
||||||
|
}
|
||||||
|
ip := header.IPv4(pkt)
|
||||||
|
if ip.Protocol() != uint8(header.TCPProtocolNumber) {
|
||||||
|
return 0, 0, 0, false
|
||||||
|
}
|
||||||
|
ihl := int(ip.HeaderLength())
|
||||||
|
// ihl+TCPMinimumSize guards the TCP-header field reads below; without
|
||||||
|
// this, an IPv4 packet with options (ihl>20) against a 40-byte buffer
|
||||||
|
// reads past the TCP slice when calling DataOffset.
|
||||||
|
if ihl < header.IPv4MinimumSize || ihl+header.TCPMinimumSize > len(pkt) {
|
||||||
|
return 0, 0, 0, false
|
||||||
|
}
|
||||||
|
tcp := header.TCP(pkt[ihl:])
|
||||||
|
tcpHdr := int(tcp.DataOffset())
|
||||||
|
if tcpHdr < header.TCPMinimumSize || ihl+tcpHdr > len(pkt) {
|
||||||
|
return 0, 0, 0, false
|
||||||
|
}
|
||||||
|
total := int(ip.TotalLength())
|
||||||
|
if total == 0 || total > len(pkt) {
|
||||||
|
total = len(pkt)
|
||||||
|
}
|
||||||
|
return tcp.SequenceNumber(), tcp.AckNumber(),
|
||||||
|
total - ihl - tcpHdr, true
|
||||||
|
}
|
||||||
112
common/tlsspoof/raw_windows_test.go
Normal file
112
common/tlsspoof/raw_windows_test.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
//go:build windows && (amd64 || 386)
|
||||||
|
|
||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-tun/gtcpip/header"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseTCPFieldsIPv4Valid(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
src := netip.MustParseAddrPort("10.0.0.1:54321")
|
||||||
|
dst := netip.MustParseAddrPort("1.2.3.4:443")
|
||||||
|
payload := []byte("hello")
|
||||||
|
frame := buildTCPSegment(src, dst, 1000, 2000, payload, false)
|
||||||
|
|
||||||
|
seq, ack, payloadLen, ok := parseTCPFields(frame, false)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, uint32(1000), seq)
|
||||||
|
require.Equal(t, uint32(2000), ack)
|
||||||
|
require.Equal(t, len(payload), payloadLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTCPFieldsIPv4NoPayload(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
src := netip.MustParseAddrPort("10.0.0.1:54321")
|
||||||
|
dst := netip.MustParseAddrPort("1.2.3.4:443")
|
||||||
|
frame := buildTCPSegment(src, dst, 42, 100, nil, false)
|
||||||
|
|
||||||
|
seq, ack, payloadLen, ok := parseTCPFields(frame, false)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, uint32(42), seq)
|
||||||
|
require.Equal(t, uint32(100), ack)
|
||||||
|
require.Equal(t, 0, payloadLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTCPFieldsIPv6Valid(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
src := netip.MustParseAddrPort("[fe80::1]:54321")
|
||||||
|
dst := netip.MustParseAddrPort("[2606:4700::1]:443")
|
||||||
|
payload := []byte("hello-v6")
|
||||||
|
frame := buildTCPSegment(src, dst, 0xDEADBEEF, 0x12345678, payload, false)
|
||||||
|
|
||||||
|
seq, ack, payloadLen, ok := parseTCPFields(frame, true)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, uint32(0xDEADBEEF), seq)
|
||||||
|
require.Equal(t, uint32(0x12345678), ack)
|
||||||
|
require.Equal(t, len(payload), payloadLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTCPFieldsIPv4TooShort(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
_, _, _, ok := parseTCPFields(make([]byte, header.IPv4MinimumSize+header.TCPMinimumSize-1), false)
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTCPFieldsIPv6TooShort(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
_, _, _, ok := parseTCPFields(make([]byte, header.IPv6MinimumSize+header.TCPMinimumSize-1), true)
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildTCPSegment only produces TCP; a UDP packet hitting parseTCPFields
|
||||||
|
// (for example from a mis-specified filter) must be rejected.
|
||||||
|
func TestParseTCPFieldsIPv4WrongProtocol(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
frame := make([]byte, header.IPv4MinimumSize+header.TCPMinimumSize)
|
||||||
|
ip := header.IPv4(frame[:header.IPv4MinimumSize])
|
||||||
|
ip.Encode(&header.IPv4Fields{
|
||||||
|
TotalLength: uint16(len(frame)),
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: 17, // UDP
|
||||||
|
SrcAddr: netip.MustParseAddr("10.0.0.1"),
|
||||||
|
DstAddr: netip.MustParseAddr("10.0.0.2"),
|
||||||
|
})
|
||||||
|
_, _, _, ok := parseTCPFields(frame, false)
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTCPFieldsIPv6WrongProtocol(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
frame := make([]byte, header.IPv6MinimumSize+header.TCPMinimumSize)
|
||||||
|
ip := header.IPv6(frame[:header.IPv6MinimumSize])
|
||||||
|
ip.Encode(&header.IPv6Fields{
|
||||||
|
PayloadLength: header.TCPMinimumSize,
|
||||||
|
TransportProtocol: 17, // UDP
|
||||||
|
HopLimit: 64,
|
||||||
|
SrcAddr: netip.MustParseAddr("fe80::1"),
|
||||||
|
DstAddr: netip.MustParseAddr("fe80::2"),
|
||||||
|
})
|
||||||
|
_, _, _, ok := parseTCPFields(frame, true)
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ihl > 20 must not read past the TCP slice. Build an IPv4 packet with
|
||||||
|
// options header but truncate so ihl*4 + TCPMinimumSize exceeds len.
|
||||||
|
func TestParseTCPFieldsIPv4OptionsOverflow(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Start with a valid IPv4+TCP frame, then lie about the header length.
|
||||||
|
src := netip.MustParseAddrPort("10.0.0.1:1")
|
||||||
|
dst := netip.MustParseAddrPort("10.0.0.2:2")
|
||||||
|
frame := buildTCPSegment(src, dst, 0, 0, []byte("x"), false)
|
||||||
|
ip := header.IPv4(frame[:header.IPv4MinimumSize])
|
||||||
|
// ihl=15 → 60 bytes of IP header claimed, but buffer only has 20.
|
||||||
|
ip.SetHeaderLength(60)
|
||||||
|
_, _, _, ok := parseTCPFields(frame, false)
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
100
common/tlsspoof/spoof.go
Normal file
100
common/tlsspoof/spoof.go
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
package tlsspoof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Method int
|
||||||
|
|
||||||
|
const (
|
||||||
|
MethodWrongSequence Method = iota
|
||||||
|
MethodWrongChecksum
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
MethodNameWrongSequence = "wrong-sequence"
|
||||||
|
MethodNameWrongChecksum = "wrong-checksum"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ParseMethod(s string) (Method, error) {
|
||||||
|
switch s {
|
||||||
|
case "", MethodNameWrongSequence:
|
||||||
|
return MethodWrongSequence, nil
|
||||||
|
case MethodNameWrongChecksum:
|
||||||
|
return MethodWrongChecksum, nil
|
||||||
|
default:
|
||||||
|
return 0, E.New("tls_spoof: unknown method: ", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Method) String() string {
|
||||||
|
switch m {
|
||||||
|
case MethodWrongSequence:
|
||||||
|
return MethodNameWrongSequence
|
||||||
|
case MethodWrongChecksum:
|
||||||
|
return MethodNameWrongChecksum
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Spoofer interface {
|
||||||
|
Inject(payload []byte) error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSpoofer(conn net.Conn, method Method) (Spoofer, error) {
|
||||||
|
return newRawSpoofer(conn, method)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Conn struct {
|
||||||
|
net.Conn
|
||||||
|
spoofer Spoofer
|
||||||
|
fakeSNI string
|
||||||
|
injected bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConn(conn net.Conn, spoofer Spoofer, fakeSNI string) *Conn {
|
||||||
|
return &Conn{
|
||||||
|
Conn: conn,
|
||||||
|
spoofer: spoofer,
|
||||||
|
fakeSNI: fakeSNI,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Write(b []byte) (int, error) {
|
||||||
|
if c.injected {
|
||||||
|
return c.Conn.Write(b)
|
||||||
|
}
|
||||||
|
defer c.spoofer.Close()
|
||||||
|
fake, err := rewriteSNI(b, c.fakeSNI)
|
||||||
|
if err != nil {
|
||||||
|
return 0, E.Cause(err, "tls_spoof: rewrite SNI")
|
||||||
|
}
|
||||||
|
err = c.spoofer.Inject(fake)
|
||||||
|
if err != nil {
|
||||||
|
return 0, E.Cause(err, "tls_spoof: inject")
|
||||||
|
}
|
||||||
|
c.injected = true
|
||||||
|
return c.Conn.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
return E.Append(c.Conn.Close(), c.spoofer.Close(), func(e error) error {
|
||||||
|
return E.Cause(e, "close spoofer")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) ReaderReplaceable() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) WriterReplaceable() bool {
|
||||||
|
return c.injected
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Upstream() any {
|
||||||
|
return c.Conn
|
||||||
|
}
|
||||||
53
common/windivert/address_test.go
Normal file
53
common/windivert/address_test.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package windivert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAddressSize(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
require.Equal(t, uintptr(80), unsafe.Sizeof(Address{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddressIPv6(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
var addr Address
|
||||||
|
require.False(t, addr.IPv6())
|
||||||
|
addr.bits = 1 << addrBitIPv6
|
||||||
|
require.True(t, addr.IPv6())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddressSetIPChecksum(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
var addr Address
|
||||||
|
addr.SetIPChecksum(true)
|
||||||
|
require.Equal(t, uint32(1<<addrBitIPChecksum), addr.bits)
|
||||||
|
addr.SetIPChecksum(false)
|
||||||
|
require.Equal(t, uint32(0), addr.bits)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddressSetTCPChecksum(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
var addr Address
|
||||||
|
addr.SetTCPChecksum(true)
|
||||||
|
require.Equal(t, uint32(1<<addrBitTCPChecksum), addr.bits)
|
||||||
|
addr.SetTCPChecksum(false)
|
||||||
|
require.Equal(t, uint32(0), addr.bits)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setters must not disturb sibling bits.
|
||||||
|
func TestAddressFlagBitsIndependent(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
var addr Address
|
||||||
|
addr.SetIPChecksum(true)
|
||||||
|
addr.SetTCPChecksum(true)
|
||||||
|
addr.bits |= 1 << addrBitIPv6
|
||||||
|
|
||||||
|
addr.SetIPChecksum(false)
|
||||||
|
require.False(t, addr.bits&(1<<addrBitIPChecksum) != 0)
|
||||||
|
require.True(t, addr.bits&(1<<addrBitTCPChecksum) != 0)
|
||||||
|
require.True(t, addr.bits&(1<<addrBitIPv6) != 0)
|
||||||
|
}
|
||||||
1191
common/windivert/assets/LICENSE.txt
Normal file
1191
common/windivert/assets/LICENSE.txt
Normal file
File diff suppressed because it is too large
Load Diff
BIN
common/windivert/assets/WinDivert32.sys
Normal file
BIN
common/windivert/assets/WinDivert32.sys
Normal file
Binary file not shown.
BIN
common/windivert/assets/WinDivert64.sys
Normal file
BIN
common/windivert/assets/WinDivert64.sys
Normal file
Binary file not shown.
14
common/windivert/assets_386.go
Normal file
14
common/windivert/assets_386.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
//go:build windows && 386
|
||||||
|
|
||||||
|
package windivert
|
||||||
|
|
||||||
|
import _ "embed"
|
||||||
|
|
||||||
|
//go:embed assets/WinDivert32.sys
|
||||||
|
var sysBytes []byte
|
||||||
|
|
||||||
|
func assetFiles() []assetFile {
|
||||||
|
return []assetFile{{"WinDivert32.sys", sysBytes}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func driverSysName() string { return "WinDivert32.sys" }
|
||||||
14
common/windivert/assets_amd64.go
Normal file
14
common/windivert/assets_amd64.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
//go:build windows && amd64
|
||||||
|
|
||||||
|
package windivert
|
||||||
|
|
||||||
|
import _ "embed"
|
||||||
|
|
||||||
|
//go:embed assets/WinDivert64.sys
|
||||||
|
var sysBytes []byte
|
||||||
|
|
||||||
|
func assetFiles() []assetFile {
|
||||||
|
return []assetFile{{"WinDivert64.sys", sysBytes}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func driverSysName() string { return "WinDivert64.sys" }
|
||||||
7
common/windivert/assets_unsupported.go
Normal file
7
common/windivert/assets_unsupported.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build windows && !amd64 && !386
|
||||||
|
|
||||||
|
package windivert
|
||||||
|
|
||||||
|
func assetFiles() []assetFile { return nil }
|
||||||
|
|
||||||
|
func driverSysName() string { return "" }
|
||||||
212
common/windivert/driver_windows.go
Normal file
212
common/windivert/driver_windows.go
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package windivert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
driverServiceName = "WinDivert"
|
||||||
|
driverDeviceName = `\\.\WinDivert`
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
driverOnce sync.Once
|
||||||
|
driverErr error
|
||||||
|
// driverDevName is ASCII-safe and must be available before ensureDriver
|
||||||
|
// so Open can try CreateFile first and only install on FILE_NOT_FOUND.
|
||||||
|
driverDevName, _ = windows.UTF16PtrFromString(driverDeviceName)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Requires SeLoadDriverPrivilege (Administrator). Running the 386 build
|
||||||
|
// under WOW64 on a 64-bit kernel is rejected — use the amd64 build.
|
||||||
|
func ensureDriver() error {
|
||||||
|
driverOnce.Do(func() {
|
||||||
|
driverErr = installDriver()
|
||||||
|
})
|
||||||
|
return driverErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func installDriver() error {
|
||||||
|
if runtime.GOARCH == "386" {
|
||||||
|
var isWow64 bool
|
||||||
|
err := windows.IsWow64Process(windows.CurrentProcess(), &isWow64)
|
||||||
|
if err == nil && isWow64 {
|
||||||
|
return E.New("windivert: 386 build detected running under WOW64 on a 64-bit kernel; use the amd64 build")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dir, err := ensureExtracted()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sysPath := filepath.Join(dir, driverSysName())
|
||||||
|
sysPathW, err := windows.UTF16PtrFromString(sysPath)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "windivert: utf16 driver path")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize driver install across concurrent processes.
|
||||||
|
mutexName, _ := windows.UTF16PtrFromString("WinDivertDriverInstallMutex")
|
||||||
|
mutex, err := windows.CreateMutex(nil, false, mutexName)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "windivert: create install mutex")
|
||||||
|
}
|
||||||
|
defer windows.CloseHandle(mutex)
|
||||||
|
_, err = windows.WaitForSingleObject(mutex, windows.INFINITE)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "windivert: wait install mutex")
|
||||||
|
}
|
||||||
|
defer windows.ReleaseMutex(mutex)
|
||||||
|
|
||||||
|
manager, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_ALL_ACCESS)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "windivert: open SCM")
|
||||||
|
}
|
||||||
|
defer windows.CloseServiceHandle(manager)
|
||||||
|
|
||||||
|
serviceNameW, _ := windows.UTF16PtrFromString(driverServiceName)
|
||||||
|
service, err := windows.OpenService(manager, serviceNameW, windows.SERVICE_ALL_ACCESS)
|
||||||
|
if err != nil {
|
||||||
|
service, err = windows.CreateService(
|
||||||
|
manager,
|
||||||
|
serviceNameW,
|
||||||
|
serviceNameW,
|
||||||
|
windows.SERVICE_ALL_ACCESS,
|
||||||
|
windows.SERVICE_KERNEL_DRIVER,
|
||||||
|
windows.SERVICE_DEMAND_START,
|
||||||
|
windows.SERVICE_ERROR_NORMAL,
|
||||||
|
sysPathW,
|
||||||
|
nil, nil, nil, nil, nil,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, windows.ERROR_SERVICE_EXISTS) {
|
||||||
|
service, err = windows.OpenService(manager, serviceNameW, windows.SERVICE_ALL_ACCESS)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return wrapDriverInstallError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer windows.CloseServiceHandle(service)
|
||||||
|
|
||||||
|
err = windows.StartService(service, 0, nil)
|
||||||
|
if err != nil && errors.Is(err, windows.ERROR_SERVICE_DISABLED) {
|
||||||
|
// A prior process called DeleteService on a still-running kernel
|
||||||
|
// driver: SCM marks the record for deletion and flips START_TYPE
|
||||||
|
// to DISABLED until the last handle closes. Re-enable so we can
|
||||||
|
// start it instead of waiting for a reboot.
|
||||||
|
err = windows.ChangeServiceConfig(
|
||||||
|
service,
|
||||||
|
windows.SERVICE_NO_CHANGE,
|
||||||
|
windows.SERVICE_DEMAND_START,
|
||||||
|
windows.SERVICE_NO_CHANGE,
|
||||||
|
nil, nil, nil, nil, nil, nil, nil,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "windivert: re-enable disabled service")
|
||||||
|
}
|
||||||
|
err = windows.StartService(service, 0, nil)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
// Mark for deletion so the driver unregisters when the last handle
|
||||||
|
// closes or on next reboot. Matches the upstream DLL's behavior:
|
||||||
|
// only the process that actually started the service takes on the
|
||||||
|
// cleanup responsibility. If another process already started it,
|
||||||
|
// we leave DeleteService to them.
|
||||||
|
_ = windows.DeleteService(service)
|
||||||
|
} else if !errors.Is(err, windows.ERROR_SERVICE_ALREADY_RUNNING) {
|
||||||
|
return E.Cause(err, "windivert: start service")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapDriverInstallError(err error) error {
|
||||||
|
if errors.Is(err, windows.ERROR_ACCESS_DENIED) {
|
||||||
|
return E.Cause(err, "windivert: installing the kernel driver requires Administrator privileges")
|
||||||
|
}
|
||||||
|
return E.Cause(err, "windivert: create service")
|
||||||
|
}
|
||||||
|
|
||||||
|
type assetFile struct {
|
||||||
|
name string
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
extractOnce sync.Once
|
||||||
|
extractErr error
|
||||||
|
extractDir string
|
||||||
|
)
|
||||||
|
|
||||||
|
// The on-disk copy is protected by Windows Authenticode signature
|
||||||
|
// enforcement, which rejects any tampered .sys at StartService time.
|
||||||
|
func ensureExtracted() (string, error) {
|
||||||
|
extractOnce.Do(func() {
|
||||||
|
extractDir, extractErr = extractImpl()
|
||||||
|
})
|
||||||
|
return extractDir, extractErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractImpl() (string, error) {
|
||||||
|
files := assetFiles()
|
||||||
|
if len(files) == 0 {
|
||||||
|
return "", E.New("windivert: unsupported architecture ", runtime.GOARCH)
|
||||||
|
}
|
||||||
|
|
||||||
|
base, err := os.UserCacheDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", E.Cause(err, "windivert: locate user cache dir")
|
||||||
|
}
|
||||||
|
dir := filepath.Join(base, "sing-box", "windivert", "v"+AssetVersion)
|
||||||
|
err = os.MkdirAll(dir, 0o755)
|
||||||
|
if err != nil {
|
||||||
|
return "", E.Cause(err, "windivert: mkdir ", dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, asset := range files {
|
||||||
|
err = ensureAsset(dir, asset)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dir, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concurrent sing-box processes race on os.Rename (atomic on NTFS);
|
||||||
|
// whichever wins creates the final file. Writers that lose the race
|
||||||
|
// silently discard their temp copy.
|
||||||
|
func ensureAsset(dir string, asset assetFile) error {
|
||||||
|
target := filepath.Join(dir, asset.name)
|
||||||
|
_, err := os.Stat(target)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
return E.Cause(err, "windivert: stat ", asset.name)
|
||||||
|
}
|
||||||
|
tmp := target + ".tmp-" + strconv.Itoa(os.Getpid())
|
||||||
|
err = os.WriteFile(tmp, asset.data, 0o644)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "windivert: write ", asset.name)
|
||||||
|
}
|
||||||
|
err = os.Rename(tmp, target)
|
||||||
|
if err != nil {
|
||||||
|
os.Remove(tmp)
|
||||||
|
if _, statErr := os.Stat(target); statErr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return E.Cause(err, "windivert: rename ", asset.name)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
182
common/windivert/filter.go
Normal file
182
common/windivert/filter.go
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
package windivert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WINDIVERT_FILTER VM instruction layout (24 bytes, #pragma pack(1)):
|
||||||
|
//
|
||||||
|
// word 0 (LE): field:11 | test:5 | success:16
|
||||||
|
// word 1 (LE): failure:16 | neg:1 | reserved:15
|
||||||
|
// words 2..5: arg[4] (native-endian uint32 each)
|
||||||
|
//
|
||||||
|
// The driver walks this as a decision tree: evaluate the test at inst i;
|
||||||
|
// on success jump to success; on failure jump to failure. Continuations
|
||||||
|
// 0x7FFE and 0x7FFF are ACCEPT and REJECT terminals.
|
||||||
|
const (
|
||||||
|
filterInstBytes = 24
|
||||||
|
filterMaxInsts = 256
|
||||||
|
|
||||||
|
fieldZero = 0
|
||||||
|
fieldOutbound = 2
|
||||||
|
fieldIP = 5
|
||||||
|
fieldIPv6 = 6
|
||||||
|
fieldTCP = 8
|
||||||
|
fieldIPSrcAddr = 21
|
||||||
|
fieldIPDstAddr = 22
|
||||||
|
fieldIPv6SrcAddr = 28
|
||||||
|
fieldIPv6DstAddr = 29
|
||||||
|
fieldTCPSrcPort = 38
|
||||||
|
fieldTCPDstPort = 39
|
||||||
|
|
||||||
|
testEQ = 0
|
||||||
|
|
||||||
|
resultAccept uint16 = 0x7FFE
|
||||||
|
resultReject uint16 = 0x7FFF
|
||||||
|
)
|
||||||
|
|
||||||
|
// Filter flags passed to IOCTL_WINDIVERT_STARTUP alongside the compiled
|
||||||
|
// filter. These tell the driver what *kinds* of packets the filter might
|
||||||
|
// match, used as a kernel-side fast-reject.
|
||||||
|
const (
|
||||||
|
filterFlagOutbound uint64 = 0x0020
|
||||||
|
filterFlagIP uint64 = 0x0040
|
||||||
|
filterFlagIPv6 uint64 = 0x0080
|
||||||
|
)
|
||||||
|
|
||||||
|
type filterInst struct {
|
||||||
|
field uint16 // 11 bits used
|
||||||
|
test uint8 // 5 bits used
|
||||||
|
success uint16
|
||||||
|
failure uint16
|
||||||
|
neg bool
|
||||||
|
arg [4]uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter is a typed specification of packets to capture. It replaces
|
||||||
|
// WinDivert's filter string language.
|
||||||
|
//
|
||||||
|
// Zero value = "reject all" (match nothing), suitable for send-only handles.
|
||||||
|
type Filter struct {
|
||||||
|
insts []filterInst
|
||||||
|
flags uint64 // filter flags for STARTUP ioctl
|
||||||
|
}
|
||||||
|
|
||||||
|
// reject returns a filter that matches no packet. The empty insts slice
|
||||||
|
// is encoded as a single rejecting instruction by encode().
|
||||||
|
func reject() *Filter {
|
||||||
|
return &Filter{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OutboundTCP returns a filter matching outbound TCP packets on the given
|
||||||
|
// 5-tuple. Both addresses must share an address family (IPv4 or IPv6).
|
||||||
|
func OutboundTCP(src, dst netip.AddrPort) (*Filter, error) {
|
||||||
|
if !src.IsValid() || !dst.IsValid() {
|
||||||
|
return nil, E.New("windivert: filter: invalid address port")
|
||||||
|
}
|
||||||
|
if src.Addr().Is4() != dst.Addr().Is4() {
|
||||||
|
return nil, E.New("windivert: filter: mixed IPv4/IPv6")
|
||||||
|
}
|
||||||
|
f := &Filter{
|
||||||
|
flags: filterFlagOutbound,
|
||||||
|
}
|
||||||
|
// Insts chain as AND: each test's failure = REJECT, success = next inst.
|
||||||
|
// The final inst's success = ACCEPT.
|
||||||
|
f.add(fieldOutbound, testEQ, argUint32(1))
|
||||||
|
if src.Addr().Is4() {
|
||||||
|
f.flags |= filterFlagIP
|
||||||
|
f.add(fieldIP, testEQ, argUint32(1))
|
||||||
|
f.add(fieldTCP, testEQ, argUint32(1))
|
||||||
|
f.add(fieldIPSrcAddr, testEQ, argIPv4(src.Addr()))
|
||||||
|
f.add(fieldIPDstAddr, testEQ, argIPv4(dst.Addr()))
|
||||||
|
} else {
|
||||||
|
f.flags |= filterFlagIPv6
|
||||||
|
f.add(fieldIPv6, testEQ, argUint32(1))
|
||||||
|
f.add(fieldTCP, testEQ, argUint32(1))
|
||||||
|
f.add(fieldIPv6SrcAddr, testEQ, argIPv6(src.Addr()))
|
||||||
|
f.add(fieldIPv6DstAddr, testEQ, argIPv6(dst.Addr()))
|
||||||
|
}
|
||||||
|
f.add(fieldTCPSrcPort, testEQ, argUint32(uint32(src.Port())))
|
||||||
|
f.add(fieldTCPDstPort, testEQ, argUint32(uint32(dst.Port())))
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Filter) add(field uint16, test uint8, arg [4]uint32) {
|
||||||
|
f.insts = append(f.insts, filterInst{field: field, test: test, arg: arg})
|
||||||
|
}
|
||||||
|
|
||||||
|
func argUint32(v uint32) [4]uint32 { return [4]uint32{v, 0, 0, 0} }
|
||||||
|
|
||||||
|
// argIPv4 encodes an IPv4 address for IP_SRCADDR/IP_DSTADDR. The driver
|
||||||
|
// compares against an IPv4-mapped-IPv6 form: {host_order_u32, 0x0000FFFF,
|
||||||
|
// 0, 0} (see sys/windivert.c windivert_get_ipv4_addr and the IPv4_SRCADDR
|
||||||
|
// val-word construction). Omitting the 0x0000FFFF marker causes the EQ
|
||||||
|
// test to fail for every packet.
|
||||||
|
func argIPv4(addr netip.Addr) [4]uint32 {
|
||||||
|
b := addr.As4()
|
||||||
|
return [4]uint32{binary.BigEndian.Uint32(b[:]), 0x0000FFFF, 0, 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
// argIPv6 encodes an IPv6 address for IPV6_SRCADDR/IPV6_DSTADDR. The
|
||||||
|
// driver stores the address as four host-order uint32s in REVERSED word
|
||||||
|
// order: val[0]=low (bytes 12..15), val[3]=high (bytes 0..3). See
|
||||||
|
// sys/windivert.c windivert_outbound_network_v6_classify val-word
|
||||||
|
// construction.
|
||||||
|
func argIPv6(addr netip.Addr) [4]uint32 {
|
||||||
|
b := addr.As16()
|
||||||
|
return [4]uint32{
|
||||||
|
binary.BigEndian.Uint32(b[12:16]),
|
||||||
|
binary.BigEndian.Uint32(b[8:12]),
|
||||||
|
binary.BigEndian.Uint32(b[4:8]),
|
||||||
|
binary.BigEndian.Uint32(b[0:4]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// encode serializes the Filter to the on-wire WINDIVERT_FILTER[] format
|
||||||
|
// plus the filter_flags for STARTUP ioctl.
|
||||||
|
func (f *Filter) encode() ([]byte, uint64, error) {
|
||||||
|
if len(f.insts) == 0 {
|
||||||
|
// "Reject all" — one instruction, ZERO == 0 is always true, but we
|
||||||
|
// invert by setting both success and failure to REJECT.
|
||||||
|
return encodeInst(filterInst{
|
||||||
|
field: fieldZero,
|
||||||
|
test: testEQ,
|
||||||
|
success: resultReject,
|
||||||
|
failure: resultReject,
|
||||||
|
}), 0, nil
|
||||||
|
}
|
||||||
|
if len(f.insts) > filterMaxInsts-1 {
|
||||||
|
return nil, 0, E.New("windivert: filter too long")
|
||||||
|
}
|
||||||
|
buf := make([]byte, 0, filterInstBytes*len(f.insts))
|
||||||
|
for i, inst := range f.insts {
|
||||||
|
if i == len(f.insts)-1 {
|
||||||
|
inst.success = resultAccept
|
||||||
|
} else {
|
||||||
|
inst.success = uint16(i + 1)
|
||||||
|
}
|
||||||
|
inst.failure = resultReject
|
||||||
|
buf = append(buf, encodeInst(inst)...)
|
||||||
|
}
|
||||||
|
return buf, f.flags, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeInst(inst filterInst) []byte {
|
||||||
|
out := make([]byte, filterInstBytes)
|
||||||
|
word0 := uint32(inst.field&0x7FF) | uint32(inst.test&0x1F)<<11 |
|
||||||
|
uint32(inst.success)<<16
|
||||||
|
word1 := uint32(inst.failure)
|
||||||
|
if inst.neg {
|
||||||
|
word1 |= 1 << 16
|
||||||
|
}
|
||||||
|
binary.LittleEndian.PutUint32(out[0:4], word0)
|
||||||
|
binary.LittleEndian.PutUint32(out[4:8], word1)
|
||||||
|
binary.LittleEndian.PutUint32(out[8:12], inst.arg[0])
|
||||||
|
binary.LittleEndian.PutUint32(out[12:16], inst.arg[1])
|
||||||
|
binary.LittleEndian.PutUint32(out[16:20], inst.arg[2])
|
||||||
|
binary.LittleEndian.PutUint32(out[20:24], inst.arg[3])
|
||||||
|
return out
|
||||||
|
}
|
||||||
140
common/windivert/filter_test.go
Normal file
140
common/windivert/filter_test.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
package windivert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRejectFilter(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
bin, flags, err := reject().encode()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(bin) != filterInstBytes {
|
||||||
|
t.Fatalf("reject filter len: got %d, want %d", len(bin), filterInstBytes)
|
||||||
|
}
|
||||||
|
if flags != 0 {
|
||||||
|
t.Fatalf("reject filter flags: got %x, want 0", flags)
|
||||||
|
}
|
||||||
|
// word0: field=ZERO=0, test=EQ=0, success=REJECT=0x7FFF
|
||||||
|
word0 := binary.LittleEndian.Uint32(bin[0:4])
|
||||||
|
if word0 != uint32(resultReject)<<16 {
|
||||||
|
t.Fatalf("reject word0 = %08x", word0)
|
||||||
|
}
|
||||||
|
// word1: failure=REJECT
|
||||||
|
word1 := binary.LittleEndian.Uint32(bin[4:8])
|
||||||
|
if word1 != uint32(resultReject) {
|
||||||
|
t.Fatalf("reject word1 = %08x", word1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOutboundTCPFilterIPv4(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
src := netip.MustParseAddrPort("10.1.2.3:54321")
|
||||||
|
dst := netip.MustParseAddrPort("1.2.3.4:443")
|
||||||
|
f, err := OutboundTCP(src, dst)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
bin, flags, err := f.encode()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if want := filterFlagOutbound | filterFlagIP; flags != want {
|
||||||
|
t.Fatalf("flags: got %x, want %x", flags, want)
|
||||||
|
}
|
||||||
|
// 7 instructions: OUTBOUND, IP, TCP, IP_SRCADDR, IP_DSTADDR, TCP_SRCPORT, TCP_DSTPORT
|
||||||
|
const wantInsts = 7
|
||||||
|
if len(bin) != wantInsts*filterInstBytes {
|
||||||
|
t.Fatalf("instruction count: got %d, want %d", len(bin)/filterInstBytes, wantInsts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inst 0: OUTBOUND == 1, success=1, failure=REJECT
|
||||||
|
checkInst(t, bin[0*filterInstBytes:], 0, fieldOutbound, testEQ, 1, resultReject, 1)
|
||||||
|
// Inst 1: IP == 1, success=2
|
||||||
|
checkInst(t, bin[1*filterInstBytes:], 1, fieldIP, testEQ, 2, resultReject, 1)
|
||||||
|
// Inst 2: TCP == 1, success=3
|
||||||
|
checkInst(t, bin[2*filterInstBytes:], 2, fieldTCP, testEQ, 3, resultReject, 1)
|
||||||
|
// Inst 3: IP_SRCADDR == 10.1.2.3 (host-order uint32 = 0x0A010203, arg[1]=0x0000FFFF marker)
|
||||||
|
checkInst(t, bin[3*filterInstBytes:], 3, fieldIPSrcAddr, testEQ, 4, resultReject, 0x0A010203)
|
||||||
|
checkArg1(t, bin[3*filterInstBytes:], 3, 0x0000FFFF)
|
||||||
|
// Inst 4: IP_DSTADDR == 1.2.3.4
|
||||||
|
checkInst(t, bin[4*filterInstBytes:], 4, fieldIPDstAddr, testEQ, 5, resultReject, 0x01020304)
|
||||||
|
checkArg1(t, bin[4*filterInstBytes:], 4, 0x0000FFFF)
|
||||||
|
// Inst 5: TCP_SRCPORT == 54321
|
||||||
|
checkInst(t, bin[5*filterInstBytes:], 5, fieldTCPSrcPort, testEQ, 6, resultReject, 54321)
|
||||||
|
// Last inst 6: TCP_DSTPORT == 443, success=ACCEPT
|
||||||
|
checkInst(t, bin[6*filterInstBytes:], 6, fieldTCPDstPort, testEQ, resultAccept, resultReject, 443)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOutboundTCPFilterIPv6(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
src := netip.MustParseAddrPort("[2001:db8::1]:54321")
|
||||||
|
dst := netip.MustParseAddrPort("[2001:db8::2]:443")
|
||||||
|
f, err := OutboundTCP(src, dst)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
bin, flags, err := f.encode()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if want := filterFlagOutbound | filterFlagIPv6; flags != want {
|
||||||
|
t.Fatalf("flags: got %x, want %x", flags, want)
|
||||||
|
}
|
||||||
|
// Inst 3: IPv6_SRCADDR. The driver stores the address in reversed
|
||||||
|
// word order: arg[0]=low (bytes 12..15)=1, arg[3]=high (bytes 0..3)=0x20010db8.
|
||||||
|
off := 3 * filterInstBytes
|
||||||
|
a0 := binary.LittleEndian.Uint32(bin[off+8:])
|
||||||
|
a1 := binary.LittleEndian.Uint32(bin[off+12:])
|
||||||
|
a2 := binary.LittleEndian.Uint32(bin[off+16:])
|
||||||
|
a3 := binary.LittleEndian.Uint32(bin[off+20:])
|
||||||
|
if a0 != 1 || a1 != 0 || a2 != 0 || a3 != 0x20010db8 {
|
||||||
|
t.Fatalf("ipv6 src arg=[%08x %08x %08x %08x], want [1 0 0 0x20010db8]", a0, a1, a2, a3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOutboundTCPFilterMixedFamily(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
src := netip.MustParseAddrPort("10.0.0.1:1234")
|
||||||
|
dst := netip.MustParseAddrPort("[2001:db8::1]:443")
|
||||||
|
if _, err := OutboundTCP(src, dst); err == nil {
|
||||||
|
t.Fatal("expected error for mixed families")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkArg1(t *testing.T, raw []byte, idx int, arg1 uint32) {
|
||||||
|
t.Helper()
|
||||||
|
got := binary.LittleEndian.Uint32(raw[12:16])
|
||||||
|
if got != arg1 {
|
||||||
|
t.Errorf("inst %d arg[1]: got %08x, want %08x", idx, got, arg1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkInst(t *testing.T, raw []byte, idx int, field uint16, test uint8, success, failure uint16, arg0 uint32) {
|
||||||
|
t.Helper()
|
||||||
|
word0 := binary.LittleEndian.Uint32(raw[0:4])
|
||||||
|
word1 := binary.LittleEndian.Uint32(raw[4:8])
|
||||||
|
a0 := binary.LittleEndian.Uint32(raw[8:12])
|
||||||
|
gotField := uint16(word0 & 0x7FF)
|
||||||
|
gotTest := uint8((word0 >> 11) & 0x1F)
|
||||||
|
gotSuccess := uint16(word0 >> 16)
|
||||||
|
gotFailure := uint16(word1 & 0xFFFF)
|
||||||
|
if gotField != field {
|
||||||
|
t.Errorf("inst %d field: got %d, want %d", idx, gotField, field)
|
||||||
|
}
|
||||||
|
if gotTest != test {
|
||||||
|
t.Errorf("inst %d test: got %d, want %d", idx, gotTest, test)
|
||||||
|
}
|
||||||
|
if gotSuccess != success {
|
||||||
|
t.Errorf("inst %d success: got %d, want %d", idx, gotSuccess, success)
|
||||||
|
}
|
||||||
|
if gotFailure != failure {
|
||||||
|
t.Errorf("inst %d failure: got %d, want %d", idx, gotFailure, failure)
|
||||||
|
}
|
||||||
|
if a0 != arg0 {
|
||||||
|
t.Errorf("inst %d arg[0]: got %08x, want %08x", idx, a0, arg0)
|
||||||
|
}
|
||||||
|
}
|
||||||
320
common/windivert/handle_windows.go
Normal file
320
common/windivert/handle_windows.go
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package windivert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Handle owns a WinDivert kernel device handle plus a private event for
|
||||||
|
// overlapped I/O. Methods on *Handle are not safe for concurrent use
|
||||||
|
// across goroutines (there is a single shared event per Handle).
|
||||||
|
//
|
||||||
|
// addr is a per-Handle Address buffer the IOCTL struct embeds a pointer
|
||||||
|
// to. It lives on the heap (as a field of a heap-allocated Handle) so
|
||||||
|
// the pointer value stored as bytes in the ioctl buffer remains valid
|
||||||
|
// across stack growth between buildIoctl* and the DeviceIoControl
|
||||||
|
// syscall — stack-local Address values are not safe for this pattern
|
||||||
|
// because Go's escape analysis does not see the pointer through the
|
||||||
|
// unsafe.Pointer → uintptr → bytes conversion.
|
||||||
|
type Handle struct {
|
||||||
|
device windows.Handle
|
||||||
|
event windows.Handle
|
||||||
|
closing sync.Once
|
||||||
|
closeErr error
|
||||||
|
addr Address
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter may be nil for "reject all", suitable for send-only handles.
|
||||||
|
// Requires Administrator on first call per process (installs the kernel
|
||||||
|
// driver via SCM); subsequent calls reuse the running driver.
|
||||||
|
func Open(filter *Filter, layer Layer, priority int16, flags Flag) (*Handle, error) {
|
||||||
|
err := validateOpenArgs(layer, priority, flags)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if filter == nil {
|
||||||
|
filter = reject()
|
||||||
|
}
|
||||||
|
filterBin, filterFlags, err := filter.encode()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
device, err := openDevice()
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, windows.ERROR_FILE_NOT_FOUND) &&
|
||||||
|
!errors.Is(err, windows.ERROR_PATH_NOT_FOUND) {
|
||||||
|
if errors.Is(err, windows.ERROR_ACCESS_DENIED) {
|
||||||
|
return nil, E.Cause(err, "windivert: open device (administrator required)")
|
||||||
|
}
|
||||||
|
return nil, E.Cause(err, "windivert: open device")
|
||||||
|
}
|
||||||
|
// Device node missing: kernel driver not loaded. Install + retry.
|
||||||
|
// Matches WinDivertOpen's lazy-install path; avoids racing StartService
|
||||||
|
// against a still-loaded driver whose SCM record is marked for deletion.
|
||||||
|
err = ensureDriver()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
device, err = openDevice()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, windows.ERROR_ACCESS_DENIED) {
|
||||||
|
return nil, E.Cause(err, "windivert: open device (administrator required)")
|
||||||
|
}
|
||||||
|
return nil, E.Cause(err, "windivert: open device")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
event, err := windows.CreateEvent(nil, 1, 0, nil) // manual reset, unsignaled
|
||||||
|
if err != nil {
|
||||||
|
windows.CloseHandle(device)
|
||||||
|
return nil, E.Cause(err, "windivert: create event")
|
||||||
|
}
|
||||||
|
h := &Handle{device: device, event: event}
|
||||||
|
|
||||||
|
err = h.initialize(layer, priority, flags)
|
||||||
|
if err != nil {
|
||||||
|
h.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = h.startup(filterBin, filterFlags)
|
||||||
|
if err != nil {
|
||||||
|
h.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func openDevice() (windows.Handle, error) {
|
||||||
|
return windows.CreateFile(
|
||||||
|
driverDevName,
|
||||||
|
windows.GENERIC_READ|windows.GENERIC_WRITE,
|
||||||
|
0, nil,
|
||||||
|
windows.OPEN_EXISTING,
|
||||||
|
windows.FILE_ATTRIBUTE_NORMAL|windows.FILE_FLAG_OVERLAPPED,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateOpenArgs(layer Layer, priority int16, flags Flag) error {
|
||||||
|
if layer != LayerNetwork {
|
||||||
|
return E.New("windivert: invalid layer ", uint32(layer))
|
||||||
|
}
|
||||||
|
if priority < PriorityLowest || priority > PriorityHighest {
|
||||||
|
return E.New("windivert: priority out of range")
|
||||||
|
}
|
||||||
|
if flags&^FlagSendOnly != 0 {
|
||||||
|
return E.New("windivert: unknown flag bits")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handle) initialize(layer Layer, priority int16, flags Flag) error {
|
||||||
|
in := buildIoctlInitialize(layer, priority, flags)
|
||||||
|
// WINDIVERT_VERSION is a 64-byte packed struct; only the first 20
|
||||||
|
// bytes (magic, major, minor, bits) carry data, the rest is reserved.
|
||||||
|
var outBuf [versionStructSize]byte
|
||||||
|
binary.LittleEndian.PutUint64(outBuf[0:8], magicDLL)
|
||||||
|
binary.LittleEndian.PutUint32(outBuf[8:12], versionMajor)
|
||||||
|
binary.LittleEndian.PutUint32(outBuf[12:16], versionMinor)
|
||||||
|
binary.LittleEndian.PutUint32(outBuf[16:20], uint32(unsafe.Sizeof(uintptr(0))*8))
|
||||||
|
_, err := doIoctl(h.device, ioctlInitialize, in[:], outBuf[:], h.event)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "windivert: initialize ioctl")
|
||||||
|
}
|
||||||
|
gotMagic := binary.LittleEndian.Uint64(outBuf[0:8])
|
||||||
|
if gotMagic != magicSYS {
|
||||||
|
return E.New("windivert: driver magic mismatch (got ", gotMagic, ")")
|
||||||
|
}
|
||||||
|
gotMajor := binary.LittleEndian.Uint32(outBuf[8:12])
|
||||||
|
if gotMajor < versionMajor {
|
||||||
|
gotMinor := binary.LittleEndian.Uint32(outBuf[12:16])
|
||||||
|
return E.New("windivert: driver version too old: ", gotMajor, ".", gotMinor)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handle) startup(filterBin []byte, filterFlags uint64) error {
|
||||||
|
in := buildIoctlStartup(filterFlags)
|
||||||
|
_, err := doIoctl(h.device, ioctlStartup, in[:], filterBin, h.event)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "windivert: startup ioctl")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the handle is closed mid-Recv the error wraps ERROR_OPERATION_ABORTED.
|
||||||
|
func (h *Handle) Recv(buf []byte) (int, Address, error) {
|
||||||
|
if len(buf) == 0 {
|
||||||
|
return 0, Address{}, E.New("windivert: recv: zero-length buffer")
|
||||||
|
}
|
||||||
|
h.addr = Address{}
|
||||||
|
in := buildIoctlRecv(&h.addr)
|
||||||
|
n, err := doIoctl(h.device, ioctlRecv, in[:], buf, h.event)
|
||||||
|
runtime.KeepAlive(h)
|
||||||
|
if err != nil {
|
||||||
|
return 0, Address{}, err
|
||||||
|
}
|
||||||
|
return int(n), h.addr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// The address's Outbound flag controls whether the packet is sent toward
|
||||||
|
// the wire (outbound=true) or delivered up the stack (outbound=false).
|
||||||
|
// IfIdx and SubIfIdx can stay zero — the driver uses the routing table
|
||||||
|
// when IfIdx=0.
|
||||||
|
func (h *Handle) Send(packet []byte, addr *Address) (int, error) {
|
||||||
|
if len(packet) == 0 {
|
||||||
|
return 0, E.New("windivert: send: empty packet")
|
||||||
|
}
|
||||||
|
if addr == nil {
|
||||||
|
return 0, E.New("windivert: send: nil address")
|
||||||
|
}
|
||||||
|
h.addr = *addr
|
||||||
|
in := buildIoctlSend(&h.addr)
|
||||||
|
n, err := doIoctl(h.device, ioctlSend, in[:], packet, h.event)
|
||||||
|
runtime.KeepAlive(h)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return int(n), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Idempotent. Aborts any in-flight I/O on the handle.
|
||||||
|
func (h *Handle) Close() error {
|
||||||
|
h.closing.Do(func() {
|
||||||
|
var errs []error
|
||||||
|
if h.device != 0 {
|
||||||
|
err := windows.CloseHandle(h.device)
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
|
h.device = 0
|
||||||
|
}
|
||||||
|
if h.event != 0 {
|
||||||
|
err := windows.CloseHandle(h.event)
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
|
h.event = 0
|
||||||
|
}
|
||||||
|
h.closeErr = E.Errors(errs...)
|
||||||
|
})
|
||||||
|
return h.closeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// IOCTL codes from windivert_device.h. CTL_CODE macro layout:
|
||||||
|
//
|
||||||
|
// (DeviceType << 16) | (Access << 14) | (Function << 2) | Method
|
||||||
|
const (
|
||||||
|
fileDeviceNetwork uint32 = 0x12
|
||||||
|
accessReadWrite uint32 = 3 // FILE_READ_DATA | FILE_WRITE_DATA
|
||||||
|
accessRead uint32 = 1
|
||||||
|
|
||||||
|
methodInDirect uint32 = 1
|
||||||
|
methodOutDirect uint32 = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
func ctlCode(deviceType, access, function, method uint32) uint32 {
|
||||||
|
return (deviceType << 16) | (access << 14) | (function << 2) | method
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ioctlInitialize = ctlCode(fileDeviceNetwork, accessReadWrite, 0x921, methodOutDirect)
|
||||||
|
ioctlStartup = ctlCode(fileDeviceNetwork, accessReadWrite, 0x922, methodInDirect)
|
||||||
|
ioctlRecv = ctlCode(fileDeviceNetwork, accessRead, 0x923, methodOutDirect)
|
||||||
|
ioctlSend = ctlCode(fileDeviceNetwork, accessReadWrite, 0x924, methodInDirect)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Magic numbers exchanged during INITIALIZE. DLL sends magicDLL in the
|
||||||
|
// version struct; driver returns magicSYS on success.
|
||||||
|
const (
|
||||||
|
magicDLL uint64 = 0x4C4C447669645724 // "$WdivDLL" in LE bytes
|
||||||
|
magicSYS uint64 = 0x5359537669645723 // "#WdivSYS" in LE bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
versionMajor uint32 = 2
|
||||||
|
versionMinor uint32 = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// Size of the WINDIVERT_IOCTL union on wire (packed).
|
||||||
|
const ioctlSize = 16
|
||||||
|
|
||||||
|
// Size of WINDIVERT_VERSION on wire (packed). Only the first 20 bytes
|
||||||
|
// carry data; the rest is reserved zero padding.
|
||||||
|
const versionStructSize = 64
|
||||||
|
|
||||||
|
// doIoctl performs a single synchronous (blocking) overlapped
|
||||||
|
// DeviceIoControl. The handle is opened with FILE_FLAG_OVERLAPPED so
|
||||||
|
// DeviceIoControl returns ERROR_IO_PENDING; we then wait for completion
|
||||||
|
// via GetOverlappedResult. Event is passed in so callers can reuse it
|
||||||
|
// across calls on the same handle (avoids per-call CreateEvent).
|
||||||
|
func doIoctl(handle windows.Handle, code uint32, in []byte, out []byte, event windows.Handle) (uint32, error) {
|
||||||
|
var overlapped windows.Overlapped
|
||||||
|
overlapped.HEvent = event
|
||||||
|
_ = windows.ResetEvent(event)
|
||||||
|
|
||||||
|
var inPtr *byte
|
||||||
|
var inLen uint32
|
||||||
|
if len(in) > 0 {
|
||||||
|
inPtr = &in[0]
|
||||||
|
inLen = uint32(len(in))
|
||||||
|
}
|
||||||
|
var outPtr *byte
|
||||||
|
var outLen uint32
|
||||||
|
if len(out) > 0 {
|
||||||
|
outPtr = &out[0]
|
||||||
|
outLen = uint32(len(out))
|
||||||
|
}
|
||||||
|
var returned uint32
|
||||||
|
err := windows.DeviceIoControl(handle, code, inPtr, inLen, outPtr, outLen, &returned, &overlapped)
|
||||||
|
if err != nil && !errors.Is(err, windows.ERROR_IO_PENDING) {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
err = windows.GetOverlappedResult(handle, &overlapped, &returned, true)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return returned, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildIoctlInitialize(layer Layer, priority int16, flags Flag) [ioctlSize]byte {
|
||||||
|
var buf [ioctlSize]byte
|
||||||
|
binary.LittleEndian.PutUint32(buf[0:4], uint32(layer))
|
||||||
|
// The driver expects priority + WINDIVERT_PRIORITY_HIGHEST (30000) so
|
||||||
|
// the low range maps to non-negative integers.
|
||||||
|
binary.LittleEndian.PutUint32(buf[4:8], uint32(int32(priority)+int32(PriorityHighest)))
|
||||||
|
binary.LittleEndian.PutUint64(buf[8:16], uint64(flags))
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildIoctlStartup(filterFlags uint64) [ioctlSize]byte {
|
||||||
|
var buf [ioctlSize]byte
|
||||||
|
binary.LittleEndian.PutUint64(buf[0:8], filterFlags)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildIoctlRecv packs a user-space pointer to a WINDIVERT_ADDRESS into
|
||||||
|
// the ioctl struct. The driver dereferences it to write the address for
|
||||||
|
// the received packet. Caller must keep the Address alive via
|
||||||
|
// runtime.KeepAlive.
|
||||||
|
func buildIoctlRecv(addr *Address) [ioctlSize]byte {
|
||||||
|
var buf [ioctlSize]byte
|
||||||
|
binary.LittleEndian.PutUint64(buf[0:8], uint64(uintptr(unsafe.Pointer(addr))))
|
||||||
|
binary.LittleEndian.PutUint64(buf[8:16], 0)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildIoctlSend(addr *Address) [ioctlSize]byte {
|
||||||
|
var buf [ioctlSize]byte
|
||||||
|
binary.LittleEndian.PutUint64(buf[0:8], uint64(uintptr(unsafe.Pointer(addr))))
|
||||||
|
binary.LittleEndian.PutUint64(buf[8:16], uint64(unsafe.Sizeof(Address{})))
|
||||||
|
return buf
|
||||||
|
}
|
||||||
106
common/windivert/handle_windows_test.go
Normal file
106
common/windivert/handle_windows_test.go
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package windivert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"testing"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CTL_CODE macro from Windows DDK:
|
||||||
|
//
|
||||||
|
// (DeviceType<<16) | (Access<<14) | (Function<<2) | Method
|
||||||
|
func TestCtlCodeMatchesDDK(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// FILE_DEVICE_NETWORK=0x12, FILE_READ_DATA|FILE_WRITE_DATA=3, METHOD_OUT_DIRECT=2
|
||||||
|
require.Equal(t, uint32(0x12E486), ctlCode(0x12, 3, 0x921, 2))
|
||||||
|
// FILE_READ_DATA=1, METHOD_OUT_DIRECT=2
|
||||||
|
require.Equal(t, uint32(0x12648E), ctlCode(0x12, 1, 0x923, 2))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Baked-in against windivert_device.h @ v2.2.2. A mismatch here means the
|
||||||
|
// kernel will reject every ioctl with ERROR_INVALID_FUNCTION.
|
||||||
|
func TestIoctlCodesMatchUpstream(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
require.Equal(t, uint32(0x12E486), ioctlInitialize)
|
||||||
|
require.Equal(t, uint32(0x12E489), ioctlStartup)
|
||||||
|
require.Equal(t, uint32(0x12648E), ioctlRecv)
|
||||||
|
require.Equal(t, uint32(0x12E491), ioctlSend)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildIoctlInitialize(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
buf := buildIoctlInitialize(LayerNetwork, 100, FlagSendOnly)
|
||||||
|
require.Equal(t, uint32(LayerNetwork), binary.LittleEndian.Uint32(buf[0:4]))
|
||||||
|
// Driver expects priority+PriorityHighest(30000) so the range is non-negative.
|
||||||
|
require.Equal(t, uint32(30100), binary.LittleEndian.Uint32(buf[4:8]))
|
||||||
|
require.Equal(t, uint64(FlagSendOnly), binary.LittleEndian.Uint64(buf[8:16]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildIoctlInitializePriorityRange(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
lowest := buildIoctlInitialize(LayerNetwork, PriorityLowest, 0)
|
||||||
|
require.Equal(t, uint32(0), binary.LittleEndian.Uint32(lowest[4:8]))
|
||||||
|
highest := buildIoctlInitialize(LayerNetwork, PriorityHighest, 0)
|
||||||
|
require.Equal(t, uint32(60000), binary.LittleEndian.Uint32(highest[4:8]))
|
||||||
|
zero := buildIoctlInitialize(LayerNetwork, 0, 0)
|
||||||
|
require.Equal(t, uint32(30000), binary.LittleEndian.Uint32(zero[4:8]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildIoctlStartup(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
flags := filterFlagOutbound | filterFlagIP
|
||||||
|
buf := buildIoctlStartup(flags)
|
||||||
|
require.Equal(t, flags, binary.LittleEndian.Uint64(buf[0:8]))
|
||||||
|
// The second quad-word is unused for STARTUP.
|
||||||
|
require.Equal(t, uint64(0), binary.LittleEndian.Uint64(buf[8:16]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildIoctlRecvEmbedsAddressPointer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
addr := &Address{Timestamp: 0xCAFEBABE}
|
||||||
|
buf := buildIoctlRecv(addr)
|
||||||
|
require.Equal(t, uint64(uintptr(unsafe.Pointer(addr))),
|
||||||
|
binary.LittleEndian.Uint64(buf[0:8]))
|
||||||
|
// RECV does not carry an address length; driver writes full Address back.
|
||||||
|
require.Equal(t, uint64(0), binary.LittleEndian.Uint64(buf[8:16]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildIoctlSendEmbedsAddressPointerAndSize(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
addr := &Address{}
|
||||||
|
buf := buildIoctlSend(addr)
|
||||||
|
require.Equal(t, uint64(uintptr(unsafe.Pointer(addr))),
|
||||||
|
binary.LittleEndian.Uint64(buf[0:8]))
|
||||||
|
require.Equal(t, uint64(unsafe.Sizeof(Address{})),
|
||||||
|
binary.LittleEndian.Uint64(buf[8:16]))
|
||||||
|
require.Equal(t, uint64(80), binary.LittleEndian.Uint64(buf[8:16]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateOpenArgsLayer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
require.NoError(t, validateOpenArgs(LayerNetwork, 0, 0))
|
||||||
|
require.Error(t, validateOpenArgs(Layer(1), 0, 0))
|
||||||
|
require.Error(t, validateOpenArgs(Layer(42), 0, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateOpenArgsPriorityBounds(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
require.NoError(t, validateOpenArgs(LayerNetwork, PriorityHighest, 0))
|
||||||
|
require.NoError(t, validateOpenArgs(LayerNetwork, PriorityLowest, 0))
|
||||||
|
require.NoError(t, validateOpenArgs(LayerNetwork, 0, 0))
|
||||||
|
require.Error(t, validateOpenArgs(LayerNetwork, PriorityHighest+1, 0))
|
||||||
|
require.Error(t, validateOpenArgs(LayerNetwork, PriorityLowest-1, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateOpenArgsFlags(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
require.NoError(t, validateOpenArgs(LayerNetwork, 0, 0))
|
||||||
|
require.NoError(t, validateOpenArgs(LayerNetwork, 0, FlagSendOnly))
|
||||||
|
// Unknown flag bits must be rejected to surface caller mistakes early.
|
||||||
|
require.Error(t, validateOpenArgs(LayerNetwork, 0, Flag(0x10)))
|
||||||
|
require.Error(t, validateOpenArgs(LayerNetwork, 0, FlagSendOnly|Flag(0x10)))
|
||||||
|
}
|
||||||
88
common/windivert/integration_windows_test.go
Normal file
88
common/windivert/integration_windows_test.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package windivert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
func openHandle(t *testing.T, filter *Filter, flags Flag) *Handle {
|
||||||
|
t.Helper()
|
||||||
|
h, err := Open(filter, LayerNetwork, 0, flags)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// A send-only handle installs+opens the driver but does not attach a
|
||||||
|
// receive filter, so it exercises the full driver-install path without
|
||||||
|
// diverting any live traffic on the host.
|
||||||
|
func TestIntegrationOpenSendOnly(t *testing.T) {
|
||||||
|
h := openHandle(t, nil, FlagSendOnly)
|
||||||
|
require.NoError(t, h.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close is idempotent per the doc contract.
|
||||||
|
func TestIntegrationCloseTwice(t *testing.T) {
|
||||||
|
h := openHandle(t, nil, FlagSendOnly)
|
||||||
|
require.NoError(t, h.Close())
|
||||||
|
require.NoError(t, h.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recv must unblock when the handle is closed concurrently. Without this,
|
||||||
|
// the spoofer's run goroutine could deadlock on shutdown.
|
||||||
|
func TestIntegrationRecvAbortsOnClose(t *testing.T) {
|
||||||
|
// A filter no live traffic will match, so Recv blocks indefinitely
|
||||||
|
// until Close aborts the overlapped I/O.
|
||||||
|
filter, err := OutboundTCP(
|
||||||
|
netip.MustParseAddrPort("10.255.255.254:1"),
|
||||||
|
netip.MustParseAddrPort("10.255.255.253:2"),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
h := openHandle(t, filter, 0)
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
buf := make([]byte, MTUMax)
|
||||||
|
_, _, recvErr := h.Recv(buf)
|
||||||
|
errCh <- recvErr
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Let Recv reach the blocking DeviceIoControl before Close races in.
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
require.NoError(t, h.Close())
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, errors.Is(err, windows.ERROR_OPERATION_ABORTED),
|
||||||
|
"Recv should return ERROR_OPERATION_ABORTED, got %v", err)
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("Recv did not unblock within 3s after Close")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Two concurrent Open calls must both succeed: the first wins the driver
|
||||||
|
// install race, the second reuses the already-running service.
|
||||||
|
func TestIntegrationConcurrentOpen(t *testing.T) {
|
||||||
|
errCh := make(chan error, 2)
|
||||||
|
handles := make(chan *Handle, 2)
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
go func() {
|
||||||
|
h, err := Open(nil, LayerNetwork, 0, FlagSendOnly)
|
||||||
|
handles <- h
|
||||||
|
errCh <- err
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
err := <-errCh
|
||||||
|
h := <-handles
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, h.Close())
|
||||||
|
}
|
||||||
|
}
|
||||||
71
common/windivert/windivert.go
Normal file
71
common/windivert/windivert.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
// Package windivert provides a pure-Go binding to the WinDivert kernel
|
||||||
|
// driver on Windows (amd64 and 386). User-mode WinDivert calls are
|
||||||
|
// reimplemented in Go; only the signed kernel driver is embedded as an
|
||||||
|
// asset, since SCM-installed drivers must live on disk and their
|
||||||
|
// Authenticode signature forbids modification.
|
||||||
|
//
|
||||||
|
// Administrator is required for the first Open in a process so SCM can
|
||||||
|
// load the driver. Upstream: https://github.com/basil00/WinDivert v2.2.2,
|
||||||
|
// redistributed under its LGPL v3 option; see assets/LICENSE.txt.
|
||||||
|
package windivert
|
||||||
|
|
||||||
|
import "unsafe"
|
||||||
|
|
||||||
|
const AssetVersion = "2.2.2"
|
||||||
|
|
||||||
|
// MTUMax is WINDIVERT_MTU_MAX from windivert.h (40 + 0xFFFF). Suitable as
|
||||||
|
// a single-packet receive buffer size.
|
||||||
|
const MTUMax = 40 + 0xFFFF
|
||||||
|
|
||||||
|
type Layer uint32
|
||||||
|
|
||||||
|
const LayerNetwork Layer = 0
|
||||||
|
|
||||||
|
type Flag uint64
|
||||||
|
|
||||||
|
const FlagSendOnly Flag = 0x0008
|
||||||
|
|
||||||
|
const (
|
||||||
|
PriorityHighest int16 = 30000
|
||||||
|
PriorityLowest int16 = -30000
|
||||||
|
)
|
||||||
|
|
||||||
|
// Address mirrors WINDIVERT_ADDRESS from windivert.h (80 bytes,
|
||||||
|
// little-endian on both amd64 and 386):
|
||||||
|
//
|
||||||
|
// 0: INT64 Timestamp
|
||||||
|
// 8: UINT32 bitfield: Layer:8 | Event:8 | flags | Reserved1:8
|
||||||
|
// 12: UINT32 Reserved2
|
||||||
|
// 16: 64 bytes union (WINDIVERT_DATA_NETWORK / FLOW / SOCKET / REFLECT)
|
||||||
|
type Address struct {
|
||||||
|
Timestamp int64
|
||||||
|
bits uint32
|
||||||
|
Reserved2 uint32
|
||||||
|
union [64]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ [80]byte = [unsafe.Sizeof(Address{})]byte{}
|
||||||
|
|
||||||
|
// Bit positions inside the Address's packed flags word.
|
||||||
|
const (
|
||||||
|
addrBitIPv6 = 20
|
||||||
|
addrBitIPChecksum = 21
|
||||||
|
addrBitTCPChecksum = 22
|
||||||
|
)
|
||||||
|
|
||||||
|
func getFlagBit(bits uint32, pos uint) bool { return bits&(1<<pos) != 0 }
|
||||||
|
func setFlagBit(bits uint32, pos uint, v bool) uint32 {
|
||||||
|
if v {
|
||||||
|
return bits | (1 << pos)
|
||||||
|
}
|
||||||
|
return bits &^ (1 << pos)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Address) IPv6() bool { return getFlagBit(a.bits, addrBitIPv6) }
|
||||||
|
func (a *Address) SetIPChecksum(v bool) {
|
||||||
|
a.bits = setFlagBit(a.bits, addrBitIPChecksum, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Address) SetTCPChecksum(v bool) {
|
||||||
|
a.bits = setFlagBit(a.bits, addrBitTCPChecksum, v)
|
||||||
|
}
|
||||||
@@ -1,3 +1,8 @@
|
|||||||
package constant
|
package constant
|
||||||
|
|
||||||
const ACMETLS1Protocol = "acme-tls/1"
|
const ACMETLS1Protocol = "acme-tls/1"
|
||||||
|
|
||||||
|
const (
|
||||||
|
TLSEngineDefault = ""
|
||||||
|
TLSEngineApple = "apple"
|
||||||
|
)
|
||||||
|
|||||||
462
dns/client.go
462
dns/client.go
@@ -30,59 +30,63 @@ var (
|
|||||||
var _ adapter.DNSClient = (*Client)(nil)
|
var _ adapter.DNSClient = (*Client)(nil)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
|
ctx context.Context
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
disableCache bool
|
disableCache bool
|
||||||
disableExpire bool
|
disableExpire bool
|
||||||
independentCache bool
|
optimisticTimeout time.Duration
|
||||||
|
cacheCapacity uint32
|
||||||
clientSubnet netip.Prefix
|
clientSubnet netip.Prefix
|
||||||
rdrc adapter.RDRCStore
|
rdrc adapter.RDRCStore
|
||||||
initRDRCFunc func() adapter.RDRCStore
|
initRDRCFunc func() adapter.RDRCStore
|
||||||
|
dnsCache adapter.DNSCacheStore
|
||||||
|
initDNSCacheFunc func() adapter.DNSCacheStore
|
||||||
logger logger.ContextLogger
|
logger logger.ContextLogger
|
||||||
cache freelru.Cache[dns.Question, *dns.Msg]
|
cache freelru.Cache[dnsCacheKey, *dns.Msg]
|
||||||
cacheLock compatible.Map[dns.Question, chan struct{}]
|
cacheLock compatible.Map[dnsCacheKey, chan struct{}]
|
||||||
transportCache freelru.Cache[transportCacheKey, *dns.Msg]
|
backgroundRefresh compatible.Map[dnsCacheKey, struct{}]
|
||||||
transportCacheLock compatible.Map[dns.Question, chan struct{}]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientOptions struct {
|
type ClientOptions struct {
|
||||||
|
Context context.Context
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
DisableCache bool
|
DisableCache bool
|
||||||
DisableExpire bool
|
DisableExpire bool
|
||||||
IndependentCache bool
|
OptimisticTimeout time.Duration
|
||||||
CacheCapacity uint32
|
CacheCapacity uint32
|
||||||
ClientSubnet netip.Prefix
|
ClientSubnet netip.Prefix
|
||||||
RDRC func() adapter.RDRCStore
|
RDRC func() adapter.RDRCStore
|
||||||
|
DNSCache func() adapter.DNSCacheStore
|
||||||
Logger logger.ContextLogger
|
Logger logger.ContextLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(options ClientOptions) *Client {
|
func NewClient(options ClientOptions) *Client {
|
||||||
|
cacheCapacity := options.CacheCapacity
|
||||||
|
if cacheCapacity < 1024 {
|
||||||
|
cacheCapacity = 1024
|
||||||
|
}
|
||||||
client := &Client{
|
client := &Client{
|
||||||
|
ctx: options.Context,
|
||||||
timeout: options.Timeout,
|
timeout: options.Timeout,
|
||||||
disableCache: options.DisableCache,
|
disableCache: options.DisableCache,
|
||||||
disableExpire: options.DisableExpire,
|
disableExpire: options.DisableExpire,
|
||||||
independentCache: options.IndependentCache,
|
optimisticTimeout: options.OptimisticTimeout,
|
||||||
|
cacheCapacity: cacheCapacity,
|
||||||
clientSubnet: options.ClientSubnet,
|
clientSubnet: options.ClientSubnet,
|
||||||
initRDRCFunc: options.RDRC,
|
initRDRCFunc: options.RDRC,
|
||||||
|
initDNSCacheFunc: options.DNSCache,
|
||||||
logger: options.Logger,
|
logger: options.Logger,
|
||||||
}
|
}
|
||||||
if client.timeout == 0 {
|
if client.timeout == 0 {
|
||||||
client.timeout = C.DNSTimeout
|
client.timeout = C.DNSTimeout
|
||||||
}
|
}
|
||||||
cacheCapacity := options.CacheCapacity
|
if !client.disableCache && client.initDNSCacheFunc == nil {
|
||||||
if cacheCapacity < 1024 {
|
client.initializeMemoryCache()
|
||||||
cacheCapacity = 1024
|
|
||||||
}
|
|
||||||
if !client.disableCache {
|
|
||||||
if !client.independentCache {
|
|
||||||
client.cache = common.Must1(freelru.NewSharded[dns.Question, *dns.Msg](cacheCapacity, maphash.NewHasher[dns.Question]().Hash32))
|
|
||||||
} else {
|
|
||||||
client.transportCache = common.Must1(freelru.NewSharded[transportCacheKey, *dns.Msg](cacheCapacity, maphash.NewHasher[transportCacheKey]().Hash32))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
type transportCacheKey struct {
|
type dnsCacheKey struct {
|
||||||
dns.Question
|
dns.Question
|
||||||
transportTag string
|
transportTag string
|
||||||
}
|
}
|
||||||
@@ -91,6 +95,19 @@ func (c *Client) Start() {
|
|||||||
if c.initRDRCFunc != nil {
|
if c.initRDRCFunc != nil {
|
||||||
c.rdrc = c.initRDRCFunc()
|
c.rdrc = c.initRDRCFunc()
|
||||||
}
|
}
|
||||||
|
if c.initDNSCacheFunc != nil {
|
||||||
|
c.dnsCache = c.initDNSCacheFunc()
|
||||||
|
}
|
||||||
|
if c.dnsCache == nil {
|
||||||
|
c.initializeMemoryCache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) initializeMemoryCache() {
|
||||||
|
if c.disableCache || c.cache != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.cache = common.Must1(freelru.NewSharded[dnsCacheKey, *dns.Msg](c.cacheCapacity, maphash.NewHasher[dnsCacheKey]().Hash32))
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractNegativeTTL(response *dns.Msg) (uint32, bool) {
|
func extractNegativeTTL(response *dns.Msg) (uint32, bool) {
|
||||||
@@ -107,6 +124,37 @@ func extractNegativeTTL(response *dns.Msg) (uint32, bool) {
|
|||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func computeTimeToLive(response *dns.Msg) uint32 {
|
||||||
|
var timeToLive uint32
|
||||||
|
if len(response.Answer) == 0 {
|
||||||
|
if soaTTL, hasSOA := extractNegativeTTL(response); hasSOA {
|
||||||
|
return soaTTL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
|
||||||
|
for _, record := range recordList {
|
||||||
|
if record.Header().Rrtype == dns.TypeOPT {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if timeToLive == 0 || record.Header().Ttl > 0 && record.Header().Ttl < timeToLive {
|
||||||
|
timeToLive = record.Header().Ttl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return timeToLive
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeTTL(response *dns.Msg, timeToLive uint32) {
|
||||||
|
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
|
||||||
|
for _, record := range recordList {
|
||||||
|
if record.Header().Rrtype == dns.TypeOPT {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
record.Header().Ttl = timeToLive
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) (*dns.Msg, error) {
|
func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) (*dns.Msg, error) {
|
||||||
if len(message.Question) == 0 {
|
if len(message.Question) == 0 {
|
||||||
if c.logger != nil {
|
if c.logger != nil {
|
||||||
@@ -121,13 +169,7 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
|
|||||||
}
|
}
|
||||||
return FixedResponseStatus(message, dns.RcodeSuccess), nil
|
return FixedResponseStatus(message, dns.RcodeSuccess), nil
|
||||||
}
|
}
|
||||||
clientSubnet := options.ClientSubnet
|
message = c.prepareExchangeMessage(message, options)
|
||||||
if !clientSubnet.IsValid() {
|
|
||||||
clientSubnet = c.clientSubnet
|
|
||||||
}
|
|
||||||
if clientSubnet.IsValid() {
|
|
||||||
message = SetClientSubnet(message, clientSubnet)
|
|
||||||
}
|
|
||||||
|
|
||||||
isSimpleRequest := len(message.Question) == 1 &&
|
isSimpleRequest := len(message.Question) == 1 &&
|
||||||
len(message.Ns) == 0 &&
|
len(message.Ns) == 0 &&
|
||||||
@@ -139,8 +181,8 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
|
|||||||
!options.ClientSubnet.IsValid()
|
!options.ClientSubnet.IsValid()
|
||||||
disableCache := !isSimpleRequest || c.disableCache || options.DisableCache
|
disableCache := !isSimpleRequest || c.disableCache || options.DisableCache
|
||||||
if !disableCache {
|
if !disableCache {
|
||||||
if c.cache != nil {
|
cacheKey := dnsCacheKey{Question: question, transportTag: transport.Tag()}
|
||||||
cond, loaded := c.cacheLock.LoadOrStore(question, make(chan struct{}))
|
cond, loaded := c.cacheLock.LoadOrStore(cacheKey, make(chan struct{}))
|
||||||
if loaded {
|
if loaded {
|
||||||
select {
|
select {
|
||||||
case <-cond:
|
case <-cond:
|
||||||
@@ -149,32 +191,24 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
defer func() {
|
defer func() {
|
||||||
c.cacheLock.Delete(question)
|
c.cacheLock.Delete(cacheKey)
|
||||||
close(cond)
|
close(cond)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
} else if c.transportCache != nil {
|
response, ttl, isStale := c.loadResponse(question, transport)
|
||||||
cond, loaded := c.transportCacheLock.LoadOrStore(question, make(chan struct{}))
|
|
||||||
if loaded {
|
|
||||||
select {
|
|
||||||
case <-cond:
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
defer func() {
|
|
||||||
c.transportCacheLock.Delete(question)
|
|
||||||
close(cond)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
response, ttl := c.loadResponse(question, transport)
|
|
||||||
if response != nil {
|
if response != nil {
|
||||||
|
if isStale && !options.DisableOptimisticCache {
|
||||||
|
c.backgroundRefreshDNS(transport, question, message.Copy(), options, responseChecker)
|
||||||
|
logOptimisticResponse(c.logger, ctx, response)
|
||||||
|
response.Id = message.Id
|
||||||
|
return response, nil
|
||||||
|
} else if !isStale {
|
||||||
logCachedResponse(c.logger, ctx, response, ttl)
|
logCachedResponse(c.logger, ctx, response, ttl)
|
||||||
response.Id = message.Id
|
response.Id = message.Id
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
messageId := message.Id
|
messageId := message.Id
|
||||||
contextTransport, clientSubnetLoaded := transportTagFromContext(ctx)
|
contextTransport, clientSubnetLoaded := transportTagFromContext(ctx)
|
||||||
@@ -188,52 +222,10 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
|
|||||||
return nil, ErrResponseRejectedCached
|
return nil, ErrResponseRejectedCached
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(ctx, c.timeout)
|
response, err := c.exchangeToTransport(ctx, transport, message)
|
||||||
response, err := transport.Exchange(ctx, message)
|
|
||||||
cancel()
|
|
||||||
if err != nil {
|
|
||||||
var rcodeError RcodeError
|
|
||||||
if errors.As(err, &rcodeError) {
|
|
||||||
response = FixedResponseStatus(message, int(rcodeError))
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/*if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
|
|
||||||
validResponse := response
|
|
||||||
loop:
|
|
||||||
for {
|
|
||||||
var (
|
|
||||||
addresses int
|
|
||||||
queryCNAME string
|
|
||||||
)
|
|
||||||
for _, rawRR := range validResponse.Answer {
|
|
||||||
switch rr := rawRR.(type) {
|
|
||||||
case *dns.A:
|
|
||||||
break loop
|
|
||||||
case *dns.AAAA:
|
|
||||||
break loop
|
|
||||||
case *dns.CNAME:
|
|
||||||
queryCNAME = rr.Target
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if queryCNAME == "" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
exMessage := *message
|
|
||||||
exMessage.Question = []dns.Question{{
|
|
||||||
Name: queryCNAME,
|
|
||||||
Qtype: question.Qtype,
|
|
||||||
}}
|
|
||||||
validResponse, err = c.Exchange(ctx, transport, &exMessage, options, responseChecker)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if validResponse != response {
|
|
||||||
response.Answer = append(response.Answer, validResponse.Answer...)
|
|
||||||
}
|
|
||||||
}*/
|
|
||||||
disableCache = disableCache || (response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError)
|
disableCache = disableCache || (response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError)
|
||||||
if responseChecker != nil {
|
if responseChecker != nil {
|
||||||
var rejected bool
|
var rejected bool
|
||||||
@@ -250,54 +242,7 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
|
|||||||
return response, ErrResponseRejected
|
return response, ErrResponseRejected
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if question.Qtype == dns.TypeHTTPS {
|
timeToLive := applyResponseOptions(question, response, options)
|
||||||
if options.Strategy == C.DomainStrategyIPv4Only || options.Strategy == C.DomainStrategyIPv6Only {
|
|
||||||
for _, rr := range response.Answer {
|
|
||||||
https, isHTTPS := rr.(*dns.HTTPS)
|
|
||||||
if !isHTTPS {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
content := https.SVCB
|
|
||||||
content.Value = common.Filter(content.Value, func(it dns.SVCBKeyValue) bool {
|
|
||||||
if options.Strategy == C.DomainStrategyIPv4Only {
|
|
||||||
return it.Key() != dns.SVCB_IPV6HINT
|
|
||||||
} else {
|
|
||||||
return it.Key() != dns.SVCB_IPV4HINT
|
|
||||||
}
|
|
||||||
})
|
|
||||||
https.SVCB = content
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var timeToLive uint32
|
|
||||||
if len(response.Answer) == 0 {
|
|
||||||
if soaTTL, hasSOA := extractNegativeTTL(response); hasSOA {
|
|
||||||
timeToLive = soaTTL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if timeToLive == 0 {
|
|
||||||
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
|
|
||||||
for _, record := range recordList {
|
|
||||||
if record.Header().Rrtype == dns.TypeOPT {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if timeToLive == 0 || record.Header().Ttl > 0 && record.Header().Ttl < timeToLive {
|
|
||||||
timeToLive = record.Header().Ttl
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if options.RewriteTTL != nil {
|
|
||||||
timeToLive = *options.RewriteTTL
|
|
||||||
}
|
|
||||||
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
|
|
||||||
for _, record := range recordList {
|
|
||||||
if record.Header().Rrtype == dns.TypeOPT {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
record.Header().Ttl = timeToLive
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !disableCache {
|
if !disableCache {
|
||||||
c.storeCache(transport, question, response, timeToLive)
|
c.storeCache(transport, question, response, timeToLive)
|
||||||
}
|
}
|
||||||
@@ -363,8 +308,12 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
|
|||||||
func (c *Client) ClearCache() {
|
func (c *Client) ClearCache() {
|
||||||
if c.cache != nil {
|
if c.cache != nil {
|
||||||
c.cache.Purge()
|
c.cache.Purge()
|
||||||
} else if c.transportCache != nil {
|
}
|
||||||
c.transportCache.Purge()
|
if c.dnsCache != nil {
|
||||||
|
err := c.dnsCache.ClearDNSCache()
|
||||||
|
if err != nil && c.logger != nil {
|
||||||
|
c.logger.Warn("clear DNS cache: ", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -380,24 +329,22 @@ func (c *Client) storeCache(transport adapter.DNSTransport, question dns.Questio
|
|||||||
if timeToLive == 0 {
|
if timeToLive == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if c.dnsCache != nil {
|
||||||
|
packed, err := message.Pack()
|
||||||
|
if err == nil {
|
||||||
|
expireAt := time.Now().Add(time.Second * time.Duration(timeToLive))
|
||||||
|
c.dnsCache.SaveDNSCacheAsync(transport.Tag(), question.Name, question.Qtype, packed, expireAt, c.logger)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c.cache == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key := dnsCacheKey{Question: question, transportTag: transport.Tag()}
|
||||||
if c.disableExpire {
|
if c.disableExpire {
|
||||||
if !c.independentCache {
|
c.cache.Add(key, message.Copy())
|
||||||
c.cache.Add(question, message.Copy())
|
|
||||||
} else {
|
} else {
|
||||||
c.transportCache.Add(transportCacheKey{
|
c.cache.AddWithLifetime(key, message.Copy(), time.Second*time.Duration(timeToLive))
|
||||||
Question: question,
|
|
||||||
transportTag: transport.Tag(),
|
|
||||||
}, message.Copy())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if !c.independentCache {
|
|
||||||
c.cache.AddWithLifetime(question, message.Copy(), time.Second*time.Duration(timeToLive))
|
|
||||||
} else {
|
|
||||||
c.transportCache.AddWithLifetime(transportCacheKey{
|
|
||||||
Question: question,
|
|
||||||
transportTag: transport.Tag(),
|
|
||||||
}, message.Copy(), time.Second*time.Duration(timeToLive))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -407,19 +354,19 @@ func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTran
|
|||||||
Qtype: qType,
|
Qtype: qType,
|
||||||
Qclass: dns.ClassINET,
|
Qclass: dns.ClassINET,
|
||||||
}
|
}
|
||||||
disableCache := c.disableCache || options.DisableCache
|
|
||||||
if !disableCache {
|
|
||||||
cachedAddresses, err := c.questionCache(question, transport)
|
|
||||||
if err != ErrNotCached {
|
|
||||||
return cachedAddresses, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
message := dns.Msg{
|
message := dns.Msg{
|
||||||
MsgHdr: dns.MsgHdr{
|
MsgHdr: dns.MsgHdr{
|
||||||
RecursionDesired: true,
|
RecursionDesired: true,
|
||||||
},
|
},
|
||||||
Question: []dns.Question{question},
|
Question: []dns.Question{question},
|
||||||
}
|
}
|
||||||
|
disableCache := c.disableCache || options.DisableCache
|
||||||
|
if !disableCache {
|
||||||
|
cachedAddresses, err := c.questionCache(ctx, transport, &message, options, responseChecker)
|
||||||
|
if err != ErrNotCached {
|
||||||
|
return cachedAddresses, err
|
||||||
|
}
|
||||||
|
}
|
||||||
response, err := c.Exchange(ctx, transport, &message, options, responseChecker)
|
response, err := c.Exchange(ctx, transport, &message, options, responseChecker)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -430,98 +377,177 @@ func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTran
|
|||||||
return MessageToAddresses(response), nil
|
return MessageToAddresses(response), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) questionCache(question dns.Question, transport adapter.DNSTransport) ([]netip.Addr, error) {
|
func (c *Client) questionCache(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) ([]netip.Addr, error) {
|
||||||
response, _ := c.loadResponse(question, transport)
|
question := message.Question[0]
|
||||||
|
response, _, isStale := c.loadResponse(question, transport)
|
||||||
if response == nil {
|
if response == nil {
|
||||||
return nil, ErrNotCached
|
return nil, ErrNotCached
|
||||||
}
|
}
|
||||||
|
if isStale {
|
||||||
|
if options.DisableOptimisticCache {
|
||||||
|
return nil, ErrNotCached
|
||||||
|
}
|
||||||
|
c.backgroundRefreshDNS(transport, question, c.prepareExchangeMessage(message.Copy(), options), options, responseChecker)
|
||||||
|
logOptimisticResponse(c.logger, ctx, response)
|
||||||
|
}
|
||||||
if response.Rcode != dns.RcodeSuccess {
|
if response.Rcode != dns.RcodeSuccess {
|
||||||
return nil, RcodeError(response.Rcode)
|
return nil, RcodeError(response.Rcode)
|
||||||
}
|
}
|
||||||
return MessageToAddresses(response), nil
|
return MessageToAddresses(response), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) loadResponse(question dns.Question, transport adapter.DNSTransport) (*dns.Msg, int) {
|
func (c *Client) loadResponse(question dns.Question, transport adapter.DNSTransport) (*dns.Msg, int, bool) {
|
||||||
var (
|
if c.dnsCache != nil {
|
||||||
response *dns.Msg
|
return c.loadPersistentResponse(question, transport)
|
||||||
loaded bool
|
}
|
||||||
)
|
if c.cache == nil {
|
||||||
|
return nil, 0, false
|
||||||
|
}
|
||||||
|
key := dnsCacheKey{Question: question, transportTag: transport.Tag()}
|
||||||
if c.disableExpire {
|
if c.disableExpire {
|
||||||
if !c.independentCache {
|
response, loaded := c.cache.Get(key)
|
||||||
response, loaded = c.cache.Get(question)
|
|
||||||
} else {
|
|
||||||
response, loaded = c.transportCache.Get(transportCacheKey{
|
|
||||||
Question: question,
|
|
||||||
transportTag: transport.Tag(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if !loaded {
|
if !loaded {
|
||||||
return nil, 0
|
return nil, 0, false
|
||||||
}
|
}
|
||||||
return response.Copy(), 0
|
return response.Copy(), 0, false
|
||||||
} else {
|
|
||||||
var expireAt time.Time
|
|
||||||
if !c.independentCache {
|
|
||||||
response, expireAt, loaded = c.cache.GetWithLifetime(question)
|
|
||||||
} else {
|
|
||||||
response, expireAt, loaded = c.transportCache.GetWithLifetime(transportCacheKey{
|
|
||||||
Question: question,
|
|
||||||
transportTag: transport.Tag(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
response, expireAt, loaded := c.cache.GetWithLifetimeNoExpire(key)
|
||||||
if !loaded {
|
if !loaded {
|
||||||
return nil, 0
|
return nil, 0, false
|
||||||
}
|
}
|
||||||
timeNow := time.Now()
|
timeNow := time.Now()
|
||||||
if timeNow.After(expireAt) {
|
if timeNow.After(expireAt) {
|
||||||
if !c.independentCache {
|
if c.optimisticTimeout > 0 && timeNow.Before(expireAt.Add(c.optimisticTimeout)) {
|
||||||
c.cache.Remove(question)
|
response = response.Copy()
|
||||||
} else {
|
normalizeTTL(response, 1)
|
||||||
c.transportCache.Remove(transportCacheKey{
|
return response, 0, true
|
||||||
Question: question,
|
|
||||||
transportTag: transport.Tag(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return nil, 0
|
|
||||||
}
|
|
||||||
var originTTL int
|
|
||||||
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
|
|
||||||
for _, record := range recordList {
|
|
||||||
if record.Header().Rrtype == dns.TypeOPT {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if originTTL == 0 || record.Header().Ttl > 0 && int(record.Header().Ttl) < originTTL {
|
|
||||||
originTTL = int(record.Header().Ttl)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
c.cache.Remove(key)
|
||||||
|
return nil, 0, false
|
||||||
}
|
}
|
||||||
nowTTL := int(expireAt.Sub(timeNow).Seconds())
|
nowTTL := int(expireAt.Sub(timeNow).Seconds())
|
||||||
if nowTTL < 0 {
|
if nowTTL < 0 {
|
||||||
nowTTL = 0
|
nowTTL = 0
|
||||||
}
|
}
|
||||||
response = response.Copy()
|
response = response.Copy()
|
||||||
if originTTL > 0 {
|
normalizeTTL(response, uint32(nowTTL))
|
||||||
duration := uint32(originTTL - nowTTL)
|
return response, nowTTL, false
|
||||||
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
|
}
|
||||||
for _, record := range recordList {
|
|
||||||
if record.Header().Rrtype == dns.TypeOPT {
|
func (c *Client) loadPersistentResponse(question dns.Question, transport adapter.DNSTransport) (*dns.Msg, int, bool) {
|
||||||
|
rawMessage, expireAt, loaded := c.dnsCache.LoadDNSCache(transport.Tag(), question.Name, question.Qtype)
|
||||||
|
if !loaded {
|
||||||
|
return nil, 0, false
|
||||||
|
}
|
||||||
|
response := new(dns.Msg)
|
||||||
|
err := response.Unpack(rawMessage)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, false
|
||||||
|
}
|
||||||
|
if c.disableExpire {
|
||||||
|
return response, 0, false
|
||||||
|
}
|
||||||
|
timeNow := time.Now()
|
||||||
|
if timeNow.After(expireAt) {
|
||||||
|
if c.optimisticTimeout > 0 && timeNow.Before(expireAt.Add(c.optimisticTimeout)) {
|
||||||
|
normalizeTTL(response, 1)
|
||||||
|
return response, 0, true
|
||||||
|
}
|
||||||
|
return nil, 0, false
|
||||||
|
}
|
||||||
|
nowTTL := int(expireAt.Sub(timeNow).Seconds())
|
||||||
|
if nowTTL < 0 {
|
||||||
|
nowTTL = 0
|
||||||
|
}
|
||||||
|
normalizeTTL(response, uint32(nowTTL))
|
||||||
|
return response, nowTTL, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyResponseOptions(question dns.Question, response *dns.Msg, options adapter.DNSQueryOptions) uint32 {
|
||||||
|
if question.Qtype == dns.TypeHTTPS && (options.Strategy == C.DomainStrategyIPv4Only || options.Strategy == C.DomainStrategyIPv6Only) {
|
||||||
|
for _, rr := range response.Answer {
|
||||||
|
https, isHTTPS := rr.(*dns.HTTPS)
|
||||||
|
if !isHTTPS {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
record.Header().Ttl = record.Header().Ttl - duration
|
content := https.SVCB
|
||||||
|
content.Value = common.Filter(content.Value, func(it dns.SVCBKeyValue) bool {
|
||||||
|
if options.Strategy == C.DomainStrategyIPv4Only {
|
||||||
|
return it.Key() != dns.SVCB_IPV6HINT
|
||||||
|
}
|
||||||
|
return it.Key() != dns.SVCB_IPV4HINT
|
||||||
|
})
|
||||||
|
https.SVCB = content
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
timeToLive := computeTimeToLive(response)
|
||||||
|
if options.RewriteTTL != nil {
|
||||||
|
timeToLive = *options.RewriteTTL
|
||||||
|
}
|
||||||
|
normalizeTTL(response, timeToLive)
|
||||||
|
return timeToLive
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) backgroundRefreshDNS(transport adapter.DNSTransport, question dns.Question, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) {
|
||||||
|
key := dnsCacheKey{Question: question, transportTag: transport.Tag()}
|
||||||
|
_, loaded := c.backgroundRefresh.LoadOrStore(key, struct{}{})
|
||||||
|
if loaded {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer c.backgroundRefresh.Delete(key)
|
||||||
|
ctx := contextWithTransportTag(c.ctx, transport.Tag())
|
||||||
|
response, err := c.exchangeToTransport(ctx, transport, message)
|
||||||
|
if err != nil {
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Debug("optimistic refresh failed for ", FqdnToDomain(question.Name), ": ", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if responseChecker != nil {
|
||||||
|
var rejected bool
|
||||||
|
if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError {
|
||||||
|
rejected = true
|
||||||
} else {
|
} else {
|
||||||
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
|
rejected = !responseChecker(response)
|
||||||
for _, record := range recordList {
|
|
||||||
if record.Header().Rrtype == dns.TypeOPT {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
record.Header().Ttl = uint32(nowTTL)
|
if rejected {
|
||||||
|
if c.rdrc != nil {
|
||||||
|
c.rdrc.SaveRDRCAsync(transport.Tag(), question.Name, question.Qtype, c.logger)
|
||||||
}
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
} else if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
return response, nowTTL
|
timeToLive := applyResponseOptions(question, response, options)
|
||||||
|
c.storeCache(transport, question, response, timeToLive)
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) prepareExchangeMessage(message *dns.Msg, options adapter.DNSQueryOptions) *dns.Msg {
|
||||||
|
clientSubnet := options.ClientSubnet
|
||||||
|
if !clientSubnet.IsValid() {
|
||||||
|
clientSubnet = c.clientSubnet
|
||||||
|
}
|
||||||
|
if clientSubnet.IsValid() {
|
||||||
|
message = SetClientSubnet(message, clientSubnet)
|
||||||
|
}
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) exchangeToTransport(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg) (*dns.Msg, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, c.timeout)
|
||||||
|
defer cancel()
|
||||||
|
response, err := transport.Exchange(ctx, message)
|
||||||
|
if err == nil {
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
var rcodeError RcodeError
|
||||||
|
if errors.As(err, &rcodeError) {
|
||||||
|
return FixedResponseStatus(message, int(rcodeError)), nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func MessageToAddresses(response *dns.Msg) []netip.Addr {
|
func MessageToAddresses(response *dns.Msg) []netip.Addr {
|
||||||
|
|||||||
@@ -22,6 +22,19 @@ func logCachedResponse(logger logger.ContextLogger, ctx context.Context, respons
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func logOptimisticResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg) {
|
||||||
|
if logger == nil || len(response.Question) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
domain := FqdnToDomain(response.Question[0].Name)
|
||||||
|
logger.DebugContext(ctx, "optimistic ", domain, " ", dns.RcodeToString[response.Rcode])
|
||||||
|
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
|
||||||
|
for _, record := range recordList {
|
||||||
|
logger.InfoContext(ctx, "optimistic ", dns.Type(record.Header().Rrtype).String(), " ", FormatQuestion(record.String()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func logExchangedResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg, ttl uint32) {
|
func logExchangedResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg, ttl uint32) {
|
||||||
if logger == nil || len(response.Question) == 0 {
|
if logger == nil || len(response.Question) == 0 {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ type Router struct {
|
|||||||
closing bool
|
closing bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOptions) *Router {
|
func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOptions) (*Router, error) {
|
||||||
router := &Router{
|
router := &Router{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
logger: logFactory.NewLogger("dns"),
|
logger: logFactory.NewLogger("dns"),
|
||||||
@@ -61,10 +61,28 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp
|
|||||||
rules: make([]adapter.DNSRule, 0, len(options.Rules)),
|
rules: make([]adapter.DNSRule, 0, len(options.Rules)),
|
||||||
defaultDomainStrategy: C.DomainStrategy(options.Strategy),
|
defaultDomainStrategy: C.DomainStrategy(options.Strategy),
|
||||||
}
|
}
|
||||||
|
if options.DNSClientOptions.IndependentCache {
|
||||||
|
deprecated.Report(ctx, deprecated.OptionIndependentDNSCache)
|
||||||
|
}
|
||||||
|
var optimisticTimeout time.Duration
|
||||||
|
optimisticOptions := common.PtrValueOrDefault(options.DNSClientOptions.Optimistic)
|
||||||
|
if optimisticOptions.Enabled {
|
||||||
|
if options.DNSClientOptions.DisableCache {
|
||||||
|
return nil, E.New("`optimistic` is conflict with `disable_cache`")
|
||||||
|
}
|
||||||
|
if options.DNSClientOptions.DisableExpire {
|
||||||
|
return nil, E.New("`optimistic` is conflict with `disable_expire`")
|
||||||
|
}
|
||||||
|
optimisticTimeout = time.Duration(optimisticOptions.Timeout)
|
||||||
|
if optimisticTimeout == 0 {
|
||||||
|
optimisticTimeout = 3 * 24 * time.Hour
|
||||||
|
}
|
||||||
|
}
|
||||||
router.client = NewClient(ClientOptions{
|
router.client = NewClient(ClientOptions{
|
||||||
|
Context: ctx,
|
||||||
DisableCache: options.DNSClientOptions.DisableCache,
|
DisableCache: options.DNSClientOptions.DisableCache,
|
||||||
DisableExpire: options.DNSClientOptions.DisableExpire,
|
DisableExpire: options.DNSClientOptions.DisableExpire,
|
||||||
IndependentCache: options.DNSClientOptions.IndependentCache,
|
OptimisticTimeout: optimisticTimeout,
|
||||||
CacheCapacity: options.DNSClientOptions.CacheCapacity,
|
CacheCapacity: options.DNSClientOptions.CacheCapacity,
|
||||||
ClientSubnet: options.DNSClientOptions.ClientSubnet.Build(netip.Prefix{}),
|
ClientSubnet: options.DNSClientOptions.ClientSubnet.Build(netip.Prefix{}),
|
||||||
RDRC: func() adapter.RDRCStore {
|
RDRC: func() adapter.RDRCStore {
|
||||||
@@ -77,12 +95,24 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp
|
|||||||
}
|
}
|
||||||
return cacheFile
|
return cacheFile
|
||||||
},
|
},
|
||||||
|
DNSCache: func() adapter.DNSCacheStore {
|
||||||
|
cacheFile := service.FromContext[adapter.CacheFile](ctx)
|
||||||
|
if cacheFile == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !cacheFile.StoreDNS() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cacheFile.SetDisableExpire(options.DNSClientOptions.DisableExpire)
|
||||||
|
cacheFile.SetOptimisticTimeout(optimisticTimeout)
|
||||||
|
return cacheFile
|
||||||
|
},
|
||||||
Logger: router.logger,
|
Logger: router.logger,
|
||||||
})
|
})
|
||||||
if options.ReverseMapping {
|
if options.ReverseMapping {
|
||||||
router.dnsReverseMapping = common.Must1(freelru.NewSharded[netip.Addr, string](1024, maphash.NewHasher[netip.Addr]().Hash32))
|
router.dnsReverseMapping = common.Must1(freelru.NewSharded[netip.Addr, string](1024, maphash.NewHasher[netip.Addr]().Hash32))
|
||||||
}
|
}
|
||||||
return router
|
return router, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) Initialize(rules []option.DNSRule) error {
|
func (r *Router) Initialize(rules []option.DNSRule) error {
|
||||||
@@ -156,7 +186,7 @@ func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, dnsRuleMo
|
|||||||
return nil, false, dnsRuleModeFlags{}, err
|
return nil, false, dnsRuleModeFlags{}, err
|
||||||
}
|
}
|
||||||
if !legacyDNSMode {
|
if !legacyDNSMode {
|
||||||
err = validateLegacyDNSModeDisabledRules(r.rawRules)
|
err = validateLegacyDNSModeDisabledRules(router, r.rawRules, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, dnsRuleModeFlags{}, err
|
return nil, false, dnsRuleModeFlags{}, err
|
||||||
}
|
}
|
||||||
@@ -218,7 +248,7 @@ func (r *Router) ValidateRuleSetMetadataUpdate(tag string, metadata adapter.Rule
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !candidateLegacyDNSMode {
|
if !candidateLegacyDNSMode {
|
||||||
return validateLegacyDNSModeDisabledRules(r.rawRules)
|
return validateLegacyDNSModeDisabledRules(router, r.rawRules, overrides)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -228,7 +258,7 @@ func (r *Router) ValidateRuleSetMetadataUpdate(tag string, metadata adapter.Rule
|
|||||||
}
|
}
|
||||||
if legacyDNSMode {
|
if legacyDNSMode {
|
||||||
if !candidateLegacyDNSMode && flags.disabled {
|
if !candidateLegacyDNSMode && flags.disabled {
|
||||||
err := validateLegacyDNSModeDisabledRules(r.rawRules)
|
err := validateLegacyDNSModeDisabledRules(router, r.rawRules, overrides)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -239,7 +269,7 @@ func (r *Router) ValidateRuleSetMetadataUpdate(tag string, metadata adapter.Rule
|
|||||||
if candidateLegacyDNSMode {
|
if candidateLegacyDNSMode {
|
||||||
return E.New(deprecated.OptionLegacyDNSAddressFilter.MessageWithLink())
|
return E.New(deprecated.OptionLegacyDNSAddressFilter.MessageWithLink())
|
||||||
}
|
}
|
||||||
return nil
|
return validateLegacyDNSModeDisabledRules(router, r.rawRules, overrides)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) matchDNS(ctx context.Context, rules []adapter.DNSRule, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) {
|
func (r *Router) matchDNS(ctx context.Context, rules []adapter.DNSRule, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) {
|
||||||
@@ -319,6 +349,9 @@ func (r *Router) applyDNSRouteOptions(options *adapter.DNSQueryOptions, routeOpt
|
|||||||
if routeOptions.DisableCache {
|
if routeOptions.DisableCache {
|
||||||
options.DisableCache = true
|
options.DisableCache = true
|
||||||
}
|
}
|
||||||
|
if routeOptions.DisableOptimisticCache {
|
||||||
|
options.DisableOptimisticCache = true
|
||||||
|
}
|
||||||
if routeOptions.RewriteTTL != nil {
|
if routeOptions.RewriteTTL != nil {
|
||||||
options.RewriteTTL = routeOptions.RewriteTTL
|
options.RewriteTTL = routeOptions.RewriteTTL
|
||||||
}
|
}
|
||||||
@@ -907,7 +940,9 @@ func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule, m
|
|||||||
return dnsRuleModeRequirementsInDefaultRule(router, rule.DefaultOptions, metadataOverrides)
|
return dnsRuleModeRequirementsInDefaultRule(router, rule.DefaultOptions, metadataOverrides)
|
||||||
case C.RuleTypeLogical:
|
case C.RuleTypeLogical:
|
||||||
flags := dnsRuleModeFlags{
|
flags := dnsRuleModeFlags{
|
||||||
disabled: dnsRuleActionType(rule) == C.RuleActionTypeEvaluate || dnsRuleActionType(rule) == C.RuleActionTypeRespond,
|
disabled: dnsRuleActionType(rule) == C.RuleActionTypeEvaluate ||
|
||||||
|
dnsRuleActionType(rule) == C.RuleActionTypeRespond ||
|
||||||
|
dnsRuleActionDisablesLegacyDNSMode(rule.LogicalOptions.DNSRuleAction),
|
||||||
neededFromStrategy: dnsRuleActionHasStrategy(rule.LogicalOptions.DNSRuleAction),
|
neededFromStrategy: dnsRuleActionHasStrategy(rule.LogicalOptions.DNSRuleAction),
|
||||||
}
|
}
|
||||||
flags.needed = flags.neededFromStrategy
|
flags.needed = flags.neededFromStrategy
|
||||||
@@ -926,7 +961,7 @@ func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule, m
|
|||||||
|
|
||||||
func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.DefaultDNSRule, metadataOverrides map[string]adapter.RuleSetMetadata) (dnsRuleModeFlags, error) {
|
func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.DefaultDNSRule, metadataOverrides map[string]adapter.RuleSetMetadata) (dnsRuleModeFlags, error) {
|
||||||
flags := dnsRuleModeFlags{
|
flags := dnsRuleModeFlags{
|
||||||
disabled: defaultRuleDisablesLegacyDNSMode(rule),
|
disabled: defaultRuleDisablesLegacyDNSMode(rule) || dnsRuleActionDisablesLegacyDNSMode(rule.DNSRuleAction),
|
||||||
neededFromStrategy: dnsRuleActionHasStrategy(rule.DNSRuleAction),
|
neededFromStrategy: dnsRuleActionHasStrategy(rule.DNSRuleAction),
|
||||||
}
|
}
|
||||||
flags.needed = defaultRuleNeedsLegacyDNSModeFromAddressFilter(rule) || flags.neededFromStrategy
|
flags.needed = defaultRuleNeedsLegacyDNSModeFromAddressFilter(rule) || flags.neededFromStrategy
|
||||||
@@ -990,10 +1025,10 @@ func referencedDNSRuleSetTags(rules []option.DNSRule) []string {
|
|||||||
return tags
|
return tags
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateLegacyDNSModeDisabledRules(rules []option.DNSRule) error {
|
func validateLegacyDNSModeDisabledRules(router adapter.Router, rules []option.DNSRule, metadataOverrides map[string]adapter.RuleSetMetadata) error {
|
||||||
var seenEvaluate bool
|
var seenEvaluate bool
|
||||||
for i, rule := range rules {
|
for i, rule := range rules {
|
||||||
requiresPriorEvaluate, err := validateLegacyDNSModeDisabledRuleTree(rule)
|
requiresPriorEvaluate, err := validateLegacyDNSModeDisabledRuleTree(router, rule, metadataOverrides)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "validate dns rule[", i, "]")
|
return E.Cause(err, "validate dns rule[", i, "]")
|
||||||
}
|
}
|
||||||
@@ -1028,14 +1063,14 @@ func validateEvaluateFakeIPRules(rules []option.DNSRule, transportManager adapte
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateLegacyDNSModeDisabledRuleTree(rule option.DNSRule) (bool, error) {
|
func validateLegacyDNSModeDisabledRuleTree(router adapter.Router, rule option.DNSRule, metadataOverrides map[string]adapter.RuleSetMetadata) (bool, error) {
|
||||||
switch rule.Type {
|
switch rule.Type {
|
||||||
case "", C.RuleTypeDefault:
|
case "", C.RuleTypeDefault:
|
||||||
return validateLegacyDNSModeDisabledDefaultRule(rule.DefaultOptions)
|
return validateLegacyDNSModeDisabledDefaultRule(router, rule.DefaultOptions, metadataOverrides)
|
||||||
case C.RuleTypeLogical:
|
case C.RuleTypeLogical:
|
||||||
requiresPriorEvaluate := dnsRuleActionType(rule) == C.RuleActionTypeRespond
|
requiresPriorEvaluate := dnsRuleActionType(rule) == C.RuleActionTypeRespond
|
||||||
for i, subRule := range rule.LogicalOptions.Rules {
|
for i, subRule := range rule.LogicalOptions.Rules {
|
||||||
subRequiresPriorEvaluate, err := validateLegacyDNSModeDisabledRuleTree(subRule)
|
subRequiresPriorEvaluate, err := validateLegacyDNSModeDisabledRuleTree(router, subRule, metadataOverrides)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, E.Cause(err, "sub rule[", i, "]")
|
return false, E.Cause(err, "sub rule[", i, "]")
|
||||||
}
|
}
|
||||||
@@ -1047,22 +1082,42 @@ func validateLegacyDNSModeDisabledRuleTree(rule option.DNSRule) (bool, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateLegacyDNSModeDisabledDefaultRule(rule option.DefaultDNSRule) (bool, error) {
|
func validateLegacyDNSModeDisabledDefaultRule(router adapter.Router, rule option.DefaultDNSRule, metadataOverrides map[string]adapter.RuleSetMetadata) (bool, error) {
|
||||||
hasResponseRecords := hasResponseMatchFields(rule)
|
hasResponseRecords := hasResponseMatchFields(rule)
|
||||||
if (hasResponseRecords || len(rule.IPCIDR) > 0 || rule.IPIsPrivate || rule.IPAcceptAny) && !rule.MatchResponse {
|
if (hasResponseRecords || len(rule.IPCIDR) > 0 || rule.IPIsPrivate || rule.IPAcceptAny) && !rule.MatchResponse {
|
||||||
return false, E.New("Response Match Fields (ip_cidr, ip_is_private, ip_accept_any, response_rcode, response_answer, response_ns, response_extra) require match_response to be enabled")
|
return false, E.New("Response Match Fields (ip_cidr, ip_is_private, ip_accept_any, response_rcode, response_answer, response_ns, response_extra) require match_response to be enabled")
|
||||||
}
|
}
|
||||||
// Intentionally do not reject rule_set here. A referenced rule set may mix
|
// rule_set entries are only rejected when every referenced set is pure-IP;
|
||||||
// destination-IP predicates with pre-response predicates such as domain items.
|
// mixed sets still fall through because their non-IP branches remain matchable
|
||||||
// When match_response is false, those destination-IP branches fail closed during
|
// before a DNS response is available.
|
||||||
// pre-response evaluation instead of consuming DNS response state, while sibling
|
if !rule.MatchResponse && len(rule.RuleSet) > 0 {
|
||||||
// non-response branches remain matchable.
|
for _, tag := range rule.RuleSet {
|
||||||
|
metadata, err := lookupDNSRuleSetMetadata(router, tag, metadataOverrides)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if metadata.ContainsIPCIDRRule && !metadata.ContainsNonIPCIDRRule {
|
||||||
|
return false, E.New(deprecated.OptionLegacyDNSAddressFilter.MessageWithLink())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if rule.RuleSetIPCIDRAcceptEmpty { //nolint:staticcheck
|
if rule.RuleSetIPCIDRAcceptEmpty { //nolint:staticcheck
|
||||||
return false, E.New(deprecated.OptionRuleSetIPCIDRAcceptEmpty.MessageWithLink())
|
return false, E.New(deprecated.OptionRuleSetIPCIDRAcceptEmpty.MessageWithLink())
|
||||||
}
|
}
|
||||||
return rule.MatchResponse || rule.Action == C.RuleActionTypeRespond, nil
|
return rule.MatchResponse || rule.Action == C.RuleActionTypeRespond, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func dnsRuleActionDisablesLegacyDNSMode(action option.DNSRuleAction) bool {
|
||||||
|
switch action.Action {
|
||||||
|
case "", C.RuleActionTypeRoute, C.RuleActionTypeEvaluate:
|
||||||
|
return action.RouteOptions.DisableOptimisticCache
|
||||||
|
case C.RuleActionTypeRouteOptions:
|
||||||
|
return action.RouteOptionsOptions.DisableOptimisticCache
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func dnsRuleActionHasStrategy(action option.DNSRuleAction) bool {
|
func dnsRuleActionHasStrategy(action option.DNSRuleAction) bool {
|
||||||
switch action.Action {
|
switch action.Action {
|
||||||
case "", C.RuleActionTypeRoute, C.RuleActionTypeEvaluate:
|
case "", C.RuleActionTypeRoute, C.RuleActionTypeEvaluate:
|
||||||
|
|||||||
@@ -762,6 +762,7 @@ func TestValidateRuleSetMetadataUpdateAllowsRuleSetThatKeepsNonLegacyDNSMode(t *
|
|||||||
|
|
||||||
err := router.ValidateRuleSetMetadataUpdate("dynamic-set", adapter.RuleSetMetadata{
|
err := router.ValidateRuleSetMetadataUpdate("dynamic-set", adapter.RuleSetMetadata{
|
||||||
ContainsIPCIDRRule: true,
|
ContainsIPCIDRRule: true,
|
||||||
|
ContainsNonIPCIDRRule: true,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
@@ -808,6 +809,163 @@ func TestValidateRuleSetMetadataUpdateAllowsRelaxingLegacyRequirement(t *testing
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInitializeRejectsPureIPRuleSetWhenLegacyDNSModeDisabled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
fakeSet := &fakeRuleSet{
|
||||||
|
metadata: adapter.RuleSetMetadata{
|
||||||
|
ContainsIPCIDRRule: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
routerService := &fakeRouter{
|
||||||
|
ruleSets: map[string]adapter.RuleSet{
|
||||||
|
"pure-ip": fakeSet,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := service.ContextWith[adapter.Router](context.Background(), routerService)
|
||||||
|
router := &Router{
|
||||||
|
ctx: ctx,
|
||||||
|
logger: log.NewNOPFactory().NewLogger("dns"),
|
||||||
|
transport: &fakeDNSTransportManager{},
|
||||||
|
client: &fakeDNSClient{},
|
||||||
|
rawRules: make([]option.DNSRule, 0, 2),
|
||||||
|
rules: make([]adapter.DNSRule, 0, 2),
|
||||||
|
defaultDomainStrategy: C.DomainStrategyAsIS,
|
||||||
|
}
|
||||||
|
err := router.Initialize([]option.DNSRule{
|
||||||
|
{
|
||||||
|
Type: C.RuleTypeDefault,
|
||||||
|
DefaultOptions: option.DefaultDNSRule{
|
||||||
|
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||||
|
QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(mDNS.TypeA)},
|
||||||
|
},
|
||||||
|
DNSRuleAction: option.DNSRuleAction{
|
||||||
|
Action: C.RuleActionTypeRoute,
|
||||||
|
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: C.RuleTypeDefault,
|
||||||
|
DefaultOptions: option.DefaultDNSRule{
|
||||||
|
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||||
|
RuleSet: badoption.Listable[string]{"pure-ip"},
|
||||||
|
},
|
||||||
|
DNSRuleAction: option.DNSRuleAction{
|
||||||
|
Action: C.RuleActionTypeRoute,
|
||||||
|
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.ErrorContains(t, err, "Address Filter Fields")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInitializeAllowsMixedRuleSetWhenLegacyDNSModeDisabled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
fakeSet := &fakeRuleSet{
|
||||||
|
metadata: adapter.RuleSetMetadata{
|
||||||
|
ContainsIPCIDRRule: true,
|
||||||
|
ContainsNonIPCIDRRule: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
routerService := &fakeRouter{
|
||||||
|
ruleSets: map[string]adapter.RuleSet{
|
||||||
|
"mixed": fakeSet,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := service.ContextWith[adapter.Router](context.Background(), routerService)
|
||||||
|
router := newTestRouterWithContext(t, ctx, []option.DNSRule{
|
||||||
|
{
|
||||||
|
Type: C.RuleTypeDefault,
|
||||||
|
DefaultOptions: option.DefaultDNSRule{
|
||||||
|
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||||
|
QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(mDNS.TypeA)},
|
||||||
|
},
|
||||||
|
DNSRuleAction: option.DNSRuleAction{
|
||||||
|
Action: C.RuleActionTypeRoute,
|
||||||
|
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: C.RuleTypeDefault,
|
||||||
|
DefaultOptions: option.DefaultDNSRule{
|
||||||
|
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||||
|
RuleSet: badoption.Listable[string]{"mixed"},
|
||||||
|
},
|
||||||
|
DNSRuleAction: option.DNSRuleAction{
|
||||||
|
Action: C.RuleActionTypeRoute,
|
||||||
|
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, &fakeDNSTransportManager{
|
||||||
|
defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
|
||||||
|
transports: map[string]adapter.DNSTransport{
|
||||||
|
"default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
|
||||||
|
"selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP},
|
||||||
|
},
|
||||||
|
}, &fakeDNSClient{})
|
||||||
|
require.False(t, router.legacyDNSMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRuleSetMetadataUpdateRejectsRuleSetFlippingToPureIP(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
fakeSet := &fakeRuleSet{
|
||||||
|
metadata: adapter.RuleSetMetadata{
|
||||||
|
ContainsIPCIDRRule: true,
|
||||||
|
ContainsNonIPCIDRRule: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
routerService := &fakeRouter{
|
||||||
|
ruleSets: map[string]adapter.RuleSet{
|
||||||
|
"mixed": fakeSet,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := service.ContextWith[adapter.Router](context.Background(), routerService)
|
||||||
|
router := newTestRouterWithContext(t, ctx, []option.DNSRule{
|
||||||
|
{
|
||||||
|
Type: C.RuleTypeDefault,
|
||||||
|
DefaultOptions: option.DefaultDNSRule{
|
||||||
|
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||||
|
QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(mDNS.TypeA)},
|
||||||
|
},
|
||||||
|
DNSRuleAction: option.DNSRuleAction{
|
||||||
|
Action: C.RuleActionTypeRoute,
|
||||||
|
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: C.RuleTypeDefault,
|
||||||
|
DefaultOptions: option.DefaultDNSRule{
|
||||||
|
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||||
|
RuleSet: badoption.Listable[string]{"mixed"},
|
||||||
|
},
|
||||||
|
DNSRuleAction: option.DNSRuleAction{
|
||||||
|
Action: C.RuleActionTypeRoute,
|
||||||
|
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, &fakeDNSTransportManager{
|
||||||
|
defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
|
||||||
|
transports: map[string]adapter.DNSTransport{
|
||||||
|
"default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
|
||||||
|
"selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP},
|
||||||
|
},
|
||||||
|
}, &fakeDNSClient{})
|
||||||
|
require.False(t, router.legacyDNSMode)
|
||||||
|
|
||||||
|
err := router.ValidateRuleSetMetadataUpdate("mixed", adapter.RuleSetMetadata{
|
||||||
|
ContainsIPCIDRRule: true,
|
||||||
|
})
|
||||||
|
require.ErrorContains(t, err, "Address Filter Fields")
|
||||||
|
}
|
||||||
|
|
||||||
func TestCloseWaitsForInFlightLookupUntilContextCancellation(t *testing.T) {
|
func TestCloseWaitsForInFlightLookupUntilContextCancellation(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
@@ -25,14 +25,23 @@ type Transport struct {
|
|||||||
dns.TransportAdapter
|
dns.TransportAdapter
|
||||||
logger logger.ContextLogger
|
logger logger.ContextLogger
|
||||||
store adapter.FakeIPStore
|
store adapter.FakeIPStore
|
||||||
|
inet4Enabled bool
|
||||||
|
inet6Enabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.FakeIPDNSServerOptions) (adapter.DNSTransport, error) {
|
func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.FakeIPDNSServerOptions) (adapter.DNSTransport, error) {
|
||||||
store := NewStore(ctx, logger, options.Inet4Range.Build(netip.Prefix{}), options.Inet6Range.Build(netip.Prefix{}))
|
inet4Range := options.Inet4Range.Build(netip.Prefix{})
|
||||||
|
inet6Range := options.Inet6Range.Build(netip.Prefix{})
|
||||||
|
if !inet4Range.IsValid() && !inet6Range.IsValid() {
|
||||||
|
return nil, E.New("at least one of inet4_range or inet6_range must be set")
|
||||||
|
}
|
||||||
|
store := NewStore(ctx, logger, inet4Range, inet6Range)
|
||||||
return &Transport{
|
return &Transport{
|
||||||
TransportAdapter: dns.NewTransportAdapter(C.DNSTypeFakeIP, tag, nil),
|
TransportAdapter: dns.NewTransportAdapter(C.DNSTypeFakeIP, tag, nil),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
store: store,
|
store: store,
|
||||||
|
inet4Enabled: inet4Range.IsValid(),
|
||||||
|
inet6Enabled: inet6Range.IsValid(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,6 +64,9 @@ func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg,
|
|||||||
if question.Qtype != mDNS.TypeA && question.Qtype != mDNS.TypeAAAA {
|
if question.Qtype != mDNS.TypeA && question.Qtype != mDNS.TypeAAAA {
|
||||||
return nil, E.New("only IP queries are supported by fakeip")
|
return nil, E.New("only IP queries are supported by fakeip")
|
||||||
}
|
}
|
||||||
|
if question.Qtype == mDNS.TypeA && !t.inet4Enabled || question.Qtype == mDNS.TypeAAAA && !t.inet6Enabled {
|
||||||
|
return dns.FixedResponseStatus(message, mDNS.RcodeSuccess), nil
|
||||||
|
}
|
||||||
address, err := t.store.Create(dns.FqdnToDomain(question.Name), question.Qtype == mDNS.TypeAAAA)
|
address, err := t.store.Create(dns.FqdnToDomain(question.Name), question.Qtype == mDNS.TypeAAAA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -33,7 +33,11 @@ func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, opt
|
|||||||
predefined = make(map[string][]netip.Addr)
|
predefined = make(map[string][]netip.Addr)
|
||||||
)
|
)
|
||||||
if len(options.Path) == 0 {
|
if len(options.Path) == 0 {
|
||||||
files = append(files, NewFile(DefaultPath))
|
defaultFile, err := NewDefault()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
files = append(files, defaultFile)
|
||||||
} else {
|
} else {
|
||||||
for _, path := range options.Path {
|
for _, path := range options.Path {
|
||||||
files = append(files, NewFile(filemanager.BasePath(ctx, os.ExpandEnv(path))))
|
files = append(files, NewFile(filemanager.BasePath(ctx, os.ExpandEnv(path))))
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -30,6 +32,14 @@ func NewFile(path string) *File {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewDefault() (*File, error) {
|
||||||
|
defaultPathResolved, err := defaultPath()
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "resolve default hosts path")
|
||||||
|
}
|
||||||
|
return NewFile(defaultPathResolved), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (f *File) Lookup(name string) []netip.Addr {
|
func (f *File) Lookup(name string) []netip.Addr {
|
||||||
f.access.Lock()
|
f.access.Lock()
|
||||||
defer f.access.Unlock()
|
defer f.access.Unlock()
|
||||||
|
|||||||
@@ -1,16 +1,29 @@
|
|||||||
package hosts_test
|
package hosts
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/dns/transport/hosts"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHosts(t *testing.T) {
|
func TestHosts(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
require.Equal(t, []netip.Addr{netip.AddrFrom4([4]byte{127, 0, 0, 1}), netip.IPv6Loopback()}, hosts.NewFile("testdata/hosts").Lookup("localhost"))
|
require.Equal(t, []netip.Addr{netip.AddrFrom4([4]byte{127, 0, 0, 1}), netip.IPv6Loopback()}, NewFile("testdata/hosts").Lookup("localhost"))
|
||||||
require.NotEmpty(t, hosts.NewFile(hosts.DefaultPath).Lookup("localhost"))
|
if runtime.GOOS != "windows" {
|
||||||
|
defaultPathResolved, err := defaultPath()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(E.Cause(err, "resolve default hosts path"))
|
||||||
|
}
|
||||||
|
content, readErr := os.ReadFile(defaultPathResolved)
|
||||||
|
require.NoError(t, readErr)
|
||||||
|
hFile := NewFile(defaultPathResolved)
|
||||||
|
if len(hFile.Lookup("localhost")) == 0 {
|
||||||
|
t.Fatal("failed to resolve localhost: ", defaultPathResolved, ": \n", string(content))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,4 +2,6 @@
|
|||||||
|
|
||||||
package hosts
|
package hosts
|
||||||
|
|
||||||
var DefaultPath = "/etc/hosts"
|
func defaultPath() (string, error) {
|
||||||
|
return "/etc/hosts", nil
|
||||||
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user