diff --git a/docs/configuration/service/ccm.md b/docs/configuration/service/ccm.md index 337cacb10..59ef5f7c0 100644 --- a/docs/configuration/service/ccm.md +++ b/docs/configuration/service/ccm.md @@ -10,6 +10,11 @@ CCM (Claude Code Multiplexer) service is a multiplexing service that allows you It handles OAuth authentication with Claude's API on your local machine while allowing remote Claude Code to authenticate using Auth Tokens via the `ANTHROPIC_AUTH_TOKEN` environment variable. +!!! quote "Changes in sing-box 1.14.0" + + :material-plus: [credentials](#credentials) + :material-alert: [users](#users) + ### Structure ```json @@ -19,6 +24,7 @@ It handles OAuth authentication with Claude's API on your local machine while al ... // Listen Fields "credential_path": "", + "credentials": [], "usages_path": "", "users": [], "headers": {}, @@ -45,6 +51,73 @@ On macOS, credentials are read from the system keychain first, then fall back to Refreshed tokens are automatically written back to the same location. +Conflict with `credentials`. + +#### credentials + +!!! question "Since sing-box 1.14.0" + +List of credential configurations for multi-credential mode. + +When set, top-level `credential_path`, `usages_path`, and `detour` are forbidden. Each user must specify a `credential` tag. + +Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a required `tag` field. + +##### Default Credential + +```json +{ + "tag": "a", + "credential_path": "/path/to/.credentials.json", + "usages_path": "/path/to/usages.json", + "detour": "", + "reserve_5h": 20, + "reserve_weekly": 20 +} +``` + +A single OAuth credential file. The `type` field can be omitted (defaults to `default`). + +- `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`. +- `usages_path`: Optional usage tracking file for this credential. +- `detour`: Outbound tag for connecting to the Claude API with this credential. +- `reserve_5h`: Reserve threshold (1-99) for 5-hour window. Credential pauses at (100-N)% utilization. +- `reserve_weekly`: Reserve threshold (1-99) for weekly window. Credential pauses at (100-N)% utilization. + +##### Balancer Credential + +```json +{ + "tag": "pool", + "type": "balancer", + "strategy": "", + "credentials": ["a", "b"], + "poll_interval": "60s" +} +``` + +Assigns sessions to default credentials based on the selected strategy. Sessions are sticky until the assigned credential hits a rate limit. + +- `strategy`: Selection strategy. One of `least_used` `round_robin` `random`. `least_used` will be used by default. +- `credentials`: ==Required== List of default credential tags. +- `poll_interval`: How often to poll upstream usage API. Default `60s`. + +##### Fallback Credential + +```json +{ + "tag": "backup", + "type": "fallback", + "credentials": ["a", "b"], + "poll_interval": "30s" +} +``` + +Uses credentials in order. Falls through to the next when the current one is exhausted. + +- `credentials`: ==Required== Ordered list of default credential tags. +- `poll_interval`: How often to poll upstream usage API. Default `60s`. + #### usages_path Path to the file for storing aggregated API usage statistics. @@ -60,6 +133,8 @@ Statistics are organized by model, context window (200k standard vs 1M premium), The statistics file is automatically saved every minute and upon service shutdown. +Conflict with `credentials`. In multi-credential mode, use `usages_path` on individual default credentials. + #### users List of authorized users for token authentication. @@ -71,7 +146,8 @@ Object format: ```json { "name": "", - "token": "" + "token": "", + "credential": "" } ``` @@ -79,6 +155,7 @@ Object fields: - `name`: Username identifier for tracking purposes. - `token`: Bearer token for authentication. Claude Code authenticates by setting the `ANTHROPIC_AUTH_TOKEN` environment variable to their token value. +- `credential`: Credential tag to use for this user. ==Required== when `credentials` is set. #### headers @@ -90,6 +167,8 @@ These headers will override any existing headers with the same name. Outbound tag for connecting to the Claude API. +Conflict with `credentials`. In multi-credential mode, use `detour` on individual default credentials. + #### tls TLS configuration, see [TLS](/configuration/shared/tls/#inbound). @@ -129,3 +208,52 @@ export ANTHROPIC_AUTH_TOKEN="ak-ccm-hello-world" claude ``` + +### Example with Multiple Credentials + +#### Server + +```json +{ + "services": [ + { + "type": "ccm", + "listen": "0.0.0.0", + "listen_port": 8080, + "credentials": [ + { + "tag": "a", + "credential_path": "/home/user/.claude-a/.credentials.json", + "usages_path": "/data/usages-a.json", + "reserve_5h": 20, + "reserve_weekly": 20 + }, + { + "tag": "b", + "credential_path": "/home/user/.claude-b/.credentials.json", + "reserve_5h": 10, + "reserve_weekly": 10 + }, + { + "tag": "pool", + "type": "balancer", + "poll_interval": "60s", + "credentials": ["a", "b"] + } + ], + "users": [ + { + "name": "alice", + "token": "ak-ccm-hello-world", + "credential": "pool" + }, + { + "name": "bob", + "token": "ak-ccm-hello-bob", + "credential": "a" + } + ] + } + ] +} +``` diff --git a/docs/configuration/service/ccm.zh.md b/docs/configuration/service/ccm.zh.md index 7bba322c7..d9496986a 100644 --- a/docs/configuration/service/ccm.zh.md +++ b/docs/configuration/service/ccm.zh.md @@ -10,6 +10,11 @@ CCM(Claude Code 多路复用器)服务是一个多路复用服务,允许 它在本地机器上处理与 Claude API 的 OAuth 身份验证,同时允许远程 Claude Code 通过 `ANTHROPIC_AUTH_TOKEN` 环境变量使用认证令牌进行身份验证。 +!!! quote "sing-box 1.14.0 中的更改" + + :material-plus: [credentials](#credentials) + :material-alert: [users](#users) + ### 结构 ```json @@ -19,6 +24,7 @@ CCM(Claude Code 多路复用器)服务是一个多路复用服务,允许 ... // 监听字段 "credential_path": "", + "credentials": [], "usages_path": "", "users": [], "headers": {}, @@ -45,6 +51,73 @@ Claude Code OAuth 凭据文件的路径。 刷新的令牌会自动写回相同位置。 +与 `credentials` 冲突。 + +#### credentials + +!!! question "自 sing-box 1.14.0 起" + +多凭据模式的凭据配置列表。 + +设置后,顶层 `credential_path`、`usages_path` 和 `detour` 被禁止。每个用户必须指定 `credential` 标签。 + +每个凭据有一个 `type` 字段(`default`、`balancer` 或 `fallback`)和一个必填的 `tag` 字段。 + +##### 默认凭据 + +```json +{ + "tag": "a", + "credential_path": "/path/to/.credentials.json", + "usages_path": "/path/to/usages.json", + "detour": "", + "reserve_5h": 20, + "reserve_weekly": 20 +} +``` + +单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。 + +- `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。 +- `usages_path`:此凭据的可选使用跟踪文件。 +- `detour`:此凭据用于连接 Claude API 的出站标签。 +- `reserve_5h`:5 小时窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。 +- `reserve_weekly`:每周窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。 + +##### 均衡凭据 + +```json +{ + "tag": "pool", + "type": "balancer", + "strategy": "", + "credentials": ["a", "b"], + "poll_interval": "60s" +} +``` + +根据选择的策略将会话分配给默认凭据。会话保持粘性,直到分配的凭据触发速率限制。 + +- `strategy`:选择策略。可选值:`least_used` `round_robin` `random`。默认使用 `least_used`。 +- `credentials`:==必填== 默认凭据标签列表。 +- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 + +##### 回退凭据 + +```json +{ + "tag": "backup", + "type": "fallback", + "credentials": ["a", "b"], + "poll_interval": "30s" +} +``` + +按顺序使用凭据。当前凭据耗尽后切换到下一个。 + +- `credentials`:==必填== 有序的默认凭据标签列表。 +- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 + #### usages_path 用于存储聚合 API 使用统计信息的文件路径。 @@ -60,6 +133,8 @@ Claude Code OAuth 凭据文件的路径。 统计文件每分钟自动保存一次,并在服务关闭时保存。 +与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `usages_path`。 + #### users 用于令牌身份验证的授权用户列表。 @@ -71,7 +146,8 @@ Claude Code OAuth 凭据文件的路径。 ```json { "name": "", - "token": "" + "token": "", + "credential": "" } ``` @@ -79,6 +155,7 @@ Claude Code OAuth 凭据文件的路径。 - `name`:用于跟踪的用户名标识符。 - `token`:用于身份验证的 Bearer 令牌。Claude Code 通过设置 `ANTHROPIC_AUTH_TOKEN` 环境变量为其令牌值进行身份验证。 +- `credential`:此用户使用的凭据标签。设置 `credentials` 时==必填==。 #### headers @@ -90,6 +167,8 @@ Claude Code OAuth 凭据文件的路径。 用于连接 Claude API 的出站标签。 +与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `detour`。 + #### tls TLS 配置,参阅 [TLS](/zh/configuration/shared/tls/#inbound)。 @@ -129,3 +208,52 @@ export ANTHROPIC_AUTH_TOKEN="ak-ccm-hello-world" claude ``` + +### 多凭据示例 + +#### 服务端 + +```json +{ + "services": [ + { + "type": "ccm", + "listen": "0.0.0.0", + "listen_port": 8080, + "credentials": [ + { + "tag": "a", + "credential_path": "/home/user/.claude-a/.credentials.json", + "usages_path": "/data/usages-a.json", + "reserve_5h": 20, + "reserve_weekly": 20 + }, + { + "tag": "b", + "credential_path": "/home/user/.claude-b/.credentials.json", + "reserve_5h": 10, + "reserve_weekly": 10 + }, + { + "tag": "pool", + "type": "balancer", + "poll_interval": "60s", + "credentials": ["a", "b"] + } + ], + "users": [ + { + "name": "alice", + "token": "ak-ccm-hello-world", + "credential": "pool" + }, + { + "name": "bob", + "token": "ak-ccm-hello-bob", + "credential": "a" + } + ] + } + ] +} +``` diff --git a/docs/configuration/service/ocm.md b/docs/configuration/service/ocm.md index 5fdf2b6b4..8dfd0e99e 100644 --- a/docs/configuration/service/ocm.md +++ b/docs/configuration/service/ocm.md @@ -10,6 +10,11 @@ OCM (OpenAI Codex Multiplexer) service is a multiplexing service that allows you It handles OAuth authentication with OpenAI's API on your local machine while allowing remote clients to authenticate using custom tokens. +!!! quote "Changes in sing-box 1.14.0" + + :material-plus: [credentials](#credentials) + :material-alert: [users](#users) + ### Structure ```json @@ -19,6 +24,7 @@ It handles OAuth authentication with OpenAI's API on your local machine while al ... // Listen Fields "credential_path": "", + "credentials": [], "usages_path": "", "users": [], "headers": {}, @@ -43,6 +49,73 @@ If not specified, defaults to: Refreshed tokens are automatically written back to the same location. +Conflict with `credentials`. + +#### credentials + +!!! question "Since sing-box 1.14.0" + +List of credential configurations for multi-credential mode. + +When set, top-level `credential_path`, `usages_path`, and `detour` are forbidden. Each user must specify a `credential` tag. + +Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a required `tag` field. + +##### Default Credential + +```json +{ + "tag": "a", + "credential_path": "/path/to/auth.json", + "usages_path": "/path/to/usages.json", + "detour": "", + "reserve_5h": 20, + "reserve_weekly": 20 +} +``` + +A single OAuth credential file. The `type` field can be omitted (defaults to `default`). + +- `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`. +- `usages_path`: Optional usage tracking file for this credential. +- `detour`: Outbound tag for connecting to the OpenAI API with this credential. +- `reserve_5h`: Reserve threshold (1-99) for primary rate limit window. Credential pauses at (100-N)% utilization. +- `reserve_weekly`: Reserve threshold (1-99) for secondary (weekly) rate limit window. Credential pauses at (100-N)% utilization. + +##### Balancer Credential + +```json +{ + "tag": "pool", + "type": "balancer", + "strategy": "", + "credentials": ["a", "b"], + "poll_interval": "60s" +} +``` + +Assigns sessions to default credentials based on the selected strategy. Sessions are sticky until the assigned credential hits a rate limit. + +- `strategy`: Selection strategy. One of `least_used` `round_robin` `random`. `least_used` will be used by default. +- `credentials`: ==Required== List of default credential tags. +- `poll_interval`: How often to poll upstream usage API. Default `60s`. + +##### Fallback Credential + +```json +{ + "tag": "backup", + "type": "fallback", + "credentials": ["a", "b"], + "poll_interval": "30s" +} +``` + +Uses credentials in order. Falls through to the next when the current one is exhausted. + +- `credentials`: ==Required== Ordered list of default credential tags. +- `poll_interval`: How often to poll upstream usage API. Default `60s`. + #### usages_path Path to the file for storing aggregated API usage statistics. @@ -58,6 +131,8 @@ Statistics are organized by model and optionally by user when authentication is The statistics file is automatically saved every minute and upon service shutdown. +Conflict with `credentials`. In multi-credential mode, use `usages_path` on individual default credentials. + #### users List of authorized users for token authentication. @@ -69,7 +144,8 @@ Object format: ```json { "name": "", - "token": "" + "token": "", + "credential": "" } ``` @@ -77,6 +153,7 @@ Object fields: - `name`: Username identifier for tracking purposes. - `token`: Bearer token for authentication. Clients authenticate by setting the `Authorization: Bearer ` header. +- `credential`: Credential tag to use for this user. ==Required== when `credentials` is set. #### headers @@ -88,6 +165,8 @@ These headers will override any existing headers with the same name. Outbound tag for connecting to the OpenAI API. +Conflict with `credentials`. In multi-credential mode, use `detour` on individual default credentials. + #### tls TLS configuration, see [TLS](/configuration/shared/tls/#inbound). @@ -183,3 +262,52 @@ Then run: ```bash codex --profile ocm ``` + +### Example with Multiple Credentials + +#### Server + +```json +{ + "services": [ + { + "type": "ocm", + "listen": "0.0.0.0", + "listen_port": 8080, + "credentials": [ + { + "tag": "a", + "credential_path": "/home/user/.codex-a/auth.json", + "usages_path": "/data/usages-a.json", + "reserve_5h": 20, + "reserve_weekly": 20 + }, + { + "tag": "b", + "credential_path": "/home/user/.codex-b/auth.json", + "reserve_5h": 10, + "reserve_weekly": 10 + }, + { + "tag": "pool", + "type": "balancer", + "poll_interval": "60s", + "credentials": ["a", "b"] + } + ], + "users": [ + { + "name": "alice", + "token": "sk-ocm-hello-world", + "credential": "pool" + }, + { + "name": "bob", + "token": "sk-ocm-hello-bob", + "credential": "a" + } + ] + } + ] +} +``` diff --git a/docs/configuration/service/ocm.zh.md b/docs/configuration/service/ocm.zh.md index 2e02dc558..ee4ffa633 100644 --- a/docs/configuration/service/ocm.zh.md +++ b/docs/configuration/service/ocm.zh.md @@ -10,6 +10,11 @@ OCM(OpenAI Codex 多路复用器)服务是一个多路复用服务,允许 它在本地机器上处理与 OpenAI API 的 OAuth 身份验证,同时允许远程客户端使用自定义令牌进行身份验证。 +!!! quote "sing-box 1.14.0 中的更改" + + :material-plus: [credentials](#credentials) + :material-alert: [users](#users) + ### 结构 ```json @@ -19,6 +24,7 @@ OCM(OpenAI Codex 多路复用器)服务是一个多路复用服务,允许 ... // 监听字段 "credential_path": "", + "credentials": [], "usages_path": "", "users": [], "headers": {}, @@ -43,6 +49,73 @@ OpenAI OAuth 凭据文件的路径。 刷新的令牌会自动写回相同位置。 +与 `credentials` 冲突。 + +#### credentials + +!!! question "自 sing-box 1.14.0 起" + +多凭据模式的凭据配置列表。 + +设置后,顶层 `credential_path`、`usages_path` 和 `detour` 被禁止。每个用户必须指定 `credential` 标签。 + +每个凭据有一个 `type` 字段(`default`、`balancer` 或 `fallback`)和一个必填的 `tag` 字段。 + +##### 默认凭据 + +```json +{ + "tag": "a", + "credential_path": "/path/to/auth.json", + "usages_path": "/path/to/usages.json", + "detour": "", + "reserve_5h": 20, + "reserve_weekly": 20 +} +``` + +单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。 + +- `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。 +- `usages_path`:此凭据的可选使用跟踪文件。 +- `detour`:此凭据用于连接 OpenAI API 的出站标签。 +- `reserve_5h`:主要速率限制窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。 +- `reserve_weekly`:次要(每周)速率限制窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。 + +##### 均衡凭据 + +```json +{ + "tag": "pool", + "type": "balancer", + "strategy": "", + "credentials": ["a", "b"], + "poll_interval": "60s" +} +``` + +根据选择的策略将会话分配给默认凭据。会话保持粘性,直到分配的凭据触发速率限制。 + +- `strategy`:选择策略。可选值:`least_used` `round_robin` `random`。默认使用 `least_used`。 +- `credentials`:==必填== 默认凭据标签列表。 +- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 + +##### 回退凭据 + +```json +{ + "tag": "backup", + "type": "fallback", + "credentials": ["a", "b"], + "poll_interval": "30s" +} +``` + +按顺序使用凭据。当前凭据耗尽后切换到下一个。 + +- `credentials`:==必填== 有序的默认凭据标签列表。 +- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。 + #### usages_path 用于存储聚合 API 使用统计信息的文件路径。 @@ -58,6 +131,8 @@ OpenAI OAuth 凭据文件的路径。 统计文件每分钟自动保存一次,并在服务关闭时保存。 +与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `usages_path`。 + #### users 用于令牌身份验证的授权用户列表。 @@ -69,7 +144,8 @@ OpenAI OAuth 凭据文件的路径。 ```json { "name": "", - "token": "" + "token": "", + "credential": "" } ``` @@ -77,6 +153,7 @@ OpenAI OAuth 凭据文件的路径。 - `name`:用于跟踪的用户名标识符。 - `token`:用于身份验证的 Bearer 令牌。客户端通过设置 `Authorization: Bearer ` 头进行身份验证。 +- `credential`:此用户使用的凭据标签。设置 `credentials` 时==必填==。 #### headers @@ -88,6 +165,8 @@ OpenAI OAuth 凭据文件的路径。 用于连接 OpenAI API 的出站标签。 +与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `detour`。 + #### tls TLS 配置,参阅 [TLS](/zh/configuration/shared/tls/#inbound)。 @@ -184,3 +263,52 @@ model_provider = "ocm" ```bash codex --profile ocm ``` + +### 多凭据示例 + +#### 服务端 + +```json +{ + "services": [ + { + "type": "ocm", + "listen": "0.0.0.0", + "listen_port": 8080, + "credentials": [ + { + "tag": "a", + "credential_path": "/home/user/.codex-a/auth.json", + "usages_path": "/data/usages-a.json", + "reserve_5h": 20, + "reserve_weekly": 20 + }, + { + "tag": "b", + "credential_path": "/home/user/.codex-b/auth.json", + "reserve_5h": 10, + "reserve_weekly": 10 + }, + { + "tag": "pool", + "type": "balancer", + "poll_interval": "60s", + "credentials": ["a", "b"] + } + ], + "users": [ + { + "name": "alice", + "token": "sk-ocm-hello-world", + "credential": "pool" + }, + { + "name": "bob", + "token": "sk-ocm-hello-bob", + "credential": "a" + } + ] + } + ] +} +``` diff --git a/go.mod b/go.mod index 98a7811a2..c380c99fa 100644 --- a/go.mod +++ b/go.mod @@ -35,7 +35,7 @@ require ( github.com/sagernet/gomobile v0.1.12 github.com/sagernet/gvisor v0.0.0-20250811.0-sing-box-mod.1 github.com/sagernet/quic-go v0.59.0-sing-box-mod.4 - github.com/sagernet/sing v0.8.2 + github.com/sagernet/sing v0.8.3-0.20260311155444-d39eb42a9f69 github.com/sagernet/sing-mux v0.3.4 github.com/sagernet/sing-quic v0.6.0 github.com/sagernet/sing-shadowsocks v0.2.8 diff --git a/go.sum b/go.sum index 1607560bf..ef9d6ea73 100644 --- a/go.sum +++ b/go.sum @@ -236,8 +236,8 @@ github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNen github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8= github.com/sagernet/quic-go v0.59.0-sing-box-mod.4 h1:6qvrUW79S+CrPwWz6cMePXohgjHoKxLo3c+MDhNwc3o= github.com/sagernet/quic-go v0.59.0-sing-box-mod.4/go.mod h1:OqILvS182CyOol5zNNo6bguvOGgXzV459+chpRaUC+4= -github.com/sagernet/sing v0.8.2 h1:kX1IH9SWJv4S0T9M8O+HNahWgbOuY1VauxbF7NU5lOg= -github.com/sagernet/sing v0.8.2/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.8.3-0.20260311155444-d39eb42a9f69 h1:h6UF2emeydBQMAso99Nr3APV6YustOs+JszVuCkcFy0= +github.com/sagernet/sing v0.8.3-0.20260311155444-d39eb42a9f69/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/sagernet/sing-mux v0.3.4 h1:ZQplKl8MNXutjzbMVtWvWG31fohhgOfCuUZR4dVQ8+s= github.com/sagernet/sing-mux v0.3.4/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk= github.com/sagernet/sing-quic v0.6.0 h1:dhrFnP45wgVKEOT1EvtsToxdzRnHIDIAgj6WHV9pLyM= diff --git a/log/format.go b/log/format.go index 6f4347b12..d2aaa2754 100644 --- a/log/format.go +++ b/log/format.go @@ -168,7 +168,11 @@ func FormatDuration(duration time.Duration) string { return F.ToString(duration.Milliseconds(), "ms") } else if duration < time.Minute { return F.ToString(int64(duration.Seconds()), ".", int64(duration.Seconds()*100)%100, "s") - } else { + } else if duration < time.Hour { return F.ToString(int64(duration.Minutes()), "m", int64(duration.Seconds())%60, "s") + } else if duration < 24*time.Hour { + return F.ToString(int64(duration.Hours()), "h", int64(duration.Minutes())%60, "m") + } else { + return F.ToString(int64(duration.Hours())/24, "d", int64(duration.Hours())%24, "h") } } diff --git a/option/ccm.go b/option/ccm.go index c916aaf22..edfe2e417 100644 --- a/option/ccm.go +++ b/option/ccm.go @@ -1,6 +1,9 @@ package option import ( + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json/badjson" "github.com/sagernet/sing/common/json/badoption" ) @@ -8,6 +11,7 @@ type CCMServiceOptions struct { ListenOptions InboundTLSOptionsContainer CredentialPath string `json:"credential_path,omitempty"` + Credentials []CCMCredential `json:"credentials,omitempty"` Users []CCMUser `json:"users,omitempty"` Headers badoption.HTTPHeader `json:"headers,omitempty"` Detour string `json:"detour,omitempty"` @@ -15,6 +19,75 @@ type CCMServiceOptions struct { } type CCMUser struct { - Name string `json:"name,omitempty"` - Token string `json:"token,omitempty"` + Name string `json:"name,omitempty"` + Token string `json:"token,omitempty"` + Credential string `json:"credential,omitempty"` +} + +type _CCMCredential struct { + Type string `json:"type,omitempty"` + Tag string `json:"tag"` + DefaultOptions CCMDefaultCredentialOptions `json:"-"` + BalancerOptions CCMBalancerCredentialOptions `json:"-"` + FallbackOptions CCMFallbackCredentialOptions `json:"-"` +} + +type CCMCredential _CCMCredential + +func (c CCMCredential) MarshalJSON() ([]byte, error) { + var v any + switch c.Type { + case "", "default": + c.Type = "" + v = c.DefaultOptions + case "balancer": + v = c.BalancerOptions + case "fallback": + v = c.FallbackOptions + default: + return nil, E.New("unknown credential type: ", c.Type) + } + return badjson.MarshallObjects((_CCMCredential)(c), v) +} + +func (c *CCMCredential) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_CCMCredential)(c)) + if err != nil { + return err + } + if c.Tag == "" { + return E.New("missing credential tag") + } + var v any + switch c.Type { + case "", "default": + c.Type = "default" + v = &c.DefaultOptions + case "balancer": + v = &c.BalancerOptions + case "fallback": + v = &c.FallbackOptions + default: + return E.New("unknown credential type: ", c.Type) + } + return badjson.UnmarshallExcluded(bytes, (*_CCMCredential)(c), v) +} + +type CCMDefaultCredentialOptions struct { + CredentialPath string `json:"credential_path,omitempty"` + UsagesPath string `json:"usages_path,omitempty"` + Detour string `json:"detour,omitempty"` + Reserve5h uint8 `json:"reserve_5h"` + ReserveWeekly uint8 `json:"reserve_weekly"` +} + +type CCMBalancerCredentialOptions struct { + Strategy string `json:"strategy,omitempty"` + Credentials badoption.Listable[string] `json:"credentials"` + PollInterval badoption.Duration `json:"poll_interval,omitempty"` +} + +type CCMFallbackCredentialOptions struct { + Credentials badoption.Listable[string] `json:"credentials"` + PollInterval badoption.Duration `json:"poll_interval,omitempty"` } diff --git a/option/ocm.go b/option/ocm.go index c13a1c1f5..832b45528 100644 --- a/option/ocm.go +++ b/option/ocm.go @@ -1,6 +1,9 @@ package option import ( + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json/badjson" "github.com/sagernet/sing/common/json/badoption" ) @@ -8,6 +11,7 @@ type OCMServiceOptions struct { ListenOptions InboundTLSOptionsContainer CredentialPath string `json:"credential_path,omitempty"` + Credentials []OCMCredential `json:"credentials,omitempty"` Users []OCMUser `json:"users,omitempty"` Headers badoption.HTTPHeader `json:"headers,omitempty"` Detour string `json:"detour,omitempty"` @@ -15,6 +19,75 @@ type OCMServiceOptions struct { } type OCMUser struct { - Name string `json:"name,omitempty"` - Token string `json:"token,omitempty"` + Name string `json:"name,omitempty"` + Token string `json:"token,omitempty"` + Credential string `json:"credential,omitempty"` +} + +type _OCMCredential struct { + Type string `json:"type,omitempty"` + Tag string `json:"tag"` + DefaultOptions OCMDefaultCredentialOptions `json:"-"` + BalancerOptions OCMBalancerCredentialOptions `json:"-"` + FallbackOptions OCMFallbackCredentialOptions `json:"-"` +} + +type OCMCredential _OCMCredential + +func (c OCMCredential) MarshalJSON() ([]byte, error) { + var v any + switch c.Type { + case "", "default": + c.Type = "" + v = c.DefaultOptions + case "balancer": + v = c.BalancerOptions + case "fallback": + v = c.FallbackOptions + default: + return nil, E.New("unknown credential type: ", c.Type) + } + return badjson.MarshallObjects((_OCMCredential)(c), v) +} + +func (c *OCMCredential) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_OCMCredential)(c)) + if err != nil { + return err + } + if c.Tag == "" { + return E.New("missing credential tag") + } + var v any + switch c.Type { + case "", "default": + c.Type = "default" + v = &c.DefaultOptions + case "balancer": + v = &c.BalancerOptions + case "fallback": + v = &c.FallbackOptions + default: + return E.New("unknown credential type: ", c.Type) + } + return badjson.UnmarshallExcluded(bytes, (*_OCMCredential)(c), v) +} + +type OCMDefaultCredentialOptions struct { + CredentialPath string `json:"credential_path,omitempty"` + UsagesPath string `json:"usages_path,omitempty"` + Detour string `json:"detour,omitempty"` + Reserve5h uint8 `json:"reserve_5h"` + ReserveWeekly uint8 `json:"reserve_weekly"` +} + +type OCMBalancerCredentialOptions struct { + Strategy string `json:"strategy,omitempty"` + Credentials badoption.Listable[string] `json:"credentials"` + PollInterval badoption.Duration `json:"poll_interval,omitempty"` +} + +type OCMFallbackCredentialOptions struct { + Credentials badoption.Listable[string] `json:"credentials"` + PollInterval badoption.Duration `json:"poll_interval,omitempty"` } diff --git a/protocol/tailscale/endpoint.go b/protocol/tailscale/endpoint.go index b6e64666e..b8f2003d2 100644 --- a/protocol/tailscale/endpoint.go +++ b/protocol/tailscale/endpoint.go @@ -847,4 +847,3 @@ func (c *dnsConfigurtor) GetBaseConfig() (tsDNS.OSConfig, error) { func (c *dnsConfigurtor) Close() error { return nil } - diff --git a/service/ccm/credential.go b/service/ccm/credential.go index 695efc7ae..0fe5e2b97 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -8,8 +8,11 @@ import ( "os" "os/user" "path/filepath" + "runtime" + "sync" "time" + "github.com/sagernet/sing-box/log" E "github.com/sagernet/sing/common/exceptions" ) @@ -21,6 +24,50 @@ const ( anthropicBetaOAuthValue = "oauth-2025-04-20" ) +const ccmUserAgentFallback = "claude-code/2.1.72" + +var ( + ccmUserAgentOnce sync.Once + ccmUserAgentValue string +) + +func initCCMUserAgent(logger log.ContextLogger) { + ccmUserAgentOnce.Do(func() { + version, err := detectClaudeCodeVersion() + if err != nil { + logger.Error("detect Claude Code version: ", err) + ccmUserAgentValue = ccmUserAgentFallback + return + } + logger.Debug("detected Claude Code version: ", version) + ccmUserAgentValue = "claude-code/" + version + }) +} + +func detectClaudeCodeVersion() (string, error) { + userInfo, err := getRealUser() + if err != nil { + return "", E.Cause(err, "get user") + } + binaryName := "claude" + if runtime.GOOS == "windows" { + binaryName = "claude.exe" + } + linkPath := filepath.Join(userInfo.HomeDir, ".local", "bin", binaryName) + target, err := os.Readlink(linkPath) + if err != nil { + return "", E.Cause(err, "readlink ", linkPath) + } + if !filepath.IsAbs(target) { + target = filepath.Join(filepath.Dir(linkPath), target) + } + parent := filepath.Base(filepath.Dir(target)) + if parent != "versions" { + return "", E.New("unexpected symlink target: ", target) + } + return filepath.Base(target), nil +} + func getRealUser() (*user.User, error) { if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { sudoUserInfo, err := user.Lookup(sudoUser) @@ -106,6 +153,7 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut } request.Header.Set("Content-Type", "application/json") request.Header.Set("Accept", "application/json") + request.Header.Set("User-Agent", ccmUserAgentValue) response, err := httpClient.Do(request) if err != nil { @@ -113,6 +161,10 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut } defer response.Body.Close() + if response.StatusCode == http.StatusTooManyRequests { + body, _ := io.ReadAll(response.Body) + return nil, E.New("refresh rate limited: ", response.Status, " ", string(body)) + } if response.StatusCode != http.StatusOK { body, _ := io.ReadAll(response.Body) return nil, E.New("refresh failed: ", response.Status, " ", string(body)) diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go new file mode 100644 index 000000000..1681b0855 --- /dev/null +++ b/service/ccm/credential_state.go @@ -0,0 +1,997 @@ +package ccm + +import ( + "bytes" + "context" + stdTLS "crypto/tls" + "encoding/json" + "errors" + "io" + "math" + "math/rand/v2" + "net" + "net/http" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/ntp" +) + +const defaultPollInterval = 60 * time.Minute + +type credentialState struct { + fiveHourUtilization float64 + fiveHourReset time.Time + weeklyUtilization float64 + weeklyReset time.Time + hardRateLimited bool + rateLimitResetAt time.Time + accountType string + lastUpdated time.Time + consecutivePollFailures int +} + +type defaultCredential struct { + tag string + credentialPath string + credentials *oauthCredentials + accessMutex sync.RWMutex + state credentialState + stateMutex sync.RWMutex + pollAccess sync.Mutex + reserve5h uint8 + reserveWeekly uint8 + usageTracker *AggregatedUsage + httpClient *http.Client + logger log.ContextLogger + + // Connection interruption + onBecameUnusable func() + interrupted bool + requestContext context.Context + cancelRequests context.CancelFunc + requestAccess sync.Mutex +} + +type credentialRequestContext struct { + context.Context + releaseOnce sync.Once + cancelOnce sync.Once + releaseFunc func() bool + cancelFunc context.CancelFunc +} + +func (c *credentialRequestContext) releaseCredentialInterrupt() { + c.releaseOnce.Do(func() { + c.releaseFunc() + }) +} + +func (c *credentialRequestContext) cancelRequest() { + c.releaseCredentialInterrupt() + c.cancelOnce.Do(c.cancelFunc) +} + +func newDefaultCredential(ctx context.Context, tag string, options option.CCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) { + credentialDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{ + Detour: options.Detour, + }, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "create dialer for credential ", tag) + } + httpClient := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSClientConfig: &stdTLS.Config{ + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + }, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + }, + } + reserve5h := options.Reserve5h + if reserve5h == 0 { + reserve5h = 1 + } + reserveWeekly := options.ReserveWeekly + if reserveWeekly == 0 { + reserveWeekly = 10 + } + requestContext, cancelRequests := context.WithCancel(context.Background()) + credential := &defaultCredential{ + tag: tag, + credentialPath: options.CredentialPath, + reserve5h: reserve5h, + reserveWeekly: reserveWeekly, + httpClient: httpClient, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + } + if options.UsagesPath != "" { + credential.usageTracker = &AggregatedUsage{ + LastUpdated: time.Now(), + Combinations: make([]CostCombination, 0), + filePath: options.UsagesPath, + logger: logger, + } + } + return credential, nil +} + +func (c *defaultCredential) start() error { + credentials, err := platformReadCredentials(c.credentialPath) + if err != nil { + return E.Cause(err, "read credentials for ", c.tag) + } + c.credentials = credentials + if credentials.SubscriptionType != "" { + c.state.accountType = credentials.SubscriptionType + } + if c.usageTracker != nil { + err = c.usageTracker.Load() + if err != nil { + c.logger.Warn("load usage statistics for ", c.tag, ": ", err) + } + } + return nil +} + +func (c *defaultCredential) getAccessToken() (string, error) { + c.accessMutex.RLock() + if !c.credentials.needsRefresh() { + token := c.credentials.AccessToken + c.accessMutex.RUnlock() + return token, nil + } + c.accessMutex.RUnlock() + + c.accessMutex.Lock() + defer c.accessMutex.Unlock() + + if !c.credentials.needsRefresh() { + return c.credentials.AccessToken, nil + } + + newCredentials, err := refreshToken(c.httpClient, c.credentials) + if err != nil { + return "", err + } + + c.credentials = newCredentials + if newCredentials.SubscriptionType != "" { + c.stateMutex.Lock() + c.state.accountType = newCredentials.SubscriptionType + c.stateMutex.Unlock() + } + + err = platformWriteCredentials(newCredentials, c.credentialPath) + if err != nil { + c.logger.Warn("persist refreshed token for ", c.tag, ": ", err) + } + + return newCredentials.AccessToken, nil +} + +func parseResetTimestamp(value string) (time.Time, error) { + if value == "" { + return time.Time{}, nil + } + unixEpoch, err := strconv.ParseInt(value, 10, 64) + if err == nil { + return time.Unix(unixEpoch, 0), nil + } + return time.Parse(time.RFC3339Nano, value) +} + +func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { + c.stateMutex.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization + + if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" { + value, err := strconv.ParseFloat(utilization, 64) + if err == nil { + newValue := math.Ceil(value * 100) + if newValue < c.state.fiveHourUtilization { + c.logger.Error("header 5h utilization for ", c.tag, " is lower than current: ", newValue, " < ", c.state.fiveHourUtilization) + } + c.state.fiveHourUtilization = newValue + } + } + if resetAt := headers.Get("anthropic-ratelimit-unified-5h-reset"); resetAt != "" { + value, err := parseResetTimestamp(resetAt) + if err == nil { + c.state.fiveHourReset = value + } + } + if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" { + value, err := strconv.ParseFloat(utilization, 64) + if err == nil { + newValue := math.Ceil(value * 100) + if newValue < c.state.weeklyUtilization { + c.logger.Error("header weekly utilization for ", c.tag, " is lower than current: ", newValue, " < ", c.state.weeklyUtilization) + } + c.state.weeklyUtilization = newValue + } + } + if resetAt := headers.Get("anthropic-ratelimit-unified-7d-reset"); resetAt != "" { + value, err := parseResetTimestamp(resetAt) + if err == nil { + c.state.weeklyReset = value + } + } + c.state.lastUpdated = time.Now() + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + } + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) markRateLimited(resetAt time.Time) { + c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) + c.stateMutex.Lock() + c.state.hardRateLimited = true + c.state.rateLimitResetAt = resetAt + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) isUsable() bool { + c.stateMutex.RLock() + if c.state.hardRateLimited { + if time.Now().Before(c.state.rateLimitResetAt) { + c.stateMutex.RUnlock() + return false + } + c.stateMutex.RUnlock() + c.stateMutex.Lock() + if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + usable := c.checkReservesLocked() + c.stateMutex.Unlock() + return usable + } + usable := c.checkReservesLocked() + c.stateMutex.RUnlock() + return usable +} + +func (c *defaultCredential) checkReservesLocked() bool { + if c.state.fiveHourUtilization >= float64(100-c.reserve5h) { + return false + } + if c.state.weeklyUtilization >= float64(100-c.reserveWeekly) { + return false + } + return true +} + +// checkTransitionLocked detects usable→unusable transition. +// Must be called with stateMutex write lock held. +func (c *defaultCredential) checkTransitionLocked() bool { + unusable := c.state.hardRateLimited || !c.checkReservesLocked() + if unusable && !c.interrupted { + c.interrupted = true + return true + } + if !unusable && c.interrupted { + c.interrupted = false + } + return false +} + +func (c *defaultCredential) interruptConnections() { + c.logger.Warn("interrupting connections for ", c.tag) + c.requestAccess.Lock() + c.cancelRequests() + c.requestContext, c.cancelRequests = context.WithCancel(context.Background()) + c.requestAccess.Unlock() + if c.onBecameUnusable != nil { + c.onBecameUnusable() + } +} + +func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext { + c.requestAccess.Lock() + credentialContext := c.requestContext + c.requestAccess.Unlock() + derived, cancel := context.WithCancel(parent) + stop := context.AfterFunc(credentialContext, func() { + cancel() + }) + return &credentialRequestContext{ + Context: derived, + releaseFunc: stop, + cancelFunc: cancel, + } +} + +func (c *defaultCredential) weeklyUtilization() float64 { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.weeklyUtilization +} + +func (c *defaultCredential) lastUpdatedTime() time.Time { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.lastUpdated +} + +func (c *defaultCredential) markUsagePollAttempted() { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + c.state.lastUpdated = time.Now() +} + +func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration { + c.stateMutex.RLock() + failures := c.state.consecutivePollFailures + c.stateMutex.RUnlock() + if failures <= 0 { + return baseInterval + } + if failures > 4 { + failures = 4 + } + return baseInterval * time.Duration(1< p.credential.pollBackoff(defaultPollInterval) { + p.credential.pollUsage(ctx) + } +} + +func (p *singleCredentialProvider) allDefaults() []*defaultCredential { + return []*defaultCredential{p.credential} +} + +func (p *singleCredentialProvider) close() {} + +const sessionExpiry = 24 * time.Hour + +type sessionEntry struct { + tag string + createdAt time.Time +} + +// balancerProvider assigns sessions to credentials based on a configurable strategy. +type balancerProvider struct { + credentials []*defaultCredential + strategy string + roundRobinIndex atomic.Uint64 + pollInterval time.Duration + sessionMutex sync.RWMutex + sessions map[string]sessionEntry + logger log.ContextLogger +} + +func newBalancerProvider(credentials []*defaultCredential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider { + if pollInterval <= 0 { + pollInterval = defaultPollInterval + } + return &balancerProvider{ + credentials: credentials, + strategy: strategy, + pollInterval: pollInterval, + sessions: make(map[string]sessionEntry), + logger: logger, + } +} + +func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredential, bool, error) { + if sessionID != "" { + p.sessionMutex.RLock() + entry, exists := p.sessions[sessionID] + p.sessionMutex.RUnlock() + if exists { + for _, credential := range p.credentials { + if credential.tag == entry.tag && credential.isUsable() { + return credential, false, nil + } + } + p.sessionMutex.Lock() + delete(p.sessions, sessionID) + p.sessionMutex.Unlock() + } + } + + best := p.pickCredential() + if best == nil { + return nil, false, allCredentialsUnavailableError(p.credentials) + } + + isNew := sessionID != "" + if isNew { + p.sessionMutex.Lock() + p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()} + p.sessionMutex.Unlock() + } + return best, isNew, nil +} + +func (p *balancerProvider) onRateLimited(sessionID string, credential *defaultCredential, resetAt time.Time) *defaultCredential { + credential.markRateLimited(resetAt) + if sessionID != "" { + p.sessionMutex.Lock() + delete(p.sessions, sessionID) + p.sessionMutex.Unlock() + } + + best := p.pickCredential() + if best != nil && sessionID != "" { + p.sessionMutex.Lock() + p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()} + p.sessionMutex.Unlock() + } + return best +} + +func (p *balancerProvider) pickCredential() *defaultCredential { + switch p.strategy { + case "round_robin": + return p.pickRoundRobin() + case "random": + return p.pickRandom() + default: + return p.pickLeastUsed() + } +} + +func (p *balancerProvider) pickLeastUsed() *defaultCredential { + var best *defaultCredential + bestUtilization := float64(101) + for _, credential := range p.credentials { + if !credential.isUsable() { + continue + } + utilization := credential.weeklyUtilization() + if utilization < bestUtilization { + bestUtilization = utilization + best = credential + } + } + return best +} + +func (p *balancerProvider) pickRoundRobin() *defaultCredential { + start := int(p.roundRobinIndex.Add(1) - 1) + count := len(p.credentials) + for offset := range count { + candidate := p.credentials[(start+offset)%count] + if candidate.isUsable() { + return candidate + } + } + return nil +} + +func (p *balancerProvider) pickRandom() *defaultCredential { + var usable []*defaultCredential + for _, candidate := range p.credentials { + if candidate.isUsable() { + usable = append(usable, candidate) + } + } + if len(usable) == 0 { + return nil + } + return usable[rand.IntN(len(usable))] +} + +func (p *balancerProvider) pollIfStale(ctx context.Context) { + now := time.Now() + p.sessionMutex.Lock() + for id, entry := range p.sessions { + if now.Sub(entry.createdAt) > sessionExpiry { + delete(p.sessions, id) + } + } + p.sessionMutex.Unlock() + + for _, credential := range p.credentials { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { + credential.pollUsage(ctx) + } + } +} + +func (p *balancerProvider) allDefaults() []*defaultCredential { + return p.credentials +} + +func (p *balancerProvider) close() {} + +// fallbackProvider tries credentials in order. +type fallbackProvider struct { + credentials []*defaultCredential + pollInterval time.Duration + logger log.ContextLogger +} + +func newFallbackProvider(credentials []*defaultCredential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider { + if pollInterval <= 0 { + pollInterval = defaultPollInterval + } + return &fallbackProvider{ + credentials: credentials, + pollInterval: pollInterval, + logger: logger, + } +} + +func (p *fallbackProvider) selectCredential(_ string) (*defaultCredential, bool, error) { + for _, credential := range p.credentials { + if credential.isUsable() { + return credential, false, nil + } + } + return nil, false, allCredentialsUnavailableError(p.credentials) +} + +func (p *fallbackProvider) onRateLimited(_ string, credential *defaultCredential, resetAt time.Time) *defaultCredential { + credential.markRateLimited(resetAt) + for _, candidate := range p.credentials { + if candidate.isUsable() { + return candidate + } + } + return nil +} + +func (p *fallbackProvider) pollIfStale(ctx context.Context) { + for _, credential := range p.credentials { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { + credential.pollUsage(ctx) + } + } +} + +func (p *fallbackProvider) allDefaults() []*defaultCredential { + return p.credentials +} + +func (p *fallbackProvider) close() {} + +func allCredentialsUnavailableError(credentials []*defaultCredential) error { + var earliest time.Time + for _, credential := range credentials { + resetAt := credential.earliestReset() + if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { + earliest = resetAt + } + } + if earliest.IsZero() { + return E.New("all credentials rate-limited") + } + return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest))) +} + +func extractCCMSessionID(bodyBytes []byte) string { + var body struct { + Metadata struct { + UserID string `json:"user_id"` + } `json:"metadata"` + } + err := json.Unmarshal(bodyBytes, &body) + if err != nil { + return "" + } + userID := body.Metadata.UserID + sessionIndex := strings.LastIndex(userID, "_session_") + if sessionIndex < 0 { + return "" + } + return userID[sessionIndex+len("_session_"):] +} + +func buildCredentialProviders( + ctx context.Context, + options option.CCMServiceOptions, + logger log.ContextLogger, +) (map[string]credentialProvider, []*defaultCredential, error) { + defaultCredentials := make(map[string]*defaultCredential) + var allDefaults []*defaultCredential + providers := make(map[string]credentialProvider) + + for _, credOpt := range options.Credentials { + switch credOpt.Type { + case "default": + credential, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) + if err != nil { + return nil, nil, err + } + defaultCredentials[credOpt.Tag] = credential + allDefaults = append(allDefaults, credential) + providers[credOpt.Tag] = &singleCredentialProvider{credential: credential} + } + } + + for _, credOpt := range options.Credentials { + switch credOpt.Type { + case "balancer": + subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, defaultCredentials, credOpt.Tag) + if err != nil { + return nil, nil, err + } + providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), logger) + case "fallback": + subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, defaultCredentials, credOpt.Tag) + if err != nil { + return nil, nil, err + } + providers[credOpt.Tag] = newFallbackProvider(subCredentials, time.Duration(credOpt.FallbackOptions.PollInterval), logger) + } + } + + return providers, allDefaults, nil +} + +func resolveCredentialTags(tags []string, defaults map[string]*defaultCredential, parentTag string) ([]*defaultCredential, error) { + credentials := make([]*defaultCredential, 0, len(tags)) + for _, tag := range tags { + credential, exists := defaults[tag] + if !exists { + return nil, E.New("credential ", parentTag, " references unknown default credential: ", tag) + } + credentials = append(credentials, credential) + } + if len(credentials) == 0 { + return nil, E.New("credential ", parentTag, " has no sub-credentials") + } + return credentials, nil +} + +func parseRateLimitResetFromHeaders(headers http.Header) time.Time { + claim := headers.Get("anthropic-ratelimit-unified-representative-claim") + switch claim { + case "5h": + if resetStr := headers.Get("anthropic-ratelimit-unified-5h-reset"); resetStr != "" { + value, err := strconv.ParseInt(resetStr, 10, 64) + if err == nil { + return time.Unix(value, 0) + } + } + case "7d": + if resetStr := headers.Get("anthropic-ratelimit-unified-7d-reset"); resetStr != "" { + value, err := strconv.ParseInt(resetStr, 10, 64) + if err == nil { + return time.Unix(value, 0) + } + } + } + if retryAfter := headers.Get("Retry-After"); retryAfter != "" { + seconds, err := strconv.ParseInt(retryAfter, 10, 64) + if err == nil { + return time.Now().Add(time.Duration(seconds) * time.Second) + } + } + return time.Now().Add(5 * time.Minute) +} + +func validateCCMOptions(options option.CCMServiceOptions) error { + hasCredentials := len(options.Credentials) > 0 + hasLegacyPath := options.CredentialPath != "" + hasLegacyUsages := options.UsagesPath != "" + hasLegacyDetour := options.Detour != "" + + if hasCredentials && hasLegacyPath { + return E.New("credential_path and credentials are mutually exclusive") + } + if hasCredentials && hasLegacyUsages { + return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials") + } + if hasCredentials && hasLegacyDetour { + return E.New("detour and credentials are mutually exclusive; use detour on individual credentials") + } + + if hasCredentials { + tags := make(map[string]bool) + for _, credential := range options.Credentials { + if tags[credential.Tag] { + return E.New("duplicate credential tag: ", credential.Tag) + } + tags[credential.Tag] = true + if credential.Type == "default" || credential.Type == "" { + if credential.DefaultOptions.Reserve5h > 99 { + return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99") + } + if credential.DefaultOptions.ReserveWeekly > 99 { + return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99") + } + } + if credential.Type == "balancer" { + switch credential.BalancerOptions.Strategy { + case "", "least_used", "round_robin", "random": + default: + return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy) + } + } + } + + for _, user := range options.Users { + if user.Credential == "" { + return E.New("user ", user.Name, " must specify credential in multi-credential mode") + } + if !tags[user.Credential] { + return E.New("user ", user.Name, " references unknown credential: ", user.Credential) + } + } + } + + return nil +} + +// retryRequestWithBody re-sends a buffered request body using a different credential. +func retryRequestWithBody( + ctx context.Context, + originalRequest *http.Request, + bodyBytes []byte, + credential *defaultCredential, + httpHeaders http.Header, +) (*http.Response, error) { + accessToken, err := credential.getAccessToken() + if err != nil { + return nil, E.Cause(err, "get access token for ", credential.tag) + } + + proxyURL := claudeAPIBaseURL + originalRequest.URL.RequestURI() + retryRequest, err := http.NewRequestWithContext(ctx, originalRequest.Method, proxyURL, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, err + } + + for key, values := range originalRequest.Header { + if !isHopByHopHeader(key) && key != "Authorization" { + retryRequest.Header[key] = values + } + } + + serviceOverridesAcceptEncoding := len(httpHeaders.Values("Accept-Encoding")) > 0 + if credential.usageTracker != nil && !serviceOverridesAcceptEncoding { + retryRequest.Header.Del("Accept-Encoding") + } + + anthropicBetaHeader := retryRequest.Header.Get("anthropic-beta") + if anthropicBetaHeader != "" { + retryRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader) + } else { + retryRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue) + } + + for key, values := range httpHeaders { + retryRequest.Header.Del(key) + retryRequest.Header[key] = values + } + retryRequest.Header.Set("Authorization", "Bearer "+accessToken) + + return credential.httpClient.Do(retryRequest) +} + +// credentialForUser finds the credential provider for a user. +// In legacy mode, returns the single provider. +// In multi-credential mode, returns the provider mapped to the user's credential tag. +func credentialForUser( + userCredentialMap map[string]string, + providers map[string]credentialProvider, + legacyProvider credentialProvider, + username string, +) (credentialProvider, error) { + if legacyProvider != nil { + return legacyProvider, nil + } + tag, exists := userCredentialMap[username] + if !exists { + return nil, E.New("no credential mapping for user: ", username) + } + provider, exists := providers[tag] + if !exists { + return nil, E.New("unknown credential: ", tag) + } + return provider, nil +} + +// noUserCredentialProvider returns the single provider for legacy mode or the first credential in multi-credential mode (no auth). +func noUserCredentialProvider( + providers map[string]credentialProvider, + legacyProvider credentialProvider, + options option.CCMServiceOptions, +) credentialProvider { + if legacyProvider != nil { + return legacyProvider + } + if len(options.Credentials) > 0 { + tag := options.Credentials[0].Tag + return providers[tag] + } + return nil +} diff --git a/service/ccm/service.go b/service/ccm/service.go index 34c38824c..ea81b1b76 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -3,12 +3,10 @@ package ccm import ( "bytes" "context" - stdTLS "crypto/tls" "encoding/json" "errors" "io" "mime" - "net" "net/http" "strconv" "strings" @@ -17,7 +15,6 @@ import ( "github.com/sagernet/sing-box/adapter" boxService "github.com/sagernet/sing-box/adapter/service" - "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/listener" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" @@ -26,9 +23,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/ntp" aTLS "github.com/sagernet/sing/common/tls" "github.com/anthropics/anthropic-sdk-go" @@ -40,6 +35,7 @@ const ( contextWindowStandard = 200000 contextWindowPremium = 1000000 premiumContextThreshold = 200000 + retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential" ) func RegisterService(registry *boxService.Registry) { @@ -60,7 +56,6 @@ type errorDetails struct { func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) - json.NewEncoder(w).Encode(errorResponse{ Type: "error", Error: errorDetails{ @@ -71,6 +66,50 @@ func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, erro }) } +func hasAlternativeCredential(provider credentialProvider, currentCredential *defaultCredential) bool { + if provider == nil || currentCredential == nil { + return false + } + for _, credential := range provider.allDefaults() { + if credential == currentCredential { + continue + } + if credential.isUsable() { + return true + } + } + return false +} + +func unavailableCredentialMessage(provider credentialProvider, fallback string) string { + if provider == nil { + return fallback + } + return allCredentialsUnavailableError(provider.allDefaults()).Error() +} + +func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) { + writeJSONError(w, r, http.StatusTooManyRequests, "rate_limit_error", retryableUsageMessage) +} + +func writeNonRetryableCredentialError(w http.ResponseWriter, r *http.Request, message string) { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", message) +} + +func writeCredentialUnavailableError( + w http.ResponseWriter, + r *http.Request, + provider credentialProvider, + currentCredential *defaultCredential, + fallback string, +) { + if hasAlternativeCredential(provider, currentCredential) { + writeRetryableUsageError(w, r) + return + } + writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, fallback)) +} + func isHopByHopHeader(header string) bool { switch strings.ToLower(header) { case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host": @@ -111,78 +150,79 @@ func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint { type Service struct { boxService.Adapter - ctx context.Context - logger log.ContextLogger - credentialPath string - credentials *oauthCredentials - users []option.CCMUser - httpClient *http.Client - httpHeaders http.Header - listener *listener.Listener - tlsConfig tls.ServerConfig - httpServer *http.Server - userManager *UserManager - accessMutex sync.RWMutex - usageTracker *AggregatedUsage - trackingGroup sync.WaitGroup - shuttingDown bool + ctx context.Context + logger log.ContextLogger + options option.CCMServiceOptions + httpHeaders http.Header + listener *listener.Listener + tlsConfig tls.ServerConfig + httpServer *http.Server + userManager *UserManager + trackingGroup sync.WaitGroup + shuttingDown bool + + // Legacy mode (single credential) + legacyCredential *defaultCredential + legacyProvider credentialProvider + + // Multi-credential mode + providers map[string]credentialProvider + allDefaults []*defaultCredential + userCredentialMap map[string]string } func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) { - serviceDialer, err := dialer.NewWithOptions(dialer.Options{ - Context: ctx, - Options: option.DialerOptions{ - Detour: options.Detour, - }, - RemoteIsDomain: true, - }) - if err != nil { - return nil, E.Cause(err, "create dialer") - } + initCCMUserAgent(logger) - httpClient := &http.Client{ - Transport: &http.Transport{ - ForceAttemptHTTP2: true, - TLSClientConfig: &stdTLS.Config{ - RootCAs: adapter.RootPoolFromContext(ctx), - Time: ntp.TimeFuncFromContext(ctx), - }, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - }, + err := validateCCMOptions(options) + if err != nil { + return nil, E.Cause(err, "validate options") } userManager := &UserManager{ tokenMap: make(map[string]string), } - var usageTracker *AggregatedUsage - if options.UsagesPath != "" { - usageTracker = &AggregatedUsage{ - LastUpdated: time.Now(), - Combinations: make([]CostCombination, 0), - filePath: options.UsagesPath, - logger: logger, - } - } - service := &Service{ - Adapter: boxService.NewAdapter(C.TypeCCM, tag), - ctx: ctx, - logger: logger, - credentialPath: options.CredentialPath, - users: options.Users, - httpClient: httpClient, - httpHeaders: options.Headers.Build(), + Adapter: boxService.NewAdapter(C.TypeCCM, tag), + ctx: ctx, + logger: logger, + options: options, + httpHeaders: options.Headers.Build(), listener: listener.New(listener.Options{ Context: ctx, Logger: logger, Network: []string{N.NetworkTCP}, Listen: options.ListenOptions, }), - userManager: userManager, - usageTracker: usageTracker, + userManager: userManager, + } + + if len(options.Credentials) > 0 { + providers, allDefaults, err := buildCredentialProviders(ctx, options, logger) + if err != nil { + return nil, E.Cause(err, "build credential providers") + } + service.providers = providers + service.allDefaults = allDefaults + + userCredentialMap := make(map[string]string) + for _, user := range options.Users { + userCredentialMap[user.Name] = user.Credential + } + service.userCredentialMap = userCredentialMap + } else { + credential, err := newDefaultCredential(ctx, "default", option.CCMDefaultCredentialOptions{ + CredentialPath: options.CredentialPath, + UsagesPath: options.UsagesPath, + Detour: options.Detour, + }, logger) + if err != nil { + return nil, err + } + service.legacyCredential = credential + service.legacyProvider = &singleCredentialProvider{credential: credential} + service.allDefaults = []*defaultCredential{credential} } if options.TLS != nil { @@ -201,18 +241,12 @@ func (s *Service) Start(stage adapter.StartStage) error { return nil } - s.userManager.UpdateUsers(s.users) + s.userManager.UpdateUsers(s.options.Users) - credentials, err := platformReadCredentials(s.credentialPath) - if err != nil { - return E.Cause(err, "read credentials") - } - s.credentials = credentials - - if s.usageTracker != nil { - err = s.usageTracker.Load() + for _, credential := range s.allDefaults { + err := credential.start() if err != nil { - s.logger.Warn("load usage statistics: ", err) + return err } } @@ -222,7 +256,7 @@ func (s *Service) Start(stage adapter.StartStage) error { s.httpServer = &http.Server{Handler: router} if s.tlsConfig != nil { - err = s.tlsConfig.Start() + err := s.tlsConfig.Start() if err != nil { return E.Cause(err, "create TLS config") } @@ -250,44 +284,19 @@ func (s *Service) Start(stage adapter.StartStage) error { return nil } -func (s *Service) getAccessToken() (string, error) { - s.accessMutex.RLock() - if !s.credentials.needsRefresh() { - token := s.credentials.AccessToken - s.accessMutex.RUnlock() - return token, nil +func isExtendedContextRequest(betaHeader string) bool { + for _, feature := range strings.Split(betaHeader, ",") { + if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") { + return true + } } - s.accessMutex.RUnlock() - - s.accessMutex.Lock() - defer s.accessMutex.Unlock() - - if !s.credentials.needsRefresh() { - return s.credentials.AccessToken, nil - } - - newCredentials, err := refreshToken(s.httpClient, s.credentials) - if err != nil { - return "", err - } - - s.credentials = newCredentials - - err = platformWriteCredentials(newCredentials, s.credentialPath) - if err != nil { - s.logger.Warn("persist refreshed token: ", err) - } - - return newCredentials.AccessToken, nil + return false } func detectContextWindow(betaHeader string, totalInputTokens int64) int { if totalInputTokens > premiumContextThreshold { - features := strings.Split(betaHeader, ",") - for _, feature := range features { - if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") { - return contextWindowPremium - } + if isExtendedContextRequest(betaHeader) { + return contextWindowPremium } } return contextWindowStandard @@ -300,7 +309,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } var username string - if len(s.users) > 0 { + if len(s.options.Users) > 0 { authHeader := r.Header.Get("Authorization") if authHeader == "" { s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") @@ -322,26 +331,78 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } + // Always read body to extract model and session ID + var bodyBytes []byte var requestModel string var messagesCount int + var sessionID string - if s.usageTracker != nil && r.Body != nil { - bodyBytes, err := io.ReadAll(r.Body) + if r.Body != nil { + var err error + bodyBytes, err = io.ReadAll(r.Body) + if err != nil { + s.logger.Error("read request body: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") + return + } + + var request struct { + Model string `json:"model"` + Messages []anthropic.MessageParam `json:"messages"` + } + err = json.Unmarshal(bodyBytes, &request) if err == nil { - var request struct { - Model string `json:"model"` - Messages []anthropic.MessageParam `json:"messages"` - } - err := json.Unmarshal(bodyBytes, &request) - if err == nil { - requestModel = request.Model - messagesCount = len(request.Messages) - } - r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + requestModel = request.Model + messagesCount = len(request.Messages) + } + + sessionID = extractCCMSessionID(bodyBytes) + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + // Resolve credential provider + var provider credentialProvider + if len(s.options.Users) > 0 { + var err error + provider, err = credentialForUser(s.userCredentialMap, s.providers, s.legacyProvider, username) + if err != nil { + s.logger.Error("resolve credential: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + return + } + } else { + provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options) + } + if provider == nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available") + return + } + + provider.pollIfStale(s.ctx) + + anthropicBetaHeader := r.Header.Get("anthropic-beta") + if isExtendedContextRequest(anthropicBetaHeader) { + if _, isSingle := provider.(*singleCredentialProvider); !isSingle { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "extended context (1m) requests will consume Extra usage, please use a default credential directly") + return } } - accessToken, err := s.getAccessToken() + credential, isNew, err := provider.selectCredential(sessionID) + if err != nil { + writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error())) + return + } + if isNew { + if username != "" { + s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID, " by user ", username) + } else { + s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID) + } + } + + accessToken, err := credential.getAccessToken() if err != nil { s.logger.Error("get access token: ", err) writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed") @@ -349,7 +410,11 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } proxyURL := claudeAPIBaseURL + r.URL.RequestURI() - proxyRequest, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, r.Body) + requestContext := credential.wrapRequestContext(r.Context()) + defer func() { + requestContext.cancelRequest() + }() + proxyRequest, err := http.NewRequestWithContext(requestContext, r.Method, proxyURL, r.Body) if err != nil { s.logger.Error("create proxy request: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") @@ -362,14 +427,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } + hasUsageTracker := credential.usageTracker != nil serviceOverridesAcceptEncoding := len(s.httpHeaders.Values("Accept-Encoding")) > 0 - if s.usageTracker != nil && !serviceOverridesAcceptEncoding { - // Strip Accept-Encoding so Go Transport adds it automatically - // and transparently decompresses the response for correct usage counting. + if hasUsageTracker && !serviceOverridesAcceptEncoding { proxyRequest.Header.Del("Accept-Encoding") } - anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta") if anthropicBetaHeader != "" { proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader) } else { @@ -383,13 +446,65 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { proxyRequest.Header.Set("Authorization", "Bearer "+accessToken) - response, err := s.httpClient.Do(proxyRequest) + response, err := credential.httpClient.Do(proxyRequest) if err != nil { + if r.Context().Err() != nil { + return + } + if requestContext.Err() != nil { + writeCredentialUnavailableError(w, r, provider, credential, "credential became unavailable while processing the request") + return + } writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) return } + requestContext.releaseCredentialInterrupt() + + // Transparent 429 retry + for response.StatusCode == http.StatusTooManyRequests { + resetAt := parseRateLimitResetFromHeaders(response.Header) + nextCredential := provider.onRateLimited(sessionID, credential, resetAt) + credential.updateStateFromHeaders(response.Header) + if bodyBytes == nil || nextCredential == nil { + response.Body.Close() + writeCredentialUnavailableError(w, r, provider, credential, "all credentials rate-limited") + return + } + response.Body.Close() + s.logger.Info("retrying with credential ", nextCredential.tag, " after 429 from ", credential.tag) + requestContext.cancelRequest() + requestContext = nextCredential.wrapRequestContext(r.Context()) + retryResponse, retryErr := retryRequestWithBody(requestContext, r, bodyBytes, nextCredential, s.httpHeaders) + if retryErr != nil { + if r.Context().Err() != nil { + return + } + if requestContext.Err() != nil { + writeCredentialUnavailableError(w, r, provider, nextCredential, "credential became unavailable while retrying the request") + return + } + s.logger.Error("retry request: ", retryErr) + writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) + return + } + requestContext.releaseCredentialInterrupt() + response = retryResponse + credential = nextCredential + } defer response.Body.Close() + credential.updateStateFromHeaders(response.Header) + + if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { + body, _ := io.ReadAll(response.Body) + s.logger.Error("upstream error from ", credential.tag, ": status ", response.StatusCode, " ", string(body)) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", + "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) + return + } + + hasUsageTracker = credential.usageTracker != nil + for key, values := range response.Header { if !isHopByHopHeader(key) { w.Header()[key] = values @@ -397,8 +512,8 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(response.StatusCode) - if s.usageTracker != nil && response.StatusCode == http.StatusOK { - s.handleResponseWithTracking(w, response, requestModel, anthropicBetaHeader, messagesCount, username) + if hasUsageTracker && response.StatusCode == http.StatusOK { + s.handleResponseWithTracking(w, response, credential.usageTracker, requestModel, anthropicBetaHeader, messagesCount, username) } else { mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) if err == nil && mediaType != "text/event-stream" { @@ -428,7 +543,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, requestModel string, anthropicBetaHeader string, messagesCount int, username string) { +func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) { weeklyCycleHint := extractWeeklyCycleHint(response.Header) mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) isStreaming := err == nil && mediaType == "text/event-stream" @@ -456,7 +571,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons if responseModel != "" { totalInputTokens := usage.InputTokens + usage.CacheCreationInputTokens + usage.CacheReadInputTokens contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens) - s.usageTracker.AddUsageWithCycleHint( + usageTracker.AddUsageWithCycleHint( responseModel, contextWindow, messagesCount, @@ -557,7 +672,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons if responseModel != "" { totalInputTokens := accumulatedUsage.InputTokens + accumulatedUsage.CacheCreationInputTokens + accumulatedUsage.CacheReadInputTokens contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens) - s.usageTracker.AddUsageWithCycleHint( + usageTracker.AddUsageWithCycleHint( responseModel, contextWindow, messagesCount, @@ -585,12 +700,8 @@ func (s *Service) Close() error { s.tlsConfig, ) - if s.usageTracker != nil { - s.usageTracker.cancelPendingSave() - saveErr := s.usageTracker.Save() - if saveErr != nil { - s.logger.Error("save usage statistics: ", saveErr) - } + for _, credential := range s.allDefaults { + credential.close() } return err diff --git a/service/ocm/credential.go b/service/ocm/credential.go index 76651a8e1..0cdbd6379 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -138,6 +138,10 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut } defer response.Body.Close() + if response.StatusCode == http.StatusTooManyRequests { + body, _ := io.ReadAll(response.Body) + return nil, E.New("refresh rate limited: ", response.Status, " ", string(body)) + } if response.StatusCode != http.StatusOK { body, _ := io.ReadAll(response.Body) return nil, E.New("refresh failed: ", response.Status, " ", string(body)) diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go new file mode 100644 index 000000000..3c6cf4ed9 --- /dev/null +++ b/service/ocm/credential_state.go @@ -0,0 +1,1022 @@ +package ocm + +import ( + "bytes" + "context" + stdTLS "crypto/tls" + "encoding/json" + "errors" + "io" + "math/rand/v2" + "net" + "net/http" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/ntp" +) + +const defaultPollInterval = 60 * time.Minute + +type credentialState struct { + fiveHourUtilization float64 + fiveHourReset time.Time + weeklyUtilization float64 + weeklyReset time.Time + hardRateLimited bool + rateLimitResetAt time.Time + accountType string + lastUpdated time.Time + consecutivePollFailures int +} + +type defaultCredential struct { + tag string + credentialPath string + credentials *oauthCredentials + accessMutex sync.RWMutex + state credentialState + stateMutex sync.RWMutex + pollAccess sync.Mutex + reserve5h uint8 + reserveWeekly uint8 + usageTracker *AggregatedUsage + dialer N.Dialer + httpClient *http.Client + logger log.ContextLogger + + // Connection interruption + onBecameUnusable func() + interrupted bool + requestContext context.Context + cancelRequests context.CancelFunc + requestAccess sync.Mutex +} + +type credentialRequestContext struct { + context.Context + releaseOnce sync.Once + cancelOnce sync.Once + releaseFunc func() bool + cancelFunc context.CancelFunc +} + +func (c *credentialRequestContext) releaseCredentialInterrupt() { + c.releaseOnce.Do(func() { + c.releaseFunc() + }) +} + +func (c *credentialRequestContext) cancelRequest() { + c.releaseCredentialInterrupt() + c.cancelOnce.Do(c.cancelFunc) +} + +func newDefaultCredential(ctx context.Context, tag string, options option.OCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) { + credentialDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{ + Detour: options.Detour, + }, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "create dialer for credential ", tag) + } + httpClient := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSClientConfig: &stdTLS.Config{ + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + }, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + }, + } + reserve5h := options.Reserve5h + if reserve5h == 0 { + reserve5h = 1 + } + reserveWeekly := options.ReserveWeekly + if reserveWeekly == 0 { + reserveWeekly = 10 + } + requestContext, cancelRequests := context.WithCancel(context.Background()) + credential := &defaultCredential{ + tag: tag, + credentialPath: options.CredentialPath, + reserve5h: reserve5h, + reserveWeekly: reserveWeekly, + dialer: credentialDialer, + httpClient: httpClient, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + } + if options.UsagesPath != "" { + credential.usageTracker = &AggregatedUsage{ + LastUpdated: time.Now(), + Combinations: make([]CostCombination, 0), + filePath: options.UsagesPath, + logger: logger, + } + } + return credential, nil +} + +func (c *defaultCredential) start() error { + credentials, err := platformReadCredentials(c.credentialPath) + if err != nil { + return E.Cause(err, "read credentials for ", c.tag) + } + c.credentials = credentials + if c.usageTracker != nil { + err = c.usageTracker.Load() + if err != nil { + c.logger.Warn("load usage statistics for ", c.tag, ": ", err) + } + } + return nil +} + +func (c *defaultCredential) getAccessToken() (string, error) { + c.accessMutex.RLock() + if !c.credentials.needsRefresh() { + token := c.credentials.getAccessToken() + c.accessMutex.RUnlock() + return token, nil + } + c.accessMutex.RUnlock() + + c.accessMutex.Lock() + defer c.accessMutex.Unlock() + + if !c.credentials.needsRefresh() { + return c.credentials.getAccessToken(), nil + } + + newCredentials, err := refreshToken(c.httpClient, c.credentials) + if err != nil { + return "", err + } + + c.credentials = newCredentials + + err = platformWriteCredentials(newCredentials, c.credentialPath) + if err != nil { + c.logger.Warn("persist refreshed token for ", c.tag, ": ", err) + } + + return newCredentials.getAccessToken(), nil +} + +func (c *defaultCredential) getAccountID() string { + c.accessMutex.RLock() + defer c.accessMutex.RUnlock() + return c.credentials.getAccountID() +} + +func (c *defaultCredential) isAPIKeyMode() bool { + c.accessMutex.RLock() + defer c.accessMutex.RUnlock() + return c.credentials.isAPIKeyMode() +} + +func (c *defaultCredential) getBaseURL() string { + if c.isAPIKeyMode() { + return openaiAPIBaseURL + } + return chatGPTBackendURL +} + +func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { + c.stateMutex.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization + + activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) + if activeLimitIdentifier == "" { + activeLimitIdentifier = "codex" + } + + fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent") + if fiveHourPercent != "" { + value, err := strconv.ParseFloat(fiveHourPercent, 64) + if err == nil { + c.state.fiveHourUtilization = value + } + } + fiveHourResetAt := headers.Get("x-" + activeLimitIdentifier + "-primary-reset-at") + if fiveHourResetAt != "" { + value, err := strconv.ParseInt(fiveHourResetAt, 10, 64) + if err == nil { + c.state.fiveHourReset = time.Unix(value, 0) + } + } + weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent") + if weeklyPercent != "" { + value, err := strconv.ParseFloat(weeklyPercent, 64) + if err == nil { + c.state.weeklyUtilization = value + } + } + weeklyResetAt := headers.Get("x-" + activeLimitIdentifier + "-secondary-reset-at") + if weeklyResetAt != "" { + value, err := strconv.ParseInt(weeklyResetAt, 10, 64) + if err == nil { + c.state.weeklyReset = time.Unix(value, 0) + } + } + c.state.lastUpdated = time.Now() + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + } + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) markRateLimited(resetAt time.Time) { + c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) + c.stateMutex.Lock() + c.state.hardRateLimited = true + c.state.rateLimitResetAt = resetAt + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) isUsable() bool { + c.stateMutex.RLock() + if c.state.hardRateLimited { + if time.Now().Before(c.state.rateLimitResetAt) { + c.stateMutex.RUnlock() + return false + } + c.stateMutex.RUnlock() + c.stateMutex.Lock() + if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + usable := c.checkReservesLocked() + c.stateMutex.Unlock() + return usable + } + usable := c.checkReservesLocked() + c.stateMutex.RUnlock() + return usable +} + +func (c *defaultCredential) checkReservesLocked() bool { + if c.state.fiveHourUtilization >= float64(100-c.reserve5h) { + return false + } + if c.state.weeklyUtilization >= float64(100-c.reserveWeekly) { + return false + } + return true +} + +// checkTransitionLocked detects usable→unusable transition. +// Must be called with stateMutex write lock held. +func (c *defaultCredential) checkTransitionLocked() bool { + unusable := c.state.hardRateLimited || !c.checkReservesLocked() + if unusable && !c.interrupted { + c.interrupted = true + return true + } + if !unusable && c.interrupted { + c.interrupted = false + } + return false +} + +func (c *defaultCredential) interruptConnections() { + c.logger.Warn("interrupting connections for ", c.tag) + c.requestAccess.Lock() + c.cancelRequests() + c.requestContext, c.cancelRequests = context.WithCancel(context.Background()) + c.requestAccess.Unlock() + if c.onBecameUnusable != nil { + c.onBecameUnusable() + } +} + +func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext { + c.requestAccess.Lock() + credentialContext := c.requestContext + c.requestAccess.Unlock() + derived, cancel := context.WithCancel(parent) + stop := context.AfterFunc(credentialContext, func() { + cancel() + }) + return &credentialRequestContext{ + Context: derived, + releaseFunc: stop, + cancelFunc: cancel, + } +} + +func (c *defaultCredential) weeklyUtilization() float64 { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.weeklyUtilization +} + +func (c *defaultCredential) lastUpdatedTime() time.Time { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.lastUpdated +} + +func (c *defaultCredential) markUsagePollAttempted() { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + c.state.lastUpdated = time.Now() +} + +func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration { + c.stateMutex.RLock() + failures := c.state.consecutivePollFailures + c.stateMutex.RUnlock() + if failures <= 0 { + return baseInterval + } + if failures > 4 { + failures = 4 + } + return baseInterval * time.Duration(1< 0 { + c.state.fiveHourReset = time.Unix(w.ResetAt, 0) + } + } + if w := usageResponse.RateLimit.SecondaryWindow; w != nil { + c.state.weeklyUtilization = w.UsedPercent + if w.ResetAt > 0 { + c.state.weeklyReset = time.Unix(w.ResetAt, 0) + } + } + } + if usageResponse.PlanType != "" { + c.state.accountType = usageResponse.PlanType + } + if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%") + } + shouldInterrupt := c.checkTransitionLocked() + c.stateMutex.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) close() { + if c.usageTracker != nil { + c.usageTracker.cancelPendingSave() + err := c.usageTracker.Save() + if err != nil { + c.logger.Error("save usage statistics for ", c.tag, ": ", err) + } + } +} + +type credentialProvider interface { + selectCredential(sessionID string) (*defaultCredential, bool, error) + onRateLimited(sessionID string, credential *defaultCredential, resetAt time.Time) *defaultCredential + pollIfStale(ctx context.Context) + allDefaults() []*defaultCredential + close() +} + +type singleCredentialProvider struct { + credential *defaultCredential +} + +func (p *singleCredentialProvider) selectCredential(_ string) (*defaultCredential, bool, error) { + if !p.credential.isUsable() { + return nil, false, E.New("credential ", p.credential.tag, " is rate-limited") + } + return p.credential, false, nil +} + +func (p *singleCredentialProvider) onRateLimited(_ string, credential *defaultCredential, resetAt time.Time) *defaultCredential { + credential.markRateLimited(resetAt) + return nil +} + +func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { + if time.Since(p.credential.lastUpdatedTime()) > p.credential.pollBackoff(defaultPollInterval) { + p.credential.pollUsage(ctx) + } +} + +func (p *singleCredentialProvider) allDefaults() []*defaultCredential { + return []*defaultCredential{p.credential} +} + +func (p *singleCredentialProvider) close() {} + +const sessionExpiry = 24 * time.Hour + +type sessionEntry struct { + tag string + createdAt time.Time +} + +type balancerProvider struct { + credentials []*defaultCredential + strategy string + roundRobinIndex atomic.Uint64 + pollInterval time.Duration + sessionMutex sync.RWMutex + sessions map[string]sessionEntry + logger log.ContextLogger +} + +func newBalancerProvider(credentials []*defaultCredential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider { + if pollInterval <= 0 { + pollInterval = defaultPollInterval + } + return &balancerProvider{ + credentials: credentials, + strategy: strategy, + pollInterval: pollInterval, + sessions: make(map[string]sessionEntry), + logger: logger, + } +} + +func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredential, bool, error) { + if sessionID != "" { + p.sessionMutex.RLock() + entry, exists := p.sessions[sessionID] + p.sessionMutex.RUnlock() + if exists { + for _, credential := range p.credentials { + if credential.tag == entry.tag && credential.isUsable() { + return credential, false, nil + } + } + p.sessionMutex.Lock() + delete(p.sessions, sessionID) + p.sessionMutex.Unlock() + } + } + + best := p.pickCredential() + if best == nil { + return nil, false, allRateLimitedError(p.credentials) + } + + isNew := sessionID != "" + if isNew { + p.sessionMutex.Lock() + p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()} + p.sessionMutex.Unlock() + } + return best, isNew, nil +} + +func (p *balancerProvider) onRateLimited(sessionID string, credential *defaultCredential, resetAt time.Time) *defaultCredential { + credential.markRateLimited(resetAt) + if sessionID != "" { + p.sessionMutex.Lock() + delete(p.sessions, sessionID) + p.sessionMutex.Unlock() + } + + best := p.pickCredential() + if best != nil && sessionID != "" { + p.sessionMutex.Lock() + p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()} + p.sessionMutex.Unlock() + } + return best +} + +func (p *balancerProvider) pickCredential() *defaultCredential { + switch p.strategy { + case "round_robin": + return p.pickRoundRobin() + case "random": + return p.pickRandom() + default: + return p.pickLeastUsed() + } +} + +func (p *balancerProvider) pickLeastUsed() *defaultCredential { + var best *defaultCredential + bestUtilization := float64(101) + for _, credential := range p.credentials { + if !credential.isUsable() { + continue + } + utilization := credential.weeklyUtilization() + if utilization < bestUtilization { + bestUtilization = utilization + best = credential + } + } + return best +} + +func (p *balancerProvider) pickRoundRobin() *defaultCredential { + start := int(p.roundRobinIndex.Add(1) - 1) + count := len(p.credentials) + for offset := range count { + candidate := p.credentials[(start+offset)%count] + if candidate.isUsable() { + return candidate + } + } + return nil +} + +func (p *balancerProvider) pickRandom() *defaultCredential { + var usable []*defaultCredential + for _, candidate := range p.credentials { + if candidate.isUsable() { + usable = append(usable, candidate) + } + } + if len(usable) == 0 { + return nil + } + return usable[rand.IntN(len(usable))] +} + +func (p *balancerProvider) pollIfStale(ctx context.Context) { + now := time.Now() + p.sessionMutex.Lock() + for id, entry := range p.sessions { + if now.Sub(entry.createdAt) > sessionExpiry { + delete(p.sessions, id) + } + } + p.sessionMutex.Unlock() + + for _, credential := range p.credentials { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { + credential.pollUsage(ctx) + } + } +} + +func (p *balancerProvider) allDefaults() []*defaultCredential { + return p.credentials +} + +func (p *balancerProvider) close() {} + +type fallbackProvider struct { + credentials []*defaultCredential + pollInterval time.Duration + logger log.ContextLogger +} + +func newFallbackProvider(credentials []*defaultCredential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider { + if pollInterval <= 0 { + pollInterval = defaultPollInterval + } + return &fallbackProvider{ + credentials: credentials, + pollInterval: pollInterval, + logger: logger, + } +} + +func (p *fallbackProvider) selectCredential(_ string) (*defaultCredential, bool, error) { + for _, credential := range p.credentials { + if credential.isUsable() { + return credential, false, nil + } + } + return nil, false, allRateLimitedError(p.credentials) +} + +func (p *fallbackProvider) onRateLimited(_ string, credential *defaultCredential, resetAt time.Time) *defaultCredential { + credential.markRateLimited(resetAt) + for _, candidate := range p.credentials { + if candidate.isUsable() { + return candidate + } + } + return nil +} + +func (p *fallbackProvider) pollIfStale(ctx context.Context) { + for _, credential := range p.credentials { + if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) { + credential.pollUsage(ctx) + } + } +} + +func (p *fallbackProvider) allDefaults() []*defaultCredential { + return p.credentials +} + +func (p *fallbackProvider) close() {} + +func allRateLimitedError(credentials []*defaultCredential) error { + var earliest time.Time + for _, credential := range credentials { + resetAt := credential.earliestReset() + if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { + earliest = resetAt + } + } + if earliest.IsZero() { + return E.New("all credentials rate-limited") + } + return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest))) +} + +func buildOCMCredentialProviders( + ctx context.Context, + options option.OCMServiceOptions, + logger log.ContextLogger, +) (map[string]credentialProvider, []*defaultCredential, error) { + defaultCredentials := make(map[string]*defaultCredential) + var allDefaults []*defaultCredential + providers := make(map[string]credentialProvider) + + for _, credOpt := range options.Credentials { + switch credOpt.Type { + case "default": + credential, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) + if err != nil { + return nil, nil, err + } + defaultCredentials[credOpt.Tag] = credential + allDefaults = append(allDefaults, credential) + providers[credOpt.Tag] = &singleCredentialProvider{credential: credential} + } + } + + for _, credOpt := range options.Credentials { + switch credOpt.Type { + case "balancer": + subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, defaultCredentials, credOpt.Tag) + if err != nil { + return nil, nil, err + } + providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), logger) + case "fallback": + subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, defaultCredentials, credOpt.Tag) + if err != nil { + return nil, nil, err + } + providers[credOpt.Tag] = newFallbackProvider(subCredentials, time.Duration(credOpt.FallbackOptions.PollInterval), logger) + } + } + + return providers, allDefaults, nil +} + +func resolveCredentialTags(tags []string, defaults map[string]*defaultCredential, parentTag string) ([]*defaultCredential, error) { + credentials := make([]*defaultCredential, 0, len(tags)) + for _, tag := range tags { + credential, exists := defaults[tag] + if !exists { + return nil, E.New("credential ", parentTag, " references unknown default credential: ", tag) + } + credentials = append(credentials, credential) + } + if len(credentials) == 0 { + return nil, E.New("credential ", parentTag, " has no sub-credentials") + } + return credentials, nil +} + +func parseOCMRateLimitResetFromHeaders(headers http.Header) time.Time { + activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) + if activeLimitIdentifier != "" { + resetHeader := "x-" + activeLimitIdentifier + "-primary-reset-at" + if resetStr := headers.Get(resetHeader); resetStr != "" { + value, err := strconv.ParseInt(resetStr, 10, 64) + if err == nil { + return time.Unix(value, 0) + } + } + } + if retryAfter := headers.Get("Retry-After"); retryAfter != "" { + seconds, err := strconv.ParseInt(retryAfter, 10, 64) + if err == nil { + return time.Now().Add(time.Duration(seconds) * time.Second) + } + } + return time.Now().Add(5 * time.Minute) +} + +func validateOCMOptions(options option.OCMServiceOptions) error { + hasCredentials := len(options.Credentials) > 0 + hasLegacyPath := options.CredentialPath != "" + hasLegacyUsages := options.UsagesPath != "" + hasLegacyDetour := options.Detour != "" + + if hasCredentials && hasLegacyPath { + return E.New("credential_path and credentials are mutually exclusive") + } + if hasCredentials && hasLegacyUsages { + return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials") + } + if hasCredentials && hasLegacyDetour { + return E.New("detour and credentials are mutually exclusive; use detour on individual credentials") + } + + if hasCredentials { + tags := make(map[string]bool) + for _, credential := range options.Credentials { + if tags[credential.Tag] { + return E.New("duplicate credential tag: ", credential.Tag) + } + tags[credential.Tag] = true + if credential.Type == "default" || credential.Type == "" { + if credential.DefaultOptions.Reserve5h > 99 { + return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99") + } + if credential.DefaultOptions.ReserveWeekly > 99 { + return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99") + } + } + if credential.Type == "balancer" { + switch credential.BalancerOptions.Strategy { + case "", "least_used", "round_robin", "random": + default: + return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy) + } + } + } + + for _, user := range options.Users { + if user.Credential == "" { + return E.New("user ", user.Name, " must specify credential in multi-credential mode") + } + if !tags[user.Credential] { + return E.New("user ", user.Name, " references unknown credential: ", user.Credential) + } + } + } + + return nil +} + +func validateOCMCompositeCredentialModes( + options option.OCMServiceOptions, + providers map[string]credentialProvider, +) error { + for _, credential := range options.Credentials { + if credential.Type != "balancer" && credential.Type != "fallback" { + continue + } + + provider, exists := providers[credential.Tag] + if !exists { + return E.New("unknown credential: ", credential.Tag) + } + + for _, subCredential := range provider.allDefaults() { + if subCredential.isAPIKeyMode() { + return E.New( + "credential ", credential.Tag, + " references API key default credential ", subCredential.tag, + "; balancer and fallback only support OAuth default credentials", + ) + } + } + } + + return nil +} + +func retryOCMRequestWithBody( + ctx context.Context, + originalRequest *http.Request, + bodyBytes []byte, + credential *defaultCredential, + httpHeaders http.Header, +) (*http.Response, error) { + accessToken, err := credential.getAccessToken() + if err != nil { + return nil, E.Cause(err, "get access token for ", credential.tag) + } + + baseURL := credential.getBaseURL() + path := originalRequest.URL.Path + var proxyPath string + if credential.isAPIKeyMode() { + proxyPath = path + } else { + proxyPath = strings.TrimPrefix(path, "/v1") + } + + proxyURL := baseURL + proxyPath + if originalRequest.URL.RawQuery != "" { + proxyURL += "?" + originalRequest.URL.RawQuery + } + + var body io.Reader + if bodyBytes != nil { + body = bytes.NewReader(bodyBytes) + } + retryRequest, err := http.NewRequestWithContext(ctx, originalRequest.Method, proxyURL, body) + if err != nil { + return nil, err + } + + for key, values := range originalRequest.Header { + if !isHopByHopHeader(key) && key != "Authorization" { + retryRequest.Header[key] = values + } + } + for key, values := range httpHeaders { + retryRequest.Header.Del(key) + retryRequest.Header[key] = values + } + retryRequest.Header.Set("Authorization", "Bearer "+accessToken) + if accountID := credential.getAccountID(); accountID != "" { + retryRequest.Header.Set("ChatGPT-Account-Id", accountID) + } + + return credential.httpClient.Do(retryRequest) +} + +func credentialForUser( + userCredentialMap map[string]string, + providers map[string]credentialProvider, + legacyProvider credentialProvider, + username string, +) (credentialProvider, error) { + if legacyProvider != nil { + return legacyProvider, nil + } + tag, exists := userCredentialMap[username] + if !exists { + return nil, E.New("no credential mapping for user: ", username) + } + provider, exists := providers[tag] + if !exists { + return nil, E.New("unknown credential: ", tag) + } + return provider, nil +} + +func noUserCredentialProvider( + providers map[string]credentialProvider, + legacyProvider credentialProvider, + options option.OCMServiceOptions, +) credentialProvider { + if legacyProvider != nil { + return legacyProvider + } + if len(options.Credentials) > 0 { + tag := options.Credentials[0].Tag + return providers[tag] + } + return nil +} diff --git a/service/ocm/service.go b/service/ocm/service.go index 8b66964a9..75f28f2c1 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -3,12 +3,10 @@ package ocm import ( "bytes" "context" - stdTLS "crypto/tls" "encoding/json" "errors" "io" "mime" - "net" "net/http" "strconv" "strings" @@ -17,7 +15,6 @@ import ( "github.com/sagernet/sing-box/adapter" boxService "github.com/sagernet/sing-box/adapter/service" - "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/listener" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" @@ -26,9 +23,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/ntp" aTLS "github.com/sagernet/sing/common/tls" "github.com/go-chi/chi/v5" @@ -52,17 +47,77 @@ type errorDetails struct { } func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) { + writeJSONErrorWithCode(w, r, statusCode, errorType, "", message) +} + +func writeJSONErrorWithCode(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, errorCode string, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) json.NewEncoder(w).Encode(errorResponse{ Error: errorDetails{ Type: errorType, + Code: errorCode, Message: message, }, }) } +func writePlainTextError(w http.ResponseWriter, statusCode int, message string) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(statusCode) + _, _ = io.WriteString(w, message) +} + +const ( + retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential" + retryableUsageCode = "credential_usage_exhausted" +) + +func hasAlternativeCredential(provider credentialProvider, currentCredential *defaultCredential) bool { + if provider == nil || currentCredential == nil { + return false + } + for _, credential := range provider.allDefaults() { + if credential == currentCredential { + continue + } + if credential.isUsable() { + return true + } + } + return false +} + +func unavailableCredentialMessage(provider credentialProvider, fallback string) string { + if provider == nil { + return fallback + } + return allRateLimitedError(provider.allDefaults()).Error() +} + +func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) { + writeJSONErrorWithCode(w, r, http.StatusServiceUnavailable, "server_error", retryableUsageCode, retryableUsageMessage) +} + +func writeNonRetryableCredentialError(w http.ResponseWriter, message string) { + writePlainTextError(w, http.StatusBadRequest, message) +} + +func writeCredentialUnavailableError( + w http.ResponseWriter, + r *http.Request, + provider credentialProvider, + currentCredential *defaultCredential, + fallback string, +) { + if hasAlternativeCredential(provider, currentCredential) { + writeRetryableUsageError(w, r) + return + } + writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, fallback)) +} + func isHopByHopHeader(header string) bool { switch strings.ToLower(header) { case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host": @@ -127,72 +182,43 @@ type Service struct { boxService.Adapter ctx context.Context logger log.ContextLogger - credentialPath string - credentials *oauthCredentials - users []option.OCMUser - dialer N.Dialer - httpClient *http.Client + options option.OCMServiceOptions httpHeaders http.Header listener *listener.Listener tlsConfig tls.ServerConfig httpServer *http.Server userManager *UserManager - accessMutex sync.RWMutex - usageTracker *AggregatedUsage webSocketMutex sync.Mutex webSocketGroup sync.WaitGroup webSocketConns map[*webSocketSession]struct{} shuttingDown bool + + // Legacy mode + legacyCredential *defaultCredential + legacyProvider credentialProvider + + // Multi-credential mode + providers map[string]credentialProvider + allDefaults []*defaultCredential + userCredentialMap map[string]string } func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) { - serviceDialer, err := dialer.NewWithOptions(dialer.Options{ - Context: ctx, - Options: option.DialerOptions{ - Detour: options.Detour, - }, - RemoteIsDomain: true, - }) + err := validateOCMOptions(options) if err != nil { - return nil, E.Cause(err, "create dialer") - } - - httpClient := &http.Client{ - Transport: &http.Transport{ - ForceAttemptHTTP2: true, - TLSClientConfig: &stdTLS.Config{ - RootCAs: adapter.RootPoolFromContext(ctx), - Time: ntp.TimeFuncFromContext(ctx), - }, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - }, + return nil, E.Cause(err, "validate options") } userManager := &UserManager{ tokenMap: make(map[string]string), } - var usageTracker *AggregatedUsage - if options.UsagesPath != "" { - usageTracker = &AggregatedUsage{ - LastUpdated: time.Now(), - Combinations: make([]CostCombination, 0), - filePath: options.UsagesPath, - logger: logger, - } - } - service := &Service{ - Adapter: boxService.NewAdapter(C.TypeOCM, tag), - ctx: ctx, - logger: logger, - credentialPath: options.CredentialPath, - users: options.Users, - dialer: serviceDialer, - httpClient: httpClient, - httpHeaders: options.Headers.Build(), + Adapter: boxService.NewAdapter(C.TypeOCM, tag), + ctx: ctx, + logger: logger, + options: options, + httpHeaders: options.Headers.Build(), listener: listener.New(listener.Options{ Context: ctx, Logger: logger, @@ -200,10 +226,36 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio Listen: options.ListenOptions, }), userManager: userManager, - usageTracker: usageTracker, webSocketConns: make(map[*webSocketSession]struct{}), } + if len(options.Credentials) > 0 { + providers, allDefaults, err := buildOCMCredentialProviders(ctx, options, logger) + if err != nil { + return nil, E.Cause(err, "build credential providers") + } + service.providers = providers + service.allDefaults = allDefaults + + userCredentialMap := make(map[string]string) + for _, user := range options.Users { + userCredentialMap[user.Name] = user.Credential + } + service.userCredentialMap = userCredentialMap + } else { + credential, err := newDefaultCredential(ctx, "default", option.OCMDefaultCredentialOptions{ + CredentialPath: options.CredentialPath, + UsagesPath: options.UsagesPath, + Detour: options.Detour, + }, logger) + if err != nil { + return nil, err + } + service.legacyCredential = credential + service.legacyProvider = &singleCredentialProvider{credential: credential} + service.allDefaults = []*defaultCredential{credential} + } + if options.TLS != nil { tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) if err != nil { @@ -220,18 +272,22 @@ func (s *Service) Start(stage adapter.StartStage) error { return nil } - s.userManager.UpdateUsers(s.users) + s.userManager.UpdateUsers(s.options.Users) - credentials, err := platformReadCredentials(s.credentialPath) - if err != nil { - return E.Cause(err, "read credentials") - } - s.credentials = credentials - - if s.usageTracker != nil { - err = s.usageTracker.Load() + for _, credential := range s.allDefaults { + err := credential.start() if err != nil { - s.logger.Warn("load usage statistics: ", err) + return err + } + tag := credential.tag + credential.onBecameUnusable = func() { + s.interruptWebSocketSessionsForCredential(tag) + } + } + if len(s.options.Credentials) > 0 { + err := validateOCMCompositeCredentialModes(s.options, s.providers) + if err != nil { + return E.Cause(err, "validate loaded credentials") } } @@ -241,7 +297,7 @@ func (s *Service) Start(stage adapter.StartStage) error { s.httpServer = &http.Server{Handler: router} if s.tlsConfig != nil { - err = s.tlsConfig.Start() + err := s.tlsConfig.Start() if err != nil { return E.Cause(err, "create TLS config") } @@ -269,54 +325,15 @@ func (s *Service) Start(stage adapter.StartStage) error { return nil } -func (s *Service) getAccessToken() (string, error) { - s.accessMutex.RLock() - if !s.credentials.needsRefresh() { - token := s.credentials.getAccessToken() - s.accessMutex.RUnlock() - return token, nil +func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) { + if len(s.options.Users) > 0 { + return credentialForUser(s.userCredentialMap, s.providers, s.legacyProvider, username) } - s.accessMutex.RUnlock() - - s.accessMutex.Lock() - defer s.accessMutex.Unlock() - - if !s.credentials.needsRefresh() { - return s.credentials.getAccessToken(), nil + provider := noUserCredentialProvider(s.providers, s.legacyProvider, s.options) + if provider == nil { + return nil, E.New("no credential available") } - - newCredentials, err := refreshToken(s.httpClient, s.credentials) - if err != nil { - return "", err - } - - s.credentials = newCredentials - - err = platformWriteCredentials(newCredentials, s.credentialPath) - if err != nil { - s.logger.Warn("persist refreshed token: ", err) - } - - return newCredentials.getAccessToken(), nil -} - -func (s *Service) getAccountID() string { - s.accessMutex.RLock() - defer s.accessMutex.RUnlock() - return s.credentials.getAccountID() -} - -func (s *Service) isAPIKeyMode() bool { - s.accessMutex.RLock() - defer s.accessMutex.RUnlock() - return s.credentials.isAPIKeyMode() -} - -func (s *Service) getBaseURL() string { - if s.isAPIKeyMode() { - return openaiAPIBaseURL - } - return chatGPTBackendURL + return provider, nil } func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -326,20 +343,8 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - var proxyPath string - if s.isAPIKeyMode() { - proxyPath = path - } else { - if path == "/v1/chat/completions" { - writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", - "chat completions endpoint is only available in API key mode") - return - } - proxyPath = strings.TrimPrefix(path, "/v1") - } - var username string - if len(s.users) > 0 { + if len(s.options.Users) > 0 { authHeader := r.Header.Get("Authorization") if authHeader == "" { s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") @@ -361,39 +366,91 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") { - s.handleWebSocket(w, r, proxyPath, username) + sessionID := r.Header.Get("session_id") + + // Resolve credential provider + provider, err := s.resolveCredentialProvider(username) + if err != nil { + s.logger.Error("resolve credential: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) return } - var requestModel string + provider.pollIfStale(s.ctx) - if s.usageTracker != nil && r.Body != nil { - bodyBytes, err := io.ReadAll(r.Body) - if err == nil { - var request struct { - Model string `json:"model"` - } - err := json.Unmarshal(bodyBytes, &request) - if err == nil { - requestModel = request.Model - } - r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + credential, isNew, err := provider.selectCredential(sessionID) + if err != nil { + writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error())) + return + } + if isNew { + if username != "" { + s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID, " by user ", username) + } else { + s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID) } } - accessToken, err := s.getAccessToken() + if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") { + s.handleWebSocket(w, r, path, username, sessionID, provider, credential) + return + } + + var proxyPath string + if credential.isAPIKeyMode() { + proxyPath = path + } else { + if path == "/v1/chat/completions" { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "chat completions endpoint is only available in API key mode") + return + } + proxyPath = strings.TrimPrefix(path, "/v1") + } + + shouldTrackUsage := credential.usageTracker != nil && + (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) + canRetryRequest := len(provider.allDefaults()) > 1 + + // Read body for model extraction and retry buffer when JSON replay is useful. + var bodyBytes []byte + var requestModel string + if r.Body != nil && (shouldTrackUsage || canRetryRequest) { + mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type")) + isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json")) + if isJSONRequest { + bodyBytes, err = io.ReadAll(r.Body) + if err != nil { + s.logger.Error("read request body: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") + return + } + var request struct { + Model string `json:"model"` + } + if json.Unmarshal(bodyBytes, &request) == nil { + requestModel = request.Model + } + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + } + + accessToken, err := credential.getAccessToken() if err != nil { s.logger.Error("get access token: ", err) writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed") return } - proxyURL := s.getBaseURL() + proxyPath + proxyURL := credential.getBaseURL() + proxyPath if r.URL.RawQuery != "" { proxyURL += "?" + r.URL.RawQuery } - proxyRequest, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, r.Body) + requestContext := credential.wrapRequestContext(r.Context()) + defer func() { + requestContext.cancelRequest() + }() + proxyRequest, err := http.NewRequestWithContext(requestContext, r.Method, proxyURL, r.Body) if err != nil { s.logger.Error("create proxy request: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") @@ -413,17 +470,68 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { proxyRequest.Header.Set("Authorization", "Bearer "+accessToken) - if accountID := s.getAccountID(); accountID != "" { + if accountID := credential.getAccountID(); accountID != "" { proxyRequest.Header.Set("ChatGPT-Account-Id", accountID) } - response, err := s.httpClient.Do(proxyRequest) + response, err := credential.httpClient.Do(proxyRequest) if err != nil { + if r.Context().Err() != nil { + return + } + if requestContext.Err() != nil { + writeCredentialUnavailableError(w, r, provider, credential, "credential became unavailable while processing the request") + return + } writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) return } + requestContext.releaseCredentialInterrupt() + + // Transparent 429 retry + for response.StatusCode == http.StatusTooManyRequests { + resetAt := parseOCMRateLimitResetFromHeaders(response.Header) + nextCredential := provider.onRateLimited(sessionID, credential, resetAt) + needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete + credential.updateStateFromHeaders(response.Header) + if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil { + response.Body.Close() + writeCredentialUnavailableError(w, r, provider, credential, "all credentials rate-limited") + return + } + response.Body.Close() + s.logger.Info("retrying with credential ", nextCredential.tag, " after 429 from ", credential.tag) + requestContext.cancelRequest() + requestContext = nextCredential.wrapRequestContext(r.Context()) + retryResponse, retryErr := retryOCMRequestWithBody(requestContext, r, bodyBytes, nextCredential, s.httpHeaders) + if retryErr != nil { + if r.Context().Err() != nil { + return + } + if requestContext.Err() != nil { + writeCredentialUnavailableError(w, r, provider, nextCredential, "credential became unavailable while retrying the request") + return + } + s.logger.Error("retry request: ", retryErr) + writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) + return + } + requestContext.releaseCredentialInterrupt() + response = retryResponse + credential = nextCredential + } defer response.Body.Close() + credential.updateStateFromHeaders(response.Header) + + if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { + body, _ := io.ReadAll(response.Body) + s.logger.Error("upstream error from ", credential.tag, ": status ", response.StatusCode, " ", string(body)) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", + "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) + return + } + for key, values := range response.Header { if !isHopByHopHeader(key) { w.Header()[key] = values @@ -431,10 +539,10 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(response.StatusCode) - trackUsage := s.usageTracker != nil && response.StatusCode == http.StatusOK && - (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) - if trackUsage { - s.handleResponseWithTracking(w, response, path, requestModel, username) + hasUsageTracker := credential.usageTracker != nil + if hasUsageTracker && response.StatusCode == http.StatusOK && + (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) { + s.handleResponseWithTracking(w, response, credential.usageTracker, path, requestModel, username) } else { mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) if err == nil && mediaType != "text/event-stream" { @@ -464,7 +572,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, path string, requestModel string, username string) { +func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) { isChatCompletions := path == "/v1/chat/completions" weeklyCycleHint := extractWeeklyCycleHint(response.Header) mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) @@ -508,7 +616,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons } if responseModel != "" { contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) - s.usageTracker.AddUsageWithCycleHint( + usageTracker.AddUsageWithCycleHint( responseModel, contextWindow, inputTokens, @@ -619,7 +727,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons if inputTokens > 0 || outputTokens > 0 { if responseModel != "" { contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) - s.usageTracker.AddUsageWithCycleHint( + usageTracker.AddUsageWithCycleHint( responseModel, contextWindow, inputTokens, @@ -650,12 +758,8 @@ func (s *Service) Close() error { } s.webSocketGroup.Wait() - if s.usageTracker != nil { - s.usageTracker.cancelPendingSave() - saveErr := s.usageTracker.Save() - if saveErr != nil { - s.logger.Error("save usage statistics: ", saveErr) - } + for _, credential := range s.allDefaults { + credential.close() } return err @@ -693,6 +797,20 @@ func (s *Service) isShuttingDown() bool { return s.shuttingDown } +func (s *Service) interruptWebSocketSessionsForCredential(tag string) { + s.webSocketMutex.Lock() + var toClose []*webSocketSession + for session := range s.webSocketConns { + if session.credentialTag == tag { + toClose = append(toClose, session) + } + } + s.webSocketMutex.Unlock() + for _, session := range toClose { + session.Close() + } +} + func (s *Service) startWebSocketShutdown() []*webSocketSession { s.webSocketMutex.Lock() defer s.webSocketMutex.Unlock() diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index d19f2df81..eafd37aae 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -1,12 +1,14 @@ package ocm import ( + "bufio" "context" stdTLS "crypto/tls" "encoding/json" "io" "net" "net/http" + "net/textproto" "strings" "sync" "time" @@ -22,9 +24,10 @@ import ( ) type webSocketSession struct { - clientConn net.Conn - upstreamConn net.Conn - closeOnce sync.Once + clientConn net.Conn + upstreamConn net.Conn + credentialTag string + closeOnce sync.Once } func (s *webSocketSession) Close() { @@ -76,57 +79,113 @@ func isForwardableWebSocketRequestHeader(key string) bool { } } -func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyPath string, username string) { - accessToken, err := s.getAccessToken() - if err != nil { - s.logger.Error("get access token for websocket: ", err) - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed") - return - } +func (s *Service) handleWebSocket( + w http.ResponseWriter, + r *http.Request, + path string, + username string, + sessionID string, + provider credentialProvider, + credential *defaultCredential, +) { + var ( + err error + upstreamConn net.Conn + upstreamBufferedReader *bufio.Reader + upstreamResponseHeaders http.Header + statusCode int + ) - upstreamURL := buildUpstreamWebSocketURL(s.getBaseURL(), proxyPath) - if r.URL.RawQuery != "" { - upstreamURL += "?" + r.URL.RawQuery - } + for { + accessToken, accessErr := credential.getAccessToken() + if accessErr != nil { + s.logger.Error("get access token for websocket: ", accessErr) + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed") + return + } - upstreamHeaders := make(http.Header) - for key, values := range r.Header { - if isForwardableWebSocketRequestHeader(key) { + var proxyPath string + if credential.isAPIKeyMode() { + proxyPath = path + } else { + proxyPath = strings.TrimPrefix(path, "/v1") + } + + upstreamURL := buildUpstreamWebSocketURL(credential.getBaseURL(), proxyPath) + if r.URL.RawQuery != "" { + upstreamURL += "?" + r.URL.RawQuery + } + + upstreamHeaders := make(http.Header) + for key, values := range r.Header { + if isForwardableWebSocketRequestHeader(key) { + upstreamHeaders[key] = values + } + } + for key, values := range s.httpHeaders { + upstreamHeaders.Del(key) upstreamHeaders[key] = values } - } - for key, values := range s.httpHeaders { - upstreamHeaders.Del(key) - upstreamHeaders[key] = values - } - upstreamHeaders.Set("Authorization", "Bearer "+accessToken) - if accountID := s.getAccountID(); accountID != "" { - upstreamHeaders.Set("ChatGPT-Account-Id", accountID) - } + upstreamHeaders.Set("Authorization", "Bearer "+accessToken) + if accountID := credential.getAccountID(); accountID != "" { + upstreamHeaders.Set("ChatGPT-Account-Id", accountID) + } - upstreamResponseHeaders := make(http.Header) - upstreamDialer := ws.Dialer{ - NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) { - return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - TLSConfig: &stdTLS.Config{ - RootCAs: adapter.RootPoolFromContext(s.ctx), - Time: ntp.TimeFuncFromContext(s.ctx), - }, - Header: ws.HandshakeHeaderHTTP(upstreamHeaders), - OnHeader: func(key, value []byte) error { - upstreamResponseHeaders.Add(string(key), string(value)) - return nil - }, - } + upstreamResponseHeaders = make(http.Header) + statusCode = 0 + upstreamDialer := ws.Dialer{ + NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) { + return credential.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + TLSConfig: &stdTLS.Config{ + RootCAs: adapter.RootPoolFromContext(s.ctx), + Time: ntp.TimeFuncFromContext(s.ctx), + }, + Header: ws.HandshakeHeaderHTTP(upstreamHeaders), + // gobwas/ws@v1.4.0: the response io.Reader is + // MultiReader(statusLine_without_CRLF, "\r\n", bufferedConn). + // ReadString('\n') consumes the status line, then ReadMIMEHeader + // parses the remaining headers. + OnStatusError: func(status int, reason []byte, response io.Reader) { + statusCode = status + bufferedResponse := bufio.NewReader(response) + _, readErr := bufferedResponse.ReadString('\n') + if readErr != nil { + return + } + mimeHeader, readErr := textproto.NewReader(bufferedResponse).ReadMIMEHeader() + if readErr == nil { + upstreamResponseHeaders = http.Header(mimeHeader) + } + }, + OnHeader: func(key, value []byte) error { + upstreamResponseHeaders.Add(string(key), string(value)) + return nil + }, + } - upstreamConn, upstreamBufferedReader, _, err := upstreamDialer.Dial(r.Context(), upstreamURL) - if err != nil { + upstreamConn, upstreamBufferedReader, _, err = upstreamDialer.Dial(s.ctx, upstreamURL) + if err == nil { + break + } + if statusCode == http.StatusTooManyRequests { + resetAt := parseOCMRateLimitResetFromHeaders(upstreamResponseHeaders) + nextCredential := provider.onRateLimited(sessionID, credential, resetAt) + if nextCredential == nil { + credential.updateStateFromHeaders(upstreamResponseHeaders) + writeCredentialUnavailableError(w, r, provider, credential, "all credentials rate-limited") + return + } + s.logger.Info("retrying websocket with credential ", nextCredential.tag, " after 429 from ", credential.tag) + credential = nextCredential + continue + } s.logger.Error("dial upstream websocket: ", err) writeJSONError(w, r, http.StatusBadGateway, "api_error", "upstream websocket connection failed") return } + credential.updateStateFromHeaders(upstreamResponseHeaders) weeklyCycleHint := extractWeeklyCycleHint(upstreamResponseHeaders) clientResponseHeaders := make(http.Header) @@ -151,8 +210,9 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP return } session := &webSocketSession{ - clientConn: clientConn, - upstreamConn: upstreamConn, + clientConn: clientConn, + upstreamConn: upstreamConn, + credentialTag: credential.tag, } if !s.registerWebSocketSession(session) { session.Close() @@ -177,17 +237,17 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, modelChannel) + s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, credential, modelChannel) }() go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, modelChannel, username, weeklyCycleHint) + s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, credential, modelChannel, username, weeklyCycleHint) }() waitGroup.Wait() } -func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, modelChannel chan<- string) { +func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, credential *defaultCredential, modelChannel chan<- string) { for { data, opCode, err := wsutil.ReadClientData(clientConn) if err != nil { @@ -197,7 +257,7 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo return } - if opCode == ws.OpText && s.usageTracker != nil { + if opCode == ws.OpText && credential.usageTracker != nil { var request struct { Type string `json:"type"` Model string `json:"model"` @@ -220,7 +280,7 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo } } -func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { +func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, credential *defaultCredential, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { var requestModel string for { data, opCode, err := wsutil.ReadServerData(upstreamReadWriter) @@ -231,7 +291,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite return } - if opCode == ws.OpText && s.usageTracker != nil { + if opCode == ws.OpText && credential.usageTracker != nil { select { case model := <-modelChannel: requestModel = model @@ -257,7 +317,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite } if responseModel != "" { contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) - s.usageTracker.AddUsageWithCycleHint( + credential.usageTracker.AddUsageWithCycleHint( responseModel, contextWindow, inputTokens,