Compare commits

..

104 Commits

Author SHA1 Message Date
世界
bf9e390cf4 fix(ccm,ocm): allow reverse connector credentials to serve local requests
Connector mode credentials were unconditionally blocked from local use
by unavailableError(), despite having a working forwardHTTPClient. Also
set credentialDialer in OCM connector mode to prevent nil panic in
WebSocket handler.
2026-03-28 18:46:10 +08:00
世界
471c9c3b47 fix(ccm): make refresh failure fail fast 2026-03-28 18:07:31 +08:00
世界
e7478ce947 fix(ccm): mark credential unavailable on refresh failure, handle poll 401
tryRefreshCredentials now returns error and calls markCredentialsUnavailable
when lock acquisition or file write permission fails. getAccessToken propagates
the error instead of silently returning the expired token. pollUsage handles
401 by attempting auth recovery and marking unavailable on failure. All
credential error paths now use Error log level instead of Debug. Startup
checks expired tokens eagerly via tryRefreshCredentials.
2026-03-28 16:58:52 +08:00
世界
cf11e0e74a Reuse SDK JSON types in ccm and ocm 2026-03-28 11:20:24 +08:00
世界
a87a2b0e2b ocm: log think level 2026-03-28 02:00:27 +08:00
世界
d9c298af1e fix(ccm,ocm): remove upstream rate limit header forwarding, compute locally
Strip all upstream rate limit headers and compute unified-status,
representative-claim, reset times, and surpassed-threshold from
aggregated utilization data. Never expose per-account overage or
fallback information. Remove per-credential unified state storage,
snapshot aggregation, and WebSocket synthetic rate limit events.
2026-03-26 23:42:14 +08:00
世界
cd5007ffbb fix(ccm,ocm): track external credential poll failures and re-poll on user connect
External credentials now properly increment consecutivePollFailures on
poll errors (matching defaultCredential behavior), marking the credential
as temporarily blocked. When a user with external_credential connects
and the credential is not usable, a forced poll is triggered to check
recovery.
2026-03-26 22:16:02 +08:00
世界
e49d0685ad ccm: Fix token refresh 2026-03-26 21:37:12 +08:00
世界
e1c9667319 Revert "fix(ocm): send rate limit status immediately on WebSocket connect"
This reverts commit 6721dff48a.
2026-03-26 16:03:46 +08:00
世界
6f45ea9c27 Revert "fix(ocm): inject synthetic rate limits inline when intercepting upstream events"
This reverts commit ca60f93184.
2026-03-26 16:03:39 +08:00
世界
ca60f93184 fix(ocm): inject synthetic rate limits inline when intercepting upstream events
The initial synthetic event from 6721dff48 arrives before the Codex CLI's
response stream reader is active. Additionally, the shouldEmit gate in
updateStateFromHeaders suppresses the async replacement when values haven't
changed. Send aggregated status inline in proxyWebSocketUpstreamToClient
so the client receives it at the exact protocol position it expects.
2026-03-26 15:42:51 +08:00
世界
1774d98793 fix(ccm,ocm): restore fixed usage polling
Remove the poll_interval config surface from CCM and OCM so both services fall back to the built-in 1h polling cadence again. Also isolate CCM credential lock mocking per test instance so the access-token refresh tests stop racing on shared global state.
2026-03-26 14:01:24 +08:00
世界
6721dff48a fix(ocm): send rate limit status immediately on WebSocket connect
Codex CLI ignores x-codex-* headers in the WebSocket upgrade response
and only reads rate limits from in-band codex.rate_limits events.
Previously, the first synthetic event was gated by firstRealRequest
(after warmup), delaying usage display. Now send aggregated status
right after subscribing, so the client sees rate limits before the
first turn begins.
2026-03-26 12:33:53 +08:00
世界
4592164a7a Align CCM and OCM rate limits 2026-03-24 22:06:10 +08:00
世界
92c8f4c5c8 fix(ccm): align default credential with Claude Code 2026-03-24 21:15:46 +08:00
世界
441c98890d feat(ccm): add claude_directory option to read Claude Code config 2026-03-22 16:59:36 +08:00
世界
bb2169bc17 release: Fix install_go.sh 2026-03-22 06:35:34 +08:00
世界
d996b60f44 ccm/ocm: Add CLAUDE.md 2026-03-22 06:23:13 +08:00
世界
084a6f1302 fix(ccm): align OAuth token refresh with Claude Code v2.1.81
After re-login with newer Claude Code (v2.1.75+), CCM refresh requests
returned persistent 429s. Root cause: CCM omitted the `scope` parameter
that the server now requires for tokens with `user:file_upload` scope.

Changes to fully match Claude Code's OAuth behavior:

- Add `scope` parameter to token refresh request body
- Parse `scope` from refresh response and store back
- Add `subscriptionType`/`rateLimitTier` to credential struct to
  preserve Claude Code's profile state on write-back
- Change credential file write to read-modify-write, preserving
  other top-level JSON keys (matches Claude Code's BP6 pattern)
- Same for macOS keychain write path
- Increase token expiry buffer from 1 min to 5 min (matching CC's
  isOAuthTokenExpired with 300s buffer)
- Add cross-process mkdir-based file lock compatible with Claude
  Code's proper-lockfile protocol (~/.claude.lock)
- Add post-failure recovery: re-read credentials from disk after
  refresh failure in case another process succeeded
- Add 401/403 "OAuth token has been revoked" recovery in proxy
  handler: reload credentials and retry once
2026-03-22 06:02:55 +08:00
世界
0950783479 fix(ccm,ocm): exclude unusable credentials from status aggregation
computeAggregatedUtilization used isAvailable() which only checks
permanent unavailability, so credentials rejected by upstream 400
still had their planWeight included in the total, inflating reported
capacity and diluting utilization.
2026-03-21 11:42:49 +08:00
世界
f172a575b7 fix(ccm): log assigned credential for each distinct model per session 2026-03-21 11:07:43 +08:00
世界
29b901a8b3 fix(ccm): robust account UUID injection and session ID validation
Replace bytes.Replace-based UUID injection with proper JSON
unmarshal/re-marshal through map[string]json.RawMessage — the old
approach silently failed when the body used non-canonical JSON escaping.

Return 500 when metadata.user_id is present but in an unrecognized
format, instead of silently passing through with an empty session ID.
2026-03-21 11:00:05 +08:00
世界
53f832330d fix(ccm): adapt to Claude Code v2.1.78 metadata format, separate state from credentials
Claude Code v2.1.78 changed metadata.user_id from a template literal
(`user_${id}_account_${uuid}_session_${sid}`) to a JSON-encoded object
(`JSON.stringify({device_id, account_uuid, session_id})`), breaking
session ID extraction via `_session_` substring match.

- Fix extractCCMSessionID to try JSON parse first, fallback to legacy
- Remove subscriptionType/rateLimitTier/isMax from oauthCredentials
  (profile state does not belong in auth credentials)
- Add state_path option for persisting profile state across restarts
- Parse account.uuid from /api/oauth/profile response
- Inject account_uuid into forwarded requests when client sends it empty
  (happens when using ANTHROPIC_AUTH_TOKEN instead of Claude AI OAuth)
2026-03-21 10:45:24 +08:00
世界
99d9e06dd0 fix(ccm,ocm): handle upstream 400 by marking external credentials rejected and polling default credentials
External credentials returning 400 are marked unavailable for pollInterval
duration; status stream/poll success clears the rejection early. Default
credentials trigger a stale poll to let the usage API detect account issues
without causing 429 storms.
2026-03-21 10:31:17 +08:00
世界
608b7e7fa2 fix(ccm,ocm): stop cascading 429 retry storm on token refresh
When the access token expires and refreshToken() gets 429, getAccessToken()
returned the error but left credentials unchanged with no cooldown. Every
subsequent request re-attempted the refresh, creating a burst that overwhelmed
the token endpoint.

- refreshToken() now returns Retry-After duration from 429 response headers
  (-1 when no header present, meaning permanently blocked)
- getAccessToken() caches the 429 and blocks further refresh attempts until
  Retry-After expires (or permanently if no header)
- reloadCredentials() clears the block when new credentials are loaded from file
- Remove go pollUsage() on upstream errors (unrelated to usage state)
2026-03-21 09:31:05 +08:00
世界
7acba74755 fix(ccm): forward 529 upstream overloaded response transparently 2026-03-18 15:53:36 +08:00
世界
2fe1e37b17 fix(ccm,ocm): add missing isFirstUpdate to external credential usage logging 2026-03-18 01:00:55 +08:00
世界
3bcfdd5455 fix(ccm,ocm): remove external context from pollUsage/pollIfStale
pollUsage(ctx) accepted caller context, and service_status.go passed
r.Context() which gets canceled on client disconnect or service shutdown.
This caused incrementPollFailures → interruptConnections on transient
cancellations. Each implementation now uses its own persistent context:
defaultCredential uses serviceContext, externalCredential uses
getReverseContext().
2026-03-18 00:54:01 +08:00
世界
b119d08764 fix(ccm,ocm): add usage logging to status stream, remove redundant isFirstUpdate
connectStatusStream updated credential state silently — no log on
first frame or value changes. After restart, external credentials
get usage via stream before any request, so pollIfStale skips them
and no usage log ever appears.

Add the same change-detection log to connectStatusStream. Also remove
redundant isFirstUpdate guards from pollUsage and updateStateFromHeaders:
when old values are zero, any non-zero new value already satisfies the
integer-percent comparison.
2026-03-17 22:37:38 +08:00
世界
6b8838d323 fix(ccm,ocm): restart status stream when receiver gets reverse session
statusStreamLoop started on start() before any reverse session existed,
got a non-retryable error, and exited permanently. Restart it when
setReverseSession transitions receiver credentials to available.
2026-03-17 22:08:30 +08:00
世界
b3429ef1f3 fix(ocm): strip non-active rate-limit headers from forwarded responses 2026-03-17 22:01:30 +08:00
世界
a2d6cf9715 fix(ocm): defer initial websocket rate-limit push 2026-03-17 21:14:14 +08:00
世界
99e19e7033 service: stop retrying fatal watch status errors 2026-03-17 20:47:52 +08:00
世界
969defeef0 ccm,ocm: validate external status response fields 2026-03-17 20:17:56 +08:00
世界
f57eff33bb ccm,ocm: fix WS push lifecycle, deduplicate rate_limits, stabilize reset aggregation
- Add closed channel to webSocketSession for push goroutine shutdown
  on connection close, preventing session leak and Service.Close() hang
- Intercept upstream codex.rate_limits events instead of forwarding;
  push goroutine is now the sole sender of aggregated rate_limits
- Emit status updates on reset-only changes (fiveHourResetChanged,
  weeklyResetChanged) so push goroutine picks up reset advances
- Skip expired resets (hours <= 0) in aggregation instead of clamping
  to now, avoiding unstable reset_at output and spurious status ticks
- Delete stale upstream reset headers when aggregated reset is zero
- Hardcode "codex" identifier everywhere: handleWebSocketRateLimitsEvent,
  buildSyntheticRateLimitsEvent, rewriteResponseHeaders
- Remove rewriteWebSocketRateLimits, rewriteWebSocketRateLimitWindow,
  identifier tracking (TypedValue), and unused imports
2026-03-17 20:00:54 +08:00
世界
0a054b9aa4 ccm,ocm: propagate reset times, rewrite headers for all users, add WS status push
- Add fiveHourReset/weeklyReset to statusPayload and aggregatedStatus
  with weight-averaged reset time aggregation across credential pools
- Rewrite response headers (utilization + reset times) for all users,
  not just external credential users
- Rewrite WebSocket rate_limits events for all users with aggregated values
- Add proactive WebSocket status push: synthetic codex.rate_limits events
  sent on connection start and on status changes via statusObserver
- Remove one-shot stream forward compatibility (statusStreamHeader,
  restoreLastUpdatedIfUnchanged, oneShot detection)
2026-03-17 18:13:54 +08:00
世界
7d15d9d282 ccm: emit status updates for plan-weight-only changes 2026-03-17 16:46:54 +08:00
世界
cf2d677043 ocm: emit status updates for plan-weight-only changes 2026-03-17 16:32:03 +08:00
世界
4a6a211775 ccm,ocm: reduce status emission noise, simplify emit-guard pattern
Guard updateStateFromHeaders emission with value-change detection to
avoid unnecessary computeAggregatedUtilization scans on every proxied
response. Replace statusAggregateStateLocked two-value return with
comparable statusSnapshot struct. Define statusPayload type for the
status wire format, replacing anonymous structs and map literals.
2026-03-17 16:10:59 +08:00
世界
f84832a369 Add stream watch endpoint 2026-03-17 16:03:35 +08:00
世界
f3c3022094 ccm,ocm: fix session race, track fallback sessions, skip warmup logging
Fix data race in selectCredential where concurrent goroutines could
overwrite each other's session entries by adding compare-and-delete
and store-if-absent patterns with retry loop. Track sessions for
fallback strategy so isNew is reported correctly. Skip logging and
usage tracking for websocket warmup requests (generate: false).
2026-03-16 22:10:10 +08:00
世界
2dd093a32e ccm,ocm: fix data race, remove dead code, clean up inefficiencies 2026-03-15 21:20:29 +08:00
世界
14ade76956 ccm,ocm: remove dead code, fix timer leaks, eliminate redundant lookups
- Remove unused onBecameUnusable field from CCM credential structs
  (OCM wires it for WebSocket interruption; CCM has no equivalent)
- Replace time.After with time.NewTimer in doHTTPWithRetry and
  connectorLoop to avoid timer leaks on context cancellation
- Pass already-resolved provider to rewriteResponseHeadersForExternalUser
  instead of re-resolving via credentialForUser
- Hoist reverseYamuxConfig to package-level var (immutable, no need to
  allocate on every call)
2026-03-15 20:42:41 +08:00
世界
9e3ec30d72 docs: fix ccm and ocm credential docs 2026-03-15 20:41:47 +08:00
世界
763e0af010 docs: complete ccm/ocm documentation for 1.14.0 features 2026-03-15 18:49:00 +08:00
世界
656b09d1be ccm,ocm: never treat external usage endpoint failures as over-limit 2026-03-15 18:48:53 +08:00
世界
8e9c61e624 ccm,ocm: normalize legacy fields into credentials at init, remove dual code path 2026-03-15 18:48:53 +08:00
世界
bc6e72408d ccm,ocm: block API key headers from being forwarded upstream 2026-03-15 18:48:52 +08:00
世界
56af7313b2 ccm,ocm: don't treat usage API 429 as account over-limit
The usage API itself has rate limits. A 429 from it means "poll less
frequently", not that the account exceeded its usage quota. Previously
incrementPollFailures() was called, marking the credential unusable and
interrupting in-flight connections.

Now: parse Retry-After, store as usageAPIRetryDelay, and retry after
that delay. The credential stays usable and relies on passive header
updates for usage data in the meantime.
2026-03-15 18:48:52 +08:00
世界
6878ad0d35 ccm,ocm: fix naming and error-handling convention violations
- Rename credential interface to Credential (exported), cred to credential
- Rename mutex/saveMutex to access/saveAccess per go-syntax.md
- Fix abbreviations: reverseHttpClient, allCreds, credOpt, extCred,
  credDialer, reverseCredDialer, portStr
- Replace errors.Is(http.ErrServerClosed) with E.IsClosed
- Add E.IsClosedOrCanceled guard before streaming write error logs
2026-03-15 18:48:51 +08:00
世界
04bd63b455 ccm,ocm: reorganize files and improve naming conventions
Split credential_state.go (1500+ lines) into credential.go,
credential_default.go, credential_provider.go, credential_builder.go.

Split service.go (900+ lines) into service.go, service_handler.go,
service_status.go.

Rename credential.go to credential_oauth.go to avoid name conflict
with the credential interface.

Apply naming fixes: accessMutex→access, stateMutex→stateAccess,
sessionMutex→sessionAccess, webSocketMutex→webSocketAccess,
httpTransport()→httpClient(), httpClient field→forwardHTTPClient,
weeklyWindowDuration→weeklyWindowHours.
2026-03-15 18:48:51 +08:00
世界
51d564c9ff ccm,ocm: merge fallback into balancer strategy, use hyphenated constant names
Merge the fallback credential type into balancer as a strategy
(C.BalancerStrategyFallback). Replace raw string literals with
C.BalancerStrategyXxx constants and switch to hyphens (least-used,
round-robin) per project convention.
2026-03-15 18:48:50 +08:00
世界
4d8baf7175 ccm: fix nil pointer in pollUsage for connector-mode credentials
Connector-mode credentials (URL + reverse: true) never assigned
httpClient, causing a nil dereference when pollUsage accessed
httpClient.Transport.

Also extract poll request logic into doPollUsageRequest to try
reverse transport first (single attempt), then fall back to
forward transport with retries if the reverse session disconnects.
2026-03-15 18:48:50 +08:00
世界
d1e5426bc8 ccm,ocm: add exponential backoff with cap for poll retry
Replace flat 1-minute poll retry interval with exponential backoff
(1m → 2m → 4m → 5m cap). Suppress error logs after reaching the cap.
2026-03-15 18:48:50 +08:00
世界
4d907bc49d ccm,ocm: allow URL-based credentials to accept reverse connections
Previously, findReceiverCredential required baseURL == reverseProxyBaseURL,
so only credentials with no URL could accept incoming reverse connections.
Now credentials with a normal URL also accept reverse connections, preferring
the reverse session when active and falling back to the direct URL when not.
2026-03-15 18:48:49 +08:00
世界
2c907bef2c Fix scoped rebalance interrupts 2026-03-15 18:48:49 +08:00
世界
d2300353fd Propagate request context to upstream requests 2026-03-15 18:48:49 +08:00
世界
f871113832 ccm,ocm: add balancer session rebalancing with per-credential interrupt
When a sticky session's credential utilization exceeds the least-used
credential by a weight-adjusted threshold, force reassign all sessions
on that credential and cancel in-flight requests scoped to the balancer.

Threshold formula: effective = rebalance_threshold / planWeight, so a
config value of 20 triggers at 2% delta for Max 20x (w=10), 4% for
Max 5x (w=5), and 20% for Pro (w=1).
2026-03-15 18:48:49 +08:00
世界
b97b9d9cfd ccm,ocm: add request ID context to HTTP request logging 2026-03-15 18:48:48 +08:00
世界
badeeb91fe service/ocm: add default OpenAI-Beta header and log websocket error body
The upstream OpenAI WebSocket endpoint requires the
OpenAI-Beta: responses_websockets=2026-02-06 header. Set it
automatically when the client doesn't provide it.

Also capture and log the response body on non-429 WebSocket
handshake failures to surface the actual error from upstream.
2026-03-15 18:48:48 +08:00
世界
f4aaf33bf2 ccm,ocm: strip reverse proxy headers from upstream responses 2026-03-15 18:48:48 +08:00
世界
8fe8e238b3 service/ocm: unify websocket logging with HTTP request logging 2026-03-15 18:48:47 +08:00
世界
6f433937ba ccm,ocm: auto-detect plan weight for external credentials via status endpoint 2026-03-15 18:48:47 +08:00
世界
80d5432654 service/ccm: update oauth token URL and remove unnecessary Accept header 2026-03-15 18:48:46 +08:00
世界
8984b45ded ccm,ocm: improve balancer least_used with plan-weighted scoring and reset urgency
Scale remaining capacity by plan weight (Pro=1, Max 5x=5, Max 20x=10
for CCM; Plus=1, Pro=10 for OCM) so higher-tier accounts contribute
proportionally more. Factor in weekly reset proximity so credentials
about to reset are preferred ("use it or lose it").

Auto-detect plan weight from subscriptionType + rateLimitTier (CCM)
or plan_type (OCM). Fetch /api/oauth/profile when rateLimitTier is
missing from the credential file. External credentials accept a
manual plan_weight option.
2026-03-15 18:48:46 +08:00
世界
25a9e4ce59 service/ocm: only log new credential assignments and add websocket logging 2026-03-15 18:48:46 +08:00
世界
615a7e05b4 service/ccm: only log new credential assignments and show context window in model 2026-03-15 18:48:46 +08:00
世界
1628272507 ccm,ocm: mark credentials unusable on usage poll failure and trigger poll on upstream error 2026-03-15 18:48:46 +08:00
世界
ee65b375cb service/ccm: allow extended context (1m) for all credentials
1m context is now available to all subscribers and no longer
consumes Extra Usage.
2026-03-15 18:48:45 +08:00
世界
a09174a9a2 service/ccm: reject fast-mode external credentials 2026-03-15 18:48:45 +08:00
世界
ce543a935f ccm,ocm: fix reserveWeekly default and remove dead reserve fields 2026-03-15 18:48:45 +08:00
世界
7f93c76b1a ccm,ocm: add limit options and fix aggregated utilization scaling
Add limit_5h and limit_weekly options as alternatives to reserve_5h
and reserve_weekly for capping credential utilization. The two are
mutually exclusive per window.

Fix computeAggregatedUtilization to scale per-credential utilization
relative to each credential's cap before averaging, so external users
see correct available capacity regardless of per-credential caps.

Fix pickLeastUsed to compare remaining capacity (cap - utilization)
instead of raw utilization, ensuring fair comparison across credentials
with different caps.
2026-03-15 18:48:44 +08:00
世界
df6e47f5f1 ocm: preserve websocket rate limit event fields 2026-03-15 18:48:44 +08:00
世界
1993da3735 ocm: rewrite codex.rate_limits WebSocket events for external users
The HTTP path rewrites utilization headers for external users via
rewriteResponseHeadersForExternalUser to show aggregated values.
The WebSocket upgrade headers were also rewritten, but in-band
codex.rate_limits events were forwarded unmodified, leaking
per-credential utilization to external users.
2026-03-15 18:48:43 +08:00
世界
22376472d0 ccm,ocm: fix passive usage update for WebSocket connections
WebSocket 101 upgrade responses do not include utilization headers
(confirmed via codex CLI source). Rate limit data is delivered
exclusively through in-band events (codex.rate_limits and error
events with status 429).

Previously, updateStateFromHeaders unconditionally bumped lastUpdated
even when no utilization headers were found, which suppressed polling
and left credential utilization permanently stale during WebSocket
sessions.

- Only bump lastUpdated when actual utilization data is parsed
- Parse in-band codex.rate_limits events to update credential state
- Detect in-band 429 error events to markRateLimited
- Fix WebSocket 429 retry to update old credential state before retry
2026-03-15 18:48:43 +08:00
世界
74bf20d349 ccm,ocm: fix reverse session shutdown race 2026-03-15 18:48:43 +08:00
世界
ff8585f7c6 ccm,ocm: block utilization decrease within same rate-limit window
updateStateFromHeaders unconditionally applied header utilization
values even when they were lower than the current state, causing
poll-sourced values to be overwritten by stale header values.

Parse reset timestamps before utilization and only allow decreases
when the reset timestamp changes (indicating a new rate-limit
window). Also add math.Ceil to CCM external credential for
consistency with default credential.
2026-03-15 18:48:42 +08:00
世界
4d5108fe7f ccm,ocm: fix connector-side bufio data loss in reverse proxy
connectorConnect() creates a bufio.NewReader to read the HTTP 101
upgrade response, but then passes the raw conn to yamux.Server().
If TCP coalesces the 101 response with initial yamux frames, the
bufio reader over-reads into its buffer and those bytes are lost
to yamux, causing session failure.

Wrap the bufio.Reader and raw conn into a bufferedConn so yamux
reads through the buffer first.
2026-03-15 18:48:42 +08:00
世界
3b177df05e ccm,ocm: fix data race on reverseContext/reverseCancel
InterfaceUpdated() writes reverseContext and reverseCancel without
synchronization while connectorLoop/connectorConnect goroutines
read them concurrently. close() also accesses reverseCancel without
a lock.

Fix by extending reverseAccess mutex to protect these fields:
- Add getReverseContext()/resetReverseContext() methods
- Pass context as parameter to connectorConnect
- Merge close() into a single lock acquisition
- Use resetReverseContext() in InterfaceUpdated()
2026-03-15 18:48:42 +08:00
世界
1824881719 ccm,ocm: reset connector backoff after successful connection
The consecutiveFailures counter in connectorLoop never resets,
causing backoff to permanently cap at 30-45s even after a
connection that served successfully for hours.

Reset the counter when connectorConnect ran for at least one
minute, indicating a successful session rather than a transient
dial/handshake failure.
2026-03-15 18:48:41 +08:00
世界
02a1409e9a ccm,ocm: unify HTTP request retry with fast retry and exponential backoff 2026-03-15 18:48:41 +08:00
世界
af94ea9089 Fix reverse external credential handling 2026-03-15 18:48:41 +08:00
世界
970951f369 ccm,ocm: add reverse proxy support for external credentials
Allow two CCM/OCM instances to share credentials when only one has a
public IP, using yamux-multiplexed reverse connections.

Three credential modes:
- Normal: URL set, reverse=false — standard HTTP proxy
- Receiver: URL empty — waits for incoming reverse connection
- Connector: URL set, reverse=true — dials out to establish connection

Extend InterfaceUpdated to services so network changes trigger
reverse connection reconnection.
2026-03-15 18:48:40 +08:00
世界
15f3619995 ccm,ocm: strip reverse proxy headers before forwarding to upstream 2026-03-15 18:48:40 +08:00
世界
b96ab4fef9 ccm,ocm,ssmapi: fix HTTP/2 over TLS with h2c handler
aTLS.NewListener returns *LazyConn, not *tls.Conn, so Go's
http.Server cannot detect TLS via type assertion and falls back
to HTTP/1.x. When ALPN negotiates h2, the client sends HTTP/2
frames that the server fails to parse, causing HTTP 520 errors
behind Cloudflare.

Wrap HTTP handlers with h2c.NewHandler to intercept the HTTP/2
client preface and dispatch to http2.Server.ServeConn, consistent
with DERP, v2rayhttp, naive, and v2raygrpclite services.
2026-03-15 18:48:40 +08:00
世界
6829f91a06 ccm,ocm: check credential file writability before token refresh
Refuse to refresh tokens when the credential file is not writable,
preventing server-side invalidation of the old refresh token that
would make the credential permanently unusable after restart.
2026-03-15 18:48:40 +08:00
世界
8e5811a8c7 ccm,ocm: watch credential_path and allow delayed credentials 2026-03-15 18:48:40 +08:00
世界
da8ff6f578 ccm/ocm: Add external credential support for cross-instance usage sharing
Extract credential interface from *defaultCredential to support both
default (OAuth) and external (remote proxy) credential types. External
credentials proxy requests to a remote ccm/ocm instance with bearer
token auth, poll a /status endpoint for utilization, and parse
aggregated rate limit headers from responses.

Add allow_external_usage user flag to control whether balancer/fallback
providers may select external credentials. Add status endpoint
(/ccm/v1/status, /ocm/v1/status) returning averaged utilization across
eligible credentials. Rewrite response rate limit headers for external
users with aggregated values.
2026-03-15 18:48:39 +08:00
世界
2801bce815 ccm/ocm: Add multi-credential support with balancer and fallback strategies 2026-03-15 18:48:39 +08:00
世界
a11cd1e0c6 Bump version 2026-03-15 17:57:54 +08:00
世界
bd0fb83d2d cronet-go: Update chromium to 145.0.7632.159 2026-03-15 17:57:54 +08:00
世界
9462b1deeb documentation: Update descriptions for neighbor rules 2026-03-15 17:57:53 +08:00
世界
44d1c86b1b Add macOS support for MAC and hostname rule items 2026-03-15 17:57:53 +08:00
世界
f802668915 Add Android support for MAC and hostname rule items 2026-03-15 17:57:53 +08:00
世界
4d217b7481 Add MAC and hostname rule items 2026-03-15 17:57:53 +08:00
世界
d3768cca36 Bump version 2026-03-15 17:56:37 +08:00
世界
0889ddd001 Fix connector canceled dial cleanup 2026-03-15 17:56:37 +08:00
深鸣
f46fbf188a documentation: Minor fixes 2026-03-15 17:56:37 +08:00
世界
f2d15139f5 tun: Fix nftables single include_uid not working 2026-03-15 16:58:34 +08:00
世界
041646b728 Fix kTLS crash 2026-03-14 21:38:38 +08:00
世界
b990de2e12 tun: Fix "Fix auto_redirect dropping SO_BINDTODEVICE traffic" 2026-03-14 21:38:38 +08:00
世界
fe585157d2 Bump version 2026-03-14 21:38:38 +08:00
世界
eed6a36e5d tun:Fix auto_redirect dropping SO_BINDTODEVICE traffic 2026-03-14 21:38:38 +08:00
世界
eb0f38544c tailscale: Fix system interface rules 2026-03-14 21:38:38 +08:00
73 changed files with 10883 additions and 5285 deletions

View File

@@ -41,13 +41,13 @@ jobs:
version: ${{ steps.outputs.outputs.version }}
steps:
- name: Checkout
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.26.0
go-version: ~1.25.8
- name: Check input version
if: github.event_name == 'workflow_dispatch'
run: |-
@@ -117,14 +117,14 @@ jobs:
- { os: android, arch: "386", ndk: "i686-linux-android23" }
steps:
- name: Checkout
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
- name: Setup Go
if: ${{ ! matrix.legacy_win7 }}
uses: actions/setup-go@v5
with:
go-version: ~1.26.0
go-version: ~1.25.8
- name: Cache Go for Windows 7
if: matrix.legacy_win7
id: cache-go-for-windows7
@@ -458,7 +458,7 @@ jobs:
- { arch: amd64, legacy_osx: true, legacy_name: "macos-10.13" }
steps:
- name: Checkout
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
- name: Setup Go
@@ -551,7 +551,7 @@ jobs:
- { arch: arm64, naive: true }
steps:
- name: Checkout
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
- name: Setup Go
@@ -634,14 +634,14 @@ jobs:
- calculate_version
steps:
- name: Checkout
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
submodules: 'recursive'
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.26.0
go-version: ~1.25.8
- name: Setup Android NDK
id: setup-ndk
uses: nttld/setup-ndk@v1
@@ -724,14 +724,14 @@ jobs:
- calculate_version
steps:
- name: Checkout
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
submodules: 'recursive'
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.26.0
go-version: ~1.25.8
- name: Setup Android NDK
id: setup-ndk
uses: nttld/setup-ndk@v1
@@ -822,7 +822,7 @@ jobs:
steps:
- name: Checkout
if: matrix.if
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
submodules: 'recursive'
@@ -830,7 +830,7 @@ jobs:
if: matrix.if
uses: actions/setup-go@v5
with:
go-version: ~1.26.0
go-version: ~1.25.8
- name: Set tag
if: matrix.if
run: |-
@@ -976,7 +976,7 @@ jobs:
- build_apple
steps:
- name: Checkout
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
- name: Cache ghr

View File

@@ -49,14 +49,14 @@ jobs:
echo "ref=$ref"
echo "ref=$ref" >> $GITHUB_OUTPUT
- name: Checkout
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
ref: ${{ steps.ref.outputs.ref }}
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.26.0
go-version: ~1.25.8
- name: Clone cronet-go
if: matrix.naive
run: |
@@ -188,7 +188,7 @@ jobs:
echo "ref=$ref"
echo "ref=$ref" >> $GITHUB_OUTPUT
- name: Checkout
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
ref: ${{ steps.ref.outputs.ref }}
fetch-depth: 0

View File

@@ -24,7 +24,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
- name: Setup Go

View File

@@ -29,13 +29,13 @@ jobs:
version: ${{ steps.outputs.outputs.version }}
steps:
- name: Checkout
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.26.0
go-version: ~1.25.8
- name: Check input version
if: github.event_name == 'workflow_dispatch'
run: |-
@@ -72,13 +72,13 @@ jobs:
- { os: linux, arch: ppc64le, debian: ppc64el, rpm: ppc64le }
steps:
- name: Checkout
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.26.0
go-version: ~1.25.8
- name: Clone cronet-go
if: matrix.naive
run: |
@@ -236,7 +236,7 @@ jobs:
- build
steps:
- name: Checkout
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
with:
fetch-depth: 0
- name: Set tag

4
.gitignore vendored
View File

@@ -18,6 +18,6 @@
.DS_Store
/config.d/
/venv/
CLAUDE.md
AGENTS.md
/CLAUDE.md
/AGENTS.md
/.claude/

View File

@@ -12,6 +12,7 @@ import (
"fmt"
"io"
"net"
"unsafe"
)
func (c *Conn) Read(b []byte) (int, error) {
@@ -229,7 +230,7 @@ func (c *Conn) readRawRecord() (typ uint8, data []byte, err error) {
record := c.rawConn.RawInput.Next(recordHeaderLen + n)
data, typ, err = c.rawConn.In.Decrypt(record)
if err != nil {
err = c.rawConn.In.SetErrorLocked(c.sendAlert(uint8(err.(tls.AlertError))))
err = c.rawConn.In.SetErrorLocked(c.sendAlert(*(*uint8)((*[2]unsafe.Pointer)(unsafe.Pointer(&err))[1])))
return
}
return

View File

@@ -38,6 +38,13 @@ const (
TypeURLTest = "urltest"
)
const (
BalancerStrategyLeastUsed = "least-used"
BalancerStrategyRoundRobin = "round-robin"
BalancerStrategyRandom = "random"
BalancerStrategyFallback = "fallback"
)
func ProxyDisplayName(proxyType string) string {
switch proxyType {
case TypeTun:

View File

@@ -55,6 +55,12 @@ type contextKeyConnecting struct{}
var errRecursiveConnectorDial = E.New("recursive connector dial")
type connectorDialResult[T any] struct {
connection T
cancel context.CancelFunc
err error
}
func (c *Connector[T]) Get(ctx context.Context) (T, error) {
var zero T
for {
@@ -100,41 +106,37 @@ func (c *Connector[T]) Get(ctx context.Context) (T, error) {
return zero, err
}
c.connecting = make(chan struct{})
connecting := make(chan struct{})
c.connecting = connecting
dialContext := context.WithValue(ctx, contextKeyConnecting{}, c)
dialResult := make(chan connectorDialResult[T], 1)
c.access.Unlock()
dialContext := context.WithValue(ctx, contextKeyConnecting{}, c)
connection, cancel, err := c.dialWithCancellation(dialContext)
go func() {
connection, cancel, err := c.dialWithCancellation(dialContext)
dialResult <- connectorDialResult[T]{
connection: connection,
cancel: cancel,
err: err,
}
}()
c.access.Lock()
close(c.connecting)
c.connecting = nil
if err != nil {
c.access.Unlock()
return zero, err
}
if c.closed {
cancel()
c.callbacks.Close(connection)
c.access.Unlock()
select {
case result := <-dialResult:
return c.completeDial(ctx, connecting, result)
case <-ctx.Done():
go func() {
result := <-dialResult
_, _ = c.completeDial(ctx, connecting, result)
}()
return zero, ctx.Err()
case <-c.closeCtx.Done():
go func() {
result := <-dialResult
_, _ = c.completeDial(ctx, connecting, result)
}()
return zero, ErrTransportClosed
}
if err = ctx.Err(); err != nil {
cancel()
c.callbacks.Close(connection)
c.access.Unlock()
return zero, err
}
c.connection = connection
c.hasConnection = true
c.connectionCancel = cancel
result := c.connection
c.access.Unlock()
return result, nil
}
}
@@ -143,6 +145,38 @@ func isRecursiveConnectorDial[T any](ctx context.Context, connector *Connector[T
return loaded && dialConnector == connector
}
func (c *Connector[T]) completeDial(ctx context.Context, connecting chan struct{}, result connectorDialResult[T]) (T, error) {
var zero T
c.access.Lock()
defer c.access.Unlock()
defer func() {
if c.connecting == connecting {
c.connecting = nil
}
close(connecting)
}()
if result.err != nil {
return zero, result.err
}
if c.closed || c.closeCtx.Err() != nil {
result.cancel()
c.callbacks.Close(result.connection)
return zero, ErrTransportClosed
}
if err := ctx.Err(); err != nil {
result.cancel()
c.callbacks.Close(result.connection)
return zero, err
}
c.connection = result.connection
c.hasConnection = true
c.connectionCancel = result.cancel
return c.connection, nil
}
func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, context.CancelFunc, error) {
var zero T
if err := ctx.Err(); err != nil {

View File

@@ -188,13 +188,157 @@ func TestConnectorCanceledRequestDoesNotCacheConnection(t *testing.T) {
err := <-result
require.ErrorIs(t, err, context.Canceled)
require.EqualValues(t, 1, dialCount.Load())
require.EqualValues(t, 1, closeCount.Load())
require.Eventually(t, func() bool {
return closeCount.Load() == 1
}, time.Second, 10*time.Millisecond)
_, err = connector.Get(context.Background())
require.NoError(t, err)
require.EqualValues(t, 2, dialCount.Load())
}
func TestConnectorCanceledRequestReturnsBeforeIgnoredDialCompletes(t *testing.T) {
t.Parallel()
var (
dialCount atomic.Int32
closeCount atomic.Int32
)
dialStarted := make(chan struct{}, 1)
releaseDial := make(chan struct{})
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
dialCount.Add(1)
select {
case dialStarted <- struct{}{}:
default:
}
<-releaseDial
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {
closeCount.Add(1)
},
Reset: func(connection *testConnectorConnection) {},
})
requestContext, cancel := context.WithCancel(context.Background())
result := make(chan error, 1)
go func() {
_, err := connector.Get(requestContext)
result <- err
}()
<-dialStarted
cancel()
select {
case err := <-result:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
t.Fatal("Get did not return after request cancel")
}
require.EqualValues(t, 1, dialCount.Load())
require.EqualValues(t, 0, closeCount.Load())
close(releaseDial)
require.Eventually(t, func() bool {
return closeCount.Load() == 1
}, time.Second, 10*time.Millisecond)
_, err := connector.Get(context.Background())
require.NoError(t, err)
require.EqualValues(t, 2, dialCount.Load())
}
func TestConnectorWaiterDoesNotStartNewDialBeforeCanceledDialCompletes(t *testing.T) {
t.Parallel()
var (
dialCount atomic.Int32
closeCount atomic.Int32
)
firstDialStarted := make(chan struct{}, 1)
secondDialStarted := make(chan struct{}, 1)
releaseFirstDial := make(chan struct{})
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
attempt := dialCount.Add(1)
switch attempt {
case 1:
select {
case firstDialStarted <- struct{}{}:
default:
}
<-releaseFirstDial
case 2:
select {
case secondDialStarted <- struct{}{}:
default:
}
}
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {
closeCount.Add(1)
},
Reset: func(connection *testConnectorConnection) {},
})
requestContext, cancel := context.WithCancel(context.Background())
firstResult := make(chan error, 1)
go func() {
_, err := connector.Get(requestContext)
firstResult <- err
}()
<-firstDialStarted
cancel()
secondResult := make(chan error, 1)
go func() {
_, err := connector.Get(context.Background())
secondResult <- err
}()
select {
case <-secondDialStarted:
t.Fatal("second dial started before first dial completed")
case <-time.After(100 * time.Millisecond):
}
select {
case err := <-firstResult:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
t.Fatal("first Get did not return after request cancel")
}
close(releaseFirstDial)
require.Eventually(t, func() bool {
return closeCount.Load() == 1
}, time.Second, 10*time.Millisecond)
select {
case <-secondDialStarted:
case <-time.After(time.Second):
t.Fatal("second dial did not start after first dial completed")
}
err := <-secondResult
require.NoError(t, err)
require.EqualValues(t, 2, dialCount.Load())
}
func TestConnectorDialContextNotCanceledByRequestContextAfterDial(t *testing.T) {
t.Parallel()

View File

@@ -2,7 +2,11 @@
icon: material/alert-decagram
---
#### 1.14.0-alpha.2
#### 1.14.0-alpha.3
* Fixes and improvements
#### 1.13.3
* Add OpenWrt and Alpine APK packages to release **1**
* Backport to macOS 10.13 High Sierra **2**
@@ -26,7 +30,11 @@ from [SagerNet/go](https://github.com/SagerNet/go).
See [OCM](/configuration/service/ocm).
#### 1.13.3-beta.1
#### 1.12.24
* Fixes and improvements
#### 1.14.0-alpha.2
* Add OpenWrt and Alpine APK packages to release **1**
* Backport to macOS 10.13 High Sierra **2**

View File

@@ -38,7 +38,7 @@ icon: material/alert-decagram
!!! warning "与官方 Hysteria2 的区别"
官方程序支持一种名为 **userpass** 的验证方式,
本质上是将用户名与密码的组合 `<username>:<password>` 作为实际上的密码,而 sing-box 不提供此别名。
本质上是将用户名与密码的组合 `<username>:<password>` 作为实际上的密码,而 sing-box 不提供此别名。
要将 sing-box 与官方程序一起使用, 您需要填写该组合作为实际密码。
### 监听字段

View File

@@ -4,8 +4,8 @@ icon: material/new-box
!!! quote "Changes in sing-box 1.14.0"
:material-plus: [include_mac_address](#include_mac_address)
:material-plus: [exclude_mac_address](#exclude_mac_address)
:material-plus: [include_mac_address](#include_mac_address)
:material-plus: [exclude_mac_address](#exclude_mac_address)
!!! quote "Changes in sing-box 1.13.3"

View File

@@ -4,7 +4,7 @@ icon: material/new-box
!!! quote "sing-box 1.14.0 中的更改"
:material-plus: [include_mac_address](#include_mac_address)
:material-plus: [include_mac_address](#include_mac_address)
:material-plus: [exclude_mac_address](#exclude_mac_address)
!!! quote "sing-box 1.13.3 中的更改"

View File

@@ -38,7 +38,7 @@
!!! warning "与官方 Hysteria2 的区别"
官方程序支持一种名为 **userpass** 的验证方式,
本质上是将用户名与密码的组合 `<username>:<password>` 作为实际上的密码,而 sing-box 不提供此别名。
本质上是将用户名与密码的组合 `<username>:<password>` 作为实际上的密码,而 sing-box 不提供此别名。
要将 sing-box 与官方程序一起使用, 您需要填写该组合作为实际密码。
### 字段

View File

@@ -13,7 +13,10 @@ It handles OAuth authentication with Claude's API on your local machine while al
!!! quote "Changes in sing-box 1.14.0"
:material-plus: [credentials](#credentials)
:material-alert: [users](#users)
:material-alert: [credential_path](#credential_path)
:material-alert: [usages_path](#usages_path)
:material-alert: [users](#users)
:material-alert: [detour](#detour)
### Structure
@@ -51,6 +54,8 @@ On macOS, credentials are read from the system keychain first, then fall back to
Refreshed tokens are automatically written back to the same location.
!!! question "Since sing-box 1.14.0"
When `credential_path` points to a file, the service can start before the file exists. The credential becomes available automatically after the file is created or updated, and becomes unavailable immediately if the file is later removed or becomes invalid.
On macOS without an explicit `credential_path`, keychain changes are not watched. Automatic reload only applies to the credential file path.
@@ -65,7 +70,7 @@ List of credential configurations for multi-credential mode.
When set, top-level `credential_path`, `usages_path`, and `detour` are forbidden. Each user must specify a `credential` tag.
Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a required `tag` field.
Each credential has a `type` field (`default`, `external`, or `balancer`) and a required `tag` field.
##### Default Credential
@@ -76,7 +81,9 @@ Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a
"usages_path": "/path/to/usages.json",
"detour": "",
"reserve_5h": 20,
"reserve_weekly": 20
"reserve_weekly": 20,
"limit_5h": 0,
"limit_weekly": 0
}
```
@@ -85,8 +92,10 @@ A single OAuth credential file. The `type` field can be omitted (defaults to `de
- `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`.
- `usages_path`: Optional usage tracking file for this credential.
- `detour`: Outbound tag for connecting to the Claude API with this credential.
- `reserve_5h`: Reserve threshold (1-99) for 5-hour window. Credential pauses at (100-N)% utilization.
- `reserve_weekly`: Reserve threshold (1-99) for weekly window. Credential pauses at (100-N)% utilization.
- `reserve_5h`: Reserve threshold (1-99) for 5-hour window. Credential pauses at (100-N)% utilization. Conflict with `limit_5h`.
- `reserve_weekly`: Reserve threshold (1-99) for weekly window. Credential pauses at (100-N)% utilization. Conflict with `limit_weekly`.
- `limit_5h`: Explicit utilization cap (0-100) for 5-hour window. `0` means unset. Credential pauses when utilization reaches this value. Conflict with `reserve_5h`.
- `limit_weekly`: Explicit utilization cap (0-100) for weekly window. `0` means unset. Credential pauses when utilization reaches this value. Conflict with `reserve_weekly`.
##### Balancer Credential
@@ -95,32 +104,55 @@ A single OAuth credential file. The `type` field can be omitted (defaults to `de
"tag": "pool",
"type": "balancer",
"strategy": "",
"credentials": ["a", "b"],
"poll_interval": "60s"
"credentials": ["a", "b"]
}
```
Assigns sessions to default credentials based on the selected strategy. Sessions are sticky until the assigned credential hits a rate limit.
- `strategy`: Selection strategy. One of `least_used` `round_robin` `random`. `least_used` will be used by default.
- `strategy`: Selection strategy. One of `least_used` `round_robin` `random` `fallback`. `least_used` will be used by default.
- `credentials`: ==Required== List of default credential tags.
- `poll_interval`: How often to poll upstream usage API. Default `60s`.
##### Fallback Credential
##### Fallback Strategy
```json
{
"tag": "backup",
"type": "fallback",
"credentials": ["a", "b"],
"poll_interval": "30s"
"type": "balancer",
"strategy": "fallback",
"credentials": ["a", "b"]
}
```
Uses credentials in order. Falls through to the next when the current one is exhausted.
A balancer with `strategy: "fallback"` uses credentials in order. It falls through to the next when the current one is exhausted.
- `credentials`: ==Required== Ordered list of default credential tags.
- `poll_interval`: How often to poll upstream usage API. Default `60s`.
##### External Credential
```json
{
"tag": "remote",
"type": "external",
"url": "",
"server": "",
"server_port": 0,
"token": "",
"reverse": false,
"detour": "",
"usages_path": ""
}
```
Proxies requests through a remote CCM instance instead of using a local OAuth credential.
- `url`: URL of the remote CCM instance. Omit to create a receiver that only waits for inbound reverse connections.
- `server`: Override server address for dialing, separate from URL hostname.
- `server_port`: Override server port for dialing.
- `token`: ==Required== Authentication token for the remote instance.
- `reverse`: Enable connector mode. Requires `url`. A connector dials out to `/ccm/v1/reverse` on the remote instance and cannot serve local requests directly. When `url` is set without `reverse`, the credential proxies requests through the remote instance normally and prefers an established reverse connection when one is available.
- `detour`: Outbound tag for connecting to the remote instance.
- `usages_path`: Optional usage tracking file.
#### usages_path
@@ -137,6 +169,8 @@ Statistics are organized by model, context window (200k standard vs 1M premium),
The statistics file is automatically saved every minute and upon service shutdown.
!!! question "Since sing-box 1.14.0"
Conflict with `credentials`. In multi-credential mode, use `usages_path` on individual default credentials.
#### users
@@ -151,7 +185,9 @@ Object format:
{
"name": "",
"token": "",
"credential": ""
"credential": "",
"external_credential": "",
"allow_external_usage": false
}
```
@@ -159,7 +195,12 @@ Object fields:
- `name`: Username identifier for tracking purposes.
- `token`: Bearer token for authentication. Claude Code authenticates by setting the `ANTHROPIC_AUTH_TOKEN` environment variable to their token value.
!!! question "Since sing-box 1.14.0"
- `credential`: Credential tag to use for this user. ==Required== when `credentials` is set.
- `external_credential`: Tag of an external credential used only to rewrite response rate-limit headers with aggregated utilization from this user's other available credentials. It does not control request routing; request selection still comes from `credential` and `allow_external_usage`.
- `allow_external_usage`: Allow this user to use external credentials. `false` by default.
#### headers
@@ -171,6 +212,8 @@ These headers will override any existing headers with the same name.
Outbound tag for connecting to the Claude API.
!!! question "Since sing-box 1.14.0"
Conflict with `credentials`. In multi-credential mode, use `detour` on individual default credentials.
#### tls
@@ -241,7 +284,6 @@ claude
{
"tag": "pool",
"type": "balancer",
"poll_interval": "60s",
"credentials": ["a", "b"]
}
],

View File

@@ -13,7 +13,10 @@ CCMClaude Code 多路复用器)服务是一个多路复用服务,允许
!!! quote "sing-box 1.14.0 中的更改"
:material-plus: [credentials](#credentials)
:material-alert: [users](#users)
:material-alert: [credential_path](#credential_path)
:material-alert: [usages_path](#usages_path)
:material-alert: [users](#users)
:material-alert: [detour](#detour)
### 结构
@@ -51,6 +54,8 @@ Claude Code OAuth 凭据文件的路径。
刷新的令牌会自动写回相同位置。
!!! question "自 sing-box 1.14.0 起"
`credential_path` 指向文件时,即使文件尚不存在,服务也可以启动。文件被创建或更新后,凭据会自动变为可用;如果文件之后被删除或变为无效,该凭据会立即变为不可用。
在 macOS 上如果未显式设置 `credential_path`,不会监听钥匙串变化。自动重载只作用于凭据文件路径。
@@ -65,7 +70,7 @@ Claude Code OAuth 凭据文件的路径。
设置后,顶层 `credential_path``usages_path``detour` 被禁止。每个用户必须指定 `credential` 标签。
每个凭据有一个 `type` 字段(`default``balancer``fallback`)和一个必填的 `tag` 字段。
每个凭据有一个 `type` 字段(`default``external``balancer`)和一个必填的 `tag` 字段。
##### 默认凭据
@@ -76,7 +81,9 @@ Claude Code OAuth 凭据文件的路径。
"usages_path": "/path/to/usages.json",
"detour": "",
"reserve_5h": 20,
"reserve_weekly": 20
"reserve_weekly": 20,
"limit_5h": 0,
"limit_weekly": 0
}
```
@@ -85,8 +92,10 @@ Claude Code OAuth 凭据文件的路径。
- `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。
- `usages_path`:此凭据的可选使用跟踪文件。
- `detour`:此凭据用于连接 Claude API 的出站标签。
- `reserve_5h`5 小时窗口的保留阈值1-99。凭据在利用率达到 (100-N)% 时暂停。
- `reserve_weekly`每周窗口的保留阈值1-99。凭据在利用率达到 (100-N)% 时暂停。
- `reserve_5h`5 小时窗口的保留阈值1-99。凭据在利用率达到 (100-N)% 时暂停。`limit_5h` 冲突。
- `reserve_weekly`每周窗口的保留阈值1-99。凭据在利用率达到 (100-N)% 时暂停。`limit_weekly` 冲突。
- `limit_5h`5 小时窗口的显式利用率上限0-100`0` 表示未设置显式上限。凭据在利用率达到此值时暂停。与 `reserve_5h` 冲突。
- `limit_weekly`每周窗口的显式利用率上限0-100`0` 表示未设置显式上限。凭据在利用率达到此值时暂停。与 `reserve_weekly` 冲突。
##### 均衡凭据
@@ -95,32 +104,55 @@ Claude Code OAuth 凭据文件的路径。
"tag": "pool",
"type": "balancer",
"strategy": "",
"credentials": ["a", "b"],
"poll_interval": "60s"
"credentials": ["a", "b"]
}
```
根据选择的策略将会话分配给默认凭据。会话保持粘性,直到分配的凭据触发速率限制。
- `strategy`:选择策略。可选值:`least_used` `round_robin` `random`。默认使用 `least_used`
- `strategy`:选择策略。可选值:`least_used` `round_robin` `random` `fallback`。默认使用 `least_used`
- `credentials`==必填== 默认凭据标签列表。
- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`
##### 回退凭据
##### 回退策略
```json
{
"tag": "backup",
"type": "fallback",
"credentials": ["a", "b"],
"poll_interval": "30s"
"type": "balancer",
"strategy": "fallback",
"credentials": ["a", "b"]
}
```
按顺序使用凭据。当前凭据耗尽后切换到下一个。
`strategy` 设为 `fallback` 的均衡凭据会按顺序使用凭据。当前凭据耗尽后切换到下一个。
- `credentials`==必填== 有序的默认凭据标签列表。
- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`
##### 外部凭据
```json
{
"tag": "remote",
"type": "external",
"url": "",
"server": "",
"server_port": 0,
"token": "",
"reverse": false,
"detour": "",
"usages_path": ""
}
```
通过远程 CCM 实例代理请求,而非使用本地 OAuth 凭据。
- `url`:远程 CCM 实例的 URL。省略时此凭据作为仅等待入站反向连接的接收器。
- `server`:覆盖拨号的服务器地址,与 URL 主机名分开。
- `server_port`:覆盖拨号的服务器端口。
- `token`==必填== 远程实例的身份验证令牌。
- `reverse`:启用连接器模式。要求设置 `url`。启用后,此凭据会主动拨出到远程实例的 `/ccm/v1/reverse`,且不能直接为本地请求提供服务。当设置了 `url` 但未启用 `reverse` 时,此凭据会正常通过远程实例转发请求,并在反向连接建立后优先使用该反向连接。
- `detour`:用于连接远程实例的出站标签。
- `usages_path`:可选的使用跟踪文件。
#### usages_path
@@ -137,6 +169,8 @@ Claude Code OAuth 凭据文件的路径。
统计文件每分钟自动保存一次,并在服务关闭时保存。
!!! question "自 sing-box 1.14.0 起"
`credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `usages_path`
#### users
@@ -151,7 +185,9 @@ Claude Code OAuth 凭据文件的路径。
{
"name": "",
"token": "",
"credential": ""
"credential": "",
"external_credential": "",
"allow_external_usage": false
}
```
@@ -159,7 +195,12 @@ Claude Code OAuth 凭据文件的路径。
- `name`:用于跟踪的用户名标识符。
- `token`:用于身份验证的 Bearer 令牌。Claude Code 通过设置 `ANTHROPIC_AUTH_TOKEN` 环境变量为其令牌值进行身份验证。
!!! question "自 sing-box 1.14.0 起"
- `credential`:此用户使用的凭据标签。设置 `credentials` 时==必填==。
- `external_credential`:仅用于用此用户其他可用凭据的聚合利用率重写响应速率限制头的外部凭据标签。它不参与请求路由;请求选择仍由 `credential``allow_external_usage` 决定。
- `allow_external_usage`:允许此用户使用外部凭据。默认为 `false`
#### headers
@@ -171,6 +212,8 @@ Claude Code OAuth 凭据文件的路径。
用于连接 Claude API 的出站标签。
!!! question "自 sing-box 1.14.0 起"
`credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `detour`
#### tls
@@ -241,7 +284,6 @@ claude
{
"tag": "pool",
"type": "balancer",
"poll_interval": "60s",
"credentials": ["a", "b"]
}
],

View File

@@ -13,7 +13,10 @@ It handles OAuth authentication with OpenAI's API on your local machine while al
!!! quote "Changes in sing-box 1.14.0"
:material-plus: [credentials](#credentials)
:material-alert: [users](#users)
:material-alert: [credential_path](#credential_path)
:material-alert: [usages_path](#usages_path)
:material-alert: [users](#users)
:material-alert: [detour](#detour)
### Structure
@@ -49,6 +52,8 @@ If not specified, defaults to:
Refreshed tokens are automatically written back to the same location.
!!! question "Since sing-box 1.14.0"
When `credential_path` points to a file, the service can start before the file exists. The credential becomes available automatically after the file is created or updated, and becomes unavailable immediately if the file is later removed or becomes invalid.
Conflict with `credentials`.
@@ -61,7 +66,7 @@ List of credential configurations for multi-credential mode.
When set, top-level `credential_path`, `usages_path`, and `detour` are forbidden. Each user must specify a `credential` tag.
Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a required `tag` field.
Each credential has a `type` field (`default`, `external`, or `balancer`) and a required `tag` field.
##### Default Credential
@@ -72,7 +77,9 @@ Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a
"usages_path": "/path/to/usages.json",
"detour": "",
"reserve_5h": 20,
"reserve_weekly": 20
"reserve_weekly": 20,
"limit_5h": 0,
"limit_weekly": 0
}
```
@@ -81,8 +88,10 @@ A single OAuth credential file. The `type` field can be omitted (defaults to `de
- `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`.
- `usages_path`: Optional usage tracking file for this credential.
- `detour`: Outbound tag for connecting to the OpenAI API with this credential.
- `reserve_5h`: Reserve threshold (1-99) for primary rate limit window. Credential pauses at (100-N)% utilization.
- `reserve_weekly`: Reserve threshold (1-99) for secondary (weekly) rate limit window. Credential pauses at (100-N)% utilization.
- `reserve_5h`: Reserve threshold (1-99) for primary rate limit window. Credential pauses at (100-N)% utilization. Conflict with `limit_5h`.
- `reserve_weekly`: Reserve threshold (1-99) for secondary (weekly) rate limit window. Credential pauses at (100-N)% utilization. Conflict with `limit_weekly`.
- `limit_5h`: Explicit utilization cap (0-100) for primary rate limit window. `0` means unset. Credential pauses when utilization reaches this value. Conflict with `reserve_5h`.
- `limit_weekly`: Explicit utilization cap (0-100) for secondary (weekly) rate limit window. `0` means unset. Credential pauses when utilization reaches this value. Conflict with `reserve_weekly`.
##### Balancer Credential
@@ -91,32 +100,55 @@ A single OAuth credential file. The `type` field can be omitted (defaults to `de
"tag": "pool",
"type": "balancer",
"strategy": "",
"credentials": ["a", "b"],
"poll_interval": "60s"
"credentials": ["a", "b"]
}
```
Assigns sessions to default credentials based on the selected strategy. Sessions are sticky until the assigned credential hits a rate limit.
- `strategy`: Selection strategy. One of `least_used` `round_robin` `random`. `least_used` will be used by default.
- `strategy`: Selection strategy. One of `least_used` `round_robin` `random` `fallback`. `least_used` will be used by default.
- `credentials`: ==Required== List of default credential tags.
- `poll_interval`: How often to poll upstream usage API. Default `60s`.
##### Fallback Credential
##### Fallback Strategy
```json
{
"tag": "backup",
"type": "fallback",
"credentials": ["a", "b"],
"poll_interval": "30s"
"type": "balancer",
"strategy": "fallback",
"credentials": ["a", "b"]
}
```
Uses credentials in order. Falls through to the next when the current one is exhausted.
A balancer with `strategy: "fallback"` uses credentials in order. It falls through to the next when the current one is exhausted.
- `credentials`: ==Required== Ordered list of default credential tags.
- `poll_interval`: How often to poll upstream usage API. Default `60s`.
##### External Credential
```json
{
"tag": "remote",
"type": "external",
"url": "",
"server": "",
"server_port": 0,
"token": "",
"reverse": false,
"detour": "",
"usages_path": ""
}
```
Proxies requests through a remote OCM instance instead of using a local OAuth credential.
- `url`: URL of the remote OCM instance. Omit to create a receiver that only waits for inbound reverse connections.
- `server`: Override server address for dialing, separate from URL hostname.
- `server_port`: Override server port for dialing.
- `token`: ==Required== Authentication token for the remote instance.
- `reverse`: Enable connector mode. Requires `url`. A connector dials out to `/ocm/v1/reverse` on the remote instance and cannot serve local requests directly. When `url` is set without `reverse`, the credential proxies requests through the remote instance normally and prefers an established reverse connection when one is available.
- `detour`: Outbound tag for connecting to the remote instance.
- `usages_path`: Optional usage tracking file.
#### usages_path
@@ -133,6 +165,8 @@ Statistics are organized by model and optionally by user when authentication is
The statistics file is automatically saved every minute and upon service shutdown.
!!! question "Since sing-box 1.14.0"
Conflict with `credentials`. In multi-credential mode, use `usages_path` on individual default credentials.
#### users
@@ -147,7 +181,9 @@ Object format:
{
"name": "",
"token": "",
"credential": ""
"credential": "",
"external_credential": "",
"allow_external_usage": false
}
```
@@ -155,7 +191,12 @@ Object fields:
- `name`: Username identifier for tracking purposes.
- `token`: Bearer token for authentication. Clients authenticate by setting the `Authorization: Bearer <token>` header.
!!! question "Since sing-box 1.14.0"
- `credential`: Credential tag to use for this user. ==Required== when `credentials` is set.
- `external_credential`: Tag of an external credential used only to rewrite response rate-limit headers with aggregated utilization from this user's other available credentials. It does not control request routing; request selection still comes from `credential` and `allow_external_usage`.
- `allow_external_usage`: Allow this user to use external credentials. `false` by default.
#### headers
@@ -167,6 +208,8 @@ These headers will override any existing headers with the same name.
Outbound tag for connecting to the OpenAI API.
!!! question "Since sing-box 1.14.0"
Conflict with `credentials`. In multi-credential mode, use `detour` on individual default credentials.
#### tls
@@ -293,7 +336,6 @@ codex --profile ocm
{
"tag": "pool",
"type": "balancer",
"poll_interval": "60s",
"credentials": ["a", "b"]
}
],

View File

@@ -13,7 +13,10 @@ OCMOpenAI Codex 多路复用器)服务是一个多路复用服务,允许
!!! quote "sing-box 1.14.0 中的更改"
:material-plus: [credentials](#credentials)
:material-alert: [users](#users)
:material-alert: [credential_path](#credential_path)
:material-alert: [usages_path](#usages_path)
:material-alert: [users](#users)
:material-alert: [detour](#detour)
### 结构
@@ -49,6 +52,8 @@ OpenAI OAuth 凭据文件的路径。
刷新的令牌会自动写回相同位置。
!!! question "自 sing-box 1.14.0 起"
`credential_path` 指向文件时,即使文件尚不存在,服务也可以启动。文件被创建或更新后,凭据会自动变为可用;如果文件之后被删除或变为无效,该凭据会立即变为不可用。
`credentials` 冲突。
@@ -61,7 +66,7 @@ OpenAI OAuth 凭据文件的路径。
设置后,顶层 `credential_path``usages_path``detour` 被禁止。每个用户必须指定 `credential` 标签。
每个凭据有一个 `type` 字段(`default``balancer``fallback`)和一个必填的 `tag` 字段。
每个凭据有一个 `type` 字段(`default``external``balancer`)和一个必填的 `tag` 字段。
##### 默认凭据
@@ -72,7 +77,9 @@ OpenAI OAuth 凭据文件的路径。
"usages_path": "/path/to/usages.json",
"detour": "",
"reserve_5h": 20,
"reserve_weekly": 20
"reserve_weekly": 20,
"limit_5h": 0,
"limit_weekly": 0
}
```
@@ -81,8 +88,10 @@ OpenAI OAuth 凭据文件的路径。
- `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。
- `usages_path`:此凭据的可选使用跟踪文件。
- `detour`:此凭据用于连接 OpenAI API 的出站标签。
- `reserve_5h`主要速率限制窗口的保留阈值1-99。凭据在利用率达到 (100-N)% 时暂停。
- `reserve_weekly`次要每周速率限制窗口的保留阈值1-99。凭据在利用率达到 (100-N)% 时暂停。
- `reserve_5h`主要速率限制窗口的保留阈值1-99。凭据在利用率达到 (100-N)% 时暂停。`limit_5h` 冲突。
- `reserve_weekly`次要每周速率限制窗口的保留阈值1-99。凭据在利用率达到 (100-N)% 时暂停。`limit_weekly` 冲突。
- `limit_5h`主要速率限制窗口的显式利用率上限0-100`0` 表示未设置显式上限。凭据在利用率达到此值时暂停。与 `reserve_5h` 冲突。
- `limit_weekly`次要每周速率限制窗口的显式利用率上限0-100`0` 表示未设置显式上限。凭据在利用率达到此值时暂停。与 `reserve_weekly` 冲突。
##### 均衡凭据
@@ -91,32 +100,55 @@ OpenAI OAuth 凭据文件的路径。
"tag": "pool",
"type": "balancer",
"strategy": "",
"credentials": ["a", "b"],
"poll_interval": "60s"
"credentials": ["a", "b"]
}
```
根据选择的策略将会话分配给默认凭据。会话保持粘性,直到分配的凭据触发速率限制。
- `strategy`:选择策略。可选值:`least_used` `round_robin` `random`。默认使用 `least_used`
- `strategy`:选择策略。可选值:`least_used` `round_robin` `random` `fallback`。默认使用 `least_used`
- `credentials`==必填== 默认凭据标签列表。
- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`
##### 回退凭据
##### 回退策略
```json
{
"tag": "backup",
"type": "fallback",
"credentials": ["a", "b"],
"poll_interval": "30s"
"type": "balancer",
"strategy": "fallback",
"credentials": ["a", "b"]
}
```
按顺序使用凭据。当前凭据耗尽后切换到下一个。
`strategy` 设为 `fallback` 的均衡凭据会按顺序使用凭据。当前凭据耗尽后切换到下一个。
- `credentials`==必填== 有序的默认凭据标签列表。
- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`
##### 外部凭据
```json
{
"tag": "remote",
"type": "external",
"url": "",
"server": "",
"server_port": 0,
"token": "",
"reverse": false,
"detour": "",
"usages_path": ""
}
```
通过远程 OCM 实例代理请求,而非使用本地 OAuth 凭据。
- `url`:远程 OCM 实例的 URL。省略时此凭据作为仅等待入站反向连接的接收器。
- `server`:覆盖拨号的服务器地址,与 URL 主机名分开。
- `server_port`:覆盖拨号的服务器端口。
- `token`==必填== 远程实例的身份验证令牌。
- `reverse`:启用连接器模式。要求设置 `url`。启用后,此凭据会主动拨出到远程实例的 `/ocm/v1/reverse`,且不能直接为本地请求提供服务。当设置了 `url` 但未启用 `reverse` 时,此凭据会正常通过远程实例转发请求,并在反向连接建立后优先使用该反向连接。
- `detour`:用于连接远程实例的出站标签。
- `usages_path`:可选的使用跟踪文件。
#### usages_path
@@ -133,6 +165,8 @@ OpenAI OAuth 凭据文件的路径。
统计文件每分钟自动保存一次,并在服务关闭时保存。
!!! question "自 sing-box 1.14.0 起"
`credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `usages_path`
#### users
@@ -147,7 +181,9 @@ OpenAI OAuth 凭据文件的路径。
{
"name": "",
"token": "",
"credential": ""
"credential": "",
"external_credential": "",
"allow_external_usage": false
}
```
@@ -155,7 +191,12 @@ OpenAI OAuth 凭据文件的路径。
- `name`:用于跟踪的用户名标识符。
- `token`:用于身份验证的 Bearer 令牌。客户端通过设置 `Authorization: Bearer <token>` 头进行身份验证。
!!! question "自 sing-box 1.14.0 起"
- `credential`:此用户使用的凭据标签。设置 `credentials` 时==必填==。
- `external_credential`:仅用于用此用户其他可用凭据的聚合利用率重写响应速率限制头的外部凭据标签。它不参与请求路由;请求选择仍由 `credential``allow_external_usage` 决定。
- `allow_external_usage`:允许此用户使用外部凭据。默认为 `false`
#### headers
@@ -167,6 +208,8 @@ OpenAI OAuth 凭据文件的路径。
用于连接 OpenAI API 的出站标签。
!!! question "自 sing-box 1.14.0 起"
`credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `detour`
#### tls
@@ -294,7 +337,6 @@ codex --profile ocm
{
"tag": "pool",
"type": "balancer",
"poll_interval": "60s",
"credentials": ["a", "b"]
}
],

2
go.mod
View File

@@ -41,7 +41,7 @@ require (
github.com/sagernet/sing-shadowsocks v0.2.8
github.com/sagernet/sing-shadowsocks2 v0.2.1
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11
github.com/sagernet/sing-tun v0.8.3-0.20260311132553-5485872f601f
github.com/sagernet/sing-tun v0.8.4-0.20260315091454-bbe21100c226
github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1
github.com/sagernet/smux v1.5.50-sing-box-mod.1
github.com/sagernet/tailscale v1.92.4-sing-box-1.13-mod.6.0.20260311131347-f88b27eeb76e

4
go.sum
View File

@@ -248,8 +248,8 @@ github.com/sagernet/sing-shadowsocks2 v0.2.1 h1:dWV9OXCeFPuYGHb6IRqlSptVnSzOelnq
github.com/sagernet/sing-shadowsocks2 v0.2.1/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ=
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 h1:tK+75l64tm9WvEFrYRE1t0YxoFdWQqw/h7Uhzj0vJ+w=
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11/go.mod h1:sWqKnGlMipCHaGsw1sTTlimyUpgzP4WP3pjhCsYt9oA=
github.com/sagernet/sing-tun v0.8.3-0.20260311132553-5485872f601f h1:uj3rzedphq1AiL0PpuVoob5RtKsPBcMRd8aqo+q0rqA=
github.com/sagernet/sing-tun v0.8.3-0.20260311132553-5485872f601f/go.mod h1:pLCo4o+LacXEzz0bhwhJkKBjLlKOGPBNOAZ97ZVZWzs=
github.com/sagernet/sing-tun v0.8.4-0.20260315091454-bbe21100c226 h1:Shy/fsm+pqVq6OkBAWPaOmOiPT/AwoRxQLiV1357Y0Y=
github.com/sagernet/sing-tun v0.8.4-0.20260315091454-bbe21100c226/go.mod h1:pLCo4o+LacXEzz0bhwhJkKBjLlKOGPBNOAZ97ZVZWzs=
github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1 h1:aSwUNYUkVyVvdmBSufR8/nRFonwJeKSIROxHcm5br9o=
github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1/go.mod h1:P11scgTxMxVVQ8dlM27yNm3Cro40mD0+gHbnqrNGDuY=
github.com/sagernet/smux v1.5.50-sing-box-mod.1 h1:XkJcivBC9V4wBjiGXIXZ229aZCU1hzcbp6kSkkyQ478=

View File

@@ -32,7 +32,6 @@ type _CCMCredential struct {
DefaultOptions CCMDefaultCredentialOptions `json:"-"`
ExternalOptions CCMExternalCredentialOptions `json:"-"`
BalancerOptions CCMBalancerCredentialOptions `json:"-"`
FallbackOptions CCMFallbackCredentialOptions `json:"-"`
}
type CCMCredential _CCMCredential
@@ -47,8 +46,6 @@ func (c CCMCredential) MarshalJSON() ([]byte, error) {
v = c.ExternalOptions
case "balancer":
v = c.BalancerOptions
case "fallback":
v = c.FallbackOptions
default:
return nil, E.New("unknown credential type: ", c.Type)
}
@@ -72,8 +69,6 @@ func (c *CCMCredential) UnmarshalJSON(bytes []byte) error {
v = &c.ExternalOptions
case "balancer":
v = &c.BalancerOptions
case "fallback":
v = &c.FallbackOptions
default:
return E.New("unknown credential type: ", c.Type)
}
@@ -81,32 +76,27 @@ func (c *CCMCredential) UnmarshalJSON(bytes []byte) error {
}
type CCMDefaultCredentialOptions struct {
CredentialPath string `json:"credential_path,omitempty"`
UsagesPath string `json:"usages_path,omitempty"`
Detour string `json:"detour,omitempty"`
Reserve5h uint8 `json:"reserve_5h"`
ReserveWeekly uint8 `json:"reserve_weekly"`
Limit5h uint8 `json:"limit_5h,omitempty"`
LimitWeekly uint8 `json:"limit_weekly,omitempty"`
CredentialPath string `json:"credential_path,omitempty"`
ClaudeDirectory string `json:"claude_directory,omitempty"`
UsagesPath string `json:"usages_path,omitempty"`
Detour string `json:"detour,omitempty"`
Reserve5h uint8 `json:"reserve_5h"`
ReserveWeekly uint8 `json:"reserve_weekly"`
Limit5h uint8 `json:"limit_5h,omitempty"`
LimitWeekly uint8 `json:"limit_weekly,omitempty"`
}
type CCMBalancerCredentialOptions struct {
Strategy string `json:"strategy,omitempty"`
Credentials badoption.Listable[string] `json:"credentials"`
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
Strategy string `json:"strategy,omitempty"`
Credentials badoption.Listable[string] `json:"credentials"`
RebalanceThreshold float64 `json:"rebalance_threshold,omitempty"`
}
type CCMExternalCredentialOptions struct {
URL string `json:"url,omitempty"`
ServerOptions
Token string `json:"token"`
Reverse bool `json:"reverse,omitempty"`
Detour string `json:"detour,omitempty"`
UsagesPath string `json:"usages_path,omitempty"`
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
}
type CCMFallbackCredentialOptions struct {
Credentials badoption.Listable[string] `json:"credentials"`
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
Token string `json:"token"`
Reverse bool `json:"reverse,omitempty"`
Detour string `json:"detour,omitempty"`
UsagesPath string `json:"usages_path,omitempty"`
}

View File

@@ -32,7 +32,6 @@ type _OCMCredential struct {
DefaultOptions OCMDefaultCredentialOptions `json:"-"`
ExternalOptions OCMExternalCredentialOptions `json:"-"`
BalancerOptions OCMBalancerCredentialOptions `json:"-"`
FallbackOptions OCMFallbackCredentialOptions `json:"-"`
}
type OCMCredential _OCMCredential
@@ -47,8 +46,6 @@ func (c OCMCredential) MarshalJSON() ([]byte, error) {
v = c.ExternalOptions
case "balancer":
v = c.BalancerOptions
case "fallback":
v = c.FallbackOptions
default:
return nil, E.New("unknown credential type: ", c.Type)
}
@@ -72,8 +69,6 @@ func (c *OCMCredential) UnmarshalJSON(bytes []byte) error {
v = &c.ExternalOptions
case "balancer":
v = &c.BalancerOptions
case "fallback":
v = &c.FallbackOptions
default:
return E.New("unknown credential type: ", c.Type)
}
@@ -91,22 +86,16 @@ type OCMDefaultCredentialOptions struct {
}
type OCMBalancerCredentialOptions struct {
Strategy string `json:"strategy,omitempty"`
Credentials badoption.Listable[string] `json:"credentials"`
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
Strategy string `json:"strategy,omitempty"`
Credentials badoption.Listable[string] `json:"credentials"`
RebalanceThreshold float64 `json:"rebalance_threshold,omitempty"`
}
type OCMExternalCredentialOptions struct {
URL string `json:"url,omitempty"`
ServerOptions
Token string `json:"token"`
Reverse bool `json:"reverse,omitempty"`
Detour string `json:"detour,omitempty"`
UsagesPath string `json:"usages_path,omitempty"`
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
}
type OCMFallbackCredentialOptions struct {
Credentials badoption.Listable[string] `json:"credentials"`
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
Token string `json:"token"`
Reverse bool `json:"reverse,omitempty"`
Detour string `json:"detour,omitempty"`
UsagesPath string `json:"usages_path,omitempty"`
}

View File

@@ -2,8 +2,16 @@
set -e -o pipefail
go_version=$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
curl -Lo go.tar.gz "https://go.dev/dl/go$go_version.linux-amd64.tar.gz"
manifest=$(curl -fS 'https://go.dev/VERSION?m=text')
go_version=$(echo "$manifest" | head -1 | sed 's/^go//')
os=$(uname -s | tr '[:upper:]' '[:lower:]')
arch=$(uname -m)
case "$arch" in
x86_64) arch="amd64" ;;
aarch64|arm64) arch="arm64" ;;
esac
curl -Lo go.tar.gz "https://go.dev/dl/go$go_version.$os-$arch.tar.gz"
sudo rm -rf /usr/local/go
sudo tar -C /usr/local -xzf go.tar.gz
rm go.tar.gz
echo "Installed Go $go_version"

13
service/ccm/CLAUDE.md Normal file
View File

@@ -0,0 +1,13 @@
# Claude Code Multiplexer
### Reverse Claude Code
Claude distributes a huge binary by default in a Bun, which is difficult to reverse engineer (and is very likely the one the user have installed now).
You must obtain the npm version of the Claude Code js source code:
Example:
```bash
cd /tmp && npm pack @anthropic-ai/claude-code && tar xzf anthropic-ai-claude-code-*.tgz && npx prettier --write package/cli.js
```

View File

@@ -1,224 +1,258 @@
package ccm
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"os"
"os/user"
"path/filepath"
"runtime"
"slices"
"strconv"
"sync"
"time"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/observable"
)
const (
oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
oauth2TokenURL = "https://platform.claude.com/v1/oauth/token"
claudeAPIBaseURL = "https://api.anthropic.com"
tokenRefreshBufferMs = 60000
anthropicBetaOAuthValue = "oauth-2025-04-20"
defaultPollInterval = 60 * time.Minute
failedPollRetryInterval = time.Minute
httpRetryMaxBackoff = 5 * time.Minute
)
const ccmUserAgentFallback = "claude-code/2.1.72"
var (
ccmUserAgentOnce sync.Once
ccmUserAgentValue string
const (
httpRetryMaxAttempts = 3
httpRetryInitialDelay = 200 * time.Millisecond
)
func initCCMUserAgent(logger log.ContextLogger) {
ccmUserAgentOnce.Do(func() {
version, err := detectClaudeCodeVersion()
if err != nil {
logger.Error("detect Claude Code version: ", err)
ccmUserAgentValue = ccmUserAgentFallback
return
const sessionExpiry = 24 * time.Hour
func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) {
var lastError error
for attempt := range httpRetryMaxAttempts {
if attempt > 0 {
delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1))
timer := time.NewTimer(delay)
select {
case <-ctx.Done():
timer.Stop()
return nil, lastError
case <-timer.C:
}
}
logger.Debug("detected Claude Code version: ", version)
ccmUserAgentValue = "claude-code/" + version
})
}
func detectClaudeCodeVersion() (string, error) {
userInfo, err := getRealUser()
if err != nil {
return "", E.Cause(err, "get user")
}
binaryName := "claude"
if runtime.GOOS == "windows" {
binaryName = "claude.exe"
}
linkPath := filepath.Join(userInfo.HomeDir, ".local", "bin", binaryName)
target, err := os.Readlink(linkPath)
if err != nil {
return "", E.Cause(err, "readlink ", linkPath)
}
if !filepath.IsAbs(target) {
target = filepath.Join(filepath.Dir(linkPath), target)
}
parent := filepath.Base(filepath.Dir(target))
if parent != "versions" {
return "", E.New("unexpected symlink target: ", target)
}
return filepath.Base(target), nil
}
func getRealUser() (*user.User, error) {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
sudoUserInfo, err := user.Lookup(sudoUser)
if err == nil {
return sudoUserInfo, nil
}
}
return user.Current()
}
func getDefaultCredentialsPath() (string, error) {
if configDir := os.Getenv("CLAUDE_CONFIG_DIR"); configDir != "" {
return filepath.Join(configDir, ".credentials.json"), nil
}
userInfo, err := getRealUser()
if err != nil {
return "", err
}
return filepath.Join(userInfo.HomeDir, ".claude", ".credentials.json"), nil
}
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var credentialsContainer struct {
ClaudeAIAuth *oauthCredentials `json:"claudeAiOauth,omitempty"`
}
err = json.Unmarshal(data, &credentialsContainer)
if err != nil {
return nil, err
}
if credentialsContainer.ClaudeAIAuth == nil {
return nil, E.New("claudeAiOauth field not found in credentials")
}
return credentialsContainer.ClaudeAIAuth, nil
}
func checkCredentialFileWritable(path string) error {
file, err := os.OpenFile(path, os.O_WRONLY, 0)
if err != nil {
return err
}
return file.Close()
}
func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) error {
data, err := json.MarshalIndent(map[string]any{
"claudeAiOauth": oauthCredentials,
}, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o600)
}
type oauthCredentials struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ExpiresAt int64 `json:"expiresAt"`
Scopes []string `json:"scopes,omitempty"`
SubscriptionType string `json:"subscriptionType,omitempty"`
RateLimitTier string `json:"rateLimitTier,omitempty"`
IsMax bool `json:"isMax,omitempty"`
}
func (c *oauthCredentials) needsRefresh() bool {
if c.ExpiresAt == 0 {
return false
}
return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs
}
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
if credentials.RefreshToken == "" {
return nil, E.New("refresh token is empty")
}
requestBody, err := json.Marshal(map[string]string{
"grant_type": "refresh_token",
"refresh_token": credentials.RefreshToken,
"client_id": oauth2ClientID,
})
if err != nil {
return nil, E.Cause(err, "marshal request")
}
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
request, err := buildRequest()
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("User-Agent", ccmUserAgentValue)
return request, nil
response, err := client.Do(request)
if err == nil {
return response, nil
}
lastError = err
if ctx.Err() != nil {
return nil, lastError
}
}
return nil, lastError
}
type credentialState struct {
fiveHourUtilization float64
fiveHourReset time.Time
weeklyUtilization float64
weeklyReset time.Time
hardRateLimited bool
rateLimitResetAt time.Time
availabilityState availabilityState
availabilityReason availabilityReason
availabilityResetAt time.Time
lastKnownDataAt time.Time
accountUUID string
accountType string
rateLimitTier string
oauthAccount *claudeOAuthAccount
remotePlanWeight float64
lastUpdated time.Time
consecutivePollFailures int
usageAPIRetryDelay time.Duration
unavailable bool
upstreamRejectedUntil time.Time
lastCredentialLoadAttempt time.Time
lastCredentialLoadError string
}
type credentialRequestContext struct {
context.Context
releaseOnce sync.Once
cancelOnce sync.Once
releaseFuncs []func() bool
cancelFunc context.CancelFunc
}
func (c *credentialRequestContext) addInterruptLink(stop func() bool) {
c.releaseFuncs = append(c.releaseFuncs, stop)
}
func (c *credentialRequestContext) releaseCredentialInterrupt() {
c.releaseOnce.Do(func() {
for _, f := range c.releaseFuncs {
f()
}
})
}
func (c *credentialRequestContext) cancelRequest() {
c.releaseCredentialInterrupt()
c.cancelOnce.Do(c.cancelFunc)
}
type Credential interface {
tagName() string
isAvailable() bool
isUsable() bool
isExternal() bool
hasSnapshotData() bool
fiveHourUtilization() float64
weeklyUtilization() float64
fiveHourCap() float64
weeklyCap() float64
planWeight() float64
fiveHourResetTime() time.Time
weeklyResetTime() time.Time
markRateLimited(resetAt time.Time)
markUpstreamRejected()
availabilityStatus() availabilityStatus
earliestReset() time.Time
unavailableError() error
getAccessToken() (string, error)
buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error)
updateStateFromHeaders(header http.Header)
wrapRequestContext(ctx context.Context) *credentialRequestContext
interruptConnections()
setStatusSubscriber(*observable.Subscriber[struct{}])
start() error
pollUsage()
lastUpdatedTime() time.Time
pollBackoff(base time.Duration) time.Duration
usageTrackerOrNil() *AggregatedUsage
httpClient() *http.Client
close()
}
type credentialSelectionScope string
const (
credentialSelectionScopeAll credentialSelectionScope = "all"
credentialSelectionScopeNonExternal credentialSelectionScope = "non_external"
)
type credentialSelection struct {
scope credentialSelectionScope
filter func(Credential) bool
}
func (s credentialSelection) allows(credential Credential) bool {
return s.filter == nil || s.filter(credential)
}
func (s credentialSelection) scopeOrDefault() credentialSelectionScope {
if s.scope == "" {
return credentialSelectionScopeAll
}
return s.scope
}
// Claude Code's unified rate-limit handling parses these reset headers with
// Number(...), compares them against Date.now()/1000, and renders them via
// new Date(seconds*1000), so keep the wire format pinned to Unix epoch seconds.
func parseAnthropicResetHeaderValue(headerName string, headerValue string) time.Time {
unixEpoch, err := strconv.ParseInt(headerValue, 10, 64)
if err != nil {
return nil, err
panic("invalid " + headerName + " header: expected Unix epoch seconds, got " + strconv.Quote(headerValue))
}
defer response.Body.Close()
if response.StatusCode == http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
return nil, E.New("refresh rate limited: ", response.Status, " ", string(body))
if unixEpoch <= 0 {
panic("invalid " + headerName + " header: expected positive Unix epoch seconds, got " + strconv.Quote(headerValue))
}
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
}
var tokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
}
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
if err != nil {
return nil, E.Cause(err, "decode response")
}
newCredentials := *credentials
newCredentials.AccessToken = tokenResponse.AccessToken
if tokenResponse.RefreshToken != "" {
newCredentials.RefreshToken = tokenResponse.RefreshToken
}
newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000
return &newCredentials, nil
return time.Unix(unixEpoch, 0)
}
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
if credentials == nil {
return nil
func parseOptionalAnthropicResetHeader(headers http.Header, headerName string) (time.Time, bool) {
headerValue := headers.Get(headerName)
if headerValue == "" {
return time.Time{}, false
}
cloned := *credentials
cloned.Scopes = append([]string(nil), credentials.Scopes...)
return &cloned
return parseAnthropicResetHeaderValue(headerName, headerValue), true
}
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
if left == nil || right == nil {
return left == right
func parseRequiredAnthropicResetHeader(headers http.Header, headerName string) time.Time {
headerValue := headers.Get(headerName)
if headerValue == "" {
panic("missing required " + headerName + " header")
}
return parseAnthropicResetHeaderValue(headerName, headerValue)
}
func (s *credentialState) noteSnapshotData() {
s.lastKnownDataAt = time.Now()
}
func (s credentialState) hasSnapshotData() bool {
return !s.lastKnownDataAt.IsZero() ||
s.fiveHourUtilization > 0 ||
s.weeklyUtilization > 0 ||
!s.fiveHourReset.IsZero() ||
!s.weeklyReset.IsZero()
}
func (s *credentialState) setAvailability(state availabilityState, reason availabilityReason, resetAt time.Time) {
s.availabilityState = state
s.availabilityReason = reason
s.availabilityResetAt = resetAt
}
func (s credentialState) currentAvailability() availabilityStatus {
now := time.Now()
switch {
case s.unavailable:
return availabilityStatus{
State: availabilityStateUnavailable,
Reason: availabilityReasonUnknown,
ResetAt: s.availabilityResetAt,
}
case s.hardRateLimited && (s.rateLimitResetAt.IsZero() || now.Before(s.rateLimitResetAt)):
reason := s.availabilityReason
if reason == "" {
reason = availabilityReasonHardRateLimit
}
return availabilityStatus{
State: availabilityStateRateLimited,
Reason: reason,
ResetAt: s.rateLimitResetAt,
}
case !s.upstreamRejectedUntil.IsZero() && now.Before(s.upstreamRejectedUntil):
return availabilityStatus{
State: availabilityStateTemporarilyBlocked,
Reason: availabilityReasonUpstreamRejected,
ResetAt: s.upstreamRejectedUntil,
}
case s.consecutivePollFailures > 0:
return availabilityStatus{
State: availabilityStateTemporarilyBlocked,
Reason: availabilityReasonPollFailed,
}
default:
return availabilityStatus{State: availabilityStateUsable}
}
}
func parseRateLimitResetFromHeaders(headers http.Header) time.Time {
claim := headers.Get("anthropic-ratelimit-unified-representative-claim")
switch claim {
case "5h":
return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset")
case "7d":
return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset")
default:
panic("invalid anthropic-ratelimit-unified-representative-claim header: " + strconv.Quote(claim))
}
return left.AccessToken == right.AccessToken &&
left.RefreshToken == right.RefreshToken &&
left.ExpiresAt == right.ExpiresAt &&
slices.Equal(left.Scopes, right.Scopes) &&
left.SubscriptionType == right.SubscriptionType &&
left.RateLimitTier == right.RateLimitTier &&
left.IsMax == right.IsMax
}

View File

@@ -0,0 +1,162 @@
package ccm
import (
"context"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
func buildCredentialProviders(
ctx context.Context,
options option.CCMServiceOptions,
logger log.ContextLogger,
) (map[string]credentialProvider, []Credential, error) {
allCredentialMap := make(map[string]Credential)
var allCredentials []Credential
providers := make(map[string]credentialProvider)
// Pass 1: create default and external credentials
for _, credentialOption := range options.Credentials {
switch credentialOption.Type {
case "default":
credential, err := newDefaultCredential(ctx, credentialOption.Tag, credentialOption.DefaultOptions, logger)
if err != nil {
return nil, nil, err
}
allCredentialMap[credentialOption.Tag] = credential
allCredentials = append(allCredentials, credential)
providers[credentialOption.Tag] = &singleCredentialProvider{credential: credential}
case "external":
credential, err := newExternalCredential(ctx, credentialOption.Tag, credentialOption.ExternalOptions, logger)
if err != nil {
return nil, nil, err
}
allCredentialMap[credentialOption.Tag] = credential
allCredentials = append(allCredentials, credential)
providers[credentialOption.Tag] = &singleCredentialProvider{credential: credential}
}
}
// Pass 2: create balancer providers
for _, credentialOption := range options.Credentials {
if credentialOption.Type == "balancer" {
subCredentials, err := resolveCredentialTags(credentialOption.BalancerOptions.Credentials, allCredentialMap, credentialOption.Tag)
if err != nil {
return nil, nil, err
}
providers[credentialOption.Tag] = newBalancerProvider(subCredentials, credentialOption.BalancerOptions.Strategy, credentialOption.BalancerOptions.RebalanceThreshold, logger)
}
}
return providers, allCredentials, nil
}
func resolveCredentialTags(tags []string, allCredentials map[string]Credential, parentTag string) ([]Credential, error) {
credentials := make([]Credential, 0, len(tags))
for _, tag := range tags {
credential, exists := allCredentials[tag]
if !exists {
return nil, E.New("credential ", parentTag, " references unknown credential: ", tag)
}
credentials = append(credentials, credential)
}
if len(credentials) == 0 {
return nil, E.New("credential ", parentTag, " has no sub-credentials")
}
return credentials, nil
}
func validateCCMOptions(options option.CCMServiceOptions) error {
tags := make(map[string]bool)
credentialTypes := make(map[string]string)
for _, credential := range options.Credentials {
if tags[credential.Tag] {
return E.New("duplicate credential tag: ", credential.Tag)
}
tags[credential.Tag] = true
credentialTypes[credential.Tag] = credential.Type
if credential.Type == "default" || credential.Type == "" {
if credential.DefaultOptions.Reserve5h > 99 {
return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99")
}
if credential.DefaultOptions.ReserveWeekly > 99 {
return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99")
}
if credential.DefaultOptions.Limit5h > 100 {
return E.New("credential ", credential.Tag, ": limit_5h must be at most 100")
}
if credential.DefaultOptions.LimitWeekly > 100 {
return E.New("credential ", credential.Tag, ": limit_weekly must be at most 100")
}
if credential.DefaultOptions.Reserve5h > 0 && credential.DefaultOptions.Limit5h > 0 {
return E.New("credential ", credential.Tag, ": reserve_5h and limit_5h are mutually exclusive")
}
if credential.DefaultOptions.ReserveWeekly > 0 && credential.DefaultOptions.LimitWeekly > 0 {
return E.New("credential ", credential.Tag, ": reserve_weekly and limit_weekly are mutually exclusive")
}
}
if credential.Type == "external" {
if credential.ExternalOptions.Token == "" {
return E.New("credential ", credential.Tag, ": external credential requires token")
}
if credential.ExternalOptions.Reverse && credential.ExternalOptions.URL == "" {
return E.New("credential ", credential.Tag, ": reverse external credential requires url")
}
}
if credential.Type == "balancer" {
switch credential.BalancerOptions.Strategy {
case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback:
default:
return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy)
}
if credential.BalancerOptions.RebalanceThreshold < 0 {
return E.New("credential ", credential.Tag, ": rebalance_threshold must not be negative")
}
}
}
singleCredential := len(options.Credentials) == 1
for _, user := range options.Users {
if user.Credential == "" && !singleCredential {
return E.New("user ", user.Name, " must specify credential in multi-credential mode")
}
if user.Credential != "" && !tags[user.Credential] {
return E.New("user ", user.Name, " references unknown credential: ", user.Credential)
}
if user.ExternalCredential != "" {
if !tags[user.ExternalCredential] {
return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential)
}
if credentialTypes[user.ExternalCredential] != "external" {
return E.New("user ", user.Name, ": external_credential must reference an external type credential")
}
}
}
return nil
}
func credentialForUser(
userConfigMap map[string]*option.CCMUser,
providers map[string]credentialProvider,
username string,
) (credentialProvider, error) {
userConfig, exists := userConfigMap[username]
if !exists {
return nil, E.New("no credential mapping for user: ", username)
}
if userConfig.Credential == "" {
for _, provider := range providers {
return provider, nil
}
return nil, E.New("no credential available")
}
provider, exists := providers[userConfig.Credential]
if !exists {
return nil, E.New("unknown credential: ", userConfig.Credential)
}
return provider, nil
}

View File

@@ -0,0 +1,147 @@
package ccm
import (
"encoding/json"
"os"
"path/filepath"
)
// claudeCodeConfig represents the persisted config written by Claude Code.
//
// ref (@anthropic-ai/claude-code @2.1.81):
//
// ref: cli.js P8() (line 174997) — reads config
// ref: cli.js c8() (line 174919) — writes config
// ref: cli.js _D() (line 39158-39163) — config file path resolution
type claudeCodeConfig struct {
UserID string `json:"userID"` // ref: cli.js XL() (line 175325) — random 32-byte hex, generated once
OAuthAccount *claudeOAuthAccount `json:"oauthAccount"` // ref: cli.js fP6() / storeOAuthAccountInfo — from /api/oauth/profile
}
type claudeOAuthAccount struct {
AccountUUID string `json:"accountUuid,omitempty"`
EmailAddress string `json:"emailAddress,omitempty"`
OrganizationUUID string `json:"organizationUuid,omitempty"`
DisplayName *string `json:"displayName,omitempty"`
HasExtraUsageEnabled *bool `json:"hasExtraUsageEnabled,omitempty"`
BillingType *string `json:"billingType,omitempty"`
AccountCreatedAt *string `json:"accountCreatedAt,omitempty"`
SubscriptionCreatedAt *string `json:"subscriptionCreatedAt,omitempty"`
}
// resolveClaudeConfigFile finds the Claude Code config file within the given directory.
//
// Config file path resolution mirrors cli.js _D() (line 39158-39163):
// 1. claudeDirectory/.config.json — newer format, checked first
// 2. claudeDirectory/.claude.json — used when CLAUDE_CONFIG_DIR is set
// 3. filepath.Dir(claudeDirectory)/.claude.json — default ~/.claude case → ~/.claude.json
//
// Returns the first path that exists, or "" if none found.
func resolveClaudeConfigFile(claudeDirectory string) string {
candidates := []string{
filepath.Join(claudeDirectory, ".config.json"),
filepath.Join(claudeDirectory, claudeCodeLegacyConfigFileName()),
filepath.Join(filepath.Dir(claudeDirectory), claudeCodeLegacyConfigFileName()),
}
for _, candidate := range candidates {
_, err := os.Stat(candidate)
if err == nil {
return candidate
}
}
return ""
}
func readClaudeCodeConfig(path string) (*claudeCodeConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var config claudeCodeConfig
err = json.Unmarshal(data, &config)
if err != nil {
return nil, err
}
return &config, nil
}
func resolveClaudeConfigWritePath(claudeDirectory string) string {
if claudeDirectory == "" {
return ""
}
existingPath := resolveClaudeConfigFile(claudeDirectory)
if existingPath != "" {
return existingPath
}
if os.Getenv("CLAUDE_CONFIG_DIR") != "" {
return filepath.Join(claudeDirectory, claudeCodeLegacyConfigFileName())
}
defaultClaudeDirectory := filepath.Join(filepath.Dir(claudeDirectory), ".claude")
if claudeDirectory != defaultClaudeDirectory {
return filepath.Join(claudeDirectory, claudeCodeLegacyConfigFileName())
}
return filepath.Join(filepath.Dir(claudeDirectory), claudeCodeLegacyConfigFileName())
}
func writeClaudeCodeOAuthAccount(path string, account *claudeOAuthAccount) error {
if path == "" || account == nil {
return nil
}
storage := jsonFileStorage{path: path}
return writeStorageValue(storage, "oauthAccount", account)
}
func claudeCodeLegacyConfigFileName() string {
if os.Getenv("CLAUDE_CODE_CUSTOM_OAUTH_URL") != "" {
return ".claude-custom-oauth.json"
}
return ".claude.json"
}
func cloneClaudeOAuthAccount(account *claudeOAuthAccount) *claudeOAuthAccount {
if account == nil {
return nil
}
cloned := *account
cloned.DisplayName = cloneStringPointer(account.DisplayName)
cloned.HasExtraUsageEnabled = cloneBoolPointer(account.HasExtraUsageEnabled)
cloned.BillingType = cloneStringPointer(account.BillingType)
cloned.AccountCreatedAt = cloneStringPointer(account.AccountCreatedAt)
cloned.SubscriptionCreatedAt = cloneStringPointer(account.SubscriptionCreatedAt)
return &cloned
}
func mergeClaudeOAuthAccount(base *claudeOAuthAccount, update *claudeOAuthAccount) *claudeOAuthAccount {
if update == nil {
return cloneClaudeOAuthAccount(base)
}
if base == nil {
return cloneClaudeOAuthAccount(update)
}
merged := cloneClaudeOAuthAccount(base)
if update.AccountUUID != "" {
merged.AccountUUID = update.AccountUUID
}
if update.EmailAddress != "" {
merged.EmailAddress = update.EmailAddress
}
if update.OrganizationUUID != "" {
merged.OrganizationUUID = update.OrganizationUUID
}
if update.DisplayName != nil {
merged.DisplayName = cloneStringPointer(update.DisplayName)
}
if update.HasExtraUsageEnabled != nil {
merged.HasExtraUsageEnabled = cloneBoolPointer(update.HasExtraUsageEnabled)
}
if update.BillingType != nil {
merged.BillingType = cloneStringPointer(update.BillingType)
}
if update.AccountCreatedAt != nil {
merged.AccountCreatedAt = cloneStringPointer(update.AccountCreatedAt)
}
if update.SubscriptionCreatedAt != nil {
merged.SubscriptionCreatedAt = cloneStringPointer(update.SubscriptionCreatedAt)
}
return merged
}

View File

@@ -14,6 +14,11 @@ import (
"github.com/keybase/go-keychain"
)
type keychainStorage struct {
service string
account string
}
func getKeychainServiceName() string {
configDirectory := os.Getenv("CLAUDE_CONFIG_DIR")
if configDirectory == "" {
@@ -76,48 +81,90 @@ func platformCanWriteCredentials(customPath string) error {
return checkCredentialFileWritable(customPath)
}
func platformWriteCredentials(oauthCredentials *oauthCredentials, customPath string) error {
func platformWriteCredentials(credentials *oauthCredentials, customPath string) error {
if customPath != "" {
return writeCredentialsToFile(oauthCredentials, customPath)
}
userInfo, err := getRealUser()
if err == nil {
data, err := json.Marshal(map[string]any{"claudeAiOauth": oauthCredentials})
if err == nil {
serviceName := getKeychainServiceName()
item := keychain.NewItem()
item.SetSecClass(keychain.SecClassGenericPassword)
item.SetService(serviceName)
item.SetAccount(userInfo.Username)
item.SetData(data)
item.SetAccessible(keychain.AccessibleWhenUnlocked)
err = keychain.AddItem(item)
if err == nil {
return nil
}
if err == keychain.ErrorDuplicateItem {
query := keychain.NewItem()
query.SetSecClass(keychain.SecClassGenericPassword)
query.SetService(serviceName)
query.SetAccount(userInfo.Username)
updateItem := keychain.NewItem()
updateItem.SetData(data)
updateErr := keychain.UpdateItem(query, updateItem)
if updateErr == nil {
return nil
}
}
}
return writeCredentialsToFile(credentials, customPath)
}
defaultPath, err := getDefaultCredentialsPath()
if err != nil {
return err
}
return writeCredentialsToFile(oauthCredentials, defaultPath)
fileStorage := jsonFileStorage{path: defaultPath}
userInfo, err := getRealUser()
if err != nil {
return writeCredentialsToFile(credentials, defaultPath)
}
return persistStorageValue(keychainStorage{
service: getKeychainServiceName(),
account: userInfo.Username,
}, fileStorage, "claudeAiOauth", credentials)
}
func (s keychainStorage) readContainer() (map[string]json.RawMessage, bool, error) {
query := keychain.NewItem()
query.SetSecClass(keychain.SecClassGenericPassword)
query.SetService(s.service)
query.SetAccount(s.account)
query.SetMatchLimit(keychain.MatchLimitOne)
query.SetReturnData(true)
results, err := keychain.QueryItem(query)
if err != nil {
if err == keychain.ErrorItemNotFound {
return make(map[string]json.RawMessage), false, nil
}
return nil, false, E.Cause(err, "query keychain")
}
if len(results) != 1 {
return make(map[string]json.RawMessage), false, nil
}
container := make(map[string]json.RawMessage)
if len(results[0].Data) == 0 {
return container, true, nil
}
if err := json.Unmarshal(results[0].Data, &container); err != nil {
return nil, true, err
}
return container, true, nil
}
func (s keychainStorage) writeContainer(container map[string]json.RawMessage) error {
data, err := json.Marshal(container)
if err != nil {
return err
}
item := keychain.NewItem()
item.SetSecClass(keychain.SecClassGenericPassword)
item.SetService(s.service)
item.SetAccount(s.account)
item.SetData(data)
item.SetAccessible(keychain.AccessibleWhenUnlocked)
err = keychain.AddItem(item)
if err == nil {
return nil
}
if err != keychain.ErrorDuplicateItem {
return err
}
updateQuery := keychain.NewItem()
updateQuery.SetSecClass(keychain.SecClassGenericPassword)
updateQuery.SetService(s.service)
updateQuery.SetAccount(s.account)
updateItem := keychain.NewItem()
updateItem.SetData(data)
return keychain.UpdateItem(updateQuery, updateItem)
}
func (s keychainStorage) delete() error {
err := keychain.DeleteGenericPasswordItem(s.service, s.account)
if err != nil && err != keychain.ErrorItemNotFound {
return err
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,245 @@
package ccm
import (
"errors"
"net/http"
"os"
"path/filepath"
"testing"
"time"
)
func TestGetAccessTokenMarksUnavailableWhenLockFails(t *testing.T) {
t.Parallel()
directory := t.TempDir()
credentialPath := filepath.Join(directory, ".credentials.json")
credentials := &oauthCredentials{
AccessToken: "old-token",
RefreshToken: "refresh-token",
ExpiresAt: time.Now().Add(time.Hour).UnixMilli(),
Scopes: []string{"user:profile", "user:inference"},
SubscriptionType: optionalStringPointer("max"),
RateLimitTier: optionalStringPointer("default_claude_max_20x"),
}
writeTestCredentials(t, credentialPath, credentials)
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
t.Fatal("refresh should not be attempted when lock acquisition fails")
return nil, nil
}))
if err := credential.reloadCredentials(true); err != nil {
t.Fatal(err)
}
expiredCredentials := cloneCredentials(credentials)
expiredCredentials.ExpiresAt = time.Now().Add(-time.Minute).UnixMilli()
writeTestCredentials(t, credentialPath, expiredCredentials)
credential.absorbCredentials(expiredCredentials)
credential.acquireLock = func(string) (func(), error) {
return nil, errors.New("permission denied")
}
_, err := credential.getAccessToken()
if err == nil {
t.Fatal("expected error when lock acquisition fails, got nil")
}
if credential.isUsable() {
t.Fatal("credential should be marked unavailable after lock failure")
}
}
func TestGetAccessTokenMarksUnavailableOnUnwritableFile(t *testing.T) {
t.Parallel()
directory := t.TempDir()
credentialPath := filepath.Join(directory, ".credentials.json")
credentials := &oauthCredentials{
AccessToken: "old-token",
RefreshToken: "refresh-token",
ExpiresAt: time.Now().Add(time.Hour).UnixMilli(),
Scopes: []string{"user:profile", "user:inference"},
}
writeTestCredentials(t, credentialPath, credentials)
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
t.Fatal("refresh should not be attempted when file is not writable")
return nil, nil
}))
if err := credential.reloadCredentials(true); err != nil {
t.Fatal(err)
}
expiredCredentials := cloneCredentials(credentials)
expiredCredentials.ExpiresAt = time.Now().Add(-time.Minute).UnixMilli()
writeTestCredentials(t, credentialPath, expiredCredentials)
credential.absorbCredentials(expiredCredentials)
os.Chmod(credentialPath, 0o444)
t.Cleanup(func() { os.Chmod(credentialPath, 0o644) })
_, err := credential.getAccessToken()
if err == nil {
t.Fatal("expected error when credential file is not writable, got nil")
}
if credential.isUsable() {
t.Fatal("credential should be marked unavailable after write permission failure")
}
}
func TestGetAccessTokenAbsorbsRefreshDoneByAnotherProcess(t *testing.T) {
t.Parallel()
directory := t.TempDir()
credentialPath := filepath.Join(directory, ".credentials.json")
oldCredentials := &oauthCredentials{
AccessToken: "old-token",
RefreshToken: "refresh-token",
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
Scopes: []string{"user:profile", "user:inference"},
SubscriptionType: optionalStringPointer("max"),
RateLimitTier: optionalStringPointer("default_claude_max_20x"),
}
writeTestCredentials(t, credentialPath, oldCredentials)
newCredentials := cloneCredentials(oldCredentials)
newCredentials.AccessToken = "new-token"
newCredentials.ExpiresAt = time.Now().Add(time.Hour).UnixMilli()
transport := roundTripFunc(func(request *http.Request) (*http.Response, error) {
if request.URL.Path == "/v1/oauth/token" {
writeTestCredentials(t, credentialPath, newCredentials)
return newJSONResponse(http.StatusInternalServerError, `{"error":"boom"}`), nil
}
t.Fatalf("unexpected path %s", request.URL.Path)
return nil, nil
})
credential := newTestDefaultCredential(t, credentialPath, transport)
if err := credential.reloadCredentials(true); err != nil {
t.Fatal(err)
}
token, err := credential.getAccessToken()
if err != nil {
t.Fatal(err)
}
if token != "new-token" {
t.Fatalf("expected refreshed token from disk, got %q", token)
}
}
func TestCustomCredentialPathDoesNotEnableClaudeConfigSync(t *testing.T) {
t.Parallel()
directory := t.TempDir()
credentialPath := filepath.Join(directory, ".credentials.json")
writeTestCredentials(t, credentialPath, &oauthCredentials{
AccessToken: "token",
ExpiresAt: time.Now().Add(time.Hour).UnixMilli(),
Scopes: []string{"user:profile"},
})
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
t.Fatalf("unexpected request to %s", request.URL.Path)
return nil, nil
}))
if err := credential.reloadCredentials(true); err != nil {
t.Fatal(err)
}
token, err := credential.getAccessToken()
if err != nil {
t.Fatal(err)
}
if token != "token" {
t.Fatalf("expected token, got %q", token)
}
if credential.shouldUseClaudeConfig() {
t.Fatal("custom credential path should not enable Claude config sync")
}
if _, err := os.Stat(filepath.Join(directory, ".claude.json")); !os.IsNotExist(err) {
t.Fatalf("did not expect config file to be created, stat err=%v", err)
}
}
func TestDefaultCredentialHydratesProfileAndWritesConfig(t *testing.T) {
configDir := t.TempDir()
credentialPath := filepath.Join(configDir, ".credentials.json")
writeTestCredentials(t, credentialPath, &oauthCredentials{
AccessToken: "old-token",
RefreshToken: "refresh-token",
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
Scopes: []string{"user:profile", "user:inference"},
})
transport := roundTripFunc(func(request *http.Request) (*http.Response, error) {
switch request.URL.Path {
case "/v1/oauth/token":
return newJSONResponse(http.StatusOK, `{
"access_token":"new-token",
"refresh_token":"new-refresh",
"expires_in":3600,
"account":{"uuid":"account","email_address":"user@example.com"},
"organization":{"uuid":"org"}
}`), nil
case "/api/oauth/profile":
return newJSONResponse(http.StatusOK, `{
"account":{
"uuid":"account",
"email":"user@example.com",
"display_name":"User",
"created_at":"2024-01-01T00:00:00Z"
},
"organization":{
"uuid":"org",
"organization_type":"claude_max",
"rate_limit_tier":"default_claude_max_20x",
"has_extra_usage_enabled":true,
"billing_type":"individual",
"subscription_created_at":"2024-01-02T00:00:00Z"
}
}`), nil
default:
t.Fatalf("unexpected path %s", request.URL.Path)
return nil, nil
}
})
credential := newTestDefaultCredential(t, credentialPath, transport)
credential.syncClaudeConfig = true
credential.claudeDirectory = configDir
credential.claudeConfigPath = resolveClaudeConfigWritePath(configDir)
if err := credential.reloadCredentials(true); err != nil {
t.Fatal(err)
}
token, err := credential.getAccessToken()
if err != nil {
t.Fatal(err)
}
if token != "new-token" {
t.Fatalf("expected refreshed token, got %q", token)
}
updatedCredentials := readTestCredentials(t, credentialPath)
if updatedCredentials.SubscriptionType == nil || *updatedCredentials.SubscriptionType != "max" {
t.Fatalf("expected subscription type to be persisted, got %#v", updatedCredentials.SubscriptionType)
}
if updatedCredentials.RateLimitTier == nil || *updatedCredentials.RateLimitTier != "default_claude_max_20x" {
t.Fatalf("expected rate limit tier to be persisted, got %#v", updatedCredentials.RateLimitTier)
}
configPath := tempConfigPath(t, configDir)
config, err := readClaudeCodeConfig(configPath)
if err != nil {
t.Fatal(err)
}
if config.OAuthAccount == nil || config.OAuthAccount.AccountUUID != "account" || config.OAuthAccount.EmailAddress != "user@example.com" {
t.Fatalf("unexpected oauth account: %#v", config.OAuthAccount)
}
if config.OAuthAccount.BillingType == nil || *config.OAuthAccount.BillingType != "individual" {
t.Fatalf("expected billing type to be hydrated, got %#v", config.OAuthAccount.BillingType)
}
}

View File

@@ -5,6 +5,7 @@ import (
"context"
stdTLS "crypto/tls"
"encoding/json"
"errors"
"io"
"net"
"net/http"
@@ -22,6 +23,7 @@ import (
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/common/observable"
"github.com/hashicorp/yamux"
)
@@ -29,25 +31,26 @@ import (
const reverseProxyBaseURL = "http://reverse-proxy"
type externalCredential struct {
tag string
baseURL string
token string
httpClient *http.Client
state credentialState
stateMutex sync.RWMutex
pollAccess sync.Mutex
pollInterval time.Duration
usageTracker *AggregatedUsage
logger log.ContextLogger
tag string
baseURL string
token string
forwardHTTPClient *http.Client
state credentialState
stateAccess sync.RWMutex
pollAccess sync.Mutex
usageTracker *AggregatedUsage
logger log.ContextLogger
onBecameUnusable func()
interrupted bool
requestContext context.Context
cancelRequests context.CancelFunc
requestAccess sync.Mutex
statusSubscriber *observable.Subscriber[struct{}]
interrupted bool
requestContext context.Context
cancelRequests context.CancelFunc
requestAccess sync.Mutex
// Reverse proxy fields
reverse bool
reverseHTTPClient *http.Client
reverseSession *yamux.Session
reverseAccess sync.RWMutex
closed bool
@@ -61,10 +64,15 @@ type externalCredential struct {
reverseService http.Handler
}
type statusStreamResult struct {
duration time.Duration
frames int
}
func externalCredentialURLPort(parsedURL *url.URL) uint16 {
portStr := parsedURL.Port()
if portStr != "" {
port, err := strconv.ParseUint(portStr, 10, 16)
portString := parsedURL.Port()
if portString != "" {
port, err := strconv.ParseUint(portString, 10, 16)
if err == nil {
return uint16(port)
}
@@ -104,18 +112,12 @@ func externalCredentialReversePath(parsedURL *url.URL, endpointPath string) stri
}
func newExternalCredential(ctx context.Context, tag string, options option.CCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) {
pollInterval := time.Duration(options.PollInterval)
if pollInterval <= 0 {
pollInterval = 30 * time.Minute
}
requestContext, cancelRequests := context.WithCancel(context.Background())
reverseContext, reverseCancel := context.WithCancel(context.Background())
cred := &externalCredential{
credential := &externalCredential{
tag: tag,
token: options.Token,
pollInterval: pollInterval,
logger: logger,
requestContext: requestContext,
cancelRequests: cancelRequests,
@@ -126,12 +128,12 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx
if options.URL == "" {
// Receiver mode: no URL, wait for reverse connection
cred.baseURL = reverseProxyBaseURL
cred.httpClient = &http.Client{
credential.baseURL = reverseProxyBaseURL
credential.forwardHTTPClient = &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: false,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return cred.openReverseConnection(ctx)
return credential.openReverseConnection(ctx)
},
},
}
@@ -172,33 +174,42 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx
}
}
cred.baseURL = externalCredentialBaseURL(parsedURL)
credential.baseURL = externalCredentialBaseURL(parsedURL)
if options.Reverse {
// Connector mode: we dial out to serve, not to proxy
cred.connectorDialer = credentialDialer
credential.connectorDialer = credentialDialer
if options.Server != "" {
cred.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort))
credential.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort))
} else {
cred.connectorDestination = M.ParseSocksaddrHostPort(parsedURL.Hostname(), externalCredentialURLPort(parsedURL))
credential.connectorDestination = M.ParseSocksaddrHostPort(parsedURL.Hostname(), externalCredentialURLPort(parsedURL))
}
cred.connectorRequestPath = externalCredentialReversePath(parsedURL, "/ccm/v1/reverse")
cred.connectorURL = parsedURL
credential.connectorRequestPath = externalCredentialReversePath(parsedURL, "/ccm/v1/reverse")
credential.connectorURL = parsedURL
if parsedURL.Scheme == "https" {
cred.connectorTLS = &stdTLS.Config{
credential.connectorTLS = &stdTLS.Config{
ServerName: parsedURL.Hostname(),
RootCAs: adapter.RootPoolFromContext(ctx),
Time: ntp.TimeFuncFromContext(ctx),
}
}
credential.forwardHTTPClient = &http.Client{Transport: transport}
} else {
// Normal mode: standard HTTP client for proxying
cred.httpClient = &http.Client{Transport: transport}
credential.forwardHTTPClient = &http.Client{Transport: transport}
credential.reverseHTTPClient = &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: false,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return credential.openReverseConnection(ctx)
},
},
}
}
}
if options.UsagesPath != "" {
cred.usageTracker = &AggregatedUsage{
credential.usageTracker = &AggregatedUsage{
LastUpdated: time.Now(),
Combinations: make([]CostCombination, 0),
filePath: options.UsagesPath,
@@ -206,7 +217,17 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx
}
}
return cred, nil
return credential, nil
}
func (c *externalCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) {
c.statusSubscriber = subscriber
}
func (c *externalCredential) emitStatusUpdate() {
if c.statusSubscriber != nil {
c.statusSubscriber.Emit(struct{}{})
}
}
func (c *externalCredential) start() error {
@@ -218,6 +239,8 @@ func (c *externalCredential) start() error {
}
if c.reverse && c.connectorURL != nil {
go c.connectorLoop()
} else {
go c.statusStreamLoop()
}
return nil
}
@@ -238,40 +261,44 @@ func (c *externalCredential) isUsable() bool {
if !c.isAvailable() {
return false
}
c.stateMutex.RLock()
c.stateAccess.RLock()
if c.state.consecutivePollFailures > 0 {
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
return false
}
if !c.state.upstreamRejectedUntil.IsZero() && time.Now().Before(c.state.upstreamRejectedUntil) {
c.stateAccess.RUnlock()
return false
}
if c.state.hardRateLimited {
if time.Now().Before(c.state.rateLimitResetAt) {
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
return false
}
c.stateMutex.RUnlock()
c.stateMutex.Lock()
c.stateAccess.RUnlock()
c.stateAccess.Lock()
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
// No reserve for external: only 100% is unusable
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
c.stateMutex.Unlock()
c.stateAccess.Unlock()
return usable
}
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
return usable
}
func (c *externalCredential) fiveHourUtilization() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.fiveHourUtilization
}
func (c *externalCredential) weeklyUtilization() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.weeklyUtilization
}
@@ -284,35 +311,56 @@ func (c *externalCredential) weeklyCap() float64 {
}
func (c *externalCredential) planWeight() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
if c.state.remotePlanWeight > 0 {
return c.state.remotePlanWeight
}
return 10
}
func (c *externalCredential) fiveHourResetTime() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.fiveHourReset
}
func (c *externalCredential) weeklyResetTime() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.weeklyReset
}
func (c *externalCredential) markRateLimited(resetAt time.Time) {
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.hardRateLimited = true
c.state.rateLimitResetAt = resetAt
c.state.setAvailability(availabilityStateRateLimited, availabilityReasonHardRateLimit, resetAt)
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *externalCredential) markUpstreamRejected() {
c.logger.Warn("upstream rejected credential ", c.tag, ", marking unavailable for ", log.FormatDuration(defaultPollInterval))
c.stateAccess.Lock()
c.state.upstreamRejectedUntil = time.Now().Add(defaultPollInterval)
c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonUpstreamRejected, c.state.upstreamRejectedUntil)
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *externalCredential) earliestReset() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
if c.state.hardRateLimited {
return c.state.rateLimitResetAt
}
@@ -324,9 +372,6 @@ func (c *externalCredential) earliestReset() time.Time {
}
func (c *externalCredential) unavailableError() error {
if c.reverse && c.connectorURL != nil {
return E.New("credential ", c.tag, " is unavailable: reverse connector credentials cannot serve local requests")
}
if c.baseURL == reverseProxyBaseURL {
session := c.getReverseSession()
if session == nil || session.IsClosed() {
@@ -341,7 +386,14 @@ func (c *externalCredential) getAccessToken() (string, error) {
}
func (c *externalCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, _ http.Header) (*http.Request, error) {
proxyURL := c.baseURL + original.URL.RequestURI()
baseURL := c.baseURL
if c.reverseHTTPClient != nil {
session := c.getReverseSession()
if session != nil && !session.IsClosed() {
baseURL = reverseProxyBaseURL
}
}
proxyURL := baseURL + original.URL.RequestURI()
var body io.Reader
if bodyBytes != nil {
body = bytes.NewReader(bodyBytes)
@@ -354,7 +406,7 @@ func (c *externalCredential) buildProxyRequest(ctx context.Context, original *ht
}
for key, values := range original.Header {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && !isAPIKeyHeader(key) && key != "Authorization" {
proxyRequest.Header[key] = values
}
}
@@ -391,10 +443,13 @@ func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Con
}
func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
c.stateMutex.Lock()
c.stateAccess.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
oldPlanWeight := c.state.remotePlanWeight
oldFiveHourReset := c.state.fiveHourReset
oldWeeklyReset := c.state.weeklyReset
hadData := false
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists {
@@ -428,7 +483,9 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
}
if hadData {
c.state.consecutivePollFailures = 0
c.state.upstreamRejectedUntil = time.Time{}
c.state.lastUpdated = time.Now()
c.state.noteSnapshotData()
}
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
resetSuffix := ""
@@ -437,15 +494,23 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
}
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
}
utilizationChanged := c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly
planWeightChanged := c.state.remotePlanWeight != oldPlanWeight
resetChanged := c.state.fiveHourReset != oldFiveHourReset || c.state.weeklyReset != oldWeeklyReset
shouldEmit := (hadData && (utilizationChanged || resetChanged)) || planWeightChanged
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
if shouldEmit {
c.emitStatusUpdate()
}
}
func (c *externalCredential) checkTransitionLocked() bool {
unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100 || c.state.consecutivePollFailures > 0
upstreamRejected := !c.state.upstreamRejectedUntil.IsZero() && time.Now().Before(c.state.upstreamRejectedUntil)
unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100 || c.state.consecutivePollFailures > 0 || upstreamRejected
if unusable && !c.interrupted {
c.interrupted = true
return true
@@ -465,9 +530,9 @@ func (c *externalCredential) wrapRequestContext(parent context.Context) *credent
cancel()
})
return &credentialRequestContext{
Context: derived,
releaseFunc: stop,
cancelFunc: cancel,
Context: derived,
releaseFuncs: []func() bool{stop},
cancelFunc: cancel,
}
}
@@ -477,32 +542,58 @@ func (c *externalCredential) interruptConnections() {
c.cancelRequests()
c.requestContext, c.cancelRequests = context.WithCancel(context.Background())
c.requestAccess.Unlock()
if c.onBecameUnusable != nil {
c.onBecameUnusable()
}
}
func (c *externalCredential) pollUsage(ctx context.Context) {
func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Response, error) {
buildRequest := func(baseURL string) func() (*http.Request, error) {
return func() (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/ccm/v1/status", nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+c.token)
return request, nil
}
}
// Try reverse transport first (single attempt, no retry)
if c.reverseHTTPClient != nil {
session := c.getReverseSession()
if session != nil && !session.IsClosed() {
request, err := buildRequest(reverseProxyBaseURL)()
if err != nil {
return nil, err
}
reverseClient := &http.Client{
Transport: c.reverseHTTPClient.Transport,
Timeout: 5 * time.Second,
}
response, err := reverseClient.Do(request)
if err == nil {
return response, nil
}
// Reverse failed, fall through to forward if available
}
}
// Forward transport with retries
if c.forwardHTTPClient != nil {
forwardClient := &http.Client{
Transport: c.forwardHTTPClient.Transport,
Timeout: 5 * time.Second,
}
return doHTTPWithRetry(ctx, forwardClient, buildRequest(c.baseURL))
}
return nil, E.New("no transport available")
}
func (c *externalCredential) pollUsage() {
if !c.pollAccess.TryLock() {
return
}
defer c.pollAccess.Unlock()
defer c.markUsagePollAttempted()
statusURL := c.baseURL + "/ccm/v1/status"
httpClient := &http.Client{
Transport: c.httpClient.Transport,
Timeout: 5 * time.Second,
}
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+c.token)
return request, nil
})
ctx := c.getReverseContext()
response, err := c.doPollUsageRequest(ctx)
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": ", err)
c.incrementPollFailures()
@@ -512,42 +603,56 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
// 404 means the remote does not have a status endpoint yet;
// usage will be updated passively from response headers.
if response.StatusCode == http.StatusNotFound {
c.stateMutex.Lock()
c.state.consecutivePollFailures = 0
c.checkTransitionLocked()
c.stateMutex.Unlock()
} else {
c.incrementPollFailures()
}
return
}
var statusResponse struct {
FiveHourUtilization float64 `json:"five_hour_utilization"`
WeeklyUtilization float64 `json:"weekly_utilization"`
PlanWeight float64 `json:"plan_weight"`
}
err = json.NewDecoder(response.Body).Decode(&statusResponse)
if err != nil {
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
c.incrementPollFailures()
return
}
c.stateMutex.Lock()
body, err := io.ReadAll(response.Body)
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": read body: ", err)
c.incrementPollFailures()
return
}
var rawFields map[string]json.RawMessage
err = json.Unmarshal(body, &rawFields)
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": decode: ", err)
c.incrementPollFailures()
return
}
if rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil ||
rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil ||
rawFields["plan_weight"] == nil {
c.logger.Error("poll usage for ", c.tag, ": invalid response")
c.incrementPollFailures()
return
}
var statusResponse statusPayload
err = json.Unmarshal(body, &statusResponse)
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": decode: ", err)
c.incrementPollFailures()
return
}
c.stateAccess.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
c.state.consecutivePollFailures = 0
c.state.upstreamRejectedUntil = time.Time{}
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
if statusResponse.PlanWeight > 0 {
c.state.remotePlanWeight = statusResponse.PlanWeight
}
if statusResponse.FiveHourReset > 0 {
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
}
if statusResponse.WeeklyReset > 0 {
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
}
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
@@ -559,50 +664,228 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
}
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *externalCredential) statusStreamLoop() {
var consecutiveFailures int
ctx := c.getReverseContext()
for {
select {
case <-ctx.Done():
return
default:
}
result, err := c.connectStatusStream(ctx)
if ctx.Err() != nil {
return
}
if !shouldRetryStatusStreamError(err) {
c.logger.Warn("status stream for ", c.tag, " disconnected: ", err, ", not retrying")
return
}
var backoff time.Duration
consecutiveFailures, backoff = c.nextStatusStreamBackoff(result, consecutiveFailures)
c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff)
timer := time.NewTimer(backoff)
select {
case <-timer.C:
case <-ctx.Done():
timer.Stop()
return
}
}
}
func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStreamResult, error) {
startTime := time.Now()
result := statusStreamResult{}
response, err := c.doStreamStatusRequest(ctx)
if err != nil {
result.duration = time.Since(startTime)
return result, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
result.duration = time.Since(startTime)
return result, E.New("status ", response.StatusCode, " ", string(body))
}
decoder := json.NewDecoder(response.Body)
for {
var rawMessage json.RawMessage
err = decoder.Decode(&rawMessage)
if err != nil {
result.duration = time.Since(startTime)
return result, err
}
var rawFields map[string]json.RawMessage
err = json.Unmarshal(rawMessage, &rawFields)
if err != nil {
result.duration = time.Since(startTime)
return result, E.Cause(err, "decode status frame")
}
if rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil ||
rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil ||
rawFields["plan_weight"] == nil {
result.duration = time.Since(startTime)
return result, E.New("invalid response")
}
var statusResponse statusPayload
err = json.Unmarshal(rawMessage, &statusResponse)
if err != nil {
result.duration = time.Since(startTime)
return result, E.Cause(err, "decode status frame")
}
c.stateAccess.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
c.state.consecutivePollFailures = 0
c.state.upstreamRejectedUntil = time.Time{}
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
if statusResponse.PlanWeight > 0 {
c.state.remotePlanWeight = statusResponse.PlanWeight
}
if statusResponse.FiveHourReset > 0 {
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
}
if statusResponse.WeeklyReset > 0 {
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
}
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
resetSuffix := ""
if !c.state.weeklyReset.IsZero() {
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
}
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
}
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
result.frames++
c.markUsageStreamUpdated()
c.emitStatusUpdate()
}
}
func shouldRetryStatusStreamError(err error) bool {
return errors.Is(err, io.ErrUnexpectedEOF) || E.IsClosedOrCanceled(err)
}
func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration) {
if result.duration >= connectorBackoffResetThreshold {
consecutiveFailures = 0
}
consecutiveFailures++
return consecutiveFailures, connectorBackoff(consecutiveFailures)
}
func (c *externalCredential) doStreamStatusRequest(ctx context.Context) (*http.Response, error) {
buildRequest := func(baseURL string) (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/ccm/v1/status?watch=true", nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+c.token)
return request, nil
}
if c.reverseHTTPClient != nil {
session := c.getReverseSession()
if session != nil && !session.IsClosed() {
request, err := buildRequest(reverseProxyBaseURL)
if err != nil {
return nil, err
}
response, err := c.reverseHTTPClient.Do(request)
if err == nil {
return response, nil
}
}
}
if c.forwardHTTPClient != nil {
request, err := buildRequest(c.baseURL)
if err != nil {
return nil, err
}
return c.forwardHTTPClient.Do(request)
}
return nil, E.New("no transport available")
}
func (c *externalCredential) lastUpdatedTime() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.lastUpdated
}
func (c *externalCredential) hasSnapshotData() bool {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.hasSnapshotData()
}
func (c *externalCredential) availabilityStatus() availabilityStatus {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.currentAvailability()
}
func (c *externalCredential) markUsageStreamUpdated() {
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
c.state.lastUpdated = time.Now()
}
func (c *externalCredential) markUsagePollAttempted() {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
c.state.lastUpdated = time.Now()
}
func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration {
c.stateMutex.RLock()
failures := c.state.consecutivePollFailures
c.stateMutex.RUnlock()
if failures <= 0 {
return baseInterval
}
return failedPollRetryInterval
return baseInterval
}
func (c *externalCredential) incrementPollFailures() {
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.consecutivePollFailures++
c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonPollFailed, time.Time{})
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage {
return c.usageTracker
}
func (c *externalCredential) httpTransport() *http.Client {
return c.httpClient
func (c *externalCredential) httpClient() *http.Client {
if c.reverseHTTPClient != nil {
session := c.getReverseSession()
if session != nil && !session.IsClosed() {
return c.reverseHTTPClient
}
}
return c.forwardHTTPClient
}
func (c *externalCredential) close() {
@@ -636,26 +919,55 @@ func (c *externalCredential) getReverseSession() *yamux.Session {
}
func (c *externalCredential) setReverseSession(session *yamux.Session) bool {
var emitStatus bool
var restartStatusStream bool
var triggerUsageRefresh bool
c.reverseAccess.Lock()
if c.closed {
c.reverseAccess.Unlock()
return false
}
wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
old := c.reverseSession
c.reverseSession = session
isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
emitStatus = wasAvailable != isAvailable
if isAvailable && !wasAvailable {
c.reverseCancel()
c.reverseContext, c.reverseCancel = context.WithCancel(context.Background())
restartStatusStream = true
triggerUsageRefresh = true
}
c.reverseAccess.Unlock()
if old != nil {
old.Close()
}
if restartStatusStream {
c.logger.Debug("poll usage for ", c.tag, ": reverse session ready, restarting status stream")
go c.statusStreamLoop()
}
if triggerUsageRefresh {
go c.pollUsage()
}
if emitStatus {
c.emitStatusUpdate()
}
return true
}
func (c *externalCredential) clearReverseSession(session *yamux.Session) {
var emitStatus bool
c.reverseAccess.Lock()
wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
if c.reverseSession == session {
c.reverseSession = nil
}
isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
emitStatus = wasAvailable != isAvailable
c.reverseAccess.Unlock()
if emitStatus {
c.emitStatusUpdate()
}
}
func (c *externalCredential) getReverseContext() context.Context {

View File

@@ -62,10 +62,10 @@ func (c *defaultCredential) ensureCredentialWatcher() error {
}
func (c *defaultCredential) retryCredentialReloadIfNeeded() {
c.stateMutex.RLock()
c.stateAccess.RLock()
unavailable := c.state.unavailable
lastAttempt := c.state.lastCredentialLoadAttempt
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
if !unavailable {
return
}
@@ -75,7 +75,7 @@ func (c *defaultCredential) retryCredentialReloadIfNeeded() {
err := c.ensureCredentialWatcher()
if err != nil {
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
c.logger.Error("start credential watcher for ", c.tag, ": ", err)
}
_ = c.reloadCredentials(false)
}
@@ -84,10 +84,10 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
c.reloadAccess.Lock()
defer c.reloadAccess.Unlock()
c.stateMutex.RLock()
c.stateAccess.RLock()
unavailable := c.state.unavailable
lastAttempt := c.state.lastCredentialLoadAttempt
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
if !force {
if !unavailable {
return nil
@@ -97,47 +97,41 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
}
}
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.lastCredentialLoadAttempt = time.Now()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
credentials, err := platformReadCredentials(c.credentialPath)
if err != nil {
return c.markCredentialsUnavailable(E.Cause(err, "read credentials"))
}
c.accessMutex.Lock()
c.credentials = credentials
c.accessMutex.Unlock()
c.stateMutex.Lock()
c.state.unavailable = false
c.state.lastCredentialLoadError = ""
c.state.accountType = credentials.SubscriptionType
c.state.rateLimitTier = credentials.RateLimitTier
c.checkTransitionLocked()
c.stateMutex.Unlock()
return nil
c.absorbCredentials(credentials)
return c.refreshCredentialsIfNeeded(false)
}
func (c *defaultCredential) markCredentialsUnavailable(err error) error {
c.accessMutex.Lock()
c.access.Lock()
hadCredentials := c.credentials != nil
c.credentials = nil
c.accessMutex.Unlock()
c.access.Unlock()
c.stateMutex.Lock()
c.stateAccess.Lock()
before := c.statusSnapshotLocked()
c.state.unavailable = true
c.state.lastCredentialLoadError = err.Error()
c.state.accountType = ""
c.state.rateLimitTier = ""
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
shouldEmit := before != c.statusSnapshotLocked()
c.stateAccess.Unlock()
if shouldInterrupt && hadCredentials {
c.interruptConnections()
}
if shouldEmit {
c.emitStatusUpdate()
}
return err
}

View File

@@ -0,0 +1,84 @@
package ccm
import (
"math/rand/v2"
"os"
"path/filepath"
"time"
E "github.com/sagernet/sing/common/exceptions"
)
// acquireCredentialLock acquires a cross-process lock compatible with Claude Code's
// proper-lockfile protocol. The lock is a directory created via mkdir (atomic on
// POSIX filesystems).
//
// ref (@anthropic-ai/claude-code @2.1.81): cli.js _P1 (line 179530-179577)
// ref: proper-lockfile mkdir protocol (cli.js:43570)
// ref: proper-lockfile default options — stale=10s, update=stale/2=5s, realpath=true (cli.js:43661-43664)
//
// Claude Code locks d1() (= ~/.claude config dir). The lock directory is
// <realpath(configDir)>.lock (proper-lockfile default: <path>.lock).
// Manual retry: initial + 5 retries = 6 total, delay 1+rand(1s) per retry.
func acquireCredentialLock(configDir string) (func(), error) {
// ref: cli.js _P1 line 179531 — mkdir -p configDir before locking
os.MkdirAll(configDir, 0o700)
// ref: proper-lockfile realpath:true (cli.js:43664) — resolve symlinks before appending .lock
resolved, err := filepath.EvalSymlinks(configDir)
if err != nil {
resolved = filepath.Clean(configDir)
}
lockPath := resolved + ".lock"
// ref: cli.js _P1 line 179539-179543 — initial + 5 retries = 6 total attempts
for attempt := 0; attempt < 6; attempt++ {
if attempt > 0 {
// ref: cli.js _P1 line 179542 — 1000 + Math.random() * 1000
delay := time.Second + time.Duration(rand.IntN(1000))*time.Millisecond
time.Sleep(delay)
}
err = os.Mkdir(lockPath, 0o755)
if err == nil {
return startLockHeartbeat(lockPath), nil
}
if !os.IsExist(err) {
return nil, E.Cause(err, "create lock directory")
}
// ref: proper-lockfile stale check (cli.js:43603-43604)
// stale threshold = 10s (cli.js:43662)
info, statErr := os.Stat(lockPath)
if statErr != nil {
continue
}
if time.Since(info.ModTime()) > 10*time.Second {
os.Remove(lockPath)
}
}
return nil, E.New("credential lock timeout")
}
// startLockHeartbeat spawns a goroutine that touches the lock directory's mtime
// every 5 seconds to prevent stale detection by other processes.
//
// ref: proper-lockfile update interval = stale/2 = 5s (cli.js:43662-43663)
//
// Returns a release function that stops the heartbeat and removes the lock directory.
func startLockHeartbeat(lockPath string) func() {
done := make(chan struct{})
go func() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
os.Chtimes(lockPath, now, now)
case <-done:
return
}
}
}()
return func() {
close(done)
os.Remove(lockPath)
}
}

View File

@@ -0,0 +1,327 @@
package ccm
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"os"
"os/user"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
)
const (
oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
oauth2TokenURL = "https://platform.claude.com/v1/oauth/token"
claudeAPIBaseURL = "https://api.anthropic.com"
anthropicBetaOAuthValue = "oauth-2025-04-20"
// ref (@anthropic-ai/claude-code @2.1.81): cli.js vB (line 172879)
tokenRefreshBufferMs = 300000
)
// ref (@anthropic-ai/claude-code @2.1.81): cli.js q78 (line 33167)
// These scopes may change across Claude Code versions.
var defaultOAuthScopes = []string{
"user:profile", "user:inference", "user:sessions:claude_code",
"user:mcp_servers", "user:file_upload",
}
// resolveRefreshScopes determines which scopes to send in the token refresh request.
//
// ref (@anthropic-ai/claude-code @2.1.81): cli.js NR() (line 172693) + mB6 scope logic (line 172761)
//
// Claude Code behavior: if stored scopes include "user:inference", send default
// scopes; otherwise send the stored scopes verbatim.
func resolveRefreshScopes(stored []string) string {
if len(stored) == 0 || slices.Contains(stored, "user:inference") {
return strings.Join(defaultOAuthScopes, " ")
}
return strings.Join(stored, " ")
}
const (
ccmRefreshUserAgent = "axios/1.13.6"
ccmUserAgentFallback = "claude-code/2.1.85"
)
var (
ccmUserAgentOnce sync.Once
ccmUserAgentValue string
)
func initCCMUserAgent(logger log.ContextLogger) {
ccmUserAgentOnce.Do(func() {
version, err := detectClaudeCodeVersion()
if err != nil {
logger.Error("detect Claude Code version: ", err)
ccmUserAgentValue = ccmUserAgentFallback
return
}
logger.Debug("detected Claude Code version: ", version)
ccmUserAgentValue = "claude-code/" + version
})
}
func detectClaudeCodeVersion() (string, error) {
userInfo, err := getRealUser()
if err != nil {
return "", E.Cause(err, "get user")
}
binaryName := "claude"
if runtime.GOOS == "windows" {
binaryName = "claude.exe"
}
linkPath := filepath.Join(userInfo.HomeDir, ".local", "bin", binaryName)
target, err := os.Readlink(linkPath)
if err != nil {
return "", E.Cause(err, "readlink ", linkPath)
}
if !filepath.IsAbs(target) {
target = filepath.Join(filepath.Dir(linkPath), target)
}
parent := filepath.Base(filepath.Dir(target))
if parent != "versions" {
return "", E.New("unexpected symlink target: ", target)
}
return filepath.Base(target), nil
}
// resolveConfigDir returns the Claude config directory for lock coordination.
//
// ref (@anthropic-ai/claude-code @2.1.81): cli.js d1() (line 2983) — config dir used for locking
func resolveConfigDir(credentialPath string, credentialFilePath string) string {
if credentialPath == "" {
if configDir := os.Getenv("CLAUDE_CONFIG_DIR"); configDir != "" {
return configDir
}
userInfo, err := getRealUser()
if err == nil {
return filepath.Join(userInfo.HomeDir, ".claude")
}
}
return filepath.Dir(credentialFilePath)
}
func getRealUser() (*user.User, error) {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
sudoUserInfo, err := user.Lookup(sudoUser)
if err == nil {
return sudoUserInfo, nil
}
}
return user.Current()
}
func getDefaultCredentialsPath() (string, error) {
if configDir := os.Getenv("CLAUDE_CONFIG_DIR"); configDir != "" {
return filepath.Join(configDir, ".credentials.json"), nil
}
userInfo, err := getRealUser()
if err != nil {
return "", err
}
return filepath.Join(userInfo.HomeDir, ".claude", ".credentials.json"), nil
}
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var credentialsContainer struct {
ClaudeAIAuth *oauthCredentials `json:"claudeAiOauth,omitempty"`
}
err = json.Unmarshal(data, &credentialsContainer)
if err != nil {
return nil, err
}
if credentialsContainer.ClaudeAIAuth == nil {
return nil, E.New("claudeAiOauth field not found in credentials")
}
return credentialsContainer.ClaudeAIAuth, nil
}
func checkCredentialFileWritable(path string) error {
file, err := os.OpenFile(path, os.O_WRONLY, 0)
if err != nil {
return err
}
return file.Close()
}
// writeCredentialsToFile performs a read-modify-write: reads the existing JSON,
// replaces only the claudeAiOauth key, and writes back. This preserves any
// other top-level keys in the credential file.
//
// ref (@anthropic-ai/claude-code @2.1.81): cli.js BP6 (line 179444-179454) — read-modify-write
// ref: cli.js qD1.update (line 176156) — writeFileSync + chmod 0o600
func writeCredentialsToFile(credentials *oauthCredentials, path string) error {
return writeStorageValue(jsonFileStorage{path: path}, "claudeAiOauth", credentials)
}
// oauthCredentials mirrors the claudeAiOauth object in Claude Code's
// credential file ($CLAUDE_CONFIG_DIR/.credentials.json).
//
// ref (@anthropic-ai/claude-code @2.1.81): cli.js BP6 (line 179446-179452)
type oauthCredentials struct {
AccessToken string `json:"accessToken"` // ref: cli.js line 179447
RefreshToken string `json:"refreshToken"` // ref: cli.js line 179448
ExpiresAt int64 `json:"expiresAt"` // ref: cli.js line 179449 (epoch ms)
Scopes []string `json:"scopes"` // ref: cli.js line 179450
SubscriptionType *string `json:"subscriptionType"` // ref: cli.js line 179451 (?? null)
RateLimitTier *string `json:"rateLimitTier"` // ref: cli.js line 179452 (?? null)
}
type oauthRefreshResult struct {
Credentials *oauthCredentials
TokenAccount *claudeOAuthAccount
Profile *claudeProfileSnapshot
}
func (c *oauthCredentials) needsRefresh() bool {
if c.ExpiresAt == 0 {
return false
}
return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs
}
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthRefreshResult, time.Duration, error) {
if credentials.RefreshToken == "" {
return nil, 0, E.New("refresh token is empty")
}
// ref (@anthropic-ai/claude-code @2.1.81): cli.js mB6 (line 172757-172761)
requestBody, err := json.Marshal(map[string]string{
"grant_type": "refresh_token",
"refresh_token": credentials.RefreshToken,
"client_id": oauth2ClientID,
"scope": resolveRefreshScopes(credentials.Scopes),
})
if err != nil {
return nil, 0, E.Cause(err, "marshal request")
}
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("User-Agent", ccmRefreshUserAgent)
return request, nil
})
if err != nil {
return nil, 0, err
}
defer response.Body.Close()
if response.StatusCode == http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
retryDelay := time.Duration(-1)
if retryAfter := response.Header.Get("Retry-After"); retryAfter != "" {
seconds, parseErr := strconv.ParseInt(retryAfter, 10, 64)
if parseErr == nil && seconds > 0 {
retryDelay = time.Duration(seconds) * time.Second
}
}
return nil, retryDelay, E.New("refresh rate limited: ", response.Status, " ", string(body))
}
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
return nil, 0, E.New("refresh failed: ", response.Status, " ", string(body))
}
// ref (@anthropic-ai/claude-code @2.1.81): cli.js mB6 response (line 172769-172772)
var tokenResponse struct {
AccessToken string `json:"access_token"` // ref: cli.js line 172770 z
RefreshToken string `json:"refresh_token"` // ref: cli.js line 172770 w (defaults to input)
ExpiresIn int `json:"expires_in"` // ref: cli.js line 172770 O
Scope *string `json:"scope"` // ref: cli.js line 172772 uB6(Y.scope)
Account *struct {
UUID string `json:"uuid"`
EmailAddress string `json:"email_address"`
} `json:"account"`
Organization *struct {
UUID string `json:"uuid"`
} `json:"organization"`
}
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
if err != nil {
return nil, 0, E.Cause(err, "decode response")
}
newCredentials := *credentials
newCredentials.AccessToken = tokenResponse.AccessToken
if tokenResponse.RefreshToken != "" {
newCredentials.RefreshToken = tokenResponse.RefreshToken
}
newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000
// ref: cli.js uB6 (line 172696-172697): A?.split(" ").filter(Boolean)
// strings.Fields matches .filter(Boolean): splits on whitespace runs, removes empty strings
if tokenResponse.Scope != nil {
newCredentials.Scopes = strings.Fields(*tokenResponse.Scope)
}
return &oauthRefreshResult{
Credentials: &newCredentials,
TokenAccount: extractTokenAccount(tokenResponse.Account, tokenResponse.Organization),
}, 0, nil
}
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
if credentials == nil {
return nil
}
cloned := *credentials
cloned.Scopes = append([]string(nil), credentials.Scopes...)
cloned.SubscriptionType = cloneStringPointer(credentials.SubscriptionType)
cloned.RateLimitTier = cloneStringPointer(credentials.RateLimitTier)
return &cloned
}
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
if left == nil || right == nil {
return left == right
}
return left.AccessToken == right.AccessToken &&
left.RefreshToken == right.RefreshToken &&
left.ExpiresAt == right.ExpiresAt &&
slices.Equal(left.Scopes, right.Scopes) &&
equalStringPointer(left.SubscriptionType, right.SubscriptionType) &&
equalStringPointer(left.RateLimitTier, right.RateLimitTier)
}
func extractTokenAccount(account *struct {
UUID string `json:"uuid"`
EmailAddress string `json:"email_address"`
}, organization *struct {
UUID string `json:"uuid"`
},
) *claudeOAuthAccount {
if account == nil && organization == nil {
return nil
}
tokenAccount := &claudeOAuthAccount{}
if account != nil {
tokenAccount.AccountUUID = account.UUID
tokenAccount.EmailAddress = account.EmailAddress
}
if organization != nil {
tokenAccount.OrganizationUUID = organization.UUID
}
if tokenAccount.AccountUUID == "" && tokenAccount.EmailAddress == "" && tokenAccount.OrganizationUUID == "" {
return nil
}
return tokenAccount
}

View File

@@ -0,0 +1,141 @@
package ccm
import (
"context"
"encoding/json"
"io"
"net/http"
"slices"
"strings"
"testing"
"time"
)
func TestRefreshTokenScopeParsing(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
storedScopes []string
responseBody string
expectedScope string
expected []string
}{
{
name: "missing scope preserves stored scopes",
storedScopes: []string{"user:profile", "user:inference"},
responseBody: `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`,
expectedScope: strings.Join(defaultOAuthScopes, " "),
expected: []string{"user:profile", "user:inference"},
},
{
name: "empty scope clears stored scopes",
storedScopes: []string{"user:profile", "user:inference"},
responseBody: `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600,"scope":""}`,
expectedScope: strings.Join(defaultOAuthScopes, " "),
expected: []string{},
},
{
name: "stored non inference scopes are sent verbatim",
storedScopes: []string{"user:profile"},
responseBody: `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600,"scope":"user:profile user:file_upload"}`,
expectedScope: "user:profile",
expected: []string{"user:profile", "user:file_upload"},
},
}
for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
var seenScope string
client := &http.Client{Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) {
body, err := io.ReadAll(request.Body)
if err != nil {
t.Fatal(err)
}
var payload map[string]string
if err := json.Unmarshal(body, &payload); err != nil {
t.Fatal(err)
}
seenScope = payload["scope"]
return newJSONResponse(http.StatusOK, testCase.responseBody), nil
})}
result, _, err := refreshToken(context.Background(), client, &oauthCredentials{
AccessToken: "old-token",
RefreshToken: "refresh-token",
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
Scopes: testCase.storedScopes,
})
if err != nil {
t.Fatal(err)
}
if seenScope != testCase.expectedScope {
t.Fatalf("expected request scope %q, got %q", testCase.expectedScope, seenScope)
}
if result == nil || result.Credentials == nil {
t.Fatal("expected refresh result credentials")
}
if !slices.Equal(result.Credentials.Scopes, testCase.expected) {
t.Fatalf("expected scopes %v, got %v", testCase.expected, result.Credentials.Scopes)
}
})
}
}
func TestRefreshTokenExtractsTokenAccount(t *testing.T) {
t.Parallel()
client := &http.Client{Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) {
return newJSONResponse(http.StatusOK, `{
"access_token":"new-token",
"refresh_token":"new-refresh",
"expires_in":3600,
"account":{"uuid":"account","email_address":"user@example.com"},
"organization":{"uuid":"org"}
}`), nil
})}
result, _, err := refreshToken(context.Background(), client, &oauthCredentials{
AccessToken: "old-token",
RefreshToken: "refresh-token",
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
Scopes: []string{"user:profile", "user:inference"},
})
if err != nil {
t.Fatal(err)
}
if result == nil || result.TokenAccount == nil {
t.Fatal("expected token account")
}
if result.TokenAccount.AccountUUID != "account" || result.TokenAccount.EmailAddress != "user@example.com" || result.TokenAccount.OrganizationUUID != "org" {
t.Fatalf("unexpected token account: %#v", result.TokenAccount)
}
}
func TestCredentialsEqualIncludesProfileFields(t *testing.T) {
t.Parallel()
subscriptionType := "max"
rateLimitTier := "default_claude_max_20x"
left := &oauthCredentials{
AccessToken: "token",
RefreshToken: "refresh",
ExpiresAt: 123,
Scopes: []string{"user:inference"},
SubscriptionType: &subscriptionType,
RateLimitTier: &rateLimitTier,
}
right := cloneCredentials(left)
if !credentialsEqual(left, right) {
t.Fatal("expected cloned credentials to be equal")
}
otherTier := "default_claude_max_5x"
right.RateLimitTier = &otherTier
if credentialsEqual(left, right) {
t.Fatal("expected different rate limit tier to break equality")
}
}

View File

@@ -0,0 +1,440 @@
package ccm
import (
"context"
"math/rand/v2"
"sync"
"sync/atomic"
"time"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
)
type credentialProvider interface {
selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error)
onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential
linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool
pollIfStale()
pollCredentialIfStale(credential Credential)
allCredentials() []Credential
close()
}
type singleCredentialProvider struct {
credential Credential
sessionAccess sync.RWMutex
sessions map[string]time.Time
}
func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) {
if !selection.allows(p.credential) {
return nil, false, E.New("credential ", p.credential.tagName(), " is filtered out")
}
if !p.credential.isAvailable() {
return nil, false, p.credential.unavailableError()
}
if !p.credential.isUsable() {
return nil, false, E.New("credential ", p.credential.tagName(), " is rate-limited")
}
var isNew bool
if sessionID != "" {
p.sessionAccess.Lock()
if p.sessions == nil {
p.sessions = make(map[string]time.Time)
}
_, exists := p.sessions[sessionID]
if !exists {
p.sessions[sessionID] = time.Now()
isNew = true
}
p.sessionAccess.Unlock()
}
return p.credential, isNew, nil
}
func (p *singleCredentialProvider) onRateLimited(_ string, credential Credential, resetAt time.Time, _ credentialSelection) Credential {
credential.markRateLimited(resetAt)
return nil
}
func (p *singleCredentialProvider) pollIfStale() {
now := time.Now()
p.sessionAccess.Lock()
for id, createdAt := range p.sessions {
if now.Sub(createdAt) > sessionExpiry {
delete(p.sessions, id)
}
}
p.sessionAccess.Unlock()
if time.Since(p.credential.lastUpdatedTime()) > p.credential.pollBackoff(defaultPollInterval) {
p.credential.pollUsage()
}
}
func (p *singleCredentialProvider) allCredentials() []Credential {
return []Credential{p.credential}
}
func (p *singleCredentialProvider) linkProviderInterrupt(_ Credential, _ credentialSelection, _ func()) func() bool {
return func() bool {
return false
}
}
func (p *singleCredentialProvider) pollCredentialIfStale(credential Credential) {
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) {
credential.pollUsage()
}
}
func (p *singleCredentialProvider) close() {}
type sessionEntry struct {
tag string
selectionScope credentialSelectionScope
createdAt time.Time
}
type credentialInterruptKey struct {
tag string
selectionScope credentialSelectionScope
}
type credentialInterruptEntry struct {
context context.Context
cancel context.CancelFunc
}
type balancerProvider struct {
credentials []Credential
strategy string
roundRobinIndex atomic.Uint64
rebalanceThreshold float64
sessionAccess sync.RWMutex
sessions map[string]sessionEntry
interruptAccess sync.Mutex
credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry
logger log.ContextLogger
}
func newBalancerProvider(credentials []Credential, strategy string, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider {
return &balancerProvider{
credentials: credentials,
strategy: strategy,
rebalanceThreshold: rebalanceThreshold,
sessions: make(map[string]sessionEntry),
credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry),
logger: logger,
}
}
func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) {
selectionScope := selection.scopeOrDefault()
for {
if p.strategy == C.BalancerStrategyFallback {
best := p.pickCredential(selection.filter)
if best == nil {
return nil, false, allCredentialsUnavailableError(p.credentials)
}
return best, p.storeSessionIfAbsent(sessionID, sessionEntry{createdAt: time.Now()}), nil
}
if sessionID != "" {
p.sessionAccess.RLock()
entry, exists := p.sessions[sessionID]
p.sessionAccess.RUnlock()
if exists {
if entry.selectionScope == selectionScope {
for _, credential := range p.credentials {
if credential.tagName() == entry.tag && selection.allows(credential) && credential.isUsable() {
if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) {
better := p.pickLeastUsed(selection.filter)
if better != nil && better.tagName() != credential.tagName() {
effectiveThreshold := p.rebalanceThreshold / credential.planWeight()
delta := credential.weeklyUtilization() - better.weeklyUtilization()
if delta > effectiveThreshold {
p.logger.Info("rebalancing away from ", credential.tagName(),
": utilization delta ", delta, "% exceeds effective threshold ",
effectiveThreshold, "% (weight ", credential.planWeight(), ")")
p.rebalanceCredential(credential.tagName(), selectionScope)
break
}
}
}
return credential, false, nil
}
}
}
p.sessionAccess.Lock()
currentEntry, stillExists := p.sessions[sessionID]
if stillExists && currentEntry == entry {
delete(p.sessions, sessionID)
p.sessionAccess.Unlock()
} else {
p.sessionAccess.Unlock()
continue
}
}
}
best := p.pickCredential(selection.filter)
if best == nil {
return nil, false, allCredentialsUnavailableError(p.credentials)
}
if p.storeSessionIfAbsent(sessionID, sessionEntry{
tag: best.tagName(),
selectionScope: selectionScope,
createdAt: time.Now(),
}) {
return best, true, nil
}
if sessionID == "" {
return best, false, nil
}
}
}
func (p *balancerProvider) storeSessionIfAbsent(sessionID string, entry sessionEntry) bool {
if sessionID == "" {
return false
}
p.sessionAccess.Lock()
defer p.sessionAccess.Unlock()
if _, exists := p.sessions[sessionID]; exists {
return false
}
p.sessions[sessionID] = entry
return true
}
func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) {
key := credentialInterruptKey{tag: tag, selectionScope: selectionScope}
p.interruptAccess.Lock()
if entry, loaded := p.credentialInterrupts[key]; loaded {
entry.cancel()
}
ctx, cancel := context.WithCancel(context.Background())
p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel}
p.interruptAccess.Unlock()
p.sessionAccess.Lock()
for id, entry := range p.sessions {
if entry.tag == tag && entry.selectionScope == selectionScope {
delete(p.sessions, id)
}
}
p.sessionAccess.Unlock()
}
func (p *balancerProvider) linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool {
if p.strategy == C.BalancerStrategyFallback {
return func() bool { return false }
}
key := credentialInterruptKey{
tag: credential.tagName(),
selectionScope: selection.scopeOrDefault(),
}
p.interruptAccess.Lock()
entry, loaded := p.credentialInterrupts[key]
if !loaded {
ctx, cancel := context.WithCancel(context.Background())
entry = credentialInterruptEntry{context: ctx, cancel: cancel}
p.credentialInterrupts[key] = entry
}
p.interruptAccess.Unlock()
return context.AfterFunc(entry.context, onInterrupt)
}
func (p *balancerProvider) onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential {
credential.markRateLimited(resetAt)
if p.strategy == C.BalancerStrategyFallback {
return p.pickCredential(selection.filter)
}
if sessionID != "" {
p.sessionAccess.Lock()
delete(p.sessions, sessionID)
p.sessionAccess.Unlock()
}
best := p.pickCredential(selection.filter)
if best != nil && sessionID != "" {
p.sessionAccess.Lock()
p.sessions[sessionID] = sessionEntry{
tag: best.tagName(),
selectionScope: selection.scopeOrDefault(),
createdAt: time.Now(),
}
p.sessionAccess.Unlock()
}
return best
}
func (p *balancerProvider) pickCredential(filter func(Credential) bool) Credential {
switch p.strategy {
case C.BalancerStrategyRoundRobin:
return p.pickRoundRobin(filter)
case C.BalancerStrategyRandom:
return p.pickRandom(filter)
case C.BalancerStrategyFallback:
return p.pickFallback(filter)
default:
return p.pickLeastUsed(filter)
}
}
func (p *balancerProvider) pickFallback(filter func(Credential) bool) Credential {
for _, credential := range p.credentials {
if filter != nil && !filter(credential) {
continue
}
if credential.isUsable() {
return credential
}
}
return nil
}
const weeklyWindowHours = 7 * 24
func (p *balancerProvider) pickLeastUsed(filter func(Credential) bool) Credential {
var best Credential
bestScore := float64(-1)
now := time.Now()
for _, credential := range p.credentials {
if filter != nil && !filter(credential) {
continue
}
if !credential.isUsable() {
continue
}
remaining := credential.weeklyCap() - credential.weeklyUtilization()
score := remaining * credential.planWeight()
resetTime := credential.weeklyResetTime()
if !resetTime.IsZero() {
timeUntilReset := resetTime.Sub(now)
if timeUntilReset < time.Hour {
timeUntilReset = time.Hour
}
score *= weeklyWindowHours / timeUntilReset.Hours()
}
if score > bestScore {
bestScore = score
best = credential
}
}
return best
}
func (p *balancerProvider) pickRoundRobin(filter func(Credential) bool) Credential {
start := int(p.roundRobinIndex.Add(1) - 1)
count := len(p.credentials)
for offset := range count {
candidate := p.credentials[(start+offset)%count]
if filter != nil && !filter(candidate) {
continue
}
if candidate.isUsable() {
return candidate
}
}
return nil
}
func (p *balancerProvider) pickRandom(filter func(Credential) bool) Credential {
var usable []Credential
for _, candidate := range p.credentials {
if filter != nil && !filter(candidate) {
continue
}
if candidate.isUsable() {
usable = append(usable, candidate)
}
}
if len(usable) == 0 {
return nil
}
return usable[rand.IntN(len(usable))]
}
func (p *balancerProvider) pollIfStale() {
now := time.Now()
p.sessionAccess.Lock()
for id, entry := range p.sessions {
if now.Sub(entry.createdAt) > sessionExpiry {
delete(p.sessions, id)
}
}
p.sessionAccess.Unlock()
p.interruptAccess.Lock()
for key, entry := range p.credentialInterrupts {
if entry.context.Err() != nil {
delete(p.credentialInterrupts, key)
}
}
p.interruptAccess.Unlock()
for _, credential := range p.credentials {
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) {
credential.pollUsage()
}
}
}
func (p *balancerProvider) pollCredentialIfStale(credential Credential) {
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) {
credential.pollUsage()
}
}
func (p *balancerProvider) allCredentials() []Credential {
return p.credentials
}
func (p *balancerProvider) close() {}
func ccmPlanWeight(accountType string, rateLimitTier string) float64 {
switch accountType {
case "max":
switch rateLimitTier {
case "default_claude_max_20x":
return 10
case "default_claude_max_5x":
return 5
default:
return 5
}
case "team":
if rateLimitTier == "default_claude_max_5x" {
return 5
}
return 1
default:
return 1
}
}
func allCredentialsUnavailableError(credentials []Credential) error {
var hasUnavailable bool
var earliest time.Time
for _, credential := range credentials {
if credential.unavailableError() != nil {
hasUnavailable = true
continue
}
resetAt := credential.earliestReset()
if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) {
earliest = resetAt
}
}
if hasUnavailable {
return E.New("all credentials unavailable")
}
if earliest.IsZero() {
return E.New("all credentials rate-limited")
}
return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest)))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,124 @@
package ccm
import (
"encoding/json"
"errors"
"os"
"path/filepath"
)
type jsonContainerStorage interface {
readContainer() (map[string]json.RawMessage, bool, error)
writeContainer(map[string]json.RawMessage) error
delete() error
}
type jsonFileStorage struct {
path string
}
func (s jsonFileStorage) readContainer() (map[string]json.RawMessage, bool, error) {
data, err := os.ReadFile(s.path)
if err != nil {
if os.IsNotExist(err) {
return make(map[string]json.RawMessage), false, nil
}
return nil, false, err
}
container := make(map[string]json.RawMessage)
if len(data) == 0 {
return container, true, nil
}
if err := json.Unmarshal(data, &container); err != nil {
return nil, true, err
}
return container, true, nil
}
func (s jsonFileStorage) writeContainer(container map[string]json.RawMessage) error {
if err := os.MkdirAll(filepath.Dir(s.path), 0o700); err != nil {
return err
}
data, err := json.MarshalIndent(container, "", " ")
if err != nil {
return err
}
return os.WriteFile(s.path, data, 0o600)
}
func (s jsonFileStorage) delete() error {
err := os.Remove(s.path)
if err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
func writeStorageValue(storage jsonContainerStorage, key string, value any) error {
container, _, err := storage.readContainer()
if err != nil {
var syntaxError *json.SyntaxError
var typeError *json.UnmarshalTypeError
if !errors.As(err, &syntaxError) && !errors.As(err, &typeError) {
return err
}
container = make(map[string]json.RawMessage)
}
if container == nil {
container = make(map[string]json.RawMessage)
}
encodedValue, err := json.Marshal(value)
if err != nil {
return err
}
container[key] = encodedValue
return storage.writeContainer(container)
}
func persistStorageValue(primary jsonContainerStorage, fallback jsonContainerStorage, key string, value any) error {
primaryErr := writeStorageValue(primary, key, value)
if primaryErr == nil {
if fallback != nil {
_ = fallback.delete()
}
return nil
}
if fallback == nil {
return primaryErr
}
if err := writeStorageValue(fallback, key, value); err != nil {
return err
}
_ = primary.delete()
return nil
}
func cloneStringPointer(value *string) *string {
if value == nil {
return nil
}
cloned := *value
return &cloned
}
func cloneBoolPointer(value *bool) *bool {
if value == nil {
return nil
}
cloned := *value
return &cloned
}
func equalStringPointer(left *string, right *string) bool {
if left == nil || right == nil {
return left == right
}
return *left == *right
}
func equalBoolPointer(left *bool, right *bool) bool {
if left == nil || right == nil {
return left == right
}
return *left == *right
}

View File

@@ -0,0 +1,125 @@
package ccm
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
type fakeJSONStorage struct {
container map[string]json.RawMessage
writeErr error
deleted bool
}
func (s *fakeJSONStorage) readContainer() (map[string]json.RawMessage, bool, error) {
if s.container == nil {
return make(map[string]json.RawMessage), false, nil
}
cloned := make(map[string]json.RawMessage, len(s.container))
for key, value := range s.container {
cloned[key] = value
}
return cloned, true, nil
}
func (s *fakeJSONStorage) writeContainer(container map[string]json.RawMessage) error {
if s.writeErr != nil {
return s.writeErr
}
s.container = make(map[string]json.RawMessage, len(container))
for key, value := range container {
s.container[key] = value
}
return nil
}
func (s *fakeJSONStorage) delete() error {
s.deleted = true
s.container = nil
return nil
}
func TestPersistStorageValueDeletesFallbackOnPrimarySuccess(t *testing.T) {
t.Parallel()
primary := &fakeJSONStorage{}
fallback := &fakeJSONStorage{container: map[string]json.RawMessage{"stale": json.RawMessage(`true`)}}
if err := persistStorageValue(primary, fallback, "claudeAiOauth", &oauthCredentials{AccessToken: "token"}); err != nil {
t.Fatal(err)
}
if !fallback.deleted {
t.Fatal("expected fallback storage to be deleted after primary write")
}
}
func TestPersistStorageValueDeletesPrimaryAfterFallbackSuccess(t *testing.T) {
t.Parallel()
primary := &fakeJSONStorage{
container: map[string]json.RawMessage{"claudeAiOauth": json.RawMessage(`{"accessToken":"old"}`)},
writeErr: os.ErrPermission,
}
fallback := &fakeJSONStorage{}
if err := persistStorageValue(primary, fallback, "claudeAiOauth", &oauthCredentials{AccessToken: "new"}); err != nil {
t.Fatal(err)
}
if !primary.deleted {
t.Fatal("expected primary storage to be deleted after fallback write")
}
}
func TestWriteCredentialsToFilePreservesTopLevelKeys(t *testing.T) {
t.Parallel()
directory := t.TempDir()
path := filepath.Join(directory, ".credentials.json")
initial := []byte(`{"keep":{"nested":true},"claudeAiOauth":{"accessToken":"old"}}`)
if err := os.WriteFile(path, initial, 0o600); err != nil {
t.Fatal(err)
}
if err := writeCredentialsToFile(&oauthCredentials{AccessToken: "new"}, path); err != nil {
t.Fatal(err)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var container map[string]json.RawMessage
if err := json.Unmarshal(data, &container); err != nil {
t.Fatal(err)
}
if _, exists := container["keep"]; !exists {
t.Fatal("expected unknown top-level key to be preserved")
}
}
func TestWriteClaudeCodeOAuthAccountPreservesTopLevelKeys(t *testing.T) {
t.Parallel()
directory := t.TempDir()
path := filepath.Join(directory, ".claude.json")
initial := []byte(`{"keep":{"nested":true},"oauthAccount":{"accountUuid":"old"}}`)
if err := os.WriteFile(path, initial, 0o600); err != nil {
t.Fatal(err)
}
if err := writeClaudeCodeOAuthAccount(path, &claudeOAuthAccount{AccountUUID: "new"}); err != nil {
t.Fatal(err)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var container map[string]json.RawMessage
if err := json.Unmarshal(data, &container); err != nil {
t.Fatal(err)
}
if _, exists := container["keep"]; !exists {
t.Fatal("expected unknown config key to be preserved")
}
}

View File

@@ -0,0 +1,76 @@
package ccm
import "time"
type availabilityState string
const (
availabilityStateUsable availabilityState = "usable"
availabilityStateRateLimited availabilityState = "rate_limited"
availabilityStateTemporarilyBlocked availabilityState = "temporarily_blocked"
availabilityStateUnavailable availabilityState = "unavailable"
availabilityStateUnknown availabilityState = "unknown"
)
type availabilityReason string
const (
availabilityReasonHardRateLimit availabilityReason = "hard_rate_limit"
availabilityReasonConnectionLimit availabilityReason = "connection_limit"
availabilityReasonPollFailed availabilityReason = "poll_failed"
availabilityReasonUpstreamRejected availabilityReason = "upstream_rejected"
availabilityReasonNoCredentials availabilityReason = "no_credentials"
availabilityReasonUnknown availabilityReason = "unknown"
)
type availabilityStatus struct {
State availabilityState
Reason availabilityReason
ResetAt time.Time
}
func (s availabilityStatus) normalized() availabilityStatus {
if s.State == "" {
s.State = availabilityStateUnknown
}
if s.Reason == "" && s.State != availabilityStateUsable {
s.Reason = availabilityReasonUnknown
}
return s
}
func claudeWindowProgress(resetAt time.Time, windowSeconds float64, now time.Time) float64 {
if resetAt.IsZero() || windowSeconds <= 0 {
return 0
}
windowStart := resetAt.Add(-time.Duration(windowSeconds * float64(time.Second)))
if now.Before(windowStart) {
return 0
}
progress := now.Sub(windowStart).Seconds() / windowSeconds
if progress < 0 {
return 0
}
if progress > 1 {
return 1
}
return progress
}
func claudeFiveHourWarning(utilizationPercent float64, resetAt time.Time, now time.Time) bool {
return utilizationPercent >= 90 && claudeWindowProgress(resetAt, 5*60*60, now) >= 0.72
}
func claudeWeeklyWarning(utilizationPercent float64, resetAt time.Time, now time.Time) bool {
progress := claudeWindowProgress(resetAt, 7*24*60*60, now)
switch {
case utilizationPercent >= 75:
return progress >= 0.60
case utilizationPercent >= 50:
return progress >= 0.35
case utilizationPercent >= 25:
return progress >= 0.15
default:
return false
}
}

View File

@@ -4,7 +4,6 @@ import (
"bufio"
"context"
stdTLS "crypto/tls"
"errors"
"io"
"math/rand/v2"
"net"
@@ -18,14 +17,14 @@ import (
"github.com/hashicorp/yamux"
)
func reverseYamuxConfig() *yamux.Config {
var defaultYamuxConfig = func() *yamux.Config {
config := yamux.DefaultConfig()
config.KeepAliveInterval = 15 * time.Second
config.ConnectionWriteTimeout = 10 * time.Second
config.MaxStreamWindowSize = 512 * 1024
config.LogOutput = io.Discard
return config
}
}()
type bufferedConn struct {
reader *bufio.Reader
@@ -58,6 +57,12 @@ func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWrite
return
}
if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"API key authentication is not supported; use Authorization: Bearer with a CCM user token")
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
@@ -103,7 +108,7 @@ func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWrite
return
}
session, err := yamux.Client(conn, reverseYamuxConfig())
session, err := yamux.Client(conn, defaultYamuxConfig)
if err != nil {
conn.Close()
s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
@@ -124,13 +129,13 @@ func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWrite
}
func (s *Service) findReceiverCredential(token string) *externalCredential {
for _, cred := range s.allCredentials {
extCred, ok := cred.(*externalCredential)
if !ok {
for _, credential := range s.allCredentials {
external, ok := credential.(*externalCredential)
if !ok || external.connectorURL != nil {
continue
}
if extCred.baseURL == reverseProxyBaseURL && extCred.token == token {
return extCred
if external.token == token {
return external
}
}
return nil
@@ -156,9 +161,11 @@ func (c *externalCredential) connectorLoop() {
consecutiveFailures++
backoff := connectorBackoff(consecutiveFailures)
c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff)
timer := time.NewTimer(backoff)
select {
case <-time.After(backoff):
case <-timer.C:
case <-ctx.Done():
timer.Stop()
return
}
}
@@ -231,7 +238,7 @@ func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duratio
}
}
session, err := yamux.Server(&bufferedConn{reader: reader, Conn: conn}, reverseYamuxConfig())
session, err := yamux.Server(&bufferedConn{reader: reader, Conn: conn}, defaultYamuxConfig)
if err != nil {
conn.Close()
return 0, E.Cause(err, "create yamux server")
@@ -248,7 +255,7 @@ func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duratio
}
err = httpServer.Serve(&yamuxNetListener{session: session})
sessionLifetime := time.Since(serveStart)
if err != nil && !errors.Is(err, http.ErrServerClosed) && ctx.Err() == nil {
if err != nil && !E.IsClosed(err) && ctx.Err() == nil {
return sessionLifetime, E.Cause(err, "serve")
}
return sessionLifetime, E.New("connection closed")

View File

@@ -1,14 +1,9 @@
package ccm
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"mime"
"net/http"
"strconv"
"strings"
"sync"
"time"
@@ -21,45 +16,30 @@ import (
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/observable"
aTLS "github.com/sagernet/sing/common/tls"
"github.com/anthropics/anthropic-sdk-go"
anthropicconstant "github.com/anthropics/anthropic-sdk-go/shared/constant"
"github.com/go-chi/chi/v5"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
const (
contextWindowStandard = 200000
contextWindowPremium = 1000000
premiumContextThreshold = 200000
retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential"
)
const retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential"
func RegisterService(registry *boxService.Registry) {
boxService.Register[option.CCMServiceOptions](registry, C.TypeCCM, NewService)
}
type errorResponse struct {
Type string `json:"type"`
Error errorDetails `json:"error"`
RequestID string `json:"request_id,omitempty"`
}
type errorDetails struct {
Type string `json:"type"`
Message string `json:"message"`
}
func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(errorResponse{
Type: "error",
Error: errorDetails{
json.NewEncoder(w).Encode(anthropic.ErrorResponse{
Type: anthropicconstant.Error("").Default(),
Error: anthropic.ErrorObjectUnion{
Type: errorType,
Message: message,
},
@@ -67,18 +47,18 @@ func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, erro
})
}
func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool {
func hasAlternativeCredential(provider credentialProvider, currentCredential Credential, selection credentialSelection) bool {
if provider == nil || currentCredential == nil {
return false
}
for _, cred := range provider.allCredentials() {
if cred == currentCredential {
for _, credential := range provider.allCredentials() {
if credential == currentCredential {
continue
}
if filter != nil && !filter(cred) {
if !selection.allows(credential) {
continue
}
if cred.isUsable() {
if credential.isUsable() {
return true
}
}
@@ -108,17 +88,32 @@ func writeCredentialUnavailableError(
w http.ResponseWriter,
r *http.Request,
provider credentialProvider,
currentCredential credential,
filter func(credential) bool,
currentCredential Credential,
selection credentialSelection,
fallback string,
) {
if hasAlternativeCredential(provider, currentCredential, filter) {
if hasAlternativeCredential(provider, currentCredential, selection) {
writeRetryableUsageError(w, r)
return
}
if provider != nil && strings.HasPrefix(allCredentialsUnavailableError(provider.allCredentials()).Error(), "all credentials rate-limited") {
writeRetryableUsageError(w, r)
return
}
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, fallback))
}
func credentialSelectionForUser(userConfig *option.CCMUser) credentialSelection {
selection := credentialSelection{scope: credentialSelectionScopeAll}
if userConfig != nil && !userConfig.AllowExternalUsage {
selection.scope = credentialSelectionScopeNonExternal
selection.filter = func(credential Credential) bool {
return !credential.isExternal()
}
}
return selection
}
func isHopByHopHeader(header string) bool {
switch strings.ToLower(header) {
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
@@ -141,49 +136,75 @@ func isReverseProxyHeader(header string) bool {
}
}
const (
weeklyWindowSeconds = 604800
weeklyWindowMinutes = weeklyWindowSeconds / 60
)
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
resetAt, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset")
if !exists {
return nil
}
return &WeeklyCycleHint{
WindowMinutes: weeklyWindowMinutes,
ResetAt: resetAt.UTC(),
func isAPIKeyHeader(header string) bool {
switch strings.ToLower(header) {
case "x-api-key", "api-key":
return true
default:
return false
}
}
type Service struct {
boxService.Adapter
ctx context.Context
logger log.ContextLogger
options option.CCMServiceOptions
httpHeaders http.Header
listener *listener.Listener
tlsConfig tls.ServerConfig
httpServer *http.Server
userManager *UserManager
trackingGroup sync.WaitGroup
shuttingDown bool
ctx context.Context
logger log.ContextLogger
options option.CCMServiceOptions
httpHeaders http.Header
listener *listener.Listener
tlsConfig tls.ServerConfig
httpServer *http.Server
userManager *UserManager
// Legacy mode (single credential)
legacyCredential *defaultCredential
legacyProvider credentialProvider
// Multi-credential mode
providers map[string]credentialProvider
allCredentials []credential
allCredentials []Credential
userConfigMap map[string]*option.CCMUser
sessionModelAccess sync.Mutex
sessionModels map[sessionModelKey]time.Time
statusSubscriber *observable.Subscriber[struct{}]
statusObserver *observable.Observer[struct{}]
}
type sessionModelKey struct {
sessionID string
model string
}
func (s *Service) cleanSessionModels() {
now := time.Now()
s.sessionModelAccess.Lock()
for key, createdAt := range s.sessionModels {
if now.Sub(createdAt) > sessionExpiry {
delete(s.sessionModels, key)
}
}
s.sessionModelAccess.Unlock()
}
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) {
initCCMUserAgent(logger)
hasLegacy := options.CredentialPath != "" || options.UsagesPath != "" || options.Detour != ""
if hasLegacy && len(options.Credentials) > 0 {
return nil, E.New("credential_path/usages_path/detour and credentials are mutually exclusive")
}
if len(options.Credentials) == 0 {
options.Credentials = []option.CCMCredential{{
Type: "default",
Tag: "default",
DefaultOptions: option.CCMDefaultCredentialOptions{
CredentialPath: options.CredentialPath,
UsagesPath: options.UsagesPath,
Detour: options.Detour,
},
}}
options.CredentialPath = ""
options.UsagesPath = ""
options.Detour = ""
}
err := validateCCMOptions(options)
if err != nil {
return nil, E.Cause(err, "validate options")
@@ -193,6 +214,7 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
tokenMap: make(map[string]string),
}
statusSubscriber := observable.NewSubscriber[struct{}](16)
service := &Service{
Adapter: boxService.NewAdapter(C.TypeCCM, tag),
ctx: ctx,
@@ -205,35 +227,24 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
Network: []string{N.NetworkTCP},
Listen: options.ListenOptions,
}),
userManager: userManager,
userManager: userManager,
sessionModels: make(map[sessionModelKey]time.Time),
statusSubscriber: statusSubscriber,
statusObserver: observable.NewObserver[struct{}](statusSubscriber, 8),
}
if len(options.Credentials) > 0 {
providers, allCredentials, err := buildCredentialProviders(ctx, options, logger)
if err != nil {
return nil, E.Cause(err, "build credential providers")
}
service.providers = providers
service.allCredentials = allCredentials
userConfigMap := make(map[string]*option.CCMUser)
for i := range options.Users {
userConfigMap[options.Users[i].Name] = &options.Users[i]
}
service.userConfigMap = userConfigMap
} else {
cred, err := newDefaultCredential(ctx, "default", option.CCMDefaultCredentialOptions{
CredentialPath: options.CredentialPath,
UsagesPath: options.UsagesPath,
Detour: options.Detour,
}, logger)
if err != nil {
return nil, err
}
service.legacyCredential = cred
service.legacyProvider = &singleCredentialProvider{cred: cred}
service.allCredentials = []credential{cred}
providers, allCredentials, err := buildCredentialProviders(ctx, options, logger)
if err != nil {
return nil, E.Cause(err, "build credential providers")
}
service.providers = providers
service.allCredentials = allCredentials
userConfigMap := make(map[string]*option.CCMUser)
for i := range options.Users {
userConfigMap[options.Users[i].Name] = &options.Users[i]
}
service.userConfigMap = userConfigMap
if options.TLS != nil {
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
@@ -253,11 +264,12 @@ func (s *Service) Start(stage adapter.StartStage) error {
s.userManager.UpdateUsers(s.options.Users)
for _, cred := range s.allCredentials {
if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil {
extCred.reverseService = s
for _, credential := range s.allCredentials {
credential.setStatusSubscriber(s.statusSubscriber)
if external, ok := credential.(*externalCredential); ok && external.reverse && external.connectorURL != nil {
external.reverseService = s
}
err := cred.start()
err := credential.start()
if err != nil {
return err
}
@@ -289,7 +301,7 @@ func (s *Service) Start(stage adapter.StartStage) error {
go func() {
serveErr := s.httpServer.Serve(tcpListener)
if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
if serveErr != nil && !E.IsClosed(serveErr) {
s.logger.Error("serve error: ", serveErr)
}
}()
@@ -297,559 +309,30 @@ func (s *Service) Start(stage adapter.StartStage) error {
return nil
}
func isExtendedContextRequest(betaHeader string) bool {
for _, feature := range strings.Split(betaHeader, ",") {
if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") {
return true
}
}
return false
}
func isFastModeRequest(betaHeader string) bool {
for _, feature := range strings.Split(betaHeader, ",") {
if strings.HasPrefix(strings.TrimSpace(feature), "fast-mode") {
return true
}
}
return false
}
func detectContextWindow(betaHeader string, totalInputTokens int64) int {
if totalInputTokens > premiumContextThreshold {
if isExtendedContextRequest(betaHeader) {
return contextWindowPremium
}
}
return contextWindowStandard
}
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := log.ContextWithNewID(r.Context())
if r.URL.Path == "/ccm/v1/status" {
s.handleStatusEndpoint(w, r)
return
}
if r.URL.Path == "/ccm/v1/reverse" {
s.handleReverseConnect(ctx, w, r)
return
}
if !strings.HasPrefix(r.URL.Path, "/v1/") {
writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found")
return
}
var username string
if len(s.options.Users) > 0 {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
var ok bool
username, ok = s.userManager.Authenticate(clientToken)
if !ok {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
}
// Always read body to extract model and session ID
var bodyBytes []byte
var requestModel string
var messagesCount int
var sessionID string
if r.Body != nil {
var err error
bodyBytes, err = io.ReadAll(r.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read request body: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
return
}
var request struct {
Model string `json:"model"`
Messages []anthropic.MessageParam `json:"messages"`
}
err = json.Unmarshal(bodyBytes, &request)
if err == nil {
requestModel = request.Model
messagesCount = len(request.Messages)
}
sessionID = extractCCMSessionID(bodyBytes)
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
// Resolve credential provider and user config
var provider credentialProvider
var userConfig *option.CCMUser
if len(s.options.Users) > 0 {
userConfig = s.userConfigMap[username]
var err error
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
s.logger.ErrorContext(ctx, "resolve credential: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
} else {
provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
}
if provider == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
return
}
provider.pollIfStale(s.ctx)
anthropicBetaHeader := r.Header.Get("anthropic-beta")
if isFastModeRequest(anthropicBetaHeader) {
if _, isSingle := provider.(*singleCredentialProvider); !isSingle {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"fast mode requests will consume Extra usage, please use a default credential directly")
return
}
}
var credentialFilter func(credential) bool
if userConfig != nil && !userConfig.AllowExternalUsage {
credentialFilter = func(c credential) bool { return !c.isExternal() }
}
selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter)
if err != nil {
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error()))
return
}
if isNew {
logParts := []any{"assigned credential ", selectedCredential.tagName()}
if sessionID != "" {
logParts = append(logParts, " for session ", sessionID)
}
if username != "" {
logParts = append(logParts, " by user ", username)
}
if requestModel != "" {
modelDisplay := requestModel
if isExtendedContextRequest(anthropicBetaHeader) {
modelDisplay += "[1m]"
}
logParts = append(logParts, ", model=", modelDisplay)
}
s.logger.DebugContext(ctx, logParts...)
}
if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"fast mode requests cannot be proxied through external credentials")
return
}
requestContext := selectedCredential.wrapRequestContext(r.Context())
defer func() {
requestContext.cancelRequest()
}()
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if err != nil {
s.logger.ErrorContext(ctx, "create proxy request: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
return
}
response, err := selectedCredential.httpTransport().Do(proxyRequest)
if err != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request")
return
}
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
return
}
requestContext.releaseCredentialInterrupt()
// Transparent 429 retry
for response.StatusCode == http.StatusTooManyRequests {
resetAt := parseRateLimitResetFromHeaders(response.Header)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
selectedCredential.updateStateFromHeaders(response.Header)
if bodyBytes == nil || nextCredential == nil {
response.Body.Close()
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
return
}
response.Body.Close()
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
requestContext.cancelRequest()
requestContext = nextCredential.wrapRequestContext(r.Context())
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
return
}
retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest)
if retryErr != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
return
}
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
return
}
requestContext.releaseCredentialInterrupt()
response = retryResponse
selectedCredential = nextCredential
}
defer response.Body.Close()
selectedCredential.updateStateFromHeaders(response.Header)
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
go selectedCredential.pollUsage(s.ctx)
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
return
}
// Rewrite response headers for external users
if userConfig != nil && userConfig.ExternalCredential != "" {
s.rewriteResponseHeadersForExternalUser(response.Header, userConfig)
}
for key, values := range response.Header {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
w.Header()[key] = values
}
}
w.WriteHeader(response.StatusCode)
usageTracker := selectedCredential.usageTrackerOrNil()
if usageTracker != nil && response.StatusCode == http.StatusOK {
s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
} else {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
if err == nil && mediaType != "text/event-stream" {
_, _ = io.Copy(w, response.Body)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
buffer := make([]byte, buf.BufferSize)
for {
n, err := response.Body.Read(buffer)
if n > 0 {
_, writeError := w.Write(buffer[:n])
if writeError != nil {
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
return
}
}
}
}
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
isStreaming := err == nil && mediaType == "text/event-stream"
if !isStreaming {
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read response body: ", err)
return
}
var message anthropic.Message
var usage anthropic.Usage
var responseModel string
err = json.Unmarshal(bodyBytes, &message)
if err == nil {
responseModel = string(message.Model)
usage = message.Usage
}
if responseModel == "" {
responseModel = requestModel
}
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
if responseModel != "" {
totalInputTokens := usage.InputTokens + usage.CacheCreationInputTokens + usage.CacheReadInputTokens
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
messagesCount,
usage.InputTokens,
usage.OutputTokens,
usage.CacheReadInputTokens,
usage.CacheCreationInputTokens,
usage.CacheCreation.Ephemeral5mInputTokens,
usage.CacheCreation.Ephemeral1hInputTokens,
username,
time.Now(),
weeklyCycleHint,
)
}
}
_, _ = writer.Write(bodyBytes)
return
}
flusher, ok := writer.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
var accumulatedUsage anthropic.Usage
var responseModel string
buffer := make([]byte, buf.BufferSize)
var leftover []byte
for {
n, err := response.Body.Read(buffer)
if n > 0 {
data := append(leftover, buffer[:n]...)
lines := bytes.Split(data, []byte("\n"))
if err == nil {
leftover = lines[len(lines)-1]
lines = lines[:len(lines)-1]
} else {
leftover = nil
}
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if bytes.HasPrefix(line, []byte("data: ")) {
eventData := bytes.TrimPrefix(line, []byte("data: "))
if bytes.Equal(eventData, []byte("[DONE]")) {
continue
}
var event anthropic.MessageStreamEventUnion
err := json.Unmarshal(eventData, &event)
if err != nil {
continue
}
switch event.Type {
case "message_start":
messageStart := event.AsMessageStart()
if messageStart.Message.Model != "" {
responseModel = string(messageStart.Message.Model)
}
if messageStart.Message.Usage.InputTokens > 0 {
accumulatedUsage.InputTokens = messageStart.Message.Usage.InputTokens
accumulatedUsage.CacheReadInputTokens = messageStart.Message.Usage.CacheReadInputTokens
accumulatedUsage.CacheCreationInputTokens = messageStart.Message.Usage.CacheCreationInputTokens
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral5mInputTokens
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral1hInputTokens
}
case "message_delta":
messageDelta := event.AsMessageDelta()
if messageDelta.Usage.OutputTokens > 0 {
accumulatedUsage.OutputTokens = messageDelta.Usage.OutputTokens
}
}
}
}
_, writeError := writer.Write(buffer[:n])
if writeError != nil {
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
if responseModel == "" {
responseModel = requestModel
}
if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 {
if responseModel != "" {
totalInputTokens := accumulatedUsage.InputTokens + accumulatedUsage.CacheCreationInputTokens + accumulatedUsage.CacheReadInputTokens
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
messagesCount,
accumulatedUsage.InputTokens,
accumulatedUsage.OutputTokens,
accumulatedUsage.CacheReadInputTokens,
accumulatedUsage.CacheCreationInputTokens,
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens,
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens,
username,
time.Now(),
weeklyCycleHint,
)
}
}
return
}
}
}
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
return
}
if len(s.options.Users) == 0 {
writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication")
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
username, ok := s.userManager.Authenticate(clientToken)
if !ok {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
userConfig := s.userConfigMap[username]
if userConfig == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found")
return
}
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
provider.pollIfStale(r.Context())
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]float64{
"five_hour_utilization": avgFiveHour,
"weekly_utilization": avgWeekly,
"plan_weight": totalWeight,
})
}
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) {
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
for _, cred := range provider.allCredentials() {
if !cred.isAvailable() {
continue
}
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
continue
}
if !userConfig.AllowExternalUsage && cred.isExternal() {
continue
}
weight := cred.planWeight()
remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization()
if remaining5h < 0 {
remaining5h = 0
}
remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization()
if remainingWeekly < 0 {
remainingWeekly = 0
}
totalWeightedRemaining5h += remaining5h * weight
totalWeightedRemainingWeekly += remainingWeekly * weight
totalWeight += weight
}
if totalWeight == 0 {
return 100, 100, 0
}
return 100 - totalWeightedRemaining5h/totalWeight,
100 - totalWeightedRemainingWeekly/totalWeight,
totalWeight
}
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.CCMUser) {
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name)
if err != nil {
return
}
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
// Rewrite utilization headers to aggregated average (convert back to 0.0-1.0 range)
headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64))
headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64))
if totalWeight > 0 {
headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64))
}
}
func (s *Service) InterfaceUpdated() {
for _, cred := range s.allCredentials {
extCred, ok := cred.(*externalCredential)
for _, credential := range s.allCredentials {
external, ok := credential.(*externalCredential)
if !ok {
continue
}
if extCred.reverse && extCred.connectorURL != nil {
extCred.reverseService = s
extCred.resetReverseContext()
go extCred.connectorLoop()
if external.reverse && external.connectorURL != nil {
external.reverseService = s
external.resetReverseContext()
go external.connectorLoop()
}
}
}
func (s *Service) Close() error {
s.statusObserver.Close()
err := common.Close(
common.PtrOrNil(s.httpServer),
common.PtrOrNil(s.listener),
s.tlsConfig,
)
for _, cred := range s.allCredentials {
cred.close()
for _, credential := range s.allCredentials {
credential.close()
}
return err

View File

@@ -0,0 +1,667 @@
package ccm
import (
"bytes"
"context"
"encoding/json"
"io"
"mime"
"net/http"
"strconv"
"strings"
"time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/anthropics/anthropic-sdk-go"
)
const (
contextWindowStandard = 200000
contextWindowPremium = 1000000
premiumContextThreshold = 200000
)
const (
weeklyWindowSeconds = 604800
weeklyWindowMinutes = weeklyWindowSeconds / 60
)
type ccmRequestMetadata struct {
Model string
MessagesCount int
SessionID string
}
func isExtendedContextRequest(betaHeader string) bool {
for _, feature := range strings.Split(betaHeader, ",") {
if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") {
return true
}
}
return false
}
func isFastModeRequest(betaHeader string) bool {
for _, feature := range strings.Split(betaHeader, ",") {
if strings.HasPrefix(strings.TrimSpace(feature), "fast-mode") {
return true
}
}
return false
}
func detectContextWindow(betaHeader string, totalInputTokens int64) int {
if totalInputTokens > premiumContextThreshold {
if isExtendedContextRequest(betaHeader) {
return contextWindowPremium
}
}
return contextWindowStandard
}
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
resetAt, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset")
if !exists {
return nil
}
return &WeeklyCycleHint{
WindowMinutes: weeklyWindowMinutes,
ResetAt: resetAt.UTC(),
}
}
// extractCCMSessionID extracts the session ID from the metadata.user_id field.
//
// Claude Code >= 2.1.78 (@anthropic-ai/claude-code) encodes user_id as:
//
// JSON.stringify({device_id, account_uuid, session_id, ...extras})
//
// ref: cli.js L66() — metadata constructor
//
// Claude Code < 2.1.78 used a template literal:
//
// `user_${deviceId}_account_${accountUuid}_session_${sessionId}`
//
// ref: cli.js qs() — old metadata constructor
//
// Returns ("", nil) when userID is empty.
// Returns error when user_id is present but in an unrecognized format.
func extractCCMSessionID(userID string) (string, error) {
if userID == "" {
return "", nil
}
// v2.1.78+ JSON object format
var userIDObject struct {
SessionID string `json:"session_id"`
}
if json.Unmarshal([]byte(userID), &userIDObject) == nil && userIDObject.SessionID != "" {
return userIDObject.SessionID, nil
}
// legacy template literal format
sessionIndex := strings.LastIndex(userID, "_session_")
if sessionIndex >= 0 {
return userID[sessionIndex+len("_session_"):], nil
}
return "", E.New("unrecognized metadata.user_id format: ", userID)
}
func extractCCMRequestMetadata(path string, bodyBytes []byte) (ccmRequestMetadata, error) {
switch path {
case "/v1/messages":
var request anthropic.MessageNewParams
if json.Unmarshal(bodyBytes, &request) != nil {
return ccmRequestMetadata{}, nil
}
metadata := ccmRequestMetadata{
Model: string(request.Model),
MessagesCount: len(request.Messages),
}
if request.Metadata.UserID.Valid() {
sessionID, err := extractCCMSessionID(request.Metadata.UserID.Value)
if err != nil {
return ccmRequestMetadata{}, err
}
metadata.SessionID = sessionID
}
return metadata, nil
case "/v1/messages/count_tokens":
var request anthropic.MessageCountTokensParams
if json.Unmarshal(bodyBytes, &request) != nil {
return ccmRequestMetadata{}, nil
}
return ccmRequestMetadata{
Model: string(request.Model),
MessagesCount: len(request.Messages),
}, nil
default:
return ccmRequestMetadata{}, nil
}
}
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := log.ContextWithNewID(r.Context())
if r.URL.Path == "/ccm/v1/status" {
s.handleStatusEndpoint(w, r)
return
}
if r.URL.Path == "/ccm/v1/reverse" {
s.handleReverseConnect(ctx, w, r)
return
}
if !strings.HasPrefix(r.URL.Path, "/v1/") {
writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found")
return
}
if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"API key authentication is not supported; use Authorization: Bearer with a CCM user token")
return
}
var username string
if len(s.options.Users) > 0 {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
var ok bool
username, ok = s.userManager.Authenticate(clientToken)
if !ok {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
}
// Always read body to extract model and session ID
var bodyBytes []byte
var requestModel string
var messagesCount int
var sessionID string
if r.Body != nil {
var err error
bodyBytes, err = io.ReadAll(r.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read request body: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
return
}
requestMetadata, err := extractCCMRequestMetadata(r.URL.Path, bodyBytes)
if err != nil {
s.logger.ErrorContext(ctx, "invalid metadata format: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "invalid metadata format")
return
}
requestModel = requestMetadata.Model
messagesCount = requestMetadata.MessagesCount
sessionID = requestMetadata.SessionID
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
// Resolve credential provider and user config
var provider credentialProvider
var userConfig *option.CCMUser
if len(s.options.Users) > 0 {
userConfig = s.userConfigMap[username]
var err error
provider, err = credentialForUser(s.userConfigMap, s.providers, username)
if err != nil {
s.logger.ErrorContext(ctx, "resolve credential: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
} else {
provider = s.providers[s.options.Credentials[0].Tag]
}
if provider == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
return
}
provider.pollIfStale()
if userConfig != nil && userConfig.ExternalCredential != "" {
for _, credential := range s.allCredentials {
if credential.tagName() == userConfig.ExternalCredential && !credential.isUsable() {
credential.pollUsage()
break
}
}
}
s.cleanSessionModels()
anthropicBetaHeader := r.Header.Get("anthropic-beta")
if isFastModeRequest(anthropicBetaHeader) {
if _, isSingle := provider.(*singleCredentialProvider); !isSingle {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"fast mode requests will consume Extra usage, please use a default credential directly")
return
}
}
selection := credentialSelectionForUser(userConfig)
selectedCredential, isNew, err := provider.selectCredential(sessionID, selection)
if err != nil {
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error()))
return
}
modelDisplay := requestModel
if requestModel != "" && isExtendedContextRequest(anthropicBetaHeader) {
modelDisplay += "[1m]"
}
isNewModel := false
if sessionID != "" && modelDisplay != "" {
key := sessionModelKey{sessionID, modelDisplay}
s.sessionModelAccess.Lock()
_, exists := s.sessionModels[key]
if !exists {
s.sessionModels[key] = time.Now()
isNewModel = true
}
s.sessionModelAccess.Unlock()
}
if isNew || isNewModel {
logParts := []any{"assigned credential ", selectedCredential.tagName()}
if sessionID != "" {
logParts = append(logParts, " for session ", sessionID)
}
if username != "" {
logParts = append(logParts, " by user ", username)
}
if modelDisplay != "" {
logParts = append(logParts, ", model=", modelDisplay)
}
s.logger.DebugContext(ctx, logParts...)
}
if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"fast mode requests cannot be proxied through external credentials")
return
}
requestContext := selectedCredential.wrapRequestContext(ctx)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
defer func() {
requestContext.cancelRequest()
}()
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if err != nil {
s.logger.ErrorContext(ctx, "create proxy request: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
return
}
response, err := selectedCredential.httpClient().Do(proxyRequest)
if err != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request")
return
}
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
return
}
requestContext.releaseCredentialInterrupt()
// Transparent 429 retry
for response.StatusCode == http.StatusTooManyRequests {
resetAt := parseRateLimitResetFromHeaders(response.Header)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
selectedCredential.updateStateFromHeaders(response.Header)
if bodyBytes == nil || nextCredential == nil {
response.Body.Close()
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
return
}
response.Body.Close()
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
requestContext.cancelRequest()
requestContext = nextCredential.wrapRequestContext(ctx)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
return
}
retryResponse, retryErr := nextCredential.httpClient().Do(retryRequest)
if retryErr != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request")
return
}
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
return
}
requestContext.releaseCredentialInterrupt()
response = retryResponse
selectedCredential = nextCredential
}
defer response.Body.Close()
selectedCredential.updateStateFromHeaders(response.Header)
if response.StatusCode == 529 {
s.logger.WarnContext(ctx, "upstream overloaded from ", selectedCredential.tagName())
for key, values := range response.Header {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
w.Header()[key] = values
}
}
w.WriteHeader(response.StatusCode)
io.Copy(w, response.Body)
return
}
if response.StatusCode == http.StatusBadRequest {
if selectedCredential.isExternal() {
selectedCredential.markUpstreamRejected()
} else {
provider.pollCredentialIfStale(selectedCredential)
}
s.logger.ErrorContext(ctx, "upstream rejected from ", selectedCredential.tagName(), ": status ", response.StatusCode)
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "upstream rejected credential")
return
}
// ref (@anthropic-ai/claude-code @2.1.81): cli.js NA9 (line 179488-179494) — 401 recovery
// ref: cli.js CR1 (line 314268-314273) — 403 "OAuth token has been revoked" recovery
if !selectedCredential.isExternal() && bodyBytes != nil &&
(response.StatusCode == http.StatusUnauthorized || response.StatusCode == http.StatusForbidden) {
shouldRetry := response.StatusCode == http.StatusUnauthorized
var peekBody []byte
if response.StatusCode == http.StatusForbidden {
peekBody, _ = io.ReadAll(response.Body)
shouldRetry = strings.Contains(string(peekBody), "OAuth token has been revoked")
if !shouldRetry {
response.Body.Close()
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(peekBody))
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(peekBody))
return
}
}
if shouldRetry {
recovered := false
var recoverErr error
if defaultCred, ok := selectedCredential.(*defaultCredential); ok {
failedAccessToken := ""
currentCredentials := defaultCred.currentCredentials()
if currentCredentials != nil {
failedAccessToken = currentCredentials.AccessToken
}
s.logger.WarnContext(ctx, "upstream auth failure from ", selectedCredential.tagName(), ", reloading credentials and retrying")
recovered, recoverErr = defaultCred.recoverAuthFailure(failedAccessToken)
}
if recoverErr != nil {
response.Body.Close()
if isHardRefreshFailure(recoverErr) || selectedCredential.unavailableError() != nil {
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable during auth recovery")
return
}
writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(recoverErr, "auth recovery").Error())
return
}
if recovered {
response.Body.Close()
retryRequest, buildErr := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(buildErr, "rebuild request after auth recovery").Error())
return
}
retryResponse, retryErr := selectedCredential.httpClient().Do(retryRequest)
if retryErr != nil {
writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(retryErr, "retry request after auth recovery").Error())
return
}
response = retryResponse
defer retryResponse.Body.Close()
} else if response.StatusCode == http.StatusForbidden {
response.Body = io.NopCloser(bytes.NewReader(peekBody))
}
}
}
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
return
}
s.rewriteResponseHeaders(response.Header, provider, userConfig)
for key, values := range response.Header {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
w.Header()[key] = values
}
}
w.WriteHeader(response.StatusCode)
usageTracker := selectedCredential.usageTrackerOrNil()
if usageTracker != nil && response.StatusCode == http.StatusOK {
s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
} else {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
if err == nil && mediaType != "text/event-stream" {
_, _ = io.Copy(w, response.Body)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
buffer := make([]byte, buf.BufferSize)
for {
n, err := response.Body.Read(buffer)
if n > 0 {
_, writeError := w.Write(buffer[:n])
if writeError != nil {
if E.IsClosedOrCanceled(writeError) {
return
}
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
return
}
}
}
}
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
isStreaming := err == nil && mediaType == "text/event-stream"
if !isStreaming {
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read response body: ", err)
return
}
var message anthropic.Message
var usage anthropic.Usage
var responseModel string
err = json.Unmarshal(bodyBytes, &message)
if err == nil {
responseModel = string(message.Model)
usage = message.Usage
}
if responseModel == "" {
responseModel = requestModel
}
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
if responseModel != "" {
totalInputTokens := usage.InputTokens + usage.CacheCreationInputTokens + usage.CacheReadInputTokens
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
messagesCount,
usage.InputTokens,
usage.OutputTokens,
usage.CacheReadInputTokens,
usage.CacheCreationInputTokens,
usage.CacheCreation.Ephemeral5mInputTokens,
usage.CacheCreation.Ephemeral1hInputTokens,
username,
time.Now(),
weeklyCycleHint,
)
}
}
_, _ = writer.Write(bodyBytes)
return
}
flusher, ok := writer.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
var accumulatedUsage anthropic.Usage
var responseModel string
buffer := make([]byte, buf.BufferSize)
var leftover []byte
for {
n, err := response.Body.Read(buffer)
if n > 0 {
data := append(leftover, buffer[:n]...)
lines := bytes.Split(data, []byte("\n"))
if err == nil {
leftover = lines[len(lines)-1]
lines = lines[:len(lines)-1]
} else {
leftover = nil
}
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if bytes.HasPrefix(line, []byte("data: ")) {
eventData := bytes.TrimPrefix(line, []byte("data: "))
if bytes.Equal(eventData, []byte("[DONE]")) {
continue
}
var event anthropic.MessageStreamEventUnion
err := json.Unmarshal(eventData, &event)
if err != nil {
continue
}
switch event.Type {
case "message_start":
messageStart := event.AsMessageStart()
if messageStart.Message.Model != "" {
responseModel = string(messageStart.Message.Model)
}
if messageStart.Message.Usage.InputTokens > 0 {
accumulatedUsage.InputTokens = messageStart.Message.Usage.InputTokens
accumulatedUsage.CacheReadInputTokens = messageStart.Message.Usage.CacheReadInputTokens
accumulatedUsage.CacheCreationInputTokens = messageStart.Message.Usage.CacheCreationInputTokens
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral5mInputTokens
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral1hInputTokens
}
case "message_delta":
messageDelta := event.AsMessageDelta()
if messageDelta.Usage.OutputTokens > 0 {
accumulatedUsage.OutputTokens = messageDelta.Usage.OutputTokens
}
}
}
}
_, writeError := writer.Write(buffer[:n])
if writeError != nil {
if E.IsClosedOrCanceled(writeError) {
return
}
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
if responseModel == "" {
responseModel = requestModel
}
if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 {
if responseModel != "" {
totalInputTokens := accumulatedUsage.InputTokens + accumulatedUsage.CacheCreationInputTokens + accumulatedUsage.CacheReadInputTokens
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
messagesCount,
accumulatedUsage.InputTokens,
accumulatedUsage.OutputTokens,
accumulatedUsage.CacheReadInputTokens,
accumulatedUsage.CacheCreationInputTokens,
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens,
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens,
username,
time.Now(),
weeklyCycleHint,
)
}
}
return
}
}
}

View File

@@ -0,0 +1,221 @@
package ccm
import (
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"sync/atomic"
"testing"
"time"
)
func newHandlerCredential(t *testing.T, transport http.RoundTripper) (*defaultCredential, string) {
t.Helper()
directory := t.TempDir()
credentialPath := filepath.Join(directory, ".credentials.json")
writeTestCredentials(t, credentialPath, &oauthCredentials{
AccessToken: "old-token",
RefreshToken: "refresh-token",
ExpiresAt: time.Now().Add(time.Hour).UnixMilli(),
Scopes: []string{"user:profile", "user:inference"},
SubscriptionType: optionalStringPointer("max"),
RateLimitTier: optionalStringPointer("default_claude_max_20x"),
})
credential := newTestDefaultCredential(t, credentialPath, transport)
if err := credential.reloadCredentials(true); err != nil {
t.Fatal(err)
}
seedTestCredentialState(credential)
return credential, credentialPath
}
func TestServiceHandlerRecoversFrom401(t *testing.T) {
t.Parallel()
var messageRequests atomic.Int32
var refreshRequests atomic.Int32
credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
switch request.URL.Path {
case "/v1/messages":
call := messageRequests.Add(1)
switch request.Header.Get("Authorization") {
case "Bearer old-token":
if call != 1 {
t.Fatalf("unexpected old-token call count %d", call)
}
return newTextResponse(http.StatusUnauthorized, "unauthorized"), nil
case "Bearer new-token":
return newJSONResponse(http.StatusOK, `{}`), nil
default:
t.Fatalf("unexpected authorization header %q", request.Header.Get("Authorization"))
}
case "/v1/oauth/token":
refreshRequests.Add(1)
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
default:
t.Fatalf("unexpected path %s", request.URL.Path)
}
return nil, nil
}))
service := newTestService(credential)
recorder := httptest.NewRecorder()
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", recorder.Code, recorder.Body.String())
}
if messageRequests.Load() != 2 {
t.Fatalf("expected two upstream message requests, got %d", messageRequests.Load())
}
if refreshRequests.Load() != 1 {
t.Fatalf("expected one refresh request, got %d", refreshRequests.Load())
}
}
func TestServiceHandlerRecoversFromRevoked403(t *testing.T) {
t.Parallel()
var messageRequests atomic.Int32
var refreshRequests atomic.Int32
credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
switch request.URL.Path {
case "/v1/messages":
messageRequests.Add(1)
if request.Header.Get("Authorization") == "Bearer old-token" {
return newTextResponse(http.StatusForbidden, "OAuth token has been revoked"), nil
}
return newJSONResponse(http.StatusOK, `{}`), nil
case "/v1/oauth/token":
refreshRequests.Add(1)
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
default:
t.Fatalf("unexpected path %s", request.URL.Path)
}
return nil, nil
}))
service := newTestService(credential)
recorder := httptest.NewRecorder()
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", recorder.Code, recorder.Body.String())
}
if refreshRequests.Load() != 1 {
t.Fatalf("expected one refresh request, got %d", refreshRequests.Load())
}
}
func TestServiceHandlerDoesNotRecoverFromOrdinary403(t *testing.T) {
t.Parallel()
var refreshRequests atomic.Int32
credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
switch request.URL.Path {
case "/v1/messages":
return newTextResponse(http.StatusForbidden, "forbidden"), nil
case "/v1/oauth/token":
refreshRequests.Add(1)
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
default:
t.Fatalf("unexpected path %s", request.URL.Path)
}
return nil, nil
}))
service := newTestService(credential)
recorder := httptest.NewRecorder()
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
if recorder.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", recorder.Code)
}
if refreshRequests.Load() != 0 {
t.Fatalf("expected no refresh request, got %d", refreshRequests.Load())
}
if !strings.Contains(recorder.Body.String(), "forbidden") {
t.Fatalf("expected forbidden body, got %s", recorder.Body.String())
}
}
func TestServiceHandlerUsesReloadedTokenBeforeRefreshing(t *testing.T) {
t.Parallel()
var messageRequests atomic.Int32
var refreshRequests atomic.Int32
var credentialPath string
var credential *defaultCredential
credential, credentialPath = newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
switch request.URL.Path {
case "/v1/messages":
call := messageRequests.Add(1)
if request.Header.Get("Authorization") == "Bearer old-token" {
updatedCredentials := readTestCredentials(t, credentialPath)
updatedCredentials.AccessToken = "disk-token"
updatedCredentials.ExpiresAt = time.Now().Add(time.Hour).UnixMilli()
writeTestCredentials(t, credentialPath, updatedCredentials)
if call != 1 {
t.Fatalf("unexpected old-token call count %d", call)
}
return newTextResponse(http.StatusUnauthorized, "unauthorized"), nil
}
if request.Header.Get("Authorization") != "Bearer disk-token" {
t.Fatalf("expected disk token retry, got %q", request.Header.Get("Authorization"))
}
return newJSONResponse(http.StatusOK, `{}`), nil
case "/v1/oauth/token":
refreshRequests.Add(1)
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
default:
t.Fatalf("unexpected path %s", request.URL.Path)
}
return nil, nil
}))
service := newTestService(credential)
recorder := httptest.NewRecorder()
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", recorder.Code, recorder.Body.String())
}
if refreshRequests.Load() != 0 {
t.Fatalf("expected zero refresh requests, got %d", refreshRequests.Load())
}
}
func TestServiceHandlerRetriesAuthRecoveryOnlyOnce(t *testing.T) {
t.Parallel()
var messageRequests atomic.Int32
var refreshRequests atomic.Int32
credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
switch request.URL.Path {
case "/v1/messages":
messageRequests.Add(1)
return newTextResponse(http.StatusUnauthorized, "still unauthorized"), nil
case "/v1/oauth/token":
refreshRequests.Add(1)
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
default:
t.Fatalf("unexpected path %s", request.URL.Path)
}
return nil, nil
}))
service := newTestService(credential)
recorder := httptest.NewRecorder()
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
if recorder.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", recorder.Code)
}
if messageRequests.Load() != 2 {
t.Fatalf("expected exactly two upstream attempts, got %d", messageRequests.Load())
}
if refreshRequests.Load() != 1 {
t.Fatalf("expected exactly one refresh request, got %d", refreshRequests.Load())
}
}

View File

@@ -0,0 +1,115 @@
package ccm
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/anthropics/anthropic-sdk-go"
)
func TestWriteJSONErrorUsesAnthropicShape(t *testing.T) {
t.Parallel()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/v1/messages", nil)
request.Header.Set("Request-Id", "req_123")
writeJSONError(recorder, request, http.StatusBadRequest, "invalid_request_error", "broken")
if recorder.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", recorder.Code)
}
var body anthropic.ErrorResponse
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
t.Fatal(err)
}
if string(body.Type) != "error" {
t.Fatalf("expected error type, got %q", body.Type)
}
if body.RequestID != "req_123" {
t.Fatalf("expected req_123 request ID, got %q", body.RequestID)
}
if body.Error.Type != "invalid_request_error" {
t.Fatalf("expected invalid_request_error, got %q", body.Error.Type)
}
if body.Error.Message != "broken" {
t.Fatalf("expected broken message, got %q", body.Error.Message)
}
}
func TestExtractCCMRequestMetadataFromMessagesJSONSession(t *testing.T) {
t.Parallel()
metadata, err := extractCCMRequestMetadata("/v1/messages", []byte(`{
"model":"claude-sonnet-4-5",
"max_tokens":1,
"messages":[{"role":"user","content":"hello"}],
"metadata":{"user_id":"{\"session_id\":\"session-1\"}"}
}`))
if err != nil {
t.Fatal(err)
}
if metadata.Model != "claude-sonnet-4-5" {
t.Fatalf("expected model, got %#v", metadata)
}
if metadata.MessagesCount != 1 {
t.Fatalf("expected one message, got %#v", metadata)
}
if metadata.SessionID != "session-1" {
t.Fatalf("expected session-1, got %#v", metadata)
}
}
func TestExtractCCMRequestMetadataFromMessagesLegacySession(t *testing.T) {
t.Parallel()
metadata, err := extractCCMRequestMetadata("/v1/messages", []byte(`{
"model":"claude-sonnet-4-5",
"max_tokens":1,
"messages":[{"role":"user","content":"hello"}],
"metadata":{"user_id":"user_device_account_account_session_session-legacy"}
}`))
if err != nil {
t.Fatal(err)
}
if metadata.SessionID != "session-legacy" {
t.Fatalf("expected session-legacy, got %#v", metadata)
}
}
func TestExtractCCMRequestMetadataFromCountTokens(t *testing.T) {
t.Parallel()
metadata, err := extractCCMRequestMetadata("/v1/messages/count_tokens", []byte(`{
"model":"claude-sonnet-4-5",
"messages":[{"role":"user","content":"hello"}]
}`))
if err != nil {
t.Fatal(err)
}
if metadata.Model != "claude-sonnet-4-5" {
t.Fatalf("expected model, got %#v", metadata)
}
if metadata.MessagesCount != 1 {
t.Fatalf("expected one message, got %#v", metadata)
}
if metadata.SessionID != "" {
t.Fatalf("expected empty session ID, got %#v", metadata)
}
}
func TestExtractCCMRequestMetadataIgnoresUnsupportedPath(t *testing.T) {
t.Parallel()
metadata, err := extractCCMRequestMetadata("/v1/models", []byte(`{"model":"claude"}`))
if err != nil {
t.Fatal(err)
}
if metadata != (ccmRequestMetadata{}) {
t.Fatalf("expected zero metadata, got %#v", metadata)
}
}

View File

@@ -0,0 +1,398 @@
package ccm
import (
"bytes"
"encoding/json"
"net/http"
"reflect"
"strconv"
"strings"
"time"
"github.com/sagernet/sing-box/option"
)
type statusPayload struct {
FiveHourUtilization float64 `json:"five_hour_utilization"`
FiveHourReset int64 `json:"five_hour_reset"`
WeeklyUtilization float64 `json:"weekly_utilization"`
WeeklyReset int64 `json:"weekly_reset"`
PlanWeight float64 `json:"plan_weight"`
}
type aggregatedStatus struct {
fiveHourUtilization float64
weeklyUtilization float64
totalWeight float64
fiveHourReset time.Time
weeklyReset time.Time
availability availabilityStatus
}
func resetToEpoch(t time.Time) int64 {
if t.IsZero() {
return 0
}
return t.Unix()
}
func (s aggregatedStatus) equal(other aggregatedStatus) bool {
return reflect.DeepEqual(s.toPayload(), other.toPayload())
}
func (s aggregatedStatus) toPayload() statusPayload {
return statusPayload{
FiveHourUtilization: s.fiveHourUtilization,
FiveHourReset: resetToEpoch(s.fiveHourReset),
WeeklyUtilization: s.weeklyUtilization,
WeeklyReset: resetToEpoch(s.weeklyReset),
PlanWeight: s.totalWeight,
}
}
type aggregateInput struct {
availability availabilityStatus
}
func aggregateAvailability(inputs []aggregateInput) availabilityStatus {
if len(inputs) == 0 {
return availabilityStatus{
State: availabilityStateUnavailable,
Reason: availabilityReasonNoCredentials,
}
}
var earliestRateLimit time.Time
var hasRateLimited bool
var blocked availabilityStatus
var hasBlocked bool
var hasUnavailable bool
for _, input := range inputs {
availability := input.availability.normalized()
switch availability.State {
case availabilityStateUsable:
return availabilityStatus{State: availabilityStateUsable}
case availabilityStateRateLimited:
hasRateLimited = true
if !availability.ResetAt.IsZero() && (earliestRateLimit.IsZero() || availability.ResetAt.Before(earliestRateLimit)) {
earliestRateLimit = availability.ResetAt
}
if blocked.State == "" {
blocked = availabilityStatus{
State: availabilityStateRateLimited,
Reason: availabilityReasonHardRateLimit,
ResetAt: earliestRateLimit,
}
}
case availabilityStateTemporarilyBlocked:
if !hasBlocked {
blocked = availability
hasBlocked = true
}
if !availability.ResetAt.IsZero() && (blocked.ResetAt.IsZero() || availability.ResetAt.Before(blocked.ResetAt)) {
blocked.ResetAt = availability.ResetAt
}
case availabilityStateUnavailable:
hasUnavailable = true
}
}
if hasRateLimited {
blocked.ResetAt = earliestRateLimit
return blocked
}
if hasBlocked {
return blocked
}
if hasUnavailable {
return availabilityStatus{
State: availabilityStateUnavailable,
Reason: availabilityReasonUnknown,
}
}
return availabilityStatus{
State: availabilityStateUnknown,
Reason: availabilityReasonUnknown,
}
}
func chooseRepresentativeClaim(fiveHourUtilization float64, fiveHourReset time.Time, weeklyUtilization float64, weeklyReset time.Time, now time.Time) string {
fiveHourWarning := claudeFiveHourWarning(fiveHourUtilization, fiveHourReset, now)
weeklyWarning := claudeWeeklyWarning(weeklyUtilization, weeklyReset, now)
type claimCandidate struct {
name string
priority int
utilization float64
}
candidateFor := func(name string, utilization float64, warning bool) claimCandidate {
priority := 0
switch {
case utilization >= 100:
priority = 2
case warning:
priority = 1
}
return claimCandidate{name: name, priority: priority, utilization: utilization}
}
five := candidateFor("5h", fiveHourUtilization, fiveHourWarning)
weekly := candidateFor("7d", weeklyUtilization, weeklyWarning)
switch {
case five.priority > weekly.priority:
return five.name
case weekly.priority > five.priority:
return weekly.name
case five.utilization > weekly.utilization:
return five.name
case weekly.utilization > five.utilization:
return weekly.name
case !fiveHourReset.IsZero():
return five.name
case !weeklyReset.IsZero():
return weekly.name
default:
return "5h"
}
}
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
return
}
var provider credentialProvider
var userConfig *option.CCMUser
if len(s.options.Users) > 0 {
if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"API key authentication is not supported; use Authorization: Bearer with a CCM user token")
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
username, ok := s.userManager.Authenticate(clientToken)
if !ok {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
userConfig = s.userConfigMap[username]
var err error
provider, err = credentialForUser(s.userConfigMap, s.providers, username)
if err != nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
} else {
provider = s.providers[s.options.Credentials[0].Tag]
}
if provider == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
return
}
if r.URL.Query().Get("watch") == "true" {
s.handleStatusStream(w, r, provider, userConfig)
return
}
provider.pollIfStale()
status := s.computeAggregatedUtilization(provider, userConfig)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(status.toPayload())
}
func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, provider credentialProvider, userConfig *option.CCMUser) {
flusher, ok := w.(http.Flusher)
if !ok {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "streaming not supported")
return
}
subscription, done, err := s.statusObserver.Subscribe()
if err != nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "service closing")
return
}
defer s.statusObserver.UnSubscribe(subscription)
provider.pollIfStale()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
last := s.computeAggregatedUtilization(provider, userConfig)
buf := &bytes.Buffer{}
json.NewEncoder(buf).Encode(last.toPayload())
_, writeErr := w.Write(buf.Bytes())
if writeErr != nil {
return
}
flusher.Flush()
for {
select {
case <-r.Context().Done():
return
case <-done:
return
case <-subscription:
for {
select {
case <-subscription:
default:
goto drained
}
}
drained:
current := s.computeAggregatedUtilization(provider, userConfig)
if current.equal(last) {
continue
}
last = current
buf.Reset()
json.NewEncoder(buf).Encode(current.toPayload())
_, writeErr = w.Write(buf.Bytes())
if writeErr != nil {
return
}
flusher.Flush()
}
}
}
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) aggregatedStatus {
visibleInputs := make([]aggregateInput, 0, len(provider.allCredentials()))
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
now := time.Now()
var totalWeightedHoursUntil5hReset, total5hResetWeight float64
var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64
var hasSnapshotData bool
for _, credential := range provider.allCredentials() {
if userConfig != nil && userConfig.ExternalCredential != "" && credential.tagName() == userConfig.ExternalCredential {
continue
}
if userConfig != nil && !userConfig.AllowExternalUsage && credential.isExternal() {
continue
}
visibleInputs = append(visibleInputs, aggregateInput{
availability: credential.availabilityStatus(),
})
if !credential.hasSnapshotData() {
continue
}
hasSnapshotData = true
weight := credential.planWeight()
remaining5h := credential.fiveHourCap() - credential.fiveHourUtilization()
if remaining5h < 0 {
remaining5h = 0
}
remainingWeekly := credential.weeklyCap() - credential.weeklyUtilization()
if remainingWeekly < 0 {
remainingWeekly = 0
}
totalWeightedRemaining5h += remaining5h * weight
totalWeightedRemainingWeekly += remainingWeekly * weight
totalWeight += weight
fiveHourReset := credential.fiveHourResetTime()
if !fiveHourReset.IsZero() {
hours := fiveHourReset.Sub(now).Hours()
if hours > 0 {
totalWeightedHoursUntil5hReset += hours * weight
total5hResetWeight += weight
}
}
weeklyReset := credential.weeklyResetTime()
if !weeklyReset.IsZero() {
hours := weeklyReset.Sub(now).Hours()
if hours > 0 {
totalWeightedHoursUntilWeeklyReset += hours * weight
totalWeeklyResetWeight += weight
}
}
}
availability := aggregateAvailability(visibleInputs)
if totalWeight == 0 {
result := aggregatedStatus{availability: availability}
if !hasSnapshotData {
result.fiveHourUtilization = 100
result.weeklyUtilization = 100
}
return result
}
result := aggregatedStatus{
fiveHourUtilization: 100 - totalWeightedRemaining5h/totalWeight,
weeklyUtilization: 100 - totalWeightedRemainingWeekly/totalWeight,
totalWeight: totalWeight,
availability: availability,
}
if total5hResetWeight > 0 {
avgHours := totalWeightedHoursUntil5hReset / total5hResetWeight
result.fiveHourReset = now.Add(time.Duration(avgHours * float64(time.Hour)))
}
if totalWeeklyResetWeight > 0 {
avgHours := totalWeightedHoursUntilWeeklyReset / totalWeeklyResetWeight
result.weeklyReset = now.Add(time.Duration(avgHours * float64(time.Hour)))
}
return result
}
func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentialProvider, userConfig *option.CCMUser) {
for key := range headers {
if strings.HasPrefix(strings.ToLower(key), "anthropic-ratelimit-unified-") {
headers.Del(key)
}
}
status := s.computeAggregatedUtilization(provider, userConfig)
now := time.Now()
headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(status.fiveHourUtilization/100, 'f', 6, 64))
headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(status.weeklyUtilization/100, 'f', 6, 64))
if !status.fiveHourReset.IsZero() {
headers.Set("anthropic-ratelimit-unified-5h-reset", strconv.FormatInt(status.fiveHourReset.Unix(), 10))
}
if !status.weeklyReset.IsZero() {
headers.Set("anthropic-ratelimit-unified-7d-reset", strconv.FormatInt(status.weeklyReset.Unix(), 10))
}
if status.totalWeight > 0 {
headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64))
}
fiveHourWarning := claudeFiveHourWarning(status.fiveHourUtilization, status.fiveHourReset, now)
weeklyWarning := claudeWeeklyWarning(status.weeklyUtilization, status.weeklyReset, now)
switch {
case status.fiveHourUtilization >= 100 || status.weeklyUtilization >= 100 ||
status.availability.State == availabilityStateRateLimited:
headers.Set("anthropic-ratelimit-unified-status", "rejected")
case fiveHourWarning || weeklyWarning:
headers.Set("anthropic-ratelimit-unified-status", "allowed_warning")
default:
headers.Set("anthropic-ratelimit-unified-status", "allowed")
}
claim := chooseRepresentativeClaim(status.fiveHourUtilization, status.fiveHourReset, status.weeklyUtilization, status.weeklyReset, now)
headers.Set("anthropic-ratelimit-unified-representative-claim", claim)
switch claim {
case "7d":
if !status.weeklyReset.IsZero() {
headers.Set("anthropic-ratelimit-unified-reset", strconv.FormatInt(status.weeklyReset.Unix(), 10))
}
default:
if !status.fiveHourReset.IsZero() {
headers.Set("anthropic-ratelimit-unified-reset", strconv.FormatInt(status.fiveHourReset.Unix(), 10))
}
}
if fiveHourWarning || status.fiveHourUtilization >= 100 {
headers.Set("anthropic-ratelimit-unified-5h-surpassed-threshold", "true")
}
if weeklyWarning || status.weeklyUtilization >= 100 {
headers.Set("anthropic-ratelimit-unified-7d-surpassed-threshold", "true")
}
}

View File

@@ -0,0 +1,234 @@
package ccm
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/sagernet/sing/common/observable"
)
type testCredential struct {
tag string
external bool
available bool
usable bool
hasData bool
fiveHour float64
weekly float64
fiveHourCapV float64
weeklyCapV float64
weight float64
fiveReset time.Time
weeklyReset time.Time
availability availabilityStatus
}
func (c *testCredential) tagName() string { return c.tag }
func (c *testCredential) isAvailable() bool { return c.available }
func (c *testCredential) isUsable() bool { return c.usable }
func (c *testCredential) isExternal() bool { return c.external }
func (c *testCredential) hasSnapshotData() bool { return c.hasData }
func (c *testCredential) fiveHourUtilization() float64 { return c.fiveHour }
func (c *testCredential) weeklyUtilization() float64 { return c.weekly }
func (c *testCredential) fiveHourCap() float64 { return c.fiveHourCapV }
func (c *testCredential) weeklyCap() float64 { return c.weeklyCapV }
func (c *testCredential) planWeight() float64 { return c.weight }
func (c *testCredential) fiveHourResetTime() time.Time { return c.fiveReset }
func (c *testCredential) weeklyResetTime() time.Time { return c.weeklyReset }
func (c *testCredential) markRateLimited(time.Time) {}
func (c *testCredential) markUpstreamRejected() {}
func (c *testCredential) availabilityStatus() availabilityStatus { return c.availability }
func (c *testCredential) earliestReset() time.Time { return c.fiveReset }
func (c *testCredential) unavailableError() error { return nil }
func (c *testCredential) getAccessToken() (string, error) { return "", nil }
func (c *testCredential) buildProxyRequest(context.Context, *http.Request, []byte, http.Header) (*http.Request, error) {
return nil, nil
}
func (c *testCredential) updateStateFromHeaders(http.Header) {}
func (c *testCredential) wrapRequestContext(context.Context) *credentialRequestContext { return nil }
func (c *testCredential) interruptConnections() {}
func (c *testCredential) setStatusSubscriber(*observable.Subscriber[struct{}]) {}
func (c *testCredential) start() error { return nil }
func (c *testCredential) pollUsage() {}
func (c *testCredential) lastUpdatedTime() time.Time { return time.Now() }
func (c *testCredential) pollBackoff(time.Duration) time.Duration { return 0 }
func (c *testCredential) usageTrackerOrNil() *AggregatedUsage { return nil }
func (c *testCredential) httpClient() *http.Client { return nil }
func (c *testCredential) close() {}
type testProvider struct {
credentials []Credential
}
func (p *testProvider) selectCredential(string, credentialSelection) (Credential, bool, error) {
return nil, false, nil
}
func (p *testProvider) onRateLimited(string, Credential, time.Time, credentialSelection) Credential {
return nil
}
func (p *testProvider) linkProviderInterrupt(Credential, credentialSelection, func()) func() bool {
return func() bool { return true }
}
func (p *testProvider) pollIfStale() {}
func (p *testProvider) pollCredentialIfStale(Credential) {}
func (p *testProvider) allCredentials() []Credential { return p.credentials }
func (p *testProvider) close() {}
func TestComputeAggregatedUtilizationPreservesSnapshotForRateLimitedCredential(t *testing.T) {
t.Parallel()
reset := time.Now().Add(15 * time.Minute)
service := &Service{}
status := service.computeAggregatedUtilization(&testProvider{credentials: []Credential{
&testCredential{
tag: "a",
available: true,
usable: false,
hasData: true,
fiveHour: 42,
weekly: 18,
fiveHourCapV: 100,
weeklyCapV: 100,
weight: 1,
fiveReset: reset,
weeklyReset: reset.Add(2 * time.Hour),
availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: reset},
},
}}, nil)
if status.fiveHourUtilization != 42 || status.weeklyUtilization != 18 {
t.Fatalf("expected preserved utilization, got 5h=%v weekly=%v", status.fiveHourUtilization, status.weeklyUtilization)
}
if status.availability.State != availabilityStateRateLimited {
t.Fatalf("expected rate-limited availability, got %#v", status.availability)
}
}
func TestRewriteResponseHeadersComputesUnifiedStatus(t *testing.T) {
t.Parallel()
reset := time.Now().Add(80 * time.Minute)
service := &Service{}
headers := make(http.Header)
service.rewriteResponseHeaders(headers, &testProvider{credentials: []Credential{
&testCredential{
tag: "a",
available: true,
usable: true,
hasData: true,
fiveHour: 92,
weekly: 30,
fiveHourCapV: 100,
weeklyCapV: 100,
weight: 1,
fiveReset: reset,
weeklyReset: time.Now().Add(4 * 24 * time.Hour),
availability: availabilityStatus{State: availabilityStateUsable},
},
}}, nil)
if headers.Get("anthropic-ratelimit-unified-status") != "allowed_warning" {
t.Fatalf("expected allowed_warning, got %q", headers.Get("anthropic-ratelimit-unified-status"))
}
if headers.Get("anthropic-ratelimit-unified-representative-claim") != "5h" {
t.Fatalf("expected 5h representative claim, got %q", headers.Get("anthropic-ratelimit-unified-representative-claim"))
}
if headers.Get("anthropic-ratelimit-unified-5h-surpassed-threshold") != "true" {
t.Fatalf("expected 5h threshold header")
}
}
func TestRewriteResponseHeadersStripsUpstreamHeaders(t *testing.T) {
t.Parallel()
service := &Service{}
headers := make(http.Header)
headers.Set("anthropic-ratelimit-unified-overage-status", "rejected")
headers.Set("anthropic-ratelimit-unified-overage-disabled-reason", "org_level_disabled")
headers.Set("anthropic-ratelimit-unified-fallback", "available")
service.rewriteResponseHeaders(headers, &testProvider{credentials: []Credential{
&testCredential{
tag: "a",
available: true,
usable: true,
hasData: true,
fiveHour: 10,
weekly: 5,
fiveHourCapV: 100,
weeklyCapV: 100,
weight: 1,
fiveReset: time.Now().Add(3 * time.Hour),
weeklyReset: time.Now().Add(5 * 24 * time.Hour),
availability: availabilityStatus{State: availabilityStateUsable},
},
}}, nil)
if headers.Get("anthropic-ratelimit-unified-overage-status") != "" {
t.Fatalf("expected overage-status stripped, got %q", headers.Get("anthropic-ratelimit-unified-overage-status"))
}
if headers.Get("anthropic-ratelimit-unified-overage-disabled-reason") != "" {
t.Fatalf("expected overage-disabled-reason stripped, got %q", headers.Get("anthropic-ratelimit-unified-overage-disabled-reason"))
}
if headers.Get("anthropic-ratelimit-unified-fallback") != "" {
t.Fatalf("expected fallback stripped, got %q", headers.Get("anthropic-ratelimit-unified-fallback"))
}
if headers.Get("anthropic-ratelimit-unified-status") != "allowed" {
t.Fatalf("expected allowed status, got %q", headers.Get("anthropic-ratelimit-unified-status"))
}
}
func TestRewriteResponseHeadersRejectedOnHardRateLimit(t *testing.T) {
t.Parallel()
reset := time.Now().Add(10 * time.Minute)
service := &Service{}
headers := make(http.Header)
service.rewriteResponseHeaders(headers, &testProvider{credentials: []Credential{
&testCredential{
tag: "a",
available: true,
usable: false,
hasData: true,
fiveHour: 50,
weekly: 20,
fiveHourCapV: 100,
weeklyCapV: 100,
weight: 1,
fiveReset: reset,
weeklyReset: time.Now().Add(5 * 24 * time.Hour),
availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: reset},
},
}}, nil)
if headers.Get("anthropic-ratelimit-unified-status") != "rejected" {
t.Fatalf("expected rejected (hard rate limited), got %q", headers.Get("anthropic-ratelimit-unified-status"))
}
}
func TestWriteCredentialUnavailableErrorReturns429ForRateLimitedCredentials(t *testing.T) {
t.Parallel()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/v1/messages", nil)
provider := &testProvider{credentials: []Credential{
&testCredential{
tag: "a",
available: true,
usable: false,
hasData: true,
fiveHourCapV: 100,
weeklyCapV: 100,
weight: 1,
availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: time.Now().Add(time.Minute)},
},
}}
writeCredentialUnavailableError(recorder, request, provider, provider.credentials[0], credentialSelection{}, "all credentials rate-limited")
if recorder.Code != http.StatusTooManyRequests {
t.Fatalf("expected 429, got %d", recorder.Code)
}
}

View File

@@ -35,13 +35,13 @@ type CostCombination struct {
type AggregatedUsage struct {
LastUpdated time.Time `json:"last_updated"`
Combinations []CostCombination `json:"combinations"`
mutex sync.Mutex
access sync.Mutex
filePath string
logger log.ContextLogger
lastSaveTime time.Time
pendingSave bool
saveTimer *time.Timer
saveMutex sync.Mutex
saveAccess sync.Mutex
}
type UsageStatsJSON struct {
@@ -527,8 +527,8 @@ func deriveWeekStartUnix(cycleHint *WeeklyCycleHint) int64 {
}
func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON {
u.mutex.Lock()
defer u.mutex.Unlock()
u.access.Lock()
defer u.access.Unlock()
result := &AggregatedUsageJSON{
LastUpdated: u.LastUpdated,
@@ -561,8 +561,8 @@ func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON {
}
func (u *AggregatedUsage) Load() error {
u.mutex.Lock()
defer u.mutex.Unlock()
u.access.Lock()
defer u.access.Unlock()
u.LastUpdated = time.Time{}
u.Combinations = nil
@@ -608,9 +608,9 @@ func (u *AggregatedUsage) Save() error {
defer os.Remove(tmpFile)
err = os.Rename(tmpFile, u.filePath)
if err == nil {
u.saveMutex.Lock()
u.saveAccess.Lock()
u.lastSaveTime = time.Now()
u.saveMutex.Unlock()
u.saveAccess.Unlock()
}
return err
}
@@ -644,15 +644,15 @@ func (u *AggregatedUsage) AddUsageWithCycleHint(
observedAt = time.Now()
}
u.mutex.Lock()
defer u.mutex.Unlock()
u.access.Lock()
defer u.access.Unlock()
u.LastUpdated = observedAt
weekStartUnix := deriveWeekStartUnix(cycleHint)
addUsageToCombinations(&u.Combinations, model, contextWindow, weekStartUnix, messagesCount, inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens, cacheCreation5MinuteTokens, cacheCreation1HourTokens, user)
go u.scheduleSave()
u.scheduleSave()
return nil
}
@@ -660,8 +660,8 @@ func (u *AggregatedUsage) AddUsageWithCycleHint(
func (u *AggregatedUsage) scheduleSave() {
const saveInterval = time.Minute
u.saveMutex.Lock()
defer u.saveMutex.Unlock()
u.saveAccess.Lock()
defer u.saveAccess.Unlock()
timeSinceLastSave := time.Since(u.lastSaveTime)
@@ -678,9 +678,9 @@ func (u *AggregatedUsage) scheduleSave() {
remainingTime := saveInterval - timeSinceLastSave
u.saveTimer = time.AfterFunc(remainingTime, func() {
u.saveMutex.Lock()
u.saveAccess.Lock()
u.pendingSave = false
u.saveMutex.Unlock()
u.saveAccess.Unlock()
u.saveAsync()
})
}
@@ -695,8 +695,8 @@ func (u *AggregatedUsage) saveAsync() {
}
func (u *AggregatedUsage) cancelPendingSave() {
u.saveMutex.Lock()
defer u.saveMutex.Unlock()
u.saveAccess.Lock()
defer u.saveAccess.Unlock()
if u.saveTimer != nil {
u.saveTimer.Stop()

View File

@@ -7,13 +7,13 @@ import (
)
type UserManager struct {
accessMutex sync.RWMutex
tokenMap map[string]string
access sync.RWMutex
tokenMap map[string]string
}
func (m *UserManager) UpdateUsers(users []option.CCMUser) {
m.accessMutex.Lock()
defer m.accessMutex.Unlock()
m.access.Lock()
defer m.access.Unlock()
tokenMap := make(map[string]string, len(users))
for _, user := range users {
tokenMap[user.Token] = user.Name
@@ -22,8 +22,8 @@ func (m *UserManager) UpdateUsers(users []option.CCMUser) {
}
func (m *UserManager) Authenticate(token string) (string, bool) {
m.accessMutex.RLock()
m.access.RLock()
username, found := m.tokenMap[token]
m.accessMutex.RUnlock()
m.access.RUnlock()
return username, found
}

View File

@@ -0,0 +1,139 @@
package ccm
import (
"context"
"io"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(request *http.Request) (*http.Response, error) {
return f(request)
}
func newJSONResponse(statusCode int, body string) *http.Response {
return &http.Response{
StatusCode: statusCode,
Status: http.StatusText(statusCode),
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(body)),
}
}
func newTextResponse(statusCode int, body string) *http.Response {
return &http.Response{
StatusCode: statusCode,
Status: http.StatusText(statusCode),
Header: http.Header{"Content-Type": []string{"text/plain"}},
Body: io.NopCloser(strings.NewReader(body)),
}
}
func writeTestCredentials(t *testing.T, path string, credentials *oauthCredentials) {
t.Helper()
if path == "" {
var err error
path, err = getDefaultCredentialsPath()
if err != nil {
t.Fatal(err)
}
}
if err := writeCredentialsToFile(credentials, path); err != nil {
t.Fatal(err)
}
}
func readTestCredentials(t *testing.T, path string) *oauthCredentials {
t.Helper()
if path == "" {
var err error
path, err = getDefaultCredentialsPath()
if err != nil {
t.Fatal(err)
}
}
credentials, err := readCredentialsFromFile(path)
if err != nil {
t.Fatal(err)
}
return credentials
}
func newTestDefaultCredential(t *testing.T, credentialPath string, transport http.RoundTripper) *defaultCredential {
t.Helper()
credentialFilePath, err := resolveCredentialFilePath(credentialPath)
if err != nil {
t.Fatal(err)
}
requestContext, cancelRequests := context.WithCancel(context.Background())
credential := &defaultCredential{
tag: "test",
serviceContext: context.Background(),
credentialPath: credentialPath,
credentialFilePath: credentialFilePath,
configDir: resolveConfigDir(credentialPath, credentialFilePath),
syncClaudeConfig: credentialPath == "",
cap5h: 99,
capWeekly: 99,
forwardHTTPClient: &http.Client{Transport: transport},
acquireLock: acquireCredentialLock,
logger: log.NewNOPFactory().Logger(),
requestContext: requestContext,
cancelRequests: cancelRequests,
}
if credential.syncClaudeConfig {
credential.claudeDirectory = credential.configDir
credential.claudeConfigPath = resolveClaudeConfigWritePath(credential.claudeDirectory)
}
credential.state.lastUpdated = time.Now()
return credential
}
func seedTestCredentialState(credential *defaultCredential) {
billingType := "individual"
accountCreatedAt := "2024-01-01T00:00:00Z"
subscriptionCreatedAt := "2024-01-02T00:00:00Z"
credential.stateAccess.Lock()
credential.state.accountUUID = "account"
credential.state.accountType = "max"
credential.state.rateLimitTier = "default_claude_max_20x"
credential.state.oauthAccount = &claudeOAuthAccount{
AccountUUID: "account",
EmailAddress: "user@example.com",
OrganizationUUID: "org",
BillingType: &billingType,
AccountCreatedAt: &accountCreatedAt,
SubscriptionCreatedAt: &subscriptionCreatedAt,
}
credential.stateAccess.Unlock()
}
func newTestService(credential *defaultCredential) *Service {
return &Service{
logger: log.NewNOPFactory().Logger(),
options: option.CCMServiceOptions{Credentials: []option.CCMCredential{{Tag: "default"}}},
httpHeaders: make(http.Header),
providers: map[string]credentialProvider{"default": &singleCredentialProvider{credential: credential}},
sessionModels: make(map[sessionModelKey]time.Time),
}
}
func newMessageRequest(body string) *http.Request {
request := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
request.Header.Set("Content-Type", "application/json")
return request
}
func tempConfigPath(t *testing.T, dir string) string {
t.Helper()
return filepath.Join(dir, claudeCodeLegacyConfigFileName())
}

7
service/ocm/CLAUDE.md Normal file
View File

@@ -0,0 +1,7 @@
# OpenAI Codex Multiplexer
### Reverse Codex
Oh, Codex is just open source.
Clone it and study its code: https://github.com/openai/codex

View File

@@ -1,225 +1,275 @@
package ocm
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"os"
"os/user"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/observable"
)
const (
oauth2ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
oauth2TokenURL = "https://auth.openai.com/oauth/token"
openaiAPIBaseURL = "https://api.openai.com"
chatGPTBackendURL = "https://chatgpt.com/backend-api/codex"
tokenRefreshIntervalDays = 8
defaultPollInterval = 60 * time.Minute
failedPollRetryInterval = time.Minute
httpRetryMaxBackoff = 5 * time.Minute
)
func getRealUser() (*user.User, error) {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
sudoUserInfo, err := user.Lookup(sudoUser)
if err == nil {
return sudoUserInfo, nil
const (
httpRetryMaxAttempts = 3
httpRetryInitialDelay = 200 * time.Millisecond
)
const sessionExpiry = 24 * time.Hour
func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) {
var lastError error
for attempt := range httpRetryMaxAttempts {
if attempt > 0 {
delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1))
timer := time.NewTimer(delay)
select {
case <-ctx.Done():
timer.Stop()
return nil, lastError
case <-timer.C:
}
}
}
return user.Current()
}
func getDefaultCredentialsPath() (string, error) {
if codexHome := os.Getenv("CODEX_HOME"); codexHome != "" {
return filepath.Join(codexHome, "auth.json"), nil
}
userInfo, err := getRealUser()
if err != nil {
return "", err
}
return filepath.Join(userInfo.HomeDir, ".codex", "auth.json"), nil
}
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var credentials oauthCredentials
err = json.Unmarshal(data, &credentials)
if err != nil {
return nil, err
}
return &credentials, nil
}
func checkCredentialFileWritable(path string) error {
file, err := os.OpenFile(path, os.O_WRONLY, 0)
if err != nil {
return err
}
return file.Close()
}
func writeCredentialsToFile(credentials *oauthCredentials, path string) error {
data, err := json.MarshalIndent(credentials, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o600)
}
type oauthCredentials struct {
APIKey string `json:"OPENAI_API_KEY,omitempty"`
Tokens *tokenData `json:"tokens,omitempty"`
LastRefresh *time.Time `json:"last_refresh,omitempty"`
}
type tokenData struct {
IDToken string `json:"id_token,omitempty"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
AccountID string `json:"account_id,omitempty"`
}
func (c *oauthCredentials) isAPIKeyMode() bool {
return c.APIKey != ""
}
func (c *oauthCredentials) getAccessToken() string {
if c.APIKey != "" {
return c.APIKey
}
if c.Tokens != nil {
return c.Tokens.AccessToken
}
return ""
}
func (c *oauthCredentials) getAccountID() string {
if c.Tokens != nil {
return c.Tokens.AccountID
}
return ""
}
func (c *oauthCredentials) needsRefresh() bool {
if c.APIKey != "" {
return false
}
if c.Tokens == nil || c.Tokens.RefreshToken == "" {
return false
}
if c.LastRefresh == nil {
return true
}
return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour
}
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" {
return nil, E.New("refresh token is empty")
}
requestBody, err := json.Marshal(map[string]string{
"grant_type": "refresh_token",
"refresh_token": credentials.Tokens.RefreshToken,
"client_id": oauth2ClientID,
"scope": "openid profile email",
})
if err != nil {
return nil, E.Cause(err, "marshal request")
}
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
request, err := buildRequest()
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
return request, nil
response, err := client.Do(request)
if err == nil {
return response, nil
}
lastError = err
if ctx.Err() != nil {
return nil, lastError
}
}
return nil, lastError
}
type credentialState struct {
fiveHourUtilization float64
fiveHourReset time.Time
weeklyUtilization float64
weeklyReset time.Time
hardRateLimited bool
rateLimitResetAt time.Time
availabilityState availabilityState
availabilityReason availabilityReason
availabilityResetAt time.Time
lastKnownDataAt time.Time
accountType string
remotePlanWeight float64
activeLimitID string
rateLimitSnapshots map[string]rateLimitSnapshot
lastUpdated time.Time
consecutivePollFailures int
usageAPIRetryDelay time.Duration
unavailable bool
upstreamRejectedUntil time.Time
lastCredentialLoadAttempt time.Time
lastCredentialLoadError string
}
type credentialRequestContext struct {
context.Context
releaseOnce sync.Once
cancelOnce sync.Once
releaseFuncs []func() bool
cancelFunc context.CancelFunc
}
func (c *credentialRequestContext) addInterruptLink(stop func() bool) {
c.releaseFuncs = append(c.releaseFuncs, stop)
}
func (c *credentialRequestContext) releaseCredentialInterrupt() {
c.releaseOnce.Do(func() {
for _, f := range c.releaseFuncs {
f()
}
})
if err != nil {
return nil, err
}
defer response.Body.Close()
}
if response.StatusCode == http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
return nil, E.New("refresh rate limited: ", response.Status, " ", string(body))
}
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
}
func (c *credentialRequestContext) cancelRequest() {
c.releaseCredentialInterrupt()
c.cancelOnce.Do(c.cancelFunc)
}
var tokenResponse struct {
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
if err != nil {
return nil, E.Cause(err, "decode response")
}
type Credential interface {
tagName() string
isAvailable() bool
isUsable() bool
isExternal() bool
hasSnapshotData() bool
fiveHourUtilization() float64
weeklyUtilization() float64
fiveHourCap() float64
weeklyCap() float64
planWeight() float64
weeklyResetTime() time.Time
fiveHourResetTime() time.Time
markRateLimited(resetAt time.Time)
markUpstreamRejected()
markTemporarilyBlocked(reason availabilityReason, resetAt time.Time)
availabilityStatus() availabilityStatus
earliestReset() time.Time
unavailableError() error
newCredentials := *credentials
if newCredentials.Tokens == nil {
newCredentials.Tokens = &tokenData{}
getAccessToken() (string, error)
buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error)
updateStateFromHeaders(header http.Header)
wrapRequestContext(ctx context.Context) *credentialRequestContext
interruptConnections()
setOnBecameUnusable(fn func())
setStatusSubscriber(*observable.Subscriber[struct{}])
start() error
pollUsage()
lastUpdatedTime() time.Time
pollBackoff(base time.Duration) time.Duration
usageTrackerOrNil() *AggregatedUsage
httpClient() *http.Client
close()
// OCM-specific
ocmDialer() N.Dialer
ocmIsAPIKeyMode() bool
ocmGetAccountID() string
ocmGetBaseURL() string
}
type credentialSelectionScope string
const (
credentialSelectionScopeAll credentialSelectionScope = "all"
credentialSelectionScopeNonExternal credentialSelectionScope = "non_external"
)
type credentialSelection struct {
scope credentialSelectionScope
filter func(Credential) bool
}
func (s credentialSelection) allows(credential Credential) bool {
return s.filter == nil || s.filter(credential)
}
func (s credentialSelection) scopeOrDefault() credentialSelectionScope {
if s.scope == "" {
return credentialSelectionScopeAll
}
if tokenResponse.IDToken != "" {
newCredentials.Tokens.IDToken = tokenResponse.IDToken
return s.scope
}
func normalizeRateLimitIdentifier(limitIdentifier string) string {
trimmedIdentifier := strings.TrimSpace(strings.ToLower(limitIdentifier))
if trimmedIdentifier == "" {
return ""
}
if tokenResponse.AccessToken != "" {
newCredentials.Tokens.AccessToken = tokenResponse.AccessToken
return strings.ReplaceAll(trimmedIdentifier, "_", "-")
}
func parseInt64Header(headers http.Header, headerName string) (int64, bool) {
headerValue := strings.TrimSpace(headers.Get(headerName))
if headerValue == "" {
return 0, false
}
if tokenResponse.RefreshToken != "" {
newCredentials.Tokens.RefreshToken = tokenResponse.RefreshToken
parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64)
if parseError != nil {
return 0, false
}
return parsedValue, true
}
func parseOCMRateLimitResetFromHeaders(headers http.Header) time.Time {
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
if activeLimitIdentifier != "" {
resetHeader := "x-" + activeLimitIdentifier + "-primary-reset-at"
if resetStr := headers.Get(resetHeader); resetStr != "" {
value, err := strconv.ParseInt(resetStr, 10, 64)
if err == nil {
return time.Unix(value, 0)
}
}
}
if retryAfter := headers.Get("Retry-After"); retryAfter != "" {
seconds, err := strconv.ParseInt(retryAfter, 10, 64)
if err == nil {
return time.Now().Add(time.Duration(seconds) * time.Second)
}
}
return time.Now().Add(5 * time.Minute)
}
func (s *credentialState) noteSnapshotData() {
s.lastKnownDataAt = time.Now()
}
func (s credentialState) hasSnapshotData() bool {
return !s.lastKnownDataAt.IsZero() ||
s.fiveHourUtilization > 0 ||
s.weeklyUtilization > 0 ||
!s.fiveHourReset.IsZero() ||
!s.weeklyReset.IsZero() ||
len(s.rateLimitSnapshots) > 0
}
func (s *credentialState) setAvailability(state availabilityState, reason availabilityReason, resetAt time.Time) {
s.availabilityState = state
s.availabilityReason = reason
s.availabilityResetAt = resetAt
}
func (s credentialState) currentAvailability() availabilityStatus {
now := time.Now()
newCredentials.LastRefresh = &now
return &newCredentials, nil
}
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
if credentials == nil {
return nil
}
cloned := *credentials
if credentials.Tokens != nil {
clonedTokens := *credentials.Tokens
cloned.Tokens = &clonedTokens
}
if credentials.LastRefresh != nil {
lastRefresh := *credentials.LastRefresh
cloned.LastRefresh = &lastRefresh
}
return &cloned
}
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
if left == nil || right == nil {
return left == right
}
if left.APIKey != right.APIKey {
return false
}
if (left.Tokens == nil) != (right.Tokens == nil) {
return false
}
if left.Tokens != nil && *left.Tokens != *right.Tokens {
return false
}
if (left.LastRefresh == nil) != (right.LastRefresh == nil) {
return false
}
if left.LastRefresh != nil && !left.LastRefresh.Equal(*right.LastRefresh) {
return false
}
return true
switch {
case s.unavailable:
return availabilityStatus{
State: availabilityStateUnavailable,
Reason: availabilityReasonUnknown,
}
case s.availabilityState == availabilityStateTemporarilyBlocked &&
(s.availabilityResetAt.IsZero() || now.Before(s.availabilityResetAt)):
reason := s.availabilityReason
if reason == "" {
reason = availabilityReasonUnknown
}
return availabilityStatus{
State: availabilityStateTemporarilyBlocked,
Reason: reason,
ResetAt: s.availabilityResetAt,
}
case s.hardRateLimited && (s.rateLimitResetAt.IsZero() || now.Before(s.rateLimitResetAt)):
reason := s.availabilityReason
if reason == "" {
reason = availabilityReasonHardRateLimit
}
return availabilityStatus{
State: availabilityStateRateLimited,
Reason: reason,
ResetAt: s.rateLimitResetAt,
}
case !s.upstreamRejectedUntil.IsZero() && now.Before(s.upstreamRejectedUntil):
return availabilityStatus{
State: availabilityStateTemporarilyBlocked,
Reason: availabilityReasonUpstreamRejected,
ResetAt: s.upstreamRejectedUntil,
}
case s.consecutivePollFailures > 0:
return availabilityStatus{
State: availabilityStateTemporarilyBlocked,
Reason: availabilityReasonPollFailed,
}
default:
return availabilityStatus{State: availabilityStateUsable}
}
}

View File

@@ -0,0 +1,193 @@
package ocm
import (
"context"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
func buildOCMCredentialProviders(
ctx context.Context,
options option.OCMServiceOptions,
logger log.ContextLogger,
) (map[string]credentialProvider, []Credential, error) {
allCredentialMap := make(map[string]Credential)
var allCredentials []Credential
providers := make(map[string]credentialProvider)
// Pass 1: create default and external credentials
for _, credentialOption := range options.Credentials {
switch credentialOption.Type {
case "default":
credential, err := newDefaultCredential(ctx, credentialOption.Tag, credentialOption.DefaultOptions, logger)
if err != nil {
return nil, nil, err
}
allCredentialMap[credentialOption.Tag] = credential
allCredentials = append(allCredentials, credential)
providers[credentialOption.Tag] = &singleCredentialProvider{credential: credential}
case "external":
credential, err := newExternalCredential(ctx, credentialOption.Tag, credentialOption.ExternalOptions, logger)
if err != nil {
return nil, nil, err
}
allCredentialMap[credentialOption.Tag] = credential
allCredentials = append(allCredentials, credential)
providers[credentialOption.Tag] = &singleCredentialProvider{credential: credential}
}
}
// Pass 2: create balancer providers
for _, credentialOption := range options.Credentials {
if credentialOption.Type == "balancer" {
subCredentials, err := resolveCredentialTags(credentialOption.BalancerOptions.Credentials, allCredentialMap, credentialOption.Tag)
if err != nil {
return nil, nil, err
}
providers[credentialOption.Tag] = newBalancerProvider(subCredentials, credentialOption.BalancerOptions.Strategy, credentialOption.BalancerOptions.RebalanceThreshold, logger)
}
}
return providers, allCredentials, nil
}
func resolveCredentialTags(tags []string, allCredentials map[string]Credential, parentTag string) ([]Credential, error) {
credentials := make([]Credential, 0, len(tags))
for _, tag := range tags {
credential, exists := allCredentials[tag]
if !exists {
return nil, E.New("credential ", parentTag, " references unknown credential: ", tag)
}
credentials = append(credentials, credential)
}
if len(credentials) == 0 {
return nil, E.New("credential ", parentTag, " has no sub-credentials")
}
return credentials, nil
}
func validateOCMOptions(options option.OCMServiceOptions) error {
tags := make(map[string]bool)
credentialTypes := make(map[string]string)
for _, credential := range options.Credentials {
if tags[credential.Tag] {
return E.New("duplicate credential tag: ", credential.Tag)
}
tags[credential.Tag] = true
credentialTypes[credential.Tag] = credential.Type
if credential.Type == "default" || credential.Type == "" {
if credential.DefaultOptions.Reserve5h > 99 {
return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99")
}
if credential.DefaultOptions.ReserveWeekly > 99 {
return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99")
}
if credential.DefaultOptions.Limit5h > 100 {
return E.New("credential ", credential.Tag, ": limit_5h must be at most 100")
}
if credential.DefaultOptions.LimitWeekly > 100 {
return E.New("credential ", credential.Tag, ": limit_weekly must be at most 100")
}
if credential.DefaultOptions.Reserve5h > 0 && credential.DefaultOptions.Limit5h > 0 {
return E.New("credential ", credential.Tag, ": reserve_5h and limit_5h are mutually exclusive")
}
if credential.DefaultOptions.ReserveWeekly > 0 && credential.DefaultOptions.LimitWeekly > 0 {
return E.New("credential ", credential.Tag, ": reserve_weekly and limit_weekly are mutually exclusive")
}
}
if credential.Type == "external" {
if credential.ExternalOptions.Token == "" {
return E.New("credential ", credential.Tag, ": external credential requires token")
}
if credential.ExternalOptions.Reverse && credential.ExternalOptions.URL == "" {
return E.New("credential ", credential.Tag, ": reverse external credential requires url")
}
}
if credential.Type == "balancer" {
switch credential.BalancerOptions.Strategy {
case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback:
default:
return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy)
}
if credential.BalancerOptions.RebalanceThreshold < 0 {
return E.New("credential ", credential.Tag, ": rebalance_threshold must not be negative")
}
}
}
singleCredential := len(options.Credentials) == 1
for _, user := range options.Users {
if user.Credential == "" && !singleCredential {
return E.New("user ", user.Name, " must specify credential in multi-credential mode")
}
if user.Credential != "" && !tags[user.Credential] {
return E.New("user ", user.Name, " references unknown credential: ", user.Credential)
}
if user.ExternalCredential != "" {
if !tags[user.ExternalCredential] {
return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential)
}
if credentialTypes[user.ExternalCredential] != "external" {
return E.New("user ", user.Name, ": external_credential must reference an external type credential")
}
}
}
return nil
}
func validateOCMCompositeCredentialModes(
options option.OCMServiceOptions,
providers map[string]credentialProvider,
) error {
for _, credentialOption := range options.Credentials {
if credentialOption.Type != "balancer" {
continue
}
provider, exists := providers[credentialOption.Tag]
if !exists {
return E.New("unknown credential: ", credentialOption.Tag)
}
for _, subCred := range provider.allCredentials() {
if !subCred.isAvailable() {
continue
}
if subCred.ocmIsAPIKeyMode() {
return E.New(
"credential ", credentialOption.Tag,
" references API key default credential ", subCred.tagName(),
"; balancer and fallback only support OAuth default credentials",
)
}
}
}
return nil
}
func credentialForUser(
userConfigMap map[string]*option.OCMUser,
providers map[string]credentialProvider,
username string,
) (credentialProvider, error) {
userConfig, exists := userConfigMap[username]
if !exists {
return nil, E.New("no credential mapping for user: ", username)
}
if userConfig.Credential == "" {
for _, provider := range providers {
return provider, nil
}
return nil, E.New("no credential available")
}
provider, exists := providers[userConfig.Credential]
if !exists {
return nil, E.New("unknown credential: ", userConfig.Credential)
}
return provider, nil
}

View File

@@ -0,0 +1,826 @@
package ocm
import (
"bytes"
"context"
stdTLS "crypto/tls"
"encoding/json"
"io"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/common/observable"
)
type defaultCredential struct {
tag string
serviceContext context.Context
credentialPath string
credentialFilePath string
credentials *oauthCredentials
access sync.RWMutex
state credentialState
stateAccess sync.RWMutex
pollAccess sync.Mutex
reloadAccess sync.Mutex
watcherAccess sync.Mutex
cap5h float64
capWeekly float64
usageTracker *AggregatedUsage
dialer N.Dialer
forwardHTTPClient *http.Client
logger log.ContextLogger
watcher *fswatch.Watcher
watcherRetryAt time.Time
statusSubscriber *observable.Subscriber[struct{}]
// Refresh rate-limit cooldown (protected by access mutex)
refreshRetryAt time.Time
refreshRetryError error
refreshBlocked bool
// Connection interruption
onBecameUnusable func()
interrupted bool
requestContext context.Context
cancelRequests context.CancelFunc
requestAccess sync.Mutex
}
func newDefaultCredential(ctx context.Context, tag string, options option.OCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) {
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: option.DialerOptions{
Detour: options.Detour,
},
RemoteIsDomain: true,
})
if err != nil {
return nil, E.Cause(err, "create dialer for credential ", tag)
}
httpClient := &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: true,
TLSClientConfig: &stdTLS.Config{
RootCAs: adapter.RootPoolFromContext(ctx),
Time: ntp.TimeFuncFromContext(ctx),
},
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
},
}
reserve5h := options.Reserve5h
if reserve5h == 0 {
reserve5h = 1
}
reserveWeekly := options.ReserveWeekly
if reserveWeekly == 0 {
reserveWeekly = 1
}
var cap5h float64
if options.Limit5h > 0 {
cap5h = float64(options.Limit5h)
} else {
cap5h = float64(100 - reserve5h)
}
var capWeekly float64
if options.LimitWeekly > 0 {
capWeekly = float64(options.LimitWeekly)
} else {
capWeekly = float64(100 - reserveWeekly)
}
requestContext, cancelRequests := context.WithCancel(context.Background())
credential := &defaultCredential{
tag: tag,
serviceContext: ctx,
credentialPath: options.CredentialPath,
cap5h: cap5h,
capWeekly: capWeekly,
dialer: credentialDialer,
forwardHTTPClient: httpClient,
logger: logger,
requestContext: requestContext,
cancelRequests: cancelRequests,
}
if options.UsagesPath != "" {
credential.usageTracker = &AggregatedUsage{
LastUpdated: time.Now(),
Combinations: make([]CostCombination, 0),
filePath: options.UsagesPath,
logger: logger,
}
}
return credential, nil
}
func (c *defaultCredential) start() error {
credentialFilePath, err := resolveCredentialFilePath(c.credentialPath)
if err != nil {
return E.Cause(err, "resolve credential path for ", c.tag)
}
c.credentialFilePath = credentialFilePath
err = c.ensureCredentialWatcher()
if err != nil {
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
}
err = c.reloadCredentials(true)
if err != nil {
c.logger.Warn("initial credential load for ", c.tag, ": ", err)
}
if c.usageTracker != nil {
err = c.usageTracker.Load()
if err != nil {
c.logger.Warn("load usage statistics for ", c.tag, ": ", err)
}
}
go c.pollUsage()
return nil
}
func (c *defaultCredential) setOnBecameUnusable(fn func()) {
c.onBecameUnusable = fn
}
func (c *defaultCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) {
c.statusSubscriber = subscriber
}
func (c *defaultCredential) emitStatusUpdate() {
if c.statusSubscriber != nil {
c.statusSubscriber.Emit(struct{}{})
}
}
func (c *defaultCredential) tagName() string {
return c.tag
}
func (c *defaultCredential) isExternal() bool {
return false
}
func (c *defaultCredential) getAccessToken() (string, error) {
c.retryCredentialReloadIfNeeded()
c.access.RLock()
if c.credentials != nil && !c.credentials.needsRefresh() {
token := c.credentials.getAccessToken()
c.access.RUnlock()
return token, nil
}
c.access.RUnlock()
err := c.reloadCredentials(true)
if err == nil {
c.access.RLock()
if c.credentials != nil && !c.credentials.needsRefresh() {
token := c.credentials.getAccessToken()
c.access.RUnlock()
return token, nil
}
c.access.RUnlock()
}
c.access.Lock()
defer c.access.Unlock()
if c.credentials == nil {
return "", c.unavailableError()
}
if !c.credentials.needsRefresh() {
return c.credentials.getAccessToken(), nil
}
if c.refreshBlocked {
return "", c.refreshRetryError
}
if !c.refreshRetryAt.IsZero() && time.Now().Before(c.refreshRetryAt) {
return "", c.refreshRetryError
}
err = platformCanWriteCredentials(c.credentialPath)
if err != nil {
return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation")
}
baseCredentials := cloneCredentials(c.credentials)
newCredentials, retryDelay, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials)
if err != nil {
if retryDelay < 0 {
c.refreshBlocked = true
c.refreshRetryError = err
} else if retryDelay > 0 {
c.refreshRetryAt = time.Now().Add(retryDelay)
c.refreshRetryError = err
}
return "", err
}
c.refreshRetryAt = time.Time{}
c.refreshRetryError = nil
c.refreshBlocked = false
latestCredentials, latestErr := platformReadCredentials(c.credentialPath)
if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) {
c.credentials = latestCredentials
c.stateAccess.Lock()
wasAvailable := !c.state.unavailable
c.state.unavailable = false
c.state.lastCredentialLoadAttempt = time.Now()
c.state.lastCredentialLoadError = ""
c.checkTransitionLocked()
shouldEmit := wasAvailable != !c.state.unavailable
c.stateAccess.Unlock()
if shouldEmit {
c.emitStatusUpdate()
}
if !latestCredentials.needsRefresh() {
return latestCredentials.getAccessToken(), nil
}
return "", E.New("credential ", c.tag, " changed while refreshing")
}
c.credentials = newCredentials
c.stateAccess.Lock()
wasAvailable := !c.state.unavailable
c.state.unavailable = false
c.state.lastCredentialLoadAttempt = time.Now()
c.state.lastCredentialLoadError = ""
c.checkTransitionLocked()
shouldEmit := wasAvailable != !c.state.unavailable
c.stateAccess.Unlock()
if shouldEmit {
c.emitStatusUpdate()
}
err = platformWriteCredentials(newCredentials, c.credentialPath)
if err != nil {
c.logger.Error("persist refreshed token for ", c.tag, ": ", err)
}
return newCredentials.getAccessToken(), nil
}
func (c *defaultCredential) getAccountID() string {
c.access.RLock()
defer c.access.RUnlock()
if c.credentials == nil {
return ""
}
return c.credentials.getAccountID()
}
func (c *defaultCredential) isAPIKeyMode() bool {
c.access.RLock()
defer c.access.RUnlock()
if c.credentials == nil {
return false
}
return c.credentials.isAPIKeyMode()
}
func (c *defaultCredential) getBaseURL() string {
if c.isAPIKeyMode() {
return openaiAPIBaseURL
}
return chatGPTBackendURL
}
func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
c.stateAccess.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
hadData := false
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
if activeLimitIdentifier == "" {
activeLimitIdentifier = "codex"
}
fiveHourResetChanged := false
fiveHourResetAt := headers.Get("x-" + activeLimitIdentifier + "-primary-reset-at")
if fiveHourResetAt != "" {
value, err := strconv.ParseInt(fiveHourResetAt, 10, 64)
if err == nil {
hadData = true
newReset := time.Unix(value, 0)
if newReset.After(c.state.fiveHourReset) {
fiveHourResetChanged = true
c.state.fiveHourReset = newReset
}
}
}
fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent")
if fiveHourPercent != "" {
value, err := strconv.ParseFloat(fiveHourPercent, 64)
if err == nil {
hadData = true
if value >= c.state.fiveHourUtilization || fiveHourResetChanged {
c.state.fiveHourUtilization = value
}
}
}
weeklyResetChanged := false
weeklyResetAt := headers.Get("x-" + activeLimitIdentifier + "-secondary-reset-at")
if weeklyResetAt != "" {
value, err := strconv.ParseInt(weeklyResetAt, 10, 64)
if err == nil {
hadData = true
newReset := time.Unix(value, 0)
if newReset.After(c.state.weeklyReset) {
weeklyResetChanged = true
c.state.weeklyReset = newReset
}
}
}
weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent")
if weeklyPercent != "" {
value, err := strconv.ParseFloat(weeklyPercent, 64)
if err == nil {
hadData = true
if value >= c.state.weeklyUtilization || weeklyResetChanged {
c.state.weeklyUtilization = value
}
}
}
if snapshots := parseRateLimitSnapshotsFromHeaders(headers); len(snapshots) > 0 {
hadData = true
applyRateLimitSnapshotsLocked(&c.state, snapshots, headers.Get("x-codex-active-limit"), c.state.remotePlanWeight, c.state.accountType)
}
if hadData {
c.state.consecutivePollFailures = 0
c.state.lastUpdated = time.Now()
c.state.noteSnapshotData()
}
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
resetSuffix := ""
if !c.state.weeklyReset.IsZero() {
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
}
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
}
shouldEmit := hadData && (c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly || fiveHourResetChanged || weeklyResetChanged)
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
if shouldEmit {
c.emitStatusUpdate()
}
}
func (c *defaultCredential) markRateLimited(resetAt time.Time) {
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
c.stateAccess.Lock()
c.state.hardRateLimited = true
c.state.rateLimitResetAt = resetAt
c.state.setAvailability(availabilityStateRateLimited, availabilityReasonHardRateLimit, resetAt)
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *defaultCredential) markUpstreamRejected() {}
func (c *defaultCredential) markTemporarilyBlocked(reason availabilityReason, resetAt time.Time) {
c.stateAccess.Lock()
c.state.setAvailability(availabilityStateTemporarilyBlocked, reason, resetAt)
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *defaultCredential) isUsable() bool {
c.retryCredentialReloadIfNeeded()
c.stateAccess.RLock()
if c.state.unavailable {
c.stateAccess.RUnlock()
return false
}
if c.state.consecutivePollFailures > 0 {
c.stateAccess.RUnlock()
return false
}
if c.state.hardRateLimited {
if time.Now().Before(c.state.rateLimitResetAt) {
c.stateAccess.RUnlock()
return false
}
c.stateAccess.RUnlock()
c.stateAccess.Lock()
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
usable := c.checkReservesLocked()
c.stateAccess.Unlock()
return usable
}
usable := c.checkReservesLocked()
c.stateAccess.RUnlock()
return usable
}
func (c *defaultCredential) checkReservesLocked() bool {
if c.state.fiveHourUtilization >= c.cap5h {
return false
}
if c.state.weeklyUtilization >= c.capWeekly {
return false
}
return true
}
// checkTransitionLocked detects usable->unusable transition.
// Must be called with stateAccess write lock held.
func (c *defaultCredential) checkTransitionLocked() bool {
unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() || c.state.consecutivePollFailures > 0
if unusable && !c.interrupted {
c.interrupted = true
return true
}
if !unusable && c.interrupted {
c.interrupted = false
}
return false
}
func (c *defaultCredential) interruptConnections() {
c.logger.Warn("interrupting connections for ", c.tag)
c.requestAccess.Lock()
c.cancelRequests()
c.requestContext, c.cancelRequests = context.WithCancel(context.Background())
c.requestAccess.Unlock()
if c.onBecameUnusable != nil {
c.onBecameUnusable()
}
}
func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext {
c.requestAccess.Lock()
credentialContext := c.requestContext
c.requestAccess.Unlock()
derived, cancel := context.WithCancel(parent)
stop := context.AfterFunc(credentialContext, func() {
cancel()
})
return &credentialRequestContext{
Context: derived,
releaseFuncs: []func() bool{stop},
cancelFunc: cancel,
}
}
func (c *defaultCredential) fiveHourUtilization() float64 {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.fiveHourUtilization
}
func (c *defaultCredential) hasSnapshotData() bool {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.hasSnapshotData()
}
func (c *defaultCredential) weeklyUtilization() float64 {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.weeklyUtilization
}
func (c *defaultCredential) planWeight() float64 {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return ocmPlanWeight(c.state.accountType)
}
func (c *defaultCredential) weeklyResetTime() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.weeklyReset
}
func (c *defaultCredential) fiveHourResetTime() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.fiveHourReset
}
func (c *defaultCredential) isAvailable() bool {
c.retryCredentialReloadIfNeeded()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return !c.state.unavailable
}
func (c *defaultCredential) availabilityStatus() availabilityStatus {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.currentAvailability()
}
func (c *defaultCredential) unavailableError() error {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
if !c.state.unavailable {
return nil
}
if c.state.lastCredentialLoadError == "" {
return E.New("credential ", c.tag, " is unavailable")
}
return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError)
}
func (c *defaultCredential) lastUpdatedTime() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.lastUpdated
}
func (c *defaultCredential) markUsagePollAttempted() {
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
c.state.lastUpdated = time.Now()
}
func (c *defaultCredential) incrementPollFailures() {
c.stateAccess.Lock()
c.state.consecutivePollFailures++
c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonPollFailed, time.Time{})
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration {
c.stateAccess.RLock()
failures := c.state.consecutivePollFailures
retryDelay := c.state.usageAPIRetryDelay
c.stateAccess.RUnlock()
if failures <= 0 {
if retryDelay > 0 {
return retryDelay
}
return baseInterval
}
backoff := failedPollRetryInterval * time.Duration(1<<(failures-1))
if backoff > httpRetryMaxBackoff {
return httpRetryMaxBackoff
}
return backoff
}
func (c *defaultCredential) isPollBackoffAtCap() bool {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
failures := c.state.consecutivePollFailures
return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff
}
func (c *defaultCredential) earliestReset() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
if c.state.unavailable {
return time.Time{}
}
if c.state.hardRateLimited {
return c.state.rateLimitResetAt
}
earliest := c.state.fiveHourReset
if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) {
earliest = c.state.weeklyReset
}
return earliest
}
func (c *defaultCredential) fiveHourCap() float64 {
return c.cap5h
}
func (c *defaultCredential) weeklyCap() float64 {
return c.capWeekly
}
func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage {
return c.usageTracker
}
func (c *defaultCredential) httpClient() *http.Client {
return c.forwardHTTPClient
}
func (c *defaultCredential) ocmDialer() N.Dialer {
return c.dialer
}
func (c *defaultCredential) ocmIsAPIKeyMode() bool {
return c.isAPIKeyMode()
}
func (c *defaultCredential) ocmGetAccountID() string {
return c.getAccountID()
}
func (c *defaultCredential) ocmGetBaseURL() string {
return c.getBaseURL()
}
func (c *defaultCredential) pollUsage() {
if !c.pollAccess.TryLock() {
return
}
defer c.pollAccess.Unlock()
defer c.markUsagePollAttempted()
c.retryCredentialReloadIfNeeded()
if !c.isAvailable() {
return
}
if c.isAPIKeyMode() {
return
}
accessToken, err := c.getAccessToken()
if err != nil {
if !c.isPollBackoffAtCap() {
c.logger.Error("poll usage for ", c.tag, ": get token: ", err)
}
c.incrementPollFailures()
return
}
ctx := c.serviceContext
usageURL := strings.TrimSuffix(chatGPTBackendURL, "/codex") + "/wham/usage"
accountID := c.getAccountID()
pollClient := &http.Client{
Transport: c.forwardHTTPClient.Transport,
Timeout: 5 * time.Second,
}
response, err := doHTTPWithRetry(ctx, pollClient, func() (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, usageURL, nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+accessToken)
if accountID != "" {
request.Header.Set("ChatGPT-Account-Id", accountID)
}
return request, nil
})
if err != nil {
if !c.isPollBackoffAtCap() {
c.logger.Error("poll usage for ", c.tag, ": ", err)
}
c.incrementPollFailures()
return
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
if response.StatusCode == http.StatusTooManyRequests {
retryDelay := time.Minute
if retryAfter := response.Header.Get("Retry-After"); retryAfter != "" {
seconds, err := strconv.ParseInt(retryAfter, 10, 64)
if err == nil && seconds > 0 {
retryDelay = time.Duration(seconds) * time.Second
}
}
c.logger.Warn("poll usage for ", c.tag, ": usage API rate limited, retry in ", log.FormatDuration(retryDelay))
c.stateAccess.Lock()
c.state.usageAPIRetryDelay = retryDelay
c.stateAccess.Unlock()
return
}
body, _ := io.ReadAll(response.Body)
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
c.incrementPollFailures()
return
}
var usageResponse usageRateLimitStatusPayload
err = json.NewDecoder(response.Body).Decode(&usageResponse)
if err != nil {
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
c.incrementPollFailures()
return
}
c.stateAccess.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
c.state.consecutivePollFailures = 0
c.state.usageAPIRetryDelay = 0
applyRateLimitSnapshotsLocked(&c.state, snapshotsFromUsagePayload(usageResponse), c.state.activeLimitID, c.state.remotePlanWeight, usageResponse.PlanType)
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
c.state.noteSnapshotData()
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
resetSuffix := ""
if !c.state.weeklyReset.IsZero() {
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
}
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
}
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) {
accessToken, err := c.getAccessToken()
if err != nil {
return nil, E.Cause(err, "get access token for ", c.tag)
}
path := original.URL.Path
var proxyPath string
if c.isAPIKeyMode() {
proxyPath = path
} else {
proxyPath = strings.TrimPrefix(path, "/v1")
}
proxyURL := c.getBaseURL() + proxyPath
if original.URL.RawQuery != "" {
proxyURL += "?" + original.URL.RawQuery
}
var body io.Reader
if bodyBytes != nil {
body = bytes.NewReader(bodyBytes)
} else {
body = original.Body
}
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
if err != nil {
return nil, err
}
for key, values := range original.Header {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && !isAPIKeyHeader(key) && key != "Authorization" {
proxyRequest.Header[key] = values
}
}
for key, values := range serviceHeaders {
proxyRequest.Header.Del(key)
proxyRequest.Header[key] = values
}
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
if accountID := c.getAccountID(); accountID != "" {
proxyRequest.Header.Set("ChatGPT-Account-Id", accountID)
}
return proxyRequest, nil
}
func (c *defaultCredential) close() {
if c.watcher != nil {
err := c.watcher.Close()
if err != nil {
c.logger.Error("close credential watcher for ", c.tag, ": ", err)
}
}
if c.usageTracker != nil {
c.usageTracker.cancelPendingSave()
err := c.usageTracker.Save()
if err != nil {
c.logger.Error("save usage statistics for ", c.tag, ": ", err)
}
}
}

View File

@@ -5,6 +5,7 @@ import (
"context"
stdTLS "crypto/tls"
"encoding/json"
"errors"
"io"
"net"
"net/http"
@@ -23,6 +24,7 @@ import (
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/common/observable"
"github.com/hashicorp/yamux"
)
@@ -30,18 +32,18 @@ import (
const reverseProxyBaseURL = "http://reverse-proxy"
type externalCredential struct {
tag string
baseURL string
token string
credDialer N.Dialer
httpClient *http.Client
state credentialState
stateMutex sync.RWMutex
pollAccess sync.Mutex
pollInterval time.Duration
usageTracker *AggregatedUsage
logger log.ContextLogger
tag string
baseURL string
token string
credentialDialer N.Dialer
forwardHTTPClient *http.Client
state credentialState
stateAccess sync.RWMutex
pollAccess sync.Mutex
usageTracker *AggregatedUsage
logger log.ContextLogger
statusSubscriber *observable.Subscriber[struct{}]
onBecameUnusable func()
interrupted bool
requestContext context.Context
@@ -49,24 +51,31 @@ type externalCredential struct {
requestAccess sync.Mutex
// Reverse proxy fields
reverse bool
reverseSession *yamux.Session
reverseAccess sync.RWMutex
closed bool
reverseContext context.Context
reverseCancel context.CancelFunc
connectorDialer N.Dialer
connectorDestination M.Socksaddr
connectorRequestPath string
connectorURL *url.URL
connectorTLS *stdTLS.Config
reverseService http.Handler
reverse bool
reverseHTTPClient *http.Client
reverseCredentialDialer N.Dialer
reverseSession *yamux.Session
reverseAccess sync.RWMutex
closed bool
reverseContext context.Context
reverseCancel context.CancelFunc
connectorDialer N.Dialer
connectorDestination M.Socksaddr
connectorRequestPath string
connectorURL *url.URL
connectorTLS *stdTLS.Config
reverseService http.Handler
}
type reverseSessionDialer struct {
credential *externalCredential
}
type statusStreamResult struct {
duration time.Duration
frames int
}
func (d reverseSessionDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if N.NetworkName(network) != N.NetworkTCP {
return nil, os.ErrInvalid
@@ -79,9 +88,9 @@ func (d reverseSessionDialer) ListenPacket(ctx context.Context, destination M.So
}
func externalCredentialURLPort(parsedURL *url.URL) uint16 {
portStr := parsedURL.Port()
if portStr != "" {
port, err := strconv.ParseUint(portStr, 10, 16)
portString := parsedURL.Port()
if portString != "" {
port, err := strconv.ParseUint(portString, 10, 16)
if err == nil {
return uint16(port)
}
@@ -121,18 +130,12 @@ func externalCredentialReversePath(parsedURL *url.URL, endpointPath string) stri
}
func newExternalCredential(ctx context.Context, tag string, options option.OCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) {
pollInterval := time.Duration(options.PollInterval)
if pollInterval <= 0 {
pollInterval = 30 * time.Minute
}
requestContext, cancelRequests := context.WithCancel(context.Background())
reverseContext, reverseCancel := context.WithCancel(context.Background())
cred := &externalCredential{
credential := &externalCredential{
tag: tag,
token: options.Token,
pollInterval: pollInterval,
logger: logger,
requestContext: requestContext,
cancelRequests: cancelRequests,
@@ -143,13 +146,13 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx
if options.URL == "" {
// Receiver mode: no URL, wait for reverse connection
cred.baseURL = reverseProxyBaseURL
cred.credDialer = reverseSessionDialer{credential: cred}
cred.httpClient = &http.Client{
credential.baseURL = reverseProxyBaseURL
credential.credentialDialer = reverseSessionDialer{credential: credential}
credential.forwardHTTPClient = &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: false,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return cred.openReverseConnection(ctx)
return credential.openReverseConnection(ctx)
},
},
}
@@ -190,34 +193,45 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx
}
}
cred.baseURL = externalCredentialBaseURL(parsedURL)
credential.baseURL = externalCredentialBaseURL(parsedURL)
if options.Reverse {
// Connector mode: we dial out to serve, not to proxy
cred.connectorDialer = credentialDialer
credential.credentialDialer = credentialDialer
credential.connectorDialer = credentialDialer
if options.Server != "" {
cred.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort))
credential.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort))
} else {
cred.connectorDestination = M.ParseSocksaddrHostPort(parsedURL.Hostname(), externalCredentialURLPort(parsedURL))
credential.connectorDestination = M.ParseSocksaddrHostPort(parsedURL.Hostname(), externalCredentialURLPort(parsedURL))
}
cred.connectorRequestPath = externalCredentialReversePath(parsedURL, "/ocm/v1/reverse")
cred.connectorURL = parsedURL
credential.connectorRequestPath = externalCredentialReversePath(parsedURL, "/ocm/v1/reverse")
credential.connectorURL = parsedURL
if parsedURL.Scheme == "https" {
cred.connectorTLS = &stdTLS.Config{
credential.connectorTLS = &stdTLS.Config{
ServerName: parsedURL.Hostname(),
RootCAs: adapter.RootPoolFromContext(ctx),
Time: ntp.TimeFuncFromContext(ctx),
}
}
credential.forwardHTTPClient = &http.Client{Transport: transport}
} else {
// Normal mode: standard HTTP client for proxying
cred.credDialer = credentialDialer
cred.httpClient = &http.Client{Transport: transport}
credential.credentialDialer = credentialDialer
credential.forwardHTTPClient = &http.Client{Transport: transport}
credential.reverseCredentialDialer = reverseSessionDialer{credential: credential}
credential.reverseHTTPClient = &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: false,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return credential.openReverseConnection(ctx)
},
},
}
}
}
if options.UsagesPath != "" {
cred.usageTracker = &AggregatedUsage{
credential.usageTracker = &AggregatedUsage{
LastUpdated: time.Now(),
Combinations: make([]CostCombination, 0),
filePath: options.UsagesPath,
@@ -225,7 +239,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx
}
}
return cred, nil
return credential, nil
}
func (c *externalCredential) start() error {
@@ -237,6 +251,8 @@ func (c *externalCredential) start() error {
}
if c.reverse && c.connectorURL != nil {
go c.connectorLoop()
} else {
go c.statusStreamLoop()
}
return nil
}
@@ -245,6 +261,16 @@ func (c *externalCredential) setOnBecameUnusable(fn func()) {
c.onBecameUnusable = fn
}
func (c *externalCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) {
c.statusSubscriber = subscriber
}
func (c *externalCredential) emitStatusUpdate() {
if c.statusSubscriber != nil {
c.statusSubscriber.Emit(struct{}{})
}
}
func (c *externalCredential) tagName() string {
return c.tag
}
@@ -261,39 +287,43 @@ func (c *externalCredential) isUsable() bool {
if !c.isAvailable() {
return false
}
c.stateMutex.RLock()
c.stateAccess.RLock()
if c.state.consecutivePollFailures > 0 {
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
return false
}
if !c.state.upstreamRejectedUntil.IsZero() && time.Now().Before(c.state.upstreamRejectedUntil) {
c.stateAccess.RUnlock()
return false
}
if c.state.hardRateLimited {
if time.Now().Before(c.state.rateLimitResetAt) {
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
return false
}
c.stateMutex.RUnlock()
c.stateMutex.Lock()
c.stateAccess.RUnlock()
c.stateAccess.Lock()
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
c.stateMutex.Unlock()
c.stateAccess.Unlock()
return usable
}
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
return usable
}
func (c *externalCredential) fiveHourUtilization() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.fiveHourUtilization
}
func (c *externalCredential) weeklyUtilization() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.weeklyUtilization
}
@@ -306,8 +336,8 @@ func (c *externalCredential) weeklyCap() float64 {
}
func (c *externalCredential) planWeight() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
if c.state.remotePlanWeight > 0 {
return c.state.remotePlanWeight
}
@@ -315,26 +345,58 @@ func (c *externalCredential) planWeight() float64 {
}
func (c *externalCredential) weeklyResetTime() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.weeklyReset
}
func (c *externalCredential) fiveHourResetTime() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.fiveHourReset
}
func (c *externalCredential) markRateLimited(resetAt time.Time) {
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.hardRateLimited = true
c.state.rateLimitResetAt = resetAt
c.state.setAvailability(availabilityStateRateLimited, availabilityReasonHardRateLimit, resetAt)
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *externalCredential) markUpstreamRejected() {
c.logger.Warn("upstream rejected credential ", c.tag, ", marking unavailable for ", log.FormatDuration(defaultPollInterval))
c.stateAccess.Lock()
c.state.upstreamRejectedUntil = time.Now().Add(defaultPollInterval)
c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonUpstreamRejected, c.state.upstreamRejectedUntil)
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *externalCredential) markTemporarilyBlocked(reason availabilityReason, resetAt time.Time) {
c.stateAccess.Lock()
c.state.setAvailability(availabilityStateTemporarilyBlocked, reason, resetAt)
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *externalCredential) earliestReset() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
if c.state.hardRateLimited {
return c.state.rateLimitResetAt
}
@@ -346,9 +408,6 @@ func (c *externalCredential) earliestReset() time.Time {
}
func (c *externalCredential) unavailableError() error {
if c.reverse && c.connectorURL != nil {
return E.New("credential ", c.tag, " is unavailable: reverse connector credentials cannot serve local requests")
}
if c.baseURL == reverseProxyBaseURL {
session := c.getReverseSession()
if session == nil || session.IsClosed() {
@@ -363,7 +422,14 @@ func (c *externalCredential) getAccessToken() (string, error) {
}
func (c *externalCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, _ http.Header) (*http.Request, error) {
proxyURL := c.baseURL + original.URL.RequestURI()
baseURL := c.baseURL
if c.reverseHTTPClient != nil {
session := c.getReverseSession()
if session != nil && !session.IsClosed() {
baseURL = reverseProxyBaseURL
}
}
proxyURL := baseURL + original.URL.RequestURI()
var body io.Reader
if bodyBytes != nil {
body = bytes.NewReader(bodyBytes)
@@ -376,7 +442,7 @@ func (c *externalCredential) buildProxyRequest(ctx context.Context, original *ht
}
for key, values := range original.Header {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && !isAPIKeyHeader(key) && key != "Authorization" {
proxyRequest.Header[key] = values
}
}
@@ -413,10 +479,13 @@ func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Con
}
func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
c.stateMutex.Lock()
c.stateAccess.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
oldPlanWeight := c.state.remotePlanWeight
oldFiveHourReset := c.state.fiveHourReset
oldWeeklyReset := c.state.weeklyReset
hadData := false
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
@@ -463,9 +532,15 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
c.state.remotePlanWeight = value
}
}
if snapshots := parseRateLimitSnapshotsFromHeaders(headers); len(snapshots) > 0 {
hadData = true
applyRateLimitSnapshotsLocked(&c.state, snapshots, headers.Get("x-codex-active-limit"), c.state.remotePlanWeight, c.state.accountType)
}
if hadData {
c.state.consecutivePollFailures = 0
c.state.upstreamRejectedUntil = time.Time{}
c.state.lastUpdated = time.Now()
c.state.noteSnapshotData()
}
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
resetSuffix := ""
@@ -474,15 +549,23 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
}
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
}
utilizationChanged := c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly
planWeightChanged := c.state.remotePlanWeight != oldPlanWeight
resetChanged := c.state.fiveHourReset != oldFiveHourReset || c.state.weeklyReset != oldWeeklyReset
shouldEmit := (hadData && (utilizationChanged || resetChanged)) || planWeightChanged
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
if shouldEmit {
c.emitStatusUpdate()
}
}
func (c *externalCredential) checkTransitionLocked() bool {
unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100 || c.state.consecutivePollFailures > 0
upstreamRejected := !c.state.upstreamRejectedUntil.IsZero() && time.Now().Before(c.state.upstreamRejectedUntil)
unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100 || c.state.consecutivePollFailures > 0 || upstreamRejected
if unusable && !c.interrupted {
c.interrupted = true
return true
@@ -502,9 +585,9 @@ func (c *externalCredential) wrapRequestContext(parent context.Context) *credent
cancel()
})
return &credentialRequestContext{
Context: derived,
releaseFunc: stop,
cancelFunc: cancel,
Context: derived,
releaseFuncs: []func() bool{stop},
cancelFunc: cancel,
}
}
@@ -519,29 +602,58 @@ func (c *externalCredential) interruptConnections() {
}
}
func (c *externalCredential) pollUsage(ctx context.Context) {
func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Response, error) {
buildRequest := func(baseURL string) func() (*http.Request, error) {
return func() (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/ocm/v1/status", nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+c.token)
return request, nil
}
}
// Try reverse transport first (single attempt, no retry)
if c.reverseHTTPClient != nil {
session := c.getReverseSession()
if session != nil && !session.IsClosed() {
request, err := buildRequest(reverseProxyBaseURL)()
if err != nil {
return nil, err
}
reverseClient := &http.Client{
Transport: c.reverseHTTPClient.Transport,
Timeout: 5 * time.Second,
}
response, err := reverseClient.Do(request)
if err == nil {
return response, nil
}
// Reverse failed, fall through to forward if available
}
}
// Forward transport with retries
if c.forwardHTTPClient != nil {
forwardClient := &http.Client{
Transport: c.forwardHTTPClient.Transport,
Timeout: 5 * time.Second,
}
return doHTTPWithRetry(ctx, forwardClient, buildRequest(c.baseURL))
}
return nil, E.New("no transport available")
}
func (c *externalCredential) pollUsage() {
if !c.pollAccess.TryLock() {
return
}
defer c.pollAccess.Unlock()
defer c.markUsagePollAttempted()
statusURL := c.baseURL + "/ocm/v1/status"
httpClient := &http.Client{
Transport: c.httpClient.Transport,
Timeout: 5 * time.Second,
}
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+c.token)
return request, nil
})
ctx := c.getReverseContext()
response, err := c.doPollUsageRequest(ctx)
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": ", err)
c.logger.Debug("poll usage for ", c.tag, ": ", err)
c.incrementPollFailures()
return
}
@@ -550,38 +662,52 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
// 404 means the remote does not have a status endpoint yet;
// usage will be updated passively from response headers.
if response.StatusCode == http.StatusNotFound {
c.stateMutex.Lock()
c.state.consecutivePollFailures = 0
c.checkTransitionLocked()
c.stateMutex.Unlock()
} else {
c.incrementPollFailures()
}
c.incrementPollFailures()
return
}
var statusResponse struct {
FiveHourUtilization float64 `json:"five_hour_utilization"`
WeeklyUtilization float64 `json:"weekly_utilization"`
PlanWeight float64 `json:"plan_weight"`
body, err := io.ReadAll(response.Body)
if err != nil {
c.logger.Debug("poll usage for ", c.tag, ": read body: ", err)
c.incrementPollFailures()
return
}
err = json.NewDecoder(response.Body).Decode(&statusResponse)
var rawFields map[string]json.RawMessage
err = json.Unmarshal(body, &rawFields)
if err != nil {
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
c.incrementPollFailures()
return
}
if rawFields["limits"] == nil && (rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil ||
rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil ||
rawFields["plan_weight"] == nil) {
c.logger.Error("poll usage for ", c.tag, ": invalid response")
c.incrementPollFailures()
return
}
var statusResponse statusPayload
err = json.Unmarshal(body, &statusResponse)
if err != nil {
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
c.incrementPollFailures()
return
}
c.stateMutex.Lock()
c.stateAccess.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
c.state.consecutivePollFailures = 0
c.state.upstreamRejectedUntil = time.Time{}
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
if statusResponse.FiveHourReset > 0 {
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
}
if statusResponse.WeeklyReset > 0 {
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
}
if statusResponse.PlanWeight > 0 {
c.state.remotePlanWeight = statusResponse.PlanWeight
}
@@ -596,54 +722,238 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
}
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *externalCredential) statusStreamLoop() {
var consecutiveFailures int
ctx := c.getReverseContext()
for {
select {
case <-ctx.Done():
return
default:
}
result, err := c.connectStatusStream(ctx)
if ctx.Err() != nil {
return
}
if !shouldRetryStatusStreamError(err) {
c.logger.Warn("status stream for ", c.tag, " disconnected: ", err, ", not retrying")
return
}
var backoff time.Duration
consecutiveFailures, backoff = c.nextStatusStreamBackoff(result, consecutiveFailures)
c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff)
timer := time.NewTimer(backoff)
select {
case <-timer.C:
case <-ctx.Done():
timer.Stop()
return
}
}
}
func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStreamResult, error) {
startTime := time.Now()
result := statusStreamResult{}
response, err := c.doStreamStatusRequest(ctx)
if err != nil {
result.duration = time.Since(startTime)
return result, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
result.duration = time.Since(startTime)
return result, E.New("status ", response.StatusCode, " ", string(body))
}
decoder := json.NewDecoder(response.Body)
for {
var rawMessage json.RawMessage
err = decoder.Decode(&rawMessage)
if err != nil {
result.duration = time.Since(startTime)
return result, err
}
var rawFields map[string]json.RawMessage
err = json.Unmarshal(rawMessage, &rawFields)
if err != nil {
result.duration = time.Since(startTime)
return result, E.Cause(err, "decode status frame")
}
if rawFields["limits"] == nil && (rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil ||
rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil ||
rawFields["plan_weight"] == nil) {
result.duration = time.Since(startTime)
return result, E.New("invalid response")
}
var statusResponse statusPayload
err = json.Unmarshal(rawMessage, &statusResponse)
if err != nil {
result.duration = time.Since(startTime)
return result, E.Cause(err, "decode status frame")
}
c.stateAccess.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
c.state.consecutivePollFailures = 0
c.state.upstreamRejectedUntil = time.Time{}
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
if statusResponse.FiveHourReset > 0 {
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
}
if statusResponse.WeeklyReset > 0 {
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
}
if statusResponse.PlanWeight > 0 {
c.state.remotePlanWeight = statusResponse.PlanWeight
}
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
resetSuffix := ""
if !c.state.weeklyReset.IsZero() {
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
}
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
}
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
result.frames++
c.markUsageStreamUpdated()
c.emitStatusUpdate()
}
}
func shouldRetryStatusStreamError(err error) bool {
return errors.Is(err, io.ErrUnexpectedEOF) || E.IsClosedOrCanceled(err)
}
func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration) {
if result.duration >= connectorBackoffResetThreshold {
consecutiveFailures = 0
}
consecutiveFailures++
return consecutiveFailures, connectorBackoff(consecutiveFailures)
}
func (c *externalCredential) doStreamStatusRequest(ctx context.Context) (*http.Response, error) {
buildRequest := func(baseURL string) (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/ocm/v1/status?watch=true", nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+c.token)
return request, nil
}
if c.reverseHTTPClient != nil {
session := c.getReverseSession()
if session != nil && !session.IsClosed() {
request, err := buildRequest(reverseProxyBaseURL)
if err != nil {
return nil, err
}
response, err := c.reverseHTTPClient.Do(request)
if err == nil {
return response, nil
}
}
}
if c.forwardHTTPClient != nil {
request, err := buildRequest(c.baseURL)
if err != nil {
return nil, err
}
return c.forwardHTTPClient.Do(request)
}
return nil, E.New("no transport available")
}
func (c *externalCredential) lastUpdatedTime() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.lastUpdated
}
func (c *externalCredential) hasSnapshotData() bool {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.hasSnapshotData()
}
func (c *externalCredential) availabilityStatus() availabilityStatus {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.currentAvailability()
}
func (c *externalCredential) markUsageStreamUpdated() {
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
c.state.lastUpdated = time.Now()
}
func (c *externalCredential) markUsagePollAttempted() {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
c.state.lastUpdated = time.Now()
}
func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration {
c.stateMutex.RLock()
failures := c.state.consecutivePollFailures
c.stateMutex.RUnlock()
if failures <= 0 {
return baseInterval
}
return failedPollRetryInterval
return baseInterval
}
func (c *externalCredential) incrementPollFailures() {
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.consecutivePollFailures++
c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonPollFailed, time.Time{})
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage {
return c.usageTracker
}
func (c *externalCredential) httpTransport() *http.Client {
return c.httpClient
func (c *externalCredential) httpClient() *http.Client {
if c.reverseHTTPClient != nil {
session := c.getReverseSession()
if session != nil && !session.IsClosed() {
return c.reverseHTTPClient
}
}
return c.forwardHTTPClient
}
func (c *externalCredential) ocmDialer() N.Dialer {
return c.credDialer
if c.reverseCredentialDialer != nil {
session := c.getReverseSession()
if session != nil && !session.IsClosed() {
return c.reverseCredentialDialer
}
}
return c.credentialDialer
}
func (c *externalCredential) ocmIsAPIKeyMode() bool {
@@ -655,6 +965,12 @@ func (c *externalCredential) ocmGetAccountID() string {
}
func (c *externalCredential) ocmGetBaseURL() string {
if c.reverseHTTPClient != nil {
session := c.getReverseSession()
if session != nil && !session.IsClosed() {
return reverseProxyBaseURL
}
}
return c.baseURL
}
@@ -689,26 +1005,55 @@ func (c *externalCredential) getReverseSession() *yamux.Session {
}
func (c *externalCredential) setReverseSession(session *yamux.Session) bool {
var emitStatus bool
var restartStatusStream bool
var triggerUsageRefresh bool
c.reverseAccess.Lock()
if c.closed {
c.reverseAccess.Unlock()
return false
}
wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
old := c.reverseSession
c.reverseSession = session
isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
emitStatus = wasAvailable != isAvailable
if isAvailable && !wasAvailable {
c.reverseCancel()
c.reverseContext, c.reverseCancel = context.WithCancel(context.Background())
restartStatusStream = true
triggerUsageRefresh = true
}
c.reverseAccess.Unlock()
if old != nil {
old.Close()
}
if restartStatusStream {
c.logger.Debug("poll usage for ", c.tag, ": reverse session ready, restarting status stream")
go c.statusStreamLoop()
}
if triggerUsageRefresh {
go c.pollUsage()
}
if emitStatus {
c.emitStatusUpdate()
}
return true
}
func (c *externalCredential) clearReverseSession(session *yamux.Session) {
var emitStatus bool
c.reverseAccess.Lock()
wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
if c.reverseSession == session {
c.reverseSession = nil
}
isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
emitStatus = wasAvailable != isAvailable
c.reverseAccess.Unlock()
if emitStatus {
c.emitStatusUpdate()
}
}
func (c *externalCredential) getReverseContext() context.Context {

View File

@@ -62,10 +62,10 @@ func (c *defaultCredential) ensureCredentialWatcher() error {
}
func (c *defaultCredential) retryCredentialReloadIfNeeded() {
c.stateMutex.RLock()
c.stateAccess.RLock()
unavailable := c.state.unavailable
lastAttempt := c.state.lastCredentialLoadAttempt
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
if !unavailable {
return
}
@@ -84,10 +84,10 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
c.reloadAccess.Lock()
defer c.reloadAccess.Unlock()
c.stateMutex.RLock()
c.stateAccess.RLock()
unavailable := c.state.unavailable
lastAttempt := c.state.lastCredentialLoadAttempt
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
if !force {
if !unavailable {
return nil
@@ -97,43 +97,56 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
}
}
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.lastCredentialLoadAttempt = time.Now()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
credentials, err := platformReadCredentials(c.credentialPath)
if err != nil {
return c.markCredentialsUnavailable(E.Cause(err, "read credentials"))
}
c.accessMutex.Lock()
c.access.Lock()
c.credentials = credentials
c.accessMutex.Unlock()
c.refreshRetryAt = time.Time{}
c.refreshRetryError = nil
c.refreshBlocked = false
c.access.Unlock()
c.stateMutex.Lock()
c.stateAccess.Lock()
wasAvailable := !c.state.unavailable
c.state.unavailable = false
c.state.lastCredentialLoadError = ""
c.checkTransitionLocked()
c.stateMutex.Unlock()
shouldEmit := wasAvailable != !c.state.unavailable
c.stateAccess.Unlock()
if shouldEmit {
c.emitStatusUpdate()
}
return nil
}
func (c *defaultCredential) markCredentialsUnavailable(err error) error {
c.accessMutex.Lock()
c.access.Lock()
hadCredentials := c.credentials != nil
c.credentials = nil
c.accessMutex.Unlock()
c.access.Unlock()
c.stateMutex.Lock()
c.stateAccess.Lock()
wasAvailable := !c.state.unavailable
c.state.unavailable = true
c.state.lastCredentialLoadError = err.Error()
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
shouldEmit := wasAvailable != !c.state.unavailable
c.stateAccess.Unlock()
if shouldInterrupt && hadCredentials {
c.interruptConnections()
}
if shouldEmit {
c.emitStatusUpdate()
}
return err
}

View File

@@ -0,0 +1,233 @@
package ocm
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"os"
"os/user"
"path/filepath"
"strconv"
"time"
E "github.com/sagernet/sing/common/exceptions"
)
const (
oauth2ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
oauth2TokenURL = "https://auth.openai.com/oauth/token"
openaiAPIBaseURL = "https://api.openai.com"
chatGPTBackendURL = "https://chatgpt.com/backend-api/codex"
tokenRefreshIntervalDays = 8
)
func getRealUser() (*user.User, error) {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
sudoUserInfo, err := user.Lookup(sudoUser)
if err == nil {
return sudoUserInfo, nil
}
}
return user.Current()
}
func getDefaultCredentialsPath() (string, error) {
if codexHome := os.Getenv("CODEX_HOME"); codexHome != "" {
return filepath.Join(codexHome, "auth.json"), nil
}
userInfo, err := getRealUser()
if err != nil {
return "", err
}
return filepath.Join(userInfo.HomeDir, ".codex", "auth.json"), nil
}
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var credentials oauthCredentials
err = json.Unmarshal(data, &credentials)
if err != nil {
return nil, err
}
return &credentials, nil
}
func checkCredentialFileWritable(path string) error {
file, err := os.OpenFile(path, os.O_WRONLY, 0)
if err != nil {
return err
}
return file.Close()
}
func writeCredentialsToFile(credentials *oauthCredentials, path string) error {
data, err := json.MarshalIndent(credentials, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o600)
}
type oauthCredentials struct {
APIKey string `json:"OPENAI_API_KEY,omitempty"`
Tokens *tokenData `json:"tokens,omitempty"`
LastRefresh *time.Time `json:"last_refresh,omitempty"`
}
type tokenData struct {
IDToken string `json:"id_token,omitempty"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
AccountID string `json:"account_id,omitempty"`
}
func (c *oauthCredentials) isAPIKeyMode() bool {
return c.APIKey != ""
}
func (c *oauthCredentials) getAccessToken() string {
if c.APIKey != "" {
return c.APIKey
}
if c.Tokens != nil {
return c.Tokens.AccessToken
}
return ""
}
func (c *oauthCredentials) getAccountID() string {
if c.Tokens != nil {
return c.Tokens.AccountID
}
return ""
}
func (c *oauthCredentials) needsRefresh() bool {
if c.APIKey != "" {
return false
}
if c.Tokens == nil || c.Tokens.RefreshToken == "" {
return false
}
if c.LastRefresh == nil {
return true
}
return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour
}
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, time.Duration, error) {
if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" {
return nil, 0, E.New("refresh token is empty")
}
requestBody, err := json.Marshal(map[string]string{
"grant_type": "refresh_token",
"refresh_token": credentials.Tokens.RefreshToken,
"client_id": oauth2ClientID,
"scope": "openid profile email",
})
if err != nil {
return nil, 0, E.Cause(err, "marshal request")
}
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
return request, nil
})
if err != nil {
return nil, 0, err
}
defer response.Body.Close()
if response.StatusCode == http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
retryDelay := time.Duration(-1)
if retryAfter := response.Header.Get("Retry-After"); retryAfter != "" {
seconds, parseErr := strconv.ParseInt(retryAfter, 10, 64)
if parseErr == nil && seconds > 0 {
retryDelay = time.Duration(seconds) * time.Second
}
}
return nil, retryDelay, E.New("refresh rate limited: ", response.Status, " ", string(body))
}
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
return nil, 0, E.New("refresh failed: ", response.Status, " ", string(body))
}
var tokenResponse struct {
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
if err != nil {
return nil, 0, E.Cause(err, "decode response")
}
newCredentials := *credentials
if newCredentials.Tokens == nil {
newCredentials.Tokens = &tokenData{}
}
if tokenResponse.IDToken != "" {
newCredentials.Tokens.IDToken = tokenResponse.IDToken
}
if tokenResponse.AccessToken != "" {
newCredentials.Tokens.AccessToken = tokenResponse.AccessToken
}
if tokenResponse.RefreshToken != "" {
newCredentials.Tokens.RefreshToken = tokenResponse.RefreshToken
}
now := time.Now()
newCredentials.LastRefresh = &now
return &newCredentials, 0, nil
}
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
if credentials == nil {
return nil
}
cloned := *credentials
if credentials.Tokens != nil {
clonedTokens := *credentials.Tokens
cloned.Tokens = &clonedTokens
}
if credentials.LastRefresh != nil {
lastRefresh := *credentials.LastRefresh
cloned.LastRefresh = &lastRefresh
}
return &cloned
}
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
if left == nil || right == nil {
return left == right
}
if left.APIKey != right.APIKey {
return false
}
if (left.Tokens == nil) != (right.Tokens == nil) {
return false
}
if left.Tokens != nil && *left.Tokens != *right.Tokens {
return false
}
if (left.LastRefresh == nil) != (right.LastRefresh == nil) {
return false
}
if left.LastRefresh != nil && !left.LastRefresh.Equal(*right.LastRefresh) {
return false
}
return true
}

View File

@@ -0,0 +1,446 @@
package ocm
import (
"context"
"math/rand/v2"
"sync"
"sync/atomic"
"time"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
)
type credentialProvider interface {
selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error)
onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential
linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool
pollIfStale()
pollCredentialIfStale(credential Credential)
allCredentials() []Credential
close()
}
type singleCredentialProvider struct {
credential Credential
sessionAccess sync.RWMutex
sessions map[string]time.Time
}
func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) {
if !selection.allows(p.credential) {
return nil, false, E.New("credential ", p.credential.tagName(), " is filtered out")
}
if !p.credential.isAvailable() {
return nil, false, p.credential.unavailableError()
}
if !p.credential.isUsable() {
return nil, false, E.New("credential ", p.credential.tagName(), " is rate-limited")
}
var isNew bool
if sessionID != "" {
p.sessionAccess.Lock()
if p.sessions == nil {
p.sessions = make(map[string]time.Time)
}
_, exists := p.sessions[sessionID]
if !exists {
p.sessions[sessionID] = time.Now()
isNew = true
}
p.sessionAccess.Unlock()
}
return p.credential, isNew, nil
}
func (p *singleCredentialProvider) onRateLimited(_ string, credential Credential, resetAt time.Time, _ credentialSelection) Credential {
credential.markRateLimited(resetAt)
return nil
}
func (p *singleCredentialProvider) pollIfStale() {
now := time.Now()
p.sessionAccess.Lock()
for id, createdAt := range p.sessions {
if now.Sub(createdAt) > sessionExpiry {
delete(p.sessions, id)
}
}
p.sessionAccess.Unlock()
if time.Since(p.credential.lastUpdatedTime()) > p.credential.pollBackoff(defaultPollInterval) {
p.credential.pollUsage()
}
}
func (p *singleCredentialProvider) allCredentials() []Credential {
return []Credential{p.credential}
}
func (p *singleCredentialProvider) linkProviderInterrupt(_ Credential, _ credentialSelection, _ func()) func() bool {
return func() bool {
return false
}
}
func (p *singleCredentialProvider) pollCredentialIfStale(credential Credential) {
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) {
credential.pollUsage()
}
}
func (p *singleCredentialProvider) close() {}
type sessionEntry struct {
tag string
selectionScope credentialSelectionScope
createdAt time.Time
}
type credentialInterruptKey struct {
tag string
selectionScope credentialSelectionScope
}
type credentialInterruptEntry struct {
context context.Context
cancel context.CancelFunc
}
type balancerProvider struct {
credentials []Credential
strategy string
roundRobinIndex atomic.Uint64
rebalanceThreshold float64
sessionAccess sync.RWMutex
sessions map[string]sessionEntry
interruptAccess sync.Mutex
credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry
logger log.ContextLogger
}
func compositeCredentialSelectable(credential Credential) bool {
return !credential.ocmIsAPIKeyMode()
}
func newBalancerProvider(credentials []Credential, strategy string, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider {
return &balancerProvider{
credentials: credentials,
strategy: strategy,
rebalanceThreshold: rebalanceThreshold,
sessions: make(map[string]sessionEntry),
credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry),
logger: logger,
}
}
func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) {
selectionScope := selection.scopeOrDefault()
for {
if p.strategy == C.BalancerStrategyFallback {
best := p.pickCredential(selection.filter)
if best == nil {
return nil, false, allRateLimitedError(p.credentials)
}
return best, p.storeSessionIfAbsent(sessionID, sessionEntry{createdAt: time.Now()}), nil
}
if sessionID != "" {
p.sessionAccess.RLock()
entry, exists := p.sessions[sessionID]
p.sessionAccess.RUnlock()
if exists {
if entry.selectionScope == selectionScope {
for _, credential := range p.credentials {
if credential.tagName() == entry.tag && compositeCredentialSelectable(credential) && selection.allows(credential) && credential.isUsable() {
if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) {
better := p.pickLeastUsed(selection.filter)
if better != nil && better.tagName() != credential.tagName() {
effectiveThreshold := p.rebalanceThreshold / credential.planWeight()
delta := credential.weeklyUtilization() - better.weeklyUtilization()
if delta > effectiveThreshold {
p.logger.Info("rebalancing away from ", credential.tagName(),
": utilization delta ", delta, "% exceeds effective threshold ",
effectiveThreshold, "% (weight ", credential.planWeight(), ")")
p.rebalanceCredential(credential.tagName(), selectionScope)
break
}
}
}
return credential, false, nil
}
}
}
p.sessionAccess.Lock()
currentEntry, stillExists := p.sessions[sessionID]
if stillExists && currentEntry == entry {
delete(p.sessions, sessionID)
p.sessionAccess.Unlock()
} else {
p.sessionAccess.Unlock()
continue
}
}
}
best := p.pickCredential(selection.filter)
if best == nil {
return nil, false, allRateLimitedError(p.credentials)
}
if p.storeSessionIfAbsent(sessionID, sessionEntry{
tag: best.tagName(),
selectionScope: selectionScope,
createdAt: time.Now(),
}) {
return best, true, nil
}
if sessionID == "" {
return best, false, nil
}
}
}
func (p *balancerProvider) storeSessionIfAbsent(sessionID string, entry sessionEntry) bool {
if sessionID == "" {
return false
}
p.sessionAccess.Lock()
defer p.sessionAccess.Unlock()
if _, exists := p.sessions[sessionID]; exists {
return false
}
p.sessions[sessionID] = entry
return true
}
func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) {
key := credentialInterruptKey{tag: tag, selectionScope: selectionScope}
p.interruptAccess.Lock()
if entry, loaded := p.credentialInterrupts[key]; loaded {
entry.cancel()
}
ctx, cancel := context.WithCancel(context.Background())
p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel}
p.interruptAccess.Unlock()
p.sessionAccess.Lock()
for id, entry := range p.sessions {
if entry.tag == tag && entry.selectionScope == selectionScope {
delete(p.sessions, id)
}
}
p.sessionAccess.Unlock()
}
func (p *balancerProvider) linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool {
if p.strategy == C.BalancerStrategyFallback {
return func() bool { return false }
}
key := credentialInterruptKey{
tag: credential.tagName(),
selectionScope: selection.scopeOrDefault(),
}
p.interruptAccess.Lock()
entry, loaded := p.credentialInterrupts[key]
if !loaded {
ctx, cancel := context.WithCancel(context.Background())
entry = credentialInterruptEntry{context: ctx, cancel: cancel}
p.credentialInterrupts[key] = entry
}
p.interruptAccess.Unlock()
return context.AfterFunc(entry.context, onInterrupt)
}
func (p *balancerProvider) onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential {
credential.markRateLimited(resetAt)
if p.strategy == C.BalancerStrategyFallback {
return p.pickCredential(selection.filter)
}
if sessionID != "" {
p.sessionAccess.Lock()
delete(p.sessions, sessionID)
p.sessionAccess.Unlock()
}
best := p.pickCredential(selection.filter)
if best != nil && sessionID != "" {
p.sessionAccess.Lock()
p.sessions[sessionID] = sessionEntry{
tag: best.tagName(),
selectionScope: selection.scopeOrDefault(),
createdAt: time.Now(),
}
p.sessionAccess.Unlock()
}
return best
}
func (p *balancerProvider) pickCredential(filter func(Credential) bool) Credential {
switch p.strategy {
case C.BalancerStrategyRoundRobin:
return p.pickRoundRobin(filter)
case C.BalancerStrategyRandom:
return p.pickRandom(filter)
case C.BalancerStrategyFallback:
return p.pickFallback(filter)
default:
return p.pickLeastUsed(filter)
}
}
func (p *balancerProvider) pickFallback(filter func(Credential) bool) Credential {
for _, credential := range p.credentials {
if filter != nil && !filter(credential) {
continue
}
if !compositeCredentialSelectable(credential) {
continue
}
if credential.isUsable() {
return credential
}
}
return nil
}
const weeklyWindowHours = 7 * 24
func (p *balancerProvider) pickLeastUsed(filter func(Credential) bool) Credential {
var best Credential
bestScore := float64(-1)
now := time.Now()
for _, credential := range p.credentials {
if filter != nil && !filter(credential) {
continue
}
if !compositeCredentialSelectable(credential) {
continue
}
if !credential.isUsable() {
continue
}
remaining := credential.weeklyCap() - credential.weeklyUtilization()
score := remaining * credential.planWeight()
resetTime := credential.weeklyResetTime()
if !resetTime.IsZero() {
timeUntilReset := resetTime.Sub(now)
if timeUntilReset < time.Hour {
timeUntilReset = time.Hour
}
score *= weeklyWindowHours / timeUntilReset.Hours()
}
if score > bestScore {
bestScore = score
best = credential
}
}
return best
}
func ocmPlanWeight(accountType string) float64 {
switch accountType {
case "pro":
return 10
case "plus":
return 1
default:
return 1
}
}
func (p *balancerProvider) pickRoundRobin(filter func(Credential) bool) Credential {
start := int(p.roundRobinIndex.Add(1) - 1)
count := len(p.credentials)
for offset := range count {
candidate := p.credentials[(start+offset)%count]
if filter != nil && !filter(candidate) {
continue
}
if !compositeCredentialSelectable(candidate) {
continue
}
if candidate.isUsable() {
return candidate
}
}
return nil
}
func (p *balancerProvider) pickRandom(filter func(Credential) bool) Credential {
var usable []Credential
for _, candidate := range p.credentials {
if filter != nil && !filter(candidate) {
continue
}
if !compositeCredentialSelectable(candidate) {
continue
}
if candidate.isUsable() {
usable = append(usable, candidate)
}
}
if len(usable) == 0 {
return nil
}
return usable[rand.IntN(len(usable))]
}
func (p *balancerProvider) pollIfStale() {
now := time.Now()
p.sessionAccess.Lock()
for id, entry := range p.sessions {
if now.Sub(entry.createdAt) > sessionExpiry {
delete(p.sessions, id)
}
}
p.sessionAccess.Unlock()
p.interruptAccess.Lock()
for key, entry := range p.credentialInterrupts {
if entry.context.Err() != nil {
delete(p.credentialInterrupts, key)
}
}
p.interruptAccess.Unlock()
for _, credential := range p.credentials {
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) {
credential.pollUsage()
}
}
}
func (p *balancerProvider) pollCredentialIfStale(credential Credential) {
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) {
credential.pollUsage()
}
}
func (p *balancerProvider) allCredentials() []Credential {
return p.credentials
}
func (p *balancerProvider) close() {}
func allRateLimitedError(credentials []Credential) error {
var hasUnavailable bool
var earliest time.Time
for _, credential := range credentials {
if credential.unavailableError() != nil {
hasUnavailable = true
continue
}
resetAt := credential.earliestReset()
if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) {
earliest = resetAt
}
}
if hasUnavailable {
return E.New("all credentials unavailable")
}
if earliest.IsZero() {
return E.New("all credentials rate-limited")
}
return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest)))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,364 @@
package ocm
import (
"net/http"
"slices"
"strconv"
"strings"
"time"
)
type availabilityState string
const (
availabilityStateUsable availabilityState = "usable"
availabilityStateRateLimited availabilityState = "rate_limited"
availabilityStateTemporarilyBlocked availabilityState = "temporarily_blocked"
availabilityStateUnavailable availabilityState = "unavailable"
availabilityStateUnknown availabilityState = "unknown"
)
type availabilityReason string
const (
availabilityReasonHardRateLimit availabilityReason = "hard_rate_limit"
availabilityReasonConnectionLimit availabilityReason = "connection_limit"
availabilityReasonPollFailed availabilityReason = "poll_failed"
availabilityReasonUpstreamRejected availabilityReason = "upstream_rejected"
availabilityReasonNoCredentials availabilityReason = "no_credentials"
availabilityReasonUnknown availabilityReason = "unknown"
)
type availabilityStatus struct {
State availabilityState
Reason availabilityReason
ResetAt time.Time
}
func (s availabilityStatus) normalized() availabilityStatus {
if s.State == "" {
s.State = availabilityStateUnknown
}
if s.Reason == "" && s.State != availabilityStateUsable {
s.Reason = availabilityReasonUnknown
}
return s
}
type creditsSnapshot struct {
HasCredits bool `json:"has_credits"`
Unlimited bool `json:"unlimited"`
Balance string `json:"balance,omitempty"`
}
type rateLimitWindow struct {
UsedPercent float64 `json:"used_percent"`
WindowMinutes int64 `json:"window_minutes,omitempty"`
ResetAt int64 `json:"reset_at,omitempty"`
}
type rateLimitSnapshot struct {
LimitID string `json:"limit_id,omitempty"`
LimitName string `json:"limit_name,omitempty"`
Primary *rateLimitWindow `json:"primary,omitempty"`
Secondary *rateLimitWindow `json:"secondary,omitempty"`
Credits *creditsSnapshot `json:"credits,omitempty"`
PlanType string `json:"plan_type,omitempty"`
}
func normalizeStoredLimitID(limitID string) string {
normalized := normalizeRateLimitIdentifier(limitID)
if normalized == "" {
return ""
}
return strings.ReplaceAll(normalized, "-", "_")
}
func headerLimitID(limitID string) string {
if limitID == "" {
return "codex"
}
return strings.ReplaceAll(normalizeStoredLimitID(limitID), "_", "-")
}
func defaultRateLimitSnapshot(limitID string) rateLimitSnapshot {
if limitID == "" {
limitID = "codex"
}
return rateLimitSnapshot{LimitID: normalizeStoredLimitID(limitID)}
}
func cloneCreditsSnapshot(snapshot *creditsSnapshot) *creditsSnapshot {
if snapshot == nil {
return nil
}
cloned := *snapshot
return &cloned
}
func cloneRateLimitWindow(window *rateLimitWindow) *rateLimitWindow {
if window == nil {
return nil
}
cloned := *window
return &cloned
}
func cloneRateLimitSnapshot(snapshot rateLimitSnapshot) rateLimitSnapshot {
snapshot.Primary = cloneRateLimitWindow(snapshot.Primary)
snapshot.Secondary = cloneRateLimitWindow(snapshot.Secondary)
snapshot.Credits = cloneCreditsSnapshot(snapshot.Credits)
return snapshot
}
func sortRateLimitSnapshots(snapshots []rateLimitSnapshot) {
slices.SortFunc(snapshots, func(a, b rateLimitSnapshot) int {
return strings.Compare(a.LimitID, b.LimitID)
})
}
func parseHeaderFloat(headers http.Header, name string) (float64, bool) {
value := strings.TrimSpace(headers.Get(name))
if value == "" {
return 0, false
}
parsed, err := strconv.ParseFloat(value, 64)
if err != nil {
return 0, false
}
if !isFinite(parsed) {
return 0, false
}
return parsed, true
}
func isFinite(value float64) bool {
return !((value != value) || value > 1e308 || value < -1e308)
}
func parseCreditsSnapshotFromHeaders(headers http.Header) *creditsSnapshot {
hasCreditsValue := strings.TrimSpace(headers.Get("x-codex-credits-has-credits"))
unlimitedValue := strings.TrimSpace(headers.Get("x-codex-credits-unlimited"))
if hasCreditsValue == "" || unlimitedValue == "" {
return nil
}
hasCredits := strings.EqualFold(hasCreditsValue, "true") || hasCreditsValue == "1"
unlimited := strings.EqualFold(unlimitedValue, "true") || unlimitedValue == "1"
return &creditsSnapshot{
HasCredits: hasCredits,
Unlimited: unlimited,
Balance: strings.TrimSpace(headers.Get("x-codex-credits-balance")),
}
}
func parseRateLimitWindowFromHeaders(headers http.Header, prefix string, windowName string) *rateLimitWindow {
usedPercent, hasPercent := parseHeaderFloat(headers, prefix+"-"+windowName+"-used-percent")
windowMinutes, hasWindow := parseInt64Header(headers, prefix+"-"+windowName+"-window-minutes")
resetAt, hasReset := parseInt64Header(headers, prefix+"-"+windowName+"-reset-at")
if !hasPercent && !hasWindow && !hasReset {
return nil
}
window := &rateLimitWindow{}
if hasPercent {
window.UsedPercent = usedPercent
}
if hasWindow {
window.WindowMinutes = windowMinutes
}
if hasReset {
window.ResetAt = resetAt
}
return window
}
func parseRateLimitSnapshotsFromHeaders(headers http.Header) []rateLimitSnapshot {
limitIDs := map[string]struct{}{}
for key := range headers {
lowerKey := strings.ToLower(key)
if strings.HasPrefix(lowerKey, "x-") && strings.Contains(lowerKey, "-primary-") {
limitID := strings.TrimPrefix(lowerKey, "x-")
if suffix := strings.Index(limitID, "-primary-"); suffix > 0 {
limitIDs[normalizeStoredLimitID(limitID[:suffix])] = struct{}{}
}
}
if strings.HasPrefix(lowerKey, "x-") && strings.Contains(lowerKey, "-secondary-") {
limitID := strings.TrimPrefix(lowerKey, "x-")
if suffix := strings.Index(limitID, "-secondary-"); suffix > 0 {
limitIDs[normalizeStoredLimitID(limitID[:suffix])] = struct{}{}
}
}
}
if activeLimit := normalizeStoredLimitID(headers.Get("x-codex-active-limit")); activeLimit != "" {
limitIDs[activeLimit] = struct{}{}
}
if credits := parseCreditsSnapshotFromHeaders(headers); credits != nil {
_ = credits
limitIDs["codex"] = struct{}{}
}
if len(limitIDs) == 0 {
return nil
}
snapshots := make([]rateLimitSnapshot, 0, len(limitIDs))
for limitID := range limitIDs {
prefix := "x-" + headerLimitID(limitID)
snapshot := defaultRateLimitSnapshot(limitID)
snapshot.LimitName = strings.TrimSpace(headers.Get(prefix + "-limit-name"))
snapshot.Primary = parseRateLimitWindowFromHeaders(headers, prefix, "primary")
snapshot.Secondary = parseRateLimitWindowFromHeaders(headers, prefix, "secondary")
if limitID == "codex" {
snapshot.Credits = parseCreditsSnapshotFromHeaders(headers)
}
if snapshot.Primary == nil && snapshot.Secondary == nil && snapshot.Credits == nil {
continue
}
snapshots = append(snapshots, snapshot)
}
sortRateLimitSnapshots(snapshots)
return snapshots
}
type usageRateLimitWindowPayload struct {
UsedPercent float64 `json:"used_percent"`
LimitWindowSeconds int64 `json:"limit_window_seconds"`
ResetAt int64 `json:"reset_at"`
}
type usageRateLimitDetailsPayload struct {
PrimaryWindow *usageRateLimitWindowPayload `json:"primary_window"`
SecondaryWindow *usageRateLimitWindowPayload `json:"secondary_window"`
}
type usageCreditsPayload struct {
HasCredits bool `json:"has_credits"`
Unlimited bool `json:"unlimited"`
Balance *string `json:"balance"`
}
type additionalRateLimitPayload struct {
LimitName string `json:"limit_name"`
MeteredFeature string `json:"metered_feature"`
RateLimit *usageRateLimitDetailsPayload `json:"rate_limit"`
}
type usageRateLimitStatusPayload struct {
PlanType string `json:"plan_type"`
RateLimit *usageRateLimitDetailsPayload `json:"rate_limit"`
Credits *usageCreditsPayload `json:"credits"`
AdditionalRateLimits []additionalRateLimitPayload `json:"additional_rate_limits"`
}
func windowFromUsagePayload(window *usageRateLimitWindowPayload) *rateLimitWindow {
if window == nil {
return nil
}
result := &rateLimitWindow{
UsedPercent: window.UsedPercent,
}
if window.LimitWindowSeconds > 0 {
result.WindowMinutes = (window.LimitWindowSeconds + 59) / 60
}
if window.ResetAt > 0 {
result.ResetAt = window.ResetAt
}
return result
}
func snapshotsFromUsagePayload(payload usageRateLimitStatusPayload) []rateLimitSnapshot {
snapshots := make([]rateLimitSnapshot, 0, 1+len(payload.AdditionalRateLimits))
codex := defaultRateLimitSnapshot("codex")
codex.PlanType = payload.PlanType
if payload.RateLimit != nil {
codex.Primary = windowFromUsagePayload(payload.RateLimit.PrimaryWindow)
codex.Secondary = windowFromUsagePayload(payload.RateLimit.SecondaryWindow)
}
if payload.Credits != nil {
codex.Credits = &creditsSnapshot{
HasCredits: payload.Credits.HasCredits,
Unlimited: payload.Credits.Unlimited,
}
if payload.Credits.Balance != nil {
codex.Credits.Balance = *payload.Credits.Balance
}
}
if codex.Primary != nil || codex.Secondary != nil || codex.Credits != nil || codex.PlanType != "" {
snapshots = append(snapshots, codex)
}
for _, additional := range payload.AdditionalRateLimits {
snapshot := defaultRateLimitSnapshot(additional.MeteredFeature)
snapshot.LimitName = additional.LimitName
snapshot.PlanType = payload.PlanType
if additional.RateLimit != nil {
snapshot.Primary = windowFromUsagePayload(additional.RateLimit.PrimaryWindow)
snapshot.Secondary = windowFromUsagePayload(additional.RateLimit.SecondaryWindow)
}
if snapshot.Primary == nil && snapshot.Secondary == nil {
continue
}
snapshots = append(snapshots, snapshot)
}
sortRateLimitSnapshots(snapshots)
return snapshots
}
func applyRateLimitSnapshotsLocked(state *credentialState, snapshots []rateLimitSnapshot, activeLimitID string, planWeight float64, planType string) {
if len(snapshots) == 0 {
return
}
if state.rateLimitSnapshots == nil {
state.rateLimitSnapshots = make(map[string]rateLimitSnapshot, len(snapshots))
} else {
clear(state.rateLimitSnapshots)
}
for _, snapshot := range snapshots {
snapshot = cloneRateLimitSnapshot(snapshot)
if snapshot.LimitID == "" {
snapshot.LimitID = "codex"
}
if snapshot.LimitName == "" && snapshot.LimitID != "codex" {
snapshot.LimitName = strings.ReplaceAll(snapshot.LimitID, "_", "-")
}
if snapshot.PlanType == "" {
snapshot.PlanType = planType
}
state.rateLimitSnapshots[snapshot.LimitID] = snapshot
}
if planWeight > 0 {
state.remotePlanWeight = planWeight
}
if planType != "" {
state.accountType = planType
}
if normalizedActive := normalizeStoredLimitID(activeLimitID); normalizedActive != "" {
state.activeLimitID = normalizedActive
} else if state.activeLimitID == "" {
if _, exists := state.rateLimitSnapshots["codex"]; exists {
state.activeLimitID = "codex"
} else {
for limitID := range state.rateLimitSnapshots {
state.activeLimitID = limitID
break
}
}
}
legacy := state.rateLimitSnapshots["codex"]
if legacy.LimitID == "" && state.activeLimitID != "" {
legacy = state.rateLimitSnapshots[state.activeLimitID]
}
state.fiveHourUtilization = 0
state.fiveHourReset = time.Time{}
state.weeklyUtilization = 0
state.weeklyReset = time.Time{}
if legacy.Primary != nil {
state.fiveHourUtilization = legacy.Primary.UsedPercent
if legacy.Primary.ResetAt > 0 {
state.fiveHourReset = time.Unix(legacy.Primary.ResetAt, 0)
}
}
if legacy.Secondary != nil {
state.weeklyUtilization = legacy.Secondary.UsedPercent
if legacy.Secondary.ResetAt > 0 {
state.weeklyReset = time.Unix(legacy.Secondary.ResetAt, 0)
}
}
state.noteSnapshotData()
}

View File

@@ -0,0 +1,88 @@
package ocm
import (
"encoding/json"
"strings"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/responses"
)
type requestLogMetadata struct {
Model string
ServiceTier string
ReasoningEffort string
}
type legacyReasoningEffortPayload struct {
ReasoningEffort string `json:"reasoning_effort"`
}
func requestLogMetadataFromChatCompletionRequest(request openai.ChatCompletionNewParams) requestLogMetadata {
return requestLogMetadata{
Model: string(request.Model),
ServiceTier: string(request.ServiceTier),
ReasoningEffort: string(request.ReasoningEffort),
}
}
func requestLogMetadataFromResponsesRequest(request responses.ResponseNewParams, legacyReasoningEffort string) requestLogMetadata {
metadata := requestLogMetadata{
Model: string(request.Model),
ServiceTier: string(request.ServiceTier),
}
if request.Reasoning.Effort != "" {
metadata.ReasoningEffort = string(request.Reasoning.Effort)
}
if metadata.ReasoningEffort == "" {
metadata.ReasoningEffort = legacyReasoningEffort
}
return metadata
}
func parseLegacyReasoningEffort(data []byte) string {
var legacy legacyReasoningEffortPayload
if json.Unmarshal(data, &legacy) != nil {
return ""
}
return legacy.ReasoningEffort
}
func parseRequestLogMetadata(path string, data []byte) requestLogMetadata {
switch {
case path == "/v1/chat/completions":
var request openai.ChatCompletionNewParams
if json.Unmarshal(data, &request) != nil {
return requestLogMetadata{}
}
return requestLogMetadataFromChatCompletionRequest(request)
case strings.HasPrefix(path, "/v1/responses"):
var request responses.ResponseNewParams
if json.Unmarshal(data, &request) != nil {
return requestLogMetadata{}
}
return requestLogMetadataFromResponsesRequest(request, parseLegacyReasoningEffort(data))
default:
return requestLogMetadata{}
}
}
func buildAssignedCredentialLogParts(credentialTag string, sessionID string, username string, metadata requestLogMetadata) []any {
logParts := []any{"assigned credential ", credentialTag}
if sessionID != "" {
logParts = append(logParts, " for session ", sessionID)
}
if username != "" {
logParts = append(logParts, " by user ", username)
}
if metadata.Model != "" {
logParts = append(logParts, ", model=", metadata.Model)
}
if metadata.ReasoningEffort != "" {
logParts = append(logParts, ", think=", metadata.ReasoningEffort)
}
if metadata.ServiceTier == "priority" {
logParts = append(logParts, ", fast")
}
return logParts
}

View File

@@ -0,0 +1,126 @@
package ocm
import (
"strings"
"testing"
F "github.com/sagernet/sing/common/format"
)
func TestParseRequestLogMetadata(t *testing.T) {
t.Parallel()
metadata := parseRequestLogMetadata("/v1/responses", []byte(`{
"model":"gpt-5.4",
"service_tier":"priority",
"reasoning":{"effort":"xhigh"}
}`))
if metadata.Model != "gpt-5.4" {
t.Fatalf("expected model gpt-5.4, got %q", metadata.Model)
}
if metadata.ServiceTier != "priority" {
t.Fatalf("expected priority service tier, got %q", metadata.ServiceTier)
}
if metadata.ReasoningEffort != "xhigh" {
t.Fatalf("expected xhigh reasoning effort, got %q", metadata.ReasoningEffort)
}
}
func TestParseRequestLogMetadataFallsBackToTopLevelReasoningEffort(t *testing.T) {
t.Parallel()
metadata := parseRequestLogMetadata("/v1/responses", []byte(`{
"model":"gpt-5.4",
"reasoning_effort":"high"
}`))
if metadata.ReasoningEffort != "high" {
t.Fatalf("expected high reasoning effort, got %q", metadata.ReasoningEffort)
}
}
func TestParseRequestLogMetadataFromChatCompletions(t *testing.T) {
t.Parallel()
metadata := parseRequestLogMetadata("/v1/chat/completions", []byte(`{
"model":"gpt-5.4",
"service_tier":"priority",
"reasoning_effort":"xhigh",
"messages":[{"role":"user","content":"hi"}]
}`))
if metadata.Model != "gpt-5.4" {
t.Fatalf("expected model gpt-5.4, got %q", metadata.Model)
}
if metadata.ServiceTier != "priority" {
t.Fatalf("expected priority service tier, got %q", metadata.ServiceTier)
}
if metadata.ReasoningEffort != "xhigh" {
t.Fatalf("expected xhigh reasoning effort, got %q", metadata.ReasoningEffort)
}
}
func TestParseRequestLogMetadataIgnoresUnsupportedPath(t *testing.T) {
t.Parallel()
metadata := parseRequestLogMetadata("/v1/files", []byte(`{"model":"gpt-5.4"}`))
if metadata != (requestLogMetadata{}) {
t.Fatalf("expected zero metadata, got %#v", metadata)
}
}
func TestBuildAssignedCredentialLogPartsIncludesThinkLevel(t *testing.T) {
t.Parallel()
message := F.ToString(buildAssignedCredentialLogParts("a", "session-1", "alice", requestLogMetadata{
Model: "gpt-5.4",
ServiceTier: "priority",
ReasoningEffort: "xhigh",
})...)
for _, fragment := range []string{
"assigned credential a",
"for session session-1",
"by user alice",
"model=gpt-5.4",
"think=xhigh",
"fast",
} {
if !strings.Contains(message, fragment) {
t.Fatalf("expected %q in %q", fragment, message)
}
}
}
func TestParseWebSocketResponseCreateRequestIncludesThinkLevel(t *testing.T) {
t.Parallel()
request, ok := parseWebSocketResponseCreateRequest([]byte(`{
"type":"response.create",
"model":"gpt-5.4",
"reasoning":{"effort":"xhigh"}
}`))
if !ok {
t.Fatal("expected websocket response.create request to parse")
}
if request.metadata().ReasoningEffort != "xhigh" {
t.Fatalf("expected xhigh reasoning effort, got %q", request.metadata().ReasoningEffort)
}
}
func TestParseWebSocketResponseCreateRequestFallsBackToLegacyReasoningEffort(t *testing.T) {
t.Parallel()
request, ok := parseWebSocketResponseCreateRequest([]byte(`{
"type":"response.create",
"model":"gpt-5.4",
"reasoning_effort":"high"
}`))
if !ok {
t.Fatal("expected websocket response.create request to parse")
}
if request.metadata().ReasoningEffort != "high" {
t.Fatalf("expected high reasoning effort, got %q", request.metadata().ReasoningEffort)
}
}

View File

@@ -4,7 +4,6 @@ import (
"bufio"
"context"
stdTLS "crypto/tls"
"errors"
"io"
"math/rand/v2"
"net"
@@ -18,14 +17,14 @@ import (
"github.com/hashicorp/yamux"
)
func reverseYamuxConfig() *yamux.Config {
var defaultYamuxConfig = func() *yamux.Config {
config := yamux.DefaultConfig()
config.KeepAliveInterval = 15 * time.Second
config.ConnectionWriteTimeout = 10 * time.Second
config.MaxStreamWindowSize = 512 * 1024
config.LogOutput = io.Discard
return config
}
}()
type bufferedConn struct {
reader *bufio.Reader
@@ -58,6 +57,12 @@ func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWrite
return
}
if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"API key authentication is not supported; use Authorization: Bearer with an OCM user token")
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
@@ -103,7 +108,7 @@ func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWrite
return
}
session, err := yamux.Client(conn, reverseYamuxConfig())
session, err := yamux.Client(conn, defaultYamuxConfig)
if err != nil {
conn.Close()
s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
@@ -124,13 +129,13 @@ func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWrite
}
func (s *Service) findReceiverCredential(token string) *externalCredential {
for _, cred := range s.allCredentials {
extCred, ok := cred.(*externalCredential)
if !ok {
for _, credential := range s.allCredentials {
external, ok := credential.(*externalCredential)
if !ok || external.connectorURL != nil {
continue
}
if extCred.baseURL == reverseProxyBaseURL && extCred.token == token {
return extCred
if external.token == token {
return external
}
}
return nil
@@ -156,9 +161,11 @@ func (c *externalCredential) connectorLoop() {
consecutiveFailures++
backoff := connectorBackoff(consecutiveFailures)
c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff)
timer := time.NewTimer(backoff)
select {
case <-time.After(backoff):
case <-timer.C:
case <-ctx.Done():
timer.Stop()
return
}
}
@@ -231,7 +238,7 @@ func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duratio
}
}
session, err := yamux.Server(&bufferedConn{reader: reader, Conn: conn}, reverseYamuxConfig())
session, err := yamux.Server(&bufferedConn{reader: reader, Conn: conn}, defaultYamuxConfig)
if err != nil {
conn.Close()
return 0, E.Cause(err, "create yamux server")
@@ -248,7 +255,7 @@ func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duratio
}
err = httpServer.Serve(&yamuxNetListener{session: session})
sessionLifetime := time.Since(serveStart)
if err != nil && !errors.Is(err, http.ErrServerClosed) && ctx.Err() == nil {
if err != nil && !E.IsClosed(err) && ctx.Err() == nil {
return sessionLifetime, E.Cause(err, "serve")
}
return sessionLifetime, E.New("connection closed")

View File

@@ -1,17 +1,12 @@
package ocm
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"mime"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
boxService "github.com/sagernet/sing-box/adapter/service"
@@ -21,14 +16,13 @@ import (
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/observable"
aTLS "github.com/sagernet/sing/common/tls"
"github.com/go-chi/chi/v5"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/responses"
openaishared "github.com/openai/openai-go/v3/shared"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
@@ -38,13 +32,7 @@ func RegisterService(registry *boxService.Registry) {
}
type errorResponse struct {
Error errorDetails `json:"error"`
}
type errorDetails struct {
Type string `json:"type"`
Code string `json:"code,omitempty"`
Message string `json:"message"`
Error openaishared.ErrorObject `json:"error"`
}
func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) {
@@ -56,10 +44,11 @@ func writeJSONErrorWithCode(w http.ResponseWriter, r *http.Request, statusCode i
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(errorResponse{
Error: errorDetails{
Error: openaishared.ErrorObject{
Type: errorType,
Code: errorCode,
Message: message,
Param: "",
},
})
}
@@ -72,21 +61,20 @@ func writePlainTextError(w http.ResponseWriter, statusCode int, message string)
const (
retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential"
retryableUsageCode = "credential_usage_exhausted"
)
func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool {
func hasAlternativeCredential(provider credentialProvider, currentCredential Credential, selection credentialSelection) bool {
if provider == nil || currentCredential == nil {
return false
}
for _, cred := range provider.allCredentials() {
if cred == currentCredential {
for _, credential := range provider.allCredentials() {
if credential == currentCredential {
continue
}
if filter != nil && !filter(cred) {
if !selection.allows(credential) {
continue
}
if cred.isUsable() {
if credential.isUsable() {
return true
}
}
@@ -105,7 +93,7 @@ func unavailableCredentialMessage(provider credentialProvider, fallback string)
}
func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) {
writeJSONErrorWithCode(w, r, http.StatusServiceUnavailable, "server_error", retryableUsageCode, retryableUsageMessage)
writeJSONErrorWithCode(w, r, http.StatusTooManyRequests, "usage_limit_reached", "", retryableUsageMessage)
}
func writeNonRetryableCredentialError(w http.ResponseWriter, message string) {
@@ -116,17 +104,32 @@ func writeCredentialUnavailableError(
w http.ResponseWriter,
r *http.Request,
provider credentialProvider,
currentCredential credential,
filter func(credential) bool,
currentCredential Credential,
selection credentialSelection,
fallback string,
) {
if hasAlternativeCredential(provider, currentCredential, filter) {
if hasAlternativeCredential(provider, currentCredential, selection) {
writeRetryableUsageError(w, r)
return
}
if provider != nil && strings.HasPrefix(allRateLimitedError(provider.allCredentials()).Error(), "all credentials rate-limited") {
writeRetryableUsageError(w, r)
return
}
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, fallback))
}
func credentialSelectionForUser(userConfig *option.OCMUser) credentialSelection {
selection := credentialSelection{scope: credentialSelectionScopeAll}
if userConfig != nil && !userConfig.AllowExternalUsage {
selection.scope = credentialSelectionScopeNonExternal
selection.filter = func(credential Credential) bool {
return !credential.isExternal()
}
}
return selection
}
func isHopByHopHeader(header string) bool {
switch strings.ToLower(header) {
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
@@ -149,83 +152,57 @@ func isReverseProxyHeader(header string) bool {
}
}
func normalizeRateLimitIdentifier(limitIdentifier string) string {
trimmedIdentifier := strings.TrimSpace(strings.ToLower(limitIdentifier))
if trimmedIdentifier == "" {
return ""
func isAPIKeyHeader(header string) bool {
switch strings.ToLower(header) {
case "x-api-key", "api-key":
return true
default:
return false
}
return strings.ReplaceAll(trimmedIdentifier, "_", "-")
}
func parseInt64Header(headers http.Header, headerName string) (int64, bool) {
headerValue := strings.TrimSpace(headers.Get(headerName))
if headerValue == "" {
return 0, false
}
parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64)
if parseError != nil {
return 0, false
}
return parsedValue, true
}
func weeklyCycleHintForLimit(headers http.Header, limitIdentifier string) *WeeklyCycleHint {
normalizedLimitIdentifier := normalizeRateLimitIdentifier(limitIdentifier)
if normalizedLimitIdentifier == "" {
return nil
}
windowHeader := "x-" + normalizedLimitIdentifier + "-secondary-window-minutes"
resetHeader := "x-" + normalizedLimitIdentifier + "-secondary-reset-at"
windowMinutes, hasWindowMinutes := parseInt64Header(headers, windowHeader)
resetAtUnix, hasResetAt := parseInt64Header(headers, resetHeader)
if !hasWindowMinutes || !hasResetAt || windowMinutes <= 0 || resetAtUnix <= 0 {
return nil
}
return &WeeklyCycleHint{
WindowMinutes: windowMinutes,
ResetAt: time.Unix(resetAtUnix, 0).UTC(),
}
}
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
if activeLimitIdentifier != "" {
if activeHint := weeklyCycleHintForLimit(headers, activeLimitIdentifier); activeHint != nil {
return activeHint
}
}
return weeklyCycleHintForLimit(headers, "codex")
}
type Service struct {
boxService.Adapter
ctx context.Context
logger log.ContextLogger
options option.OCMServiceOptions
httpHeaders http.Header
listener *listener.Listener
tlsConfig tls.ServerConfig
httpServer *http.Server
userManager *UserManager
webSocketMutex sync.Mutex
webSocketGroup sync.WaitGroup
webSocketConns map[*webSocketSession]struct{}
shuttingDown bool
ctx context.Context
logger log.ContextLogger
options option.OCMServiceOptions
httpHeaders http.Header
listener *listener.Listener
tlsConfig tls.ServerConfig
httpServer *http.Server
userManager *UserManager
webSocketAccess sync.Mutex
webSocketGroup sync.WaitGroup
webSocketConns map[*webSocketSession]struct{}
shuttingDown bool
// Legacy mode
legacyCredential *defaultCredential
legacyProvider credentialProvider
// Multi-credential mode
providers map[string]credentialProvider
allCredentials []credential
userConfigMap map[string]*option.OCMUser
providers map[string]credentialProvider
allCredentials []Credential
userConfigMap map[string]*option.OCMUser
statusSubscriber *observable.Subscriber[struct{}]
statusObserver *observable.Observer[struct{}]
}
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) {
hasLegacy := options.CredentialPath != "" || options.UsagesPath != "" || options.Detour != ""
if hasLegacy && len(options.Credentials) > 0 {
return nil, E.New("credential_path/usages_path/detour and credentials are mutually exclusive")
}
if len(options.Credentials) == 0 {
options.Credentials = []option.OCMCredential{{
Type: "default",
Tag: "default",
DefaultOptions: option.OCMDefaultCredentialOptions{
CredentialPath: options.CredentialPath,
UsagesPath: options.UsagesPath,
Detour: options.Detour,
},
}}
options.CredentialPath = ""
options.UsagesPath = ""
options.Detour = ""
}
err := validateOCMOptions(options)
if err != nil {
return nil, E.Cause(err, "validate options")
@@ -235,6 +212,8 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
tokenMap: make(map[string]string),
}
statusSubscriber := observable.NewSubscriber[struct{}](16)
service := &Service{
Adapter: boxService.NewAdapter(C.TypeOCM, tag),
ctx: ctx,
@@ -247,36 +226,24 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
Network: []string{N.NetworkTCP},
Listen: options.ListenOptions,
}),
userManager: userManager,
webSocketConns: make(map[*webSocketSession]struct{}),
userManager: userManager,
statusSubscriber: statusSubscriber,
statusObserver: observable.NewObserver[struct{}](statusSubscriber, 8),
webSocketConns: make(map[*webSocketSession]struct{}),
}
if len(options.Credentials) > 0 {
providers, allCredentials, err := buildOCMCredentialProviders(ctx, options, logger)
if err != nil {
return nil, E.Cause(err, "build credential providers")
}
service.providers = providers
service.allCredentials = allCredentials
userConfigMap := make(map[string]*option.OCMUser)
for i := range options.Users {
userConfigMap[options.Users[i].Name] = &options.Users[i]
}
service.userConfigMap = userConfigMap
} else {
cred, err := newDefaultCredential(ctx, "default", option.OCMDefaultCredentialOptions{
CredentialPath: options.CredentialPath,
UsagesPath: options.UsagesPath,
Detour: options.Detour,
}, logger)
if err != nil {
return nil, err
}
service.legacyCredential = cred
service.legacyProvider = &singleCredentialProvider{cred: cred}
service.allCredentials = []credential{cred}
providers, allCredentials, err := buildOCMCredentialProviders(ctx, options, logger)
if err != nil {
return nil, E.Cause(err, "build credential providers")
}
service.providers = providers
service.allCredentials = allCredentials
userConfigMap := make(map[string]*option.OCMUser)
for i := range options.Users {
userConfigMap[options.Users[i].Name] = &options.Users[i]
}
service.userConfigMap = userConfigMap
if options.TLS != nil {
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
@@ -296,24 +263,23 @@ func (s *Service) Start(stage adapter.StartStage) error {
s.userManager.UpdateUsers(s.options.Users)
for _, cred := range s.allCredentials {
if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil {
extCred.reverseService = s
for _, credential := range s.allCredentials {
credential.setStatusSubscriber(s.statusSubscriber)
if external, ok := credential.(*externalCredential); ok && external.reverse && external.connectorURL != nil {
external.reverseService = s
}
err := cred.start()
err := credential.start()
if err != nil {
return err
}
tag := cred.tagName()
cred.setOnBecameUnusable(func() {
tag := credential.tagName()
credential.setOnBecameUnusable(func() {
s.interruptWebSocketSessionsForCredential(tag)
})
}
if len(s.options.Credentials) > 0 {
err := validateOCMCompositeCredentialModes(s.options, s.providers)
if err != nil {
return E.Cause(err, "validate loaded credentials")
}
err := validateOCMCompositeCredentialModes(s.options, s.providers)
if err != nil {
return E.Cause(err, "validate loaded credentials")
}
router := chi.NewRouter()
@@ -342,7 +308,7 @@ func (s *Service) Start(stage adapter.StartStage) error {
go func() {
serveErr := s.httpServer.Serve(tcpListener)
if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
if serveErr != nil && !E.IsClosed(serveErr) {
s.logger.Error("serve error: ", serveErr)
}
}()
@@ -350,568 +316,22 @@ func (s *Service) Start(stage adapter.StartStage) error {
return nil
}
func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) {
if len(s.options.Users) > 0 {
return credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
}
provider := noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
if provider == nil {
return nil, E.New("no credential available")
}
return provider, nil
}
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := log.ContextWithNewID(r.Context())
if r.URL.Path == "/ocm/v1/status" {
s.handleStatusEndpoint(w, r)
return
}
if r.URL.Path == "/ocm/v1/reverse" {
s.handleReverseConnect(ctx, w, r)
return
}
path := r.URL.Path
if !strings.HasPrefix(path, "/v1/") {
writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/")
return
}
var username string
if len(s.options.Users) > 0 {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
var ok bool
username, ok = s.userManager.Authenticate(clientToken)
if !ok {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
}
sessionID := r.Header.Get("session_id")
// Resolve credential provider and user config
var provider credentialProvider
var userConfig *option.OCMUser
if len(s.options.Users) > 0 {
userConfig = s.userConfigMap[username]
var err error
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
s.logger.ErrorContext(ctx, "resolve credential: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
} else {
provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
}
if provider == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
return
}
provider.pollIfStale(s.ctx)
var credentialFilter func(credential) bool
if userConfig != nil && !userConfig.AllowExternalUsage {
credentialFilter = func(c credential) bool { return !c.isExternal() }
}
selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter)
if err != nil {
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error()))
return
}
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew)
return
}
if !selectedCredential.isExternal() && selectedCredential.ocmIsAPIKeyMode() {
// API key mode path handling
} else if !selectedCredential.isExternal() {
if path == "/v1/chat/completions" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"chat completions endpoint is only available in API key mode")
return
}
}
shouldTrackUsage := selectedCredential.usageTrackerOrNil() != nil &&
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses"))
canRetryRequest := len(provider.allCredentials()) > 1
// Read body for model extraction and retry buffer when JSON replay is useful.
var bodyBytes []byte
var requestModel string
var requestServiceTier string
if r.Body != nil && (shouldTrackUsage || canRetryRequest) {
mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type"))
isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json"))
if isJSONRequest {
bodyBytes, err = io.ReadAll(r.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read request body: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
return
}
var request struct {
Model string `json:"model"`
ServiceTier string `json:"service_tier"`
}
if json.Unmarshal(bodyBytes, &request) == nil {
requestModel = request.Model
requestServiceTier = request.ServiceTier
}
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
}
if isNew {
logParts := []any{"assigned credential ", selectedCredential.tagName()}
if sessionID != "" {
logParts = append(logParts, " for session ", sessionID)
}
if username != "" {
logParts = append(logParts, " by user ", username)
}
if requestModel != "" {
logParts = append(logParts, ", model=", requestModel)
}
if requestServiceTier == "priority" {
logParts = append(logParts, ", fast")
}
s.logger.DebugContext(ctx, logParts...)
}
requestContext := selectedCredential.wrapRequestContext(r.Context())
defer func() {
requestContext.cancelRequest()
}()
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if err != nil {
s.logger.ErrorContext(ctx, "create proxy request: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
return
}
response, err := selectedCredential.httpTransport().Do(proxyRequest)
if err != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request")
return
}
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
return
}
requestContext.releaseCredentialInterrupt()
// Transparent 429 retry
for response.StatusCode == http.StatusTooManyRequests {
resetAt := parseOCMRateLimitResetFromHeaders(response.Header)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete
selectedCredential.updateStateFromHeaders(response.Header)
if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil {
response.Body.Close()
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
return
}
response.Body.Close()
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
requestContext.cancelRequest()
requestContext = nextCredential.wrapRequestContext(r.Context())
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
return
}
retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest)
if retryErr != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
return
}
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
return
}
requestContext.releaseCredentialInterrupt()
response = retryResponse
selectedCredential = nextCredential
}
defer response.Body.Close()
selectedCredential.updateStateFromHeaders(response.Header)
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
go selectedCredential.pollUsage(s.ctx)
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
return
}
// Rewrite response headers for external users
if userConfig != nil && userConfig.ExternalCredential != "" {
s.rewriteResponseHeadersForExternalUser(response.Header, userConfig)
}
for key, values := range response.Header {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
w.Header()[key] = values
}
}
w.WriteHeader(response.StatusCode)
usageTracker := selectedCredential.usageTrackerOrNil()
if usageTracker != nil && response.StatusCode == http.StatusOK &&
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) {
s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username)
} else {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
if err == nil && mediaType != "text/event-stream" {
_, _ = io.Copy(w, response.Body)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
buffer := make([]byte, buf.BufferSize)
for {
n, err := response.Body.Read(buffer)
if n > 0 {
_, writeError := w.Write(buffer[:n])
if writeError != nil {
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
return
}
}
}
}
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) {
isChatCompletions := path == "/v1/chat/completions"
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
isStreaming := err == nil && mediaType == "text/event-stream"
if !isStreaming && !isChatCompletions && response.Header.Get("Content-Type") == "" {
isStreaming = true
}
if !isStreaming {
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read response body: ", err)
return
}
var responseModel, serviceTier string
var inputTokens, outputTokens, cachedTokens int64
if isChatCompletions {
var chatCompletion openai.ChatCompletion
if json.Unmarshal(bodyBytes, &chatCompletion) == nil {
responseModel = chatCompletion.Model
serviceTier = string(chatCompletion.ServiceTier)
inputTokens = chatCompletion.Usage.PromptTokens
outputTokens = chatCompletion.Usage.CompletionTokens
cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens
}
} else {
var responsesResponse responses.Response
if json.Unmarshal(bodyBytes, &responsesResponse) == nil {
responseModel = string(responsesResponse.Model)
serviceTier = string(responsesResponse.ServiceTier)
inputTokens = responsesResponse.Usage.InputTokens
outputTokens = responsesResponse.Usage.OutputTokens
cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
}
}
if inputTokens > 0 || outputTokens > 0 {
if responseModel == "" {
responseModel = requestModel
}
if responseModel != "" {
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
inputTokens,
outputTokens,
cachedTokens,
serviceTier,
username,
time.Now(),
weeklyCycleHint,
)
}
}
_, _ = writer.Write(bodyBytes)
return
}
flusher, ok := writer.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
var inputTokens, outputTokens, cachedTokens int64
var responseModel, serviceTier string
buffer := make([]byte, buf.BufferSize)
var leftover []byte
for {
n, err := response.Body.Read(buffer)
if n > 0 {
data := append(leftover, buffer[:n]...)
lines := bytes.Split(data, []byte("\n"))
if err == nil {
leftover = lines[len(lines)-1]
lines = lines[:len(lines)-1]
} else {
leftover = nil
}
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if bytes.HasPrefix(line, []byte("data: ")) {
eventData := bytes.TrimPrefix(line, []byte("data: "))
if bytes.Equal(eventData, []byte("[DONE]")) {
continue
}
if isChatCompletions {
var chatChunk openai.ChatCompletionChunk
if json.Unmarshal(eventData, &chatChunk) == nil {
if chatChunk.Model != "" {
responseModel = chatChunk.Model
}
if chatChunk.ServiceTier != "" {
serviceTier = string(chatChunk.ServiceTier)
}
if chatChunk.Usage.PromptTokens > 0 {
inputTokens = chatChunk.Usage.PromptTokens
cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens
}
if chatChunk.Usage.CompletionTokens > 0 {
outputTokens = chatChunk.Usage.CompletionTokens
}
}
} else {
var streamEvent responses.ResponseStreamEventUnion
if json.Unmarshal(eventData, &streamEvent) == nil {
if streamEvent.Type == "response.completed" {
completedEvent := streamEvent.AsResponseCompleted()
if string(completedEvent.Response.Model) != "" {
responseModel = string(completedEvent.Response.Model)
}
if completedEvent.Response.ServiceTier != "" {
serviceTier = string(completedEvent.Response.ServiceTier)
}
if completedEvent.Response.Usage.InputTokens > 0 {
inputTokens = completedEvent.Response.Usage.InputTokens
cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens
}
if completedEvent.Response.Usage.OutputTokens > 0 {
outputTokens = completedEvent.Response.Usage.OutputTokens
}
}
}
}
}
}
_, writeError := writer.Write(buffer[:n])
if writeError != nil {
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
if responseModel == "" {
responseModel = requestModel
}
if inputTokens > 0 || outputTokens > 0 {
if responseModel != "" {
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
inputTokens,
outputTokens,
cachedTokens,
serviceTier,
username,
time.Now(),
weeklyCycleHint,
)
}
}
return
}
}
}
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
return
}
if len(s.options.Users) == 0 {
writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication")
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
username, ok := s.userManager.Authenticate(clientToken)
if !ok {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
userConfig := s.userConfigMap[username]
if userConfig == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found")
return
}
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
provider.pollIfStale(r.Context())
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]float64{
"five_hour_utilization": avgFiveHour,
"weekly_utilization": avgWeekly,
"plan_weight": totalWeight,
})
}
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) {
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
for _, cred := range provider.allCredentials() {
if !cred.isAvailable() {
continue
}
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
continue
}
if !userConfig.AllowExternalUsage && cred.isExternal() {
continue
}
weight := cred.planWeight()
remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization()
if remaining5h < 0 {
remaining5h = 0
}
remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization()
if remainingWeekly < 0 {
remainingWeekly = 0
}
totalWeightedRemaining5h += remaining5h * weight
totalWeightedRemainingWeekly += remainingWeekly * weight
totalWeight += weight
}
if totalWeight == 0 {
return 100, 100, 0
}
return 100 - totalWeightedRemaining5h/totalWeight,
100 - totalWeightedRemainingWeekly/totalWeight,
totalWeight
}
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.OCMUser) {
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name)
if err != nil {
return
}
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
if activeLimitIdentifier == "" {
activeLimitIdentifier = "codex"
}
headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64))
headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64))
if totalWeight > 0 {
headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64))
}
}
func (s *Service) InterfaceUpdated() {
for _, cred := range s.allCredentials {
extCred, ok := cred.(*externalCredential)
for _, credential := range s.allCredentials {
external, ok := credential.(*externalCredential)
if !ok {
continue
}
if extCred.reverse && extCred.connectorURL != nil {
extCred.reverseService = s
extCred.resetReverseContext()
go extCred.connectorLoop()
if external.reverse && external.connectorURL != nil {
external.reverseService = s
external.resetReverseContext()
go external.connectorLoop()
}
}
}
func (s *Service) Close() error {
s.statusObserver.Close()
webSocketSessions := s.startWebSocketShutdown()
err := common.Close(
@@ -924,16 +344,16 @@ func (s *Service) Close() error {
}
s.webSocketGroup.Wait()
for _, cred := range s.allCredentials {
cred.close()
for _, credential := range s.allCredentials {
credential.close()
}
return err
}
func (s *Service) registerWebSocketSession(session *webSocketSession) bool {
s.webSocketMutex.Lock()
defer s.webSocketMutex.Unlock()
s.webSocketAccess.Lock()
defer s.webSocketAccess.Unlock()
if s.shuttingDown {
return false
@@ -945,12 +365,12 @@ func (s *Service) registerWebSocketSession(session *webSocketSession) bool {
}
func (s *Service) unregisterWebSocketSession(session *webSocketSession) {
s.webSocketMutex.Lock()
s.webSocketAccess.Lock()
_, loaded := s.webSocketConns[session]
if loaded {
delete(s.webSocketConns, session)
}
s.webSocketMutex.Unlock()
s.webSocketAccess.Unlock()
if loaded {
s.webSocketGroup.Done()
@@ -958,28 +378,28 @@ func (s *Service) unregisterWebSocketSession(session *webSocketSession) {
}
func (s *Service) isShuttingDown() bool {
s.webSocketMutex.Lock()
defer s.webSocketMutex.Unlock()
s.webSocketAccess.Lock()
defer s.webSocketAccess.Unlock()
return s.shuttingDown
}
func (s *Service) interruptWebSocketSessionsForCredential(tag string) {
s.webSocketMutex.Lock()
s.webSocketAccess.Lock()
var toClose []*webSocketSession
for session := range s.webSocketConns {
if session.credentialTag == tag {
toClose = append(toClose, session)
}
}
s.webSocketMutex.Unlock()
s.webSocketAccess.Unlock()
for _, session := range toClose {
session.Close()
}
}
func (s *Service) startWebSocketShutdown() []*webSocketSession {
s.webSocketMutex.Lock()
defer s.webSocketMutex.Unlock()
s.webSocketAccess.Lock()
defer s.webSocketAccess.Unlock()
s.shuttingDown = true

View File

@@ -0,0 +1,512 @@
package ocm
import (
"bytes"
"context"
"encoding/json"
"io"
"mime"
"net/http"
"strconv"
"strings"
"time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/responses"
)
func weeklyCycleHintForLimit(headers http.Header, limitIdentifier string) *WeeklyCycleHint {
normalizedLimitIdentifier := normalizeRateLimitIdentifier(limitIdentifier)
if normalizedLimitIdentifier == "" {
return nil
}
windowHeader := "x-" + normalizedLimitIdentifier + "-secondary-window-minutes"
resetHeader := "x-" + normalizedLimitIdentifier + "-secondary-reset-at"
windowMinutes, hasWindowMinutes := parseInt64Header(headers, windowHeader)
resetAtUnix, hasResetAt := parseInt64Header(headers, resetHeader)
if !hasWindowMinutes || !hasResetAt || windowMinutes <= 0 || resetAtUnix <= 0 {
return nil
}
return &WeeklyCycleHint{
WindowMinutes: windowMinutes,
ResetAt: time.Unix(resetAtUnix, 0).UTC(),
}
}
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
if activeLimitIdentifier != "" {
if activeHint := weeklyCycleHintForLimit(headers, activeLimitIdentifier); activeHint != nil {
return activeHint
}
}
return weeklyCycleHintForLimit(headers, "codex")
}
func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) {
if len(s.options.Users) > 0 {
return credentialForUser(s.userConfigMap, s.providers, username)
}
provider := s.providers[s.options.Credentials[0].Tag]
if provider == nil {
return nil, E.New("no credential available")
}
return provider, nil
}
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := log.ContextWithNewID(r.Context())
if r.URL.Path == "/ocm/v1/status" {
s.handleStatusEndpoint(w, r)
return
}
if r.URL.Path == "/ocm/v1/reverse" {
s.handleReverseConnect(ctx, w, r)
return
}
path := r.URL.Path
if !strings.HasPrefix(path, "/v1/") {
writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/")
return
}
if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"API key authentication is not supported; use Authorization: Bearer with an OCM user token")
return
}
var username string
if len(s.options.Users) > 0 {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
var ok bool
username, ok = s.userManager.Authenticate(clientToken)
if !ok {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
}
sessionID := r.Header.Get("session_id")
// Resolve credential provider and user config
var provider credentialProvider
var userConfig *option.OCMUser
if len(s.options.Users) > 0 {
userConfig = s.userConfigMap[username]
var err error
provider, err = credentialForUser(s.userConfigMap, s.providers, username)
if err != nil {
s.logger.ErrorContext(ctx, "resolve credential: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
} else {
provider = s.providers[s.options.Credentials[0].Tag]
}
if provider == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
return
}
provider.pollIfStale()
if userConfig != nil && userConfig.ExternalCredential != "" {
for _, credential := range s.allCredentials {
if credential.tagName() == userConfig.ExternalCredential && !credential.isUsable() {
credential.pollUsage()
break
}
}
}
selection := credentialSelectionForUser(userConfig)
selectedCredential, isNew, err := provider.selectCredential(sessionID, selection)
if err != nil {
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error()))
return
}
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, selection, isNew)
return
}
if !selectedCredential.isExternal() && selectedCredential.ocmIsAPIKeyMode() {
// API key mode path handling
} else if !selectedCredential.isExternal() {
if path == "/v1/chat/completions" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"chat completions endpoint is only available in API key mode")
return
}
}
shouldTrackUsage := selectedCredential.usageTrackerOrNil() != nil &&
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses"))
canRetryRequest := len(provider.allCredentials()) > 1
// Read body for model extraction and retry buffer when JSON replay is useful.
var bodyBytes []byte
var requestMetadata requestLogMetadata
var requestModel string
if r.Body != nil && (isNew || shouldTrackUsage || canRetryRequest) {
mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type"))
isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json"))
if isJSONRequest {
bodyBytes, err = io.ReadAll(r.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read request body: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
return
}
requestMetadata = parseRequestLogMetadata(path, bodyBytes)
requestModel = requestMetadata.Model
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
}
if isNew {
s.logger.DebugContext(ctx, buildAssignedCredentialLogParts(selectedCredential.tagName(), sessionID, username, requestMetadata)...)
}
requestContext := selectedCredential.wrapRequestContext(ctx)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
defer func() {
requestContext.cancelRequest()
}()
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if err != nil {
s.logger.ErrorContext(ctx, "create proxy request: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
return
}
response, err := selectedCredential.httpClient().Do(proxyRequest)
if err != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request")
return
}
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
return
}
requestContext.releaseCredentialInterrupt()
// Transparent 429 retry
for response.StatusCode == http.StatusTooManyRequests {
resetAt := parseOCMRateLimitResetFromHeaders(response.Header)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete
selectedCredential.updateStateFromHeaders(response.Header)
if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil {
response.Body.Close()
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
return
}
response.Body.Close()
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
requestContext.cancelRequest()
requestContext = nextCredential.wrapRequestContext(ctx)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
return
}
retryResponse, retryErr := nextCredential.httpClient().Do(retryRequest)
if retryErr != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request")
return
}
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
return
}
requestContext.releaseCredentialInterrupt()
response = retryResponse
selectedCredential = nextCredential
}
defer response.Body.Close()
selectedCredential.updateStateFromHeaders(response.Header)
if response.StatusCode == http.StatusBadRequest {
if selectedCredential.isExternal() {
selectedCredential.markUpstreamRejected()
} else {
provider.pollCredentialIfStale(selectedCredential)
}
s.logger.ErrorContext(ctx, "upstream rejected from ", selectedCredential.tagName(), ": status ", response.StatusCode)
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "upstream rejected credential")
return
}
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
return
}
s.rewriteResponseHeaders(response.Header, provider, userConfig)
for key, values := range response.Header {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
w.Header()[key] = values
}
}
w.WriteHeader(response.StatusCode)
usageTracker := selectedCredential.usageTrackerOrNil()
if usageTracker != nil && response.StatusCode == http.StatusOK &&
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) {
s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username)
} else {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
if err == nil && mediaType != "text/event-stream" {
_, _ = io.Copy(w, response.Body)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
buffer := make([]byte, buf.BufferSize)
for {
n, err := response.Body.Read(buffer)
if n > 0 {
_, writeError := w.Write(buffer[:n])
if writeError != nil {
if E.IsClosedOrCanceled(writeError) {
return
}
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
return
}
}
}
}
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) {
isChatCompletions := path == "/v1/chat/completions"
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
isStreaming := err == nil && mediaType == "text/event-stream"
if !isStreaming && !isChatCompletions && response.Header.Get("Content-Type") == "" {
isStreaming = true
}
if !isStreaming {
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read response body: ", err)
return
}
var responseModel, serviceTier string
var inputTokens, outputTokens, cachedTokens int64
if isChatCompletions {
var chatCompletion openai.ChatCompletion
if json.Unmarshal(bodyBytes, &chatCompletion) == nil {
responseModel = chatCompletion.Model
serviceTier = string(chatCompletion.ServiceTier)
inputTokens = chatCompletion.Usage.PromptTokens
outputTokens = chatCompletion.Usage.CompletionTokens
cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens
}
} else {
var responsesResponse responses.Response
if json.Unmarshal(bodyBytes, &responsesResponse) == nil {
responseModel = string(responsesResponse.Model)
serviceTier = string(responsesResponse.ServiceTier)
inputTokens = responsesResponse.Usage.InputTokens
outputTokens = responsesResponse.Usage.OutputTokens
cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
}
}
if inputTokens > 0 || outputTokens > 0 {
if responseModel == "" {
responseModel = requestModel
}
if responseModel != "" {
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
inputTokens,
outputTokens,
cachedTokens,
serviceTier,
username,
time.Now(),
weeklyCycleHint,
)
}
}
_, _ = writer.Write(bodyBytes)
return
}
flusher, ok := writer.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
var inputTokens, outputTokens, cachedTokens int64
var responseModel, serviceTier string
buffer := make([]byte, buf.BufferSize)
var leftover []byte
for {
n, err := response.Body.Read(buffer)
if n > 0 {
data := append(leftover, buffer[:n]...)
lines := bytes.Split(data, []byte("\n"))
if err == nil {
leftover = lines[len(lines)-1]
lines = lines[:len(lines)-1]
} else {
leftover = nil
}
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if bytes.HasPrefix(line, []byte("data: ")) {
eventData := bytes.TrimPrefix(line, []byte("data: "))
if bytes.Equal(eventData, []byte("[DONE]")) {
continue
}
if isChatCompletions {
var chatChunk openai.ChatCompletionChunk
if json.Unmarshal(eventData, &chatChunk) == nil {
if chatChunk.Model != "" {
responseModel = chatChunk.Model
}
if chatChunk.ServiceTier != "" {
serviceTier = string(chatChunk.ServiceTier)
}
if chatChunk.Usage.PromptTokens > 0 {
inputTokens = chatChunk.Usage.PromptTokens
cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens
}
if chatChunk.Usage.CompletionTokens > 0 {
outputTokens = chatChunk.Usage.CompletionTokens
}
}
} else {
var streamEvent responses.ResponseStreamEventUnion
if json.Unmarshal(eventData, &streamEvent) == nil {
if streamEvent.Type == "response.completed" {
completedEvent := streamEvent.AsResponseCompleted()
if string(completedEvent.Response.Model) != "" {
responseModel = string(completedEvent.Response.Model)
}
if completedEvent.Response.ServiceTier != "" {
serviceTier = string(completedEvent.Response.ServiceTier)
}
if completedEvent.Response.Usage.InputTokens > 0 {
inputTokens = completedEvent.Response.Usage.InputTokens
cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens
}
if completedEvent.Response.Usage.OutputTokens > 0 {
outputTokens = completedEvent.Response.Usage.OutputTokens
}
}
}
}
}
}
_, writeError := writer.Write(buffer[:n])
if writeError != nil {
if E.IsClosedOrCanceled(writeError) {
return
}
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
if responseModel == "" {
responseModel = requestModel
}
if inputTokens > 0 || outputTokens > 0 {
if responseModel != "" {
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
inputTokens,
outputTokens,
cachedTokens,
serviceTier,
username,
time.Now(),
weeklyCycleHint,
)
}
}
return
}
}
}

View File

@@ -0,0 +1,78 @@
package ocm
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
)
func TestWriteJSONErrorIncludesSDKErrorFields(t *testing.T) {
t.Parallel()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/v1/responses", nil)
writeJSONErrorWithCode(recorder, request, http.StatusBadRequest, "invalid_request_error", "bad_thing", "broken")
if recorder.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", recorder.Code)
}
var body struct {
Error map[string]any `json:"error"`
}
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
t.Fatal(err)
}
for _, key := range []string{"type", "message", "code", "param"} {
if _, exists := body.Error[key]; !exists {
t.Fatalf("expected error.%s to be present, got %#v", key, body.Error)
}
}
if body.Error["type"] != "invalid_request_error" {
t.Fatalf("expected invalid_request_error type, got %#v", body.Error["type"])
}
if body.Error["message"] != "broken" {
t.Fatalf("expected broken message, got %#v", body.Error["message"])
}
if body.Error["code"] != "bad_thing" {
t.Fatalf("expected bad_thing code, got %#v", body.Error["code"])
}
if body.Error["param"] != "" {
t.Fatalf("expected empty param, got %#v", body.Error["param"])
}
}
func TestHandleWebSocketErrorEventRateLimitTracksHeadersAndReset(t *testing.T) {
t.Parallel()
credential := &testCredential{availability: availabilityStatus{State: availabilityStateUsable}}
service := &Service{}
resetAt := time.Now().Add(time.Minute).Unix()
service.handleWebSocketErrorEvent([]byte(`{
"type":"error",
"status_code":429,
"headers":{
"x-codex-active-limit":"codex",
"x-codex-primary-reset-at":"`+strconv.FormatInt(resetAt, 10)+`"
},
"error":{
"type":"rate_limit_error",
"code":"rate_limited",
"message":"limit hit",
"param":""
}
}`), credential)
if credential.lastHeaders.Get("x-codex-active-limit") != "codex" {
t.Fatalf("expected headers to be forwarded, got %#v", credential.lastHeaders)
}
if credential.rateLimitedAt.Unix() != resetAt {
t.Fatalf("expected reset %d, got %d", resetAt, credential.rateLimitedAt.Unix())
}
}

View File

@@ -0,0 +1,347 @@
package ocm
import (
"bytes"
"encoding/json"
"net/http"
"reflect"
"strconv"
"strings"
"time"
"github.com/sagernet/sing-box/option"
)
type statusPayload struct {
FiveHourUtilization float64 `json:"five_hour_utilization"`
FiveHourReset int64 `json:"five_hour_reset"`
WeeklyUtilization float64 `json:"weekly_utilization"`
WeeklyReset int64 `json:"weekly_reset"`
PlanWeight float64 `json:"plan_weight"`
}
type aggregatedStatus struct {
fiveHourUtilization float64
weeklyUtilization float64
totalWeight float64
fiveHourReset time.Time
weeklyReset time.Time
availability availabilityStatus
}
func resetToEpoch(t time.Time) int64 {
if t.IsZero() {
return 0
}
return t.Unix()
}
func (s aggregatedStatus) equal(other aggregatedStatus) bool {
return reflect.DeepEqual(s.toPayload(), other.toPayload())
}
func (s aggregatedStatus) toPayload() statusPayload {
return statusPayload{
FiveHourUtilization: s.fiveHourUtilization,
FiveHourReset: resetToEpoch(s.fiveHourReset),
WeeklyUtilization: s.weeklyUtilization,
WeeklyReset: resetToEpoch(s.weeklyReset),
PlanWeight: s.totalWeight,
}
}
type aggregateInput struct {
availability availabilityStatus
}
func aggregateAvailability(inputs []aggregateInput) availabilityStatus {
if len(inputs) == 0 {
return availabilityStatus{
State: availabilityStateUnavailable,
Reason: availabilityReasonNoCredentials,
}
}
var earliestRateLimited time.Time
var hasRateLimited bool
var bestBlocked availabilityStatus
var hasBlocked bool
var hasUnavailable bool
blockedPriority := func(reason availabilityReason) int {
switch reason {
case availabilityReasonConnectionLimit:
return 3
case availabilityReasonPollFailed:
return 2
case availabilityReasonUpstreamRejected:
return 1
default:
return 0
}
}
for _, input := range inputs {
availability := input.availability.normalized()
switch availability.State {
case availabilityStateUsable:
return availabilityStatus{State: availabilityStateUsable}
case availabilityStateRateLimited:
hasRateLimited = true
if !availability.ResetAt.IsZero() && (earliestRateLimited.IsZero() || availability.ResetAt.Before(earliestRateLimited)) {
earliestRateLimited = availability.ResetAt
}
case availabilityStateTemporarilyBlocked:
if !hasBlocked || blockedPriority(availability.Reason) > blockedPriority(bestBlocked.Reason) {
bestBlocked = availability
hasBlocked = true
}
if hasBlocked && !availability.ResetAt.IsZero() && (bestBlocked.ResetAt.IsZero() || availability.ResetAt.Before(bestBlocked.ResetAt)) {
bestBlocked.ResetAt = availability.ResetAt
}
case availabilityStateUnavailable:
hasUnavailable = true
}
}
if hasRateLimited {
return availabilityStatus{
State: availabilityStateRateLimited,
Reason: availabilityReasonHardRateLimit,
ResetAt: earliestRateLimited,
}
}
if hasBlocked {
return bestBlocked
}
if hasUnavailable {
return availabilityStatus{
State: availabilityStateUnavailable,
Reason: availabilityReasonUnknown,
}
}
return availabilityStatus{
State: availabilityStateUnknown,
Reason: availabilityReasonUnknown,
}
}
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
return
}
var provider credentialProvider
var userConfig *option.OCMUser
if len(s.options.Users) > 0 {
if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"API key authentication is not supported; use Authorization: Bearer with an OCM user token")
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
username, ok := s.userManager.Authenticate(clientToken)
if !ok {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
userConfig = s.userConfigMap[username]
var err error
provider, err = credentialForUser(s.userConfigMap, s.providers, username)
if err != nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
} else {
provider = s.providers[s.options.Credentials[0].Tag]
}
if provider == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
return
}
if r.URL.Query().Get("watch") == "true" {
s.handleStatusStream(w, r, provider, userConfig)
return
}
provider.pollIfStale()
status := s.computeAggregatedUtilization(provider, userConfig)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(status.toPayload())
}
func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, provider credentialProvider, userConfig *option.OCMUser) {
flusher, ok := w.(http.Flusher)
if !ok {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "streaming not supported")
return
}
subscription, done, err := s.statusObserver.Subscribe()
if err != nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "service closing")
return
}
defer s.statusObserver.UnSubscribe(subscription)
provider.pollIfStale()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
last := s.computeAggregatedUtilization(provider, userConfig)
buf := &bytes.Buffer{}
json.NewEncoder(buf).Encode(last.toPayload())
_, writeErr := w.Write(buf.Bytes())
if writeErr != nil {
return
}
flusher.Flush()
for {
select {
case <-r.Context().Done():
return
case <-done:
return
case <-subscription:
for {
select {
case <-subscription:
default:
goto drained
}
}
drained:
current := s.computeAggregatedUtilization(provider, userConfig)
if current.equal(last) {
continue
}
last = current
buf.Reset()
json.NewEncoder(buf).Encode(current.toPayload())
_, writeErr = w.Write(buf.Bytes())
if writeErr != nil {
return
}
flusher.Flush()
}
}
}
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) aggregatedStatus {
inputs := make([]aggregateInput, 0, len(provider.allCredentials()))
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
now := time.Now()
var totalWeightedHoursUntil5hReset, total5hResetWeight float64
var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64
var hasSnapshotData bool
for _, credential := range provider.allCredentials() {
if userConfig != nil && userConfig.ExternalCredential != "" && credential.tagName() == userConfig.ExternalCredential {
continue
}
if userConfig != nil && !userConfig.AllowExternalUsage && credential.isExternal() {
continue
}
inputs = append(inputs, aggregateInput{
availability: credential.availabilityStatus(),
})
if !credential.hasSnapshotData() {
continue
}
hasSnapshotData = true
weight := credential.planWeight()
remaining5h := credential.fiveHourCap() - credential.fiveHourUtilization()
if remaining5h < 0 {
remaining5h = 0
}
remainingWeekly := credential.weeklyCap() - credential.weeklyUtilization()
if remainingWeekly < 0 {
remainingWeekly = 0
}
totalWeightedRemaining5h += remaining5h * weight
totalWeightedRemainingWeekly += remainingWeekly * weight
totalWeight += weight
fiveHourReset := credential.fiveHourResetTime()
if !fiveHourReset.IsZero() {
hours := fiveHourReset.Sub(now).Hours()
if hours > 0 {
totalWeightedHoursUntil5hReset += hours * weight
total5hResetWeight += weight
}
}
weeklyReset := credential.weeklyResetTime()
if !weeklyReset.IsZero() {
hours := weeklyReset.Sub(now).Hours()
if hours > 0 {
totalWeightedHoursUntilWeeklyReset += hours * weight
totalWeeklyResetWeight += weight
}
}
}
availability := aggregateAvailability(inputs)
if totalWeight == 0 {
result := aggregatedStatus{availability: availability}
if !hasSnapshotData {
result.fiveHourUtilization = 100
result.weeklyUtilization = 100
}
return result
}
result := aggregatedStatus{
fiveHourUtilization: 100 - totalWeightedRemaining5h/totalWeight,
weeklyUtilization: 100 - totalWeightedRemainingWeekly/totalWeight,
totalWeight: totalWeight,
availability: availability,
}
if total5hResetWeight > 0 {
avgHours := totalWeightedHoursUntil5hReset / total5hResetWeight
result.fiveHourReset = now.Add(time.Duration(avgHours * float64(time.Hour)))
}
if totalWeeklyResetWeight > 0 {
avgHours := totalWeightedHoursUntilWeeklyReset / totalWeeklyResetWeight
result.weeklyReset = now.Add(time.Duration(avgHours * float64(time.Hour)))
}
return result
}
func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentialProvider, userConfig *option.OCMUser) {
for key := range headers {
lowerKey := strings.ToLower(key)
if lowerKey == "x-codex-active-limit" ||
strings.HasSuffix(lowerKey, "-primary-used-percent") ||
strings.HasSuffix(lowerKey, "-primary-window-minutes") ||
strings.HasSuffix(lowerKey, "-primary-reset-at") ||
strings.HasSuffix(lowerKey, "-secondary-used-percent") ||
strings.HasSuffix(lowerKey, "-secondary-window-minutes") ||
strings.HasSuffix(lowerKey, "-secondary-reset-at") ||
strings.HasSuffix(lowerKey, "-limit-name") ||
strings.HasPrefix(lowerKey, "x-codex-credits-") {
headers.Del(key)
}
}
status := s.computeAggregatedUtilization(provider, userConfig)
headers.Set("x-codex-primary-used-percent", strconv.FormatFloat(status.fiveHourUtilization, 'f', 2, 64))
headers.Set("x-codex-secondary-used-percent", strconv.FormatFloat(status.weeklyUtilization, 'f', 2, 64))
if !status.fiveHourReset.IsZero() {
headers.Set("x-codex-primary-reset-at", strconv.FormatInt(status.fiveHourReset.Unix(), 10))
}
if !status.weeklyReset.IsZero() {
headers.Set("x-codex-secondary-reset-at", strconv.FormatInt(status.weeklyReset.Unix(), 10))
}
if status.totalWeight > 0 {
headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64))
}
}

View File

@@ -0,0 +1,136 @@
package ocm
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/observable"
)
type testCredential struct {
tag string
external bool
available bool
usable bool
hasData bool
fiveHour float64
weekly float64
fiveHourCapV float64
weeklyCapV float64
weight float64
fiveReset time.Time
weeklyReset time.Time
availability availabilityStatus
lastHeaders http.Header
rateLimitedAt time.Time
}
func (c *testCredential) tagName() string { return c.tag }
func (c *testCredential) isAvailable() bool { return c.available }
func (c *testCredential) isUsable() bool { return c.usable }
func (c *testCredential) isExternal() bool { return c.external }
func (c *testCredential) hasSnapshotData() bool { return c.hasData }
func (c *testCredential) fiveHourUtilization() float64 { return c.fiveHour }
func (c *testCredential) weeklyUtilization() float64 { return c.weekly }
func (c *testCredential) fiveHourCap() float64 { return c.fiveHourCapV }
func (c *testCredential) weeklyCap() float64 { return c.weeklyCapV }
func (c *testCredential) planWeight() float64 { return c.weight }
func (c *testCredential) weeklyResetTime() time.Time { return c.weeklyReset }
func (c *testCredential) fiveHourResetTime() time.Time { return c.fiveReset }
func (c *testCredential) markRateLimited(resetAt time.Time) {
c.rateLimitedAt = resetAt
}
func (c *testCredential) markUpstreamRejected() {}
func (c *testCredential) markTemporarilyBlocked(reason availabilityReason, resetAt time.Time) {
c.availability = availabilityStatus{State: availabilityStateTemporarilyBlocked, Reason: reason, ResetAt: resetAt}
}
func (c *testCredential) availabilityStatus() availabilityStatus { return c.availability }
func (c *testCredential) earliestReset() time.Time { return c.fiveReset }
func (c *testCredential) unavailableError() error { return nil }
func (c *testCredential) getAccessToken() (string, error) { return "", nil }
func (c *testCredential) buildProxyRequest(context.Context, *http.Request, []byte, http.Header) (*http.Request, error) {
return nil, nil
}
func (c *testCredential) updateStateFromHeaders(headers http.Header) {
c.lastHeaders = headers.Clone()
}
func (c *testCredential) wrapRequestContext(context.Context) *credentialRequestContext { return nil }
func (c *testCredential) interruptConnections() {}
func (c *testCredential) setOnBecameUnusable(func()) {}
func (c *testCredential) setStatusSubscriber(*observable.Subscriber[struct{}]) {}
func (c *testCredential) start() error { return nil }
func (c *testCredential) pollUsage() {}
func (c *testCredential) lastUpdatedTime() time.Time { return time.Now() }
func (c *testCredential) pollBackoff(time.Duration) time.Duration { return 0 }
func (c *testCredential) usageTrackerOrNil() *AggregatedUsage { return nil }
func (c *testCredential) httpClient() *http.Client { return nil }
func (c *testCredential) close() {}
func (c *testCredential) ocmDialer() N.Dialer { return nil }
func (c *testCredential) ocmIsAPIKeyMode() bool { return false }
func (c *testCredential) ocmGetAccountID() string { return "" }
func (c *testCredential) ocmGetBaseURL() string { return "" }
type testProvider struct {
credentials []Credential
}
func (p *testProvider) selectCredential(string, credentialSelection) (Credential, bool, error) {
return nil, false, nil
}
func (p *testProvider) onRateLimited(string, Credential, time.Time, credentialSelection) Credential {
return nil
}
func (p *testProvider) linkProviderInterrupt(Credential, credentialSelection, func()) func() bool {
return func() bool { return true }
}
func (p *testProvider) pollIfStale() {}
func (p *testProvider) pollCredentialIfStale(Credential) {}
func (p *testProvider) allCredentials() []Credential { return p.credentials }
func (p *testProvider) close() {}
func TestHandleWebSocketErrorEventConnectionLimitDoesNotUseRateLimitPath(t *testing.T) {
t.Parallel()
credential := &testCredential{availability: availabilityStatus{State: availabilityStateUsable}}
service := &Service{}
service.handleWebSocketErrorEvent([]byte(`{"type":"error","status_code":400,"error":{"code":"websocket_connection_limit_reached"}}`), credential)
if credential.availability.State != availabilityStateTemporarilyBlocked || credential.availability.Reason != availabilityReasonConnectionLimit {
t.Fatalf("expected temporary connection limit block, got %#v", credential.availability)
}
}
func TestWriteCredentialUnavailableErrorReturns429ForRateLimitedCredentials(t *testing.T) {
t.Parallel()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/v1/responses", nil)
provider := &testProvider{credentials: []Credential{
&testCredential{
tag: "a",
available: true,
usable: false,
hasData: true,
weight: 1,
availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: time.Now().Add(time.Minute)},
},
}}
writeCredentialUnavailableError(recorder, request, provider, provider.credentials[0], credentialSelection{}, "all credentials rate-limited")
if recorder.Code != http.StatusTooManyRequests {
t.Fatalf("expected 429, got %d", recorder.Code)
}
var body map[string]map[string]string
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
t.Fatal(err)
}
if body["error"]["type"] != "usage_limit_reached" {
t.Fatalf("expected usage_limit_reached type, got %#v", body)
}
}

View File

@@ -55,13 +55,13 @@ type CostCombination struct {
type AggregatedUsage struct {
LastUpdated time.Time `json:"last_updated"`
Combinations []CostCombination `json:"combinations"`
mutex sync.Mutex
access sync.Mutex
filePath string
logger log.ContextLogger
lastSaveTime time.Time
pendingSave bool
saveTimer *time.Timer
saveMutex sync.Mutex
saveAccess sync.Mutex
}
type UsageStatsJSON struct {
@@ -1035,8 +1035,8 @@ func deriveWeekStartUnix(cycleHint *WeeklyCycleHint) int64 {
}
func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON {
u.mutex.Lock()
defer u.mutex.Unlock()
u.access.Lock()
defer u.access.Unlock()
result := &AggregatedUsageJSON{
LastUpdated: u.LastUpdated,
@@ -1069,8 +1069,8 @@ func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON {
}
func (u *AggregatedUsage) Load() error {
u.mutex.Lock()
defer u.mutex.Unlock()
u.access.Lock()
defer u.access.Unlock()
u.LastUpdated = time.Time{}
u.Combinations = nil
@@ -1116,9 +1116,9 @@ func (u *AggregatedUsage) Save() error {
defer os.Remove(tmpFile)
err = os.Rename(tmpFile, u.filePath)
if err == nil {
u.saveMutex.Lock()
u.saveAccess.Lock()
u.lastSaveTime = time.Now()
u.saveMutex.Unlock()
u.saveAccess.Unlock()
}
return err
}
@@ -1140,15 +1140,15 @@ func (u *AggregatedUsage) AddUsageWithCycleHint(model string, contextWindow int,
observedAt = time.Now()
}
u.mutex.Lock()
defer u.mutex.Unlock()
u.access.Lock()
defer u.access.Unlock()
u.LastUpdated = observedAt
weekStartUnix := deriveWeekStartUnix(cycleHint)
addUsageToCombinations(&u.Combinations, model, normalizedServiceTier, contextWindow, weekStartUnix, user, inputTokens, outputTokens, cachedTokens)
go u.scheduleSave()
u.scheduleSave()
return nil
}
@@ -1156,8 +1156,8 @@ func (u *AggregatedUsage) AddUsageWithCycleHint(model string, contextWindow int,
func (u *AggregatedUsage) scheduleSave() {
const saveInterval = time.Minute
u.saveMutex.Lock()
defer u.saveMutex.Unlock()
u.saveAccess.Lock()
defer u.saveAccess.Unlock()
timeSinceLastSave := time.Since(u.lastSaveTime)
@@ -1174,9 +1174,9 @@ func (u *AggregatedUsage) scheduleSave() {
remainingTime := saveInterval - timeSinceLastSave
u.saveTimer = time.AfterFunc(remainingTime, func() {
u.saveMutex.Lock()
u.saveAccess.Lock()
u.pendingSave = false
u.saveMutex.Unlock()
u.saveAccess.Unlock()
u.saveAsync()
})
}
@@ -1191,8 +1191,8 @@ func (u *AggregatedUsage) saveAsync() {
}
func (u *AggregatedUsage) cancelPendingSave() {
u.saveMutex.Lock()
defer u.saveMutex.Unlock()
u.saveAccess.Lock()
defer u.saveAccess.Unlock()
if u.saveTimer != nil {
u.saveTimer.Stop()

View File

@@ -7,13 +7,13 @@ import (
)
type UserManager struct {
accessMutex sync.RWMutex
tokenMap map[string]string
access sync.RWMutex
tokenMap map[string]string
}
func (m *UserManager) UpdateUsers(users []option.OCMUser) {
m.accessMutex.Lock()
defer m.accessMutex.Unlock()
m.access.Lock()
defer m.access.Unlock()
tokenMap := make(map[string]string, len(users))
for _, user := range users {
tokenMap[user.Token] = user.Name
@@ -22,8 +22,8 @@ func (m *UserManager) UpdateUsers(users []option.OCMUser) {
}
func (m *UserManager) Authenticate(token string) (string, bool) {
m.accessMutex.RLock()
m.access.RLock()
username, found := m.tokenMap[token]
m.accessMutex.RUnlock()
m.access.RUnlock()
return username, found
}

View File

@@ -23,19 +23,87 @@ import (
"github.com/sagernet/ws/wsutil"
"github.com/openai/openai-go/v3/responses"
openaishared "github.com/openai/openai-go/v3/shared"
openaiconstant "github.com/openai/openai-go/v3/shared/constant"
)
type webSocketSession struct {
clientConn net.Conn
upstreamConn net.Conn
credentialTag string
closeOnce sync.Once
clientConn net.Conn
upstreamConn net.Conn
credentialTag string
releaseProviderInterrupt func()
closeOnce sync.Once
closed chan struct{}
}
func (s *webSocketSession) Close() {
s.closeOnce.Do(func() {
s.clientConn.Close()
s.upstreamConn.Close()
close(s.closed)
if s.releaseProviderInterrupt != nil {
s.releaseProviderInterrupt()
}
if s.clientConn != nil {
s.clientConn.Close()
}
if s.upstreamConn != nil {
s.upstreamConn.Close()
}
})
}
type webSocketResponseCreateRequest struct {
responses.ResponseNewParams
legacyReasoningEffortPayload
Type string `json:"type"`
Generate *bool `json:"generate"`
}
func (r *webSocketResponseCreateRequest) UnmarshalJSON(data []byte) error {
type requestEnvelope struct {
Type string `json:"type"`
Generate *bool `json:"generate"`
legacyReasoningEffortPayload
}
var envelope requestEnvelope
if err := json.Unmarshal(data, &envelope); err != nil {
return err
}
var params responses.ResponseNewParams
if err := json.Unmarshal(data, &params); err != nil {
return err
}
r.ResponseNewParams = params
r.legacyReasoningEffortPayload = envelope.legacyReasoningEffortPayload
r.Type = envelope.Type
r.Generate = envelope.Generate
return nil
}
func (r webSocketResponseCreateRequest) metadata() requestLogMetadata {
return requestLogMetadataFromResponsesRequest(r.ResponseNewParams, r.ReasoningEffort)
}
func parseWebSocketResponseCreateRequest(data []byte) (webSocketResponseCreateRequest, bool) {
var request webSocketResponseCreateRequest
if json.Unmarshal(data, &request) != nil {
return webSocketResponseCreateRequest{}, false
}
if request.Type != string(openaiconstant.ResponseCreate("").Default()) || request.Model == "" {
return webSocketResponseCreateRequest{}, false
}
return request, true
}
func (r webSocketResponseCreateRequest) isWarmup() bool {
return r.Generate != nil && !*r.Generate
}
func signalWebSocketReady(channel chan struct{}, once *sync.Once) {
once.Do(func() {
close(channel)
})
}
@@ -74,6 +142,8 @@ func isForwardableWebSocketRequestHeader(key string) bool {
switch {
case lowerKey == "authorization":
return false
case lowerKey == "x-api-key" || lowerKey == "api-key":
return false
case strings.HasPrefix(lowerKey, "sec-websocket-"):
return false
default:
@@ -90,18 +160,26 @@ func (s *Service) handleWebSocket(
sessionID string,
userConfig *option.OCMUser,
provider credentialProvider,
selectedCredential credential,
credentialFilter func(credential) bool,
selectedCredential Credential,
selection credentialSelection,
isNew bool,
) {
var (
err error
requestContext *credentialRequestContext
clientConn net.Conn
session *webSocketSession
upstreamConn net.Conn
upstreamBufferedReader *bufio.Reader
upstreamResponseHeaders http.Header
statusCode int
statusResponseBody string
)
defer func() {
if requestContext != nil {
requestContext.cancelRequest()
}
}()
for {
accessToken, accessErr := selectedCredential.getAccessToken()
@@ -179,22 +257,49 @@ func (s *Service) handleWebSocket(
},
}
upstreamConn, upstreamBufferedReader, _, err = upstreamDialer.Dial(s.ctx, upstreamURL)
requestContext = selectedCredential.wrapRequestContext(ctx)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
if session != nil {
session.Close()
return
}
if clientConn != nil {
clientConn.Close()
}
if upstreamConn != nil {
upstreamConn.Close()
}
}))
}
upstreamConn, upstreamBufferedReader, _, err = upstreamDialer.Dial(requestContext, upstreamURL)
if err == nil {
break
}
requestContext.cancelRequest()
requestContext = nil
upstreamConn = nil
clientConn = nil
if statusCode == http.StatusTooManyRequests {
resetAt := parseOCMRateLimitResetFromHeaders(upstreamResponseHeaders)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
selectedCredential.updateStateFromHeaders(upstreamResponseHeaders)
if nextCredential == nil {
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
return
}
s.logger.InfoContext(ctx, "retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
selectedCredential = nextCredential
continue
}
if statusCode == http.StatusBadRequest && selectedCredential.isExternal() {
selectedCredential.markUpstreamRejected()
s.logger.ErrorContext(ctx, "upstream rejected websocket from ", selectedCredential.tagName(), ": status ", statusCode)
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "upstream rejected credential")
return
}
if statusCode > 0 && statusResponseBody != "" {
s.logger.ErrorContext(ctx, "dial upstream websocket: status ", statusCode, " body: ", statusResponseBody)
} else {
@@ -213,9 +318,7 @@ func (s *Service) handleWebSocket(
clientResponseHeaders[key] = append([]string(nil), values...)
}
}
if userConfig != nil && userConfig.ExternalCredential != "" {
s.rewriteResponseHeadersForExternalUser(clientResponseHeaders, userConfig)
}
s.rewriteResponseHeaders(clientResponseHeaders, provider, userConfig)
clientUpgrader := ws.HTTPUpgrader{
Header: clientResponseHeaders,
@@ -225,16 +328,18 @@ func (s *Service) handleWebSocket(
writeJSONError(w, r, http.StatusServiceUnavailable, "api_error", "service is shutting down")
return
}
clientConn, _, _, err := clientUpgrader.Upgrade(r, w)
clientConn, _, _, err = clientUpgrader.Upgrade(r, w)
if err != nil {
s.logger.ErrorContext(ctx, "upgrade client websocket: ", err)
upstreamConn.Close()
return
}
session := &webSocketSession{
clientConn: clientConn,
upstreamConn: upstreamConn,
credentialTag: selectedCredential.tagName(),
session = &webSocketSession{
clientConn: clientConn,
upstreamConn: upstreamConn,
credentialTag: selectedCredential.tagName(),
releaseProviderInterrupt: requestContext.releaseCredentialInterrupt,
closed: make(chan struct{}),
}
if !s.registerWebSocketSession(session) {
session.Close()
@@ -252,24 +357,27 @@ func (s *Service) handleWebSocket(
upstreamReadWriter = upstreamConn
}
var clientWriteAccess sync.Mutex
modelChannel := make(chan string, 1)
firstRealRequest := make(chan struct{})
var firstRealRequestOnce sync.Once
var waitGroup sync.WaitGroup
waitGroup.Add(2)
go func() {
defer waitGroup.Done()
defer session.Close()
s.proxyWebSocketClientToUpstream(ctx, clientConn, upstreamConn, selectedCredential, modelChannel, isNew, username, sessionID)
s.proxyWebSocketClientToUpstream(ctx, clientConn, upstreamConn, selectedCredential, modelChannel, firstRealRequest, &firstRealRequestOnce, isNew, username, sessionID)
}()
go func() {
defer waitGroup.Done()
defer session.Close()
s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint)
s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, &clientWriteAccess, selectedCredential, modelChannel, username, weeklyCycleHint)
}()
waitGroup.Wait()
}
func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string, isNew bool, username string, sessionID string) {
func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn net.Conn, upstreamConn net.Conn, selectedCredential Credential, modelChannel chan<- string, firstRealRequest chan struct{}, firstRealRequestOnce *sync.Once, isNew bool, username string, sessionID string) {
logged := false
for {
data, opCode, err := wsutil.ReadClientData(clientConn)
@@ -280,34 +388,23 @@ func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn
return
}
shouldSignalFirstRealRequest := false
if opCode == ws.OpText {
var request struct {
Type string `json:"type"`
Model string `json:"model"`
ServiceTier string `json:"service_tier"`
}
if json.Unmarshal(data, &request) == nil && request.Type == "response.create" && request.Model != "" {
if isNew && !logged {
if request, ok := parseWebSocketResponseCreateRequest(data); ok {
isWarmup := request.isWarmup()
if !isWarmup && isNew && !logged {
logged = true
logParts := []any{"assigned credential ", selectedCredential.tagName()}
if sessionID != "" {
logParts = append(logParts, " for session ", sessionID)
}
if username != "" {
logParts = append(logParts, " by user ", username)
}
logParts = append(logParts, ", model=", request.Model)
if request.ServiceTier == "priority" {
logParts = append(logParts, ", fast")
}
s.logger.DebugContext(ctx, logParts...)
s.logger.DebugContext(ctx, buildAssignedCredentialLogParts(selectedCredential.tagName(), sessionID, username, request.metadata())...)
}
if selectedCredential.usageTrackerOrNil() != nil {
if !isWarmup && selectedCredential.usageTrackerOrNil() != nil {
select {
case modelChannel <- request.Model:
default:
}
}
if !isWarmup {
shouldSignalFirstRealRequest = true
}
}
}
@@ -318,10 +415,13 @@ func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn
}
return
}
if shouldSignalFirstRealRequest {
signalWebSocketReady(firstRealRequest, firstRealRequestOnce)
}
}
}
func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, clientWriteAccess *sync.Mutex, selectedCredential Credential, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
usageTracker := selectedCredential.usageTrackerOrNil()
var requestModel string
for {
@@ -342,16 +442,9 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe
switch event.Type {
case "codex.rate_limits":
s.handleWebSocketRateLimitsEvent(data, selectedCredential)
if userConfig != nil && userConfig.ExternalCredential != "" {
rewritten, rewriteErr := s.rewriteWebSocketRateLimitsForExternalUser(data, provider, userConfig)
if rewriteErr == nil {
data = rewritten
}
}
continue
case "error":
if event.StatusCode == http.StatusTooManyRequests {
s.handleWebSocketErrorRateLimited(data, selectedCredential)
}
s.handleWebSocketErrorEvent(data, selectedCredential)
case "response.completed":
if usageTracker != nil {
select {
@@ -365,7 +458,9 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe
}
}
clientWriteAccess.Lock()
err = wsutil.WriteServerMessage(clientConn, opCode, data)
clientWriteAccess.Unlock()
if err != nil {
if !E.IsClosedOrCanceled(err) {
s.logger.DebugContext(ctx, "write client websocket: ", err)
@@ -375,47 +470,66 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe
}
}
func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential credential) {
func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential Credential) {
var rateLimitsEvent struct {
RateLimits struct {
MeteredLimitName string `json:"metered_limit_name"`
LimitName string `json:"limit_name"`
RateLimits struct {
Primary *struct {
UsedPercent float64 `json:"used_percent"`
ResetAt int64 `json:"reset_at"`
UsedPercent float64 `json:"used_percent"`
WindowMinutes int64 `json:"window_minutes"`
ResetAt int64 `json:"reset_at"`
} `json:"primary"`
Secondary *struct {
UsedPercent float64 `json:"used_percent"`
ResetAt int64 `json:"reset_at"`
UsedPercent float64 `json:"used_percent"`
WindowMinutes int64 `json:"window_minutes"`
ResetAt int64 `json:"reset_at"`
} `json:"secondary"`
} `json:"rate_limits"`
LimitName string `json:"limit_name"`
MeteredLimitName string `json:"metered_limit_name"`
PlanWeight float64 `json:"plan_weight"`
Credits *creditsSnapshot `json:"credits"`
PlanWeight float64 `json:"plan_weight"`
}
err := json.Unmarshal(data, &rateLimitsEvent)
if err != nil {
return
}
identifier := rateLimitsEvent.MeteredLimitName
if identifier == "" {
identifier = rateLimitsEvent.LimitName
}
if identifier == "" {
identifier = "codex"
}
identifier = normalizeRateLimitIdentifier(identifier)
headers := make(http.Header)
headers.Set("x-codex-active-limit", identifier)
limitID := rateLimitsEvent.MeteredLimitName
if limitID == "" {
limitID = rateLimitsEvent.LimitName
}
if limitID == "" {
limitID = "codex"
}
headerLimit := headerLimitID(limitID)
headers.Set("x-codex-active-limit", headerLimit)
if w := rateLimitsEvent.RateLimits.Primary; w != nil {
headers.Set("x-"+identifier+"-primary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64))
headers.Set("x-"+headerLimit+"-primary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64))
if w.WindowMinutes > 0 {
headers.Set("x-"+headerLimit+"-primary-window-minutes", strconv.FormatInt(w.WindowMinutes, 10))
}
if w.ResetAt > 0 {
headers.Set("x-"+identifier+"-primary-reset-at", strconv.FormatInt(w.ResetAt, 10))
headers.Set("x-"+headerLimit+"-primary-reset-at", strconv.FormatInt(w.ResetAt, 10))
}
}
if w := rateLimitsEvent.RateLimits.Secondary; w != nil {
headers.Set("x-"+identifier+"-secondary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64))
headers.Set("x-"+headerLimit+"-secondary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64))
if w.WindowMinutes > 0 {
headers.Set("x-"+headerLimit+"-secondary-window-minutes", strconv.FormatInt(w.WindowMinutes, 10))
}
if w.ResetAt > 0 {
headers.Set("x-"+identifier+"-secondary-reset-at", strconv.FormatInt(w.ResetAt, 10))
headers.Set("x-"+headerLimit+"-secondary-reset-at", strconv.FormatInt(w.ResetAt, 10))
}
}
if rateLimitsEvent.LimitName != "" {
headers.Set("x-"+headerLimit+"-limit-name", rateLimitsEvent.LimitName)
}
if rateLimitsEvent.Credits != nil && normalizeStoredLimitID(limitID) == "codex" {
headers.Set("x-codex-credits-has-credits", strconv.FormatBool(rateLimitsEvent.Credits.HasCredits))
headers.Set("x-codex-credits-unlimited", strconv.FormatBool(rateLimitsEvent.Credits.Unlimited))
if rateLimitsEvent.Credits.Balance != "" {
headers.Set("x-codex-credits-balance", rateLimitsEvent.Credits.Balance)
}
}
if rateLimitsEvent.PlanWeight > 0 {
@@ -424,14 +538,23 @@ func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential
selectedCredential.updateStateFromHeaders(headers)
}
func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredential credential) {
func (s *Service) handleWebSocketErrorEvent(data []byte, selectedCredential Credential) {
var errorEvent struct {
Headers map[string]string `json:"headers"`
StatusCode int `json:"status_code"`
Headers map[string]string `json:"headers"`
Error openaishared.ErrorObject `json:"error"`
}
err := json.Unmarshal(data, &errorEvent)
if err != nil {
return
}
if errorEvent.StatusCode == http.StatusBadRequest && errorEvent.Error.Code == "websocket_connection_limit_reached" {
selectedCredential.markTemporarilyBlocked(availabilityReasonConnectionLimit, time.Now().Add(time.Minute))
return
}
if errorEvent.StatusCode != http.StatusTooManyRequests {
return
}
headers := make(http.Header)
for key, value := range errorEvent.Headers {
headers.Set(key, value)
@@ -441,73 +564,6 @@ func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredentia
selectedCredential.markRateLimited(resetAt)
}
func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provider credentialProvider, userConfig *option.OCMUser) ([]byte, error) {
var event map[string]json.RawMessage
err := json.Unmarshal(data, &event)
if err != nil {
return nil, err
}
rateLimitsData, exists := event["rate_limits"]
if !exists || len(rateLimitsData) == 0 || string(rateLimitsData) == "null" {
return data, nil
}
var rateLimits map[string]json.RawMessage
err = json.Unmarshal(rateLimitsData, &rateLimits)
if err != nil {
return nil, err
}
averageFiveHour, averageWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
if totalWeight > 0 {
event["plan_weight"], _ = json.Marshal(totalWeight)
}
primaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["primary"], averageFiveHour)
if err != nil {
return nil, err
}
if primaryData != nil {
rateLimits["primary"] = primaryData
}
secondaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["secondary"], averageWeekly)
if err != nil {
return nil, err
}
if secondaryData != nil {
rateLimits["secondary"] = secondaryData
}
event["rate_limits"], err = json.Marshal(rateLimits)
if err != nil {
return nil, err
}
return json.Marshal(event)
}
func rewriteWebSocketRateLimitWindow(data json.RawMessage, usedPercent float64) (json.RawMessage, error) {
if len(data) == 0 || string(data) == "null" {
return nil, nil
}
var window map[string]json.RawMessage
err := json.Unmarshal(data, &window)
if err != nil {
return nil, err
}
window["used_percent"], err = json.Marshal(usedPercent)
if err != nil {
return nil, err
}
return json.Marshal(window)
}
func (s *Service) handleWebSocketResponseCompleted(data []byte, usageTracker *AggregatedUsage, requestModel string, username string, weeklyCycleHint *WeeklyCycleHint) {
var streamEvent responses.ResponseStreamEventUnion
if json.Unmarshal(data, &streamEvent) != nil {