Compare commits

..

85 Commits

Author SHA1 Message Date
世界
43f96c4279 documentation: Bump version 2025-02-05 10:08:04 +08:00
世界
d3fc35a529 Use x/net publicsuffix library instead 2025-02-05 10:08:04 +08:00
世界
bc200170b2 Fix UDP DNS not saving connections 2025-02-05 10:08:04 +08:00
世界
6158af1858 Fix dns router crash 2025-02-05 10:08:04 +08:00
世界
f4bab7ee92 Fix fakeip not started (2) 2025-02-05 10:08:04 +08:00
世界
e5bc72f8ed Fix WireGuard panic 2025-02-05 10:08:04 +08:00
世界
985a48d399 Fix query options leakage 2025-02-05 10:08:04 +08:00
世界
f34c3c937e Fix domain resolver for DNS server 2025-02-05 10:08:04 +08:00
世界
faeed289a8 documentation: Fix fakeip example 2025-02-05 10:08:04 +08:00
世界
3e6124fb17 release: Skip testflight when another build in review 2025-02-05 10:08:04 +08:00
世界
f2f4bd4efb Update quic-go to v0.49.0-beta.1 2025-02-05 10:08:04 +08:00
世界
87317565a3 Fix fakeip not started 2025-02-05 10:08:04 +08:00
世界
146793bbe5 Fix missing default domain resolver in route 2025-02-05 10:08:04 +08:00
世界
a0b6fcf5c6 Fix missing default store value 2025-02-05 10:08:04 +08:00
世界
5cf02687e1 documentation: Remove outdated icons 2025-02-05 10:04:34 +08:00
世界
29d157d383 documentation: Certificate store 2025-02-05 10:04:33 +08:00
世界
89afd175e0 documentation: TLS fragment 2025-02-05 10:04:33 +08:00
世界
961277a988 documentation: Outbound domain resolver 2025-02-05 10:04:33 +08:00
世界
1e9dfb8335 documentation: Refactor DNS 2025-02-05 10:04:32 +08:00
世界
9982f29cd7 Add certificate store 2025-02-05 10:04:32 +08:00
世界
fc90e49924 Add TLS fragment support 2025-02-05 10:04:22 +08:00
世界
c017a69bc9 refactor: Outbound domain resolver 2025-02-05 10:04:22 +08:00
世界
4377d381c8 refactor: DNS 2025-02-05 10:04:21 +08:00
世界
92d245ad04 Bump version 2025-02-05 09:59:52 +08:00
世界
0908627297 Fix crash on remote rule-set stop 2025-02-05 08:58:10 +08:00
世界
7f79458b4f Minor updates 2025-02-01 19:49:33 +08:00
世界
9b4c11ba95 Fix rule-set not closed 2025-02-01 19:49:33 +08:00
世界
27c31eac5d Fix local rule-set not updated 2025-02-01 19:42:21 +08:00
世界
bab8dc0b82 Fix missing handshake for early conn 2025-01-31 12:57:35 +08:00
世界
d09d2fb665 Bump version 2025-01-30 15:09:54 +08:00
世界
e64cf3b7df Do not set address sets to routes on Apple platforms
Network Extension was observed to stop for unknown reasons
2025-01-27 13:40:39 +08:00
世界
9b73222314 documentation: Bump version 2025-01-27 10:53:14 +08:00
世界
3923b57abf release: Update NDK to r28-beta3 2025-01-27 10:53:14 +08:00
世界
4807e64609 documentation: Fix typo 2025-01-27 10:04:07 +08:00
世界
eeb37d89f1 Fix rule-set upgrade command 2025-01-27 09:50:27 +08:00
世界
08c1ec4b7e Fix legacy routes 2025-01-27 09:50:27 +08:00
HystericalDragon
6b4cf67add Fix endpoints not close 2025-01-27 09:50:27 +08:00
世界
e65926fd08 Fix tests 2025-01-13 15:14:30 +08:00
世界
f2ec319fe1 Fix system time 2025-01-13 15:14:30 +08:00
世界
32377a61b7 Add port hopping for hysteria2 2025-01-13 15:14:30 +08:00
世界
7aac801ccd tun: Set address sets to routes 2025-01-13 15:14:30 +08:00
世界
96fdf59ee4 Fix default dialer on legacy xiaomi systems 2025-01-13 15:14:30 +08:00
世界
50b8f3ab94 Add rule-set merge command 2025-01-13 15:14:30 +08:00
世界
ff7aaf977b Fix DNS match 2025-01-13 15:14:30 +08:00
世界
9a1efbe54d Fix domain strategy 2025-01-13 15:14:30 +08:00
世界
906c21f458 Fix time service 2025-01-13 15:14:30 +08:00
世界
d5e7af7a7e Fix socks5 UDP implementation 2025-01-13 15:14:30 +08:00
世界
4d41f03bd5 clash-api: Fix missing endpoints 2025-01-13 15:14:30 +08:00
世界
30704a15a7 hysteria2: Add more masquerade options 2025-01-13 15:14:30 +08:00
世界
83889178ed Improve timeouts 2025-01-13 15:14:30 +08:00
世界
1d2720bf5e Add UDP timeout route option 2025-01-13 15:14:30 +08:00
世界
c4b6d0eadb Make GSO adaptive 2025-01-13 15:14:30 +08:00
世界
0c66888691 Fix lint 2025-01-13 15:14:30 +08:00
世界
68781387fe refactor: WireGuard endpoint 2025-01-13 15:14:30 +08:00
世界
fd299a0961 refactor: connection manager 2025-01-13 15:14:30 +08:00
世界
285a82050c documentation: Fix typo 2025-01-13 15:14:30 +08:00
世界
2dbb8c55c9 Add override destination to route options 2025-01-13 15:14:30 +08:00
世界
effcf39469 Add dns.cache_capacity 2025-01-13 15:14:30 +08:00
世界
9db9484863 Refactor multi networks strategy 2025-01-13 15:14:30 +08:00
世界
ca813f461b documentation: Remove unused titles 2025-01-13 15:14:30 +08:00
世界
bb46cdb2b3 Add multi network dialing 2025-01-13 15:14:30 +08:00
世界
dcb10c21a1 documentation: Merge route options to route actions 2025-01-13 15:14:29 +08:00
世界
05ea0ca00e Add network_[type/is_expensive/is_constrained] rule items 2025-01-13 15:14:29 +08:00
世界
c098f282b1 Merge route options to route actions 2025-01-13 15:14:29 +08:00
世界
ecf82d197c refactor: Platform Interfaces 2025-01-13 15:14:29 +08:00
世界
9afe75586a refactor: Extract services form router 2025-01-13 15:14:29 +08:00
世界
a1be455202 refactor: Modular network manager 2025-01-13 15:14:29 +08:00
世界
19fb214226 refactor: Modular inbound/outbound manager 2025-01-13 15:14:29 +08:00
世界
28ec898a8c documentation: Add rule action 2025-01-13 15:14:29 +08:00
世界
467b1bbeeb documentation: Update the scheduled removal time of deprecated features 2025-01-13 15:14:29 +08:00
世界
02ab8ce806 documentation: Remove outdated icons 2025-01-13 15:14:29 +08:00
世界
ce69e620e9 Migrate bad options to library 2025-01-13 15:14:29 +08:00
世界
1133cf3ef5 Implement udp connect 2025-01-13 15:14:29 +08:00
世界
59a607e303 Implement new deprecated warnings 2025-01-13 15:14:29 +08:00
世界
313be3d7a4 Improve rule actions 2025-01-13 15:14:29 +08:00
世界
4fe40fcee0 Remove unused reject methods 2025-01-13 15:14:29 +08:00
世界
e233fd4fe5 refactor: Modular inbounds/outbounds 2025-01-13 15:14:29 +08:00
世界
9f7683818f Implement dns-hijack 2025-01-13 15:14:29 +08:00
世界
179e3cb2f5 Implement resolve(server) 2025-01-13 15:14:29 +08:00
世界
41b960552d Implement TCP and ICMP rejects 2025-01-13 15:14:29 +08:00
世界
8304295c48 Crazy sekai overturns the small pond 2025-01-13 15:14:29 +08:00
世界
253b41936e Bump version 2025-01-11 20:58:00 +08:00
世界
ce5b4b06b5 auto-redirect: Fix fetch interfaces 2025-01-07 21:24:52 +08:00
世界
50f5006c43 Fix leak in reality server 2025-01-07 17:27:10 +08:00
世界
e42ff22c2e quic: Fix source not unwrapped 2025-01-07 17:27:10 +08:00
235 changed files with 13477 additions and 3798 deletions

View File

@@ -144,7 +144,7 @@ jobs:
~/go/go1.20.14
key: go120
- name: Setup legacy Go
if: matrix.require_legacy_go == 'true' && steps.cache-legacy-go.outputs.cache-hit != 'true'
if: matrix.require_legacy_go && steps.cache-legacy-go.outputs.cache-hit != 'true'
run: |-
wget https://dl.google.com/go/go1.20.14.linux-amd64.tar.gz
tar -xzf go1.20.14.linux-amd64.tar.gz
@@ -159,7 +159,7 @@ jobs:
uses: goreleaser/goreleaser-action@v6
with:
distribution: goreleaser-pro
version: latest
version: 2.5.1
install-only: true
- name: Extract signing key
run: |-
@@ -224,7 +224,7 @@ jobs:
id: setup-ndk
uses: nttld/setup-ndk@v1
with:
ndk-version: r28-beta2
ndk-version: r28-beta3
- name: Setup OpenJDK
run: |-
sudo apt update && sudo apt install -y openjdk-17-jdk-headless
@@ -256,8 +256,7 @@ jobs:
with:
path: ~/.gradle
key: gradle-${{ hashFiles('**/*.gradle') }}
- name: Build release
if: github.event_name == 'workflow_dispatch'
- name: Build
run: |-
go run -v ./cmd/internal/update_android_version --ci
mkdir clients/android/app/libs
@@ -268,47 +267,18 @@ jobs:
JAVA_HOME: /usr/lib/jvm/java-17-openjdk-amd64
ANDROID_NDK_HOME: ${{ steps.setup-ndk.outputs.ndk-path }}
LOCAL_PROPERTIES: ${{ secrets.LOCAL_PROPERTIES }}
- name: Build debug
if: github.event_name != 'workflow_dispatch'
run: |-
go run -v ./cmd/internal/update_android_version --ci
mkdir clients/android/app/libs
cp libbox.aar clients/android/app/libs
cd clients/android
./gradlew :app:assemblePlayRelease
env:
JAVA_HOME: /usr/lib/jvm/java-17-openjdk-amd64
ANDROID_NDK_HOME: ${{ steps.setup-ndk.outputs.ndk-path }}
LOCAL_PROPERTIES: ${{ secrets.LOCAL_PROPERTIES }}
- name: Prepare release upload
- name: Prepare upload
if: github.event_name == 'workflow_dispatch'
run: |-
mkdir -p dist/release
cp clients/android/app/build/outputs/apk/play/release/*.apk dist/release
cp clients/android/app/build/outputs/apk/other/release/*-universal.apk dist/release
- name: Prepare debug upload
if: github.event_name != 'workflow_dispatch'
run: |-
mkdir -p dist/release
cp clients/android/app/build/outputs/apk/play/release/*.apk dist/release
- name: Upload artifact
if: github.event_name == 'workflow_dispatch'
uses: actions/upload-artifact@v4
with:
name: binary-android-apks
path: 'dist'
- name: Upload debug apk (arm64-v8a)
if: github.event_name != 'workflow_dispatch'
uses: actions/upload-artifact@v4
with:
name: "SFA-${{ needs.calculate_version.outputs.version }}-arm64-v8a.apk"
path: 'dist/release/*-arm64-v8a.apk'
- name: Upload debug apk (universal)
if: github.event_name != 'workflow_dispatch'
uses: actions/upload-artifact@v4
with:
name: "SFA-${{ needs.calculate_version.outputs.version }}-universal.apk"
path: 'dist/release/*-universal.apk'
publish_android:
name: Publish Android
if: github.event_name == 'workflow_dispatch' && inputs.build == 'publish-android'
@@ -329,7 +299,7 @@ jobs:
id: setup-ndk
uses: nttld/setup-ndk@v1
with:
ndk-version: r28-beta2
ndk-version: r28-beta3
- name: Setup OpenJDK
run: |-
sudo apt update && sudo apt install -y openjdk-17-jdk-headless
@@ -578,7 +548,7 @@ jobs:
uses: goreleaser/goreleaser-action@v6
with:
distribution: goreleaser-pro
version: latest
version: 2.5.1
install-only: true
- name: Cache ghr
uses: actions/cache@v4

View File

@@ -6,7 +6,9 @@ builds:
- -v
- -trimpath
ldflags:
- -X github.com/sagernet/sing-box/constant.Version={{ .Version }} -s -w -buildid=
- -X github.com/sagernet/sing-box/constant.Version={{ .Version }}
- -s
- -buildid=
tags:
- with_gvisor
- with_quic

View File

@@ -52,7 +52,7 @@ builds:
env:
- CGO_ENABLED=0
- GOROOT={{ .Env.GOPATH }}/go1.20.14
gobinary: "{{ .Env.GOPATH }}/go1.20.14/bin/go"
tool: "{{ .Env.GOPATH }}/go1.20.14/bin/go"
targets:
- windows_amd64_v1
- windows_386

View File

@@ -61,6 +61,9 @@ proto_install:
go install -v google.golang.org/protobuf/cmd/protoc-gen-go@latest
go install -v google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest
update_certificates:
go run ./cmd/internal/update_certificates
release:
go run ./cmd/internal/build goreleaser release --clean --skip publish
mkdir dist/release

21
adapter/certificate.go Normal file
View File

@@ -0,0 +1,21 @@
package adapter
import (
"context"
"crypto/x509"
"github.com/sagernet/sing/service"
)
type CertificateStore interface {
LifecycleService
Pool() *x509.CertPool
}
func RootPoolFromContext(ctx context.Context) *x509.CertPool {
store := service.FromContext[CertificateStore](ctx)
if store == nil {
return nil
}
return store.Pool()
}

73
adapter/dns.go Normal file
View File

@@ -0,0 +1,73 @@
package adapter
import (
"context"
"net/netip"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/logger"
"github.com/miekg/dns"
)
type DNSRouter interface {
Lifecycle
Exchange(ctx context.Context, message *dns.Msg, options DNSQueryOptions) (*dns.Msg, error)
Lookup(ctx context.Context, domain string, options DNSQueryOptions) ([]netip.Addr, error)
ClearCache()
LookupReverseMapping(ip netip.Addr) (string, bool)
ResetNetwork()
}
type DNSClient interface {
Start()
Exchange(ctx context.Context, transport DNSTransport, message *dns.Msg, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error)
Lookup(ctx context.Context, transport DNSTransport, domain string, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error)
LookupCache(domain string, strategy C.DomainStrategy) ([]netip.Addr, bool)
ExchangeCache(ctx context.Context, message *dns.Msg) (*dns.Msg, bool)
ClearCache()
}
type DNSQueryOptions struct {
Transport DNSTransport
Strategy C.DomainStrategy
DisableCache bool
RewriteTTL *uint32
ClientSubnet netip.Prefix
}
type RDRCStore interface {
LoadRDRC(transportName string, qName string, qType uint16) (rejected bool)
SaveRDRC(transportName string, qName string, qType uint16) error
SaveRDRCAsync(transportName string, qName string, qType uint16, logger logger.Logger)
}
type DNSTransport interface {
Type() string
Tag() string
Dependencies() []string
Reset()
Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error)
}
type LegacyDNSTransport interface {
LegacyStrategy() C.DomainStrategy
LegacyClientSubnet() netip.Prefix
}
type DNSTransportRegistry interface {
option.DNSTransportOptionsRegistry
CreateDNSTransport(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) (DNSTransport, error)
}
type DNSTransportManager interface {
Lifecycle
Transports() []DNSTransport
Transport(tag string) (DNSTransport, bool)
Default() DNSTransport
FakeIP() FakeIPTransport
Remove(tag string) error
Create(ctx context.Context, logger log.ContextLogger, tag string, outboundType string, options any) error
}

View File

@@ -6,8 +6,6 @@ import (
"encoding/binary"
"time"
"github.com/sagernet/sing-box/common/urltest"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/varbin"
)
@@ -16,7 +14,20 @@ type ClashServer interface {
ConnectionTracker
Mode() string
ModeList() []string
HistoryStorage() *urltest.HistoryStorage
HistoryStorage() URLTestHistoryStorage
}
type URLTestHistory struct {
Time time.Time `json:"time"`
Delay uint16 `json:"delay"`
}
type URLTestHistoryStorage interface {
SetHook(hook chan<- struct{})
LoadURLTestHistory(tag string) *URLTestHistory
DeleteURLTestHistory(tag string)
StoreURLTestHistory(tag string, history *URLTestHistory)
Close() error
}
type V2RayServer interface {
@@ -31,7 +42,7 @@ type CacheFile interface {
FakeIPStorage
StoreRDRC() bool
dns.RDRCStore
RDRCStore
LoadMode() string
StoreMode(mode string) error
@@ -39,17 +50,17 @@ type CacheFile interface {
StoreSelected(group string, selected string) error
LoadGroupExpand(group string) (isExpand bool, loaded bool)
StoreGroupExpand(group string, expand bool) error
LoadRuleSet(tag string) *SavedRuleSet
SaveRuleSet(tag string, set *SavedRuleSet) error
LoadRuleSet(tag string) *SavedBinary
SaveRuleSet(tag string, set *SavedBinary) error
}
type SavedRuleSet struct {
type SavedBinary struct {
Content []byte
LastUpdated time.Time
LastEtag string
}
func (s *SavedRuleSet) MarshalBinary() ([]byte, error) {
func (s *SavedBinary) MarshalBinary() ([]byte, error) {
var buffer bytes.Buffer
err := binary.Write(&buffer, binary.BigEndian, uint8(1))
if err != nil {
@@ -70,7 +81,7 @@ func (s *SavedRuleSet) MarshalBinary() ([]byte, error) {
return buffer.Bytes(), nil
}
func (s *SavedRuleSet) UnmarshalBinary(data []byte) error {
func (s *SavedBinary) UnmarshalBinary(data []byte) error {
reader := bytes.NewReader(data)
var version uint8
err := binary.Read(reader, binary.BigEndian, &version)

View File

@@ -3,7 +3,6 @@ package adapter
import (
"net/netip"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/logger"
)
@@ -27,6 +26,6 @@ type FakeIPStorage interface {
}
type FakeIPTransport interface {
dns.Transport
DNSTransport
Store() FakeIPStore
}

View File

@@ -71,14 +71,14 @@ type InboundContext struct {
UDPDisableDomainUnmapping bool
UDPConnect bool
UDPTimeout time.Duration
TLSFragment bool
TLSFragmentFallbackDelay time.Duration
NetworkStrategy *C.NetworkStrategy
NetworkType []C.InterfaceType
FallbackNetworkType []C.InterfaceType
FallbackDelay time.Duration
DNSServer string
DestinationAddresses []netip.Addr
SourceGeoIPCode string
GeoIPCode string

View File

@@ -28,12 +28,14 @@ type NetworkManager interface {
}
type NetworkOptions struct {
NetworkStrategy *C.NetworkStrategy
NetworkType []C.InterfaceType
FallbackNetworkType []C.InterfaceType
FallbackDelay time.Duration
BindInterface string
RoutingMark uint32
BindInterface string
RoutingMark uint32
DomainResolver string
DomainResolveOptions DNSQueryOptions
NetworkStrategy *C.NetworkStrategy
NetworkType []C.InterfaceType
FallbackNetworkType []C.InterfaceType
FallbackDelay time.Duration
}
type InterfaceUpdateListener interface {

View File

@@ -23,7 +23,7 @@ type Manager struct {
registry adapter.OutboundRegistry
endpoint adapter.EndpointManager
defaultTag string
access sync.Mutex
access sync.RWMutex
started bool
stage adapter.StartStage
outbounds []adapter.Outbound
@@ -169,15 +169,15 @@ func (m *Manager) Close() error {
}
func (m *Manager) Outbounds() []adapter.Outbound {
m.access.Lock()
defer m.access.Unlock()
m.access.RLock()
defer m.access.RUnlock()
return m.outbounds
}
func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
m.access.Lock()
m.access.RLock()
outbound, found := m.outboundByTag[tag]
m.access.Unlock()
m.access.RUnlock()
if found {
return outbound, true
}
@@ -185,8 +185,8 @@ func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
}
func (m *Manager) Default() adapter.Outbound {
m.access.Lock()
defer m.access.Unlock()
m.access.RLock()
defer m.access.RUnlock()
if m.defaultOutbound != nil {
return m.defaultOutbound
} else {
@@ -196,9 +196,9 @@ func (m *Manager) Default() adapter.Outbound {
func (m *Manager) Remove(tag string) error {
m.access.Lock()
defer m.access.Unlock()
outbound, found := m.outboundByTag[tag]
if !found {
m.access.Unlock()
return os.ErrInvalid
}
delete(m.outboundByTag, tag)
@@ -232,7 +232,6 @@ func (m *Manager) Remove(tag string) error {
})
}
}
m.access.Unlock()
if started {
return common.Close(outbound)
}

View File

@@ -2,44 +2,29 @@ package adapter
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/netip"
"sync"
"github.com/sagernet/sing-box/common/geoip"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-dns"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/common/x/list"
mdns "github.com/miekg/dns"
"go4.org/netipx"
)
type Router interface {
Lifecycle
FakeIPStore() FakeIPStore
ConnectionRouter
PreMatch(metadata InboundContext) error
ConnectionRouterEx
GeoIPReader() *geoip.Reader
LoadGeosite(code string) (Rule, error)
RuleSet(tag string) (RuleSet, bool)
NeedWIFIState() bool
Exchange(ctx context.Context, message *mdns.Msg) (*mdns.Msg, error)
Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error)
LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error)
ClearDNSCache()
Rules() []Rule
SetTracker(tracker ConnectionTracker)
ResetNetwork()
}
@@ -83,12 +68,14 @@ type RuleSetMetadata struct {
ContainsIPCIDRRule bool
}
type HTTPStartContext struct {
ctx context.Context
access sync.Mutex
httpClientCache map[string]*http.Client
}
func NewHTTPStartContext() *HTTPStartContext {
func NewHTTPStartContext(ctx context.Context) *HTTPStartContext {
return &HTTPStartContext{
ctx: ctx,
httpClientCache: make(map[string]*http.Client),
}
}
@@ -106,6 +93,10 @@ func (c *HTTPStartContext) HTTPClient(detour string, dialer N.Dialer) *http.Clie
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
TLSClientConfig: &tls.Config{
Time: ntp.TimeFuncFromContext(c.ctx),
RootCAs: RootPoolFromContext(c.ctx),
},
},
}
c.httpClientCache[detour] = httpClient

View File

@@ -13,7 +13,6 @@ type Rule interface {
HeadlessRule
Service
Type() string
UpdateGeosite() error
Action() RuleAction
}

140
box.go
View File

@@ -12,11 +12,13 @@ import (
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/conntrack"
"github.com/sagernet/sing-box/common/certificate"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/taskmonitor"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport/local"
"github.com/sagernet/sing-box/experimental"
"github.com/sagernet/sing-box/experimental/cachefile"
"github.com/sagernet/sing-box/experimental/libbox/platform"
@@ -35,17 +37,19 @@ import (
var _ adapter.Service = (*Box)(nil)
type Box struct {
createdAt time.Time
logFactory log.Factory
logger log.ContextLogger
network *route.NetworkManager
endpoint *endpoint.Manager
inbound *inbound.Manager
outbound *outbound.Manager
connection *route.ConnectionManager
router *route.Router
services []adapter.LifecycleService
done chan struct{}
createdAt time.Time
logFactory log.Factory
logger log.ContextLogger
network *route.NetworkManager
endpoint *endpoint.Manager
inbound *inbound.Manager
outbound *outbound.Manager
dnsTransport *dns.TransportManager
dnsRouter *dns.Router
connection *route.ConnectionManager
router *route.Router
services []adapter.LifecycleService
done chan struct{}
}
type Options struct {
@@ -59,6 +63,7 @@ func Context(
inboundRegistry adapter.InboundRegistry,
outboundRegistry adapter.OutboundRegistry,
endpointRegistry adapter.EndpointRegistry,
dnsTransportRegistry adapter.DNSTransportRegistry,
) context.Context {
if service.FromContext[option.InboundOptionsRegistry](ctx) == nil ||
service.FromContext[adapter.InboundRegistry](ctx) == nil {
@@ -75,6 +80,10 @@ func Context(
ctx = service.ContextWith[option.EndpointOptionsRegistry](ctx, endpointRegistry)
ctx = service.ContextWith[adapter.EndpointRegistry](ctx, endpointRegistry)
}
if service.FromContext[adapter.DNSTransportRegistry](ctx) == nil {
ctx = service.ContextWith[option.DNSTransportOptionsRegistry](ctx, dnsTransportRegistry)
ctx = service.ContextWith[adapter.DNSTransportRegistry](ctx, dnsTransportRegistry)
}
return ctx
}
@@ -85,9 +94,11 @@ func New(options Options) (*Box, error) {
ctx = context.Background()
}
ctx = service.ContextWithDefaultRegistry(ctx)
endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx)
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
dnsTransportRegistry := service.FromContext[adapter.DNSTransportRegistry](ctx)
if endpointRegistry == nil {
return nil, E.New("missing endpoint registry in context")
@@ -101,10 +112,7 @@ func New(options Options) (*Box, error) {
ctx = pause.WithDefaultManager(ctx)
experimentalOptions := common.PtrValueOrDefault(options.Experimental)
debugOptions := common.PtrValueOrDefault(experimentalOptions.Debug)
applyDebugOptions(debugOptions)
ctx = conntrack.ContextWithDefaultTracker(ctx, debugOptions.OOMKiller, uint64(debugOptions.MemoryLimit))
applyDebugOptions(common.PtrValueOrDefault(experimentalOptions.Debug))
var needCacheFile bool
var needClashAPI bool
var needV2RayAPI bool
@@ -134,14 +142,32 @@ func New(options Options) (*Box, error) {
return nil, E.Cause(err, "create log factory")
}
var services []adapter.LifecycleService
certificateOptions := common.PtrValueOrDefault(options.Certificate)
if C.IsAndroid || certificateOptions.Store != "" && certificateOptions.Store != C.CertificateStoreSystem ||
len(certificateOptions.Certificate) > 0 ||
len(certificateOptions.CertificatePath) > 0 ||
len(certificateOptions.CertificateDirectoryPath) > 0 {
certificateStore, err := certificate.NewStore(ctx, logFactory.NewLogger("certificate"), certificateOptions)
if err != nil {
return nil, err
}
service.MustRegister[adapter.CertificateStore](ctx, certificateStore)
services = append(services, certificateStore)
}
routeOptions := common.PtrValueOrDefault(options.Route)
dnsOptions := common.PtrValueOrDefault(options.DNS)
endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry)
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager)
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final)
dnsTransportManager := dns.NewTransportManager(logFactory.NewLogger("dns/transport"), dnsTransportRegistry, outboundManager, dnsOptions.Final)
service.MustRegister[adapter.EndpointManager](ctx, endpointManager)
service.MustRegister[adapter.InboundManager](ctx, inboundManager)
service.MustRegister[adapter.OutboundManager](ctx, outboundManager)
service.MustRegister[adapter.DNSTransportManager](ctx, dnsTransportManager)
dnsRouter := dns.NewRouter(ctx, logFactory, dnsOptions)
service.MustRegister[adapter.DNSRouter](ctx, dnsRouter)
networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions)
if err != nil {
return nil, E.Cause(err, "initialize network manager")
@@ -149,18 +175,40 @@ func New(options Options) (*Box, error) {
service.MustRegister[adapter.NetworkManager](ctx, networkManager)
connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection"))
service.MustRegister[adapter.ConnectionManager](ctx, connectionManager)
router, err := route.NewRouter(ctx, logFactory, routeOptions, common.PtrValueOrDefault(options.DNS))
router := route.NewRouter(ctx, logFactory, routeOptions, dnsOptions)
service.MustRegister[adapter.Router](ctx, router)
err = router.Initialize(routeOptions.Rules, routeOptions.RuleSet)
if err != nil {
return nil, E.Cause(err, "initialize router")
}
ntpOptions := common.PtrValueOrDefault(options.NTP)
var timeService *tls.TimeServiceWrapper
if ntpOptions.Enabled {
timeService = new(tls.TimeServiceWrapper)
service.MustRegister[ntp.TimeService](ctx, timeService)
}
for i, transportOptions := range dnsOptions.Servers {
var tag string
if transportOptions.Tag != "" {
tag = transportOptions.Tag
} else {
tag = F.ToString(i)
}
err = dnsTransportManager.Create(
ctx,
logFactory.NewLogger(F.ToString("dns/", transportOptions.Type, "[", tag, "]")),
tag,
transportOptions.Type,
transportOptions.Options,
)
if err != nil {
return nil, E.Cause(err, "initialize DNS server[", i, "]")
}
}
err = dnsRouter.Initialize(dnsOptions.Rules)
if err != nil {
return nil, E.Cause(err, "initialize dns router")
}
for i, endpointOptions := range options.Endpoints {
var tag string
if endpointOptions.Tag != "" {
@@ -168,7 +216,8 @@ func New(options Options) (*Box, error) {
} else {
tag = F.ToString(i)
}
err = endpointManager.Create(ctx,
err = endpointManager.Create(
ctx,
router,
logFactory.NewLogger(F.ToString("endpoint/", endpointOptions.Type, "[", tag, "]")),
tag,
@@ -176,7 +225,7 @@ func New(options Options) (*Box, error) {
endpointOptions.Options,
)
if err != nil {
return nil, E.Cause(err, "initialize inbound[", i, "]")
return nil, E.Cause(err, "initialize endpoint[", i, "]")
}
}
for i, inboundOptions := range options.Inbounds {
@@ -186,7 +235,8 @@ func New(options Options) (*Box, error) {
} else {
tag = F.ToString(i)
}
err = inboundManager.Create(ctx,
err = inboundManager.Create(
ctx,
router,
logFactory.NewLogger(F.ToString("inbound/", inboundOptions.Type, "[", tag, "]")),
tag,
@@ -232,13 +282,19 @@ func New(options Options) (*Box, error) {
option.DirectOutboundOptions{},
),
))
dnsTransportManager.Initialize(common.Must1(
local.NewTransport(
ctx,
logFactory.NewLogger("dns/local"),
"local",
option.LocalDNSServerOptions{},
)))
if platformInterface != nil {
err = platformInterface.Initialize(networkManager)
if err != nil {
return nil, E.Cause(err, "initialize platform interface")
}
}
var services []adapter.LifecycleService
if needCacheFile {
cacheFile := cachefile.New(ctx, common.PtrValueOrDefault(experimentalOptions.CacheFile))
service.MustRegister[adapter.CacheFile](ctx, cacheFile)
@@ -267,7 +323,7 @@ func New(options Options) (*Box, error) {
}
}
if ntpOptions.Enabled {
ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions)
ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions, ntpOptions.ServerIsDomain())
if err != nil {
return nil, E.Cause(err, "create NTP service")
}
@@ -283,17 +339,19 @@ func New(options Options) (*Box, error) {
services = append(services, adapter.NewLifecycleService(ntpService, "ntp service"))
}
return &Box{
network: networkManager,
endpoint: endpointManager,
inbound: inboundManager,
outbound: outboundManager,
connection: connectionManager,
router: router,
createdAt: createdAt,
logFactory: logFactory,
logger: logFactory.Logger(),
services: services,
done: make(chan struct{}),
network: networkManager,
endpoint: endpointManager,
inbound: inboundManager,
outbound: outboundManager,
dnsTransport: dnsTransportManager,
dnsRouter: dnsRouter,
connection: connectionManager,
router: router,
createdAt: createdAt,
logFactory: logFactory,
logger: logFactory.Logger(),
services: services,
done: make(chan struct{}),
}, nil
}
@@ -347,11 +405,11 @@ func (s *Box) preStart() error {
if err != nil {
return err
}
err = adapter.Start(adapter.StartStateInitialize, s.network, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
err = adapter.Start(adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
if err != nil {
return err
}
err = adapter.Start(adapter.StartStateStart, s.outbound, s.network, s.connection, s.router)
err = adapter.Start(adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router)
if err != nil {
return err
}
@@ -375,7 +433,7 @@ func (s *Box) start() error {
if err != nil {
return err
}
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.connection, s.router, s.inbound, s.endpoint)
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.inbound, s.endpoint)
if err != nil {
return err
}
@@ -383,7 +441,7 @@ func (s *Box) start() error {
if err != nil {
return err
}
err = adapter.Start(adapter.StartStateStarted, s.network, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
err = adapter.Start(adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
if err != nil {
return err
}
@@ -402,7 +460,7 @@ func (s *Box) Close() error {
close(s.done)
}
err := common.Close(
s.inbound, s.outbound, s.router, s.connection, s.network,
s.inbound, s.outbound, s.endpoint, s.router, s.connection, s.dnsRouter, s.dnsTransport, s.network,
)
for _, lifecycleService := range s.services {
err = E.Append(err, lifecycleService.Close(), func(err error) error {

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/sagernet/asc-go/asc"
@@ -194,6 +195,10 @@ func publishTestflight(ctx context.Context) error {
log.Info(string(platform), " ", tag, " create submission")
_, _, err = client.TestFlight.CreateBetaAppReviewSubmission(ctx, build.ID)
if err != nil {
if strings.Contains(err.Error(), "ANOTHER_BUILD_IN_REVIEW") {
log.Error(err)
break
}
return err
}
}

View File

@@ -48,7 +48,7 @@ func FindSDK() {
}
func findNDK() bool {
const fixedVersion = "28.0.12674087"
const fixedVersion = "28.0.12916984"
const versionFile = "source.properties"
if fixedPath := filepath.Join(androidSDKPath, "ndk", fixedVersion); rw.IsFile(filepath.Join(fixedPath, versionFile)) {
androidNDKPath = fixedPath

View File

@@ -0,0 +1,68 @@
package main
import (
"encoding/csv"
"io"
"net/http"
"os"
"strings"
"github.com/sagernet/sing-box/log"
"golang.org/x/exp/slices"
)
func main() {
err := updateMozillaIncludedRootCAs()
if err != nil {
log.Error(err)
}
}
func updateMozillaIncludedRootCAs() error {
response, err := http.Get("https://ccadb.my.salesforce-sites.com/mozilla/IncludedCACertificateReportPEMCSV")
if err != nil {
return err
}
defer response.Body.Close()
reader := csv.NewReader(response.Body)
header, err := reader.Read()
if err != nil {
return err
}
geoIndex := slices.Index(header, "Geographic Focus")
nameIndex := slices.Index(header, "Common Name or Certificate Name")
certIndex := slices.Index(header, "PEM Info")
generated := strings.Builder{}
generated.WriteString(`// Code generated by 'make update_certificates'. DO NOT EDIT.
package certificate
import "crypto/x509"
var mozillaIncluded *x509.CertPool
func init() {
mozillaIncluded = x509.NewCertPool()
`)
for {
record, err := reader.Read()
if err == io.EOF {
break
} else if err != nil {
return err
}
if record[geoIndex] == "China" {
continue
}
generated.WriteString("\n // ")
generated.WriteString(record[nameIndex])
generated.WriteString("\n")
generated.WriteString(" mozillaIncluded.AppendCertsFromPEM([]byte(`")
generated.WriteString(record[certIndex])
generated.WriteString("`))\n")
}
generated.WriteString("}\n")
return os.WriteFile("common/certificate/mozilla.go", []byte(generated.String()), 0o644)
}

View File

@@ -69,5 +69,5 @@ func preRun(cmd *cobra.Command, args []string) {
configPaths = append(configPaths, "config.json")
}
globalCtx = service.ContextWith(globalCtx, deprecated.NewStderrManager(log.StdLogger()))
globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry(), include.DNSTransportRegistry())
}

View File

@@ -30,7 +30,7 @@ func init() {
}
func generateTLSKeyPair(serverName string) error {
privateKeyPem, publicKeyPem, err := tls.GenerateKeyPair(time.Now, serverName, time.Now().AddDate(0, flagGenerateTLSKeyPairMonths, 0))
privateKeyPem, publicKeyPem, err := tls.GenerateCertificate(nil, nil, time.Now, serverName, time.Now().AddDate(0, flagGenerateTLSKeyPairMonths, 0))
if err != nil {
return err
}

View File

@@ -61,14 +61,15 @@ func upgradeRuleSet(sourcePath string) error {
log.Info("already up-to-date")
return nil
}
plainRuleSet, err := plainRuleSetCompat.Upgrade()
plainRuleSetCompat.Options, err = plainRuleSetCompat.Upgrade()
if err != nil {
return err
}
plainRuleSetCompat.Version = C.RuleSetVersionCurrent
buffer := new(bytes.Buffer)
encoder := json.NewEncoder(buffer)
encoder.SetIndent("", " ")
err = encoder.Encode(plainRuleSet)
err = encoder.Encode(plainRuleSetCompat)
if err != nil {
return E.Cause(err, "encode config")
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"os"
"github.com/sagernet/sing-box/common/settings"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
@@ -58,7 +57,7 @@ func syncTime() error {
return err
}
if commandSyncTimeWrite {
err = settings.SetSystemTime(response.Time)
err = ntp.SetSystemTime(response.Time)
if err != nil {
return E.Cause(err, "write time to system")
}

File diff suppressed because it is too large Load Diff

184
common/certificate/store.go Normal file
View File

@@ -0,0 +1,184 @@
package certificate
import (
"context"
"crypto/x509"
"io/fs"
"os"
"path/filepath"
"strings"
"github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/service"
)
var _ adapter.CertificateStore = (*Store)(nil)
type Store struct {
systemPool *x509.CertPool
currentPool *x509.CertPool
certificate string
certificatePaths []string
certificateDirectoryPaths []string
watcher *fswatch.Watcher
}
func NewStore(ctx context.Context, logger logger.Logger, options option.CertificateOptions) (*Store, error) {
var systemPool *x509.CertPool
switch options.Store {
case C.CertificateStoreSystem, "":
platformInterface := service.FromContext[platform.Interface](ctx)
systemCertificates := platformInterface.SystemCertificates()
if len(systemCertificates) > 0 {
systemPool = x509.NewCertPool()
for _, cert := range systemCertificates {
if !systemPool.AppendCertsFromPEM([]byte(cert)) {
return nil, E.New("invalid system certificate PEM: ", cert)
}
}
} else {
certPool, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
systemPool = certPool
}
case C.CertificateStoreMozilla:
systemPool = mozillaIncluded
case C.CertificateStoreNone:
systemPool = nil
default:
return nil, E.New("unknown certificate store: ", options.Store)
}
store := &Store{
systemPool: systemPool,
certificate: strings.Join(options.Certificate, "\n"),
certificatePaths: options.CertificatePath,
certificateDirectoryPaths: options.CertificateDirectoryPath,
}
var watchPaths []string
for _, target := range options.CertificatePath {
watchPaths = append(watchPaths, target)
}
for _, target := range options.CertificateDirectoryPath {
watchPaths = append(watchPaths, target)
}
if len(watchPaths) > 0 {
watcher, err := fswatch.NewWatcher(fswatch.Options{
Path: watchPaths,
Logger: logger,
Callback: func(_ string) {
err := store.update()
if err != nil {
logger.Error(E.Cause(err, "reload certificates"))
}
},
})
if err != nil {
return nil, E.Cause(err, "fswatch: create fsnotify watcher")
}
store.watcher = watcher
}
err := store.update()
if err != nil {
return nil, E.Cause(err, "initializing certificate store")
}
return store, nil
}
func (s *Store) Name() string {
return "certificate"
}
func (s *Store) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if s.watcher != nil {
return s.watcher.Start()
}
return nil
}
func (s *Store) Close() error {
if s.watcher != nil {
return s.watcher.Close()
}
return nil
}
func (s *Store) Pool() *x509.CertPool {
return s.currentPool
}
func (s *Store) update() error {
var currentPool *x509.CertPool
if s.systemPool == nil {
currentPool = x509.NewCertPool()
} else {
currentPool = s.systemPool.Clone()
}
if s.certificate != "" {
if !currentPool.AppendCertsFromPEM([]byte(s.certificate)) {
return E.New("invalid certificate PEM strings")
}
}
for _, path := range s.certificatePaths {
pemContent, err := os.ReadFile(path)
if err != nil {
return err
}
if !currentPool.AppendCertsFromPEM(pemContent) {
return E.New("invalid certificate PEM file: ", path)
}
}
var firstErr error
for _, directoryPath := range s.certificateDirectoryPaths {
directoryEntries, err := readUniqueDirectoryEntries(directoryPath)
if err != nil {
if firstErr == nil && !os.IsNotExist(err) {
firstErr = E.Cause(err, "invalid certificate directory: ", directoryPath)
}
continue
}
for _, directoryEntry := range directoryEntries {
pemContent, err := os.ReadFile(filepath.Join(directoryPath, directoryEntry.Name()))
if err == nil {
currentPool.AppendCertsFromPEM(pemContent)
}
}
}
if firstErr != nil {
return firstErr
}
s.currentPool = currentPool
return nil
}
func readUniqueDirectoryEntries(dir string) ([]fs.DirEntry, error) {
files, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
uniq := files[:0]
for _, f := range files {
if !isSameDirSymlink(f, dir) {
uniq = append(uniq, f)
}
}
return uniq, nil
}
func isSameDirSymlink(f fs.DirEntry, dir string) bool {
if f.Type()&fs.ModeSymlink == 0 {
return false
}
target, err := os.Readlink(filepath.Join(dir, f.Name()))
return err == nil && !strings.Contains(target, "/")
}

54
common/conntrack/conn.go Normal file
View File

@@ -0,0 +1,54 @@
package conntrack
import (
"io"
"net"
"github.com/sagernet/sing/common/x/list"
)
type Conn struct {
net.Conn
element *list.Element[io.Closer]
}
func NewConn(conn net.Conn) (net.Conn, error) {
connAccess.Lock()
element := openConnection.PushBack(conn)
connAccess.Unlock()
if KillerEnabled {
err := KillerCheck()
if err != nil {
conn.Close()
return nil, err
}
}
return &Conn{
Conn: conn,
element: element,
}, nil
}
func (c *Conn) Close() error {
if c.element.Value != nil {
connAccess.Lock()
if c.element.Value != nil {
openConnection.Remove(c.element)
c.element.Value = nil
}
connAccess.Unlock()
}
return c.Conn.Close()
}
func (c *Conn) Upstream() any {
return c.Conn
}
func (c *Conn) ReaderReplaceable() bool {
return true
}
func (c *Conn) WriterReplaceable() bool {
return true
}

View File

@@ -1,14 +0,0 @@
package conntrack
import (
"context"
"github.com/sagernet/sing/service"
)
func ContextWithDefaultTracker(ctx context.Context, killerEnabled bool, memoryLimit uint64) context.Context {
if service.FromContext[Tracker](ctx) != nil {
return ctx
}
return service.ContextWith[Tracker](ctx, NewDefaultTracker(killerEnabled, memoryLimit))
}

View File

@@ -1,245 +0,0 @@
package conntrack
import (
"net"
"net/netip"
runtimeDebug "runtime/debug"
"sync"
"time"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/memory"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
)
var _ Tracker = (*DefaultTracker)(nil)
type DefaultTracker struct {
connAccess sync.RWMutex
connList list.List[net.Conn]
connAddress map[netip.AddrPort]netip.AddrPort
packetConnAccess sync.RWMutex
packetConnList list.List[AbstractPacketConn]
packetConnAddress map[netip.AddrPort]bool
pendingAccess sync.RWMutex
pendingList list.List[netip.AddrPort]
killerEnabled bool
memoryLimit uint64
killerLastCheck time.Time
}
func NewDefaultTracker(killerEnabled bool, memoryLimit uint64) *DefaultTracker {
return &DefaultTracker{
connAddress: make(map[netip.AddrPort]netip.AddrPort),
packetConnAddress: make(map[netip.AddrPort]bool),
killerEnabled: killerEnabled,
memoryLimit: memoryLimit,
}
}
func (t *DefaultTracker) NewConn(conn net.Conn) (net.Conn, error) {
err := t.KillerCheck()
if err != nil {
conn.Close()
return nil, err
}
t.connAccess.Lock()
element := t.connList.PushBack(conn)
t.connAddress[M.AddrPortFromNet(conn.LocalAddr())] = M.AddrPortFromNet(conn.RemoteAddr())
t.connAccess.Unlock()
return &Conn{
Conn: conn,
closeFunc: common.OnceFunc(func() {
t.removeConn(element)
}),
}, nil
}
func (t *DefaultTracker) NewConnEx(conn net.Conn) (N.CloseHandlerFunc, error) {
err := t.KillerCheck()
if err != nil {
conn.Close()
return nil, err
}
t.connAccess.Lock()
element := t.connList.PushBack(conn)
t.connAddress[M.AddrPortFromNet(conn.LocalAddr())] = M.AddrPortFromNet(conn.RemoteAddr())
t.connAccess.Unlock()
return N.OnceClose(func(it error) {
t.removeConn(element)
}), nil
}
func (t *DefaultTracker) NewPacketConn(conn net.PacketConn) (net.PacketConn, error) {
err := t.KillerCheck()
if err != nil {
conn.Close()
return nil, err
}
t.packetConnAccess.Lock()
element := t.packetConnList.PushBack(conn)
t.packetConnAddress[M.AddrPortFromNet(conn.LocalAddr())] = true
t.packetConnAccess.Unlock()
return &PacketConn{
PacketConn: conn,
closeFunc: common.OnceFunc(func() {
t.removePacketConn(element)
}),
}, nil
}
func (t *DefaultTracker) NewPacketConnEx(conn AbstractPacketConn) (N.CloseHandlerFunc, error) {
err := t.KillerCheck()
if err != nil {
conn.Close()
return nil, err
}
t.packetConnAccess.Lock()
element := t.packetConnList.PushBack(conn)
t.packetConnAddress[M.AddrPortFromNet(conn.LocalAddr())] = true
t.packetConnAccess.Unlock()
return N.OnceClose(func(it error) {
t.removePacketConn(element)
}), nil
}
func (t *DefaultTracker) CheckConn(source netip.AddrPort, destination netip.AddrPort) bool {
t.connAccess.RLock()
defer t.connAccess.RUnlock()
return t.connAddress[source] == destination
}
func (t *DefaultTracker) CheckPacketConn(source netip.AddrPort) bool {
t.packetConnAccess.RLock()
defer t.packetConnAccess.RUnlock()
return t.packetConnAddress[source]
}
func (t *DefaultTracker) AddPendingDestination(destination netip.AddrPort) func() {
t.pendingAccess.Lock()
defer t.pendingAccess.Unlock()
element := t.pendingList.PushBack(destination)
return func() {
t.pendingAccess.Lock()
defer t.pendingAccess.Unlock()
t.pendingList.Remove(element)
}
}
func (t *DefaultTracker) CheckDestination(destination netip.AddrPort) bool {
t.pendingAccess.RLock()
defer t.pendingAccess.RUnlock()
for element := t.pendingList.Front(); element != nil; element = element.Next() {
if element.Value == destination {
return true
}
}
return false
}
func (t *DefaultTracker) KillerCheck() error {
if !t.killerEnabled {
return nil
}
nowTime := time.Now()
if nowTime.Sub(t.killerLastCheck) < 3*time.Second {
return nil
}
t.killerLastCheck = nowTime
if memory.Total() > t.memoryLimit {
t.Close()
go func() {
time.Sleep(time.Second)
runtimeDebug.FreeOSMemory()
}()
return E.New("out of memory")
}
return nil
}
func (t *DefaultTracker) Count() int {
t.connAccess.RLock()
defer t.connAccess.RUnlock()
t.packetConnAccess.RLock()
defer t.packetConnAccess.RUnlock()
return t.connList.Len() + t.packetConnList.Len()
}
func (t *DefaultTracker) Close() {
t.connAccess.Lock()
for element := t.connList.Front(); element != nil; element = element.Next() {
element.Value.Close()
}
t.connList.Init()
t.connAccess.Unlock()
t.packetConnAccess.Lock()
for element := t.packetConnList.Front(); element != nil; element = element.Next() {
element.Value.Close()
}
t.packetConnList.Init()
t.packetConnAccess.Unlock()
}
func (t *DefaultTracker) removeConn(element *list.Element[net.Conn]) {
t.connAccess.Lock()
defer t.connAccess.Unlock()
delete(t.connAddress, M.AddrPortFromNet(element.Value.LocalAddr()))
t.connList.Remove(element)
}
func (t *DefaultTracker) removePacketConn(element *list.Element[AbstractPacketConn]) {
t.packetConnAccess.Lock()
defer t.packetConnAccess.Unlock()
delete(t.packetConnAddress, M.AddrPortFromNet(element.Value.LocalAddr()))
t.packetConnList.Remove(element)
}
type Conn struct {
net.Conn
closeFunc func()
}
func (c *Conn) Close() error {
c.closeFunc()
return c.Conn.Close()
}
func (c *Conn) Upstream() any {
return c.Conn
}
func (c *Conn) ReaderReplaceable() bool {
return true
}
func (c *Conn) WriterReplaceable() bool {
return true
}
type PacketConn struct {
net.PacketConn
closeFunc func()
}
func (c *PacketConn) Close() error {
c.closeFunc()
return c.PacketConn.Close()
}
func (c *PacketConn) Upstream() any {
return c.PacketConn
}
func (c *PacketConn) ReaderReplaceable() bool {
return true
}
func (c *PacketConn) WriterReplaceable() bool {
return true
}

View File

@@ -0,0 +1,35 @@
package conntrack
import (
runtimeDebug "runtime/debug"
"time"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/memory"
)
var (
KillerEnabled bool
MemoryLimit uint64
killerLastCheck time.Time
)
func KillerCheck() error {
if !KillerEnabled {
return nil
}
nowTime := time.Now()
if nowTime.Sub(killerLastCheck) < 3*time.Second {
return nil
}
killerLastCheck = nowTime
if memory.Total() > MemoryLimit {
Close()
go func() {
time.Sleep(time.Second)
runtimeDebug.FreeOSMemory()
}()
return E.New("out of memory")
}
return nil
}

View File

@@ -0,0 +1,55 @@
package conntrack
import (
"io"
"net"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/x/list"
)
type PacketConn struct {
net.PacketConn
element *list.Element[io.Closer]
}
func NewPacketConn(conn net.PacketConn) (net.PacketConn, error) {
connAccess.Lock()
element := openConnection.PushBack(conn)
connAccess.Unlock()
if KillerEnabled {
err := KillerCheck()
if err != nil {
conn.Close()
return nil, err
}
}
return &PacketConn{
PacketConn: conn,
element: element,
}, nil
}
func (c *PacketConn) Close() error {
if c.element.Value != nil {
connAccess.Lock()
if c.element.Value != nil {
openConnection.Remove(c.element)
c.element.Value = nil
}
connAccess.Unlock()
}
return c.PacketConn.Close()
}
func (c *PacketConn) Upstream() any {
return bufio.NewPacketConn(c.PacketConn)
}
func (c *PacketConn) ReaderReplaceable() bool {
return true
}
func (c *PacketConn) WriterReplaceable() bool {
return true
}

47
common/conntrack/track.go Normal file
View File

@@ -0,0 +1,47 @@
package conntrack
import (
"io"
"sync"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/x/list"
)
var (
connAccess sync.RWMutex
openConnection list.List[io.Closer]
)
func Count() int {
if !Enabled {
return 0
}
return openConnection.Len()
}
func List() []io.Closer {
if !Enabled {
return nil
}
connAccess.RLock()
defer connAccess.RUnlock()
connList := make([]io.Closer, 0, openConnection.Len())
for element := openConnection.Front(); element != nil; element = element.Next() {
connList = append(connList, element.Value)
}
return connList
}
func Close() {
if !Enabled {
return
}
connAccess.Lock()
defer connAccess.Unlock()
for element := openConnection.Front(); element != nil; element = element.Next() {
common.Close(element.Value)
element.Value = nil
}
openConnection.Init()
}

View File

@@ -0,0 +1,5 @@
//go:build !with_conntrack
package conntrack
const Enabled = false

View File

@@ -0,0 +1,5 @@
//go:build with_conntrack
package conntrack
const Enabled = true

View File

@@ -1,32 +0,0 @@
package conntrack
import (
"net"
"net/netip"
"time"
N "github.com/sagernet/sing/common/network"
)
// TODO: add to N
type AbstractPacketConn interface {
Close() error
LocalAddr() net.Addr
SetDeadline(t time.Time) error
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
}
type Tracker interface {
NewConn(conn net.Conn) (net.Conn, error)
NewPacketConn(conn net.PacketConn) (net.PacketConn, error)
NewConnEx(conn net.Conn) (N.CloseHandlerFunc, error)
NewPacketConnEx(conn AbstractPacketConn) (N.CloseHandlerFunc, error)
CheckConn(source netip.AddrPort, destination netip.AddrPort) bool
CheckPacketConn(source netip.AddrPort) bool
AddPendingDestination(destination netip.AddrPort) func()
CheckDestination(destination netip.AddrPort) bool
KillerCheck() error
Count() int
Close()
}

View File

@@ -28,7 +28,6 @@ var (
)
type DefaultDialer struct {
tracker conntrack.Tracker
dialer4 tcpDialer
dialer6 tcpDialer
udpDialer4 net.Dialer
@@ -36,7 +35,6 @@ type DefaultDialer struct {
udpListener net.ListenConfig
udpAddr4 string
udpAddr6 string
isWireGuardListener bool
networkManager adapter.NetworkManager
networkStrategy *C.NetworkStrategy
defaultNetworkStrategy bool
@@ -47,7 +45,6 @@ type DefaultDialer struct {
}
func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDialer, error) {
tracker := service.FromContext[conntrack.Tracker](ctx)
networkManager := service.FromContext[adapter.NetworkManager](ctx)
platformInterface := service.FromContext[platform.Interface](ctx)
@@ -185,11 +182,6 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
}
setMultiPathTCP(&dialer4)
}
if options.IsWireGuardListener {
for _, controlFn := range WgControlFns {
listener.Control = control.Append(listener.Control, controlFn)
}
}
tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen)
if err != nil {
return nil, err
@@ -199,7 +191,6 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
return nil, err
}
return &DefaultDialer{
tracker: tracker,
dialer4: tcpDialer4,
dialer6: tcpDialer6,
udpDialer4: udpDialer4,
@@ -207,7 +198,6 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
udpListener: listener,
udpAddr4: udpAddr4,
udpAddr6: udpAddr6,
isWireGuardListener: options.IsWireGuardListener,
networkManager: networkManager,
networkStrategy: networkStrategy,
defaultNetworkStrategy: defaultNetworkStrategy,
@@ -222,26 +212,18 @@ func (d *DefaultDialer) DialContext(ctx context.Context, network string, address
return nil, E.New("invalid address")
}
if d.networkStrategy == nil {
if address.IsFqdn() {
return nil, E.New("unexpected domain destination")
}
// Since pending check is only used by ndis, it is not performed for non-windows connections which are only supported on platform clients
if d.tracker != nil {
done := d.tracker.AddPendingDestination(address.AddrPort())
defer done()
}
switch N.NetworkName(network) {
case N.NetworkUDP:
if !address.IsIPv6() {
return d.trackConn(d.udpDialer4.DialContext(ctx, network, address.String()))
return trackConn(d.udpDialer4.DialContext(ctx, network, address.String()))
} else {
return d.trackConn(d.udpDialer6.DialContext(ctx, network, address.String()))
return trackConn(d.udpDialer6.DialContext(ctx, network, address.String()))
}
}
if !address.IsIPv6() {
return d.trackConn(DialSlowContext(&d.dialer4, ctx, network, address))
return trackConn(DialSlowContext(&d.dialer4, ctx, network, address))
} else {
return d.trackConn(DialSlowContext(&d.dialer6, ctx, network, address))
return trackConn(DialSlowContext(&d.dialer6, ctx, network, address))
}
} else {
return d.DialParallelInterface(ctx, network, address, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
@@ -293,17 +275,17 @@ func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network strin
if !fastFallback && !isPrimary {
d.networkLastFallback.Store(time.Now())
}
return d.trackConn(conn, nil)
return trackConn(conn, nil)
}
func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if d.networkStrategy == nil {
if destination.IsIPv6() {
return d.trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6))
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6))
} else if destination.IsIPv4() && !destination.Addr.IsUnspecified() {
return d.trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4))
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4))
} else {
return d.trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4))
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4))
}
} else {
return d.ListenSerialInterfacePacket(ctx, destination, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
@@ -340,23 +322,23 @@ func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destina
return nil, err
}
}
return d.trackPacketConn(packetConn, nil)
return trackPacketConn(packetConn, nil)
}
func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) {
return d.udpListener.ListenPacket(context.Background(), network, address)
}
func (d *DefaultDialer) trackConn(conn net.Conn, err error) (net.Conn, error) {
if d.tracker == nil || err != nil {
func trackConn(conn net.Conn, err error) (net.Conn, error) {
if !conntrack.Enabled || err != nil {
return conn, err
}
return d.tracker.NewConn(conn)
return conntrack.NewConn(conn)
}
func (d *DefaultDialer) trackPacketConn(conn net.PacketConn, err error) (net.PacketConn, error) {
if err != nil {
func trackPacketConn(conn net.PacketConn, err error) (net.PacketConn, error) {
if !conntrack.Enabled || err != nil {
return conn, err
}
return d.tracker.NewPacketConn(conn)
return conntrack.NewPacketConn(conn)
}

View File

@@ -29,16 +29,18 @@ func (d *DetourDialer) Start() error {
}
func (d *DetourDialer) Dialer() (N.Dialer, error) {
d.initOnce.Do(func() {
var loaded bool
d.dialer, loaded = d.outboundManager.Outbound(d.detour)
if !loaded {
d.initErr = E.New("outbound detour not found: ", d.detour)
}
})
d.initOnce.Do(d.init)
return d.dialer, d.initErr
}
func (d *DetourDialer) init() {
var loaded bool
d.dialer, loaded = d.outboundManager.Outbound(d.detour)
if !loaded {
d.initErr = E.New("outbound detour not found: ", d.detour)
}
}
func (d *DetourDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
dialer, err := d.Dialer()
if err != nil {

View File

@@ -8,68 +8,113 @@ import (
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/deprecated"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func New(ctx context.Context, options option.DialerOptions) (N.Dialer, error) {
if options.IsWireGuardListener {
return NewDefault(ctx, options)
}
type Options struct {
Context context.Context
Options option.DialerOptions
RemoteIsDomain bool
DirectResolver bool
ResolverOnDetour bool
}
// TODO: merge with NewWithOptions
func New(ctx context.Context, options option.DialerOptions, remoteIsDomain bool) (N.Dialer, error) {
return NewWithOptions(Options{
Context: ctx,
Options: options,
RemoteIsDomain: remoteIsDomain,
})
}
func NewWithOptions(options Options) (N.Dialer, error) {
dialOptions := options.Options
var (
dialer N.Dialer
err error
)
if options.Detour == "" {
dialer, err = NewDefault(ctx, options)
if err != nil {
return nil, err
}
} else {
outboundManager := service.FromContext[adapter.OutboundManager](ctx)
if dialOptions.Detour != "" {
outboundManager := service.FromContext[adapter.OutboundManager](options.Context)
if outboundManager == nil {
return nil, E.New("missing outbound manager")
}
dialer = NewDetour(outboundManager, options.Detour)
}
if options.Detour == "" {
router := service.FromContext[adapter.Router](ctx)
if router != nil {
dialer = NewResolveDialer(
router,
dialer,
options.Detour == "" && !options.TCPFastOpen,
dns.DomainStrategy(options.DomainStrategy),
time.Duration(options.FallbackDelay))
dialer = NewDetour(outboundManager, dialOptions.Detour)
} else {
dialer, err = NewDefault(options.Context, dialOptions)
if err != nil {
return nil, err
}
}
if options.RemoteIsDomain && (dialOptions.Detour == "" || options.ResolverOnDetour) {
networkManager := service.FromContext[adapter.NetworkManager](options.Context)
dnsTransport := service.FromContext[adapter.DNSTransportManager](options.Context)
var defaultOptions adapter.NetworkOptions
if networkManager != nil {
defaultOptions = networkManager.DefaultOptions()
}
var (
server string
dnsQueryOptions adapter.DNSQueryOptions
resolveFallbackDelay time.Duration
)
if dialOptions.DomainResolver != nil && dialOptions.DomainResolver.Server != "" {
var transport adapter.DNSTransport
if !options.DirectResolver {
var loaded bool
transport, loaded = dnsTransport.Transport(dialOptions.DomainResolver.Server)
if !loaded {
return nil, E.New("domain resolver not found: " + dialOptions.DomainResolver.Server)
}
}
var strategy C.DomainStrategy
if dialOptions.DomainResolver.Strategy != option.DomainStrategy(C.DomainStrategyAsIS) {
strategy = C.DomainStrategy(dialOptions.DomainResolver.Strategy)
} else if
//nolint:staticcheck
dialOptions.DomainStrategy != option.DomainStrategy(C.DomainStrategyAsIS) {
//nolint:staticcheck
strategy = C.DomainStrategy(dialOptions.DomainStrategy)
}
server = dialOptions.DomainResolver.Server
dnsQueryOptions = adapter.DNSQueryOptions{
Transport: transport,
Strategy: strategy,
DisableCache: dialOptions.DomainResolver.DisableCache,
RewriteTTL: dialOptions.DomainResolver.RewriteTTL,
ClientSubnet: dialOptions.DomainResolver.ClientSubnet.Build(netip.Prefix{}),
}
resolveFallbackDelay = time.Duration(dialOptions.FallbackDelay)
} else if options.DirectResolver {
return nil, E.New("missing domain resolver for domain server address")
} else if defaultOptions.DomainResolver != "" {
dnsQueryOptions = defaultOptions.DomainResolveOptions
transport, loaded := dnsTransport.Transport(defaultOptions.DomainResolver)
if !loaded {
return nil, E.New("default domain resolver not found: " + defaultOptions.DomainResolver)
}
dnsQueryOptions.Transport = transport
resolveFallbackDelay = time.Duration(dialOptions.FallbackDelay)
} else {
deprecated.Report(options.Context, deprecated.OptionMissingDomainResolver)
}
dialer = NewResolveDialer(
options.Context,
dialer,
dialOptions.Detour == "" && !dialOptions.TCPFastOpen,
server,
dnsQueryOptions,
resolveFallbackDelay,
)
}
return dialer, nil
}
func NewDirect(ctx context.Context, options option.DialerOptions) (ParallelInterfaceDialer, error) {
if options.Detour != "" {
return nil, E.New("`detour` is not supported in direct context")
}
if options.IsWireGuardListener {
return NewDefault(ctx, options)
}
dialer, err := NewDefault(ctx, options)
if err != nil {
return nil, err
}
return NewResolveParallelInterfaceDialer(
service.FromContext[adapter.Router](ctx),
dialer,
true,
dns.DomainStrategy(options.DomainStrategy),
time.Duration(options.FallbackDelay),
), nil
}
type ParallelInterfaceDialer interface {
N.Dialer
DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error)

View File

@@ -3,16 +3,17 @@ package dialer
import (
"context"
"net"
"net/netip"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
var (
@@ -20,21 +21,37 @@ var (
_ ParallelInterfaceDialer = (*resolveParallelNetworkDialer)(nil)
)
type ResolveDialer interface {
N.Dialer
QueryOptions() adapter.DNSQueryOptions
}
type ParallelInterfaceResolveDialer interface {
ParallelInterfaceDialer
QueryOptions() adapter.DNSQueryOptions
}
type resolveDialer struct {
transport adapter.DNSTransportManager
router adapter.DNSRouter
dialer N.Dialer
parallel bool
router adapter.Router
strategy dns.DomainStrategy
server string
initOnce sync.Once
initErr error
queryOptions adapter.DNSQueryOptions
fallbackDelay time.Duration
}
func NewResolveDialer(router adapter.Router, dialer N.Dialer, parallel bool, strategy dns.DomainStrategy, fallbackDelay time.Duration) N.Dialer {
func NewResolveDialer(ctx context.Context, dialer N.Dialer, parallel bool, server string, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ResolveDialer {
return &resolveDialer{
dialer,
parallel,
router,
strategy,
fallbackDelay,
transport: service.FromContext[adapter.DNSTransportManager](ctx),
router: service.FromContext[adapter.DNSRouter](ctx),
dialer: dialer,
parallel: parallel,
server: server,
queryOptions: queryOptions,
fallbackDelay: fallbackDelay,
}
}
@@ -43,59 +60,68 @@ type resolveParallelNetworkDialer struct {
dialer ParallelInterfaceDialer
}
func NewResolveParallelInterfaceDialer(router adapter.Router, dialer ParallelInterfaceDialer, parallel bool, strategy dns.DomainStrategy, fallbackDelay time.Duration) ParallelInterfaceDialer {
func NewResolveParallelInterfaceDialer(ctx context.Context, dialer ParallelInterfaceDialer, parallel bool, server string, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ParallelInterfaceResolveDialer {
return &resolveParallelNetworkDialer{
resolveDialer{
dialer,
parallel,
router,
strategy,
fallbackDelay,
transport: service.FromContext[adapter.DNSTransportManager](ctx),
router: service.FromContext[adapter.DNSRouter](ctx),
dialer: dialer,
parallel: parallel,
server: server,
queryOptions: queryOptions,
fallbackDelay: fallbackDelay,
},
dialer,
}
}
func (d *resolveDialer) initialize() error {
d.initOnce.Do(d.initServer)
return d.initErr
}
func (d *resolveDialer) initServer() {
if d.server == "" {
return
}
transport, loaded := d.transport.Transport(d.server)
if !loaded {
d.initErr = E.New("domain resolver not found: " + d.server)
return
}
d.queryOptions.Transport = transport
}
func (d *resolveDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination)
}
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
addresses, err := d.router.Lookup(ctx, destination.Fqdn, d.queryOptions)
if err != nil {
return nil, err
}
if d.parallel {
return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == dns.DomainStrategyPreferIPv6, d.fallbackDelay)
return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.queryOptions.Strategy == C.DomainStrategyPreferIPv6, d.fallbackDelay)
} else {
return N.DialSerial(ctx, d.dialer, network, destination, addresses)
}
}
func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination)
}
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
addresses, err := d.router.Lookup(ctx, destination.Fqdn, d.queryOptions)
if err != nil {
return nil, err
}
@@ -106,21 +132,24 @@ func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd
return bufio.NewNATPacketConn(bufio.NewPacketConn(conn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
}
func (d *resolveDialer) QueryOptions() adapter.DNSQueryOptions {
return d.queryOptions
}
func (d *resolveDialer) Upstream() any {
return d.dialer
}
func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination)
}
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
addresses, err := d.router.Lookup(ctx, destination.Fqdn, d.queryOptions)
if err != nil {
return nil, err
}
@@ -128,30 +157,28 @@ func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context
fallbackDelay = d.fallbackDelay
}
if d.parallel {
return DialParallelNetwork(ctx, d.dialer, network, destination, addresses, d.strategy == dns.DomainStrategyPreferIPv6, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
return DialParallelNetwork(ctx, d.dialer, network, destination, addresses, d.queryOptions.Strategy == C.DomainStrategyPreferIPv6, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
} else {
return DialSerialNetwork(ctx, d.dialer, network, destination, addresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
}
}
func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination)
}
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
addresses, err := d.router.Lookup(ctx, destination.Fqdn, d.queryOptions)
if err != nil {
return nil, err
}
if fallbackDelay == 0 {
fallbackDelay = d.fallbackDelay
}
conn, destinationAddress, err := ListenSerialNetworkPacket(ctx, d.dialer, destination, addresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
if err != nil {
return nil, err
@@ -159,6 +186,10 @@ func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.C
return bufio.NewNATPacketConn(bufio.NewPacketConn(conn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
}
func (d *resolveDialer) Upstream() any {
func (d *resolveParallelNetworkDialer) QueryOptions() adapter.DNSQueryOptions {
return d.queryOptions
}
func (d *resolveParallelNetworkDialer) Upstream() any {
return d.dialer
}

View File

@@ -7,24 +7,27 @@ import (
"github.com/sagernet/sing-box/adapter"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
type DefaultOutboundDialer struct {
outboundManager adapter.OutboundManager
outbound adapter.OutboundManager
}
func NewDefaultOutbound(outboundManager adapter.OutboundManager) N.Dialer {
return &DefaultOutboundDialer{outboundManager: outboundManager}
func NewDefaultOutbound(ctx context.Context) N.Dialer {
return &DefaultOutboundDialer{
outbound: service.FromContext[adapter.OutboundManager](ctx),
}
}
func (d *DefaultOutboundDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return d.outboundManager.Default().DialContext(ctx, network, destination)
return d.outbound.Default().DialContext(ctx, network, destination)
}
func (d *DefaultOutboundDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return d.outboundManager.Default().ListenPacket(ctx, destination)
return d.outbound.Default().ListenPacket(ctx, destination)
}
func (d *DefaultOutboundDialer) Upstream() any {
return d.outboundManager.Default()
return d.outbound.Default()
}

View File

@@ -23,6 +23,7 @@ type Config struct {
}
type Info struct {
ProcessID uint32
ProcessPath string
PackageName string
User string

View File

@@ -2,14 +2,11 @@ package process
import (
"context"
"fmt"
"net/netip"
"os"
"syscall"
"unsafe"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/winiphlpapi"
"golang.org/x/sys/windows"
)
@@ -26,209 +23,39 @@ func NewSearcher(_ Config) (Searcher, error) {
return &windowsSearcher{}, nil
}
var (
modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
procGetExtendedTcpTable = modiphlpapi.NewProc("GetExtendedTcpTable")
procGetExtendedUdpTable = modiphlpapi.NewProc("GetExtendedUdpTable")
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
procQueryFullProcessImageNameW = modkernel32.NewProc("QueryFullProcessImageNameW")
)
func initWin32API() error {
err := modiphlpapi.Load()
if err != nil {
return E.Cause(err, "load iphlpapi.dll")
}
err = procGetExtendedTcpTable.Find()
if err != nil {
return E.Cause(err, "load iphlpapi::GetExtendedTcpTable")
}
err = procGetExtendedUdpTable.Find()
if err != nil {
return E.Cause(err, "load iphlpapi::GetExtendedUdpTable")
}
err = modkernel32.Load()
if err != nil {
return E.Cause(err, "load kernel32.dll")
}
err = procQueryFullProcessImageNameW.Find()
if err != nil {
return E.Cause(err, "load kernel32::QueryFullProcessImageNameW")
}
return nil
return winiphlpapi.LoadExtendedTable()
}
func (s *windowsSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) {
processName, err := findProcessName(network, source.Addr(), int(source.Port()))
pid, err := winiphlpapi.FindPid(network, source)
if err != nil {
return nil, err
}
return &Info{ProcessPath: processName, UserId: -1}, nil
}
func findProcessName(network string, ip netip.Addr, srcPort int) (string, error) {
family := windows.AF_INET
if ip.Is6() {
family = windows.AF_INET6
}
const (
tcpTablePidConn = 4
udpTablePid = 1
)
var class int
var fn uintptr
switch network {
case N.NetworkTCP:
fn = procGetExtendedTcpTable.Addr()
class = tcpTablePidConn
case N.NetworkUDP:
fn = procGetExtendedUdpTable.Addr()
class = udpTablePid
default:
return "", os.ErrInvalid
}
buf, err := getTransportTable(fn, family, class)
path, err := getProcessPath(pid)
if err != nil {
return "", err
return &Info{ProcessID: pid, UserId: -1}, err
}
s := newSearcher(family == windows.AF_INET, network == N.NetworkTCP)
pid, err := s.Search(buf, ip, uint16(srcPort))
if err != nil {
return "", err
}
return getExecPathFromPID(pid)
return &Info{ProcessID: pid, ProcessPath: path, UserId: -1}, nil
}
type searcher struct {
itemSize int
port int
ip int
ipSize int
pid int
tcpState int
}
func (s *searcher) Search(b []byte, ip netip.Addr, port uint16) (uint32, error) {
n := int(readNativeUint32(b[:4]))
itemSize := s.itemSize
for i := 0; i < n; i++ {
row := b[4+itemSize*i : 4+itemSize*(i+1)]
if s.tcpState >= 0 {
tcpState := readNativeUint32(row[s.tcpState : s.tcpState+4])
// MIB_TCP_STATE_ESTAB, only check established connections for TCP
if tcpState != 5 {
continue
}
}
// according to MSDN, only the lower 16 bits of dwLocalPort are used and the port number is in network endian.
// this field can be illustrated as follows depends on different machine endianess:
// little endian: [ MSB LSB 0 0 ] interpret as native uint32 is ((LSB<<8)|MSB)
// big endian: [ 0 0 MSB LSB ] interpret as native uint32 is ((MSB<<8)|LSB)
// so we need an syscall.Ntohs on the lower 16 bits after read the port as native uint32
srcPort := syscall.Ntohs(uint16(readNativeUint32(row[s.port : s.port+4])))
if srcPort != port {
continue
}
srcIP, _ := netip.AddrFromSlice(row[s.ip : s.ip+s.ipSize])
// windows binds an unbound udp socket to 0.0.0.0/[::] while first sendto
if ip != srcIP && (!srcIP.IsUnspecified() || s.tcpState != -1) {
continue
}
pid := readNativeUint32(row[s.pid : s.pid+4])
return pid, nil
}
return 0, ErrNotFound
}
func newSearcher(isV4, isTCP bool) *searcher {
var itemSize, port, ip, ipSize, pid int
tcpState := -1
switch {
case isV4 && isTCP:
// struct MIB_TCPROW_OWNER_PID
itemSize, port, ip, ipSize, pid, tcpState = 24, 8, 4, 4, 20, 0
case isV4 && !isTCP:
// struct MIB_UDPROW_OWNER_PID
itemSize, port, ip, ipSize, pid = 12, 4, 0, 4, 8
case !isV4 && isTCP:
// struct MIB_TCP6ROW_OWNER_PID
itemSize, port, ip, ipSize, pid, tcpState = 56, 20, 0, 16, 52, 48
case !isV4 && !isTCP:
// struct MIB_UDP6ROW_OWNER_PID
itemSize, port, ip, ipSize, pid = 28, 20, 0, 16, 24
}
return &searcher{
itemSize: itemSize,
port: port,
ip: ip,
ipSize: ipSize,
pid: pid,
tcpState: tcpState,
}
}
func getTransportTable(fn uintptr, family int, class int) ([]byte, error) {
for size, buf := uint32(8), make([]byte, 8); ; {
ptr := unsafe.Pointer(&buf[0])
err, _, _ := syscall.SyscallN(fn, uintptr(ptr), uintptr(unsafe.Pointer(&size)), 0, uintptr(family), uintptr(class), 0)
switch err {
case 0:
return buf, nil
case uintptr(syscall.ERROR_INSUFFICIENT_BUFFER):
buf = make([]byte, size)
default:
return nil, fmt.Errorf("syscall error: %d", err)
}
}
}
func readNativeUint32(b []byte) uint32 {
return *(*uint32)(unsafe.Pointer(&b[0]))
}
func getExecPathFromPID(pid uint32) (string, error) {
// kernel process starts with a colon in order to distinguish with normal processes
func getProcessPath(pid uint32) (string, error) {
switch pid {
case 0:
// reserved pid for system idle process
return ":System Idle Process", nil
case 4:
// reserved pid for windows kernel image
return ":System", nil
}
h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, pid)
handle, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, pid)
if err != nil {
return "", err
}
defer windows.CloseHandle(h)
defer windows.CloseHandle(handle)
size := uint32(syscall.MAX_LONG_PATH)
buf := make([]uint16, syscall.MAX_LONG_PATH)
size := uint32(len(buf))
r1, _, err := syscall.SyscallN(
procQueryFullProcessImageNameW.Addr(),
uintptr(h),
uintptr(0),
uintptr(unsafe.Pointer(&buf[0])),
uintptr(unsafe.Pointer(&size)),
)
if r1 == 0 {
err = windows.QueryFullProcessImageName(handle, 0, &buf[0], &size)
if err != nil {
return "", err
}
return syscall.UTF16ToString(buf[:size]), nil
return windows.UTF16ToString(buf[:size]), nil
}

View File

@@ -1,12 +0,0 @@
//go:build !(windows || linux || darwin)
package settings
import (
"os"
"time"
)
func SetSystemTime(nowTime time.Time) error {
return os.ErrInvalid
}

View File

@@ -1,14 +0,0 @@
//go:build linux || darwin
package settings
import (
"time"
"golang.org/x/sys/unix"
)
func SetSystemTime(nowTime time.Time) error {
timeVal := unix.NsecToTimeval(nowTime.UnixNano())
return unix.Settimeofday(&timeVal)
}

View File

@@ -1,32 +0,0 @@
package settings
import (
"time"
"unsafe"
"golang.org/x/sys/windows"
)
func SetSystemTime(nowTime time.Time) error {
var systemTime windows.Systemtime
systemTime.Year = uint16(nowTime.Year())
systemTime.Month = uint16(nowTime.Month())
systemTime.Day = uint16(nowTime.Day())
systemTime.Hour = uint16(nowTime.Hour())
systemTime.Minute = uint16(nowTime.Minute())
systemTime.Second = uint16(nowTime.Second())
systemTime.Milliseconds = uint16(nowTime.UnixMilli() - nowTime.Unix()*1000)
dllKernel32 := windows.NewLazySystemDLL("kernel32.dll")
proc := dllKernel32.NewProc("SetSystemTime")
_, _, err := proc.Call(
uintptr(unsafe.Pointer(&systemTime)),
)
if err != nil && err.Error() != "The operation completed successfully." {
return err
}
return nil
}

View File

@@ -15,8 +15,8 @@ import (
cftls "github.com/sagernet/cloudflare-tls"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/service"
@@ -100,6 +100,7 @@ func NewECHClient(ctx context.Context, serverAddress string, options option.Outb
var tlsConfig cftls.Config
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx)
if options.DisableSNI {
tlsConfig.ServerName = "127.0.0.1"
} else {
@@ -215,7 +216,7 @@ func fetchECHClientConfig(ctx context.Context) func(_ context.Context, serverNam
},
},
}
response, err := service.FromContext[adapter.Router](ctx).Exchange(ctx, message)
response, err := service.FromContext[adapter.DNSRouter](ctx).Exchange(ctx, message, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}

View File

@@ -90,7 +90,7 @@ func (c *echServerConfig) startWatcher() error {
Callback: func(path string) {
err := c.credentialsUpdated(path)
if err != nil {
c.logger.Error(E.Cause(err, "reload credentials from ", path))
c.logger.Error(E.Cause(err, "reload credentials"))
}
},
})

View File

@@ -11,8 +11,8 @@ import (
"time"
)
func GenerateCertificate(timeFunc func() time.Time, serverName string) (*tls.Certificate, error) {
privateKeyPem, publicKeyPem, err := GenerateKeyPair(timeFunc, serverName, timeFunc().Add(time.Hour))
func GenerateKeyPair(parent *x509.Certificate, parentKey any, timeFunc func() time.Time, serverName string) (*tls.Certificate, error) {
privateKeyPem, publicKeyPem, err := GenerateCertificate(parent, parentKey, timeFunc, serverName, timeFunc().Add(time.Hour))
if err != nil {
return nil, err
}
@@ -23,7 +23,7 @@ func GenerateCertificate(timeFunc func() time.Time, serverName string) (*tls.Cer
return &certificate, err
}
func GenerateKeyPair(timeFunc func() time.Time, serverName string, expire time.Time) (privateKeyPem []byte, publicKeyPem []byte, err error) {
func GenerateCertificate(parent *x509.Certificate, parentKey any, timeFunc func() time.Time, serverName string, expire time.Time) (privateKeyPem []byte, publicKeyPem []byte, err error) {
if timeFunc == nil {
timeFunc = time.Now
}
@@ -47,7 +47,11 @@ func GenerateKeyPair(timeFunc func() time.Time, serverName string, expire time.T
},
DNSNames: []string{serverName},
}
publicDer, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key)
if parent == nil {
parent = template
parentKey = key
}
publicDer, err := x509.CreateCertificate(rand.Reader, template, parent, key.Public(), parentKey)
if err != nil {
return
}

View File

@@ -27,9 +27,11 @@ import (
"time"
"unsafe"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp"
aTLS "github.com/sagernet/sing/common/tls"
utls "github.com/sagernet/utls"
@@ -40,6 +42,7 @@ import (
var _ ConfigCompat = (*RealityClientConfig)(nil)
type RealityClientConfig struct {
ctx context.Context
uClient *UTLSClientConfig
publicKey []byte
shortID [8]byte
@@ -70,7 +73,7 @@ func NewRealityClient(ctx context.Context, serverAddress string, options option.
if decodedLen > 8 {
return nil, E.New("invalid short_id")
}
return &RealityClientConfig{uClient, publicKey, shortID}, nil
return &RealityClientConfig{ctx, uClient, publicKey, shortID}, nil
}
func (e *RealityClientConfig) ServerName() string {
@@ -180,20 +183,24 @@ func (e *RealityClientConfig) ClientHandshake(ctx context.Context, conn net.Conn
}
if !verifier.verified {
go realityClientFallback(uConn, e.uClient.ServerName(), e.uClient.id)
go realityClientFallback(e.ctx, uConn, e.uClient.ServerName(), e.uClient.id)
return nil, E.New("reality verification failed")
}
return &utlsConnWrapper{uConn}, nil
return &realityClientConnWrapper{uConn}, nil
}
func realityClientFallback(uConn net.Conn, serverName string, fingerprint utls.ClientHelloID) {
func realityClientFallback(ctx context.Context, uConn net.Conn, serverName string, fingerprint utls.ClientHelloID) {
defer uConn.Close()
client := &http.Client{
Transport: &http2.Transport{
DialTLSContext: func(ctx context.Context, network, addr string, config *tls.Config) (net.Conn, error) {
return uConn, nil
},
TLSClientConfig: &tls.Config{
Time: ntp.TimeFuncFromContext(ctx),
RootCAs: adapter.RootPoolFromContext(ctx),
},
},
}
request, _ := http.NewRequest("GET", "https://"+serverName, nil)
@@ -213,6 +220,7 @@ func (e *RealityClientConfig) SetSessionIDGenerator(generator func(clientHello [
func (e *RealityClientConfig) Clone() Config {
return &RealityClientConfig{
e.ctx,
e.uClient.Clone().(*UTLSClientConfig),
e.publicKey,
e.shortID,
@@ -249,3 +257,36 @@ func (c *realityVerifier) VerifyPeerCertificate(rawCerts [][]byte, verifiedChain
}
return nil
}
type realityClientConnWrapper struct {
*utls.UConn
}
func (c *realityClientConnWrapper) ConnectionState() tls.ConnectionState {
state := c.Conn.ConnectionState()
//nolint:staticcheck
return tls.ConnectionState{
Version: state.Version,
HandshakeComplete: state.HandshakeComplete,
DidResume: state.DidResume,
CipherSuite: state.CipherSuite,
NegotiatedProtocol: state.NegotiatedProtocol,
NegotiatedProtocolIsMutual: state.NegotiatedProtocolIsMutual,
ServerName: state.ServerName,
PeerCertificates: state.PeerCertificates,
VerifiedChains: state.VerifiedChains,
SignedCertificateTimestamps: state.SignedCertificateTimestamps,
OCSPResponse: state.OCSPResponse,
TLSUnique: state.TLSUnique,
}
}
func (c *realityClientConnWrapper) Upstream() any {
return c.UConn
}
// Due to low implementation quality, the reality server intercepted half close and caused memory leaks.
// We fixed it by calling Close() directly.
func (c *realityClientConnWrapper) CloseWrite() error {
return c.Close()
}

View File

@@ -101,7 +101,7 @@ func NewRealityServer(ctx context.Context, logger log.Logger, options option.Inb
tlsConfig.ShortIds[shortID] = true
}
handshakeDialer, err := dialer.New(ctx, options.Reality.Handshake.DialerOptions)
handshakeDialer, err := dialer.New(ctx, options.Reality.Handshake.DialerOptions, options.Reality.Handshake.ServerIsDomain())
if err != nil {
return nil, err
}
@@ -194,3 +194,9 @@ func (c *realityConnWrapper) ConnectionState() ConnectionState {
func (c *realityConnWrapper) Upstream() any {
return c.Conn
}
// Due to low implementation quality, the reality server intercepted half close and caused memory leaks.
// We fixed it by calling Close() directly.
func (c *realityConnWrapper) CloseWrite() error {
return c.Close()
}

View File

@@ -5,10 +5,10 @@ import (
"crypto/tls"
"crypto/x509"
"net"
"net/netip"
"os"
"strings"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp"
@@ -51,9 +51,7 @@ func NewSTDClient(ctx context.Context, serverAddress string, options option.Outb
if options.ServerName != "" {
serverName = options.ServerName
} else if serverAddress != "" {
if _, err := netip.ParseAddr(serverName); err != nil {
serverName = serverAddress
}
serverName = serverAddress
}
if serverName == "" && !options.Insecure {
return nil, E.New("missing server_name or insecure=true")
@@ -61,6 +59,7 @@ func NewSTDClient(ctx context.Context, serverAddress string, options option.Outb
var tlsConfig tls.Config
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx)
if options.DisableSNI {
tlsConfig.ServerName = "127.0.0.1"
} else {

View File

@@ -99,7 +99,7 @@ func (c *STDServerConfig) startWatcher() error {
Callback: func(path string) {
err := c.certificateUpdated(path)
if err != nil {
c.logger.Error(err)
c.logger.Error(E.Cause(err, "reload certificate"))
}
},
})
@@ -222,7 +222,7 @@ func NewSTDServer(ctx context.Context, logger log.Logger, options option.Inbound
}
if certificate == nil && key == nil && options.Insecure {
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return GenerateCertificate(ntp.TimeFuncFromContext(ctx), info.ServerName)
return GenerateKeyPair(nil, nil, ntp.TimeFuncFromContext(ctx), info.ServerName)
}
} else {
if certificate == nil {

View File

@@ -12,6 +12,7 @@ import (
"os"
"strings"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp"
@@ -130,6 +131,7 @@ func NewUTLSClient(ctx context.Context, serverAddress string, options option.Out
var tlsConfig utls.Config
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx)
if options.DisableSNI {
tlsConfig.ServerName = "127.0.0.1"
} else {

107
common/tlsfragment/conn.go Normal file
View File

@@ -0,0 +1,107 @@
package tf
import (
"context"
"math/rand"
"net"
"strings"
"time"
N "github.com/sagernet/sing/common/network"
"golang.org/x/net/publicsuffix"
)
type Conn struct {
net.Conn
tcpConn *net.TCPConn
ctx context.Context
firstPacketWritten bool
fallbackDelay time.Duration
}
func NewConn(conn net.Conn, ctx context.Context, fallbackDelay time.Duration) (*Conn, error) {
tcpConn, _ := N.UnwrapReader(conn).(*net.TCPConn)
return &Conn{
Conn: conn,
tcpConn: tcpConn,
ctx: ctx,
fallbackDelay: fallbackDelay,
}, nil
}
func (c *Conn) Write(b []byte) (n int, err error) {
if !c.firstPacketWritten {
defer func() {
c.firstPacketWritten = true
}()
serverName := indexTLSServerName(b)
if serverName != nil {
if c.tcpConn != nil {
err = c.tcpConn.SetNoDelay(true)
if err != nil {
return
}
}
splits := strings.Split(serverName.ServerName, ".")
currentIndex := serverName.Index
if publicSuffix := publicsuffix.List.PublicSuffix(serverName.ServerName); publicSuffix != "" {
splits = splits[:len(splits)-strings.Count(serverName.ServerName, ".")]
}
if len(splits) > 1 && splits[0] == "..." {
currentIndex += len(splits[0]) + 1
splits = splits[1:]
}
var splitIndexes []int
for i, split := range splits {
splitAt := rand.Intn(len(split))
splitIndexes = append(splitIndexes, currentIndex+splitAt)
currentIndex += len(split)
if i != len(splits)-1 {
currentIndex++
}
}
for i := 0; i <= len(splitIndexes); i++ {
var payload []byte
if i == 0 {
payload = b[:splitIndexes[i]]
} else if i == len(splitIndexes) {
payload = b[splitIndexes[i-1]:]
} else {
payload = b[splitIndexes[i-1]:splitIndexes[i]]
}
if c.tcpConn != nil && i != len(splitIndexes) {
err = writeAndWaitAck(c.ctx, c.tcpConn, payload, c.fallbackDelay)
if err != nil {
return
}
} else {
_, err = c.Conn.Write(payload)
if err != nil {
return
}
}
}
if c.tcpConn != nil {
err = c.tcpConn.SetNoDelay(false)
if err != nil {
return
}
}
return len(b), nil
}
}
return c.Conn.Write(b)
}
func (c *Conn) ReaderReplaceable() bool {
return true
}
func (c *Conn) WriterReplaceable() bool {
return c.firstPacketWritten
}
func (c *Conn) Upstream() any {
return c.Conn
}

131
common/tlsfragment/index.go Normal file
View File

@@ -0,0 +1,131 @@
package tf
import (
"encoding/binary"
)
const (
recordLayerHeaderLen int = 5
handshakeHeaderLen int = 6
randomDataLen int = 32
sessionIDHeaderLen int = 1
cipherSuiteHeaderLen int = 2
compressMethodHeaderLen int = 1
extensionsHeaderLen int = 2
extensionHeaderLen int = 4
sniExtensionHeaderLen int = 5
contentType uint8 = 22
handshakeType uint8 = 1
sniExtensionType uint16 = 0
sniNameDNSHostnameType uint8 = 0
tlsVersionBitmask uint16 = 0xFFFC
tls13 uint16 = 0x0304
)
type myServerName struct {
Index int
Length int
ServerName string
}
func indexTLSServerName(payload []byte) *myServerName {
if len(payload) < recordLayerHeaderLen || payload[0] != contentType {
return nil
}
segmentLen := binary.BigEndian.Uint16(payload[3:5])
if len(payload) < recordLayerHeaderLen+int(segmentLen) {
return nil
}
serverName := indexTLSServerNameFromHandshake(payload[recordLayerHeaderLen : recordLayerHeaderLen+int(segmentLen)])
if serverName == nil {
return nil
}
serverName.Length += recordLayerHeaderLen
return serverName
}
func indexTLSServerNameFromHandshake(hs []byte) *myServerName {
if len(hs) < handshakeHeaderLen+randomDataLen+sessionIDHeaderLen {
return nil
}
if hs[0] != handshakeType {
return nil
}
handshakeLen := uint32(hs[1])<<16 | uint32(hs[2])<<8 | uint32(hs[3])
if len(hs[4:]) != int(handshakeLen) {
return nil
}
tlsVersion := uint16(hs[4])<<8 | uint16(hs[5])
if tlsVersion&tlsVersionBitmask != 0x0300 && tlsVersion != tls13 {
return nil
}
sessionIDLen := hs[38]
if len(hs) < handshakeHeaderLen+randomDataLen+sessionIDHeaderLen+int(sessionIDLen) {
return nil
}
cs := hs[handshakeHeaderLen+randomDataLen+sessionIDHeaderLen+int(sessionIDLen):]
if len(cs) < cipherSuiteHeaderLen {
return nil
}
csLen := uint16(cs[0])<<8 | uint16(cs[1])
if len(cs) < cipherSuiteHeaderLen+int(csLen)+compressMethodHeaderLen {
return nil
}
compressMethodLen := uint16(cs[cipherSuiteHeaderLen+int(csLen)])
if len(cs) < cipherSuiteHeaderLen+int(csLen)+compressMethodHeaderLen+int(compressMethodLen) {
return nil
}
currentIndex := cipherSuiteHeaderLen + int(csLen) + compressMethodHeaderLen + int(compressMethodLen)
serverName := indexTLSServerNameFromExtensions(cs[currentIndex:])
if serverName == nil {
return nil
}
serverName.Index += currentIndex
return serverName
}
func indexTLSServerNameFromExtensions(exs []byte) *myServerName {
if len(exs) == 0 {
return nil
}
if len(exs) < extensionsHeaderLen {
return nil
}
exsLen := uint16(exs[0])<<8 | uint16(exs[1])
exs = exs[extensionsHeaderLen:]
if len(exs) < int(exsLen) {
return nil
}
for currentIndex := extensionsHeaderLen; len(exs) > 0; {
if len(exs) < extensionHeaderLen {
return nil
}
exType := uint16(exs[0])<<8 | uint16(exs[1])
exLen := uint16(exs[2])<<8 | uint16(exs[3])
if len(exs) < extensionHeaderLen+int(exLen) {
return nil
}
sex := exs[extensionHeaderLen : extensionHeaderLen+int(exLen)]
switch exType {
case sniExtensionType:
if len(sex) < sniExtensionHeaderLen {
return nil
}
sniType := sex[2]
if sniType != sniNameDNSHostnameType {
return nil
}
sniLen := uint16(sex[3])<<8 | uint16(sex[4])
sex = sex[sniExtensionHeaderLen:]
return &myServerName{
Index: currentIndex + extensionHeaderLen + sniExtensionHeaderLen,
Length: int(sniLen),
ServerName: string(sex),
}
}
exs = exs[4+exLen:]
currentIndex += 4 + int(exLen)
}
return nil
}

View File

@@ -0,0 +1,93 @@
package tf
import (
"context"
"net"
"time"
"github.com/sagernet/sing/common/control"
"golang.org/x/sys/unix"
)
/*
const tcpMaxNotifyAck = 10
type tcpNotifyAckID uint32
type tcpNotifyAckComplete struct {
NotifyPending uint32
NotifyCompleteCount uint32
NotifyCompleteID [tcpMaxNotifyAck]tcpNotifyAckID
}
var sizeOfTCPNotifyAckComplete = int(unsafe.Sizeof(tcpNotifyAckComplete{}))
func getsockoptTCPNotifyAckComplete(fd, level, opt int) (*tcpNotifyAckComplete, error) {
var value tcpNotifyAckComplete
vallen := uint32(sizeOfTCPNotifyAckComplete)
err := getsockopt(fd, level, opt, unsafe.Pointer(&value), &vallen)
return &value, err
}
//go:linkname getsockopt golang.org/x/sys/unix.getsockopt
func getsockopt(s int, level int, name int, val unsafe.Pointer, vallen *uint32) error
func waitAck(ctx context.Context, conn *net.TCPConn, _ time.Duration) error {
const TCP_NOTIFY_ACKNOWLEDGEMENT = 0x212
return control.Conn(conn, func(fd uintptr) error {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, TCP_NOTIFY_ACKNOWLEDGEMENT, 1)
if err != nil {
if errors.Is(err, unix.EINVAL) {
return waitAckFallback(ctx, conn, 0)
}
return err
}
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
var ackComplete *tcpNotifyAckComplete
ackComplete, err = getsockoptTCPNotifyAckComplete(int(fd), unix.IPPROTO_TCP, TCP_NOTIFY_ACKNOWLEDGEMENT)
if err != nil {
return err
}
if ackComplete.NotifyPending == 0 {
return nil
}
time.Sleep(10 * time.Millisecond)
}
})
}
*/
func writeAndWaitAck(ctx context.Context, conn *net.TCPConn, payload []byte, fallbackDelay time.Duration) error {
_, err := conn.Write(payload)
if err != nil {
return err
}
return control.Conn(conn, func(fd uintptr) error {
start := time.Now()
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
unacked, err := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_NWRITE)
if err != nil {
return err
}
if unacked == 0 {
if time.Since(start) <= 20*time.Millisecond {
// under transparent proxy
time.Sleep(fallbackDelay)
}
return nil
}
time.Sleep(10 * time.Millisecond)
}
})
}

View File

@@ -0,0 +1,40 @@
package tf
import (
"context"
"net"
"time"
"github.com/sagernet/sing/common/control"
"golang.org/x/sys/unix"
)
func writeAndWaitAck(ctx context.Context, conn *net.TCPConn, payload []byte, fallbackDelay time.Duration) error {
_, err := conn.Write(payload)
if err != nil {
return err
}
return control.Conn(conn, func(fd uintptr) error {
start := time.Now()
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
tcpInfo, err := unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO)
if err != nil {
return err
}
if tcpInfo.Unacked == 0 {
if time.Since(start) <= 20*time.Millisecond {
// under transparent proxy
time.Sleep(fallbackDelay)
}
return nil
}
time.Sleep(10 * time.Millisecond)
}
})
}

View File

@@ -0,0 +1,14 @@
//go:build !(linux || darwin || windows)
package tf
import (
"context"
"net"
"time"
)
func writeAndWaitAck(ctx context.Context, conn *net.TCPConn, payload []byte, fallbackDelay time.Duration) error {
time.Sleep(fallbackDelay)
return nil
}

View File

@@ -0,0 +1,28 @@
package tf
import (
"context"
"errors"
"net"
"time"
"github.com/sagernet/sing/common/winiphlpapi"
"golang.org/x/sys/windows"
)
func writeAndWaitAck(ctx context.Context, conn *net.TCPConn, payload []byte, fallbackDelay time.Duration) error {
start := time.Now()
err := winiphlpapi.WriteAndWaitAck(ctx, conn, payload)
if err != nil {
if errors.Is(err, windows.ERROR_ACCESS_DENIED) {
time.Sleep(fallbackDelay)
return nil
}
return err
}
if time.Since(start) <= 20*time.Millisecond {
time.Sleep(fallbackDelay)
}
return nil
}

View File

@@ -2,32 +2,32 @@ package urltest
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/url"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
)
type History struct {
Time time.Time `json:"time"`
Delay uint16 `json:"delay"`
}
var _ adapter.URLTestHistoryStorage = (*HistoryStorage)(nil)
type HistoryStorage struct {
access sync.RWMutex
delayHistory map[string]*History
delayHistory map[string]*adapter.URLTestHistory
updateHook chan<- struct{}
}
func NewHistoryStorage() *HistoryStorage {
return &HistoryStorage{
delayHistory: make(map[string]*History),
delayHistory: make(map[string]*adapter.URLTestHistory),
}
}
@@ -35,7 +35,7 @@ func (s *HistoryStorage) SetHook(hook chan<- struct{}) {
s.updateHook = hook
}
func (s *HistoryStorage) LoadURLTestHistory(tag string) *History {
func (s *HistoryStorage) LoadURLTestHistory(tag string) *adapter.URLTestHistory {
if s == nil {
return nil
}
@@ -51,7 +51,7 @@ func (s *HistoryStorage) DeleteURLTestHistory(tag string) {
s.notifyUpdated()
}
func (s *HistoryStorage) StoreURLTestHistory(tag string, history *History) {
func (s *HistoryStorage) StoreURLTestHistory(tag string, history *adapter.URLTestHistory) {
s.access.Lock()
s.delayHistory[tag] = history
s.access.Unlock()
@@ -110,6 +110,10 @@ func URLTest(ctx context.Context, link string, detour N.Dialer) (t uint16, err e
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return instance, nil
},
TLSClientConfig: &tls.Config{
Time: ntp.TimeFuncFromContext(ctx),
RootCAs: adapter.RootPoolFromContext(ctx),
},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse

7
constant/certificate.go Normal file
View File

@@ -0,0 +1,7 @@
package constant
const (
CertificateStoreSystem = "system"
CertificateStoreMozilla = "mozilla"
CertificateStoreNone = "none"
)

View File

@@ -1,5 +1,34 @@
package constant
const (
DefaultDNSTTL = 600
)
type DomainStrategy = uint8
const (
DomainStrategyAsIS DomainStrategy = iota
DomainStrategyPreferIPv4
DomainStrategyPreferIPv6
DomainStrategyIPv4Only
DomainStrategyIPv6Only
)
const (
DNSTypeLegacy = "legacy"
DNSTypeUDP = "udp"
DNSTypeTCP = "tcp"
DNSTypeTLS = "tls"
DNSTypeHTTPS = "https"
DNSTypeQUIC = "quic"
DNSTypeHTTP3 = "h3"
DNSTypeHosts = "hosts"
DNSTypeLocal = "local"
DNSTypePreDefined = "predefined"
DNSTypeFakeIP = "fakeip"
DNSTypeDHCP = "dhcp"
)
const (
DNSProviderAliDNS = "alidns"
DNSProviderCloudflare = "cloudflare"

View File

@@ -23,7 +23,6 @@ const (
TypeVLESS = "vless"
TypeTUIC = "tuic"
TypeHysteria2 = "hysteria2"
TypeNDIS = "ndis"
)
const (
@@ -81,8 +80,6 @@ func ProxyDisplayName(proxyType string) string {
return "Selector"
case TypeURLTest:
return "URLTest"
case TypeNDIS:
return "NDIS"
default:
return "Unknown"
}

View File

@@ -16,6 +16,7 @@ const (
StopTimeout = 5 * time.Second
FatalStopTimeout = 10 * time.Second
FakeIPMetadataSaveInterval = 10 * time.Second
TLSFragmentFallbackDelay = 500 * time.Millisecond
)
var PortProtocols = map[uint16]string{

View File

@@ -3,6 +3,7 @@ package box
import (
"runtime/debug"
"github.com/sagernet/sing-box/common/conntrack"
"github.com/sagernet/sing-box/option"
)
@@ -25,5 +26,9 @@ func applyDebugOptions(options option.DebugOptions) {
}
if options.MemoryLimit != 0 {
debug.SetMemoryLimit(int64(float64(options.MemoryLimit) / 1.5))
conntrack.MemoryLimit = uint64(options.MemoryLimit)
}
if options.OOMKiller != nil {
conntrack.KillerEnabled = *options.OOMKiller
}
}

563
dns/client.go Normal file
View File

@@ -0,0 +1,563 @@
package dns
import (
"context"
"net"
"net/netip"
"strings"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
"github.com/miekg/dns"
)
var (
ErrNoRawSupport = E.New("no raw query support by current transport")
ErrNotCached = E.New("not cached")
ErrResponseRejected = E.New("response rejected")
ErrResponseRejectedCached = E.Extend(ErrResponseRejected, "cached")
)
var _ adapter.DNSClient = (*Client)(nil)
type Client struct {
timeout time.Duration
disableCache bool
disableExpire bool
independentCache bool
rdrc adapter.RDRCStore
initRDRCFunc func() adapter.RDRCStore
logger logger.ContextLogger
cache freelru.Cache[dns.Question, *dns.Msg]
transportCache freelru.Cache[transportCacheKey, *dns.Msg]
}
type ClientOptions struct {
Timeout time.Duration
DisableCache bool
DisableExpire bool
IndependentCache bool
CacheCapacity uint32
RDRC func() adapter.RDRCStore
Logger logger.ContextLogger
}
func NewClient(options ClientOptions) *Client {
client := &Client{
timeout: options.Timeout,
disableCache: options.DisableCache,
disableExpire: options.DisableExpire,
independentCache: options.IndependentCache,
initRDRCFunc: options.RDRC,
logger: options.Logger,
}
if client.timeout == 0 {
client.timeout = C.DNSTimeout
}
cacheCapacity := options.CacheCapacity
if cacheCapacity < 1024 {
cacheCapacity = 1024
}
if !client.disableCache {
if !client.independentCache {
client.cache = common.Must1(freelru.NewSharded[dns.Question, *dns.Msg](cacheCapacity, maphash.NewHasher[dns.Question]().Hash32))
} else {
client.transportCache = common.Must1(freelru.NewSharded[transportCacheKey, *dns.Msg](cacheCapacity, maphash.NewHasher[transportCacheKey]().Hash32))
}
}
return client
}
type transportCacheKey struct {
dns.Question
transportTag string
}
func (c *Client) Start() {
if c.initRDRCFunc != nil {
c.rdrc = c.initRDRCFunc()
}
}
func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error) {
if len(message.Question) == 0 {
if c.logger != nil {
c.logger.WarnContext(ctx, "bad question size: ", len(message.Question))
}
responseMessage := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: message.Id,
Response: true,
Rcode: dns.RcodeFormatError,
},
Question: message.Question,
}
return &responseMessage, nil
}
question := message.Question[0]
if options.ClientSubnet.IsValid() {
message = SetClientSubnet(message, options.ClientSubnet, true)
}
isSimpleRequest := len(message.Question) == 1 &&
len(message.Ns) == 0 &&
len(message.Extra) == 0 &&
!options.ClientSubnet.IsValid()
disableCache := !isSimpleRequest || c.disableCache || options.DisableCache
if !disableCache {
response, ttl := c.loadResponse(question, transport)
if response != nil {
logCachedResponse(c.logger, ctx, response, ttl)
response.Id = message.Id
return response, nil
}
}
if question.Qtype == dns.TypeA && options.Strategy == C.DomainStrategyIPv6Only || question.Qtype == dns.TypeAAAA && options.Strategy == C.DomainStrategyIPv4Only {
responseMessage := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: message.Id,
Response: true,
Rcode: dns.RcodeSuccess,
},
Question: []dns.Question{question},
}
if c.logger != nil {
c.logger.DebugContext(ctx, "strategy rejected")
}
return &responseMessage, nil
}
messageId := message.Id
contextTransport, clientSubnetLoaded := transportTagFromContext(ctx)
if clientSubnetLoaded && transport.Tag() == contextTransport {
return nil, E.New("DNS query loopback in transport[", contextTransport, "]")
}
ctx = contextWithTransportTag(ctx, transport.Tag())
if responseChecker != nil && c.rdrc != nil {
rejected := c.rdrc.LoadRDRC(transport.Tag(), question.Name, question.Qtype)
if rejected {
return nil, ErrResponseRejectedCached
}
}
ctx, cancel := context.WithTimeout(ctx, c.timeout)
response, err := transport.Exchange(ctx, message)
cancel()
if err != nil {
return nil, err
}
/*if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
validResponse := response
loop:
for {
var (
addresses int
queryCNAME string
)
for _, rawRR := range validResponse.Answer {
switch rr := rawRR.(type) {
case *dns.A:
break loop
case *dns.AAAA:
break loop
case *dns.CNAME:
queryCNAME = rr.Target
}
}
if queryCNAME == "" {
break
}
exMessage := *message
exMessage.Question = []dns.Question{{
Name: queryCNAME,
Qtype: question.Qtype,
}}
validResponse, err = c.Exchange(ctx, transport, &exMessage, options, responseChecker)
if err != nil {
return nil, err
}
}
if validResponse != response {
response.Answer = append(response.Answer, validResponse.Answer...)
}
}*/
if responseChecker != nil {
addr, addrErr := MessageToAddresses(response)
if addrErr != nil || !responseChecker(addr) {
if c.rdrc != nil {
c.rdrc.SaveRDRCAsync(transport.Tag(), question.Name, question.Qtype, c.logger)
}
logRejectedResponse(c.logger, ctx, response)
return response, ErrResponseRejected
}
}
if question.Qtype == dns.TypeHTTPS {
if options.Strategy == C.DomainStrategyIPv4Only || options.Strategy == C.DomainStrategyIPv6Only {
for _, rr := range response.Answer {
https, isHTTPS := rr.(*dns.HTTPS)
if !isHTTPS {
continue
}
content := https.SVCB
content.Value = common.Filter(content.Value, func(it dns.SVCBKeyValue) bool {
if options.Strategy == C.DomainStrategyIPv4Only {
return it.Key() != dns.SVCB_IPV6HINT
} else {
return it.Key() != dns.SVCB_IPV4HINT
}
})
https.SVCB = content
}
}
}
var timeToLive uint32
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
if timeToLive == 0 || record.Header().Ttl > 0 && record.Header().Ttl < timeToLive {
timeToLive = record.Header().Ttl
}
}
}
if options.RewriteTTL != nil {
timeToLive = *options.RewriteTTL
}
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
record.Header().Ttl = timeToLive
}
}
response.Id = messageId
if !disableCache {
c.storeCache(transport, question, response, timeToLive)
}
logExchangedResponse(c.logger, ctx, response, timeToLive)
return response, err
}
func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
domain = FqdnToDomain(domain)
dnsName := dns.Fqdn(domain)
if options.Strategy == C.DomainStrategyIPv4Only {
return c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, options, responseChecker)
} else if options.Strategy == C.DomainStrategyIPv6Only {
return c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, options, responseChecker)
}
var response4 []netip.Addr
var response6 []netip.Addr
var group task.Group
group.Append("exchange4", func(ctx context.Context) error {
response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, options, responseChecker)
if err != nil {
return err
}
response4 = response
return nil
})
group.Append("exchange6", func(ctx context.Context) error {
response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, options, responseChecker)
if err != nil {
return err
}
response6 = response
return nil
})
err := group.Run(ctx)
if len(response4) == 0 && len(response6) == 0 {
return nil, err
}
return sortAddresses(response4, response6, options.Strategy), nil
}
func (c *Client) ClearCache() {
if c.cache != nil {
c.cache.Purge()
}
if c.transportCache != nil {
c.transportCache.Purge()
}
}
func (c *Client) LookupCache(domain string, strategy C.DomainStrategy) ([]netip.Addr, bool) {
if c.disableCache || c.independentCache {
return nil, false
}
if dns.IsFqdn(domain) {
domain = domain[:len(domain)-1]
}
dnsName := dns.Fqdn(domain)
if strategy == C.DomainStrategyIPv4Only {
response, err := c.questionCache(dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}, nil)
if err != ErrNotCached {
return response, true
}
} else if strategy == C.DomainStrategyIPv6Only {
response, err := c.questionCache(dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}, nil)
if err != ErrNotCached {
return response, true
}
} else {
response4, _ := c.questionCache(dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}, nil)
response6, _ := c.questionCache(dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}, nil)
if len(response4) > 0 || len(response6) > 0 {
return sortAddresses(response4, response6, strategy), true
}
}
return nil, false
}
func (c *Client) ExchangeCache(ctx context.Context, message *dns.Msg) (*dns.Msg, bool) {
if c.disableCache || c.independentCache || len(message.Question) != 1 {
return nil, false
}
question := message.Question[0]
response, ttl := c.loadResponse(question, nil)
if response == nil {
return nil, false
}
logCachedResponse(c.logger, ctx, response, ttl)
response.Id = message.Id
return response, true
}
func sortAddresses(response4 []netip.Addr, response6 []netip.Addr, strategy C.DomainStrategy) []netip.Addr {
if strategy == C.DomainStrategyPreferIPv6 {
return append(response6, response4...)
} else {
return append(response4, response6...)
}
}
func (c *Client) storeCache(transport adapter.DNSTransport, question dns.Question, message *dns.Msg, timeToLive uint32) {
if timeToLive == 0 {
return
}
if c.disableExpire {
if !c.independentCache {
c.cache.Add(question, message)
} else {
c.transportCache.Add(transportCacheKey{
Question: question,
transportTag: transport.Tag(),
}, message)
}
return
}
if !c.independentCache {
c.cache.AddWithLifetime(question, message, time.Second*time.Duration(timeToLive))
} else {
c.transportCache.AddWithLifetime(transportCacheKey{
Question: question,
transportTag: transport.Tag(),
}, message, time.Second*time.Duration(timeToLive))
}
}
func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTransport, name string, qType uint16, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
question := dns.Question{
Name: name,
Qtype: qType,
Qclass: dns.ClassINET,
}
disableCache := c.disableCache || options.DisableCache
if !disableCache {
cachedAddresses, err := c.questionCache(question, transport)
if err != ErrNotCached {
return cachedAddresses, err
}
}
message := dns.Msg{
MsgHdr: dns.MsgHdr{
RecursionDesired: true,
},
Question: []dns.Question{question},
}
response, err := c.Exchange(ctx, transport, &message, options, responseChecker)
if err != nil {
return nil, err
}
return MessageToAddresses(response)
}
func (c *Client) questionCache(question dns.Question, transport adapter.DNSTransport) ([]netip.Addr, error) {
response, _ := c.loadResponse(question, transport)
if response == nil {
return nil, ErrNotCached
}
return MessageToAddresses(response)
}
func (c *Client) loadResponse(question dns.Question, transport adapter.DNSTransport) (*dns.Msg, int) {
var (
response *dns.Msg
loaded bool
)
if c.disableExpire {
if !c.independentCache {
response, loaded = c.cache.Get(question)
} else {
response, loaded = c.transportCache.Get(transportCacheKey{
Question: question,
transportTag: transport.Tag(),
})
}
if !loaded {
return nil, 0
}
return response.Copy(), 0
} else {
var expireAt time.Time
if !c.independentCache {
response, expireAt, loaded = c.cache.GetWithLifetime(question)
} else {
response, expireAt, loaded = c.transportCache.GetWithLifetime(transportCacheKey{
Question: question,
transportTag: transport.Tag(),
})
}
if !loaded {
return nil, 0
}
timeNow := time.Now()
if timeNow.After(expireAt) {
if !c.independentCache {
c.cache.Remove(question)
} else {
c.transportCache.Remove(transportCacheKey{
Question: question,
transportTag: transport.Tag(),
})
}
return nil, 0
}
var originTTL int
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
if originTTL == 0 || record.Header().Ttl > 0 && int(record.Header().Ttl) < originTTL {
originTTL = int(record.Header().Ttl)
}
}
}
nowTTL := int(expireAt.Sub(timeNow).Seconds())
if nowTTL < 0 {
nowTTL = 0
}
response = response.Copy()
if originTTL > 0 {
duration := uint32(originTTL - nowTTL)
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
record.Header().Ttl = record.Header().Ttl - duration
}
}
} else {
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
record.Header().Ttl = uint32(nowTTL)
}
}
}
return response, nowTTL
}
}
func MessageToAddresses(response *dns.Msg) ([]netip.Addr, error) {
if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError {
return nil, RCodeError(response.Rcode)
}
addresses := make([]netip.Addr, 0, len(response.Answer))
for _, rawAnswer := range response.Answer {
switch answer := rawAnswer.(type) {
case *dns.A:
addresses = append(addresses, M.AddrFromIP(answer.A))
case *dns.AAAA:
addresses = append(addresses, M.AddrFromIP(answer.AAAA))
case *dns.HTTPS:
for _, value := range answer.SVCB.Value {
if value.Key() == dns.SVCB_IPV4HINT || value.Key() == dns.SVCB_IPV6HINT {
addresses = append(addresses, common.Map(strings.Split(value.String(), ","), M.ParseAddr)...)
}
}
}
}
return addresses, nil
}
func wrapError(err error) error {
switch dnsErr := err.(type) {
case *net.DNSError:
if dnsErr.IsNotFound {
return RCodeNameError
}
case *net.AddrError:
return RCodeNameError
}
return err
}
type transportKey struct{}
func contextWithTransportTag(ctx context.Context, transportTag string) context.Context {
return context.WithValue(ctx, transportKey{}, transportTag)
}
func transportTagFromContext(ctx context.Context) (string, bool) {
value, loaded := ctx.Value(transportKey{}).(string)
return value, loaded
}
func FixedResponse(id uint16, question dns.Question, addresses []netip.Addr, timeToLive uint32) *dns.Msg {
response := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: id,
Rcode: dns.RcodeSuccess,
Response: true,
},
Question: []dns.Question{question},
}
for _, address := range addresses {
if address.Is4() {
response.Answer = append(response.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: timeToLive,
},
A: address.AsSlice(),
})
} else {
response.Answer = append(response.Answer, &dns.AAAA{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: timeToLive,
},
AAAA: address.AsSlice(),
})
}
}
return &response
}

69
dns/client_log.go Normal file
View File

@@ -0,0 +1,69 @@
package dns
import (
"context"
"strings"
"github.com/sagernet/sing/common/logger"
"github.com/miekg/dns"
)
func logCachedResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg, ttl int) {
if logger == nil || len(response.Question) == 0 {
return
}
domain := FqdnToDomain(response.Question[0].Name)
logger.DebugContext(ctx, "cached ", domain, " ", dns.RcodeToString[response.Rcode], " ", ttl)
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
logger.InfoContext(ctx, "cached ", dns.Type(record.Header().Rrtype).String(), " ", FormatQuestion(record.String()))
}
}
}
func logExchangedResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg, ttl uint32) {
if logger == nil || len(response.Question) == 0 {
return
}
domain := FqdnToDomain(response.Question[0].Name)
logger.DebugContext(ctx, "exchanged ", domain, " ", dns.RcodeToString[response.Rcode], " ", ttl)
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
logger.InfoContext(ctx, "exchanged ", dns.Type(record.Header().Rrtype).String(), " ", FormatQuestion(record.String()))
}
}
}
func logRejectedResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg) {
if logger == nil || len(response.Question) == 0 {
return
}
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
logger.InfoContext(ctx, "rejected ", dns.Type(record.Header().Rrtype).String(), " ", FormatQuestion(record.String()))
}
}
}
func FqdnToDomain(fqdn string) string {
if dns.IsFqdn(fqdn) {
return fqdn[:len(fqdn)-1]
}
return fqdn
}
func FormatQuestion(string string) string {
for strings.HasPrefix(string, ";") {
string = string[1:]
}
string = strings.ReplaceAll(string, "\t", " ")
string = strings.ReplaceAll(string, "\n", " ")
string = strings.ReplaceAll(string, ";; ", " ")
string = strings.ReplaceAll(string, "; ", " ")
for strings.Contains(string, " ") {
string = strings.ReplaceAll(string, " ", " ")
}
return strings.TrimSpace(string)
}

29
dns/client_truncate.go Normal file
View File

@@ -0,0 +1,29 @@
package dns
import (
"github.com/sagernet/sing/common/buf"
"github.com/miekg/dns"
)
func TruncateDNSMessage(request *dns.Msg, response *dns.Msg, headroom int) (*buf.Buffer, error) {
maxLen := 512
if edns0Option := request.IsEdns0(); edns0Option != nil {
if udpSize := int(edns0Option.UDPSize()); udpSize > 512 {
maxLen = udpSize
}
}
responseLen := response.Len()
if responseLen > maxLen {
response.Truncate(maxLen)
}
buffer := buf.NewSize(headroom*2 + 1 + responseLen)
buffer.Resize(headroom, 0)
rawMessage, err := response.PackBuffer(buffer.FreeBytes())
if err != nil {
buffer.Release()
return nil, err
}
buffer.Truncate(len(rawMessage))
return buffer, nil
}

View File

@@ -0,0 +1,56 @@
package dns
import (
"net/netip"
"github.com/miekg/dns"
)
func SetClientSubnet(message *dns.Msg, clientSubnet netip.Prefix, override bool) *dns.Msg {
var (
optRecord *dns.OPT
subnetOption *dns.EDNS0_SUBNET
)
findExists:
for _, record := range message.Extra {
var isOPTRecord bool
if optRecord, isOPTRecord = record.(*dns.OPT); isOPTRecord {
for _, option := range optRecord.Option {
var isEDNS0Subnet bool
subnetOption, isEDNS0Subnet = option.(*dns.EDNS0_SUBNET)
if isEDNS0Subnet {
if !override {
return message
}
break findExists
}
}
}
}
if optRecord == nil {
exMessage := *message
message = &exMessage
optRecord = &dns.OPT{
Hdr: dns.RR_Header{
Name: ".",
Rrtype: dns.TypeOPT,
},
}
message.Extra = append(message.Extra, optRecord)
} else {
message = message.Copy()
}
if subnetOption == nil {
subnetOption = new(dns.EDNS0_SUBNET)
optRecord.Option = append(optRecord.Option, subnetOption)
}
subnetOption.Code = dns.EDNS0SUBNET
if clientSubnet.Addr().Is4() {
subnetOption.Family = 1
} else {
subnetOption.Family = 2
}
subnetOption.SourceNetmask = uint8(clientSubnet.Bits())
subnetOption.Address = clientSubnet.Addr().AsSlice()
return message
}

33
dns/rcode.go Normal file
View File

@@ -0,0 +1,33 @@
package dns
import F "github.com/sagernet/sing/common/format"
const (
RCodeSuccess RCodeError = 0 // NoError
RCodeFormatError RCodeError = 1 // FormErr
RCodeServerFailure RCodeError = 2 // ServFail
RCodeNameError RCodeError = 3 // NXDomain
RCodeNotImplemented RCodeError = 4 // NotImp
RCodeRefused RCodeError = 5 // Refused
)
type RCodeError uint16
func (e RCodeError) Error() string {
switch e {
case RCodeSuccess:
return "success"
case RCodeFormatError:
return "format error"
case RCodeServerFailure:
return "server failure"
case RCodeNameError:
return "name error"
case RCodeNotImplemented:
return "not implemented"
case RCodeRefused:
return "refused"
default:
return F.ToString("unknown error: ", uint16(e))
}
}

437
dns/router.go Normal file
View File

@@ -0,0 +1,437 @@
package dns
import (
"context"
"errors"
"net/netip"
"strings"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
R "github.com/sagernet/sing-box/route/rule"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
"github.com/sagernet/sing/service"
mDNS "github.com/miekg/dns"
)
var _ adapter.DNSRouter = (*Router)(nil)
type Router struct {
ctx context.Context
logger logger.ContextLogger
transport adapter.DNSTransportManager
outbound adapter.OutboundManager
client adapter.DNSClient
rules []adapter.DNSRule
defaultDomainStrategy C.DomainStrategy
dnsReverseMapping freelru.Cache[netip.Addr, string]
platformInterface platform.Interface
}
func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOptions) *Router {
router := &Router{
ctx: ctx,
logger: logFactory.NewLogger("dns"),
transport: service.FromContext[adapter.DNSTransportManager](ctx),
outbound: service.FromContext[adapter.OutboundManager](ctx),
rules: make([]adapter.DNSRule, 0, len(options.Rules)),
defaultDomainStrategy: C.DomainStrategy(options.Strategy),
}
router.client = NewClient(ClientOptions{
DisableCache: options.DNSClientOptions.DisableCache,
DisableExpire: options.DNSClientOptions.DisableExpire,
IndependentCache: options.DNSClientOptions.IndependentCache,
CacheCapacity: options.DNSClientOptions.CacheCapacity,
RDRC: func() adapter.RDRCStore {
cacheFile := service.FromContext[adapter.CacheFile](ctx)
if cacheFile == nil {
return nil
}
if !cacheFile.StoreRDRC() {
return nil
}
return cacheFile
},
Logger: router.logger,
})
if options.ReverseMapping {
router.dnsReverseMapping = common.Must1(freelru.NewSharded[netip.Addr, string](1024, maphash.NewHasher[netip.Addr]().Hash32))
}
return router
}
func (r *Router) Initialize(rules []option.DNSRule) error {
for i, ruleOptions := range rules {
dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true)
if err != nil {
return E.Cause(err, "parse dns rule[", i, "]")
}
r.rules = append(r.rules, dnsRule)
}
return nil
}
func (r *Router) Start(stage adapter.StartStage) error {
monitor := taskmonitor.New(r.logger, C.StartTimeout)
switch stage {
case adapter.StartStateStart:
monitor.Start("initialize DNS client")
r.client.Start()
monitor.Finish()
for i, rule := range r.rules {
monitor.Start("initialize DNS rule[", i, "]")
err := rule.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize DNS rule[", i, "]")
}
}
}
return nil
}
func (r *Router) Close() error {
monitor := taskmonitor.New(r.logger, C.StopTimeout)
var err error
for i, rule := range r.rules {
monitor.Start("close dns rule[", i, "]")
err = E.Append(err, rule.Close(), func(err error) error {
return E.Cause(err, "close dns rule[", i, "]")
})
monitor.Finish()
}
return err
}
func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) {
metadata := adapter.ContextFrom(ctx)
if metadata == nil {
panic("no context")
}
var currentRuleIndex int
if ruleIndex != -1 {
currentRuleIndex = ruleIndex + 1
}
for ; currentRuleIndex < len(r.rules); currentRuleIndex++ {
currentRule := r.rules[currentRuleIndex]
if currentRule.WithAddressLimit() && !isAddressQuery {
continue
}
metadata.ResetRuleCache()
if currentRule.Match(metadata) {
displayRuleIndex := currentRuleIndex
if displayRuleIndex != -1 {
displayRuleIndex += displayRuleIndex + 1
}
ruleDescription := currentRule.String()
if ruleDescription != "" {
r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] ", currentRule, " => ", currentRule.Action())
} else {
r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action())
}
switch action := currentRule.Action().(type) {
case *R.RuleActionDNSRoute:
transport, loaded := r.transport.Transport(action.Server)
if !loaded {
r.logger.ErrorContext(ctx, "transport not found: ", action.Server)
continue
}
isFakeIP := transport.Type() == C.DNSTypeFakeIP
if isFakeIP && !allowFakeIP {
continue
}
if action.Strategy != C.DomainStrategyAsIS {
options.Strategy = action.Strategy
}
if isFakeIP || action.DisableCache {
options.DisableCache = true
}
if action.RewriteTTL != nil {
options.RewriteTTL = action.RewriteTTL
}
if action.ClientSubnet.IsValid() {
options.ClientSubnet = action.ClientSubnet
}
if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy {
if options.Strategy == C.DomainStrategyAsIS {
options.Strategy = legacyTransport.LegacyStrategy()
}
if !options.ClientSubnet.IsValid() {
options.ClientSubnet = legacyTransport.LegacyClientSubnet()
}
}
r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action())
return transport, currentRule, currentRuleIndex
case *R.RuleActionDNSRouteOptions:
if action.Strategy != C.DomainStrategyAsIS {
options.Strategy = action.Strategy
}
if action.DisableCache {
options.DisableCache = true
}
if action.RewriteTTL != nil {
options.RewriteTTL = action.RewriteTTL
}
if action.ClientSubnet.IsValid() {
options.ClientSubnet = action.ClientSubnet
}
r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action())
case *R.RuleActionReject:
r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action())
return nil, currentRule, currentRuleIndex
}
}
}
return r.transport.Default(), nil, -1
}
func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions) (*mDNS.Msg, error) {
if len(message.Question) != 1 {
r.logger.WarnContext(ctx, "bad question size: ", len(message.Question))
responseMessage := mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
Id: message.Id,
Response: true,
Rcode: mDNS.RcodeFormatError,
},
Question: message.Question,
}
return &responseMessage, nil
}
r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String()))
var (
transport adapter.DNSTransport
err error
)
response, cached := r.client.ExchangeCache(ctx, message)
if !cached {
var metadata *adapter.InboundContext
ctx, metadata = adapter.ExtendContext(ctx)
metadata.Destination = M.Socksaddr{}
metadata.QueryType = message.Question[0].Qtype
switch metadata.QueryType {
case mDNS.TypeA:
metadata.IPVersion = 4
case mDNS.TypeAAAA:
metadata.IPVersion = 6
}
metadata.Domain = FqdnToDomain(message.Question[0].Name)
if options.Transport != nil {
transport = options.Transport
if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy {
if options.Strategy == C.DomainStrategyAsIS {
options.Strategy = legacyTransport.LegacyStrategy()
}
if !options.ClientSubnet.IsValid() {
options.ClientSubnet = legacyTransport.LegacyClientSubnet()
}
}
if options.Strategy == C.DomainStrategyAsIS {
options.Strategy = r.defaultDomainStrategy
}
response, err = r.client.Exchange(ctx, transport, message, options, nil)
} else {
var (
rule adapter.DNSRule
ruleIndex int
)
ruleIndex = -1
for {
dnsCtx := adapter.OverrideContext(ctx)
dnsOptions := options
transport, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex, isAddressQuery(message), &dnsOptions)
if rule != nil {
switch action := rule.Action().(type) {
case *R.RuleActionReject:
switch action.Method {
case C.RuleActionRejectMethodDefault:
return FixedResponse(message.Id, message.Question[0], nil, 0), nil
case C.RuleActionRejectMethodDrop:
return nil, tun.ErrDrop
}
}
}
var responseCheck func(responseAddrs []netip.Addr) bool
if rule != nil && rule.WithAddressLimit() {
responseCheck = func(responseAddrs []netip.Addr) bool {
metadata.DestinationAddresses = responseAddrs
return rule.MatchAddressLimit(metadata)
}
}
if dnsOptions.Strategy == C.DomainStrategyAsIS {
dnsOptions.Strategy = r.defaultDomainStrategy
}
response, err = r.client.Exchange(dnsCtx, transport, message, dnsOptions, responseCheck)
var rejected bool
if err != nil {
if errors.Is(err, ErrResponseRejectedCached) {
rejected = true
r.logger.DebugContext(ctx, E.Cause(err, "response rejected for ", FormatQuestion(message.Question[0].String())), " (cached)")
} else if errors.Is(err, ErrResponseRejected) {
rejected = true
r.logger.DebugContext(ctx, E.Cause(err, "response rejected for ", FormatQuestion(message.Question[0].String())))
} else if len(message.Question) > 0 {
r.logger.ErrorContext(ctx, E.Cause(err, "exchange failed for ", FormatQuestion(message.Question[0].String())))
} else {
r.logger.ErrorContext(ctx, E.Cause(err, "exchange failed for <empty query>"))
}
}
if responseCheck != nil && rejected {
continue
}
break
}
}
}
if err != nil {
return nil, err
}
if r.dnsReverseMapping != nil && len(message.Question) > 0 && response != nil && len(response.Answer) > 0 {
if transport == nil || transport.Type() != C.DNSTypeFakeIP {
for _, answer := range response.Answer {
switch record := answer.(type) {
case *mDNS.A:
r.dnsReverseMapping.AddWithLifetime(M.AddrFromIP(record.A), FqdnToDomain(record.Hdr.Name), time.Duration(record.Hdr.Ttl)*time.Second)
case *mDNS.AAAA:
r.dnsReverseMapping.AddWithLifetime(M.AddrFromIP(record.AAAA), FqdnToDomain(record.Hdr.Name), time.Duration(record.Hdr.Ttl)*time.Second)
}
}
}
}
return response, nil
}
func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
var (
responseAddrs []netip.Addr
cached bool
err error
)
printResult := func() {
if err != nil {
if errors.Is(err, ErrResponseRejectedCached) {
r.logger.DebugContext(ctx, "response rejected for ", domain, " (cached)")
} else if errors.Is(err, ErrResponseRejected) {
r.logger.DebugContext(ctx, "response rejected for ", domain)
} else {
r.logger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain))
}
} else if len(responseAddrs) == 0 {
r.logger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result")
err = RCodeNameError
}
}
responseAddrs, cached = r.client.LookupCache(domain, options.Strategy)
if cached {
if len(responseAddrs) == 0 {
return nil, RCodeNameError
}
return responseAddrs, nil
}
r.logger.DebugContext(ctx, "lookup domain ", domain)
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Destination = M.Socksaddr{}
metadata.Domain = domain
if options.Transport != nil {
transport := options.Transport
if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy {
if options.Strategy == C.DomainStrategyAsIS {
options.Strategy = r.defaultDomainStrategy
}
if !options.ClientSubnet.IsValid() {
options.ClientSubnet = legacyTransport.LegacyClientSubnet()
}
}
if options.Strategy == C.DomainStrategyAsIS {
options.Strategy = r.defaultDomainStrategy
}
responseAddrs, err = r.client.Lookup(ctx, transport, domain, options, nil)
} else {
var (
transport adapter.DNSTransport
rule adapter.DNSRule
ruleIndex int
)
ruleIndex = -1
for {
dnsCtx := adapter.OverrideContext(ctx)
transport, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true, &options)
if rule != nil {
switch action := rule.Action().(type) {
case *R.RuleActionReject:
switch action.Method {
case C.RuleActionRejectMethodDefault:
return nil, nil
case C.RuleActionRejectMethodDrop:
return nil, tun.ErrDrop
}
}
}
var responseCheck func(responseAddrs []netip.Addr) bool
if rule != nil && rule.WithAddressLimit() {
responseCheck = func(responseAddrs []netip.Addr) bool {
metadata.DestinationAddresses = responseAddrs
return rule.MatchAddressLimit(metadata)
}
}
if options.Strategy == C.DomainStrategyAsIS {
options.Strategy = r.defaultDomainStrategy
}
responseAddrs, err = r.client.Lookup(dnsCtx, transport, domain, options, responseCheck)
if responseCheck == nil || err == nil {
break
}
printResult()
}
}
printResult()
if len(responseAddrs) > 0 {
r.logger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(responseAddrs), " "))
}
return responseAddrs, err
}
func isAddressQuery(message *mDNS.Msg) bool {
for _, question := range message.Question {
if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA || question.Qtype == mDNS.TypeHTTPS {
return true
}
}
return false
}
func (r *Router) ClearCache() {
r.client.ClearCache()
if r.platformInterface != nil {
r.platformInterface.ClearDNSCache()
}
}
func (r *Router) LookupReverseMapping(ip netip.Addr) (string, bool) {
if r.dnsReverseMapping == nil {
return "", false
}
domain, loaded := r.dnsReverseMapping.Get(ip)
return domain, loaded
}
func (r *Router) ResetNetwork() {
r.ClearCache()
for _, transport := range r.transport.Transports() {
transport.Reset()
}
}

View File

@@ -3,9 +3,6 @@ package dhcp
import (
"context"
"net"
"net/netip"
"net/url"
"os"
"runtime"
"strings"
"sync"
@@ -14,13 +11,18 @@ import (
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
@@ -29,76 +31,70 @@ import (
mDNS "github.com/miekg/dns"
)
func init() {
dns.RegisterTransport([]string{"dhcp"}, func(options dns.TransportOptions) (dns.Transport, error) {
return NewTransport(options)
})
func RegisterTransport(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.DHCPDNSServerOptions](registry, C.DNSTypeDHCP, NewTransport)
}
var _ adapter.DNSTransport = (*Transport)(nil)
type Transport struct {
options dns.TransportOptions
router adapter.Router
dns.TransportAdapter
ctx context.Context
dialer N.Dialer
logger logger.ContextLogger
networkManager adapter.NetworkManager
interfaceName string
autoInterface bool
interfaceCallback *list.Element[tun.DefaultInterfaceUpdateCallback]
transports []dns.Transport
transports []adapter.DNSTransport
updateAccess sync.Mutex
updatedAt time.Time
}
func NewTransport(options dns.TransportOptions) (*Transport, error) {
linkURL, err := url.Parse(options.Address)
func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.DHCPDNSServerOptions) (adapter.DNSTransport, error) {
transportDialer, err := dns.NewLocalDialer(ctx, options.LocalDNSServerOptions)
if err != nil {
return nil, err
}
if linkURL.Host == "" {
return nil, E.New("missing interface name for DHCP")
}
transport := &Transport{
options: options,
networkManager: service.FromContext[adapter.NetworkManager](options.Context),
interfaceName: linkURL.Host,
autoInterface: linkURL.Host == "auto",
}
return transport, nil
return &Transport{
TransportAdapter: dns.NewTransportAdapterWithLocalOptions(C.DNSTypeDHCP, tag, options.LocalDNSServerOptions),
ctx: ctx,
dialer: transportDialer,
logger: logger,
networkManager: service.FromContext[adapter.NetworkManager](ctx),
interfaceName: options.Interface,
}, nil
}
func (t *Transport) Name() string {
return t.options.Name
}
func (t *Transport) Start() error {
func (t *Transport) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
err := t.fetchServers()
if err != nil {
return err
}
if t.autoInterface {
if t.interfaceName == "" {
t.interfaceCallback = t.networkManager.InterfaceMonitor().RegisterCallback(t.interfaceUpdated)
}
return nil
}
func (t *Transport) Close() error {
for _, transport := range t.transports {
transport.Reset()
}
if t.interfaceCallback != nil {
t.networkManager.InterfaceMonitor().UnregisterCallback(t.interfaceCallback)
}
return nil
}
func (t *Transport) Reset() {
for _, transport := range t.transports {
transport.Reset()
}
}
func (t *Transport) Close() error {
for _, transport := range t.transports {
transport.Close()
}
if t.interfaceCallback != nil {
t.networkManager.InterfaceMonitor().UnregisterCallback(t.interfaceCallback)
}
return nil
}
func (t *Transport) Raw() bool {
return true
}
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
err := t.fetchServers()
if err != nil {
@@ -120,7 +116,7 @@ func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg,
}
func (t *Transport) fetchInterface() (*control.Interface, error) {
if t.autoInterface {
if t.interfaceName == "" {
if t.networkManager.InterfaceMonitor() == nil {
return nil, E.New("missing monitor for auto DHCP, set route.auto_detect_interface")
}
@@ -152,8 +148,8 @@ func (t *Transport) updateServers() error {
return E.Cause(err, "dhcp: prepare interface")
}
t.options.Logger.Info("dhcp: query DNS servers on ", iface.Name)
fetchCtx, cancel := context.WithTimeout(t.options.Context, C.DHCPTimeout)
t.logger.Info("dhcp: query DNS servers on ", iface.Name)
fetchCtx, cancel := context.WithTimeout(t.ctx, C.DHCPTimeout)
err = t.fetchServers0(fetchCtx, iface)
cancel()
if err != nil {
@@ -169,7 +165,7 @@ func (t *Transport) updateServers() error {
func (t *Transport) interfaceUpdated(defaultInterface *control.Interface, flags int) {
err := t.updateServers()
if err != nil {
t.options.Logger.Error("update servers: ", err)
t.logger.Error("update servers: ", err)
}
}
@@ -181,7 +177,7 @@ func (t *Transport) fetchServers0(ctx context.Context, iface *control.Interface)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
listenAddr = "255.255.255.255:68"
}
packetConn, err := listener.ListenPacket(t.options.Context, "udp4", listenAddr)
packetConn, err := listener.ListenPacket(t.ctx, "udp4", listenAddr)
if err != nil {
return err
}
@@ -219,17 +215,17 @@ func (t *Transport) fetchServersResponse(iface *control.Interface, packetConn ne
dhcpPacket, err := dhcpv4.FromBytes(buffer.Bytes())
if err != nil {
t.options.Logger.Trace("dhcp: parse DHCP response: ", err)
t.logger.Trace("dhcp: parse DHCP response: ", err)
return err
}
if dhcpPacket.MessageType() != dhcpv4.MessageTypeOffer {
t.options.Logger.Trace("dhcp: expected OFFER response, but got ", dhcpPacket.MessageType())
t.logger.Trace("dhcp: expected OFFER response, but got ", dhcpPacket.MessageType())
continue
}
if dhcpPacket.TransactionID != transactionID {
t.options.Logger.Trace("dhcp: expected transaction ID ", transactionID, ", but got ", dhcpPacket.TransactionID)
t.logger.Trace("dhcp: expected transaction ID ", transactionID, ", but got ", dhcpPacket.TransactionID)
continue
}
@@ -237,44 +233,27 @@ func (t *Transport) fetchServersResponse(iface *control.Interface, packetConn ne
if len(dns) == 0 {
return nil
}
var addrs []netip.Addr
for _, ip := range dns {
addr, _ := netip.AddrFromSlice(ip)
addrs = append(addrs, addr.Unmap())
}
return t.recreateServers(iface, addrs)
return t.recreateServers(iface, common.Map(dns, func(it net.IP) M.Socksaddr {
return M.SocksaddrFrom(M.AddrFromIP(it), 53)
}))
}
}
func (t *Transport) recreateServers(iface *control.Interface, serverAddrs []netip.Addr) error {
func (t *Transport) recreateServers(iface *control.Interface, serverAddrs []M.Socksaddr) error {
if len(serverAddrs) > 0 {
t.options.Logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, func(it netip.Addr) string {
return it.String()
}), ","), "]")
t.logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, M.Socksaddr.String), ","), "]")
}
serverDialer := common.Must1(dialer.NewDefault(t.options.Context, option.DialerOptions{
serverDialer := common.Must1(dialer.NewDefault(t.ctx, option.DialerOptions{
BindInterface: iface.Name,
UDPFragmentDefault: true,
}))
var transports []dns.Transport
var transports []adapter.DNSTransport
for _, serverAddr := range serverAddrs {
newOptions := t.options
newOptions.Address = serverAddr.String()
newOptions.Dialer = serverDialer
serverTransport, err := dns.NewUDPTransport(newOptions)
if err != nil {
return E.Cause(err, "create UDP transport from DHCP result: ", serverAddr)
}
transports = append(transports, serverTransport)
transports = append(transports, transport.NewUDPRaw(t.logger, t.TransportAdapter, serverDialer, serverAddr))
}
for _, transport := range t.transports {
transport.Close()
transport.Reset()
}
t.transports = transports
return nil
}
func (t *Transport) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
return nil, os.ErrInvalid
}

View File

@@ -0,0 +1,67 @@
package fakeip
import (
"context"
"net/netip"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
mDNS "github.com/miekg/dns"
)
func RegisterTransport(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.FakeIPDNSServerOptions](registry, C.DNSTypeFakeIP, NewTransport)
}
var _ adapter.FakeIPTransport = (*Transport)(nil)
type Transport struct {
dns.TransportAdapter
logger logger.ContextLogger
store adapter.FakeIPStore
}
func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.FakeIPDNSServerOptions) (adapter.DNSTransport, error) {
store := NewStore(ctx, logger, options.Inet4Range.Build(netip.Prefix{}), options.Inet6Range.Build(netip.Prefix{}))
return &Transport{
TransportAdapter: dns.NewTransportAdapter(C.DNSTypeFakeIP, tag, nil),
logger: logger,
store: store,
}, nil
}
func (t *Transport) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return t.store.Start()
}
func (t *Transport) Close() error {
return t.store.Close()
}
func (t *Transport) Reset() {
}
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
question := message.Question[0]
if question.Qtype != mDNS.TypeA && question.Qtype != mDNS.TypeAAAA {
return nil, E.New("only IP queries are supported by fakeip")
}
address, err := t.store.Create(question.Name, question.Qtype == mDNS.TypeAAAA)
if err != nil {
return nil, err
}
return dns.FixedResponse(message.Id, question, []netip.Addr{address}, C.DefaultDNSTTL), nil
}
func (t *Transport) Store() adapter.FakeIPStore {
return t.store
}

View File

@@ -0,0 +1,63 @@
package hosts
import (
"context"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
mDNS "github.com/miekg/dns"
)
func RegisterTransport(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.HostsDNSServerOptions](registry, C.DNSTypeHosts, NewTransport)
}
var _ adapter.DNSTransport = (*Transport)(nil)
type Transport struct {
dns.TransportAdapter
files []*File
}
func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.HostsDNSServerOptions) (adapter.DNSTransport, error) {
var files []*File
if len(options.Path) == 0 {
files = append(files, NewFile(DefaultPath))
} else {
for _, path := range options.Path {
files = append(files, NewFile(path))
}
}
return &Transport{
TransportAdapter: dns.NewTransportAdapter(C.DNSTypeHosts, tag, nil),
files: files,
}, nil
}
func (t *Transport) Reset() {
}
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
question := message.Question[0]
domain := dns.FqdnToDomain(question.Name)
if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
for _, file := range t.files {
addresses := file.Lookup(domain)
if len(addresses) > 0 {
return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil
}
}
}
return &mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
Id: message.Id,
Rcode: mDNS.RcodeNameError,
Response: true,
},
Question: []mDNS.Question{question},
}, nil
}

View File

@@ -0,0 +1,102 @@
package hosts
import (
"bufio"
"errors"
"io"
"net/netip"
"os"
"strings"
"sync"
"time"
"github.com/miekg/dns"
)
const cacheMaxAge = 5 * time.Second
type File struct {
path string
access sync.Mutex
byName map[string][]netip.Addr
expire time.Time
modTime time.Time
size int64
}
func NewFile(path string) *File {
return &File{
path: path,
}
}
func (f *File) Lookup(name string) []netip.Addr {
f.access.Lock()
defer f.access.Unlock()
f.update()
return f.byName[name]
}
func (f *File) update() {
now := time.Now()
if now.Before(f.expire) && len(f.byName) > 0 {
return
}
stat, err := os.Stat(f.path)
if err != nil {
return
}
if f.modTime.Equal(stat.ModTime()) && f.size == stat.Size() {
f.expire = now.Add(cacheMaxAge)
return
}
byName := make(map[string][]netip.Addr)
file, err := os.Open(f.path)
if err != nil {
return
}
defer file.Close()
reader := bufio.NewReader(file)
var (
prefix []byte
line []byte
isPrefix bool
)
for {
line, isPrefix, err = reader.ReadLine()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return
}
if isPrefix {
prefix = append(prefix, line...)
continue
} else if len(prefix) > 0 {
line = append(prefix, line...)
prefix = nil
}
commentIndex := strings.IndexRune(string(line), '#')
if commentIndex != -1 {
line = line[:commentIndex]
}
fields := strings.Fields(string(line))
if len(fields) < 2 {
continue
}
var addr netip.Addr
addr, err = netip.ParseAddr(fields[0])
if err != nil {
continue
}
for index := 1; index < len(fields); index++ {
canonicalName := dns.CanonicalName(fields[index])
byName[canonicalName] = append(byName[canonicalName], addr)
}
}
f.expire = now.Add(cacheMaxAge)
f.modTime = stat.ModTime()
f.size = stat.Size()
f.byName = byName
}

View File

@@ -0,0 +1,16 @@
package hosts_test
import (
"net/netip"
"testing"
"github.com/sagernet/sing-box/dns/transport/hosts"
"github.com/stretchr/testify/require"
)
func TestHosts(t *testing.T) {
t.Parallel()
require.Equal(t, []netip.Addr{netip.AddrFrom4([4]byte{127, 0, 0, 1}), netip.IPv6Loopback()}, hosts.NewFile("testdata/hosts").Lookup("localhost."))
require.NotEmpty(t, hosts.NewFile(hosts.DefaultPath).Lookup("localhost."))
}

View File

@@ -0,0 +1,5 @@
//go:build !windows
package hosts
var DefaultPath = "/etc/hosts"

View File

@@ -0,0 +1,17 @@
package hosts
import (
"path/filepath"
"golang.org/x/sys/windows"
)
var DefaultPath string
func init() {
systemDirectory, err := windows.GetSystemDirectory()
if err != nil {
systemDirectory = "C:\\Windows\\System32"
}
DefaultPath = filepath.Join(systemDirectory, "Drivers/etc/hosts")
}

2
dns/transport/hosts/testdata/hosts vendored Normal file
View File

@@ -0,0 +1,2 @@
127.0.0.1 localhost
::1 localhost

204
dns/transport/https.go Normal file
View File

@@ -0,0 +1,204 @@
package transport
import (
"bytes"
"context"
"io"
"net"
"net/http"
"net/url"
"strconv"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
sHTTP "github.com/sagernet/sing/protocol/http"
mDNS "github.com/miekg/dns"
"golang.org/x/net/http2"
)
const MimeType = "application/dns-message"
var _ adapter.DNSTransport = (*HTTPSTransport)(nil)
func RegisterHTTPS(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.RemoteHTTPSDNSServerOptions](registry, C.DNSTypeHTTPS, NewHTTPS)
}
type HTTPSTransport struct {
dns.TransportAdapter
logger logger.ContextLogger
dialer N.Dialer
destination *url.URL
headers http.Header
transport *http.Transport
}
func NewHTTPS(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteHTTPSDNSServerOptions) (adapter.DNSTransport, error) {
transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions)
if err != nil {
return nil, err
}
tlsOptions := common.PtrValueOrDefault(options.TLS)
tlsOptions.Enabled = true
tlsConfig, err := tls.NewClient(ctx, options.Server, tlsOptions)
if err != nil {
return nil, err
}
if common.Error(tlsConfig.Config()) == nil && !common.Contains(tlsConfig.NextProtos(), http2.NextProtoTLS) {
tlsConfig.SetNextProtos(append(tlsConfig.NextProtos(), http2.NextProtoTLS))
}
if !common.Contains(tlsConfig.NextProtos(), "http/1.1") {
tlsConfig.SetNextProtos(append(tlsConfig.NextProtos(), "http/1.1"))
}
headers := options.Headers.Build()
host := headers.Get("Host")
if host != "" {
headers.Del("Host")
} else {
if tlsConfig.ServerName() != "" {
host = tlsConfig.ServerName()
} else {
host = options.Server
}
}
destinationURL := url.URL{
Scheme: "https",
Host: host,
}
if destinationURL.Host == "" {
destinationURL.Host = options.Server
}
if options.ServerPort != 0 && options.ServerPort != 443 {
destinationURL.Host = net.JoinHostPort(destinationURL.Host, strconv.Itoa(int(options.ServerPort)))
}
path := options.Path
if path == "" {
path = "/dns-query"
}
err = sHTTP.URLSetPath(&destinationURL, path)
if err != nil {
return nil, err
}
serverAddr := options.ServerOptions.Build()
if serverAddr.Port == 0 {
serverAddr.Port = 443
}
return NewHTTPSRaw(
dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeHTTPS, tag, options.RemoteDNSServerOptions),
logger,
transportDialer,
&destinationURL,
headers,
serverAddr,
tlsConfig,
), nil
}
func NewHTTPSRaw(
adapter dns.TransportAdapter,
logger log.ContextLogger,
dialer N.Dialer,
destination *url.URL,
headers http.Header,
serverAddr M.Socksaddr,
tlsConfig tls.Config,
) *HTTPSTransport {
var transport *http.Transport
if tlsConfig != nil {
transport = &http.Transport{
ForceAttemptHTTP2: true,
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
tcpConn, hErr := dialer.DialContext(ctx, network, serverAddr)
if hErr != nil {
return nil, hErr
}
tlsConn, hErr := aTLS.ClientHandshake(ctx, tcpConn, tlsConfig)
if hErr != nil {
tcpConn.Close()
return nil, hErr
}
return tlsConn, nil
},
}
} else {
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, serverAddr)
},
}
}
return &HTTPSTransport{
TransportAdapter: adapter,
logger: logger,
dialer: dialer,
destination: destination,
headers: headers,
transport: transport,
}
}
func (t *HTTPSTransport) Reset() {
t.transport.CloseIdleConnections()
t.transport = t.transport.Clone()
}
func (t *HTTPSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
exMessage := *message
exMessage.Id = 0
exMessage.Compress = true
requestBuffer := buf.NewSize(1 + message.Len())
rawMessage, err := exMessage.PackBuffer(requestBuffer.FreeBytes())
if err != nil {
requestBuffer.Release()
return nil, err
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, t.destination.String(), bytes.NewReader(rawMessage))
if err != nil {
requestBuffer.Release()
return nil, err
}
request.Header = t.headers.Clone()
request.Header.Set("Content-Type", MimeType)
request.Header.Set("Accept", MimeType)
response, err := t.transport.RoundTrip(request)
requestBuffer.Release()
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return nil, E.New("unexpected status: ", response.Status)
}
var responseMessage mDNS.Msg
if response.ContentLength > 0 {
responseBuffer := buf.NewSize(int(response.ContentLength))
_, err = responseBuffer.ReadFullFrom(response.Body, int(response.ContentLength))
if err != nil {
return nil, err
}
err = responseMessage.Unpack(responseBuffer.Bytes())
responseBuffer.Release()
} else {
rawMessage, err = io.ReadAll(response.Body)
if err != nil {
return nil, err
}
err = responseMessage.Unpack(rawMessage)
}
if err != nil {
return nil, err
}
return &responseMessage, nil
}

View File

@@ -0,0 +1,195 @@
package local
import (
"context"
"math/rand"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport/hosts"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
mDNS "github.com/miekg/dns"
)
func RegisterTransport(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.LocalDNSServerOptions](registry, C.DNSTypeLocal, NewTransport)
}
var _ adapter.DNSTransport = (*Transport)(nil)
type Transport struct {
dns.TransportAdapter
hosts *hosts.File
dialer N.Dialer
}
func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.LocalDNSServerOptions) (adapter.DNSTransport, error) {
transportDialer, err := dns.NewLocalDialer(ctx, options)
if err != nil {
return nil, err
}
return &Transport{
TransportAdapter: dns.NewTransportAdapterWithLocalOptions(C.DNSTypeLocal, tag, options),
hosts: hosts.NewFile(hosts.DefaultPath),
dialer: transportDialer,
}, nil
}
func (t *Transport) Reset() {
}
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
question := message.Question[0]
domain := dns.FqdnToDomain(question.Name)
if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
addresses := t.hosts.Lookup(domain)
if len(addresses) > 0 {
return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil
}
}
systemConfig := getSystemDNSConfig()
if systemConfig.singleRequest || !(message.Question[0].Qtype == mDNS.TypeA || message.Question[0].Qtype == mDNS.TypeAAAA) {
return t.exchangeSingleRequest(ctx, systemConfig, message, domain)
} else {
return t.exchangeParallel(ctx, systemConfig, message, domain)
}
}
func (t *Transport) exchangeSingleRequest(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
var lastErr error
for _, fqdn := range systemConfig.nameList(domain) {
response, err := t.tryOneName(ctx, systemConfig, fqdn, message)
if err != nil {
lastErr = err
continue
}
return response, nil
}
return nil, lastErr
}
func (t *Transport) exchangeParallel(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
returned := make(chan struct{})
defer close(returned)
type queryResult struct {
response *mDNS.Msg
err error
}
results := make(chan queryResult)
startRacer := func(ctx context.Context, fqdn string) {
response, err := t.tryOneName(ctx, systemConfig, fqdn, message)
addresses, _ := dns.MessageToAddresses(response)
if len(addresses) == 0 {
err = E.New(fqdn, ": empty result")
}
select {
case results <- queryResult{response, err}:
case <-returned:
}
}
queryCtx, queryCancel := context.WithCancel(ctx)
defer queryCancel()
var nameCount int
for _, fqdn := range systemConfig.nameList(domain) {
nameCount++
go startRacer(queryCtx, fqdn)
}
var errors []error
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case result := <-results:
if result.err == nil {
return result.response, nil
}
errors = append(errors, result.err)
if len(errors) == nameCount {
return nil, E.Errors(errors...)
}
}
}
}
func (t *Transport) tryOneName(ctx context.Context, config *dnsConfig, fqdn string, message *mDNS.Msg) (*mDNS.Msg, error) {
serverOffset := config.serverOffset()
sLen := uint32(len(config.servers))
var lastErr error
for i := 0; i < config.attempts; i++ {
for j := uint32(0); j < sLen; j++ {
server := config.servers[(serverOffset+j)%sLen]
question := message.Question[0]
question.Name = fqdn
response, err := t.exchangeOne(ctx, M.ParseSocksaddr(server), question, config.timeout, config.useTCP, config.trustAD)
if err != nil {
lastErr = err
continue
}
return response, nil
}
}
return nil, E.Cause(lastErr, fqdn)
}
func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) {
var networks []string
if useTCP {
networks = []string{N.NetworkTCP}
} else {
networks = []string{N.NetworkUDP, N.NetworkTCP}
}
request := &mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
Id: uint16(rand.Uint32()),
RecursionDesired: true,
AuthenticatedData: ad,
},
Question: []mDNS.Question{question},
Compress: true,
}
request.SetEdns0(maxDNSPacketSize, false)
buffer := buf.Get(buf.UDPBufferSize)
defer buf.Put(buffer)
for _, network := range networks {
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
defer cancel()
conn, err := t.dialer.DialContext(ctx, network, server)
if err != nil {
return nil, err
}
defer conn.Close()
if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
conn.SetDeadline(deadline)
}
rawMessage, err := request.PackBuffer(buffer)
if err != nil {
return nil, E.Cause(err, "pack request")
}
_, err = conn.Write(rawMessage)
if err != nil {
return nil, E.Cause(err, "write request")
}
n, err := conn.Read(buffer)
if err != nil {
return nil, E.Cause(err, "read response")
}
var response mDNS.Msg
err = response.Unpack(buffer[:n])
if err != nil {
return nil, E.Cause(err, "unpack response")
}
if response.Truncated && network == N.NetworkUDP {
continue
}
return &response, nil
}
panic("unexpected")
}

View File

@@ -0,0 +1,146 @@
package local
import (
"os"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
)
const (
// net.maxDNSPacketSize
maxDNSPacketSize = 1232
)
type resolverConfig struct {
initOnce sync.Once
ch chan struct{}
lastChecked time.Time
dnsConfig atomic.Pointer[dnsConfig]
}
var resolvConf resolverConfig
func getSystemDNSConfig() *dnsConfig {
resolvConf.tryUpdate("/etc/resolv.conf")
return resolvConf.dnsConfig.Load()
}
func (conf *resolverConfig) init() {
conf.dnsConfig.Store(dnsReadConfig("/etc/resolv.conf"))
conf.lastChecked = time.Now()
conf.ch = make(chan struct{}, 1)
}
func (conf *resolverConfig) tryUpdate(name string) {
conf.initOnce.Do(conf.init)
if conf.dnsConfig.Load().noReload {
return
}
if !conf.tryAcquireSema() {
return
}
defer conf.releaseSema()
now := time.Now()
if conf.lastChecked.After(now.Add(-5 * time.Second)) {
return
}
conf.lastChecked = now
if runtime.GOOS != "windows" {
var mtime time.Time
if fi, err := os.Stat(name); err == nil {
mtime = fi.ModTime()
}
if mtime.Equal(conf.dnsConfig.Load().mtime) {
return
}
}
dnsConf := dnsReadConfig(name)
conf.dnsConfig.Store(dnsConf)
}
func (conf *resolverConfig) tryAcquireSema() bool {
select {
case conf.ch <- struct{}{}:
return true
default:
return false
}
}
func (conf *resolverConfig) releaseSema() {
<-conf.ch
}
type dnsConfig struct {
servers []string
search []string
ndots int
timeout time.Duration
attempts int
rotate bool
unknownOpt bool
lookup []string
err error
mtime time.Time
soffset uint32
singleRequest bool
useTCP bool
trustAD bool
noReload bool
}
func (c *dnsConfig) serverOffset() uint32 {
if c.rotate {
return atomic.AddUint32(&c.soffset, 1) - 1 // return 0 to start
}
return 0
}
func (conf *dnsConfig) nameList(name string) []string {
l := len(name)
rooted := l > 0 && name[l-1] == '.'
if l > 254 || l == 254 && !rooted {
return nil
}
if rooted {
if avoidDNS(name) {
return nil
}
return []string{name}
}
hasNdots := strings.Count(name, ".") >= conf.ndots
name += "."
// l++
names := make([]string, 0, 1+len(conf.search))
if hasNdots && !avoidDNS(name) {
names = append(names, name)
}
for _, suffix := range conf.search {
fqdn := name + suffix
if !avoidDNS(fqdn) && len(fqdn) <= 254 {
names = append(names, fqdn)
}
}
if !hasNdots && !avoidDNS(name) {
names = append(names, name)
}
return names
}
func avoidDNS(name string) bool {
if name == "" {
return true
}
if name[len(name)-1] == '.' {
name = name[:len(name)-1]
}
return strings.HasSuffix(name, ".onion")
}

View File

@@ -0,0 +1,175 @@
//go:build !windows
package local
import (
"bufio"
"net"
"net/netip"
"os"
"strings"
"time"
_ "unsafe"
)
func dnsReadConfig(name string) *dnsConfig {
conf := &dnsConfig{
ndots: 1,
timeout: 5 * time.Second,
attempts: 2,
}
file, err := os.Open(name)
if err != nil {
conf.servers = defaultNS
conf.search = dnsDefaultSearch()
conf.err = err
return conf
}
defer file.Close()
fi, err := file.Stat()
if err == nil {
conf.mtime = fi.ModTime()
} else {
conf.servers = defaultNS
conf.search = dnsDefaultSearch()
conf.err = err
return conf
}
reader := bufio.NewReader(file)
var (
prefix []byte
line []byte
isPrefix bool
)
for {
line, isPrefix, err = reader.ReadLine()
if err != nil {
break
}
if isPrefix {
prefix = append(prefix, line...)
continue
} else if len(prefix) > 0 {
line = append(prefix, line...)
prefix = nil
}
if len(line) > 0 && (line[0] == ';' || line[0] == '#') {
continue
}
f := strings.Fields(string(line))
if len(f) < 1 {
continue
}
switch f[0] {
case "nameserver":
if len(f) > 1 && len(conf.servers) < 3 {
if _, err := netip.ParseAddr(f[1]); err == nil {
conf.servers = append(conf.servers, net.JoinHostPort(f[1], "53"))
}
}
case "domain":
if len(f) > 1 {
conf.search = []string{ensureRooted(f[1])}
}
case "search":
conf.search = make([]string, 0, len(f)-1)
for i := 1; i < len(f); i++ {
name := ensureRooted(f[i])
if name == "." {
continue
}
conf.search = append(conf.search, name)
}
case "options":
for _, s := range f[1:] {
switch {
case strings.HasPrefix(s, "ndots:"):
n, _, _ := dtoi(s[6:])
if n < 0 {
n = 0
} else if n > 15 {
n = 15
}
conf.ndots = n
case strings.HasPrefix(s, "timeout:"):
n, _, _ := dtoi(s[8:])
if n < 1 {
n = 1
}
conf.timeout = time.Duration(n) * time.Second
case strings.HasPrefix(s, "attempts:"):
n, _, _ := dtoi(s[9:])
if n < 1 {
n = 1
}
conf.attempts = n
case s == "rotate":
conf.rotate = true
case s == "single-request" || s == "single-request-reopen":
conf.singleRequest = true
case s == "use-vc" || s == "usevc" || s == "tcp":
conf.useTCP = true
case s == "trust-ad":
conf.trustAD = true
case s == "edns0":
case s == "no-reload":
conf.noReload = true
default:
conf.unknownOpt = true
}
}
case "lookup":
conf.lookup = f[1:]
default:
conf.unknownOpt = true
}
}
if len(conf.servers) == 0 {
conf.servers = defaultNS
}
if len(conf.search) == 0 {
conf.search = dnsDefaultSearch()
}
return conf
}
//go:linkname defaultNS net.defaultNS
var defaultNS []string
func dnsDefaultSearch() []string {
hn, err := os.Hostname()
if err != nil {
return nil
}
if i := strings.IndexRune(hn, '.'); i >= 0 && i < len(hn)-1 {
return []string{ensureRooted(hn[i+1:])}
}
return nil
}
func ensureRooted(s string) string {
if len(s) > 0 && s[len(s)-1] == '.' {
return s
}
return s + "."
}
const big = 0xFFFFFF
func dtoi(s string) (n int, i int, ok bool) {
n = 0
for i = 0; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ {
n = n*10 + int(s[i]-'0')
if n >= big {
return big, i, false
}
}
if i == 0 {
return 0, 0, false
}
return n, i, true
}

View File

@@ -0,0 +1,100 @@
package local
import (
"net"
"net/netip"
"os"
"syscall"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
func dnsReadConfig(_ string) *dnsConfig {
conf := &dnsConfig{
ndots: 1,
timeout: 5 * time.Second,
attempts: 2,
}
defer func() {
if len(conf.servers) == 0 {
conf.servers = defaultNS
}
}()
aas, err := adapterAddresses()
if err != nil {
return nil
}
for _, aa := range aas {
// Only take interfaces whose OperStatus is IfOperStatusUp(0x01) into DNS configs.
if aa.OperStatus != windows.IfOperStatusUp {
continue
}
// Only take interfaces which have at least one gateway
if aa.FirstGatewayAddress == nil {
continue
}
for dns := aa.FirstDnsServerAddress; dns != nil; dns = dns.Next {
sa, err := dns.Address.Sockaddr.Sockaddr()
if err != nil {
continue
}
var ip netip.Addr
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
ip = netip.AddrFrom4([4]byte{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]})
case *syscall.SockaddrInet6:
var addr16 [16]byte
copy(addr16[:], sa.Addr[:])
if addr16[0] == 0xfe && addr16[1] == 0xc0 {
// fec0/10 IPv6 addresses are site local anycast DNS
// addresses Microsoft sets by default if no other
// IPv6 DNS address is set. Site local anycast is
// deprecated since 2004, see
// https://datatracker.ietf.org/doc/html/rfc3879
continue
}
ip = netip.AddrFrom16(addr16)
default:
// Unexpected type.
continue
}
conf.servers = append(conf.servers, net.JoinHostPort(ip.String(), "53"))
}
}
return conf
}
//go:linkname defaultNS net.defaultNS
var defaultNS []string
func adapterAddresses() ([]*windows.IpAdapterAddresses, error) {
var b []byte
l := uint32(15000) // recommended initial size
for {
b = make([]byte, l)
const flags = windows.GAA_FLAG_INCLUDE_PREFIX | windows.GAA_FLAG_INCLUDE_GATEWAYS
err := windows.GetAdaptersAddresses(syscall.AF_UNSPEC, flags, 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &l)
if err == nil {
if l == 0 {
return nil, nil
}
break
}
if err.(syscall.Errno) != syscall.ERROR_BUFFER_OVERFLOW {
return nil, os.NewSyscallError("getadaptersaddresses", err)
}
if l <= uint32(len(b)) {
return nil, os.NewSyscallError("getadaptersaddresses", err)
}
}
var aas []*windows.IpAdapterAddresses
for aa := (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])); aa != nil; aa = aa.Next {
aas = append(aas, aa)
}
return aas, nil
}

View File

@@ -0,0 +1,82 @@
package transport
import (
"context"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
mDNS "github.com/miekg/dns"
)
var _ adapter.DNSTransport = (*PredefinedTransport)(nil)
func RegisterPredefined(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.PredefinedDNSServerOptions](registry, C.DNSTypePreDefined, NewPredefined)
}
type PredefinedTransport struct {
dns.TransportAdapter
responses []*predefinedResponse
}
type predefinedResponse struct {
questions []mDNS.Question
answer *mDNS.Msg
}
func NewPredefined(ctx context.Context, logger log.ContextLogger, tag string, options option.PredefinedDNSServerOptions) (adapter.DNSTransport, error) {
var responses []*predefinedResponse
for _, response := range options.Responses {
questions, msg, err := response.Build()
if err != nil {
return nil, err
}
responses = append(responses, &predefinedResponse{
questions: questions,
answer: msg,
})
}
if len(responses) == 0 {
return nil, E.New("empty predefined responses")
}
return &PredefinedTransport{
TransportAdapter: dns.NewTransportAdapter(C.DNSTypePreDefined, tag, nil),
responses: responses,
}, nil
}
func (t *PredefinedTransport) Reset() {
}
func (t *PredefinedTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
for _, response := range t.responses {
for _, question := range response.questions {
if func() bool {
if question.Name == "" && question.Qtype == mDNS.TypeNone {
return true
} else if question.Name == "" {
return common.Any(message.Question, func(it mDNS.Question) bool {
return it.Qtype == question.Qtype
})
} else if question.Qtype == mDNS.TypeNone {
return common.Any(message.Question, func(it mDNS.Question) bool {
return it.Name == question.Name
})
} else {
return common.Contains(message.Question, question)
}
}() {
copyAnswer := *response.answer
copyAnswer.Id = message.Id
return &copyAnswer, nil
}
}
}
return nil, dns.RCodeNameError
}

167
dns/transport/quic/http3.go Normal file
View File

@@ -0,0 +1,167 @@
package quic
import (
"bytes"
"context"
"io"
"net"
"net/http"
"net/url"
"strconv"
"github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/http3"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
sHTTP "github.com/sagernet/sing/protocol/http"
mDNS "github.com/miekg/dns"
)
var _ adapter.DNSTransport = (*HTTP3Transport)(nil)
func RegisterHTTP3Transport(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.RemoteHTTPSDNSServerOptions](registry, C.DNSTypeHTTP3, NewHTTP3)
}
type HTTP3Transport struct {
dns.TransportAdapter
logger logger.ContextLogger
dialer N.Dialer
destination *url.URL
headers http.Header
transport *http3.Transport
}
func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteHTTPSDNSServerOptions) (adapter.DNSTransport, error) {
transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions)
if err != nil {
return nil, err
}
tlsOptions := common.PtrValueOrDefault(options.TLS)
tlsOptions.Enabled = true
tlsConfig, err := tls.NewClient(ctx, options.Server, tlsOptions)
if err != nil {
return nil, err
}
stdConfig, err := tlsConfig.Config()
if err != nil {
return nil, err
}
headers := options.Headers.Build()
host := headers.Get("Host")
if host != "" {
headers.Del("Host")
} else {
if tlsConfig.ServerName() != "" {
host = tlsConfig.ServerName()
} else {
host = options.Server
}
}
destinationURL := url.URL{
Scheme: "HTTP3",
Host: host,
}
if destinationURL.Host == "" {
destinationURL.Host = options.Server
}
if options.ServerPort != 0 && options.ServerPort != 443 {
destinationURL.Host = net.JoinHostPort(destinationURL.Host, strconv.Itoa(int(options.ServerPort)))
}
path := options.Path
if path == "" {
path = "/dns-query"
}
err = sHTTP.URLSetPath(&destinationURL, path)
if err != nil {
return nil, err
}
serverAddr := options.ServerOptions.Build()
if serverAddr.Port == 0 {
serverAddr.Port = 443
}
return &HTTP3Transport{
TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeHTTP3, tag, options.RemoteDNSServerOptions),
logger: logger,
dialer: transportDialer,
destination: &destinationURL,
headers: headers,
transport: &http3.Transport{
Dial: func(ctx context.Context, addr string, tlsCfg *tls.STDConfig, cfg *quic.Config) (quic.EarlyConnection, error) {
destinationAddr := M.ParseSocksaddr(addr)
conn, dialErr := transportDialer.DialContext(ctx, N.NetworkUDP, destinationAddr)
if dialErr != nil {
return nil, dialErr
}
return quic.DialEarly(ctx, bufio.NewUnbindPacketConn(conn), conn.RemoteAddr(), tlsCfg, cfg)
},
TLSClientConfig: stdConfig,
},
}, nil
}
func (t *HTTP3Transport) Reset() {
t.transport.Close()
}
func (t *HTTP3Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
exMessage := *message
exMessage.Id = 0
exMessage.Compress = true
requestBuffer := buf.NewSize(1 + message.Len())
rawMessage, err := exMessage.PackBuffer(requestBuffer.FreeBytes())
if err != nil {
requestBuffer.Release()
return nil, err
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, t.destination.String(), bytes.NewReader(rawMessage))
if err != nil {
requestBuffer.Release()
return nil, err
}
request.Header = t.headers.Clone()
request.Header.Set("Content-Type", transport.MimeType)
request.Header.Set("Accept", transport.MimeType)
response, err := t.transport.RoundTrip(request)
requestBuffer.Release()
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return nil, E.New("unexpected status: ", response.Status)
}
var responseMessage mDNS.Msg
if response.ContentLength > 0 {
responseBuffer := buf.NewSize(int(response.ContentLength))
_, err = responseBuffer.ReadFullFrom(response.Body, int(response.ContentLength))
if err != nil {
return nil, err
}
err = responseMessage.Unpack(responseBuffer.Bytes())
responseBuffer.Release()
} else {
rawMessage, err = io.ReadAll(response.Body)
if err != nil {
return nil, err
}
err = responseMessage.Unpack(rawMessage)
}
if err != nil {
return nil, err
}
return &responseMessage, nil
}

174
dns/transport/quic/quic.go Normal file
View File

@@ -0,0 +1,174 @@
package quic
import (
"context"
"errors"
"sync"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
sQUIC "github.com/sagernet/sing-quic"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
mDNS "github.com/miekg/dns"
)
var _ adapter.DNSTransport = (*Transport)(nil)
func RegisterTransport(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.RemoteTLSDNSServerOptions](registry, C.DNSTypeQUIC, NewQUIC)
}
type Transport struct {
dns.TransportAdapter
ctx context.Context
logger logger.ContextLogger
dialer N.Dialer
serverAddr M.Socksaddr
tlsConfig tls.Config
access sync.Mutex
connection quic.EarlyConnection
}
func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) {
transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions)
if err != nil {
return nil, err
}
tlsOptions := common.PtrValueOrDefault(options.TLS)
tlsOptions.Enabled = true
tlsConfig, err := tls.NewClient(ctx, options.Server, tlsOptions)
if err != nil {
return nil, err
}
if len(tlsConfig.NextProtos()) == 0 {
tlsConfig.SetNextProtos([]string{"doq"})
}
serverAddr := options.ServerOptions.Build()
if serverAddr.Port == 0 {
serverAddr.Port = 853
}
return &Transport{
TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeQUIC, tag, options.RemoteDNSServerOptions),
ctx: ctx,
logger: logger,
dialer: transportDialer,
serverAddr: serverAddr,
tlsConfig: tlsConfig,
}, nil
}
func (t *Transport) Reset() {
t.access.Lock()
defer t.access.Unlock()
connection := t.connection
if connection != nil {
connection.CloseWithError(0, "")
}
}
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
var (
conn quic.Connection
err error
response *mDNS.Msg
)
for i := 0; i < 2; i++ {
conn, err = t.openConnection()
if err != nil {
return nil, err
}
response, err = t.exchange(ctx, message, conn)
if err == nil {
return response, nil
} else if !isQUICRetryError(err) {
return nil, err
} else {
conn.CloseWithError(quic.ApplicationErrorCode(0), "")
continue
}
}
return nil, err
}
func (t *Transport) openConnection() (quic.EarlyConnection, error) {
connection := t.connection
if connection != nil && !common.Done(connection.Context()) {
return connection, nil
}
t.access.Lock()
defer t.access.Unlock()
connection = t.connection
if connection != nil && !common.Done(connection.Context()) {
return connection, nil
}
conn, err := t.dialer.DialContext(t.ctx, N.NetworkUDP, t.serverAddr)
if err != nil {
return nil, err
}
earlyConnection, err := sQUIC.DialEarly(
t.ctx,
bufio.NewUnbindPacketConn(conn),
t.serverAddr.UDPAddr(),
t.tlsConfig,
nil,
)
if err != nil {
return nil, err
}
t.connection = earlyConnection
return earlyConnection, nil
}
func (t *Transport) exchange(ctx context.Context, message *mDNS.Msg, conn quic.Connection) (*mDNS.Msg, error) {
stream, err := conn.OpenStreamSync(ctx)
if err != nil {
return nil, err
}
defer stream.Close()
defer stream.CancelRead(0)
err = transport.WriteMessage(stream, 0, message)
if err != nil {
return nil, err
}
return transport.ReadMessage(stream)
}
// https://github.com/AdguardTeam/dnsproxy/blob/fd1868577652c639cce3da00e12ca548f421baf1/upstream/upstream_quic.go#L394
func isQUICRetryError(err error) (ok bool) {
var qAppErr *quic.ApplicationError
if errors.As(err, &qAppErr) && qAppErr.ErrorCode == 0 {
return true
}
var qIdleErr *quic.IdleTimeoutError
if errors.As(err, &qIdleErr) {
return true
}
var resetErr *quic.StatelessResetError
if errors.As(err, &resetErr) {
return true
}
var qTransportError *quic.TransportError
if errors.As(err, &qTransportError) && qTransportError.ErrorCode == quic.NoError {
return true
}
if errors.Is(err, quic.Err0RTTRejected) {
return true
}
return false
}

99
dns/transport/tcp.go Normal file
View File

@@ -0,0 +1,99 @@
package transport
import (
"context"
"encoding/binary"
"io"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
mDNS "github.com/miekg/dns"
)
var _ adapter.DNSTransport = (*TCPTransport)(nil)
func RegisterTCP(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.RemoteDNSServerOptions](registry, C.DNSTypeTCP, NewTCP)
}
type TCPTransport struct {
dns.TransportAdapter
dialer N.Dialer
serverAddr M.Socksaddr
}
func NewTCP(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteDNSServerOptions) (adapter.DNSTransport, error) {
transportDialer, err := dns.NewRemoteDialer(ctx, options)
if err != nil {
return nil, err
}
serverAddr := options.ServerOptions.Build()
if serverAddr.Port == 0 {
serverAddr.Port = 53
}
return &TCPTransport{
TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeTCP, tag, options),
dialer: transportDialer,
serverAddr: serverAddr,
}, nil
}
func (t *TCPTransport) Reset() {
}
func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr)
if err != nil {
return nil, err
}
defer conn.Close()
err = WriteMessage(conn, 0, message)
if err != nil {
return nil, err
}
return ReadMessage(conn)
}
func ReadMessage(reader io.Reader) (*mDNS.Msg, error) {
var responseLen uint16
err := binary.Read(reader, binary.BigEndian, &responseLen)
if err != nil {
return nil, err
}
if responseLen < 10 {
return nil, mDNS.ErrShortRead
}
buffer := buf.NewSize(int(responseLen))
defer buffer.Release()
_, err = buffer.ReadFullFrom(reader, int(responseLen))
if err != nil {
return nil, err
}
var message mDNS.Msg
err = message.Unpack(buffer.Bytes())
return &message, err
}
func WriteMessage(writer io.Writer, messageId uint16, message *mDNS.Msg) error {
requestLen := message.Len()
buffer := buf.NewSize(3 + requestLen)
defer buffer.Release()
common.Must(binary.Write(buffer, binary.BigEndian, uint16(requestLen)))
exMessage := *message
exMessage.Id = messageId
exMessage.Compress = true
rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes())
if err != nil {
return err
}
buffer.Truncate(2 + len(rawMessage))
return common.Error(writer.Write(buffer.Bytes()))
}

115
dns/transport/tls.go Normal file
View File

@@ -0,0 +1,115 @@
package transport
import (
"context"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
mDNS "github.com/miekg/dns"
)
var _ adapter.DNSTransport = (*TLSTransport)(nil)
func RegisterTLS(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.RemoteTLSDNSServerOptions](registry, C.DNSTypeTLS, NewTLS)
}
type TLSTransport struct {
dns.TransportAdapter
logger logger.ContextLogger
dialer N.Dialer
serverAddr M.Socksaddr
tlsConfig tls.Config
access sync.Mutex
connections list.List[*tlsDNSConn]
}
type tlsDNSConn struct {
tls.Conn
queryId uint16
}
func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) {
transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions)
if err != nil {
return nil, err
}
tlsOptions := common.PtrValueOrDefault(options.TLS)
tlsOptions.Enabled = true
tlsConfig, err := tls.NewClient(ctx, options.Server, tlsOptions)
if err != nil {
return nil, err
}
serverAddr := options.ServerOptions.Build()
if serverAddr.Port == 0 {
serverAddr.Port = 853
}
return &TLSTransport{
TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeTLS, tag, options.RemoteDNSServerOptions),
logger: logger,
dialer: transportDialer,
serverAddr: serverAddr,
tlsConfig: tlsConfig,
}, nil
}
func (t *TLSTransport) Reset() {
t.access.Lock()
defer t.access.Unlock()
for connection := t.connections.Front(); connection != nil; connection = connection.Next() {
connection.Value.Close()
}
t.connections.Init()
}
func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
t.access.Lock()
conn := t.connections.PopFront()
t.access.Unlock()
if conn != nil {
response, err := t.exchange(message, conn)
if err == nil {
return response, nil
}
}
tcpConn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr)
if err != nil {
return nil, err
}
tlsConn, err := tls.ClientHandshake(ctx, tcpConn, t.tlsConfig)
if err != nil {
tcpConn.Close()
return nil, err
}
return t.exchange(message, &tlsDNSConn{Conn: tlsConn})
}
func (t *TLSTransport) exchange(message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) {
conn.queryId++
err := WriteMessage(conn, conn.queryId, message)
if err != nil {
conn.Close()
return nil, E.Cause(err, "write request")
}
response, err := ReadMessage(conn)
if err != nil {
conn.Close()
return nil, E.Cause(err, "read response")
}
t.access.Lock()
t.connections.PushBack(conn)
t.access.Unlock()
return response, nil
}

223
dns/transport/udp.go Normal file
View File

@@ -0,0 +1,223 @@
package transport
import (
"context"
"net"
"os"
"sync"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
mDNS "github.com/miekg/dns"
)
var _ adapter.DNSTransport = (*UDPTransport)(nil)
func RegisterUDP(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.RemoteDNSServerOptions](registry, C.DNSTypeUDP, NewUDP)
}
type UDPTransport struct {
dns.TransportAdapter
logger logger.ContextLogger
dialer N.Dialer
serverAddr M.Socksaddr
udpSize int
tcpTransport *TCPTransport
access sync.Mutex
conn *dnsConnection
done chan struct{}
}
func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteDNSServerOptions) (adapter.DNSTransport, error) {
transportDialer, err := dns.NewRemoteDialer(ctx, options)
if err != nil {
return nil, err
}
serverAddr := options.ServerOptions.Build()
if serverAddr.Port == 0 {
serverAddr.Port = 53
}
return NewUDPRaw(logger, dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeUDP, tag, options), transportDialer, serverAddr), nil
}
func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr) *UDPTransport {
return &UDPTransport{
TransportAdapter: adapter,
logger: logger,
dialer: dialer,
serverAddr: serverAddr,
udpSize: 512,
tcpTransport: &TCPTransport{
dialer: dialer,
serverAddr: serverAddr,
},
done: make(chan struct{}),
}
}
func (t *UDPTransport) Reset() {
t.access.Lock()
defer t.access.Unlock()
close(t.done)
t.done = make(chan struct{})
}
func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
response, err := t.exchange(ctx, message)
if err != nil {
return nil, err
}
if response.Truncated {
t.logger.InfoContext(ctx, "response truncated, retrying with TCP")
return t.tcpTransport.Exchange(ctx, message)
}
return response, nil
}
func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
conn, err := t.open(ctx)
if err != nil {
return nil, err
}
if edns0Opt := message.IsEdns0(); edns0Opt != nil {
if udpSize := int(edns0Opt.UDPSize()); udpSize > t.udpSize {
t.udpSize = udpSize
}
}
buffer := buf.NewSize(1 + message.Len())
defer buffer.Release()
exMessage := *message
exMessage.Compress = true
messageId := message.Id
callback := &dnsCallback{
done: make(chan struct{}),
}
conn.access.Lock()
conn.queryId++
exMessage.Id = conn.queryId
conn.callbacks[exMessage.Id] = callback
conn.access.Unlock()
defer func() {
conn.access.Lock()
delete(conn.callbacks, messageId)
conn.access.Unlock()
callback.access.Lock()
select {
case <-callback.done:
default:
close(callback.done)
}
callback.access.Unlock()
}()
rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes())
if err != nil {
return nil, err
}
_, err = conn.Write(rawMessage)
if err != nil {
conn.Close(err)
return nil, err
}
select {
case <-callback.done:
callback.message.Id = messageId
return callback.message, nil
case <-conn.done:
return nil, conn.err
case <-t.done:
return nil, os.ErrClosed
case <-ctx.Done():
conn.Close(ctx.Err())
return nil, ctx.Err()
}
}
func (t *UDPTransport) open(ctx context.Context) (*dnsConnection, error) {
t.access.Lock()
defer t.access.Unlock()
if t.conn != nil {
select {
case <-t.conn.done:
default:
return t.conn, nil
}
}
conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
if err != nil {
return nil, err
}
dnsConn := &dnsConnection{
Conn: conn,
done: make(chan struct{}),
callbacks: make(map[uint16]*dnsCallback),
}
go t.recvLoop(dnsConn)
t.conn = dnsConn
return dnsConn, nil
}
func (t *UDPTransport) recvLoop(conn *dnsConnection) {
for {
buffer := buf.NewSize(t.udpSize)
_, err := buffer.ReadOnceFrom(conn)
if err != nil {
buffer.Release()
conn.Close(err)
return
}
var message mDNS.Msg
err = message.Unpack(buffer.Bytes())
buffer.Release()
if err != nil {
conn.Close(err)
return
}
conn.access.RLock()
callback, loaded := conn.callbacks[message.Id]
conn.access.RUnlock()
if !loaded {
continue
}
callback.access.Lock()
select {
case <-callback.done:
default:
callback.message = &message
close(callback.done)
}
callback.access.Unlock()
}
}
type dnsConnection struct {
net.Conn
access sync.RWMutex
done chan struct{}
closeOnce sync.Once
err error
queryId uint16
callbacks map[uint16]*dnsCallback
}
func (c *dnsConnection) Close(err error) {
c.closeOnce.Do(func() {
close(c.done)
c.err = err
})
c.Conn.Close()
}
type dnsCallback struct {
access sync.Mutex
message *mDNS.Msg
done chan struct{}
}

78
dns/transport_adapter.go Normal file
View File

@@ -0,0 +1,78 @@
package dns
import (
"net/netip"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
)
var _ adapter.LegacyDNSTransport = (*TransportAdapter)(nil)
type TransportAdapter struct {
transportType string
transportTag string
dependencies []string
strategy C.DomainStrategy
clientSubnet netip.Prefix
}
func NewTransportAdapter(transportType string, transportTag string, dependencies []string) TransportAdapter {
return TransportAdapter{
transportType: transportType,
transportTag: transportTag,
dependencies: dependencies,
}
}
func NewTransportAdapterWithLocalOptions(transportType string, transportTag string, localOptions option.LocalDNSServerOptions) TransportAdapter {
var dependencies []string
if localOptions.DomainResolver != nil && localOptions.DomainResolver.Server != "" {
dependencies = append(dependencies, localOptions.DomainResolver.Server)
}
return TransportAdapter{
transportType: transportType,
transportTag: transportTag,
dependencies: dependencies,
strategy: C.DomainStrategy(localOptions.LegacyStrategy),
clientSubnet: localOptions.LegacyClientSubnet,
}
}
func NewTransportAdapterWithRemoteOptions(transportType string, transportTag string, remoteOptions option.RemoteDNSServerOptions) TransportAdapter {
var dependencies []string
if remoteOptions.DomainResolver != nil && remoteOptions.DomainResolver.Server != "" {
dependencies = append(dependencies, remoteOptions.DomainResolver.Server)
}
if remoteOptions.LegacyAddressResolver != "" {
dependencies = append(dependencies, remoteOptions.LegacyAddressResolver)
}
return TransportAdapter{
transportType: transportType,
transportTag: transportTag,
dependencies: dependencies,
strategy: C.DomainStrategy(remoteOptions.LegacyStrategy),
clientSubnet: remoteOptions.LegacyClientSubnet,
}
}
func (a *TransportAdapter) Type() string {
return a.transportType
}
func (a *TransportAdapter) Tag() string {
return a.transportTag
}
func (a *TransportAdapter) Dependencies() []string {
return a.dependencies
}
func (a *TransportAdapter) LegacyStrategy() C.DomainStrategy {
return a.strategy
}
func (a *TransportAdapter) LegacyClientSubnet() netip.Prefix {
return a.clientSubnet
}

103
dns/transport_dialer.go Normal file
View File

@@ -0,0 +1,103 @@
package dns
import (
"context"
"net"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func NewLocalDialer(ctx context.Context, options option.LocalDNSServerOptions) (N.Dialer, error) {
if options.LegacyDefaultDialer {
return dialer.NewDefaultOutbound(ctx), nil
} else {
return dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: options.DialerOptions,
DirectResolver: true,
})
}
}
func NewRemoteDialer(ctx context.Context, options option.RemoteDNSServerOptions) (N.Dialer, error) {
if options.LegacyDefaultDialer {
transportDialer := dialer.NewDefaultOutbound(ctx)
if options.LegacyAddressResolver != "" {
transport := service.FromContext[adapter.DNSTransportManager](ctx)
resolverTransport, loaded := transport.Transport(options.LegacyAddressResolver)
if !loaded {
return nil, E.New("address resolver not found: ", options.LegacyAddressResolver)
}
transportDialer = newTransportDialer(transportDialer, service.FromContext[adapter.DNSRouter](ctx), resolverTransport, C.DomainStrategy(options.LegacyAddressStrategy), time.Duration(options.LegacyAddressFallbackDelay))
} else if options.ServerIsDomain() {
return nil, E.New("missing address resolver for server: ", options.Server)
}
return transportDialer, nil
} else {
return dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: options.DialerOptions,
RemoteIsDomain: options.ServerIsDomain(),
DirectResolver: true,
})
}
}
type legacyTransportDialer struct {
dialer N.Dialer
dnsRouter adapter.DNSRouter
transport adapter.DNSTransport
strategy C.DomainStrategy
fallbackDelay time.Duration
}
func newTransportDialer(dialer N.Dialer, dnsRouter adapter.DNSRouter, transport adapter.DNSTransport, strategy C.DomainStrategy, fallbackDelay time.Duration) *legacyTransportDialer {
return &legacyTransportDialer{
dialer,
dnsRouter,
transport,
strategy,
fallbackDelay,
}
}
func (d *legacyTransportDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if destination.IsIP() {
return d.dialer.DialContext(ctx, network, destination)
}
addresses, err := d.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{
Transport: d.transport,
Strategy: d.strategy,
})
if err != nil {
return nil, err
}
return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == C.DomainStrategyPreferIPv6, d.fallbackDelay)
}
func (d *legacyTransportDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if destination.IsIP() {
return d.dialer.ListenPacket(ctx, destination)
}
addresses, err := d.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{
Transport: d.transport,
Strategy: d.strategy,
})
if err != nil {
return nil, err
}
conn, _, err := N.ListenSerial(ctx, d.dialer, destination, addresses)
return conn, err
}
func (d *legacyTransportDialer) Upstream() any {
return d.dialer
}

288
dns/transport_manager.go Normal file
View File

@@ -0,0 +1,288 @@
package dns
import (
"context"
"io"
"os"
"strings"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
)
var _ adapter.DNSTransportManager = (*TransportManager)(nil)
type TransportManager struct {
logger log.ContextLogger
registry adapter.DNSTransportRegistry
outbound adapter.OutboundManager
defaultTag string
access sync.RWMutex
started bool
stage adapter.StartStage
transports []adapter.DNSTransport
transportByTag map[string]adapter.DNSTransport
dependByTag map[string][]string
defaultTransport adapter.DNSTransport
defaultTransportFallback adapter.DNSTransport
fakeIPTransport adapter.FakeIPTransport
}
func NewTransportManager(logger logger.ContextLogger, registry adapter.DNSTransportRegistry, outbound adapter.OutboundManager, defaultTag string) *TransportManager {
return &TransportManager{
logger: logger,
registry: registry,
outbound: outbound,
defaultTag: defaultTag,
transportByTag: make(map[string]adapter.DNSTransport),
dependByTag: make(map[string][]string),
}
}
func (m *TransportManager) Initialize(defaultTransportFallback adapter.DNSTransport) {
m.defaultTransportFallback = defaultTransportFallback
}
func (m *TransportManager) Start(stage adapter.StartStage) error {
m.access.Lock()
if m.started && m.stage >= stage {
panic("already started")
}
m.started = true
m.stage = stage
outbounds := m.transports
m.access.Unlock()
if stage == adapter.StartStateStart {
return m.startTransports(m.transports)
} else {
for _, outbound := range outbounds {
err := adapter.LegacyStart(outbound, stage)
if err != nil {
return E.Cause(err, stage, " dns/", outbound.Type(), "[", outbound.Tag(), "]")
}
}
}
return nil
}
func (m *TransportManager) startTransports(transports []adapter.DNSTransport) error {
monitor := taskmonitor.New(m.logger, C.StartTimeout)
started := make(map[string]bool)
for {
canContinue := false
startOne:
for _, transportToStart := range transports {
transportTag := transportToStart.Tag()
if started[transportTag] {
continue
}
dependencies := transportToStart.Dependencies()
for _, dependency := range dependencies {
if !started[dependency] {
continue startOne
}
}
started[transportTag] = true
canContinue = true
if starter, isStarter := transportToStart.(adapter.Lifecycle); isStarter {
monitor.Start("start dns/", transportToStart.Type(), "[", transportTag, "]")
err := starter.Start(adapter.StartStateStart)
monitor.Finish()
if err != nil {
return E.Cause(err, "start dns/", transportToStart.Type(), "[", transportTag, "]")
}
}
}
if len(started) == len(transports) {
break
}
if canContinue {
continue
}
currentTransport := common.Find(transports, func(it adapter.DNSTransport) bool {
return !started[it.Tag()]
})
var lintTransport func(oTree []string, oCurrent adapter.DNSTransport) error
lintTransport = func(oTree []string, oCurrent adapter.DNSTransport) error {
problemTransportTag := common.Find(oCurrent.Dependencies(), func(it string) bool {
return !started[it]
})
if common.Contains(oTree, problemTransportTag) {
return E.New("circular server dependency: ", strings.Join(oTree, " -> "), " -> ", problemTransportTag)
}
m.access.Lock()
problemTransport := m.transportByTag[problemTransportTag]
m.access.Unlock()
if problemTransport == nil {
return E.New("dependency[", problemTransportTag, "] not found for server[", oCurrent.Tag(), "]")
}
return lintTransport(append(oTree, problemTransportTag), problemTransport)
}
return lintTransport([]string{currentTransport.Tag()}, currentTransport)
}
return nil
}
func (m *TransportManager) Close() error {
monitor := taskmonitor.New(m.logger, C.StopTimeout)
m.access.Lock()
if !m.started {
m.access.Unlock()
return nil
}
m.started = false
transports := m.transports
m.transports = nil
m.access.Unlock()
var err error
for _, transport := range transports {
if closer, isCloser := transport.(io.Closer); isCloser {
monitor.Start("close server/", transport.Type(), "[", transport.Tag(), "]")
err = E.Append(err, closer.Close(), func(err error) error {
return E.Cause(err, "close server/", transport.Type(), "[", transport.Tag(), "]")
})
monitor.Finish()
}
}
return nil
}
func (m *TransportManager) Transports() []adapter.DNSTransport {
m.access.RLock()
defer m.access.RUnlock()
return m.transports
}
func (m *TransportManager) Transport(tag string) (adapter.DNSTransport, bool) {
m.access.RLock()
outbound, found := m.transportByTag[tag]
m.access.RUnlock()
return outbound, found
}
func (m *TransportManager) Default() adapter.DNSTransport {
m.access.RLock()
defer m.access.RUnlock()
if m.defaultTransport != nil {
return m.defaultTransport
} else {
return m.defaultTransportFallback
}
}
func (m *TransportManager) FakeIP() adapter.FakeIPTransport {
m.access.RLock()
defer m.access.RUnlock()
return m.fakeIPTransport
}
func (m *TransportManager) Remove(tag string) error {
m.access.Lock()
defer m.access.Unlock()
transport, found := m.transportByTag[tag]
if !found {
return os.ErrInvalid
}
delete(m.transportByTag, tag)
index := common.Index(m.transports, func(it adapter.DNSTransport) bool {
return it == transport
})
if index == -1 {
panic("invalid inbound index")
}
m.transports = append(m.transports[:index], m.transports[index+1:]...)
started := m.started
if m.defaultTransport == transport {
if len(m.transports) > 0 {
nextTransport := m.transports[0]
if nextTransport.Type() != C.DNSTypeFakeIP {
return E.New("default server cannot be fakeip")
}
m.defaultTransport = nextTransport
m.logger.Info("updated default server to ", m.defaultTransport.Tag())
} else {
m.defaultTransport = nil
}
}
dependBy := m.dependByTag[tag]
if len(dependBy) > 0 {
return E.New("server[", tag, "] is depended by ", strings.Join(dependBy, ", "))
}
dependencies := transport.Dependencies()
for _, dependency := range dependencies {
if len(m.dependByTag[dependency]) == 1 {
delete(m.dependByTag, dependency)
} else {
m.dependByTag[dependency] = common.Filter(m.dependByTag[dependency], func(it string) bool {
return it != tag
})
}
}
if started {
transport.Reset()
}
return nil
}
func (m *TransportManager) Create(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) error {
if tag == "" {
return os.ErrInvalid
}
transport, err := m.registry.CreateDNSTransport(ctx, logger, tag, transportType, options)
if err != nil {
return err
}
m.access.Lock()
defer m.access.Unlock()
if m.started {
for _, stage := range adapter.ListStartStages {
err = adapter.LegacyStart(transport, stage)
if err != nil {
return E.Cause(err, stage, " dns/", transport.Type(), "[", transport.Tag(), "]")
}
}
}
if existsTransport, loaded := m.transportByTag[tag]; loaded {
if m.started {
err = common.Close(existsTransport)
if err != nil {
return E.Cause(err, "close dns/", existsTransport.Type(), "[", existsTransport.Tag(), "]")
}
}
existsIndex := common.Index(m.transports, func(it adapter.DNSTransport) bool {
return it == existsTransport
})
if existsIndex == -1 {
panic("invalid inbound index")
}
m.transports = append(m.transports[:existsIndex], m.transports[existsIndex+1:]...)
}
m.transports = append(m.transports, transport)
m.transportByTag[tag] = transport
dependencies := transport.Dependencies()
for _, dependency := range dependencies {
m.dependByTag[dependency] = append(m.dependByTag[dependency], tag)
}
if tag == m.defaultTag || (m.defaultTag == "" && m.defaultTransport == nil) {
if transport.Type() == C.DNSTypeFakeIP {
return E.New("default server cannot be fakeip")
}
m.defaultTransport = transport
if m.started {
m.logger.Info("updated default server to ", transport.Tag())
}
}
if transport.Type() == C.DNSTypeFakeIP {
if m.fakeIPTransport != nil {
return E.New("multiple fakeip server are not supported")
}
m.fakeIPTransport = transport.(adapter.FakeIPTransport)
}
return nil
}

72
dns/transport_registry.go Normal file
View File

@@ -0,0 +1,72 @@
package dns
import (
"context"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
type TransportConstructorFunc[T any] func(ctx context.Context, logger log.ContextLogger, tag string, options T) (adapter.DNSTransport, error)
func RegisterTransport[Options any](registry *TransportRegistry, transportType string, constructor TransportConstructorFunc[Options]) {
registry.register(transportType, func() any {
return new(Options)
}, func(ctx context.Context, logger log.ContextLogger, tag string, rawOptions any) (adapter.DNSTransport, error) {
var options *Options
if rawOptions != nil {
options = rawOptions.(*Options)
}
return constructor(ctx, logger, tag, common.PtrValueOrDefault(options))
})
}
var _ adapter.DNSTransportRegistry = (*TransportRegistry)(nil)
type (
optionsConstructorFunc func() any
constructorFunc func(ctx context.Context, logger log.ContextLogger, tag string, options any) (adapter.DNSTransport, error)
)
type TransportRegistry struct {
access sync.Mutex
optionsType map[string]optionsConstructorFunc
constructors map[string]constructorFunc
}
func NewTransportRegistry() *TransportRegistry {
return &TransportRegistry{
optionsType: make(map[string]optionsConstructorFunc),
constructors: make(map[string]constructorFunc),
}
}
func (r *TransportRegistry) CreateOptions(transportType string) (any, bool) {
r.access.Lock()
defer r.access.Unlock()
optionsConstructor, loaded := r.optionsType[transportType]
if !loaded {
return nil, false
}
return optionsConstructor(), true
}
func (r *TransportRegistry) CreateDNSTransport(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) (adapter.DNSTransport, error) {
r.access.Lock()
defer r.access.Unlock()
constructor, loaded := r.constructors[transportType]
if !loaded {
return nil, E.New("transport type not found: " + transportType)
}
return constructor(ctx, logger, tag, options)
}
func (r *TransportRegistry) register(transportType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) {
r.access.Lock()
defer r.access.Unlock()
r.optionsType[transportType] = optionsConstructor
r.constructors[transportType] = constructor
}

View File

@@ -2,6 +2,190 @@
icon: material/alert-decagram
---
#### 1.12.0-alpha.3
* Fixes and improvements
### 1.11.1
* Fixes and improvements
#### 1.12.0-alpha.2
* Update quic-go to v0.49.0
* Fixes and improvements
#### 1.12.0-alpha.1
* Refactor DNS servers **1**
* Add domain resolver options**2**
* Add TLS fragment route options **3**
* Add certificate options **4**
**1**:
DNS servers are refactored for better performance and scalability.
See [DNS server](/configuration/dns/server/).
For migration, see [Migrate to new DNS server formats](/migration/#migrate-to-new-dns-servers).
Compatibility for old formats will be removed in sing-box 1.14.0.
**2**:
Legacy `outbound` DNS rules are deprecated
and can be replaced by the new `domain_resolver` option.
See [Dial Fields](/configuration/shared/dial/#domain_resolver) and
[Route](/configuration/route/#default_domain_resolver).
For migration, see [Migrate outbound DNS rule items to domain resolver](/migration/#migrate-outbound-dns-rule-items-to-domain-resolver).
**3**:
The new TLS fragment route options allow you to fragment TLS handshakes to bypass firewalls.
This feature is intended to circumvent simple firewalls based on **plaintext packet matching**, and should not be used to circumvent real censorship.
Since it is not designed for performance, it should not be applied to all connections, but only to server names that are known to be blocked.
See [Route Action](/configuration/route/rule_action/#tls_fragment).
**4**:
New certificate options allow you to manage the default list of trusted X509 CA certificates.
For the system certificate list, fixed Go not reading Android trusted certificates correctly.
You can also use the Mozilla Included List instead, or add trusted certificates yourself.
See [Certificate](/configuration/certificate/).
### 1.11.0
Important changes since 1.10:
* Introducing rule actions **1**
* Improve tun compatibility **3**
* Merge route options to route actions **4**
* Add `network_type`, `network_is_expensive` and `network_is_constrainted` rule items **5**
* Add multi network dialing **6**
* Add `cache_capacity` DNS option **7**
* Add `override_address` and `override_port` route options **8**
* Upgrade WireGuard outbound to endpoint **9**
* Add UDP GSO support for WireGuard
* Make GSO adaptive **10**
* Add UDP timeout route option **11**
* Add more masquerade options for hysteria2 **12**
* Add `rule-set merge` command
* Add port hopping support for Hysteria2 **13**
* Hysteria2 `ignore_client_bandwidth` behavior update **14**
**1**:
New rule actions replace legacy inbound fields and special outbound fields,
and can be used for pre-matching **2**.
See [Rule](/configuration/route/rule/),
[Rule Action](/configuration/route/rule_action/),
[DNS Rule](/configuration/dns/rule/) and
[DNS Rule Action](/configuration/dns/rule_action/).
For migration, see
[Migrate legacy special outbounds to rule actions](/migration/#migrate-legacy-special-outbounds-to-rule-actions),
[Migrate legacy inbound fields to rule actions](/migration/#migrate-legacy-inbound-fields-to-rule-actions)
and [Migrate legacy DNS route options to rule actions](/migration/#migrate-legacy-dns-route-options-to-rule-actions).
**2**:
Similar to Surge's pre-matching.
Specifically, new rule actions allow you to reject connections with
TCP RST (for TCP connections) and ICMP port unreachable (for UDP packets)
before connection established to improve tun's compatibility.
See [Rule Action](/configuration/route/rule_action/).
**3**:
When `gvisor` tun stack is enabled, even if the request passes routing,
if the outbound connection establishment fails,
the connection still does not need to be established and a TCP RST is replied.
**4**:
Route options in DNS route actions will no longer be considered deprecated,
see [DNS Route Action](/configuration/dns/rule_action/).
Also, now `udp_disable_domain_unmapping` and `udp_connect` can also be configured in route action,
see [Route Action](/configuration/route/rule_action/).
**5**:
When using in graphical clients, new routing rule items allow you to match on
network type (WIFI, cellular, etc.), whether the network is expensive, and whether Low Data Mode is enabled.
See [Route Rule](/configuration/route/rule/), [DNS Route Rule](/configuration/dns/rule/)
and [Headless Rule](/configuration/rule-set/headless-rule/).
**6**:
Similar to Surge's strategy.
New options allow you to connect using multiple network interfaces,
prefer or only use one type of interface,
and configure a timeout to fallback to other interfaces.
See [Dial Fields](/configuration/shared/dial/#network_strategy),
[Rule Action](/configuration/route/rule_action/#network_strategy)
and [Route](/configuration/route/#default_network_strategy).
**7**:
See [DNS](/configuration/dns/#cache_capacity).
**8**:
See [Rule Action](/configuration/route/#override_address) and
[Migrate destination override fields to route options](/migration/#migrate-destination-override-fields-to-route-options).
**9**:
The new WireGuard endpoint combines inbound and outbound capabilities,
and the old outbound will be removed in sing-box 1.13.0.
See [Endpoint](/configuration/endpoint/), [WireGuard Endpoint](/configuration/endpoint/wireguard/)
and [Migrate WireGuard outbound fields to route options](/migration/#migrate-wireguard-outbound-to-endpoint).
**10**:
For WireGuard outbound and endpoint, GSO will be automatically enabled when available,
see [WireGuard Outbound](/configuration/outbound/wireguard/#gso).
For TUN, GSO has been removed,
see [Deprecated](/deprecated/#gso-option-in-tun).
**11**:
See [Rule Action](/configuration/route/rule_action/#udp_timeout).
**12**:
See [Hysteria2](/configuration/inbound/hysteria2/#masquerade).
**13**:
See [Hysteria2](/configuration/outbound/hysteria2/).
**14**:
When `up_mbps` and `down_mbps` are set, `ignore_client_bandwidth` instead denies clients from using BBR CC.
### 1.10.7
* Fixes and improvements
#### 1.11.0-beta.20
* Hysteria2 `ignore_client_bandwidth` behavior update **1**

View File

@@ -0,0 +1,54 @@
---
icon: material/new-box
---
!!! question "Since sing-box 1.12.0"
# Certificate
### Structure
```json
{
"store": "",
"certificate": [],
"certificate_path": [],
"certificate_directory_path": []
}
```
!!! note ""
You can ignore the JSON Array [] tag when the content is only one item
### Fields
#### store
The default X509 trusted CA certificate list.
| Type | Description |
|--------------------|---------------------------------------------------------------------------------------------------------------|
| `system` (default) | System trusted CA certificates |
| `mozilla` | [Mozilla Included List](https://wiki.mozilla.org/CA/Included_Certificates) with China CA certificates removed |
| `none` | Empty list |
#### certificate
The certificate line array to trust, in PEM format.
#### certificate_path
!!! note ""
Will be automatically reloaded if file modified.
The paths to certificates to trust, in PEM format.
#### certificate_directory_path
!!! note ""
Will be automatically reloaded if file modified.
The directory path to search for certificates to trust,in PEM format.

View File

@@ -1,3 +1,11 @@
---
icon: material/delete-clock
---
!!! failure "Deprecated in sing-box 1.12.0"
Legacy fake-ip configuration is deprecated and will be removed in sing-box 1.14.0, check [Migration](/migration/#migrate-to-new-dns-servers).
### Structure
```json

View File

@@ -1,3 +1,11 @@
---
icon: material/delete-clock
---
!!! failure "已在 sing-box 1.12.0 废弃"
旧的 fake-ip 配置已废弃且将在 sing-box 1.14.0 中被移除,参阅 [迁移指南](/migration/#migrate-to-new-dns-servers)。
### 结构
```json

View File

@@ -49,8 +49,6 @@ Default domain strategy for resolving the domain names.
One of `prefer_ipv4` `prefer_ipv6` `ipv4_only` `ipv6_only`.
Take no effect if `server.strategy` is set.
#### disable_cache
Disable dns cache.

Some files were not shown because too many files have changed in this diff Show More