mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-14 04:38:28 +10:00
Compare commits
75 Commits
ccm-ocm-im
...
testing
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
334dd6e5c0 | ||
|
|
ccfdbf2d57 | ||
|
|
9b75d28ca4 | ||
|
|
2e64545db4 | ||
|
|
9675b0902a | ||
|
|
ebd31ca363 | ||
|
|
6ba7a6f001 | ||
|
|
b7e1a14974 | ||
|
|
a5c0112f0c | ||
|
|
e6427e8244 | ||
|
|
c0d9551bcf | ||
|
|
5cdf1aa000 | ||
|
|
6da0aa0c82 | ||
|
|
97f4723467 | ||
|
|
6c7fb1dad1 | ||
|
|
e0696f5e94 | ||
|
|
ddcaf040e2 | ||
|
|
57039ac11d | ||
|
|
abd6baf3cb | ||
|
|
a48fd106c3 | ||
|
|
6dfab9225f | ||
|
|
5e7e58f5e9 | ||
|
|
cfcc766d74 | ||
|
|
a24170638e | ||
|
|
ac9c0e7a81 | ||
|
|
51166f4601 | ||
|
|
5d254d9015 | ||
|
|
d3fc58ceb8 | ||
|
|
58d22df1be | ||
|
|
574852bdc1 | ||
|
|
ddc181f65a | ||
|
|
e2727d9556 | ||
|
|
f8b05790d1 | ||
|
|
c1203821f9 | ||
|
|
9805db343c | ||
|
|
b28083b131 | ||
|
|
0d1ce7957d | ||
|
|
025b947a24 | ||
|
|
76fa3c2e5e | ||
|
|
53db1f178c | ||
|
|
55ec8abf17 | ||
|
|
5a957fd750 | ||
|
|
7c3d8cf8db | ||
|
|
813b634d08 | ||
|
|
d9b435fb62 | ||
|
|
354b4b040e | ||
|
|
7ffdc48b49 | ||
|
|
e15bdf11eb | ||
|
|
e3bcb06c3e | ||
|
|
84d2280960 | ||
|
|
4fd2532b0a | ||
|
|
02ccde6c71 | ||
|
|
e98b4ad449 | ||
|
|
d09182614c | ||
|
|
6381de7bab | ||
|
|
b0c6762bc1 | ||
|
|
7425100bac | ||
|
|
d454aa0fdf | ||
|
|
a3623eb41a | ||
|
|
72bc4c1f87 | ||
|
|
9ac1e2ff32 | ||
|
|
0045103d14 | ||
|
|
d2a933784c | ||
|
|
3f05a37f65 | ||
|
|
b8e5a71450 | ||
|
|
c13faa8e3c | ||
|
|
7623bcd19e | ||
|
|
795d1c2892 | ||
|
|
6913b11e0a | ||
|
|
1e57c06295 | ||
|
|
ea464cef8d | ||
|
|
a8e3cd3256 | ||
|
|
686cf1f304 | ||
|
|
9fbfb87723 | ||
|
|
d2fa21d07b |
@@ -4,6 +4,7 @@
|
|||||||
--license GPL-3.0-or-later
|
--license GPL-3.0-or-later
|
||||||
--description "The universal proxy platform."
|
--description "The universal proxy platform."
|
||||||
--url "https://sing-box.sagernet.org/"
|
--url "https://sing-box.sagernet.org/"
|
||||||
|
--vendor SagerNet
|
||||||
--maintainer "nekohasekai <contact-git@sekai.icu>"
|
--maintainer "nekohasekai <contact-git@sekai.icu>"
|
||||||
--deb-field "Bug: https://github.com/SagerNet/sing-box/issues"
|
--deb-field "Bug: https://github.com/SagerNet/sing-box/issues"
|
||||||
--no-deb-generate-changes
|
--no-deb-generate-changes
|
||||||
|
|||||||
2
.github/CRONET_GO_VERSION
vendored
2
.github/CRONET_GO_VERSION
vendored
@@ -1 +1 @@
|
|||||||
ea7cd33752aed62603775af3df946c1b83f4b0b3
|
335e5bef5d88fc4474c9a70b865561f45a67de83
|
||||||
|
|||||||
33
.github/detect_track.sh
vendored
Executable file
33
.github/detect_track.sh
vendored
Executable file
@@ -0,0 +1,33 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
branches=$(git branch -r --contains HEAD)
|
||||||
|
if echo "$branches" | grep -q 'origin/stable'; then
|
||||||
|
track=stable
|
||||||
|
elif echo "$branches" | grep -q 'origin/testing'; then
|
||||||
|
track=testing
|
||||||
|
elif echo "$branches" | grep -q 'origin/oldstable'; then
|
||||||
|
track=oldstable
|
||||||
|
else
|
||||||
|
echo "ERROR: HEAD is not on any known release branch (stable/testing/oldstable)" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$track" == "stable" ]]; then
|
||||||
|
tag=$(git describe --tags --exact-match HEAD 2>/dev/null || true)
|
||||||
|
if [[ -n "$tag" && "$tag" == *"-"* ]]; then
|
||||||
|
track=beta
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
case "$track" in
|
||||||
|
stable) name=sing-box; docker_tag=latest ;;
|
||||||
|
beta) name=sing-box-beta; docker_tag=latest-beta ;;
|
||||||
|
testing) name=sing-box-testing; docker_tag=latest-testing ;;
|
||||||
|
oldstable) name=sing-box-oldstable; docker_tag=latest-oldstable ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
echo "track=${track} name=${name} docker_tag=${docker_tag}" >&2
|
||||||
|
echo "TRACK=${track}" >> "$GITHUB_ENV"
|
||||||
|
echo "NAME=${name}" >> "$GITHUB_ENV"
|
||||||
|
echo "DOCKER_TAG=${docker_tag}" >> "$GITHUB_ENV"
|
||||||
19
.github/workflows/docker.yml
vendored
19
.github/workflows/docker.yml
vendored
@@ -19,7 +19,6 @@ env:
|
|||||||
jobs:
|
jobs:
|
||||||
build_binary:
|
build_binary:
|
||||||
name: Build binary
|
name: Build binary
|
||||||
if: github.event_name != 'release' || github.event.release.target_commitish != 'oldstable'
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: true
|
fail-fast: true
|
||||||
@@ -260,13 +259,13 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
echo "ref=$ref"
|
echo "ref=$ref"
|
||||||
echo "ref=$ref" >> $GITHUB_OUTPUT
|
echo "ref=$ref" >> $GITHUB_OUTPUT
|
||||||
if [[ $ref == *"-"* ]]; then
|
- name: Checkout
|
||||||
latest=latest-beta
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
|
||||||
else
|
with:
|
||||||
latest=latest
|
ref: ${{ steps.ref.outputs.ref }}
|
||||||
fi
|
fetch-depth: 0
|
||||||
echo "latest=$latest"
|
- name: Detect track
|
||||||
echo "latest=$latest" >> $GITHUB_OUTPUT
|
run: bash .github/detect_track.sh
|
||||||
- name: Download digests
|
- name: Download digests
|
||||||
uses: actions/download-artifact@v5
|
uses: actions/download-artifact@v5
|
||||||
with:
|
with:
|
||||||
@@ -286,11 +285,11 @@ jobs:
|
|||||||
working-directory: /tmp/digests
|
working-directory: /tmp/digests
|
||||||
run: |
|
run: |
|
||||||
docker buildx imagetools create \
|
docker buildx imagetools create \
|
||||||
-t "${{ env.REGISTRY_IMAGE }}:${{ steps.ref.outputs.latest }}" \
|
-t "${{ env.REGISTRY_IMAGE }}:${{ env.DOCKER_TAG }}" \
|
||||||
-t "${{ env.REGISTRY_IMAGE }}:${{ steps.ref.outputs.ref }}" \
|
-t "${{ env.REGISTRY_IMAGE }}:${{ steps.ref.outputs.ref }}" \
|
||||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||||
- name: Inspect image
|
- name: Inspect image
|
||||||
if: github.event_name != 'push'
|
if: github.event_name != 'push'
|
||||||
run: |
|
run: |
|
||||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.ref.outputs.latest }}
|
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ env.DOCKER_TAG }}
|
||||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.ref.outputs.ref }}
|
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.ref.outputs.ref }}
|
||||||
|
|||||||
16
.github/workflows/linux.yml
vendored
16
.github/workflows/linux.yml
vendored
@@ -11,11 +11,6 @@ on:
|
|||||||
description: "Version name"
|
description: "Version name"
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
forceBeta:
|
|
||||||
description: "Force beta"
|
|
||||||
required: false
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
release:
|
release:
|
||||||
types:
|
types:
|
||||||
- published
|
- published
|
||||||
@@ -23,7 +18,6 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
calculate_version:
|
calculate_version:
|
||||||
name: Calculate version
|
name: Calculate version
|
||||||
if: github.event_name != 'release' || github.event.release.target_commitish != 'oldstable'
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
version: ${{ steps.outputs.outputs.version }}
|
version: ${{ steps.outputs.outputs.version }}
|
||||||
@@ -168,14 +162,8 @@ jobs:
|
|||||||
- name: Set mtime
|
- name: Set mtime
|
||||||
run: |-
|
run: |-
|
||||||
TZ=UTC touch -t '197001010000' dist/sing-box
|
TZ=UTC touch -t '197001010000' dist/sing-box
|
||||||
- name: Set name
|
- name: Detect track
|
||||||
if: (! contains(needs.calculate_version.outputs.version, '-')) && !inputs.forceBeta
|
run: bash .github/detect_track.sh
|
||||||
run: |-
|
|
||||||
echo "NAME=sing-box" >> "$GITHUB_ENV"
|
|
||||||
- name: Set beta name
|
|
||||||
if: contains(needs.calculate_version.outputs.version, '-') || inputs.forceBeta
|
|
||||||
run: |-
|
|
||||||
echo "NAME=sing-box-beta" >> "$GITHUB_ENV"
|
|
||||||
- name: Set version
|
- name: Set version
|
||||||
run: |-
|
run: |-
|
||||||
PKG_VERSION="${{ needs.calculate_version.outputs.version }}"
|
PKG_VERSION="${{ needs.calculate_version.outputs.version }}"
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -18,6 +18,6 @@
|
|||||||
.DS_Store
|
.DS_Store
|
||||||
/config.d/
|
/config.d/
|
||||||
/venv/
|
/venv/
|
||||||
/CLAUDE.md
|
CLAUDE.md
|
||||||
/AGENTS.md
|
AGENTS.md
|
||||||
/.claude/
|
/.claude/
|
||||||
|
|||||||
21
adapter/certificate/adapter.go
Normal file
21
adapter/certificate/adapter.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package certificate
|
||||||
|
|
||||||
|
type Adapter struct {
|
||||||
|
providerType string
|
||||||
|
providerTag string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAdapter(providerType string, providerTag string) Adapter {
|
||||||
|
return Adapter{
|
||||||
|
providerType: providerType,
|
||||||
|
providerTag: providerTag,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adapter) Type() string {
|
||||||
|
return a.providerType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adapter) Tag() string {
|
||||||
|
return a.providerTag
|
||||||
|
}
|
||||||
158
adapter/certificate/manager.go
Normal file
158
adapter/certificate/manager.go
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
package certificate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/taskmonitor"
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
F "github.com/sagernet/sing/common/format"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ adapter.CertificateProviderManager = (*Manager)(nil)
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
logger log.ContextLogger
|
||||||
|
registry adapter.CertificateProviderRegistry
|
||||||
|
access sync.Mutex
|
||||||
|
started bool
|
||||||
|
stage adapter.StartStage
|
||||||
|
providers []adapter.CertificateProviderService
|
||||||
|
providerByTag map[string]adapter.CertificateProviderService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(logger log.ContextLogger, registry adapter.CertificateProviderRegistry) *Manager {
|
||||||
|
return &Manager{
|
||||||
|
logger: logger,
|
||||||
|
registry: registry,
|
||||||
|
providerByTag: make(map[string]adapter.CertificateProviderService),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Start(stage adapter.StartStage) error {
|
||||||
|
m.access.Lock()
|
||||||
|
if m.started && m.stage >= stage {
|
||||||
|
panic("already started")
|
||||||
|
}
|
||||||
|
m.started = true
|
||||||
|
m.stage = stage
|
||||||
|
providers := m.providers
|
||||||
|
m.access.Unlock()
|
||||||
|
for _, provider := range providers {
|
||||||
|
name := "certificate-provider/" + provider.Type() + "[" + provider.Tag() + "]"
|
||||||
|
m.logger.Trace(stage, " ", name)
|
||||||
|
startTime := time.Now()
|
||||||
|
err := adapter.LegacyStart(provider, stage)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, stage, " ", name)
|
||||||
|
}
|
||||||
|
m.logger.Trace(stage, " ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Close() error {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
if !m.started {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.started = false
|
||||||
|
providers := m.providers
|
||||||
|
m.providers = nil
|
||||||
|
monitor := taskmonitor.New(m.logger, C.StopTimeout)
|
||||||
|
var err error
|
||||||
|
for _, provider := range providers {
|
||||||
|
name := "certificate-provider/" + provider.Type() + "[" + provider.Tag() + "]"
|
||||||
|
m.logger.Trace("close ", name)
|
||||||
|
startTime := time.Now()
|
||||||
|
monitor.Start("close ", name)
|
||||||
|
err = E.Append(err, provider.Close(), func(err error) error {
|
||||||
|
return E.Cause(err, "close ", name)
|
||||||
|
})
|
||||||
|
monitor.Finish()
|
||||||
|
m.logger.Trace("close ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) CertificateProviders() []adapter.CertificateProviderService {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
return m.providers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Get(tag string) (adapter.CertificateProviderService, bool) {
|
||||||
|
m.access.Lock()
|
||||||
|
provider, found := m.providerByTag[tag]
|
||||||
|
m.access.Unlock()
|
||||||
|
return provider, found
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Remove(tag string) error {
|
||||||
|
m.access.Lock()
|
||||||
|
provider, found := m.providerByTag[tag]
|
||||||
|
if !found {
|
||||||
|
m.access.Unlock()
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
delete(m.providerByTag, tag)
|
||||||
|
index := common.Index(m.providers, func(it adapter.CertificateProviderService) bool {
|
||||||
|
return it == provider
|
||||||
|
})
|
||||||
|
if index == -1 {
|
||||||
|
panic("invalid certificate provider index")
|
||||||
|
}
|
||||||
|
m.providers = append(m.providers[:index], m.providers[index+1:]...)
|
||||||
|
started := m.started
|
||||||
|
m.access.Unlock()
|
||||||
|
if started {
|
||||||
|
return provider.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Create(ctx context.Context, logger log.ContextLogger, tag string, providerType string, options any) error {
|
||||||
|
provider, err := m.registry.Create(ctx, logger, tag, providerType, options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
if m.started {
|
||||||
|
name := "certificate-provider/" + provider.Type() + "[" + provider.Tag() + "]"
|
||||||
|
for _, stage := range adapter.ListStartStages {
|
||||||
|
m.logger.Trace(stage, " ", name)
|
||||||
|
startTime := time.Now()
|
||||||
|
err = adapter.LegacyStart(provider, stage)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, stage, " ", name)
|
||||||
|
}
|
||||||
|
m.logger.Trace(stage, " ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if existsProvider, loaded := m.providerByTag[tag]; loaded {
|
||||||
|
if m.started {
|
||||||
|
err = existsProvider.Close()
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "close certificate-provider/", existsProvider.Type(), "[", existsProvider.Tag(), "]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
existsIndex := common.Index(m.providers, func(it adapter.CertificateProviderService) bool {
|
||||||
|
return it == existsProvider
|
||||||
|
})
|
||||||
|
if existsIndex == -1 {
|
||||||
|
panic("invalid certificate provider index")
|
||||||
|
}
|
||||||
|
m.providers = append(m.providers[:existsIndex], m.providers[existsIndex+1:]...)
|
||||||
|
}
|
||||||
|
m.providers = append(m.providers, provider)
|
||||||
|
m.providerByTag[tag] = provider
|
||||||
|
return nil
|
||||||
|
}
|
||||||
72
adapter/certificate/registry.go
Normal file
72
adapter/certificate/registry.go
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
package certificate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConstructorFunc[T any] func(ctx context.Context, logger log.ContextLogger, tag string, options T) (adapter.CertificateProviderService, error)
|
||||||
|
|
||||||
|
func Register[Options any](registry *Registry, providerType string, constructor ConstructorFunc[Options]) {
|
||||||
|
registry.register(providerType, func() any {
|
||||||
|
return new(Options)
|
||||||
|
}, func(ctx context.Context, logger log.ContextLogger, tag string, rawOptions any) (adapter.CertificateProviderService, error) {
|
||||||
|
var options *Options
|
||||||
|
if rawOptions != nil {
|
||||||
|
options = rawOptions.(*Options)
|
||||||
|
}
|
||||||
|
return constructor(ctx, logger, tag, common.PtrValueOrDefault(options))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ adapter.CertificateProviderRegistry = (*Registry)(nil)
|
||||||
|
|
||||||
|
type (
|
||||||
|
optionsConstructorFunc func() any
|
||||||
|
constructorFunc func(ctx context.Context, logger log.ContextLogger, tag string, options any) (adapter.CertificateProviderService, error)
|
||||||
|
)
|
||||||
|
|
||||||
|
type Registry struct {
|
||||||
|
access sync.Mutex
|
||||||
|
optionsType map[string]optionsConstructorFunc
|
||||||
|
constructor map[string]constructorFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRegistry() *Registry {
|
||||||
|
return &Registry{
|
||||||
|
optionsType: make(map[string]optionsConstructorFunc),
|
||||||
|
constructor: make(map[string]constructorFunc),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Registry) CreateOptions(providerType string) (any, bool) {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
optionsConstructor, loaded := m.optionsType[providerType]
|
||||||
|
if !loaded {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return optionsConstructor(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Registry) Create(ctx context.Context, logger log.ContextLogger, tag string, providerType string, options any) (adapter.CertificateProviderService, error) {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
constructor, loaded := m.constructor[providerType]
|
||||||
|
if !loaded {
|
||||||
|
return nil, E.New("certificate provider type not found: " + providerType)
|
||||||
|
}
|
||||||
|
return constructor(ctx, logger, tag, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Registry) register(providerType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
m.optionsType[providerType] = optionsConstructor
|
||||||
|
m.constructor[providerType] = constructor
|
||||||
|
}
|
||||||
38
adapter/certificate_provider.go
Normal file
38
adapter/certificate_provider.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CertificateProvider interface {
|
||||||
|
GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ACMECertificateProvider interface {
|
||||||
|
CertificateProvider
|
||||||
|
GetACMENextProtos() []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type CertificateProviderService interface {
|
||||||
|
Lifecycle
|
||||||
|
Type() string
|
||||||
|
Tag() string
|
||||||
|
CertificateProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
type CertificateProviderRegistry interface {
|
||||||
|
option.CertificateProviderOptionsRegistry
|
||||||
|
Create(ctx context.Context, logger log.ContextLogger, tag string, providerType string, options any) (CertificateProviderService, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type CertificateProviderManager interface {
|
||||||
|
Lifecycle
|
||||||
|
CertificateProviders() []CertificateProviderService
|
||||||
|
Get(tag string) (CertificateProviderService, bool)
|
||||||
|
Remove(tag string) error
|
||||||
|
Create(ctx context.Context, logger log.ContextLogger, tag string, providerType string, options any) error
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package adapter
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
@@ -25,18 +26,19 @@ type DNSRouter interface {
|
|||||||
|
|
||||||
type DNSClient interface {
|
type DNSClient interface {
|
||||||
Start()
|
Start()
|
||||||
Exchange(ctx context.Context, transport DNSTransport, message *dns.Msg, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error)
|
Exchange(ctx context.Context, transport DNSTransport, message *dns.Msg, options DNSQueryOptions, responseChecker func(response *dns.Msg) bool) (*dns.Msg, error)
|
||||||
Lookup(ctx context.Context, transport DNSTransport, domain string, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error)
|
Lookup(ctx context.Context, transport DNSTransport, domain string, options DNSQueryOptions, responseChecker func(response *dns.Msg) bool) ([]netip.Addr, error)
|
||||||
ClearCache()
|
ClearCache()
|
||||||
}
|
}
|
||||||
|
|
||||||
type DNSQueryOptions struct {
|
type DNSQueryOptions struct {
|
||||||
Transport DNSTransport
|
Transport DNSTransport
|
||||||
Strategy C.DomainStrategy
|
Strategy C.DomainStrategy
|
||||||
LookupStrategy C.DomainStrategy
|
LookupStrategy C.DomainStrategy
|
||||||
DisableCache bool
|
DisableCache bool
|
||||||
RewriteTTL *uint32
|
DisableOptimisticCache bool
|
||||||
ClientSubnet netip.Prefix
|
RewriteTTL *uint32
|
||||||
|
ClientSubnet netip.Prefix
|
||||||
}
|
}
|
||||||
|
|
||||||
func DNSQueryOptionsFrom(ctx context.Context, options *option.DomainResolveOptions) (*DNSQueryOptions, error) {
|
func DNSQueryOptionsFrom(ctx context.Context, options *option.DomainResolveOptions) (*DNSQueryOptions, error) {
|
||||||
@@ -49,11 +51,12 @@ func DNSQueryOptionsFrom(ctx context.Context, options *option.DomainResolveOptio
|
|||||||
return nil, E.New("domain resolver not found: " + options.Server)
|
return nil, E.New("domain resolver not found: " + options.Server)
|
||||||
}
|
}
|
||||||
return &DNSQueryOptions{
|
return &DNSQueryOptions{
|
||||||
Transport: transport,
|
Transport: transport,
|
||||||
Strategy: C.DomainStrategy(options.Strategy),
|
Strategy: C.DomainStrategy(options.Strategy),
|
||||||
DisableCache: options.DisableCache,
|
DisableCache: options.DisableCache,
|
||||||
RewriteTTL: options.RewriteTTL,
|
DisableOptimisticCache: options.DisableOptimisticCache,
|
||||||
ClientSubnet: options.ClientSubnet.Build(netip.Prefix{}),
|
RewriteTTL: options.RewriteTTL,
|
||||||
|
ClientSubnet: options.ClientSubnet.Build(netip.Prefix{}),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,6 +66,13 @@ type RDRCStore interface {
|
|||||||
SaveRDRCAsync(transportName string, qName string, qType uint16, logger logger.Logger)
|
SaveRDRCAsync(transportName string, qName string, qType uint16, logger logger.Logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DNSCacheStore interface {
|
||||||
|
LoadDNSCache(transportName string, qName string, qType uint16) (rawMessage []byte, expireAt time.Time, loaded bool)
|
||||||
|
SaveDNSCache(transportName string, qName string, qType uint16, rawMessage []byte, expireAt time.Time) error
|
||||||
|
SaveDNSCacheAsync(transportName string, qName string, qType uint16, rawMessage []byte, expireAt time.Time, logger logger.Logger)
|
||||||
|
ClearDNSCache() error
|
||||||
|
}
|
||||||
|
|
||||||
type DNSTransport interface {
|
type DNSTransport interface {
|
||||||
Lifecycle
|
Lifecycle
|
||||||
Type() string
|
Type() string
|
||||||
@@ -72,11 +82,6 @@ type DNSTransport interface {
|
|||||||
Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error)
|
Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type LegacyDNSTransport interface {
|
|
||||||
LegacyStrategy() C.DomainStrategy
|
|
||||||
LegacyClientSubnet() netip.Prefix
|
|
||||||
}
|
|
||||||
|
|
||||||
type DNSTransportRegistry interface {
|
type DNSTransportRegistry interface {
|
||||||
option.DNSTransportOptionsRegistry
|
option.DNSTransportOptionsRegistry
|
||||||
CreateDNSTransport(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) (DNSTransport, error)
|
CreateDNSTransport(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) (DNSTransport, error)
|
||||||
|
|||||||
@@ -47,6 +47,12 @@ type CacheFile interface {
|
|||||||
StoreRDRC() bool
|
StoreRDRC() bool
|
||||||
RDRCStore
|
RDRCStore
|
||||||
|
|
||||||
|
StoreDNS() bool
|
||||||
|
DNSCacheStore
|
||||||
|
|
||||||
|
SetDisableExpire(disableExpire bool)
|
||||||
|
SetOptimisticTimeout(timeout time.Duration)
|
||||||
|
|
||||||
LoadMode() string
|
LoadMode() string
|
||||||
StoreMode(mode string) error
|
StoreMode(mode string) error
|
||||||
LoadSelected(group string) string
|
LoadSelected(group string) string
|
||||||
|
|||||||
22
adapter/http.go
Normal file
22
adapter/http.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HTTPTransport interface {
|
||||||
|
http.RoundTripper
|
||||||
|
CloseIdleConnections()
|
||||||
|
Clone() HTTPTransport
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type HTTPClientManager interface {
|
||||||
|
ResolveTransport(ctx context.Context, logger logger.ContextLogger, options option.HTTPClientOptions) (HTTPTransport, error)
|
||||||
|
DefaultTransport() HTTPTransport
|
||||||
|
ResetNetwork()
|
||||||
|
}
|
||||||
@@ -10,6 +10,8 @@ import (
|
|||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Inbound interface {
|
type Inbound interface {
|
||||||
@@ -79,14 +81,16 @@ type InboundContext struct {
|
|||||||
FallbackNetworkType []C.InterfaceType
|
FallbackNetworkType []C.InterfaceType
|
||||||
FallbackDelay time.Duration
|
FallbackDelay time.Duration
|
||||||
|
|
||||||
DestinationAddresses []netip.Addr
|
DestinationAddresses []netip.Addr
|
||||||
SourceGeoIPCode string
|
DNSResponse *dns.Msg
|
||||||
GeoIPCode string
|
DestinationAddressMatchFromResponse bool
|
||||||
ProcessInfo *ConnectionOwner
|
SourceGeoIPCode string
|
||||||
SourceMACAddress net.HardwareAddr
|
GeoIPCode string
|
||||||
SourceHostname string
|
ProcessInfo *ConnectionOwner
|
||||||
QueryType uint16
|
SourceMACAddress net.HardwareAddr
|
||||||
FakeIP bool
|
SourceHostname string
|
||||||
|
QueryType uint16
|
||||||
|
FakeIP bool
|
||||||
|
|
||||||
// rule cache
|
// rule cache
|
||||||
|
|
||||||
@@ -104,6 +108,10 @@ type InboundContext struct {
|
|||||||
func (c *InboundContext) ResetRuleCache() {
|
func (c *InboundContext) ResetRuleCache() {
|
||||||
c.IPCIDRMatchSource = false
|
c.IPCIDRMatchSource = false
|
||||||
c.IPCIDRAcceptEmpty = false
|
c.IPCIDRAcceptEmpty = false
|
||||||
|
c.ResetRuleMatchCache()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *InboundContext) ResetRuleMatchCache() {
|
||||||
c.SourceAddressMatch = false
|
c.SourceAddressMatch = false
|
||||||
c.SourcePortMatch = false
|
c.SourcePortMatch = false
|
||||||
c.DestinationAddressMatch = false
|
c.DestinationAddressMatch = false
|
||||||
@@ -111,6 +119,51 @@ func (c *InboundContext) ResetRuleCache() {
|
|||||||
c.DidMatch = false
|
c.DidMatch = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *InboundContext) DNSResponseAddressesForMatch() []netip.Addr {
|
||||||
|
return DNSResponseAddresses(c.DNSResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DNSResponseAddresses(response *dns.Msg) []netip.Addr {
|
||||||
|
if response == nil || response.Rcode != dns.RcodeSuccess {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
addresses := make([]netip.Addr, 0, len(response.Answer))
|
||||||
|
for _, rawRecord := range response.Answer {
|
||||||
|
switch record := rawRecord.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
addr := M.AddrFromIP(record.A)
|
||||||
|
if addr.IsValid() {
|
||||||
|
addresses = append(addresses, addr)
|
||||||
|
}
|
||||||
|
case *dns.AAAA:
|
||||||
|
addr := M.AddrFromIP(record.AAAA)
|
||||||
|
if addr.IsValid() {
|
||||||
|
addresses = append(addresses, addr)
|
||||||
|
}
|
||||||
|
case *dns.HTTPS:
|
||||||
|
for _, value := range record.SVCB.Value {
|
||||||
|
switch hint := value.(type) {
|
||||||
|
case *dns.SVCBIPv4Hint:
|
||||||
|
for _, ip := range hint.Hint {
|
||||||
|
addr := M.AddrFromIP(ip).Unmap()
|
||||||
|
if addr.IsValid() {
|
||||||
|
addresses = append(addresses, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case *dns.SVCBIPv6Hint:
|
||||||
|
for _, ip := range hint.Hint {
|
||||||
|
addr := M.AddrFromIP(ip)
|
||||||
|
if addr.IsValid() {
|
||||||
|
addresses = append(addresses, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return addresses
|
||||||
|
}
|
||||||
|
|
||||||
type inboundContextKey struct{}
|
type inboundContextKey struct{}
|
||||||
|
|
||||||
func WithContext(ctx context.Context, inboundContext *InboundContext) context.Context {
|
func WithContext(ctx context.Context, inboundContext *InboundContext) context.Context {
|
||||||
|
|||||||
45
adapter/inbound_test.go
Normal file
45
adapter/inbound_test.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDNSResponseAddressesUnmapsHTTPSIPv4Hints(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ipv4Hint := net.ParseIP("1.1.1.1")
|
||||||
|
require.NotNil(t, ipv4Hint)
|
||||||
|
|
||||||
|
response := &dns.Msg{
|
||||||
|
MsgHdr: dns.MsgHdr{
|
||||||
|
Response: true,
|
||||||
|
Rcode: dns.RcodeSuccess,
|
||||||
|
},
|
||||||
|
Answer: []dns.RR{
|
||||||
|
&dns.HTTPS{
|
||||||
|
SVCB: dns.SVCB{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: dns.Fqdn("example.com"),
|
||||||
|
Rrtype: dns.TypeHTTPS,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: 60,
|
||||||
|
},
|
||||||
|
Priority: 1,
|
||||||
|
Target: ".",
|
||||||
|
Value: []dns.SVCBKeyValue{
|
||||||
|
&dns.SVCBIPv4Hint{Hint: []net.IP{ipv4Hint}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
addresses := DNSResponseAddresses(response)
|
||||||
|
require.Equal(t, []netip.Addr{netip.MustParseAddr("1.1.1.1")}, addresses)
|
||||||
|
require.True(t, addresses[0].Is4())
|
||||||
|
}
|
||||||
@@ -51,11 +51,11 @@ type FindConnectionOwnerRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ConnectionOwner struct {
|
type ConnectionOwner struct {
|
||||||
ProcessID uint32
|
ProcessID uint32
|
||||||
UserId int32
|
UserId int32
|
||||||
UserName string
|
UserName string
|
||||||
ProcessPath string
|
ProcessPath string
|
||||||
AndroidPackageName string
|
AndroidPackageNames []string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Notification struct {
|
type Notification struct {
|
||||||
|
|||||||
@@ -2,17 +2,11 @@ package adapter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
C "github.com/sagernet/sing-box/constant"
|
|
||||||
"github.com/sagernet/sing-tun"
|
"github.com/sagernet/sing-tun"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
"github.com/sagernet/sing/common/ntp"
|
|
||||||
"github.com/sagernet/sing/common/x/list"
|
"github.com/sagernet/sing/common/x/list"
|
||||||
|
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
@@ -51,7 +45,7 @@ type ConnectionRouterEx interface {
|
|||||||
|
|
||||||
type RuleSet interface {
|
type RuleSet interface {
|
||||||
Name() string
|
Name() string
|
||||||
StartContext(ctx context.Context, startContext *HTTPStartContext) error
|
StartContext(ctx context.Context) error
|
||||||
PostStart() error
|
PostStart() error
|
||||||
Metadata() RuleSetMetadata
|
Metadata() RuleSetMetadata
|
||||||
ExtractIPSet() []*netipx.IPSet
|
ExtractIPSet() []*netipx.IPSet
|
||||||
@@ -66,51 +60,14 @@ type RuleSet interface {
|
|||||||
|
|
||||||
type RuleSetUpdateCallback func(it RuleSet)
|
type RuleSetUpdateCallback func(it RuleSet)
|
||||||
|
|
||||||
|
type DNSRuleSetUpdateValidator interface {
|
||||||
|
ValidateRuleSetMetadataUpdate(tag string, metadata RuleSetMetadata) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ip_version is not a headless-rule item, so ContainsIPVersionRule is intentionally absent.
|
||||||
type RuleSetMetadata struct {
|
type RuleSetMetadata struct {
|
||||||
ContainsProcessRule bool
|
ContainsProcessRule bool
|
||||||
ContainsWIFIRule bool
|
ContainsWIFIRule bool
|
||||||
ContainsIPCIDRRule bool
|
ContainsIPCIDRRule bool
|
||||||
}
|
ContainsDNSQueryTypeRule bool
|
||||||
type HTTPStartContext struct {
|
|
||||||
ctx context.Context
|
|
||||||
access sync.Mutex
|
|
||||||
httpClientCache map[string]*http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHTTPStartContext(ctx context.Context) *HTTPStartContext {
|
|
||||||
return &HTTPStartContext{
|
|
||||||
ctx: ctx,
|
|
||||||
httpClientCache: make(map[string]*http.Client),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HTTPStartContext) HTTPClient(detour string, dialer N.Dialer) *http.Client {
|
|
||||||
c.access.Lock()
|
|
||||||
defer c.access.Unlock()
|
|
||||||
if httpClient, loaded := c.httpClientCache[detour]; loaded {
|
|
||||||
return httpClient
|
|
||||||
}
|
|
||||||
httpClient := &http.Client{
|
|
||||||
Transport: &http.Transport{
|
|
||||||
ForceAttemptHTTP2: true,
|
|
||||||
TLSHandshakeTimeout: C.TCPTimeout,
|
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
|
||||||
},
|
|
||||||
TLSClientConfig: &tls.Config{
|
|
||||||
Time: ntp.TimeFuncFromContext(c.ctx),
|
|
||||||
RootCAs: RootPoolFromContext(c.ctx),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
c.httpClientCache[detour] = httpClient
|
|
||||||
return httpClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HTTPStartContext) Close() {
|
|
||||||
c.access.Lock()
|
|
||||||
defer c.access.Unlock()
|
|
||||||
for _, client := range c.httpClientCache {
|
|
||||||
client.CloseIdleConnections()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package adapter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
type HeadlessRule interface {
|
type HeadlessRule interface {
|
||||||
@@ -18,8 +20,9 @@ type Rule interface {
|
|||||||
|
|
||||||
type DNSRule interface {
|
type DNSRule interface {
|
||||||
Rule
|
Rule
|
||||||
|
LegacyPreMatch(metadata *InboundContext) bool
|
||||||
WithAddressLimit() bool
|
WithAddressLimit() bool
|
||||||
MatchAddressLimit(metadata *InboundContext) bool
|
MatchAddressLimit(metadata *InboundContext, response *dns.Msg) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type RuleAction interface {
|
type RuleAction interface {
|
||||||
@@ -29,7 +32,7 @@ type RuleAction interface {
|
|||||||
|
|
||||||
func IsFinalAction(action RuleAction) bool {
|
func IsFinalAction(action RuleAction) bool {
|
||||||
switch action.Type() {
|
switch action.Type() {
|
||||||
case C.RuleActionTypeSniff, C.RuleActionTypeResolve:
|
case C.RuleActionTypeSniff, C.RuleActionTypeResolve, C.RuleActionTypeEvaluate:
|
||||||
return false
|
return false
|
||||||
default:
|
default:
|
||||||
return true
|
return true
|
||||||
|
|||||||
49
adapter/tailscale.go
Normal file
49
adapter/tailscale.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type TailscaleEndpoint interface {
|
||||||
|
SubscribeTailscaleStatus(ctx context.Context, fn func(*TailscaleEndpointStatus)) error
|
||||||
|
StartTailscalePing(ctx context.Context, peerIP string, fn func(*TailscalePingResult)) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TailscalePingResult struct {
|
||||||
|
LatencyMs float64
|
||||||
|
IsDirect bool
|
||||||
|
Endpoint string
|
||||||
|
DERPRegionID int32
|
||||||
|
DERPRegionCode string
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
type TailscaleEndpointStatus struct {
|
||||||
|
BackendState string
|
||||||
|
AuthURL string
|
||||||
|
NetworkName string
|
||||||
|
MagicDNSSuffix string
|
||||||
|
Self *TailscalePeer
|
||||||
|
UserGroups []*TailscaleUserGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
type TailscaleUserGroup struct {
|
||||||
|
UserID int64
|
||||||
|
LoginName string
|
||||||
|
DisplayName string
|
||||||
|
ProfilePicURL string
|
||||||
|
Peers []*TailscalePeer
|
||||||
|
}
|
||||||
|
|
||||||
|
type TailscalePeer struct {
|
||||||
|
HostName string
|
||||||
|
DNSName string
|
||||||
|
OS string
|
||||||
|
TailscaleIPs []string
|
||||||
|
Online bool
|
||||||
|
ExitNode bool
|
||||||
|
ExitNodeOption bool
|
||||||
|
Active bool
|
||||||
|
RxBytes int64
|
||||||
|
TxBytes int64
|
||||||
|
UserID int64
|
||||||
|
KeyExpiry int64
|
||||||
|
}
|
||||||
176
box.go
176
box.go
@@ -9,19 +9,21 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
boxCertificate "github.com/sagernet/sing-box/adapter/certificate"
|
||||||
"github.com/sagernet/sing-box/adapter/endpoint"
|
"github.com/sagernet/sing-box/adapter/endpoint"
|
||||||
"github.com/sagernet/sing-box/adapter/inbound"
|
"github.com/sagernet/sing-box/adapter/inbound"
|
||||||
"github.com/sagernet/sing-box/adapter/outbound"
|
"github.com/sagernet/sing-box/adapter/outbound"
|
||||||
boxService "github.com/sagernet/sing-box/adapter/service"
|
boxService "github.com/sagernet/sing-box/adapter/service"
|
||||||
"github.com/sagernet/sing-box/common/certificate"
|
"github.com/sagernet/sing-box/common/certificate"
|
||||||
"github.com/sagernet/sing-box/common/dialer"
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
|
"github.com/sagernet/sing-box/common/httpclient"
|
||||||
"github.com/sagernet/sing-box/common/taskmonitor"
|
"github.com/sagernet/sing-box/common/taskmonitor"
|
||||||
"github.com/sagernet/sing-box/common/tls"
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/dns"
|
"github.com/sagernet/sing-box/dns"
|
||||||
"github.com/sagernet/sing-box/dns/transport/local"
|
|
||||||
"github.com/sagernet/sing-box/experimental"
|
"github.com/sagernet/sing-box/experimental"
|
||||||
"github.com/sagernet/sing-box/experimental/cachefile"
|
"github.com/sagernet/sing-box/experimental/cachefile"
|
||||||
|
"github.com/sagernet/sing-box/experimental/deprecated"
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
"github.com/sagernet/sing-box/protocol/direct"
|
"github.com/sagernet/sing-box/protocol/direct"
|
||||||
@@ -37,20 +39,22 @@ import (
|
|||||||
var _ adapter.SimpleLifecycle = (*Box)(nil)
|
var _ adapter.SimpleLifecycle = (*Box)(nil)
|
||||||
|
|
||||||
type Box struct {
|
type Box struct {
|
||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
logFactory log.Factory
|
logFactory log.Factory
|
||||||
logger log.ContextLogger
|
logger log.ContextLogger
|
||||||
network *route.NetworkManager
|
network *route.NetworkManager
|
||||||
endpoint *endpoint.Manager
|
endpoint *endpoint.Manager
|
||||||
inbound *inbound.Manager
|
inbound *inbound.Manager
|
||||||
outbound *outbound.Manager
|
outbound *outbound.Manager
|
||||||
service *boxService.Manager
|
service *boxService.Manager
|
||||||
dnsTransport *dns.TransportManager
|
certificateProvider *boxCertificate.Manager
|
||||||
dnsRouter *dns.Router
|
dnsTransport *dns.TransportManager
|
||||||
connection *route.ConnectionManager
|
dnsRouter *dns.Router
|
||||||
router *route.Router
|
connection *route.ConnectionManager
|
||||||
internalService []adapter.LifecycleService
|
router *route.Router
|
||||||
done chan struct{}
|
httpClientService adapter.LifecycleService
|
||||||
|
internalService []adapter.LifecycleService
|
||||||
|
done chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
@@ -66,6 +70,7 @@ func Context(
|
|||||||
endpointRegistry adapter.EndpointRegistry,
|
endpointRegistry adapter.EndpointRegistry,
|
||||||
dnsTransportRegistry adapter.DNSTransportRegistry,
|
dnsTransportRegistry adapter.DNSTransportRegistry,
|
||||||
serviceRegistry adapter.ServiceRegistry,
|
serviceRegistry adapter.ServiceRegistry,
|
||||||
|
certificateProviderRegistry adapter.CertificateProviderRegistry,
|
||||||
) context.Context {
|
) context.Context {
|
||||||
if service.FromContext[option.InboundOptionsRegistry](ctx) == nil ||
|
if service.FromContext[option.InboundOptionsRegistry](ctx) == nil ||
|
||||||
service.FromContext[adapter.InboundRegistry](ctx) == nil {
|
service.FromContext[adapter.InboundRegistry](ctx) == nil {
|
||||||
@@ -90,6 +95,10 @@ func Context(
|
|||||||
ctx = service.ContextWith[option.ServiceOptionsRegistry](ctx, serviceRegistry)
|
ctx = service.ContextWith[option.ServiceOptionsRegistry](ctx, serviceRegistry)
|
||||||
ctx = service.ContextWith[adapter.ServiceRegistry](ctx, serviceRegistry)
|
ctx = service.ContextWith[adapter.ServiceRegistry](ctx, serviceRegistry)
|
||||||
}
|
}
|
||||||
|
if service.FromContext[adapter.CertificateProviderRegistry](ctx) == nil {
|
||||||
|
ctx = service.ContextWith[option.CertificateProviderOptionsRegistry](ctx, certificateProviderRegistry)
|
||||||
|
ctx = service.ContextWith[adapter.CertificateProviderRegistry](ctx, certificateProviderRegistry)
|
||||||
|
}
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,6 +115,7 @@ func New(options Options) (*Box, error) {
|
|||||||
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
|
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
|
||||||
dnsTransportRegistry := service.FromContext[adapter.DNSTransportRegistry](ctx)
|
dnsTransportRegistry := service.FromContext[adapter.DNSTransportRegistry](ctx)
|
||||||
serviceRegistry := service.FromContext[adapter.ServiceRegistry](ctx)
|
serviceRegistry := service.FromContext[adapter.ServiceRegistry](ctx)
|
||||||
|
certificateProviderRegistry := service.FromContext[adapter.CertificateProviderRegistry](ctx)
|
||||||
|
|
||||||
if endpointRegistry == nil {
|
if endpointRegistry == nil {
|
||||||
return nil, E.New("missing endpoint registry in context")
|
return nil, E.New("missing endpoint registry in context")
|
||||||
@@ -122,6 +132,9 @@ func New(options Options) (*Box, error) {
|
|||||||
if serviceRegistry == nil {
|
if serviceRegistry == nil {
|
||||||
return nil, E.New("missing service registry in context")
|
return nil, E.New("missing service registry in context")
|
||||||
}
|
}
|
||||||
|
if certificateProviderRegistry == nil {
|
||||||
|
return nil, E.New("missing certificate provider registry in context")
|
||||||
|
}
|
||||||
|
|
||||||
ctx = pause.WithDefaultManager(ctx)
|
ctx = pause.WithDefaultManager(ctx)
|
||||||
experimentalOptions := common.PtrValueOrDefault(options.Experimental)
|
experimentalOptions := common.PtrValueOrDefault(options.Experimental)
|
||||||
@@ -159,6 +172,7 @@ func New(options Options) (*Box, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var internalServices []adapter.LifecycleService
|
var internalServices []adapter.LifecycleService
|
||||||
|
routeOptions := common.PtrValueOrDefault(options.Route)
|
||||||
certificateOptions := common.PtrValueOrDefault(options.Certificate)
|
certificateOptions := common.PtrValueOrDefault(options.Certificate)
|
||||||
if C.IsAndroid || certificateOptions.Store != "" && certificateOptions.Store != C.CertificateStoreSystem ||
|
if C.IsAndroid || certificateOptions.Store != "" && certificateOptions.Store != C.CertificateStoreSystem ||
|
||||||
len(certificateOptions.Certificate) > 0 ||
|
len(certificateOptions.Certificate) > 0 ||
|
||||||
@@ -171,21 +185,25 @@ func New(options Options) (*Box, error) {
|
|||||||
service.MustRegister[adapter.CertificateStore](ctx, certificateStore)
|
service.MustRegister[adapter.CertificateStore](ctx, certificateStore)
|
||||||
internalServices = append(internalServices, certificateStore)
|
internalServices = append(internalServices, certificateStore)
|
||||||
}
|
}
|
||||||
|
|
||||||
routeOptions := common.PtrValueOrDefault(options.Route)
|
|
||||||
dnsOptions := common.PtrValueOrDefault(options.DNS)
|
dnsOptions := common.PtrValueOrDefault(options.DNS)
|
||||||
endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry)
|
endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry)
|
||||||
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager)
|
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager)
|
||||||
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final)
|
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final)
|
||||||
dnsTransportManager := dns.NewTransportManager(logFactory.NewLogger("dns/transport"), dnsTransportRegistry, outboundManager, dnsOptions.Final)
|
dnsTransportManager := dns.NewTransportManager(logFactory.NewLogger("dns/transport"), dnsTransportRegistry, outboundManager, dnsOptions.Final)
|
||||||
serviceManager := boxService.NewManager(logFactory.NewLogger("service"), serviceRegistry)
|
serviceManager := boxService.NewManager(logFactory.NewLogger("service"), serviceRegistry)
|
||||||
|
certificateProviderManager := boxCertificate.NewManager(logFactory.NewLogger("certificate-provider"), certificateProviderRegistry)
|
||||||
service.MustRegister[adapter.EndpointManager](ctx, endpointManager)
|
service.MustRegister[adapter.EndpointManager](ctx, endpointManager)
|
||||||
service.MustRegister[adapter.InboundManager](ctx, inboundManager)
|
service.MustRegister[adapter.InboundManager](ctx, inboundManager)
|
||||||
service.MustRegister[adapter.OutboundManager](ctx, outboundManager)
|
service.MustRegister[adapter.OutboundManager](ctx, outboundManager)
|
||||||
service.MustRegister[adapter.DNSTransportManager](ctx, dnsTransportManager)
|
service.MustRegister[adapter.DNSTransportManager](ctx, dnsTransportManager)
|
||||||
service.MustRegister[adapter.ServiceManager](ctx, serviceManager)
|
service.MustRegister[adapter.ServiceManager](ctx, serviceManager)
|
||||||
dnsRouter := dns.NewRouter(ctx, logFactory, dnsOptions)
|
service.MustRegister[adapter.CertificateProviderManager](ctx, certificateProviderManager)
|
||||||
|
dnsRouter, err := dns.NewRouter(ctx, logFactory, dnsOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "initialize DNS router")
|
||||||
|
}
|
||||||
service.MustRegister[adapter.DNSRouter](ctx, dnsRouter)
|
service.MustRegister[adapter.DNSRouter](ctx, dnsRouter)
|
||||||
|
service.MustRegister[adapter.DNSRuleSetUpdateValidator](ctx, dnsRouter)
|
||||||
networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions, dnsOptions)
|
networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions, dnsOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.Cause(err, "initialize network manager")
|
return nil, E.Cause(err, "initialize network manager")
|
||||||
@@ -193,6 +211,10 @@ func New(options Options) (*Box, error) {
|
|||||||
service.MustRegister[adapter.NetworkManager](ctx, networkManager)
|
service.MustRegister[adapter.NetworkManager](ctx, networkManager)
|
||||||
connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection"))
|
connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection"))
|
||||||
service.MustRegister[adapter.ConnectionManager](ctx, connectionManager)
|
service.MustRegister[adapter.ConnectionManager](ctx, connectionManager)
|
||||||
|
// Must register after ConnectionManager: the Apple HTTP engine's proxy bridge reads it from the context when Manager.Start resolves the default client.
|
||||||
|
httpClientManager := httpclient.NewManager(ctx, logFactory.NewLogger("httpclient"), options.HTTPClients, routeOptions.DefaultHTTPClient)
|
||||||
|
service.MustRegister[adapter.HTTPClientManager](ctx, httpClientManager)
|
||||||
|
httpClientService := adapter.LifecycleService(httpClientManager)
|
||||||
router := route.NewRouter(ctx, logFactory, routeOptions, dnsOptions)
|
router := route.NewRouter(ctx, logFactory, routeOptions, dnsOptions)
|
||||||
service.MustRegister[adapter.Router](ctx, router)
|
service.MustRegister[adapter.Router](ctx, router)
|
||||||
err = router.Initialize(routeOptions.Rules, routeOptions.RuleSet)
|
err = router.Initialize(routeOptions.Rules, routeOptions.RuleSet)
|
||||||
@@ -272,6 +294,24 @@ func New(options Options) (*Box, error) {
|
|||||||
return nil, E.Cause(err, "initialize inbound[", i, "]")
|
return nil, E.Cause(err, "initialize inbound[", i, "]")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for i, serviceOptions := range options.Services {
|
||||||
|
var tag string
|
||||||
|
if serviceOptions.Tag != "" {
|
||||||
|
tag = serviceOptions.Tag
|
||||||
|
} else {
|
||||||
|
tag = F.ToString(i)
|
||||||
|
}
|
||||||
|
err = serviceManager.Create(
|
||||||
|
ctx,
|
||||||
|
logFactory.NewLogger(F.ToString("service/", serviceOptions.Type, "[", tag, "]")),
|
||||||
|
tag,
|
||||||
|
serviceOptions.Type,
|
||||||
|
serviceOptions.Options,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "initialize service[", i, "]")
|
||||||
|
}
|
||||||
|
}
|
||||||
for i, outboundOptions := range options.Outbounds {
|
for i, outboundOptions := range options.Outbounds {
|
||||||
var tag string
|
var tag string
|
||||||
if outboundOptions.Tag != "" {
|
if outboundOptions.Tag != "" {
|
||||||
@@ -298,22 +338,22 @@ func New(options Options) (*Box, error) {
|
|||||||
return nil, E.Cause(err, "initialize outbound[", i, "]")
|
return nil, E.Cause(err, "initialize outbound[", i, "]")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for i, serviceOptions := range options.Services {
|
for i, certificateProviderOptions := range options.CertificateProviders {
|
||||||
var tag string
|
var tag string
|
||||||
if serviceOptions.Tag != "" {
|
if certificateProviderOptions.Tag != "" {
|
||||||
tag = serviceOptions.Tag
|
tag = certificateProviderOptions.Tag
|
||||||
} else {
|
} else {
|
||||||
tag = F.ToString(i)
|
tag = F.ToString(i)
|
||||||
}
|
}
|
||||||
err = serviceManager.Create(
|
err = certificateProviderManager.Create(
|
||||||
ctx,
|
ctx,
|
||||||
logFactory.NewLogger(F.ToString("service/", serviceOptions.Type, "[", tag, "]")),
|
logFactory.NewLogger(F.ToString("certificate-provider/", certificateProviderOptions.Type, "[", tag, "]")),
|
||||||
tag,
|
tag,
|
||||||
serviceOptions.Type,
|
certificateProviderOptions.Type,
|
||||||
serviceOptions.Options,
|
certificateProviderOptions.Options,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.Cause(err, "initialize service[", i, "]")
|
return nil, E.Cause(err, "initialize certificate provider[", i, "]")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
outboundManager.Initialize(func() (adapter.Outbound, error) {
|
outboundManager.Initialize(func() (adapter.Outbound, error) {
|
||||||
@@ -326,13 +366,20 @@ func New(options Options) (*Box, error) {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
dnsTransportManager.Initialize(func() (adapter.DNSTransport, error) {
|
dnsTransportManager.Initialize(func() (adapter.DNSTransport, error) {
|
||||||
return local.NewTransport(
|
return dnsTransportRegistry.CreateDNSTransport(
|
||||||
ctx,
|
ctx,
|
||||||
logFactory.NewLogger("dns/local"),
|
logFactory.NewLogger("dns/local"),
|
||||||
"local",
|
"local",
|
||||||
option.LocalDNSServerOptions{},
|
C.DNSTypeLocal,
|
||||||
|
&option.LocalDNSServerOptions{},
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
httpClientManager.Initialize(func() (*httpclient.Transport, error) {
|
||||||
|
deprecated.Report(ctx, deprecated.OptionImplicitDefaultHTTPClient)
|
||||||
|
var httpClientOptions option.HTTPClientOptions
|
||||||
|
httpClientOptions.DefaultOutbound = true
|
||||||
|
return httpclient.NewTransport(ctx, logFactory.NewLogger("httpclient"), "", httpClientOptions)
|
||||||
|
})
|
||||||
if platformInterface != nil {
|
if platformInterface != nil {
|
||||||
err = platformInterface.Initialize(networkManager)
|
err = platformInterface.Initialize(networkManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -340,7 +387,7 @@ func New(options Options) (*Box, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if needCacheFile {
|
if needCacheFile {
|
||||||
cacheFile := cachefile.New(ctx, common.PtrValueOrDefault(experimentalOptions.CacheFile))
|
cacheFile := cachefile.New(ctx, logFactory.NewLogger("cache-file"), common.PtrValueOrDefault(experimentalOptions.CacheFile))
|
||||||
service.MustRegister[adapter.CacheFile](ctx, cacheFile)
|
service.MustRegister[adapter.CacheFile](ctx, cacheFile)
|
||||||
internalServices = append(internalServices, cacheFile)
|
internalServices = append(internalServices, cacheFile)
|
||||||
}
|
}
|
||||||
@@ -383,20 +430,22 @@ func New(options Options) (*Box, error) {
|
|||||||
internalServices = append(internalServices, adapter.NewLifecycleService(ntpService, "ntp service"))
|
internalServices = append(internalServices, adapter.NewLifecycleService(ntpService, "ntp service"))
|
||||||
}
|
}
|
||||||
return &Box{
|
return &Box{
|
||||||
network: networkManager,
|
network: networkManager,
|
||||||
endpoint: endpointManager,
|
endpoint: endpointManager,
|
||||||
inbound: inboundManager,
|
inbound: inboundManager,
|
||||||
outbound: outboundManager,
|
outbound: outboundManager,
|
||||||
dnsTransport: dnsTransportManager,
|
dnsTransport: dnsTransportManager,
|
||||||
service: serviceManager,
|
service: serviceManager,
|
||||||
dnsRouter: dnsRouter,
|
certificateProvider: certificateProviderManager,
|
||||||
connection: connectionManager,
|
dnsRouter: dnsRouter,
|
||||||
router: router,
|
connection: connectionManager,
|
||||||
createdAt: createdAt,
|
router: router,
|
||||||
logFactory: logFactory,
|
httpClientService: httpClientService,
|
||||||
logger: logFactory.Logger(),
|
createdAt: createdAt,
|
||||||
internalService: internalServices,
|
logFactory: logFactory,
|
||||||
done: make(chan struct{}),
|
logger: logFactory.Logger(),
|
||||||
|
internalService: internalServices,
|
||||||
|
done: make(chan struct{}),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -450,11 +499,19 @@ func (s *Box) preStart() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = adapter.Start(s.logger, adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.service)
|
err = adapter.Start(s.logger, adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.service, s.certificateProvider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = adapter.Start(s.logger, adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router)
|
err = adapter.Start(s.logger, adapter.StartStateStart, s.outbound, s.dnsTransport, s.network, s.connection)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = adapter.StartNamed(s.logger, adapter.StartStateStart, []adapter.LifecycleService{s.httpClientService})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = adapter.Start(s.logger, adapter.StartStateStart, s.router, s.dnsRouter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -470,11 +527,19 @@ func (s *Box) start() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = adapter.Start(s.logger, adapter.StartStateStart, s.inbound, s.endpoint, s.service)
|
err = adapter.Start(s.logger, adapter.StartStateStart, s.endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = adapter.Start(s.logger, adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.inbound, s.endpoint, s.service)
|
err = adapter.Start(s.logger, adapter.StartStateStart, s.certificateProvider)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = adapter.Start(s.logger, adapter.StartStateStart, s.inbound, s.service)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = adapter.Start(s.logger, adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.endpoint, s.certificateProvider, s.inbound, s.service)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -482,7 +547,7 @@ func (s *Box) start() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = adapter.Start(s.logger, adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.service)
|
err = adapter.Start(s.logger, adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.endpoint, s.certificateProvider, s.inbound, s.service)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -506,8 +571,9 @@ func (s *Box) Close() error {
|
|||||||
service adapter.Lifecycle
|
service adapter.Lifecycle
|
||||||
}{
|
}{
|
||||||
{"service", s.service},
|
{"service", s.service},
|
||||||
{"endpoint", s.endpoint},
|
|
||||||
{"inbound", s.inbound},
|
{"inbound", s.inbound},
|
||||||
|
{"certificate-provider", s.certificateProvider},
|
||||||
|
{"endpoint", s.endpoint},
|
||||||
{"outbound", s.outbound},
|
{"outbound", s.outbound},
|
||||||
{"router", s.router},
|
{"router", s.router},
|
||||||
{"connection", s.connection},
|
{"connection", s.connection},
|
||||||
@@ -522,6 +588,14 @@ func (s *Box) Close() error {
|
|||||||
})
|
})
|
||||||
s.logger.Trace("close ", closeItem.name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
|
s.logger.Trace("close ", closeItem.name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
|
||||||
}
|
}
|
||||||
|
if s.httpClientService != nil {
|
||||||
|
s.logger.Trace("close ", s.httpClientService.Name())
|
||||||
|
startTime := time.Now()
|
||||||
|
err = E.Append(err, s.httpClientService.Close(), func(err error) error {
|
||||||
|
return E.Cause(err, "close ", s.httpClientService.Name())
|
||||||
|
})
|
||||||
|
s.logger.Trace("close ", s.httpClientService.Name(), " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
|
||||||
|
}
|
||||||
for _, lifecycleService := range s.internalService {
|
for _, lifecycleService := range s.internalService {
|
||||||
s.logger.Trace("close ", lifecycleService.Name())
|
s.logger.Trace("close ", lifecycleService.Name())
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
@@ -555,6 +629,10 @@ func (s *Box) Outbound() adapter.OutboundManager {
|
|||||||
return s.outbound
|
return s.outbound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Box) Endpoint() adapter.EndpointManager {
|
||||||
|
return s.endpoint
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Box) LogFactory() log.Factory {
|
func (s *Box) LogFactory() log.Factory {
|
||||||
return s.logFactory
|
return s.logFactory
|
||||||
}
|
}
|
||||||
|
|||||||
Submodule clients/android updated: 6f09892c71...fea0f3a7ba
Submodule clients/apple updated: f3b4b2238e...ffbf405b52
@@ -204,6 +204,9 @@ func buildApple() {
|
|||||||
"-target", bindTarget,
|
"-target", bindTarget,
|
||||||
"-libname=box",
|
"-libname=box",
|
||||||
"-tags-not-macos=with_low_memory",
|
"-tags-not-macos=with_low_memory",
|
||||||
|
"-iosversion=15.0",
|
||||||
|
"-macosversion=13.0",
|
||||||
|
"-tvosversion=17.0",
|
||||||
}
|
}
|
||||||
//if !withTailscale {
|
//if !withTailscale {
|
||||||
// args = append(args, "-tags-macos="+strings.Join(memcTags, ","))
|
// args = append(args, "-tags-macos="+strings.Join(memcTags, ","))
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
@@ -35,21 +36,9 @@ func updateMozillaIncludedRootCAs() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
geoIndex := slices.Index(header, "Geographic Focus")
|
geoIndex := slices.Index(header, "Geographic Focus")
|
||||||
nameIndex := slices.Index(header, "Common Name or Certificate Name")
|
|
||||||
certIndex := slices.Index(header, "PEM Info")
|
certIndex := slices.Index(header, "PEM Info")
|
||||||
|
|
||||||
generated := strings.Builder{}
|
pemBundle := strings.Builder{}
|
||||||
generated.WriteString(`// Code generated by 'make update_certificates'. DO NOT EDIT.
|
|
||||||
|
|
||||||
package certificate
|
|
||||||
|
|
||||||
import "crypto/x509"
|
|
||||||
|
|
||||||
var mozillaIncluded *x509.CertPool
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
mozillaIncluded = x509.NewCertPool()
|
|
||||||
`)
|
|
||||||
for {
|
for {
|
||||||
record, err := reader.Read()
|
record, err := reader.Read()
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
@@ -60,18 +49,12 @@ func init() {
|
|||||||
if record[geoIndex] == "China" {
|
if record[geoIndex] == "China" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
generated.WriteString("\n // ")
|
|
||||||
generated.WriteString(record[nameIndex])
|
|
||||||
generated.WriteString("\n")
|
|
||||||
generated.WriteString(" mozillaIncluded.AppendCertsFromPEM([]byte(`")
|
|
||||||
cert := record[certIndex]
|
cert := record[certIndex]
|
||||||
// Remove single quotes
|
|
||||||
cert = cert[1 : len(cert)-1]
|
cert = cert[1 : len(cert)-1]
|
||||||
generated.WriteString(cert)
|
pemBundle.WriteString(cert)
|
||||||
generated.WriteString("`))\n")
|
pemBundle.WriteString("\n")
|
||||||
}
|
}
|
||||||
generated.WriteString("}\n")
|
return writeGeneratedCertificateBundle("mozilla", "mozillaIncluded", pemBundle.String())
|
||||||
return os.WriteFile("common/certificate/mozilla.go", []byte(generated.String()), 0o644)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchChinaFingerprints() (map[string]bool, error) {
|
func fetchChinaFingerprints() (map[string]bool, error) {
|
||||||
@@ -119,23 +102,11 @@ func updateChromeIncludedRootCAs() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
subjectIndex := slices.Index(header, "Subject")
|
|
||||||
statusIndex := slices.Index(header, "Google Chrome Status")
|
statusIndex := slices.Index(header, "Google Chrome Status")
|
||||||
certIndex := slices.Index(header, "X.509 Certificate (PEM)")
|
certIndex := slices.Index(header, "X.509 Certificate (PEM)")
|
||||||
fingerprintIndex := slices.Index(header, "SHA-256 Fingerprint")
|
fingerprintIndex := slices.Index(header, "SHA-256 Fingerprint")
|
||||||
|
|
||||||
generated := strings.Builder{}
|
pemBundle := strings.Builder{}
|
||||||
generated.WriteString(`// Code generated by 'make update_certificates'. DO NOT EDIT.
|
|
||||||
|
|
||||||
package certificate
|
|
||||||
|
|
||||||
import "crypto/x509"
|
|
||||||
|
|
||||||
var chromeIncluded *x509.CertPool
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
chromeIncluded = x509.NewCertPool()
|
|
||||||
`)
|
|
||||||
for {
|
for {
|
||||||
record, err := reader.Read()
|
record, err := reader.Read()
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
@@ -149,18 +120,39 @@ func init() {
|
|||||||
if chinaFingerprints[record[fingerprintIndex]] {
|
if chinaFingerprints[record[fingerprintIndex]] {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
generated.WriteString("\n // ")
|
|
||||||
generated.WriteString(record[subjectIndex])
|
|
||||||
generated.WriteString("\n")
|
|
||||||
generated.WriteString(" chromeIncluded.AppendCertsFromPEM([]byte(`")
|
|
||||||
cert := record[certIndex]
|
cert := record[certIndex]
|
||||||
// Remove single quotes if present
|
|
||||||
if len(cert) > 0 && cert[0] == '\'' {
|
if len(cert) > 0 && cert[0] == '\'' {
|
||||||
cert = cert[1 : len(cert)-1]
|
cert = cert[1 : len(cert)-1]
|
||||||
}
|
}
|
||||||
generated.WriteString(cert)
|
pemBundle.WriteString(cert)
|
||||||
generated.WriteString("`))\n")
|
pemBundle.WriteString("\n")
|
||||||
}
|
}
|
||||||
generated.WriteString("}\n")
|
return writeGeneratedCertificateBundle("chrome", "chromeIncluded", pemBundle.String())
|
||||||
return os.WriteFile("common/certificate/chrome.go", []byte(generated.String()), 0o644)
|
}
|
||||||
|
|
||||||
|
func writeGeneratedCertificateBundle(name string, variableName string, pemBundle string) error {
|
||||||
|
goSource := `// Code generated by 'make update_certificates'. DO NOT EDIT.
|
||||||
|
|
||||||
|
package certificate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
_ "embed"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed ` + name + `.pem
|
||||||
|
var ` + variableName + `PEM string
|
||||||
|
|
||||||
|
var ` + variableName + ` *x509.CertPool
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
` + variableName + ` = x509.NewCertPool()
|
||||||
|
` + variableName + `.AppendCertsFromPEM([]byte(` + variableName + `PEM))
|
||||||
|
}
|
||||||
|
`
|
||||||
|
err := os.WriteFile(filepath.Join("common/certificate", name+".pem"), []byte(pemBundle), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.WriteFile(filepath.Join("common/certificate", name+".go"), []byte(goSource), 0o644)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -82,6 +82,11 @@ func compileRuleSet(sourcePath string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func downgradeRuleSetVersion(version uint8, options option.PlainRuleSet) uint8 {
|
func downgradeRuleSetVersion(version uint8, options option.PlainRuleSet) uint8 {
|
||||||
|
if version == C.RuleSetVersion5 && !rule.HasHeadlessRule(options.Rules, func(rule option.DefaultHeadlessRule) bool {
|
||||||
|
return len(rule.PackageNameRegex) > 0
|
||||||
|
}) {
|
||||||
|
version = C.RuleSetVersion4
|
||||||
|
}
|
||||||
if version == C.RuleSetVersion4 && !rule.HasHeadlessRule(options.Rules, func(rule option.DefaultHeadlessRule) bool {
|
if version == C.RuleSetVersion4 && !rule.HasHeadlessRule(options.Rules, func(rule option.DefaultHeadlessRule) bool {
|
||||||
return rule.NetworkInterfaceAddress != nil && rule.NetworkInterfaceAddress.Size() > 0 ||
|
return rule.NetworkInterfaceAddress != nil && rule.NetworkInterfaceAddress.Size() > 0 ||
|
||||||
len(rule.DefaultInterfaceAddress) > 0
|
len(rule.DefaultInterfaceAddress) > 0
|
||||||
|
|||||||
121
cmd/sing-box/cmd_tools_networkquality.go
Normal file
121
cmd/sing-box/cmd_tools_networkquality.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/common/networkquality"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
commandNetworkQualityFlagConfigURL string
|
||||||
|
commandNetworkQualityFlagSerial bool
|
||||||
|
commandNetworkQualityFlagMaxRuntime int
|
||||||
|
commandNetworkQualityFlagHTTP3 bool
|
||||||
|
)
|
||||||
|
|
||||||
|
var commandNetworkQuality = &cobra.Command{
|
||||||
|
Use: "networkquality",
|
||||||
|
Short: "Run a network quality test",
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
err := runNetworkQuality()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
commandNetworkQuality.Flags().StringVar(
|
||||||
|
&commandNetworkQualityFlagConfigURL,
|
||||||
|
"config-url", "",
|
||||||
|
"Network quality test config URL (default: Apple mensura)",
|
||||||
|
)
|
||||||
|
commandNetworkQuality.Flags().BoolVar(
|
||||||
|
&commandNetworkQualityFlagSerial,
|
||||||
|
"serial", false,
|
||||||
|
"Run download and upload tests sequentially instead of in parallel",
|
||||||
|
)
|
||||||
|
commandNetworkQuality.Flags().IntVar(
|
||||||
|
&commandNetworkQualityFlagMaxRuntime,
|
||||||
|
"max-runtime", int(networkquality.DefaultMaxRuntime/time.Second),
|
||||||
|
"Network quality maximum runtime in seconds",
|
||||||
|
)
|
||||||
|
commandNetworkQuality.Flags().BoolVar(
|
||||||
|
&commandNetworkQualityFlagHTTP3,
|
||||||
|
"http3", false,
|
||||||
|
"Use HTTP/3 (QUIC) for measurement traffic",
|
||||||
|
)
|
||||||
|
commandTools.AddCommand(commandNetworkQuality)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runNetworkQuality() error {
|
||||||
|
instance, err := createPreStartedClient()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer instance.Close()
|
||||||
|
|
||||||
|
dialer, err := createDialer(instance, commandToolsFlagOutbound)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
httpClient := networkquality.NewHTTPClient(dialer)
|
||||||
|
defer httpClient.CloseIdleConnections()
|
||||||
|
|
||||||
|
measurementClientFactory, err := networkquality.NewOptionalHTTP3Factory(dialer, commandNetworkQualityFlagHTTP3)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintln(os.Stderr, "==== NETWORK QUALITY TEST ====")
|
||||||
|
|
||||||
|
result, err := networkquality.Run(networkquality.Options{
|
||||||
|
ConfigURL: commandNetworkQualityFlagConfigURL,
|
||||||
|
HTTPClient: httpClient,
|
||||||
|
NewMeasurementClient: measurementClientFactory,
|
||||||
|
Serial: commandNetworkQualityFlagSerial,
|
||||||
|
MaxRuntime: time.Duration(commandNetworkQualityFlagMaxRuntime) * time.Second,
|
||||||
|
Context: globalCtx,
|
||||||
|
OnProgress: func(p networkquality.Progress) {
|
||||||
|
if !commandNetworkQualityFlagSerial && p.Phase != networkquality.PhaseIdle {
|
||||||
|
fmt.Fprintf(os.Stderr, "\rDownload: %s RPM: %d Upload: %s RPM: %d",
|
||||||
|
networkquality.FormatBitrate(p.DownloadCapacity), p.DownloadRPM,
|
||||||
|
networkquality.FormatBitrate(p.UploadCapacity), p.UploadRPM)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch networkquality.Phase(p.Phase) {
|
||||||
|
case networkquality.PhaseIdle:
|
||||||
|
if p.IdleLatencyMs > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "\rIdle Latency: %d ms", p.IdleLatencyMs)
|
||||||
|
} else {
|
||||||
|
fmt.Fprint(os.Stderr, "\rMeasuring idle latency...")
|
||||||
|
}
|
||||||
|
case networkquality.PhaseDownload:
|
||||||
|
fmt.Fprintf(os.Stderr, "\rDownload: %s RPM: %d",
|
||||||
|
networkquality.FormatBitrate(p.DownloadCapacity), p.DownloadRPM)
|
||||||
|
case networkquality.PhaseUpload:
|
||||||
|
fmt.Fprintf(os.Stderr, "\rUpload: %s RPM: %d",
|
||||||
|
networkquality.FormatBitrate(p.UploadCapacity), p.UploadRPM)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintln(os.Stderr)
|
||||||
|
fmt.Fprintln(os.Stderr, strings.Repeat("-", 40))
|
||||||
|
fmt.Fprintf(os.Stderr, "Idle Latency: %d ms\n", result.IdleLatencyMs)
|
||||||
|
fmt.Fprintf(os.Stderr, "Download Capacity: %-20s Accuracy: %s\n", networkquality.FormatBitrate(result.DownloadCapacity), result.DownloadCapacityAccuracy)
|
||||||
|
fmt.Fprintf(os.Stderr, "Upload Capacity: %-20s Accuracy: %s\n", networkquality.FormatBitrate(result.UploadCapacity), result.UploadCapacityAccuracy)
|
||||||
|
fmt.Fprintf(os.Stderr, "Download Responsiveness: %-20s Accuracy: %s\n", fmt.Sprintf("%d RPM", result.DownloadRPM), result.DownloadRPMAccuracy)
|
||||||
|
fmt.Fprintf(os.Stderr, "Upload Responsiveness: %-20s Accuracy: %s\n", fmt.Sprintf("%d RPM", result.UploadRPM), result.UploadRPMAccuracy)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
79
cmd/sing-box/cmd_tools_stun.go
Normal file
79
cmd/sing-box/cmd_tools_stun.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/common/stun"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
var commandSTUNFlagServer string
|
||||||
|
|
||||||
|
var commandSTUN = &cobra.Command{
|
||||||
|
Use: "stun",
|
||||||
|
Short: "Run a STUN test",
|
||||||
|
Args: cobra.NoArgs,
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
err := runSTUN()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
commandSTUN.Flags().StringVarP(&commandSTUNFlagServer, "server", "s", stun.DefaultServer, "STUN server address")
|
||||||
|
commandTools.AddCommand(commandSTUN)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runSTUN() error {
|
||||||
|
instance, err := createPreStartedClient()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer instance.Close()
|
||||||
|
|
||||||
|
dialer, err := createDialer(instance, commandToolsFlagOutbound)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintln(os.Stderr, "==== STUN TEST ====")
|
||||||
|
|
||||||
|
result, err := stun.Run(stun.Options{
|
||||||
|
Server: commandSTUNFlagServer,
|
||||||
|
Dialer: dialer,
|
||||||
|
Context: globalCtx,
|
||||||
|
OnProgress: func(p stun.Progress) {
|
||||||
|
switch p.Phase {
|
||||||
|
case stun.PhaseBinding:
|
||||||
|
if p.ExternalAddr != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "\rExternal Address: %s (%d ms)", p.ExternalAddr, p.LatencyMs)
|
||||||
|
} else {
|
||||||
|
fmt.Fprint(os.Stderr, "\rSending binding request...")
|
||||||
|
}
|
||||||
|
case stun.PhaseNATMapping:
|
||||||
|
fmt.Fprint(os.Stderr, "\rDetecting NAT mapping behavior...")
|
||||||
|
case stun.PhaseNATFiltering:
|
||||||
|
fmt.Fprint(os.Stderr, "\rDetecting NAT filtering behavior...")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintln(os.Stderr)
|
||||||
|
fmt.Fprintf(os.Stderr, "External Address: %s\n", result.ExternalAddr)
|
||||||
|
fmt.Fprintf(os.Stderr, "Latency: %d ms\n", result.LatencyMs)
|
||||||
|
if result.NATTypeSupported {
|
||||||
|
fmt.Fprintf(os.Stderr, "NAT Mapping: %s\n", result.NATMapping)
|
||||||
|
fmt.Fprintf(os.Stderr, "NAT Filtering: %s\n", result.NATFiltering)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintln(os.Stderr, "NAT Type Detection: not supported by server")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
2650
common/certificate/chrome.pem
Normal file
2650
common/certificate/chrome.pem
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
4256
common/certificate/mozilla.pem
Normal file
4256
common/certificate/mozilla.pem
Normal file
File diff suppressed because it is too large
Load Diff
@@ -22,8 +22,10 @@ var _ adapter.CertificateStore = (*Store)(nil)
|
|||||||
|
|
||||||
type Store struct {
|
type Store struct {
|
||||||
access sync.RWMutex
|
access sync.RWMutex
|
||||||
|
store string
|
||||||
systemPool *x509.CertPool
|
systemPool *x509.CertPool
|
||||||
currentPool *x509.CertPool
|
currentPool *x509.CertPool
|
||||||
|
currentPEM []string
|
||||||
certificate string
|
certificate string
|
||||||
certificatePaths []string
|
certificatePaths []string
|
||||||
certificateDirectoryPaths []string
|
certificateDirectoryPaths []string
|
||||||
@@ -61,6 +63,7 @@ func NewStore(ctx context.Context, logger logger.Logger, options option.Certific
|
|||||||
return nil, E.New("unknown certificate store: ", options.Store)
|
return nil, E.New("unknown certificate store: ", options.Store)
|
||||||
}
|
}
|
||||||
store := &Store{
|
store := &Store{
|
||||||
|
store: options.Store,
|
||||||
systemPool: systemPool,
|
systemPool: systemPool,
|
||||||
certificate: strings.Join(options.Certificate, "\n"),
|
certificate: strings.Join(options.Certificate, "\n"),
|
||||||
certificatePaths: options.CertificatePath,
|
certificatePaths: options.CertificatePath,
|
||||||
@@ -123,19 +126,37 @@ func (s *Store) Pool() *x509.CertPool {
|
|||||||
return s.currentPool
|
return s.currentPool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) StoreKind() string {
|
||||||
|
return s.store
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CurrentPEM() []string {
|
||||||
|
s.access.RLock()
|
||||||
|
defer s.access.RUnlock()
|
||||||
|
return append([]string(nil), s.currentPEM...)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) update() error {
|
func (s *Store) update() error {
|
||||||
s.access.Lock()
|
s.access.Lock()
|
||||||
defer s.access.Unlock()
|
defer s.access.Unlock()
|
||||||
var currentPool *x509.CertPool
|
var currentPool *x509.CertPool
|
||||||
|
var currentPEM []string
|
||||||
if s.systemPool == nil {
|
if s.systemPool == nil {
|
||||||
currentPool = x509.NewCertPool()
|
currentPool = x509.NewCertPool()
|
||||||
} else {
|
} else {
|
||||||
currentPool = s.systemPool.Clone()
|
currentPool = s.systemPool.Clone()
|
||||||
}
|
}
|
||||||
|
switch s.store {
|
||||||
|
case C.CertificateStoreMozilla:
|
||||||
|
currentPEM = append(currentPEM, mozillaIncludedPEM)
|
||||||
|
case C.CertificateStoreChrome:
|
||||||
|
currentPEM = append(currentPEM, chromeIncludedPEM)
|
||||||
|
}
|
||||||
if s.certificate != "" {
|
if s.certificate != "" {
|
||||||
if !currentPool.AppendCertsFromPEM([]byte(s.certificate)) {
|
if !currentPool.AppendCertsFromPEM([]byte(s.certificate)) {
|
||||||
return E.New("invalid certificate PEM strings")
|
return E.New("invalid certificate PEM strings")
|
||||||
}
|
}
|
||||||
|
currentPEM = append(currentPEM, s.certificate)
|
||||||
}
|
}
|
||||||
for _, path := range s.certificatePaths {
|
for _, path := range s.certificatePaths {
|
||||||
pemContent, err := os.ReadFile(path)
|
pemContent, err := os.ReadFile(path)
|
||||||
@@ -145,6 +166,7 @@ func (s *Store) update() error {
|
|||||||
if !currentPool.AppendCertsFromPEM(pemContent) {
|
if !currentPool.AppendCertsFromPEM(pemContent) {
|
||||||
return E.New("invalid certificate PEM file: ", path)
|
return E.New("invalid certificate PEM file: ", path)
|
||||||
}
|
}
|
||||||
|
currentPEM = append(currentPEM, string(pemContent))
|
||||||
}
|
}
|
||||||
var firstErr error
|
var firstErr error
|
||||||
for _, directoryPath := range s.certificateDirectoryPaths {
|
for _, directoryPath := range s.certificateDirectoryPaths {
|
||||||
@@ -157,8 +179,8 @@ func (s *Store) update() error {
|
|||||||
}
|
}
|
||||||
for _, directoryEntry := range directoryEntries {
|
for _, directoryEntry := range directoryEntries {
|
||||||
pemContent, err := os.ReadFile(filepath.Join(directoryPath, directoryEntry.Name()))
|
pemContent, err := os.ReadFile(filepath.Join(directoryPath, directoryEntry.Name()))
|
||||||
if err == nil {
|
if err == nil && currentPool.AppendCertsFromPEM(pemContent) {
|
||||||
currentPool.AppendCertsFromPEM(pemContent)
|
currentPEM = append(currentPEM, string(pemContent))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -166,6 +188,7 @@ func (s *Store) update() error {
|
|||||||
return firstErr
|
return firstErr
|
||||||
}
|
}
|
||||||
s.currentPool = currentPool
|
s.currentPool = currentPool
|
||||||
|
s.currentPEM = currentPEM
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -149,7 +149,10 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
|
|||||||
} else {
|
} else {
|
||||||
dialer.Timeout = C.TCPConnectTimeout
|
dialer.Timeout = C.TCPConnectTimeout
|
||||||
}
|
}
|
||||||
if !options.DisableTCPKeepAlive {
|
if options.DisableTCPKeepAlive {
|
||||||
|
dialer.KeepAlive = -1
|
||||||
|
dialer.KeepAliveConfig.Enable = false
|
||||||
|
} else {
|
||||||
keepIdle := time.Duration(options.TCPKeepAlive)
|
keepIdle := time.Duration(options.TCPKeepAlive)
|
||||||
if keepIdle == 0 {
|
if keepIdle == 0 {
|
||||||
keepIdle = C.TCPKeepAliveInitial
|
keepIdle = C.TCPKeepAliveInitial
|
||||||
@@ -239,7 +242,7 @@ func setMarkWrapper(networkManager adapter.NetworkManager, mark uint32, isDefaul
|
|||||||
func (d *DefaultDialer) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) {
|
func (d *DefaultDialer) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) {
|
||||||
if !address.IsValid() {
|
if !address.IsValid() {
|
||||||
return nil, E.New("invalid address")
|
return nil, E.New("invalid address")
|
||||||
} else if address.IsFqdn() {
|
} else if address.IsDomain() {
|
||||||
return nil, E.New("domain not resolved")
|
return nil, E.New("domain not resolved")
|
||||||
}
|
}
|
||||||
if d.networkStrategy == nil {
|
if d.networkStrategy == nil {
|
||||||
@@ -329,9 +332,9 @@ func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksadd
|
|||||||
|
|
||||||
func (d *DefaultDialer) DialerForICMPDestination(destination netip.Addr) net.Dialer {
|
func (d *DefaultDialer) DialerForICMPDestination(destination netip.Addr) net.Dialer {
|
||||||
if !destination.Is6() {
|
if !destination.Is6() {
|
||||||
return d.dialer6.Dialer
|
|
||||||
} else {
|
|
||||||
return d.dialer4.Dialer
|
return d.dialer4.Dialer
|
||||||
|
} else {
|
||||||
|
return d.dialer6.Dialer
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ type DirectDialer interface {
|
|||||||
type DetourDialer struct {
|
type DetourDialer struct {
|
||||||
outboundManager adapter.OutboundManager
|
outboundManager adapter.OutboundManager
|
||||||
detour string
|
detour string
|
||||||
|
defaultOutbound bool
|
||||||
legacyDNSDialer bool
|
legacyDNSDialer bool
|
||||||
dialer N.Dialer
|
dialer N.Dialer
|
||||||
initOnce sync.Once
|
initOnce sync.Once
|
||||||
@@ -33,6 +34,13 @@ func NewDetour(outboundManager adapter.OutboundManager, detour string, legacyDNS
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewDefaultOutboundDetour(outboundManager adapter.OutboundManager) N.Dialer {
|
||||||
|
return &DetourDialer{
|
||||||
|
outboundManager: outboundManager,
|
||||||
|
defaultOutbound: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func InitializeDetour(dialer N.Dialer) error {
|
func InitializeDetour(dialer N.Dialer) error {
|
||||||
detourDialer, isDetour := common.Cast[*DetourDialer](dialer)
|
detourDialer, isDetour := common.Cast[*DetourDialer](dialer)
|
||||||
if !isDetour {
|
if !isDetour {
|
||||||
@@ -47,12 +55,18 @@ func (d *DetourDialer) Dialer() (N.Dialer, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DetourDialer) init() {
|
func (d *DetourDialer) init() {
|
||||||
dialer, loaded := d.outboundManager.Outbound(d.detour)
|
var dialer adapter.Outbound
|
||||||
if !loaded {
|
if d.detour != "" {
|
||||||
d.initErr = E.New("outbound detour not found: ", d.detour)
|
var loaded bool
|
||||||
return
|
dialer, loaded = d.outboundManager.Outbound(d.detour)
|
||||||
|
if !loaded {
|
||||||
|
d.initErr = E.New("outbound detour not found: ", d.detour)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dialer = d.outboundManager.Default()
|
||||||
}
|
}
|
||||||
if !d.legacyDNSDialer {
|
if !d.defaultOutbound && !d.legacyDNSDialer {
|
||||||
if directDialer, isDirect := dialer.(DirectDialer); isDirect {
|
if directDialer, isDirect := dialer.(DirectDialer); isDirect {
|
||||||
if directDialer.IsEmpty() {
|
if directDialer.IsEmpty() {
|
||||||
d.initErr = E.New("detour to an empty direct outbound makes no sense")
|
d.initErr = E.New("detour to an empty direct outbound makes no sense")
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ type Options struct {
|
|||||||
NewDialer bool
|
NewDialer bool
|
||||||
LegacyDNSDialer bool
|
LegacyDNSDialer bool
|
||||||
DirectOutbound bool
|
DirectOutbound bool
|
||||||
|
DefaultOutbound bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: merge with NewWithOptions
|
// TODO: merge with NewWithOptions
|
||||||
@@ -42,19 +43,26 @@ func NewWithOptions(options Options) (N.Dialer, error) {
|
|||||||
dialer N.Dialer
|
dialer N.Dialer
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
hasDetour := dialOptions.Detour != "" || options.DefaultOutbound
|
||||||
if dialOptions.Detour != "" {
|
if dialOptions.Detour != "" {
|
||||||
outboundManager := service.FromContext[adapter.OutboundManager](options.Context)
|
outboundManager := service.FromContext[adapter.OutboundManager](options.Context)
|
||||||
if outboundManager == nil {
|
if outboundManager == nil {
|
||||||
return nil, E.New("missing outbound manager")
|
return nil, E.New("missing outbound manager")
|
||||||
}
|
}
|
||||||
dialer = NewDetour(outboundManager, dialOptions.Detour, options.LegacyDNSDialer)
|
dialer = NewDetour(outboundManager, dialOptions.Detour, options.LegacyDNSDialer)
|
||||||
|
} else if options.DefaultOutbound {
|
||||||
|
outboundManager := service.FromContext[adapter.OutboundManager](options.Context)
|
||||||
|
if outboundManager == nil {
|
||||||
|
return nil, E.New("missing outbound manager")
|
||||||
|
}
|
||||||
|
dialer = NewDefaultOutboundDetour(outboundManager)
|
||||||
} else {
|
} else {
|
||||||
dialer, err = NewDefault(options.Context, dialOptions)
|
dialer, err = NewDefault(options.Context, dialOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if options.RemoteIsDomain && (dialOptions.Detour == "" || options.ResolverOnDetour || dialOptions.DomainResolver != nil && dialOptions.DomainResolver.Server != "") {
|
if options.RemoteIsDomain && (!hasDetour || options.ResolverOnDetour || dialOptions.DomainResolver != nil && dialOptions.DomainResolver.Server != "") {
|
||||||
networkManager := service.FromContext[adapter.NetworkManager](options.Context)
|
networkManager := service.FromContext[adapter.NetworkManager](options.Context)
|
||||||
dnsTransport := service.FromContext[adapter.DNSTransportManager](options.Context)
|
dnsTransport := service.FromContext[adapter.DNSTransportManager](options.Context)
|
||||||
var defaultOptions adapter.NetworkOptions
|
var defaultOptions adapter.NetworkOptions
|
||||||
@@ -87,11 +95,12 @@ func NewWithOptions(options Options) (N.Dialer, error) {
|
|||||||
}
|
}
|
||||||
server = dialOptions.DomainResolver.Server
|
server = dialOptions.DomainResolver.Server
|
||||||
dnsQueryOptions = adapter.DNSQueryOptions{
|
dnsQueryOptions = adapter.DNSQueryOptions{
|
||||||
Transport: transport,
|
Transport: transport,
|
||||||
Strategy: strategy,
|
Strategy: strategy,
|
||||||
DisableCache: dialOptions.DomainResolver.DisableCache,
|
DisableCache: dialOptions.DomainResolver.DisableCache,
|
||||||
RewriteTTL: dialOptions.DomainResolver.RewriteTTL,
|
DisableOptimisticCache: dialOptions.DomainResolver.DisableOptimisticCache,
|
||||||
ClientSubnet: dialOptions.DomainResolver.ClientSubnet.Build(netip.Prefix{}),
|
RewriteTTL: dialOptions.DomainResolver.RewriteTTL,
|
||||||
|
ClientSubnet: dialOptions.DomainResolver.ClientSubnet.Build(netip.Prefix{}),
|
||||||
}
|
}
|
||||||
resolveFallbackDelay = time.Duration(dialOptions.FallbackDelay)
|
resolveFallbackDelay = time.Duration(dialOptions.FallbackDelay)
|
||||||
} else if options.DirectResolver {
|
} else if options.DirectResolver {
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ func (d *resolveDialer) DialContext(ctx context.Context, network string, destina
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if !destination.IsFqdn() {
|
if !destination.IsDomain() {
|
||||||
return d.dialer.DialContext(ctx, network, destination)
|
return d.dialer.DialContext(ctx, network, destination)
|
||||||
}
|
}
|
||||||
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
|
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
|
||||||
@@ -116,7 +116,7 @@ func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if !destination.IsFqdn() {
|
if !destination.IsDomain() {
|
||||||
return d.dialer.ListenPacket(ctx, destination)
|
return d.dialer.ListenPacket(ctx, destination)
|
||||||
}
|
}
|
||||||
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
|
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
|
||||||
@@ -144,7 +144,7 @@ func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if !destination.IsFqdn() {
|
if !destination.IsDomain() {
|
||||||
return d.dialer.DialContext(ctx, network, destination)
|
return d.dialer.DialContext(ctx, network, destination)
|
||||||
}
|
}
|
||||||
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
|
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
|
||||||
@@ -167,7 +167,7 @@ func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.C
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if !destination.IsFqdn() {
|
if !destination.IsDomain() {
|
||||||
return d.dialer.ListenPacket(ctx, destination)
|
return d.dialer.ListenPacket(ctx, destination)
|
||||||
}
|
}
|
||||||
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
|
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
|
||||||
|
|||||||
442
common/httpclient/apple_transport_darwin.go
Normal file
442
common/httpclient/apple_transport_darwin.go
Normal file
@@ -0,0 +1,442 @@
|
|||||||
|
//go:build darwin && cgo
|
||||||
|
|
||||||
|
package httpclient
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CFLAGS: -x objective-c -fobjc-arc
|
||||||
|
#cgo LDFLAGS: -framework Foundation -framework Security
|
||||||
|
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include "apple_transport_darwin.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/proxybridge"
|
||||||
|
boxTLS "github.com/sagernet/sing-box/common/tls"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
const applePinnedHashSize = sha256.Size
|
||||||
|
|
||||||
|
func verifyApplePinnedPublicKeySHA256(flatHashes []byte, leafCertificate []byte) error {
|
||||||
|
if len(flatHashes)%applePinnedHashSize != 0 {
|
||||||
|
return E.New("invalid pinned public key list")
|
||||||
|
}
|
||||||
|
knownHashes := make([][]byte, 0, len(flatHashes)/applePinnedHashSize)
|
||||||
|
for offset := 0; offset < len(flatHashes); offset += applePinnedHashSize {
|
||||||
|
knownHashes = append(knownHashes, append([]byte(nil), flatHashes[offset:offset+applePinnedHashSize]...))
|
||||||
|
}
|
||||||
|
return boxTLS.VerifyPublicKeySHA256(knownHashes, [][]byte{leafCertificate})
|
||||||
|
}
|
||||||
|
|
||||||
|
//export box_apple_http_verify_public_key_sha256
|
||||||
|
func box_apple_http_verify_public_key_sha256(knownHashValues *C.uint8_t, knownHashValuesLen C.size_t, leafCert *C.uint8_t, leafCertLen C.size_t) *C.char {
|
||||||
|
flatHashes := C.GoBytes(unsafe.Pointer(knownHashValues), C.int(knownHashValuesLen))
|
||||||
|
leafCertificate := C.GoBytes(unsafe.Pointer(leafCert), C.int(leafCertLen))
|
||||||
|
err := verifyApplePinnedPublicKeySHA256(flatHashes, leafCertificate)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return C.CString(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
type appleSessionConfig struct {
|
||||||
|
serverName string
|
||||||
|
minVersion uint16
|
||||||
|
maxVersion uint16
|
||||||
|
insecure bool
|
||||||
|
anchorPEM string
|
||||||
|
anchorOnly bool
|
||||||
|
pinnedPublicKeySHA256s []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type appleTransportShared struct {
|
||||||
|
logger logger.ContextLogger
|
||||||
|
bridge *proxybridge.Bridge
|
||||||
|
config appleSessionConfig
|
||||||
|
refs atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
type appleTransport struct {
|
||||||
|
shared *appleTransportShared
|
||||||
|
access sync.Mutex
|
||||||
|
session *C.box_apple_http_session_t
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type errorTransport struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleTransport(ctx context.Context, logger logger.ContextLogger, rawDialer N.Dialer, options option.HTTPClientOptions) (adapter.HTTPTransport, error) {
|
||||||
|
sessionConfig, err := newAppleSessionConfig(ctx, options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
bridge, err := proxybridge.New(ctx, logger, "apple http proxy", rawDialer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
shared := &appleTransportShared{
|
||||||
|
logger: logger,
|
||||||
|
bridge: bridge,
|
||||||
|
config: sessionConfig,
|
||||||
|
}
|
||||||
|
shared.refs.Store(1)
|
||||||
|
session, err := shared.newSession()
|
||||||
|
if err != nil {
|
||||||
|
bridge.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &appleTransport{
|
||||||
|
shared: shared,
|
||||||
|
session: session,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleSessionConfig(ctx context.Context, options option.HTTPClientOptions) (appleSessionConfig, error) {
|
||||||
|
version := options.Version
|
||||||
|
if version == 0 {
|
||||||
|
version = 2
|
||||||
|
}
|
||||||
|
switch version {
|
||||||
|
case 2:
|
||||||
|
case 1:
|
||||||
|
return appleSessionConfig{}, E.New("HTTP/1.1 is unsupported in Apple HTTP engine")
|
||||||
|
case 3:
|
||||||
|
return appleSessionConfig{}, E.New("HTTP/3 is unsupported in Apple HTTP engine")
|
||||||
|
default:
|
||||||
|
return appleSessionConfig{}, E.New("unknown HTTP version: ", version)
|
||||||
|
}
|
||||||
|
if options.DisableVersionFallback {
|
||||||
|
return appleSessionConfig{}, E.New("disable_version_fallback is unsupported in Apple HTTP engine")
|
||||||
|
}
|
||||||
|
if options.HTTP2Options != (option.HTTP2Options{}) {
|
||||||
|
return appleSessionConfig{}, E.New("HTTP/2 options are unsupported in Apple HTTP engine")
|
||||||
|
}
|
||||||
|
if options.HTTP3Options != (option.QUICOptions{}) {
|
||||||
|
return appleSessionConfig{}, E.New("QUIC options are unsupported in Apple HTTP engine")
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsOptions := common.PtrValueOrDefault(options.TLS)
|
||||||
|
if tlsOptions.Engine != "" {
|
||||||
|
return appleSessionConfig{}, E.New("tls.engine is unsupported in Apple HTTP engine")
|
||||||
|
}
|
||||||
|
if len(tlsOptions.ALPN) > 0 {
|
||||||
|
return appleSessionConfig{}, E.New("tls.alpn is unsupported in Apple HTTP engine")
|
||||||
|
}
|
||||||
|
validated, err := boxTLS.ValidateAppleTLSOptions(ctx, tlsOptions, "Apple HTTP engine")
|
||||||
|
if err != nil {
|
||||||
|
return appleSessionConfig{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
config := appleSessionConfig{
|
||||||
|
serverName: tlsOptions.ServerName,
|
||||||
|
minVersion: validated.MinVersion,
|
||||||
|
maxVersion: validated.MaxVersion,
|
||||||
|
insecure: tlsOptions.Insecure || len(tlsOptions.CertificatePublicKeySHA256) > 0,
|
||||||
|
anchorPEM: validated.AnchorPEM,
|
||||||
|
anchorOnly: validated.AnchorOnly,
|
||||||
|
}
|
||||||
|
if len(tlsOptions.CertificatePublicKeySHA256) > 0 {
|
||||||
|
config.pinnedPublicKeySHA256s = make([]byte, 0, len(tlsOptions.CertificatePublicKeySHA256)*applePinnedHashSize)
|
||||||
|
for _, hashValue := range tlsOptions.CertificatePublicKeySHA256 {
|
||||||
|
if len(hashValue) != applePinnedHashSize {
|
||||||
|
return appleSessionConfig{}, E.New("invalid certificate_public_key_sha256 length: ", len(hashValue))
|
||||||
|
}
|
||||||
|
config.pinnedPublicKeySHA256s = append(config.pinnedPublicKeySHA256s, hashValue...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appleTransportShared) retain() {
|
||||||
|
s.refs.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appleTransportShared) release() error {
|
||||||
|
if s.refs.Add(-1) == 0 {
|
||||||
|
return s.bridge.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appleTransportShared) newSession() (*C.box_apple_http_session_t, error) {
|
||||||
|
cProxyHost := C.CString("127.0.0.1")
|
||||||
|
defer C.free(unsafe.Pointer(cProxyHost))
|
||||||
|
cProxyUsername := C.CString(s.bridge.Username())
|
||||||
|
defer C.free(unsafe.Pointer(cProxyUsername))
|
||||||
|
cProxyPassword := C.CString(s.bridge.Password())
|
||||||
|
defer C.free(unsafe.Pointer(cProxyPassword))
|
||||||
|
var cAnchorPEM *C.char
|
||||||
|
if s.config.anchorPEM != "" {
|
||||||
|
cAnchorPEM = C.CString(s.config.anchorPEM)
|
||||||
|
defer C.free(unsafe.Pointer(cAnchorPEM))
|
||||||
|
}
|
||||||
|
var pinnedPointer *C.uint8_t
|
||||||
|
if len(s.config.pinnedPublicKeySHA256s) > 0 {
|
||||||
|
pinnedPointer = (*C.uint8_t)(C.CBytes(s.config.pinnedPublicKeySHA256s))
|
||||||
|
defer C.free(unsafe.Pointer(pinnedPointer))
|
||||||
|
}
|
||||||
|
cConfig := C.box_apple_http_session_config_t{
|
||||||
|
proxy_host: cProxyHost,
|
||||||
|
proxy_port: C.int(s.bridge.Port()),
|
||||||
|
proxy_username: cProxyUsername,
|
||||||
|
proxy_password: cProxyPassword,
|
||||||
|
min_tls_version: C.uint16_t(s.config.minVersion),
|
||||||
|
max_tls_version: C.uint16_t(s.config.maxVersion),
|
||||||
|
insecure: C.bool(s.config.insecure),
|
||||||
|
anchor_pem: cAnchorPEM,
|
||||||
|
anchor_pem_len: C.size_t(len(s.config.anchorPEM)),
|
||||||
|
anchor_only: C.bool(s.config.anchorOnly),
|
||||||
|
pinned_public_key_sha256: pinnedPointer,
|
||||||
|
pinned_public_key_sha256_len: C.size_t(len(s.config.pinnedPublicKeySHA256s)),
|
||||||
|
}
|
||||||
|
var cErr *C.char
|
||||||
|
session := C.box_apple_http_session_create(&cConfig, &cErr)
|
||||||
|
if session != nil {
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
return nil, appleCStringError(cErr, "create Apple HTTP session")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *appleTransport) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
if requestRequiresHTTP1(request) {
|
||||||
|
return nil, E.New("HTTP upgrade requests are unsupported in Apple HTTP engine")
|
||||||
|
}
|
||||||
|
if request.URL == nil {
|
||||||
|
return nil, E.New("missing request URL")
|
||||||
|
}
|
||||||
|
switch request.URL.Scheme {
|
||||||
|
case "http", "https":
|
||||||
|
default:
|
||||||
|
return nil, E.New("unsupported URL scheme: ", request.URL.Scheme)
|
||||||
|
}
|
||||||
|
if request.URL.Scheme == "https" && t.shared.config.serverName != "" && !strings.EqualFold(t.shared.config.serverName, request.URL.Hostname()) {
|
||||||
|
return nil, E.New("tls.server_name is unsupported in Apple HTTP engine unless it matches request host")
|
||||||
|
}
|
||||||
|
var body []byte
|
||||||
|
if request.Body != nil && request.Body != http.NoBody {
|
||||||
|
defer request.Body.Close()
|
||||||
|
var err error
|
||||||
|
body, err = io.ReadAll(request.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
headerKeys, headerValues := flattenRequestHeaders(request)
|
||||||
|
cMethod := C.CString(request.Method)
|
||||||
|
defer C.free(unsafe.Pointer(cMethod))
|
||||||
|
cURL := C.CString(request.URL.String())
|
||||||
|
defer C.free(unsafe.Pointer(cURL))
|
||||||
|
cHeaderKeys := make([]*C.char, len(headerKeys))
|
||||||
|
cHeaderValues := make([]*C.char, len(headerValues))
|
||||||
|
defer func() {
|
||||||
|
for _, ptr := range cHeaderKeys {
|
||||||
|
C.free(unsafe.Pointer(ptr))
|
||||||
|
}
|
||||||
|
for _, ptr := range cHeaderValues {
|
||||||
|
C.free(unsafe.Pointer(ptr))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
for index, value := range headerKeys {
|
||||||
|
cHeaderKeys[index] = C.CString(value)
|
||||||
|
}
|
||||||
|
for index, value := range headerValues {
|
||||||
|
cHeaderValues[index] = C.CString(value)
|
||||||
|
}
|
||||||
|
var headerKeysPointer **C.char
|
||||||
|
var headerValuesPointer **C.char
|
||||||
|
if len(cHeaderKeys) > 0 {
|
||||||
|
pointerArraySize := C.size_t(len(cHeaderKeys)) * C.size_t(unsafe.Sizeof((*C.char)(nil)))
|
||||||
|
headerKeysPointer = (**C.char)(C.malloc(pointerArraySize))
|
||||||
|
defer C.free(unsafe.Pointer(headerKeysPointer))
|
||||||
|
headerValuesPointer = (**C.char)(C.malloc(pointerArraySize))
|
||||||
|
defer C.free(unsafe.Pointer(headerValuesPointer))
|
||||||
|
copy(unsafe.Slice(headerKeysPointer, len(cHeaderKeys)), cHeaderKeys)
|
||||||
|
copy(unsafe.Slice(headerValuesPointer, len(cHeaderValues)), cHeaderValues)
|
||||||
|
}
|
||||||
|
var bodyPointer *C.uint8_t
|
||||||
|
if len(body) > 0 {
|
||||||
|
bodyPointer = (*C.uint8_t)(C.CBytes(body))
|
||||||
|
defer C.free(unsafe.Pointer(bodyPointer))
|
||||||
|
}
|
||||||
|
cRequest := C.box_apple_http_request_t{
|
||||||
|
method: cMethod,
|
||||||
|
url: cURL,
|
||||||
|
header_keys: (**C.char)(headerKeysPointer),
|
||||||
|
header_values: (**C.char)(headerValuesPointer),
|
||||||
|
header_count: C.size_t(len(cHeaderKeys)),
|
||||||
|
body: bodyPointer,
|
||||||
|
body_len: C.size_t(len(body)),
|
||||||
|
}
|
||||||
|
var cErr *C.char
|
||||||
|
var task *C.box_apple_http_task_t
|
||||||
|
t.access.Lock()
|
||||||
|
if t.session == nil {
|
||||||
|
t.access.Unlock()
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
// Keep the session attached until NSURLSession has created the task.
|
||||||
|
task = C.box_apple_http_session_send_async(t.session, &cRequest, &cErr)
|
||||||
|
t.access.Unlock()
|
||||||
|
if task == nil {
|
||||||
|
return nil, appleCStringError(cErr, "create Apple HTTP request")
|
||||||
|
}
|
||||||
|
cancelDone := make(chan struct{})
|
||||||
|
cancelExit := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(cancelExit)
|
||||||
|
select {
|
||||||
|
case <-request.Context().Done():
|
||||||
|
C.box_apple_http_task_cancel(task)
|
||||||
|
case <-cancelDone:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
cResponse := C.box_apple_http_task_wait(task, &cErr)
|
||||||
|
close(cancelDone)
|
||||||
|
<-cancelExit
|
||||||
|
C.box_apple_http_task_close(task)
|
||||||
|
if cResponse == nil {
|
||||||
|
err := appleCStringError(cErr, "Apple HTTP request failed")
|
||||||
|
if request.Context().Err() != nil {
|
||||||
|
return nil, request.Context().Err()
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer C.box_apple_http_response_free(cResponse)
|
||||||
|
return parseAppleHTTPResponse(request, cResponse), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *appleTransport) CloseIdleConnections() {
|
||||||
|
t.access.Lock()
|
||||||
|
if t.closed {
|
||||||
|
t.access.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.access.Unlock()
|
||||||
|
newSession, err := t.shared.newSession()
|
||||||
|
if err != nil {
|
||||||
|
t.shared.logger.Error(E.Cause(err, "reset Apple HTTP session"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.access.Lock()
|
||||||
|
if t.closed {
|
||||||
|
t.access.Unlock()
|
||||||
|
C.box_apple_http_session_close(newSession)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
oldSession := t.session
|
||||||
|
t.session = newSession
|
||||||
|
t.access.Unlock()
|
||||||
|
C.box_apple_http_session_retire(oldSession)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *appleTransport) Clone() adapter.HTTPTransport {
|
||||||
|
t.shared.retain()
|
||||||
|
session, err := t.shared.newSession()
|
||||||
|
if err != nil {
|
||||||
|
_ = t.shared.release()
|
||||||
|
return &errorTransport{err: err}
|
||||||
|
}
|
||||||
|
return &appleTransport{
|
||||||
|
shared: t.shared,
|
||||||
|
session: session,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *appleTransport) Close() error {
|
||||||
|
t.access.Lock()
|
||||||
|
if t.closed {
|
||||||
|
t.access.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
t.closed = true
|
||||||
|
session := t.session
|
||||||
|
t.session = nil
|
||||||
|
t.access.Unlock()
|
||||||
|
C.box_apple_http_session_close(session)
|
||||||
|
return t.shared.release()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *errorTransport) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
return nil, t.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *errorTransport) CloseIdleConnections() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *errorTransport) Clone() adapter.HTTPTransport {
|
||||||
|
return &errorTransport{err: t.err}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *errorTransport) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func flattenRequestHeaders(request *http.Request) ([]string, []string) {
|
||||||
|
var (
|
||||||
|
keys []string
|
||||||
|
values []string
|
||||||
|
)
|
||||||
|
for key, headerValues := range request.Header {
|
||||||
|
for _, value := range headerValues {
|
||||||
|
keys = append(keys, key)
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if request.Host != "" {
|
||||||
|
keys = append(keys, "Host")
|
||||||
|
values = append(values, request.Host)
|
||||||
|
}
|
||||||
|
return keys, values
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAppleHTTPResponse(request *http.Request, response *C.box_apple_http_response_t) *http.Response {
|
||||||
|
headers := make(http.Header)
|
||||||
|
headerKeys := unsafe.Slice(response.header_keys, int(response.header_count))
|
||||||
|
headerValues := unsafe.Slice(response.header_values, int(response.header_count))
|
||||||
|
for index := range headerKeys {
|
||||||
|
headers.Add(C.GoString(headerKeys[index]), C.GoString(headerValues[index]))
|
||||||
|
}
|
||||||
|
body := bytes.NewReader(C.GoBytes(unsafe.Pointer(response.body), C.int(response.body_len)))
|
||||||
|
// NSURLSession's completion-handler API does not expose the negotiated protocol;
|
||||||
|
// callers that read Response.Proto will see HTTP/1.1 even when the wire was HTTP/2.
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: int(response.status_code),
|
||||||
|
Status: fmt.Sprintf("%d %s", int(response.status_code), http.StatusText(int(response.status_code))),
|
||||||
|
Proto: "HTTP/1.1",
|
||||||
|
ProtoMajor: 1,
|
||||||
|
ProtoMinor: 1,
|
||||||
|
Header: headers,
|
||||||
|
Body: io.NopCloser(body),
|
||||||
|
ContentLength: int64(body.Len()),
|
||||||
|
Request: request,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func appleCStringError(cErr *C.char, message string) error {
|
||||||
|
if cErr == nil {
|
||||||
|
return E.New(message)
|
||||||
|
}
|
||||||
|
defer C.free(unsafe.Pointer(cErr))
|
||||||
|
return E.New(message, ": ", C.GoString(cErr))
|
||||||
|
}
|
||||||
69
common/httpclient/apple_transport_darwin.h
Normal file
69
common/httpclient/apple_transport_darwin.h
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
#include <stdbool.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
typedef struct box_apple_http_session box_apple_http_session_t;
|
||||||
|
typedef struct box_apple_http_task box_apple_http_task_t;
|
||||||
|
|
||||||
|
typedef struct box_apple_http_session_config {
|
||||||
|
const char *proxy_host;
|
||||||
|
int proxy_port;
|
||||||
|
const char *proxy_username;
|
||||||
|
const char *proxy_password;
|
||||||
|
uint16_t min_tls_version;
|
||||||
|
uint16_t max_tls_version;
|
||||||
|
bool insecure;
|
||||||
|
const char *anchor_pem;
|
||||||
|
size_t anchor_pem_len;
|
||||||
|
bool anchor_only;
|
||||||
|
const uint8_t *pinned_public_key_sha256;
|
||||||
|
size_t pinned_public_key_sha256_len;
|
||||||
|
} box_apple_http_session_config_t;
|
||||||
|
|
||||||
|
typedef struct box_apple_http_request {
|
||||||
|
const char *method;
|
||||||
|
const char *url;
|
||||||
|
const char **header_keys;
|
||||||
|
const char **header_values;
|
||||||
|
size_t header_count;
|
||||||
|
const uint8_t *body;
|
||||||
|
size_t body_len;
|
||||||
|
} box_apple_http_request_t;
|
||||||
|
|
||||||
|
typedef struct box_apple_http_response {
|
||||||
|
int status_code;
|
||||||
|
char **header_keys;
|
||||||
|
char **header_values;
|
||||||
|
size_t header_count;
|
||||||
|
uint8_t *body;
|
||||||
|
size_t body_len;
|
||||||
|
char *error;
|
||||||
|
} box_apple_http_response_t;
|
||||||
|
|
||||||
|
box_apple_http_session_t *box_apple_http_session_create(
|
||||||
|
const box_apple_http_session_config_t *config,
|
||||||
|
char **error_out
|
||||||
|
);
|
||||||
|
void box_apple_http_session_retire(box_apple_http_session_t *session);
|
||||||
|
void box_apple_http_session_close(box_apple_http_session_t *session);
|
||||||
|
|
||||||
|
box_apple_http_task_t *box_apple_http_session_send_async(
|
||||||
|
box_apple_http_session_t *session,
|
||||||
|
const box_apple_http_request_t *request,
|
||||||
|
char **error_out
|
||||||
|
);
|
||||||
|
box_apple_http_response_t *box_apple_http_task_wait(
|
||||||
|
box_apple_http_task_t *task,
|
||||||
|
char **error_out
|
||||||
|
);
|
||||||
|
void box_apple_http_task_cancel(box_apple_http_task_t *task);
|
||||||
|
void box_apple_http_task_close(box_apple_http_task_t *task);
|
||||||
|
|
||||||
|
void box_apple_http_response_free(box_apple_http_response_t *response);
|
||||||
|
|
||||||
|
char *box_apple_http_verify_public_key_sha256(
|
||||||
|
uint8_t *known_hash_values,
|
||||||
|
size_t known_hash_values_len,
|
||||||
|
uint8_t *leaf_cert,
|
||||||
|
size_t leaf_cert_len
|
||||||
|
);
|
||||||
386
common/httpclient/apple_transport_darwin.m
Normal file
386
common/httpclient/apple_transport_darwin.m
Normal file
@@ -0,0 +1,386 @@
|
|||||||
|
#import "apple_transport_darwin.h"
|
||||||
|
|
||||||
|
#import <CoreFoundation/CFStream.h>
|
||||||
|
#import <Foundation/Foundation.h>
|
||||||
|
#import <Security/Security.h>
|
||||||
|
#import <dispatch/dispatch.h>
|
||||||
|
#import <stdlib.h>
|
||||||
|
#import <string.h>
|
||||||
|
|
||||||
|
typedef struct box_apple_http_session {
|
||||||
|
void *handle;
|
||||||
|
} box_apple_http_session_t;
|
||||||
|
|
||||||
|
typedef struct box_apple_http_task {
|
||||||
|
void *task;
|
||||||
|
void *done_semaphore;
|
||||||
|
box_apple_http_response_t *response;
|
||||||
|
char *error;
|
||||||
|
} box_apple_http_task_t;
|
||||||
|
|
||||||
|
static void box_set_error_string(char **error_out, NSString *message) {
|
||||||
|
if (error_out == NULL || *error_out != NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const char *utf8 = [message UTF8String];
|
||||||
|
*error_out = strdup(utf8 != NULL ? utf8 : "unknown error");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void box_set_error_from_nserror(char **error_out, NSError *error) {
|
||||||
|
if (error == nil) {
|
||||||
|
box_set_error_string(error_out, @"unknown error");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
box_set_error_string(error_out, error.localizedDescription ?: error.description);
|
||||||
|
}
|
||||||
|
|
||||||
|
static NSArray *box_parse_certificates_from_pem(const char *pem, size_t pem_len) {
|
||||||
|
if (pem == NULL || pem_len == 0) {
|
||||||
|
return @[];
|
||||||
|
}
|
||||||
|
NSString *content = [[NSString alloc] initWithBytes:pem length:pem_len encoding:NSUTF8StringEncoding];
|
||||||
|
if (content == nil) {
|
||||||
|
return @[];
|
||||||
|
}
|
||||||
|
NSString *beginMarker = @"-----BEGIN CERTIFICATE-----";
|
||||||
|
NSString *endMarker = @"-----END CERTIFICATE-----";
|
||||||
|
NSMutableArray *certificates = [NSMutableArray array];
|
||||||
|
NSUInteger searchFrom = 0;
|
||||||
|
while (searchFrom < content.length) {
|
||||||
|
NSRange beginRange = [content rangeOfString:beginMarker options:0 range:NSMakeRange(searchFrom, content.length - searchFrom)];
|
||||||
|
if (beginRange.location == NSNotFound) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
NSUInteger bodyStart = beginRange.location + beginRange.length;
|
||||||
|
NSRange endRange = [content rangeOfString:endMarker options:0 range:NSMakeRange(bodyStart, content.length - bodyStart)];
|
||||||
|
if (endRange.location == NSNotFound) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
NSString *base64Section = [content substringWithRange:NSMakeRange(bodyStart, endRange.location - bodyStart)];
|
||||||
|
NSArray<NSString *> *components = [base64Section componentsSeparatedByCharactersInSet:[NSCharacterSet whitespaceAndNewlineCharacterSet]];
|
||||||
|
NSString *base64Content = [components componentsJoinedByString:@""];
|
||||||
|
NSData *der = [[NSData alloc] initWithBase64EncodedString:base64Content options:0];
|
||||||
|
if (der != nil) {
|
||||||
|
SecCertificateRef certificate = SecCertificateCreateWithData(NULL, (__bridge CFDataRef)der);
|
||||||
|
if (certificate != NULL) {
|
||||||
|
[certificates addObject:(__bridge id)certificate];
|
||||||
|
CFRelease(certificate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
searchFrom = endRange.location + endRange.length;
|
||||||
|
}
|
||||||
|
return certificates;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool box_evaluate_trust(SecTrustRef trustRef, NSArray *anchors, bool anchor_only) {
|
||||||
|
if (trustRef == NULL) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (anchors.count > 0 || anchor_only) {
|
||||||
|
CFMutableArrayRef anchorArray = CFArrayCreateMutable(NULL, 0, &kCFTypeArrayCallBacks);
|
||||||
|
for (id certificate in anchors) {
|
||||||
|
CFArrayAppendValue(anchorArray, (__bridge const void *)certificate);
|
||||||
|
}
|
||||||
|
SecTrustSetAnchorCertificates(trustRef, anchorArray);
|
||||||
|
SecTrustSetAnchorCertificatesOnly(trustRef, anchor_only);
|
||||||
|
CFRelease(anchorArray);
|
||||||
|
}
|
||||||
|
CFErrorRef error = NULL;
|
||||||
|
bool result = SecTrustEvaluateWithError(trustRef, &error);
|
||||||
|
if (error != NULL) {
|
||||||
|
CFRelease(error);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static box_apple_http_response_t *box_create_response(NSHTTPURLResponse *httpResponse, NSData *data) {
|
||||||
|
box_apple_http_response_t *response = calloc(1, sizeof(box_apple_http_response_t));
|
||||||
|
response->status_code = (int)httpResponse.statusCode;
|
||||||
|
NSDictionary *headers = httpResponse.allHeaderFields;
|
||||||
|
response->header_count = headers.count;
|
||||||
|
if (response->header_count > 0) {
|
||||||
|
response->header_keys = calloc(response->header_count, sizeof(char *));
|
||||||
|
response->header_values = calloc(response->header_count, sizeof(char *));
|
||||||
|
NSUInteger index = 0;
|
||||||
|
for (id key in headers) {
|
||||||
|
NSString *keyString = [[key description] copy];
|
||||||
|
NSString *valueString = [[headers[key] description] copy];
|
||||||
|
response->header_keys[index] = strdup(keyString.UTF8String ?: "");
|
||||||
|
response->header_values[index] = strdup(valueString.UTF8String ?: "");
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (data.length > 0) {
|
||||||
|
response->body_len = data.length;
|
||||||
|
response->body = malloc(data.length);
|
||||||
|
memcpy(response->body, data.bytes, data.length);
|
||||||
|
}
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
|
||||||
|
@interface BoxAppleHTTPSessionDelegate : NSObject <NSURLSessionTaskDelegate, NSURLSessionDataDelegate>
|
||||||
|
@property(nonatomic, assign) BOOL insecure;
|
||||||
|
@property(nonatomic, assign) BOOL anchorOnly;
|
||||||
|
@property(nonatomic, strong) NSArray *anchors;
|
||||||
|
@property(nonatomic, strong) NSData *pinnedPublicKeyHashes;
|
||||||
|
@end
|
||||||
|
|
||||||
|
@implementation BoxAppleHTTPSessionDelegate
|
||||||
|
|
||||||
|
- (void)URLSession:(NSURLSession *)session
|
||||||
|
task:(NSURLSessionTask *)task
|
||||||
|
willPerformHTTPRedirection:(NSHTTPURLResponse *)response
|
||||||
|
newRequest:(NSURLRequest *)request
|
||||||
|
completionHandler:(void (^)(NSURLRequest * _Nullable))completionHandler {
|
||||||
|
completionHandler(nil);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)URLSession:(NSURLSession *)session
|
||||||
|
task:(NSURLSessionTask *)task
|
||||||
|
didReceiveChallenge:(NSURLAuthenticationChallenge *)challenge
|
||||||
|
completionHandler:(void (^)(NSURLSessionAuthChallengeDisposition disposition, NSURLCredential * _Nullable credential))completionHandler {
|
||||||
|
if (![challenge.protectionSpace.authenticationMethod isEqualToString:NSURLAuthenticationMethodServerTrust]) {
|
||||||
|
completionHandler(NSURLSessionAuthChallengePerformDefaultHandling, nil);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
SecTrustRef trustRef = challenge.protectionSpace.serverTrust;
|
||||||
|
if (trustRef == NULL) {
|
||||||
|
completionHandler(NSURLSessionAuthChallengeCancelAuthenticationChallenge, nil);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
BOOL needsCustomHandling = self.insecure || self.anchorOnly || self.anchors.count > 0 || self.pinnedPublicKeyHashes.length > 0;
|
||||||
|
if (!needsCustomHandling) {
|
||||||
|
completionHandler(NSURLSessionAuthChallengePerformDefaultHandling, nil);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
BOOL ok = YES;
|
||||||
|
if (!self.insecure) {
|
||||||
|
if (self.anchorOnly || self.anchors.count > 0) {
|
||||||
|
ok = box_evaluate_trust(trustRef, self.anchors, self.anchorOnly);
|
||||||
|
} else {
|
||||||
|
CFErrorRef error = NULL;
|
||||||
|
ok = SecTrustEvaluateWithError(trustRef, &error);
|
||||||
|
if (error != NULL) {
|
||||||
|
CFRelease(error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (ok && self.pinnedPublicKeyHashes.length > 0) {
|
||||||
|
CFArrayRef certificateChain = SecTrustCopyCertificateChain(trustRef);
|
||||||
|
SecCertificateRef leafCertificate = NULL;
|
||||||
|
if (certificateChain != NULL && CFArrayGetCount(certificateChain) > 0) {
|
||||||
|
leafCertificate = (SecCertificateRef)CFArrayGetValueAtIndex(certificateChain, 0);
|
||||||
|
}
|
||||||
|
if (leafCertificate == NULL) {
|
||||||
|
ok = NO;
|
||||||
|
} else {
|
||||||
|
NSData *leafData = CFBridgingRelease(SecCertificateCopyData(leafCertificate));
|
||||||
|
char *pinError = box_apple_http_verify_public_key_sha256(
|
||||||
|
(uint8_t *)self.pinnedPublicKeyHashes.bytes,
|
||||||
|
self.pinnedPublicKeyHashes.length,
|
||||||
|
(uint8_t *)leafData.bytes,
|
||||||
|
leafData.length
|
||||||
|
);
|
||||||
|
if (pinError != NULL) {
|
||||||
|
free(pinError);
|
||||||
|
ok = NO;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (certificateChain != NULL) {
|
||||||
|
CFRelease(certificateChain);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!ok) {
|
||||||
|
completionHandler(NSURLSessionAuthChallengeCancelAuthenticationChallenge, nil);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
completionHandler(NSURLSessionAuthChallengeUseCredential, [NSURLCredential credentialForTrust:trustRef]);
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
@interface BoxAppleHTTPSessionHandle : NSObject
|
||||||
|
@property(nonatomic, strong) NSURLSession *session;
|
||||||
|
@property(nonatomic, strong) BoxAppleHTTPSessionDelegate *delegate;
|
||||||
|
@end
|
||||||
|
|
||||||
|
@implementation BoxAppleHTTPSessionHandle
|
||||||
|
@end
|
||||||
|
|
||||||
|
box_apple_http_session_t *box_apple_http_session_create(
|
||||||
|
const box_apple_http_session_config_t *config,
|
||||||
|
char **error_out
|
||||||
|
) {
|
||||||
|
@autoreleasepool {
|
||||||
|
NSURLSessionConfiguration *sessionConfig = [NSURLSessionConfiguration ephemeralSessionConfiguration];
|
||||||
|
sessionConfig.URLCache = nil;
|
||||||
|
sessionConfig.HTTPCookieStorage = nil;
|
||||||
|
sessionConfig.URLCredentialStorage = nil;
|
||||||
|
sessionConfig.HTTPShouldSetCookies = NO;
|
||||||
|
if (config != NULL && config->proxy_host != NULL && config->proxy_port > 0) {
|
||||||
|
NSMutableDictionary *proxyDictionary = [NSMutableDictionary dictionary];
|
||||||
|
proxyDictionary[(__bridge NSString *)kCFStreamPropertySOCKSProxyHost] = [NSString stringWithUTF8String:config->proxy_host];
|
||||||
|
proxyDictionary[(__bridge NSString *)kCFStreamPropertySOCKSProxyPort] = @(config->proxy_port);
|
||||||
|
proxyDictionary[(__bridge NSString *)kCFStreamPropertySOCKSVersion] = (__bridge NSString *)kCFStreamSocketSOCKSVersion5;
|
||||||
|
if (config->proxy_username != NULL) {
|
||||||
|
proxyDictionary[(__bridge NSString *)kCFStreamPropertySOCKSUser] = [NSString stringWithUTF8String:config->proxy_username];
|
||||||
|
}
|
||||||
|
if (config->proxy_password != NULL) {
|
||||||
|
proxyDictionary[(__bridge NSString *)kCFStreamPropertySOCKSPassword] = [NSString stringWithUTF8String:config->proxy_password];
|
||||||
|
}
|
||||||
|
sessionConfig.connectionProxyDictionary = proxyDictionary;
|
||||||
|
}
|
||||||
|
if (config != NULL && config->min_tls_version != 0) {
|
||||||
|
sessionConfig.TLSMinimumSupportedProtocolVersion = (tls_protocol_version_t)config->min_tls_version;
|
||||||
|
}
|
||||||
|
if (config != NULL && config->max_tls_version != 0) {
|
||||||
|
sessionConfig.TLSMaximumSupportedProtocolVersion = (tls_protocol_version_t)config->max_tls_version;
|
||||||
|
}
|
||||||
|
BoxAppleHTTPSessionDelegate *delegate = [[BoxAppleHTTPSessionDelegate alloc] init];
|
||||||
|
if (config != NULL) {
|
||||||
|
delegate.insecure = config->insecure;
|
||||||
|
delegate.anchorOnly = config->anchor_only;
|
||||||
|
delegate.anchors = box_parse_certificates_from_pem(config->anchor_pem, config->anchor_pem_len);
|
||||||
|
if (config->pinned_public_key_sha256 != NULL && config->pinned_public_key_sha256_len > 0) {
|
||||||
|
delegate.pinnedPublicKeyHashes = [NSData dataWithBytes:config->pinned_public_key_sha256 length:config->pinned_public_key_sha256_len];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
NSURLSession *session = [NSURLSession sessionWithConfiguration:sessionConfig delegate:delegate delegateQueue:nil];
|
||||||
|
if (session == nil) {
|
||||||
|
box_set_error_string(error_out, @"create URLSession");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
BoxAppleHTTPSessionHandle *handle = [[BoxAppleHTTPSessionHandle alloc] init];
|
||||||
|
handle.session = session;
|
||||||
|
handle.delegate = delegate;
|
||||||
|
box_apple_http_session_t *sessionHandle = calloc(1, sizeof(box_apple_http_session_t));
|
||||||
|
sessionHandle->handle = (__bridge_retained void *)handle;
|
||||||
|
return sessionHandle;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void box_apple_http_session_retire(box_apple_http_session_t *session) {
|
||||||
|
if (session == NULL || session->handle == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
BoxAppleHTTPSessionHandle *handle = (__bridge_transfer BoxAppleHTTPSessionHandle *)session->handle;
|
||||||
|
[handle.session finishTasksAndInvalidate];
|
||||||
|
free(session);
|
||||||
|
}
|
||||||
|
|
||||||
|
void box_apple_http_session_close(box_apple_http_session_t *session) {
|
||||||
|
if (session == NULL || session->handle == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
BoxAppleHTTPSessionHandle *handle = (__bridge_transfer BoxAppleHTTPSessionHandle *)session->handle;
|
||||||
|
[handle.session invalidateAndCancel];
|
||||||
|
free(session);
|
||||||
|
}
|
||||||
|
|
||||||
|
box_apple_http_task_t *box_apple_http_session_send_async(
|
||||||
|
box_apple_http_session_t *session,
|
||||||
|
const box_apple_http_request_t *request,
|
||||||
|
char **error_out
|
||||||
|
) {
|
||||||
|
@autoreleasepool {
|
||||||
|
if (session == NULL || session->handle == NULL || request == NULL || request->method == NULL || request->url == NULL) {
|
||||||
|
box_set_error_string(error_out, @"invalid apple HTTP request");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
BoxAppleHTTPSessionHandle *handle = (__bridge BoxAppleHTTPSessionHandle *)session->handle;
|
||||||
|
NSURL *requestURL = [NSURL URLWithString:[NSString stringWithUTF8String:request->url]];
|
||||||
|
if (requestURL == nil) {
|
||||||
|
box_set_error_string(error_out, @"invalid request URL");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
NSMutableURLRequest *urlRequest = [NSMutableURLRequest requestWithURL:requestURL];
|
||||||
|
urlRequest.HTTPMethod = [NSString stringWithUTF8String:request->method];
|
||||||
|
for (size_t index = 0; index < request->header_count; index++) {
|
||||||
|
const char *key = request->header_keys[index];
|
||||||
|
const char *value = request->header_values[index];
|
||||||
|
if (key == NULL || value == NULL) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
[urlRequest addValue:[NSString stringWithUTF8String:value] forHTTPHeaderField:[NSString stringWithUTF8String:key]];
|
||||||
|
}
|
||||||
|
if (request->body != NULL && request->body_len > 0) {
|
||||||
|
urlRequest.HTTPBody = [NSData dataWithBytes:request->body length:request->body_len];
|
||||||
|
}
|
||||||
|
box_apple_http_task_t *task = calloc(1, sizeof(box_apple_http_task_t));
|
||||||
|
dispatch_semaphore_t doneSemaphore = dispatch_semaphore_create(0);
|
||||||
|
task->done_semaphore = (__bridge_retained void *)doneSemaphore;
|
||||||
|
NSURLSessionDataTask *dataTask = [handle.session dataTaskWithRequest:urlRequest completionHandler:^(NSData *data, NSURLResponse *response, NSError *error) {
|
||||||
|
if (error != nil) {
|
||||||
|
box_set_error_from_nserror(&task->error, error);
|
||||||
|
} else if (![response isKindOfClass:[NSHTTPURLResponse class]]) {
|
||||||
|
box_set_error_string(&task->error, @"unexpected HTTP response type");
|
||||||
|
} else {
|
||||||
|
task->response = box_create_response((NSHTTPURLResponse *)response, data ?: [NSData data]);
|
||||||
|
}
|
||||||
|
dispatch_semaphore_signal((__bridge dispatch_semaphore_t)task->done_semaphore);
|
||||||
|
}];
|
||||||
|
if (dataTask == nil) {
|
||||||
|
box_set_error_string(error_out, @"create data task");
|
||||||
|
box_apple_http_task_close(task);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
task->task = (__bridge_retained void *)dataTask;
|
||||||
|
[dataTask resume];
|
||||||
|
return task;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
box_apple_http_response_t *box_apple_http_task_wait(
|
||||||
|
box_apple_http_task_t *task,
|
||||||
|
char **error_out
|
||||||
|
) {
|
||||||
|
if (task == NULL || task->done_semaphore == NULL) {
|
||||||
|
box_set_error_string(error_out, @"invalid apple HTTP task");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
dispatch_semaphore_wait((__bridge dispatch_semaphore_t)task->done_semaphore, DISPATCH_TIME_FOREVER);
|
||||||
|
if (task->error != NULL) {
|
||||||
|
box_set_error_string(error_out, [NSString stringWithUTF8String:task->error]);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
return task->response;
|
||||||
|
}
|
||||||
|
|
||||||
|
void box_apple_http_task_cancel(box_apple_http_task_t *task) {
|
||||||
|
if (task == NULL || task->task == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
NSURLSessionTask *nsTask = (__bridge NSURLSessionTask *)task->task;
|
||||||
|
[nsTask cancel];
|
||||||
|
}
|
||||||
|
|
||||||
|
void box_apple_http_task_close(box_apple_http_task_t *task) {
|
||||||
|
if (task == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (task->task != NULL) {
|
||||||
|
__unused NSURLSessionTask *nsTask = (__bridge_transfer NSURLSessionTask *)task->task;
|
||||||
|
task->task = NULL;
|
||||||
|
}
|
||||||
|
if (task->done_semaphore != NULL) {
|
||||||
|
__unused dispatch_semaphore_t doneSemaphore = (__bridge_transfer dispatch_semaphore_t)task->done_semaphore;
|
||||||
|
task->done_semaphore = NULL;
|
||||||
|
}
|
||||||
|
free(task->error);
|
||||||
|
free(task);
|
||||||
|
}
|
||||||
|
|
||||||
|
void box_apple_http_response_free(box_apple_http_response_t *response) {
|
||||||
|
if (response == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (size_t index = 0; index < response->header_count; index++) {
|
||||||
|
free(response->header_keys[index]);
|
||||||
|
free(response->header_values[index]);
|
||||||
|
}
|
||||||
|
free(response->header_keys);
|
||||||
|
free(response->header_values);
|
||||||
|
free(response->body);
|
||||||
|
free(response->error);
|
||||||
|
free(response);
|
||||||
|
}
|
||||||
876
common/httpclient/apple_transport_darwin_test.go
Normal file
876
common/httpclient/apple_transport_darwin_test.go
Normal file
@@ -0,0 +1,876 @@
|
|||||||
|
//go:build darwin && cgo
|
||||||
|
|
||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
stdtls "crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
boxTLS "github.com/sagernet/sing-box/common/tls"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing-box/route"
|
||||||
|
"github.com/sagernet/sing/common/json/badoption"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/sing/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
const appleHTTPTestTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
const appleHTTPRecoveryLoops = 5
|
||||||
|
|
||||||
|
type appleHTTPTestDialer struct {
|
||||||
|
dialer net.Dialer
|
||||||
|
listener net.ListenConfig
|
||||||
|
hostMap map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
type appleHTTPObservedRequest struct {
|
||||||
|
method string
|
||||||
|
body string
|
||||||
|
host string
|
||||||
|
values []string
|
||||||
|
protoMajor int
|
||||||
|
}
|
||||||
|
|
||||||
|
type appleHTTPTestServer struct {
|
||||||
|
server *httptest.Server
|
||||||
|
baseURL string
|
||||||
|
dialHost string
|
||||||
|
certificate stdtls.Certificate
|
||||||
|
certificatePEM string
|
||||||
|
publicKeyHash []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewAppleSessionConfig(t *testing.T) {
|
||||||
|
serverCertificate, serverCertificatePEM := newAppleHTTPTestCertificate(t, "localhost")
|
||||||
|
serverHash := certificatePublicKeySHA256(t, serverCertificate.Certificate[0])
|
||||||
|
otherHash := bytes.Repeat([]byte{0x7f}, applePinnedHashSize)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
options option.HTTPClientOptions
|
||||||
|
check func(t *testing.T, config appleSessionConfig)
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success with certificate anchors",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
DialerOptions: option.DialerOptions{
|
||||||
|
ConnectTimeout: badoption.Duration(2 * time.Second),
|
||||||
|
},
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
ServerName: "localhost",
|
||||||
|
MinVersion: "1.2",
|
||||||
|
MaxVersion: "1.3",
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: func(t *testing.T, config appleSessionConfig) {
|
||||||
|
t.Helper()
|
||||||
|
if config.serverName != "localhost" {
|
||||||
|
t.Fatalf("unexpected server name: %q", config.serverName)
|
||||||
|
}
|
||||||
|
if config.minVersion != stdtls.VersionTLS12 {
|
||||||
|
t.Fatalf("unexpected min version: %x", config.minVersion)
|
||||||
|
}
|
||||||
|
if config.maxVersion != stdtls.VersionTLS13 {
|
||||||
|
t.Fatalf("unexpected max version: %x", config.maxVersion)
|
||||||
|
}
|
||||||
|
if config.insecure {
|
||||||
|
t.Fatal("unexpected insecure flag")
|
||||||
|
}
|
||||||
|
if !config.anchorOnly {
|
||||||
|
t.Fatal("expected anchor_only")
|
||||||
|
}
|
||||||
|
if !strings.Contains(config.anchorPEM, "BEGIN CERTIFICATE") {
|
||||||
|
t.Fatalf("unexpected anchor pem: %q", config.anchorPEM)
|
||||||
|
}
|
||||||
|
if len(config.pinnedPublicKeySHA256s) != 0 {
|
||||||
|
t.Fatalf("unexpected pinned hashes: %d", len(config.pinnedPublicKeySHA256s))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success with flattened pins",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
Insecure: true,
|
||||||
|
CertificatePublicKeySHA256: badoption.Listable[[]byte]{serverHash, otherHash},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: func(t *testing.T, config appleSessionConfig) {
|
||||||
|
t.Helper()
|
||||||
|
if !config.insecure {
|
||||||
|
t.Fatal("expected insecure flag")
|
||||||
|
}
|
||||||
|
if len(config.pinnedPublicKeySHA256s) != 2*applePinnedHashSize {
|
||||||
|
t.Fatalf("unexpected flattened pin length: %d", len(config.pinnedPublicKeySHA256s))
|
||||||
|
}
|
||||||
|
if !bytes.Equal(config.pinnedPublicKeySHA256s[:applePinnedHashSize], serverHash) {
|
||||||
|
t.Fatal("unexpected first pin")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(config.pinnedPublicKeySHA256s[applePinnedHashSize:], otherHash) {
|
||||||
|
t.Fatal("unexpected second pin")
|
||||||
|
}
|
||||||
|
if config.anchorPEM != "" {
|
||||||
|
t.Fatalf("unexpected anchor pem: %q", config.anchorPEM)
|
||||||
|
}
|
||||||
|
if config.anchorOnly {
|
||||||
|
t.Fatal("unexpected anchor_only")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "http11 unsupported",
|
||||||
|
options: option.HTTPClientOptions{Version: 1},
|
||||||
|
wantErr: "HTTP/1.1 is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "http3 unsupported",
|
||||||
|
options: option.HTTPClientOptions{Version: 3},
|
||||||
|
wantErr: "HTTP/3 is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown version",
|
||||||
|
options: option.HTTPClientOptions{Version: 9},
|
||||||
|
wantErr: "unknown HTTP version: 9",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disable version fallback unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
DisableVersionFallback: true,
|
||||||
|
},
|
||||||
|
wantErr: "disable_version_fallback is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "http2 options unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
HTTP2Options: option.HTTP2Options{
|
||||||
|
IdleTimeout: badoption.Duration(time.Second),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "HTTP/2 options are unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "quic options unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
HTTP3Options: option.QUICOptions{
|
||||||
|
InitialPacketSize: 1200,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "QUIC options are unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tls engine unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{Engine: "go"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "tls.engine is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disable sni unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{DisableSNI: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "disable_sni is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "alpn unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
ALPN: badoption.Listable[string]{"h2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "tls.alpn is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cipher suites unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
CipherSuites: badoption.Listable[string]{"TLS_AES_128_GCM_SHA256"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "cipher_suites is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "curve preferences unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
CurvePreferences: badoption.Listable[option.CurvePreference]{option.CurvePreference(option.X25519)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "curve_preferences is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "client certificate unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
ClientCertificate: badoption.Listable[string]{"client-certificate"},
|
||||||
|
ClientKey: badoption.Listable[string]{"client-key"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "client certificate is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tls fragment unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{Fragment: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "tls fragment is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ktls unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{KernelTx: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "ktls is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ech unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
ECH: &option.OutboundECHOptions{Enabled: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "ech is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "utls unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
UTLS: &option.OutboundUTLSOptions{Enabled: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "utls is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reality unsupported",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
Reality: &option.OutboundRealityOptions{Enabled: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "reality is unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pin and certificate conflict",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
CertificatePublicKeySHA256: badoption.Listable[[]byte]{serverHash},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "certificate_public_key_sha256 is conflict with certificate or certificate_path",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid min version",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{MinVersion: "bogus"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "parse min_version",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid max version",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{MaxVersion: "bogus"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "parse max_version",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid pin length",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
CertificatePublicKeySHA256: badoption.Listable[[]byte]{{0x01, 0x02}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "invalid certificate_public_key_sha256 length: 2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
config, err := newAppleSessionConfig(context.Background(), testCase.options)
|
||||||
|
if testCase.wantErr != "" {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), testCase.wantErr) {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if testCase.check != nil {
|
||||||
|
testCase.check(t, config)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleTransportVerifyPublicKeySHA256(t *testing.T) {
|
||||||
|
serverCertificate, _ := newAppleHTTPTestCertificate(t, "localhost")
|
||||||
|
goodHash := certificatePublicKeySHA256(t, serverCertificate.Certificate[0])
|
||||||
|
badHash := append([]byte(nil), goodHash...)
|
||||||
|
badHash[0] ^= 0xff
|
||||||
|
|
||||||
|
err := verifyApplePinnedPublicKeySHA256(goodHash, serverCertificate.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected correct pin to succeed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = verifyApplePinnedPublicKeySHA256(badHash, serverCertificate.Certificate[0])
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected incorrect pin to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "unrecognized remote public key") {
|
||||||
|
t.Fatalf("unexpected pin mismatch error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = verifyApplePinnedPublicKeySHA256(goodHash[:applePinnedHashSize-1], serverCertificate.Certificate[0])
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected malformed pin list to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "invalid pinned public key list") {
|
||||||
|
t.Fatalf("unexpected malformed pin error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleTransportRoundTripHTTPS(t *testing.T) {
|
||||||
|
requests := make(chan appleHTTPObservedRequest, 1)
|
||||||
|
server := startAppleHTTPTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
requests <- appleHTTPObservedRequest{
|
||||||
|
method: r.Method,
|
||||||
|
body: string(body),
|
||||||
|
host: r.Host,
|
||||||
|
values: append([]string(nil), r.Header.Values("X-Test")...),
|
||||||
|
protoMajor: r.ProtoMajor,
|
||||||
|
}
|
||||||
|
w.Header().Set("X-Reply", "apple")
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
_, _ = w.Write([]byte("response body"))
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := newAppleHTTPTestTransport(t, server, option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: appleHTTPServerTLSOptions(server),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
request, err := http.NewRequest(http.MethodPost, server.URL("/roundtrip"), bytes.NewReader([]byte("request body")))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
request.Header.Add("X-Test", "one")
|
||||||
|
request.Header.Add("X-Test", "two")
|
||||||
|
request.Host = "custom.example"
|
||||||
|
|
||||||
|
response, err := transport.RoundTrip(request)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
responseBody := readResponseBody(t, response)
|
||||||
|
if response.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("unexpected status code: %d", response.StatusCode)
|
||||||
|
}
|
||||||
|
if response.Status != "201 Created" {
|
||||||
|
t.Fatalf("unexpected status: %q", response.Status)
|
||||||
|
}
|
||||||
|
if response.Header.Get("X-Reply") != "apple" {
|
||||||
|
t.Fatalf("unexpected response header: %q", response.Header.Get("X-Reply"))
|
||||||
|
}
|
||||||
|
if responseBody != "response body" {
|
||||||
|
t.Fatalf("unexpected response body: %q", responseBody)
|
||||||
|
}
|
||||||
|
if response.ContentLength != int64(len(responseBody)) {
|
||||||
|
t.Fatalf("unexpected content length: %d", response.ContentLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
observed := waitObservedRequest(t, requests)
|
||||||
|
if observed.method != http.MethodPost {
|
||||||
|
t.Fatalf("unexpected method: %q", observed.method)
|
||||||
|
}
|
||||||
|
if observed.body != "request body" {
|
||||||
|
t.Fatalf("unexpected request body: %q", observed.body)
|
||||||
|
}
|
||||||
|
if observed.host != "custom.example" {
|
||||||
|
t.Fatalf("unexpected host: %q", observed.host)
|
||||||
|
}
|
||||||
|
if observed.protoMajor != 2 {
|
||||||
|
t.Fatalf("expected HTTP/2 request, got HTTP/%d", observed.protoMajor)
|
||||||
|
}
|
||||||
|
var normalizedValues []string
|
||||||
|
for _, value := range observed.values {
|
||||||
|
for _, part := range strings.Split(value, ",") {
|
||||||
|
normalizedValues = append(normalizedValues, strings.TrimSpace(part))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slices.Sort(normalizedValues)
|
||||||
|
if !slices.Equal(normalizedValues, []string{"one", "two"}) {
|
||||||
|
t.Fatalf("unexpected header values: %#v", observed.values)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleTransportPinnedPublicKey(t *testing.T) {
|
||||||
|
server := startAppleHTTPTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("pinned"))
|
||||||
|
})
|
||||||
|
|
||||||
|
goodTransport := newAppleHTTPTestTransport(t, server, option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
ServerName: "localhost",
|
||||||
|
Insecure: true,
|
||||||
|
CertificatePublicKeySHA256: badoption.Listable[[]byte]{server.publicKeyHash},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
response, err := goodTransport.RoundTrip(newAppleHTTPRequest(t, http.MethodGet, server.URL("/good"), nil))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected pinned request to succeed: %v", err)
|
||||||
|
}
|
||||||
|
response.Body.Close()
|
||||||
|
|
||||||
|
badHash := append([]byte(nil), server.publicKeyHash...)
|
||||||
|
badHash[0] ^= 0xff
|
||||||
|
badTransport := newAppleHTTPTestTransport(t, server, option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
ServerName: "localhost",
|
||||||
|
Insecure: true,
|
||||||
|
CertificatePublicKeySHA256: badoption.Listable[[]byte]{badHash},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
response, err = badTransport.RoundTrip(newAppleHTTPRequest(t, http.MethodGet, server.URL("/bad"), nil))
|
||||||
|
if err == nil {
|
||||||
|
response.Body.Close()
|
||||||
|
t.Fatal("expected incorrect pinned public key to fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleTransportGuardrails(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
options option.HTTPClientOptions
|
||||||
|
buildRequest func(t *testing.T) *http.Request
|
||||||
|
wantErrSubstr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "websocket upgrade rejected",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
},
|
||||||
|
buildRequest: func(t *testing.T) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
request := newAppleHTTPRequest(t, http.MethodGet, "https://localhost/socket", nil)
|
||||||
|
request.Header.Set("Connection", "Upgrade")
|
||||||
|
request.Header.Set("Upgrade", "websocket")
|
||||||
|
return request
|
||||||
|
},
|
||||||
|
wantErrSubstr: "HTTP upgrade requests are unsupported in Apple HTTP engine",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing url rejected",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
},
|
||||||
|
buildRequest: func(t *testing.T) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
return &http.Request{Method: http.MethodGet}
|
||||||
|
},
|
||||||
|
wantErrSubstr: "missing request URL",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported scheme rejected",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
},
|
||||||
|
buildRequest: func(t *testing.T) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
return newAppleHTTPRequest(t, http.MethodGet, "ftp://localhost/file", nil)
|
||||||
|
},
|
||||||
|
wantErrSubstr: "unsupported URL scheme: ftp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "server name mismatch rejected",
|
||||||
|
options: option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: &option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
ServerName: "example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
buildRequest: func(t *testing.T) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
return newAppleHTTPRequest(t, http.MethodGet, "https://localhost/path", nil)
|
||||||
|
},
|
||||||
|
wantErrSubstr: "tls.server_name is unsupported in Apple HTTP engine unless it matches request host",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
transport := newAppleHTTPTestTransport(t, nil, testCase.options)
|
||||||
|
response, err := transport.RoundTrip(testCase.buildRequest(t))
|
||||||
|
if err == nil {
|
||||||
|
response.Body.Close()
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), testCase.wantErrSubstr) {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleTransportCancellationRecovery(t *testing.T) {
|
||||||
|
server := startAppleHTTPTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/block":
|
||||||
|
select {
|
||||||
|
case <-r.Context().Done():
|
||||||
|
return
|
||||||
|
case <-time.After(appleHTTPTestTimeout):
|
||||||
|
http.Error(w, "request was not canceled", http.StatusGatewayTimeout)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := newAppleHTTPTestTransport(t, server, option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: appleHTTPServerTLSOptions(server),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
for index := 0; index < appleHTTPRecoveryLoops; index++ {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
request := newAppleHTTPRequestWithContext(t, ctx, http.MethodGet, server.URL("/block"), nil)
|
||||||
|
response, err := transport.RoundTrip(request)
|
||||||
|
cancel()
|
||||||
|
if err == nil {
|
||||||
|
response.Body.Close()
|
||||||
|
t.Fatalf("iteration %d: expected cancellation error", index)
|
||||||
|
}
|
||||||
|
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("iteration %d: unexpected cancellation error: %v", index, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err = transport.RoundTrip(newAppleHTTPRequest(t, http.MethodGet, server.URL("/ok"), nil))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("iteration %d: follow-up request failed: %v", index, err)
|
||||||
|
}
|
||||||
|
if body := readResponseBody(t, response); body != "ok" {
|
||||||
|
response.Body.Close()
|
||||||
|
t.Fatalf("iteration %d: unexpected follow-up body: %q", index, body)
|
||||||
|
}
|
||||||
|
response.Body.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleTransportLifecycle(t *testing.T) {
|
||||||
|
server := startAppleHTTPTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := newAppleHTTPTestTransport(t, server, option.HTTPClientOptions{
|
||||||
|
Version: 2,
|
||||||
|
OutboundTLSOptionsContainer: option.OutboundTLSOptionsContainer{
|
||||||
|
TLS: appleHTTPServerTLSOptions(server),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
clone := transport.Clone()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = clone.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
assertAppleHTTPSucceeds(t, transport, server.URL("/original"))
|
||||||
|
assertAppleHTTPSucceeds(t, clone, server.URL("/clone"))
|
||||||
|
|
||||||
|
transport.CloseIdleConnections()
|
||||||
|
assertAppleHTTPSucceeds(t, transport, server.URL("/reset"))
|
||||||
|
|
||||||
|
if err := transport.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := transport.RoundTrip(newAppleHTTPRequest(t, http.MethodGet, server.URL("/closed"), nil))
|
||||||
|
if err == nil {
|
||||||
|
response.Body.Close()
|
||||||
|
t.Fatal("expected closed transport to fail")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, net.ErrClosed) {
|
||||||
|
t.Fatalf("unexpected closed transport error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertAppleHTTPSucceeds(t, clone, server.URL("/clone-after-original-close"))
|
||||||
|
|
||||||
|
clone.CloseIdleConnections()
|
||||||
|
assertAppleHTTPSucceeds(t, clone, server.URL("/clone-reset"))
|
||||||
|
|
||||||
|
if err := clone.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
response, err = clone.RoundTrip(newAppleHTTPRequest(t, http.MethodGet, server.URL("/clone-closed"), nil))
|
||||||
|
if err == nil {
|
||||||
|
response.Body.Close()
|
||||||
|
t.Fatal("expected closed clone to fail")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, net.ErrClosed) {
|
||||||
|
t.Fatalf("unexpected closed clone error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func startAppleHTTPTestServer(t *testing.T, handler http.HandlerFunc) *appleHTTPTestServer {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
serverCertificate, serverCertificatePEM := newAppleHTTPTestCertificate(t, "localhost")
|
||||||
|
server := httptest.NewUnstartedServer(handler)
|
||||||
|
server.EnableHTTP2 = true
|
||||||
|
server.TLS = &stdtls.Config{
|
||||||
|
Certificates: []stdtls.Certificate{serverCertificate},
|
||||||
|
MinVersion: stdtls.VersionTLS12,
|
||||||
|
}
|
||||||
|
server.StartTLS()
|
||||||
|
t.Cleanup(server.Close)
|
||||||
|
|
||||||
|
parsedURL, err := url.Parse(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
baseURL := *parsedURL
|
||||||
|
baseURL.Host = net.JoinHostPort("localhost", parsedURL.Port())
|
||||||
|
|
||||||
|
return &appleHTTPTestServer{
|
||||||
|
server: server,
|
||||||
|
baseURL: baseURL.String(),
|
||||||
|
dialHost: parsedURL.Hostname(),
|
||||||
|
certificate: serverCertificate,
|
||||||
|
certificatePEM: serverCertificatePEM,
|
||||||
|
publicKeyHash: certificatePublicKeySHA256(t, serverCertificate.Certificate[0]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appleHTTPTestServer) URL(path string) string {
|
||||||
|
if path == "" {
|
||||||
|
return s.baseURL
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(path, "/") {
|
||||||
|
return s.baseURL + path
|
||||||
|
}
|
||||||
|
return s.baseURL + "/" + path
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleHTTPTestTransport(t *testing.T, server *appleHTTPTestServer, options option.HTTPClientOptions) adapter.HTTPTransport {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx := service.ContextWith[adapter.ConnectionManager](
|
||||||
|
context.Background(),
|
||||||
|
route.NewConnectionManager(log.NewNOPFactory().NewLogger("connection")),
|
||||||
|
)
|
||||||
|
dialer := &appleHTTPTestDialer{
|
||||||
|
hostMap: make(map[string]string),
|
||||||
|
}
|
||||||
|
if server != nil {
|
||||||
|
dialer.hostMap["localhost"] = server.dialHost
|
||||||
|
}
|
||||||
|
|
||||||
|
transport, err := newAppleTransport(ctx, log.NewNOPFactory().NewLogger("httpclient"), dialer, options)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = transport.Close()
|
||||||
|
})
|
||||||
|
return transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *appleHTTPTestDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||||
|
host := destination.AddrString()
|
||||||
|
if destination.IsDomain() {
|
||||||
|
host = destination.Fqdn
|
||||||
|
if mappedHost, loaded := d.hostMap[host]; loaded {
|
||||||
|
host = mappedHost
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return d.dialer.DialContext(ctx, network, net.JoinHostPort(host, strconv.Itoa(int(destination.Port))))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *appleHTTPTestDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||||
|
host := destination.AddrString()
|
||||||
|
if destination.IsDomain() {
|
||||||
|
host = destination.Fqdn
|
||||||
|
if mappedHost, loaded := d.hostMap[host]; loaded {
|
||||||
|
host = mappedHost
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if host == "" {
|
||||||
|
host = "127.0.0.1"
|
||||||
|
}
|
||||||
|
return d.listener.ListenPacket(ctx, N.NetworkUDP, net.JoinHostPort(host, strconv.Itoa(int(destination.Port))))
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleHTTPTestCertificate(t *testing.T, serverName string) (stdtls.Certificate, string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
privateKeyPEM, certificatePEM, err := boxTLS.GenerateCertificate(nil, nil, time.Now, serverName, time.Now().Add(time.Hour))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
certificate, err := stdtls.X509KeyPair(certificatePEM, privateKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return certificate, string(certificatePEM)
|
||||||
|
}
|
||||||
|
|
||||||
|
func certificatePublicKeySHA256(t *testing.T, certificateDER []byte) []byte {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
certificate, err := x509.ParseCertificate(certificateDER)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
publicKeyDER, err := x509.MarshalPKIXPublicKey(certificate.PublicKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
hashValue := sha256.Sum256(publicKeyDER)
|
||||||
|
return append([]byte(nil), hashValue[:]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func appleHTTPServerTLSOptions(server *appleHTTPTestServer) *option.OutboundTLSOptions {
|
||||||
|
return &option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
ServerName: "localhost",
|
||||||
|
Certificate: badoption.Listable[string]{server.certificatePEM},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleHTTPRequest(t *testing.T, method string, rawURL string, body []byte) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
return newAppleHTTPRequestWithContext(t, context.Background(), method, rawURL, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleHTTPRequestWithContext(t *testing.T, ctx context.Context, method string, rawURL string, body []byte) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
request, err := http.NewRequestWithContext(ctx, method, rawURL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return request
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitObservedRequest(t *testing.T, requests <-chan appleHTTPObservedRequest) appleHTTPObservedRequest {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case request := <-requests:
|
||||||
|
return request
|
||||||
|
case <-time.After(appleHTTPTestTimeout):
|
||||||
|
t.Fatal("timed out waiting for observed request")
|
||||||
|
return appleHTTPObservedRequest{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readResponseBody(t *testing.T, response *http.Response) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return string(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertAppleHTTPSucceeds(t *testing.T, transport adapter.HTTPTransport, rawURL string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
response, err := transport.RoundTrip(newAppleHTTPRequest(t, http.MethodGet, rawURL, nil))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
if body := readResponseBody(t, response); body != "ok" {
|
||||||
|
t.Fatalf("unexpected response body: %q", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
17
common/httpclient/apple_transport_stub.go
Normal file
17
common/httpclient/apple_transport_stub.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
//go:build !darwin || !cgo
|
||||||
|
|
||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAppleTransport(ctx context.Context, logger logger.ContextLogger, rawDialer N.Dialer, options option.HTTPClientOptions) (adapter.HTTPTransport, error) {
|
||||||
|
return nil, E.New("Apple HTTP engine is not available on non-Apple platforms")
|
||||||
|
}
|
||||||
182
common/httpclient/client.go
Normal file
182
common/httpclient/client.go
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Transport struct {
|
||||||
|
transport adapter.HTTPTransport
|
||||||
|
dialer N.Dialer
|
||||||
|
headers http.Header
|
||||||
|
host string
|
||||||
|
tag string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTransport(ctx context.Context, logger logger.ContextLogger, tag string, options option.HTTPClientOptions) (*Transport, error) {
|
||||||
|
rawDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||||
|
Context: ctx,
|
||||||
|
Options: options.DialerOptions,
|
||||||
|
RemoteIsDomain: true,
|
||||||
|
DirectResolver: options.DirectResolver,
|
||||||
|
ResolverOnDetour: options.ResolveOnDetour,
|
||||||
|
NewDialer: options.ResolveOnDetour,
|
||||||
|
DefaultOutbound: options.DefaultOutbound,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch options.Engine {
|
||||||
|
case C.TLSEngineApple:
|
||||||
|
transport, transportErr := newAppleTransport(ctx, logger, rawDialer, options)
|
||||||
|
if transportErr != nil {
|
||||||
|
return nil, transportErr
|
||||||
|
}
|
||||||
|
headers := options.Headers.Build()
|
||||||
|
host := headers.Get("Host")
|
||||||
|
headers.Del("Host")
|
||||||
|
return &Transport{
|
||||||
|
transport: transport,
|
||||||
|
dialer: rawDialer,
|
||||||
|
headers: headers,
|
||||||
|
host: host,
|
||||||
|
tag: tag,
|
||||||
|
}, nil
|
||||||
|
case C.TLSEngineDefault, "go":
|
||||||
|
default:
|
||||||
|
return nil, E.New("unknown HTTP engine: ", options.Engine)
|
||||||
|
}
|
||||||
|
tlsOptions := common.PtrValueOrDefault(options.TLS)
|
||||||
|
tlsOptions.Enabled = true
|
||||||
|
baseTLSConfig, err := tls.NewClientWithOptions(tls.ClientOptions{
|
||||||
|
Context: ctx,
|
||||||
|
Logger: logger,
|
||||||
|
Options: tlsOptions,
|
||||||
|
AllowEmptyServerName: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return NewTransportWithDialer(rawDialer, baseTLSConfig, tag, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTransportWithDialer(rawDialer N.Dialer, baseTLSConfig tls.Config, tag string, options option.HTTPClientOptions) (*Transport, error) {
|
||||||
|
transport, err := newTransport(rawDialer, baseTLSConfig, options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
headers := options.Headers.Build()
|
||||||
|
host := headers.Get("Host")
|
||||||
|
headers.Del("Host")
|
||||||
|
return &Transport{
|
||||||
|
transport: transport,
|
||||||
|
dialer: rawDialer,
|
||||||
|
headers: headers,
|
||||||
|
host: host,
|
||||||
|
tag: tag,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTransport(rawDialer N.Dialer, baseTLSConfig tls.Config, options option.HTTPClientOptions) (adapter.HTTPTransport, error) {
|
||||||
|
version := options.Version
|
||||||
|
if version == 0 {
|
||||||
|
version = 2
|
||||||
|
}
|
||||||
|
fallbackDelay := time.Duration(options.DialerOptions.FallbackDelay)
|
||||||
|
if fallbackDelay == 0 {
|
||||||
|
fallbackDelay = 300 * time.Millisecond
|
||||||
|
}
|
||||||
|
var transport adapter.HTTPTransport
|
||||||
|
var err error
|
||||||
|
switch version {
|
||||||
|
case 1:
|
||||||
|
transport = newHTTP1Transport(rawDialer, baseTLSConfig)
|
||||||
|
case 2:
|
||||||
|
if options.DisableVersionFallback {
|
||||||
|
transport, err = newHTTP2Transport(rawDialer, baseTLSConfig, options.HTTP2Options)
|
||||||
|
} else {
|
||||||
|
transport, err = newHTTP2FallbackTransport(rawDialer, baseTLSConfig, options.HTTP2Options)
|
||||||
|
}
|
||||||
|
case 3:
|
||||||
|
if baseTLSConfig != nil {
|
||||||
|
_, err = baseTLSConfig.STDConfig()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if options.DisableVersionFallback {
|
||||||
|
transport, err = newHTTP3Transport(rawDialer, baseTLSConfig, options.HTTP3Options)
|
||||||
|
} else {
|
||||||
|
var h2Fallback adapter.HTTPTransport
|
||||||
|
h2Fallback, err = newHTTP2FallbackTransport(rawDialer, baseTLSConfig, options.HTTP2Options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
transport, err = newHTTP3FallbackTransport(rawDialer, baseTLSConfig, h2Fallback, options.HTTP3Options, fallbackDelay)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, E.New("unknown HTTP version: ", version)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return transport, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Transport) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
if c.tag == "" && len(c.headers) == 0 && c.host == "" {
|
||||||
|
return c.transport.RoundTrip(request)
|
||||||
|
}
|
||||||
|
if c.tag != "" {
|
||||||
|
if transportTag, loaded := transportTagFromContext(request.Context()); loaded && transportTag == c.tag {
|
||||||
|
return nil, E.New("HTTP request loopback in transport[", c.tag, "]")
|
||||||
|
}
|
||||||
|
request = request.Clone(contextWithTransportTag(request.Context(), c.tag))
|
||||||
|
} else {
|
||||||
|
request = request.Clone(request.Context())
|
||||||
|
}
|
||||||
|
applyHeaders(request, c.headers, c.host)
|
||||||
|
return c.transport.RoundTrip(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Transport) CloseIdleConnections() {
|
||||||
|
c.transport.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Transport) Clone() adapter.HTTPTransport {
|
||||||
|
return &Transport{
|
||||||
|
transport: c.transport.Clone(),
|
||||||
|
dialer: c.dialer,
|
||||||
|
headers: c.headers.Clone(),
|
||||||
|
host: c.host,
|
||||||
|
tag: c.tag,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Transport) Close() error {
|
||||||
|
return c.transport.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitializeDetour eagerly resolves the detour dialer backing transport so that
|
||||||
|
// detour misconfigurations surface at startup instead of on the first request.
|
||||||
|
func InitializeDetour(transport adapter.HTTPTransport) error {
|
||||||
|
if shared, isShared := transport.(*sharedTransport); isShared {
|
||||||
|
transport = shared.HTTPTransport
|
||||||
|
}
|
||||||
|
inner, isInner := transport.(*Transport)
|
||||||
|
if !isInner {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return dialer.InitializeDetour(inner.dialer)
|
||||||
|
}
|
||||||
14
common/httpclient/context.go
Normal file
14
common/httpclient/context.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package httpclient
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type transportKey struct{}
|
||||||
|
|
||||||
|
func contextWithTransportTag(ctx context.Context, transportTag string) context.Context {
|
||||||
|
return context.WithValue(ctx, transportKey{}, transportTag)
|
||||||
|
}
|
||||||
|
|
||||||
|
func transportTagFromContext(ctx context.Context) (string, bool) {
|
||||||
|
value, loaded := ctx.Value(transportKey{}).(string)
|
||||||
|
return value, loaded
|
||||||
|
}
|
||||||
86
common/httpclient/helpers.go
Normal file
86
common/httpclient/helpers.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
stdTLS "crypto/tls"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func dialTLS(ctx context.Context, rawDialer N.Dialer, baseTLSConfig tls.Config, destination M.Socksaddr, nextProtos []string, expectProto string) (net.Conn, error) {
|
||||||
|
if baseTLSConfig == nil {
|
||||||
|
return nil, E.New("TLS transport unavailable")
|
||||||
|
}
|
||||||
|
tlsConfig := baseTLSConfig.Clone()
|
||||||
|
if tlsConfig.ServerName() == "" && destination.IsValid() {
|
||||||
|
tlsConfig.SetServerName(destination.AddrString())
|
||||||
|
}
|
||||||
|
tlsConfig.SetNextProtos(nextProtos)
|
||||||
|
conn, err := rawDialer.DialContext(ctx, N.NetworkTCP, destination)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tlsConn, err := tls.ClientHandshake(ctx, conn, tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if expectProto != "" && tlsConn.ConnectionState().NegotiatedProtocol != expectProto {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, errHTTP2Fallback
|
||||||
|
}
|
||||||
|
return tlsConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyHeaders(request *http.Request, headers http.Header, host string) {
|
||||||
|
for header, values := range headers {
|
||||||
|
request.Header[header] = append([]string(nil), values...)
|
||||||
|
}
|
||||||
|
if host != "" {
|
||||||
|
request.Host = host
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestRequiresHTTP1(request *http.Request) bool {
|
||||||
|
return strings.Contains(strings.ToLower(request.Header.Get("Connection")), "upgrade") &&
|
||||||
|
strings.EqualFold(request.Header.Get("Upgrade"), "websocket")
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestReplayable(request *http.Request) bool {
|
||||||
|
return request.Body == nil || request.Body == http.NoBody || request.GetBody != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneRequestForRetry(request *http.Request) *http.Request {
|
||||||
|
cloned := request.Clone(request.Context())
|
||||||
|
if request.Body != nil && request.Body != http.NoBody && request.GetBody != nil {
|
||||||
|
cloned.Body = mustGetBody(request)
|
||||||
|
}
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustGetBody(request *http.Request) io.ReadCloser {
|
||||||
|
body, err := request.GetBody()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildSTDTLSConfig(baseTLSConfig tls.Config, destination M.Socksaddr, nextProtos []string) (*stdTLS.Config, error) {
|
||||||
|
if baseTLSConfig == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
tlsConfig := baseTLSConfig.Clone()
|
||||||
|
if tlsConfig.ServerName() == "" && destination.IsValid() {
|
||||||
|
tlsConfig.SetServerName(destination.AddrString())
|
||||||
|
}
|
||||||
|
tlsConfig.SetNextProtos(nextProtos)
|
||||||
|
return tlsConfig.STDConfig()
|
||||||
|
}
|
||||||
47
common/httpclient/http1_transport.go
Normal file
47
common/httpclient/http1_transport.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
type http1Transport struct {
|
||||||
|
transport *http.Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP1Transport(rawDialer N.Dialer, baseTLSConfig tls.Config) *http1Transport {
|
||||||
|
transport := &http.Transport{
|
||||||
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return rawDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if baseTLSConfig != nil {
|
||||||
|
transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return dialTLS(ctx, rawDialer, baseTLSConfig, M.ParseSocksaddr(addr), []string{"http/1.1"}, "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &http1Transport{transport: transport}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http1Transport) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
return t.transport.RoundTrip(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http1Transport) CloseIdleConnections() {
|
||||||
|
t.transport.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http1Transport) Clone() adapter.HTTPTransport {
|
||||||
|
return &http1Transport{transport: t.transport.Clone()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http1Transport) Close() error {
|
||||||
|
t.CloseIdleConnections()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
42
common/httpclient/http2_config.go
Normal file
42
common/httpclient/http2_config.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
stdTLS "crypto/tls"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CloneHTTP2Transport(transport *http2.Transport) *http2.Transport {
|
||||||
|
return &http2.Transport{
|
||||||
|
ReadIdleTimeout: transport.ReadIdleTimeout,
|
||||||
|
PingTimeout: transport.PingTimeout,
|
||||||
|
DialTLSContext: transport.DialTLSContext,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConfigureHTTP2Transport(options option.HTTP2Options) (*http2.Transport, error) {
|
||||||
|
stdTransport := &http.Transport{
|
||||||
|
TLSClientConfig: &stdTLS.Config{},
|
||||||
|
HTTP2: &http.HTTP2Config{
|
||||||
|
MaxReceiveBufferPerStream: int(options.StreamReceiveWindow.Value()),
|
||||||
|
MaxReceiveBufferPerConnection: int(options.ConnectionReceiveWindow.Value()),
|
||||||
|
MaxConcurrentStreams: options.MaxConcurrentStreams,
|
||||||
|
SendPingTimeout: time.Duration(options.KeepAlivePeriod),
|
||||||
|
PingTimeout: time.Duration(options.IdleTimeout),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h2Transport, err := http2.ConfigureTransports(stdTransport)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "configure HTTP/2 transport")
|
||||||
|
}
|
||||||
|
// ConfigureTransports binds ConnPool to the throwaway http.Transport; sever it so DialTLSContext is used directly.
|
||||||
|
h2Transport.ConnPool = nil
|
||||||
|
h2Transport.ReadIdleTimeout = time.Duration(options.KeepAlivePeriod)
|
||||||
|
h2Transport.PingTimeout = time.Duration(options.IdleTimeout)
|
||||||
|
return h2Transport, nil
|
||||||
|
}
|
||||||
93
common/httpclient/http2_fallback_transport.go
Normal file
93
common/httpclient/http2_fallback_transport.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
stdTLS "crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errHTTP2Fallback = E.New("fallback to HTTP/1.1")
|
||||||
|
|
||||||
|
type http2FallbackTransport struct {
|
||||||
|
h2Transport *http2.Transport
|
||||||
|
h1Transport *http1Transport
|
||||||
|
h2Fallback *atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP2FallbackTransport(rawDialer N.Dialer, baseTLSConfig tls.Config, options option.HTTP2Options) (*http2FallbackTransport, error) {
|
||||||
|
h1 := newHTTP1Transport(rawDialer, baseTLSConfig)
|
||||||
|
var fallback atomic.Bool
|
||||||
|
h2Transport, err := ConfigureHTTP2Transport(options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
h2Transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *stdTLS.Config) (net.Conn, error) {
|
||||||
|
conn, dialErr := dialTLS(ctx, rawDialer, baseTLSConfig, M.ParseSocksaddr(addr), []string{http2.NextProtoTLS, "http/1.1"}, http2.NextProtoTLS)
|
||||||
|
if dialErr != nil {
|
||||||
|
if errors.Is(dialErr, errHTTP2Fallback) {
|
||||||
|
fallback.Store(true)
|
||||||
|
}
|
||||||
|
return nil, dialErr
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
return &http2FallbackTransport{
|
||||||
|
h2Transport: h2Transport,
|
||||||
|
h1Transport: h1,
|
||||||
|
h2Fallback: &fallback,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http2FallbackTransport) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
return t.roundTrip(request, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http2FallbackTransport) roundTrip(request *http.Request, allowHTTP1Fallback bool) (*http.Response, error) {
|
||||||
|
if request.URL.Scheme != "https" || requestRequiresHTTP1(request) {
|
||||||
|
return t.h1Transport.RoundTrip(request)
|
||||||
|
}
|
||||||
|
if t.h2Fallback.Load() {
|
||||||
|
if !allowHTTP1Fallback {
|
||||||
|
return nil, errHTTP2Fallback
|
||||||
|
}
|
||||||
|
return t.h1Transport.RoundTrip(request)
|
||||||
|
}
|
||||||
|
response, err := t.h2Transport.RoundTrip(request)
|
||||||
|
if err == nil {
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, errHTTP2Fallback) || !allowHTTP1Fallback {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return t.h1Transport.RoundTrip(cloneRequestForRetry(request))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http2FallbackTransport) CloseIdleConnections() {
|
||||||
|
t.h1Transport.CloseIdleConnections()
|
||||||
|
t.h2Transport.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http2FallbackTransport) Clone() adapter.HTTPTransport {
|
||||||
|
return &http2FallbackTransport{
|
||||||
|
h2Transport: CloneHTTP2Transport(t.h2Transport),
|
||||||
|
h1Transport: t.h1Transport.Clone().(*http1Transport),
|
||||||
|
h2Fallback: t.h2Fallback,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http2FallbackTransport) Close() error {
|
||||||
|
t.CloseIdleConnections()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
60
common/httpclient/http2_transport.go
Normal file
60
common/httpclient/http2_transport.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
stdTLS "crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type http2Transport struct {
|
||||||
|
h2Transport *http2.Transport
|
||||||
|
h1Transport *http1Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP2Transport(rawDialer N.Dialer, baseTLSConfig tls.Config, options option.HTTP2Options) (*http2Transport, error) {
|
||||||
|
h1 := newHTTP1Transport(rawDialer, baseTLSConfig)
|
||||||
|
h2Transport, err := ConfigureHTTP2Transport(options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
h2Transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *stdTLS.Config) (net.Conn, error) {
|
||||||
|
return dialTLS(ctx, rawDialer, baseTLSConfig, M.ParseSocksaddr(addr), []string{http2.NextProtoTLS}, http2.NextProtoTLS)
|
||||||
|
}
|
||||||
|
return &http2Transport{
|
||||||
|
h2Transport: h2Transport,
|
||||||
|
h1Transport: h1,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http2Transport) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
if request.URL.Scheme != "https" || requestRequiresHTTP1(request) {
|
||||||
|
return t.h1Transport.RoundTrip(request)
|
||||||
|
}
|
||||||
|
return t.h2Transport.RoundTrip(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http2Transport) CloseIdleConnections() {
|
||||||
|
t.h1Transport.CloseIdleConnections()
|
||||||
|
t.h2Transport.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http2Transport) Clone() adapter.HTTPTransport {
|
||||||
|
return &http2Transport{
|
||||||
|
h2Transport: CloneHTTP2Transport(t.h2Transport),
|
||||||
|
h1Transport: t.h1Transport.Clone().(*http1Transport),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http2Transport) Close() error {
|
||||||
|
t.CloseIdleConnections()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
312
common/httpclient/http3_transport.go
Normal file
312
common/httpclient/http3_transport.go
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
//go:build with_quic
|
||||||
|
|
||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
stdTLS "crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go"
|
||||||
|
"github.com/sagernet/quic-go/http3"
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing/common/bufio"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
type http3Transport struct {
|
||||||
|
h3Transport *http3.Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
type http3FallbackTransport struct {
|
||||||
|
h3Transport *http3.Transport
|
||||||
|
h2Fallback adapter.HTTPTransport
|
||||||
|
fallbackDelay time.Duration
|
||||||
|
brokenAccess sync.Mutex
|
||||||
|
brokenUntil time.Time
|
||||||
|
brokenBackoff time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP3RoundTripper(
|
||||||
|
rawDialer N.Dialer,
|
||||||
|
baseTLSConfig tls.Config,
|
||||||
|
options option.QUICOptions,
|
||||||
|
) *http3.Transport {
|
||||||
|
var handshakeTimeout time.Duration
|
||||||
|
if baseTLSConfig != nil {
|
||||||
|
handshakeTimeout = baseTLSConfig.HandshakeTimeout()
|
||||||
|
}
|
||||||
|
quicConfig := &quic.Config{
|
||||||
|
InitialStreamReceiveWindow: options.StreamReceiveWindow.Value(),
|
||||||
|
MaxStreamReceiveWindow: options.StreamReceiveWindow.Value(),
|
||||||
|
InitialConnectionReceiveWindow: options.ConnectionReceiveWindow.Value(),
|
||||||
|
MaxConnectionReceiveWindow: options.ConnectionReceiveWindow.Value(),
|
||||||
|
KeepAlivePeriod: time.Duration(options.KeepAlivePeriod),
|
||||||
|
MaxIdleTimeout: time.Duration(options.IdleTimeout),
|
||||||
|
DisablePathMTUDiscovery: options.DisablePathMTUDiscovery,
|
||||||
|
}
|
||||||
|
if options.InitialPacketSize > 0 {
|
||||||
|
quicConfig.InitialPacketSize = uint16(options.InitialPacketSize)
|
||||||
|
}
|
||||||
|
if options.MaxConcurrentStreams > 0 {
|
||||||
|
quicConfig.MaxIncomingStreams = int64(options.MaxConcurrentStreams)
|
||||||
|
}
|
||||||
|
if handshakeTimeout > 0 {
|
||||||
|
quicConfig.HandshakeIdleTimeout = handshakeTimeout
|
||||||
|
}
|
||||||
|
h3Transport := &http3.Transport{
|
||||||
|
TLSClientConfig: &stdTLS.Config{},
|
||||||
|
QUICConfig: quicConfig,
|
||||||
|
Dial: func(ctx context.Context, addr string, tlsConfig *stdTLS.Config, quicConfig *quic.Config) (*quic.Conn, error) {
|
||||||
|
if handshakeTimeout > 0 && quicConfig.HandshakeIdleTimeout == 0 {
|
||||||
|
quicConfig = quicConfig.Clone()
|
||||||
|
quicConfig.HandshakeIdleTimeout = handshakeTimeout
|
||||||
|
}
|
||||||
|
if baseTLSConfig != nil {
|
||||||
|
var err error
|
||||||
|
tlsConfig, err = buildSTDTLSConfig(baseTLSConfig, M.ParseSocksaddr(addr), []string{http3.NextProtoH3})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tlsConfig = tlsConfig.Clone()
|
||||||
|
tlsConfig.NextProtos = []string{http3.NextProtoH3}
|
||||||
|
}
|
||||||
|
conn, err := rawDialer.DialContext(ctx, N.NetworkUDP, M.ParseSocksaddr(addr))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
quicConn, err := quic.DialEarly(ctx, bufio.NewUnbindPacketConn(conn), conn.RemoteAddr(), tlsConfig, quicConfig)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return quicConn, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return h3Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP3Transport(
|
||||||
|
rawDialer N.Dialer,
|
||||||
|
baseTLSConfig tls.Config,
|
||||||
|
options option.QUICOptions,
|
||||||
|
) (adapter.HTTPTransport, error) {
|
||||||
|
return &http3Transport{
|
||||||
|
h3Transport: newHTTP3RoundTripper(rawDialer, baseTLSConfig, options),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP3FallbackTransport(
|
||||||
|
rawDialer N.Dialer,
|
||||||
|
baseTLSConfig tls.Config,
|
||||||
|
h2Fallback adapter.HTTPTransport,
|
||||||
|
options option.QUICOptions,
|
||||||
|
fallbackDelay time.Duration,
|
||||||
|
) (adapter.HTTPTransport, error) {
|
||||||
|
return &http3FallbackTransport{
|
||||||
|
h3Transport: newHTTP3RoundTripper(rawDialer, baseTLSConfig, options),
|
||||||
|
h2Fallback: h2Fallback,
|
||||||
|
fallbackDelay: fallbackDelay,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3Transport) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
return t.h3Transport.RoundTrip(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3Transport) CloseIdleConnections() {
|
||||||
|
t.h3Transport.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3Transport) Close() error {
|
||||||
|
t.CloseIdleConnections()
|
||||||
|
return t.h3Transport.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3Transport) Clone() adapter.HTTPTransport {
|
||||||
|
return &http3Transport{
|
||||||
|
h3Transport: t.h3Transport,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
if request.URL.Scheme != "https" || requestRequiresHTTP1(request) {
|
||||||
|
return t.h2Fallback.RoundTrip(request)
|
||||||
|
}
|
||||||
|
return t.roundTripHTTP3(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) roundTripHTTP3(request *http.Request) (*http.Response, error) {
|
||||||
|
if t.h3Broken() {
|
||||||
|
return t.h2FallbackRoundTrip(request)
|
||||||
|
}
|
||||||
|
response, err := t.h3Transport.RoundTripOpt(request, http3.RoundTripOpt{OnlyCachedConn: true})
|
||||||
|
if err == nil {
|
||||||
|
t.clearH3Broken()
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, http3.ErrNoCachedConn) {
|
||||||
|
t.markH3Broken()
|
||||||
|
return t.h2FallbackRoundTrip(cloneRequestForRetry(request))
|
||||||
|
}
|
||||||
|
if !requestReplayable(request) {
|
||||||
|
response, err = t.h3Transport.RoundTrip(request)
|
||||||
|
if err == nil {
|
||||||
|
t.clearH3Broken()
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
t.markH3Broken()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return t.roundTripHTTP3Race(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) roundTripHTTP3Race(request *http.Request) (*http.Response, error) {
|
||||||
|
ctx, cancel := context.WithCancel(request.Context())
|
||||||
|
defer cancel()
|
||||||
|
type result struct {
|
||||||
|
response *http.Response
|
||||||
|
err error
|
||||||
|
h3 bool
|
||||||
|
}
|
||||||
|
results := make(chan result, 2)
|
||||||
|
startRoundTrip := func(request *http.Request, useH3 bool) {
|
||||||
|
request = request.WithContext(ctx)
|
||||||
|
var (
|
||||||
|
response *http.Response
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if useH3 {
|
||||||
|
response, err = t.h3Transport.RoundTrip(request)
|
||||||
|
} else {
|
||||||
|
response, err = t.h2FallbackRoundTrip(request)
|
||||||
|
}
|
||||||
|
results <- result{response: response, err: err, h3: useH3}
|
||||||
|
}
|
||||||
|
goroutines := 1
|
||||||
|
received := 0
|
||||||
|
drainRemaining := func() {
|
||||||
|
cancel()
|
||||||
|
for range goroutines - received {
|
||||||
|
go func() {
|
||||||
|
loser := <-results
|
||||||
|
if loser.response != nil && loser.response.Body != nil {
|
||||||
|
loser.response.Body.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
go startRoundTrip(cloneRequestForRetry(request), true)
|
||||||
|
timer := time.NewTimer(t.fallbackDelay)
|
||||||
|
defer timer.Stop()
|
||||||
|
var (
|
||||||
|
h3Err error
|
||||||
|
fallbackErr error
|
||||||
|
)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
if goroutines == 1 {
|
||||||
|
goroutines++
|
||||||
|
go startRoundTrip(cloneRequestForRetry(request), false)
|
||||||
|
}
|
||||||
|
case raceResult := <-results:
|
||||||
|
received++
|
||||||
|
if raceResult.err == nil {
|
||||||
|
if raceResult.h3 {
|
||||||
|
t.clearH3Broken()
|
||||||
|
}
|
||||||
|
drainRemaining()
|
||||||
|
return raceResult.response, nil
|
||||||
|
}
|
||||||
|
if raceResult.h3 {
|
||||||
|
t.markH3Broken()
|
||||||
|
h3Err = raceResult.err
|
||||||
|
if goroutines == 1 {
|
||||||
|
goroutines++
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
go startRoundTrip(cloneRequestForRetry(request), false)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fallbackErr = raceResult.err
|
||||||
|
}
|
||||||
|
if received < goroutines {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
drainRemaining()
|
||||||
|
switch {
|
||||||
|
case h3Err != nil && fallbackErr != nil:
|
||||||
|
return nil, E.Errors(h3Err, fallbackErr)
|
||||||
|
case fallbackErr != nil:
|
||||||
|
return nil, fallbackErr
|
||||||
|
default:
|
||||||
|
return nil, h3Err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) h2FallbackRoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
if fallback, isFallback := t.h2Fallback.(*http2FallbackTransport); isFallback {
|
||||||
|
return fallback.roundTrip(request, true)
|
||||||
|
}
|
||||||
|
return t.h2Fallback.RoundTrip(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) CloseIdleConnections() {
|
||||||
|
t.h3Transport.CloseIdleConnections()
|
||||||
|
t.h2Fallback.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) Close() error {
|
||||||
|
t.CloseIdleConnections()
|
||||||
|
return t.h3Transport.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) Clone() adapter.HTTPTransport {
|
||||||
|
return &http3FallbackTransport{
|
||||||
|
h3Transport: t.h3Transport,
|
||||||
|
h2Fallback: t.h2Fallback.Clone(),
|
||||||
|
fallbackDelay: t.fallbackDelay,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) h3Broken() bool {
|
||||||
|
t.brokenAccess.Lock()
|
||||||
|
defer t.brokenAccess.Unlock()
|
||||||
|
return !t.brokenUntil.IsZero() && time.Now().Before(t.brokenUntil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) clearH3Broken() {
|
||||||
|
t.brokenAccess.Lock()
|
||||||
|
t.brokenUntil = time.Time{}
|
||||||
|
t.brokenBackoff = 0
|
||||||
|
t.brokenAccess.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *http3FallbackTransport) markH3Broken() {
|
||||||
|
t.brokenAccess.Lock()
|
||||||
|
defer t.brokenAccess.Unlock()
|
||||||
|
if t.brokenBackoff == 0 {
|
||||||
|
t.brokenBackoff = 5 * time.Minute
|
||||||
|
} else {
|
||||||
|
t.brokenBackoff *= 2
|
||||||
|
if t.brokenBackoff > 48*time.Hour {
|
||||||
|
t.brokenBackoff = 48 * time.Hour
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.brokenUntil = time.Now().Add(t.brokenBackoff)
|
||||||
|
}
|
||||||
31
common/httpclient/http3_transport_stub.go
Normal file
31
common/httpclient/http3_transport_stub.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
//go:build !with_quic
|
||||||
|
|
||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newHTTP3FallbackTransport(
|
||||||
|
rawDialer N.Dialer,
|
||||||
|
baseTLSConfig tls.Config,
|
||||||
|
h2Fallback adapter.HTTPTransport,
|
||||||
|
options option.QUICOptions,
|
||||||
|
fallbackDelay time.Duration,
|
||||||
|
) (adapter.HTTPTransport, error) {
|
||||||
|
return nil, E.New("HTTP/3 requires building with the with_quic tag")
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP3Transport(
|
||||||
|
rawDialer N.Dialer,
|
||||||
|
baseTLSConfig tls.Config,
|
||||||
|
options option.QUICOptions,
|
||||||
|
) (adapter.HTTPTransport, error) {
|
||||||
|
return nil, E.New("HTTP/3 requires building with the with_quic tag")
|
||||||
|
}
|
||||||
164
common/httpclient/manager.go
Normal file
164
common/httpclient/manager.go
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
package httpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ adapter.HTTPClientManager = (*Manager)(nil)
|
||||||
|
_ adapter.LifecycleService = (*Manager)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
ctx context.Context
|
||||||
|
logger log.ContextLogger
|
||||||
|
access sync.Mutex
|
||||||
|
defines map[string]option.HTTPClient
|
||||||
|
transports map[string]*Transport
|
||||||
|
defaultTag string
|
||||||
|
defaultTransport adapter.HTTPTransport
|
||||||
|
defaultTransportFallback func() (*Transport, error)
|
||||||
|
fallbackTransport *Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(ctx context.Context, logger log.ContextLogger, clients []option.HTTPClient, defaultHTTPClient string) *Manager {
|
||||||
|
defines := make(map[string]option.HTTPClient, len(clients))
|
||||||
|
for _, client := range clients {
|
||||||
|
defines[client.Tag] = client
|
||||||
|
}
|
||||||
|
defaultTag := defaultHTTPClient
|
||||||
|
if defaultTag == "" && len(clients) > 0 {
|
||||||
|
defaultTag = clients[0].Tag
|
||||||
|
}
|
||||||
|
return &Manager{
|
||||||
|
ctx: ctx,
|
||||||
|
logger: logger,
|
||||||
|
defines: defines,
|
||||||
|
transports: make(map[string]*Transport),
|
||||||
|
defaultTag: defaultTag,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Initialize(defaultTransportFallback func() (*Transport, error)) {
|
||||||
|
m.defaultTransportFallback = defaultTransportFallback
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Name() string {
|
||||||
|
return "http-client"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if m.defaultTag != "" {
|
||||||
|
transport, err := m.resolveShared(m.defaultTag)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "resolve default http client")
|
||||||
|
}
|
||||||
|
m.defaultTransport = transport
|
||||||
|
} else if m.defaultTransportFallback != nil {
|
||||||
|
transport, err := m.defaultTransportFallback()
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "create default http client")
|
||||||
|
}
|
||||||
|
m.defaultTransport = transport
|
||||||
|
m.fallbackTransport = transport
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DefaultTransport() adapter.HTTPTransport {
|
||||||
|
if m.defaultTransport == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &sharedTransport{m.defaultTransport}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) ResolveTransport(ctx context.Context, logger logger.ContextLogger, options option.HTTPClientOptions) (adapter.HTTPTransport, error) {
|
||||||
|
if options.Tag != "" {
|
||||||
|
if options.ResolveOnDetour {
|
||||||
|
define, loaded := m.defines[options.Tag]
|
||||||
|
if !loaded {
|
||||||
|
return nil, E.New("http_client not found: ", options.Tag)
|
||||||
|
}
|
||||||
|
resolvedOptions := define.Options()
|
||||||
|
resolvedOptions.ResolveOnDetour = true
|
||||||
|
return NewTransport(ctx, logger, options.Tag, resolvedOptions)
|
||||||
|
}
|
||||||
|
transport, err := m.resolveShared(options.Tag)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &sharedTransport{transport}, nil
|
||||||
|
}
|
||||||
|
return NewTransport(ctx, logger, "", options)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) resolveShared(tag string) (adapter.HTTPTransport, error) {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
if transport, loaded := m.transports[tag]; loaded {
|
||||||
|
return transport, nil
|
||||||
|
}
|
||||||
|
define, loaded := m.defines[tag]
|
||||||
|
if !loaded {
|
||||||
|
return nil, E.New("http_client not found: ", tag)
|
||||||
|
}
|
||||||
|
transport, err := NewTransport(m.ctx, m.logger, tag, define.Options())
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "create shared http_client[", tag, "]")
|
||||||
|
}
|
||||||
|
m.transports[tag] = transport
|
||||||
|
return transport, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type sharedTransport struct {
|
||||||
|
adapter.HTTPTransport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *sharedTransport) CloseIdleConnections() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *sharedTransport) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) ResetNetwork() {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
for _, transport := range m.transports {
|
||||||
|
transport.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
if m.fallbackTransport != nil {
|
||||||
|
m.fallbackTransport.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Close() error {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
if m.transports == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
for _, transport := range m.transports {
|
||||||
|
err = E.Append(err, transport.Close(), func(err error) error {
|
||||||
|
return E.Cause(err, "close http client")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if m.fallbackTransport != nil {
|
||||||
|
err = E.Append(err, m.fallbackTransport.Close(), func(err error) error {
|
||||||
|
return E.Cause(err, "close default http client")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
m.transports = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -37,7 +37,10 @@ func (l *Listener) ListenTCP() (net.Listener, error) {
|
|||||||
if l.listenOptions.ReuseAddr {
|
if l.listenOptions.ReuseAddr {
|
||||||
listenConfig.Control = control.Append(listenConfig.Control, control.ReuseAddr())
|
listenConfig.Control = control.Append(listenConfig.Control, control.ReuseAddr())
|
||||||
}
|
}
|
||||||
if !l.listenOptions.DisableTCPKeepAlive {
|
if l.listenOptions.DisableTCPKeepAlive {
|
||||||
|
listenConfig.KeepAlive = -1
|
||||||
|
listenConfig.KeepAliveConfig.Enable = false
|
||||||
|
} else {
|
||||||
keepIdle := time.Duration(l.listenOptions.TCPKeepAlive)
|
keepIdle := time.Duration(l.listenOptions.TCPKeepAlive)
|
||||||
if keepIdle == 0 {
|
if keepIdle == 0 {
|
||||||
keepIdle = C.TCPKeepAliveInitial
|
keepIdle = C.TCPKeepAliveInitial
|
||||||
|
|||||||
142
common/networkquality/http.go
Normal file
142
common/networkquality/http.go
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
package networkquality
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
sBufio "github.com/sagernet/sing/common/bufio"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func FormatBitrate(bps int64) string {
|
||||||
|
switch {
|
||||||
|
case bps >= 1_000_000_000:
|
||||||
|
return fmt.Sprintf("%.1f Gbps", float64(bps)/1_000_000_000)
|
||||||
|
case bps >= 1_000_000:
|
||||||
|
return fmt.Sprintf("%.1f Mbps", float64(bps)/1_000_000)
|
||||||
|
case bps >= 1_000:
|
||||||
|
return fmt.Sprintf("%.1f Kbps", float64(bps)/1_000)
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%d bps", bps)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHTTPClient(dialer N.Dialer) *http.Client {
|
||||||
|
transport := &http.Transport{
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
TLSHandshakeTimeout: C.TCPTimeout,
|
||||||
|
}
|
||||||
|
if dialer != nil {
|
||||||
|
transport.DialContext = func(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||||
|
return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &http.Client{Transport: transport}
|
||||||
|
}
|
||||||
|
|
||||||
|
func baseTransportFromClient(client *http.Client) (*http.Transport, error) {
|
||||||
|
if client == nil {
|
||||||
|
return nil, E.New("http client is nil")
|
||||||
|
}
|
||||||
|
if client.Transport == nil {
|
||||||
|
return http.DefaultTransport.(*http.Transport).Clone(), nil
|
||||||
|
}
|
||||||
|
transport, ok := client.Transport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
return nil, E.New("http client transport must be *http.Transport")
|
||||||
|
}
|
||||||
|
return transport.Clone(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMeasurementClient(
|
||||||
|
baseClient *http.Client,
|
||||||
|
connectEndpoint string,
|
||||||
|
singleConnection bool,
|
||||||
|
disableKeepAlives bool,
|
||||||
|
readCounters []N.CountFunc,
|
||||||
|
writeCounters []N.CountFunc,
|
||||||
|
) (*http.Client, error) {
|
||||||
|
transport, err := baseTransportFromClient(baseClient)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
transport.DisableCompression = true
|
||||||
|
transport.DisableKeepAlives = disableKeepAlives
|
||||||
|
if singleConnection {
|
||||||
|
transport.MaxConnsPerHost = 1
|
||||||
|
transport.MaxIdleConnsPerHost = 1
|
||||||
|
transport.MaxIdleConns = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
baseDialContext := transport.DialContext
|
||||||
|
if baseDialContext == nil {
|
||||||
|
dialer := &net.Dialer{}
|
||||||
|
baseDialContext = dialer.DialContext
|
||||||
|
}
|
||||||
|
transport.DialContext = func(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||||
|
dialAddr := addr
|
||||||
|
if connectEndpoint != "" {
|
||||||
|
dialAddr = rewriteDialAddress(addr, connectEndpoint)
|
||||||
|
}
|
||||||
|
conn, dialErr := baseDialContext(ctx, network, dialAddr)
|
||||||
|
if dialErr != nil {
|
||||||
|
return nil, dialErr
|
||||||
|
}
|
||||||
|
if len(readCounters) > 0 || len(writeCounters) > 0 {
|
||||||
|
return sBufio.NewCounterConn(conn, readCounters, writeCounters), nil
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Client{
|
||||||
|
Transport: transport,
|
||||||
|
CheckRedirect: baseClient.CheckRedirect,
|
||||||
|
Jar: baseClient.Jar,
|
||||||
|
Timeout: baseClient.Timeout,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type MeasurementClientFactory func(
|
||||||
|
connectEndpoint string,
|
||||||
|
singleConnection bool,
|
||||||
|
disableKeepAlives bool,
|
||||||
|
readCounters []N.CountFunc,
|
||||||
|
writeCounters []N.CountFunc,
|
||||||
|
) (*http.Client, error)
|
||||||
|
|
||||||
|
func defaultMeasurementClientFactory(baseClient *http.Client) MeasurementClientFactory {
|
||||||
|
return func(connectEndpoint string, singleConnection, disableKeepAlives bool, readCounters, writeCounters []N.CountFunc) (*http.Client, error) {
|
||||||
|
return newMeasurementClient(baseClient, connectEndpoint, singleConnection, disableKeepAlives, readCounters, writeCounters)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOptionalHTTP3Factory(dialer N.Dialer, useHTTP3 bool) (MeasurementClientFactory, error) {
|
||||||
|
if !useHTTP3 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return NewHTTP3MeasurementClientFactory(dialer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func rewriteDialAddress(addr string, connectEndpoint string) string {
|
||||||
|
connectEndpoint = strings.TrimSpace(connectEndpoint)
|
||||||
|
host, port, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
endpointHost, endpointPort, err := net.SplitHostPort(connectEndpoint)
|
||||||
|
if err == nil {
|
||||||
|
host = endpointHost
|
||||||
|
if endpointPort != "" {
|
||||||
|
port = endpointPort
|
||||||
|
}
|
||||||
|
} else if connectEndpoint != "" {
|
||||||
|
host = connectEndpoint
|
||||||
|
}
|
||||||
|
return net.JoinHostPort(host, port)
|
||||||
|
}
|
||||||
55
common/networkquality/http3.go
Normal file
55
common/networkquality/http3.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
//go:build with_quic
|
||||||
|
|
||||||
|
package networkquality
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go"
|
||||||
|
"github.com/sagernet/quic-go/http3"
|
||||||
|
sBufio "github.com/sagernet/sing/common/bufio"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewHTTP3MeasurementClientFactory(dialer N.Dialer) (MeasurementClientFactory, error) {
|
||||||
|
// singleConnection and disableKeepAlives are not applied:
|
||||||
|
// HTTP/3 multiplexes streams over a single QUIC connection by default.
|
||||||
|
return func(connectEndpoint string, _, _ bool, readCounters, writeCounters []N.CountFunc) (*http.Client, error) {
|
||||||
|
transport := &http3.Transport{
|
||||||
|
Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
|
||||||
|
dialAddr := addr
|
||||||
|
if connectEndpoint != "" {
|
||||||
|
dialAddr = rewriteDialAddress(addr, connectEndpoint)
|
||||||
|
}
|
||||||
|
destination := M.ParseSocksaddr(dialAddr)
|
||||||
|
var udpConn net.Conn
|
||||||
|
var dialErr error
|
||||||
|
if dialer != nil {
|
||||||
|
udpConn, dialErr = dialer.DialContext(ctx, N.NetworkUDP, destination)
|
||||||
|
} else {
|
||||||
|
var netDialer net.Dialer
|
||||||
|
udpConn, dialErr = netDialer.DialContext(ctx, N.NetworkUDP, destination.String())
|
||||||
|
}
|
||||||
|
if dialErr != nil {
|
||||||
|
return nil, dialErr
|
||||||
|
}
|
||||||
|
wrappedConn := udpConn
|
||||||
|
if len(readCounters) > 0 || len(writeCounters) > 0 {
|
||||||
|
wrappedConn = sBufio.NewCounterConn(udpConn, readCounters, writeCounters)
|
||||||
|
}
|
||||||
|
packetConn := sBufio.NewUnbindPacketConn(wrappedConn)
|
||||||
|
quicConn, dialErr := quic.DialEarly(ctx, packetConn, udpConn.RemoteAddr(), tlsCfg, cfg)
|
||||||
|
if dialErr != nil {
|
||||||
|
udpConn.Close()
|
||||||
|
return nil, dialErr
|
||||||
|
}
|
||||||
|
return quicConn, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return &http.Client{Transport: transport}, nil
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
12
common/networkquality/http3_stub.go
Normal file
12
common/networkquality/http3_stub.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
//go:build !with_quic
|
||||||
|
|
||||||
|
package networkquality
|
||||||
|
|
||||||
|
import (
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewHTTP3MeasurementClientFactory(dialer N.Dialer) (MeasurementClientFactory, error) {
|
||||||
|
return nil, C.ErrQUICNotIncluded
|
||||||
|
}
|
||||||
1413
common/networkquality/networkquality.go
Normal file
1413
common/networkquality/networkquality.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
type Searcher interface {
|
type Searcher interface {
|
||||||
FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error)
|
FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error)
|
||||||
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrNotFound = E.New("process not found")
|
var ErrNotFound = E.New("process not found")
|
||||||
@@ -28,7 +29,7 @@ func FindProcessInfo(searcher Searcher, ctx context.Context, network string, sou
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if info.UserId != -1 {
|
if info.UserId != -1 && info.UserName == "" {
|
||||||
osUser, _ := user.LookupId(F.ToString(info.UserId))
|
osUser, _ := user.LookupId(F.ToString(info.UserId))
|
||||||
if osUser != nil {
|
if osUser != nil {
|
||||||
info.UserName = osUser.Username
|
info.UserName = osUser.Username
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
"github.com/sagernet/sing-tun"
|
"github.com/sagernet/sing-tun"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ Searcher = (*androidSearcher)(nil)
|
var _ Searcher = (*androidSearcher)(nil)
|
||||||
@@ -18,22 +19,30 @@ func NewSearcher(config Config) (Searcher, error) {
|
|||||||
return &androidSearcher{config.PackageManager}, nil
|
return &androidSearcher{config.PackageManager}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *androidSearcher) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *androidSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
|
func (s *androidSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
|
||||||
_, uid, err := resolveSocketByNetlink(network, source, destination)
|
family, protocol, err := socketDiagSettings(network, source)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if sharedPackage, loaded := s.packageManager.SharedPackageByID(uid % 100000); loaded {
|
_, uid, err := querySocketDiagOnce(family, protocol, source)
|
||||||
return &adapter.ConnectionOwner{
|
if err != nil {
|
||||||
UserId: int32(uid),
|
return nil, err
|
||||||
AndroidPackageName: sharedPackage,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
if packageName, loaded := s.packageManager.PackageByID(uid % 100000); loaded {
|
appID := uid % 100000
|
||||||
return &adapter.ConnectionOwner{
|
var packageNames []string
|
||||||
UserId: int32(uid),
|
if sharedPackage, loaded := s.packageManager.SharedPackageByID(appID); loaded {
|
||||||
AndroidPackageName: packageName,
|
packageNames = append(packageNames, sharedPackage)
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
return &adapter.ConnectionOwner{UserId: int32(uid)}, nil
|
if packages, loaded := s.packageManager.PackagesByID(appID); loaded {
|
||||||
|
packageNames = append(packageNames, packages...)
|
||||||
|
}
|
||||||
|
packageNames = common.Uniq(packageNames)
|
||||||
|
return &adapter.ConnectionOwner{
|
||||||
|
UserId: int32(uid),
|
||||||
|
AndroidPackageNames: packageNames,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,19 +1,15 @@
|
|||||||
|
//go:build darwin
|
||||||
|
|
||||||
package process
|
package process
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
N "github.com/sagernet/sing/common/network"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ Searcher = (*darwinSearcher)(nil)
|
var _ Searcher = (*darwinSearcher)(nil)
|
||||||
@@ -24,12 +20,12 @@ func NewSearcher(_ Config) (Searcher, error) {
|
|||||||
return &darwinSearcher{}, nil
|
return &darwinSearcher{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *darwinSearcher) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (d *darwinSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
|
func (d *darwinSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
|
||||||
processName, err := findProcessName(network, source.Addr(), int(source.Port()))
|
return FindDarwinConnectionOwner(network, source, destination)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &adapter.ConnectionOwner{ProcessPath: processName, UserId: -1}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var structSize = func() int {
|
var structSize = func() int {
|
||||||
@@ -47,107 +43,3 @@ var structSize = func() int {
|
|||||||
return 384
|
return 384
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
func findProcessName(network string, ip netip.Addr, port int) (string, error) {
|
|
||||||
var spath string
|
|
||||||
switch network {
|
|
||||||
case N.NetworkTCP:
|
|
||||||
spath = "net.inet.tcp.pcblist_n"
|
|
||||||
case N.NetworkUDP:
|
|
||||||
spath = "net.inet.udp.pcblist_n"
|
|
||||||
default:
|
|
||||||
return "", os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
isIPv4 := ip.Is4()
|
|
||||||
|
|
||||||
value, err := unix.SysctlRaw(spath)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := value
|
|
||||||
|
|
||||||
// from darwin-xnu/bsd/netinet/in_pcblist.c:get_pcblist_n
|
|
||||||
// size/offset are round up (aligned) to 8 bytes in darwin
|
|
||||||
// rup8(sizeof(xinpcb_n)) + rup8(sizeof(xsocket_n)) +
|
|
||||||
// 2 * rup8(sizeof(xsockbuf_n)) + rup8(sizeof(xsockstat_n))
|
|
||||||
itemSize := structSize
|
|
||||||
if network == N.NetworkTCP {
|
|
||||||
// rup8(sizeof(xtcpcb_n))
|
|
||||||
itemSize += 208
|
|
||||||
}
|
|
||||||
|
|
||||||
var fallbackUDPProcess string
|
|
||||||
// skip the first xinpgen(24 bytes) block
|
|
||||||
for i := 24; i+itemSize <= len(buf); i += itemSize {
|
|
||||||
// offset of xinpcb_n and xsocket_n
|
|
||||||
inp, so := i, i+104
|
|
||||||
|
|
||||||
srcPort := binary.BigEndian.Uint16(buf[inp+18 : inp+20])
|
|
||||||
if uint16(port) != srcPort {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// xinpcb_n.inp_vflag
|
|
||||||
flag := buf[inp+44]
|
|
||||||
|
|
||||||
var srcIP netip.Addr
|
|
||||||
srcIsIPv4 := false
|
|
||||||
switch {
|
|
||||||
case flag&0x1 > 0 && isIPv4:
|
|
||||||
// ipv4
|
|
||||||
srcIP = netip.AddrFrom4([4]byte(buf[inp+76 : inp+80]))
|
|
||||||
srcIsIPv4 = true
|
|
||||||
case flag&0x2 > 0 && !isIPv4:
|
|
||||||
// ipv6
|
|
||||||
srcIP = netip.AddrFrom16([16]byte(buf[inp+64 : inp+80]))
|
|
||||||
default:
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if ip == srcIP {
|
|
||||||
// xsocket_n.so_last_pid
|
|
||||||
pid := readNativeUint32(buf[so+68 : so+72])
|
|
||||||
return getExecPathFromPID(pid)
|
|
||||||
}
|
|
||||||
|
|
||||||
// udp packet connection may be not equal with srcIP
|
|
||||||
if network == N.NetworkUDP && srcIP.IsUnspecified() && isIPv4 == srcIsIPv4 {
|
|
||||||
pid := readNativeUint32(buf[so+68 : so+72])
|
|
||||||
fallbackUDPProcess, _ = getExecPathFromPID(pid)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if network == N.NetworkUDP && len(fallbackUDPProcess) > 0 {
|
|
||||||
return fallbackUDPProcess, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func getExecPathFromPID(pid uint32) (string, error) {
|
|
||||||
const (
|
|
||||||
procpidpathinfo = 0xb
|
|
||||||
procpidpathinfosize = 1024
|
|
||||||
proccallnumpidinfo = 0x2
|
|
||||||
)
|
|
||||||
buf := make([]byte, procpidpathinfosize)
|
|
||||||
_, _, errno := syscall.Syscall6(
|
|
||||||
syscall.SYS_PROC_INFO,
|
|
||||||
proccallnumpidinfo,
|
|
||||||
uintptr(pid),
|
|
||||||
procpidpathinfo,
|
|
||||||
0,
|
|
||||||
uintptr(unsafe.Pointer(&buf[0])),
|
|
||||||
procpidpathinfosize)
|
|
||||||
if errno != 0 {
|
|
||||||
return "", errno
|
|
||||||
}
|
|
||||||
|
|
||||||
return unix.ByteSliceToString(buf), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func readNativeUint32(b []byte) uint32 {
|
|
||||||
return *(*uint32)(unsafe.Pointer(&b[0]))
|
|
||||||
}
|
|
||||||
|
|||||||
269
common/process/searcher_darwin_shared.go
Normal file
269
common/process/searcher_darwin_shared.go
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
//go:build darwin
|
||||||
|
|
||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
darwinSnapshotTTL = 200 * time.Millisecond
|
||||||
|
|
||||||
|
darwinXinpgenSize = 24
|
||||||
|
darwinXsocketOffset = 104
|
||||||
|
darwinXinpcbForeignPort = 16
|
||||||
|
darwinXinpcbLocalPort = 18
|
||||||
|
darwinXinpcbVFlag = 44
|
||||||
|
darwinXinpcbForeignAddr = 48
|
||||||
|
darwinXinpcbLocalAddr = 64
|
||||||
|
darwinXinpcbIPv4Addr = 12
|
||||||
|
darwinXsocketUID = 64
|
||||||
|
darwinXsocketLastPID = 68
|
||||||
|
darwinTCPExtraStructSize = 208
|
||||||
|
)
|
||||||
|
|
||||||
|
type darwinConnectionEntry struct {
|
||||||
|
localAddr netip.Addr
|
||||||
|
remoteAddr netip.Addr
|
||||||
|
localPort uint16
|
||||||
|
remotePort uint16
|
||||||
|
pid uint32
|
||||||
|
uid int32
|
||||||
|
}
|
||||||
|
|
||||||
|
type darwinConnectionMatchKind uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
darwinConnectionMatchExact darwinConnectionMatchKind = iota
|
||||||
|
darwinConnectionMatchLocalFallback
|
||||||
|
darwinConnectionMatchWildcardFallback
|
||||||
|
)
|
||||||
|
|
||||||
|
type darwinSnapshot struct {
|
||||||
|
createdAt time.Time
|
||||||
|
entries []darwinConnectionEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
type darwinConnectionFinder struct {
|
||||||
|
access sync.Mutex
|
||||||
|
ttl time.Duration
|
||||||
|
snapshots map[string]darwinSnapshot
|
||||||
|
builder func(string) (darwinSnapshot, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var sharedDarwinConnectionFinder = newDarwinConnectionFinder(darwinSnapshotTTL)
|
||||||
|
|
||||||
|
func newDarwinConnectionFinder(ttl time.Duration) *darwinConnectionFinder {
|
||||||
|
return &darwinConnectionFinder{
|
||||||
|
ttl: ttl,
|
||||||
|
snapshots: make(map[string]darwinSnapshot),
|
||||||
|
builder: buildDarwinSnapshot,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindDarwinConnectionOwner(network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
|
||||||
|
return sharedDarwinConnectionFinder.find(network, source, destination)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *darwinConnectionFinder) find(network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
|
||||||
|
networkName := N.NetworkName(network)
|
||||||
|
source = normalizeDarwinAddrPort(source)
|
||||||
|
destination = normalizeDarwinAddrPort(destination)
|
||||||
|
var lastOwner *adapter.ConnectionOwner
|
||||||
|
for attempt := 0; attempt < 2; attempt++ {
|
||||||
|
snapshot, fromCache, err := f.loadSnapshot(networkName, attempt > 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
entry, matchKind, err := matchDarwinConnectionEntry(snapshot.entries, networkName, source, destination)
|
||||||
|
if err != nil {
|
||||||
|
if err == ErrNotFound && fromCache {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if fromCache && matchKind != darwinConnectionMatchExact {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
owner := &adapter.ConnectionOwner{
|
||||||
|
UserId: entry.uid,
|
||||||
|
}
|
||||||
|
lastOwner = owner
|
||||||
|
if entry.pid == 0 {
|
||||||
|
return owner, nil
|
||||||
|
}
|
||||||
|
processPath, err := getExecPathFromPID(entry.pid)
|
||||||
|
if err == nil {
|
||||||
|
owner.ProcessPath = processPath
|
||||||
|
return owner, nil
|
||||||
|
}
|
||||||
|
if fromCache {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return owner, nil
|
||||||
|
}
|
||||||
|
if lastOwner != nil {
|
||||||
|
return lastOwner, nil
|
||||||
|
}
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *darwinConnectionFinder) loadSnapshot(network string, forceRefresh bool) (darwinSnapshot, bool, error) {
|
||||||
|
f.access.Lock()
|
||||||
|
defer f.access.Unlock()
|
||||||
|
if !forceRefresh {
|
||||||
|
if snapshot, loaded := f.snapshots[network]; loaded && time.Since(snapshot.createdAt) < f.ttl {
|
||||||
|
return snapshot, true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
snapshot, err := f.builder(network)
|
||||||
|
if err != nil {
|
||||||
|
return darwinSnapshot{}, false, err
|
||||||
|
}
|
||||||
|
f.snapshots[network] = snapshot
|
||||||
|
return snapshot, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildDarwinSnapshot(network string) (darwinSnapshot, error) {
|
||||||
|
spath, itemSize, err := darwinSnapshotSettings(network)
|
||||||
|
if err != nil {
|
||||||
|
return darwinSnapshot{}, err
|
||||||
|
}
|
||||||
|
value, err := unix.SysctlRaw(spath)
|
||||||
|
if err != nil {
|
||||||
|
return darwinSnapshot{}, err
|
||||||
|
}
|
||||||
|
return darwinSnapshot{
|
||||||
|
createdAt: time.Now(),
|
||||||
|
entries: parseDarwinSnapshot(value, itemSize),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func darwinSnapshotSettings(network string) (string, int, error) {
|
||||||
|
itemSize := structSize
|
||||||
|
switch network {
|
||||||
|
case N.NetworkTCP:
|
||||||
|
return "net.inet.tcp.pcblist_n", itemSize + darwinTCPExtraStructSize, nil
|
||||||
|
case N.NetworkUDP:
|
||||||
|
return "net.inet.udp.pcblist_n", itemSize, nil
|
||||||
|
default:
|
||||||
|
return "", 0, os.ErrInvalid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDarwinSnapshot(buf []byte, itemSize int) []darwinConnectionEntry {
|
||||||
|
entries := make([]darwinConnectionEntry, 0, (len(buf)-darwinXinpgenSize)/itemSize)
|
||||||
|
for i := darwinXinpgenSize; i+itemSize <= len(buf); i += itemSize {
|
||||||
|
inp := i
|
||||||
|
so := i + darwinXsocketOffset
|
||||||
|
entry, ok := parseDarwinConnectionEntry(buf[inp:so], buf[so:so+structSize-darwinXsocketOffset])
|
||||||
|
if ok {
|
||||||
|
entries = append(entries, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return entries
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDarwinConnectionEntry(inp []byte, so []byte) (darwinConnectionEntry, bool) {
|
||||||
|
if len(inp) < darwinXsocketOffset || len(so) < structSize-darwinXsocketOffset {
|
||||||
|
return darwinConnectionEntry{}, false
|
||||||
|
}
|
||||||
|
entry := darwinConnectionEntry{
|
||||||
|
remotePort: binary.BigEndian.Uint16(inp[darwinXinpcbForeignPort : darwinXinpcbForeignPort+2]),
|
||||||
|
localPort: binary.BigEndian.Uint16(inp[darwinXinpcbLocalPort : darwinXinpcbLocalPort+2]),
|
||||||
|
pid: binary.NativeEndian.Uint32(so[darwinXsocketLastPID : darwinXsocketLastPID+4]),
|
||||||
|
uid: int32(binary.NativeEndian.Uint32(so[darwinXsocketUID : darwinXsocketUID+4])),
|
||||||
|
}
|
||||||
|
flag := inp[darwinXinpcbVFlag]
|
||||||
|
switch {
|
||||||
|
case flag&0x1 != 0:
|
||||||
|
entry.remoteAddr = netip.AddrFrom4([4]byte(inp[darwinXinpcbForeignAddr+darwinXinpcbIPv4Addr : darwinXinpcbForeignAddr+darwinXinpcbIPv4Addr+4]))
|
||||||
|
entry.localAddr = netip.AddrFrom4([4]byte(inp[darwinXinpcbLocalAddr+darwinXinpcbIPv4Addr : darwinXinpcbLocalAddr+darwinXinpcbIPv4Addr+4]))
|
||||||
|
return entry, true
|
||||||
|
case flag&0x2 != 0:
|
||||||
|
entry.remoteAddr = netip.AddrFrom16([16]byte(inp[darwinXinpcbForeignAddr : darwinXinpcbForeignAddr+16]))
|
||||||
|
entry.localAddr = netip.AddrFrom16([16]byte(inp[darwinXinpcbLocalAddr : darwinXinpcbLocalAddr+16]))
|
||||||
|
return entry, true
|
||||||
|
default:
|
||||||
|
return darwinConnectionEntry{}, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func matchDarwinConnectionEntry(entries []darwinConnectionEntry, network string, source netip.AddrPort, destination netip.AddrPort) (darwinConnectionEntry, darwinConnectionMatchKind, error) {
|
||||||
|
sourceAddr := source.Addr()
|
||||||
|
if !sourceAddr.IsValid() {
|
||||||
|
return darwinConnectionEntry{}, darwinConnectionMatchExact, os.ErrInvalid
|
||||||
|
}
|
||||||
|
var localFallback darwinConnectionEntry
|
||||||
|
var hasLocalFallback bool
|
||||||
|
var wildcardFallback darwinConnectionEntry
|
||||||
|
var hasWildcardFallback bool
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.localPort != source.Port() || sourceAddr.BitLen() != entry.localAddr.BitLen() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if entry.localAddr == sourceAddr && destination.IsValid() && entry.remotePort == destination.Port() && entry.remoteAddr == destination.Addr() {
|
||||||
|
return entry, darwinConnectionMatchExact, nil
|
||||||
|
}
|
||||||
|
if !destination.IsValid() && entry.localAddr == sourceAddr {
|
||||||
|
return entry, darwinConnectionMatchExact, nil
|
||||||
|
}
|
||||||
|
if network != N.NetworkUDP {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !hasLocalFallback && entry.localAddr == sourceAddr {
|
||||||
|
hasLocalFallback = true
|
||||||
|
localFallback = entry
|
||||||
|
}
|
||||||
|
if !hasWildcardFallback && entry.localAddr.IsUnspecified() {
|
||||||
|
hasWildcardFallback = true
|
||||||
|
wildcardFallback = entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasLocalFallback {
|
||||||
|
return localFallback, darwinConnectionMatchLocalFallback, nil
|
||||||
|
}
|
||||||
|
if hasWildcardFallback {
|
||||||
|
return wildcardFallback, darwinConnectionMatchWildcardFallback, nil
|
||||||
|
}
|
||||||
|
return darwinConnectionEntry{}, darwinConnectionMatchExact, ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeDarwinAddrPort(addrPort netip.AddrPort) netip.AddrPort {
|
||||||
|
if !addrPort.IsValid() {
|
||||||
|
return addrPort
|
||||||
|
}
|
||||||
|
return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())
|
||||||
|
}
|
||||||
|
|
||||||
|
func getExecPathFromPID(pid uint32) (string, error) {
|
||||||
|
const (
|
||||||
|
procpidpathinfo = 0xb
|
||||||
|
procpidpathinfosize = 1024
|
||||||
|
proccallnumpidinfo = 0x2
|
||||||
|
)
|
||||||
|
buf := make([]byte, procpidpathinfosize)
|
||||||
|
_, _, errno := syscall.Syscall6(
|
||||||
|
syscall.SYS_PROC_INFO,
|
||||||
|
proccallnumpidinfo,
|
||||||
|
uintptr(pid),
|
||||||
|
procpidpathinfo,
|
||||||
|
0,
|
||||||
|
uintptr(unsafe.Pointer(&buf[0])),
|
||||||
|
procpidpathinfosize)
|
||||||
|
if errno != 0 {
|
||||||
|
return "", errno
|
||||||
|
}
|
||||||
|
return unix.ByteSliceToString(buf), nil
|
||||||
|
}
|
||||||
@@ -4,33 +4,82 @@ package process
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ Searcher = (*linuxSearcher)(nil)
|
var _ Searcher = (*linuxSearcher)(nil)
|
||||||
|
|
||||||
type linuxSearcher struct {
|
type linuxSearcher struct {
|
||||||
logger log.ContextLogger
|
logger log.ContextLogger
|
||||||
|
diagConns [4]*socketDiagConn
|
||||||
|
processPathCache *uidProcessPathCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSearcher(config Config) (Searcher, error) {
|
func NewSearcher(config Config) (Searcher, error) {
|
||||||
return &linuxSearcher{config.Logger}, nil
|
searcher := &linuxSearcher{
|
||||||
|
logger: config.Logger,
|
||||||
|
processPathCache: newUIDProcessPathCache(time.Second),
|
||||||
|
}
|
||||||
|
for _, family := range []uint8{syscall.AF_INET, syscall.AF_INET6} {
|
||||||
|
for _, protocol := range []uint8{syscall.IPPROTO_TCP, syscall.IPPROTO_UDP} {
|
||||||
|
searcher.diagConns[socketDiagConnIndex(family, protocol)] = newSocketDiagConn(family, protocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return searcher, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *linuxSearcher) Close() error {
|
||||||
|
var errs []error
|
||||||
|
for _, conn := range s.diagConns {
|
||||||
|
if conn == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
errs = append(errs, conn.Close())
|
||||||
|
}
|
||||||
|
return E.Errors(errs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *linuxSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
|
func (s *linuxSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
|
||||||
inode, uid, err := resolveSocketByNetlink(network, source, destination)
|
inode, uid, err := s.resolveSocketByNetlink(network, source, destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
processPath, err := resolveProcessNameByProcSearch(inode, uid)
|
processInfo := &adapter.ConnectionOwner{
|
||||||
|
UserId: int32(uid),
|
||||||
|
}
|
||||||
|
processPath, err := s.processPathCache.findProcessPath(inode, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.DebugContext(ctx, "find process path: ", err)
|
s.logger.DebugContext(ctx, "find process path: ", err)
|
||||||
|
} else {
|
||||||
|
processInfo.ProcessPath = processPath
|
||||||
}
|
}
|
||||||
return &adapter.ConnectionOwner{
|
return processInfo, nil
|
||||||
UserId: int32(uid),
|
}
|
||||||
ProcessPath: processPath,
|
|
||||||
}, nil
|
func (s *linuxSearcher) resolveSocketByNetlink(network string, source netip.AddrPort, destination netip.AddrPort) (inode, uid uint32, err error) {
|
||||||
|
family, protocol, err := socketDiagSettings(network, source)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
conn := s.diagConns[socketDiagConnIndex(family, protocol)]
|
||||||
|
if conn == nil {
|
||||||
|
return 0, 0, E.New("missing socket diag connection for family=", family, " protocol=", protocol)
|
||||||
|
}
|
||||||
|
if destination.IsValid() && source.Addr().BitLen() == destination.Addr().BitLen() {
|
||||||
|
inode, uid, err = conn.query(source, destination)
|
||||||
|
if err == nil {
|
||||||
|
return inode, uid, nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrNotFound) {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return querySocketDiagOnce(family, protocol, source)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,43 +3,67 @@
|
|||||||
package process
|
package process
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"errors"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
"unicode"
|
"unicode"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/sing/contrab/freelru"
|
||||||
|
"github.com/sagernet/sing/contrab/maphash"
|
||||||
)
|
)
|
||||||
|
|
||||||
// from https://github.com/vishvananda/netlink/blob/bca67dfc8220b44ef582c9da4e9172bf1c9ec973/nl/nl_linux.go#L52-L62
|
|
||||||
var nativeEndian = func() binary.ByteOrder {
|
|
||||||
var x uint32 = 0x01020304
|
|
||||||
if *(*byte)(unsafe.Pointer(&x)) == 0x01 {
|
|
||||||
return binary.BigEndian
|
|
||||||
}
|
|
||||||
|
|
||||||
return binary.LittleEndian
|
|
||||||
}()
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
sizeOfSocketDiagRequest = syscall.SizeofNlMsghdr + 8 + 48
|
sizeOfSocketDiagRequestData = 56
|
||||||
socketDiagByFamily = 20
|
sizeOfSocketDiagRequest = syscall.SizeofNlMsghdr + sizeOfSocketDiagRequestData
|
||||||
pathProc = "/proc"
|
socketDiagResponseMinSize = 72
|
||||||
|
socketDiagByFamily = 20
|
||||||
|
pathProc = "/proc"
|
||||||
)
|
)
|
||||||
|
|
||||||
func resolveSocketByNetlink(network string, source netip.AddrPort, destination netip.AddrPort) (inode, uid uint32, err error) {
|
type socketDiagConn struct {
|
||||||
var family uint8
|
access sync.Mutex
|
||||||
var protocol uint8
|
family uint8
|
||||||
|
protocol uint8
|
||||||
|
fd int
|
||||||
|
}
|
||||||
|
|
||||||
|
type uidProcessPathCache struct {
|
||||||
|
cache freelru.Cache[uint32, *uidProcessPaths]
|
||||||
|
}
|
||||||
|
|
||||||
|
type uidProcessPaths struct {
|
||||||
|
entries map[uint32]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSocketDiagConn(family, protocol uint8) *socketDiagConn {
|
||||||
|
return &socketDiagConn{
|
||||||
|
family: family,
|
||||||
|
protocol: protocol,
|
||||||
|
fd: -1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func socketDiagConnIndex(family, protocol uint8) int {
|
||||||
|
index := 0
|
||||||
|
if protocol == syscall.IPPROTO_UDP {
|
||||||
|
index += 2
|
||||||
|
}
|
||||||
|
if family == syscall.AF_INET6 {
|
||||||
|
index++
|
||||||
|
}
|
||||||
|
return index
|
||||||
|
}
|
||||||
|
|
||||||
|
func socketDiagSettings(network string, source netip.AddrPort) (family, protocol uint8, err error) {
|
||||||
switch network {
|
switch network {
|
||||||
case N.NetworkTCP:
|
case N.NetworkTCP:
|
||||||
protocol = syscall.IPPROTO_TCP
|
protocol = syscall.IPPROTO_TCP
|
||||||
@@ -48,151 +72,308 @@ func resolveSocketByNetlink(network string, source netip.AddrPort, destination n
|
|||||||
default:
|
default:
|
||||||
return 0, 0, os.ErrInvalid
|
return 0, 0, os.ErrInvalid
|
||||||
}
|
}
|
||||||
|
switch {
|
||||||
if source.Addr().Is4() {
|
case source.Addr().Is4():
|
||||||
family = syscall.AF_INET
|
family = syscall.AF_INET
|
||||||
} else {
|
case source.Addr().Is6():
|
||||||
family = syscall.AF_INET6
|
family = syscall.AF_INET6
|
||||||
|
default:
|
||||||
|
return 0, 0, os.ErrInvalid
|
||||||
}
|
}
|
||||||
|
return family, protocol, nil
|
||||||
req := packSocketDiagRequest(family, protocol, source)
|
|
||||||
|
|
||||||
socket, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_DGRAM, syscall.NETLINK_INET_DIAG)
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, E.Cause(err, "dial netlink")
|
|
||||||
}
|
|
||||||
defer syscall.Close(socket)
|
|
||||||
|
|
||||||
syscall.SetsockoptTimeval(socket, syscall.SOL_SOCKET, syscall.SO_SNDTIMEO, &syscall.Timeval{Usec: 100})
|
|
||||||
syscall.SetsockoptTimeval(socket, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &syscall.Timeval{Usec: 100})
|
|
||||||
|
|
||||||
err = syscall.Connect(socket, &syscall.SockaddrNetlink{
|
|
||||||
Family: syscall.AF_NETLINK,
|
|
||||||
Pad: 0,
|
|
||||||
Pid: 0,
|
|
||||||
Groups: 0,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = syscall.Write(socket, req)
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, E.Cause(err, "write netlink request")
|
|
||||||
}
|
|
||||||
|
|
||||||
buffer := buf.New()
|
|
||||||
defer buffer.Release()
|
|
||||||
|
|
||||||
n, err := syscall.Read(socket, buffer.FreeBytes())
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, E.Cause(err, "read netlink response")
|
|
||||||
}
|
|
||||||
|
|
||||||
buffer.Truncate(n)
|
|
||||||
|
|
||||||
messages, err := syscall.ParseNetlinkMessage(buffer.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, E.Cause(err, "parse netlink message")
|
|
||||||
} else if len(messages) == 0 {
|
|
||||||
return 0, 0, E.New("unexcepted netlink response")
|
|
||||||
}
|
|
||||||
|
|
||||||
message := messages[0]
|
|
||||||
if message.Header.Type&syscall.NLMSG_ERROR != 0 {
|
|
||||||
return 0, 0, E.New("netlink message: NLMSG_ERROR")
|
|
||||||
}
|
|
||||||
|
|
||||||
inode, uid = unpackSocketDiagResponse(&messages[0])
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func packSocketDiagRequest(family, protocol byte, source netip.AddrPort) []byte {
|
func newUIDProcessPathCache(ttl time.Duration) *uidProcessPathCache {
|
||||||
s := make([]byte, 16)
|
cache := common.Must1(freelru.NewSharded[uint32, *uidProcessPaths](64, maphash.NewHasher[uint32]().Hash32))
|
||||||
copy(s, source.Addr().AsSlice())
|
cache.SetLifetime(ttl)
|
||||||
|
return &uidProcessPathCache{cache: cache}
|
||||||
buf := make([]byte, sizeOfSocketDiagRequest)
|
|
||||||
|
|
||||||
nativeEndian.PutUint32(buf[0:4], sizeOfSocketDiagRequest)
|
|
||||||
nativeEndian.PutUint16(buf[4:6], socketDiagByFamily)
|
|
||||||
nativeEndian.PutUint16(buf[6:8], syscall.NLM_F_REQUEST|syscall.NLM_F_DUMP)
|
|
||||||
nativeEndian.PutUint32(buf[8:12], 0)
|
|
||||||
nativeEndian.PutUint32(buf[12:16], 0)
|
|
||||||
|
|
||||||
buf[16] = family
|
|
||||||
buf[17] = protocol
|
|
||||||
buf[18] = 0
|
|
||||||
buf[19] = 0
|
|
||||||
nativeEndian.PutUint32(buf[20:24], 0xFFFFFFFF)
|
|
||||||
|
|
||||||
binary.BigEndian.PutUint16(buf[24:26], source.Port())
|
|
||||||
binary.BigEndian.PutUint16(buf[26:28], 0)
|
|
||||||
|
|
||||||
copy(buf[28:44], s)
|
|
||||||
copy(buf[44:60], net.IPv6zero)
|
|
||||||
|
|
||||||
nativeEndian.PutUint32(buf[60:64], 0)
|
|
||||||
nativeEndian.PutUint64(buf[64:72], 0xFFFFFFFFFFFFFFFF)
|
|
||||||
|
|
||||||
return buf
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func unpackSocketDiagResponse(msg *syscall.NetlinkMessage) (inode, uid uint32) {
|
func (c *uidProcessPathCache) findProcessPath(targetInode, uid uint32) (string, error) {
|
||||||
if len(msg.Data) < 72 {
|
if cached, ok := c.cache.Get(uid); ok {
|
||||||
return 0, 0
|
if processPath, found := cached.entries[targetInode]; found {
|
||||||
|
return processPath, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
processPaths, err := buildProcessPathByUIDCache(uid)
|
||||||
data := msg.Data
|
|
||||||
|
|
||||||
uid = nativeEndian.Uint32(data[64:68])
|
|
||||||
inode = nativeEndian.Uint32(data[68:72])
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func resolveProcessNameByProcSearch(inode, uid uint32) (string, error) {
|
|
||||||
files, err := os.ReadDir(pathProc)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
c.cache.Add(uid, &uidProcessPaths{entries: processPaths})
|
||||||
|
processPath, found := processPaths[targetInode]
|
||||||
|
if !found {
|
||||||
|
return "", E.New("process of uid(", uid, "), inode(", targetInode, ") not found")
|
||||||
|
}
|
||||||
|
return processPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *socketDiagConn) Close() error {
|
||||||
|
c.access.Lock()
|
||||||
|
defer c.access.Unlock()
|
||||||
|
return c.closeLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *socketDiagConn) query(source netip.AddrPort, destination netip.AddrPort) (inode, uid uint32, err error) {
|
||||||
|
c.access.Lock()
|
||||||
|
defer c.access.Unlock()
|
||||||
|
request := packSocketDiagRequest(c.family, c.protocol, source, destination, false)
|
||||||
|
for attempt := 0; attempt < 2; attempt++ {
|
||||||
|
err = c.ensureOpenLocked()
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, E.Cause(err, "dial netlink")
|
||||||
|
}
|
||||||
|
inode, uid, err = querySocketDiag(c.fd, request)
|
||||||
|
if err == nil || errors.Is(err, ErrNotFound) {
|
||||||
|
return inode, uid, err
|
||||||
|
}
|
||||||
|
if !shouldRetrySocketDiag(err) {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
_ = c.closeLocked()
|
||||||
|
}
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func querySocketDiagOnce(family, protocol uint8, source netip.AddrPort) (inode, uid uint32, err error) {
|
||||||
|
fd, err := openSocketDiag()
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, E.Cause(err, "dial netlink")
|
||||||
|
}
|
||||||
|
defer syscall.Close(fd)
|
||||||
|
return querySocketDiag(fd, packSocketDiagRequest(family, protocol, source, netip.AddrPort{}, true))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *socketDiagConn) ensureOpenLocked() error {
|
||||||
|
if c.fd != -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
fd, err := openSocketDiag()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.fd = fd
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func openSocketDiag() (int, error) {
|
||||||
|
fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_DGRAM|syscall.SOCK_CLOEXEC, syscall.NETLINK_INET_DIAG)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
timeout := &syscall.Timeval{Usec: 100}
|
||||||
|
if err = syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET, syscall.SO_SNDTIMEO, timeout); err != nil {
|
||||||
|
syscall.Close(fd)
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
if err = syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, timeout); err != nil {
|
||||||
|
syscall.Close(fd)
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
if err = syscall.Connect(fd, &syscall.SockaddrNetlink{
|
||||||
|
Family: syscall.AF_NETLINK,
|
||||||
|
Pid: 0,
|
||||||
|
Groups: 0,
|
||||||
|
}); err != nil {
|
||||||
|
syscall.Close(fd)
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
return fd, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *socketDiagConn) closeLocked() error {
|
||||||
|
if c.fd == -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
err := syscall.Close(c.fd)
|
||||||
|
c.fd = -1
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func packSocketDiagRequest(family, protocol byte, source netip.AddrPort, destination netip.AddrPort, dump bool) []byte {
|
||||||
|
request := make([]byte, sizeOfSocketDiagRequest)
|
||||||
|
|
||||||
|
binary.NativeEndian.PutUint32(request[0:4], sizeOfSocketDiagRequest)
|
||||||
|
binary.NativeEndian.PutUint16(request[4:6], socketDiagByFamily)
|
||||||
|
flags := uint16(syscall.NLM_F_REQUEST)
|
||||||
|
if dump {
|
||||||
|
flags |= syscall.NLM_F_DUMP
|
||||||
|
}
|
||||||
|
binary.NativeEndian.PutUint16(request[6:8], flags)
|
||||||
|
binary.NativeEndian.PutUint32(request[8:12], 0)
|
||||||
|
binary.NativeEndian.PutUint32(request[12:16], 0)
|
||||||
|
|
||||||
|
request[16] = family
|
||||||
|
request[17] = protocol
|
||||||
|
request[18] = 0
|
||||||
|
request[19] = 0
|
||||||
|
if dump {
|
||||||
|
binary.NativeEndian.PutUint32(request[20:24], 0xFFFFFFFF)
|
||||||
|
}
|
||||||
|
requestSource := source
|
||||||
|
requestDestination := destination
|
||||||
|
if protocol == syscall.IPPROTO_UDP && !dump && destination.IsValid() {
|
||||||
|
// udp_dump_one expects the exact-match endpoints reversed for historical reasons.
|
||||||
|
requestSource, requestDestination = destination, source
|
||||||
|
}
|
||||||
|
binary.BigEndian.PutUint16(request[24:26], requestSource.Port())
|
||||||
|
binary.BigEndian.PutUint16(request[26:28], requestDestination.Port())
|
||||||
|
if family == syscall.AF_INET6 {
|
||||||
|
copy(request[28:44], requestSource.Addr().AsSlice())
|
||||||
|
if requestDestination.IsValid() {
|
||||||
|
copy(request[44:60], requestDestination.Addr().AsSlice())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
copy(request[28:32], requestSource.Addr().AsSlice())
|
||||||
|
if requestDestination.IsValid() {
|
||||||
|
copy(request[44:48], requestDestination.Addr().AsSlice())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
binary.NativeEndian.PutUint32(request[60:64], 0)
|
||||||
|
binary.NativeEndian.PutUint64(request[64:72], 0xFFFFFFFFFFFFFFFF)
|
||||||
|
return request
|
||||||
|
}
|
||||||
|
|
||||||
|
func querySocketDiag(fd int, request []byte) (inode, uid uint32, err error) {
|
||||||
|
_, err = syscall.Write(fd, request)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, E.Cause(err, "write netlink request")
|
||||||
|
}
|
||||||
|
buffer := make([]byte, 64<<10)
|
||||||
|
n, err := syscall.Read(fd, buffer)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, E.Cause(err, "read netlink response")
|
||||||
|
}
|
||||||
|
messages, err := syscall.ParseNetlinkMessage(buffer[:n])
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, E.Cause(err, "parse netlink message")
|
||||||
|
}
|
||||||
|
return unpackSocketDiagMessages(messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
func unpackSocketDiagMessages(messages []syscall.NetlinkMessage) (inode, uid uint32, err error) {
|
||||||
|
for _, message := range messages {
|
||||||
|
switch message.Header.Type {
|
||||||
|
case syscall.NLMSG_DONE:
|
||||||
|
continue
|
||||||
|
case syscall.NLMSG_ERROR:
|
||||||
|
err = unpackSocketDiagError(&message)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
case socketDiagByFamily:
|
||||||
|
inode, uid = unpackSocketDiagResponse(&message)
|
||||||
|
if inode != 0 || uid != 0 {
|
||||||
|
return inode, uid, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, 0, ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func unpackSocketDiagResponse(msg *syscall.NetlinkMessage) (inode, uid uint32) {
|
||||||
|
if len(msg.Data) < socketDiagResponseMinSize {
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
uid = binary.NativeEndian.Uint32(msg.Data[64:68])
|
||||||
|
inode = binary.NativeEndian.Uint32(msg.Data[68:72])
|
||||||
|
return inode, uid
|
||||||
|
}
|
||||||
|
|
||||||
|
func unpackSocketDiagError(msg *syscall.NetlinkMessage) error {
|
||||||
|
if len(msg.Data) < 4 {
|
||||||
|
return E.New("netlink message: NLMSG_ERROR")
|
||||||
|
}
|
||||||
|
errno := int32(binary.NativeEndian.Uint32(msg.Data[:4]))
|
||||||
|
if errno == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if errno < 0 {
|
||||||
|
errno = -errno
|
||||||
|
}
|
||||||
|
sysErr := syscall.Errno(errno)
|
||||||
|
switch sysErr {
|
||||||
|
case syscall.ENOENT, syscall.ESRCH:
|
||||||
|
return ErrNotFound
|
||||||
|
default:
|
||||||
|
return E.New("netlink message: ", sysErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldRetrySocketDiag(err error) bool {
|
||||||
|
return err != nil && !errors.Is(err, ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildProcessPathByUIDCache(uid uint32) (map[uint32]string, error) {
|
||||||
|
files, err := os.ReadDir(pathProc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
buffer := make([]byte, syscall.PathMax)
|
buffer := make([]byte, syscall.PathMax)
|
||||||
socket := []byte(fmt.Sprintf("socket:[%d]", inode))
|
processPaths := make(map[uint32]string)
|
||||||
|
for _, file := range files {
|
||||||
for _, f := range files {
|
if !file.IsDir() || !isPid(file.Name()) {
|
||||||
if !f.IsDir() || !isPid(f.Name()) {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
info, err := file.Info()
|
||||||
info, err := f.Info()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
if isIgnorableProcError(err) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
if info.Sys().(*syscall.Stat_t).Uid != uid {
|
if info.Sys().(*syscall.Stat_t).Uid != uid {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
processPath := filepath.Join(pathProc, file.Name())
|
||||||
processPath := path.Join(pathProc, f.Name())
|
fdPath := filepath.Join(processPath, "fd")
|
||||||
fdPath := path.Join(processPath, "fd")
|
exePath, err := os.Readlink(filepath.Join(processPath, "exe"))
|
||||||
|
if err != nil {
|
||||||
|
if isIgnorableProcError(err) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
fds, err := os.ReadDir(fdPath)
|
fds, err := os.ReadDir(fdPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, fd := range fds {
|
for _, fd := range fds {
|
||||||
n, err := syscall.Readlink(path.Join(fdPath, fd.Name()), buffer)
|
n, err := syscall.Readlink(filepath.Join(fdPath, fd.Name()), buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
inode, ok := parseSocketInode(buffer[:n])
|
||||||
if bytes.Equal(buffer[:n], socket) {
|
if !ok {
|
||||||
return os.Readlink(path.Join(processPath, "exe"))
|
continue
|
||||||
|
}
|
||||||
|
if _, loaded := processPaths[inode]; !loaded {
|
||||||
|
processPaths[inode] = exePath
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return processPaths, nil
|
||||||
|
}
|
||||||
|
|
||||||
return "", fmt.Errorf("process of uid(%d),inode(%d) not found", uid, inode)
|
func isIgnorableProcError(err error) bool {
|
||||||
|
return os.IsNotExist(err) || os.IsPermission(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSocketInode(link []byte) (uint32, bool) {
|
||||||
|
const socketPrefix = "socket:["
|
||||||
|
if len(link) <= len(socketPrefix) || string(link[:len(socketPrefix)]) != socketPrefix || link[len(link)-1] != ']' {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
var inode uint64
|
||||||
|
for _, char := range link[len(socketPrefix) : len(link)-1] {
|
||||||
|
if char < '0' || char > '9' {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
inode = inode*10 + uint64(char-'0')
|
||||||
|
if inode > uint64(^uint32(0)) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return uint32(inode), true
|
||||||
}
|
}
|
||||||
|
|
||||||
func isPid(s string) bool {
|
func isPid(s string) bool {
|
||||||
|
|||||||
60
common/process/searcher_linux_shared_test.go
Normal file
60
common/process/searcher_linux_shared_test.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"syscall"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestQuerySocketDiagUDPExact(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
server, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client, err := net.DialUDP("udp4", nil, server.LocalAddr().(*net.UDPAddr))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
err = client.SetDeadline(time.Now().Add(time.Second))
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = client.Write([]byte{0})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = server.SetReadDeadline(time.Now().Add(time.Second))
|
||||||
|
require.NoError(t, err)
|
||||||
|
buffer := make([]byte, 1)
|
||||||
|
_, _, err = server.ReadFromUDP(buffer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
source := addrPortFromUDPAddr(t, client.LocalAddr())
|
||||||
|
destination := addrPortFromUDPAddr(t, client.RemoteAddr())
|
||||||
|
|
||||||
|
fd, err := openSocketDiag()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer syscall.Close(fd)
|
||||||
|
|
||||||
|
inode, uid, err := querySocketDiag(fd, packSocketDiagRequest(syscall.AF_INET, syscall.IPPROTO_UDP, source, destination, false))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotZero(t, inode)
|
||||||
|
require.EqualValues(t, os.Getuid(), uid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addrPortFromUDPAddr(t *testing.T, addr net.Addr) netip.AddrPort {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
udpAddr, ok := addr.(*net.UDPAddr)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
ip, ok := netip.AddrFromSlice(udpAddr.IP)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
return netip.AddrPortFrom(ip.Unmap(), uint16(udpAddr.Port))
|
||||||
|
}
|
||||||
@@ -28,6 +28,10 @@ func initWin32API() error {
|
|||||||
return winiphlpapi.LoadExtendedTable()
|
return winiphlpapi.LoadExtendedTable()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *windowsSearcher) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *windowsSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
|
func (s *windowsSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
|
||||||
pid, err := winiphlpapi.FindPid(network, source)
|
pid, err := winiphlpapi.FindPid(network, source)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
115
common/proxybridge/bridge.go
Normal file
115
common/proxybridge/bridge.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package proxybridge
|
||||||
|
|
||||||
|
import (
|
||||||
|
std_bufio "bufio"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/auth"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/sing/protocol/socks"
|
||||||
|
"github.com/sagernet/sing/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Bridge struct {
|
||||||
|
ctx context.Context
|
||||||
|
logger logger.ContextLogger
|
||||||
|
tag string
|
||||||
|
dialer N.Dialer
|
||||||
|
connection adapter.ConnectionManager
|
||||||
|
tcpListener *net.TCPListener
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
authenticator *auth.Authenticator
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(ctx context.Context, logger logger.ContextLogger, tag string, dialer N.Dialer) (*Bridge, error) {
|
||||||
|
username := randomHex(16)
|
||||||
|
password := randomHex(16)
|
||||||
|
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
bridge := &Bridge{
|
||||||
|
ctx: ctx,
|
||||||
|
logger: logger,
|
||||||
|
tag: tag,
|
||||||
|
dialer: dialer,
|
||||||
|
connection: service.FromContext[adapter.ConnectionManager](ctx),
|
||||||
|
tcpListener: tcpListener,
|
||||||
|
username: username,
|
||||||
|
password: password,
|
||||||
|
authenticator: auth.NewAuthenticator([]auth.User{{Username: username, Password: password}}),
|
||||||
|
}
|
||||||
|
go bridge.acceptLoop()
|
||||||
|
return bridge, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomHex(size int) string {
|
||||||
|
raw := make([]byte, size)
|
||||||
|
rand.Read(raw)
|
||||||
|
return hex.EncodeToString(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bridge) Port() uint16 {
|
||||||
|
return M.SocksaddrFromNet(b.tcpListener.Addr()).Port
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bridge) Username() string {
|
||||||
|
return b.username
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bridge) Password() string {
|
||||||
|
return b.password
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bridge) Close() error {
|
||||||
|
return common.Close(b.tcpListener)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bridge) acceptLoop() {
|
||||||
|
for {
|
||||||
|
tcpConn, err := b.tcpListener.AcceptTCP()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx := log.ContextWithNewID(b.ctx)
|
||||||
|
go func() {
|
||||||
|
hErr := socks.HandleConnectionEx(ctx, tcpConn, std_bufio.NewReader(tcpConn), b.authenticator, b, nil, 0, M.SocksaddrFromNet(tcpConn.RemoteAddr()), nil)
|
||||||
|
if hErr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if E.IsClosedOrCanceled(hErr) {
|
||||||
|
b.logger.DebugContext(ctx, E.Cause(hErr, b.tag, " connection closed"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.logger.ErrorContext(ctx, E.Cause(hErr, b.tag))
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bridge) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
|
||||||
|
var metadata adapter.InboundContext
|
||||||
|
metadata.Source = source
|
||||||
|
metadata.Destination = destination
|
||||||
|
metadata.Network = N.NetworkTCP
|
||||||
|
b.logger.InfoContext(ctx, b.tag, " connection to ", metadata.Destination)
|
||||||
|
b.connection.NewConnection(ctx, b.dialer, conn, metadata, onClose)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bridge) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
|
||||||
|
var metadata adapter.InboundContext
|
||||||
|
metadata.Source = source
|
||||||
|
metadata.Destination = destination
|
||||||
|
metadata.Network = N.NetworkUDP
|
||||||
|
b.logger.InfoContext(ctx, b.tag, " packet connection to ", metadata.Destination)
|
||||||
|
b.connection.NewPacketConnection(ctx, b.dialer, conn, metadata, onClose)
|
||||||
|
}
|
||||||
@@ -46,6 +46,7 @@ const (
|
|||||||
ruleItemNetworkIsConstrained
|
ruleItemNetworkIsConstrained
|
||||||
ruleItemNetworkInterfaceAddress
|
ruleItemNetworkInterfaceAddress
|
||||||
ruleItemDefaultInterfaceAddress
|
ruleItemDefaultInterfaceAddress
|
||||||
|
ruleItemPackageNameRegex
|
||||||
ruleItemFinal uint8 = 0xFF
|
ruleItemFinal uint8 = 0xFF
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -215,6 +216,8 @@ func readDefaultRule(reader varbin.Reader, recover bool) (rule option.DefaultHea
|
|||||||
rule.ProcessPathRegex, err = readRuleItemString(reader)
|
rule.ProcessPathRegex, err = readRuleItemString(reader)
|
||||||
case ruleItemPackageName:
|
case ruleItemPackageName:
|
||||||
rule.PackageName, err = readRuleItemString(reader)
|
rule.PackageName, err = readRuleItemString(reader)
|
||||||
|
case ruleItemPackageNameRegex:
|
||||||
|
rule.PackageNameRegex, err = readRuleItemString(reader)
|
||||||
case ruleItemWIFISSID:
|
case ruleItemWIFISSID:
|
||||||
rule.WIFISSID, err = readRuleItemString(reader)
|
rule.WIFISSID, err = readRuleItemString(reader)
|
||||||
case ruleItemWIFIBSSID:
|
case ruleItemWIFIBSSID:
|
||||||
@@ -394,6 +397,15 @@ func writeDefaultRule(writer varbin.Writer, rule option.DefaultHeadlessRule, gen
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(rule.PackageNameRegex) > 0 {
|
||||||
|
if generateVersion < C.RuleSetVersion5 {
|
||||||
|
return E.New("`package_name_regex` rule item is only supported in version 5 or later")
|
||||||
|
}
|
||||||
|
err = writeRuleItemString(writer, ruleItemPackageNameRegex, rule.PackageNameRegex)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
if len(rule.NetworkType) > 0 {
|
if len(rule.NetworkType) > 0 {
|
||||||
if generateVersion < C.RuleSetVersion3 {
|
if generateVersion < C.RuleSetVersion3 {
|
||||||
return E.New("`network_type` rule item is only supported in version 3 or later")
|
return E.New("`network_type` rule item is only supported in version 3 or later")
|
||||||
|
|||||||
612
common/stun/stun.go
Normal file
612
common/stun/stun.go
Normal file
@@ -0,0 +1,612 @@
|
|||||||
|
package stun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common/bufio"
|
||||||
|
"github.com/sagernet/sing/common/bufio/deadline"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultServer = "stun.voipgate.com:3478"
|
||||||
|
|
||||||
|
magicCookie = 0x2112A442
|
||||||
|
headerSize = 20
|
||||||
|
|
||||||
|
bindingRequest = 0x0001
|
||||||
|
bindingSuccessResponse = 0x0101
|
||||||
|
bindingErrorResponse = 0x0111
|
||||||
|
|
||||||
|
attrMappedAddress = 0x0001
|
||||||
|
attrChangeRequest = 0x0003
|
||||||
|
attrErrorCode = 0x0009
|
||||||
|
attrXORMappedAddress = 0x0020
|
||||||
|
attrOtherAddress = 0x802c
|
||||||
|
|
||||||
|
familyIPv4 = 0x01
|
||||||
|
familyIPv6 = 0x02
|
||||||
|
|
||||||
|
changeIP = 0x04
|
||||||
|
changePort = 0x02
|
||||||
|
|
||||||
|
defaultRTO = 500 * time.Millisecond
|
||||||
|
minRTO = 250 * time.Millisecond
|
||||||
|
maxRetransmit = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type Phase int32
|
||||||
|
|
||||||
|
const (
|
||||||
|
PhaseBinding Phase = iota
|
||||||
|
PhaseNATMapping
|
||||||
|
PhaseNATFiltering
|
||||||
|
PhaseDone
|
||||||
|
)
|
||||||
|
|
||||||
|
type NATMapping int32
|
||||||
|
|
||||||
|
const (
|
||||||
|
NATMappingUnknown NATMapping = iota
|
||||||
|
_ // reserved
|
||||||
|
NATMappingEndpointIndependent
|
||||||
|
NATMappingAddressDependent
|
||||||
|
NATMappingAddressAndPortDependent
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m NATMapping) String() string {
|
||||||
|
switch m {
|
||||||
|
case NATMappingEndpointIndependent:
|
||||||
|
return "Endpoint Independent"
|
||||||
|
case NATMappingAddressDependent:
|
||||||
|
return "Address Dependent"
|
||||||
|
case NATMappingAddressAndPortDependent:
|
||||||
|
return "Address and Port Dependent"
|
||||||
|
default:
|
||||||
|
return "Unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type NATFiltering int32
|
||||||
|
|
||||||
|
const (
|
||||||
|
NATFilteringUnknown NATFiltering = iota
|
||||||
|
NATFilteringEndpointIndependent
|
||||||
|
NATFilteringAddressDependent
|
||||||
|
NATFilteringAddressAndPortDependent
|
||||||
|
)
|
||||||
|
|
||||||
|
func (f NATFiltering) String() string {
|
||||||
|
switch f {
|
||||||
|
case NATFilteringEndpointIndependent:
|
||||||
|
return "Endpoint Independent"
|
||||||
|
case NATFilteringAddressDependent:
|
||||||
|
return "Address Dependent"
|
||||||
|
case NATFilteringAddressAndPortDependent:
|
||||||
|
return "Address and Port Dependent"
|
||||||
|
default:
|
||||||
|
return "Unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type TransactionID [12]byte
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
Server string
|
||||||
|
Dialer N.Dialer
|
||||||
|
Context context.Context
|
||||||
|
OnProgress func(Progress)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Progress struct {
|
||||||
|
Phase Phase
|
||||||
|
ExternalAddr string
|
||||||
|
LatencyMs int32
|
||||||
|
NATMapping NATMapping
|
||||||
|
NATFiltering NATFiltering
|
||||||
|
}
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
ExternalAddr string
|
||||||
|
LatencyMs int32
|
||||||
|
NATMapping NATMapping
|
||||||
|
NATFiltering NATFiltering
|
||||||
|
NATTypeSupported bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type parsedResponse struct {
|
||||||
|
xorMappedAddr netip.AddrPort
|
||||||
|
mappedAddr netip.AddrPort
|
||||||
|
otherAddr netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *parsedResponse) externalAddr() (netip.AddrPort, bool) {
|
||||||
|
if r.xorMappedAddr.IsValid() {
|
||||||
|
return r.xorMappedAddr, true
|
||||||
|
}
|
||||||
|
if r.mappedAddr.IsValid() {
|
||||||
|
return r.mappedAddr, true
|
||||||
|
}
|
||||||
|
return netip.AddrPort{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
type stunAttribute struct {
|
||||||
|
typ uint16
|
||||||
|
value []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTransactionID() TransactionID {
|
||||||
|
var id TransactionID
|
||||||
|
_, _ = rand.Read(id[:])
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildBindingRequest(txID TransactionID, attrs ...stunAttribute) []byte {
|
||||||
|
attrLen := 0
|
||||||
|
for _, attr := range attrs {
|
||||||
|
attrLen += 4 + len(attr.value) + paddingLen(len(attr.value))
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, headerSize+attrLen)
|
||||||
|
binary.BigEndian.PutUint16(buf[0:2], bindingRequest)
|
||||||
|
binary.BigEndian.PutUint16(buf[2:4], uint16(attrLen))
|
||||||
|
binary.BigEndian.PutUint32(buf[4:8], magicCookie)
|
||||||
|
copy(buf[8:20], txID[:])
|
||||||
|
|
||||||
|
offset := headerSize
|
||||||
|
for _, attr := range attrs {
|
||||||
|
binary.BigEndian.PutUint16(buf[offset:offset+2], attr.typ)
|
||||||
|
binary.BigEndian.PutUint16(buf[offset+2:offset+4], uint16(len(attr.value)))
|
||||||
|
copy(buf[offset+4:offset+4+len(attr.value)], attr.value)
|
||||||
|
offset += 4 + len(attr.value) + paddingLen(len(attr.value))
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func changeRequestAttr(flags byte) stunAttribute {
|
||||||
|
return stunAttribute{
|
||||||
|
typ: attrChangeRequest,
|
||||||
|
value: []byte{0, 0, 0, flags},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseResponse(data []byte, expectedTxID TransactionID) (*parsedResponse, error) {
|
||||||
|
if len(data) < headerSize {
|
||||||
|
return nil, E.New("response too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
msgType := binary.BigEndian.Uint16(data[0:2])
|
||||||
|
if msgType&0xC000 != 0 {
|
||||||
|
return nil, E.New("invalid STUN message: top 2 bits not zero")
|
||||||
|
}
|
||||||
|
|
||||||
|
cookie := binary.BigEndian.Uint32(data[4:8])
|
||||||
|
if cookie != magicCookie {
|
||||||
|
return nil, E.New("invalid magic cookie")
|
||||||
|
}
|
||||||
|
|
||||||
|
var txID TransactionID
|
||||||
|
copy(txID[:], data[8:20])
|
||||||
|
if txID != expectedTxID {
|
||||||
|
return nil, E.New("transaction ID mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
msgLen := int(binary.BigEndian.Uint16(data[2:4]))
|
||||||
|
if msgLen > len(data)-headerSize {
|
||||||
|
return nil, E.New("message length exceeds data")
|
||||||
|
}
|
||||||
|
|
||||||
|
attrData := data[headerSize : headerSize+msgLen]
|
||||||
|
|
||||||
|
if msgType == bindingErrorResponse {
|
||||||
|
return nil, parseErrorResponse(attrData)
|
||||||
|
}
|
||||||
|
if msgType != bindingSuccessResponse {
|
||||||
|
return nil, E.New("unexpected message type: ", fmt.Sprintf("0x%04x", msgType))
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &parsedResponse{}
|
||||||
|
offset := 0
|
||||||
|
for offset+4 <= len(attrData) {
|
||||||
|
attrType := binary.BigEndian.Uint16(attrData[offset : offset+2])
|
||||||
|
attrLen := int(binary.BigEndian.Uint16(attrData[offset+2 : offset+4]))
|
||||||
|
if offset+4+attrLen > len(attrData) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
attrValue := attrData[offset+4 : offset+4+attrLen]
|
||||||
|
|
||||||
|
switch attrType {
|
||||||
|
case attrXORMappedAddress:
|
||||||
|
addr, err := parseXORMappedAddress(attrValue, txID)
|
||||||
|
if err == nil {
|
||||||
|
resp.xorMappedAddr = addr
|
||||||
|
}
|
||||||
|
case attrMappedAddress:
|
||||||
|
addr, err := parseMappedAddress(attrValue)
|
||||||
|
if err == nil {
|
||||||
|
resp.mappedAddr = addr
|
||||||
|
}
|
||||||
|
case attrOtherAddress:
|
||||||
|
addr, err := parseMappedAddress(attrValue)
|
||||||
|
if err == nil {
|
||||||
|
resp.otherAddr = addr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
offset += 4 + attrLen + paddingLen(attrLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseErrorResponse(data []byte) error {
|
||||||
|
offset := 0
|
||||||
|
for offset+4 <= len(data) {
|
||||||
|
attrType := binary.BigEndian.Uint16(data[offset : offset+2])
|
||||||
|
attrLen := int(binary.BigEndian.Uint16(data[offset+2 : offset+4]))
|
||||||
|
if offset+4+attrLen > len(data) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if attrType == attrErrorCode && attrLen >= 4 {
|
||||||
|
attrValue := data[offset+4 : offset+4+attrLen]
|
||||||
|
class := int(attrValue[2] & 0x07)
|
||||||
|
number := int(attrValue[3])
|
||||||
|
code := class*100 + number
|
||||||
|
if attrLen > 4 {
|
||||||
|
return E.New("STUN error ", code, ": ", string(attrValue[4:]))
|
||||||
|
}
|
||||||
|
return E.New("STUN error ", code)
|
||||||
|
}
|
||||||
|
offset += 4 + attrLen + paddingLen(attrLen)
|
||||||
|
}
|
||||||
|
return E.New("STUN error response")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseXORMappedAddress(data []byte, txID TransactionID) (netip.AddrPort, error) {
|
||||||
|
if len(data) < 4 {
|
||||||
|
return netip.AddrPort{}, E.New("XOR-MAPPED-ADDRESS too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
family := data[1]
|
||||||
|
xPort := binary.BigEndian.Uint16(data[2:4])
|
||||||
|
port := xPort ^ uint16(magicCookie>>16)
|
||||||
|
|
||||||
|
switch family {
|
||||||
|
case familyIPv4:
|
||||||
|
if len(data) < 8 {
|
||||||
|
return netip.AddrPort{}, E.New("XOR-MAPPED-ADDRESS IPv4 too short")
|
||||||
|
}
|
||||||
|
var ip [4]byte
|
||||||
|
binary.BigEndian.PutUint32(ip[:], binary.BigEndian.Uint32(data[4:8])^magicCookie)
|
||||||
|
return netip.AddrPortFrom(netip.AddrFrom4(ip), port), nil
|
||||||
|
case familyIPv6:
|
||||||
|
if len(data) < 20 {
|
||||||
|
return netip.AddrPort{}, E.New("XOR-MAPPED-ADDRESS IPv6 too short")
|
||||||
|
}
|
||||||
|
var ip [16]byte
|
||||||
|
var xorKey [16]byte
|
||||||
|
binary.BigEndian.PutUint32(xorKey[0:4], magicCookie)
|
||||||
|
copy(xorKey[4:16], txID[:])
|
||||||
|
for i := range 16 {
|
||||||
|
ip[i] = data[4+i] ^ xorKey[i]
|
||||||
|
}
|
||||||
|
return netip.AddrPortFrom(netip.AddrFrom16(ip), port), nil
|
||||||
|
default:
|
||||||
|
return netip.AddrPort{}, E.New("unknown address family: ", family)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMappedAddress(data []byte) (netip.AddrPort, error) {
|
||||||
|
if len(data) < 4 {
|
||||||
|
return netip.AddrPort{}, E.New("MAPPED-ADDRESS too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
family := data[1]
|
||||||
|
port := binary.BigEndian.Uint16(data[2:4])
|
||||||
|
|
||||||
|
switch family {
|
||||||
|
case familyIPv4:
|
||||||
|
if len(data) < 8 {
|
||||||
|
return netip.AddrPort{}, E.New("MAPPED-ADDRESS IPv4 too short")
|
||||||
|
}
|
||||||
|
return netip.AddrPortFrom(
|
||||||
|
netip.AddrFrom4([4]byte{data[4], data[5], data[6], data[7]}), port,
|
||||||
|
), nil
|
||||||
|
case familyIPv6:
|
||||||
|
if len(data) < 20 {
|
||||||
|
return netip.AddrPort{}, E.New("MAPPED-ADDRESS IPv6 too short")
|
||||||
|
}
|
||||||
|
var ip [16]byte
|
||||||
|
copy(ip[:], data[4:20])
|
||||||
|
return netip.AddrPortFrom(netip.AddrFrom16(ip), port), nil
|
||||||
|
default:
|
||||||
|
return netip.AddrPort{}, E.New("unknown address family: ", family)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func roundTrip(conn net.PacketConn, addr net.Addr, txID TransactionID, attrs []stunAttribute, rto time.Duration) (*parsedResponse, time.Duration, error) {
|
||||||
|
request := buildBindingRequest(txID, attrs...)
|
||||||
|
currentRTO := rto
|
||||||
|
retransmitCount := 0
|
||||||
|
|
||||||
|
sendTime := time.Now()
|
||||||
|
_, err := conn.WriteTo(request, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, E.Cause(err, "send STUN request")
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
for {
|
||||||
|
err = conn.SetReadDeadline(sendTime.Add(currentRTO))
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, E.Cause(err, "set read deadline")
|
||||||
|
}
|
||||||
|
|
||||||
|
n, _, readErr := conn.ReadFrom(buf)
|
||||||
|
if readErr != nil {
|
||||||
|
if E.IsTimeout(readErr) && retransmitCount < maxRetransmit {
|
||||||
|
retransmitCount++
|
||||||
|
currentRTO *= 2
|
||||||
|
sendTime = time.Now()
|
||||||
|
_, err = conn.WriteTo(request, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, E.Cause(err, "retransmit STUN request")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, 0, E.Cause(readErr, "read STUN response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if n < headerSize || buf[0]&0xC0 != 0 ||
|
||||||
|
binary.BigEndian.Uint32(buf[4:8]) != magicCookie {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var receivedTxID TransactionID
|
||||||
|
copy(receivedTxID[:], buf[8:20])
|
||||||
|
if receivedTxID != txID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
latency := time.Since(sendTime)
|
||||||
|
|
||||||
|
resp, parseErr := parseResponse(buf[:n], txID)
|
||||||
|
if parseErr != nil {
|
||||||
|
return nil, 0, parseErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, latency, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Run(options Options) (*Result, error) {
|
||||||
|
ctx := options.Context
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
server := options.Server
|
||||||
|
if server == "" {
|
||||||
|
server = DefaultServer
|
||||||
|
}
|
||||||
|
serverSocksaddr := M.ParseSocksaddr(server)
|
||||||
|
if serverSocksaddr.Port == 0 {
|
||||||
|
serverSocksaddr.Port = 3478
|
||||||
|
}
|
||||||
|
|
||||||
|
reportProgress := options.OnProgress
|
||||||
|
if reportProgress == nil {
|
||||||
|
reportProgress = func(Progress) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
packetConn net.PacketConn
|
||||||
|
serverAddr net.Addr
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
if options.Dialer != nil {
|
||||||
|
packetConn, err = options.Dialer.ListenPacket(ctx, serverSocksaddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "create UDP socket")
|
||||||
|
}
|
||||||
|
serverAddr = serverSocksaddr
|
||||||
|
} else {
|
||||||
|
serverUDPAddr, resolveErr := net.ResolveUDPAddr("udp", serverSocksaddr.String())
|
||||||
|
if resolveErr != nil {
|
||||||
|
return nil, E.Cause(resolveErr, "resolve STUN server")
|
||||||
|
}
|
||||||
|
packetConn, err = net.ListenPacket("udp", "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "create UDP socket")
|
||||||
|
}
|
||||||
|
serverAddr = serverUDPAddr
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = packetConn.Close()
|
||||||
|
}()
|
||||||
|
if deadline.NeedAdditionalReadDeadline(packetConn) {
|
||||||
|
packetConn = deadline.NewPacketConn(bufio.NewPacketConn(packetConn))
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
rto := defaultRTO
|
||||||
|
|
||||||
|
// Phase 1: Binding
|
||||||
|
reportProgress(Progress{Phase: PhaseBinding})
|
||||||
|
|
||||||
|
txID := newTransactionID()
|
||||||
|
resp, latency, err := roundTrip(packetConn, serverAddr, txID, nil, rto)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "binding request")
|
||||||
|
}
|
||||||
|
|
||||||
|
rto = max(minRTO, 3*latency)
|
||||||
|
|
||||||
|
externalAddr, ok := resp.externalAddr()
|
||||||
|
if !ok {
|
||||||
|
return nil, E.New("no mapped address in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &Result{
|
||||||
|
ExternalAddr: externalAddr.String(),
|
||||||
|
LatencyMs: int32(latency.Milliseconds()),
|
||||||
|
}
|
||||||
|
|
||||||
|
reportProgress(Progress{
|
||||||
|
Phase: PhaseBinding,
|
||||||
|
ExternalAddr: result.ExternalAddr,
|
||||||
|
LatencyMs: result.LatencyMs,
|
||||||
|
})
|
||||||
|
|
||||||
|
otherAddr := resp.otherAddr
|
||||||
|
if !otherAddr.IsValid() {
|
||||||
|
result.NATTypeSupported = false
|
||||||
|
reportProgress(Progress{
|
||||||
|
Phase: PhaseDone,
|
||||||
|
ExternalAddr: result.ExternalAddr,
|
||||||
|
LatencyMs: result.LatencyMs,
|
||||||
|
})
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
result.NATTypeSupported = true
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return result, nil
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 2: NAT Mapping Detection (RFC 5780 Section 4.3)
|
||||||
|
reportProgress(Progress{
|
||||||
|
Phase: PhaseNATMapping,
|
||||||
|
ExternalAddr: result.ExternalAddr,
|
||||||
|
LatencyMs: result.LatencyMs,
|
||||||
|
})
|
||||||
|
|
||||||
|
result.NATMapping = detectNATMapping(
|
||||||
|
packetConn, serverSocksaddr.Port, externalAddr, otherAddr, rto,
|
||||||
|
)
|
||||||
|
|
||||||
|
reportProgress(Progress{
|
||||||
|
Phase: PhaseNATMapping,
|
||||||
|
ExternalAddr: result.ExternalAddr,
|
||||||
|
LatencyMs: result.LatencyMs,
|
||||||
|
NATMapping: result.NATMapping,
|
||||||
|
})
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return result, nil
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 3: NAT Filtering Detection (RFC 5780 Section 4.4)
|
||||||
|
reportProgress(Progress{
|
||||||
|
Phase: PhaseNATFiltering,
|
||||||
|
ExternalAddr: result.ExternalAddr,
|
||||||
|
LatencyMs: result.LatencyMs,
|
||||||
|
NATMapping: result.NATMapping,
|
||||||
|
})
|
||||||
|
|
||||||
|
result.NATFiltering = detectNATFiltering(packetConn, serverAddr, rto)
|
||||||
|
|
||||||
|
reportProgress(Progress{
|
||||||
|
Phase: PhaseDone,
|
||||||
|
ExternalAddr: result.ExternalAddr,
|
||||||
|
LatencyMs: result.LatencyMs,
|
||||||
|
NATMapping: result.NATMapping,
|
||||||
|
NATFiltering: result.NATFiltering,
|
||||||
|
})
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func detectNATMapping(
|
||||||
|
conn net.PacketConn,
|
||||||
|
serverPort uint16,
|
||||||
|
externalAddr netip.AddrPort,
|
||||||
|
otherAddr netip.AddrPort,
|
||||||
|
rto time.Duration,
|
||||||
|
) NATMapping {
|
||||||
|
// Mapping Test II: Send to other_ip:server_port
|
||||||
|
testIIAddr := net.UDPAddrFromAddrPort(
|
||||||
|
netip.AddrPortFrom(otherAddr.Addr(), serverPort),
|
||||||
|
)
|
||||||
|
txID2 := newTransactionID()
|
||||||
|
resp2, _, err := roundTrip(conn, testIIAddr, txID2, nil, rto)
|
||||||
|
if err != nil {
|
||||||
|
return NATMappingUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
externalAddr2, ok := resp2.externalAddr()
|
||||||
|
if !ok {
|
||||||
|
return NATMappingUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
if externalAddr == externalAddr2 {
|
||||||
|
return NATMappingEndpointIndependent
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mapping Test III: Send to other_ip:other_port
|
||||||
|
testIIIAddr := net.UDPAddrFromAddrPort(otherAddr)
|
||||||
|
txID3 := newTransactionID()
|
||||||
|
resp3, _, err := roundTrip(conn, testIIIAddr, txID3, nil, rto)
|
||||||
|
if err != nil {
|
||||||
|
return NATMappingUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
externalAddr3, ok := resp3.externalAddr()
|
||||||
|
if !ok {
|
||||||
|
return NATMappingUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
if externalAddr2 == externalAddr3 {
|
||||||
|
return NATMappingAddressDependent
|
||||||
|
}
|
||||||
|
return NATMappingAddressAndPortDependent
|
||||||
|
}
|
||||||
|
|
||||||
|
func detectNATFiltering(
|
||||||
|
conn net.PacketConn,
|
||||||
|
serverAddr net.Addr,
|
||||||
|
rto time.Duration,
|
||||||
|
) NATFiltering {
|
||||||
|
// Filtering Test II: Request response from different IP and port
|
||||||
|
txID := newTransactionID()
|
||||||
|
_, _, err := roundTrip(conn, serverAddr, txID,
|
||||||
|
[]stunAttribute{changeRequestAttr(changeIP | changePort)}, rto)
|
||||||
|
if err == nil {
|
||||||
|
return NATFilteringEndpointIndependent
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filtering Test III: Request response from different port only
|
||||||
|
txID = newTransactionID()
|
||||||
|
_, _, err = roundTrip(conn, serverAddr, txID,
|
||||||
|
[]stunAttribute{changeRequestAttr(changePort)}, rto)
|
||||||
|
if err == nil {
|
||||||
|
return NATFilteringAddressDependent
|
||||||
|
}
|
||||||
|
|
||||||
|
return NATFilteringAddressAndPortDependent
|
||||||
|
}
|
||||||
|
|
||||||
|
func paddingLen(n int) int {
|
||||||
|
if n%4 == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return 4 - n%4
|
||||||
|
}
|
||||||
@@ -38,37 +38,6 @@ func (w *acmeWrapper) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type acmeLogWriter struct {
|
|
||||||
logger logger.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *acmeLogWriter) Write(p []byte) (n int, err error) {
|
|
||||||
logLine := strings.ReplaceAll(string(p), " ", ": ")
|
|
||||||
switch {
|
|
||||||
case strings.HasPrefix(logLine, "error: "):
|
|
||||||
w.logger.Error(logLine[7:])
|
|
||||||
case strings.HasPrefix(logLine, "warn: "):
|
|
||||||
w.logger.Warn(logLine[6:])
|
|
||||||
case strings.HasPrefix(logLine, "info: "):
|
|
||||||
w.logger.Info(logLine[6:])
|
|
||||||
case strings.HasPrefix(logLine, "debug: "):
|
|
||||||
w.logger.Debug(logLine[7:])
|
|
||||||
default:
|
|
||||||
w.logger.Debug(logLine)
|
|
||||||
}
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *acmeLogWriter) Sync() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func encoderConfig() zapcore.EncoderConfig {
|
|
||||||
config := zap.NewProductionEncoderConfig()
|
|
||||||
config.TimeKey = zapcore.OmitKey
|
|
||||||
return config
|
|
||||||
}
|
|
||||||
|
|
||||||
func startACME(ctx context.Context, logger logger.Logger, options option.InboundACMEOptions) (*tls.Config, adapter.SimpleLifecycle, error) {
|
func startACME(ctx context.Context, logger logger.Logger, options option.InboundACMEOptions) (*tls.Config, adapter.SimpleLifecycle, error) {
|
||||||
var acmeServer string
|
var acmeServer string
|
||||||
switch options.Provider {
|
switch options.Provider {
|
||||||
@@ -91,8 +60,8 @@ func startACME(ctx context.Context, logger logger.Logger, options option.Inbound
|
|||||||
storage = certmagic.Default.Storage
|
storage = certmagic.Default.Storage
|
||||||
}
|
}
|
||||||
zapLogger := zap.New(zapcore.NewCore(
|
zapLogger := zap.New(zapcore.NewCore(
|
||||||
zapcore.NewConsoleEncoder(encoderConfig()),
|
zapcore.NewConsoleEncoder(ACMEEncoderConfig()),
|
||||||
&acmeLogWriter{logger: logger},
|
&ACMELogWriter{Logger: logger},
|
||||||
zap.DebugLevel,
|
zap.DebugLevel,
|
||||||
))
|
))
|
||||||
config := &certmagic.Config{
|
config := &certmagic.Config{
|
||||||
@@ -158,7 +127,7 @@ func startACME(ctx context.Context, logger logger.Logger, options option.Inbound
|
|||||||
} else {
|
} else {
|
||||||
tlsConfig = &tls.Config{
|
tlsConfig = &tls.Config{
|
||||||
GetCertificate: config.GetCertificate,
|
GetCertificate: config.GetCertificate,
|
||||||
NextProtos: []string{ACMETLS1Protocol},
|
NextProtos: []string{C.ACMETLS1Protocol},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return tlsConfig, &acmeWrapper{ctx: ctx, cfg: config, cache: cache, domain: options.Domain}, nil
|
return tlsConfig, &acmeWrapper{ctx: ctx, cfg: config, cache: cache, domain: options.Domain}, nil
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
package tls
|
|
||||||
|
|
||||||
const ACMETLS1Protocol = "acme-tls/1"
|
|
||||||
41
common/tls/acme_logger.go
Normal file
41
common/tls/acme_logger.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package tls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"go.uber.org/zap/zapcore"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ACMELogWriter struct {
|
||||||
|
Logger logger.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ACMELogWriter) Write(p []byte) (n int, err error) {
|
||||||
|
logLine := strings.ReplaceAll(string(p), " ", ": ")
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(logLine, "error: "):
|
||||||
|
w.Logger.Error(logLine[7:])
|
||||||
|
case strings.HasPrefix(logLine, "warn: "):
|
||||||
|
w.Logger.Warn(logLine[6:])
|
||||||
|
case strings.HasPrefix(logLine, "info: "):
|
||||||
|
w.Logger.Info(logLine[6:])
|
||||||
|
case strings.HasPrefix(logLine, "debug: "):
|
||||||
|
w.Logger.Debug(logLine[7:])
|
||||||
|
default:
|
||||||
|
w.Logger.Debug(logLine)
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ACMELogWriter) Sync() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ACMEEncoderConfig() zapcore.EncoderConfig {
|
||||||
|
config := zap.NewProductionEncoderConfig()
|
||||||
|
config.TimeKey = zapcore.OmitKey
|
||||||
|
return config
|
||||||
|
}
|
||||||
214
common/tls/apple_client.go
Normal file
214
common/tls/apple_client.go
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
//go:build darwin && cgo
|
||||||
|
|
||||||
|
package tls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
boxConstant "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
"github.com/sagernet/sing/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type appleCertificateStore interface {
|
||||||
|
StoreKind() string
|
||||||
|
CurrentPEM() []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type appleClientConfig struct {
|
||||||
|
serverName string
|
||||||
|
nextProtos []string
|
||||||
|
handshakeTimeout time.Duration
|
||||||
|
minVersion uint16
|
||||||
|
maxVersion uint16
|
||||||
|
insecure bool
|
||||||
|
anchorPEM string
|
||||||
|
anchorOnly bool
|
||||||
|
certificatePublicKeySHA256 [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) ServerName() string {
|
||||||
|
return c.serverName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) SetServerName(serverName string) {
|
||||||
|
c.serverName = serverName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) NextProtos() []string {
|
||||||
|
return c.nextProtos
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) SetNextProtos(nextProto []string) {
|
||||||
|
c.nextProtos = append(c.nextProtos[:0], nextProto...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) HandshakeTimeout() time.Duration {
|
||||||
|
return c.handshakeTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) SetHandshakeTimeout(timeout time.Duration) {
|
||||||
|
c.handshakeTimeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) STDConfig() (*STDConfig, error) {
|
||||||
|
return nil, E.New("unsupported usage for Apple TLS engine")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) Client(conn net.Conn) (Conn, error) {
|
||||||
|
return nil, os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleClientConfig) Clone() Config {
|
||||||
|
return &appleClientConfig{
|
||||||
|
serverName: c.serverName,
|
||||||
|
nextProtos: append([]string(nil), c.nextProtos...),
|
||||||
|
handshakeTimeout: c.handshakeTimeout,
|
||||||
|
minVersion: c.minVersion,
|
||||||
|
maxVersion: c.maxVersion,
|
||||||
|
insecure: c.insecure,
|
||||||
|
anchorPEM: c.anchorPEM,
|
||||||
|
anchorOnly: c.anchorOnly,
|
||||||
|
certificatePublicKeySHA256: append([][]byte(nil), c.certificatePublicKeySHA256...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
|
||||||
|
validated, err := ValidateAppleTLSOptions(ctx, options, "Apple TLS engine")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var serverName string
|
||||||
|
if options.ServerName != "" {
|
||||||
|
serverName = options.ServerName
|
||||||
|
} else if serverAddress != "" {
|
||||||
|
serverName = serverAddress
|
||||||
|
}
|
||||||
|
if serverName == "" && !options.Insecure && !allowEmptyServerName {
|
||||||
|
return nil, errMissingServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
var handshakeTimeout time.Duration
|
||||||
|
if options.HandshakeTimeout > 0 {
|
||||||
|
handshakeTimeout = options.HandshakeTimeout.Build()
|
||||||
|
} else {
|
||||||
|
handshakeTimeout = boxConstant.TCPTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
return &appleClientConfig{
|
||||||
|
serverName: serverName,
|
||||||
|
nextProtos: append([]string(nil), options.ALPN...),
|
||||||
|
handshakeTimeout: handshakeTimeout,
|
||||||
|
minVersion: validated.MinVersion,
|
||||||
|
maxVersion: validated.MaxVersion,
|
||||||
|
insecure: options.Insecure || len(options.CertificatePublicKeySHA256) > 0,
|
||||||
|
anchorPEM: validated.AnchorPEM,
|
||||||
|
anchorOnly: validated.AnchorOnly,
|
||||||
|
certificatePublicKeySHA256: append([][]byte(nil), options.CertificatePublicKeySHA256...),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type AppleTLSValidated struct {
|
||||||
|
MinVersion uint16
|
||||||
|
MaxVersion uint16
|
||||||
|
AnchorPEM string
|
||||||
|
AnchorOnly bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateAppleTLSOptions(ctx context.Context, options option.OutboundTLSOptions, engineName string) (AppleTLSValidated, error) {
|
||||||
|
if options.Reality != nil && options.Reality.Enabled {
|
||||||
|
return AppleTLSValidated{}, E.New("reality is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if options.UTLS != nil && options.UTLS.Enabled {
|
||||||
|
return AppleTLSValidated{}, E.New("utls is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if options.ECH != nil && options.ECH.Enabled {
|
||||||
|
return AppleTLSValidated{}, E.New("ech is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if options.DisableSNI {
|
||||||
|
return AppleTLSValidated{}, E.New("disable_sni is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if len(options.CipherSuites) > 0 {
|
||||||
|
return AppleTLSValidated{}, E.New("cipher_suites is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if len(options.CurvePreferences) > 0 {
|
||||||
|
return AppleTLSValidated{}, E.New("curve_preferences is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if len(options.ClientCertificate) > 0 || options.ClientCertificatePath != "" || len(options.ClientKey) > 0 || options.ClientKeyPath != "" {
|
||||||
|
return AppleTLSValidated{}, E.New("client certificate is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if options.Fragment || options.RecordFragment {
|
||||||
|
return AppleTLSValidated{}, E.New("tls fragment is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if options.KernelTx || options.KernelRx {
|
||||||
|
return AppleTLSValidated{}, E.New("ktls is unsupported in ", engineName)
|
||||||
|
}
|
||||||
|
if len(options.CertificatePublicKeySHA256) > 0 && (len(options.Certificate) > 0 || options.CertificatePath != "") {
|
||||||
|
return AppleTLSValidated{}, E.New("certificate_public_key_sha256 is conflict with certificate or certificate_path")
|
||||||
|
}
|
||||||
|
var minVersion uint16
|
||||||
|
if options.MinVersion != "" {
|
||||||
|
var err error
|
||||||
|
minVersion, err = ParseTLSVersion(options.MinVersion)
|
||||||
|
if err != nil {
|
||||||
|
return AppleTLSValidated{}, E.Cause(err, "parse min_version")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var maxVersion uint16
|
||||||
|
if options.MaxVersion != "" {
|
||||||
|
var err error
|
||||||
|
maxVersion, err = ParseTLSVersion(options.MaxVersion)
|
||||||
|
if err != nil {
|
||||||
|
return AppleTLSValidated{}, E.Cause(err, "parse max_version")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
anchorPEM, anchorOnly, err := AppleAnchorPEM(ctx, options)
|
||||||
|
if err != nil {
|
||||||
|
return AppleTLSValidated{}, err
|
||||||
|
}
|
||||||
|
return AppleTLSValidated{
|
||||||
|
MinVersion: minVersion,
|
||||||
|
MaxVersion: maxVersion,
|
||||||
|
AnchorPEM: anchorPEM,
|
||||||
|
AnchorOnly: anchorOnly,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func AppleAnchorPEM(ctx context.Context, options option.OutboundTLSOptions) (string, bool, error) {
|
||||||
|
if len(options.Certificate) > 0 {
|
||||||
|
return strings.Join(options.Certificate, "\n"), true, nil
|
||||||
|
}
|
||||||
|
if options.CertificatePath != "" {
|
||||||
|
content, err := os.ReadFile(options.CertificatePath)
|
||||||
|
if err != nil {
|
||||||
|
return "", false, E.Cause(err, "read certificate")
|
||||||
|
}
|
||||||
|
return string(content), true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
certificateStore := service.FromContext[adapter.CertificateStore](ctx)
|
||||||
|
if certificateStore == nil {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
store, ok := certificateStore.(appleCertificateStore)
|
||||||
|
if !ok {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch store.StoreKind() {
|
||||||
|
case boxConstant.CertificateStoreSystem, "":
|
||||||
|
return strings.Join(store.CurrentPEM(), "\n"), false, nil
|
||||||
|
case boxConstant.CertificateStoreMozilla, boxConstant.CertificateStoreChrome, boxConstant.CertificateStoreNone:
|
||||||
|
return strings.Join(store.CurrentPEM(), "\n"), true, nil
|
||||||
|
default:
|
||||||
|
return "", false, E.New("unsupported certificate store for Apple TLS engine: ", store.StoreKind())
|
||||||
|
}
|
||||||
|
}
|
||||||
414
common/tls/apple_client_platform.go
Normal file
414
common/tls/apple_client_platform.go
Normal file
@@ -0,0 +1,414 @@
|
|||||||
|
//go:build darwin && cgo
|
||||||
|
|
||||||
|
package tls
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CFLAGS: -x objective-c -fobjc-arc
|
||||||
|
#cgo LDFLAGS: -framework Foundation -framework Network -framework Security
|
||||||
|
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include "apple_client_platform.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *appleClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (Conn, error) {
|
||||||
|
rawSyscallConn, ok := common.Cast[syscall.Conn](conn)
|
||||||
|
if !ok {
|
||||||
|
return nil, E.New("apple TLS: requires fd-backed TCP connection")
|
||||||
|
}
|
||||||
|
syscallConn, err := rawSyscallConn.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "access raw connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
var dupFD int
|
||||||
|
controlErr := syscallConn.Control(func(fd uintptr) {
|
||||||
|
dupFD, err = unix.Dup(int(fd))
|
||||||
|
})
|
||||||
|
if controlErr != nil {
|
||||||
|
return nil, E.Cause(controlErr, "access raw connection")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "duplicate raw connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
serverName := c.serverName
|
||||||
|
serverNamePtr := cStringOrNil(serverName)
|
||||||
|
defer cFree(serverNamePtr)
|
||||||
|
|
||||||
|
alpn := strings.Join(c.nextProtos, "\n")
|
||||||
|
alpnPtr := cStringOrNil(alpn)
|
||||||
|
defer cFree(alpnPtr)
|
||||||
|
|
||||||
|
anchorPEMPtr := cStringOrNil(c.anchorPEM)
|
||||||
|
defer cFree(anchorPEMPtr)
|
||||||
|
|
||||||
|
var errorPtr *C.char
|
||||||
|
client := C.box_apple_tls_client_create(
|
||||||
|
C.int(dupFD),
|
||||||
|
serverNamePtr,
|
||||||
|
alpnPtr,
|
||||||
|
C.size_t(len(alpn)),
|
||||||
|
C.uint16_t(c.minVersion),
|
||||||
|
C.uint16_t(c.maxVersion),
|
||||||
|
C.bool(c.insecure),
|
||||||
|
anchorPEMPtr,
|
||||||
|
C.size_t(len(c.anchorPEM)),
|
||||||
|
C.bool(c.anchorOnly),
|
||||||
|
&errorPtr,
|
||||||
|
)
|
||||||
|
if client == nil {
|
||||||
|
if errorPtr != nil {
|
||||||
|
defer C.free(unsafe.Pointer(errorPtr))
|
||||||
|
return nil, E.New(C.GoString(errorPtr))
|
||||||
|
}
|
||||||
|
return nil, E.New("apple TLS: create connection")
|
||||||
|
}
|
||||||
|
if err = waitAppleTLSClientReady(ctx, client); err != nil {
|
||||||
|
C.box_apple_tls_client_cancel(client)
|
||||||
|
C.box_apple_tls_client_free(client)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var state C.box_apple_tls_state_t
|
||||||
|
stateOK := C.box_apple_tls_client_copy_state(client, &state, &errorPtr)
|
||||||
|
if !bool(stateOK) {
|
||||||
|
C.box_apple_tls_client_cancel(client)
|
||||||
|
C.box_apple_tls_client_free(client)
|
||||||
|
if errorPtr != nil {
|
||||||
|
defer C.free(unsafe.Pointer(errorPtr))
|
||||||
|
return nil, E.New(C.GoString(errorPtr))
|
||||||
|
}
|
||||||
|
return nil, E.New("apple TLS: read metadata")
|
||||||
|
}
|
||||||
|
defer C.box_apple_tls_state_free(&state)
|
||||||
|
|
||||||
|
connectionState, rawCerts, err := parseAppleTLSState(&state)
|
||||||
|
if err != nil {
|
||||||
|
C.box_apple_tls_client_cancel(client)
|
||||||
|
C.box_apple_tls_client_free(client)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(c.certificatePublicKeySHA256) > 0 {
|
||||||
|
err = VerifyPublicKeySHA256(c.certificatePublicKeySHA256, rawCerts)
|
||||||
|
if err != nil {
|
||||||
|
C.box_apple_tls_client_cancel(client)
|
||||||
|
C.box_apple_tls_client_free(client)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &appleTLSConn{
|
||||||
|
rawConn: conn,
|
||||||
|
client: client,
|
||||||
|
state: connectionState,
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const appleTLSHandshakePollInterval = 100 * time.Millisecond
|
||||||
|
|
||||||
|
func waitAppleTLSClientReady(ctx context.Context, client *C.box_apple_tls_client_t) error {
|
||||||
|
for {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
C.box_apple_tls_client_cancel(client)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
waitTimeout := appleTLSHandshakePollInterval
|
||||||
|
if deadline, loaded := ctx.Deadline(); loaded {
|
||||||
|
remaining := time.Until(deadline)
|
||||||
|
if remaining <= 0 {
|
||||||
|
C.box_apple_tls_client_cancel(client)
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return context.DeadlineExceeded
|
||||||
|
}
|
||||||
|
if remaining < waitTimeout {
|
||||||
|
waitTimeout = remaining
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errorPtr *C.char
|
||||||
|
waitResult := C.box_apple_tls_client_wait_ready(client, C.int(timeoutFromDuration(waitTimeout)), &errorPtr)
|
||||||
|
switch waitResult {
|
||||||
|
case 1:
|
||||||
|
return nil
|
||||||
|
case -2:
|
||||||
|
continue
|
||||||
|
case 0:
|
||||||
|
if errorPtr != nil {
|
||||||
|
defer C.free(unsafe.Pointer(errorPtr))
|
||||||
|
return E.New(C.GoString(errorPtr))
|
||||||
|
}
|
||||||
|
return E.New("apple TLS: handshake failed")
|
||||||
|
default:
|
||||||
|
return E.New("apple TLS: invalid handshake state")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type appleTLSConn struct {
|
||||||
|
rawConn net.Conn
|
||||||
|
client *C.box_apple_tls_client_t
|
||||||
|
state tls.ConnectionState
|
||||||
|
|
||||||
|
readAccess sync.Mutex
|
||||||
|
writeAccess sync.Mutex
|
||||||
|
stateAccess sync.RWMutex
|
||||||
|
closeOnce sync.Once
|
||||||
|
ioAccess sync.Mutex
|
||||||
|
ioGroup sync.WaitGroup
|
||||||
|
closed chan struct{}
|
||||||
|
readEOF bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) Read(p []byte) (int, error) {
|
||||||
|
c.readAccess.Lock()
|
||||||
|
defer c.readAccess.Unlock()
|
||||||
|
if c.readEOF {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := c.acquireClient()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer c.releaseClient()
|
||||||
|
|
||||||
|
var eof C.bool
|
||||||
|
var errorPtr *C.char
|
||||||
|
n := C.box_apple_tls_client_read(client, unsafe.Pointer(&p[0]), C.size_t(len(p)), &eof, &errorPtr)
|
||||||
|
switch {
|
||||||
|
case n >= 0:
|
||||||
|
if bool(eof) {
|
||||||
|
c.readEOF = true
|
||||||
|
if n == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return int(n), nil
|
||||||
|
default:
|
||||||
|
if errorPtr != nil {
|
||||||
|
defer C.free(unsafe.Pointer(errorPtr))
|
||||||
|
if c.isClosed() {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
return 0, E.New(C.GoString(errorPtr))
|
||||||
|
}
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) Write(p []byte) (int, error) {
|
||||||
|
c.writeAccess.Lock()
|
||||||
|
defer c.writeAccess.Unlock()
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := c.acquireClient()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer c.releaseClient()
|
||||||
|
|
||||||
|
var errorPtr *C.char
|
||||||
|
n := C.box_apple_tls_client_write(client, unsafe.Pointer(&p[0]), C.size_t(len(p)), &errorPtr)
|
||||||
|
if n >= 0 {
|
||||||
|
return int(n), nil
|
||||||
|
}
|
||||||
|
if errorPtr != nil {
|
||||||
|
defer C.free(unsafe.Pointer(errorPtr))
|
||||||
|
if c.isClosed() {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
return 0, E.New(C.GoString(errorPtr))
|
||||||
|
}
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) Close() error {
|
||||||
|
var closeErr error
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
close(c.closed)
|
||||||
|
C.box_apple_tls_client_cancel(c.client)
|
||||||
|
closeErr = c.rawConn.Close()
|
||||||
|
c.ioAccess.Lock()
|
||||||
|
c.ioGroup.Wait()
|
||||||
|
C.box_apple_tls_client_free(c.client)
|
||||||
|
c.client = nil
|
||||||
|
c.ioAccess.Unlock()
|
||||||
|
})
|
||||||
|
return closeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) LocalAddr() net.Addr {
|
||||||
|
return c.rawConn.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) RemoteAddr() net.Addr {
|
||||||
|
return c.rawConn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) SetDeadline(t time.Time) error {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) SetReadDeadline(t time.Time) error {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) NeedAdditionalReadDeadline() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) isClosed() bool {
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) acquireClient() (*C.box_apple_tls_client_t, error) {
|
||||||
|
c.ioAccess.Lock()
|
||||||
|
defer c.ioAccess.Unlock()
|
||||||
|
if c.isClosed() {
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
client := c.client
|
||||||
|
if client == nil {
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
c.ioGroup.Add(1)
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) releaseClient() {
|
||||||
|
c.ioGroup.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) NetConn() net.Conn {
|
||||||
|
return c.rawConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) HandshakeContext(ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *appleTLSConn) ConnectionState() ConnectionState {
|
||||||
|
c.stateAccess.RLock()
|
||||||
|
defer c.stateAccess.RUnlock()
|
||||||
|
return c.state
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAppleTLSState(state *C.box_apple_tls_state_t) (tls.ConnectionState, [][]byte, error) {
|
||||||
|
rawCerts, peerCertificates, err := parseAppleCertChain(state.peer_cert_chain, state.peer_cert_chain_len)
|
||||||
|
if err != nil {
|
||||||
|
return tls.ConnectionState{}, nil, err
|
||||||
|
}
|
||||||
|
var negotiatedProtocol string
|
||||||
|
if state.alpn != nil {
|
||||||
|
negotiatedProtocol = C.GoString(state.alpn)
|
||||||
|
}
|
||||||
|
var serverName string
|
||||||
|
if state.server_name != nil {
|
||||||
|
serverName = C.GoString(state.server_name)
|
||||||
|
}
|
||||||
|
return tls.ConnectionState{
|
||||||
|
Version: uint16(state.version),
|
||||||
|
HandshakeComplete: true,
|
||||||
|
CipherSuite: uint16(state.cipher_suite),
|
||||||
|
NegotiatedProtocol: negotiatedProtocol,
|
||||||
|
ServerName: serverName,
|
||||||
|
PeerCertificates: peerCertificates,
|
||||||
|
}, rawCerts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAppleCertChain(chain *C.uint8_t, chainLen C.size_t) ([][]byte, []*x509.Certificate, error) {
|
||||||
|
if chain == nil || chainLen == 0 {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
chainBytes := C.GoBytes(unsafe.Pointer(chain), C.int(chainLen))
|
||||||
|
var (
|
||||||
|
rawCerts [][]byte
|
||||||
|
peerCertificates []*x509.Certificate
|
||||||
|
)
|
||||||
|
for len(chainBytes) >= 4 {
|
||||||
|
certificateLen := binary.BigEndian.Uint32(chainBytes[:4])
|
||||||
|
chainBytes = chainBytes[4:]
|
||||||
|
if len(chainBytes) < int(certificateLen) {
|
||||||
|
return nil, nil, E.New("apple TLS: invalid certificate chain")
|
||||||
|
}
|
||||||
|
certificateData := append([]byte(nil), chainBytes[:certificateLen]...)
|
||||||
|
certificate, err := x509.ParseCertificate(certificateData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, E.Cause(err, "parse peer certificate")
|
||||||
|
}
|
||||||
|
rawCerts = append(rawCerts, certificateData)
|
||||||
|
peerCertificates = append(peerCertificates, certificate)
|
||||||
|
chainBytes = chainBytes[certificateLen:]
|
||||||
|
}
|
||||||
|
if len(chainBytes) != 0 {
|
||||||
|
return nil, nil, E.New("apple TLS: invalid certificate chain")
|
||||||
|
}
|
||||||
|
return rawCerts, peerCertificates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func timeoutFromDuration(timeout time.Duration) int {
|
||||||
|
if timeout <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
timeoutMilliseconds := int64(timeout / time.Millisecond)
|
||||||
|
if timeout%time.Millisecond != 0 {
|
||||||
|
timeoutMilliseconds++
|
||||||
|
}
|
||||||
|
if timeoutMilliseconds > math.MaxInt32 {
|
||||||
|
return math.MaxInt32
|
||||||
|
}
|
||||||
|
return int(timeoutMilliseconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cStringOrNil(value string) *C.char {
|
||||||
|
if value == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return C.CString(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cFree(pointer *C.char) {
|
||||||
|
if pointer != nil {
|
||||||
|
C.free(unsafe.Pointer(pointer))
|
||||||
|
}
|
||||||
|
}
|
||||||
37
common/tls/apple_client_platform.h
Normal file
37
common/tls/apple_client_platform.h
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
#include <stdbool.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
|
typedef struct box_apple_tls_client box_apple_tls_client_t;
|
||||||
|
|
||||||
|
typedef struct box_apple_tls_state {
|
||||||
|
uint16_t version;
|
||||||
|
uint16_t cipher_suite;
|
||||||
|
char *alpn;
|
||||||
|
char *server_name;
|
||||||
|
uint8_t *peer_cert_chain;
|
||||||
|
size_t peer_cert_chain_len;
|
||||||
|
} box_apple_tls_state_t;
|
||||||
|
|
||||||
|
box_apple_tls_client_t *box_apple_tls_client_create(
|
||||||
|
int connected_socket,
|
||||||
|
const char *server_name,
|
||||||
|
const char *alpn,
|
||||||
|
size_t alpn_len,
|
||||||
|
uint16_t min_version,
|
||||||
|
uint16_t max_version,
|
||||||
|
bool insecure,
|
||||||
|
const char *anchor_pem,
|
||||||
|
size_t anchor_pem_len,
|
||||||
|
bool anchor_only,
|
||||||
|
char **error_out
|
||||||
|
);
|
||||||
|
|
||||||
|
int box_apple_tls_client_wait_ready(box_apple_tls_client_t *client, int timeout_msec, char **error_out);
|
||||||
|
void box_apple_tls_client_cancel(box_apple_tls_client_t *client);
|
||||||
|
void box_apple_tls_client_free(box_apple_tls_client_t *client);
|
||||||
|
ssize_t box_apple_tls_client_read(box_apple_tls_client_t *client, void *buffer, size_t buffer_len, bool *eof_out, char **error_out);
|
||||||
|
ssize_t box_apple_tls_client_write(box_apple_tls_client_t *client, const void *buffer, size_t buffer_len, char **error_out);
|
||||||
|
bool box_apple_tls_client_copy_state(box_apple_tls_client_t *client, box_apple_tls_state_t *state, char **error_out);
|
||||||
|
void box_apple_tls_state_free(box_apple_tls_state_t *state);
|
||||||
631
common/tls/apple_client_platform.m
Normal file
631
common/tls/apple_client_platform.m
Normal file
@@ -0,0 +1,631 @@
|
|||||||
|
#import "apple_client_platform.h"
|
||||||
|
|
||||||
|
#import <Foundation/Foundation.h>
|
||||||
|
#import <Network/Network.h>
|
||||||
|
#import <Security/Security.h>
|
||||||
|
#import <Security/SecProtocolMetadata.h>
|
||||||
|
#import <Security/SecProtocolOptions.h>
|
||||||
|
#import <Security/SecProtocolTypes.h>
|
||||||
|
#import <arpa/inet.h>
|
||||||
|
#import <dlfcn.h>
|
||||||
|
#import <dispatch/dispatch.h>
|
||||||
|
#import <stdatomic.h>
|
||||||
|
#import <stdlib.h>
|
||||||
|
#import <string.h>
|
||||||
|
#import <unistd.h>
|
||||||
|
|
||||||
|
typedef nw_connection_t _Nullable (*box_nw_connection_create_with_connected_socket_and_parameters_f)(int connected_socket, nw_parameters_t parameters);
|
||||||
|
typedef const char * _Nullable (*box_sec_protocol_metadata_string_accessor_f)(sec_protocol_metadata_t metadata);
|
||||||
|
|
||||||
|
typedef struct box_apple_tls_client {
|
||||||
|
void *connection;
|
||||||
|
void *queue;
|
||||||
|
void *ready_semaphore;
|
||||||
|
atomic_int ref_count;
|
||||||
|
atomic_bool ready;
|
||||||
|
atomic_bool ready_done;
|
||||||
|
char *ready_error;
|
||||||
|
box_apple_tls_state_t state;
|
||||||
|
} box_apple_tls_client_t;
|
||||||
|
|
||||||
|
static nw_connection_t box_apple_tls_connection(box_apple_tls_client_t *client) {
|
||||||
|
if (client == NULL || client->connection == NULL) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
return (__bridge nw_connection_t)client->connection;
|
||||||
|
}
|
||||||
|
|
||||||
|
static dispatch_queue_t box_apple_tls_client_queue(box_apple_tls_client_t *client) {
|
||||||
|
if (client == NULL || client->queue == NULL) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
return (__bridge dispatch_queue_t)client->queue;
|
||||||
|
}
|
||||||
|
|
||||||
|
static dispatch_semaphore_t box_apple_tls_ready_semaphore(box_apple_tls_client_t *client) {
|
||||||
|
if (client == NULL || client->ready_semaphore == NULL) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
return (__bridge dispatch_semaphore_t)client->ready_semaphore;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void box_apple_tls_state_reset(box_apple_tls_state_t *state) {
|
||||||
|
if (state == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
free(state->alpn);
|
||||||
|
free(state->server_name);
|
||||||
|
free(state->peer_cert_chain);
|
||||||
|
memset(state, 0, sizeof(box_apple_tls_state_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void box_apple_tls_client_destroy(box_apple_tls_client_t *client) {
|
||||||
|
free(client->ready_error);
|
||||||
|
box_apple_tls_state_reset(&client->state);
|
||||||
|
if (client->ready_semaphore != NULL) {
|
||||||
|
CFBridgingRelease(client->ready_semaphore);
|
||||||
|
}
|
||||||
|
if (client->connection != NULL) {
|
||||||
|
CFBridgingRelease(client->connection);
|
||||||
|
}
|
||||||
|
if (client->queue != NULL) {
|
||||||
|
CFBridgingRelease(client->queue);
|
||||||
|
}
|
||||||
|
free(client);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void box_apple_tls_client_release(box_apple_tls_client_t *client) {
|
||||||
|
if (client == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (atomic_fetch_sub(&client->ref_count, 1) == 1) {
|
||||||
|
box_apple_tls_client_destroy(client);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void box_set_error_string(char **error_out, NSString *message) {
|
||||||
|
if (error_out == NULL || *error_out != NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const char *utf8 = [message UTF8String];
|
||||||
|
*error_out = strdup(utf8 != NULL ? utf8 : "unknown error");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void box_set_error_message(char **error_out, const char *message) {
|
||||||
|
if (error_out == NULL || *error_out != NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
*error_out = strdup(message != NULL ? message : "unknown error");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void box_set_error_from_nw_error(char **error_out, nw_error_t error) {
|
||||||
|
if (error == NULL) {
|
||||||
|
box_set_error_message(error_out, "unknown network error");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
CFErrorRef cfError = nw_error_copy_cf_error(error);
|
||||||
|
if (cfError == NULL) {
|
||||||
|
box_set_error_message(error_out, "unknown network error");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
NSString *description = [(__bridge NSError *)cfError description];
|
||||||
|
box_set_error_string(error_out, description);
|
||||||
|
CFRelease(cfError);
|
||||||
|
}
|
||||||
|
|
||||||
|
static char *box_apple_tls_metadata_copy_negotiated_protocol(sec_protocol_metadata_t metadata) {
|
||||||
|
static box_sec_protocol_metadata_string_accessor_f copy_fn;
|
||||||
|
static box_sec_protocol_metadata_string_accessor_f get_fn;
|
||||||
|
static dispatch_once_t onceToken;
|
||||||
|
dispatch_once(&onceToken, ^{
|
||||||
|
copy_fn = (box_sec_protocol_metadata_string_accessor_f)dlsym(RTLD_DEFAULT, "sec_protocol_metadata_copy_negotiated_protocol");
|
||||||
|
get_fn = (box_sec_protocol_metadata_string_accessor_f)dlsym(RTLD_DEFAULT, "sec_protocol_metadata_get_negotiated_protocol");
|
||||||
|
});
|
||||||
|
if (copy_fn != NULL) {
|
||||||
|
return (char *)copy_fn(metadata);
|
||||||
|
}
|
||||||
|
if (get_fn != NULL) {
|
||||||
|
const char *protocol = get_fn(metadata);
|
||||||
|
if (protocol != NULL) {
|
||||||
|
return strdup(protocol);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
static char *box_apple_tls_metadata_copy_server_name(sec_protocol_metadata_t metadata) {
|
||||||
|
static box_sec_protocol_metadata_string_accessor_f copy_fn;
|
||||||
|
static box_sec_protocol_metadata_string_accessor_f get_fn;
|
||||||
|
static dispatch_once_t onceToken;
|
||||||
|
dispatch_once(&onceToken, ^{
|
||||||
|
copy_fn = (box_sec_protocol_metadata_string_accessor_f)dlsym(RTLD_DEFAULT, "sec_protocol_metadata_copy_server_name");
|
||||||
|
get_fn = (box_sec_protocol_metadata_string_accessor_f)dlsym(RTLD_DEFAULT, "sec_protocol_metadata_get_server_name");
|
||||||
|
});
|
||||||
|
if (copy_fn != NULL) {
|
||||||
|
return (char *)copy_fn(metadata);
|
||||||
|
}
|
||||||
|
if (get_fn != NULL) {
|
||||||
|
const char *server_name = get_fn(metadata);
|
||||||
|
if (server_name != NULL) {
|
||||||
|
return strdup(server_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
static NSArray<NSString *> *box_split_lines(const char *content, size_t content_len) {
|
||||||
|
if (content == NULL || content_len == 0) {
|
||||||
|
return @[];
|
||||||
|
}
|
||||||
|
NSString *string = [[NSString alloc] initWithBytes:content length:content_len encoding:NSUTF8StringEncoding];
|
||||||
|
if (string == nil) {
|
||||||
|
return @[];
|
||||||
|
}
|
||||||
|
NSMutableArray<NSString *> *lines = [NSMutableArray array];
|
||||||
|
[string enumerateLinesUsingBlock:^(NSString *line, BOOL *stop) {
|
||||||
|
if (line.length > 0) {
|
||||||
|
[lines addObject:line];
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
return lines;
|
||||||
|
}
|
||||||
|
|
||||||
|
static NSArray *box_parse_certificates_from_pem(const char *pem, size_t pem_len) {
|
||||||
|
if (pem == NULL || pem_len == 0) {
|
||||||
|
return @[];
|
||||||
|
}
|
||||||
|
NSString *content = [[NSString alloc] initWithBytes:pem length:pem_len encoding:NSUTF8StringEncoding];
|
||||||
|
if (content == nil) {
|
||||||
|
return @[];
|
||||||
|
}
|
||||||
|
NSString *beginMarker = @"-----BEGIN CERTIFICATE-----";
|
||||||
|
NSString *endMarker = @"-----END CERTIFICATE-----";
|
||||||
|
NSMutableArray *certificates = [NSMutableArray array];
|
||||||
|
NSUInteger searchFrom = 0;
|
||||||
|
while (searchFrom < content.length) {
|
||||||
|
NSRange beginRange = [content rangeOfString:beginMarker options:0 range:NSMakeRange(searchFrom, content.length - searchFrom)];
|
||||||
|
if (beginRange.location == NSNotFound) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
NSUInteger bodyStart = beginRange.location + beginRange.length;
|
||||||
|
NSRange endRange = [content rangeOfString:endMarker options:0 range:NSMakeRange(bodyStart, content.length - bodyStart)];
|
||||||
|
if (endRange.location == NSNotFound) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
NSString *base64Section = [content substringWithRange:NSMakeRange(bodyStart, endRange.location - bodyStart)];
|
||||||
|
NSArray<NSString *> *components = [base64Section componentsSeparatedByCharactersInSet:[NSCharacterSet whitespaceAndNewlineCharacterSet]];
|
||||||
|
NSString *base64Content = [components componentsJoinedByString:@""];
|
||||||
|
NSData *der = [[NSData alloc] initWithBase64EncodedString:base64Content options:0];
|
||||||
|
if (der != nil) {
|
||||||
|
SecCertificateRef certificate = SecCertificateCreateWithData(NULL, (__bridge CFDataRef)der);
|
||||||
|
if (certificate != NULL) {
|
||||||
|
[certificates addObject:(__bridge id)certificate];
|
||||||
|
CFRelease(certificate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
searchFrom = endRange.location + endRange.length;
|
||||||
|
}
|
||||||
|
return certificates;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool box_evaluate_trust(sec_trust_t trust, NSArray *anchors, bool anchor_only) {
|
||||||
|
bool result = false;
|
||||||
|
SecTrustRef trustRef = sec_trust_copy_ref(trust);
|
||||||
|
if (trustRef == NULL) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (anchors.count > 0 || anchor_only) {
|
||||||
|
CFMutableArrayRef anchorArray = CFArrayCreateMutable(NULL, 0, &kCFTypeArrayCallBacks);
|
||||||
|
for (id certificate in anchors) {
|
||||||
|
CFArrayAppendValue(anchorArray, (__bridge const void *)certificate);
|
||||||
|
}
|
||||||
|
SecTrustSetAnchorCertificates(trustRef, anchorArray);
|
||||||
|
SecTrustSetAnchorCertificatesOnly(trustRef, anchor_only);
|
||||||
|
CFRelease(anchorArray);
|
||||||
|
}
|
||||||
|
CFErrorRef error = NULL;
|
||||||
|
result = SecTrustEvaluateWithError(trustRef, &error);
|
||||||
|
if (error != NULL) {
|
||||||
|
CFRelease(error);
|
||||||
|
}
|
||||||
|
CFRelease(trustRef);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static nw_connection_t box_apple_tls_create_connection(int connected_socket, nw_parameters_t parameters) {
|
||||||
|
static box_nw_connection_create_with_connected_socket_and_parameters_f create_fn;
|
||||||
|
static dispatch_once_t onceToken;
|
||||||
|
dispatch_once(&onceToken, ^{
|
||||||
|
char name[] = "sretemarap_dna_tekcos_detcennoc_htiw_etaerc_noitcennoc_wn";
|
||||||
|
for (size_t i = 0, j = sizeof(name) - 2; i < j; i++, j--) {
|
||||||
|
char t = name[i];
|
||||||
|
name[i] = name[j];
|
||||||
|
name[j] = t;
|
||||||
|
}
|
||||||
|
create_fn = (box_nw_connection_create_with_connected_socket_and_parameters_f)dlsym(RTLD_DEFAULT, name);
|
||||||
|
});
|
||||||
|
if (create_fn == NULL) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
return create_fn(connected_socket, parameters);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool box_apple_tls_state_copy(const box_apple_tls_state_t *source, box_apple_tls_state_t *destination) {
|
||||||
|
memset(destination, 0, sizeof(box_apple_tls_state_t));
|
||||||
|
destination->version = source->version;
|
||||||
|
destination->cipher_suite = source->cipher_suite;
|
||||||
|
if (source->alpn != NULL) {
|
||||||
|
destination->alpn = strdup(source->alpn);
|
||||||
|
if (destination->alpn == NULL) {
|
||||||
|
goto oom;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (source->server_name != NULL) {
|
||||||
|
destination->server_name = strdup(source->server_name);
|
||||||
|
if (destination->server_name == NULL) {
|
||||||
|
goto oom;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (source->peer_cert_chain_len > 0) {
|
||||||
|
destination->peer_cert_chain = malloc(source->peer_cert_chain_len);
|
||||||
|
if (destination->peer_cert_chain == NULL) {
|
||||||
|
goto oom;
|
||||||
|
}
|
||||||
|
memcpy(destination->peer_cert_chain, source->peer_cert_chain, source->peer_cert_chain_len);
|
||||||
|
destination->peer_cert_chain_len = source->peer_cert_chain_len;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
|
||||||
|
oom:
|
||||||
|
box_apple_tls_state_reset(destination);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool box_apple_tls_state_load(nw_connection_t connection, box_apple_tls_state_t *state, char **error_out) {
|
||||||
|
box_apple_tls_state_reset(state);
|
||||||
|
if (connection == nil) {
|
||||||
|
box_set_error_message(error_out, "apple TLS: invalid client");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
nw_protocol_definition_t tls_definition = nw_protocol_copy_tls_definition();
|
||||||
|
nw_protocol_metadata_t metadata = nw_connection_copy_protocol_metadata(connection, tls_definition);
|
||||||
|
if (metadata == NULL || !nw_protocol_metadata_is_tls(metadata)) {
|
||||||
|
box_set_error_message(error_out, "apple TLS: metadata unavailable");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
sec_protocol_metadata_t sec_metadata = nw_tls_copy_sec_protocol_metadata(metadata);
|
||||||
|
if (sec_metadata == NULL) {
|
||||||
|
box_set_error_message(error_out, "apple TLS: metadata unavailable");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
state->version = (uint16_t)sec_protocol_metadata_get_negotiated_tls_protocol_version(sec_metadata);
|
||||||
|
state->cipher_suite = (uint16_t)sec_protocol_metadata_get_negotiated_tls_ciphersuite(sec_metadata);
|
||||||
|
state->alpn = box_apple_tls_metadata_copy_negotiated_protocol(sec_metadata);
|
||||||
|
state->server_name = box_apple_tls_metadata_copy_server_name(sec_metadata);
|
||||||
|
|
||||||
|
NSMutableData *chain_data = [NSMutableData data];
|
||||||
|
sec_protocol_metadata_access_peer_certificate_chain(sec_metadata, ^(sec_certificate_t certificate) {
|
||||||
|
SecCertificateRef certificate_ref = sec_certificate_copy_ref(certificate);
|
||||||
|
if (certificate_ref == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
CFDataRef certificate_data = SecCertificateCopyData(certificate_ref);
|
||||||
|
CFRelease(certificate_ref);
|
||||||
|
if (certificate_data == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
uint32_t certificate_len = (uint32_t)CFDataGetLength(certificate_data);
|
||||||
|
uint32_t network_len = htonl(certificate_len);
|
||||||
|
[chain_data appendBytes:&network_len length:sizeof(network_len)];
|
||||||
|
[chain_data appendBytes:CFDataGetBytePtr(certificate_data) length:certificate_len];
|
||||||
|
CFRelease(certificate_data);
|
||||||
|
});
|
||||||
|
if (chain_data.length > 0) {
|
||||||
|
state->peer_cert_chain = malloc(chain_data.length);
|
||||||
|
if (state->peer_cert_chain == NULL) {
|
||||||
|
box_set_error_message(error_out, "apple TLS: out of memory");
|
||||||
|
box_apple_tls_state_reset(state);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
memcpy(state->peer_cert_chain, chain_data.bytes, chain_data.length);
|
||||||
|
state->peer_cert_chain_len = chain_data.length;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
box_apple_tls_client_t *box_apple_tls_client_create(
|
||||||
|
int connected_socket,
|
||||||
|
const char *server_name,
|
||||||
|
const char *alpn,
|
||||||
|
size_t alpn_len,
|
||||||
|
uint16_t min_version,
|
||||||
|
uint16_t max_version,
|
||||||
|
bool insecure,
|
||||||
|
const char *anchor_pem,
|
||||||
|
size_t anchor_pem_len,
|
||||||
|
bool anchor_only,
|
||||||
|
char **error_out
|
||||||
|
) {
|
||||||
|
box_apple_tls_client_t *client = calloc(1, sizeof(box_apple_tls_client_t));
|
||||||
|
if (client == NULL) {
|
||||||
|
close(connected_socket);
|
||||||
|
box_set_error_message(error_out, "apple TLS: out of memory");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
client->queue = (__bridge_retained void *)dispatch_queue_create("sing-box.apple-private-tls", DISPATCH_QUEUE_SERIAL);
|
||||||
|
client->ready_semaphore = (__bridge_retained void *)dispatch_semaphore_create(0);
|
||||||
|
atomic_init(&client->ref_count, 1);
|
||||||
|
atomic_init(&client->ready, false);
|
||||||
|
atomic_init(&client->ready_done, false);
|
||||||
|
|
||||||
|
NSArray<NSString *> *alpnList = box_split_lines(alpn, alpn_len);
|
||||||
|
NSArray *anchors = box_parse_certificates_from_pem(anchor_pem, anchor_pem_len);
|
||||||
|
nw_parameters_t parameters = nw_parameters_create_secure_tcp(^(nw_protocol_options_t tls_options) {
|
||||||
|
sec_protocol_options_t sec_options = nw_tls_copy_sec_protocol_options(tls_options);
|
||||||
|
if (min_version != 0) {
|
||||||
|
sec_protocol_options_set_min_tls_protocol_version(sec_options, (tls_protocol_version_t)min_version);
|
||||||
|
}
|
||||||
|
if (max_version != 0) {
|
||||||
|
sec_protocol_options_set_max_tls_protocol_version(sec_options, (tls_protocol_version_t)max_version);
|
||||||
|
}
|
||||||
|
if (server_name != NULL && server_name[0] != '\0') {
|
||||||
|
sec_protocol_options_set_tls_server_name(sec_options, server_name);
|
||||||
|
}
|
||||||
|
for (NSString *protocol in alpnList) {
|
||||||
|
sec_protocol_options_add_tls_application_protocol(sec_options, protocol.UTF8String);
|
||||||
|
}
|
||||||
|
sec_protocol_options_set_peer_authentication_required(sec_options, !insecure);
|
||||||
|
if (insecure) {
|
||||||
|
sec_protocol_options_set_verify_block(sec_options, ^(sec_protocol_metadata_t metadata, sec_trust_t trust, sec_protocol_verify_complete_t complete) {
|
||||||
|
complete(true);
|
||||||
|
}, box_apple_tls_client_queue(client));
|
||||||
|
} else if (anchors.count > 0 || anchor_only) {
|
||||||
|
sec_protocol_options_set_verify_block(sec_options, ^(sec_protocol_metadata_t metadata, sec_trust_t trust, sec_protocol_verify_complete_t complete) {
|
||||||
|
complete(box_evaluate_trust(trust, anchors, anchor_only));
|
||||||
|
}, box_apple_tls_client_queue(client));
|
||||||
|
}
|
||||||
|
}, NW_PARAMETERS_DEFAULT_CONFIGURATION);
|
||||||
|
|
||||||
|
nw_connection_t connection = box_apple_tls_create_connection(connected_socket, parameters);
|
||||||
|
if (connection == NULL) {
|
||||||
|
close(connected_socket);
|
||||||
|
if (client->ready_semaphore != NULL) {
|
||||||
|
CFBridgingRelease(client->ready_semaphore);
|
||||||
|
}
|
||||||
|
if (client->queue != NULL) {
|
||||||
|
CFBridgingRelease(client->queue);
|
||||||
|
}
|
||||||
|
free(client);
|
||||||
|
box_set_error_message(error_out, "apple TLS: failed to create connection");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
client->connection = (__bridge_retained void *)connection;
|
||||||
|
atomic_fetch_add(&client->ref_count, 1);
|
||||||
|
|
||||||
|
nw_connection_set_state_changed_handler(connection, ^(nw_connection_state_t state, nw_error_t error) {
|
||||||
|
switch (state) {
|
||||||
|
case nw_connection_state_ready:
|
||||||
|
if (!atomic_load(&client->ready_done)) {
|
||||||
|
atomic_store(&client->ready, box_apple_tls_state_load(connection, &client->state, &client->ready_error));
|
||||||
|
atomic_store(&client->ready_done, true);
|
||||||
|
dispatch_semaphore_signal(box_apple_tls_ready_semaphore(client));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case nw_connection_state_failed:
|
||||||
|
if (!atomic_load(&client->ready_done)) {
|
||||||
|
box_set_error_from_nw_error(&client->ready_error, error);
|
||||||
|
atomic_store(&client->ready_done, true);
|
||||||
|
dispatch_semaphore_signal(box_apple_tls_ready_semaphore(client));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case nw_connection_state_cancelled:
|
||||||
|
if (!atomic_load(&client->ready_done)) {
|
||||||
|
box_set_error_from_nw_error(&client->ready_error, error);
|
||||||
|
atomic_store(&client->ready_done, true);
|
||||||
|
dispatch_semaphore_signal(box_apple_tls_ready_semaphore(client));
|
||||||
|
}
|
||||||
|
box_apple_tls_client_release(client);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
nw_connection_set_queue(connection, box_apple_tls_client_queue(client));
|
||||||
|
nw_connection_start(connection);
|
||||||
|
return client;
|
||||||
|
}
|
||||||
|
|
||||||
|
int box_apple_tls_client_wait_ready(box_apple_tls_client_t *client, int timeout_msec, char **error_out) {
|
||||||
|
dispatch_semaphore_t ready_semaphore = box_apple_tls_ready_semaphore(client);
|
||||||
|
if (ready_semaphore == nil) {
|
||||||
|
box_set_error_message(error_out, "apple TLS: invalid client");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
if (!atomic_load(&client->ready_done)) {
|
||||||
|
dispatch_time_t timeout = DISPATCH_TIME_FOREVER;
|
||||||
|
if (timeout_msec >= 0) {
|
||||||
|
timeout = dispatch_time(DISPATCH_TIME_NOW, (int64_t)timeout_msec * NSEC_PER_MSEC);
|
||||||
|
}
|
||||||
|
long wait_result = dispatch_semaphore_wait(ready_semaphore, timeout);
|
||||||
|
if (wait_result != 0) {
|
||||||
|
return -2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (atomic_load(&client->ready)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
if (client->ready_error != NULL) {
|
||||||
|
if (error_out != NULL) {
|
||||||
|
*error_out = client->ready_error;
|
||||||
|
client->ready_error = NULL;
|
||||||
|
} else {
|
||||||
|
free(client->ready_error);
|
||||||
|
client->ready_error = NULL;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
box_set_error_message(error_out, "apple TLS: handshake failed");
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void box_apple_tls_client_cancel(box_apple_tls_client_t *client) {
|
||||||
|
if (client == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
nw_connection_t connection = box_apple_tls_connection(client);
|
||||||
|
if (connection != nil) {
|
||||||
|
nw_connection_cancel(connection);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void box_apple_tls_client_free(box_apple_tls_client_t *client) {
|
||||||
|
if (client == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
nw_connection_t connection = box_apple_tls_connection(client);
|
||||||
|
if (connection != nil) {
|
||||||
|
nw_connection_cancel(connection);
|
||||||
|
}
|
||||||
|
box_apple_tls_client_release(client);
|
||||||
|
}
|
||||||
|
|
||||||
|
ssize_t box_apple_tls_client_read(box_apple_tls_client_t *client, void *buffer, size_t buffer_len, bool *eof_out, char **error_out) {
|
||||||
|
nw_connection_t connection = box_apple_tls_connection(client);
|
||||||
|
if (connection == nil) {
|
||||||
|
box_set_error_message(error_out, "apple TLS: invalid client");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch_semaphore_t read_semaphore = dispatch_semaphore_create(0);
|
||||||
|
__block NSData *content_data = nil;
|
||||||
|
__block bool read_eof = false;
|
||||||
|
__block char *local_error = NULL;
|
||||||
|
|
||||||
|
nw_connection_receive(connection, 1, (uint32_t)buffer_len, ^(dispatch_data_t content, nw_content_context_t context, bool is_complete, nw_error_t error) {
|
||||||
|
if (content != NULL) {
|
||||||
|
const void *mapped = NULL;
|
||||||
|
size_t mapped_len = 0;
|
||||||
|
dispatch_data_t mapped_data = dispatch_data_create_map(content, &mapped, &mapped_len);
|
||||||
|
if (mapped != NULL && mapped_len > 0) {
|
||||||
|
content_data = [NSData dataWithBytes:mapped length:mapped_len];
|
||||||
|
}
|
||||||
|
(void)mapped_data;
|
||||||
|
}
|
||||||
|
if (error != NULL && content_data.length == 0) {
|
||||||
|
box_set_error_from_nw_error(&local_error, error);
|
||||||
|
}
|
||||||
|
if (is_complete && (context == NULL || nw_content_context_get_is_final(context))) {
|
||||||
|
read_eof = true;
|
||||||
|
}
|
||||||
|
dispatch_semaphore_signal(read_semaphore);
|
||||||
|
});
|
||||||
|
|
||||||
|
dispatch_semaphore_wait(read_semaphore, DISPATCH_TIME_FOREVER);
|
||||||
|
if (local_error != NULL) {
|
||||||
|
if (error_out != NULL) {
|
||||||
|
*error_out = local_error;
|
||||||
|
} else {
|
||||||
|
free(local_error);
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (eof_out != NULL) {
|
||||||
|
*eof_out = read_eof;
|
||||||
|
}
|
||||||
|
if (content_data == nil || content_data.length == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
memcpy(buffer, content_data.bytes, content_data.length);
|
||||||
|
return (ssize_t)content_data.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
ssize_t box_apple_tls_client_write(box_apple_tls_client_t *client, const void *buffer, size_t buffer_len, char **error_out) {
|
||||||
|
nw_connection_t connection = box_apple_tls_connection(client);
|
||||||
|
if (connection == nil) {
|
||||||
|
box_set_error_message(error_out, "apple TLS: invalid client");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (buffer_len == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void *content_copy = malloc(buffer_len);
|
||||||
|
dispatch_queue_t queue = box_apple_tls_client_queue(client);
|
||||||
|
if (content_copy == NULL) {
|
||||||
|
free(content_copy);
|
||||||
|
box_set_error_message(error_out, "apple TLS: out of memory");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (queue == nil) {
|
||||||
|
free(content_copy);
|
||||||
|
box_set_error_message(error_out, "apple TLS: invalid client");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
memcpy(content_copy, buffer, buffer_len);
|
||||||
|
dispatch_data_t content = dispatch_data_create(content_copy, buffer_len, queue, ^{
|
||||||
|
free(content_copy);
|
||||||
|
});
|
||||||
|
|
||||||
|
dispatch_semaphore_t write_semaphore = dispatch_semaphore_create(0);
|
||||||
|
__block char *local_error = NULL;
|
||||||
|
|
||||||
|
nw_connection_send(connection, content, NW_CONNECTION_DEFAULT_STREAM_CONTEXT, false, ^(nw_error_t error) {
|
||||||
|
if (error != NULL) {
|
||||||
|
box_set_error_from_nw_error(&local_error, error);
|
||||||
|
}
|
||||||
|
dispatch_semaphore_signal(write_semaphore);
|
||||||
|
});
|
||||||
|
|
||||||
|
dispatch_semaphore_wait(write_semaphore, DISPATCH_TIME_FOREVER);
|
||||||
|
if (local_error != NULL) {
|
||||||
|
if (error_out != NULL) {
|
||||||
|
*error_out = local_error;
|
||||||
|
} else {
|
||||||
|
free(local_error);
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
return (ssize_t)buffer_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool box_apple_tls_client_copy_state(box_apple_tls_client_t *client, box_apple_tls_state_t *state, char **error_out) {
|
||||||
|
dispatch_queue_t queue = box_apple_tls_client_queue(client);
|
||||||
|
if (queue == nil || state == NULL) {
|
||||||
|
box_set_error_message(error_out, "apple TLS: invalid client");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
memset(state, 0, sizeof(box_apple_tls_state_t));
|
||||||
|
__block bool copied = false;
|
||||||
|
__block char *local_error = NULL;
|
||||||
|
dispatch_sync(queue, ^{
|
||||||
|
if (!atomic_load(&client->ready)) {
|
||||||
|
box_set_error_message(&local_error, "apple TLS: metadata unavailable");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!box_apple_tls_state_copy(&client->state, state)) {
|
||||||
|
box_set_error_message(&local_error, "apple TLS: out of memory");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
copied = true;
|
||||||
|
});
|
||||||
|
if (copied) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (local_error != NULL) {
|
||||||
|
if (error_out != NULL) {
|
||||||
|
*error_out = local_error;
|
||||||
|
} else {
|
||||||
|
free(local_error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
box_apple_tls_state_reset(state);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void box_apple_tls_state_free(box_apple_tls_state_t *state) {
|
||||||
|
box_apple_tls_state_reset(state);
|
||||||
|
}
|
||||||
301
common/tls/apple_client_platform_test.go
Normal file
301
common/tls/apple_client_platform_test.go
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
//go:build darwin && cgo
|
||||||
|
|
||||||
|
package tls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
stdtls "crypto/tls"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing/common/json/badoption"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
const appleTLSTestTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
const (
|
||||||
|
appleTLSSuccessHandshakeLoops = 20
|
||||||
|
appleTLSFailureRecoveryLoops = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
type appleTLSServerResult struct {
|
||||||
|
state stdtls.ConnectionState
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleClientHandshakeAppliesALPNAndVersion(t *testing.T) {
|
||||||
|
serverCertificate, serverCertificatePEM := newAppleTestCertificate(t, "localhost")
|
||||||
|
for index := 0; index < appleTLSSuccessHandshakeLoops; index++ {
|
||||||
|
serverResult, serverAddress := startAppleTLSTestServer(t, &stdtls.Config{
|
||||||
|
Certificates: []stdtls.Certificate{serverCertificate},
|
||||||
|
MinVersion: stdtls.VersionTLS12,
|
||||||
|
MaxVersion: stdtls.VersionTLS12,
|
||||||
|
NextProtos: []string{"h2"},
|
||||||
|
})
|
||||||
|
|
||||||
|
clientConn, err := newAppleTestClientConn(t, serverAddress, option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
Engine: "apple",
|
||||||
|
ServerName: "localhost",
|
||||||
|
MinVersion: "1.2",
|
||||||
|
MaxVersion: "1.2",
|
||||||
|
ALPN: badoption.Listable[string]{"h2"},
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("iteration %d: %v", index, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientState := clientConn.ConnectionState()
|
||||||
|
if clientState.Version != stdtls.VersionTLS12 {
|
||||||
|
_ = clientConn.Close()
|
||||||
|
t.Fatalf("iteration %d: unexpected negotiated version: %x", index, clientState.Version)
|
||||||
|
}
|
||||||
|
if clientState.NegotiatedProtocol != "h2" {
|
||||||
|
_ = clientConn.Close()
|
||||||
|
t.Fatalf("iteration %d: unexpected negotiated protocol: %q", index, clientState.NegotiatedProtocol)
|
||||||
|
}
|
||||||
|
_ = clientConn.Close()
|
||||||
|
|
||||||
|
result := <-serverResult
|
||||||
|
if result.err != nil {
|
||||||
|
t.Fatalf("iteration %d: %v", index, result.err)
|
||||||
|
}
|
||||||
|
if result.state.Version != stdtls.VersionTLS12 {
|
||||||
|
t.Fatalf("iteration %d: server negotiated unexpected version: %x", index, result.state.Version)
|
||||||
|
}
|
||||||
|
if result.state.NegotiatedProtocol != "h2" {
|
||||||
|
t.Fatalf("iteration %d: server negotiated unexpected protocol: %q", index, result.state.NegotiatedProtocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleClientHandshakeRejectsVersionMismatch(t *testing.T) {
|
||||||
|
serverCertificate, serverCertificatePEM := newAppleTestCertificate(t, "localhost")
|
||||||
|
serverResult, serverAddress := startAppleTLSTestServer(t, &stdtls.Config{
|
||||||
|
Certificates: []stdtls.Certificate{serverCertificate},
|
||||||
|
MinVersion: stdtls.VersionTLS13,
|
||||||
|
MaxVersion: stdtls.VersionTLS13,
|
||||||
|
})
|
||||||
|
|
||||||
|
clientConn, err := newAppleTestClientConn(t, serverAddress, option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
Engine: "apple",
|
||||||
|
ServerName: "localhost",
|
||||||
|
MaxVersion: "1.2",
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
clientConn.Close()
|
||||||
|
t.Fatal("expected version mismatch handshake to fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result := <-serverResult; result.err == nil {
|
||||||
|
t.Fatal("expected server handshake to fail on version mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleClientHandshakeRejectsServerNameMismatch(t *testing.T) {
|
||||||
|
serverCertificate, serverCertificatePEM := newAppleTestCertificate(t, "localhost")
|
||||||
|
serverResult, serverAddress := startAppleTLSTestServer(t, &stdtls.Config{
|
||||||
|
Certificates: []stdtls.Certificate{serverCertificate},
|
||||||
|
})
|
||||||
|
|
||||||
|
clientConn, err := newAppleTestClientConn(t, serverAddress, option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
Engine: "apple",
|
||||||
|
ServerName: "example.com",
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
clientConn.Close()
|
||||||
|
t.Fatal("expected server name mismatch handshake to fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result := <-serverResult; result.err == nil {
|
||||||
|
t.Fatal("expected server handshake to fail on server name mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppleClientHandshakeRecoversAfterFailure(t *testing.T) {
|
||||||
|
serverCertificate, serverCertificatePEM := newAppleTestCertificate(t, "localhost")
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
serverConfig *stdtls.Config
|
||||||
|
clientOptions option.OutboundTLSOptions
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "version mismatch",
|
||||||
|
serverConfig: &stdtls.Config{
|
||||||
|
Certificates: []stdtls.Certificate{serverCertificate},
|
||||||
|
MinVersion: stdtls.VersionTLS13,
|
||||||
|
MaxVersion: stdtls.VersionTLS13,
|
||||||
|
},
|
||||||
|
clientOptions: option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
Engine: "apple",
|
||||||
|
ServerName: "localhost",
|
||||||
|
MaxVersion: "1.2",
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "server name mismatch",
|
||||||
|
serverConfig: &stdtls.Config{
|
||||||
|
Certificates: []stdtls.Certificate{serverCertificate},
|
||||||
|
},
|
||||||
|
clientOptions: option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
Engine: "apple",
|
||||||
|
ServerName: "example.com",
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
successClientOptions := option.OutboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
Engine: "apple",
|
||||||
|
ServerName: "localhost",
|
||||||
|
MinVersion: "1.2",
|
||||||
|
MaxVersion: "1.2",
|
||||||
|
ALPN: badoption.Listable[string]{"h2"},
|
||||||
|
Certificate: badoption.Listable[string]{serverCertificatePEM},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
for index := 0; index < appleTLSFailureRecoveryLoops; index++ {
|
||||||
|
failedResult, failedAddress := startAppleTLSTestServer(t, testCase.serverConfig)
|
||||||
|
failedConn, err := newAppleTestClientConn(t, failedAddress, testCase.clientOptions)
|
||||||
|
if err == nil {
|
||||||
|
_ = failedConn.Close()
|
||||||
|
t.Fatalf("iteration %d: expected handshake failure", index)
|
||||||
|
}
|
||||||
|
if result := <-failedResult; result.err == nil {
|
||||||
|
t.Fatalf("iteration %d: expected server handshake failure", index)
|
||||||
|
}
|
||||||
|
|
||||||
|
successResult, successAddress := startAppleTLSTestServer(t, &stdtls.Config{
|
||||||
|
Certificates: []stdtls.Certificate{serverCertificate},
|
||||||
|
MinVersion: stdtls.VersionTLS12,
|
||||||
|
MaxVersion: stdtls.VersionTLS12,
|
||||||
|
NextProtos: []string{"h2"},
|
||||||
|
})
|
||||||
|
successConn, err := newAppleTestClientConn(t, successAddress, successClientOptions)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("iteration %d: follow-up handshake failed: %v", index, err)
|
||||||
|
}
|
||||||
|
clientState := successConn.ConnectionState()
|
||||||
|
if clientState.NegotiatedProtocol != "h2" {
|
||||||
|
_ = successConn.Close()
|
||||||
|
t.Fatalf("iteration %d: unexpected negotiated protocol after failure: %q", index, clientState.NegotiatedProtocol)
|
||||||
|
}
|
||||||
|
_ = successConn.Close()
|
||||||
|
|
||||||
|
result := <-successResult
|
||||||
|
if result.err != nil {
|
||||||
|
t.Fatalf("iteration %d: follow-up server handshake failed: %v", index, result.err)
|
||||||
|
}
|
||||||
|
if result.state.NegotiatedProtocol != "h2" {
|
||||||
|
t.Fatalf("iteration %d: follow-up server negotiated unexpected protocol: %q", index, result.state.NegotiatedProtocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleTestCertificate(t *testing.T, serverName string) (stdtls.Certificate, string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
privateKeyPEM, certificatePEM, err := GenerateCertificate(nil, nil, time.Now, serverName, time.Now().Add(time.Hour))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
certificate, err := stdtls.X509KeyPair(certificatePEM, privateKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return certificate, string(certificatePEM)
|
||||||
|
}
|
||||||
|
|
||||||
|
func startAppleTLSTestServer(t *testing.T, tlsConfig *stdtls.Config) (<-chan appleTLSServerResult, string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
listener.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
if tcpListener, isTCP := listener.(*net.TCPListener); isTCP {
|
||||||
|
err = tcpListener.SetDeadline(time.Now().Add(appleTLSTestTimeout))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make(chan appleTLSServerResult, 1)
|
||||||
|
go func() {
|
||||||
|
defer close(result)
|
||||||
|
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
result <- appleTLSServerResult{err: err}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
err = conn.SetDeadline(time.Now().Add(appleTLSTestTimeout))
|
||||||
|
if err != nil {
|
||||||
|
result <- appleTLSServerResult{err: err}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn := stdtls.Server(conn, tlsConfig)
|
||||||
|
defer tlsConn.Close()
|
||||||
|
|
||||||
|
err = tlsConn.Handshake()
|
||||||
|
if err != nil {
|
||||||
|
result <- appleTLSServerResult{err: err}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result <- appleTLSServerResult{state: tlsConn.ConnectionState()}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return result, listener.Addr().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAppleTestClientConn(t *testing.T, serverAddress string, options option.OutboundTLSOptions) (Conn, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), appleTLSTestTimeout)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
clientConfig, err := NewClientWithOptions(ClientOptions{
|
||||||
|
Context: ctx,
|
||||||
|
Logger: logger.NOP(),
|
||||||
|
ServerAddress: "",
|
||||||
|
Options: options,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.DialTimeout("tcp", serverAddress, appleTLSTestTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn, err := ClientHandshake(ctx, conn, clientConfig)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return tlsConn, nil
|
||||||
|
}
|
||||||
15
common/tls/apple_client_stub.go
Normal file
15
common/tls/apple_client_stub.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build !darwin || !cgo
|
||||||
|
|
||||||
|
package tls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAppleClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
|
||||||
|
return nil, E.New("Apple TLS engine is not available on non-Apple platforms")
|
||||||
|
}
|
||||||
@@ -10,12 +10,15 @@ import (
|
|||||||
"github.com/sagernet/sing-box/common/badtls"
|
"github.com/sagernet/sing-box/common/badtls"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
"github.com/sagernet/sing/common/logger"
|
"github.com/sagernet/sing/common/logger"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
aTLS "github.com/sagernet/sing/common/tls"
|
aTLS "github.com/sagernet/sing/common/tls"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var errMissingServerName = E.New("missing server_name or insecure=true")
|
||||||
|
|
||||||
func NewDialerFromOptions(ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) {
|
func NewDialerFromOptions(ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) {
|
||||||
if !options.Enabled {
|
if !options.Enabled {
|
||||||
return dialer, nil
|
return dialer, nil
|
||||||
@@ -42,11 +45,12 @@ func NewClient(ctx context.Context, logger logger.ContextLogger, serverAddress s
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ClientOptions struct {
|
type ClientOptions struct {
|
||||||
Context context.Context
|
Context context.Context
|
||||||
Logger logger.ContextLogger
|
Logger logger.ContextLogger
|
||||||
ServerAddress string
|
ServerAddress string
|
||||||
Options option.OutboundTLSOptions
|
Options option.OutboundTLSOptions
|
||||||
KTLSCompatible bool
|
AllowEmptyServerName bool
|
||||||
|
KTLSCompatible bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClientWithOptions(options ClientOptions) (Config, error) {
|
func NewClientWithOptions(options ClientOptions) (Config, error) {
|
||||||
@@ -61,17 +65,22 @@ func NewClientWithOptions(options ClientOptions) (Config, error) {
|
|||||||
if options.Options.KernelRx {
|
if options.Options.KernelRx {
|
||||||
options.Logger.Warn("enabling kTLS RX will definitely reduce performance, please checkout https://sing-box.sagernet.org/configuration/shared/tls/#kernel_rx")
|
options.Logger.Warn("enabling kTLS RX will definitely reduce performance, please checkout https://sing-box.sagernet.org/configuration/shared/tls/#kernel_rx")
|
||||||
}
|
}
|
||||||
if options.Options.Reality != nil && options.Options.Reality.Enabled {
|
switch options.Options.Engine {
|
||||||
return NewRealityClient(options.Context, options.Logger, options.ServerAddress, options.Options)
|
case C.TLSEngineDefault, "go":
|
||||||
} else if options.Options.UTLS != nil && options.Options.UTLS.Enabled {
|
case C.TLSEngineApple:
|
||||||
return NewUTLSClient(options.Context, options.Logger, options.ServerAddress, options.Options)
|
return newAppleClient(options.Context, options.Logger, options.ServerAddress, options.Options, options.AllowEmptyServerName)
|
||||||
|
default:
|
||||||
|
return nil, E.New("unknown tls engine: ", options.Options.Engine)
|
||||||
}
|
}
|
||||||
return NewSTDClient(options.Context, options.Logger, options.ServerAddress, options.Options)
|
if options.Options.Reality != nil && options.Options.Reality.Enabled {
|
||||||
|
return newRealityClient(options.Context, options.Logger, options.ServerAddress, options.Options, options.AllowEmptyServerName)
|
||||||
|
} else if options.Options.UTLS != nil && options.Options.UTLS.Enabled {
|
||||||
|
return newUTLSClient(options.Context, options.Logger, options.ServerAddress, options.Options, options.AllowEmptyServerName)
|
||||||
|
}
|
||||||
|
return newSTDClient(options.Context, options.Logger, options.ServerAddress, options.Options, options.AllowEmptyServerName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, error) {
|
func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, C.TCPTimeout)
|
|
||||||
defer cancel()
|
|
||||||
tlsConn, err := aTLS.ClientHandshake(ctx, conn, config)
|
tlsConn, err := aTLS.ClientHandshake(ctx, conn, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -52,11 +52,15 @@ type RealityClientConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
func NewRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||||
|
return newRealityClient(ctx, logger, serverAddress, options, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
|
||||||
if options.UTLS == nil || !options.UTLS.Enabled {
|
if options.UTLS == nil || !options.UTLS.Enabled {
|
||||||
return nil, E.New("uTLS is required by reality client")
|
return nil, E.New("uTLS is required by reality client")
|
||||||
}
|
}
|
||||||
|
|
||||||
uClient, err := NewUTLSClient(ctx, logger, serverAddress, options)
|
uClient, err := newUTLSClient(ctx, logger, serverAddress, options, allowEmptyServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -108,6 +112,14 @@ func (e *RealityClientConfig) SetNextProtos(nextProto []string) {
|
|||||||
e.uClient.SetNextProtos(nextProto)
|
e.uClient.SetNextProtos(nextProto)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *RealityClientConfig) HandshakeTimeout() time.Duration {
|
||||||
|
return e.uClient.HandshakeTimeout()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RealityClientConfig) SetHandshakeTimeout(timeout time.Duration) {
|
||||||
|
e.uClient.SetHandshakeTimeout(timeout)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *RealityClientConfig) STDConfig() (*STDConfig, error) {
|
func (e *RealityClientConfig) STDConfig() (*STDConfig, error) {
|
||||||
return nil, E.New("unsupported usage for reality")
|
return nil, E.New("unsupported usage for reality")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,12 +26,17 @@ import (
|
|||||||
var _ ServerConfigCompat = (*RealityServerConfig)(nil)
|
var _ ServerConfigCompat = (*RealityServerConfig)(nil)
|
||||||
|
|
||||||
type RealityServerConfig struct {
|
type RealityServerConfig struct {
|
||||||
config *utls.RealityConfig
|
config *utls.RealityConfig
|
||||||
|
handshakeTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRealityServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) {
|
func NewRealityServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) {
|
||||||
var tlsConfig utls.RealityConfig
|
var tlsConfig utls.RealityConfig
|
||||||
|
|
||||||
|
if options.CertificateProvider != nil {
|
||||||
|
return nil, E.New("certificate_provider is unavailable in reality")
|
||||||
|
}
|
||||||
|
//nolint:staticcheck
|
||||||
if options.ACME != nil && len(options.ACME.Domain) > 0 {
|
if options.ACME != nil && len(options.ACME.Domain) > 0 {
|
||||||
return nil, E.New("acme is unavailable in reality")
|
return nil, E.New("acme is unavailable in reality")
|
||||||
}
|
}
|
||||||
@@ -126,7 +131,16 @@ func NewRealityServer(ctx context.Context, logger log.ContextLogger, options opt
|
|||||||
if options.ECH != nil && options.ECH.Enabled {
|
if options.ECH != nil && options.ECH.Enabled {
|
||||||
return nil, E.New("Reality is conflict with ECH")
|
return nil, E.New("Reality is conflict with ECH")
|
||||||
}
|
}
|
||||||
var config ServerConfig = &RealityServerConfig{&tlsConfig}
|
var handshakeTimeout time.Duration
|
||||||
|
if options.HandshakeTimeout > 0 {
|
||||||
|
handshakeTimeout = options.HandshakeTimeout.Build()
|
||||||
|
} else {
|
||||||
|
handshakeTimeout = C.TCPTimeout
|
||||||
|
}
|
||||||
|
var config ServerConfig = &RealityServerConfig{
|
||||||
|
config: &tlsConfig,
|
||||||
|
handshakeTimeout: handshakeTimeout,
|
||||||
|
}
|
||||||
if options.KernelTx || options.KernelRx {
|
if options.KernelTx || options.KernelRx {
|
||||||
if !C.IsLinux {
|
if !C.IsLinux {
|
||||||
return nil, E.New("kTLS is only supported on Linux")
|
return nil, E.New("kTLS is only supported on Linux")
|
||||||
@@ -157,6 +171,14 @@ func (c *RealityServerConfig) SetNextProtos(nextProto []string) {
|
|||||||
c.config.NextProtos = nextProto
|
c.config.NextProtos = nextProto
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *RealityServerConfig) HandshakeTimeout() time.Duration {
|
||||||
|
return c.handshakeTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RealityServerConfig) SetHandshakeTimeout(timeout time.Duration) {
|
||||||
|
c.handshakeTimeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
func (c *RealityServerConfig) STDConfig() (*tls.Config, error) {
|
func (c *RealityServerConfig) STDConfig() (*tls.Config, error) {
|
||||||
return nil, E.New("unsupported usage for reality")
|
return nil, E.New("unsupported usage for reality")
|
||||||
}
|
}
|
||||||
@@ -187,7 +209,8 @@ func (c *RealityServerConfig) ServerHandshake(ctx context.Context, conn net.Conn
|
|||||||
|
|
||||||
func (c *RealityServerConfig) Clone() Config {
|
func (c *RealityServerConfig) Clone() Config {
|
||||||
return &RealityServerConfig{
|
return &RealityServerConfig{
|
||||||
config: c.config.Clone(),
|
config: c.config.Clone(),
|
||||||
|
handshakeTimeout: c.handshakeTimeout,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -46,8 +46,11 @@ func NewServerWithOptions(options ServerOptions) (ServerConfig, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ServerHandshake(ctx context.Context, conn net.Conn, config ServerConfig) (Conn, error) {
|
func ServerHandshake(ctx context.Context, conn net.Conn, config ServerConfig) (Conn, error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, C.TCPTimeout)
|
if config.HandshakeTimeout() == 0 {
|
||||||
defer cancel()
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, C.TCPTimeout)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
tlsConn, err := aTLS.ServerHandshake(ctx, conn, config)
|
tlsConn, err := aTLS.ServerHandshake(ctx, conn, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -24,16 +24,30 @@ import (
|
|||||||
type STDClientConfig struct {
|
type STDClientConfig struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
config *tls.Config
|
config *tls.Config
|
||||||
|
serverName string
|
||||||
|
disableSNI bool
|
||||||
|
verifyServerName bool
|
||||||
|
handshakeTimeout time.Duration
|
||||||
fragment bool
|
fragment bool
|
||||||
fragmentFallbackDelay time.Duration
|
fragmentFallbackDelay time.Duration
|
||||||
recordFragment bool
|
recordFragment bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *STDClientConfig) ServerName() string {
|
func (c *STDClientConfig) ServerName() string {
|
||||||
return c.config.ServerName
|
return c.serverName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *STDClientConfig) SetServerName(serverName string) {
|
func (c *STDClientConfig) SetServerName(serverName string) {
|
||||||
|
c.serverName = serverName
|
||||||
|
if c.disableSNI {
|
||||||
|
c.config.ServerName = ""
|
||||||
|
if c.verifyServerName {
|
||||||
|
c.config.VerifyConnection = verifyConnection(c.config.RootCAs, c.config.Time, serverName)
|
||||||
|
} else {
|
||||||
|
c.config.VerifyConnection = nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
c.config.ServerName = serverName
|
c.config.ServerName = serverName
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -45,6 +59,14 @@ func (c *STDClientConfig) SetNextProtos(nextProto []string) {
|
|||||||
c.config.NextProtos = nextProto
|
c.config.NextProtos = nextProto
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *STDClientConfig) HandshakeTimeout() time.Duration {
|
||||||
|
return c.handshakeTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *STDClientConfig) SetHandshakeTimeout(timeout time.Duration) {
|
||||||
|
c.handshakeTimeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
func (c *STDClientConfig) STDConfig() (*STDConfig, error) {
|
func (c *STDClientConfig) STDConfig() (*STDConfig, error) {
|
||||||
return c.config, nil
|
return c.config, nil
|
||||||
}
|
}
|
||||||
@@ -57,13 +79,19 @@ func (c *STDClientConfig) Client(conn net.Conn) (Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *STDClientConfig) Clone() Config {
|
func (c *STDClientConfig) Clone() Config {
|
||||||
return &STDClientConfig{
|
cloned := &STDClientConfig{
|
||||||
ctx: c.ctx,
|
ctx: c.ctx,
|
||||||
config: c.config.Clone(),
|
config: c.config.Clone(),
|
||||||
|
serverName: c.serverName,
|
||||||
|
disableSNI: c.disableSNI,
|
||||||
|
verifyServerName: c.verifyServerName,
|
||||||
|
handshakeTimeout: c.handshakeTimeout,
|
||||||
fragment: c.fragment,
|
fragment: c.fragment,
|
||||||
fragmentFallbackDelay: c.fragmentFallbackDelay,
|
fragmentFallbackDelay: c.fragmentFallbackDelay,
|
||||||
recordFragment: c.recordFragment,
|
recordFragment: c.recordFragment,
|
||||||
}
|
}
|
||||||
|
cloned.SetServerName(cloned.serverName)
|
||||||
|
return cloned
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *STDClientConfig) ECHConfigList() []byte {
|
func (c *STDClientConfig) ECHConfigList() []byte {
|
||||||
@@ -75,41 +103,27 @@ func (c *STDClientConfig) SetECHConfigList(EncryptedClientHelloConfigList []byte
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||||
|
return newSTDClient(ctx, logger, serverAddress, options, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
|
||||||
var serverName string
|
var serverName string
|
||||||
if options.ServerName != "" {
|
if options.ServerName != "" {
|
||||||
serverName = options.ServerName
|
serverName = options.ServerName
|
||||||
} else if serverAddress != "" {
|
} else if serverAddress != "" {
|
||||||
serverName = serverAddress
|
serverName = serverAddress
|
||||||
}
|
}
|
||||||
if serverName == "" && !options.Insecure {
|
if serverName == "" && !options.Insecure && !allowEmptyServerName {
|
||||||
return nil, E.New("missing server_name or insecure=true")
|
return nil, errMissingServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
var tlsConfig tls.Config
|
var tlsConfig tls.Config
|
||||||
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
|
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
|
||||||
tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx)
|
tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx)
|
||||||
if !options.DisableSNI {
|
|
||||||
tlsConfig.ServerName = serverName
|
|
||||||
}
|
|
||||||
if options.Insecure {
|
if options.Insecure {
|
||||||
tlsConfig.InsecureSkipVerify = options.Insecure
|
tlsConfig.InsecureSkipVerify = options.Insecure
|
||||||
} else if options.DisableSNI {
|
} else if options.DisableSNI {
|
||||||
tlsConfig.InsecureSkipVerify = true
|
tlsConfig.InsecureSkipVerify = true
|
||||||
tlsConfig.VerifyConnection = func(state tls.ConnectionState) error {
|
|
||||||
verifyOptions := x509.VerifyOptions{
|
|
||||||
Roots: tlsConfig.RootCAs,
|
|
||||||
DNSName: serverName,
|
|
||||||
Intermediates: x509.NewCertPool(),
|
|
||||||
}
|
|
||||||
for _, cert := range state.PeerCertificates[1:] {
|
|
||||||
verifyOptions.Intermediates.AddCert(cert)
|
|
||||||
}
|
|
||||||
if tlsConfig.Time != nil {
|
|
||||||
verifyOptions.CurrentTime = tlsConfig.Time()
|
|
||||||
}
|
|
||||||
_, err := state.PeerCertificates[0].Verify(verifyOptions)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if len(options.CertificatePublicKeySHA256) > 0 {
|
if len(options.CertificatePublicKeySHA256) > 0 {
|
||||||
if len(options.Certificate) > 0 || options.CertificatePath != "" {
|
if len(options.Certificate) > 0 || options.CertificatePath != "" {
|
||||||
@@ -117,7 +131,7 @@ func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddres
|
|||||||
}
|
}
|
||||||
tlsConfig.InsecureSkipVerify = true
|
tlsConfig.InsecureSkipVerify = true
|
||||||
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||||
return verifyPublicKeySHA256(options.CertificatePublicKeySHA256, rawCerts, tlsConfig.Time)
|
return VerifyPublicKeySHA256(options.CertificatePublicKeySHA256, rawCerts)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(options.ALPN) > 0 {
|
if len(options.ALPN) > 0 {
|
||||||
@@ -198,7 +212,24 @@ func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddres
|
|||||||
} else if len(clientCertificate) > 0 || len(clientKey) > 0 {
|
} else if len(clientCertificate) > 0 || len(clientKey) > 0 {
|
||||||
return nil, E.New("client certificate and client key must be provided together")
|
return nil, E.New("client certificate and client key must be provided together")
|
||||||
}
|
}
|
||||||
var config Config = &STDClientConfig{ctx, &tlsConfig, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment}
|
var handshakeTimeout time.Duration
|
||||||
|
if options.HandshakeTimeout > 0 {
|
||||||
|
handshakeTimeout = options.HandshakeTimeout.Build()
|
||||||
|
} else {
|
||||||
|
handshakeTimeout = C.TCPTimeout
|
||||||
|
}
|
||||||
|
var config Config = &STDClientConfig{
|
||||||
|
ctx: ctx,
|
||||||
|
config: &tlsConfig,
|
||||||
|
serverName: serverName,
|
||||||
|
disableSNI: options.DisableSNI,
|
||||||
|
verifyServerName: options.DisableSNI && !options.Insecure,
|
||||||
|
handshakeTimeout: handshakeTimeout,
|
||||||
|
fragment: options.Fragment,
|
||||||
|
fragmentFallbackDelay: time.Duration(options.FragmentFallbackDelay),
|
||||||
|
recordFragment: options.RecordFragment,
|
||||||
|
}
|
||||||
|
config.SetServerName(serverName)
|
||||||
if options.ECH != nil && options.ECH.Enabled {
|
if options.ECH != nil && options.ECH.Enabled {
|
||||||
var err error
|
var err error
|
||||||
config, err = parseECHClientConfig(ctx, config.(ECHCapableConfig), options)
|
config, err = parseECHClientConfig(ctx, config.(ECHCapableConfig), options)
|
||||||
@@ -220,7 +251,28 @@ func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddres
|
|||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func verifyPublicKeySHA256(knownHashValues [][]byte, rawCerts [][]byte, timeFunc func() time.Time) error {
|
func verifyConnection(rootCAs *x509.CertPool, timeFunc func() time.Time, serverName string) func(state tls.ConnectionState) error {
|
||||||
|
return func(state tls.ConnectionState) error {
|
||||||
|
if serverName == "" {
|
||||||
|
return errMissingServerName
|
||||||
|
}
|
||||||
|
verifyOptions := x509.VerifyOptions{
|
||||||
|
Roots: rootCAs,
|
||||||
|
DNSName: serverName,
|
||||||
|
Intermediates: x509.NewCertPool(),
|
||||||
|
}
|
||||||
|
for _, cert := range state.PeerCertificates[1:] {
|
||||||
|
verifyOptions.Intermediates.AddCert(cert)
|
||||||
|
}
|
||||||
|
if timeFunc != nil {
|
||||||
|
verifyOptions.CurrentTime = timeFunc()
|
||||||
|
}
|
||||||
|
_, err := state.PeerCertificates[0].Verify(verifyOptions)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func VerifyPublicKeySHA256(knownHashValues [][]byte, rawCerts [][]byte) error {
|
||||||
leafCertificate, err := x509.ParseCertificate(rawCerts[0])
|
leafCertificate, err := x509.ParseCertificate(rawCerts[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "failed to parse leaf certificate")
|
return E.Cause(err, "failed to parse leaf certificate")
|
||||||
|
|||||||
@@ -13,19 +13,88 @@ import (
|
|||||||
"github.com/sagernet/fswatch"
|
"github.com/sagernet/fswatch"
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/experimental/deprecated"
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
"github.com/sagernet/sing/common/ntp"
|
"github.com/sagernet/sing/common/ntp"
|
||||||
|
"github.com/sagernet/sing/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errInsecureUnused = E.New("tls: insecure unused")
|
var errInsecureUnused = E.New("tls: insecure unused")
|
||||||
|
|
||||||
|
type managedCertificateProvider interface {
|
||||||
|
adapter.CertificateProvider
|
||||||
|
adapter.SimpleLifecycle
|
||||||
|
}
|
||||||
|
|
||||||
|
type sharedCertificateProvider struct {
|
||||||
|
tag string
|
||||||
|
manager adapter.CertificateProviderManager
|
||||||
|
provider adapter.CertificateProviderService
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sharedCertificateProvider) Start() error {
|
||||||
|
provider, found := p.manager.Get(p.tag)
|
||||||
|
if !found {
|
||||||
|
return E.New("certificate provider not found: ", p.tag)
|
||||||
|
}
|
||||||
|
p.provider = provider
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sharedCertificateProvider) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sharedCertificateProvider) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
return p.provider.GetCertificate(hello)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sharedCertificateProvider) GetACMENextProtos() []string {
|
||||||
|
return getACMENextProtos(p.provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
type inlineCertificateProvider struct {
|
||||||
|
provider adapter.CertificateProviderService
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *inlineCertificateProvider) Start() error {
|
||||||
|
for _, stage := range adapter.ListStartStages {
|
||||||
|
err := adapter.LegacyStart(p.provider, stage)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *inlineCertificateProvider) Close() error {
|
||||||
|
return p.provider.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *inlineCertificateProvider) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
return p.provider.GetCertificate(hello)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *inlineCertificateProvider) GetACMENextProtos() []string {
|
||||||
|
return getACMENextProtos(p.provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getACMENextProtos(provider adapter.CertificateProvider) []string {
|
||||||
|
if acmeProvider, isACME := provider.(adapter.ACMECertificateProvider); isACME {
|
||||||
|
return acmeProvider.GetACMENextProtos()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type STDServerConfig struct {
|
type STDServerConfig struct {
|
||||||
access sync.RWMutex
|
access sync.RWMutex
|
||||||
config *tls.Config
|
config *tls.Config
|
||||||
|
handshakeTimeout time.Duration
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
|
certificateProvider managedCertificateProvider
|
||||||
acmeService adapter.SimpleLifecycle
|
acmeService adapter.SimpleLifecycle
|
||||||
certificate []byte
|
certificate []byte
|
||||||
key []byte
|
key []byte
|
||||||
@@ -53,18 +122,17 @@ func (c *STDServerConfig) SetServerName(serverName string) {
|
|||||||
func (c *STDServerConfig) NextProtos() []string {
|
func (c *STDServerConfig) NextProtos() []string {
|
||||||
c.access.RLock()
|
c.access.RLock()
|
||||||
defer c.access.RUnlock()
|
defer c.access.RUnlock()
|
||||||
if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol {
|
if c.hasACMEALPN() && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == C.ACMETLS1Protocol {
|
||||||
return c.config.NextProtos[1:]
|
return c.config.NextProtos[1:]
|
||||||
} else {
|
|
||||||
return c.config.NextProtos
|
|
||||||
}
|
}
|
||||||
|
return c.config.NextProtos
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *STDServerConfig) SetNextProtos(nextProto []string) {
|
func (c *STDServerConfig) SetNextProtos(nextProto []string) {
|
||||||
c.access.Lock()
|
c.access.Lock()
|
||||||
defer c.access.Unlock()
|
defer c.access.Unlock()
|
||||||
config := c.config.Clone()
|
config := c.config.Clone()
|
||||||
if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol {
|
if c.hasACMEALPN() && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == C.ACMETLS1Protocol {
|
||||||
config.NextProtos = append(c.config.NextProtos[:1], nextProto...)
|
config.NextProtos = append(c.config.NextProtos[:1], nextProto...)
|
||||||
} else {
|
} else {
|
||||||
config.NextProtos = nextProto
|
config.NextProtos = nextProto
|
||||||
@@ -72,6 +140,30 @@ func (c *STDServerConfig) SetNextProtos(nextProto []string) {
|
|||||||
c.config = config
|
c.config = config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *STDServerConfig) HandshakeTimeout() time.Duration {
|
||||||
|
c.access.RLock()
|
||||||
|
defer c.access.RUnlock()
|
||||||
|
return c.handshakeTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *STDServerConfig) SetHandshakeTimeout(timeout time.Duration) {
|
||||||
|
c.access.Lock()
|
||||||
|
defer c.access.Unlock()
|
||||||
|
c.handshakeTimeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *STDServerConfig) hasACMEALPN() bool {
|
||||||
|
if c.acmeService != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if c.certificateProvider != nil {
|
||||||
|
if acmeProvider, isACME := c.certificateProvider.(adapter.ACMECertificateProvider); isACME {
|
||||||
|
return len(acmeProvider.GetACMENextProtos()) > 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (c *STDServerConfig) STDConfig() (*STDConfig, error) {
|
func (c *STDServerConfig) STDConfig() (*STDConfig, error) {
|
||||||
return c.config, nil
|
return c.config, nil
|
||||||
}
|
}
|
||||||
@@ -86,20 +178,45 @@ func (c *STDServerConfig) Server(conn net.Conn) (Conn, error) {
|
|||||||
|
|
||||||
func (c *STDServerConfig) Clone() Config {
|
func (c *STDServerConfig) Clone() Config {
|
||||||
return &STDServerConfig{
|
return &STDServerConfig{
|
||||||
config: c.config.Clone(),
|
config: c.config.Clone(),
|
||||||
|
handshakeTimeout: c.handshakeTimeout,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *STDServerConfig) Start() error {
|
func (c *STDServerConfig) Start() error {
|
||||||
if c.acmeService != nil {
|
if c.certificateProvider != nil {
|
||||||
return c.acmeService.Start()
|
err := c.certificateProvider.Start()
|
||||||
} else {
|
|
||||||
err := c.startWatcher()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Warn("create fsnotify watcher: ", err)
|
return err
|
||||||
|
}
|
||||||
|
if acmeProvider, isACME := c.certificateProvider.(adapter.ACMECertificateProvider); isACME {
|
||||||
|
nextProtos := acmeProvider.GetACMENextProtos()
|
||||||
|
if len(nextProtos) > 0 {
|
||||||
|
c.access.Lock()
|
||||||
|
config := c.config.Clone()
|
||||||
|
mergedNextProtos := append([]string{}, nextProtos...)
|
||||||
|
for _, nextProto := range config.NextProtos {
|
||||||
|
if !common.Contains(mergedNextProtos, nextProto) {
|
||||||
|
mergedNextProtos = append(mergedNextProtos, nextProto)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.NextProtos = mergedNextProtos
|
||||||
|
c.config = config
|
||||||
|
c.access.Unlock()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
if c.acmeService != nil {
|
||||||
|
err := c.acmeService.Start()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err := c.startWatcher()
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Warn("create fsnotify watcher: ", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *STDServerConfig) startWatcher() error {
|
func (c *STDServerConfig) startWatcher() error {
|
||||||
@@ -203,23 +320,34 @@ func (c *STDServerConfig) certificateUpdated(path string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *STDServerConfig) Close() error {
|
func (c *STDServerConfig) Close() error {
|
||||||
if c.acmeService != nil {
|
return common.Close(c.certificateProvider, c.acmeService, c.watcher)
|
||||||
return c.acmeService.Close()
|
|
||||||
}
|
|
||||||
if c.watcher != nil {
|
|
||||||
return c.watcher.Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) {
|
func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) {
|
||||||
if !options.Enabled {
|
if !options.Enabled {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
//nolint:staticcheck
|
||||||
|
if options.CertificateProvider != nil && options.ACME != nil {
|
||||||
|
return nil, E.New("certificate_provider and acme are mutually exclusive")
|
||||||
|
}
|
||||||
var tlsConfig *tls.Config
|
var tlsConfig *tls.Config
|
||||||
|
var certificateProvider managedCertificateProvider
|
||||||
var acmeService adapter.SimpleLifecycle
|
var acmeService adapter.SimpleLifecycle
|
||||||
var err error
|
var err error
|
||||||
if options.ACME != nil && len(options.ACME.Domain) > 0 {
|
if options.CertificateProvider != nil {
|
||||||
|
certificateProvider, err = newCertificateProvider(ctx, logger, options.CertificateProvider)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tlsConfig = &tls.Config{
|
||||||
|
GetCertificate: certificateProvider.GetCertificate,
|
||||||
|
}
|
||||||
|
if options.Insecure {
|
||||||
|
return nil, errInsecureUnused
|
||||||
|
}
|
||||||
|
} else if options.ACME != nil && len(options.ACME.Domain) > 0 { //nolint:staticcheck
|
||||||
|
deprecated.Report(ctx, deprecated.OptionInlineACME)
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
tlsConfig, acmeService, err = startACME(ctx, logger, common.PtrValueOrDefault(options.ACME))
|
tlsConfig, acmeService, err = startACME(ctx, logger, common.PtrValueOrDefault(options.ACME))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -272,7 +400,7 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.
|
|||||||
certificate []byte
|
certificate []byte
|
||||||
key []byte
|
key []byte
|
||||||
)
|
)
|
||||||
if acmeService == nil {
|
if certificateProvider == nil && acmeService == nil {
|
||||||
if len(options.Certificate) > 0 {
|
if len(options.Certificate) > 0 {
|
||||||
certificate = []byte(strings.Join(options.Certificate, "\n"))
|
certificate = []byte(strings.Join(options.Certificate, "\n"))
|
||||||
} else if options.CertificatePath != "" {
|
} else if options.CertificatePath != "" {
|
||||||
@@ -344,7 +472,7 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.
|
|||||||
tlsConfig.ClientAuth = tls.RequestClientCert
|
tlsConfig.ClientAuth = tls.RequestClientCert
|
||||||
}
|
}
|
||||||
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||||
return verifyPublicKeySHA256(options.ClientCertificatePublicKeySHA256, rawCerts, tlsConfig.Time)
|
return VerifyPublicKeySHA256(options.ClientCertificatePublicKeySHA256, rawCerts)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return nil, E.New("missing client_certificate, client_certificate_path or client_certificate_public_key_sha256 for client authentication")
|
return nil, E.New("missing client_certificate, client_certificate_path or client_certificate_public_key_sha256 for client authentication")
|
||||||
@@ -357,9 +485,17 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
var handshakeTimeout time.Duration
|
||||||
|
if options.HandshakeTimeout > 0 {
|
||||||
|
handshakeTimeout = options.HandshakeTimeout.Build()
|
||||||
|
} else {
|
||||||
|
handshakeTimeout = C.TCPTimeout
|
||||||
|
}
|
||||||
serverConfig := &STDServerConfig{
|
serverConfig := &STDServerConfig{
|
||||||
config: tlsConfig,
|
config: tlsConfig,
|
||||||
|
handshakeTimeout: handshakeTimeout,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
certificateProvider: certificateProvider,
|
||||||
acmeService: acmeService,
|
acmeService: acmeService,
|
||||||
certificate: certificate,
|
certificate: certificate,
|
||||||
key: key,
|
key: key,
|
||||||
@@ -369,8 +505,8 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.
|
|||||||
echKeyPath: echKeyPath,
|
echKeyPath: echKeyPath,
|
||||||
}
|
}
|
||||||
serverConfig.config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
serverConfig.config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
serverConfig.access.Lock()
|
serverConfig.access.RLock()
|
||||||
defer serverConfig.access.Unlock()
|
defer serverConfig.access.RUnlock()
|
||||||
return serverConfig.config, nil
|
return serverConfig.config, nil
|
||||||
}
|
}
|
||||||
var config ServerConfig = serverConfig
|
var config ServerConfig = serverConfig
|
||||||
@@ -387,3 +523,27 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.
|
|||||||
}
|
}
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newCertificateProvider(ctx context.Context, logger log.ContextLogger, options *option.CertificateProviderOptions) (managedCertificateProvider, error) {
|
||||||
|
if options.IsShared() {
|
||||||
|
manager := service.FromContext[adapter.CertificateProviderManager](ctx)
|
||||||
|
if manager == nil {
|
||||||
|
return nil, E.New("missing certificate provider manager in context")
|
||||||
|
}
|
||||||
|
return &sharedCertificateProvider{
|
||||||
|
tag: options.Tag,
|
||||||
|
manager: manager,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
registry := service.FromContext[adapter.CertificateProviderRegistry](ctx)
|
||||||
|
if registry == nil {
|
||||||
|
return nil, E.New("missing certificate provider registry in context")
|
||||||
|
}
|
||||||
|
provider, err := registry.Create(ctx, logger, "", options.Type, options.Options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "create inline certificate provider")
|
||||||
|
}
|
||||||
|
return &inlineCertificateProvider{
|
||||||
|
provider: provider,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ import (
|
|||||||
type UTLSClientConfig struct {
|
type UTLSClientConfig struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
config *utls.Config
|
config *utls.Config
|
||||||
|
serverName string
|
||||||
|
disableSNI bool
|
||||||
|
verifyServerName bool
|
||||||
|
handshakeTimeout time.Duration
|
||||||
id utls.ClientHelloID
|
id utls.ClientHelloID
|
||||||
fragment bool
|
fragment bool
|
||||||
fragmentFallbackDelay time.Duration
|
fragmentFallbackDelay time.Duration
|
||||||
@@ -35,10 +39,20 @@ type UTLSClientConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *UTLSClientConfig) ServerName() string {
|
func (c *UTLSClientConfig) ServerName() string {
|
||||||
return c.config.ServerName
|
return c.serverName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UTLSClientConfig) SetServerName(serverName string) {
|
func (c *UTLSClientConfig) SetServerName(serverName string) {
|
||||||
|
c.serverName = serverName
|
||||||
|
if c.disableSNI {
|
||||||
|
c.config.ServerName = ""
|
||||||
|
if c.verifyServerName {
|
||||||
|
c.config.InsecureServerNameToVerify = serverName
|
||||||
|
} else {
|
||||||
|
c.config.InsecureServerNameToVerify = ""
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
c.config.ServerName = serverName
|
c.config.ServerName = serverName
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,6 +67,14 @@ func (c *UTLSClientConfig) SetNextProtos(nextProto []string) {
|
|||||||
c.config.NextProtos = nextProto
|
c.config.NextProtos = nextProto
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *UTLSClientConfig) HandshakeTimeout() time.Duration {
|
||||||
|
return c.handshakeTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UTLSClientConfig) SetHandshakeTimeout(timeout time.Duration) {
|
||||||
|
c.handshakeTimeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
func (c *UTLSClientConfig) STDConfig() (*STDConfig, error) {
|
func (c *UTLSClientConfig) STDConfig() (*STDConfig, error) {
|
||||||
return nil, E.New("unsupported usage for uTLS")
|
return nil, E.New("unsupported usage for uTLS")
|
||||||
}
|
}
|
||||||
@@ -69,9 +91,20 @@ func (c *UTLSClientConfig) SetSessionIDGenerator(generator func(clientHello []by
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *UTLSClientConfig) Clone() Config {
|
func (c *UTLSClientConfig) Clone() Config {
|
||||||
return &UTLSClientConfig{
|
cloned := &UTLSClientConfig{
|
||||||
c.ctx, c.config.Clone(), c.id, c.fragment, c.fragmentFallbackDelay, c.recordFragment,
|
ctx: c.ctx,
|
||||||
|
config: c.config.Clone(),
|
||||||
|
serverName: c.serverName,
|
||||||
|
disableSNI: c.disableSNI,
|
||||||
|
verifyServerName: c.verifyServerName,
|
||||||
|
handshakeTimeout: c.handshakeTimeout,
|
||||||
|
id: c.id,
|
||||||
|
fragment: c.fragment,
|
||||||
|
fragmentFallbackDelay: c.fragmentFallbackDelay,
|
||||||
|
recordFragment: c.recordFragment,
|
||||||
}
|
}
|
||||||
|
cloned.SetServerName(cloned.serverName)
|
||||||
|
return cloned
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UTLSClientConfig) ECHConfigList() []byte {
|
func (c *UTLSClientConfig) ECHConfigList() []byte {
|
||||||
@@ -143,29 +176,29 @@ func (c *utlsALPNWrapper) HandshakeContext(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||||
|
return newUTLSClient(ctx, logger, serverAddress, options, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
|
||||||
var serverName string
|
var serverName string
|
||||||
if options.ServerName != "" {
|
if options.ServerName != "" {
|
||||||
serverName = options.ServerName
|
serverName = options.ServerName
|
||||||
} else if serverAddress != "" {
|
} else if serverAddress != "" {
|
||||||
serverName = serverAddress
|
serverName = serverAddress
|
||||||
}
|
}
|
||||||
if serverName == "" && !options.Insecure {
|
if serverName == "" && !options.Insecure && !allowEmptyServerName {
|
||||||
return nil, E.New("missing server_name or insecure=true")
|
return nil, errMissingServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
var tlsConfig utls.Config
|
var tlsConfig utls.Config
|
||||||
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
|
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
|
||||||
tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx)
|
tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx)
|
||||||
if !options.DisableSNI {
|
|
||||||
tlsConfig.ServerName = serverName
|
|
||||||
}
|
|
||||||
if options.Insecure {
|
if options.Insecure {
|
||||||
tlsConfig.InsecureSkipVerify = options.Insecure
|
tlsConfig.InsecureSkipVerify = options.Insecure
|
||||||
} else if options.DisableSNI {
|
} else if options.DisableSNI {
|
||||||
if options.Reality != nil && options.Reality.Enabled {
|
if options.Reality != nil && options.Reality.Enabled {
|
||||||
return nil, E.New("disable_sni is unsupported in reality")
|
return nil, E.New("disable_sni is unsupported in reality")
|
||||||
}
|
}
|
||||||
tlsConfig.InsecureServerNameToVerify = serverName
|
|
||||||
}
|
}
|
||||||
if len(options.CertificatePublicKeySHA256) > 0 {
|
if len(options.CertificatePublicKeySHA256) > 0 {
|
||||||
if len(options.Certificate) > 0 || options.CertificatePath != "" {
|
if len(options.Certificate) > 0 || options.CertificatePath != "" {
|
||||||
@@ -173,7 +206,7 @@ func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddre
|
|||||||
}
|
}
|
||||||
tlsConfig.InsecureSkipVerify = true
|
tlsConfig.InsecureSkipVerify = true
|
||||||
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||||
return verifyPublicKeySHA256(options.CertificatePublicKeySHA256, rawCerts, tlsConfig.Time)
|
return VerifyPublicKeySHA256(options.CertificatePublicKeySHA256, rawCerts)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(options.ALPN) > 0 {
|
if len(options.ALPN) > 0 {
|
||||||
@@ -251,11 +284,29 @@ func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddre
|
|||||||
} else if len(clientCertificate) > 0 || len(clientKey) > 0 {
|
} else if len(clientCertificate) > 0 || len(clientKey) > 0 {
|
||||||
return nil, E.New("client certificate and client key must be provided together")
|
return nil, E.New("client certificate and client key must be provided together")
|
||||||
}
|
}
|
||||||
|
var handshakeTimeout time.Duration
|
||||||
|
if options.HandshakeTimeout > 0 {
|
||||||
|
handshakeTimeout = options.HandshakeTimeout.Build()
|
||||||
|
} else {
|
||||||
|
handshakeTimeout = C.TCPTimeout
|
||||||
|
}
|
||||||
id, err := uTLSClientHelloID(options.UTLS.Fingerprint)
|
id, err := uTLSClientHelloID(options.UTLS.Fingerprint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var config Config = &UTLSClientConfig{ctx, &tlsConfig, id, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment}
|
var config Config = &UTLSClientConfig{
|
||||||
|
ctx: ctx,
|
||||||
|
config: &tlsConfig,
|
||||||
|
serverName: serverName,
|
||||||
|
disableSNI: options.DisableSNI,
|
||||||
|
verifyServerName: options.DisableSNI && !options.Insecure,
|
||||||
|
handshakeTimeout: handshakeTimeout,
|
||||||
|
id: id,
|
||||||
|
fragment: options.Fragment,
|
||||||
|
fragmentFallbackDelay: time.Duration(options.FragmentFallbackDelay),
|
||||||
|
recordFragment: options.RecordFragment,
|
||||||
|
}
|
||||||
|
config.SetServerName(serverName)
|
||||||
if options.ECH != nil && options.ECH.Enabled {
|
if options.ECH != nil && options.ECH.Enabled {
|
||||||
if options.Reality != nil && options.Reality.Enabled {
|
if options.Reality != nil && options.Reality.Enabled {
|
||||||
return nil, E.New("Reality is conflict with ECH")
|
return nil, E.New("Reality is conflict with ECH")
|
||||||
|
|||||||
@@ -12,10 +12,18 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||||
|
return newUTLSClient(ctx, logger, serverAddress, options, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
|
||||||
return nil, E.New(`uTLS is not included in this build, rebuild with -tags with_utls`)
|
return nil, E.New(`uTLS is not included in this build, rebuild with -tags with_utls`)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
func NewRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||||
|
return newRealityClient(ctx, logger, serverAddress, options, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
|
||||||
return nil, E.New(`uTLS, which is required by reality is not included in this build, rebuild with -tags with_utls`)
|
return nil, E.New(`uTLS, which is required by reality is not included in this build, rebuild with -tags with_utls`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,19 +15,18 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
DNSTypeLegacy = "legacy"
|
DNSTypeLegacy = "legacy"
|
||||||
DNSTypeLegacyRcode = "legacy_rcode"
|
DNSTypeUDP = "udp"
|
||||||
DNSTypeUDP = "udp"
|
DNSTypeTCP = "tcp"
|
||||||
DNSTypeTCP = "tcp"
|
DNSTypeTLS = "tls"
|
||||||
DNSTypeTLS = "tls"
|
DNSTypeHTTPS = "https"
|
||||||
DNSTypeHTTPS = "https"
|
DNSTypeQUIC = "quic"
|
||||||
DNSTypeQUIC = "quic"
|
DNSTypeHTTP3 = "h3"
|
||||||
DNSTypeHTTP3 = "h3"
|
DNSTypeLocal = "local"
|
||||||
DNSTypeLocal = "local"
|
DNSTypeHosts = "hosts"
|
||||||
DNSTypeHosts = "hosts"
|
DNSTypeFakeIP = "fakeip"
|
||||||
DNSTypeFakeIP = "fakeip"
|
DNSTypeDHCP = "dhcp"
|
||||||
DNSTypeDHCP = "dhcp"
|
DNSTypeTailscale = "tailscale"
|
||||||
DNSTypeTailscale = "tailscale"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -1,36 +1,39 @@
|
|||||||
package constant
|
package constant
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TypeTun = "tun"
|
TypeTun = "tun"
|
||||||
TypeRedirect = "redirect"
|
TypeRedirect = "redirect"
|
||||||
TypeTProxy = "tproxy"
|
TypeTProxy = "tproxy"
|
||||||
TypeDirect = "direct"
|
TypeDirect = "direct"
|
||||||
TypeBlock = "block"
|
TypeBlock = "block"
|
||||||
TypeDNS = "dns"
|
TypeDNS = "dns"
|
||||||
TypeSOCKS = "socks"
|
TypeSOCKS = "socks"
|
||||||
TypeHTTP = "http"
|
TypeHTTP = "http"
|
||||||
TypeMixed = "mixed"
|
TypeMixed = "mixed"
|
||||||
TypeShadowsocks = "shadowsocks"
|
TypeShadowsocks = "shadowsocks"
|
||||||
TypeVMess = "vmess"
|
TypeVMess = "vmess"
|
||||||
TypeTrojan = "trojan"
|
TypeTrojan = "trojan"
|
||||||
TypeNaive = "naive"
|
TypeNaive = "naive"
|
||||||
TypeWireGuard = "wireguard"
|
TypeWireGuard = "wireguard"
|
||||||
TypeHysteria = "hysteria"
|
TypeHysteria = "hysteria"
|
||||||
TypeTor = "tor"
|
TypeTor = "tor"
|
||||||
TypeSSH = "ssh"
|
TypeSSH = "ssh"
|
||||||
TypeShadowTLS = "shadowtls"
|
TypeShadowTLS = "shadowtls"
|
||||||
TypeAnyTLS = "anytls"
|
TypeAnyTLS = "anytls"
|
||||||
TypeShadowsocksR = "shadowsocksr"
|
TypeShadowsocksR = "shadowsocksr"
|
||||||
TypeVLESS = "vless"
|
TypeVLESS = "vless"
|
||||||
TypeTUIC = "tuic"
|
TypeTUIC = "tuic"
|
||||||
TypeHysteria2 = "hysteria2"
|
TypeHysteria2 = "hysteria2"
|
||||||
TypeTailscale = "tailscale"
|
TypeTailscale = "tailscale"
|
||||||
TypeDERP = "derp"
|
TypeCloudflared = "cloudflared"
|
||||||
TypeResolved = "resolved"
|
TypeDERP = "derp"
|
||||||
TypeSSMAPI = "ssm-api"
|
TypeResolved = "resolved"
|
||||||
TypeCCM = "ccm"
|
TypeSSMAPI = "ssm-api"
|
||||||
TypeOCM = "ocm"
|
TypeCCM = "ccm"
|
||||||
TypeOOMKiller = "oom-killer"
|
TypeOCM = "ocm"
|
||||||
|
TypeOOMKiller = "oom-killer"
|
||||||
|
TypeACME = "acme"
|
||||||
|
TypeCloudflareOriginCA = "cloudflare-origin-ca"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -38,13 +41,6 @@ const (
|
|||||||
TypeURLTest = "urltest"
|
TypeURLTest = "urltest"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
BalancerStrategyLeastUsed = "least-used"
|
|
||||||
BalancerStrategyRoundRobin = "round-robin"
|
|
||||||
BalancerStrategyRandom = "random"
|
|
||||||
BalancerStrategyFallback = "fallback"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ProxyDisplayName(proxyType string) string {
|
func ProxyDisplayName(proxyType string) string {
|
||||||
switch proxyType {
|
switch proxyType {
|
||||||
case TypeTun:
|
case TypeTun:
|
||||||
@@ -95,6 +91,8 @@ func ProxyDisplayName(proxyType string) string {
|
|||||||
return "AnyTLS"
|
return "AnyTLS"
|
||||||
case TypeTailscale:
|
case TypeTailscale:
|
||||||
return "Tailscale"
|
return "Tailscale"
|
||||||
|
case TypeCloudflared:
|
||||||
|
return "Cloudflared"
|
||||||
case TypeSelector:
|
case TypeSelector:
|
||||||
return "Selector"
|
return "Selector"
|
||||||
case TypeURLTest:
|
case TypeURLTest:
|
||||||
|
|||||||
@@ -23,12 +23,15 @@ const (
|
|||||||
RuleSetVersion2
|
RuleSetVersion2
|
||||||
RuleSetVersion3
|
RuleSetVersion3
|
||||||
RuleSetVersion4
|
RuleSetVersion4
|
||||||
RuleSetVersionCurrent = RuleSetVersion4
|
RuleSetVersion5
|
||||||
|
RuleSetVersionCurrent = RuleSetVersion5
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RuleActionTypeRoute = "route"
|
RuleActionTypeRoute = "route"
|
||||||
RuleActionTypeRouteOptions = "route-options"
|
RuleActionTypeRouteOptions = "route-options"
|
||||||
|
RuleActionTypeEvaluate = "evaluate"
|
||||||
|
RuleActionTypeRespond = "respond"
|
||||||
RuleActionTypeDirect = "direct"
|
RuleActionTypeDirect = "direct"
|
||||||
RuleActionTypeBypass = "bypass"
|
RuleActionTypeBypass = "bypass"
|
||||||
RuleActionTypeReject = "reject"
|
RuleActionTypeReject = "reject"
|
||||||
|
|||||||
8
constant/tls.go
Normal file
8
constant/tls.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
const ACMETLS1Protocol = "acme-tls/1"
|
||||||
|
|
||||||
|
const (
|
||||||
|
TLSEngineDefault = ""
|
||||||
|
TLSEngineApple = "apple"
|
||||||
|
)
|
||||||
@@ -87,12 +87,17 @@ func (s *StartedService) newInstance(profileContent string, overrideOptions *Ove
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if s.oomKiller && C.IsIos {
|
if s.oomKillerEnabled {
|
||||||
if !common.Any(options.Services, func(it option.Service) bool {
|
if !common.Any(options.Services, func(it option.Service) bool {
|
||||||
return it.Type == C.TypeOOMKiller
|
return it.Type == C.TypeOOMKiller
|
||||||
}) {
|
}) {
|
||||||
|
oomOptions := &option.OOMKillerServiceOptions{
|
||||||
|
KillerDisabled: s.oomKillerDisabled,
|
||||||
|
MemoryLimitOverride: s.oomMemoryLimit,
|
||||||
|
}
|
||||||
options.Services = append(options.Services, option.Service{
|
options.Services = append(options.Services, option.Service{
|
||||||
Type: C.TypeOOMKiller,
|
Type: C.TypeOOMKiller,
|
||||||
|
Options: oomOptions,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,5 +5,6 @@ type PlatformHandler interface {
|
|||||||
ServiceReload() error
|
ServiceReload() error
|
||||||
SystemProxyStatus() (*SystemProxyStatus, error)
|
SystemProxyStatus() (*SystemProxyStatus, error)
|
||||||
SetSystemProxyEnabled(enabled bool) error
|
SetSystemProxyEnabled(enabled bool) error
|
||||||
|
TriggerNativeCrash() error
|
||||||
WriteDebugMessage(message string)
|
WriteDebugMessage(message string)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,14 +6,20 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
|
"github.com/sagernet/sing-box/common/networkquality"
|
||||||
|
"github.com/sagernet/sing-box/common/stun"
|
||||||
"github.com/sagernet/sing-box/common/urltest"
|
"github.com/sagernet/sing-box/common/urltest"
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/experimental/clashapi"
|
"github.com/sagernet/sing-box/experimental/clashapi"
|
||||||
"github.com/sagernet/sing-box/experimental/clashapi/trafficontrol"
|
"github.com/sagernet/sing-box/experimental/clashapi/trafficontrol"
|
||||||
"github.com/sagernet/sing-box/experimental/deprecated"
|
"github.com/sagernet/sing-box/experimental/deprecated"
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
"github.com/sagernet/sing-box/protocol/group"
|
"github.com/sagernet/sing-box/protocol/group"
|
||||||
|
"github.com/sagernet/sing-box/service/oomkiller"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/batch"
|
"github.com/sagernet/sing/common/batch"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
@@ -24,6 +30,8 @@ import (
|
|||||||
|
|
||||||
"github.com/gofrs/uuid/v5"
|
"github.com/gofrs/uuid/v5"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/types/known/emptypb"
|
"google.golang.org/protobuf/types/known/emptypb"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -32,10 +40,12 @@ var _ StartedServiceServer = (*StartedService)(nil)
|
|||||||
type StartedService struct {
|
type StartedService struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
// platform adapter.PlatformInterface
|
// platform adapter.PlatformInterface
|
||||||
handler PlatformHandler
|
handler PlatformHandler
|
||||||
debug bool
|
debug bool
|
||||||
logMaxLines int
|
logMaxLines int
|
||||||
oomKiller bool
|
oomKillerEnabled bool
|
||||||
|
oomKillerDisabled bool
|
||||||
|
oomMemoryLimit uint64
|
||||||
// workingDirectory string
|
// workingDirectory string
|
||||||
// tempDirectory string
|
// tempDirectory string
|
||||||
// userID int
|
// userID int
|
||||||
@@ -64,10 +74,12 @@ type StartedService struct {
|
|||||||
type ServiceOptions struct {
|
type ServiceOptions struct {
|
||||||
Context context.Context
|
Context context.Context
|
||||||
// Platform adapter.PlatformInterface
|
// Platform adapter.PlatformInterface
|
||||||
Handler PlatformHandler
|
Handler PlatformHandler
|
||||||
Debug bool
|
Debug bool
|
||||||
LogMaxLines int
|
LogMaxLines int
|
||||||
OOMKiller bool
|
OOMKillerEnabled bool
|
||||||
|
OOMKillerDisabled bool
|
||||||
|
OOMMemoryLimit uint64
|
||||||
// WorkingDirectory string
|
// WorkingDirectory string
|
||||||
// TempDirectory string
|
// TempDirectory string
|
||||||
// UserID int
|
// UserID int
|
||||||
@@ -79,10 +91,12 @@ func NewStartedService(options ServiceOptions) *StartedService {
|
|||||||
s := &StartedService{
|
s := &StartedService{
|
||||||
ctx: options.Context,
|
ctx: options.Context,
|
||||||
// platform: options.Platform,
|
// platform: options.Platform,
|
||||||
handler: options.Handler,
|
handler: options.Handler,
|
||||||
debug: options.Debug,
|
debug: options.Debug,
|
||||||
logMaxLines: options.LogMaxLines,
|
logMaxLines: options.LogMaxLines,
|
||||||
oomKiller: options.OOMKiller,
|
oomKillerEnabled: options.OOMKillerEnabled,
|
||||||
|
oomKillerDisabled: options.OOMKillerDisabled,
|
||||||
|
oomMemoryLimit: options.OOMMemoryLimit,
|
||||||
// workingDirectory: options.WorkingDirectory,
|
// workingDirectory: options.WorkingDirectory,
|
||||||
// tempDirectory: options.TempDirectory,
|
// tempDirectory: options.TempDirectory,
|
||||||
// userID: options.UserID,
|
// userID: options.UserID,
|
||||||
@@ -168,7 +182,7 @@ func (s *StartedService) waitForStarted(ctx context.Context) error {
|
|||||||
func (s *StartedService) StartOrReloadService(profileContent string, options *OverrideOptions) error {
|
func (s *StartedService) StartOrReloadService(profileContent string, options *OverrideOptions) error {
|
||||||
s.serviceAccess.Lock()
|
s.serviceAccess.Lock()
|
||||||
switch s.serviceStatus.Status {
|
switch s.serviceStatus.Status {
|
||||||
case ServiceStatus_IDLE, ServiceStatus_STARTED, ServiceStatus_STARTING:
|
case ServiceStatus_IDLE, ServiceStatus_STARTED, ServiceStatus_STARTING, ServiceStatus_FATAL:
|
||||||
default:
|
default:
|
||||||
s.serviceAccess.Unlock()
|
s.serviceAccess.Unlock()
|
||||||
return os.ErrInvalid
|
return os.ErrInvalid
|
||||||
@@ -226,13 +240,14 @@ func (s *StartedService) CloseService() error {
|
|||||||
return os.ErrInvalid
|
return os.ErrInvalid
|
||||||
}
|
}
|
||||||
s.updateStatus(ServiceStatus_STOPPING)
|
s.updateStatus(ServiceStatus_STOPPING)
|
||||||
if s.instance != nil {
|
instance := s.instance
|
||||||
err := s.instance.Close()
|
s.instance = nil
|
||||||
|
if instance != nil {
|
||||||
|
err := instance.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s.updateStatusError(err)
|
return s.updateStatusError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.instance = nil
|
|
||||||
s.startedAt = time.Time{}
|
s.startedAt = time.Time{}
|
||||||
s.updateStatus(ServiceStatus_IDLE)
|
s.updateStatus(ServiceStatus_IDLE)
|
||||||
s.serviceAccess.Unlock()
|
s.serviceAccess.Unlock()
|
||||||
@@ -681,7 +696,42 @@ func (s *StartedService) SetSystemProxyEnabled(ctx context.Context, request *Set
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return nil, err
|
return &emptypb.Empty{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StartedService) TriggerDebugCrash(ctx context.Context, request *DebugCrashRequest) (*emptypb.Empty, error) {
|
||||||
|
if !s.debug {
|
||||||
|
return nil, status.Error(codes.PermissionDenied, "debug crash trigger unavailable")
|
||||||
|
}
|
||||||
|
if request == nil {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "missing debug crash request")
|
||||||
|
}
|
||||||
|
switch request.Type {
|
||||||
|
case DebugCrashRequest_GO:
|
||||||
|
time.AfterFunc(200*time.Millisecond, func() {
|
||||||
|
*(*int)(unsafe.Pointer(uintptr(0))) = 0
|
||||||
|
})
|
||||||
|
case DebugCrashRequest_NATIVE:
|
||||||
|
err := s.handler.TriggerNativeCrash()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "unknown debug crash type")
|
||||||
|
}
|
||||||
|
return &emptypb.Empty{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StartedService) TriggerOOMReport(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
|
||||||
|
instance := s.Instance()
|
||||||
|
if instance == nil {
|
||||||
|
return nil, status.Error(codes.FailedPrecondition, "service not started")
|
||||||
|
}
|
||||||
|
reporter := service.FromContext[oomkiller.OOMReporter](instance.ctx)
|
||||||
|
if reporter == nil {
|
||||||
|
return nil, status.Error(codes.Unavailable, "OOM reporter not available")
|
||||||
|
}
|
||||||
|
return &emptypb.Empty{}, reporter.WriteReport(memory.Total())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StartedService) SubscribeConnections(request *SubscribeConnectionsRequest, server grpc.ServerStreamingServer[ConnectionEvents]) error {
|
func (s *StartedService) SubscribeConnections(request *SubscribeConnectionsRequest, server grpc.ServerStreamingServer[ConnectionEvents]) error {
|
||||||
@@ -949,11 +999,11 @@ func buildConnectionProto(metadata *trafficontrol.TrackerMetadata) *Connection {
|
|||||||
var processInfo *ProcessInfo
|
var processInfo *ProcessInfo
|
||||||
if metadata.Metadata.ProcessInfo != nil {
|
if metadata.Metadata.ProcessInfo != nil {
|
||||||
processInfo = &ProcessInfo{
|
processInfo = &ProcessInfo{
|
||||||
ProcessId: metadata.Metadata.ProcessInfo.ProcessID,
|
ProcessId: metadata.Metadata.ProcessInfo.ProcessID,
|
||||||
UserId: metadata.Metadata.ProcessInfo.UserId,
|
UserId: metadata.Metadata.ProcessInfo.UserId,
|
||||||
UserName: metadata.Metadata.ProcessInfo.UserName,
|
UserName: metadata.Metadata.ProcessInfo.UserName,
|
||||||
ProcessPath: metadata.Metadata.ProcessInfo.ProcessPath,
|
ProcessPath: metadata.Metadata.ProcessInfo.ProcessPath,
|
||||||
PackageName: metadata.Metadata.ProcessInfo.AndroidPackageName,
|
PackageNames: metadata.Metadata.ProcessInfo.AndroidPackageNames,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &Connection{
|
return &Connection{
|
||||||
@@ -1018,9 +1068,12 @@ func (s *StartedService) GetDeprecatedWarnings(ctx context.Context, empty *empty
|
|||||||
return &DeprecatedWarnings{
|
return &DeprecatedWarnings{
|
||||||
Warnings: common.Map(notes, func(it deprecated.Note) *DeprecatedWarning {
|
Warnings: common.Map(notes, func(it deprecated.Note) *DeprecatedWarning {
|
||||||
return &DeprecatedWarning{
|
return &DeprecatedWarning{
|
||||||
Message: it.Message(),
|
Message: it.Message(),
|
||||||
Impending: it.Impending(),
|
Impending: it.Impending(),
|
||||||
MigrationLink: it.MigrationLink,
|
MigrationLink: it.MigrationLink,
|
||||||
|
Description: it.Description,
|
||||||
|
DeprecatedVersion: it.DeprecatedVersion,
|
||||||
|
ScheduledVersion: it.ScheduledVersion,
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
}, nil
|
}, nil
|
||||||
@@ -1032,6 +1085,386 @@ func (s *StartedService) GetStartedAt(ctx context.Context, empty *emptypb.Empty)
|
|||||||
return &StartedAt{StartedAt: s.startedAt.UnixMilli()}, nil
|
return &StartedAt{StartedAt: s.startedAt.UnixMilli()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *StartedService) SubscribeOutbounds(_ *emptypb.Empty, server grpc.ServerStreamingServer[OutboundList]) error {
|
||||||
|
err := s.waitForStarted(server.Context())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
subscription, done, err := s.urlTestObserver.Subscribe()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer s.urlTestObserver.UnSubscribe(subscription)
|
||||||
|
for {
|
||||||
|
s.serviceAccess.RLock()
|
||||||
|
if s.serviceStatus.Status != ServiceStatus_STARTED {
|
||||||
|
s.serviceAccess.RUnlock()
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
boxService := s.instance
|
||||||
|
s.serviceAccess.RUnlock()
|
||||||
|
historyStorage := boxService.urlTestHistoryStorage
|
||||||
|
var list OutboundList
|
||||||
|
for _, ob := range boxService.instance.Outbound().Outbounds() {
|
||||||
|
item := &GroupItem{
|
||||||
|
Tag: ob.Tag(),
|
||||||
|
Type: ob.Type(),
|
||||||
|
}
|
||||||
|
if history := historyStorage.LoadURLTestHistory(adapter.OutboundTag(ob)); history != nil {
|
||||||
|
item.UrlTestTime = history.Time.Unix()
|
||||||
|
item.UrlTestDelay = int32(history.Delay)
|
||||||
|
}
|
||||||
|
list.Outbounds = append(list.Outbounds, item)
|
||||||
|
}
|
||||||
|
for _, ep := range boxService.instance.Endpoint().Endpoints() {
|
||||||
|
item := &GroupItem{
|
||||||
|
Tag: ep.Tag(),
|
||||||
|
Type: ep.Type(),
|
||||||
|
}
|
||||||
|
if history := historyStorage.LoadURLTestHistory(adapter.OutboundTag(ep)); history != nil {
|
||||||
|
item.UrlTestTime = history.Time.Unix()
|
||||||
|
item.UrlTestDelay = int32(history.Delay)
|
||||||
|
}
|
||||||
|
list.Outbounds = append(list.Outbounds, item)
|
||||||
|
}
|
||||||
|
err = server.Send(&list)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-subscription:
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return s.ctx.Err()
|
||||||
|
case <-server.Context().Done():
|
||||||
|
return server.Context().Err()
|
||||||
|
case <-done:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveOutbound(instance *Instance, tag string) (adapter.Outbound, error) {
|
||||||
|
if tag == "" {
|
||||||
|
return instance.instance.Outbound().Default(), nil
|
||||||
|
}
|
||||||
|
outbound, loaded := instance.instance.Outbound().Outbound(tag)
|
||||||
|
if !loaded {
|
||||||
|
return nil, E.New("outbound not found: ", tag)
|
||||||
|
}
|
||||||
|
return outbound, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StartedService) StartNetworkQualityTest(
|
||||||
|
request *NetworkQualityTestRequest,
|
||||||
|
server grpc.ServerStreamingServer[NetworkQualityTestProgress],
|
||||||
|
) error {
|
||||||
|
err := s.waitForStarted(server.Context())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.serviceAccess.RLock()
|
||||||
|
boxService := s.instance
|
||||||
|
s.serviceAccess.RUnlock()
|
||||||
|
|
||||||
|
outbound, err := resolveOutbound(boxService, request.OutboundTag)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedDialer := dialer.NewResolveDialer(boxService.ctx, outbound, true, "", adapter.DNSQueryOptions{}, 0)
|
||||||
|
httpClient := networkquality.NewHTTPClient(resolvedDialer)
|
||||||
|
defer httpClient.CloseIdleConnections()
|
||||||
|
|
||||||
|
measurementClientFactory, err := networkquality.NewOptionalHTTP3Factory(resolvedDialer, request.Http3)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
result, nqErr := networkquality.Run(networkquality.Options{
|
||||||
|
ConfigURL: request.ConfigURL,
|
||||||
|
HTTPClient: httpClient,
|
||||||
|
NewMeasurementClient: measurementClientFactory,
|
||||||
|
Serial: request.Serial,
|
||||||
|
MaxRuntime: time.Duration(request.MaxRuntimeSeconds) * time.Second,
|
||||||
|
Context: server.Context(),
|
||||||
|
OnProgress: func(p networkquality.Progress) {
|
||||||
|
_ = server.Send(&NetworkQualityTestProgress{
|
||||||
|
Phase: int32(p.Phase),
|
||||||
|
DownloadCapacity: p.DownloadCapacity,
|
||||||
|
UploadCapacity: p.UploadCapacity,
|
||||||
|
DownloadRPM: p.DownloadRPM,
|
||||||
|
UploadRPM: p.UploadRPM,
|
||||||
|
IdleLatencyMs: p.IdleLatencyMs,
|
||||||
|
ElapsedMs: p.ElapsedMs,
|
||||||
|
DownloadCapacityAccuracy: int32(p.DownloadCapacityAccuracy),
|
||||||
|
UploadCapacityAccuracy: int32(p.UploadCapacityAccuracy),
|
||||||
|
DownloadRPMAccuracy: int32(p.DownloadRPMAccuracy),
|
||||||
|
UploadRPMAccuracy: int32(p.UploadRPMAccuracy),
|
||||||
|
})
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if nqErr != nil {
|
||||||
|
return server.Send(&NetworkQualityTestProgress{
|
||||||
|
IsFinal: true,
|
||||||
|
Error: nqErr.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return server.Send(&NetworkQualityTestProgress{
|
||||||
|
Phase: int32(networkquality.PhaseDone),
|
||||||
|
DownloadCapacity: result.DownloadCapacity,
|
||||||
|
UploadCapacity: result.UploadCapacity,
|
||||||
|
DownloadRPM: result.DownloadRPM,
|
||||||
|
UploadRPM: result.UploadRPM,
|
||||||
|
IdleLatencyMs: result.IdleLatencyMs,
|
||||||
|
IsFinal: true,
|
||||||
|
DownloadCapacityAccuracy: int32(result.DownloadCapacityAccuracy),
|
||||||
|
UploadCapacityAccuracy: int32(result.UploadCapacityAccuracy),
|
||||||
|
DownloadRPMAccuracy: int32(result.DownloadRPMAccuracy),
|
||||||
|
UploadRPMAccuracy: int32(result.UploadRPMAccuracy),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StartedService) StartSTUNTest(
|
||||||
|
request *STUNTestRequest,
|
||||||
|
server grpc.ServerStreamingServer[STUNTestProgress],
|
||||||
|
) error {
|
||||||
|
err := s.waitForStarted(server.Context())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.serviceAccess.RLock()
|
||||||
|
boxService := s.instance
|
||||||
|
s.serviceAccess.RUnlock()
|
||||||
|
|
||||||
|
outbound, err := resolveOutbound(boxService, request.OutboundTag)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedDialer := dialer.NewResolveDialer(boxService.ctx, outbound, true, "", adapter.DNSQueryOptions{}, 0)
|
||||||
|
|
||||||
|
result, stunErr := stun.Run(stun.Options{
|
||||||
|
Server: request.Server,
|
||||||
|
Dialer: resolvedDialer,
|
||||||
|
Context: server.Context(),
|
||||||
|
OnProgress: func(p stun.Progress) {
|
||||||
|
_ = server.Send(&STUNTestProgress{
|
||||||
|
Phase: int32(p.Phase),
|
||||||
|
ExternalAddr: p.ExternalAddr,
|
||||||
|
LatencyMs: p.LatencyMs,
|
||||||
|
NatMapping: int32(p.NATMapping),
|
||||||
|
NatFiltering: int32(p.NATFiltering),
|
||||||
|
})
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if stunErr != nil {
|
||||||
|
return server.Send(&STUNTestProgress{
|
||||||
|
IsFinal: true,
|
||||||
|
Error: stunErr.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return server.Send(&STUNTestProgress{
|
||||||
|
Phase: int32(stun.PhaseDone),
|
||||||
|
ExternalAddr: result.ExternalAddr,
|
||||||
|
LatencyMs: result.LatencyMs,
|
||||||
|
NatMapping: int32(result.NATMapping),
|
||||||
|
NatFiltering: int32(result.NATFiltering),
|
||||||
|
IsFinal: true,
|
||||||
|
NatTypeSupported: result.NATTypeSupported,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StartedService) SubscribeTailscaleStatus(
|
||||||
|
_ *emptypb.Empty,
|
||||||
|
server grpc.ServerStreamingServer[TailscaleStatusUpdate],
|
||||||
|
) error {
|
||||||
|
err := s.waitForStarted(server.Context())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.serviceAccess.RLock()
|
||||||
|
boxService := s.instance
|
||||||
|
s.serviceAccess.RUnlock()
|
||||||
|
|
||||||
|
endpointManager := service.FromContext[adapter.EndpointManager](boxService.ctx)
|
||||||
|
if endpointManager == nil {
|
||||||
|
return status.Error(codes.FailedPrecondition, "endpoint manager not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
type tailscaleEndpoint struct {
|
||||||
|
tag string
|
||||||
|
provider adapter.TailscaleEndpoint
|
||||||
|
}
|
||||||
|
var endpoints []tailscaleEndpoint
|
||||||
|
for _, endpoint := range endpointManager.Endpoints() {
|
||||||
|
if endpoint.Type() != C.TypeTailscale {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
provider, loaded := endpoint.(adapter.TailscaleEndpoint)
|
||||||
|
if !loaded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
endpoints = append(endpoints, tailscaleEndpoint{
|
||||||
|
tag: endpoint.Tag(),
|
||||||
|
provider: provider,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if len(endpoints) == 0 {
|
||||||
|
return status.Error(codes.NotFound, "no Tailscale endpoint found")
|
||||||
|
}
|
||||||
|
|
||||||
|
type taggedStatus struct {
|
||||||
|
tag string
|
||||||
|
status *adapter.TailscaleEndpointStatus
|
||||||
|
}
|
||||||
|
updates := make(chan taggedStatus, len(endpoints))
|
||||||
|
ctx, cancel := context.WithCancel(server.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var waitGroup sync.WaitGroup
|
||||||
|
for _, endpoint := range endpoints {
|
||||||
|
waitGroup.Add(1)
|
||||||
|
go func(tag string, provider adapter.TailscaleEndpoint) {
|
||||||
|
defer waitGroup.Done()
|
||||||
|
_ = provider.SubscribeTailscaleStatus(ctx, func(endpointStatus *adapter.TailscaleEndpointStatus) {
|
||||||
|
select {
|
||||||
|
case updates <- taggedStatus{tag: tag, status: endpointStatus}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}(endpoint.tag, endpoint.provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
waitGroup.Wait()
|
||||||
|
close(updates)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var tags []string
|
||||||
|
statuses := make(map[string]*adapter.TailscaleEndpointStatus, len(endpoints))
|
||||||
|
for update := range updates {
|
||||||
|
if _, exists := statuses[update.tag]; !exists {
|
||||||
|
tags = append(tags, update.tag)
|
||||||
|
}
|
||||||
|
statuses[update.tag] = update.status
|
||||||
|
protoEndpoints := make([]*TailscaleEndpointStatus, 0, len(statuses))
|
||||||
|
for _, tag := range tags {
|
||||||
|
protoEndpoints = append(protoEndpoints, tailscaleEndpointStatusToProto(tag, statuses[tag]))
|
||||||
|
}
|
||||||
|
sendErr := server.Send(&TailscaleStatusUpdate{
|
||||||
|
Endpoints: protoEndpoints,
|
||||||
|
})
|
||||||
|
if sendErr != nil {
|
||||||
|
return sendErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func tailscaleEndpointStatusToProto(tag string, s *adapter.TailscaleEndpointStatus) *TailscaleEndpointStatus {
|
||||||
|
userGroups := make([]*TailscaleUserGroup, len(s.UserGroups))
|
||||||
|
for i, group := range s.UserGroups {
|
||||||
|
peers := make([]*TailscalePeer, len(group.Peers))
|
||||||
|
for j, peer := range group.Peers {
|
||||||
|
peers[j] = tailscalePeerToProto(peer)
|
||||||
|
}
|
||||||
|
userGroups[i] = &TailscaleUserGroup{
|
||||||
|
UserID: group.UserID,
|
||||||
|
LoginName: group.LoginName,
|
||||||
|
DisplayName: group.DisplayName,
|
||||||
|
ProfilePicURL: group.ProfilePicURL,
|
||||||
|
Peers: peers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result := &TailscaleEndpointStatus{
|
||||||
|
EndpointTag: tag,
|
||||||
|
BackendState: s.BackendState,
|
||||||
|
AuthURL: s.AuthURL,
|
||||||
|
NetworkName: s.NetworkName,
|
||||||
|
MagicDNSSuffix: s.MagicDNSSuffix,
|
||||||
|
UserGroups: userGroups,
|
||||||
|
}
|
||||||
|
if s.Self != nil {
|
||||||
|
result.Self = tailscalePeerToProto(s.Self)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func tailscalePeerToProto(peer *adapter.TailscalePeer) *TailscalePeer {
|
||||||
|
return &TailscalePeer{
|
||||||
|
HostName: peer.HostName,
|
||||||
|
DnsName: peer.DNSName,
|
||||||
|
Os: peer.OS,
|
||||||
|
TailscaleIPs: peer.TailscaleIPs,
|
||||||
|
Online: peer.Online,
|
||||||
|
ExitNode: peer.ExitNode,
|
||||||
|
ExitNodeOption: peer.ExitNodeOption,
|
||||||
|
Active: peer.Active,
|
||||||
|
RxBytes: peer.RxBytes,
|
||||||
|
TxBytes: peer.TxBytes,
|
||||||
|
KeyExpiry: peer.KeyExpiry,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StartedService) StartTailscalePing(
|
||||||
|
request *TailscalePingRequest,
|
||||||
|
server grpc.ServerStreamingServer[TailscalePingResponse],
|
||||||
|
) error {
|
||||||
|
err := s.waitForStarted(server.Context())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.serviceAccess.RLock()
|
||||||
|
boxService := s.instance
|
||||||
|
s.serviceAccess.RUnlock()
|
||||||
|
|
||||||
|
endpointManager := service.FromContext[adapter.EndpointManager](boxService.ctx)
|
||||||
|
if endpointManager == nil {
|
||||||
|
return status.Error(codes.FailedPrecondition, "endpoint manager not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
var provider adapter.TailscaleEndpoint
|
||||||
|
if request.EndpointTag != "" {
|
||||||
|
endpoint, loaded := endpointManager.Get(request.EndpointTag)
|
||||||
|
if !loaded {
|
||||||
|
return status.Error(codes.NotFound, "endpoint not found: "+request.EndpointTag)
|
||||||
|
}
|
||||||
|
if endpoint.Type() != C.TypeTailscale {
|
||||||
|
return status.Error(codes.InvalidArgument, "endpoint is not Tailscale: "+request.EndpointTag)
|
||||||
|
}
|
||||||
|
pingProvider, loaded := endpoint.(adapter.TailscaleEndpoint)
|
||||||
|
if !loaded {
|
||||||
|
return status.Error(codes.FailedPrecondition, "endpoint does not support ping")
|
||||||
|
}
|
||||||
|
provider = pingProvider
|
||||||
|
} else {
|
||||||
|
for _, endpoint := range endpointManager.Endpoints() {
|
||||||
|
if endpoint.Type() != C.TypeTailscale {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pingProvider, loaded := endpoint.(adapter.TailscaleEndpoint)
|
||||||
|
if loaded {
|
||||||
|
provider = pingProvider
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if provider == nil {
|
||||||
|
return status.Error(codes.NotFound, "no Tailscale endpoint found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return provider.StartTailscalePing(server.Context(), request.PeerIP, func(result *adapter.TailscalePingResult) {
|
||||||
|
_ = server.Send(&TailscalePingResponse{
|
||||||
|
LatencyMs: result.LatencyMs,
|
||||||
|
IsDirect: result.IsDirect,
|
||||||
|
Endpoint: result.Endpoint,
|
||||||
|
DerpRegionID: result.DERPRegionID,
|
||||||
|
DerpRegionCode: result.DERPRegionCode,
|
||||||
|
Error: result.Error,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *StartedService) mustEmbedUnimplementedStartedServiceServer() {
|
func (s *StartedService) mustEmbedUnimplementedStartedServiceServer() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -26,12 +26,20 @@ service StartedService {
|
|||||||
|
|
||||||
rpc GetSystemProxyStatus(google.protobuf.Empty) returns(SystemProxyStatus) {}
|
rpc GetSystemProxyStatus(google.protobuf.Empty) returns(SystemProxyStatus) {}
|
||||||
rpc SetSystemProxyEnabled(SetSystemProxyEnabledRequest) returns(google.protobuf.Empty) {}
|
rpc SetSystemProxyEnabled(SetSystemProxyEnabledRequest) returns(google.protobuf.Empty) {}
|
||||||
|
rpc TriggerDebugCrash(DebugCrashRequest) returns(google.protobuf.Empty) {}
|
||||||
|
rpc TriggerOOMReport(google.protobuf.Empty) returns(google.protobuf.Empty) {}
|
||||||
|
|
||||||
rpc SubscribeConnections(SubscribeConnectionsRequest) returns(stream ConnectionEvents) {}
|
rpc SubscribeConnections(SubscribeConnectionsRequest) returns(stream ConnectionEvents) {}
|
||||||
rpc CloseConnection(CloseConnectionRequest) returns(google.protobuf.Empty) {}
|
rpc CloseConnection(CloseConnectionRequest) returns(google.protobuf.Empty) {}
|
||||||
rpc CloseAllConnections(google.protobuf.Empty) returns(google.protobuf.Empty) {}
|
rpc CloseAllConnections(google.protobuf.Empty) returns(google.protobuf.Empty) {}
|
||||||
rpc GetDeprecatedWarnings(google.protobuf.Empty) returns(DeprecatedWarnings) {}
|
rpc GetDeprecatedWarnings(google.protobuf.Empty) returns(DeprecatedWarnings) {}
|
||||||
rpc GetStartedAt(google.protobuf.Empty) returns(StartedAt) {}
|
rpc GetStartedAt(google.protobuf.Empty) returns(StartedAt) {}
|
||||||
|
|
||||||
|
rpc SubscribeOutbounds(google.protobuf.Empty) returns (stream OutboundList) {}
|
||||||
|
rpc StartNetworkQualityTest(NetworkQualityTestRequest) returns (stream NetworkQualityTestProgress) {}
|
||||||
|
rpc StartSTUNTest(STUNTestRequest) returns (stream STUNTestProgress) {}
|
||||||
|
rpc SubscribeTailscaleStatus(google.protobuf.Empty) returns (stream TailscaleStatusUpdate) {}
|
||||||
|
rpc StartTailscalePing(TailscalePingRequest) returns (stream TailscalePingResponse) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
message ServiceStatus {
|
message ServiceStatus {
|
||||||
@@ -141,6 +149,15 @@ message SetSystemProxyEnabledRequest {
|
|||||||
bool enabled = 1;
|
bool enabled = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message DebugCrashRequest {
|
||||||
|
enum Type {
|
||||||
|
GO = 0;
|
||||||
|
NATIVE = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Type type = 1;
|
||||||
|
}
|
||||||
|
|
||||||
message SubscribeConnectionsRequest {
|
message SubscribeConnectionsRequest {
|
||||||
int64 interval = 1;
|
int64 interval = 1;
|
||||||
}
|
}
|
||||||
@@ -195,7 +212,7 @@ message ProcessInfo {
|
|||||||
int32 userId = 2;
|
int32 userId = 2;
|
||||||
string userName = 3;
|
string userName = 3;
|
||||||
string processPath = 4;
|
string processPath = 4;
|
||||||
string packageName = 5;
|
repeated string packageNames = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
message CloseConnectionRequest {
|
message CloseConnectionRequest {
|
||||||
@@ -210,8 +227,105 @@ message DeprecatedWarning {
|
|||||||
string message = 1;
|
string message = 1;
|
||||||
bool impending = 2;
|
bool impending = 2;
|
||||||
string migrationLink = 3;
|
string migrationLink = 3;
|
||||||
|
string description = 4;
|
||||||
|
string deprecatedVersion = 5;
|
||||||
|
string scheduledVersion = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StartedAt {
|
message StartedAt {
|
||||||
int64 startedAt = 1;
|
int64 startedAt = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message OutboundList {
|
||||||
|
repeated GroupItem outbounds = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message NetworkQualityTestRequest {
|
||||||
|
string configURL = 1;
|
||||||
|
string outboundTag = 2;
|
||||||
|
bool serial = 3;
|
||||||
|
int32 maxRuntimeSeconds = 4;
|
||||||
|
bool http3 = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message NetworkQualityTestProgress {
|
||||||
|
int32 phase = 1;
|
||||||
|
int64 downloadCapacity = 2;
|
||||||
|
int64 uploadCapacity = 3;
|
||||||
|
int32 downloadRPM = 4;
|
||||||
|
int32 uploadRPM = 5;
|
||||||
|
int32 idleLatencyMs = 6;
|
||||||
|
int64 elapsedMs = 7;
|
||||||
|
bool isFinal = 8;
|
||||||
|
string error = 9;
|
||||||
|
int32 downloadCapacityAccuracy = 10;
|
||||||
|
int32 uploadCapacityAccuracy = 11;
|
||||||
|
int32 downloadRPMAccuracy = 12;
|
||||||
|
int32 uploadRPMAccuracy = 13;
|
||||||
|
}
|
||||||
|
|
||||||
|
message STUNTestRequest {
|
||||||
|
string server = 1;
|
||||||
|
string outboundTag = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message STUNTestProgress {
|
||||||
|
int32 phase = 1;
|
||||||
|
string externalAddr = 2;
|
||||||
|
int32 latencyMs = 3;
|
||||||
|
int32 natMapping = 4;
|
||||||
|
int32 natFiltering = 5;
|
||||||
|
bool isFinal = 6;
|
||||||
|
string error = 7;
|
||||||
|
bool natTypeSupported = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TailscaleStatusUpdate {
|
||||||
|
repeated TailscaleEndpointStatus endpoints = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TailscaleEndpointStatus {
|
||||||
|
string endpointTag = 1;
|
||||||
|
string backendState = 2;
|
||||||
|
string authURL = 3;
|
||||||
|
string networkName = 4;
|
||||||
|
string magicDNSSuffix = 5;
|
||||||
|
TailscalePeer self = 6;
|
||||||
|
repeated TailscaleUserGroup userGroups = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TailscaleUserGroup {
|
||||||
|
int64 userID = 1;
|
||||||
|
string loginName = 2;
|
||||||
|
string displayName = 3;
|
||||||
|
string profilePicURL = 4;
|
||||||
|
repeated TailscalePeer peers = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TailscalePeer {
|
||||||
|
string hostName = 1;
|
||||||
|
string dnsName = 2;
|
||||||
|
string os = 3;
|
||||||
|
repeated string tailscaleIPs = 4;
|
||||||
|
bool online = 5;
|
||||||
|
bool exitNode = 6;
|
||||||
|
bool exitNodeOption = 7;
|
||||||
|
bool active = 8;
|
||||||
|
int64 rxBytes = 9;
|
||||||
|
int64 txBytes = 10;
|
||||||
|
int64 keyExpiry = 11;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TailscalePingRequest {
|
||||||
|
string endpointTag = 1;
|
||||||
|
string peerIP = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TailscalePingResponse {
|
||||||
|
double latencyMs = 1;
|
||||||
|
bool isDirect = 2;
|
||||||
|
string endpoint = 3;
|
||||||
|
int32 derpRegionID = 4;
|
||||||
|
string derpRegionCode = 5;
|
||||||
|
string error = 6;
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,27 +15,34 @@ import (
|
|||||||
const _ = grpc.SupportPackageIsVersion9
|
const _ = grpc.SupportPackageIsVersion9
|
||||||
|
|
||||||
const (
|
const (
|
||||||
StartedService_StopService_FullMethodName = "/daemon.StartedService/StopService"
|
StartedService_StopService_FullMethodName = "/daemon.StartedService/StopService"
|
||||||
StartedService_ReloadService_FullMethodName = "/daemon.StartedService/ReloadService"
|
StartedService_ReloadService_FullMethodName = "/daemon.StartedService/ReloadService"
|
||||||
StartedService_SubscribeServiceStatus_FullMethodName = "/daemon.StartedService/SubscribeServiceStatus"
|
StartedService_SubscribeServiceStatus_FullMethodName = "/daemon.StartedService/SubscribeServiceStatus"
|
||||||
StartedService_SubscribeLog_FullMethodName = "/daemon.StartedService/SubscribeLog"
|
StartedService_SubscribeLog_FullMethodName = "/daemon.StartedService/SubscribeLog"
|
||||||
StartedService_GetDefaultLogLevel_FullMethodName = "/daemon.StartedService/GetDefaultLogLevel"
|
StartedService_GetDefaultLogLevel_FullMethodName = "/daemon.StartedService/GetDefaultLogLevel"
|
||||||
StartedService_ClearLogs_FullMethodName = "/daemon.StartedService/ClearLogs"
|
StartedService_ClearLogs_FullMethodName = "/daemon.StartedService/ClearLogs"
|
||||||
StartedService_SubscribeStatus_FullMethodName = "/daemon.StartedService/SubscribeStatus"
|
StartedService_SubscribeStatus_FullMethodName = "/daemon.StartedService/SubscribeStatus"
|
||||||
StartedService_SubscribeGroups_FullMethodName = "/daemon.StartedService/SubscribeGroups"
|
StartedService_SubscribeGroups_FullMethodName = "/daemon.StartedService/SubscribeGroups"
|
||||||
StartedService_GetClashModeStatus_FullMethodName = "/daemon.StartedService/GetClashModeStatus"
|
StartedService_GetClashModeStatus_FullMethodName = "/daemon.StartedService/GetClashModeStatus"
|
||||||
StartedService_SubscribeClashMode_FullMethodName = "/daemon.StartedService/SubscribeClashMode"
|
StartedService_SubscribeClashMode_FullMethodName = "/daemon.StartedService/SubscribeClashMode"
|
||||||
StartedService_SetClashMode_FullMethodName = "/daemon.StartedService/SetClashMode"
|
StartedService_SetClashMode_FullMethodName = "/daemon.StartedService/SetClashMode"
|
||||||
StartedService_URLTest_FullMethodName = "/daemon.StartedService/URLTest"
|
StartedService_URLTest_FullMethodName = "/daemon.StartedService/URLTest"
|
||||||
StartedService_SelectOutbound_FullMethodName = "/daemon.StartedService/SelectOutbound"
|
StartedService_SelectOutbound_FullMethodName = "/daemon.StartedService/SelectOutbound"
|
||||||
StartedService_SetGroupExpand_FullMethodName = "/daemon.StartedService/SetGroupExpand"
|
StartedService_SetGroupExpand_FullMethodName = "/daemon.StartedService/SetGroupExpand"
|
||||||
StartedService_GetSystemProxyStatus_FullMethodName = "/daemon.StartedService/GetSystemProxyStatus"
|
StartedService_GetSystemProxyStatus_FullMethodName = "/daemon.StartedService/GetSystemProxyStatus"
|
||||||
StartedService_SetSystemProxyEnabled_FullMethodName = "/daemon.StartedService/SetSystemProxyEnabled"
|
StartedService_SetSystemProxyEnabled_FullMethodName = "/daemon.StartedService/SetSystemProxyEnabled"
|
||||||
StartedService_SubscribeConnections_FullMethodName = "/daemon.StartedService/SubscribeConnections"
|
StartedService_TriggerDebugCrash_FullMethodName = "/daemon.StartedService/TriggerDebugCrash"
|
||||||
StartedService_CloseConnection_FullMethodName = "/daemon.StartedService/CloseConnection"
|
StartedService_TriggerOOMReport_FullMethodName = "/daemon.StartedService/TriggerOOMReport"
|
||||||
StartedService_CloseAllConnections_FullMethodName = "/daemon.StartedService/CloseAllConnections"
|
StartedService_SubscribeConnections_FullMethodName = "/daemon.StartedService/SubscribeConnections"
|
||||||
StartedService_GetDeprecatedWarnings_FullMethodName = "/daemon.StartedService/GetDeprecatedWarnings"
|
StartedService_CloseConnection_FullMethodName = "/daemon.StartedService/CloseConnection"
|
||||||
StartedService_GetStartedAt_FullMethodName = "/daemon.StartedService/GetStartedAt"
|
StartedService_CloseAllConnections_FullMethodName = "/daemon.StartedService/CloseAllConnections"
|
||||||
|
StartedService_GetDeprecatedWarnings_FullMethodName = "/daemon.StartedService/GetDeprecatedWarnings"
|
||||||
|
StartedService_GetStartedAt_FullMethodName = "/daemon.StartedService/GetStartedAt"
|
||||||
|
StartedService_SubscribeOutbounds_FullMethodName = "/daemon.StartedService/SubscribeOutbounds"
|
||||||
|
StartedService_StartNetworkQualityTest_FullMethodName = "/daemon.StartedService/StartNetworkQualityTest"
|
||||||
|
StartedService_StartSTUNTest_FullMethodName = "/daemon.StartedService/StartSTUNTest"
|
||||||
|
StartedService_SubscribeTailscaleStatus_FullMethodName = "/daemon.StartedService/SubscribeTailscaleStatus"
|
||||||
|
StartedService_StartTailscalePing_FullMethodName = "/daemon.StartedService/StartTailscalePing"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StartedServiceClient is the client API for StartedService service.
|
// StartedServiceClient is the client API for StartedService service.
|
||||||
@@ -58,11 +65,18 @@ type StartedServiceClient interface {
|
|||||||
SetGroupExpand(ctx context.Context, in *SetGroupExpandRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
SetGroupExpand(ctx context.Context, in *SetGroupExpandRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||||
GetSystemProxyStatus(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*SystemProxyStatus, error)
|
GetSystemProxyStatus(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*SystemProxyStatus, error)
|
||||||
SetSystemProxyEnabled(ctx context.Context, in *SetSystemProxyEnabledRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
SetSystemProxyEnabled(ctx context.Context, in *SetSystemProxyEnabledRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||||
|
TriggerDebugCrash(ctx context.Context, in *DebugCrashRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||||
|
TriggerOOMReport(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||||
SubscribeConnections(ctx context.Context, in *SubscribeConnectionsRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ConnectionEvents], error)
|
SubscribeConnections(ctx context.Context, in *SubscribeConnectionsRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ConnectionEvents], error)
|
||||||
CloseConnection(ctx context.Context, in *CloseConnectionRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
CloseConnection(ctx context.Context, in *CloseConnectionRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||||
CloseAllConnections(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
CloseAllConnections(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||||
GetDeprecatedWarnings(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*DeprecatedWarnings, error)
|
GetDeprecatedWarnings(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*DeprecatedWarnings, error)
|
||||||
GetStartedAt(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*StartedAt, error)
|
GetStartedAt(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*StartedAt, error)
|
||||||
|
SubscribeOutbounds(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (grpc.ServerStreamingClient[OutboundList], error)
|
||||||
|
StartNetworkQualityTest(ctx context.Context, in *NetworkQualityTestRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[NetworkQualityTestProgress], error)
|
||||||
|
StartSTUNTest(ctx context.Context, in *STUNTestRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[STUNTestProgress], error)
|
||||||
|
SubscribeTailscaleStatus(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (grpc.ServerStreamingClient[TailscaleStatusUpdate], error)
|
||||||
|
StartTailscalePing(ctx context.Context, in *TailscalePingRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[TailscalePingResponse], error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type startedServiceClient struct {
|
type startedServiceClient struct {
|
||||||
@@ -278,6 +292,26 @@ func (c *startedServiceClient) SetSystemProxyEnabled(ctx context.Context, in *Se
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *startedServiceClient) TriggerDebugCrash(ctx context.Context, in *DebugCrashRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) {
|
||||||
|
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||||
|
out := new(emptypb.Empty)
|
||||||
|
err := c.cc.Invoke(ctx, StartedService_TriggerDebugCrash_FullMethodName, in, out, cOpts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *startedServiceClient) TriggerOOMReport(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*emptypb.Empty, error) {
|
||||||
|
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||||
|
out := new(emptypb.Empty)
|
||||||
|
err := c.cc.Invoke(ctx, StartedService_TriggerOOMReport_FullMethodName, in, out, cOpts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *startedServiceClient) SubscribeConnections(ctx context.Context, in *SubscribeConnectionsRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ConnectionEvents], error) {
|
func (c *startedServiceClient) SubscribeConnections(ctx context.Context, in *SubscribeConnectionsRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ConnectionEvents], error) {
|
||||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||||
stream, err := c.cc.NewStream(ctx, &StartedService_ServiceDesc.Streams[5], StartedService_SubscribeConnections_FullMethodName, cOpts...)
|
stream, err := c.cc.NewStream(ctx, &StartedService_ServiceDesc.Streams[5], StartedService_SubscribeConnections_FullMethodName, cOpts...)
|
||||||
@@ -337,6 +371,101 @@ func (c *startedServiceClient) GetStartedAt(ctx context.Context, in *emptypb.Emp
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *startedServiceClient) SubscribeOutbounds(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (grpc.ServerStreamingClient[OutboundList], error) {
|
||||||
|
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||||
|
stream, err := c.cc.NewStream(ctx, &StartedService_ServiceDesc.Streams[6], StartedService_SubscribeOutbounds_FullMethodName, cOpts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
x := &grpc.GenericClientStream[emptypb.Empty, OutboundList]{ClientStream: stream}
|
||||||
|
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := x.ClientStream.CloseSend(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||||
|
type StartedService_SubscribeOutboundsClient = grpc.ServerStreamingClient[OutboundList]
|
||||||
|
|
||||||
|
func (c *startedServiceClient) StartNetworkQualityTest(ctx context.Context, in *NetworkQualityTestRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[NetworkQualityTestProgress], error) {
|
||||||
|
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||||
|
stream, err := c.cc.NewStream(ctx, &StartedService_ServiceDesc.Streams[7], StartedService_StartNetworkQualityTest_FullMethodName, cOpts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
x := &grpc.GenericClientStream[NetworkQualityTestRequest, NetworkQualityTestProgress]{ClientStream: stream}
|
||||||
|
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := x.ClientStream.CloseSend(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||||
|
type StartedService_StartNetworkQualityTestClient = grpc.ServerStreamingClient[NetworkQualityTestProgress]
|
||||||
|
|
||||||
|
func (c *startedServiceClient) StartSTUNTest(ctx context.Context, in *STUNTestRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[STUNTestProgress], error) {
|
||||||
|
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||||
|
stream, err := c.cc.NewStream(ctx, &StartedService_ServiceDesc.Streams[8], StartedService_StartSTUNTest_FullMethodName, cOpts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
x := &grpc.GenericClientStream[STUNTestRequest, STUNTestProgress]{ClientStream: stream}
|
||||||
|
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := x.ClientStream.CloseSend(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||||
|
type StartedService_StartSTUNTestClient = grpc.ServerStreamingClient[STUNTestProgress]
|
||||||
|
|
||||||
|
func (c *startedServiceClient) SubscribeTailscaleStatus(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (grpc.ServerStreamingClient[TailscaleStatusUpdate], error) {
|
||||||
|
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||||
|
stream, err := c.cc.NewStream(ctx, &StartedService_ServiceDesc.Streams[9], StartedService_SubscribeTailscaleStatus_FullMethodName, cOpts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
x := &grpc.GenericClientStream[emptypb.Empty, TailscaleStatusUpdate]{ClientStream: stream}
|
||||||
|
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := x.ClientStream.CloseSend(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||||
|
type StartedService_SubscribeTailscaleStatusClient = grpc.ServerStreamingClient[TailscaleStatusUpdate]
|
||||||
|
|
||||||
|
func (c *startedServiceClient) StartTailscalePing(ctx context.Context, in *TailscalePingRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[TailscalePingResponse], error) {
|
||||||
|
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||||
|
stream, err := c.cc.NewStream(ctx, &StartedService_ServiceDesc.Streams[10], StartedService_StartTailscalePing_FullMethodName, cOpts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
x := &grpc.GenericClientStream[TailscalePingRequest, TailscalePingResponse]{ClientStream: stream}
|
||||||
|
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := x.ClientStream.CloseSend(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||||
|
type StartedService_StartTailscalePingClient = grpc.ServerStreamingClient[TailscalePingResponse]
|
||||||
|
|
||||||
// StartedServiceServer is the server API for StartedService service.
|
// StartedServiceServer is the server API for StartedService service.
|
||||||
// All implementations must embed UnimplementedStartedServiceServer
|
// All implementations must embed UnimplementedStartedServiceServer
|
||||||
// for forward compatibility.
|
// for forward compatibility.
|
||||||
@@ -357,11 +486,18 @@ type StartedServiceServer interface {
|
|||||||
SetGroupExpand(context.Context, *SetGroupExpandRequest) (*emptypb.Empty, error)
|
SetGroupExpand(context.Context, *SetGroupExpandRequest) (*emptypb.Empty, error)
|
||||||
GetSystemProxyStatus(context.Context, *emptypb.Empty) (*SystemProxyStatus, error)
|
GetSystemProxyStatus(context.Context, *emptypb.Empty) (*SystemProxyStatus, error)
|
||||||
SetSystemProxyEnabled(context.Context, *SetSystemProxyEnabledRequest) (*emptypb.Empty, error)
|
SetSystemProxyEnabled(context.Context, *SetSystemProxyEnabledRequest) (*emptypb.Empty, error)
|
||||||
|
TriggerDebugCrash(context.Context, *DebugCrashRequest) (*emptypb.Empty, error)
|
||||||
|
TriggerOOMReport(context.Context, *emptypb.Empty) (*emptypb.Empty, error)
|
||||||
SubscribeConnections(*SubscribeConnectionsRequest, grpc.ServerStreamingServer[ConnectionEvents]) error
|
SubscribeConnections(*SubscribeConnectionsRequest, grpc.ServerStreamingServer[ConnectionEvents]) error
|
||||||
CloseConnection(context.Context, *CloseConnectionRequest) (*emptypb.Empty, error)
|
CloseConnection(context.Context, *CloseConnectionRequest) (*emptypb.Empty, error)
|
||||||
CloseAllConnections(context.Context, *emptypb.Empty) (*emptypb.Empty, error)
|
CloseAllConnections(context.Context, *emptypb.Empty) (*emptypb.Empty, error)
|
||||||
GetDeprecatedWarnings(context.Context, *emptypb.Empty) (*DeprecatedWarnings, error)
|
GetDeprecatedWarnings(context.Context, *emptypb.Empty) (*DeprecatedWarnings, error)
|
||||||
GetStartedAt(context.Context, *emptypb.Empty) (*StartedAt, error)
|
GetStartedAt(context.Context, *emptypb.Empty) (*StartedAt, error)
|
||||||
|
SubscribeOutbounds(*emptypb.Empty, grpc.ServerStreamingServer[OutboundList]) error
|
||||||
|
StartNetworkQualityTest(*NetworkQualityTestRequest, grpc.ServerStreamingServer[NetworkQualityTestProgress]) error
|
||||||
|
StartSTUNTest(*STUNTestRequest, grpc.ServerStreamingServer[STUNTestProgress]) error
|
||||||
|
SubscribeTailscaleStatus(*emptypb.Empty, grpc.ServerStreamingServer[TailscaleStatusUpdate]) error
|
||||||
|
StartTailscalePing(*TailscalePingRequest, grpc.ServerStreamingServer[TailscalePingResponse]) error
|
||||||
mustEmbedUnimplementedStartedServiceServer()
|
mustEmbedUnimplementedStartedServiceServer()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -436,6 +572,14 @@ func (UnimplementedStartedServiceServer) SetSystemProxyEnabled(context.Context,
|
|||||||
return nil, status.Error(codes.Unimplemented, "method SetSystemProxyEnabled not implemented")
|
return nil, status.Error(codes.Unimplemented, "method SetSystemProxyEnabled not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (UnimplementedStartedServiceServer) TriggerDebugCrash(context.Context, *DebugCrashRequest) (*emptypb.Empty, error) {
|
||||||
|
return nil, status.Error(codes.Unimplemented, "method TriggerDebugCrash not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedStartedServiceServer) TriggerOOMReport(context.Context, *emptypb.Empty) (*emptypb.Empty, error) {
|
||||||
|
return nil, status.Error(codes.Unimplemented, "method TriggerOOMReport not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (UnimplementedStartedServiceServer) SubscribeConnections(*SubscribeConnectionsRequest, grpc.ServerStreamingServer[ConnectionEvents]) error {
|
func (UnimplementedStartedServiceServer) SubscribeConnections(*SubscribeConnectionsRequest, grpc.ServerStreamingServer[ConnectionEvents]) error {
|
||||||
return status.Error(codes.Unimplemented, "method SubscribeConnections not implemented")
|
return status.Error(codes.Unimplemented, "method SubscribeConnections not implemented")
|
||||||
}
|
}
|
||||||
@@ -455,6 +599,26 @@ func (UnimplementedStartedServiceServer) GetDeprecatedWarnings(context.Context,
|
|||||||
func (UnimplementedStartedServiceServer) GetStartedAt(context.Context, *emptypb.Empty) (*StartedAt, error) {
|
func (UnimplementedStartedServiceServer) GetStartedAt(context.Context, *emptypb.Empty) (*StartedAt, error) {
|
||||||
return nil, status.Error(codes.Unimplemented, "method GetStartedAt not implemented")
|
return nil, status.Error(codes.Unimplemented, "method GetStartedAt not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (UnimplementedStartedServiceServer) SubscribeOutbounds(*emptypb.Empty, grpc.ServerStreamingServer[OutboundList]) error {
|
||||||
|
return status.Error(codes.Unimplemented, "method SubscribeOutbounds not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedStartedServiceServer) StartNetworkQualityTest(*NetworkQualityTestRequest, grpc.ServerStreamingServer[NetworkQualityTestProgress]) error {
|
||||||
|
return status.Error(codes.Unimplemented, "method StartNetworkQualityTest not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedStartedServiceServer) StartSTUNTest(*STUNTestRequest, grpc.ServerStreamingServer[STUNTestProgress]) error {
|
||||||
|
return status.Error(codes.Unimplemented, "method StartSTUNTest not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedStartedServiceServer) SubscribeTailscaleStatus(*emptypb.Empty, grpc.ServerStreamingServer[TailscaleStatusUpdate]) error {
|
||||||
|
return status.Error(codes.Unimplemented, "method SubscribeTailscaleStatus not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedStartedServiceServer) StartTailscalePing(*TailscalePingRequest, grpc.ServerStreamingServer[TailscalePingResponse]) error {
|
||||||
|
return status.Error(codes.Unimplemented, "method StartTailscalePing not implemented")
|
||||||
|
}
|
||||||
func (UnimplementedStartedServiceServer) mustEmbedUnimplementedStartedServiceServer() {}
|
func (UnimplementedStartedServiceServer) mustEmbedUnimplementedStartedServiceServer() {}
|
||||||
func (UnimplementedStartedServiceServer) testEmbeddedByValue() {}
|
func (UnimplementedStartedServiceServer) testEmbeddedByValue() {}
|
||||||
|
|
||||||
@@ -729,6 +893,42 @@ func _StartedService_SetSystemProxyEnabled_Handler(srv interface{}, ctx context.
|
|||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func _StartedService_TriggerDebugCrash_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(DebugCrashRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(StartedServiceServer).TriggerDebugCrash(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: StartedService_TriggerDebugCrash_FullMethodName,
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(StartedServiceServer).TriggerDebugCrash(ctx, req.(*DebugCrashRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _StartedService_TriggerOOMReport_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(emptypb.Empty)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(StartedServiceServer).TriggerOOMReport(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: StartedService_TriggerOOMReport_FullMethodName,
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(StartedServiceServer).TriggerOOMReport(ctx, req.(*emptypb.Empty))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
func _StartedService_SubscribeConnections_Handler(srv interface{}, stream grpc.ServerStream) error {
|
func _StartedService_SubscribeConnections_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||||
m := new(SubscribeConnectionsRequest)
|
m := new(SubscribeConnectionsRequest)
|
||||||
if err := stream.RecvMsg(m); err != nil {
|
if err := stream.RecvMsg(m); err != nil {
|
||||||
@@ -812,6 +1012,61 @@ func _StartedService_GetStartedAt_Handler(srv interface{}, ctx context.Context,
|
|||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func _StartedService_SubscribeOutbounds_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||||
|
m := new(emptypb.Empty)
|
||||||
|
if err := stream.RecvMsg(m); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return srv.(StartedServiceServer).SubscribeOutbounds(m, &grpc.GenericServerStream[emptypb.Empty, OutboundList]{ServerStream: stream})
|
||||||
|
}
|
||||||
|
|
||||||
|
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||||
|
type StartedService_SubscribeOutboundsServer = grpc.ServerStreamingServer[OutboundList]
|
||||||
|
|
||||||
|
func _StartedService_StartNetworkQualityTest_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||||
|
m := new(NetworkQualityTestRequest)
|
||||||
|
if err := stream.RecvMsg(m); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return srv.(StartedServiceServer).StartNetworkQualityTest(m, &grpc.GenericServerStream[NetworkQualityTestRequest, NetworkQualityTestProgress]{ServerStream: stream})
|
||||||
|
}
|
||||||
|
|
||||||
|
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||||
|
type StartedService_StartNetworkQualityTestServer = grpc.ServerStreamingServer[NetworkQualityTestProgress]
|
||||||
|
|
||||||
|
func _StartedService_StartSTUNTest_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||||
|
m := new(STUNTestRequest)
|
||||||
|
if err := stream.RecvMsg(m); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return srv.(StartedServiceServer).StartSTUNTest(m, &grpc.GenericServerStream[STUNTestRequest, STUNTestProgress]{ServerStream: stream})
|
||||||
|
}
|
||||||
|
|
||||||
|
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||||
|
type StartedService_StartSTUNTestServer = grpc.ServerStreamingServer[STUNTestProgress]
|
||||||
|
|
||||||
|
func _StartedService_SubscribeTailscaleStatus_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||||
|
m := new(emptypb.Empty)
|
||||||
|
if err := stream.RecvMsg(m); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return srv.(StartedServiceServer).SubscribeTailscaleStatus(m, &grpc.GenericServerStream[emptypb.Empty, TailscaleStatusUpdate]{ServerStream: stream})
|
||||||
|
}
|
||||||
|
|
||||||
|
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||||
|
type StartedService_SubscribeTailscaleStatusServer = grpc.ServerStreamingServer[TailscaleStatusUpdate]
|
||||||
|
|
||||||
|
func _StartedService_StartTailscalePing_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||||
|
m := new(TailscalePingRequest)
|
||||||
|
if err := stream.RecvMsg(m); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return srv.(StartedServiceServer).StartTailscalePing(m, &grpc.GenericServerStream[TailscalePingRequest, TailscalePingResponse]{ServerStream: stream})
|
||||||
|
}
|
||||||
|
|
||||||
|
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||||
|
type StartedService_StartTailscalePingServer = grpc.ServerStreamingServer[TailscalePingResponse]
|
||||||
|
|
||||||
// StartedService_ServiceDesc is the grpc.ServiceDesc for StartedService service.
|
// StartedService_ServiceDesc is the grpc.ServiceDesc for StartedService service.
|
||||||
// It's only intended for direct use with grpc.RegisterService,
|
// It's only intended for direct use with grpc.RegisterService,
|
||||||
// and not to be introspected or modified (even as a copy)
|
// and not to be introspected or modified (even as a copy)
|
||||||
@@ -863,6 +1118,14 @@ var StartedService_ServiceDesc = grpc.ServiceDesc{
|
|||||||
MethodName: "SetSystemProxyEnabled",
|
MethodName: "SetSystemProxyEnabled",
|
||||||
Handler: _StartedService_SetSystemProxyEnabled_Handler,
|
Handler: _StartedService_SetSystemProxyEnabled_Handler,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
MethodName: "TriggerDebugCrash",
|
||||||
|
Handler: _StartedService_TriggerDebugCrash_Handler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MethodName: "TriggerOOMReport",
|
||||||
|
Handler: _StartedService_TriggerOOMReport_Handler,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
MethodName: "CloseConnection",
|
MethodName: "CloseConnection",
|
||||||
Handler: _StartedService_CloseConnection_Handler,
|
Handler: _StartedService_CloseConnection_Handler,
|
||||||
@@ -911,6 +1174,31 @@ var StartedService_ServiceDesc = grpc.ServiceDesc{
|
|||||||
Handler: _StartedService_SubscribeConnections_Handler,
|
Handler: _StartedService_SubscribeConnections_Handler,
|
||||||
ServerStreams: true,
|
ServerStreams: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
StreamName: "SubscribeOutbounds",
|
||||||
|
Handler: _StartedService_SubscribeOutbounds_Handler,
|
||||||
|
ServerStreams: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StreamName: "StartNetworkQualityTest",
|
||||||
|
Handler: _StartedService_StartNetworkQualityTest_Handler,
|
||||||
|
ServerStreams: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StreamName: "StartSTUNTest",
|
||||||
|
Handler: _StartedService_StartSTUNTest_Handler,
|
||||||
|
ServerStreams: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StreamName: "SubscribeTailscaleStatus",
|
||||||
|
Handler: _StartedService_SubscribeTailscaleStatus_Handler,
|
||||||
|
ServerStreams: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StreamName: "StartTailscalePing",
|
||||||
|
Handler: _StartedService_StartTailscalePing_Handler,
|
||||||
|
ServerStreams: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Metadata: "daemon/started_service.proto",
|
Metadata: "daemon/started_service.proto",
|
||||||
}
|
}
|
||||||
|
|||||||
596
dns/client.go
596
dns/client.go
@@ -5,7 +5,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
@@ -14,7 +13,6 @@ import (
|
|||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
"github.com/sagernet/sing/common/logger"
|
"github.com/sagernet/sing/common/logger"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
"github.com/sagernet/sing/common/task"
|
"github.com/sagernet/sing/common/task"
|
||||||
"github.com/sagernet/sing/contrab/freelru"
|
"github.com/sagernet/sing/contrab/freelru"
|
||||||
"github.com/sagernet/sing/contrab/maphash"
|
"github.com/sagernet/sing/contrab/maphash"
|
||||||
@@ -32,59 +30,63 @@ var (
|
|||||||
var _ adapter.DNSClient = (*Client)(nil)
|
var _ adapter.DNSClient = (*Client)(nil)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
timeout time.Duration
|
ctx context.Context
|
||||||
disableCache bool
|
timeout time.Duration
|
||||||
disableExpire bool
|
disableCache bool
|
||||||
independentCache bool
|
disableExpire bool
|
||||||
clientSubnet netip.Prefix
|
optimisticTimeout time.Duration
|
||||||
rdrc adapter.RDRCStore
|
cacheCapacity uint32
|
||||||
initRDRCFunc func() adapter.RDRCStore
|
clientSubnet netip.Prefix
|
||||||
logger logger.ContextLogger
|
rdrc adapter.RDRCStore
|
||||||
cache freelru.Cache[dns.Question, *dns.Msg]
|
initRDRCFunc func() adapter.RDRCStore
|
||||||
cacheLock compatible.Map[dns.Question, chan struct{}]
|
dnsCache adapter.DNSCacheStore
|
||||||
transportCache freelru.Cache[transportCacheKey, *dns.Msg]
|
initDNSCacheFunc func() adapter.DNSCacheStore
|
||||||
transportCacheLock compatible.Map[dns.Question, chan struct{}]
|
logger logger.ContextLogger
|
||||||
|
cache freelru.Cache[dnsCacheKey, *dns.Msg]
|
||||||
|
cacheLock compatible.Map[dnsCacheKey, chan struct{}]
|
||||||
|
backgroundRefresh compatible.Map[dnsCacheKey, struct{}]
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientOptions struct {
|
type ClientOptions struct {
|
||||||
Timeout time.Duration
|
Context context.Context
|
||||||
DisableCache bool
|
Timeout time.Duration
|
||||||
DisableExpire bool
|
DisableCache bool
|
||||||
IndependentCache bool
|
DisableExpire bool
|
||||||
CacheCapacity uint32
|
OptimisticTimeout time.Duration
|
||||||
ClientSubnet netip.Prefix
|
CacheCapacity uint32
|
||||||
RDRC func() adapter.RDRCStore
|
ClientSubnet netip.Prefix
|
||||||
Logger logger.ContextLogger
|
RDRC func() adapter.RDRCStore
|
||||||
|
DNSCache func() adapter.DNSCacheStore
|
||||||
|
Logger logger.ContextLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(options ClientOptions) *Client {
|
func NewClient(options ClientOptions) *Client {
|
||||||
client := &Client{
|
|
||||||
timeout: options.Timeout,
|
|
||||||
disableCache: options.DisableCache,
|
|
||||||
disableExpire: options.DisableExpire,
|
|
||||||
independentCache: options.IndependentCache,
|
|
||||||
clientSubnet: options.ClientSubnet,
|
|
||||||
initRDRCFunc: options.RDRC,
|
|
||||||
logger: options.Logger,
|
|
||||||
}
|
|
||||||
if client.timeout == 0 {
|
|
||||||
client.timeout = C.DNSTimeout
|
|
||||||
}
|
|
||||||
cacheCapacity := options.CacheCapacity
|
cacheCapacity := options.CacheCapacity
|
||||||
if cacheCapacity < 1024 {
|
if cacheCapacity < 1024 {
|
||||||
cacheCapacity = 1024
|
cacheCapacity = 1024
|
||||||
}
|
}
|
||||||
if !client.disableCache {
|
client := &Client{
|
||||||
if !client.independentCache {
|
ctx: options.Context,
|
||||||
client.cache = common.Must1(freelru.NewSharded[dns.Question, *dns.Msg](cacheCapacity, maphash.NewHasher[dns.Question]().Hash32))
|
timeout: options.Timeout,
|
||||||
} else {
|
disableCache: options.DisableCache,
|
||||||
client.transportCache = common.Must1(freelru.NewSharded[transportCacheKey, *dns.Msg](cacheCapacity, maphash.NewHasher[transportCacheKey]().Hash32))
|
disableExpire: options.DisableExpire,
|
||||||
}
|
optimisticTimeout: options.OptimisticTimeout,
|
||||||
|
cacheCapacity: cacheCapacity,
|
||||||
|
clientSubnet: options.ClientSubnet,
|
||||||
|
initRDRCFunc: options.RDRC,
|
||||||
|
initDNSCacheFunc: options.DNSCache,
|
||||||
|
logger: options.Logger,
|
||||||
|
}
|
||||||
|
if client.timeout == 0 {
|
||||||
|
client.timeout = C.DNSTimeout
|
||||||
|
}
|
||||||
|
if !client.disableCache && client.initDNSCacheFunc == nil {
|
||||||
|
client.initializeMemoryCache()
|
||||||
}
|
}
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
type transportCacheKey struct {
|
type dnsCacheKey struct {
|
||||||
dns.Question
|
dns.Question
|
||||||
transportTag string
|
transportTag string
|
||||||
}
|
}
|
||||||
@@ -93,6 +95,19 @@ func (c *Client) Start() {
|
|||||||
if c.initRDRCFunc != nil {
|
if c.initRDRCFunc != nil {
|
||||||
c.rdrc = c.initRDRCFunc()
|
c.rdrc = c.initRDRCFunc()
|
||||||
}
|
}
|
||||||
|
if c.initDNSCacheFunc != nil {
|
||||||
|
c.dnsCache = c.initDNSCacheFunc()
|
||||||
|
}
|
||||||
|
if c.dnsCache == nil {
|
||||||
|
c.initializeMemoryCache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) initializeMemoryCache() {
|
||||||
|
if c.disableCache || c.cache != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.cache = common.Must1(freelru.NewSharded[dnsCacheKey, *dns.Msg](c.cacheCapacity, maphash.NewHasher[dnsCacheKey]().Hash32))
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractNegativeTTL(response *dns.Msg) (uint32, bool) {
|
func extractNegativeTTL(response *dns.Msg) (uint32, bool) {
|
||||||
@@ -109,7 +124,38 @@ func extractNegativeTTL(response *dns.Msg) (uint32, bool) {
|
|||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error) {
|
func computeTimeToLive(response *dns.Msg) uint32 {
|
||||||
|
var timeToLive uint32
|
||||||
|
if len(response.Answer) == 0 {
|
||||||
|
if soaTTL, hasSOA := extractNegativeTTL(response); hasSOA {
|
||||||
|
return soaTTL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
|
||||||
|
for _, record := range recordList {
|
||||||
|
if record.Header().Rrtype == dns.TypeOPT {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if timeToLive == 0 || record.Header().Ttl > 0 && record.Header().Ttl < timeToLive {
|
||||||
|
timeToLive = record.Header().Ttl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return timeToLive
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeTTL(response *dns.Msg, timeToLive uint32) {
|
||||||
|
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
|
||||||
|
for _, record := range recordList {
|
||||||
|
if record.Header().Rrtype == dns.TypeOPT {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
record.Header().Ttl = timeToLive
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) (*dns.Msg, error) {
|
||||||
if len(message.Question) == 0 {
|
if len(message.Question) == 0 {
|
||||||
if c.logger != nil {
|
if c.logger != nil {
|
||||||
c.logger.WarnContext(ctx, "bad question size: ", len(message.Question))
|
c.logger.WarnContext(ctx, "bad question size: ", len(message.Question))
|
||||||
@@ -123,13 +169,7 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
|
|||||||
}
|
}
|
||||||
return FixedResponseStatus(message, dns.RcodeSuccess), nil
|
return FixedResponseStatus(message, dns.RcodeSuccess), nil
|
||||||
}
|
}
|
||||||
clientSubnet := options.ClientSubnet
|
message = c.prepareExchangeMessage(message, options)
|
||||||
if !clientSubnet.IsValid() {
|
|
||||||
clientSubnet = c.clientSubnet
|
|
||||||
}
|
|
||||||
if clientSubnet.IsValid() {
|
|
||||||
message = SetClientSubnet(message, clientSubnet)
|
|
||||||
}
|
|
||||||
|
|
||||||
isSimpleRequest := len(message.Question) == 1 &&
|
isSimpleRequest := len(message.Question) == 1 &&
|
||||||
len(message.Ns) == 0 &&
|
len(message.Ns) == 0 &&
|
||||||
@@ -141,40 +181,32 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
|
|||||||
!options.ClientSubnet.IsValid()
|
!options.ClientSubnet.IsValid()
|
||||||
disableCache := !isSimpleRequest || c.disableCache || options.DisableCache
|
disableCache := !isSimpleRequest || c.disableCache || options.DisableCache
|
||||||
if !disableCache {
|
if !disableCache {
|
||||||
if c.cache != nil {
|
cacheKey := dnsCacheKey{Question: question, transportTag: transport.Tag()}
|
||||||
cond, loaded := c.cacheLock.LoadOrStore(question, make(chan struct{}))
|
cond, loaded := c.cacheLock.LoadOrStore(cacheKey, make(chan struct{}))
|
||||||
if loaded {
|
if loaded {
|
||||||
select {
|
select {
|
||||||
case <-cond:
|
case <-cond:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
}
|
|
||||||
} else {
|
|
||||||
defer func() {
|
|
||||||
c.cacheLock.Delete(question)
|
|
||||||
close(cond)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
} else if c.transportCache != nil {
|
|
||||||
cond, loaded := c.transportCacheLock.LoadOrStore(question, make(chan struct{}))
|
|
||||||
if loaded {
|
|
||||||
select {
|
|
||||||
case <-cond:
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
defer func() {
|
|
||||||
c.transportCacheLock.Delete(question)
|
|
||||||
close(cond)
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
defer func() {
|
||||||
|
c.cacheLock.Delete(cacheKey)
|
||||||
|
close(cond)
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
response, ttl := c.loadResponse(question, transport)
|
response, ttl, isStale := c.loadResponse(question, transport)
|
||||||
if response != nil {
|
if response != nil {
|
||||||
logCachedResponse(c.logger, ctx, response, ttl)
|
if isStale && !options.DisableOptimisticCache {
|
||||||
response.Id = message.Id
|
c.backgroundRefreshDNS(transport, question, message.Copy(), options, responseChecker)
|
||||||
return response, nil
|
logOptimisticResponse(c.logger, ctx, response)
|
||||||
|
response.Id = message.Id
|
||||||
|
return response, nil
|
||||||
|
} else if !isStale {
|
||||||
|
logCachedResponse(c.logger, ctx, response, ttl)
|
||||||
|
response.Id = message.Id
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,62 +222,17 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
|
|||||||
return nil, ErrResponseRejectedCached
|
return nil, ErrResponseRejectedCached
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(ctx, c.timeout)
|
response, err := c.exchangeToTransport(ctx, transport, message)
|
||||||
response, err := transport.Exchange(ctx, message)
|
|
||||||
cancel()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var rcodeError RcodeError
|
return nil, err
|
||||||
if errors.As(err, &rcodeError) {
|
|
||||||
response = FixedResponseStatus(message, int(rcodeError))
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
/*if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
|
|
||||||
validResponse := response
|
|
||||||
loop:
|
|
||||||
for {
|
|
||||||
var (
|
|
||||||
addresses int
|
|
||||||
queryCNAME string
|
|
||||||
)
|
|
||||||
for _, rawRR := range validResponse.Answer {
|
|
||||||
switch rr := rawRR.(type) {
|
|
||||||
case *dns.A:
|
|
||||||
break loop
|
|
||||||
case *dns.AAAA:
|
|
||||||
break loop
|
|
||||||
case *dns.CNAME:
|
|
||||||
queryCNAME = rr.Target
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if queryCNAME == "" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
exMessage := *message
|
|
||||||
exMessage.Question = []dns.Question{{
|
|
||||||
Name: queryCNAME,
|
|
||||||
Qtype: question.Qtype,
|
|
||||||
}}
|
|
||||||
validResponse, err = c.Exchange(ctx, transport, &exMessage, options, responseChecker)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if validResponse != response {
|
|
||||||
response.Answer = append(response.Answer, validResponse.Answer...)
|
|
||||||
}
|
|
||||||
}*/
|
|
||||||
disableCache = disableCache || (response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError)
|
disableCache = disableCache || (response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError)
|
||||||
if responseChecker != nil {
|
if responseChecker != nil {
|
||||||
var rejected bool
|
var rejected bool
|
||||||
// TODO: add accept_any rule and support to check response instead of addresses
|
|
||||||
if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError {
|
if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError {
|
||||||
rejected = true
|
rejected = true
|
||||||
} else if len(response.Answer) == 0 {
|
|
||||||
rejected = !responseChecker(nil)
|
|
||||||
} else {
|
} else {
|
||||||
rejected = !responseChecker(MessageToAddresses(response))
|
rejected = !responseChecker(response)
|
||||||
}
|
}
|
||||||
if rejected {
|
if rejected {
|
||||||
if !disableCache && c.rdrc != nil {
|
if !disableCache && c.rdrc != nil {
|
||||||
@@ -255,48 +242,7 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
|
|||||||
return response, ErrResponseRejected
|
return response, ErrResponseRejected
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if question.Qtype == dns.TypeHTTPS {
|
timeToLive := applyResponseOptions(question, response, options)
|
||||||
if options.Strategy == C.DomainStrategyIPv4Only || options.Strategy == C.DomainStrategyIPv6Only {
|
|
||||||
for _, rr := range response.Answer {
|
|
||||||
https, isHTTPS := rr.(*dns.HTTPS)
|
|
||||||
if !isHTTPS {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
content := https.SVCB
|
|
||||||
content.Value = common.Filter(content.Value, func(it dns.SVCBKeyValue) bool {
|
|
||||||
if options.Strategy == C.DomainStrategyIPv4Only {
|
|
||||||
return it.Key() != dns.SVCB_IPV6HINT
|
|
||||||
} else {
|
|
||||||
return it.Key() != dns.SVCB_IPV4HINT
|
|
||||||
}
|
|
||||||
})
|
|
||||||
https.SVCB = content
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var timeToLive uint32
|
|
||||||
if len(response.Answer) == 0 {
|
|
||||||
if soaTTL, hasSOA := extractNegativeTTL(response); hasSOA {
|
|
||||||
timeToLive = soaTTL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if timeToLive == 0 {
|
|
||||||
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
|
|
||||||
for _, record := range recordList {
|
|
||||||
if timeToLive == 0 || record.Header().Ttl > 0 && record.Header().Ttl < timeToLive {
|
|
||||||
timeToLive = record.Header().Ttl
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if options.RewriteTTL != nil {
|
|
||||||
timeToLive = *options.RewriteTTL
|
|
||||||
}
|
|
||||||
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
|
|
||||||
for _, record := range recordList {
|
|
||||||
record.Header().Ttl = timeToLive
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !disableCache {
|
if !disableCache {
|
||||||
c.storeCache(transport, question, response, timeToLive)
|
c.storeCache(transport, question, response, timeToLive)
|
||||||
}
|
}
|
||||||
@@ -315,7 +261,7 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
|
|||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
|
func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) ([]netip.Addr, error) {
|
||||||
domain = FqdnToDomain(domain)
|
domain = FqdnToDomain(domain)
|
||||||
dnsName := dns.Fqdn(domain)
|
dnsName := dns.Fqdn(domain)
|
||||||
var strategy C.DomainStrategy
|
var strategy C.DomainStrategy
|
||||||
@@ -362,8 +308,12 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
|
|||||||
func (c *Client) ClearCache() {
|
func (c *Client) ClearCache() {
|
||||||
if c.cache != nil {
|
if c.cache != nil {
|
||||||
c.cache.Purge()
|
c.cache.Purge()
|
||||||
} else if c.transportCache != nil {
|
}
|
||||||
c.transportCache.Purge()
|
if c.dnsCache != nil {
|
||||||
|
err := c.dnsCache.ClearDNSCache()
|
||||||
|
if err != nil && c.logger != nil {
|
||||||
|
c.logger.Warn("clear DNS cache: ", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -379,46 +329,44 @@ func (c *Client) storeCache(transport adapter.DNSTransport, question dns.Questio
|
|||||||
if timeToLive == 0 {
|
if timeToLive == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if c.dnsCache != nil {
|
||||||
|
packed, err := message.Pack()
|
||||||
|
if err == nil {
|
||||||
|
expireAt := time.Now().Add(time.Second * time.Duration(timeToLive))
|
||||||
|
c.dnsCache.SaveDNSCacheAsync(transport.Tag(), question.Name, question.Qtype, packed, expireAt, c.logger)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c.cache == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key := dnsCacheKey{Question: question, transportTag: transport.Tag()}
|
||||||
if c.disableExpire {
|
if c.disableExpire {
|
||||||
if !c.independentCache {
|
c.cache.Add(key, message.Copy())
|
||||||
c.cache.Add(question, message)
|
|
||||||
} else {
|
|
||||||
c.transportCache.Add(transportCacheKey{
|
|
||||||
Question: question,
|
|
||||||
transportTag: transport.Tag(),
|
|
||||||
}, message)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if !c.independentCache {
|
c.cache.AddWithLifetime(key, message.Copy(), time.Second*time.Duration(timeToLive))
|
||||||
c.cache.AddWithLifetime(question, message, time.Second*time.Duration(timeToLive))
|
|
||||||
} else {
|
|
||||||
c.transportCache.AddWithLifetime(transportCacheKey{
|
|
||||||
Question: question,
|
|
||||||
transportTag: transport.Tag(),
|
|
||||||
}, message, time.Second*time.Duration(timeToLive))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTransport, name string, qType uint16, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
|
func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTransport, name string, qType uint16, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) ([]netip.Addr, error) {
|
||||||
question := dns.Question{
|
question := dns.Question{
|
||||||
Name: name,
|
Name: name,
|
||||||
Qtype: qType,
|
Qtype: qType,
|
||||||
Qclass: dns.ClassINET,
|
Qclass: dns.ClassINET,
|
||||||
}
|
}
|
||||||
disableCache := c.disableCache || options.DisableCache
|
|
||||||
if !disableCache {
|
|
||||||
cachedAddresses, err := c.questionCache(question, transport)
|
|
||||||
if err != ErrNotCached {
|
|
||||||
return cachedAddresses, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
message := dns.Msg{
|
message := dns.Msg{
|
||||||
MsgHdr: dns.MsgHdr{
|
MsgHdr: dns.MsgHdr{
|
||||||
RecursionDesired: true,
|
RecursionDesired: true,
|
||||||
},
|
},
|
||||||
Question: []dns.Question{question},
|
Question: []dns.Question{question},
|
||||||
}
|
}
|
||||||
|
disableCache := c.disableCache || options.DisableCache
|
||||||
|
if !disableCache {
|
||||||
|
cachedAddresses, err := c.questionCache(ctx, transport, &message, options, responseChecker)
|
||||||
|
if err != ErrNotCached {
|
||||||
|
return cachedAddresses, err
|
||||||
|
}
|
||||||
|
}
|
||||||
response, err := c.Exchange(ctx, transport, &message, options, responseChecker)
|
response, err := c.Exchange(ctx, transport, &message, options, responseChecker)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -429,111 +377,181 @@ func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTran
|
|||||||
return MessageToAddresses(response), nil
|
return MessageToAddresses(response), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) questionCache(question dns.Question, transport adapter.DNSTransport) ([]netip.Addr, error) {
|
func (c *Client) questionCache(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) ([]netip.Addr, error) {
|
||||||
response, _ := c.loadResponse(question, transport)
|
question := message.Question[0]
|
||||||
|
response, _, isStale := c.loadResponse(question, transport)
|
||||||
if response == nil {
|
if response == nil {
|
||||||
return nil, ErrNotCached
|
return nil, ErrNotCached
|
||||||
}
|
}
|
||||||
|
if isStale {
|
||||||
|
if options.DisableOptimisticCache {
|
||||||
|
return nil, ErrNotCached
|
||||||
|
}
|
||||||
|
c.backgroundRefreshDNS(transport, question, c.prepareExchangeMessage(message.Copy(), options), options, responseChecker)
|
||||||
|
logOptimisticResponse(c.logger, ctx, response)
|
||||||
|
}
|
||||||
if response.Rcode != dns.RcodeSuccess {
|
if response.Rcode != dns.RcodeSuccess {
|
||||||
return nil, RcodeError(response.Rcode)
|
return nil, RcodeError(response.Rcode)
|
||||||
}
|
}
|
||||||
return MessageToAddresses(response), nil
|
return MessageToAddresses(response), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) loadResponse(question dns.Question, transport adapter.DNSTransport) (*dns.Msg, int) {
|
func (c *Client) loadResponse(question dns.Question, transport adapter.DNSTransport) (*dns.Msg, int, bool) {
|
||||||
var (
|
if c.dnsCache != nil {
|
||||||
response *dns.Msg
|
return c.loadPersistentResponse(question, transport)
|
||||||
loaded bool
|
|
||||||
)
|
|
||||||
if c.disableExpire {
|
|
||||||
if !c.independentCache {
|
|
||||||
response, loaded = c.cache.Get(question)
|
|
||||||
} else {
|
|
||||||
response, loaded = c.transportCache.Get(transportCacheKey{
|
|
||||||
Question: question,
|
|
||||||
transportTag: transport.Tag(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if !loaded {
|
|
||||||
return nil, 0
|
|
||||||
}
|
|
||||||
return response.Copy(), 0
|
|
||||||
} else {
|
|
||||||
var expireAt time.Time
|
|
||||||
if !c.independentCache {
|
|
||||||
response, expireAt, loaded = c.cache.GetWithLifetime(question)
|
|
||||||
} else {
|
|
||||||
response, expireAt, loaded = c.transportCache.GetWithLifetime(transportCacheKey{
|
|
||||||
Question: question,
|
|
||||||
transportTag: transport.Tag(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if !loaded {
|
|
||||||
return nil, 0
|
|
||||||
}
|
|
||||||
timeNow := time.Now()
|
|
||||||
if timeNow.After(expireAt) {
|
|
||||||
if !c.independentCache {
|
|
||||||
c.cache.Remove(question)
|
|
||||||
} else {
|
|
||||||
c.transportCache.Remove(transportCacheKey{
|
|
||||||
Question: question,
|
|
||||||
transportTag: transport.Tag(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return nil, 0
|
|
||||||
}
|
|
||||||
var originTTL int
|
|
||||||
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
|
|
||||||
for _, record := range recordList {
|
|
||||||
if originTTL == 0 || record.Header().Ttl > 0 && int(record.Header().Ttl) < originTTL {
|
|
||||||
originTTL = int(record.Header().Ttl)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
nowTTL := int(expireAt.Sub(timeNow).Seconds())
|
|
||||||
if nowTTL < 0 {
|
|
||||||
nowTTL = 0
|
|
||||||
}
|
|
||||||
response = response.Copy()
|
|
||||||
if originTTL > 0 {
|
|
||||||
duration := uint32(originTTL - nowTTL)
|
|
||||||
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
|
|
||||||
for _, record := range recordList {
|
|
||||||
record.Header().Ttl = record.Header().Ttl - duration
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
|
|
||||||
for _, record := range recordList {
|
|
||||||
record.Header().Ttl = uint32(nowTTL)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return response, nowTTL
|
|
||||||
}
|
}
|
||||||
|
if c.cache == nil {
|
||||||
|
return nil, 0, false
|
||||||
|
}
|
||||||
|
key := dnsCacheKey{Question: question, transportTag: transport.Tag()}
|
||||||
|
if c.disableExpire {
|
||||||
|
response, loaded := c.cache.Get(key)
|
||||||
|
if !loaded {
|
||||||
|
return nil, 0, false
|
||||||
|
}
|
||||||
|
return response.Copy(), 0, false
|
||||||
|
}
|
||||||
|
response, expireAt, loaded := c.cache.GetWithLifetimeNoExpire(key)
|
||||||
|
if !loaded {
|
||||||
|
return nil, 0, false
|
||||||
|
}
|
||||||
|
timeNow := time.Now()
|
||||||
|
if timeNow.After(expireAt) {
|
||||||
|
if c.optimisticTimeout > 0 && timeNow.Before(expireAt.Add(c.optimisticTimeout)) {
|
||||||
|
response = response.Copy()
|
||||||
|
normalizeTTL(response, 1)
|
||||||
|
return response, 0, true
|
||||||
|
}
|
||||||
|
c.cache.Remove(key)
|
||||||
|
return nil, 0, false
|
||||||
|
}
|
||||||
|
nowTTL := int(expireAt.Sub(timeNow).Seconds())
|
||||||
|
if nowTTL < 0 {
|
||||||
|
nowTTL = 0
|
||||||
|
}
|
||||||
|
response = response.Copy()
|
||||||
|
normalizeTTL(response, uint32(nowTTL))
|
||||||
|
return response, nowTTL, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) loadPersistentResponse(question dns.Question, transport adapter.DNSTransport) (*dns.Msg, int, bool) {
|
||||||
|
rawMessage, expireAt, loaded := c.dnsCache.LoadDNSCache(transport.Tag(), question.Name, question.Qtype)
|
||||||
|
if !loaded {
|
||||||
|
return nil, 0, false
|
||||||
|
}
|
||||||
|
response := new(dns.Msg)
|
||||||
|
err := response.Unpack(rawMessage)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, false
|
||||||
|
}
|
||||||
|
if c.disableExpire {
|
||||||
|
return response, 0, false
|
||||||
|
}
|
||||||
|
timeNow := time.Now()
|
||||||
|
if timeNow.After(expireAt) {
|
||||||
|
if c.optimisticTimeout > 0 && timeNow.Before(expireAt.Add(c.optimisticTimeout)) {
|
||||||
|
normalizeTTL(response, 1)
|
||||||
|
return response, 0, true
|
||||||
|
}
|
||||||
|
return nil, 0, false
|
||||||
|
}
|
||||||
|
nowTTL := int(expireAt.Sub(timeNow).Seconds())
|
||||||
|
if nowTTL < 0 {
|
||||||
|
nowTTL = 0
|
||||||
|
}
|
||||||
|
normalizeTTL(response, uint32(nowTTL))
|
||||||
|
return response, nowTTL, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyResponseOptions(question dns.Question, response *dns.Msg, options adapter.DNSQueryOptions) uint32 {
|
||||||
|
if question.Qtype == dns.TypeHTTPS && (options.Strategy == C.DomainStrategyIPv4Only || options.Strategy == C.DomainStrategyIPv6Only) {
|
||||||
|
for _, rr := range response.Answer {
|
||||||
|
https, isHTTPS := rr.(*dns.HTTPS)
|
||||||
|
if !isHTTPS {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content := https.SVCB
|
||||||
|
content.Value = common.Filter(content.Value, func(it dns.SVCBKeyValue) bool {
|
||||||
|
if options.Strategy == C.DomainStrategyIPv4Only {
|
||||||
|
return it.Key() != dns.SVCB_IPV6HINT
|
||||||
|
}
|
||||||
|
return it.Key() != dns.SVCB_IPV4HINT
|
||||||
|
})
|
||||||
|
https.SVCB = content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
timeToLive := computeTimeToLive(response)
|
||||||
|
if options.RewriteTTL != nil {
|
||||||
|
timeToLive = *options.RewriteTTL
|
||||||
|
}
|
||||||
|
normalizeTTL(response, timeToLive)
|
||||||
|
return timeToLive
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) backgroundRefreshDNS(transport adapter.DNSTransport, question dns.Question, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) {
|
||||||
|
key := dnsCacheKey{Question: question, transportTag: transport.Tag()}
|
||||||
|
_, loaded := c.backgroundRefresh.LoadOrStore(key, struct{}{})
|
||||||
|
if loaded {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer c.backgroundRefresh.Delete(key)
|
||||||
|
ctx := contextWithTransportTag(c.ctx, transport.Tag())
|
||||||
|
response, err := c.exchangeToTransport(ctx, transport, message)
|
||||||
|
if err != nil {
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Debug("optimistic refresh failed for ", FqdnToDomain(question.Name), ": ", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if responseChecker != nil {
|
||||||
|
var rejected bool
|
||||||
|
if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError {
|
||||||
|
rejected = true
|
||||||
|
} else {
|
||||||
|
rejected = !responseChecker(response)
|
||||||
|
}
|
||||||
|
if rejected {
|
||||||
|
if c.rdrc != nil {
|
||||||
|
c.rdrc.SaveRDRCAsync(transport.Tag(), question.Name, question.Qtype, c.logger)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
timeToLive := applyResponseOptions(question, response, options)
|
||||||
|
c.storeCache(transport, question, response, timeToLive)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) prepareExchangeMessage(message *dns.Msg, options adapter.DNSQueryOptions) *dns.Msg {
|
||||||
|
clientSubnet := options.ClientSubnet
|
||||||
|
if !clientSubnet.IsValid() {
|
||||||
|
clientSubnet = c.clientSubnet
|
||||||
|
}
|
||||||
|
if clientSubnet.IsValid() {
|
||||||
|
message = SetClientSubnet(message, clientSubnet)
|
||||||
|
}
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) exchangeToTransport(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg) (*dns.Msg, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, c.timeout)
|
||||||
|
defer cancel()
|
||||||
|
response, err := transport.Exchange(ctx, message)
|
||||||
|
if err == nil {
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
var rcodeError RcodeError
|
||||||
|
if errors.As(err, &rcodeError) {
|
||||||
|
return FixedResponseStatus(message, int(rcodeError)), nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func MessageToAddresses(response *dns.Msg) []netip.Addr {
|
func MessageToAddresses(response *dns.Msg) []netip.Addr {
|
||||||
if response == nil || response.Rcode != dns.RcodeSuccess {
|
return adapter.DNSResponseAddresses(response)
|
||||||
return nil
|
|
||||||
}
|
|
||||||
addresses := make([]netip.Addr, 0, len(response.Answer))
|
|
||||||
for _, rawAnswer := range response.Answer {
|
|
||||||
switch answer := rawAnswer.(type) {
|
|
||||||
case *dns.A:
|
|
||||||
addresses = append(addresses, M.AddrFromIP(answer.A))
|
|
||||||
case *dns.AAAA:
|
|
||||||
addresses = append(addresses, M.AddrFromIP(answer.AAAA))
|
|
||||||
case *dns.HTTPS:
|
|
||||||
for _, value := range answer.SVCB.Value {
|
|
||||||
if value.Key() == dns.SVCB_IPV4HINT || value.Key() == dns.SVCB_IPV6HINT {
|
|
||||||
addresses = append(addresses, common.Map(strings.Split(value.String(), ","), M.ParseAddr)...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return addresses
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func wrapError(err error) error {
|
func wrapError(err error) error {
|
||||||
|
|||||||
@@ -22,6 +22,19 @@ func logCachedResponse(logger logger.ContextLogger, ctx context.Context, respons
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func logOptimisticResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg) {
|
||||||
|
if logger == nil || len(response.Question) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
domain := FqdnToDomain(response.Question[0].Name)
|
||||||
|
logger.DebugContext(ctx, "optimistic ", domain, " ", dns.RcodeToString[response.Rcode])
|
||||||
|
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
|
||||||
|
for _, record := range recordList {
|
||||||
|
logger.InfoContext(ctx, "optimistic ", dns.Type(record.Header().Rrtype).String(), " ", FormatQuestion(record.String()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func logExchangedResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg, ttl uint32) {
|
func logExchangedResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg, ttl uint32) {
|
||||||
if logger == nil || len(response.Question) == 0 {
|
if logger == nil || len(response.Question) == 0 {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -5,10 +5,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RcodeSuccess RcodeError = mDNS.RcodeSuccess
|
RcodeSuccess RcodeError = mDNS.RcodeSuccess
|
||||||
RcodeFormatError RcodeError = mDNS.RcodeFormatError
|
RcodeServerFailure RcodeError = mDNS.RcodeServerFailure
|
||||||
RcodeNameError RcodeError = mDNS.RcodeNameError
|
RcodeFormatError RcodeError = mDNS.RcodeFormatError
|
||||||
RcodeRefused RcodeError = mDNS.RcodeRefused
|
RcodeNameError RcodeError = mDNS.RcodeNameError
|
||||||
|
RcodeRefused RcodeError = mDNS.RcodeRefused
|
||||||
)
|
)
|
||||||
|
|
||||||
type RcodeError int
|
type RcodeError int
|
||||||
|
|||||||
111
dns/repro_test.go
Normal file
111
dns/repro_test.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/json/badoption"
|
||||||
|
|
||||||
|
mDNS "github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReproLookupWithRulesUsesRequestStrategy(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
||||||
|
var qTypes []uint16
|
||||||
|
router := newTestRouter(t, nil, &fakeDNSTransportManager{
|
||||||
|
defaultTransport: defaultTransport,
|
||||||
|
transports: map[string]adapter.DNSTransport{
|
||||||
|
"default": defaultTransport,
|
||||||
|
},
|
||||||
|
}, &fakeDNSClient{
|
||||||
|
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
|
qTypes = append(qTypes, message.Question[0].Qtype)
|
||||||
|
if message.Question[0].Qtype == mDNS.TypeA {
|
||||||
|
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2.2.2.2")}, 60), nil
|
||||||
|
}
|
||||||
|
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2001:db8::1")}, 60), nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
|
||||||
|
Strategy: C.DomainStrategyIPv4Only,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []uint16{mDNS.TypeA}, qTypes)
|
||||||
|
require.Equal(t, []netip.Addr{netip.MustParseAddr("2.2.2.2")}, addresses)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReproLogicalMatchResponseIPCIDR(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
transportManager := &fakeDNSTransportManager{
|
||||||
|
defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
|
||||||
|
transports: map[string]adapter.DNSTransport{
|
||||||
|
"upstream": &fakeDNSTransport{tag: "upstream", transportType: C.DNSTypeUDP},
|
||||||
|
"selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP},
|
||||||
|
"default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := &fakeDNSClient{
|
||||||
|
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
|
switch transport.Tag() {
|
||||||
|
case "upstream":
|
||||||
|
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil
|
||||||
|
case "selected":
|
||||||
|
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil
|
||||||
|
default:
|
||||||
|
return nil, E.New("unexpected transport")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
rules := []option.DNSRule{
|
||||||
|
{
|
||||||
|
Type: C.RuleTypeDefault,
|
||||||
|
DefaultOptions: option.DefaultDNSRule{
|
||||||
|
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||||
|
Domain: badoption.Listable[string]{"example.com"},
|
||||||
|
},
|
||||||
|
DNSRuleAction: option.DNSRuleAction{
|
||||||
|
Action: C.RuleActionTypeEvaluate,
|
||||||
|
RouteOptions: option.DNSRouteActionOptions{Server: "upstream"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: C.RuleTypeLogical,
|
||||||
|
LogicalOptions: option.LogicalDNSRule{
|
||||||
|
RawLogicalDNSRule: option.RawLogicalDNSRule{
|
||||||
|
Mode: C.LogicalTypeOr,
|
||||||
|
Rules: []option.DNSRule{{
|
||||||
|
Type: C.RuleTypeDefault,
|
||||||
|
DefaultOptions: option.DefaultDNSRule{
|
||||||
|
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||||
|
MatchResponse: true,
|
||||||
|
IPCIDR: badoption.Listable[string]{"1.1.1.0/24"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
DNSRuleAction: option.DNSRuleAction{
|
||||||
|
Action: C.RuleActionTypeRoute,
|
||||||
|
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
router := newTestRouter(t, rules, transportManager, client)
|
||||||
|
|
||||||
|
response, err := router.Exchange(context.Background(), &mDNS.Msg{
|
||||||
|
Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)},
|
||||||
|
}, adapter.DNSQueryOptions{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []netip.Addr{netip.MustParseAddr("8.8.8.8")}, MessageToAddresses(response))
|
||||||
|
}
|
||||||
847
dns/router.go
847
dns/router.go
File diff suppressed because it is too large
Load Diff
2547
dns/router_test.go
Normal file
2547
dns/router_test.go
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user