Compare commits

..

96 Commits

Author SHA1 Message Date
Bruce Wayne
db8b42c107 fix(ocm): map prolite to lower pro weight 2026-04-12 01:44:42 +08:00
世界
bf9e390cf4 fix(ccm,ocm): allow reverse connector credentials to serve local requests
Connector mode credentials were unconditionally blocked from local use
by unavailableError(), despite having a working forwardHTTPClient. Also
set credentialDialer in OCM connector mode to prevent nil panic in
WebSocket handler.
2026-03-28 18:46:10 +08:00
世界
471c9c3b47 fix(ccm): make refresh failure fail fast 2026-03-28 18:07:31 +08:00
世界
e7478ce947 fix(ccm): mark credential unavailable on refresh failure, handle poll 401
tryRefreshCredentials now returns error and calls markCredentialsUnavailable
when lock acquisition or file write permission fails. getAccessToken propagates
the error instead of silently returning the expired token. pollUsage handles
401 by attempting auth recovery and marking unavailable on failure. All
credential error paths now use Error log level instead of Debug. Startup
checks expired tokens eagerly via tryRefreshCredentials.
2026-03-28 16:58:52 +08:00
世界
cf11e0e74a Reuse SDK JSON types in ccm and ocm 2026-03-28 11:20:24 +08:00
世界
a87a2b0e2b ocm: log think level 2026-03-28 02:00:27 +08:00
世界
d9c298af1e fix(ccm,ocm): remove upstream rate limit header forwarding, compute locally
Strip all upstream rate limit headers and compute unified-status,
representative-claim, reset times, and surpassed-threshold from
aggregated utilization data. Never expose per-account overage or
fallback information. Remove per-credential unified state storage,
snapshot aggregation, and WebSocket synthetic rate limit events.
2026-03-26 23:42:14 +08:00
世界
cd5007ffbb fix(ccm,ocm): track external credential poll failures and re-poll on user connect
External credentials now properly increment consecutivePollFailures on
poll errors (matching defaultCredential behavior), marking the credential
as temporarily blocked. When a user with external_credential connects
and the credential is not usable, a forced poll is triggered to check
recovery.
2026-03-26 22:16:02 +08:00
世界
e49d0685ad ccm: Fix token refresh 2026-03-26 21:37:12 +08:00
世界
e1c9667319 Revert "fix(ocm): send rate limit status immediately on WebSocket connect"
This reverts commit 6721dff48a.
2026-03-26 16:03:46 +08:00
世界
6f45ea9c27 Revert "fix(ocm): inject synthetic rate limits inline when intercepting upstream events"
This reverts commit ca60f93184.
2026-03-26 16:03:39 +08:00
世界
ca60f93184 fix(ocm): inject synthetic rate limits inline when intercepting upstream events
The initial synthetic event from 6721dff48 arrives before the Codex CLI's
response stream reader is active. Additionally, the shouldEmit gate in
updateStateFromHeaders suppresses the async replacement when values haven't
changed. Send aggregated status inline in proxyWebSocketUpstreamToClient
so the client receives it at the exact protocol position it expects.
2026-03-26 15:42:51 +08:00
世界
1774d98793 fix(ccm,ocm): restore fixed usage polling
Remove the poll_interval config surface from CCM and OCM so both services fall back to the built-in 1h polling cadence again. Also isolate CCM credential lock mocking per test instance so the access-token refresh tests stop racing on shared global state.
2026-03-26 14:01:24 +08:00
世界
6721dff48a fix(ocm): send rate limit status immediately on WebSocket connect
Codex CLI ignores x-codex-* headers in the WebSocket upgrade response
and only reads rate limits from in-band codex.rate_limits events.
Previously, the first synthetic event was gated by firstRealRequest
(after warmup), delaying usage display. Now send aggregated status
right after subscribing, so the client sees rate limits before the
first turn begins.
2026-03-26 12:33:53 +08:00
世界
4592164a7a Align CCM and OCM rate limits 2026-03-24 22:06:10 +08:00
世界
92c8f4c5c8 fix(ccm): align default credential with Claude Code 2026-03-24 21:15:46 +08:00
世界
441c98890d feat(ccm): add claude_directory option to read Claude Code config 2026-03-22 16:59:36 +08:00
世界
bb2169bc17 release: Fix install_go.sh 2026-03-22 06:35:34 +08:00
世界
d996b60f44 ccm/ocm: Add CLAUDE.md 2026-03-22 06:23:13 +08:00
世界
084a6f1302 fix(ccm): align OAuth token refresh with Claude Code v2.1.81
After re-login with newer Claude Code (v2.1.75+), CCM refresh requests
returned persistent 429s. Root cause: CCM omitted the `scope` parameter
that the server now requires for tokens with `user:file_upload` scope.

Changes to fully match Claude Code's OAuth behavior:

- Add `scope` parameter to token refresh request body
- Parse `scope` from refresh response and store back
- Add `subscriptionType`/`rateLimitTier` to credential struct to
  preserve Claude Code's profile state on write-back
- Change credential file write to read-modify-write, preserving
  other top-level JSON keys (matches Claude Code's BP6 pattern)
- Same for macOS keychain write path
- Increase token expiry buffer from 1 min to 5 min (matching CC's
  isOAuthTokenExpired with 300s buffer)
- Add cross-process mkdir-based file lock compatible with Claude
  Code's proper-lockfile protocol (~/.claude.lock)
- Add post-failure recovery: re-read credentials from disk after
  refresh failure in case another process succeeded
- Add 401/403 "OAuth token has been revoked" recovery in proxy
  handler: reload credentials and retry once
2026-03-22 06:02:55 +08:00
世界
0950783479 fix(ccm,ocm): exclude unusable credentials from status aggregation
computeAggregatedUtilization used isAvailable() which only checks
permanent unavailability, so credentials rejected by upstream 400
still had their planWeight included in the total, inflating reported
capacity and diluting utilization.
2026-03-21 11:42:49 +08:00
世界
f172a575b7 fix(ccm): log assigned credential for each distinct model per session 2026-03-21 11:07:43 +08:00
世界
29b901a8b3 fix(ccm): robust account UUID injection and session ID validation
Replace bytes.Replace-based UUID injection with proper JSON
unmarshal/re-marshal through map[string]json.RawMessage — the old
approach silently failed when the body used non-canonical JSON escaping.

Return 500 when metadata.user_id is present but in an unrecognized
format, instead of silently passing through with an empty session ID.
2026-03-21 11:00:05 +08:00
世界
53f832330d fix(ccm): adapt to Claude Code v2.1.78 metadata format, separate state from credentials
Claude Code v2.1.78 changed metadata.user_id from a template literal
(`user_${id}_account_${uuid}_session_${sid}`) to a JSON-encoded object
(`JSON.stringify({device_id, account_uuid, session_id})`), breaking
session ID extraction via `_session_` substring match.

- Fix extractCCMSessionID to try JSON parse first, fallback to legacy
- Remove subscriptionType/rateLimitTier/isMax from oauthCredentials
  (profile state does not belong in auth credentials)
- Add state_path option for persisting profile state across restarts
- Parse account.uuid from /api/oauth/profile response
- Inject account_uuid into forwarded requests when client sends it empty
  (happens when using ANTHROPIC_AUTH_TOKEN instead of Claude AI OAuth)
2026-03-21 10:45:24 +08:00
世界
99d9e06dd0 fix(ccm,ocm): handle upstream 400 by marking external credentials rejected and polling default credentials
External credentials returning 400 are marked unavailable for pollInterval
duration; status stream/poll success clears the rejection early. Default
credentials trigger a stale poll to let the usage API detect account issues
without causing 429 storms.
2026-03-21 10:31:17 +08:00
世界
608b7e7fa2 fix(ccm,ocm): stop cascading 429 retry storm on token refresh
When the access token expires and refreshToken() gets 429, getAccessToken()
returned the error but left credentials unchanged with no cooldown. Every
subsequent request re-attempted the refresh, creating a burst that overwhelmed
the token endpoint.

- refreshToken() now returns Retry-After duration from 429 response headers
  (-1 when no header present, meaning permanently blocked)
- getAccessToken() caches the 429 and blocks further refresh attempts until
  Retry-After expires (or permanently if no header)
- reloadCredentials() clears the block when new credentials are loaded from file
- Remove go pollUsage() on upstream errors (unrelated to usage state)
2026-03-21 09:31:05 +08:00
世界
7acba74755 fix(ccm): forward 529 upstream overloaded response transparently 2026-03-18 15:53:36 +08:00
世界
2fe1e37b17 fix(ccm,ocm): add missing isFirstUpdate to external credential usage logging 2026-03-18 01:00:55 +08:00
世界
3bcfdd5455 fix(ccm,ocm): remove external context from pollUsage/pollIfStale
pollUsage(ctx) accepted caller context, and service_status.go passed
r.Context() which gets canceled on client disconnect or service shutdown.
This caused incrementPollFailures → interruptConnections on transient
cancellations. Each implementation now uses its own persistent context:
defaultCredential uses serviceContext, externalCredential uses
getReverseContext().
2026-03-18 00:54:01 +08:00
世界
b119d08764 fix(ccm,ocm): add usage logging to status stream, remove redundant isFirstUpdate
connectStatusStream updated credential state silently — no log on
first frame or value changes. After restart, external credentials
get usage via stream before any request, so pollIfStale skips them
and no usage log ever appears.

Add the same change-detection log to connectStatusStream. Also remove
redundant isFirstUpdate guards from pollUsage and updateStateFromHeaders:
when old values are zero, any non-zero new value already satisfies the
integer-percent comparison.
2026-03-17 22:37:38 +08:00
世界
6b8838d323 fix(ccm,ocm): restart status stream when receiver gets reverse session
statusStreamLoop started on start() before any reverse session existed,
got a non-retryable error, and exited permanently. Restart it when
setReverseSession transitions receiver credentials to available.
2026-03-17 22:08:30 +08:00
世界
b3429ef1f3 fix(ocm): strip non-active rate-limit headers from forwarded responses 2026-03-17 22:01:30 +08:00
世界
a2d6cf9715 fix(ocm): defer initial websocket rate-limit push 2026-03-17 21:14:14 +08:00
世界
99e19e7033 service: stop retrying fatal watch status errors 2026-03-17 20:47:52 +08:00
世界
969defeef0 ccm,ocm: validate external status response fields 2026-03-17 20:17:56 +08:00
世界
f57eff33bb ccm,ocm: fix WS push lifecycle, deduplicate rate_limits, stabilize reset aggregation
- Add closed channel to webSocketSession for push goroutine shutdown
  on connection close, preventing session leak and Service.Close() hang
- Intercept upstream codex.rate_limits events instead of forwarding;
  push goroutine is now the sole sender of aggregated rate_limits
- Emit status updates on reset-only changes (fiveHourResetChanged,
  weeklyResetChanged) so push goroutine picks up reset advances
- Skip expired resets (hours <= 0) in aggregation instead of clamping
  to now, avoiding unstable reset_at output and spurious status ticks
- Delete stale upstream reset headers when aggregated reset is zero
- Hardcode "codex" identifier everywhere: handleWebSocketRateLimitsEvent,
  buildSyntheticRateLimitsEvent, rewriteResponseHeaders
- Remove rewriteWebSocketRateLimits, rewriteWebSocketRateLimitWindow,
  identifier tracking (TypedValue), and unused imports
2026-03-17 20:00:54 +08:00
世界
0a054b9aa4 ccm,ocm: propagate reset times, rewrite headers for all users, add WS status push
- Add fiveHourReset/weeklyReset to statusPayload and aggregatedStatus
  with weight-averaged reset time aggregation across credential pools
- Rewrite response headers (utilization + reset times) for all users,
  not just external credential users
- Rewrite WebSocket rate_limits events for all users with aggregated values
- Add proactive WebSocket status push: synthetic codex.rate_limits events
  sent on connection start and on status changes via statusObserver
- Remove one-shot stream forward compatibility (statusStreamHeader,
  restoreLastUpdatedIfUnchanged, oneShot detection)
2026-03-17 18:13:54 +08:00
世界
7d15d9d282 ccm: emit status updates for plan-weight-only changes 2026-03-17 16:46:54 +08:00
世界
cf2d677043 ocm: emit status updates for plan-weight-only changes 2026-03-17 16:32:03 +08:00
世界
4a6a211775 ccm,ocm: reduce status emission noise, simplify emit-guard pattern
Guard updateStateFromHeaders emission with value-change detection to
avoid unnecessary computeAggregatedUtilization scans on every proxied
response. Replace statusAggregateStateLocked two-value return with
comparable statusSnapshot struct. Define statusPayload type for the
status wire format, replacing anonymous structs and map literals.
2026-03-17 16:10:59 +08:00
世界
f84832a369 Add stream watch endpoint 2026-03-17 16:03:35 +08:00
世界
f3c3022094 ccm,ocm: fix session race, track fallback sessions, skip warmup logging
Fix data race in selectCredential where concurrent goroutines could
overwrite each other's session entries by adding compare-and-delete
and store-if-absent patterns with retry loop. Track sessions for
fallback strategy so isNew is reported correctly. Skip logging and
usage tracking for websocket warmup requests (generate: false).
2026-03-16 22:10:10 +08:00
世界
2dd093a32e ccm,ocm: fix data race, remove dead code, clean up inefficiencies 2026-03-15 21:20:29 +08:00
世界
14ade76956 ccm,ocm: remove dead code, fix timer leaks, eliminate redundant lookups
- Remove unused onBecameUnusable field from CCM credential structs
  (OCM wires it for WebSocket interruption; CCM has no equivalent)
- Replace time.After with time.NewTimer in doHTTPWithRetry and
  connectorLoop to avoid timer leaks on context cancellation
- Pass already-resolved provider to rewriteResponseHeadersForExternalUser
  instead of re-resolving via credentialForUser
- Hoist reverseYamuxConfig to package-level var (immutable, no need to
  allocate on every call)
2026-03-15 20:42:41 +08:00
世界
9e3ec30d72 docs: fix ccm and ocm credential docs 2026-03-15 20:41:47 +08:00
世界
763e0af010 docs: complete ccm/ocm documentation for 1.14.0 features 2026-03-15 18:49:00 +08:00
世界
656b09d1be ccm,ocm: never treat external usage endpoint failures as over-limit 2026-03-15 18:48:53 +08:00
世界
8e9c61e624 ccm,ocm: normalize legacy fields into credentials at init, remove dual code path 2026-03-15 18:48:53 +08:00
世界
bc6e72408d ccm,ocm: block API key headers from being forwarded upstream 2026-03-15 18:48:52 +08:00
世界
56af7313b2 ccm,ocm: don't treat usage API 429 as account over-limit
The usage API itself has rate limits. A 429 from it means "poll less
frequently", not that the account exceeded its usage quota. Previously
incrementPollFailures() was called, marking the credential unusable and
interrupting in-flight connections.

Now: parse Retry-After, store as usageAPIRetryDelay, and retry after
that delay. The credential stays usable and relies on passive header
updates for usage data in the meantime.
2026-03-15 18:48:52 +08:00
世界
6878ad0d35 ccm,ocm: fix naming and error-handling convention violations
- Rename credential interface to Credential (exported), cred to credential
- Rename mutex/saveMutex to access/saveAccess per go-syntax.md
- Fix abbreviations: reverseHttpClient, allCreds, credOpt, extCred,
  credDialer, reverseCredDialer, portStr
- Replace errors.Is(http.ErrServerClosed) with E.IsClosed
- Add E.IsClosedOrCanceled guard before streaming write error logs
2026-03-15 18:48:51 +08:00
世界
04bd63b455 ccm,ocm: reorganize files and improve naming conventions
Split credential_state.go (1500+ lines) into credential.go,
credential_default.go, credential_provider.go, credential_builder.go.

Split service.go (900+ lines) into service.go, service_handler.go,
service_status.go.

Rename credential.go to credential_oauth.go to avoid name conflict
with the credential interface.

Apply naming fixes: accessMutex→access, stateMutex→stateAccess,
sessionMutex→sessionAccess, webSocketMutex→webSocketAccess,
httpTransport()→httpClient(), httpClient field→forwardHTTPClient,
weeklyWindowDuration→weeklyWindowHours.
2026-03-15 18:48:51 +08:00
世界
51d564c9ff ccm,ocm: merge fallback into balancer strategy, use hyphenated constant names
Merge the fallback credential type into balancer as a strategy
(C.BalancerStrategyFallback). Replace raw string literals with
C.BalancerStrategyXxx constants and switch to hyphens (least-used,
round-robin) per project convention.
2026-03-15 18:48:50 +08:00
世界
4d8baf7175 ccm: fix nil pointer in pollUsage for connector-mode credentials
Connector-mode credentials (URL + reverse: true) never assigned
httpClient, causing a nil dereference when pollUsage accessed
httpClient.Transport.

Also extract poll request logic into doPollUsageRequest to try
reverse transport first (single attempt), then fall back to
forward transport with retries if the reverse session disconnects.
2026-03-15 18:48:50 +08:00
世界
d1e5426bc8 ccm,ocm: add exponential backoff with cap for poll retry
Replace flat 1-minute poll retry interval with exponential backoff
(1m → 2m → 4m → 5m cap). Suppress error logs after reaching the cap.
2026-03-15 18:48:50 +08:00
世界
4d907bc49d ccm,ocm: allow URL-based credentials to accept reverse connections
Previously, findReceiverCredential required baseURL == reverseProxyBaseURL,
so only credentials with no URL could accept incoming reverse connections.
Now credentials with a normal URL also accept reverse connections, preferring
the reverse session when active and falling back to the direct URL when not.
2026-03-15 18:48:49 +08:00
世界
2c907bef2c Fix scoped rebalance interrupts 2026-03-15 18:48:49 +08:00
世界
d2300353fd Propagate request context to upstream requests 2026-03-15 18:48:49 +08:00
世界
f871113832 ccm,ocm: add balancer session rebalancing with per-credential interrupt
When a sticky session's credential utilization exceeds the least-used
credential by a weight-adjusted threshold, force reassign all sessions
on that credential and cancel in-flight requests scoped to the balancer.

Threshold formula: effective = rebalance_threshold / planWeight, so a
config value of 20 triggers at 2% delta for Max 20x (w=10), 4% for
Max 5x (w=5), and 20% for Pro (w=1).
2026-03-15 18:48:49 +08:00
世界
b97b9d9cfd ccm,ocm: add request ID context to HTTP request logging 2026-03-15 18:48:48 +08:00
世界
badeeb91fe service/ocm: add default OpenAI-Beta header and log websocket error body
The upstream OpenAI WebSocket endpoint requires the
OpenAI-Beta: responses_websockets=2026-02-06 header. Set it
automatically when the client doesn't provide it.

Also capture and log the response body on non-429 WebSocket
handshake failures to surface the actual error from upstream.
2026-03-15 18:48:48 +08:00
世界
f4aaf33bf2 ccm,ocm: strip reverse proxy headers from upstream responses 2026-03-15 18:48:48 +08:00
世界
8fe8e238b3 service/ocm: unify websocket logging with HTTP request logging 2026-03-15 18:48:47 +08:00
世界
6f433937ba ccm,ocm: auto-detect plan weight for external credentials via status endpoint 2026-03-15 18:48:47 +08:00
世界
80d5432654 service/ccm: update oauth token URL and remove unnecessary Accept header 2026-03-15 18:48:46 +08:00
世界
8984b45ded ccm,ocm: improve balancer least_used with plan-weighted scoring and reset urgency
Scale remaining capacity by plan weight (Pro=1, Max 5x=5, Max 20x=10
for CCM; Plus=1, Pro=10 for OCM) so higher-tier accounts contribute
proportionally more. Factor in weekly reset proximity so credentials
about to reset are preferred ("use it or lose it").

Auto-detect plan weight from subscriptionType + rateLimitTier (CCM)
or plan_type (OCM). Fetch /api/oauth/profile when rateLimitTier is
missing from the credential file. External credentials accept a
manual plan_weight option.
2026-03-15 18:48:46 +08:00
世界
25a9e4ce59 service/ocm: only log new credential assignments and add websocket logging 2026-03-15 18:48:46 +08:00
世界
615a7e05b4 service/ccm: only log new credential assignments and show context window in model 2026-03-15 18:48:46 +08:00
世界
1628272507 ccm,ocm: mark credentials unusable on usage poll failure and trigger poll on upstream error 2026-03-15 18:48:46 +08:00
世界
ee65b375cb service/ccm: allow extended context (1m) for all credentials
1m context is now available to all subscribers and no longer
consumes Extra Usage.
2026-03-15 18:48:45 +08:00
世界
a09174a9a2 service/ccm: reject fast-mode external credentials 2026-03-15 18:48:45 +08:00
世界
ce543a935f ccm,ocm: fix reserveWeekly default and remove dead reserve fields 2026-03-15 18:48:45 +08:00
世界
7f93c76b1a ccm,ocm: add limit options and fix aggregated utilization scaling
Add limit_5h and limit_weekly options as alternatives to reserve_5h
and reserve_weekly for capping credential utilization. The two are
mutually exclusive per window.

Fix computeAggregatedUtilization to scale per-credential utilization
relative to each credential's cap before averaging, so external users
see correct available capacity regardless of per-credential caps.

Fix pickLeastUsed to compare remaining capacity (cap - utilization)
instead of raw utilization, ensuring fair comparison across credentials
with different caps.
2026-03-15 18:48:44 +08:00
世界
df6e47f5f1 ocm: preserve websocket rate limit event fields 2026-03-15 18:48:44 +08:00
世界
1993da3735 ocm: rewrite codex.rate_limits WebSocket events for external users
The HTTP path rewrites utilization headers for external users via
rewriteResponseHeadersForExternalUser to show aggregated values.
The WebSocket upgrade headers were also rewritten, but in-band
codex.rate_limits events were forwarded unmodified, leaking
per-credential utilization to external users.
2026-03-15 18:48:43 +08:00
世界
22376472d0 ccm,ocm: fix passive usage update for WebSocket connections
WebSocket 101 upgrade responses do not include utilization headers
(confirmed via codex CLI source). Rate limit data is delivered
exclusively through in-band events (codex.rate_limits and error
events with status 429).

Previously, updateStateFromHeaders unconditionally bumped lastUpdated
even when no utilization headers were found, which suppressed polling
and left credential utilization permanently stale during WebSocket
sessions.

- Only bump lastUpdated when actual utilization data is parsed
- Parse in-band codex.rate_limits events to update credential state
- Detect in-band 429 error events to markRateLimited
- Fix WebSocket 429 retry to update old credential state before retry
2026-03-15 18:48:43 +08:00
世界
74bf20d349 ccm,ocm: fix reverse session shutdown race 2026-03-15 18:48:43 +08:00
世界
ff8585f7c6 ccm,ocm: block utilization decrease within same rate-limit window
updateStateFromHeaders unconditionally applied header utilization
values even when they were lower than the current state, causing
poll-sourced values to be overwritten by stale header values.

Parse reset timestamps before utilization and only allow decreases
when the reset timestamp changes (indicating a new rate-limit
window). Also add math.Ceil to CCM external credential for
consistency with default credential.
2026-03-15 18:48:42 +08:00
世界
4d5108fe7f ccm,ocm: fix connector-side bufio data loss in reverse proxy
connectorConnect() creates a bufio.NewReader to read the HTTP 101
upgrade response, but then passes the raw conn to yamux.Server().
If TCP coalesces the 101 response with initial yamux frames, the
bufio reader over-reads into its buffer and those bytes are lost
to yamux, causing session failure.

Wrap the bufio.Reader and raw conn into a bufferedConn so yamux
reads through the buffer first.
2026-03-15 18:48:42 +08:00
世界
3b177df05e ccm,ocm: fix data race on reverseContext/reverseCancel
InterfaceUpdated() writes reverseContext and reverseCancel without
synchronization while connectorLoop/connectorConnect goroutines
read them concurrently. close() also accesses reverseCancel without
a lock.

Fix by extending reverseAccess mutex to protect these fields:
- Add getReverseContext()/resetReverseContext() methods
- Pass context as parameter to connectorConnect
- Merge close() into a single lock acquisition
- Use resetReverseContext() in InterfaceUpdated()
2026-03-15 18:48:42 +08:00
世界
1824881719 ccm,ocm: reset connector backoff after successful connection
The consecutiveFailures counter in connectorLoop never resets,
causing backoff to permanently cap at 30-45s even after a
connection that served successfully for hours.

Reset the counter when connectorConnect ran for at least one
minute, indicating a successful session rather than a transient
dial/handshake failure.
2026-03-15 18:48:41 +08:00
世界
02a1409e9a ccm,ocm: unify HTTP request retry with fast retry and exponential backoff 2026-03-15 18:48:41 +08:00
世界
af94ea9089 Fix reverse external credential handling 2026-03-15 18:48:41 +08:00
世界
970951f369 ccm,ocm: add reverse proxy support for external credentials
Allow two CCM/OCM instances to share credentials when only one has a
public IP, using yamux-multiplexed reverse connections.

Three credential modes:
- Normal: URL set, reverse=false — standard HTTP proxy
- Receiver: URL empty — waits for incoming reverse connection
- Connector: URL set, reverse=true — dials out to establish connection

Extend InterfaceUpdated to services so network changes trigger
reverse connection reconnection.
2026-03-15 18:48:40 +08:00
世界
15f3619995 ccm,ocm: strip reverse proxy headers before forwarding to upstream 2026-03-15 18:48:40 +08:00
世界
b96ab4fef9 ccm,ocm,ssmapi: fix HTTP/2 over TLS with h2c handler
aTLS.NewListener returns *LazyConn, not *tls.Conn, so Go's
http.Server cannot detect TLS via type assertion and falls back
to HTTP/1.x. When ALPN negotiates h2, the client sends HTTP/2
frames that the server fails to parse, causing HTTP 520 errors
behind Cloudflare.

Wrap HTTP handlers with h2c.NewHandler to intercept the HTTP/2
client preface and dispatch to http2.Server.ServeConn, consistent
with DERP, v2rayhttp, naive, and v2raygrpclite services.
2026-03-15 18:48:40 +08:00
世界
6829f91a06 ccm,ocm: check credential file writability before token refresh
Refuse to refresh tokens when the credential file is not writable,
preventing server-side invalidation of the old refresh token that
would make the credential permanently unusable after restart.
2026-03-15 18:48:40 +08:00
世界
8e5811a8c7 ccm,ocm: watch credential_path and allow delayed credentials 2026-03-15 18:48:40 +08:00
世界
da8ff6f578 ccm/ocm: Add external credential support for cross-instance usage sharing
Extract credential interface from *defaultCredential to support both
default (OAuth) and external (remote proxy) credential types. External
credentials proxy requests to a remote ccm/ocm instance with bearer
token auth, poll a /status endpoint for utilization, and parse
aggregated rate limit headers from responses.

Add allow_external_usage user flag to control whether balancer/fallback
providers may select external credentials. Add status endpoint
(/ccm/v1/status, /ocm/v1/status) returning averaged utilization across
eligible credentials. Rewrite response rate limit headers for external
users with aggregated values.
2026-03-15 18:48:39 +08:00
世界
2801bce815 ccm/ocm: Add multi-credential support with balancer and fallback strategies 2026-03-15 18:48:39 +08:00
世界
a11cd1e0c6 Bump version 2026-03-15 17:57:54 +08:00
世界
bd0fb83d2d cronet-go: Update chromium to 145.0.7632.159 2026-03-15 17:57:54 +08:00
世界
9462b1deeb documentation: Update descriptions for neighbor rules 2026-03-15 17:57:53 +08:00
世界
44d1c86b1b Add macOS support for MAC and hostname rule items 2026-03-15 17:57:53 +08:00
世界
f802668915 Add Android support for MAC and hostname rule items 2026-03-15 17:57:53 +08:00
世界
4d217b7481 Add MAC and hostname rule items 2026-03-15 17:57:53 +08:00
415 changed files with 24390 additions and 34188 deletions

View File

@@ -4,7 +4,6 @@
--license GPL-3.0-or-later --license GPL-3.0-or-later
--description "The universal proxy platform." --description "The universal proxy platform."
--url "https://sing-box.sagernet.org/" --url "https://sing-box.sagernet.org/"
--vendor SagerNet
--maintainer "nekohasekai <contact-git@sekai.icu>" --maintainer "nekohasekai <contact-git@sekai.icu>"
--deb-field "Bug: https://github.com/SagerNet/sing-box/issues" --deb-field "Bug: https://github.com/SagerNet/sing-box/issues"
--no-deb-generate-changes --no-deb-generate-changes

View File

@@ -1 +1 @@
335e5bef5d88fc4474c9a70b865561f45a67de83 ea7cd33752aed62603775af3df946c1b83f4b0b3

View File

@@ -1,33 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
branches=$(git branch -r --contains HEAD)
if echo "$branches" | grep -q 'origin/stable'; then
track=stable
elif echo "$branches" | grep -q 'origin/testing'; then
track=testing
elif echo "$branches" | grep -q 'origin/oldstable'; then
track=oldstable
else
echo "ERROR: HEAD is not on any known release branch (stable/testing/oldstable)" >&2
exit 1
fi
if [[ "$track" == "stable" ]]; then
tag=$(git describe --tags --exact-match HEAD 2>/dev/null || true)
if [[ -n "$tag" && "$tag" == *"-"* ]]; then
track=beta
fi
fi
case "$track" in
stable) name=sing-box; docker_tag=latest ;;
beta) name=sing-box-beta; docker_tag=latest-beta ;;
testing) name=sing-box-testing; docker_tag=latest-testing ;;
oldstable) name=sing-box-oldstable; docker_tag=latest-oldstable ;;
esac
echo "track=${track} name=${name} docker_tag=${docker_tag}" >&2
echo "TRACK=${track}" >> "$GITHUB_ENV"
echo "NAME=${name}" >> "$GITHUB_ENV"
echo "DOCKER_TAG=${docker_tag}" >> "$GITHUB_ENV"

View File

@@ -19,6 +19,7 @@ env:
jobs: jobs:
build_binary: build_binary:
name: Build binary name: Build binary
if: github.event_name != 'release' || github.event.release.target_commitish != 'oldstable'
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: true fail-fast: true
@@ -259,13 +260,13 @@ jobs:
fi fi
echo "ref=$ref" echo "ref=$ref"
echo "ref=$ref" >> $GITHUB_OUTPUT echo "ref=$ref" >> $GITHUB_OUTPUT
- name: Checkout if [[ $ref == *"-"* ]]; then
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5 latest=latest-beta
with: else
ref: ${{ steps.ref.outputs.ref }} latest=latest
fetch-depth: 0 fi
- name: Detect track echo "latest=$latest"
run: bash .github/detect_track.sh echo "latest=$latest" >> $GITHUB_OUTPUT
- name: Download digests - name: Download digests
uses: actions/download-artifact@v5 uses: actions/download-artifact@v5
with: with:
@@ -285,11 +286,11 @@ jobs:
working-directory: /tmp/digests working-directory: /tmp/digests
run: | run: |
docker buildx imagetools create \ docker buildx imagetools create \
-t "${{ env.REGISTRY_IMAGE }}:${{ env.DOCKER_TAG }}" \ -t "${{ env.REGISTRY_IMAGE }}:${{ steps.ref.outputs.latest }}" \
-t "${{ env.REGISTRY_IMAGE }}:${{ steps.ref.outputs.ref }}" \ -t "${{ env.REGISTRY_IMAGE }}:${{ steps.ref.outputs.ref }}" \
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *) $(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
- name: Inspect image - name: Inspect image
if: github.event_name != 'push' if: github.event_name != 'push'
run: | run: |
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ env.DOCKER_TAG }} docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.ref.outputs.latest }}
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.ref.outputs.ref }} docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.ref.outputs.ref }}

View File

@@ -11,6 +11,11 @@ on:
description: "Version name" description: "Version name"
required: true required: true
type: string type: string
forceBeta:
description: "Force beta"
required: false
type: boolean
default: false
release: release:
types: types:
- published - published
@@ -18,6 +23,7 @@ on:
jobs: jobs:
calculate_version: calculate_version:
name: Calculate version name: Calculate version
if: github.event_name != 'release' || github.event.release.target_commitish != 'oldstable'
runs-on: ubuntu-latest runs-on: ubuntu-latest
outputs: outputs:
version: ${{ steps.outputs.outputs.version }} version: ${{ steps.outputs.outputs.version }}
@@ -162,8 +168,14 @@ jobs:
- name: Set mtime - name: Set mtime
run: |- run: |-
TZ=UTC touch -t '197001010000' dist/sing-box TZ=UTC touch -t '197001010000' dist/sing-box
- name: Detect track - name: Set name
run: bash .github/detect_track.sh if: (! contains(needs.calculate_version.outputs.version, '-')) && !inputs.forceBeta
run: |-
echo "NAME=sing-box" >> "$GITHUB_ENV"
- name: Set beta name
if: contains(needs.calculate_version.outputs.version, '-') || inputs.forceBeta
run: |-
echo "NAME=sing-box-beta" >> "$GITHUB_ENV"
- name: Set version - name: Set version
run: |- run: |-
PKG_VERSION="${{ needs.calculate_version.outputs.version }}" PKG_VERSION="${{ needs.calculate_version.outputs.version }}"

4
.gitignore vendored
View File

@@ -18,6 +18,6 @@
.DS_Store .DS_Store
/config.d/ /config.d/
/venv/ /venv/
CLAUDE.md /CLAUDE.md
AGENTS.md /AGENTS.md
/.claude/ /.claude/

View File

@@ -1,21 +0,0 @@
package certificate
type Adapter struct {
providerType string
providerTag string
}
func NewAdapter(providerType string, providerTag string) Adapter {
return Adapter{
providerType: providerType,
providerTag: providerTag,
}
}
func (a *Adapter) Type() string {
return a.providerType
}
func (a *Adapter) Tag() string {
return a.providerTag
}

View File

@@ -1,158 +0,0 @@
package certificate
import (
"context"
"os"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
)
var _ adapter.CertificateProviderManager = (*Manager)(nil)
type Manager struct {
logger log.ContextLogger
registry adapter.CertificateProviderRegistry
access sync.Mutex
started bool
stage adapter.StartStage
providers []adapter.CertificateProviderService
providerByTag map[string]adapter.CertificateProviderService
}
func NewManager(logger log.ContextLogger, registry adapter.CertificateProviderRegistry) *Manager {
return &Manager{
logger: logger,
registry: registry,
providerByTag: make(map[string]adapter.CertificateProviderService),
}
}
func (m *Manager) Start(stage adapter.StartStage) error {
m.access.Lock()
if m.started && m.stage >= stage {
panic("already started")
}
m.started = true
m.stage = stage
providers := m.providers
m.access.Unlock()
for _, provider := range providers {
name := "certificate-provider/" + provider.Type() + "[" + provider.Tag() + "]"
m.logger.Trace(stage, " ", name)
startTime := time.Now()
err := adapter.LegacyStart(provider, stage)
if err != nil {
return E.Cause(err, stage, " ", name)
}
m.logger.Trace(stage, " ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
}
return nil
}
func (m *Manager) Close() error {
m.access.Lock()
defer m.access.Unlock()
if !m.started {
return nil
}
m.started = false
providers := m.providers
m.providers = nil
monitor := taskmonitor.New(m.logger, C.StopTimeout)
var err error
for _, provider := range providers {
name := "certificate-provider/" + provider.Type() + "[" + provider.Tag() + "]"
m.logger.Trace("close ", name)
startTime := time.Now()
monitor.Start("close ", name)
err = E.Append(err, provider.Close(), func(err error) error {
return E.Cause(err, "close ", name)
})
monitor.Finish()
m.logger.Trace("close ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
}
return err
}
func (m *Manager) CertificateProviders() []adapter.CertificateProviderService {
m.access.Lock()
defer m.access.Unlock()
return m.providers
}
func (m *Manager) Get(tag string) (adapter.CertificateProviderService, bool) {
m.access.Lock()
provider, found := m.providerByTag[tag]
m.access.Unlock()
return provider, found
}
func (m *Manager) Remove(tag string) error {
m.access.Lock()
provider, found := m.providerByTag[tag]
if !found {
m.access.Unlock()
return os.ErrInvalid
}
delete(m.providerByTag, tag)
index := common.Index(m.providers, func(it adapter.CertificateProviderService) bool {
return it == provider
})
if index == -1 {
panic("invalid certificate provider index")
}
m.providers = append(m.providers[:index], m.providers[index+1:]...)
started := m.started
m.access.Unlock()
if started {
return provider.Close()
}
return nil
}
func (m *Manager) Create(ctx context.Context, logger log.ContextLogger, tag string, providerType string, options any) error {
provider, err := m.registry.Create(ctx, logger, tag, providerType, options)
if err != nil {
return err
}
m.access.Lock()
defer m.access.Unlock()
if m.started {
name := "certificate-provider/" + provider.Type() + "[" + provider.Tag() + "]"
for _, stage := range adapter.ListStartStages {
m.logger.Trace(stage, " ", name)
startTime := time.Now()
err = adapter.LegacyStart(provider, stage)
if err != nil {
return E.Cause(err, stage, " ", name)
}
m.logger.Trace(stage, " ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
}
}
if existsProvider, loaded := m.providerByTag[tag]; loaded {
if m.started {
err = existsProvider.Close()
if err != nil {
return E.Cause(err, "close certificate-provider/", existsProvider.Type(), "[", existsProvider.Tag(), "]")
}
}
existsIndex := common.Index(m.providers, func(it adapter.CertificateProviderService) bool {
return it == existsProvider
})
if existsIndex == -1 {
panic("invalid certificate provider index")
}
m.providers = append(m.providers[:existsIndex], m.providers[existsIndex+1:]...)
}
m.providers = append(m.providers, provider)
m.providerByTag[tag] = provider
return nil
}

View File

@@ -1,72 +0,0 @@
package certificate
import (
"context"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
type ConstructorFunc[T any] func(ctx context.Context, logger log.ContextLogger, tag string, options T) (adapter.CertificateProviderService, error)
func Register[Options any](registry *Registry, providerType string, constructor ConstructorFunc[Options]) {
registry.register(providerType, func() any {
return new(Options)
}, func(ctx context.Context, logger log.ContextLogger, tag string, rawOptions any) (adapter.CertificateProviderService, error) {
var options *Options
if rawOptions != nil {
options = rawOptions.(*Options)
}
return constructor(ctx, logger, tag, common.PtrValueOrDefault(options))
})
}
var _ adapter.CertificateProviderRegistry = (*Registry)(nil)
type (
optionsConstructorFunc func() any
constructorFunc func(ctx context.Context, logger log.ContextLogger, tag string, options any) (adapter.CertificateProviderService, error)
)
type Registry struct {
access sync.Mutex
optionsType map[string]optionsConstructorFunc
constructor map[string]constructorFunc
}
func NewRegistry() *Registry {
return &Registry{
optionsType: make(map[string]optionsConstructorFunc),
constructor: make(map[string]constructorFunc),
}
}
func (m *Registry) CreateOptions(providerType string) (any, bool) {
m.access.Lock()
defer m.access.Unlock()
optionsConstructor, loaded := m.optionsType[providerType]
if !loaded {
return nil, false
}
return optionsConstructor(), true
}
func (m *Registry) Create(ctx context.Context, logger log.ContextLogger, tag string, providerType string, options any) (adapter.CertificateProviderService, error) {
m.access.Lock()
defer m.access.Unlock()
constructor, loaded := m.constructor[providerType]
if !loaded {
return nil, E.New("certificate provider type not found: " + providerType)
}
return constructor(ctx, logger, tag, options)
}
func (m *Registry) register(providerType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) {
m.access.Lock()
defer m.access.Unlock()
m.optionsType[providerType] = optionsConstructor
m.constructor[providerType] = constructor
}

View File

@@ -1,38 +0,0 @@
package adapter
import (
"context"
"crypto/tls"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)
type CertificateProvider interface {
GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
}
type ACMECertificateProvider interface {
CertificateProvider
GetACMENextProtos() []string
}
type CertificateProviderService interface {
Lifecycle
Type() string
Tag() string
CertificateProvider
}
type CertificateProviderRegistry interface {
option.CertificateProviderOptionsRegistry
Create(ctx context.Context, logger log.ContextLogger, tag string, providerType string, options any) (CertificateProviderService, error)
}
type CertificateProviderManager interface {
Lifecycle
CertificateProviders() []CertificateProviderService
Get(tag string) (CertificateProviderService, bool)
Remove(tag string) error
Create(ctx context.Context, logger log.ContextLogger, tag string, providerType string, options any) error
}

View File

@@ -3,7 +3,6 @@ package adapter
import ( import (
"context" "context"
"net/netip" "net/netip"
"time"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
@@ -26,19 +25,18 @@ type DNSRouter interface {
type DNSClient interface { type DNSClient interface {
Start() Start()
Exchange(ctx context.Context, transport DNSTransport, message *dns.Msg, options DNSQueryOptions, responseChecker func(response *dns.Msg) bool) (*dns.Msg, error) Exchange(ctx context.Context, transport DNSTransport, message *dns.Msg, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error)
Lookup(ctx context.Context, transport DNSTransport, domain string, options DNSQueryOptions, responseChecker func(response *dns.Msg) bool) ([]netip.Addr, error) Lookup(ctx context.Context, transport DNSTransport, domain string, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error)
ClearCache() ClearCache()
} }
type DNSQueryOptions struct { type DNSQueryOptions struct {
Transport DNSTransport Transport DNSTransport
Strategy C.DomainStrategy Strategy C.DomainStrategy
LookupStrategy C.DomainStrategy LookupStrategy C.DomainStrategy
DisableCache bool DisableCache bool
DisableOptimisticCache bool RewriteTTL *uint32
RewriteTTL *uint32 ClientSubnet netip.Prefix
ClientSubnet netip.Prefix
} }
func DNSQueryOptionsFrom(ctx context.Context, options *option.DomainResolveOptions) (*DNSQueryOptions, error) { func DNSQueryOptionsFrom(ctx context.Context, options *option.DomainResolveOptions) (*DNSQueryOptions, error) {
@@ -51,12 +49,11 @@ func DNSQueryOptionsFrom(ctx context.Context, options *option.DomainResolveOptio
return nil, E.New("domain resolver not found: " + options.Server) return nil, E.New("domain resolver not found: " + options.Server)
} }
return &DNSQueryOptions{ return &DNSQueryOptions{
Transport: transport, Transport: transport,
Strategy: C.DomainStrategy(options.Strategy), Strategy: C.DomainStrategy(options.Strategy),
DisableCache: options.DisableCache, DisableCache: options.DisableCache,
DisableOptimisticCache: options.DisableOptimisticCache, RewriteTTL: options.RewriteTTL,
RewriteTTL: options.RewriteTTL, ClientSubnet: options.ClientSubnet.Build(netip.Prefix{}),
ClientSubnet: options.ClientSubnet.Build(netip.Prefix{}),
}, nil }, nil
} }
@@ -66,13 +63,6 @@ type RDRCStore interface {
SaveRDRCAsync(transportName string, qName string, qType uint16, logger logger.Logger) SaveRDRCAsync(transportName string, qName string, qType uint16, logger logger.Logger)
} }
type DNSCacheStore interface {
LoadDNSCache(transportName string, qName string, qType uint16) (rawMessage []byte, expireAt time.Time, loaded bool)
SaveDNSCache(transportName string, qName string, qType uint16, rawMessage []byte, expireAt time.Time) error
SaveDNSCacheAsync(transportName string, qName string, qType uint16, rawMessage []byte, expireAt time.Time, logger logger.Logger)
ClearDNSCache() error
}
type DNSTransport interface { type DNSTransport interface {
Lifecycle Lifecycle
Type() string Type() string
@@ -82,6 +72,11 @@ type DNSTransport interface {
Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error) Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error)
} }
type LegacyDNSTransport interface {
LegacyStrategy() C.DomainStrategy
LegacyClientSubnet() netip.Prefix
}
type DNSTransportRegistry interface { type DNSTransportRegistry interface {
option.DNSTransportOptionsRegistry option.DNSTransportOptionsRegistry
CreateDNSTransport(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) (DNSTransport, error) CreateDNSTransport(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) (DNSTransport, error)

View File

@@ -47,12 +47,6 @@ type CacheFile interface {
StoreRDRC() bool StoreRDRC() bool
RDRCStore RDRCStore
StoreDNS() bool
DNSCacheStore
SetDisableExpire(disableExpire bool)
SetOptimisticTimeout(timeout time.Duration)
LoadMode() string LoadMode() string
StoreMode(mode string) error StoreMode(mode string) error
LoadSelected(group string) string LoadSelected(group string) string

View File

@@ -1,13 +0,0 @@
package adapter
import (
"net/http"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/logger"
)
type HTTPClientManager interface {
ResolveTransport(logger logger.ContextLogger, options option.HTTPClientOptions) (http.RoundTripper, error)
DefaultTransport() http.RoundTripper
}

View File

@@ -10,8 +10,6 @@ import (
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/miekg/dns"
) )
type Inbound interface { type Inbound interface {
@@ -81,16 +79,14 @@ type InboundContext struct {
FallbackNetworkType []C.InterfaceType FallbackNetworkType []C.InterfaceType
FallbackDelay time.Duration FallbackDelay time.Duration
DestinationAddresses []netip.Addr DestinationAddresses []netip.Addr
DNSResponse *dns.Msg SourceGeoIPCode string
DestinationAddressMatchFromResponse bool GeoIPCode string
SourceGeoIPCode string ProcessInfo *ConnectionOwner
GeoIPCode string SourceMACAddress net.HardwareAddr
ProcessInfo *ConnectionOwner SourceHostname string
SourceMACAddress net.HardwareAddr QueryType uint16
SourceHostname string FakeIP bool
QueryType uint16
FakeIP bool
// rule cache // rule cache
@@ -108,10 +104,6 @@ type InboundContext struct {
func (c *InboundContext) ResetRuleCache() { func (c *InboundContext) ResetRuleCache() {
c.IPCIDRMatchSource = false c.IPCIDRMatchSource = false
c.IPCIDRAcceptEmpty = false c.IPCIDRAcceptEmpty = false
c.ResetRuleMatchCache()
}
func (c *InboundContext) ResetRuleMatchCache() {
c.SourceAddressMatch = false c.SourceAddressMatch = false
c.SourcePortMatch = false c.SourcePortMatch = false
c.DestinationAddressMatch = false c.DestinationAddressMatch = false
@@ -119,51 +111,6 @@ func (c *InboundContext) ResetRuleMatchCache() {
c.DidMatch = false c.DidMatch = false
} }
func (c *InboundContext) DNSResponseAddressesForMatch() []netip.Addr {
return DNSResponseAddresses(c.DNSResponse)
}
func DNSResponseAddresses(response *dns.Msg) []netip.Addr {
if response == nil || response.Rcode != dns.RcodeSuccess {
return nil
}
addresses := make([]netip.Addr, 0, len(response.Answer))
for _, rawRecord := range response.Answer {
switch record := rawRecord.(type) {
case *dns.A:
addr := M.AddrFromIP(record.A)
if addr.IsValid() {
addresses = append(addresses, addr)
}
case *dns.AAAA:
addr := M.AddrFromIP(record.AAAA)
if addr.IsValid() {
addresses = append(addresses, addr)
}
case *dns.HTTPS:
for _, value := range record.SVCB.Value {
switch hint := value.(type) {
case *dns.SVCBIPv4Hint:
for _, ip := range hint.Hint {
addr := M.AddrFromIP(ip).Unmap()
if addr.IsValid() {
addresses = append(addresses, addr)
}
}
case *dns.SVCBIPv6Hint:
for _, ip := range hint.Hint {
addr := M.AddrFromIP(ip)
if addr.IsValid() {
addresses = append(addresses, addr)
}
}
}
}
}
}
return addresses
}
type inboundContextKey struct{} type inboundContextKey struct{}
func WithContext(ctx context.Context, inboundContext *InboundContext) context.Context { func WithContext(ctx context.Context, inboundContext *InboundContext) context.Context {

View File

@@ -1,45 +0,0 @@
package adapter
import (
"net"
"net/netip"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
func TestDNSResponseAddressesUnmapsHTTPSIPv4Hints(t *testing.T) {
t.Parallel()
ipv4Hint := net.ParseIP("1.1.1.1")
require.NotNil(t, ipv4Hint)
response := &dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
Rcode: dns.RcodeSuccess,
},
Answer: []dns.RR{
&dns.HTTPS{
SVCB: dns.SVCB{
Hdr: dns.RR_Header{
Name: dns.Fqdn("example.com"),
Rrtype: dns.TypeHTTPS,
Class: dns.ClassINET,
Ttl: 60,
},
Priority: 1,
Target: ".",
Value: []dns.SVCBKeyValue{
&dns.SVCBIPv4Hint{Hint: []net.IP{ipv4Hint}},
},
},
},
},
}
addresses := DNSResponseAddresses(response)
require.Equal(t, []netip.Addr{netip.MustParseAddr("1.1.1.1")}, addresses)
require.True(t, addresses[0].Is4())
}

View File

@@ -51,11 +51,11 @@ type FindConnectionOwnerRequest struct {
} }
type ConnectionOwner struct { type ConnectionOwner struct {
ProcessID uint32 ProcessID uint32
UserId int32 UserId int32
UserName string UserName string
ProcessPath string ProcessPath string
AndroidPackageNames []string AndroidPackageName string
} }
type Notification struct { type Notification struct {

View File

@@ -2,11 +2,17 @@ package adapter
import ( import (
"context" "context"
"crypto/tls"
"net" "net"
"net/http"
"sync"
"time" "time"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/common/x/list"
"go4.org/netipx" "go4.org/netipx"
@@ -45,7 +51,7 @@ type ConnectionRouterEx interface {
type RuleSet interface { type RuleSet interface {
Name() string Name() string
StartContext(ctx context.Context) error StartContext(ctx context.Context, startContext *HTTPStartContext) error
PostStart() error PostStart() error
Metadata() RuleSetMetadata Metadata() RuleSetMetadata
ExtractIPSet() []*netipx.IPSet ExtractIPSet() []*netipx.IPSet
@@ -60,14 +66,51 @@ type RuleSet interface {
type RuleSetUpdateCallback func(it RuleSet) type RuleSetUpdateCallback func(it RuleSet)
type DNSRuleSetUpdateValidator interface { type RuleSetMetadata struct {
ValidateRuleSetMetadataUpdate(tag string, metadata RuleSetMetadata) error ContainsProcessRule bool
ContainsWIFIRule bool
ContainsIPCIDRRule bool
}
type HTTPStartContext struct {
ctx context.Context
access sync.Mutex
httpClientCache map[string]*http.Client
} }
// ip_version is not a headless-rule item, so ContainsIPVersionRule is intentionally absent. func NewHTTPStartContext(ctx context.Context) *HTTPStartContext {
type RuleSetMetadata struct { return &HTTPStartContext{
ContainsProcessRule bool ctx: ctx,
ContainsWIFIRule bool httpClientCache: make(map[string]*http.Client),
ContainsIPCIDRRule bool }
ContainsDNSQueryTypeRule bool }
func (c *HTTPStartContext) HTTPClient(detour string, dialer N.Dialer) *http.Client {
c.access.Lock()
defer c.access.Unlock()
if httpClient, loaded := c.httpClientCache[detour]; loaded {
return httpClient
}
httpClient := &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: C.TCPTimeout,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
TLSClientConfig: &tls.Config{
Time: ntp.TimeFuncFromContext(c.ctx),
RootCAs: RootPoolFromContext(c.ctx),
},
},
}
c.httpClientCache[detour] = httpClient
return httpClient
}
func (c *HTTPStartContext) Close() {
c.access.Lock()
defer c.access.Unlock()
for _, client := range c.httpClientCache {
client.CloseIdleConnections()
}
} }

View File

@@ -2,8 +2,6 @@ package adapter
import ( import (
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/miekg/dns"
) )
type HeadlessRule interface { type HeadlessRule interface {
@@ -20,9 +18,8 @@ type Rule interface {
type DNSRule interface { type DNSRule interface {
Rule Rule
LegacyPreMatch(metadata *InboundContext) bool
WithAddressLimit() bool WithAddressLimit() bool
MatchAddressLimit(metadata *InboundContext, response *dns.Msg) bool MatchAddressLimit(metadata *InboundContext) bool
} }
type RuleAction interface { type RuleAction interface {
@@ -32,7 +29,7 @@ type RuleAction interface {
func IsFinalAction(action RuleAction) bool { func IsFinalAction(action RuleAction) bool {
switch action.Type() { switch action.Type() {
case C.RuleActionTypeSniff, C.RuleActionTypeResolve, C.RuleActionTypeEvaluate: case C.RuleActionTypeSniff, C.RuleActionTypeResolve:
return false return false
default: default:
return true return true

View File

@@ -1,49 +0,0 @@
package adapter
import "context"
type TailscaleEndpoint interface {
SubscribeTailscaleStatus(ctx context.Context, fn func(*TailscaleEndpointStatus)) error
StartTailscalePing(ctx context.Context, peerIP string, fn func(*TailscalePingResult)) error
}
type TailscalePingResult struct {
LatencyMs float64
IsDirect bool
Endpoint string
DERPRegionID int32
DERPRegionCode string
Error string
}
type TailscaleEndpointStatus struct {
BackendState string
AuthURL string
NetworkName string
MagicDNSSuffix string
Self *TailscalePeer
UserGroups []*TailscaleUserGroup
}
type TailscaleUserGroup struct {
UserID int64
LoginName string
DisplayName string
ProfilePicURL string
Peers []*TailscalePeer
}
type TailscalePeer struct {
HostName string
DNSName string
OS string
TailscaleIPs []string
Online bool
ExitNode bool
ExitNodeOption bool
Active bool
RxBytes int64
TxBytes int64
UserID int64
KeyExpiry int64
}

175
box.go
View File

@@ -9,21 +9,19 @@ import (
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
boxCertificate "github.com/sagernet/sing-box/adapter/certificate"
"github.com/sagernet/sing-box/adapter/endpoint" "github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/adapter/outbound"
boxService "github.com/sagernet/sing-box/adapter/service" boxService "github.com/sagernet/sing-box/adapter/service"
"github.com/sagernet/sing-box/common/certificate" "github.com/sagernet/sing-box/common/certificate"
"github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/httpclient"
"github.com/sagernet/sing-box/common/taskmonitor" "github.com/sagernet/sing-box/common/taskmonitor"
"github.com/sagernet/sing-box/common/tls" "github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns" "github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport/local"
"github.com/sagernet/sing-box/experimental" "github.com/sagernet/sing-box/experimental"
"github.com/sagernet/sing-box/experimental/cachefile" "github.com/sagernet/sing-box/experimental/cachefile"
"github.com/sagernet/sing-box/experimental/deprecated"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/protocol/direct" "github.com/sagernet/sing-box/protocol/direct"
@@ -39,22 +37,20 @@ import (
var _ adapter.SimpleLifecycle = (*Box)(nil) var _ adapter.SimpleLifecycle = (*Box)(nil)
type Box struct { type Box struct {
createdAt time.Time createdAt time.Time
logFactory log.Factory logFactory log.Factory
logger log.ContextLogger logger log.ContextLogger
network *route.NetworkManager network *route.NetworkManager
endpoint *endpoint.Manager endpoint *endpoint.Manager
inbound *inbound.Manager inbound *inbound.Manager
outbound *outbound.Manager outbound *outbound.Manager
service *boxService.Manager service *boxService.Manager
certificateProvider *boxCertificate.Manager dnsTransport *dns.TransportManager
dnsTransport *dns.TransportManager dnsRouter *dns.Router
dnsRouter *dns.Router connection *route.ConnectionManager
connection *route.ConnectionManager router *route.Router
router *route.Router internalService []adapter.LifecycleService
httpClientService adapter.LifecycleService done chan struct{}
internalService []adapter.LifecycleService
done chan struct{}
} }
type Options struct { type Options struct {
@@ -70,7 +66,6 @@ func Context(
endpointRegistry adapter.EndpointRegistry, endpointRegistry adapter.EndpointRegistry,
dnsTransportRegistry adapter.DNSTransportRegistry, dnsTransportRegistry adapter.DNSTransportRegistry,
serviceRegistry adapter.ServiceRegistry, serviceRegistry adapter.ServiceRegistry,
certificateProviderRegistry adapter.CertificateProviderRegistry,
) context.Context { ) context.Context {
if service.FromContext[option.InboundOptionsRegistry](ctx) == nil || if service.FromContext[option.InboundOptionsRegistry](ctx) == nil ||
service.FromContext[adapter.InboundRegistry](ctx) == nil { service.FromContext[adapter.InboundRegistry](ctx) == nil {
@@ -95,10 +90,6 @@ func Context(
ctx = service.ContextWith[option.ServiceOptionsRegistry](ctx, serviceRegistry) ctx = service.ContextWith[option.ServiceOptionsRegistry](ctx, serviceRegistry)
ctx = service.ContextWith[adapter.ServiceRegistry](ctx, serviceRegistry) ctx = service.ContextWith[adapter.ServiceRegistry](ctx, serviceRegistry)
} }
if service.FromContext[adapter.CertificateProviderRegistry](ctx) == nil {
ctx = service.ContextWith[option.CertificateProviderOptionsRegistry](ctx, certificateProviderRegistry)
ctx = service.ContextWith[adapter.CertificateProviderRegistry](ctx, certificateProviderRegistry)
}
return ctx return ctx
} }
@@ -115,7 +106,6 @@ func New(options Options) (*Box, error) {
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx) outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
dnsTransportRegistry := service.FromContext[adapter.DNSTransportRegistry](ctx) dnsTransportRegistry := service.FromContext[adapter.DNSTransportRegistry](ctx)
serviceRegistry := service.FromContext[adapter.ServiceRegistry](ctx) serviceRegistry := service.FromContext[adapter.ServiceRegistry](ctx)
certificateProviderRegistry := service.FromContext[adapter.CertificateProviderRegistry](ctx)
if endpointRegistry == nil { if endpointRegistry == nil {
return nil, E.New("missing endpoint registry in context") return nil, E.New("missing endpoint registry in context")
@@ -132,9 +122,6 @@ func New(options Options) (*Box, error) {
if serviceRegistry == nil { if serviceRegistry == nil {
return nil, E.New("missing service registry in context") return nil, E.New("missing service registry in context")
} }
if certificateProviderRegistry == nil {
return nil, E.New("missing certificate provider registry in context")
}
ctx = pause.WithDefaultManager(ctx) ctx = pause.WithDefaultManager(ctx)
experimentalOptions := common.PtrValueOrDefault(options.Experimental) experimentalOptions := common.PtrValueOrDefault(options.Experimental)
@@ -172,10 +159,6 @@ func New(options Options) (*Box, error) {
} }
var internalServices []adapter.LifecycleService var internalServices []adapter.LifecycleService
routeOptions := common.PtrValueOrDefault(options.Route)
httpClientManager := httpclient.NewManager(ctx, logFactory.NewLogger("httpclient"), options.HTTPClients, routeOptions.DefaultHTTPClient)
service.MustRegister[adapter.HTTPClientManager](ctx, httpClientManager)
httpClientService := adapter.LifecycleService(httpClientManager)
certificateOptions := common.PtrValueOrDefault(options.Certificate) certificateOptions := common.PtrValueOrDefault(options.Certificate)
if C.IsAndroid || certificateOptions.Store != "" && certificateOptions.Store != C.CertificateStoreSystem || if C.IsAndroid || certificateOptions.Store != "" && certificateOptions.Store != C.CertificateStoreSystem ||
len(certificateOptions.Certificate) > 0 || len(certificateOptions.Certificate) > 0 ||
@@ -188,25 +171,21 @@ func New(options Options) (*Box, error) {
service.MustRegister[adapter.CertificateStore](ctx, certificateStore) service.MustRegister[adapter.CertificateStore](ctx, certificateStore)
internalServices = append(internalServices, certificateStore) internalServices = append(internalServices, certificateStore)
} }
routeOptions := common.PtrValueOrDefault(options.Route)
dnsOptions := common.PtrValueOrDefault(options.DNS) dnsOptions := common.PtrValueOrDefault(options.DNS)
endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry) endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry)
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager) inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager)
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final) outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final)
dnsTransportManager := dns.NewTransportManager(logFactory.NewLogger("dns/transport"), dnsTransportRegistry, outboundManager, dnsOptions.Final) dnsTransportManager := dns.NewTransportManager(logFactory.NewLogger("dns/transport"), dnsTransportRegistry, outboundManager, dnsOptions.Final)
serviceManager := boxService.NewManager(logFactory.NewLogger("service"), serviceRegistry) serviceManager := boxService.NewManager(logFactory.NewLogger("service"), serviceRegistry)
certificateProviderManager := boxCertificate.NewManager(logFactory.NewLogger("certificate-provider"), certificateProviderRegistry)
service.MustRegister[adapter.EndpointManager](ctx, endpointManager) service.MustRegister[adapter.EndpointManager](ctx, endpointManager)
service.MustRegister[adapter.InboundManager](ctx, inboundManager) service.MustRegister[adapter.InboundManager](ctx, inboundManager)
service.MustRegister[adapter.OutboundManager](ctx, outboundManager) service.MustRegister[adapter.OutboundManager](ctx, outboundManager)
service.MustRegister[adapter.DNSTransportManager](ctx, dnsTransportManager) service.MustRegister[adapter.DNSTransportManager](ctx, dnsTransportManager)
service.MustRegister[adapter.ServiceManager](ctx, serviceManager) service.MustRegister[adapter.ServiceManager](ctx, serviceManager)
service.MustRegister[adapter.CertificateProviderManager](ctx, certificateProviderManager) dnsRouter := dns.NewRouter(ctx, logFactory, dnsOptions)
dnsRouter, err := dns.NewRouter(ctx, logFactory, dnsOptions)
if err != nil {
return nil, E.Cause(err, "initialize DNS router")
}
service.MustRegister[adapter.DNSRouter](ctx, dnsRouter) service.MustRegister[adapter.DNSRouter](ctx, dnsRouter)
service.MustRegister[adapter.DNSRuleSetUpdateValidator](ctx, dnsRouter)
networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions, dnsOptions) networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions, dnsOptions)
if err != nil { if err != nil {
return nil, E.Cause(err, "initialize network manager") return nil, E.Cause(err, "initialize network manager")
@@ -293,24 +272,6 @@ func New(options Options) (*Box, error) {
return nil, E.Cause(err, "initialize inbound[", i, "]") return nil, E.Cause(err, "initialize inbound[", i, "]")
} }
} }
for i, serviceOptions := range options.Services {
var tag string
if serviceOptions.Tag != "" {
tag = serviceOptions.Tag
} else {
tag = F.ToString(i)
}
err = serviceManager.Create(
ctx,
logFactory.NewLogger(F.ToString("service/", serviceOptions.Type, "[", tag, "]")),
tag,
serviceOptions.Type,
serviceOptions.Options,
)
if err != nil {
return nil, E.Cause(err, "initialize service[", i, "]")
}
}
for i, outboundOptions := range options.Outbounds { for i, outboundOptions := range options.Outbounds {
var tag string var tag string
if outboundOptions.Tag != "" { if outboundOptions.Tag != "" {
@@ -337,22 +298,22 @@ func New(options Options) (*Box, error) {
return nil, E.Cause(err, "initialize outbound[", i, "]") return nil, E.Cause(err, "initialize outbound[", i, "]")
} }
} }
for i, certificateProviderOptions := range options.CertificateProviders { for i, serviceOptions := range options.Services {
var tag string var tag string
if certificateProviderOptions.Tag != "" { if serviceOptions.Tag != "" {
tag = certificateProviderOptions.Tag tag = serviceOptions.Tag
} else { } else {
tag = F.ToString(i) tag = F.ToString(i)
} }
err = certificateProviderManager.Create( err = serviceManager.Create(
ctx, ctx,
logFactory.NewLogger(F.ToString("certificate-provider/", certificateProviderOptions.Type, "[", tag, "]")), logFactory.NewLogger(F.ToString("service/", serviceOptions.Type, "[", tag, "]")),
tag, tag,
certificateProviderOptions.Type, serviceOptions.Type,
certificateProviderOptions.Options, serviceOptions.Options,
) )
if err != nil { if err != nil {
return nil, E.Cause(err, "initialize certificate provider[", i, "]") return nil, E.Cause(err, "initialize service[", i, "]")
} }
} }
outboundManager.Initialize(func() (adapter.Outbound, error) { outboundManager.Initialize(func() (adapter.Outbound, error) {
@@ -365,20 +326,13 @@ func New(options Options) (*Box, error) {
) )
}) })
dnsTransportManager.Initialize(func() (adapter.DNSTransport, error) { dnsTransportManager.Initialize(func() (adapter.DNSTransport, error) {
return dnsTransportRegistry.CreateDNSTransport( return local.NewTransport(
ctx, ctx,
logFactory.NewLogger("dns/local"), logFactory.NewLogger("dns/local"),
"local", "local",
C.DNSTypeLocal, option.LocalDNSServerOptions{},
&option.LocalDNSServerOptions{},
) )
}) })
httpClientManager.Initialize(func() (*httpclient.Client, error) {
deprecated.Report(ctx, deprecated.OptionImplicitDefaultHTTPClient)
var httpClientOptions option.HTTPClientOptions
httpClientOptions.DefaultOutbound = true
return httpclient.NewClient(ctx, logFactory.NewLogger("httpclient"), "", httpClientOptions)
})
if platformInterface != nil { if platformInterface != nil {
err = platformInterface.Initialize(networkManager) err = platformInterface.Initialize(networkManager)
if err != nil { if err != nil {
@@ -386,7 +340,7 @@ func New(options Options) (*Box, error) {
} }
} }
if needCacheFile { if needCacheFile {
cacheFile := cachefile.New(ctx, logFactory.NewLogger("cache-file"), common.PtrValueOrDefault(experimentalOptions.CacheFile)) cacheFile := cachefile.New(ctx, common.PtrValueOrDefault(experimentalOptions.CacheFile))
service.MustRegister[adapter.CacheFile](ctx, cacheFile) service.MustRegister[adapter.CacheFile](ctx, cacheFile)
internalServices = append(internalServices, cacheFile) internalServices = append(internalServices, cacheFile)
} }
@@ -429,22 +383,20 @@ func New(options Options) (*Box, error) {
internalServices = append(internalServices, adapter.NewLifecycleService(ntpService, "ntp service")) internalServices = append(internalServices, adapter.NewLifecycleService(ntpService, "ntp service"))
} }
return &Box{ return &Box{
network: networkManager, network: networkManager,
endpoint: endpointManager, endpoint: endpointManager,
inbound: inboundManager, inbound: inboundManager,
outbound: outboundManager, outbound: outboundManager,
dnsTransport: dnsTransportManager, dnsTransport: dnsTransportManager,
service: serviceManager, service: serviceManager,
certificateProvider: certificateProviderManager, dnsRouter: dnsRouter,
dnsRouter: dnsRouter, connection: connectionManager,
connection: connectionManager, router: router,
router: router, createdAt: createdAt,
httpClientService: httpClientService, logFactory: logFactory,
createdAt: createdAt, logger: logFactory.Logger(),
logFactory: logFactory, internalService: internalServices,
logger: logFactory.Logger(), done: make(chan struct{}),
internalService: internalServices,
done: make(chan struct{}),
}, nil }, nil
} }
@@ -498,19 +450,11 @@ func (s *Box) preStart() error {
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(s.logger, adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.service, s.certificateProvider) err = adapter.Start(s.logger, adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.service)
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(s.logger, adapter.StartStateStart, s.outbound, s.dnsTransport, s.network, s.connection) err = adapter.Start(s.logger, adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router)
if err != nil {
return err
}
err = adapter.StartNamed(s.logger, adapter.StartStateStart, []adapter.LifecycleService{s.httpClientService})
if err != nil {
return err
}
err = adapter.Start(s.logger, adapter.StartStateStart, s.router, s.dnsRouter)
if err != nil { if err != nil {
return err return err
} }
@@ -526,19 +470,11 @@ func (s *Box) start() error {
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(s.logger, adapter.StartStateStart, s.endpoint) err = adapter.Start(s.logger, adapter.StartStateStart, s.inbound, s.endpoint, s.service)
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(s.logger, adapter.StartStateStart, s.certificateProvider) err = adapter.Start(s.logger, adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.inbound, s.endpoint, s.service)
if err != nil {
return err
}
err = adapter.Start(s.logger, adapter.StartStateStart, s.inbound, s.service)
if err != nil {
return err
}
err = adapter.Start(s.logger, adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.endpoint, s.certificateProvider, s.inbound, s.service)
if err != nil { if err != nil {
return err return err
} }
@@ -546,7 +482,7 @@ func (s *Box) start() error {
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(s.logger, adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.endpoint, s.certificateProvider, s.inbound, s.service) err = adapter.Start(s.logger, adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.service)
if err != nil { if err != nil {
return err return err
} }
@@ -570,9 +506,8 @@ func (s *Box) Close() error {
service adapter.Lifecycle service adapter.Lifecycle
}{ }{
{"service", s.service}, {"service", s.service},
{"inbound", s.inbound},
{"certificate-provider", s.certificateProvider},
{"endpoint", s.endpoint}, {"endpoint", s.endpoint},
{"inbound", s.inbound},
{"outbound", s.outbound}, {"outbound", s.outbound},
{"router", s.router}, {"router", s.router},
{"connection", s.connection}, {"connection", s.connection},
@@ -587,14 +522,6 @@ func (s *Box) Close() error {
}) })
s.logger.Trace("close ", closeItem.name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)") s.logger.Trace("close ", closeItem.name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
} }
if s.httpClientService != nil {
s.logger.Trace("close ", s.httpClientService.Name())
startTime := time.Now()
err = E.Append(err, s.httpClientService.Close(), func(err error) error {
return E.Cause(err, "close ", s.httpClientService.Name())
})
s.logger.Trace("close ", s.httpClientService.Name(), " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
}
for _, lifecycleService := range s.internalService { for _, lifecycleService := range s.internalService {
s.logger.Trace("close ", lifecycleService.Name()) s.logger.Trace("close ", lifecycleService.Name())
startTime := time.Now() startTime := time.Now()
@@ -628,10 +555,6 @@ func (s *Box) Outbound() adapter.OutboundManager {
return s.outbound return s.outbound
} }
func (s *Box) Endpoint() adapter.EndpointManager {
return s.endpoint
}
func (s *Box) LogFactory() log.Factory { func (s *Box) LogFactory() log.Factory {
return s.logFactory return s.logFactory
} }

View File

@@ -5,7 +5,6 @@ import (
"io" "io"
"net/http" "net/http"
"os" "os"
"path/filepath"
"strings" "strings"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
@@ -36,9 +35,21 @@ func updateMozillaIncludedRootCAs() error {
return err return err
} }
geoIndex := slices.Index(header, "Geographic Focus") geoIndex := slices.Index(header, "Geographic Focus")
nameIndex := slices.Index(header, "Common Name or Certificate Name")
certIndex := slices.Index(header, "PEM Info") certIndex := slices.Index(header, "PEM Info")
pemBundle := strings.Builder{} generated := strings.Builder{}
generated.WriteString(`// Code generated by 'make update_certificates'. DO NOT EDIT.
package certificate
import "crypto/x509"
var mozillaIncluded *x509.CertPool
func init() {
mozillaIncluded = x509.NewCertPool()
`)
for { for {
record, err := reader.Read() record, err := reader.Read()
if err == io.EOF { if err == io.EOF {
@@ -49,12 +60,18 @@ func updateMozillaIncludedRootCAs() error {
if record[geoIndex] == "China" { if record[geoIndex] == "China" {
continue continue
} }
generated.WriteString("\n // ")
generated.WriteString(record[nameIndex])
generated.WriteString("\n")
generated.WriteString(" mozillaIncluded.AppendCertsFromPEM([]byte(`")
cert := record[certIndex] cert := record[certIndex]
// Remove single quotes
cert = cert[1 : len(cert)-1] cert = cert[1 : len(cert)-1]
pemBundle.WriteString(cert) generated.WriteString(cert)
pemBundle.WriteString("\n") generated.WriteString("`))\n")
} }
return writeGeneratedCertificateBundle("mozilla", "mozillaIncluded", pemBundle.String()) generated.WriteString("}\n")
return os.WriteFile("common/certificate/mozilla.go", []byte(generated.String()), 0o644)
} }
func fetchChinaFingerprints() (map[string]bool, error) { func fetchChinaFingerprints() (map[string]bool, error) {
@@ -102,11 +119,23 @@ func updateChromeIncludedRootCAs() error {
if err != nil { if err != nil {
return err return err
} }
subjectIndex := slices.Index(header, "Subject")
statusIndex := slices.Index(header, "Google Chrome Status") statusIndex := slices.Index(header, "Google Chrome Status")
certIndex := slices.Index(header, "X.509 Certificate (PEM)") certIndex := slices.Index(header, "X.509 Certificate (PEM)")
fingerprintIndex := slices.Index(header, "SHA-256 Fingerprint") fingerprintIndex := slices.Index(header, "SHA-256 Fingerprint")
pemBundle := strings.Builder{} generated := strings.Builder{}
generated.WriteString(`// Code generated by 'make update_certificates'. DO NOT EDIT.
package certificate
import "crypto/x509"
var chromeIncluded *x509.CertPool
func init() {
chromeIncluded = x509.NewCertPool()
`)
for { for {
record, err := reader.Read() record, err := reader.Read()
if err == io.EOF { if err == io.EOF {
@@ -120,39 +149,18 @@ func updateChromeIncludedRootCAs() error {
if chinaFingerprints[record[fingerprintIndex]] { if chinaFingerprints[record[fingerprintIndex]] {
continue continue
} }
generated.WriteString("\n // ")
generated.WriteString(record[subjectIndex])
generated.WriteString("\n")
generated.WriteString(" chromeIncluded.AppendCertsFromPEM([]byte(`")
cert := record[certIndex] cert := record[certIndex]
// Remove single quotes if present
if len(cert) > 0 && cert[0] == '\'' { if len(cert) > 0 && cert[0] == '\'' {
cert = cert[1 : len(cert)-1] cert = cert[1 : len(cert)-1]
} }
pemBundle.WriteString(cert) generated.WriteString(cert)
pemBundle.WriteString("\n") generated.WriteString("`))\n")
} }
return writeGeneratedCertificateBundle("chrome", "chromeIncluded", pemBundle.String()) generated.WriteString("}\n")
} return os.WriteFile("common/certificate/chrome.go", []byte(generated.String()), 0o644)
func writeGeneratedCertificateBundle(name string, variableName string, pemBundle string) error {
goSource := `// Code generated by 'make update_certificates'. DO NOT EDIT.
package certificate
import (
"crypto/x509"
_ "embed"
)
//go:embed ` + name + `.pem
var ` + variableName + `PEM string
var ` + variableName + ` *x509.CertPool
func init() {
` + variableName + ` = x509.NewCertPool()
` + variableName + `.AppendCertsFromPEM([]byte(` + variableName + `PEM))
}
`
err := os.WriteFile(filepath.Join("common/certificate", name+".pem"), []byte(pemBundle), 0o644)
if err != nil {
return err
}
return os.WriteFile(filepath.Join("common/certificate", name+".go"), []byte(goSource), 0o644)
} }

View File

@@ -82,11 +82,6 @@ func compileRuleSet(sourcePath string) error {
} }
func downgradeRuleSetVersion(version uint8, options option.PlainRuleSet) uint8 { func downgradeRuleSetVersion(version uint8, options option.PlainRuleSet) uint8 {
if version == C.RuleSetVersion5 && !rule.HasHeadlessRule(options.Rules, func(rule option.DefaultHeadlessRule) bool {
return len(rule.PackageNameRegex) > 0
}) {
version = C.RuleSetVersion4
}
if version == C.RuleSetVersion4 && !rule.HasHeadlessRule(options.Rules, func(rule option.DefaultHeadlessRule) bool { if version == C.RuleSetVersion4 && !rule.HasHeadlessRule(options.Rules, func(rule option.DefaultHeadlessRule) bool {
return rule.NetworkInterfaceAddress != nil && rule.NetworkInterfaceAddress.Size() > 0 || return rule.NetworkInterfaceAddress != nil && rule.NetworkInterfaceAddress.Size() > 0 ||
len(rule.DefaultInterfaceAddress) > 0 len(rule.DefaultInterfaceAddress) > 0

View File

@@ -1,121 +0,0 @@
package main
import (
"fmt"
"os"
"strings"
"time"
"github.com/sagernet/sing-box/common/networkquality"
"github.com/sagernet/sing-box/log"
"github.com/spf13/cobra"
)
var (
commandNetworkQualityFlagConfigURL string
commandNetworkQualityFlagSerial bool
commandNetworkQualityFlagMaxRuntime int
commandNetworkQualityFlagHTTP3 bool
)
var commandNetworkQuality = &cobra.Command{
Use: "networkquality",
Short: "Run a network quality test",
Run: func(cmd *cobra.Command, args []string) {
err := runNetworkQuality()
if err != nil {
log.Fatal(err)
}
},
}
func init() {
commandNetworkQuality.Flags().StringVar(
&commandNetworkQualityFlagConfigURL,
"config-url", "",
"Network quality test config URL (default: Apple mensura)",
)
commandNetworkQuality.Flags().BoolVar(
&commandNetworkQualityFlagSerial,
"serial", false,
"Run download and upload tests sequentially instead of in parallel",
)
commandNetworkQuality.Flags().IntVar(
&commandNetworkQualityFlagMaxRuntime,
"max-runtime", int(networkquality.DefaultMaxRuntime/time.Second),
"Network quality maximum runtime in seconds",
)
commandNetworkQuality.Flags().BoolVar(
&commandNetworkQualityFlagHTTP3,
"http3", false,
"Use HTTP/3 (QUIC) for measurement traffic",
)
commandTools.AddCommand(commandNetworkQuality)
}
func runNetworkQuality() error {
instance, err := createPreStartedClient()
if err != nil {
return err
}
defer instance.Close()
dialer, err := createDialer(instance, commandToolsFlagOutbound)
if err != nil {
return err
}
httpClient := networkquality.NewHTTPClient(dialer)
defer httpClient.CloseIdleConnections()
measurementClientFactory, err := networkquality.NewOptionalHTTP3Factory(dialer, commandNetworkQualityFlagHTTP3)
if err != nil {
return err
}
fmt.Fprintln(os.Stderr, "==== NETWORK QUALITY TEST ====")
result, err := networkquality.Run(networkquality.Options{
ConfigURL: commandNetworkQualityFlagConfigURL,
HTTPClient: httpClient,
NewMeasurementClient: measurementClientFactory,
Serial: commandNetworkQualityFlagSerial,
MaxRuntime: time.Duration(commandNetworkQualityFlagMaxRuntime) * time.Second,
Context: globalCtx,
OnProgress: func(p networkquality.Progress) {
if !commandNetworkQualityFlagSerial && p.Phase != networkquality.PhaseIdle {
fmt.Fprintf(os.Stderr, "\rDownload: %s RPM: %d Upload: %s RPM: %d",
networkquality.FormatBitrate(p.DownloadCapacity), p.DownloadRPM,
networkquality.FormatBitrate(p.UploadCapacity), p.UploadRPM)
return
}
switch networkquality.Phase(p.Phase) {
case networkquality.PhaseIdle:
if p.IdleLatencyMs > 0 {
fmt.Fprintf(os.Stderr, "\rIdle Latency: %d ms", p.IdleLatencyMs)
} else {
fmt.Fprint(os.Stderr, "\rMeasuring idle latency...")
}
case networkquality.PhaseDownload:
fmt.Fprintf(os.Stderr, "\rDownload: %s RPM: %d",
networkquality.FormatBitrate(p.DownloadCapacity), p.DownloadRPM)
case networkquality.PhaseUpload:
fmt.Fprintf(os.Stderr, "\rUpload: %s RPM: %d",
networkquality.FormatBitrate(p.UploadCapacity), p.UploadRPM)
}
},
})
if err != nil {
return err
}
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, strings.Repeat("-", 40))
fmt.Fprintf(os.Stderr, "Idle Latency: %d ms\n", result.IdleLatencyMs)
fmt.Fprintf(os.Stderr, "Download Capacity: %-20s Accuracy: %s\n", networkquality.FormatBitrate(result.DownloadCapacity), result.DownloadCapacityAccuracy)
fmt.Fprintf(os.Stderr, "Upload Capacity: %-20s Accuracy: %s\n", networkquality.FormatBitrate(result.UploadCapacity), result.UploadCapacityAccuracy)
fmt.Fprintf(os.Stderr, "Download Responsiveness: %-20s Accuracy: %s\n", fmt.Sprintf("%d RPM", result.DownloadRPM), result.DownloadRPMAccuracy)
fmt.Fprintf(os.Stderr, "Upload Responsiveness: %-20s Accuracy: %s\n", fmt.Sprintf("%d RPM", result.UploadRPM), result.UploadRPMAccuracy)
return nil
}

View File

@@ -1,79 +0,0 @@
package main
import (
"fmt"
"os"
"github.com/sagernet/sing-box/common/stun"
"github.com/sagernet/sing-box/log"
"github.com/spf13/cobra"
)
var commandSTUNFlagServer string
var commandSTUN = &cobra.Command{
Use: "stun",
Short: "Run a STUN test",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
err := runSTUN()
if err != nil {
log.Fatal(err)
}
},
}
func init() {
commandSTUN.Flags().StringVarP(&commandSTUNFlagServer, "server", "s", stun.DefaultServer, "STUN server address")
commandTools.AddCommand(commandSTUN)
}
func runSTUN() error {
instance, err := createPreStartedClient()
if err != nil {
return err
}
defer instance.Close()
dialer, err := createDialer(instance, commandToolsFlagOutbound)
if err != nil {
return err
}
fmt.Fprintln(os.Stderr, "==== STUN TEST ====")
result, err := stun.Run(stun.Options{
Server: commandSTUNFlagServer,
Dialer: dialer,
Context: globalCtx,
OnProgress: func(p stun.Progress) {
switch p.Phase {
case stun.PhaseBinding:
if p.ExternalAddr != "" {
fmt.Fprintf(os.Stderr, "\rExternal Address: %s (%d ms)", p.ExternalAddr, p.LatencyMs)
} else {
fmt.Fprint(os.Stderr, "\rSending binding request...")
}
case stun.PhaseNATMapping:
fmt.Fprint(os.Stderr, "\rDetecting NAT mapping behavior...")
case stun.PhaseNATFiltering:
fmt.Fprint(os.Stderr, "\rDetecting NAT filtering behavior...")
}
},
})
if err != nil {
return err
}
fmt.Fprintln(os.Stderr)
fmt.Fprintf(os.Stderr, "External Address: %s\n", result.ExternalAddr)
fmt.Fprintf(os.Stderr, "Latency: %d ms\n", result.LatencyMs)
if result.NATTypeSupported {
fmt.Fprintf(os.Stderr, "NAT Mapping: %s\n", result.NATMapping)
fmt.Fprintf(os.Stderr, "NAT Filtering: %s\n", result.NATFiltering)
} else {
fmt.Fprintln(os.Stderr, "NAT Type Detection: not supported by server")
}
return nil
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -22,10 +22,8 @@ var _ adapter.CertificateStore = (*Store)(nil)
type Store struct { type Store struct {
access sync.RWMutex access sync.RWMutex
store string
systemPool *x509.CertPool systemPool *x509.CertPool
currentPool *x509.CertPool currentPool *x509.CertPool
currentPEM []string
certificate string certificate string
certificatePaths []string certificatePaths []string
certificateDirectoryPaths []string certificateDirectoryPaths []string
@@ -63,7 +61,6 @@ func NewStore(ctx context.Context, logger logger.Logger, options option.Certific
return nil, E.New("unknown certificate store: ", options.Store) return nil, E.New("unknown certificate store: ", options.Store)
} }
store := &Store{ store := &Store{
store: options.Store,
systemPool: systemPool, systemPool: systemPool,
certificate: strings.Join(options.Certificate, "\n"), certificate: strings.Join(options.Certificate, "\n"),
certificatePaths: options.CertificatePath, certificatePaths: options.CertificatePath,
@@ -126,37 +123,19 @@ func (s *Store) Pool() *x509.CertPool {
return s.currentPool return s.currentPool
} }
func (s *Store) StoreKind() string {
return s.store
}
func (s *Store) CurrentPEM() []string {
s.access.RLock()
defer s.access.RUnlock()
return append([]string(nil), s.currentPEM...)
}
func (s *Store) update() error { func (s *Store) update() error {
s.access.Lock() s.access.Lock()
defer s.access.Unlock() defer s.access.Unlock()
var currentPool *x509.CertPool var currentPool *x509.CertPool
var currentPEM []string
if s.systemPool == nil { if s.systemPool == nil {
currentPool = x509.NewCertPool() currentPool = x509.NewCertPool()
} else { } else {
currentPool = s.systemPool.Clone() currentPool = s.systemPool.Clone()
} }
switch s.store {
case C.CertificateStoreMozilla:
currentPEM = append(currentPEM, mozillaIncludedPEM)
case C.CertificateStoreChrome:
currentPEM = append(currentPEM, chromeIncludedPEM)
}
if s.certificate != "" { if s.certificate != "" {
if !currentPool.AppendCertsFromPEM([]byte(s.certificate)) { if !currentPool.AppendCertsFromPEM([]byte(s.certificate)) {
return E.New("invalid certificate PEM strings") return E.New("invalid certificate PEM strings")
} }
currentPEM = append(currentPEM, s.certificate)
} }
for _, path := range s.certificatePaths { for _, path := range s.certificatePaths {
pemContent, err := os.ReadFile(path) pemContent, err := os.ReadFile(path)
@@ -166,7 +145,6 @@ func (s *Store) update() error {
if !currentPool.AppendCertsFromPEM(pemContent) { if !currentPool.AppendCertsFromPEM(pemContent) {
return E.New("invalid certificate PEM file: ", path) return E.New("invalid certificate PEM file: ", path)
} }
currentPEM = append(currentPEM, string(pemContent))
} }
var firstErr error var firstErr error
for _, directoryPath := range s.certificateDirectoryPaths { for _, directoryPath := range s.certificateDirectoryPaths {
@@ -179,8 +157,8 @@ func (s *Store) update() error {
} }
for _, directoryEntry := range directoryEntries { for _, directoryEntry := range directoryEntries {
pemContent, err := os.ReadFile(filepath.Join(directoryPath, directoryEntry.Name())) pemContent, err := os.ReadFile(filepath.Join(directoryPath, directoryEntry.Name()))
if err == nil && currentPool.AppendCertsFromPEM(pemContent) { if err == nil {
currentPEM = append(currentPEM, string(pemContent)) currentPool.AppendCertsFromPEM(pemContent)
} }
} }
} }
@@ -188,7 +166,6 @@ func (s *Store) update() error {
return firstErr return firstErr
} }
s.currentPool = currentPool s.currentPool = currentPool
s.currentPEM = currentPEM
return nil return nil
} }

View File

@@ -149,10 +149,7 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
} else { } else {
dialer.Timeout = C.TCPConnectTimeout dialer.Timeout = C.TCPConnectTimeout
} }
if options.DisableTCPKeepAlive { if !options.DisableTCPKeepAlive {
dialer.KeepAlive = -1
dialer.KeepAliveConfig.Enable = false
} else {
keepIdle := time.Duration(options.TCPKeepAlive) keepIdle := time.Duration(options.TCPKeepAlive)
if keepIdle == 0 { if keepIdle == 0 {
keepIdle = C.TCPKeepAliveInitial keepIdle = C.TCPKeepAliveInitial
@@ -242,7 +239,7 @@ func setMarkWrapper(networkManager adapter.NetworkManager, mark uint32, isDefaul
func (d *DefaultDialer) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) { func (d *DefaultDialer) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) {
if !address.IsValid() { if !address.IsValid() {
return nil, E.New("invalid address") return nil, E.New("invalid address")
} else if address.IsDomain() { } else if address.IsFqdn() {
return nil, E.New("domain not resolved") return nil, E.New("domain not resolved")
} }
if d.networkStrategy == nil { if d.networkStrategy == nil {
@@ -332,9 +329,9 @@ func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksadd
func (d *DefaultDialer) DialerForICMPDestination(destination netip.Addr) net.Dialer { func (d *DefaultDialer) DialerForICMPDestination(destination netip.Addr) net.Dialer {
if !destination.Is6() { if !destination.Is6() {
return d.dialer4.Dialer
} else {
return d.dialer6.Dialer return d.dialer6.Dialer
} else {
return d.dialer4.Dialer
} }
} }

View File

@@ -19,7 +19,6 @@ type DirectDialer interface {
type DetourDialer struct { type DetourDialer struct {
outboundManager adapter.OutboundManager outboundManager adapter.OutboundManager
detour string detour string
defaultOutbound bool
legacyDNSDialer bool legacyDNSDialer bool
dialer N.Dialer dialer N.Dialer
initOnce sync.Once initOnce sync.Once
@@ -34,13 +33,6 @@ func NewDetour(outboundManager adapter.OutboundManager, detour string, legacyDNS
} }
} }
func NewDefaultOutboundDetour(outboundManager adapter.OutboundManager) N.Dialer {
return &DetourDialer{
outboundManager: outboundManager,
defaultOutbound: true,
}
}
func InitializeDetour(dialer N.Dialer) error { func InitializeDetour(dialer N.Dialer) error {
detourDialer, isDetour := common.Cast[*DetourDialer](dialer) detourDialer, isDetour := common.Cast[*DetourDialer](dialer)
if !isDetour { if !isDetour {
@@ -55,18 +47,12 @@ func (d *DetourDialer) Dialer() (N.Dialer, error) {
} }
func (d *DetourDialer) init() { func (d *DetourDialer) init() {
var dialer adapter.Outbound dialer, loaded := d.outboundManager.Outbound(d.detour)
if d.detour != "" { if !loaded {
var loaded bool d.initErr = E.New("outbound detour not found: ", d.detour)
dialer, loaded = d.outboundManager.Outbound(d.detour) return
if !loaded {
d.initErr = E.New("outbound detour not found: ", d.detour)
return
}
} else {
dialer = d.outboundManager.Default()
} }
if !d.defaultOutbound && !d.legacyDNSDialer { if !d.legacyDNSDialer {
if directDialer, isDirect := dialer.(DirectDialer); isDirect { if directDialer, isDirect := dialer.(DirectDialer); isDirect {
if directDialer.IsEmpty() { if directDialer.IsEmpty() {
d.initErr = E.New("detour to an empty direct outbound makes no sense") d.initErr = E.New("detour to an empty direct outbound makes no sense")

View File

@@ -25,7 +25,6 @@ type Options struct {
NewDialer bool NewDialer bool
LegacyDNSDialer bool LegacyDNSDialer bool
DirectOutbound bool DirectOutbound bool
DefaultOutbound bool
} }
// TODO: merge with NewWithOptions // TODO: merge with NewWithOptions
@@ -43,26 +42,19 @@ func NewWithOptions(options Options) (N.Dialer, error) {
dialer N.Dialer dialer N.Dialer
err error err error
) )
hasDetour := dialOptions.Detour != "" || options.DefaultOutbound
if dialOptions.Detour != "" { if dialOptions.Detour != "" {
outboundManager := service.FromContext[adapter.OutboundManager](options.Context) outboundManager := service.FromContext[adapter.OutboundManager](options.Context)
if outboundManager == nil { if outboundManager == nil {
return nil, E.New("missing outbound manager") return nil, E.New("missing outbound manager")
} }
dialer = NewDetour(outboundManager, dialOptions.Detour, options.LegacyDNSDialer) dialer = NewDetour(outboundManager, dialOptions.Detour, options.LegacyDNSDialer)
} else if options.DefaultOutbound {
outboundManager := service.FromContext[adapter.OutboundManager](options.Context)
if outboundManager == nil {
return nil, E.New("missing outbound manager")
}
dialer = NewDefaultOutboundDetour(outboundManager)
} else { } else {
dialer, err = NewDefault(options.Context, dialOptions) dialer, err = NewDefault(options.Context, dialOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
if options.RemoteIsDomain && (!hasDetour || options.ResolverOnDetour || dialOptions.DomainResolver != nil && dialOptions.DomainResolver.Server != "") { if options.RemoteIsDomain && (dialOptions.Detour == "" || options.ResolverOnDetour || dialOptions.DomainResolver != nil && dialOptions.DomainResolver.Server != "") {
networkManager := service.FromContext[adapter.NetworkManager](options.Context) networkManager := service.FromContext[adapter.NetworkManager](options.Context)
dnsTransport := service.FromContext[adapter.DNSTransportManager](options.Context) dnsTransport := service.FromContext[adapter.DNSTransportManager](options.Context)
var defaultOptions adapter.NetworkOptions var defaultOptions adapter.NetworkOptions
@@ -95,12 +87,11 @@ func NewWithOptions(options Options) (N.Dialer, error) {
} }
server = dialOptions.DomainResolver.Server server = dialOptions.DomainResolver.Server
dnsQueryOptions = adapter.DNSQueryOptions{ dnsQueryOptions = adapter.DNSQueryOptions{
Transport: transport, Transport: transport,
Strategy: strategy, Strategy: strategy,
DisableCache: dialOptions.DomainResolver.DisableCache, DisableCache: dialOptions.DomainResolver.DisableCache,
DisableOptimisticCache: dialOptions.DomainResolver.DisableOptimisticCache, RewriteTTL: dialOptions.DomainResolver.RewriteTTL,
RewriteTTL: dialOptions.DomainResolver.RewriteTTL, ClientSubnet: dialOptions.DomainResolver.ClientSubnet.Build(netip.Prefix{}),
ClientSubnet: dialOptions.DomainResolver.ClientSubnet.Build(netip.Prefix{}),
} }
resolveFallbackDelay = time.Duration(dialOptions.FallbackDelay) resolveFallbackDelay = time.Duration(dialOptions.FallbackDelay)
} else if options.DirectResolver { } else if options.DirectResolver {

View File

@@ -96,7 +96,7 @@ func (d *resolveDialer) DialContext(ctx context.Context, network string, destina
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !destination.IsDomain() { if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination) return d.dialer.DialContext(ctx, network, destination)
} }
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug) ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
@@ -116,7 +116,7 @@ func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !destination.IsDomain() { if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination) return d.dialer.ListenPacket(ctx, destination)
} }
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug) ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
@@ -144,7 +144,7 @@ func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !destination.IsDomain() { if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination) return d.dialer.DialContext(ctx, network, destination)
} }
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug) ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
@@ -167,7 +167,7 @@ func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.C
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !destination.IsDomain() { if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination) return d.dialer.ListenPacket(ctx, destination)
} }
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug) ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)

View File

@@ -1,154 +0,0 @@
package httpclient
import (
"context"
"io"
"net/http"
"time"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/tls"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network"
)
type httpTransport interface {
http.RoundTripper
CloseIdleConnections()
Clone() httpTransport
}
type Client struct {
transport httpTransport
headers http.Header
host string
tag string
}
func NewClient(ctx context.Context, logger logger.ContextLogger, tag string, options option.HTTPClientOptions) (*Client, error) {
rawDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: options.DialerOptions,
RemoteIsDomain: true,
ResolverOnDetour: options.ResolveOnDetour,
NewDialer: options.ResolveOnDetour,
DefaultOutbound: options.DefaultOutbound,
})
if err != nil {
return nil, err
}
tlsOptions := common.PtrValueOrDefault(options.TLS)
tlsOptions.Enabled = true
baseTLSConfig, err := tls.NewClientWithOptions(tls.ClientOptions{
Context: ctx,
Logger: logger,
Options: tlsOptions,
AllowEmptyServerName: true,
})
if err != nil {
return nil, err
}
return NewClientWithDialer(rawDialer, baseTLSConfig, tag, options)
}
func NewClientWithDialer(rawDialer N.Dialer, baseTLSConfig tls.Config, tag string, options option.HTTPClientOptions) (*Client, error) {
headers := options.Headers.Build()
host := headers.Get("Host")
headers.Del("Host")
transport, err := newTransport(rawDialer, baseTLSConfig, options)
if err != nil {
return nil, err
}
return &Client{
transport: transport,
headers: headers,
host: host,
tag: tag,
}, nil
}
func newTransport(rawDialer N.Dialer, baseTLSConfig tls.Config, options option.HTTPClientOptions) (httpTransport, error) {
version := options.Version
if version == 0 {
version = 2
}
fallbackDelay := time.Duration(options.DialerOptions.FallbackDelay)
if fallbackDelay == 0 {
fallbackDelay = 300 * time.Millisecond
}
var transport httpTransport
var err error
switch version {
case 1:
transport = newHTTP1Transport(rawDialer, baseTLSConfig)
case 2:
if options.DisableVersionFallback {
transport, err = newHTTP2Transport(rawDialer, baseTLSConfig, options.HTTP2Options)
} else {
transport, err = newHTTP2FallbackTransport(rawDialer, baseTLSConfig, options.HTTP2Options)
}
case 3:
if baseTLSConfig != nil {
_, err = baseTLSConfig.STDConfig()
if err != nil {
return nil, err
}
}
if options.DisableVersionFallback {
transport, err = newHTTP3Transport(rawDialer, baseTLSConfig, options.HTTP3Options)
} else {
var h2Fallback httpTransport
h2Fallback, err = newHTTP2FallbackTransport(rawDialer, baseTLSConfig, options.HTTP2Options)
if err != nil {
return nil, err
}
transport, err = newHTTP3FallbackTransport(rawDialer, baseTLSConfig, h2Fallback, options.HTTP3Options, fallbackDelay)
}
default:
return nil, E.New("unknown HTTP version: ", version)
}
if err != nil {
return nil, err
}
return transport, nil
}
func (c *Client) RoundTrip(request *http.Request) (*http.Response, error) {
if c.tag == "" && len(c.headers) == 0 && c.host == "" {
return c.transport.RoundTrip(request)
}
if c.tag != "" {
if transportTag, loaded := transportTagFromContext(request.Context()); loaded && transportTag == c.tag {
return nil, E.New("HTTP request loopback in transport[", c.tag, "]")
}
request = request.Clone(contextWithTransportTag(request.Context(), c.tag))
} else {
request = request.Clone(request.Context())
}
applyHeaders(request, c.headers, c.host)
return c.transport.RoundTrip(request)
}
func (c *Client) CloseIdleConnections() {
c.transport.CloseIdleConnections()
}
func (c *Client) Clone() *Client {
return &Client{
transport: c.transport.Clone(),
headers: c.headers.Clone(),
host: c.host,
tag: c.tag,
}
}
func (c *Client) Close() error {
c.CloseIdleConnections()
if closer, isCloser := c.transport.(io.Closer); isCloser {
return closer.Close()
}
return nil
}

View File

@@ -1,14 +0,0 @@
package httpclient
import "context"
type transportKey struct{}
func contextWithTransportTag(ctx context.Context, transportTag string) context.Context {
return context.WithValue(ctx, transportKey{}, transportTag)
}
func transportTagFromContext(ctx context.Context) (string, bool) {
value, loaded := ctx.Value(transportKey{}).(string)
return value, loaded
}

View File

@@ -1,86 +0,0 @@
package httpclient
import (
"context"
stdTLS "crypto/tls"
"io"
"net"
"net/http"
"strings"
"github.com/sagernet/sing-box/common/tls"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func dialTLS(ctx context.Context, rawDialer N.Dialer, baseTLSConfig tls.Config, destination M.Socksaddr, nextProtos []string, expectProto string) (net.Conn, error) {
if baseTLSConfig == nil {
return nil, E.New("TLS transport unavailable")
}
tlsConfig := baseTLSConfig.Clone()
if tlsConfig.ServerName() == "" && destination.IsValid() {
tlsConfig.SetServerName(destination.AddrString())
}
tlsConfig.SetNextProtos(nextProtos)
conn, err := rawDialer.DialContext(ctx, N.NetworkTCP, destination)
if err != nil {
return nil, err
}
tlsConn, err := tls.ClientHandshake(ctx, conn, tlsConfig)
if err != nil {
conn.Close()
return nil, err
}
if expectProto != "" && tlsConn.ConnectionState().NegotiatedProtocol != expectProto {
tlsConn.Close()
return nil, errHTTP2Fallback
}
return tlsConn, nil
}
func applyHeaders(request *http.Request, headers http.Header, host string) {
for header, values := range headers {
request.Header[header] = append([]string(nil), values...)
}
if host != "" {
request.Host = host
}
}
func requestRequiresHTTP1(request *http.Request) bool {
return strings.Contains(strings.ToLower(request.Header.Get("Connection")), "upgrade") &&
strings.EqualFold(request.Header.Get("Upgrade"), "websocket")
}
func requestReplayable(request *http.Request) bool {
return request.Body == nil || request.Body == http.NoBody || request.GetBody != nil
}
func cloneRequestForRetry(request *http.Request) *http.Request {
cloned := request.Clone(request.Context())
if request.Body != nil && request.Body != http.NoBody && request.GetBody != nil {
cloned.Body = mustGetBody(request)
}
return cloned
}
func mustGetBody(request *http.Request) io.ReadCloser {
body, err := request.GetBody()
if err != nil {
panic(err)
}
return body
}
func buildSTDTLSConfig(baseTLSConfig tls.Config, destination M.Socksaddr, nextProtos []string) (*stdTLS.Config, error) {
if baseTLSConfig == nil {
return nil, nil
}
tlsConfig := baseTLSConfig.Clone()
if tlsConfig.ServerName() == "" && destination.IsValid() {
tlsConfig.SetServerName(destination.AddrString())
}
tlsConfig.SetNextProtos(nextProtos)
return tlsConfig.STDConfig()
}

View File

@@ -1,41 +0,0 @@
package httpclient
import (
"context"
"net"
"net/http"
"github.com/sagernet/sing-box/common/tls"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type http1Transport struct {
transport *http.Transport
}
func newHTTP1Transport(rawDialer N.Dialer, baseTLSConfig tls.Config) *http1Transport {
transport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return rawDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
}
if baseTLSConfig != nil {
transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialTLS(ctx, rawDialer, baseTLSConfig, M.ParseSocksaddr(addr), []string{"http/1.1"}, "")
}
}
return &http1Transport{transport: transport}
}
func (t *http1Transport) RoundTrip(request *http.Request) (*http.Response, error) {
return t.transport.RoundTrip(request)
}
func (t *http1Transport) CloseIdleConnections() {
t.transport.CloseIdleConnections()
}
func (t *http1Transport) Clone() httpTransport {
return &http1Transport{transport: t.transport.Clone()}
}

View File

@@ -1,42 +0,0 @@
package httpclient
import (
stdTLS "crypto/tls"
"net/http"
"time"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"golang.org/x/net/http2"
)
func CloneHTTP2Transport(transport *http2.Transport) *http2.Transport {
return &http2.Transport{
ReadIdleTimeout: transport.ReadIdleTimeout,
PingTimeout: transport.PingTimeout,
DialTLSContext: transport.DialTLSContext,
}
}
func ConfigureHTTP2Transport(options option.HTTP2Options) (*http2.Transport, error) {
stdTransport := &http.Transport{
TLSClientConfig: &stdTLS.Config{},
HTTP2: &http.HTTP2Config{
MaxReceiveBufferPerStream: int(options.StreamReceiveWindow.Value()),
MaxReceiveBufferPerConnection: int(options.ConnectionReceiveWindow.Value()),
MaxConcurrentStreams: options.MaxConcurrentStreams,
SendPingTimeout: time.Duration(options.KeepAlivePeriod),
PingTimeout: time.Duration(options.IdleTimeout),
},
}
h2Transport, err := http2.ConfigureTransports(stdTransport)
if err != nil {
return nil, E.Cause(err, "configure HTTP/2 transport")
}
// ConfigureTransports binds ConnPool to the throwaway http.Transport; sever it so DialTLSContext is used directly.
h2Transport.ConnPool = nil
h2Transport.ReadIdleTimeout = time.Duration(options.KeepAlivePeriod)
h2Transport.PingTimeout = time.Duration(options.IdleTimeout)
return h2Transport, nil
}

View File

@@ -1,87 +0,0 @@
package httpclient
import (
"context"
stdTLS "crypto/tls"
"errors"
"net"
"net/http"
"sync/atomic"
"github.com/sagernet/sing-box/common/tls"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"golang.org/x/net/http2"
)
var errHTTP2Fallback = E.New("fallback to HTTP/1.1")
type http2FallbackTransport struct {
h2Transport *http2.Transport
h1Transport *http1Transport
h2Fallback *atomic.Bool
}
func newHTTP2FallbackTransport(rawDialer N.Dialer, baseTLSConfig tls.Config, options option.HTTP2Options) (*http2FallbackTransport, error) {
h1 := newHTTP1Transport(rawDialer, baseTLSConfig)
var fallback atomic.Bool
h2Transport, err := ConfigureHTTP2Transport(options)
if err != nil {
return nil, err
}
h2Transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *stdTLS.Config) (net.Conn, error) {
conn, dialErr := dialTLS(ctx, rawDialer, baseTLSConfig, M.ParseSocksaddr(addr), []string{http2.NextProtoTLS, "http/1.1"}, http2.NextProtoTLS)
if dialErr != nil {
if errors.Is(dialErr, errHTTP2Fallback) {
fallback.Store(true)
}
return nil, dialErr
}
return conn, nil
}
return &http2FallbackTransport{
h2Transport: h2Transport,
h1Transport: h1,
h2Fallback: &fallback,
}, nil
}
func (t *http2FallbackTransport) RoundTrip(request *http.Request) (*http.Response, error) {
return t.roundTrip(request, true)
}
func (t *http2FallbackTransport) roundTrip(request *http.Request, allowHTTP1Fallback bool) (*http.Response, error) {
if request.URL.Scheme != "https" || requestRequiresHTTP1(request) {
return t.h1Transport.RoundTrip(request)
}
if t.h2Fallback.Load() {
if !allowHTTP1Fallback {
return nil, errHTTP2Fallback
}
return t.h1Transport.RoundTrip(request)
}
response, err := t.h2Transport.RoundTrip(request)
if err == nil {
return response, nil
}
if !errors.Is(err, errHTTP2Fallback) || !allowHTTP1Fallback {
return nil, err
}
return t.h1Transport.RoundTrip(cloneRequestForRetry(request))
}
func (t *http2FallbackTransport) CloseIdleConnections() {
t.h1Transport.CloseIdleConnections()
t.h2Transport.CloseIdleConnections()
}
func (t *http2FallbackTransport) Clone() httpTransport {
return &http2FallbackTransport{
h2Transport: CloneHTTP2Transport(t.h2Transport),
h1Transport: t.h1Transport.Clone().(*http1Transport),
h2Fallback: t.h2Fallback,
}
}

View File

@@ -1,54 +0,0 @@
package httpclient
import (
"context"
stdTLS "crypto/tls"
"net"
"net/http"
"github.com/sagernet/sing-box/common/tls"
"github.com/sagernet/sing-box/option"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"golang.org/x/net/http2"
)
type http2Transport struct {
h2Transport *http2.Transport
h1Transport *http1Transport
}
func newHTTP2Transport(rawDialer N.Dialer, baseTLSConfig tls.Config, options option.HTTP2Options) (*http2Transport, error) {
h1 := newHTTP1Transport(rawDialer, baseTLSConfig)
h2Transport, err := ConfigureHTTP2Transport(options)
if err != nil {
return nil, err
}
h2Transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *stdTLS.Config) (net.Conn, error) {
return dialTLS(ctx, rawDialer, baseTLSConfig, M.ParseSocksaddr(addr), []string{http2.NextProtoTLS}, http2.NextProtoTLS)
}
return &http2Transport{
h2Transport: h2Transport,
h1Transport: h1,
}, nil
}
func (t *http2Transport) RoundTrip(request *http.Request) (*http.Response, error) {
if request.URL.Scheme != "https" || requestRequiresHTTP1(request) {
return t.h1Transport.RoundTrip(request)
}
return t.h2Transport.RoundTrip(request)
}
func (t *http2Transport) CloseIdleConnections() {
t.h1Transport.CloseIdleConnections()
t.h2Transport.CloseIdleConnections()
}
func (t *http2Transport) Clone() httpTransport {
return &http2Transport{
h2Transport: CloneHTTP2Transport(t.h2Transport),
h1Transport: t.h1Transport.Clone().(*http1Transport),
}
}

View File

@@ -1,311 +0,0 @@
//go:build with_quic
package httpclient
import (
"context"
stdTLS "crypto/tls"
"errors"
"net/http"
"sync"
"time"
"github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/http3"
"github.com/sagernet/sing-box/common/tls"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type http3Transport struct {
h3Transport *http3.Transport
}
type http3FallbackTransport struct {
h3Transport *http3.Transport
h2Fallback httpTransport
fallbackDelay time.Duration
brokenAccess sync.Mutex
brokenUntil time.Time
brokenBackoff time.Duration
}
func newHTTP3RoundTripper(
rawDialer N.Dialer,
baseTLSConfig tls.Config,
options option.QUICOptions,
) *http3.Transport {
var handshakeTimeout time.Duration
if baseTLSConfig != nil {
handshakeTimeout = baseTLSConfig.HandshakeTimeout()
}
quicConfig := &quic.Config{
InitialStreamReceiveWindow: options.StreamReceiveWindow.Value(),
MaxStreamReceiveWindow: options.StreamReceiveWindow.Value(),
InitialConnectionReceiveWindow: options.ConnectionReceiveWindow.Value(),
MaxConnectionReceiveWindow: options.ConnectionReceiveWindow.Value(),
KeepAlivePeriod: time.Duration(options.KeepAlivePeriod),
MaxIdleTimeout: time.Duration(options.IdleTimeout),
DisablePathMTUDiscovery: options.DisablePathMTUDiscovery,
}
if options.InitialPacketSize > 0 {
quicConfig.InitialPacketSize = uint16(options.InitialPacketSize)
}
if options.MaxConcurrentStreams > 0 {
quicConfig.MaxIncomingStreams = int64(options.MaxConcurrentStreams)
}
if handshakeTimeout > 0 {
quicConfig.HandshakeIdleTimeout = handshakeTimeout
}
h3Transport := &http3.Transport{
TLSClientConfig: &stdTLS.Config{},
QUICConfig: quicConfig,
Dial: func(ctx context.Context, addr string, tlsConfig *stdTLS.Config, quicConfig *quic.Config) (*quic.Conn, error) {
if handshakeTimeout > 0 && quicConfig.HandshakeIdleTimeout == 0 {
quicConfig = quicConfig.Clone()
quicConfig.HandshakeIdleTimeout = handshakeTimeout
}
if baseTLSConfig != nil {
var err error
tlsConfig, err = buildSTDTLSConfig(baseTLSConfig, M.ParseSocksaddr(addr), []string{http3.NextProtoH3})
if err != nil {
return nil, err
}
} else {
tlsConfig = tlsConfig.Clone()
tlsConfig.NextProtos = []string{http3.NextProtoH3}
}
conn, err := rawDialer.DialContext(ctx, N.NetworkUDP, M.ParseSocksaddr(addr))
if err != nil {
return nil, err
}
quicConn, err := quic.DialEarly(ctx, bufio.NewUnbindPacketConn(conn), conn.RemoteAddr(), tlsConfig, quicConfig)
if err != nil {
conn.Close()
return nil, err
}
return quicConn, nil
},
}
return h3Transport
}
func newHTTP3Transport(
rawDialer N.Dialer,
baseTLSConfig tls.Config,
options option.QUICOptions,
) (httpTransport, error) {
return &http3Transport{
h3Transport: newHTTP3RoundTripper(rawDialer, baseTLSConfig, options),
}, nil
}
func newHTTP3FallbackTransport(
rawDialer N.Dialer,
baseTLSConfig tls.Config,
h2Fallback httpTransport,
options option.QUICOptions,
fallbackDelay time.Duration,
) (httpTransport, error) {
return &http3FallbackTransport{
h3Transport: newHTTP3RoundTripper(rawDialer, baseTLSConfig, options),
h2Fallback: h2Fallback,
fallbackDelay: fallbackDelay,
}, nil
}
func (t *http3Transport) RoundTrip(request *http.Request) (*http.Response, error) {
return t.h3Transport.RoundTrip(request)
}
func (t *http3Transport) CloseIdleConnections() {
t.h3Transport.CloseIdleConnections()
}
func (t *http3Transport) Close() error {
t.CloseIdleConnections()
return t.h3Transport.Close()
}
func (t *http3Transport) Clone() httpTransport {
return &http3Transport{
h3Transport: t.h3Transport,
}
}
func (t *http3FallbackTransport) RoundTrip(request *http.Request) (*http.Response, error) {
if request.URL.Scheme != "https" || requestRequiresHTTP1(request) {
return t.h2Fallback.RoundTrip(request)
}
return t.roundTripHTTP3(request)
}
func (t *http3FallbackTransport) roundTripHTTP3(request *http.Request) (*http.Response, error) {
if t.h3Broken() {
return t.h2FallbackRoundTrip(request)
}
response, err := t.h3Transport.RoundTripOpt(request, http3.RoundTripOpt{OnlyCachedConn: true})
if err == nil {
t.clearH3Broken()
return response, nil
}
if !errors.Is(err, http3.ErrNoCachedConn) {
t.markH3Broken()
return t.h2FallbackRoundTrip(cloneRequestForRetry(request))
}
if !requestReplayable(request) {
response, err = t.h3Transport.RoundTrip(request)
if err == nil {
t.clearH3Broken()
return response, nil
}
t.markH3Broken()
return nil, err
}
return t.roundTripHTTP3Race(request)
}
func (t *http3FallbackTransport) roundTripHTTP3Race(request *http.Request) (*http.Response, error) {
ctx, cancel := context.WithCancel(request.Context())
defer cancel()
type result struct {
response *http.Response
err error
h3 bool
}
results := make(chan result, 2)
startRoundTrip := func(request *http.Request, useH3 bool) {
request = request.WithContext(ctx)
var (
response *http.Response
err error
)
if useH3 {
response, err = t.h3Transport.RoundTrip(request)
} else {
response, err = t.h2FallbackRoundTrip(request)
}
results <- result{response: response, err: err, h3: useH3}
}
goroutines := 1
received := 0
drainRemaining := func() {
cancel()
for range goroutines - received {
go func() {
loser := <-results
if loser.response != nil && loser.response.Body != nil {
loser.response.Body.Close()
}
}()
}
}
go startRoundTrip(cloneRequestForRetry(request), true)
timer := time.NewTimer(t.fallbackDelay)
defer timer.Stop()
var (
h3Err error
fallbackErr error
)
for {
select {
case <-timer.C:
if goroutines == 1 {
goroutines++
go startRoundTrip(cloneRequestForRetry(request), false)
}
case raceResult := <-results:
received++
if raceResult.err == nil {
if raceResult.h3 {
t.clearH3Broken()
}
drainRemaining()
return raceResult.response, nil
}
if raceResult.h3 {
t.markH3Broken()
h3Err = raceResult.err
if goroutines == 1 {
goroutines++
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
go startRoundTrip(cloneRequestForRetry(request), false)
}
} else {
fallbackErr = raceResult.err
}
if received < goroutines {
continue
}
drainRemaining()
switch {
case h3Err != nil && fallbackErr != nil:
return nil, E.Errors(h3Err, fallbackErr)
case fallbackErr != nil:
return nil, fallbackErr
default:
return nil, h3Err
}
}
}
}
func (t *http3FallbackTransport) h2FallbackRoundTrip(request *http.Request) (*http.Response, error) {
if fallback, isFallback := t.h2Fallback.(*http2FallbackTransport); isFallback {
return fallback.roundTrip(request, true)
}
return t.h2Fallback.RoundTrip(request)
}
func (t *http3FallbackTransport) CloseIdleConnections() {
t.h3Transport.CloseIdleConnections()
t.h2Fallback.CloseIdleConnections()
}
func (t *http3FallbackTransport) Close() error {
t.CloseIdleConnections()
return t.h3Transport.Close()
}
func (t *http3FallbackTransport) Clone() httpTransport {
return &http3FallbackTransport{
h3Transport: t.h3Transport,
h2Fallback: t.h2Fallback.Clone(),
fallbackDelay: t.fallbackDelay,
}
}
func (t *http3FallbackTransport) h3Broken() bool {
t.brokenAccess.Lock()
defer t.brokenAccess.Unlock()
return !t.brokenUntil.IsZero() && time.Now().Before(t.brokenUntil)
}
func (t *http3FallbackTransport) clearH3Broken() {
t.brokenAccess.Lock()
t.brokenUntil = time.Time{}
t.brokenBackoff = 0
t.brokenAccess.Unlock()
}
func (t *http3FallbackTransport) markH3Broken() {
t.brokenAccess.Lock()
defer t.brokenAccess.Unlock()
if t.brokenBackoff == 0 {
t.brokenBackoff = 5 * time.Minute
} else {
t.brokenBackoff *= 2
if t.brokenBackoff > 48*time.Hour {
t.brokenBackoff = 48 * time.Hour
}
}
t.brokenUntil = time.Now().Add(t.brokenBackoff)
}

View File

@@ -1,30 +0,0 @@
//go:build !with_quic
package httpclient
import (
"time"
"github.com/sagernet/sing-box/common/tls"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
)
func newHTTP3FallbackTransport(
rawDialer N.Dialer,
baseTLSConfig tls.Config,
h2Fallback httpTransport,
options option.QUICOptions,
fallbackDelay time.Duration,
) (httpTransport, error) {
return nil, E.New("HTTP/3 requires building with the with_quic tag")
}
func newHTTP3Transport(
rawDialer N.Dialer,
baseTLSConfig tls.Config,
options option.QUICOptions,
) (httpTransport, error) {
return nil, E.New("HTTP/3 requires building with the with_quic tag")
}

View File

@@ -1,136 +0,0 @@
package httpclient
import (
"context"
"net/http"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
)
var (
_ adapter.HTTPClientManager = (*Manager)(nil)
_ adapter.LifecycleService = (*Manager)(nil)
)
type Manager struct {
ctx context.Context
logger log.ContextLogger
access sync.Mutex
defines map[string]option.HTTPClient
clients map[string]*Client
defaultTag string
defaultTransport http.RoundTripper
defaultTransportFallback func() (*Client, error)
fallbackClient *Client
}
func NewManager(ctx context.Context, logger log.ContextLogger, clients []option.HTTPClient, defaultHTTPClient string) *Manager {
defines := make(map[string]option.HTTPClient, len(clients))
for _, client := range clients {
defines[client.Tag] = client
}
defaultTag := defaultHTTPClient
if defaultTag == "" && len(clients) > 0 {
defaultTag = clients[0].Tag
}
return &Manager{
ctx: ctx,
logger: logger,
defines: defines,
clients: make(map[string]*Client),
defaultTag: defaultTag,
}
}
func (m *Manager) Initialize(defaultTransportFallback func() (*Client, error)) {
m.defaultTransportFallback = defaultTransportFallback
}
func (m *Manager) Name() string {
return "http-client"
}
func (m *Manager) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if m.defaultTag != "" {
transport, err := m.resolveShared(m.defaultTag)
if err != nil {
return E.Cause(err, "resolve default http client")
}
m.defaultTransport = transport
} else if m.defaultTransportFallback != nil {
client, err := m.defaultTransportFallback()
if err != nil {
return E.Cause(err, "create default http client")
}
m.defaultTransport = client
m.fallbackClient = client
}
return nil
}
func (m *Manager) DefaultTransport() http.RoundTripper {
return m.defaultTransport
}
func (m *Manager) ResolveTransport(logger logger.ContextLogger, options option.HTTPClientOptions) (http.RoundTripper, error) {
if options.Tag != "" {
if options.ResolveOnDetour {
define, loaded := m.defines[options.Tag]
if !loaded {
return nil, E.New("http_client not found: ", options.Tag)
}
resolvedOptions := define.Options()
resolvedOptions.ResolveOnDetour = true
return NewClient(m.ctx, logger, options.Tag, resolvedOptions)
}
return m.resolveShared(options.Tag)
}
return NewClient(m.ctx, logger, "", options)
}
func (m *Manager) resolveShared(tag string) (http.RoundTripper, error) {
m.access.Lock()
defer m.access.Unlock()
if client, loaded := m.clients[tag]; loaded {
return client, nil
}
define, loaded := m.defines[tag]
if !loaded {
return nil, E.New("http_client not found: ", tag)
}
client, err := NewClient(m.ctx, m.logger, tag, define.Options())
if err != nil {
return nil, E.Cause(err, "create shared http_client[", tag, "]")
}
m.clients[tag] = client
return client, nil
}
func (m *Manager) Close() error {
m.access.Lock()
defer m.access.Unlock()
if m.clients == nil {
return nil
}
var err error
for _, client := range m.clients {
err = E.Append(err, client.Close(), func(err error) error {
return E.Cause(err, "close http client")
})
}
if m.fallbackClient != nil {
err = E.Append(err, m.fallbackClient.Close(), func(err error) error {
return E.Cause(err, "close default http client")
})
}
m.clients = nil
return err
}

View File

@@ -37,10 +37,7 @@ func (l *Listener) ListenTCP() (net.Listener, error) {
if l.listenOptions.ReuseAddr { if l.listenOptions.ReuseAddr {
listenConfig.Control = control.Append(listenConfig.Control, control.ReuseAddr()) listenConfig.Control = control.Append(listenConfig.Control, control.ReuseAddr())
} }
if l.listenOptions.DisableTCPKeepAlive { if !l.listenOptions.DisableTCPKeepAlive {
listenConfig.KeepAlive = -1
listenConfig.KeepAliveConfig.Enable = false
} else {
keepIdle := time.Duration(l.listenOptions.TCPKeepAlive) keepIdle := time.Duration(l.listenOptions.TCPKeepAlive)
if keepIdle == 0 { if keepIdle == 0 {
keepIdle = C.TCPKeepAliveInitial keepIdle = C.TCPKeepAliveInitial

View File

@@ -1,142 +0,0 @@
package networkquality
import (
"context"
"fmt"
"net"
"net/http"
"strings"
C "github.com/sagernet/sing-box/constant"
sBufio "github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func FormatBitrate(bps int64) string {
switch {
case bps >= 1_000_000_000:
return fmt.Sprintf("%.1f Gbps", float64(bps)/1_000_000_000)
case bps >= 1_000_000:
return fmt.Sprintf("%.1f Mbps", float64(bps)/1_000_000)
case bps >= 1_000:
return fmt.Sprintf("%.1f Kbps", float64(bps)/1_000)
default:
return fmt.Sprintf("%d bps", bps)
}
}
func NewHTTPClient(dialer N.Dialer) *http.Client {
transport := &http.Transport{
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: C.TCPTimeout,
}
if dialer != nil {
transport.DialContext = func(ctx context.Context, network string, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
}
}
return &http.Client{Transport: transport}
}
func baseTransportFromClient(client *http.Client) (*http.Transport, error) {
if client == nil {
return nil, E.New("http client is nil")
}
if client.Transport == nil {
return http.DefaultTransport.(*http.Transport).Clone(), nil
}
transport, ok := client.Transport.(*http.Transport)
if !ok {
return nil, E.New("http client transport must be *http.Transport")
}
return transport.Clone(), nil
}
func newMeasurementClient(
baseClient *http.Client,
connectEndpoint string,
singleConnection bool,
disableKeepAlives bool,
readCounters []N.CountFunc,
writeCounters []N.CountFunc,
) (*http.Client, error) {
transport, err := baseTransportFromClient(baseClient)
if err != nil {
return nil, err
}
transport.DisableCompression = true
transport.DisableKeepAlives = disableKeepAlives
if singleConnection {
transport.MaxConnsPerHost = 1
transport.MaxIdleConnsPerHost = 1
transport.MaxIdleConns = 1
}
baseDialContext := transport.DialContext
if baseDialContext == nil {
dialer := &net.Dialer{}
baseDialContext = dialer.DialContext
}
transport.DialContext = func(ctx context.Context, network string, addr string) (net.Conn, error) {
dialAddr := addr
if connectEndpoint != "" {
dialAddr = rewriteDialAddress(addr, connectEndpoint)
}
conn, dialErr := baseDialContext(ctx, network, dialAddr)
if dialErr != nil {
return nil, dialErr
}
if len(readCounters) > 0 || len(writeCounters) > 0 {
return sBufio.NewCounterConn(conn, readCounters, writeCounters), nil
}
return conn, nil
}
return &http.Client{
Transport: transport,
CheckRedirect: baseClient.CheckRedirect,
Jar: baseClient.Jar,
Timeout: baseClient.Timeout,
}, nil
}
type MeasurementClientFactory func(
connectEndpoint string,
singleConnection bool,
disableKeepAlives bool,
readCounters []N.CountFunc,
writeCounters []N.CountFunc,
) (*http.Client, error)
func defaultMeasurementClientFactory(baseClient *http.Client) MeasurementClientFactory {
return func(connectEndpoint string, singleConnection, disableKeepAlives bool, readCounters, writeCounters []N.CountFunc) (*http.Client, error) {
return newMeasurementClient(baseClient, connectEndpoint, singleConnection, disableKeepAlives, readCounters, writeCounters)
}
}
func NewOptionalHTTP3Factory(dialer N.Dialer, useHTTP3 bool) (MeasurementClientFactory, error) {
if !useHTTP3 {
return nil, nil
}
return NewHTTP3MeasurementClientFactory(dialer)
}
func rewriteDialAddress(addr string, connectEndpoint string) string {
connectEndpoint = strings.TrimSpace(connectEndpoint)
host, port, err := net.SplitHostPort(addr)
if err != nil {
return addr
}
endpointHost, endpointPort, err := net.SplitHostPort(connectEndpoint)
if err == nil {
host = endpointHost
if endpointPort != "" {
port = endpointPort
}
} else if connectEndpoint != "" {
host = connectEndpoint
}
return net.JoinHostPort(host, port)
}

View File

@@ -1,55 +0,0 @@
//go:build with_quic
package networkquality
import (
"context"
"crypto/tls"
"net"
"net/http"
"github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/http3"
sBufio "github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func NewHTTP3MeasurementClientFactory(dialer N.Dialer) (MeasurementClientFactory, error) {
// singleConnection and disableKeepAlives are not applied:
// HTTP/3 multiplexes streams over a single QUIC connection by default.
return func(connectEndpoint string, _, _ bool, readCounters, writeCounters []N.CountFunc) (*http.Client, error) {
transport := &http3.Transport{
Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
dialAddr := addr
if connectEndpoint != "" {
dialAddr = rewriteDialAddress(addr, connectEndpoint)
}
destination := M.ParseSocksaddr(dialAddr)
var udpConn net.Conn
var dialErr error
if dialer != nil {
udpConn, dialErr = dialer.DialContext(ctx, N.NetworkUDP, destination)
} else {
var netDialer net.Dialer
udpConn, dialErr = netDialer.DialContext(ctx, N.NetworkUDP, destination.String())
}
if dialErr != nil {
return nil, dialErr
}
wrappedConn := udpConn
if len(readCounters) > 0 || len(writeCounters) > 0 {
wrappedConn = sBufio.NewCounterConn(udpConn, readCounters, writeCounters)
}
packetConn := sBufio.NewUnbindPacketConn(wrappedConn)
quicConn, dialErr := quic.DialEarly(ctx, packetConn, udpConn.RemoteAddr(), tlsCfg, cfg)
if dialErr != nil {
udpConn.Close()
return nil, dialErr
}
return quicConn, nil
},
}
return &http.Client{Transport: transport}, nil
}, nil
}

View File

@@ -1,12 +0,0 @@
//go:build !with_quic
package networkquality
import (
C "github.com/sagernet/sing-box/constant"
N "github.com/sagernet/sing/common/network"
)
func NewHTTP3MeasurementClientFactory(dialer N.Dialer) (MeasurementClientFactory, error) {
return nil, C.ErrQUICNotIncluded
}

File diff suppressed because it is too large Load Diff

View File

@@ -14,7 +14,6 @@ import (
type Searcher interface { type Searcher interface {
FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error)
Close() error
} }
var ErrNotFound = E.New("process not found") var ErrNotFound = E.New("process not found")
@@ -29,7 +28,7 @@ func FindProcessInfo(searcher Searcher, ctx context.Context, network string, sou
if err != nil { if err != nil {
return nil, err return nil, err
} }
if info.UserId != -1 && info.UserName == "" { if info.UserId != -1 {
osUser, _ := user.LookupId(F.ToString(info.UserId)) osUser, _ := user.LookupId(F.ToString(info.UserId))
if osUser != nil { if osUser != nil {
info.UserName = osUser.Username info.UserName = osUser.Username

View File

@@ -6,7 +6,6 @@ import (
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
) )
var _ Searcher = (*androidSearcher)(nil) var _ Searcher = (*androidSearcher)(nil)
@@ -19,30 +18,22 @@ func NewSearcher(config Config) (Searcher, error) {
return &androidSearcher{config.PackageManager}, nil return &androidSearcher{config.PackageManager}, nil
} }
func (s *androidSearcher) Close() error {
return nil
}
func (s *androidSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) { func (s *androidSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
family, protocol, err := socketDiagSettings(network, source) _, uid, err := resolveSocketByNetlink(network, source, destination)
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, uid, err := querySocketDiagOnce(family, protocol, source) if sharedPackage, loaded := s.packageManager.SharedPackageByID(uid % 100000); loaded {
if err != nil { return &adapter.ConnectionOwner{
return nil, err UserId: int32(uid),
AndroidPackageName: sharedPackage,
}, nil
} }
appID := uid % 100000 if packageName, loaded := s.packageManager.PackageByID(uid % 100000); loaded {
var packageNames []string return &adapter.ConnectionOwner{
if sharedPackage, loaded := s.packageManager.SharedPackageByID(appID); loaded { UserId: int32(uid),
packageNames = append(packageNames, sharedPackage) AndroidPackageName: packageName,
}, nil
} }
if packages, loaded := s.packageManager.PackagesByID(appID); loaded { return &adapter.ConnectionOwner{UserId: int32(uid)}, nil
packageNames = append(packageNames, packages...)
}
packageNames = common.Uniq(packageNames)
return &adapter.ConnectionOwner{
UserId: int32(uid),
AndroidPackageNames: packageNames,
}, nil
} }

View File

@@ -1,15 +1,19 @@
//go:build darwin
package process package process
import ( import (
"context" "context"
"encoding/binary"
"net/netip" "net/netip"
"os"
"strconv" "strconv"
"strings" "strings"
"syscall" "syscall"
"unsafe"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/unix"
) )
var _ Searcher = (*darwinSearcher)(nil) var _ Searcher = (*darwinSearcher)(nil)
@@ -20,12 +24,12 @@ func NewSearcher(_ Config) (Searcher, error) {
return &darwinSearcher{}, nil return &darwinSearcher{}, nil
} }
func (d *darwinSearcher) Close() error {
return nil
}
func (d *darwinSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) { func (d *darwinSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
return FindDarwinConnectionOwner(network, source, destination) processName, err := findProcessName(network, source.Addr(), int(source.Port()))
if err != nil {
return nil, err
}
return &adapter.ConnectionOwner{ProcessPath: processName, UserId: -1}, nil
} }
var structSize = func() int { var structSize = func() int {
@@ -43,3 +47,107 @@ var structSize = func() int {
return 384 return 384
} }
}() }()
func findProcessName(network string, ip netip.Addr, port int) (string, error) {
var spath string
switch network {
case N.NetworkTCP:
spath = "net.inet.tcp.pcblist_n"
case N.NetworkUDP:
spath = "net.inet.udp.pcblist_n"
default:
return "", os.ErrInvalid
}
isIPv4 := ip.Is4()
value, err := unix.SysctlRaw(spath)
if err != nil {
return "", err
}
buf := value
// from darwin-xnu/bsd/netinet/in_pcblist.c:get_pcblist_n
// size/offset are round up (aligned) to 8 bytes in darwin
// rup8(sizeof(xinpcb_n)) + rup8(sizeof(xsocket_n)) +
// 2 * rup8(sizeof(xsockbuf_n)) + rup8(sizeof(xsockstat_n))
itemSize := structSize
if network == N.NetworkTCP {
// rup8(sizeof(xtcpcb_n))
itemSize += 208
}
var fallbackUDPProcess string
// skip the first xinpgen(24 bytes) block
for i := 24; i+itemSize <= len(buf); i += itemSize {
// offset of xinpcb_n and xsocket_n
inp, so := i, i+104
srcPort := binary.BigEndian.Uint16(buf[inp+18 : inp+20])
if uint16(port) != srcPort {
continue
}
// xinpcb_n.inp_vflag
flag := buf[inp+44]
var srcIP netip.Addr
srcIsIPv4 := false
switch {
case flag&0x1 > 0 && isIPv4:
// ipv4
srcIP = netip.AddrFrom4([4]byte(buf[inp+76 : inp+80]))
srcIsIPv4 = true
case flag&0x2 > 0 && !isIPv4:
// ipv6
srcIP = netip.AddrFrom16([16]byte(buf[inp+64 : inp+80]))
default:
continue
}
if ip == srcIP {
// xsocket_n.so_last_pid
pid := readNativeUint32(buf[so+68 : so+72])
return getExecPathFromPID(pid)
}
// udp packet connection may be not equal with srcIP
if network == N.NetworkUDP && srcIP.IsUnspecified() && isIPv4 == srcIsIPv4 {
pid := readNativeUint32(buf[so+68 : so+72])
fallbackUDPProcess, _ = getExecPathFromPID(pid)
}
}
if network == N.NetworkUDP && len(fallbackUDPProcess) > 0 {
return fallbackUDPProcess, nil
}
return "", ErrNotFound
}
func getExecPathFromPID(pid uint32) (string, error) {
const (
procpidpathinfo = 0xb
procpidpathinfosize = 1024
proccallnumpidinfo = 0x2
)
buf := make([]byte, procpidpathinfosize)
_, _, errno := syscall.Syscall6(
syscall.SYS_PROC_INFO,
proccallnumpidinfo,
uintptr(pid),
procpidpathinfo,
0,
uintptr(unsafe.Pointer(&buf[0])),
procpidpathinfosize)
if errno != 0 {
return "", errno
}
return unix.ByteSliceToString(buf), nil
}
func readNativeUint32(b []byte) uint32 {
return *(*uint32)(unsafe.Pointer(&b[0]))
}

View File

@@ -1,269 +0,0 @@
//go:build darwin
package process
import (
"encoding/binary"
"net/netip"
"os"
"sync"
"syscall"
"time"
"unsafe"
"github.com/sagernet/sing-box/adapter"
N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/unix"
)
const (
darwinSnapshotTTL = 200 * time.Millisecond
darwinXinpgenSize = 24
darwinXsocketOffset = 104
darwinXinpcbForeignPort = 16
darwinXinpcbLocalPort = 18
darwinXinpcbVFlag = 44
darwinXinpcbForeignAddr = 48
darwinXinpcbLocalAddr = 64
darwinXinpcbIPv4Addr = 12
darwinXsocketUID = 64
darwinXsocketLastPID = 68
darwinTCPExtraStructSize = 208
)
type darwinConnectionEntry struct {
localAddr netip.Addr
remoteAddr netip.Addr
localPort uint16
remotePort uint16
pid uint32
uid int32
}
type darwinConnectionMatchKind uint8
const (
darwinConnectionMatchExact darwinConnectionMatchKind = iota
darwinConnectionMatchLocalFallback
darwinConnectionMatchWildcardFallback
)
type darwinSnapshot struct {
createdAt time.Time
entries []darwinConnectionEntry
}
type darwinConnectionFinder struct {
access sync.Mutex
ttl time.Duration
snapshots map[string]darwinSnapshot
builder func(string) (darwinSnapshot, error)
}
var sharedDarwinConnectionFinder = newDarwinConnectionFinder(darwinSnapshotTTL)
func newDarwinConnectionFinder(ttl time.Duration) *darwinConnectionFinder {
return &darwinConnectionFinder{
ttl: ttl,
snapshots: make(map[string]darwinSnapshot),
builder: buildDarwinSnapshot,
}
}
func FindDarwinConnectionOwner(network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
return sharedDarwinConnectionFinder.find(network, source, destination)
}
func (f *darwinConnectionFinder) find(network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
networkName := N.NetworkName(network)
source = normalizeDarwinAddrPort(source)
destination = normalizeDarwinAddrPort(destination)
var lastOwner *adapter.ConnectionOwner
for attempt := 0; attempt < 2; attempt++ {
snapshot, fromCache, err := f.loadSnapshot(networkName, attempt > 0)
if err != nil {
return nil, err
}
entry, matchKind, err := matchDarwinConnectionEntry(snapshot.entries, networkName, source, destination)
if err != nil {
if err == ErrNotFound && fromCache {
continue
}
return nil, err
}
if fromCache && matchKind != darwinConnectionMatchExact {
continue
}
owner := &adapter.ConnectionOwner{
UserId: entry.uid,
}
lastOwner = owner
if entry.pid == 0 {
return owner, nil
}
processPath, err := getExecPathFromPID(entry.pid)
if err == nil {
owner.ProcessPath = processPath
return owner, nil
}
if fromCache {
continue
}
return owner, nil
}
if lastOwner != nil {
return lastOwner, nil
}
return nil, ErrNotFound
}
func (f *darwinConnectionFinder) loadSnapshot(network string, forceRefresh bool) (darwinSnapshot, bool, error) {
f.access.Lock()
defer f.access.Unlock()
if !forceRefresh {
if snapshot, loaded := f.snapshots[network]; loaded && time.Since(snapshot.createdAt) < f.ttl {
return snapshot, true, nil
}
}
snapshot, err := f.builder(network)
if err != nil {
return darwinSnapshot{}, false, err
}
f.snapshots[network] = snapshot
return snapshot, false, nil
}
func buildDarwinSnapshot(network string) (darwinSnapshot, error) {
spath, itemSize, err := darwinSnapshotSettings(network)
if err != nil {
return darwinSnapshot{}, err
}
value, err := unix.SysctlRaw(spath)
if err != nil {
return darwinSnapshot{}, err
}
return darwinSnapshot{
createdAt: time.Now(),
entries: parseDarwinSnapshot(value, itemSize),
}, nil
}
func darwinSnapshotSettings(network string) (string, int, error) {
itemSize := structSize
switch network {
case N.NetworkTCP:
return "net.inet.tcp.pcblist_n", itemSize + darwinTCPExtraStructSize, nil
case N.NetworkUDP:
return "net.inet.udp.pcblist_n", itemSize, nil
default:
return "", 0, os.ErrInvalid
}
}
func parseDarwinSnapshot(buf []byte, itemSize int) []darwinConnectionEntry {
entries := make([]darwinConnectionEntry, 0, (len(buf)-darwinXinpgenSize)/itemSize)
for i := darwinXinpgenSize; i+itemSize <= len(buf); i += itemSize {
inp := i
so := i + darwinXsocketOffset
entry, ok := parseDarwinConnectionEntry(buf[inp:so], buf[so:so+structSize-darwinXsocketOffset])
if ok {
entries = append(entries, entry)
}
}
return entries
}
func parseDarwinConnectionEntry(inp []byte, so []byte) (darwinConnectionEntry, bool) {
if len(inp) < darwinXsocketOffset || len(so) < structSize-darwinXsocketOffset {
return darwinConnectionEntry{}, false
}
entry := darwinConnectionEntry{
remotePort: binary.BigEndian.Uint16(inp[darwinXinpcbForeignPort : darwinXinpcbForeignPort+2]),
localPort: binary.BigEndian.Uint16(inp[darwinXinpcbLocalPort : darwinXinpcbLocalPort+2]),
pid: binary.NativeEndian.Uint32(so[darwinXsocketLastPID : darwinXsocketLastPID+4]),
uid: int32(binary.NativeEndian.Uint32(so[darwinXsocketUID : darwinXsocketUID+4])),
}
flag := inp[darwinXinpcbVFlag]
switch {
case flag&0x1 != 0:
entry.remoteAddr = netip.AddrFrom4([4]byte(inp[darwinXinpcbForeignAddr+darwinXinpcbIPv4Addr : darwinXinpcbForeignAddr+darwinXinpcbIPv4Addr+4]))
entry.localAddr = netip.AddrFrom4([4]byte(inp[darwinXinpcbLocalAddr+darwinXinpcbIPv4Addr : darwinXinpcbLocalAddr+darwinXinpcbIPv4Addr+4]))
return entry, true
case flag&0x2 != 0:
entry.remoteAddr = netip.AddrFrom16([16]byte(inp[darwinXinpcbForeignAddr : darwinXinpcbForeignAddr+16]))
entry.localAddr = netip.AddrFrom16([16]byte(inp[darwinXinpcbLocalAddr : darwinXinpcbLocalAddr+16]))
return entry, true
default:
return darwinConnectionEntry{}, false
}
}
func matchDarwinConnectionEntry(entries []darwinConnectionEntry, network string, source netip.AddrPort, destination netip.AddrPort) (darwinConnectionEntry, darwinConnectionMatchKind, error) {
sourceAddr := source.Addr()
if !sourceAddr.IsValid() {
return darwinConnectionEntry{}, darwinConnectionMatchExact, os.ErrInvalid
}
var localFallback darwinConnectionEntry
var hasLocalFallback bool
var wildcardFallback darwinConnectionEntry
var hasWildcardFallback bool
for _, entry := range entries {
if entry.localPort != source.Port() || sourceAddr.BitLen() != entry.localAddr.BitLen() {
continue
}
if entry.localAddr == sourceAddr && destination.IsValid() && entry.remotePort == destination.Port() && entry.remoteAddr == destination.Addr() {
return entry, darwinConnectionMatchExact, nil
}
if !destination.IsValid() && entry.localAddr == sourceAddr {
return entry, darwinConnectionMatchExact, nil
}
if network != N.NetworkUDP {
continue
}
if !hasLocalFallback && entry.localAddr == sourceAddr {
hasLocalFallback = true
localFallback = entry
}
if !hasWildcardFallback && entry.localAddr.IsUnspecified() {
hasWildcardFallback = true
wildcardFallback = entry
}
}
if hasLocalFallback {
return localFallback, darwinConnectionMatchLocalFallback, nil
}
if hasWildcardFallback {
return wildcardFallback, darwinConnectionMatchWildcardFallback, nil
}
return darwinConnectionEntry{}, darwinConnectionMatchExact, ErrNotFound
}
func normalizeDarwinAddrPort(addrPort netip.AddrPort) netip.AddrPort {
if !addrPort.IsValid() {
return addrPort
}
return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())
}
func getExecPathFromPID(pid uint32) (string, error) {
const (
procpidpathinfo = 0xb
procpidpathinfosize = 1024
proccallnumpidinfo = 0x2
)
buf := make([]byte, procpidpathinfosize)
_, _, errno := syscall.Syscall6(
syscall.SYS_PROC_INFO,
proccallnumpidinfo,
uintptr(pid),
procpidpathinfo,
0,
uintptr(unsafe.Pointer(&buf[0])),
procpidpathinfosize)
if errno != 0 {
return "", errno
}
return unix.ByteSliceToString(buf), nil
}

View File

@@ -4,82 +4,33 @@ package process
import ( import (
"context" "context"
"errors"
"net/netip" "net/netip"
"syscall"
"time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
) )
var _ Searcher = (*linuxSearcher)(nil) var _ Searcher = (*linuxSearcher)(nil)
type linuxSearcher struct { type linuxSearcher struct {
logger log.ContextLogger logger log.ContextLogger
diagConns [4]*socketDiagConn
processPathCache *uidProcessPathCache
} }
func NewSearcher(config Config) (Searcher, error) { func NewSearcher(config Config) (Searcher, error) {
searcher := &linuxSearcher{ return &linuxSearcher{config.Logger}, nil
logger: config.Logger,
processPathCache: newUIDProcessPathCache(time.Second),
}
for _, family := range []uint8{syscall.AF_INET, syscall.AF_INET6} {
for _, protocol := range []uint8{syscall.IPPROTO_TCP, syscall.IPPROTO_UDP} {
searcher.diagConns[socketDiagConnIndex(family, protocol)] = newSocketDiagConn(family, protocol)
}
}
return searcher, nil
}
func (s *linuxSearcher) Close() error {
var errs []error
for _, conn := range s.diagConns {
if conn == nil {
continue
}
errs = append(errs, conn.Close())
}
return E.Errors(errs...)
} }
func (s *linuxSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) { func (s *linuxSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
inode, uid, err := s.resolveSocketByNetlink(network, source, destination) inode, uid, err := resolveSocketByNetlink(network, source, destination)
if err != nil { if err != nil {
return nil, err return nil, err
} }
processInfo := &adapter.ConnectionOwner{ processPath, err := resolveProcessNameByProcSearch(inode, uid)
UserId: int32(uid),
}
processPath, err := s.processPathCache.findProcessPath(inode, uid)
if err != nil { if err != nil {
s.logger.DebugContext(ctx, "find process path: ", err) s.logger.DebugContext(ctx, "find process path: ", err)
} else {
processInfo.ProcessPath = processPath
} }
return processInfo, nil return &adapter.ConnectionOwner{
} UserId: int32(uid),
ProcessPath: processPath,
func (s *linuxSearcher) resolveSocketByNetlink(network string, source netip.AddrPort, destination netip.AddrPort) (inode, uid uint32, err error) { }, nil
family, protocol, err := socketDiagSettings(network, source)
if err != nil {
return 0, 0, err
}
conn := s.diagConns[socketDiagConnIndex(family, protocol)]
if conn == nil {
return 0, 0, E.New("missing socket diag connection for family=", family, " protocol=", protocol)
}
if destination.IsValid() && source.Addr().BitLen() == destination.Addr().BitLen() {
inode, uid, err = conn.query(source, destination)
if err == nil {
return inode, uid, nil
}
if !errors.Is(err, ErrNotFound) {
return 0, 0, err
}
}
return querySocketDiagOnce(family, protocol, source)
} }

View File

@@ -3,67 +3,43 @@
package process package process
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"errors" "fmt"
"net"
"net/netip" "net/netip"
"os" "os"
"path/filepath" "path"
"strings" "strings"
"sync"
"syscall" "syscall"
"time"
"unicode" "unicode"
"unsafe"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
) )
// from https://github.com/vishvananda/netlink/blob/bca67dfc8220b44ef582c9da4e9172bf1c9ec973/nl/nl_linux.go#L52-L62
var nativeEndian = func() binary.ByteOrder {
var x uint32 = 0x01020304
if *(*byte)(unsafe.Pointer(&x)) == 0x01 {
return binary.BigEndian
}
return binary.LittleEndian
}()
const ( const (
sizeOfSocketDiagRequestData = 56 sizeOfSocketDiagRequest = syscall.SizeofNlMsghdr + 8 + 48
sizeOfSocketDiagRequest = syscall.SizeofNlMsghdr + sizeOfSocketDiagRequestData socketDiagByFamily = 20
socketDiagResponseMinSize = 72 pathProc = "/proc"
socketDiagByFamily = 20
pathProc = "/proc"
) )
type socketDiagConn struct { func resolveSocketByNetlink(network string, source netip.AddrPort, destination netip.AddrPort) (inode, uid uint32, err error) {
access sync.Mutex var family uint8
family uint8 var protocol uint8
protocol uint8
fd int
}
type uidProcessPathCache struct {
cache freelru.Cache[uint32, *uidProcessPaths]
}
type uidProcessPaths struct {
entries map[uint32]string
}
func newSocketDiagConn(family, protocol uint8) *socketDiagConn {
return &socketDiagConn{
family: family,
protocol: protocol,
fd: -1,
}
}
func socketDiagConnIndex(family, protocol uint8) int {
index := 0
if protocol == syscall.IPPROTO_UDP {
index += 2
}
if family == syscall.AF_INET6 {
index++
}
return index
}
func socketDiagSettings(network string, source netip.AddrPort) (family, protocol uint8, err error) {
switch network { switch network {
case N.NetworkTCP: case N.NetworkTCP:
protocol = syscall.IPPROTO_TCP protocol = syscall.IPPROTO_TCP
@@ -72,308 +48,151 @@ func socketDiagSettings(network string, source netip.AddrPort) (family, protocol
default: default:
return 0, 0, os.ErrInvalid return 0, 0, os.ErrInvalid
} }
switch {
case source.Addr().Is4(): if source.Addr().Is4() {
family = syscall.AF_INET family = syscall.AF_INET
case source.Addr().Is6(): } else {
family = syscall.AF_INET6 family = syscall.AF_INET6
default:
return 0, 0, os.ErrInvalid
} }
return family, protocol, nil
}
func newUIDProcessPathCache(ttl time.Duration) *uidProcessPathCache { req := packSocketDiagRequest(family, protocol, source)
cache := common.Must1(freelru.NewSharded[uint32, *uidProcessPaths](64, maphash.NewHasher[uint32]().Hash32))
cache.SetLifetime(ttl)
return &uidProcessPathCache{cache: cache}
}
func (c *uidProcessPathCache) findProcessPath(targetInode, uid uint32) (string, error) { socket, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_DGRAM, syscall.NETLINK_INET_DIAG)
if cached, ok := c.cache.Get(uid); ok {
if processPath, found := cached.entries[targetInode]; found {
return processPath, nil
}
}
processPaths, err := buildProcessPathByUIDCache(uid)
if err != nil {
return "", err
}
c.cache.Add(uid, &uidProcessPaths{entries: processPaths})
processPath, found := processPaths[targetInode]
if !found {
return "", E.New("process of uid(", uid, "), inode(", targetInode, ") not found")
}
return processPath, nil
}
func (c *socketDiagConn) Close() error {
c.access.Lock()
defer c.access.Unlock()
return c.closeLocked()
}
func (c *socketDiagConn) query(source netip.AddrPort, destination netip.AddrPort) (inode, uid uint32, err error) {
c.access.Lock()
defer c.access.Unlock()
request := packSocketDiagRequest(c.family, c.protocol, source, destination, false)
for attempt := 0; attempt < 2; attempt++ {
err = c.ensureOpenLocked()
if err != nil {
return 0, 0, E.Cause(err, "dial netlink")
}
inode, uid, err = querySocketDiag(c.fd, request)
if err == nil || errors.Is(err, ErrNotFound) {
return inode, uid, err
}
if !shouldRetrySocketDiag(err) {
return 0, 0, err
}
_ = c.closeLocked()
}
return 0, 0, err
}
func querySocketDiagOnce(family, protocol uint8, source netip.AddrPort) (inode, uid uint32, err error) {
fd, err := openSocketDiag()
if err != nil { if err != nil {
return 0, 0, E.Cause(err, "dial netlink") return 0, 0, E.Cause(err, "dial netlink")
} }
defer syscall.Close(fd) defer syscall.Close(socket)
return querySocketDiag(fd, packSocketDiagRequest(family, protocol, source, netip.AddrPort{}, true))
}
func (c *socketDiagConn) ensureOpenLocked() error { syscall.SetsockoptTimeval(socket, syscall.SOL_SOCKET, syscall.SO_SNDTIMEO, &syscall.Timeval{Usec: 100})
if c.fd != -1 { syscall.SetsockoptTimeval(socket, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &syscall.Timeval{Usec: 100})
return nil
}
fd, err := openSocketDiag()
if err != nil {
return err
}
c.fd = fd
return nil
}
func openSocketDiag() (int, error) { err = syscall.Connect(socket, &syscall.SockaddrNetlink{
fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_DGRAM|syscall.SOCK_CLOEXEC, syscall.NETLINK_INET_DIAG)
if err != nil {
return -1, err
}
timeout := &syscall.Timeval{Usec: 100}
if err = syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET, syscall.SO_SNDTIMEO, timeout); err != nil {
syscall.Close(fd)
return -1, err
}
if err = syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, timeout); err != nil {
syscall.Close(fd)
return -1, err
}
if err = syscall.Connect(fd, &syscall.SockaddrNetlink{
Family: syscall.AF_NETLINK, Family: syscall.AF_NETLINK,
Pad: 0,
Pid: 0, Pid: 0,
Groups: 0, Groups: 0,
}); err != nil { })
syscall.Close(fd) if err != nil {
return -1, err return
} }
return fd, nil
}
func (c *socketDiagConn) closeLocked() error { _, err = syscall.Write(socket, req)
if c.fd == -1 {
return nil
}
err := syscall.Close(c.fd)
c.fd = -1
return err
}
func packSocketDiagRequest(family, protocol byte, source netip.AddrPort, destination netip.AddrPort, dump bool) []byte {
request := make([]byte, sizeOfSocketDiagRequest)
binary.NativeEndian.PutUint32(request[0:4], sizeOfSocketDiagRequest)
binary.NativeEndian.PutUint16(request[4:6], socketDiagByFamily)
flags := uint16(syscall.NLM_F_REQUEST)
if dump {
flags |= syscall.NLM_F_DUMP
}
binary.NativeEndian.PutUint16(request[6:8], flags)
binary.NativeEndian.PutUint32(request[8:12], 0)
binary.NativeEndian.PutUint32(request[12:16], 0)
request[16] = family
request[17] = protocol
request[18] = 0
request[19] = 0
if dump {
binary.NativeEndian.PutUint32(request[20:24], 0xFFFFFFFF)
}
requestSource := source
requestDestination := destination
if protocol == syscall.IPPROTO_UDP && !dump && destination.IsValid() {
// udp_dump_one expects the exact-match endpoints reversed for historical reasons.
requestSource, requestDestination = destination, source
}
binary.BigEndian.PutUint16(request[24:26], requestSource.Port())
binary.BigEndian.PutUint16(request[26:28], requestDestination.Port())
if family == syscall.AF_INET6 {
copy(request[28:44], requestSource.Addr().AsSlice())
if requestDestination.IsValid() {
copy(request[44:60], requestDestination.Addr().AsSlice())
}
} else {
copy(request[28:32], requestSource.Addr().AsSlice())
if requestDestination.IsValid() {
copy(request[44:48], requestDestination.Addr().AsSlice())
}
}
binary.NativeEndian.PutUint32(request[60:64], 0)
binary.NativeEndian.PutUint64(request[64:72], 0xFFFFFFFFFFFFFFFF)
return request
}
func querySocketDiag(fd int, request []byte) (inode, uid uint32, err error) {
_, err = syscall.Write(fd, request)
if err != nil { if err != nil {
return 0, 0, E.Cause(err, "write netlink request") return 0, 0, E.Cause(err, "write netlink request")
} }
buffer := make([]byte, 64<<10)
n, err := syscall.Read(fd, buffer) buffer := buf.New()
defer buffer.Release()
n, err := syscall.Read(socket, buffer.FreeBytes())
if err != nil { if err != nil {
return 0, 0, E.Cause(err, "read netlink response") return 0, 0, E.Cause(err, "read netlink response")
} }
messages, err := syscall.ParseNetlinkMessage(buffer[:n])
buffer.Truncate(n)
messages, err := syscall.ParseNetlinkMessage(buffer.Bytes())
if err != nil { if err != nil {
return 0, 0, E.Cause(err, "parse netlink message") return 0, 0, E.Cause(err, "parse netlink message")
} else if len(messages) == 0 {
return 0, 0, E.New("unexcepted netlink response")
} }
return unpackSocketDiagMessages(messages)
message := messages[0]
if message.Header.Type&syscall.NLMSG_ERROR != 0 {
return 0, 0, E.New("netlink message: NLMSG_ERROR")
}
inode, uid = unpackSocketDiagResponse(&messages[0])
return
} }
func unpackSocketDiagMessages(messages []syscall.NetlinkMessage) (inode, uid uint32, err error) { func packSocketDiagRequest(family, protocol byte, source netip.AddrPort) []byte {
for _, message := range messages { s := make([]byte, 16)
switch message.Header.Type { copy(s, source.Addr().AsSlice())
case syscall.NLMSG_DONE:
continue buf := make([]byte, sizeOfSocketDiagRequest)
case syscall.NLMSG_ERROR:
err = unpackSocketDiagError(&message) nativeEndian.PutUint32(buf[0:4], sizeOfSocketDiagRequest)
if err != nil { nativeEndian.PutUint16(buf[4:6], socketDiagByFamily)
return 0, 0, err nativeEndian.PutUint16(buf[6:8], syscall.NLM_F_REQUEST|syscall.NLM_F_DUMP)
} nativeEndian.PutUint32(buf[8:12], 0)
case socketDiagByFamily: nativeEndian.PutUint32(buf[12:16], 0)
inode, uid = unpackSocketDiagResponse(&message)
if inode != 0 || uid != 0 { buf[16] = family
return inode, uid, nil buf[17] = protocol
} buf[18] = 0
} buf[19] = 0
} nativeEndian.PutUint32(buf[20:24], 0xFFFFFFFF)
return 0, 0, ErrNotFound
binary.BigEndian.PutUint16(buf[24:26], source.Port())
binary.BigEndian.PutUint16(buf[26:28], 0)
copy(buf[28:44], s)
copy(buf[44:60], net.IPv6zero)
nativeEndian.PutUint32(buf[60:64], 0)
nativeEndian.PutUint64(buf[64:72], 0xFFFFFFFFFFFFFFFF)
return buf
} }
func unpackSocketDiagResponse(msg *syscall.NetlinkMessage) (inode, uid uint32) { func unpackSocketDiagResponse(msg *syscall.NetlinkMessage) (inode, uid uint32) {
if len(msg.Data) < socketDiagResponseMinSize { if len(msg.Data) < 72 {
return 0, 0 return 0, 0
} }
uid = binary.NativeEndian.Uint32(msg.Data[64:68])
inode = binary.NativeEndian.Uint32(msg.Data[68:72]) data := msg.Data
return inode, uid
uid = nativeEndian.Uint32(data[64:68])
inode = nativeEndian.Uint32(data[68:72])
return
} }
func unpackSocketDiagError(msg *syscall.NetlinkMessage) error { func resolveProcessNameByProcSearch(inode, uid uint32) (string, error) {
if len(msg.Data) < 4 {
return E.New("netlink message: NLMSG_ERROR")
}
errno := int32(binary.NativeEndian.Uint32(msg.Data[:4]))
if errno == 0 {
return nil
}
if errno < 0 {
errno = -errno
}
sysErr := syscall.Errno(errno)
switch sysErr {
case syscall.ENOENT, syscall.ESRCH:
return ErrNotFound
default:
return E.New("netlink message: ", sysErr)
}
}
func shouldRetrySocketDiag(err error) bool {
return err != nil && !errors.Is(err, ErrNotFound)
}
func buildProcessPathByUIDCache(uid uint32) (map[uint32]string, error) {
files, err := os.ReadDir(pathProc) files, err := os.ReadDir(pathProc)
if err != nil { if err != nil {
return nil, err return "", err
} }
buffer := make([]byte, syscall.PathMax) buffer := make([]byte, syscall.PathMax)
processPaths := make(map[uint32]string) socket := []byte(fmt.Sprintf("socket:[%d]", inode))
for _, file := range files {
if !file.IsDir() || !isPid(file.Name()) { for _, f := range files {
if !f.IsDir() || !isPid(f.Name()) {
continue continue
} }
info, err := file.Info()
info, err := f.Info()
if err != nil { if err != nil {
if isIgnorableProcError(err) { return "", err
continue
}
return nil, err
} }
if info.Sys().(*syscall.Stat_t).Uid != uid { if info.Sys().(*syscall.Stat_t).Uid != uid {
continue continue
} }
processPath := filepath.Join(pathProc, file.Name())
fdPath := filepath.Join(processPath, "fd") processPath := path.Join(pathProc, f.Name())
exePath, err := os.Readlink(filepath.Join(processPath, "exe")) fdPath := path.Join(processPath, "fd")
if err != nil {
if isIgnorableProcError(err) {
continue
}
return nil, err
}
fds, err := os.ReadDir(fdPath) fds, err := os.ReadDir(fdPath)
if err != nil { if err != nil {
continue continue
} }
for _, fd := range fds { for _, fd := range fds {
n, err := syscall.Readlink(filepath.Join(fdPath, fd.Name()), buffer) n, err := syscall.Readlink(path.Join(fdPath, fd.Name()), buffer)
if err != nil { if err != nil {
continue continue
} }
inode, ok := parseSocketInode(buffer[:n])
if !ok { if bytes.Equal(buffer[:n], socket) {
continue return os.Readlink(path.Join(processPath, "exe"))
}
if _, loaded := processPaths[inode]; !loaded {
processPaths[inode] = exePath
} }
} }
} }
return processPaths, nil
}
func isIgnorableProcError(err error) bool { return "", fmt.Errorf("process of uid(%d),inode(%d) not found", uid, inode)
return os.IsNotExist(err) || os.IsPermission(err)
}
func parseSocketInode(link []byte) (uint32, bool) {
const socketPrefix = "socket:["
if len(link) <= len(socketPrefix) || string(link[:len(socketPrefix)]) != socketPrefix || link[len(link)-1] != ']' {
return 0, false
}
var inode uint64
for _, char := range link[len(socketPrefix) : len(link)-1] {
if char < '0' || char > '9' {
return 0, false
}
inode = inode*10 + uint64(char-'0')
if inode > uint64(^uint32(0)) {
return 0, false
}
}
return uint32(inode), true
} }
func isPid(s string) bool { func isPid(s string) bool {

View File

@@ -1,60 +0,0 @@
//go:build linux
package process
import (
"net"
"net/netip"
"os"
"syscall"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestQuerySocketDiagUDPExact(t *testing.T) {
t.Parallel()
server, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
defer server.Close()
client, err := net.DialUDP("udp4", nil, server.LocalAddr().(*net.UDPAddr))
require.NoError(t, err)
defer client.Close()
err = client.SetDeadline(time.Now().Add(time.Second))
require.NoError(t, err)
_, err = client.Write([]byte{0})
require.NoError(t, err)
err = server.SetReadDeadline(time.Now().Add(time.Second))
require.NoError(t, err)
buffer := make([]byte, 1)
_, _, err = server.ReadFromUDP(buffer)
require.NoError(t, err)
source := addrPortFromUDPAddr(t, client.LocalAddr())
destination := addrPortFromUDPAddr(t, client.RemoteAddr())
fd, err := openSocketDiag()
require.NoError(t, err)
defer syscall.Close(fd)
inode, uid, err := querySocketDiag(fd, packSocketDiagRequest(syscall.AF_INET, syscall.IPPROTO_UDP, source, destination, false))
require.NoError(t, err)
require.NotZero(t, inode)
require.EqualValues(t, os.Getuid(), uid)
}
func addrPortFromUDPAddr(t *testing.T, addr net.Addr) netip.AddrPort {
t.Helper()
udpAddr, ok := addr.(*net.UDPAddr)
require.True(t, ok)
ip, ok := netip.AddrFromSlice(udpAddr.IP)
require.True(t, ok)
return netip.AddrPortFrom(ip.Unmap(), uint16(udpAddr.Port))
}

View File

@@ -28,10 +28,6 @@ func initWin32API() error {
return winiphlpapi.LoadExtendedTable() return winiphlpapi.LoadExtendedTable()
} }
func (s *windowsSearcher) Close() error {
return nil
}
func (s *windowsSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) { func (s *windowsSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
pid, err := winiphlpapi.FindPid(network, source) pid, err := winiphlpapi.FindPid(network, source)
if err != nil { if err != nil {

View File

@@ -46,7 +46,6 @@ const (
ruleItemNetworkIsConstrained ruleItemNetworkIsConstrained
ruleItemNetworkInterfaceAddress ruleItemNetworkInterfaceAddress
ruleItemDefaultInterfaceAddress ruleItemDefaultInterfaceAddress
ruleItemPackageNameRegex
ruleItemFinal uint8 = 0xFF ruleItemFinal uint8 = 0xFF
) )
@@ -216,8 +215,6 @@ func readDefaultRule(reader varbin.Reader, recover bool) (rule option.DefaultHea
rule.ProcessPathRegex, err = readRuleItemString(reader) rule.ProcessPathRegex, err = readRuleItemString(reader)
case ruleItemPackageName: case ruleItemPackageName:
rule.PackageName, err = readRuleItemString(reader) rule.PackageName, err = readRuleItemString(reader)
case ruleItemPackageNameRegex:
rule.PackageNameRegex, err = readRuleItemString(reader)
case ruleItemWIFISSID: case ruleItemWIFISSID:
rule.WIFISSID, err = readRuleItemString(reader) rule.WIFISSID, err = readRuleItemString(reader)
case ruleItemWIFIBSSID: case ruleItemWIFIBSSID:
@@ -397,15 +394,6 @@ func writeDefaultRule(writer varbin.Writer, rule option.DefaultHeadlessRule, gen
return err return err
} }
} }
if len(rule.PackageNameRegex) > 0 {
if generateVersion < C.RuleSetVersion5 {
return E.New("`package_name_regex` rule item is only supported in version 5 or later")
}
err = writeRuleItemString(writer, ruleItemPackageNameRegex, rule.PackageNameRegex)
if err != nil {
return err
}
}
if len(rule.NetworkType) > 0 { if len(rule.NetworkType) > 0 {
if generateVersion < C.RuleSetVersion3 { if generateVersion < C.RuleSetVersion3 {
return E.New("`network_type` rule item is only supported in version 3 or later") return E.New("`network_type` rule item is only supported in version 3 or later")

View File

@@ -1,612 +0,0 @@
package stun
import (
"context"
"crypto/rand"
"encoding/binary"
"fmt"
"net"
"net/netip"
"time"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/bufio/deadline"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
const (
DefaultServer = "stun.voipgate.com:3478"
magicCookie = 0x2112A442
headerSize = 20
bindingRequest = 0x0001
bindingSuccessResponse = 0x0101
bindingErrorResponse = 0x0111
attrMappedAddress = 0x0001
attrChangeRequest = 0x0003
attrErrorCode = 0x0009
attrXORMappedAddress = 0x0020
attrOtherAddress = 0x802c
familyIPv4 = 0x01
familyIPv6 = 0x02
changeIP = 0x04
changePort = 0x02
defaultRTO = 500 * time.Millisecond
minRTO = 250 * time.Millisecond
maxRetransmit = 2
)
type Phase int32
const (
PhaseBinding Phase = iota
PhaseNATMapping
PhaseNATFiltering
PhaseDone
)
type NATMapping int32
const (
NATMappingUnknown NATMapping = iota
_ // reserved
NATMappingEndpointIndependent
NATMappingAddressDependent
NATMappingAddressAndPortDependent
)
func (m NATMapping) String() string {
switch m {
case NATMappingEndpointIndependent:
return "Endpoint Independent"
case NATMappingAddressDependent:
return "Address Dependent"
case NATMappingAddressAndPortDependent:
return "Address and Port Dependent"
default:
return "Unknown"
}
}
type NATFiltering int32
const (
NATFilteringUnknown NATFiltering = iota
NATFilteringEndpointIndependent
NATFilteringAddressDependent
NATFilteringAddressAndPortDependent
)
func (f NATFiltering) String() string {
switch f {
case NATFilteringEndpointIndependent:
return "Endpoint Independent"
case NATFilteringAddressDependent:
return "Address Dependent"
case NATFilteringAddressAndPortDependent:
return "Address and Port Dependent"
default:
return "Unknown"
}
}
type TransactionID [12]byte
type Options struct {
Server string
Dialer N.Dialer
Context context.Context
OnProgress func(Progress)
}
type Progress struct {
Phase Phase
ExternalAddr string
LatencyMs int32
NATMapping NATMapping
NATFiltering NATFiltering
}
type Result struct {
ExternalAddr string
LatencyMs int32
NATMapping NATMapping
NATFiltering NATFiltering
NATTypeSupported bool
}
type parsedResponse struct {
xorMappedAddr netip.AddrPort
mappedAddr netip.AddrPort
otherAddr netip.AddrPort
}
func (r *parsedResponse) externalAddr() (netip.AddrPort, bool) {
if r.xorMappedAddr.IsValid() {
return r.xorMappedAddr, true
}
if r.mappedAddr.IsValid() {
return r.mappedAddr, true
}
return netip.AddrPort{}, false
}
type stunAttribute struct {
typ uint16
value []byte
}
func newTransactionID() TransactionID {
var id TransactionID
_, _ = rand.Read(id[:])
return id
}
func buildBindingRequest(txID TransactionID, attrs ...stunAttribute) []byte {
attrLen := 0
for _, attr := range attrs {
attrLen += 4 + len(attr.value) + paddingLen(len(attr.value))
}
buf := make([]byte, headerSize+attrLen)
binary.BigEndian.PutUint16(buf[0:2], bindingRequest)
binary.BigEndian.PutUint16(buf[2:4], uint16(attrLen))
binary.BigEndian.PutUint32(buf[4:8], magicCookie)
copy(buf[8:20], txID[:])
offset := headerSize
for _, attr := range attrs {
binary.BigEndian.PutUint16(buf[offset:offset+2], attr.typ)
binary.BigEndian.PutUint16(buf[offset+2:offset+4], uint16(len(attr.value)))
copy(buf[offset+4:offset+4+len(attr.value)], attr.value)
offset += 4 + len(attr.value) + paddingLen(len(attr.value))
}
return buf
}
func changeRequestAttr(flags byte) stunAttribute {
return stunAttribute{
typ: attrChangeRequest,
value: []byte{0, 0, 0, flags},
}
}
func parseResponse(data []byte, expectedTxID TransactionID) (*parsedResponse, error) {
if len(data) < headerSize {
return nil, E.New("response too short")
}
msgType := binary.BigEndian.Uint16(data[0:2])
if msgType&0xC000 != 0 {
return nil, E.New("invalid STUN message: top 2 bits not zero")
}
cookie := binary.BigEndian.Uint32(data[4:8])
if cookie != magicCookie {
return nil, E.New("invalid magic cookie")
}
var txID TransactionID
copy(txID[:], data[8:20])
if txID != expectedTxID {
return nil, E.New("transaction ID mismatch")
}
msgLen := int(binary.BigEndian.Uint16(data[2:4]))
if msgLen > len(data)-headerSize {
return nil, E.New("message length exceeds data")
}
attrData := data[headerSize : headerSize+msgLen]
if msgType == bindingErrorResponse {
return nil, parseErrorResponse(attrData)
}
if msgType != bindingSuccessResponse {
return nil, E.New("unexpected message type: ", fmt.Sprintf("0x%04x", msgType))
}
resp := &parsedResponse{}
offset := 0
for offset+4 <= len(attrData) {
attrType := binary.BigEndian.Uint16(attrData[offset : offset+2])
attrLen := int(binary.BigEndian.Uint16(attrData[offset+2 : offset+4]))
if offset+4+attrLen > len(attrData) {
break
}
attrValue := attrData[offset+4 : offset+4+attrLen]
switch attrType {
case attrXORMappedAddress:
addr, err := parseXORMappedAddress(attrValue, txID)
if err == nil {
resp.xorMappedAddr = addr
}
case attrMappedAddress:
addr, err := parseMappedAddress(attrValue)
if err == nil {
resp.mappedAddr = addr
}
case attrOtherAddress:
addr, err := parseMappedAddress(attrValue)
if err == nil {
resp.otherAddr = addr
}
}
offset += 4 + attrLen + paddingLen(attrLen)
}
return resp, nil
}
func parseErrorResponse(data []byte) error {
offset := 0
for offset+4 <= len(data) {
attrType := binary.BigEndian.Uint16(data[offset : offset+2])
attrLen := int(binary.BigEndian.Uint16(data[offset+2 : offset+4]))
if offset+4+attrLen > len(data) {
break
}
if attrType == attrErrorCode && attrLen >= 4 {
attrValue := data[offset+4 : offset+4+attrLen]
class := int(attrValue[2] & 0x07)
number := int(attrValue[3])
code := class*100 + number
if attrLen > 4 {
return E.New("STUN error ", code, ": ", string(attrValue[4:]))
}
return E.New("STUN error ", code)
}
offset += 4 + attrLen + paddingLen(attrLen)
}
return E.New("STUN error response")
}
func parseXORMappedAddress(data []byte, txID TransactionID) (netip.AddrPort, error) {
if len(data) < 4 {
return netip.AddrPort{}, E.New("XOR-MAPPED-ADDRESS too short")
}
family := data[1]
xPort := binary.BigEndian.Uint16(data[2:4])
port := xPort ^ uint16(magicCookie>>16)
switch family {
case familyIPv4:
if len(data) < 8 {
return netip.AddrPort{}, E.New("XOR-MAPPED-ADDRESS IPv4 too short")
}
var ip [4]byte
binary.BigEndian.PutUint32(ip[:], binary.BigEndian.Uint32(data[4:8])^magicCookie)
return netip.AddrPortFrom(netip.AddrFrom4(ip), port), nil
case familyIPv6:
if len(data) < 20 {
return netip.AddrPort{}, E.New("XOR-MAPPED-ADDRESS IPv6 too short")
}
var ip [16]byte
var xorKey [16]byte
binary.BigEndian.PutUint32(xorKey[0:4], magicCookie)
copy(xorKey[4:16], txID[:])
for i := range 16 {
ip[i] = data[4+i] ^ xorKey[i]
}
return netip.AddrPortFrom(netip.AddrFrom16(ip), port), nil
default:
return netip.AddrPort{}, E.New("unknown address family: ", family)
}
}
func parseMappedAddress(data []byte) (netip.AddrPort, error) {
if len(data) < 4 {
return netip.AddrPort{}, E.New("MAPPED-ADDRESS too short")
}
family := data[1]
port := binary.BigEndian.Uint16(data[2:4])
switch family {
case familyIPv4:
if len(data) < 8 {
return netip.AddrPort{}, E.New("MAPPED-ADDRESS IPv4 too short")
}
return netip.AddrPortFrom(
netip.AddrFrom4([4]byte{data[4], data[5], data[6], data[7]}), port,
), nil
case familyIPv6:
if len(data) < 20 {
return netip.AddrPort{}, E.New("MAPPED-ADDRESS IPv6 too short")
}
var ip [16]byte
copy(ip[:], data[4:20])
return netip.AddrPortFrom(netip.AddrFrom16(ip), port), nil
default:
return netip.AddrPort{}, E.New("unknown address family: ", family)
}
}
func roundTrip(conn net.PacketConn, addr net.Addr, txID TransactionID, attrs []stunAttribute, rto time.Duration) (*parsedResponse, time.Duration, error) {
request := buildBindingRequest(txID, attrs...)
currentRTO := rto
retransmitCount := 0
sendTime := time.Now()
_, err := conn.WriteTo(request, addr)
if err != nil {
return nil, 0, E.Cause(err, "send STUN request")
}
buf := make([]byte, 1024)
for {
err = conn.SetReadDeadline(sendTime.Add(currentRTO))
if err != nil {
return nil, 0, E.Cause(err, "set read deadline")
}
n, _, readErr := conn.ReadFrom(buf)
if readErr != nil {
if E.IsTimeout(readErr) && retransmitCount < maxRetransmit {
retransmitCount++
currentRTO *= 2
sendTime = time.Now()
_, err = conn.WriteTo(request, addr)
if err != nil {
return nil, 0, E.Cause(err, "retransmit STUN request")
}
continue
}
return nil, 0, E.Cause(readErr, "read STUN response")
}
if n < headerSize || buf[0]&0xC0 != 0 ||
binary.BigEndian.Uint32(buf[4:8]) != magicCookie {
continue
}
var receivedTxID TransactionID
copy(receivedTxID[:], buf[8:20])
if receivedTxID != txID {
continue
}
latency := time.Since(sendTime)
resp, parseErr := parseResponse(buf[:n], txID)
if parseErr != nil {
return nil, 0, parseErr
}
return resp, latency, nil
}
}
func Run(options Options) (*Result, error) {
ctx := options.Context
if ctx == nil {
ctx = context.Background()
}
server := options.Server
if server == "" {
server = DefaultServer
}
serverSocksaddr := M.ParseSocksaddr(server)
if serverSocksaddr.Port == 0 {
serverSocksaddr.Port = 3478
}
reportProgress := options.OnProgress
if reportProgress == nil {
reportProgress = func(Progress) {}
}
var (
packetConn net.PacketConn
serverAddr net.Addr
err error
)
if options.Dialer != nil {
packetConn, err = options.Dialer.ListenPacket(ctx, serverSocksaddr)
if err != nil {
return nil, E.Cause(err, "create UDP socket")
}
serverAddr = serverSocksaddr
} else {
serverUDPAddr, resolveErr := net.ResolveUDPAddr("udp", serverSocksaddr.String())
if resolveErr != nil {
return nil, E.Cause(resolveErr, "resolve STUN server")
}
packetConn, err = net.ListenPacket("udp", "")
if err != nil {
return nil, E.Cause(err, "create UDP socket")
}
serverAddr = serverUDPAddr
}
defer func() {
_ = packetConn.Close()
}()
if deadline.NeedAdditionalReadDeadline(packetConn) {
packetConn = deadline.NewPacketConn(bufio.NewPacketConn(packetConn))
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
rto := defaultRTO
// Phase 1: Binding
reportProgress(Progress{Phase: PhaseBinding})
txID := newTransactionID()
resp, latency, err := roundTrip(packetConn, serverAddr, txID, nil, rto)
if err != nil {
return nil, E.Cause(err, "binding request")
}
rto = max(minRTO, 3*latency)
externalAddr, ok := resp.externalAddr()
if !ok {
return nil, E.New("no mapped address in response")
}
result := &Result{
ExternalAddr: externalAddr.String(),
LatencyMs: int32(latency.Milliseconds()),
}
reportProgress(Progress{
Phase: PhaseBinding,
ExternalAddr: result.ExternalAddr,
LatencyMs: result.LatencyMs,
})
otherAddr := resp.otherAddr
if !otherAddr.IsValid() {
result.NATTypeSupported = false
reportProgress(Progress{
Phase: PhaseDone,
ExternalAddr: result.ExternalAddr,
LatencyMs: result.LatencyMs,
})
return result, nil
}
result.NATTypeSupported = true
select {
case <-ctx.Done():
return result, nil
default:
}
// Phase 2: NAT Mapping Detection (RFC 5780 Section 4.3)
reportProgress(Progress{
Phase: PhaseNATMapping,
ExternalAddr: result.ExternalAddr,
LatencyMs: result.LatencyMs,
})
result.NATMapping = detectNATMapping(
packetConn, serverSocksaddr.Port, externalAddr, otherAddr, rto,
)
reportProgress(Progress{
Phase: PhaseNATMapping,
ExternalAddr: result.ExternalAddr,
LatencyMs: result.LatencyMs,
NATMapping: result.NATMapping,
})
select {
case <-ctx.Done():
return result, nil
default:
}
// Phase 3: NAT Filtering Detection (RFC 5780 Section 4.4)
reportProgress(Progress{
Phase: PhaseNATFiltering,
ExternalAddr: result.ExternalAddr,
LatencyMs: result.LatencyMs,
NATMapping: result.NATMapping,
})
result.NATFiltering = detectNATFiltering(packetConn, serverAddr, rto)
reportProgress(Progress{
Phase: PhaseDone,
ExternalAddr: result.ExternalAddr,
LatencyMs: result.LatencyMs,
NATMapping: result.NATMapping,
NATFiltering: result.NATFiltering,
})
return result, nil
}
func detectNATMapping(
conn net.PacketConn,
serverPort uint16,
externalAddr netip.AddrPort,
otherAddr netip.AddrPort,
rto time.Duration,
) NATMapping {
// Mapping Test II: Send to other_ip:server_port
testIIAddr := net.UDPAddrFromAddrPort(
netip.AddrPortFrom(otherAddr.Addr(), serverPort),
)
txID2 := newTransactionID()
resp2, _, err := roundTrip(conn, testIIAddr, txID2, nil, rto)
if err != nil {
return NATMappingUnknown
}
externalAddr2, ok := resp2.externalAddr()
if !ok {
return NATMappingUnknown
}
if externalAddr == externalAddr2 {
return NATMappingEndpointIndependent
}
// Mapping Test III: Send to other_ip:other_port
testIIIAddr := net.UDPAddrFromAddrPort(otherAddr)
txID3 := newTransactionID()
resp3, _, err := roundTrip(conn, testIIIAddr, txID3, nil, rto)
if err != nil {
return NATMappingUnknown
}
externalAddr3, ok := resp3.externalAddr()
if !ok {
return NATMappingUnknown
}
if externalAddr2 == externalAddr3 {
return NATMappingAddressDependent
}
return NATMappingAddressAndPortDependent
}
func detectNATFiltering(
conn net.PacketConn,
serverAddr net.Addr,
rto time.Duration,
) NATFiltering {
// Filtering Test II: Request response from different IP and port
txID := newTransactionID()
_, _, err := roundTrip(conn, serverAddr, txID,
[]stunAttribute{changeRequestAttr(changeIP | changePort)}, rto)
if err == nil {
return NATFilteringEndpointIndependent
}
// Filtering Test III: Request response from different port only
txID = newTransactionID()
_, _, err = roundTrip(conn, serverAddr, txID,
[]stunAttribute{changeRequestAttr(changePort)}, rto)
if err == nil {
return NATFilteringAddressDependent
}
return NATFilteringAddressAndPortDependent
}
func paddingLen(n int) int {
if n%4 == 0 {
return 0
}
return 4 - n%4
}

View File

@@ -38,6 +38,37 @@ func (w *acmeWrapper) Close() error {
return nil return nil
} }
type acmeLogWriter struct {
logger logger.Logger
}
func (w *acmeLogWriter) Write(p []byte) (n int, err error) {
logLine := strings.ReplaceAll(string(p), " ", ": ")
switch {
case strings.HasPrefix(logLine, "error: "):
w.logger.Error(logLine[7:])
case strings.HasPrefix(logLine, "warn: "):
w.logger.Warn(logLine[6:])
case strings.HasPrefix(logLine, "info: "):
w.logger.Info(logLine[6:])
case strings.HasPrefix(logLine, "debug: "):
w.logger.Debug(logLine[7:])
default:
w.logger.Debug(logLine)
}
return len(p), nil
}
func (w *acmeLogWriter) Sync() error {
return nil
}
func encoderConfig() zapcore.EncoderConfig {
config := zap.NewProductionEncoderConfig()
config.TimeKey = zapcore.OmitKey
return config
}
func startACME(ctx context.Context, logger logger.Logger, options option.InboundACMEOptions) (*tls.Config, adapter.SimpleLifecycle, error) { func startACME(ctx context.Context, logger logger.Logger, options option.InboundACMEOptions) (*tls.Config, adapter.SimpleLifecycle, error) {
var acmeServer string var acmeServer string
switch options.Provider { switch options.Provider {
@@ -60,8 +91,8 @@ func startACME(ctx context.Context, logger logger.Logger, options option.Inbound
storage = certmagic.Default.Storage storage = certmagic.Default.Storage
} }
zapLogger := zap.New(zapcore.NewCore( zapLogger := zap.New(zapcore.NewCore(
zapcore.NewConsoleEncoder(ACMEEncoderConfig()), zapcore.NewConsoleEncoder(encoderConfig()),
&ACMELogWriter{Logger: logger}, &acmeLogWriter{logger: logger},
zap.DebugLevel, zap.DebugLevel,
)) ))
config := &certmagic.Config{ config := &certmagic.Config{
@@ -127,7 +158,7 @@ func startACME(ctx context.Context, logger logger.Logger, options option.Inbound
} else { } else {
tlsConfig = &tls.Config{ tlsConfig = &tls.Config{
GetCertificate: config.GetCertificate, GetCertificate: config.GetCertificate,
NextProtos: []string{C.ACMETLS1Protocol}, NextProtos: []string{ACMETLS1Protocol},
} }
} }
return tlsConfig, &acmeWrapper{ctx: ctx, cfg: config, cache: cache, domain: options.Domain}, nil return tlsConfig, &acmeWrapper{ctx: ctx, cfg: config, cache: cache, domain: options.Domain}, nil

View File

@@ -1,3 +1,3 @@
package constant package tls
const ACMETLS1Protocol = "acme-tls/1" const ACMETLS1Protocol = "acme-tls/1"

View File

@@ -1,41 +0,0 @@
package tls
import (
"strings"
"github.com/sagernet/sing/common/logger"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type ACMELogWriter struct {
Logger logger.Logger
}
func (w *ACMELogWriter) Write(p []byte) (n int, err error) {
logLine := strings.ReplaceAll(string(p), " ", ": ")
switch {
case strings.HasPrefix(logLine, "error: "):
w.Logger.Error(logLine[7:])
case strings.HasPrefix(logLine, "warn: "):
w.Logger.Warn(logLine[6:])
case strings.HasPrefix(logLine, "info: "):
w.Logger.Info(logLine[6:])
case strings.HasPrefix(logLine, "debug: "):
w.Logger.Debug(logLine[7:])
default:
w.Logger.Debug(logLine)
}
return len(p), nil
}
func (w *ACMELogWriter) Sync() error {
return nil
}
func ACMEEncoderConfig() zapcore.EncoderConfig {
config := zap.NewProductionEncoderConfig()
config.TimeKey = zapcore.OmitKey
return config
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,205 +0,0 @@
//go:build darwin && cgo
package tls
import (
"context"
stdtls "crypto/tls"
"net"
"testing"
"time"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/json/badoption"
"github.com/sagernet/sing/common/logger"
)
const appleTLSTestTimeout = 5 * time.Second
type appleTLSServerResult struct {
state stdtls.ConnectionState
err error
}
func TestAppleClientHandshakeAppliesALPNAndVersion(t *testing.T) {
serverCertificate, serverCertificatePEM := newAppleTestCertificate(t, "localhost")
serverResult, serverAddress := startAppleTLSTestServer(t, &stdtls.Config{
Certificates: []stdtls.Certificate{serverCertificate},
MinVersion: stdtls.VersionTLS12,
MaxVersion: stdtls.VersionTLS12,
NextProtos: []string{"h2"},
})
clientConn, err := newAppleTestClientConn(t, serverAddress, option.OutboundTLSOptions{
Enabled: true,
Engine: "apple",
ServerName: "localhost",
MinVersion: "1.2",
MaxVersion: "1.2",
ALPN: badoption.Listable[string]{"h2"},
Certificate: badoption.Listable[string]{serverCertificatePEM},
})
if err != nil {
t.Fatal(err)
}
defer clientConn.Close()
clientState := clientConn.ConnectionState()
if clientState.Version != stdtls.VersionTLS12 {
t.Fatalf("unexpected negotiated version: %x", clientState.Version)
}
if clientState.NegotiatedProtocol != "h2" {
t.Fatalf("unexpected negotiated protocol: %q", clientState.NegotiatedProtocol)
}
result := <-serverResult
if result.err != nil {
t.Fatal(result.err)
}
if result.state.Version != stdtls.VersionTLS12 {
t.Fatalf("server negotiated unexpected version: %x", result.state.Version)
}
if result.state.NegotiatedProtocol != "h2" {
t.Fatalf("server negotiated unexpected protocol: %q", result.state.NegotiatedProtocol)
}
}
func TestAppleClientHandshakeRejectsVersionMismatch(t *testing.T) {
serverCertificate, serverCertificatePEM := newAppleTestCertificate(t, "localhost")
serverResult, serverAddress := startAppleTLSTestServer(t, &stdtls.Config{
Certificates: []stdtls.Certificate{serverCertificate},
MinVersion: stdtls.VersionTLS13,
MaxVersion: stdtls.VersionTLS13,
})
clientConn, err := newAppleTestClientConn(t, serverAddress, option.OutboundTLSOptions{
Enabled: true,
Engine: "apple",
ServerName: "localhost",
MaxVersion: "1.2",
Certificate: badoption.Listable[string]{serverCertificatePEM},
})
if err == nil {
clientConn.Close()
t.Fatal("expected version mismatch handshake to fail")
}
if result := <-serverResult; result.err == nil {
t.Fatal("expected server handshake to fail on version mismatch")
}
}
func TestAppleClientHandshakeRejectsServerNameMismatch(t *testing.T) {
serverCertificate, serverCertificatePEM := newAppleTestCertificate(t, "localhost")
serverResult, serverAddress := startAppleTLSTestServer(t, &stdtls.Config{
Certificates: []stdtls.Certificate{serverCertificate},
})
clientConn, err := newAppleTestClientConn(t, serverAddress, option.OutboundTLSOptions{
Enabled: true,
Engine: "apple",
ServerName: "example.com",
Certificate: badoption.Listable[string]{serverCertificatePEM},
})
if err == nil {
clientConn.Close()
t.Fatal("expected server name mismatch handshake to fail")
}
if result := <-serverResult; result.err == nil {
t.Fatal("expected server handshake to fail on server name mismatch")
}
}
func newAppleTestCertificate(t *testing.T, serverName string) (stdtls.Certificate, string) {
t.Helper()
privateKeyPEM, certificatePEM, err := GenerateCertificate(nil, nil, time.Now, serverName, time.Now().Add(time.Hour))
if err != nil {
t.Fatal(err)
}
certificate, err := stdtls.X509KeyPair(certificatePEM, privateKeyPEM)
if err != nil {
t.Fatal(err)
}
return certificate, string(certificatePEM)
}
func startAppleTLSTestServer(t *testing.T, tlsConfig *stdtls.Config) (<-chan appleTLSServerResult, string) {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
listener.Close()
})
if tcpListener, isTCP := listener.(*net.TCPListener); isTCP {
err = tcpListener.SetDeadline(time.Now().Add(appleTLSTestTimeout))
if err != nil {
t.Fatal(err)
}
}
result := make(chan appleTLSServerResult, 1)
go func() {
defer close(result)
conn, err := listener.Accept()
if err != nil {
result <- appleTLSServerResult{err: err}
return
}
defer conn.Close()
err = conn.SetDeadline(time.Now().Add(appleTLSTestTimeout))
if err != nil {
result <- appleTLSServerResult{err: err}
return
}
tlsConn := stdtls.Server(conn, tlsConfig)
defer tlsConn.Close()
err = tlsConn.Handshake()
if err != nil {
result <- appleTLSServerResult{err: err}
return
}
result <- appleTLSServerResult{state: tlsConn.ConnectionState()}
}()
return result, listener.Addr().String()
}
func newAppleTestClientConn(t *testing.T, serverAddress string, options option.OutboundTLSOptions) (Conn, error) {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), appleTLSTestTimeout)
t.Cleanup(cancel)
clientConfig, err := NewClientWithOptions(ClientOptions{
Context: ctx,
Logger: logger.NOP(),
ServerAddress: "",
Options: options,
})
if err != nil {
return nil, err
}
conn, err := net.DialTimeout("tcp", serverAddress, appleTLSTestTimeout)
if err != nil {
return nil, err
}
tlsConn, err := ClientHandshake(ctx, conn, clientConfig)
if err != nil {
conn.Close()
return nil, err
}
return tlsConn, nil
}

View File

@@ -1,15 +0,0 @@
//go:build !darwin || !cgo
package tls
import (
"context"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
)
func newAppleClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
return nil, E.New("Apple TLS engine is not available on non-Apple platforms")
}

View File

@@ -8,16 +8,14 @@ import (
"os" "os"
"github.com/sagernet/sing-box/common/badtls" "github.com/sagernet/sing-box/common/badtls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls" aTLS "github.com/sagernet/sing/common/tls"
) )
var errMissingServerName = E.New("missing server_name or insecure=true")
func NewDialerFromOptions(ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) { func NewDialerFromOptions(ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) {
if !options.Enabled { if !options.Enabled {
return dialer, nil return dialer, nil
@@ -44,12 +42,11 @@ func NewClient(ctx context.Context, logger logger.ContextLogger, serverAddress s
} }
type ClientOptions struct { type ClientOptions struct {
Context context.Context Context context.Context
Logger logger.ContextLogger Logger logger.ContextLogger
ServerAddress string ServerAddress string
Options option.OutboundTLSOptions Options option.OutboundTLSOptions
AllowEmptyServerName bool KTLSCompatible bool
KTLSCompatible bool
} }
func NewClientWithOptions(options ClientOptions) (Config, error) { func NewClientWithOptions(options ClientOptions) (Config, error) {
@@ -64,22 +61,17 @@ func NewClientWithOptions(options ClientOptions) (Config, error) {
if options.Options.KernelRx { if options.Options.KernelRx {
options.Logger.Warn("enabling kTLS RX will definitely reduce performance, please checkout https://sing-box.sagernet.org/configuration/shared/tls/#kernel_rx") options.Logger.Warn("enabling kTLS RX will definitely reduce performance, please checkout https://sing-box.sagernet.org/configuration/shared/tls/#kernel_rx")
} }
switch options.Options.Engine {
case "", "go":
case "apple":
return newAppleClient(options.Context, options.Logger, options.ServerAddress, options.Options, options.AllowEmptyServerName)
default:
return nil, E.New("unknown tls engine: ", options.Options.Engine)
}
if options.Options.Reality != nil && options.Options.Reality.Enabled { if options.Options.Reality != nil && options.Options.Reality.Enabled {
return newRealityClient(options.Context, options.Logger, options.ServerAddress, options.Options, options.AllowEmptyServerName) return NewRealityClient(options.Context, options.Logger, options.ServerAddress, options.Options)
} else if options.Options.UTLS != nil && options.Options.UTLS.Enabled { } else if options.Options.UTLS != nil && options.Options.UTLS.Enabled {
return newUTLSClient(options.Context, options.Logger, options.ServerAddress, options.Options, options.AllowEmptyServerName) return NewUTLSClient(options.Context, options.Logger, options.ServerAddress, options.Options)
} }
return newSTDClient(options.Context, options.Logger, options.ServerAddress, options.Options, options.AllowEmptyServerName) return NewSTDClient(options.Context, options.Logger, options.ServerAddress, options.Options)
} }
func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, error) { func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, error) {
ctx, cancel := context.WithTimeout(ctx, C.TCPTimeout)
defer cancel()
tlsConn, err := aTLS.ClientHandshake(ctx, conn, config) tlsConn, err := aTLS.ClientHandshake(ctx, conn, config)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -52,15 +52,11 @@ type RealityClientConfig struct {
} }
func NewRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) { func NewRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
return newRealityClient(ctx, logger, serverAddress, options, false)
}
func newRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
if options.UTLS == nil || !options.UTLS.Enabled { if options.UTLS == nil || !options.UTLS.Enabled {
return nil, E.New("uTLS is required by reality client") return nil, E.New("uTLS is required by reality client")
} }
uClient, err := newUTLSClient(ctx, logger, serverAddress, options, allowEmptyServerName) uClient, err := NewUTLSClient(ctx, logger, serverAddress, options)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -112,14 +108,6 @@ func (e *RealityClientConfig) SetNextProtos(nextProto []string) {
e.uClient.SetNextProtos(nextProto) e.uClient.SetNextProtos(nextProto)
} }
func (e *RealityClientConfig) HandshakeTimeout() time.Duration {
return e.uClient.HandshakeTimeout()
}
func (e *RealityClientConfig) SetHandshakeTimeout(timeout time.Duration) {
e.uClient.SetHandshakeTimeout(timeout)
}
func (e *RealityClientConfig) STDConfig() (*STDConfig, error) { func (e *RealityClientConfig) STDConfig() (*STDConfig, error) {
return nil, E.New("unsupported usage for reality") return nil, E.New("unsupported usage for reality")
} }

View File

@@ -26,17 +26,12 @@ import (
var _ ServerConfigCompat = (*RealityServerConfig)(nil) var _ ServerConfigCompat = (*RealityServerConfig)(nil)
type RealityServerConfig struct { type RealityServerConfig struct {
config *utls.RealityConfig config *utls.RealityConfig
handshakeTimeout time.Duration
} }
func NewRealityServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) { func NewRealityServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) {
var tlsConfig utls.RealityConfig var tlsConfig utls.RealityConfig
if options.CertificateProvider != nil {
return nil, E.New("certificate_provider is unavailable in reality")
}
//nolint:staticcheck
if options.ACME != nil && len(options.ACME.Domain) > 0 { if options.ACME != nil && len(options.ACME.Domain) > 0 {
return nil, E.New("acme is unavailable in reality") return nil, E.New("acme is unavailable in reality")
} }
@@ -131,16 +126,7 @@ func NewRealityServer(ctx context.Context, logger log.ContextLogger, options opt
if options.ECH != nil && options.ECH.Enabled { if options.ECH != nil && options.ECH.Enabled {
return nil, E.New("Reality is conflict with ECH") return nil, E.New("Reality is conflict with ECH")
} }
var handshakeTimeout time.Duration var config ServerConfig = &RealityServerConfig{&tlsConfig}
if options.HandshakeTimeout > 0 {
handshakeTimeout = options.HandshakeTimeout.Build()
} else {
handshakeTimeout = C.TCPTimeout
}
var config ServerConfig = &RealityServerConfig{
config: &tlsConfig,
handshakeTimeout: handshakeTimeout,
}
if options.KernelTx || options.KernelRx { if options.KernelTx || options.KernelRx {
if !C.IsLinux { if !C.IsLinux {
return nil, E.New("kTLS is only supported on Linux") return nil, E.New("kTLS is only supported on Linux")
@@ -171,14 +157,6 @@ func (c *RealityServerConfig) SetNextProtos(nextProto []string) {
c.config.NextProtos = nextProto c.config.NextProtos = nextProto
} }
func (c *RealityServerConfig) HandshakeTimeout() time.Duration {
return c.handshakeTimeout
}
func (c *RealityServerConfig) SetHandshakeTimeout(timeout time.Duration) {
c.handshakeTimeout = timeout
}
func (c *RealityServerConfig) STDConfig() (*tls.Config, error) { func (c *RealityServerConfig) STDConfig() (*tls.Config, error) {
return nil, E.New("unsupported usage for reality") return nil, E.New("unsupported usage for reality")
} }
@@ -209,8 +187,7 @@ func (c *RealityServerConfig) ServerHandshake(ctx context.Context, conn net.Conn
func (c *RealityServerConfig) Clone() Config { func (c *RealityServerConfig) Clone() Config {
return &RealityServerConfig{ return &RealityServerConfig{
config: c.config.Clone(), config: c.config.Clone(),
handshakeTimeout: c.handshakeTimeout,
} }
} }

View File

@@ -46,11 +46,8 @@ func NewServerWithOptions(options ServerOptions) (ServerConfig, error) {
} }
func ServerHandshake(ctx context.Context, conn net.Conn, config ServerConfig) (Conn, error) { func ServerHandshake(ctx context.Context, conn net.Conn, config ServerConfig) (Conn, error) {
if config.HandshakeTimeout() == 0 { ctx, cancel := context.WithTimeout(ctx, C.TCPTimeout)
var cancel context.CancelFunc defer cancel()
ctx, cancel = context.WithTimeout(ctx, C.TCPTimeout)
defer cancel()
}
tlsConn, err := aTLS.ServerHandshake(ctx, conn, config) tlsConn, err := aTLS.ServerHandshake(ctx, conn, config)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -24,30 +24,16 @@ import (
type STDClientConfig struct { type STDClientConfig struct {
ctx context.Context ctx context.Context
config *tls.Config config *tls.Config
serverName string
disableSNI bool
verifyServerName bool
handshakeTimeout time.Duration
fragment bool fragment bool
fragmentFallbackDelay time.Duration fragmentFallbackDelay time.Duration
recordFragment bool recordFragment bool
} }
func (c *STDClientConfig) ServerName() string { func (c *STDClientConfig) ServerName() string {
return c.serverName return c.config.ServerName
} }
func (c *STDClientConfig) SetServerName(serverName string) { func (c *STDClientConfig) SetServerName(serverName string) {
c.serverName = serverName
if c.disableSNI {
c.config.ServerName = ""
if c.verifyServerName {
c.config.VerifyConnection = verifyConnection(c.config.RootCAs, c.config.Time, serverName)
} else {
c.config.VerifyConnection = nil
}
return
}
c.config.ServerName = serverName c.config.ServerName = serverName
} }
@@ -59,14 +45,6 @@ func (c *STDClientConfig) SetNextProtos(nextProto []string) {
c.config.NextProtos = nextProto c.config.NextProtos = nextProto
} }
func (c *STDClientConfig) HandshakeTimeout() time.Duration {
return c.handshakeTimeout
}
func (c *STDClientConfig) SetHandshakeTimeout(timeout time.Duration) {
c.handshakeTimeout = timeout
}
func (c *STDClientConfig) STDConfig() (*STDConfig, error) { func (c *STDClientConfig) STDConfig() (*STDConfig, error) {
return c.config, nil return c.config, nil
} }
@@ -79,19 +57,13 @@ func (c *STDClientConfig) Client(conn net.Conn) (Conn, error) {
} }
func (c *STDClientConfig) Clone() Config { func (c *STDClientConfig) Clone() Config {
cloned := &STDClientConfig{ return &STDClientConfig{
ctx: c.ctx, ctx: c.ctx,
config: c.config.Clone(), config: c.config.Clone(),
serverName: c.serverName,
disableSNI: c.disableSNI,
verifyServerName: c.verifyServerName,
handshakeTimeout: c.handshakeTimeout,
fragment: c.fragment, fragment: c.fragment,
fragmentFallbackDelay: c.fragmentFallbackDelay, fragmentFallbackDelay: c.fragmentFallbackDelay,
recordFragment: c.recordFragment, recordFragment: c.recordFragment,
} }
cloned.SetServerName(cloned.serverName)
return cloned
} }
func (c *STDClientConfig) ECHConfigList() []byte { func (c *STDClientConfig) ECHConfigList() []byte {
@@ -103,27 +75,41 @@ func (c *STDClientConfig) SetECHConfigList(EncryptedClientHelloConfigList []byte
} }
func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) { func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
return newSTDClient(ctx, logger, serverAddress, options, false)
}
func newSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
var serverName string var serverName string
if options.ServerName != "" { if options.ServerName != "" {
serverName = options.ServerName serverName = options.ServerName
} else if serverAddress != "" { } else if serverAddress != "" {
serverName = serverAddress serverName = serverAddress
} }
if serverName == "" && !options.Insecure && !allowEmptyServerName { if serverName == "" && !options.Insecure {
return nil, errMissingServerName return nil, E.New("missing server_name or insecure=true")
} }
var tlsConfig tls.Config var tlsConfig tls.Config
tlsConfig.Time = ntp.TimeFuncFromContext(ctx) tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx) tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx)
if !options.DisableSNI {
tlsConfig.ServerName = serverName
}
if options.Insecure { if options.Insecure {
tlsConfig.InsecureSkipVerify = options.Insecure tlsConfig.InsecureSkipVerify = options.Insecure
} else if options.DisableSNI { } else if options.DisableSNI {
tlsConfig.InsecureSkipVerify = true tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyConnection = func(state tls.ConnectionState) error {
verifyOptions := x509.VerifyOptions{
Roots: tlsConfig.RootCAs,
DNSName: serverName,
Intermediates: x509.NewCertPool(),
}
for _, cert := range state.PeerCertificates[1:] {
verifyOptions.Intermediates.AddCert(cert)
}
if tlsConfig.Time != nil {
verifyOptions.CurrentTime = tlsConfig.Time()
}
_, err := state.PeerCertificates[0].Verify(verifyOptions)
return err
}
} }
if len(options.CertificatePublicKeySHA256) > 0 { if len(options.CertificatePublicKeySHA256) > 0 {
if len(options.Certificate) > 0 || options.CertificatePath != "" { if len(options.Certificate) > 0 || options.CertificatePath != "" {
@@ -212,24 +198,7 @@ func newSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddres
} else if len(clientCertificate) > 0 || len(clientKey) > 0 { } else if len(clientCertificate) > 0 || len(clientKey) > 0 {
return nil, E.New("client certificate and client key must be provided together") return nil, E.New("client certificate and client key must be provided together")
} }
var handshakeTimeout time.Duration var config Config = &STDClientConfig{ctx, &tlsConfig, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment}
if options.HandshakeTimeout > 0 {
handshakeTimeout = options.HandshakeTimeout.Build()
} else {
handshakeTimeout = C.TCPTimeout
}
var config Config = &STDClientConfig{
ctx: ctx,
config: &tlsConfig,
serverName: serverName,
disableSNI: options.DisableSNI,
verifyServerName: options.DisableSNI && !options.Insecure,
handshakeTimeout: handshakeTimeout,
fragment: options.Fragment,
fragmentFallbackDelay: time.Duration(options.FragmentFallbackDelay),
recordFragment: options.RecordFragment,
}
config.SetServerName(serverName)
if options.ECH != nil && options.ECH.Enabled { if options.ECH != nil && options.ECH.Enabled {
var err error var err error
config, err = parseECHClientConfig(ctx, config.(ECHCapableConfig), options) config, err = parseECHClientConfig(ctx, config.(ECHCapableConfig), options)
@@ -251,27 +220,6 @@ func newSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddres
return config, nil return config, nil
} }
func verifyConnection(rootCAs *x509.CertPool, timeFunc func() time.Time, serverName string) func(state tls.ConnectionState) error {
return func(state tls.ConnectionState) error {
if serverName == "" {
return errMissingServerName
}
verifyOptions := x509.VerifyOptions{
Roots: rootCAs,
DNSName: serverName,
Intermediates: x509.NewCertPool(),
}
for _, cert := range state.PeerCertificates[1:] {
verifyOptions.Intermediates.AddCert(cert)
}
if timeFunc != nil {
verifyOptions.CurrentTime = timeFunc()
}
_, err := state.PeerCertificates[0].Verify(verifyOptions)
return err
}
}
func verifyPublicKeySHA256(knownHashValues [][]byte, rawCerts [][]byte, timeFunc func() time.Time) error { func verifyPublicKeySHA256(knownHashValues [][]byte, rawCerts [][]byte, timeFunc func() time.Time) error {
leafCertificate, err := x509.ParseCertificate(rawCerts[0]) leafCertificate, err := x509.ParseCertificate(rawCerts[0])
if err != nil { if err != nil {

View File

@@ -13,88 +13,19 @@ import (
"github.com/sagernet/fswatch" "github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/deprecated"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp" "github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/service"
) )
var errInsecureUnused = E.New("tls: insecure unused") var errInsecureUnused = E.New("tls: insecure unused")
type managedCertificateProvider interface {
adapter.CertificateProvider
adapter.SimpleLifecycle
}
type sharedCertificateProvider struct {
tag string
manager adapter.CertificateProviderManager
provider adapter.CertificateProviderService
}
func (p *sharedCertificateProvider) Start() error {
provider, found := p.manager.Get(p.tag)
if !found {
return E.New("certificate provider not found: ", p.tag)
}
p.provider = provider
return nil
}
func (p *sharedCertificateProvider) Close() error {
return nil
}
func (p *sharedCertificateProvider) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return p.provider.GetCertificate(hello)
}
func (p *sharedCertificateProvider) GetACMENextProtos() []string {
return getACMENextProtos(p.provider)
}
type inlineCertificateProvider struct {
provider adapter.CertificateProviderService
}
func (p *inlineCertificateProvider) Start() error {
for _, stage := range adapter.ListStartStages {
err := adapter.LegacyStart(p.provider, stage)
if err != nil {
return err
}
}
return nil
}
func (p *inlineCertificateProvider) Close() error {
return p.provider.Close()
}
func (p *inlineCertificateProvider) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return p.provider.GetCertificate(hello)
}
func (p *inlineCertificateProvider) GetACMENextProtos() []string {
return getACMENextProtos(p.provider)
}
func getACMENextProtos(provider adapter.CertificateProvider) []string {
if acmeProvider, isACME := provider.(adapter.ACMECertificateProvider); isACME {
return acmeProvider.GetACMENextProtos()
}
return nil
}
type STDServerConfig struct { type STDServerConfig struct {
access sync.RWMutex access sync.RWMutex
config *tls.Config config *tls.Config
handshakeTimeout time.Duration
logger log.Logger logger log.Logger
certificateProvider managedCertificateProvider
acmeService adapter.SimpleLifecycle acmeService adapter.SimpleLifecycle
certificate []byte certificate []byte
key []byte key []byte
@@ -122,17 +53,18 @@ func (c *STDServerConfig) SetServerName(serverName string) {
func (c *STDServerConfig) NextProtos() []string { func (c *STDServerConfig) NextProtos() []string {
c.access.RLock() c.access.RLock()
defer c.access.RUnlock() defer c.access.RUnlock()
if c.hasACMEALPN() && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == C.ACMETLS1Protocol { if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol {
return c.config.NextProtos[1:] return c.config.NextProtos[1:]
} else {
return c.config.NextProtos
} }
return c.config.NextProtos
} }
func (c *STDServerConfig) SetNextProtos(nextProto []string) { func (c *STDServerConfig) SetNextProtos(nextProto []string) {
c.access.Lock() c.access.Lock()
defer c.access.Unlock() defer c.access.Unlock()
config := c.config.Clone() config := c.config.Clone()
if c.hasACMEALPN() && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == C.ACMETLS1Protocol { if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol {
config.NextProtos = append(c.config.NextProtos[:1], nextProto...) config.NextProtos = append(c.config.NextProtos[:1], nextProto...)
} else { } else {
config.NextProtos = nextProto config.NextProtos = nextProto
@@ -140,30 +72,6 @@ func (c *STDServerConfig) SetNextProtos(nextProto []string) {
c.config = config c.config = config
} }
func (c *STDServerConfig) HandshakeTimeout() time.Duration {
c.access.RLock()
defer c.access.RUnlock()
return c.handshakeTimeout
}
func (c *STDServerConfig) SetHandshakeTimeout(timeout time.Duration) {
c.access.Lock()
defer c.access.Unlock()
c.handshakeTimeout = timeout
}
func (c *STDServerConfig) hasACMEALPN() bool {
if c.acmeService != nil {
return true
}
if c.certificateProvider != nil {
if acmeProvider, isACME := c.certificateProvider.(adapter.ACMECertificateProvider); isACME {
return len(acmeProvider.GetACMENextProtos()) > 0
}
}
return false
}
func (c *STDServerConfig) STDConfig() (*STDConfig, error) { func (c *STDServerConfig) STDConfig() (*STDConfig, error) {
return c.config, nil return c.config, nil
} }
@@ -178,45 +86,20 @@ func (c *STDServerConfig) Server(conn net.Conn) (Conn, error) {
func (c *STDServerConfig) Clone() Config { func (c *STDServerConfig) Clone() Config {
return &STDServerConfig{ return &STDServerConfig{
config: c.config.Clone(), config: c.config.Clone(),
handshakeTimeout: c.handshakeTimeout,
} }
} }
func (c *STDServerConfig) Start() error { func (c *STDServerConfig) Start() error {
if c.certificateProvider != nil {
err := c.certificateProvider.Start()
if err != nil {
return err
}
if acmeProvider, isACME := c.certificateProvider.(adapter.ACMECertificateProvider); isACME {
nextProtos := acmeProvider.GetACMENextProtos()
if len(nextProtos) > 0 {
c.access.Lock()
config := c.config.Clone()
mergedNextProtos := append([]string{}, nextProtos...)
for _, nextProto := range config.NextProtos {
if !common.Contains(mergedNextProtos, nextProto) {
mergedNextProtos = append(mergedNextProtos, nextProto)
}
}
config.NextProtos = mergedNextProtos
c.config = config
c.access.Unlock()
}
}
}
if c.acmeService != nil { if c.acmeService != nil {
err := c.acmeService.Start() return c.acmeService.Start()
} else {
err := c.startWatcher()
if err != nil { if err != nil {
return err c.logger.Warn("create fsnotify watcher: ", err)
} }
return nil
} }
err := c.startWatcher()
if err != nil {
c.logger.Warn("create fsnotify watcher: ", err)
}
return nil
} }
func (c *STDServerConfig) startWatcher() error { func (c *STDServerConfig) startWatcher() error {
@@ -320,34 +203,23 @@ func (c *STDServerConfig) certificateUpdated(path string) error {
} }
func (c *STDServerConfig) Close() error { func (c *STDServerConfig) Close() error {
return common.Close(c.certificateProvider, c.acmeService, c.watcher) if c.acmeService != nil {
return c.acmeService.Close()
}
if c.watcher != nil {
return c.watcher.Close()
}
return nil
} }
func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) { func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) {
if !options.Enabled { if !options.Enabled {
return nil, nil return nil, nil
} }
//nolint:staticcheck
if options.CertificateProvider != nil && options.ACME != nil {
return nil, E.New("certificate_provider and acme are mutually exclusive")
}
var tlsConfig *tls.Config var tlsConfig *tls.Config
var certificateProvider managedCertificateProvider
var acmeService adapter.SimpleLifecycle var acmeService adapter.SimpleLifecycle
var err error var err error
if options.CertificateProvider != nil { if options.ACME != nil && len(options.ACME.Domain) > 0 {
certificateProvider, err = newCertificateProvider(ctx, logger, options.CertificateProvider)
if err != nil {
return nil, err
}
tlsConfig = &tls.Config{
GetCertificate: certificateProvider.GetCertificate,
}
if options.Insecure {
return nil, errInsecureUnused
}
} else if options.ACME != nil && len(options.ACME.Domain) > 0 { //nolint:staticcheck
deprecated.Report(ctx, deprecated.OptionInlineACME)
//nolint:staticcheck //nolint:staticcheck
tlsConfig, acmeService, err = startACME(ctx, logger, common.PtrValueOrDefault(options.ACME)) tlsConfig, acmeService, err = startACME(ctx, logger, common.PtrValueOrDefault(options.ACME))
if err != nil { if err != nil {
@@ -400,7 +272,7 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.
certificate []byte certificate []byte
key []byte key []byte
) )
if certificateProvider == nil && acmeService == nil { if acmeService == nil {
if len(options.Certificate) > 0 { if len(options.Certificate) > 0 {
certificate = []byte(strings.Join(options.Certificate, "\n")) certificate = []byte(strings.Join(options.Certificate, "\n"))
} else if options.CertificatePath != "" { } else if options.CertificatePath != "" {
@@ -485,17 +357,9 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.
return nil, err return nil, err
} }
} }
var handshakeTimeout time.Duration
if options.HandshakeTimeout > 0 {
handshakeTimeout = options.HandshakeTimeout.Build()
} else {
handshakeTimeout = C.TCPTimeout
}
serverConfig := &STDServerConfig{ serverConfig := &STDServerConfig{
config: tlsConfig, config: tlsConfig,
handshakeTimeout: handshakeTimeout,
logger: logger, logger: logger,
certificateProvider: certificateProvider,
acmeService: acmeService, acmeService: acmeService,
certificate: certificate, certificate: certificate,
key: key, key: key,
@@ -505,8 +369,8 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.
echKeyPath: echKeyPath, echKeyPath: echKeyPath,
} }
serverConfig.config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { serverConfig.config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
serverConfig.access.RLock() serverConfig.access.Lock()
defer serverConfig.access.RUnlock() defer serverConfig.access.Unlock()
return serverConfig.config, nil return serverConfig.config, nil
} }
var config ServerConfig = serverConfig var config ServerConfig = serverConfig
@@ -523,27 +387,3 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.
} }
return config, nil return config, nil
} }
func newCertificateProvider(ctx context.Context, logger log.ContextLogger, options *option.CertificateProviderOptions) (managedCertificateProvider, error) {
if options.IsShared() {
manager := service.FromContext[adapter.CertificateProviderManager](ctx)
if manager == nil {
return nil, E.New("missing certificate provider manager in context")
}
return &sharedCertificateProvider{
tag: options.Tag,
manager: manager,
}, nil
}
registry := service.FromContext[adapter.CertificateProviderRegistry](ctx)
if registry == nil {
return nil, E.New("missing certificate provider registry in context")
}
provider, err := registry.Create(ctx, logger, "", options.Type, options.Options)
if err != nil {
return nil, E.Cause(err, "create inline certificate provider")
}
return &inlineCertificateProvider{
provider: provider,
}, nil
}

View File

@@ -28,10 +28,6 @@ import (
type UTLSClientConfig struct { type UTLSClientConfig struct {
ctx context.Context ctx context.Context
config *utls.Config config *utls.Config
serverName string
disableSNI bool
verifyServerName bool
handshakeTimeout time.Duration
id utls.ClientHelloID id utls.ClientHelloID
fragment bool fragment bool
fragmentFallbackDelay time.Duration fragmentFallbackDelay time.Duration
@@ -39,20 +35,10 @@ type UTLSClientConfig struct {
} }
func (c *UTLSClientConfig) ServerName() string { func (c *UTLSClientConfig) ServerName() string {
return c.serverName return c.config.ServerName
} }
func (c *UTLSClientConfig) SetServerName(serverName string) { func (c *UTLSClientConfig) SetServerName(serverName string) {
c.serverName = serverName
if c.disableSNI {
c.config.ServerName = ""
if c.verifyServerName {
c.config.InsecureServerNameToVerify = serverName
} else {
c.config.InsecureServerNameToVerify = ""
}
return
}
c.config.ServerName = serverName c.config.ServerName = serverName
} }
@@ -67,14 +53,6 @@ func (c *UTLSClientConfig) SetNextProtos(nextProto []string) {
c.config.NextProtos = nextProto c.config.NextProtos = nextProto
} }
func (c *UTLSClientConfig) HandshakeTimeout() time.Duration {
return c.handshakeTimeout
}
func (c *UTLSClientConfig) SetHandshakeTimeout(timeout time.Duration) {
c.handshakeTimeout = timeout
}
func (c *UTLSClientConfig) STDConfig() (*STDConfig, error) { func (c *UTLSClientConfig) STDConfig() (*STDConfig, error) {
return nil, E.New("unsupported usage for uTLS") return nil, E.New("unsupported usage for uTLS")
} }
@@ -91,20 +69,9 @@ func (c *UTLSClientConfig) SetSessionIDGenerator(generator func(clientHello []by
} }
func (c *UTLSClientConfig) Clone() Config { func (c *UTLSClientConfig) Clone() Config {
cloned := &UTLSClientConfig{ return &UTLSClientConfig{
ctx: c.ctx, c.ctx, c.config.Clone(), c.id, c.fragment, c.fragmentFallbackDelay, c.recordFragment,
config: c.config.Clone(),
serverName: c.serverName,
disableSNI: c.disableSNI,
verifyServerName: c.verifyServerName,
handshakeTimeout: c.handshakeTimeout,
id: c.id,
fragment: c.fragment,
fragmentFallbackDelay: c.fragmentFallbackDelay,
recordFragment: c.recordFragment,
} }
cloned.SetServerName(cloned.serverName)
return cloned
} }
func (c *UTLSClientConfig) ECHConfigList() []byte { func (c *UTLSClientConfig) ECHConfigList() []byte {
@@ -176,29 +143,29 @@ func (c *utlsALPNWrapper) HandshakeContext(ctx context.Context) error {
} }
func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) { func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
return newUTLSClient(ctx, logger, serverAddress, options, false)
}
func newUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
var serverName string var serverName string
if options.ServerName != "" { if options.ServerName != "" {
serverName = options.ServerName serverName = options.ServerName
} else if serverAddress != "" { } else if serverAddress != "" {
serverName = serverAddress serverName = serverAddress
} }
if serverName == "" && !options.Insecure && !allowEmptyServerName { if serverName == "" && !options.Insecure {
return nil, errMissingServerName return nil, E.New("missing server_name or insecure=true")
} }
var tlsConfig utls.Config var tlsConfig utls.Config
tlsConfig.Time = ntp.TimeFuncFromContext(ctx) tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx) tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx)
if !options.DisableSNI {
tlsConfig.ServerName = serverName
}
if options.Insecure { if options.Insecure {
tlsConfig.InsecureSkipVerify = options.Insecure tlsConfig.InsecureSkipVerify = options.Insecure
} else if options.DisableSNI { } else if options.DisableSNI {
if options.Reality != nil && options.Reality.Enabled { if options.Reality != nil && options.Reality.Enabled {
return nil, E.New("disable_sni is unsupported in reality") return nil, E.New("disable_sni is unsupported in reality")
} }
tlsConfig.InsecureServerNameToVerify = serverName
} }
if len(options.CertificatePublicKeySHA256) > 0 { if len(options.CertificatePublicKeySHA256) > 0 {
if len(options.Certificate) > 0 || options.CertificatePath != "" { if len(options.Certificate) > 0 || options.CertificatePath != "" {
@@ -284,29 +251,11 @@ func newUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddre
} else if len(clientCertificate) > 0 || len(clientKey) > 0 { } else if len(clientCertificate) > 0 || len(clientKey) > 0 {
return nil, E.New("client certificate and client key must be provided together") return nil, E.New("client certificate and client key must be provided together")
} }
var handshakeTimeout time.Duration
if options.HandshakeTimeout > 0 {
handshakeTimeout = options.HandshakeTimeout.Build()
} else {
handshakeTimeout = C.TCPTimeout
}
id, err := uTLSClientHelloID(options.UTLS.Fingerprint) id, err := uTLSClientHelloID(options.UTLS.Fingerprint)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var config Config = &UTLSClientConfig{ var config Config = &UTLSClientConfig{ctx, &tlsConfig, id, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment}
ctx: ctx,
config: &tlsConfig,
serverName: serverName,
disableSNI: options.DisableSNI,
verifyServerName: options.DisableSNI && !options.Insecure,
handshakeTimeout: handshakeTimeout,
id: id,
fragment: options.Fragment,
fragmentFallbackDelay: time.Duration(options.FragmentFallbackDelay),
recordFragment: options.RecordFragment,
}
config.SetServerName(serverName)
if options.ECH != nil && options.ECH.Enabled { if options.ECH != nil && options.ECH.Enabled {
if options.Reality != nil && options.Reality.Enabled { if options.Reality != nil && options.Reality.Enabled {
return nil, E.New("Reality is conflict with ECH") return nil, E.New("Reality is conflict with ECH")

View File

@@ -12,18 +12,10 @@ import (
) )
func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) { func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
return newUTLSClient(ctx, logger, serverAddress, options, false)
}
func newUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
return nil, E.New(`uTLS is not included in this build, rebuild with -tags with_utls`) return nil, E.New(`uTLS is not included in this build, rebuild with -tags with_utls`)
} }
func NewRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) { func NewRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
return newRealityClient(ctx, logger, serverAddress, options, false)
}
func newRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions, allowEmptyServerName bool) (Config, error) {
return nil, E.New(`uTLS, which is required by reality is not included in this build, rebuild with -tags with_utls`) return nil, E.New(`uTLS, which is required by reality is not included in this build, rebuild with -tags with_utls`)
} }

View File

@@ -15,18 +15,19 @@ const (
) )
const ( const (
DNSTypeLegacy = "legacy" DNSTypeLegacy = "legacy"
DNSTypeUDP = "udp" DNSTypeLegacyRcode = "legacy_rcode"
DNSTypeTCP = "tcp" DNSTypeUDP = "udp"
DNSTypeTLS = "tls" DNSTypeTCP = "tcp"
DNSTypeHTTPS = "https" DNSTypeTLS = "tls"
DNSTypeQUIC = "quic" DNSTypeHTTPS = "https"
DNSTypeHTTP3 = "h3" DNSTypeQUIC = "quic"
DNSTypeLocal = "local" DNSTypeHTTP3 = "h3"
DNSTypeHosts = "hosts" DNSTypeLocal = "local"
DNSTypeFakeIP = "fakeip" DNSTypeHosts = "hosts"
DNSTypeDHCP = "dhcp" DNSTypeFakeIP = "fakeip"
DNSTypeTailscale = "tailscale" DNSTypeDHCP = "dhcp"
DNSTypeTailscale = "tailscale"
) )
const ( const (

View File

@@ -1,39 +1,36 @@
package constant package constant
const ( const (
TypeTun = "tun" TypeTun = "tun"
TypeRedirect = "redirect" TypeRedirect = "redirect"
TypeTProxy = "tproxy" TypeTProxy = "tproxy"
TypeDirect = "direct" TypeDirect = "direct"
TypeBlock = "block" TypeBlock = "block"
TypeDNS = "dns" TypeDNS = "dns"
TypeSOCKS = "socks" TypeSOCKS = "socks"
TypeHTTP = "http" TypeHTTP = "http"
TypeMixed = "mixed" TypeMixed = "mixed"
TypeShadowsocks = "shadowsocks" TypeShadowsocks = "shadowsocks"
TypeVMess = "vmess" TypeVMess = "vmess"
TypeTrojan = "trojan" TypeTrojan = "trojan"
TypeNaive = "naive" TypeNaive = "naive"
TypeWireGuard = "wireguard" TypeWireGuard = "wireguard"
TypeHysteria = "hysteria" TypeHysteria = "hysteria"
TypeTor = "tor" TypeTor = "tor"
TypeSSH = "ssh" TypeSSH = "ssh"
TypeShadowTLS = "shadowtls" TypeShadowTLS = "shadowtls"
TypeAnyTLS = "anytls" TypeAnyTLS = "anytls"
TypeShadowsocksR = "shadowsocksr" TypeShadowsocksR = "shadowsocksr"
TypeVLESS = "vless" TypeVLESS = "vless"
TypeTUIC = "tuic" TypeTUIC = "tuic"
TypeHysteria2 = "hysteria2" TypeHysteria2 = "hysteria2"
TypeTailscale = "tailscale" TypeTailscale = "tailscale"
TypeCloudflared = "cloudflared" TypeDERP = "derp"
TypeDERP = "derp" TypeResolved = "resolved"
TypeResolved = "resolved" TypeSSMAPI = "ssm-api"
TypeSSMAPI = "ssm-api" TypeCCM = "ccm"
TypeCCM = "ccm" TypeOCM = "ocm"
TypeOCM = "ocm" TypeOOMKiller = "oom-killer"
TypeOOMKiller = "oom-killer"
TypeACME = "acme"
TypeCloudflareOriginCA = "cloudflare-origin-ca"
) )
const ( const (
@@ -41,6 +38,13 @@ const (
TypeURLTest = "urltest" TypeURLTest = "urltest"
) )
const (
BalancerStrategyLeastUsed = "least-used"
BalancerStrategyRoundRobin = "round-robin"
BalancerStrategyRandom = "random"
BalancerStrategyFallback = "fallback"
)
func ProxyDisplayName(proxyType string) string { func ProxyDisplayName(proxyType string) string {
switch proxyType { switch proxyType {
case TypeTun: case TypeTun:
@@ -91,8 +95,6 @@ func ProxyDisplayName(proxyType string) string {
return "AnyTLS" return "AnyTLS"
case TypeTailscale: case TypeTailscale:
return "Tailscale" return "Tailscale"
case TypeCloudflared:
return "Cloudflared"
case TypeSelector: case TypeSelector:
return "Selector" return "Selector"
case TypeURLTest: case TypeURLTest:

View File

@@ -23,15 +23,12 @@ const (
RuleSetVersion2 RuleSetVersion2
RuleSetVersion3 RuleSetVersion3
RuleSetVersion4 RuleSetVersion4
RuleSetVersion5 RuleSetVersionCurrent = RuleSetVersion4
RuleSetVersionCurrent = RuleSetVersion5
) )
const ( const (
RuleActionTypeRoute = "route" RuleActionTypeRoute = "route"
RuleActionTypeRouteOptions = "route-options" RuleActionTypeRouteOptions = "route-options"
RuleActionTypeEvaluate = "evaluate"
RuleActionTypeRespond = "respond"
RuleActionTypeDirect = "direct" RuleActionTypeDirect = "direct"
RuleActionTypeBypass = "bypass" RuleActionTypeBypass = "bypass"
RuleActionTypeReject = "reject" RuleActionTypeReject = "reject"

View File

@@ -87,17 +87,12 @@ func (s *StartedService) newInstance(profileContent string, overrideOptions *Ove
} }
} }
} }
if s.oomKillerEnabled { if s.oomKiller && C.IsIos {
if !common.Any(options.Services, func(it option.Service) bool { if !common.Any(options.Services, func(it option.Service) bool {
return it.Type == C.TypeOOMKiller return it.Type == C.TypeOOMKiller
}) { }) {
oomOptions := &option.OOMKillerServiceOptions{
KillerDisabled: s.oomKillerDisabled,
MemoryLimitOverride: s.oomMemoryLimit,
}
options.Services = append(options.Services, option.Service{ options.Services = append(options.Services, option.Service{
Type: C.TypeOOMKiller, Type: C.TypeOOMKiller,
Options: oomOptions,
}) })
} }
} }

View File

@@ -5,6 +5,5 @@ type PlatformHandler interface {
ServiceReload() error ServiceReload() error
SystemProxyStatus() (*SystemProxyStatus, error) SystemProxyStatus() (*SystemProxyStatus, error)
SetSystemProxyEnabled(enabled bool) error SetSystemProxyEnabled(enabled bool) error
TriggerNativeCrash() error
WriteDebugMessage(message string) WriteDebugMessage(message string)
} }

View File

@@ -6,20 +6,14 @@ import (
"runtime" "runtime"
"sync" "sync"
"time" "time"
"unsafe"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/networkquality"
"github.com/sagernet/sing-box/common/stun"
"github.com/sagernet/sing-box/common/urltest" "github.com/sagernet/sing-box/common/urltest"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/clashapi" "github.com/sagernet/sing-box/experimental/clashapi"
"github.com/sagernet/sing-box/experimental/clashapi/trafficontrol" "github.com/sagernet/sing-box/experimental/clashapi/trafficontrol"
"github.com/sagernet/sing-box/experimental/deprecated" "github.com/sagernet/sing-box/experimental/deprecated"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/protocol/group" "github.com/sagernet/sing-box/protocol/group"
"github.com/sagernet/sing-box/service/oomkiller"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/batch" "github.com/sagernet/sing/common/batch"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
@@ -30,8 +24,6 @@ import (
"github.com/gofrs/uuid/v5" "github.com/gofrs/uuid/v5"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/emptypb"
) )
@@ -40,12 +32,10 @@ var _ StartedServiceServer = (*StartedService)(nil)
type StartedService struct { type StartedService struct {
ctx context.Context ctx context.Context
// platform adapter.PlatformInterface // platform adapter.PlatformInterface
handler PlatformHandler handler PlatformHandler
debug bool debug bool
logMaxLines int logMaxLines int
oomKillerEnabled bool oomKiller bool
oomKillerDisabled bool
oomMemoryLimit uint64
// workingDirectory string // workingDirectory string
// tempDirectory string // tempDirectory string
// userID int // userID int
@@ -74,12 +64,10 @@ type StartedService struct {
type ServiceOptions struct { type ServiceOptions struct {
Context context.Context Context context.Context
// Platform adapter.PlatformInterface // Platform adapter.PlatformInterface
Handler PlatformHandler Handler PlatformHandler
Debug bool Debug bool
LogMaxLines int LogMaxLines int
OOMKillerEnabled bool OOMKiller bool
OOMKillerDisabled bool
OOMMemoryLimit uint64
// WorkingDirectory string // WorkingDirectory string
// TempDirectory string // TempDirectory string
// UserID int // UserID int
@@ -91,12 +79,10 @@ func NewStartedService(options ServiceOptions) *StartedService {
s := &StartedService{ s := &StartedService{
ctx: options.Context, ctx: options.Context,
// platform: options.Platform, // platform: options.Platform,
handler: options.Handler, handler: options.Handler,
debug: options.Debug, debug: options.Debug,
logMaxLines: options.LogMaxLines, logMaxLines: options.LogMaxLines,
oomKillerEnabled: options.OOMKillerEnabled, oomKiller: options.OOMKiller,
oomKillerDisabled: options.OOMKillerDisabled,
oomMemoryLimit: options.OOMMemoryLimit,
// workingDirectory: options.WorkingDirectory, // workingDirectory: options.WorkingDirectory,
// tempDirectory: options.TempDirectory, // tempDirectory: options.TempDirectory,
// userID: options.UserID, // userID: options.UserID,
@@ -182,7 +168,7 @@ func (s *StartedService) waitForStarted(ctx context.Context) error {
func (s *StartedService) StartOrReloadService(profileContent string, options *OverrideOptions) error { func (s *StartedService) StartOrReloadService(profileContent string, options *OverrideOptions) error {
s.serviceAccess.Lock() s.serviceAccess.Lock()
switch s.serviceStatus.Status { switch s.serviceStatus.Status {
case ServiceStatus_IDLE, ServiceStatus_STARTED, ServiceStatus_STARTING, ServiceStatus_FATAL: case ServiceStatus_IDLE, ServiceStatus_STARTED, ServiceStatus_STARTING:
default: default:
s.serviceAccess.Unlock() s.serviceAccess.Unlock()
return os.ErrInvalid return os.ErrInvalid
@@ -240,14 +226,13 @@ func (s *StartedService) CloseService() error {
return os.ErrInvalid return os.ErrInvalid
} }
s.updateStatus(ServiceStatus_STOPPING) s.updateStatus(ServiceStatus_STOPPING)
instance := s.instance if s.instance != nil {
s.instance = nil err := s.instance.Close()
if instance != nil {
err := instance.Close()
if err != nil { if err != nil {
return s.updateStatusError(err) return s.updateStatusError(err)
} }
} }
s.instance = nil
s.startedAt = time.Time{} s.startedAt = time.Time{}
s.updateStatus(ServiceStatus_IDLE) s.updateStatus(ServiceStatus_IDLE)
s.serviceAccess.Unlock() s.serviceAccess.Unlock()
@@ -696,42 +681,7 @@ func (s *StartedService) SetSystemProxyEnabled(ctx context.Context, request *Set
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &emptypb.Empty{}, nil return nil, err
}
func (s *StartedService) TriggerDebugCrash(ctx context.Context, request *DebugCrashRequest) (*emptypb.Empty, error) {
if !s.debug {
return nil, status.Error(codes.PermissionDenied, "debug crash trigger unavailable")
}
if request == nil {
return nil, status.Error(codes.InvalidArgument, "missing debug crash request")
}
switch request.Type {
case DebugCrashRequest_GO:
time.AfterFunc(200*time.Millisecond, func() {
*(*int)(unsafe.Pointer(uintptr(0))) = 0
})
case DebugCrashRequest_NATIVE:
err := s.handler.TriggerNativeCrash()
if err != nil {
return nil, err
}
default:
return nil, status.Error(codes.InvalidArgument, "unknown debug crash type")
}
return &emptypb.Empty{}, nil
}
func (s *StartedService) TriggerOOMReport(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
instance := s.Instance()
if instance == nil {
return nil, status.Error(codes.FailedPrecondition, "service not started")
}
reporter := service.FromContext[oomkiller.OOMReporter](instance.ctx)
if reporter == nil {
return nil, status.Error(codes.Unavailable, "OOM reporter not available")
}
return &emptypb.Empty{}, reporter.WriteReport(memory.Total())
} }
func (s *StartedService) SubscribeConnections(request *SubscribeConnectionsRequest, server grpc.ServerStreamingServer[ConnectionEvents]) error { func (s *StartedService) SubscribeConnections(request *SubscribeConnectionsRequest, server grpc.ServerStreamingServer[ConnectionEvents]) error {
@@ -999,11 +949,11 @@ func buildConnectionProto(metadata *trafficontrol.TrackerMetadata) *Connection {
var processInfo *ProcessInfo var processInfo *ProcessInfo
if metadata.Metadata.ProcessInfo != nil { if metadata.Metadata.ProcessInfo != nil {
processInfo = &ProcessInfo{ processInfo = &ProcessInfo{
ProcessId: metadata.Metadata.ProcessInfo.ProcessID, ProcessId: metadata.Metadata.ProcessInfo.ProcessID,
UserId: metadata.Metadata.ProcessInfo.UserId, UserId: metadata.Metadata.ProcessInfo.UserId,
UserName: metadata.Metadata.ProcessInfo.UserName, UserName: metadata.Metadata.ProcessInfo.UserName,
ProcessPath: metadata.Metadata.ProcessInfo.ProcessPath, ProcessPath: metadata.Metadata.ProcessInfo.ProcessPath,
PackageNames: metadata.Metadata.ProcessInfo.AndroidPackageNames, PackageName: metadata.Metadata.ProcessInfo.AndroidPackageName,
} }
} }
return &Connection{ return &Connection{
@@ -1068,12 +1018,9 @@ func (s *StartedService) GetDeprecatedWarnings(ctx context.Context, empty *empty
return &DeprecatedWarnings{ return &DeprecatedWarnings{
Warnings: common.Map(notes, func(it deprecated.Note) *DeprecatedWarning { Warnings: common.Map(notes, func(it deprecated.Note) *DeprecatedWarning {
return &DeprecatedWarning{ return &DeprecatedWarning{
Message: it.Message(), Message: it.Message(),
Impending: it.Impending(), Impending: it.Impending(),
MigrationLink: it.MigrationLink, MigrationLink: it.MigrationLink,
Description: it.Description,
DeprecatedVersion: it.DeprecatedVersion,
ScheduledVersion: it.ScheduledVersion,
} }
}), }),
}, nil }, nil
@@ -1085,386 +1032,6 @@ func (s *StartedService) GetStartedAt(ctx context.Context, empty *emptypb.Empty)
return &StartedAt{StartedAt: s.startedAt.UnixMilli()}, nil return &StartedAt{StartedAt: s.startedAt.UnixMilli()}, nil
} }
func (s *StartedService) SubscribeOutbounds(_ *emptypb.Empty, server grpc.ServerStreamingServer[OutboundList]) error {
err := s.waitForStarted(server.Context())
if err != nil {
return err
}
subscription, done, err := s.urlTestObserver.Subscribe()
if err != nil {
return err
}
defer s.urlTestObserver.UnSubscribe(subscription)
for {
s.serviceAccess.RLock()
if s.serviceStatus.Status != ServiceStatus_STARTED {
s.serviceAccess.RUnlock()
return os.ErrInvalid
}
boxService := s.instance
s.serviceAccess.RUnlock()
historyStorage := boxService.urlTestHistoryStorage
var list OutboundList
for _, ob := range boxService.instance.Outbound().Outbounds() {
item := &GroupItem{
Tag: ob.Tag(),
Type: ob.Type(),
}
if history := historyStorage.LoadURLTestHistory(adapter.OutboundTag(ob)); history != nil {
item.UrlTestTime = history.Time.Unix()
item.UrlTestDelay = int32(history.Delay)
}
list.Outbounds = append(list.Outbounds, item)
}
for _, ep := range boxService.instance.Endpoint().Endpoints() {
item := &GroupItem{
Tag: ep.Tag(),
Type: ep.Type(),
}
if history := historyStorage.LoadURLTestHistory(adapter.OutboundTag(ep)); history != nil {
item.UrlTestTime = history.Time.Unix()
item.UrlTestDelay = int32(history.Delay)
}
list.Outbounds = append(list.Outbounds, item)
}
err = server.Send(&list)
if err != nil {
return err
}
select {
case <-subscription:
case <-s.ctx.Done():
return s.ctx.Err()
case <-server.Context().Done():
return server.Context().Err()
case <-done:
return nil
}
}
}
func resolveOutbound(instance *Instance, tag string) (adapter.Outbound, error) {
if tag == "" {
return instance.instance.Outbound().Default(), nil
}
outbound, loaded := instance.instance.Outbound().Outbound(tag)
if !loaded {
return nil, E.New("outbound not found: ", tag)
}
return outbound, nil
}
func (s *StartedService) StartNetworkQualityTest(
request *NetworkQualityTestRequest,
server grpc.ServerStreamingServer[NetworkQualityTestProgress],
) error {
err := s.waitForStarted(server.Context())
if err != nil {
return err
}
s.serviceAccess.RLock()
boxService := s.instance
s.serviceAccess.RUnlock()
outbound, err := resolveOutbound(boxService, request.OutboundTag)
if err != nil {
return err
}
resolvedDialer := dialer.NewResolveDialer(boxService.ctx, outbound, true, "", adapter.DNSQueryOptions{}, 0)
httpClient := networkquality.NewHTTPClient(resolvedDialer)
defer httpClient.CloseIdleConnections()
measurementClientFactory, err := networkquality.NewOptionalHTTP3Factory(resolvedDialer, request.Http3)
if err != nil {
return err
}
result, nqErr := networkquality.Run(networkquality.Options{
ConfigURL: request.ConfigURL,
HTTPClient: httpClient,
NewMeasurementClient: measurementClientFactory,
Serial: request.Serial,
MaxRuntime: time.Duration(request.MaxRuntimeSeconds) * time.Second,
Context: server.Context(),
OnProgress: func(p networkquality.Progress) {
_ = server.Send(&NetworkQualityTestProgress{
Phase: int32(p.Phase),
DownloadCapacity: p.DownloadCapacity,
UploadCapacity: p.UploadCapacity,
DownloadRPM: p.DownloadRPM,
UploadRPM: p.UploadRPM,
IdleLatencyMs: p.IdleLatencyMs,
ElapsedMs: p.ElapsedMs,
DownloadCapacityAccuracy: int32(p.DownloadCapacityAccuracy),
UploadCapacityAccuracy: int32(p.UploadCapacityAccuracy),
DownloadRPMAccuracy: int32(p.DownloadRPMAccuracy),
UploadRPMAccuracy: int32(p.UploadRPMAccuracy),
})
},
})
if nqErr != nil {
return server.Send(&NetworkQualityTestProgress{
IsFinal: true,
Error: nqErr.Error(),
})
}
return server.Send(&NetworkQualityTestProgress{
Phase: int32(networkquality.PhaseDone),
DownloadCapacity: result.DownloadCapacity,
UploadCapacity: result.UploadCapacity,
DownloadRPM: result.DownloadRPM,
UploadRPM: result.UploadRPM,
IdleLatencyMs: result.IdleLatencyMs,
IsFinal: true,
DownloadCapacityAccuracy: int32(result.DownloadCapacityAccuracy),
UploadCapacityAccuracy: int32(result.UploadCapacityAccuracy),
DownloadRPMAccuracy: int32(result.DownloadRPMAccuracy),
UploadRPMAccuracy: int32(result.UploadRPMAccuracy),
})
}
func (s *StartedService) StartSTUNTest(
request *STUNTestRequest,
server grpc.ServerStreamingServer[STUNTestProgress],
) error {
err := s.waitForStarted(server.Context())
if err != nil {
return err
}
s.serviceAccess.RLock()
boxService := s.instance
s.serviceAccess.RUnlock()
outbound, err := resolveOutbound(boxService, request.OutboundTag)
if err != nil {
return err
}
resolvedDialer := dialer.NewResolveDialer(boxService.ctx, outbound, true, "", adapter.DNSQueryOptions{}, 0)
result, stunErr := stun.Run(stun.Options{
Server: request.Server,
Dialer: resolvedDialer,
Context: server.Context(),
OnProgress: func(p stun.Progress) {
_ = server.Send(&STUNTestProgress{
Phase: int32(p.Phase),
ExternalAddr: p.ExternalAddr,
LatencyMs: p.LatencyMs,
NatMapping: int32(p.NATMapping),
NatFiltering: int32(p.NATFiltering),
})
},
})
if stunErr != nil {
return server.Send(&STUNTestProgress{
IsFinal: true,
Error: stunErr.Error(),
})
}
return server.Send(&STUNTestProgress{
Phase: int32(stun.PhaseDone),
ExternalAddr: result.ExternalAddr,
LatencyMs: result.LatencyMs,
NatMapping: int32(result.NATMapping),
NatFiltering: int32(result.NATFiltering),
IsFinal: true,
NatTypeSupported: result.NATTypeSupported,
})
}
func (s *StartedService) SubscribeTailscaleStatus(
_ *emptypb.Empty,
server grpc.ServerStreamingServer[TailscaleStatusUpdate],
) error {
err := s.waitForStarted(server.Context())
if err != nil {
return err
}
s.serviceAccess.RLock()
boxService := s.instance
s.serviceAccess.RUnlock()
endpointManager := service.FromContext[adapter.EndpointManager](boxService.ctx)
if endpointManager == nil {
return status.Error(codes.FailedPrecondition, "endpoint manager not available")
}
type tailscaleEndpoint struct {
tag string
provider adapter.TailscaleEndpoint
}
var endpoints []tailscaleEndpoint
for _, endpoint := range endpointManager.Endpoints() {
if endpoint.Type() != C.TypeTailscale {
continue
}
provider, loaded := endpoint.(adapter.TailscaleEndpoint)
if !loaded {
continue
}
endpoints = append(endpoints, tailscaleEndpoint{
tag: endpoint.Tag(),
provider: provider,
})
}
if len(endpoints) == 0 {
return status.Error(codes.NotFound, "no Tailscale endpoint found")
}
type taggedStatus struct {
tag string
status *adapter.TailscaleEndpointStatus
}
updates := make(chan taggedStatus, len(endpoints))
ctx, cancel := context.WithCancel(server.Context())
defer cancel()
var waitGroup sync.WaitGroup
for _, endpoint := range endpoints {
waitGroup.Add(1)
go func(tag string, provider adapter.TailscaleEndpoint) {
defer waitGroup.Done()
_ = provider.SubscribeTailscaleStatus(ctx, func(endpointStatus *adapter.TailscaleEndpointStatus) {
select {
case updates <- taggedStatus{tag: tag, status: endpointStatus}:
case <-ctx.Done():
}
})
}(endpoint.tag, endpoint.provider)
}
go func() {
waitGroup.Wait()
close(updates)
}()
var tags []string
statuses := make(map[string]*adapter.TailscaleEndpointStatus, len(endpoints))
for update := range updates {
if _, exists := statuses[update.tag]; !exists {
tags = append(tags, update.tag)
}
statuses[update.tag] = update.status
protoEndpoints := make([]*TailscaleEndpointStatus, 0, len(statuses))
for _, tag := range tags {
protoEndpoints = append(protoEndpoints, tailscaleEndpointStatusToProto(tag, statuses[tag]))
}
sendErr := server.Send(&TailscaleStatusUpdate{
Endpoints: protoEndpoints,
})
if sendErr != nil {
return sendErr
}
}
return nil
}
func tailscaleEndpointStatusToProto(tag string, s *adapter.TailscaleEndpointStatus) *TailscaleEndpointStatus {
userGroups := make([]*TailscaleUserGroup, len(s.UserGroups))
for i, group := range s.UserGroups {
peers := make([]*TailscalePeer, len(group.Peers))
for j, peer := range group.Peers {
peers[j] = tailscalePeerToProto(peer)
}
userGroups[i] = &TailscaleUserGroup{
UserID: group.UserID,
LoginName: group.LoginName,
DisplayName: group.DisplayName,
ProfilePicURL: group.ProfilePicURL,
Peers: peers,
}
}
result := &TailscaleEndpointStatus{
EndpointTag: tag,
BackendState: s.BackendState,
AuthURL: s.AuthURL,
NetworkName: s.NetworkName,
MagicDNSSuffix: s.MagicDNSSuffix,
UserGroups: userGroups,
}
if s.Self != nil {
result.Self = tailscalePeerToProto(s.Self)
}
return result
}
func tailscalePeerToProto(peer *adapter.TailscalePeer) *TailscalePeer {
return &TailscalePeer{
HostName: peer.HostName,
DnsName: peer.DNSName,
Os: peer.OS,
TailscaleIPs: peer.TailscaleIPs,
Online: peer.Online,
ExitNode: peer.ExitNode,
ExitNodeOption: peer.ExitNodeOption,
Active: peer.Active,
RxBytes: peer.RxBytes,
TxBytes: peer.TxBytes,
KeyExpiry: peer.KeyExpiry,
}
}
func (s *StartedService) StartTailscalePing(
request *TailscalePingRequest,
server grpc.ServerStreamingServer[TailscalePingResponse],
) error {
err := s.waitForStarted(server.Context())
if err != nil {
return err
}
s.serviceAccess.RLock()
boxService := s.instance
s.serviceAccess.RUnlock()
endpointManager := service.FromContext[adapter.EndpointManager](boxService.ctx)
if endpointManager == nil {
return status.Error(codes.FailedPrecondition, "endpoint manager not available")
}
var provider adapter.TailscaleEndpoint
if request.EndpointTag != "" {
endpoint, loaded := endpointManager.Get(request.EndpointTag)
if !loaded {
return status.Error(codes.NotFound, "endpoint not found: "+request.EndpointTag)
}
if endpoint.Type() != C.TypeTailscale {
return status.Error(codes.InvalidArgument, "endpoint is not Tailscale: "+request.EndpointTag)
}
pingProvider, loaded := endpoint.(adapter.TailscaleEndpoint)
if !loaded {
return status.Error(codes.FailedPrecondition, "endpoint does not support ping")
}
provider = pingProvider
} else {
for _, endpoint := range endpointManager.Endpoints() {
if endpoint.Type() != C.TypeTailscale {
continue
}
pingProvider, loaded := endpoint.(adapter.TailscaleEndpoint)
if loaded {
provider = pingProvider
break
}
}
if provider == nil {
return status.Error(codes.NotFound, "no Tailscale endpoint found")
}
}
return provider.StartTailscalePing(server.Context(), request.PeerIP, func(result *adapter.TailscalePingResult) {
_ = server.Send(&TailscalePingResponse{
LatencyMs: result.LatencyMs,
IsDirect: result.IsDirect,
Endpoint: result.Endpoint,
DerpRegionID: result.DERPRegionID,
DerpRegionCode: result.DERPRegionCode,
Error: result.Error,
})
})
}
func (s *StartedService) mustEmbedUnimplementedStartedServiceServer() { func (s *StartedService) mustEmbedUnimplementedStartedServiceServer() {
} }

File diff suppressed because it is too large Load Diff

View File

@@ -26,20 +26,12 @@ service StartedService {
rpc GetSystemProxyStatus(google.protobuf.Empty) returns(SystemProxyStatus) {} rpc GetSystemProxyStatus(google.protobuf.Empty) returns(SystemProxyStatus) {}
rpc SetSystemProxyEnabled(SetSystemProxyEnabledRequest) returns(google.protobuf.Empty) {} rpc SetSystemProxyEnabled(SetSystemProxyEnabledRequest) returns(google.protobuf.Empty) {}
rpc TriggerDebugCrash(DebugCrashRequest) returns(google.protobuf.Empty) {}
rpc TriggerOOMReport(google.protobuf.Empty) returns(google.protobuf.Empty) {}
rpc SubscribeConnections(SubscribeConnectionsRequest) returns(stream ConnectionEvents) {} rpc SubscribeConnections(SubscribeConnectionsRequest) returns(stream ConnectionEvents) {}
rpc CloseConnection(CloseConnectionRequest) returns(google.protobuf.Empty) {} rpc CloseConnection(CloseConnectionRequest) returns(google.protobuf.Empty) {}
rpc CloseAllConnections(google.protobuf.Empty) returns(google.protobuf.Empty) {} rpc CloseAllConnections(google.protobuf.Empty) returns(google.protobuf.Empty) {}
rpc GetDeprecatedWarnings(google.protobuf.Empty) returns(DeprecatedWarnings) {} rpc GetDeprecatedWarnings(google.protobuf.Empty) returns(DeprecatedWarnings) {}
rpc GetStartedAt(google.protobuf.Empty) returns(StartedAt) {} rpc GetStartedAt(google.protobuf.Empty) returns(StartedAt) {}
rpc SubscribeOutbounds(google.protobuf.Empty) returns (stream OutboundList) {}
rpc StartNetworkQualityTest(NetworkQualityTestRequest) returns (stream NetworkQualityTestProgress) {}
rpc StartSTUNTest(STUNTestRequest) returns (stream STUNTestProgress) {}
rpc SubscribeTailscaleStatus(google.protobuf.Empty) returns (stream TailscaleStatusUpdate) {}
rpc StartTailscalePing(TailscalePingRequest) returns (stream TailscalePingResponse) {}
} }
message ServiceStatus { message ServiceStatus {
@@ -149,15 +141,6 @@ message SetSystemProxyEnabledRequest {
bool enabled = 1; bool enabled = 1;
} }
message DebugCrashRequest {
enum Type {
GO = 0;
NATIVE = 1;
}
Type type = 1;
}
message SubscribeConnectionsRequest { message SubscribeConnectionsRequest {
int64 interval = 1; int64 interval = 1;
} }
@@ -212,7 +195,7 @@ message ProcessInfo {
int32 userId = 2; int32 userId = 2;
string userName = 3; string userName = 3;
string processPath = 4; string processPath = 4;
repeated string packageNames = 5; string packageName = 5;
} }
message CloseConnectionRequest { message CloseConnectionRequest {
@@ -227,105 +210,8 @@ message DeprecatedWarning {
string message = 1; string message = 1;
bool impending = 2; bool impending = 2;
string migrationLink = 3; string migrationLink = 3;
string description = 4;
string deprecatedVersion = 5;
string scheduledVersion = 6;
} }
message StartedAt { message StartedAt {
int64 startedAt = 1; int64 startedAt = 1;
} }
message OutboundList {
repeated GroupItem outbounds = 1;
}
message NetworkQualityTestRequest {
string configURL = 1;
string outboundTag = 2;
bool serial = 3;
int32 maxRuntimeSeconds = 4;
bool http3 = 5;
}
message NetworkQualityTestProgress {
int32 phase = 1;
int64 downloadCapacity = 2;
int64 uploadCapacity = 3;
int32 downloadRPM = 4;
int32 uploadRPM = 5;
int32 idleLatencyMs = 6;
int64 elapsedMs = 7;
bool isFinal = 8;
string error = 9;
int32 downloadCapacityAccuracy = 10;
int32 uploadCapacityAccuracy = 11;
int32 downloadRPMAccuracy = 12;
int32 uploadRPMAccuracy = 13;
}
message STUNTestRequest {
string server = 1;
string outboundTag = 2;
}
message STUNTestProgress {
int32 phase = 1;
string externalAddr = 2;
int32 latencyMs = 3;
int32 natMapping = 4;
int32 natFiltering = 5;
bool isFinal = 6;
string error = 7;
bool natTypeSupported = 8;
}
message TailscaleStatusUpdate {
repeated TailscaleEndpointStatus endpoints = 1;
}
message TailscaleEndpointStatus {
string endpointTag = 1;
string backendState = 2;
string authURL = 3;
string networkName = 4;
string magicDNSSuffix = 5;
TailscalePeer self = 6;
repeated TailscaleUserGroup userGroups = 7;
}
message TailscaleUserGroup {
int64 userID = 1;
string loginName = 2;
string displayName = 3;
string profilePicURL = 4;
repeated TailscalePeer peers = 5;
}
message TailscalePeer {
string hostName = 1;
string dnsName = 2;
string os = 3;
repeated string tailscaleIPs = 4;
bool online = 5;
bool exitNode = 6;
bool exitNodeOption = 7;
bool active = 8;
int64 rxBytes = 9;
int64 txBytes = 10;
int64 keyExpiry = 11;
}
message TailscalePingRequest {
string endpointTag = 1;
string peerIP = 2;
}
message TailscalePingResponse {
double latencyMs = 1;
bool isDirect = 2;
string endpoint = 3;
int32 derpRegionID = 4;
string derpRegionCode = 5;
string error = 6;
}

View File

@@ -15,34 +15,27 @@ import (
const _ = grpc.SupportPackageIsVersion9 const _ = grpc.SupportPackageIsVersion9
const ( const (
StartedService_StopService_FullMethodName = "/daemon.StartedService/StopService" StartedService_StopService_FullMethodName = "/daemon.StartedService/StopService"
StartedService_ReloadService_FullMethodName = "/daemon.StartedService/ReloadService" StartedService_ReloadService_FullMethodName = "/daemon.StartedService/ReloadService"
StartedService_SubscribeServiceStatus_FullMethodName = "/daemon.StartedService/SubscribeServiceStatus" StartedService_SubscribeServiceStatus_FullMethodName = "/daemon.StartedService/SubscribeServiceStatus"
StartedService_SubscribeLog_FullMethodName = "/daemon.StartedService/SubscribeLog" StartedService_SubscribeLog_FullMethodName = "/daemon.StartedService/SubscribeLog"
StartedService_GetDefaultLogLevel_FullMethodName = "/daemon.StartedService/GetDefaultLogLevel" StartedService_GetDefaultLogLevel_FullMethodName = "/daemon.StartedService/GetDefaultLogLevel"
StartedService_ClearLogs_FullMethodName = "/daemon.StartedService/ClearLogs" StartedService_ClearLogs_FullMethodName = "/daemon.StartedService/ClearLogs"
StartedService_SubscribeStatus_FullMethodName = "/daemon.StartedService/SubscribeStatus" StartedService_SubscribeStatus_FullMethodName = "/daemon.StartedService/SubscribeStatus"
StartedService_SubscribeGroups_FullMethodName = "/daemon.StartedService/SubscribeGroups" StartedService_SubscribeGroups_FullMethodName = "/daemon.StartedService/SubscribeGroups"
StartedService_GetClashModeStatus_FullMethodName = "/daemon.StartedService/GetClashModeStatus" StartedService_GetClashModeStatus_FullMethodName = "/daemon.StartedService/GetClashModeStatus"
StartedService_SubscribeClashMode_FullMethodName = "/daemon.StartedService/SubscribeClashMode" StartedService_SubscribeClashMode_FullMethodName = "/daemon.StartedService/SubscribeClashMode"
StartedService_SetClashMode_FullMethodName = "/daemon.StartedService/SetClashMode" StartedService_SetClashMode_FullMethodName = "/daemon.StartedService/SetClashMode"
StartedService_URLTest_FullMethodName = "/daemon.StartedService/URLTest" StartedService_URLTest_FullMethodName = "/daemon.StartedService/URLTest"
StartedService_SelectOutbound_FullMethodName = "/daemon.StartedService/SelectOutbound" StartedService_SelectOutbound_FullMethodName = "/daemon.StartedService/SelectOutbound"
StartedService_SetGroupExpand_FullMethodName = "/daemon.StartedService/SetGroupExpand" StartedService_SetGroupExpand_FullMethodName = "/daemon.StartedService/SetGroupExpand"
StartedService_GetSystemProxyStatus_FullMethodName = "/daemon.StartedService/GetSystemProxyStatus" StartedService_GetSystemProxyStatus_FullMethodName = "/daemon.StartedService/GetSystemProxyStatus"
StartedService_SetSystemProxyEnabled_FullMethodName = "/daemon.StartedService/SetSystemProxyEnabled" StartedService_SetSystemProxyEnabled_FullMethodName = "/daemon.StartedService/SetSystemProxyEnabled"
StartedService_TriggerDebugCrash_FullMethodName = "/daemon.StartedService/TriggerDebugCrash" StartedService_SubscribeConnections_FullMethodName = "/daemon.StartedService/SubscribeConnections"
StartedService_TriggerOOMReport_FullMethodName = "/daemon.StartedService/TriggerOOMReport" StartedService_CloseConnection_FullMethodName = "/daemon.StartedService/CloseConnection"
StartedService_SubscribeConnections_FullMethodName = "/daemon.StartedService/SubscribeConnections" StartedService_CloseAllConnections_FullMethodName = "/daemon.StartedService/CloseAllConnections"
StartedService_CloseConnection_FullMethodName = "/daemon.StartedService/CloseConnection" StartedService_GetDeprecatedWarnings_FullMethodName = "/daemon.StartedService/GetDeprecatedWarnings"
StartedService_CloseAllConnections_FullMethodName = "/daemon.StartedService/CloseAllConnections" StartedService_GetStartedAt_FullMethodName = "/daemon.StartedService/GetStartedAt"
StartedService_GetDeprecatedWarnings_FullMethodName = "/daemon.StartedService/GetDeprecatedWarnings"
StartedService_GetStartedAt_FullMethodName = "/daemon.StartedService/GetStartedAt"
StartedService_SubscribeOutbounds_FullMethodName = "/daemon.StartedService/SubscribeOutbounds"
StartedService_StartNetworkQualityTest_FullMethodName = "/daemon.StartedService/StartNetworkQualityTest"
StartedService_StartSTUNTest_FullMethodName = "/daemon.StartedService/StartSTUNTest"
StartedService_SubscribeTailscaleStatus_FullMethodName = "/daemon.StartedService/SubscribeTailscaleStatus"
StartedService_StartTailscalePing_FullMethodName = "/daemon.StartedService/StartTailscalePing"
) )
// StartedServiceClient is the client API for StartedService service. // StartedServiceClient is the client API for StartedService service.
@@ -65,18 +58,11 @@ type StartedServiceClient interface {
SetGroupExpand(ctx context.Context, in *SetGroupExpandRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) SetGroupExpand(ctx context.Context, in *SetGroupExpandRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
GetSystemProxyStatus(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*SystemProxyStatus, error) GetSystemProxyStatus(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*SystemProxyStatus, error)
SetSystemProxyEnabled(ctx context.Context, in *SetSystemProxyEnabledRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) SetSystemProxyEnabled(ctx context.Context, in *SetSystemProxyEnabledRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
TriggerDebugCrash(ctx context.Context, in *DebugCrashRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
TriggerOOMReport(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*emptypb.Empty, error)
SubscribeConnections(ctx context.Context, in *SubscribeConnectionsRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ConnectionEvents], error) SubscribeConnections(ctx context.Context, in *SubscribeConnectionsRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ConnectionEvents], error)
CloseConnection(ctx context.Context, in *CloseConnectionRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) CloseConnection(ctx context.Context, in *CloseConnectionRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
CloseAllConnections(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*emptypb.Empty, error) CloseAllConnections(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*emptypb.Empty, error)
GetDeprecatedWarnings(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*DeprecatedWarnings, error) GetDeprecatedWarnings(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*DeprecatedWarnings, error)
GetStartedAt(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*StartedAt, error) GetStartedAt(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*StartedAt, error)
SubscribeOutbounds(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (grpc.ServerStreamingClient[OutboundList], error)
StartNetworkQualityTest(ctx context.Context, in *NetworkQualityTestRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[NetworkQualityTestProgress], error)
StartSTUNTest(ctx context.Context, in *STUNTestRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[STUNTestProgress], error)
SubscribeTailscaleStatus(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (grpc.ServerStreamingClient[TailscaleStatusUpdate], error)
StartTailscalePing(ctx context.Context, in *TailscalePingRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[TailscalePingResponse], error)
} }
type startedServiceClient struct { type startedServiceClient struct {
@@ -292,26 +278,6 @@ func (c *startedServiceClient) SetSystemProxyEnabled(ctx context.Context, in *Se
return out, nil return out, nil
} }
func (c *startedServiceClient) TriggerDebugCrash(ctx context.Context, in *DebugCrashRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(emptypb.Empty)
err := c.cc.Invoke(ctx, StartedService_TriggerDebugCrash_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *startedServiceClient) TriggerOOMReport(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*emptypb.Empty, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(emptypb.Empty)
err := c.cc.Invoke(ctx, StartedService_TriggerOOMReport_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *startedServiceClient) SubscribeConnections(ctx context.Context, in *SubscribeConnectionsRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ConnectionEvents], error) { func (c *startedServiceClient) SubscribeConnections(ctx context.Context, in *SubscribeConnectionsRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ConnectionEvents], error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &StartedService_ServiceDesc.Streams[5], StartedService_SubscribeConnections_FullMethodName, cOpts...) stream, err := c.cc.NewStream(ctx, &StartedService_ServiceDesc.Streams[5], StartedService_SubscribeConnections_FullMethodName, cOpts...)
@@ -371,101 +337,6 @@ func (c *startedServiceClient) GetStartedAt(ctx context.Context, in *emptypb.Emp
return out, nil return out, nil
} }
func (c *startedServiceClient) SubscribeOutbounds(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (grpc.ServerStreamingClient[OutboundList], error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &StartedService_ServiceDesc.Streams[6], StartedService_SubscribeOutbounds_FullMethodName, cOpts...)
if err != nil {
return nil, err
}
x := &grpc.GenericClientStream[emptypb.Empty, OutboundList]{ClientStream: stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type StartedService_SubscribeOutboundsClient = grpc.ServerStreamingClient[OutboundList]
func (c *startedServiceClient) StartNetworkQualityTest(ctx context.Context, in *NetworkQualityTestRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[NetworkQualityTestProgress], error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &StartedService_ServiceDesc.Streams[7], StartedService_StartNetworkQualityTest_FullMethodName, cOpts...)
if err != nil {
return nil, err
}
x := &grpc.GenericClientStream[NetworkQualityTestRequest, NetworkQualityTestProgress]{ClientStream: stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type StartedService_StartNetworkQualityTestClient = grpc.ServerStreamingClient[NetworkQualityTestProgress]
func (c *startedServiceClient) StartSTUNTest(ctx context.Context, in *STUNTestRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[STUNTestProgress], error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &StartedService_ServiceDesc.Streams[8], StartedService_StartSTUNTest_FullMethodName, cOpts...)
if err != nil {
return nil, err
}
x := &grpc.GenericClientStream[STUNTestRequest, STUNTestProgress]{ClientStream: stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type StartedService_StartSTUNTestClient = grpc.ServerStreamingClient[STUNTestProgress]
func (c *startedServiceClient) SubscribeTailscaleStatus(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (grpc.ServerStreamingClient[TailscaleStatusUpdate], error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &StartedService_ServiceDesc.Streams[9], StartedService_SubscribeTailscaleStatus_FullMethodName, cOpts...)
if err != nil {
return nil, err
}
x := &grpc.GenericClientStream[emptypb.Empty, TailscaleStatusUpdate]{ClientStream: stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type StartedService_SubscribeTailscaleStatusClient = grpc.ServerStreamingClient[TailscaleStatusUpdate]
func (c *startedServiceClient) StartTailscalePing(ctx context.Context, in *TailscalePingRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[TailscalePingResponse], error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &StartedService_ServiceDesc.Streams[10], StartedService_StartTailscalePing_FullMethodName, cOpts...)
if err != nil {
return nil, err
}
x := &grpc.GenericClientStream[TailscalePingRequest, TailscalePingResponse]{ClientStream: stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type StartedService_StartTailscalePingClient = grpc.ServerStreamingClient[TailscalePingResponse]
// StartedServiceServer is the server API for StartedService service. // StartedServiceServer is the server API for StartedService service.
// All implementations must embed UnimplementedStartedServiceServer // All implementations must embed UnimplementedStartedServiceServer
// for forward compatibility. // for forward compatibility.
@@ -486,18 +357,11 @@ type StartedServiceServer interface {
SetGroupExpand(context.Context, *SetGroupExpandRequest) (*emptypb.Empty, error) SetGroupExpand(context.Context, *SetGroupExpandRequest) (*emptypb.Empty, error)
GetSystemProxyStatus(context.Context, *emptypb.Empty) (*SystemProxyStatus, error) GetSystemProxyStatus(context.Context, *emptypb.Empty) (*SystemProxyStatus, error)
SetSystemProxyEnabled(context.Context, *SetSystemProxyEnabledRequest) (*emptypb.Empty, error) SetSystemProxyEnabled(context.Context, *SetSystemProxyEnabledRequest) (*emptypb.Empty, error)
TriggerDebugCrash(context.Context, *DebugCrashRequest) (*emptypb.Empty, error)
TriggerOOMReport(context.Context, *emptypb.Empty) (*emptypb.Empty, error)
SubscribeConnections(*SubscribeConnectionsRequest, grpc.ServerStreamingServer[ConnectionEvents]) error SubscribeConnections(*SubscribeConnectionsRequest, grpc.ServerStreamingServer[ConnectionEvents]) error
CloseConnection(context.Context, *CloseConnectionRequest) (*emptypb.Empty, error) CloseConnection(context.Context, *CloseConnectionRequest) (*emptypb.Empty, error)
CloseAllConnections(context.Context, *emptypb.Empty) (*emptypb.Empty, error) CloseAllConnections(context.Context, *emptypb.Empty) (*emptypb.Empty, error)
GetDeprecatedWarnings(context.Context, *emptypb.Empty) (*DeprecatedWarnings, error) GetDeprecatedWarnings(context.Context, *emptypb.Empty) (*DeprecatedWarnings, error)
GetStartedAt(context.Context, *emptypb.Empty) (*StartedAt, error) GetStartedAt(context.Context, *emptypb.Empty) (*StartedAt, error)
SubscribeOutbounds(*emptypb.Empty, grpc.ServerStreamingServer[OutboundList]) error
StartNetworkQualityTest(*NetworkQualityTestRequest, grpc.ServerStreamingServer[NetworkQualityTestProgress]) error
StartSTUNTest(*STUNTestRequest, grpc.ServerStreamingServer[STUNTestProgress]) error
SubscribeTailscaleStatus(*emptypb.Empty, grpc.ServerStreamingServer[TailscaleStatusUpdate]) error
StartTailscalePing(*TailscalePingRequest, grpc.ServerStreamingServer[TailscalePingResponse]) error
mustEmbedUnimplementedStartedServiceServer() mustEmbedUnimplementedStartedServiceServer()
} }
@@ -572,14 +436,6 @@ func (UnimplementedStartedServiceServer) SetSystemProxyEnabled(context.Context,
return nil, status.Error(codes.Unimplemented, "method SetSystemProxyEnabled not implemented") return nil, status.Error(codes.Unimplemented, "method SetSystemProxyEnabled not implemented")
} }
func (UnimplementedStartedServiceServer) TriggerDebugCrash(context.Context, *DebugCrashRequest) (*emptypb.Empty, error) {
return nil, status.Error(codes.Unimplemented, "method TriggerDebugCrash not implemented")
}
func (UnimplementedStartedServiceServer) TriggerOOMReport(context.Context, *emptypb.Empty) (*emptypb.Empty, error) {
return nil, status.Error(codes.Unimplemented, "method TriggerOOMReport not implemented")
}
func (UnimplementedStartedServiceServer) SubscribeConnections(*SubscribeConnectionsRequest, grpc.ServerStreamingServer[ConnectionEvents]) error { func (UnimplementedStartedServiceServer) SubscribeConnections(*SubscribeConnectionsRequest, grpc.ServerStreamingServer[ConnectionEvents]) error {
return status.Error(codes.Unimplemented, "method SubscribeConnections not implemented") return status.Error(codes.Unimplemented, "method SubscribeConnections not implemented")
} }
@@ -599,26 +455,6 @@ func (UnimplementedStartedServiceServer) GetDeprecatedWarnings(context.Context,
func (UnimplementedStartedServiceServer) GetStartedAt(context.Context, *emptypb.Empty) (*StartedAt, error) { func (UnimplementedStartedServiceServer) GetStartedAt(context.Context, *emptypb.Empty) (*StartedAt, error) {
return nil, status.Error(codes.Unimplemented, "method GetStartedAt not implemented") return nil, status.Error(codes.Unimplemented, "method GetStartedAt not implemented")
} }
func (UnimplementedStartedServiceServer) SubscribeOutbounds(*emptypb.Empty, grpc.ServerStreamingServer[OutboundList]) error {
return status.Error(codes.Unimplemented, "method SubscribeOutbounds not implemented")
}
func (UnimplementedStartedServiceServer) StartNetworkQualityTest(*NetworkQualityTestRequest, grpc.ServerStreamingServer[NetworkQualityTestProgress]) error {
return status.Error(codes.Unimplemented, "method StartNetworkQualityTest not implemented")
}
func (UnimplementedStartedServiceServer) StartSTUNTest(*STUNTestRequest, grpc.ServerStreamingServer[STUNTestProgress]) error {
return status.Error(codes.Unimplemented, "method StartSTUNTest not implemented")
}
func (UnimplementedStartedServiceServer) SubscribeTailscaleStatus(*emptypb.Empty, grpc.ServerStreamingServer[TailscaleStatusUpdate]) error {
return status.Error(codes.Unimplemented, "method SubscribeTailscaleStatus not implemented")
}
func (UnimplementedStartedServiceServer) StartTailscalePing(*TailscalePingRequest, grpc.ServerStreamingServer[TailscalePingResponse]) error {
return status.Error(codes.Unimplemented, "method StartTailscalePing not implemented")
}
func (UnimplementedStartedServiceServer) mustEmbedUnimplementedStartedServiceServer() {} func (UnimplementedStartedServiceServer) mustEmbedUnimplementedStartedServiceServer() {}
func (UnimplementedStartedServiceServer) testEmbeddedByValue() {} func (UnimplementedStartedServiceServer) testEmbeddedByValue() {}
@@ -893,42 +729,6 @@ func _StartedService_SetSystemProxyEnabled_Handler(srv interface{}, ctx context.
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _StartedService_TriggerDebugCrash_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(DebugCrashRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(StartedServiceServer).TriggerDebugCrash(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: StartedService_TriggerDebugCrash_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(StartedServiceServer).TriggerDebugCrash(ctx, req.(*DebugCrashRequest))
}
return interceptor(ctx, in, info, handler)
}
func _StartedService_TriggerOOMReport_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(emptypb.Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(StartedServiceServer).TriggerOOMReport(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: StartedService_TriggerOOMReport_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(StartedServiceServer).TriggerOOMReport(ctx, req.(*emptypb.Empty))
}
return interceptor(ctx, in, info, handler)
}
func _StartedService_SubscribeConnections_Handler(srv interface{}, stream grpc.ServerStream) error { func _StartedService_SubscribeConnections_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(SubscribeConnectionsRequest) m := new(SubscribeConnectionsRequest)
if err := stream.RecvMsg(m); err != nil { if err := stream.RecvMsg(m); err != nil {
@@ -1012,61 +812,6 @@ func _StartedService_GetStartedAt_Handler(srv interface{}, ctx context.Context,
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _StartedService_SubscribeOutbounds_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(emptypb.Empty)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(StartedServiceServer).SubscribeOutbounds(m, &grpc.GenericServerStream[emptypb.Empty, OutboundList]{ServerStream: stream})
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type StartedService_SubscribeOutboundsServer = grpc.ServerStreamingServer[OutboundList]
func _StartedService_StartNetworkQualityTest_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(NetworkQualityTestRequest)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(StartedServiceServer).StartNetworkQualityTest(m, &grpc.GenericServerStream[NetworkQualityTestRequest, NetworkQualityTestProgress]{ServerStream: stream})
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type StartedService_StartNetworkQualityTestServer = grpc.ServerStreamingServer[NetworkQualityTestProgress]
func _StartedService_StartSTUNTest_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(STUNTestRequest)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(StartedServiceServer).StartSTUNTest(m, &grpc.GenericServerStream[STUNTestRequest, STUNTestProgress]{ServerStream: stream})
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type StartedService_StartSTUNTestServer = grpc.ServerStreamingServer[STUNTestProgress]
func _StartedService_SubscribeTailscaleStatus_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(emptypb.Empty)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(StartedServiceServer).SubscribeTailscaleStatus(m, &grpc.GenericServerStream[emptypb.Empty, TailscaleStatusUpdate]{ServerStream: stream})
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type StartedService_SubscribeTailscaleStatusServer = grpc.ServerStreamingServer[TailscaleStatusUpdate]
func _StartedService_StartTailscalePing_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(TailscalePingRequest)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(StartedServiceServer).StartTailscalePing(m, &grpc.GenericServerStream[TailscalePingRequest, TailscalePingResponse]{ServerStream: stream})
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type StartedService_StartTailscalePingServer = grpc.ServerStreamingServer[TailscalePingResponse]
// StartedService_ServiceDesc is the grpc.ServiceDesc for StartedService service. // StartedService_ServiceDesc is the grpc.ServiceDesc for StartedService service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
@@ -1118,14 +863,6 @@ var StartedService_ServiceDesc = grpc.ServiceDesc{
MethodName: "SetSystemProxyEnabled", MethodName: "SetSystemProxyEnabled",
Handler: _StartedService_SetSystemProxyEnabled_Handler, Handler: _StartedService_SetSystemProxyEnabled_Handler,
}, },
{
MethodName: "TriggerDebugCrash",
Handler: _StartedService_TriggerDebugCrash_Handler,
},
{
MethodName: "TriggerOOMReport",
Handler: _StartedService_TriggerOOMReport_Handler,
},
{ {
MethodName: "CloseConnection", MethodName: "CloseConnection",
Handler: _StartedService_CloseConnection_Handler, Handler: _StartedService_CloseConnection_Handler,
@@ -1174,31 +911,6 @@ var StartedService_ServiceDesc = grpc.ServiceDesc{
Handler: _StartedService_SubscribeConnections_Handler, Handler: _StartedService_SubscribeConnections_Handler,
ServerStreams: true, ServerStreams: true,
}, },
{
StreamName: "SubscribeOutbounds",
Handler: _StartedService_SubscribeOutbounds_Handler,
ServerStreams: true,
},
{
StreamName: "StartNetworkQualityTest",
Handler: _StartedService_StartNetworkQualityTest_Handler,
ServerStreams: true,
},
{
StreamName: "StartSTUNTest",
Handler: _StartedService_StartSTUNTest_Handler,
ServerStreams: true,
},
{
StreamName: "SubscribeTailscaleStatus",
Handler: _StartedService_SubscribeTailscaleStatus_Handler,
ServerStreams: true,
},
{
StreamName: "StartTailscalePing",
Handler: _StartedService_StartTailscalePing_Handler,
ServerStreams: true,
},
}, },
Metadata: "daemon/started_service.proto", Metadata: "daemon/started_service.proto",
} }

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"net" "net"
"net/netip" "net/netip"
"strings"
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
@@ -13,6 +14,7 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/task" "github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/contrab/freelru" "github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash" "github.com/sagernet/sing/contrab/maphash"
@@ -30,63 +32,59 @@ var (
var _ adapter.DNSClient = (*Client)(nil) var _ adapter.DNSClient = (*Client)(nil)
type Client struct { type Client struct {
ctx context.Context timeout time.Duration
timeout time.Duration disableCache bool
disableCache bool disableExpire bool
disableExpire bool independentCache bool
optimisticTimeout time.Duration clientSubnet netip.Prefix
cacheCapacity uint32 rdrc adapter.RDRCStore
clientSubnet netip.Prefix initRDRCFunc func() adapter.RDRCStore
rdrc adapter.RDRCStore logger logger.ContextLogger
initRDRCFunc func() adapter.RDRCStore cache freelru.Cache[dns.Question, *dns.Msg]
dnsCache adapter.DNSCacheStore cacheLock compatible.Map[dns.Question, chan struct{}]
initDNSCacheFunc func() adapter.DNSCacheStore transportCache freelru.Cache[transportCacheKey, *dns.Msg]
logger logger.ContextLogger transportCacheLock compatible.Map[dns.Question, chan struct{}]
cache freelru.Cache[dnsCacheKey, *dns.Msg]
cacheLock compatible.Map[dnsCacheKey, chan struct{}]
backgroundRefresh compatible.Map[dnsCacheKey, struct{}]
} }
type ClientOptions struct { type ClientOptions struct {
Context context.Context Timeout time.Duration
Timeout time.Duration DisableCache bool
DisableCache bool DisableExpire bool
DisableExpire bool IndependentCache bool
OptimisticTimeout time.Duration CacheCapacity uint32
CacheCapacity uint32 ClientSubnet netip.Prefix
ClientSubnet netip.Prefix RDRC func() adapter.RDRCStore
RDRC func() adapter.RDRCStore Logger logger.ContextLogger
DNSCache func() adapter.DNSCacheStore
Logger logger.ContextLogger
} }
func NewClient(options ClientOptions) *Client { func NewClient(options ClientOptions) *Client {
cacheCapacity := options.CacheCapacity
if cacheCapacity < 1024 {
cacheCapacity = 1024
}
client := &Client{ client := &Client{
ctx: options.Context, timeout: options.Timeout,
timeout: options.Timeout, disableCache: options.DisableCache,
disableCache: options.DisableCache, disableExpire: options.DisableExpire,
disableExpire: options.DisableExpire, independentCache: options.IndependentCache,
optimisticTimeout: options.OptimisticTimeout, clientSubnet: options.ClientSubnet,
cacheCapacity: cacheCapacity, initRDRCFunc: options.RDRC,
clientSubnet: options.ClientSubnet, logger: options.Logger,
initRDRCFunc: options.RDRC,
initDNSCacheFunc: options.DNSCache,
logger: options.Logger,
} }
if client.timeout == 0 { if client.timeout == 0 {
client.timeout = C.DNSTimeout client.timeout = C.DNSTimeout
} }
if !client.disableCache && client.initDNSCacheFunc == nil { cacheCapacity := options.CacheCapacity
client.initializeMemoryCache() if cacheCapacity < 1024 {
cacheCapacity = 1024
}
if !client.disableCache {
if !client.independentCache {
client.cache = common.Must1(freelru.NewSharded[dns.Question, *dns.Msg](cacheCapacity, maphash.NewHasher[dns.Question]().Hash32))
} else {
client.transportCache = common.Must1(freelru.NewSharded[transportCacheKey, *dns.Msg](cacheCapacity, maphash.NewHasher[transportCacheKey]().Hash32))
}
} }
return client return client
} }
type dnsCacheKey struct { type transportCacheKey struct {
dns.Question dns.Question
transportTag string transportTag string
} }
@@ -95,19 +93,6 @@ func (c *Client) Start() {
if c.initRDRCFunc != nil { if c.initRDRCFunc != nil {
c.rdrc = c.initRDRCFunc() c.rdrc = c.initRDRCFunc()
} }
if c.initDNSCacheFunc != nil {
c.dnsCache = c.initDNSCacheFunc()
}
if c.dnsCache == nil {
c.initializeMemoryCache()
}
}
func (c *Client) initializeMemoryCache() {
if c.disableCache || c.cache != nil {
return
}
c.cache = common.Must1(freelru.NewSharded[dnsCacheKey, *dns.Msg](c.cacheCapacity, maphash.NewHasher[dnsCacheKey]().Hash32))
} }
func extractNegativeTTL(response *dns.Msg) (uint32, bool) { func extractNegativeTTL(response *dns.Msg) (uint32, bool) {
@@ -124,38 +109,7 @@ func extractNegativeTTL(response *dns.Msg) (uint32, bool) {
return 0, false return 0, false
} }
func computeTimeToLive(response *dns.Msg) uint32 { func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error) {
var timeToLive uint32
if len(response.Answer) == 0 {
if soaTTL, hasSOA := extractNegativeTTL(response); hasSOA {
return soaTTL
}
}
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
if record.Header().Rrtype == dns.TypeOPT {
continue
}
if timeToLive == 0 || record.Header().Ttl > 0 && record.Header().Ttl < timeToLive {
timeToLive = record.Header().Ttl
}
}
}
return timeToLive
}
func normalizeTTL(response *dns.Msg, timeToLive uint32) {
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
if record.Header().Rrtype == dns.TypeOPT {
continue
}
record.Header().Ttl = timeToLive
}
}
}
func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) (*dns.Msg, error) {
if len(message.Question) == 0 { if len(message.Question) == 0 {
if c.logger != nil { if c.logger != nil {
c.logger.WarnContext(ctx, "bad question size: ", len(message.Question)) c.logger.WarnContext(ctx, "bad question size: ", len(message.Question))
@@ -169,7 +123,13 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
} }
return FixedResponseStatus(message, dns.RcodeSuccess), nil return FixedResponseStatus(message, dns.RcodeSuccess), nil
} }
message = c.prepareExchangeMessage(message, options) clientSubnet := options.ClientSubnet
if !clientSubnet.IsValid() {
clientSubnet = c.clientSubnet
}
if clientSubnet.IsValid() {
message = SetClientSubnet(message, clientSubnet)
}
isSimpleRequest := len(message.Question) == 1 && isSimpleRequest := len(message.Question) == 1 &&
len(message.Ns) == 0 && len(message.Ns) == 0 &&
@@ -181,32 +141,40 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
!options.ClientSubnet.IsValid() !options.ClientSubnet.IsValid()
disableCache := !isSimpleRequest || c.disableCache || options.DisableCache disableCache := !isSimpleRequest || c.disableCache || options.DisableCache
if !disableCache { if !disableCache {
cacheKey := dnsCacheKey{Question: question, transportTag: transport.Tag()} if c.cache != nil {
cond, loaded := c.cacheLock.LoadOrStore(cacheKey, make(chan struct{})) cond, loaded := c.cacheLock.LoadOrStore(question, make(chan struct{}))
if loaded { if loaded {
select { select {
case <-cond: case <-cond:
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()
}
} else {
defer func() {
c.cacheLock.Delete(question)
close(cond)
}()
}
} else if c.transportCache != nil {
cond, loaded := c.transportCacheLock.LoadOrStore(question, make(chan struct{}))
if loaded {
select {
case <-cond:
case <-ctx.Done():
return nil, ctx.Err()
}
} else {
defer func() {
c.transportCacheLock.Delete(question)
close(cond)
}()
} }
} else {
defer func() {
c.cacheLock.Delete(cacheKey)
close(cond)
}()
} }
response, ttl, isStale := c.loadResponse(question, transport) response, ttl := c.loadResponse(question, transport)
if response != nil { if response != nil {
if isStale && !options.DisableOptimisticCache { logCachedResponse(c.logger, ctx, response, ttl)
c.backgroundRefreshDNS(transport, question, message.Copy(), options, responseChecker) response.Id = message.Id
logOptimisticResponse(c.logger, ctx, response) return response, nil
response.Id = message.Id
return response, nil
} else if !isStale {
logCachedResponse(c.logger, ctx, response, ttl)
response.Id = message.Id
return response, nil
}
} }
} }
@@ -222,17 +190,62 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
return nil, ErrResponseRejectedCached return nil, ErrResponseRejectedCached
} }
} }
response, err := c.exchangeToTransport(ctx, transport, message) ctx, cancel := context.WithTimeout(ctx, c.timeout)
response, err := transport.Exchange(ctx, message)
cancel()
if err != nil { if err != nil {
return nil, err var rcodeError RcodeError
if errors.As(err, &rcodeError) {
response = FixedResponseStatus(message, int(rcodeError))
} else {
return nil, err
}
} }
/*if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
validResponse := response
loop:
for {
var (
addresses int
queryCNAME string
)
for _, rawRR := range validResponse.Answer {
switch rr := rawRR.(type) {
case *dns.A:
break loop
case *dns.AAAA:
break loop
case *dns.CNAME:
queryCNAME = rr.Target
}
}
if queryCNAME == "" {
break
}
exMessage := *message
exMessage.Question = []dns.Question{{
Name: queryCNAME,
Qtype: question.Qtype,
}}
validResponse, err = c.Exchange(ctx, transport, &exMessage, options, responseChecker)
if err != nil {
return nil, err
}
}
if validResponse != response {
response.Answer = append(response.Answer, validResponse.Answer...)
}
}*/
disableCache = disableCache || (response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError) disableCache = disableCache || (response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError)
if responseChecker != nil { if responseChecker != nil {
var rejected bool var rejected bool
// TODO: add accept_any rule and support to check response instead of addresses
if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError { if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError {
rejected = true rejected = true
} else if len(response.Answer) == 0 {
rejected = !responseChecker(nil)
} else { } else {
rejected = !responseChecker(response) rejected = !responseChecker(MessageToAddresses(response))
} }
if rejected { if rejected {
if !disableCache && c.rdrc != nil { if !disableCache && c.rdrc != nil {
@@ -242,7 +255,48 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
return response, ErrResponseRejected return response, ErrResponseRejected
} }
} }
timeToLive := applyResponseOptions(question, response, options) if question.Qtype == dns.TypeHTTPS {
if options.Strategy == C.DomainStrategyIPv4Only || options.Strategy == C.DomainStrategyIPv6Only {
for _, rr := range response.Answer {
https, isHTTPS := rr.(*dns.HTTPS)
if !isHTTPS {
continue
}
content := https.SVCB
content.Value = common.Filter(content.Value, func(it dns.SVCBKeyValue) bool {
if options.Strategy == C.DomainStrategyIPv4Only {
return it.Key() != dns.SVCB_IPV6HINT
} else {
return it.Key() != dns.SVCB_IPV4HINT
}
})
https.SVCB = content
}
}
}
var timeToLive uint32
if len(response.Answer) == 0 {
if soaTTL, hasSOA := extractNegativeTTL(response); hasSOA {
timeToLive = soaTTL
}
}
if timeToLive == 0 {
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
if timeToLive == 0 || record.Header().Ttl > 0 && record.Header().Ttl < timeToLive {
timeToLive = record.Header().Ttl
}
}
}
}
if options.RewriteTTL != nil {
timeToLive = *options.RewriteTTL
}
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
record.Header().Ttl = timeToLive
}
}
if !disableCache { if !disableCache {
c.storeCache(transport, question, response, timeToLive) c.storeCache(transport, question, response, timeToLive)
} }
@@ -261,7 +315,7 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
return response, nil return response, nil
} }
func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) ([]netip.Addr, error) { func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
domain = FqdnToDomain(domain) domain = FqdnToDomain(domain)
dnsName := dns.Fqdn(domain) dnsName := dns.Fqdn(domain)
var strategy C.DomainStrategy var strategy C.DomainStrategy
@@ -308,12 +362,8 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
func (c *Client) ClearCache() { func (c *Client) ClearCache() {
if c.cache != nil { if c.cache != nil {
c.cache.Purge() c.cache.Purge()
} } else if c.transportCache != nil {
if c.dnsCache != nil { c.transportCache.Purge()
err := c.dnsCache.ClearDNSCache()
if err != nil && c.logger != nil {
c.logger.Warn("clear DNS cache: ", err)
}
} }
} }
@@ -329,44 +379,46 @@ func (c *Client) storeCache(transport adapter.DNSTransport, question dns.Questio
if timeToLive == 0 { if timeToLive == 0 {
return return
} }
if c.dnsCache != nil {
packed, err := message.Pack()
if err == nil {
expireAt := time.Now().Add(time.Second * time.Duration(timeToLive))
c.dnsCache.SaveDNSCacheAsync(transport.Tag(), question.Name, question.Qtype, packed, expireAt, c.logger)
}
return
}
if c.cache == nil {
return
}
key := dnsCacheKey{Question: question, transportTag: transport.Tag()}
if c.disableExpire { if c.disableExpire {
c.cache.Add(key, message.Copy()) if !c.independentCache {
c.cache.Add(question, message)
} else {
c.transportCache.Add(transportCacheKey{
Question: question,
transportTag: transport.Tag(),
}, message)
}
} else { } else {
c.cache.AddWithLifetime(key, message.Copy(), time.Second*time.Duration(timeToLive)) if !c.independentCache {
c.cache.AddWithLifetime(question, message, time.Second*time.Duration(timeToLive))
} else {
c.transportCache.AddWithLifetime(transportCacheKey{
Question: question,
transportTag: transport.Tag(),
}, message, time.Second*time.Duration(timeToLive))
}
} }
} }
func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTransport, name string, qType uint16, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) ([]netip.Addr, error) { func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTransport, name string, qType uint16, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
question := dns.Question{ question := dns.Question{
Name: name, Name: name,
Qtype: qType, Qtype: qType,
Qclass: dns.ClassINET, Qclass: dns.ClassINET,
} }
disableCache := c.disableCache || options.DisableCache
if !disableCache {
cachedAddresses, err := c.questionCache(question, transport)
if err != ErrNotCached {
return cachedAddresses, err
}
}
message := dns.Msg{ message := dns.Msg{
MsgHdr: dns.MsgHdr{ MsgHdr: dns.MsgHdr{
RecursionDesired: true, RecursionDesired: true,
}, },
Question: []dns.Question{question}, Question: []dns.Question{question},
} }
disableCache := c.disableCache || options.DisableCache
if !disableCache {
cachedAddresses, err := c.questionCache(ctx, transport, &message, options, responseChecker)
if err != ErrNotCached {
return cachedAddresses, err
}
}
response, err := c.Exchange(ctx, transport, &message, options, responseChecker) response, err := c.Exchange(ctx, transport, &message, options, responseChecker)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -377,181 +429,111 @@ func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTran
return MessageToAddresses(response), nil return MessageToAddresses(response), nil
} }
func (c *Client) questionCache(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) ([]netip.Addr, error) { func (c *Client) questionCache(question dns.Question, transport adapter.DNSTransport) ([]netip.Addr, error) {
question := message.Question[0] response, _ := c.loadResponse(question, transport)
response, _, isStale := c.loadResponse(question, transport)
if response == nil { if response == nil {
return nil, ErrNotCached return nil, ErrNotCached
} }
if isStale {
if options.DisableOptimisticCache {
return nil, ErrNotCached
}
c.backgroundRefreshDNS(transport, question, c.prepareExchangeMessage(message.Copy(), options), options, responseChecker)
logOptimisticResponse(c.logger, ctx, response)
}
if response.Rcode != dns.RcodeSuccess { if response.Rcode != dns.RcodeSuccess {
return nil, RcodeError(response.Rcode) return nil, RcodeError(response.Rcode)
} }
return MessageToAddresses(response), nil return MessageToAddresses(response), nil
} }
func (c *Client) loadResponse(question dns.Question, transport adapter.DNSTransport) (*dns.Msg, int, bool) { func (c *Client) loadResponse(question dns.Question, transport adapter.DNSTransport) (*dns.Msg, int) {
if c.dnsCache != nil { var (
return c.loadPersistentResponse(question, transport) response *dns.Msg
} loaded bool
if c.cache == nil { )
return nil, 0, false
}
key := dnsCacheKey{Question: question, transportTag: transport.Tag()}
if c.disableExpire { if c.disableExpire {
response, loaded := c.cache.Get(key) if !c.independentCache {
if !loaded { response, loaded = c.cache.Get(question)
return nil, 0, false } else {
} response, loaded = c.transportCache.Get(transportCacheKey{
return response.Copy(), 0, false Question: question,
} transportTag: transport.Tag(),
response, expireAt, loaded := c.cache.GetWithLifetimeNoExpire(key)
if !loaded {
return nil, 0, false
}
timeNow := time.Now()
if timeNow.After(expireAt) {
if c.optimisticTimeout > 0 && timeNow.Before(expireAt.Add(c.optimisticTimeout)) {
response = response.Copy()
normalizeTTL(response, 1)
return response, 0, true
}
c.cache.Remove(key)
return nil, 0, false
}
nowTTL := int(expireAt.Sub(timeNow).Seconds())
if nowTTL < 0 {
nowTTL = 0
}
response = response.Copy()
normalizeTTL(response, uint32(nowTTL))
return response, nowTTL, false
}
func (c *Client) loadPersistentResponse(question dns.Question, transport adapter.DNSTransport) (*dns.Msg, int, bool) {
rawMessage, expireAt, loaded := c.dnsCache.LoadDNSCache(transport.Tag(), question.Name, question.Qtype)
if !loaded {
return nil, 0, false
}
response := new(dns.Msg)
err := response.Unpack(rawMessage)
if err != nil {
return nil, 0, false
}
if c.disableExpire {
return response, 0, false
}
timeNow := time.Now()
if timeNow.After(expireAt) {
if c.optimisticTimeout > 0 && timeNow.Before(expireAt.Add(c.optimisticTimeout)) {
normalizeTTL(response, 1)
return response, 0, true
}
return nil, 0, false
}
nowTTL := int(expireAt.Sub(timeNow).Seconds())
if nowTTL < 0 {
nowTTL = 0
}
normalizeTTL(response, uint32(nowTTL))
return response, nowTTL, false
}
func applyResponseOptions(question dns.Question, response *dns.Msg, options adapter.DNSQueryOptions) uint32 {
if question.Qtype == dns.TypeHTTPS && (options.Strategy == C.DomainStrategyIPv4Only || options.Strategy == C.DomainStrategyIPv6Only) {
for _, rr := range response.Answer {
https, isHTTPS := rr.(*dns.HTTPS)
if !isHTTPS {
continue
}
content := https.SVCB
content.Value = common.Filter(content.Value, func(it dns.SVCBKeyValue) bool {
if options.Strategy == C.DomainStrategyIPv4Only {
return it.Key() != dns.SVCB_IPV6HINT
}
return it.Key() != dns.SVCB_IPV4HINT
}) })
https.SVCB = content
} }
} if !loaded {
timeToLive := computeTimeToLive(response) return nil, 0
if options.RewriteTTL != nil {
timeToLive = *options.RewriteTTL
}
normalizeTTL(response, timeToLive)
return timeToLive
}
func (c *Client) backgroundRefreshDNS(transport adapter.DNSTransport, question dns.Question, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) {
key := dnsCacheKey{Question: question, transportTag: transport.Tag()}
_, loaded := c.backgroundRefresh.LoadOrStore(key, struct{}{})
if loaded {
return
}
go func() {
defer c.backgroundRefresh.Delete(key)
ctx := contextWithTransportTag(c.ctx, transport.Tag())
response, err := c.exchangeToTransport(ctx, transport, message)
if err != nil {
if c.logger != nil {
c.logger.Debug("optimistic refresh failed for ", FqdnToDomain(question.Name), ": ", err)
}
return
} }
if responseChecker != nil { return response.Copy(), 0
var rejected bool } else {
if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError { var expireAt time.Time
rejected = true if !c.independentCache {
response, expireAt, loaded = c.cache.GetWithLifetime(question)
} else {
response, expireAt, loaded = c.transportCache.GetWithLifetime(transportCacheKey{
Question: question,
transportTag: transport.Tag(),
})
}
if !loaded {
return nil, 0
}
timeNow := time.Now()
if timeNow.After(expireAt) {
if !c.independentCache {
c.cache.Remove(question)
} else { } else {
rejected = !responseChecker(response) c.transportCache.Remove(transportCacheKey{
Question: question,
transportTag: transport.Tag(),
})
} }
if rejected { return nil, 0
if c.rdrc != nil {
c.rdrc.SaveRDRCAsync(transport.Tag(), question.Name, question.Qtype, c.logger)
}
return
}
} else if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError {
return
} }
timeToLive := applyResponseOptions(question, response, options) var originTTL int
c.storeCache(transport, question, response, timeToLive) for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
}() for _, record := range recordList {
} if originTTL == 0 || record.Header().Ttl > 0 && int(record.Header().Ttl) < originTTL {
originTTL = int(record.Header().Ttl)
func (c *Client) prepareExchangeMessage(message *dns.Msg, options adapter.DNSQueryOptions) *dns.Msg { }
clientSubnet := options.ClientSubnet }
if !clientSubnet.IsValid() { }
clientSubnet = c.clientSubnet nowTTL := int(expireAt.Sub(timeNow).Seconds())
if nowTTL < 0 {
nowTTL = 0
}
response = response.Copy()
if originTTL > 0 {
duration := uint32(originTTL - nowTTL)
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
record.Header().Ttl = record.Header().Ttl - duration
}
}
} else {
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
record.Header().Ttl = uint32(nowTTL)
}
}
}
return response, nowTTL
} }
if clientSubnet.IsValid() {
message = SetClientSubnet(message, clientSubnet)
}
return message
}
func (c *Client) exchangeToTransport(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg) (*dns.Msg, error) {
ctx, cancel := context.WithTimeout(ctx, c.timeout)
defer cancel()
response, err := transport.Exchange(ctx, message)
if err == nil {
return response, nil
}
var rcodeError RcodeError
if errors.As(err, &rcodeError) {
return FixedResponseStatus(message, int(rcodeError)), nil
}
return nil, err
} }
func MessageToAddresses(response *dns.Msg) []netip.Addr { func MessageToAddresses(response *dns.Msg) []netip.Addr {
return adapter.DNSResponseAddresses(response) if response == nil || response.Rcode != dns.RcodeSuccess {
return nil
}
addresses := make([]netip.Addr, 0, len(response.Answer))
for _, rawAnswer := range response.Answer {
switch answer := rawAnswer.(type) {
case *dns.A:
addresses = append(addresses, M.AddrFromIP(answer.A))
case *dns.AAAA:
addresses = append(addresses, M.AddrFromIP(answer.AAAA))
case *dns.HTTPS:
for _, value := range answer.SVCB.Value {
if value.Key() == dns.SVCB_IPV4HINT || value.Key() == dns.SVCB_IPV6HINT {
addresses = append(addresses, common.Map(strings.Split(value.String(), ","), M.ParseAddr)...)
}
}
}
}
return addresses
} }
func wrapError(err error) error { func wrapError(err error) error {

View File

@@ -22,19 +22,6 @@ func logCachedResponse(logger logger.ContextLogger, ctx context.Context, respons
} }
} }
func logOptimisticResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg) {
if logger == nil || len(response.Question) == 0 {
return
}
domain := FqdnToDomain(response.Question[0].Name)
logger.DebugContext(ctx, "optimistic ", domain, " ", dns.RcodeToString[response.Rcode])
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
logger.InfoContext(ctx, "optimistic ", dns.Type(record.Header().Rrtype).String(), " ", FormatQuestion(record.String()))
}
}
}
func logExchangedResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg, ttl uint32) { func logExchangedResponse(logger logger.ContextLogger, ctx context.Context, response *dns.Msg, ttl uint32) {
if logger == nil || len(response.Question) == 0 { if logger == nil || len(response.Question) == 0 {
return return

View File

@@ -5,11 +5,10 @@ import (
) )
const ( const (
RcodeSuccess RcodeError = mDNS.RcodeSuccess RcodeSuccess RcodeError = mDNS.RcodeSuccess
RcodeServerFailure RcodeError = mDNS.RcodeServerFailure RcodeFormatError RcodeError = mDNS.RcodeFormatError
RcodeFormatError RcodeError = mDNS.RcodeFormatError RcodeNameError RcodeError = mDNS.RcodeNameError
RcodeNameError RcodeError = mDNS.RcodeNameError RcodeRefused RcodeError = mDNS.RcodeRefused
RcodeRefused RcodeError = mDNS.RcodeRefused
) )
type RcodeError int type RcodeError int

View File

@@ -1,111 +0,0 @@
package dns
import (
"context"
"net/netip"
"testing"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json/badoption"
mDNS "github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
func TestReproLookupWithRulesUsesRequestStrategy(t *testing.T) {
t.Parallel()
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
var qTypes []uint16
router := newTestRouter(t, nil, &fakeDNSTransportManager{
defaultTransport: defaultTransport,
transports: map[string]adapter.DNSTransport{
"default": defaultTransport,
},
}, &fakeDNSClient{
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
qTypes = append(qTypes, message.Question[0].Qtype)
if message.Question[0].Qtype == mDNS.TypeA {
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2.2.2.2")}, 60), nil
}
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2001:db8::1")}, 60), nil
},
})
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
Strategy: C.DomainStrategyIPv4Only,
})
require.NoError(t, err)
require.Equal(t, []uint16{mDNS.TypeA}, qTypes)
require.Equal(t, []netip.Addr{netip.MustParseAddr("2.2.2.2")}, addresses)
}
func TestReproLogicalMatchResponseIPCIDR(t *testing.T) {
t.Parallel()
transportManager := &fakeDNSTransportManager{
defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
transports: map[string]adapter.DNSTransport{
"upstream": &fakeDNSTransport{tag: "upstream", transportType: C.DNSTypeUDP},
"selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP},
"default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
},
}
client := &fakeDNSClient{
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
switch transport.Tag() {
case "upstream":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil
case "selected":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil
default:
return nil, E.New("unexpected transport")
}
},
}
rules := []option.DNSRule{
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
Domain: badoption.Listable[string]{"example.com"},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeEvaluate,
RouteOptions: option.DNSRouteActionOptions{Server: "upstream"},
},
},
},
{
Type: C.RuleTypeLogical,
LogicalOptions: option.LogicalDNSRule{
RawLogicalDNSRule: option.RawLogicalDNSRule{
Mode: C.LogicalTypeOr,
Rules: []option.DNSRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
MatchResponse: true,
IPCIDR: badoption.Listable[string]{"1.1.1.0/24"},
},
},
}},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
},
},
},
}
router := newTestRouter(t, rules, transportManager, client)
response, err := router.Exchange(context.Background(), &mDNS.Msg{
Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)},
}, adapter.DNSQueryOptions{})
require.NoError(t, err)
require.Equal(t, []netip.Addr{netip.MustParseAddr("8.8.8.8")}, MessageToAddresses(response))
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -7,6 +7,7 @@ import (
"strings" "strings"
"syscall" "syscall"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport" "github.com/sagernet/sing-box/dns/transport"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
@@ -39,6 +40,13 @@ func (t *Transport) exchangeParallel(ctx context.Context, servers []M.Socksaddr,
results := make(chan queryResult) results := make(chan queryResult)
startRacer := func(ctx context.Context, fqdn string) { startRacer := func(ctx context.Context, fqdn string) {
response, err := t.tryOneName(ctx, servers, fqdn, message) response, err := t.tryOneName(ctx, servers, fqdn, message)
if err == nil {
if response.Rcode != mDNS.RcodeSuccess {
err = dns.RcodeError(response.Rcode)
} else if len(dns.MessageToAddresses(response)) == 0 {
err = dns.RcodeSuccess
}
}
select { select {
case results <- queryResult{response, err}: case results <- queryResult{response, err}:
case <-returned: case <-returned:

View File

@@ -3,18 +3,17 @@ package transport
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/base64"
"errors" "errors"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strconv"
"sync" "sync"
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/httpclient"
"github.com/sagernet/sing-box/common/tls" "github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns" "github.com/sagernet/sing-box/dns"
@@ -45,20 +44,14 @@ type HTTPSTransport struct {
logger logger.ContextLogger logger logger.ContextLogger
dialer N.Dialer dialer N.Dialer
destination *url.URL destination *url.URL
method string headers http.Header
host string
queryHeaders http.Header
transportAccess sync.Mutex transportAccess sync.Mutex
transport *httpclient.Client transport *HTTPSTransportWrapper
transportResetAt time.Time transportResetAt time.Time
} }
func NewHTTPS(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteHTTPSDNSServerOptions) (adapter.DNSTransport, error) { func NewHTTPS(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteHTTPSDNSServerOptions) (adapter.DNSTransport, error) {
remoteOptions := option.RemoteDNSServerOptions{ transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions)
DNSServerAddressOptions: options.DNSServerAddressOptions,
}
remoteOptions.DialerOptions = options.DialerOptions
transportDialer, err := dns.NewRemoteDialer(ctx, remoteOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -69,21 +62,28 @@ func NewHTTPS(ctx context.Context, logger log.ContextLogger, tag string, options
return nil, err return nil, err
} }
if len(tlsConfig.NextProtos()) == 0 { if len(tlsConfig.NextProtos()) == 0 {
tlsConfig.SetNextProtos([]string{http2.NextProtoTLS}) tlsConfig.SetNextProtos([]string{http2.NextProtoTLS, "http/1.1"})
} else if !common.Contains(tlsConfig.NextProtos(), http2.NextProtoTLS) {
tlsConfig.SetNextProtos(append([]string{http2.NextProtoTLS}, tlsConfig.NextProtos()...))
} }
headers := options.Headers.Build() headers := options.Headers.Build()
serverAddr := options.DNSServerAddressOptions.Build() host := headers.Get("Host")
if serverAddr.Port == 0 { if host != "" {
serverAddr.Port = 443 headers.Del("Host")
} } else {
if !serverAddr.IsValid() { if tlsConfig.ServerName() != "" {
return nil, E.New("invalid server address: ", serverAddr) host = tlsConfig.ServerName()
} else {
host = options.Server
}
} }
destinationURL := url.URL{ destinationURL := url.URL{
Scheme: "https", Scheme: "https",
Host: doHURLHost(serverAddr, 443), Host: host,
}
if destinationURL.Host == "" {
destinationURL.Host = options.Server
}
if options.ServerPort != 0 && options.ServerPort != 443 {
destinationURL.Host = net.JoinHostPort(destinationURL.Host, strconv.Itoa(int(options.ServerPort)))
} }
path := options.Path path := options.Path
if path == "" { if path == "" {
@@ -93,67 +93,41 @@ func NewHTTPS(ctx context.Context, logger log.ContextLogger, tag string, options
if err != nil { if err != nil {
return nil, err return nil, err
} }
method := strings.ToUpper(options.Method) serverAddr := options.DNSServerAddressOptions.Build()
if method == "" { if serverAddr.Port == 0 {
method = http.MethodPost serverAddr.Port = 443
} }
switch method { if !serverAddr.IsValid() {
case http.MethodGet, http.MethodPost: return nil, E.New("invalid server address: ", serverAddr)
default:
return nil, E.New("unsupported HTTPS DNS method: ", options.Method)
} }
httpClientOptions := options.HTTPClientOptions return NewHTTPSRaw(
return NewHTTPRaw( dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeHTTPS, tag, options.RemoteDNSServerOptions),
dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeHTTPS, tag, remoteOptions),
logger, logger,
transportDialer, transportDialer,
&destinationURL, &destinationURL,
headers, headers,
serverAddr,
tlsConfig, tlsConfig,
httpClientOptions, ), nil
method,
)
} }
func NewHTTPRaw( func NewHTTPSRaw(
adapter dns.TransportAdapter, adapter dns.TransportAdapter,
logger logger.ContextLogger, logger log.ContextLogger,
dialer N.Dialer, dialer N.Dialer,
destination *url.URL, destination *url.URL,
headers http.Header, headers http.Header,
serverAddr M.Socksaddr,
tlsConfig tls.Config, tlsConfig tls.Config,
httpClientOptions option.HTTPClientOptions, ) *HTTPSTransport {
method string,
) (*HTTPSTransport, error) {
if destination.Scheme == "https" && tlsConfig == nil {
return nil, E.New("TLS transport unavailable")
}
queryHeaders := headers.Clone()
if queryHeaders == nil {
queryHeaders = make(http.Header)
}
host := queryHeaders.Get("Host")
queryHeaders.Del("Host")
queryHeaders.Set("Accept", MimeType)
if method == http.MethodPost {
queryHeaders.Set("Content-Type", MimeType)
}
httpClientOptions.Tag = ""
httpClientOptions.Headers = nil
currentTransport, err := httpclient.NewClientWithDialer(dialer, tlsConfig, "", httpClientOptions)
if err != nil {
return nil, err
}
return &HTTPSTransport{ return &HTTPSTransport{
TransportAdapter: adapter, TransportAdapter: adapter,
logger: logger, logger: logger,
dialer: dialer, dialer: dialer,
destination: destination, destination: destination,
method: method, headers: headers,
host: host, transport: NewHTTPSTransportWrapper(tls.NewDialer(dialer, tlsConfig), serverAddr),
queryHeaders: queryHeaders, }
transport: currentTransport,
}, nil
} }
func (t *HTTPSTransport) Start(stage adapter.StartStage) error { func (t *HTTPSTransport) Start(stage adapter.StartStage) error {
@@ -207,25 +181,14 @@ func (t *HTTPSTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS
requestBuffer.Release() requestBuffer.Release()
return nil, err return nil, err
} }
requestURL := *t.destination request, err := http.NewRequestWithContext(ctx, http.MethodPost, t.destination.String(), bytes.NewReader(rawMessage))
var request *http.Request
switch t.method {
case http.MethodGet:
query := requestURL.Query()
query.Set("dns", base64.RawURLEncoding.EncodeToString(rawMessage))
requestURL.RawQuery = query.Encode()
request, err = http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
default:
request, err = http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(rawMessage))
}
if err != nil { if err != nil {
requestBuffer.Release() requestBuffer.Release()
return nil, err return nil, err
} }
request.Header = t.queryHeaders.Clone() request.Header = t.headers.Clone()
if t.host != "" { request.Header.Set("Content-Type", MimeType)
request.Host = t.host request.Header.Set("Accept", MimeType)
}
t.transportAccess.Lock() t.transportAccess.Lock()
currentTransport := t.transport currentTransport := t.transport
t.transportAccess.Unlock() t.transportAccess.Unlock()
@@ -259,13 +222,3 @@ func (t *HTTPSTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS
} }
return &responseMessage, nil return &responseMessage, nil
} }
func doHURLHost(serverAddr M.Socksaddr, defaultPort uint16) string {
if serverAddr.Port != defaultPort {
return serverAddr.String()
}
if serverAddr.IsIPv6() {
return "[" + serverAddr.AddrString() + "]"
}
return serverAddr.AddrString()
}

View File

@@ -0,0 +1,80 @@
package transport
import (
"context"
"errors"
"net"
"net/http"
"sync/atomic"
"github.com/sagernet/sing-box/common/tls"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/net/http2"
)
var errFallback = E.New("fallback to HTTP/1.1")
type HTTPSTransportWrapper struct {
http2Transport *http2.Transport
httpTransport *http.Transport
fallback *atomic.Bool
}
func NewHTTPSTransportWrapper(dialer tls.Dialer, serverAddr M.Socksaddr) *HTTPSTransportWrapper {
var fallback atomic.Bool
return &HTTPSTransportWrapper{
http2Transport: &http2.Transport{
DialTLSContext: func(ctx context.Context, _, _ string, _ *tls.STDConfig) (net.Conn, error) {
tlsConn, err := dialer.DialTLSContext(ctx, serverAddr)
if err != nil {
return nil, err
}
state := tlsConn.ConnectionState()
if state.NegotiatedProtocol == http2.NextProtoTLS {
return tlsConn, nil
}
tlsConn.Close()
fallback.Store(true)
return nil, errFallback
},
},
httpTransport: &http.Transport{
DialTLSContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialTLSContext(ctx, serverAddr)
},
},
fallback: &fallback,
}
}
func (h *HTTPSTransportWrapper) RoundTrip(request *http.Request) (*http.Response, error) {
if h.fallback.Load() {
return h.httpTransport.RoundTrip(request)
} else {
response, err := h.http2Transport.RoundTrip(request)
if err != nil {
if errors.Is(err, errFallback) {
return h.httpTransport.RoundTrip(request)
}
return nil, err
}
return response, nil
}
}
func (h *HTTPSTransportWrapper) CloseIdleConnections() {
h.http2Transport.CloseIdleConnections()
h.httpTransport.CloseIdleConnections()
}
func (h *HTTPSTransportWrapper) Clone() *HTTPSTransportWrapper {
return &HTTPSTransportWrapper{
httpTransport: h.httpTransport,
http2Transport: &http2.Transport{
DialTLSContext: h.http2Transport.DialTLSContext,
},
fallback: h.fallback,
}
}

View File

@@ -4,6 +4,8 @@ package local
import ( import (
"context" "context"
"errors"
"net"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
@@ -12,6 +14,7 @@ import (
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
@@ -32,8 +35,10 @@ type Transport struct {
logger logger.ContextLogger logger logger.ContextLogger
hosts *hosts.File hosts *hosts.File
dialer N.Dialer dialer N.Dialer
preferGo bool
fallback bool fallback bool
dhcpTransport dhcpTransport dhcpTransport dhcpTransport
resolver net.Resolver
} }
type dhcpTransport interface { type dhcpTransport interface {
@@ -47,12 +52,14 @@ func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, opt
if err != nil { if err != nil {
return nil, err return nil, err
} }
transportAdapter := dns.NewTransportAdapterWithLocalOptions(C.DNSTypeLocal, tag, options)
return &Transport{ return &Transport{
TransportAdapter: dns.NewTransportAdapterWithLocalOptions(C.DNSTypeLocal, tag, options), TransportAdapter: transportAdapter,
ctx: ctx, ctx: ctx,
logger: logger, logger: logger,
hosts: hosts.NewFile(hosts.DefaultPath), hosts: hosts.NewFile(hosts.DefaultPath),
dialer: transportDialer, dialer: transportDialer,
preferGo: options.PreferGo,
}, nil }, nil
} }
@@ -90,3 +97,44 @@ func (t *Transport) Reset() {
t.dhcpTransport.Reset() t.dhcpTransport.Reset()
} }
} }
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
question := message.Question[0]
if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
addresses := t.hosts.Lookup(dns.FqdnToDomain(question.Name))
if len(addresses) > 0 {
return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil
}
}
if !t.fallback {
return t.exchange(ctx, message, question.Name)
}
if t.dhcpTransport != nil {
dhcpTransports := t.dhcpTransport.Fetch()
if len(dhcpTransports) > 0 {
return t.dhcpTransport.Exchange0(ctx, message, dhcpTransports)
}
}
if t.preferGo {
// Assuming the user knows what they are doing, we still execute the query which will fail.
return t.exchange(ctx, message, question.Name)
}
if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
var network string
if question.Qtype == mDNS.TypeA {
network = "ip4"
} else {
network = "ip6"
}
addresses, err := t.resolver.LookupNetIP(ctx, network, question.Name)
if err != nil {
var dnsError *net.DNSError
if errors.As(err, &dnsError) && dnsError.IsNotFound {
return nil, dns.RcodeRefused
}
return nil, err
}
return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil
}
return nil, E.New("only A and AAAA queries are supported on Apple platforms when using TUN and DHCP unavailable.")
}

View File

@@ -1,249 +0,0 @@
//go:build darwin
package local
/*
#include <stdlib.h>
#include <dns.h>
#include <resolv.h>
static void *cgo_dns_open_super() {
return (void *)dns_open(NULL);
}
static void cgo_dns_close(void *opaque) {
if (opaque != NULL) dns_free((dns_handle_t)opaque);
}
static int cgo_dns_search(void *opaque, const char *name, int class, int type,
unsigned char *answer, int anslen) {
dns_handle_t handle = (dns_handle_t)opaque;
struct sockaddr_storage from;
uint32_t fromlen = sizeof(from);
return dns_search(handle, name, class, type, (char *)answer, anslen, (struct sockaddr *)&from, &fromlen);
}
static void *cgo_res_init() {
res_state state = calloc(1, sizeof(struct __res_state));
if (state == NULL) return NULL;
if (res_ninit(state) != 0) {
free(state);
return NULL;
}
return state;
}
static void cgo_res_destroy(void *opaque) {
res_state state = (res_state)opaque;
res_ndestroy(state);
free(state);
}
static int cgo_res_nsearch(void *opaque, const char *dname, int class, int type,
unsigned char *answer, int anslen,
int timeout_seconds,
int *out_h_errno) {
res_state state = (res_state)opaque;
state->retrans = timeout_seconds;
state->retry = 1;
int n = res_nsearch(state, dname, class, type, answer, anslen);
if (n < 0) {
*out_h_errno = state->res_h_errno;
}
return n;
}
*/
import "C"
import (
"context"
"errors"
"time"
"unsafe"
boxC "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
E "github.com/sagernet/sing/common/exceptions"
mDNS "github.com/miekg/dns"
)
const (
darwinResolverHostNotFound = 1
darwinResolverTryAgain = 2
darwinResolverNoRecovery = 3
darwinResolverNoData = 4
darwinResolverMaxPacketSize = 65535
)
var errDarwinNeedLargerBuffer = errors.New("darwin resolver response truncated")
func darwinLookupSystemDNS(name string, class, qtype, timeoutSeconds int) (*mDNS.Msg, error) {
response, err := darwinSearchWithSystemRouting(name, class, qtype)
if err == nil {
return response, nil
}
fallbackResponse, fallbackErr := darwinSearchWithResolv(name, class, qtype, timeoutSeconds)
if fallbackErr == nil || fallbackResponse != nil {
return fallbackResponse, fallbackErr
}
return nil, E.Errors(
E.Cause(err, "dns_search"),
E.Cause(fallbackErr, "res_nsearch"),
)
}
func darwinSearchWithSystemRouting(name string, class, qtype int) (*mDNS.Msg, error) {
handle := C.cgo_dns_open_super()
if handle == nil {
return nil, E.New("dns_open failed")
}
defer C.cgo_dns_close(handle)
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
bufSize := 1232
for {
answer := make([]byte, bufSize)
n := C.cgo_dns_search(handle, cName, C.int(class), C.int(qtype),
(*C.uchar)(unsafe.Pointer(&answer[0])), C.int(len(answer)))
if n <= 0 {
return nil, E.New("dns_search failed for ", name)
}
if int(n) > bufSize {
bufSize = int(n)
continue
}
return unpackDarwinResolverMessage(answer[:int(n)], "dns_search")
}
}
func darwinSearchWithResolv(name string, class, qtype int, timeoutSeconds int) (*mDNS.Msg, error) {
state := C.cgo_res_init()
if state == nil {
return nil, E.New("res_ninit failed")
}
defer C.cgo_res_destroy(state)
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
bufSize := 1232
for {
answer := make([]byte, bufSize)
var hErrno C.int
n := C.cgo_res_nsearch(state, cName, C.int(class), C.int(qtype),
(*C.uchar)(unsafe.Pointer(&answer[0])), C.int(len(answer)),
C.int(timeoutSeconds),
&hErrno)
if n >= 0 {
if int(n) > bufSize {
bufSize = int(n)
continue
}
return unpackDarwinResolverMessage(answer[:int(n)], "res_nsearch")
}
response, err := handleDarwinResolvFailure(name, answer, int(hErrno))
if err == nil {
return response, nil
}
if errors.Is(err, errDarwinNeedLargerBuffer) && bufSize < darwinResolverMaxPacketSize {
bufSize *= 2
if bufSize > darwinResolverMaxPacketSize {
bufSize = darwinResolverMaxPacketSize
}
continue
}
return nil, err
}
}
func unpackDarwinResolverMessage(packet []byte, source string) (*mDNS.Msg, error) {
var response mDNS.Msg
err := response.Unpack(packet)
if err != nil {
return nil, E.Cause(err, "unpack ", source, " response")
}
return &response, nil
}
func handleDarwinResolvFailure(name string, answer []byte, hErrno int) (*mDNS.Msg, error) {
response, err := unpackDarwinResolverMessage(answer, "res_nsearch failure")
if err == nil && response.Response {
if response.Truncated && len(answer) < darwinResolverMaxPacketSize {
return nil, errDarwinNeedLargerBuffer
}
return response, nil
}
return nil, darwinResolverHErrno(name, hErrno)
}
func darwinResolverHErrno(name string, hErrno int) error {
switch hErrno {
case darwinResolverHostNotFound:
return dns.RcodeNameError
case darwinResolverTryAgain:
return dns.RcodeServerFailure
case darwinResolverNoRecovery:
return dns.RcodeServerFailure
case darwinResolverNoData:
return dns.RcodeSuccess
default:
return E.New("res_nsearch: unknown error ", hErrno, " for ", name)
}
}
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
question := message.Question[0]
if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
addresses := t.hosts.Lookup(dns.FqdnToDomain(question.Name))
if len(addresses) > 0 {
return dns.FixedResponse(message.Id, question, addresses, boxC.DefaultDNSTTL), nil
}
}
if t.fallback && t.dhcpTransport != nil {
dhcpServers := t.dhcpTransport.Fetch()
if len(dhcpServers) > 0 {
return t.dhcpTransport.Exchange0(ctx, message, dhcpServers)
}
}
name := question.Name
timeoutSeconds := int(boxC.DNSTimeout / time.Second)
if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
remaining := time.Until(deadline)
if remaining <= 0 {
return nil, context.DeadlineExceeded
}
seconds := int(remaining.Seconds())
if seconds < 1 {
seconds = 1
}
timeoutSeconds = seconds
}
type resolvResult struct {
response *mDNS.Msg
err error
}
resultCh := make(chan resolvResult, 1)
go func() {
response, err := darwinLookupSystemDNS(name, int(question.Qclass), int(question.Qtype), timeoutSeconds)
resultCh <- resolvResult{response, err}
}()
var result resolvResult
select {
case <-ctx.Done():
return nil, ctx.Err()
case result = <-resultCh:
}
if result.err != nil {
var rcodeError dns.RcodeError
if errors.As(result.err, &rcodeError) {
return dns.FixedResponseStatus(message, int(rcodeError)), nil
}
return nil, result.err
}
result.response.Id = message.Id
return result.response, nil
}

View File

@@ -1,5 +1,3 @@
//go:build !darwin
package local package local
import ( import (
@@ -9,6 +7,7 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport" "github.com/sagernet/sing-box/dns/transport"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
@@ -50,6 +49,13 @@ func (t *Transport) exchangeParallel(ctx context.Context, systemConfig *dnsConfi
results := make(chan queryResult) results := make(chan queryResult)
startRacer := func(ctx context.Context, fqdn string) { startRacer := func(ctx context.Context, fqdn string) {
response, err := t.tryOneName(ctx, systemConfig, fqdn, message) response, err := t.tryOneName(ctx, systemConfig, fqdn, message)
if err == nil {
if response.Rcode != mDNS.RcodeSuccess {
err = dns.RcodeError(response.Rcode)
} else if len(dns.MessageToAddresses(response)) == 0 {
err = E.New(fqdn, ": empty result")
}
}
select { select {
case results <- queryResult{response, err}: case results <- queryResult{response, err}:
case <-returned: case <-returned:

View File

@@ -9,7 +9,6 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"sync" "sync"
"time"
"github.com/sagernet/quic-go" "github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/http3" "github.com/sagernet/quic-go/http3"
@@ -41,23 +40,18 @@ func RegisterHTTP3Transport(registry *dns.TransportRegistry) {
type HTTP3Transport struct { type HTTP3Transport struct {
dns.TransportAdapter dns.TransportAdapter
logger logger.ContextLogger logger logger.ContextLogger
dialer N.Dialer dialer N.Dialer
destination *url.URL destination *url.URL
headers http.Header headers http.Header
handshakeTimeout time.Duration serverAddr M.Socksaddr
serverAddr M.Socksaddr tlsConfig *tls.STDConfig
tlsConfig *tls.STDConfig transportAccess sync.Mutex
transportAccess sync.Mutex transport *http3.Transport
transport *http3.Transport
} }
func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteHTTPSDNSServerOptions) (adapter.DNSTransport, error) { func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteHTTPSDNSServerOptions) (adapter.DNSTransport, error) {
remoteOptions := option.RemoteDNSServerOptions{ transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions)
DNSServerAddressOptions: options.DNSServerAddressOptions,
}
remoteOptions.DialerOptions = options.DialerOptions
transportDialer, err := dns.NewRemoteDialer(ctx, remoteOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -67,7 +61,6 @@ func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options
if err != nil { if err != nil {
return nil, err return nil, err
} }
handshakeTimeout := tlsConfig.HandshakeTimeout()
stdConfig, err := tlsConfig.STDConfig() stdConfig, err := tlsConfig.STDConfig()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -109,12 +102,11 @@ func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options
return nil, E.New("invalid server address: ", serverAddr) return nil, E.New("invalid server address: ", serverAddr)
} }
t := &HTTP3Transport{ t := &HTTP3Transport{
TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeHTTP3, tag, remoteOptions), TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeHTTP3, tag, options.RemoteDNSServerOptions),
logger: logger, logger: logger,
dialer: transportDialer, dialer: transportDialer,
destination: &destinationURL, destination: &destinationURL,
headers: headers, headers: headers,
handshakeTimeout: handshakeTimeout,
serverAddr: serverAddr, serverAddr: serverAddr,
tlsConfig: stdConfig, tlsConfig: stdConfig,
} }
@@ -123,17 +115,8 @@ func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options
} }
func (t *HTTP3Transport) newTransport() *http3.Transport { func (t *HTTP3Transport) newTransport() *http3.Transport {
quicConfig := &quic.Config{}
if t.handshakeTimeout > 0 {
quicConfig.HandshakeIdleTimeout = t.handshakeTimeout
}
return &http3.Transport{ return &http3.Transport{
QUICConfig: quicConfig,
Dial: func(ctx context.Context, addr string, tlsCfg *tls.STDConfig, cfg *quic.Config) (*quic.Conn, error) { Dial: func(ctx context.Context, addr string, tlsCfg *tls.STDConfig, cfg *quic.Config) (*quic.Conn, error) {
if t.handshakeTimeout > 0 && cfg.HandshakeIdleTimeout == 0 {
cfg = cfg.Clone()
cfg.HandshakeIdleTimeout = t.handshakeTimeout
}
conn, dialErr := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr) conn, dialErr := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
if dialErr != nil { if dialErr != nil {
return nil, dialErr return nil, dialErr

View File

@@ -1,13 +1,21 @@
package dns package dns
import ( import (
"net/netip"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
) )
var _ adapter.LegacyDNSTransport = (*TransportAdapter)(nil)
type TransportAdapter struct { type TransportAdapter struct {
transportType string transportType string
transportTag string transportTag string
dependencies []string dependencies []string
strategy C.DomainStrategy
clientSubnet netip.Prefix
} }
func NewTransportAdapter(transportType string, transportTag string, dependencies []string) TransportAdapter { func NewTransportAdapter(transportType string, transportTag string, dependencies []string) TransportAdapter {
@@ -27,6 +35,8 @@ func NewTransportAdapterWithLocalOptions(transportType string, transportTag stri
transportType: transportType, transportType: transportType,
transportTag: transportTag, transportTag: transportTag,
dependencies: dependencies, dependencies: dependencies,
strategy: C.DomainStrategy(localOptions.LegacyStrategy),
clientSubnet: localOptions.LegacyClientSubnet,
} }
} }
@@ -35,10 +45,15 @@ func NewTransportAdapterWithRemoteOptions(transportType string, transportTag str
if remoteOptions.DomainResolver != nil && remoteOptions.DomainResolver.Server != "" { if remoteOptions.DomainResolver != nil && remoteOptions.DomainResolver.Server != "" {
dependencies = append(dependencies, remoteOptions.DomainResolver.Server) dependencies = append(dependencies, remoteOptions.DomainResolver.Server)
} }
if remoteOptions.LegacyAddressResolver != "" {
dependencies = append(dependencies, remoteOptions.LegacyAddressResolver)
}
return TransportAdapter{ return TransportAdapter{
transportType: transportType, transportType: transportType,
transportTag: transportTag, transportTag: transportTag,
dependencies: dependencies, dependencies: dependencies,
strategy: C.DomainStrategy(remoteOptions.LegacyStrategy),
clientSubnet: remoteOptions.LegacyClientSubnet,
} }
} }
@@ -53,3 +68,11 @@ func (a *TransportAdapter) Tag() string {
func (a *TransportAdapter) Dependencies() []string { func (a *TransportAdapter) Dependencies() []string {
return a.dependencies return a.dependencies
} }
func (a *TransportAdapter) LegacyStrategy() C.DomainStrategy {
return a.strategy
}
func (a *TransportAdapter) LegacyClientSubnet() netip.Prefix {
return a.clientSubnet
}

View File

@@ -2,25 +2,104 @@ package dns
import ( import (
"context" "context"
"net"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
) )
func NewLocalDialer(ctx context.Context, options option.LocalDNSServerOptions) (N.Dialer, error) { func NewLocalDialer(ctx context.Context, options option.LocalDNSServerOptions) (N.Dialer, error) {
return dialer.NewWithOptions(dialer.Options{ if options.LegacyDefaultDialer {
Context: ctx, return dialer.NewDefaultOutbound(ctx), nil
Options: options.DialerOptions, } else {
DirectResolver: true, return dialer.NewWithOptions(dialer.Options{
}) Context: ctx,
Options: options.DialerOptions,
DirectResolver: true,
LegacyDNSDialer: options.Legacy,
})
}
} }
func NewRemoteDialer(ctx context.Context, options option.RemoteDNSServerOptions) (N.Dialer, error) { func NewRemoteDialer(ctx context.Context, options option.RemoteDNSServerOptions) (N.Dialer, error) {
return dialer.NewWithOptions(dialer.Options{ if options.LegacyDefaultDialer {
Context: ctx, transportDialer := dialer.NewDefaultOutbound(ctx)
Options: options.DialerOptions, if options.LegacyAddressResolver != "" {
RemoteIsDomain: options.ServerIsDomain(), transport := service.FromContext[adapter.DNSTransportManager](ctx)
DirectResolver: true, resolverTransport, loaded := transport.Transport(options.LegacyAddressResolver)
}) if !loaded {
return nil, E.New("address resolver not found: ", options.LegacyAddressResolver)
}
transportDialer = newTransportDialer(transportDialer, service.FromContext[adapter.DNSRouter](ctx), resolverTransport, C.DomainStrategy(options.LegacyAddressStrategy), time.Duration(options.LegacyAddressFallbackDelay))
} else if options.ServerIsDomain() {
return nil, E.New("missing address resolver for server: ", options.Server)
}
return transportDialer, nil
} else {
return dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: options.DialerOptions,
RemoteIsDomain: options.ServerIsDomain(),
DirectResolver: true,
LegacyDNSDialer: options.Legacy,
})
}
}
type legacyTransportDialer struct {
dialer N.Dialer
dnsRouter adapter.DNSRouter
transport adapter.DNSTransport
strategy C.DomainStrategy
fallbackDelay time.Duration
}
func newTransportDialer(dialer N.Dialer, dnsRouter adapter.DNSRouter, transport adapter.DNSTransport, strategy C.DomainStrategy, fallbackDelay time.Duration) *legacyTransportDialer {
return &legacyTransportDialer{
dialer,
dnsRouter,
transport,
strategy,
fallbackDelay,
}
}
func (d *legacyTransportDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if destination.IsIP() {
return d.dialer.DialContext(ctx, network, destination)
}
addresses, err := d.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{
Transport: d.transport,
Strategy: d.strategy,
})
if err != nil {
return nil, err
}
return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == C.DomainStrategyPreferIPv6, d.fallbackDelay)
}
func (d *legacyTransportDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if destination.IsIP() {
return d.dialer.ListenPacket(ctx, destination)
}
addresses, err := d.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{
Transport: d.transport,
Strategy: d.strategy,
})
if err != nil {
return nil, err
}
conn, _, err := N.ListenSerial(ctx, d.dialer, destination, addresses)
return conn, err
}
func (d *legacyTransportDialer) Upstream() any {
return d.dialer
} }

View File

@@ -2,132 +2,10 @@
icon: material/alert-decagram icon: material/alert-decagram
--- ---
#### 1.14.0-alpha.11 #### 1.14.0-alpha.3
* Add optimistic DNS cache **1**
* Update NaiveProxy to 147.0.7727.49
* Fixes and improvements
**1**:
Optimistic DNS cache returns an expired cached response immediately while
refreshing it in the background, reducing tail latency for repeated
queries. Enabled via [`optimistic`](/configuration/dns/#optimistic)
in DNS options, and can be persisted across restarts with the new
[`store_dns`](/configuration/experimental/cache-file/#store_dns) cache
file option. A per-query
[`disable_optimistic_cache`](/configuration/dns/rule_action/#disable_optimistic_cache)
field is also available on DNS rule actions and the `resolve` route rule
action.
This deprecates the `independent_cache` DNS option (the DNS cache now
always keys by transport) and the `store_rdrc` cache file option
(replaced by `store_dns`); both will be removed in sing-box 1.16.0.
See [Migration](/migration/#migrate-independent-dns-cache).
#### 1.14.0-alpha.10
* Add `evaluate` DNS rule action and Response Match Fields **1**
* `ip_version` and `query_type` now also take effect on internal DNS lookups **2**
* Add `package_name_regex` route, DNS and headless rule item **3**
* Add cloudflared inbound **4**
* Fixes and improvements
**1**:
Response Match Fields
([`response_rcode`](/configuration/dns/rule/#response_rcode),
[`response_answer`](/configuration/dns/rule/#response_answer),
[`response_ns`](/configuration/dns/rule/#response_ns),
and [`response_extra`](/configuration/dns/rule/#response_extra))
match the evaluated DNS response. They are gated by the new
[`match_response`](/configuration/dns/rule/#match_response) field and
populated by a preceding
[`evaluate`](/configuration/dns/rule_action/#evaluate) DNS rule action;
the evaluated response can also be returned directly by a
[`respond`](/configuration/dns/rule_action/#respond) action.
This deprecates the Legacy Address Filter Fields (`ip_cidr`,
`ip_is_private` without `match_response`) in DNS rules, the Legacy
`strategy` DNS rule action option, and the Legacy
`rule_set_ip_cidr_accept_empty` DNS rule item; all three will be removed
in sing-box 1.16.0.
See [Migration](/migration/#migrate-address-filter-fields-to-response-matching).
**2**:
`ip_version` and `query_type` in DNS rules, together with `query_type` in
referenced rule-sets, now take effect on every DNS rule evaluation,
including matches from internal domain resolutions that do not target a
specific DNS server (for example a `resolve` route rule action without
`server` set). In earlier versions they were silently ignored in that
path. Combining these fields with any of the legacy DNS fields deprecated
in **1** in the same DNS configuration is no longer supported and is
rejected at startup.
See [Migration](/migration/#ip_version-and-query_type-behavior-changes-in-dns-rules).
**3**:
See [Route Rule](/configuration/route/rule/#package_name_regex),
[DNS Rule](/configuration/dns/rule/#package_name_regex) and
[Headless Rule](/configuration/rule-set/headless-rule/#package_name_regex).
**4**:
See [Cloudflared](/configuration/inbound/cloudflared/).
#### 1.13.7
* Fixes and improvement
#### 1.13.6
* Fixes and improvements * Fixes and improvements
#### 1.14.0-alpha.8
* Add BBR profile and hop interval randomization for Hysteria2 **1**
* Fixes and improvements
**1**:
See [Hysteria2 Inbound](/configuration/inbound/hysteria2/#bbr_profile) and [Hysteria2 Outbound](/configuration/outbound/hysteria2/#bbr_profile).
#### 1.14.0-alpha.8
* Fixes and improvements
#### 1.13.5
* Fixes and improvements
#### 1.14.0-alpha.7
* Fixes and improvements
#### 1.13.4
* Fixes and improvements
#### 1.14.0-alpha.4
* Refactor ACME support to certificate provider system **1**
* Add Cloudflare Origin CA certificate provider **2**
* Add Tailscale certificate provider **3**
* Fixes and improvements
**1**:
See [Certificate Provider](/configuration/shared/certificate-provider/) and [Migration](/migration/#migrate-inline-acme-to-certificate-provider).
**2**:
See [Cloudflare Origin CA](/configuration/shared/certificate-provider/cloudflare-origin-ca).
**3**:
See [Tailscale](/configuration/shared/certificate-provider/tailscale).
#### 1.13.3 #### 1.13.3
* Add OpenWrt and Alpine APK packages to release **1** * Add OpenWrt and Alpine APK packages to release **1**
@@ -826,7 +704,7 @@ DNS servers are refactored for better performance and scalability.
See [DNS server](/configuration/dns/server/). See [DNS server](/configuration/dns/server/).
For migration, see [Migrate to new DNS server formats](/migration/#migrate-to-new-dns-server-formats). For migration, see [Migrate to new DNS server formats](/migration/#migrate-to-new-dns-servers).
Compatibility for old formats will be removed in sing-box 1.14.0. Compatibility for old formats will be removed in sing-box 1.14.0.
@@ -1296,7 +1174,7 @@ DNS servers are refactored for better performance and scalability.
See [DNS server](/configuration/dns/server/). See [DNS server](/configuration/dns/server/).
For migration, see [Migrate to new DNS server formats](/migration/#migrate-to-new-dns-server-formats). For migration, see [Migrate to new DNS server formats](/migration/#migrate-to-new-dns-servers).
Compatibility for old formats will be removed in sing-box 1.14.0. Compatibility for old formats will be removed in sing-box 1.14.0.
@@ -2132,7 +2010,7 @@ See [Migration](/migration/#process_path-format-update-on-windows).
The new DNS feature allows you to more precisely bypass Chinese websites via **DNS leaks**. Do not use plain local DNS The new DNS feature allows you to more precisely bypass Chinese websites via **DNS leaks**. Do not use plain local DNS
if using this method. if using this method.
See [Legacy Address Filter Fields](/configuration/dns/rule#legacy-address-filter-fields). See [Address Filter Fields](/configuration/dns/rule#address-filter-fields).
[Client example](/manual/proxy/client#traffic-bypass-usage-for-chinese-users) updated. [Client example](/manual/proxy/client#traffic-bypass-usage-for-chinese-users) updated.
@@ -2146,7 +2024,7 @@ the [Client example](/manual/proxy/client#traffic-bypass-usage-for-chinese-users
**5**: **5**:
The new feature allows you to cache the check results of The new feature allows you to cache the check results of
[Legacy Address Filter Fields](/configuration/dns/rule/#legacy-address-filter-fields) until expiration. [Address filter DNS rule items](/configuration/dns/rule/#address-filter-fields) until expiration.
**6**: **6**:
@@ -2327,7 +2205,7 @@ See [TUN](/configuration/inbound/tun) inbound.
**1**: **1**:
The new feature allows you to cache the check results of The new feature allows you to cache the check results of
[Legacy Address Filter Fields](/configuration/dns/rule/#legacy-address-filter-fields) until expiration. [Address filter DNS rule items](/configuration/dns/rule/#address-filter-fields) until expiration.
#### 1.9.0-alpha.7 #### 1.9.0-alpha.7
@@ -2374,7 +2252,7 @@ See [Migration](/migration/#process_path-format-update-on-windows).
The new DNS feature allows you to more precisely bypass Chinese websites via **DNS leaks**. Do not use plain local DNS The new DNS feature allows you to more precisely bypass Chinese websites via **DNS leaks**. Do not use plain local DNS
if using this method. if using this method.
See [Legacy Address Filter Fields](/configuration/dns/rule#legacy-address-filter-fields). See [Address Filter Fields](/configuration/dns/rule#address-filter-fields).
[Client example](/manual/proxy/client#traffic-bypass-usage-for-chinese-users) updated. [Client example](/manual/proxy/client#traffic-bypass-usage-for-chinese-users) updated.

View File

@@ -42,7 +42,6 @@ SFA provides an unprivileged TUN implementation through Android VpnService.
| `process_path` | :material-close: | No permission | | `process_path` | :material-close: | No permission |
| `process_path_regex` | :material-close: | No permission | | `process_path_regex` | :material-close: | No permission |
| `package_name` | :material-check: | / | | `package_name` | :material-check: | / |
| `package_name_regex` | :material-check: | / |
| `user` | :material-close: | Use `package_name` instead | | `user` | :material-close: | Use `package_name` instead |
| `user_id` | :material-close: | Use `package_name` instead | | `user_id` | :material-close: | Use `package_name` instead |
| `wifi_ssid` | :material-check: | Fine location permission required | | `wifi_ssid` | :material-check: | Fine location permission required |

Some files were not shown because too many files have changed in this diff Show More