diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 635642b74..ac4b9ec49 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -190,7 +190,7 @@ jobs: - name: Set build tags run: | set -xeuo pipefail - TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,badlinkname,tfogo_checklinkname0' + TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,badlinkname,tfogo_checklinkname0' if [[ "${{ matrix.naive }}" == "true" ]]; then TAGS="${TAGS},with_naive_outbound" fi @@ -427,7 +427,7 @@ jobs: - name: Set build tags run: | set -xeuo pipefail - TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,badlinkname,tfogo_checklinkname0' + TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,badlinkname,tfogo_checklinkname0' if [[ "${{ matrix.legacy_go124 }}" != "true" ]]; then TAGS="${TAGS},with_naive_outbound" fi @@ -495,7 +495,7 @@ jobs: - name: Build run: | mkdir -p dist - go build -v -trimpath -o dist/sing-box.exe -tags "with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_naive_outbound,with_purego,badlinkname,tfogo_checklinkname0" ` + go build -v -trimpath -o dist/sing-box.exe -tags "with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,with_naive_outbound,with_purego,badlinkname,tfogo_checklinkname0" ` -ldflags "-s -buildid= -X github.com/sagernet/sing-box/constant.Version=${{ needs.calculate_version.outputs.version }} -checklinkname=0" ` ./cmd/sing-box env: @@ -885,6 +885,16 @@ jobs: with: path: dist merge-multiple: true + - name: Generate SFA version metadata + run: |- + VERSION_CODE=$(grep VERSION_CODE clients/android/version.properties | cut -d= -f2) + cat > dist/SFA-version-metadata.json << EOF + { + "version_code": ${VERSION_CODE}, + "version_name": "${VERSION}" + } + EOF + cat dist/SFA-version-metadata.json - name: Upload builds if: ${{ env.PUBLISHED == 'false' }} run: |- diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index e4041b5b0..5447457e2 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -93,7 +93,7 @@ jobs: - name: Set build tags run: | set -xeuo pipefail - TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,badlinkname,tfogo_checklinkname0' + TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,badlinkname,tfogo_checklinkname0' if [[ "${{ matrix.naive }}" == "true" ]]; then TAGS="${TAGS},with_naive_outbound,with_musl" fi diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index f755f2df4..5b60f4e7d 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -116,7 +116,7 @@ jobs: - name: Set build tags run: | set -xeuo pipefail - TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,badlinkname,tfogo_checklinkname0' + TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,badlinkname,tfogo_checklinkname0' if [[ "${{ matrix.naive }}" == "true" ]]; then TAGS="${TAGS},with_naive_outbound,with_musl" fi diff --git a/Dockerfile b/Dockerfile index 5162d4613..fb39e8b60 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,7 +13,7 @@ RUN set -ex \ && export COMMIT=$(git rev-parse --short HEAD) \ && export VERSION=$(go run ./cmd/internal/read_tag) \ && go build -v -trimpath -tags \ - "with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,badlinkname,tfogo_checklinkname0" \ + "with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,badlinkname,tfogo_checklinkname0" \ -o /go/bin/sing-box \ -ldflags "-X \"github.com/sagernet/sing-box/constant.Version=$VERSION\" -s -w -buildid= -checklinkname=0" \ ./cmd/sing-box diff --git a/Makefile b/Makefile index c5e2d8ff8..bd70837eb 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ NAME = sing-box COMMIT = $(shell git rev-parse --short HEAD) -TAGS ?= with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,badlinkname,tfogo_checklinkname0 +TAGS ?= with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,badlinkname,tfogo_checklinkname0 GOHOSTOS = $(shell go env GOHOSTOS) GOHOSTARCH = $(shell go env GOHOSTARCH) diff --git a/constant/proxy.go b/constant/proxy.go index a54a3a75d..a51936234 100644 --- a/constant/proxy.go +++ b/constant/proxy.go @@ -29,6 +29,7 @@ const ( TypeResolved = "resolved" TypeSSMAPI = "ssm-api" TypeCCM = "ccm" + TypeOCM = "ocm" ) const ( diff --git a/docs/configuration/service/index.md b/docs/configuration/service/index.md index 2bd1a4a3f..de3583b2b 100644 --- a/docs/configuration/service/index.md +++ b/docs/configuration/service/index.md @@ -25,6 +25,7 @@ icon: material/new-box |------------|------------------------| | `ccm` | [CCM](./ccm) | | `derp` | [DERP](./derp) | +| `ocm` | [OCM](./ocm) | | `resolved` | [Resolved](./resolved) | | `ssm-api` | [SSM API](./ssm-api) | diff --git a/docs/configuration/service/index.zh.md b/docs/configuration/service/index.zh.md index b4a73eda9..a0d18cbba 100644 --- a/docs/configuration/service/index.zh.md +++ b/docs/configuration/service/index.zh.md @@ -25,6 +25,7 @@ icon: material/new-box |-----------|------------------------| | `ccm` | [CCM](./ccm) | | `derp` | [DERP](./derp) | +| `ocm` | [OCM](./ocm) | | `resolved`| [Resolved](./resolved) | | `ssm-api` | [SSM API](./ssm-api) | diff --git a/docs/configuration/service/ocm.md b/docs/configuration/service/ocm.md new file mode 100644 index 000000000..59dba7daa --- /dev/null +++ b/docs/configuration/service/ocm.md @@ -0,0 +1,171 @@ +--- +icon: material/new-box +--- + +!!! question "Since sing-box 1.13.0" + +# OCM + +OCM (OpenAI Codex Multiplexer) service is a multiplexing service that allows you to access your local OpenAI Codex subscription remotely through custom tokens. + +It handles OAuth authentication with OpenAI's API on your local machine while allowing remote clients to authenticate using custom tokens. + +### Structure + +```json +{ + "type": "ocm", + + ... // Listen Fields + + "credential_path": "", + "usages_path": "", + "users": [], + "headers": {}, + "detour": "", + "tls": {} +} +``` + +### Listen Fields + +See [Listen Fields](/configuration/shared/listen/) for details. + +### Fields + +#### credential_path + +Path to the OpenAI OAuth credentials file. + +If not specified, defaults to `~/.codex/auth.json`. + +Refreshed tokens are automatically written back to the same location. + +#### usages_path + +Path to the file for storing aggregated API usage statistics. + +Usage tracking is disabled if not specified. + +When enabled, the service tracks and saves comprehensive statistics including: +- Request counts +- Token usage (input, output, cached) +- Calculated costs in USD based on OpenAI API pricing + +Statistics are organized by model and optionally by user when authentication is enabled. + +The statistics file is automatically saved every minute and upon service shutdown. + +#### users + +List of authorized users for token authentication. + +If empty, no authentication is required. + +Object format: + +```json +{ + "name": "", + "token": "" +} +``` + +Object fields: + +- `name`: Username identifier for tracking purposes. +- `token`: Bearer token for authentication. Clients authenticate by setting the `Authorization: Bearer ` header. + +#### headers + +Custom HTTP headers to send to the OpenAI API. + +These headers will override any existing headers with the same name. + +#### detour + +Outbound tag for connecting to the OpenAI API. + +#### tls + +TLS configuration, see [TLS](/configuration/shared/tls/#inbound). + +### Example + +#### Server + +```json +{ + "services": [ + { + "type": "ocm", + "listen": "127.0.0.1", + "listen_port": 8080 + } + ] +} +``` + +#### Client + +Add to `~/.codex/config.toml`: + +```toml +[model_providers.ocm] +name = "OCM Proxy" +base_url = "http://127.0.0.1:8080/v1" +wire_api = "responses" +requires_openai_auth = false +``` + +Then run: + +```bash +codex --model-provider ocm +``` + +### Example with Authentication + +#### Server + +```json +{ + "services": [ + { + "type": "ocm", + "listen": "0.0.0.0", + "listen_port": 8080, + "usages_path": "./codex-usages.json", + "users": [ + { + "name": "alice", + "token": "sk-alice-secret-token" + }, + { + "name": "bob", + "token": "sk-bob-secret-token" + } + ] + } + ] +} +``` + +#### Client + +Add to `~/.codex/config.toml`: + +```toml +[model_providers.ocm] +name = "OCM Proxy" +base_url = "http://127.0.0.1:8080/v1" +wire_api = "responses" +requires_openai_auth = false +experimental_bearer_token = "sk-alice-secret-token" +``` + +Then run: + +```bash +codex --model-provider ocm +``` diff --git a/docs/configuration/service/ocm.zh.md b/docs/configuration/service/ocm.zh.md new file mode 100644 index 000000000..ee1d85101 --- /dev/null +++ b/docs/configuration/service/ocm.zh.md @@ -0,0 +1,171 @@ +--- +icon: material/new-box +--- + +!!! question "自 sing-box 1.13.0 起" + +# OCM + +OCM(OpenAI Codex 多路复用器)服务是一个多路复用服务,允许您通过自定义令牌远程访问本地的 OpenAI Codex 订阅。 + +它在本地机器上处理与 OpenAI API 的 OAuth 身份验证,同时允许远程客户端使用自定义令牌进行身份验证。 + +### 结构 + +```json +{ + "type": "ocm", + + ... // 监听字段 + + "credential_path": "", + "usages_path": "", + "users": [], + "headers": {}, + "detour": "", + "tls": {} +} +``` + +### 监听字段 + +参阅 [监听字段](/zh/configuration/shared/listen/) 了解详情。 + +### 字段 + +#### credential_path + +OpenAI OAuth 凭据文件的路径。 + +如果未指定,默认值为 `~/.codex/auth.json`。 + +刷新的令牌会自动写回相同位置。 + +#### usages_path + +用于存储聚合 API 使用统计信息的文件路径。 + +如果未指定,使用跟踪将被禁用。 + +启用后,服务会跟踪并保存全面的统计信息,包括: +- 请求计数 +- 令牌使用量(输入、输出、缓存) +- 基于 OpenAI API 定价计算的美元成本 + +统计信息按模型以及可选的用户(启用身份验证时)进行组织。 + +统计文件每分钟自动保存一次,并在服务关闭时保存。 + +#### users + +用于令牌身份验证的授权用户列表。 + +如果为空,则不需要身份验证。 + +对象格式: + +```json +{ + "name": "", + "token": "" +} +``` + +对象字段: + +- `name`:用于跟踪的用户名标识符。 +- `token`:用于身份验证的 Bearer 令牌。客户端通过设置 `Authorization: Bearer ` 头进行身份验证。 + +#### headers + +发送到 OpenAI API 的自定义 HTTP 头。 + +这些头会覆盖同名的现有头。 + +#### detour + +用于连接 OpenAI API 的出站标签。 + +#### tls + +TLS 配置,参阅 [TLS](/zh/configuration/shared/tls/#inbound)。 + +### 示例 + +#### 服务端 + +```json +{ + "services": [ + { + "type": "ocm", + "listen": "127.0.0.1", + "listen_port": 8080 + } + ] +} +``` + +#### 客户端 + +在 `~/.codex/config.toml` 中添加: + +```toml +[model_providers.ocm] +name = "OCM Proxy" +base_url = "http://127.0.0.1:8080/v1" +wire_api = "responses" +requires_openai_auth = false +``` + +然后运行: + +```bash +codex --model-provider ocm +``` + +### 带身份验证的示例 + +#### 服务端 + +```json +{ + "services": [ + { + "type": "ocm", + "listen": "0.0.0.0", + "listen_port": 8080, + "usages_path": "./codex-usages.json", + "users": [ + { + "name": "alice", + "token": "sk-alice-secret-token" + }, + { + "name": "bob", + "token": "sk-bob-secret-token" + } + ] + } + ] +} +``` + +#### 客户端 + +在 `~/.codex/config.toml` 中添加: + +```toml +[model_providers.ocm] +name = "OCM Proxy" +base_url = "http://127.0.0.1:8080/v1" +wire_api = "responses" +requires_openai_auth = false +experimental_bearer_token = "sk-alice-secret-token" +``` + +然后运行: + +```bash +codex --model-provider ocm +``` diff --git a/go.mod b/go.mod index a0b13a0a0..61d8deef9 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/metacubex/utls v1.8.4 github.com/mholt/acmez/v3 v3.1.2 github.com/miekg/dns v1.1.67 + github.com/openai/openai-go/v3 v3.13.0 github.com/oschwald/maxminddb-golang v1.13.1 github.com/sagernet/asc-go v0.0.0-20241217030726-d563060fe4e1 github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a diff --git a/go.sum b/go.sum index 84efb47ee..789fa8d80 100644 --- a/go.sum +++ b/go.sum @@ -131,6 +131,8 @@ github.com/mitchellh/go-ps v1.0.0 h1:i6ampVEEF4wQFF+bkYfwYgY+F/uYJDktmvLPf7qIgjc github.com/mitchellh/go-ps v1.0.0/go.mod h1:J4lOc8z8yJs6vUwklHw2XEIiT4z4C40KtWVN3nvg8Pg= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= +github.com/openai/openai-go/v3 v3.13.0 h1:arSFmVHcBHNVYG5iqspPJrLoin0Qqn2JcCLWWcTcM1Q= +github.com/openai/openai-go/v3 v3.13.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE= github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= diff --git a/include/ocm.go b/include/ocm.go new file mode 100644 index 000000000..cdea9eeae --- /dev/null +++ b/include/ocm.go @@ -0,0 +1,12 @@ +//go:build with_ocm + +package include + +import ( + "github.com/sagernet/sing-box/adapter/service" + "github.com/sagernet/sing-box/service/ocm" +) + +func registerOCMService(registry *service.Registry) { + ocm.RegisterService(registry) +} diff --git a/include/ocm_stub.go b/include/ocm_stub.go new file mode 100644 index 000000000..d5a94fcba --- /dev/null +++ b/include/ocm_stub.go @@ -0,0 +1,20 @@ +//go:build !with_ocm + +package include + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/service" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +func registerOCMService(registry *service.Registry) { + service.Register[option.OCMServiceOptions](registry, C.TypeOCM, func(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) { + return nil, E.New(`OCM is not included in this build, rebuild with -tags with_ocm`) + }) +} diff --git a/include/registry.go b/include/registry.go index 8f08189d4..d909b8500 100644 --- a/include/registry.go +++ b/include/registry.go @@ -136,6 +136,7 @@ func ServiceRegistry() *service.Registry { registerDERPService(registry) registerCCMService(registry) + registerOCMService(registry) return registry } diff --git a/mkdocs.yml b/mkdocs.yml index c49bfa2a6..a505f3e4d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -176,6 +176,8 @@ nav: - DERP: configuration/service/derp.md - Resolved: configuration/service/resolved.md - SSM API: configuration/service/ssm-api.md + - CCM: configuration/service/ccm.md + - OCM: configuration/service/ocm.md markdown_extensions: - pymdownx.inlinehilite - pymdownx.snippets diff --git a/option/ocm.go b/option/ocm.go new file mode 100644 index 000000000..c13a1c1f5 --- /dev/null +++ b/option/ocm.go @@ -0,0 +1,20 @@ +package option + +import ( + "github.com/sagernet/sing/common/json/badoption" +) + +type OCMServiceOptions struct { + ListenOptions + InboundTLSOptionsContainer + CredentialPath string `json:"credential_path,omitempty"` + Users []OCMUser `json:"users,omitempty"` + Headers badoption.HTTPHeader `json:"headers,omitempty"` + Detour string `json:"detour,omitempty"` + UsagesPath string `json:"usages_path,omitempty"` +} + +type OCMUser struct { + Name string `json:"name,omitempty"` + Token string `json:"token,omitempty"` +} diff --git a/release/local/common.sh b/release/local/common.sh index d24bba475..68a494bab 100755 --- a/release/local/common.sh +++ b/release/local/common.sh @@ -11,7 +11,7 @@ INSTALL_CONFIG_PATH="/usr/local/etc/sing-box" INSTALL_DATA_PATH="/var/lib/sing-box" SYSTEMD_SERVICE_PATH="/etc/systemd/system" -DEFAULT_BUILD_TAGS="with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,badlinkname,tfogo_checklinkname0" +DEFAULT_BUILD_TAGS="with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,badlinkname,tfogo_checklinkname0" setup_environment() { if [ -d /usr/local/go ]; then diff --git a/service/ocm/credential.go b/service/ocm/credential.go new file mode 100644 index 000000000..76651a8e1 --- /dev/null +++ b/service/ocm/credential.go @@ -0,0 +1,173 @@ +package ocm + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "os" + "os/user" + "path/filepath" + "time" + + E "github.com/sagernet/sing/common/exceptions" +) + +const ( + oauth2ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + oauth2TokenURL = "https://auth.openai.com/oauth/token" + openaiAPIBaseURL = "https://api.openai.com" + chatGPTBackendURL = "https://chatgpt.com/backend-api/codex" + tokenRefreshIntervalDays = 8 +) + +func getRealUser() (*user.User, error) { + if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { + sudoUserInfo, err := user.Lookup(sudoUser) + if err == nil { + return sudoUserInfo, nil + } + } + return user.Current() +} + +func getDefaultCredentialsPath() (string, error) { + if codexHome := os.Getenv("CODEX_HOME"); codexHome != "" { + return filepath.Join(codexHome, "auth.json"), nil + } + userInfo, err := getRealUser() + if err != nil { + return "", err + } + return filepath.Join(userInfo.HomeDir, ".codex", "auth.json"), nil +} + +func readCredentialsFromFile(path string) (*oauthCredentials, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var credentials oauthCredentials + err = json.Unmarshal(data, &credentials) + if err != nil { + return nil, err + } + return &credentials, nil +} + +func writeCredentialsToFile(credentials *oauthCredentials, path string) error { + data, err := json.MarshalIndent(credentials, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0o600) +} + +type oauthCredentials struct { + APIKey string `json:"OPENAI_API_KEY,omitempty"` + Tokens *tokenData `json:"tokens,omitempty"` + LastRefresh *time.Time `json:"last_refresh,omitempty"` +} + +type tokenData struct { + IDToken string `json:"id_token,omitempty"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + AccountID string `json:"account_id,omitempty"` +} + +func (c *oauthCredentials) isAPIKeyMode() bool { + return c.APIKey != "" +} + +func (c *oauthCredentials) getAccessToken() string { + if c.APIKey != "" { + return c.APIKey + } + if c.Tokens != nil { + return c.Tokens.AccessToken + } + return "" +} + +func (c *oauthCredentials) getAccountID() string { + if c.Tokens != nil { + return c.Tokens.AccountID + } + return "" +} + +func (c *oauthCredentials) needsRefresh() bool { + if c.APIKey != "" { + return false + } + if c.Tokens == nil || c.Tokens.RefreshToken == "" { + return false + } + if c.LastRefresh == nil { + return true + } + return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour +} + +func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { + if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" { + return nil, E.New("refresh token is empty") + } + + requestBody, err := json.Marshal(map[string]string{ + "grant_type": "refresh_token", + "refresh_token": credentials.Tokens.RefreshToken, + "client_id": oauth2ClientID, + "scope": "openid profile email", + }) + if err != nil { + return nil, E.Cause(err, "marshal request") + } + + request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody)) + if err != nil { + return nil, err + } + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Accept", "application/json") + + response, err := httpClient.Do(request) + if err != nil { + return nil, err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + body, _ := io.ReadAll(response.Body) + return nil, E.New("refresh failed: ", response.Status, " ", string(body)) + } + + var tokenResponse struct { + IDToken string `json:"id_token"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + } + err = json.NewDecoder(response.Body).Decode(&tokenResponse) + if err != nil { + return nil, E.Cause(err, "decode response") + } + + newCredentials := *credentials + if newCredentials.Tokens == nil { + newCredentials.Tokens = &tokenData{} + } + if tokenResponse.IDToken != "" { + newCredentials.Tokens.IDToken = tokenResponse.IDToken + } + if tokenResponse.AccessToken != "" { + newCredentials.Tokens.AccessToken = tokenResponse.AccessToken + } + if tokenResponse.RefreshToken != "" { + newCredentials.Tokens.RefreshToken = tokenResponse.RefreshToken + } + now := time.Now() + newCredentials.LastRefresh = &now + + return &newCredentials, nil +} diff --git a/service/ocm/credential_darwin.go b/service/ocm/credential_darwin.go new file mode 100644 index 000000000..f3da2a63e --- /dev/null +++ b/service/ocm/credential_darwin.go @@ -0,0 +1,25 @@ +//go:build darwin + +package ocm + +func platformReadCredentials(customPath string) (*oauthCredentials, error) { + if customPath == "" { + var err error + customPath, err = getDefaultCredentialsPath() + if err != nil { + return nil, err + } + } + return readCredentialsFromFile(customPath) +} + +func platformWriteCredentials(credentials *oauthCredentials, customPath string) error { + if customPath == "" { + var err error + customPath, err = getDefaultCredentialsPath() + if err != nil { + return err + } + } + return writeCredentialsToFile(credentials, customPath) +} diff --git a/service/ocm/credential_other.go b/service/ocm/credential_other.go new file mode 100644 index 000000000..22dfd0337 --- /dev/null +++ b/service/ocm/credential_other.go @@ -0,0 +1,25 @@ +//go:build !darwin + +package ocm + +func platformReadCredentials(customPath string) (*oauthCredentials, error) { + if customPath == "" { + var err error + customPath, err = getDefaultCredentialsPath() + if err != nil { + return nil, err + } + } + return readCredentialsFromFile(customPath) +} + +func platformWriteCredentials(credentials *oauthCredentials, customPath string) error { + if customPath == "" { + var err error + customPath, err = getDefaultCredentialsPath() + if err != nil { + return err + } + } + return writeCredentialsToFile(credentials, customPath) +} diff --git a/service/ocm/service.go b/service/ocm/service.go new file mode 100644 index 000000000..e8f954105 --- /dev/null +++ b/service/ocm/service.go @@ -0,0 +1,555 @@ +package ocm + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "mime" + "net" + "net/http" + "strings" + "sync" + "time" + + "github.com/sagernet/sing-box/adapter" + boxService "github.com/sagernet/sing-box/adapter/service" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/common/listener" + "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + aTLS "github.com/sagernet/sing/common/tls" + + "github.com/go-chi/chi/v5" + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/responses" + "golang.org/x/net/http2" +) + +func RegisterService(registry *boxService.Registry) { + boxService.Register[option.OCMServiceOptions](registry, C.TypeOCM, NewService) +} + +type errorResponse struct { + Error errorDetails `json:"error"` +} + +type errorDetails struct { + Type string `json:"type"` + Code string `json:"code,omitempty"` + Message string `json:"message"` +} + +func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + + json.NewEncoder(w).Encode(errorResponse{ + Error: errorDetails{ + Type: errorType, + Message: message, + }, + }) +} + +func isHopByHopHeader(header string) bool { + switch strings.ToLower(header) { + case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host": + return true + default: + return false + } +} + +type Service struct { + boxService.Adapter + ctx context.Context + logger log.ContextLogger + credentialPath string + credentials *oauthCredentials + users []option.OCMUser + httpClient *http.Client + httpHeaders http.Header + listener *listener.Listener + tlsConfig tls.ServerConfig + httpServer *http.Server + userManager *UserManager + accessMutex sync.RWMutex + usageTracker *AggregatedUsage + trackingGroup sync.WaitGroup + shuttingDown bool +} + +func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) { + serviceDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{ + Detour: options.Detour, + }, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "create dialer") + } + + httpClient := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + }, + } + + userManager := &UserManager{ + tokenMap: make(map[string]string), + } + + var usageTracker *AggregatedUsage + if options.UsagesPath != "" { + usageTracker = &AggregatedUsage{ + LastUpdated: time.Now(), + Combinations: make([]CostCombination, 0), + filePath: options.UsagesPath, + logger: logger, + } + } + + service := &Service{ + Adapter: boxService.NewAdapter(C.TypeOCM, tag), + ctx: ctx, + logger: logger, + credentialPath: options.CredentialPath, + users: options.Users, + httpClient: httpClient, + httpHeaders: options.Headers.Build(), + listener: listener.New(listener.Options{ + Context: ctx, + Logger: logger, + Network: []string{N.NetworkTCP}, + Listen: options.ListenOptions, + }), + userManager: userManager, + usageTracker: usageTracker, + } + + if options.TLS != nil { + tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) + if err != nil { + return nil, err + } + service.tlsConfig = tlsConfig + } + + return service, nil +} + +func (s *Service) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + + s.userManager.UpdateUsers(s.users) + + credentials, err := platformReadCredentials(s.credentialPath) + if err != nil { + return E.Cause(err, "read credentials") + } + s.credentials = credentials + + if s.usageTracker != nil { + err = s.usageTracker.Load() + if err != nil { + s.logger.Warn("load usage statistics: ", err) + } + } + + router := chi.NewRouter() + router.Mount("/", s) + + s.httpServer = &http.Server{Handler: router} + + if s.tlsConfig != nil { + err = s.tlsConfig.Start() + if err != nil { + return E.Cause(err, "create TLS config") + } + } + + tcpListener, err := s.listener.ListenTCP() + if err != nil { + return err + } + + if s.tlsConfig != nil { + if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) { + s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...)) + } + tcpListener = aTLS.NewListener(tcpListener, s.tlsConfig) + } + + go func() { + serveErr := s.httpServer.Serve(tcpListener) + if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) { + s.logger.Error("serve error: ", serveErr) + } + }() + + return nil +} + +func (s *Service) getAccessToken() (string, error) { + s.accessMutex.RLock() + if !s.credentials.needsRefresh() { + token := s.credentials.getAccessToken() + s.accessMutex.RUnlock() + return token, nil + } + s.accessMutex.RUnlock() + + s.accessMutex.Lock() + defer s.accessMutex.Unlock() + + if !s.credentials.needsRefresh() { + return s.credentials.getAccessToken(), nil + } + + newCredentials, err := refreshToken(s.httpClient, s.credentials) + if err != nil { + return "", err + } + + s.credentials = newCredentials + + err = platformWriteCredentials(newCredentials, s.credentialPath) + if err != nil { + s.logger.Warn("persist refreshed token: ", err) + } + + return newCredentials.getAccessToken(), nil +} + +func (s *Service) getAccountID() string { + s.accessMutex.RLock() + defer s.accessMutex.RUnlock() + return s.credentials.getAccountID() +} + +func (s *Service) isAPIKeyMode() bool { + s.accessMutex.RLock() + defer s.accessMutex.RUnlock() + return s.credentials.isAPIKeyMode() +} + +func (s *Service) getBaseURL() string { + if s.isAPIKeyMode() { + return openaiAPIBaseURL + } + return chatGPTBackendURL +} + +func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + if !strings.HasPrefix(path, "/v1/") { + writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/") + return + } + + var proxyPath string + if s.isAPIKeyMode() { + proxyPath = path + } else { + if path == "/v1/chat/completions" { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "chat completions endpoint is only available in API key mode") + return + } + proxyPath = strings.TrimPrefix(path, "/v1") + } + + var username string + if len(s.users) > 0 { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") + return + } + clientToken := strings.TrimPrefix(authHeader, "Bearer ") + if clientToken == authHeader { + s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") + return + } + var ok bool + username, ok = s.userManager.Authenticate(clientToken) + if !ok { + s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") + return + } + } + + var requestModel string + + if s.usageTracker != nil && r.Body != nil { + bodyBytes, err := io.ReadAll(r.Body) + if err == nil { + var request struct { + Model string `json:"model"` + } + err := json.Unmarshal(bodyBytes, &request) + if err == nil { + requestModel = request.Model + } + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + } + } + + accessToken, err := s.getAccessToken() + if err != nil { + s.logger.Error("get access token: ", err) + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed") + return + } + + proxyURL := s.getBaseURL() + proxyPath + if r.URL.RawQuery != "" { + proxyURL += "?" + r.URL.RawQuery + } + proxyRequest, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, r.Body) + if err != nil { + s.logger.Error("create proxy request: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") + return + } + + for key, values := range r.Header { + if !isHopByHopHeader(key) && key != "Authorization" { + proxyRequest.Header[key] = values + } + } + + for key, values := range s.httpHeaders { + proxyRequest.Header.Del(key) + proxyRequest.Header[key] = values + } + + proxyRequest.Header.Set("Authorization", "Bearer "+accessToken) + + if accountID := s.getAccountID(); accountID != "" { + proxyRequest.Header.Set("ChatGPT-Account-Id", accountID) + } + + response, err := s.httpClient.Do(proxyRequest) + if err != nil { + writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) + return + } + defer response.Body.Close() + + for key, values := range response.Header { + if !isHopByHopHeader(key) { + w.Header()[key] = values + } + } + w.WriteHeader(response.StatusCode) + + trackUsage := s.usageTracker != nil && response.StatusCode == http.StatusOK && + (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) + if trackUsage { + s.handleResponseWithTracking(w, response, path, requestModel, username) + } else { + mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) + if err == nil && mediaType != "text/event-stream" { + _, _ = io.Copy(w, response.Body) + return + } + flusher, ok := w.(http.Flusher) + if !ok { + s.logger.Error("streaming not supported") + return + } + buffer := make([]byte, buf.BufferSize) + for { + n, err := response.Body.Read(buffer) + if n > 0 { + _, writeError := w.Write(buffer[:n]) + if writeError != nil { + s.logger.Error("write streaming response: ", writeError) + return + } + flusher.Flush() + } + if err != nil { + return + } + } + } +} + +func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, path string, requestModel string, username string) { + isChatCompletions := path == "/v1/chat/completions" + mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) + isStreaming := err == nil && mediaType == "text/event-stream" + + if !isStreaming { + bodyBytes, err := io.ReadAll(response.Body) + if err != nil { + s.logger.Error("read response body: ", err) + return + } + + var responseModel string + var inputTokens, outputTokens, cachedTokens int64 + + if isChatCompletions { + var chatCompletion openai.ChatCompletion + if json.Unmarshal(bodyBytes, &chatCompletion) == nil { + responseModel = chatCompletion.Model + inputTokens = chatCompletion.Usage.PromptTokens + outputTokens = chatCompletion.Usage.CompletionTokens + cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens + } + } else { + var responsesResponse responses.Response + if json.Unmarshal(bodyBytes, &responsesResponse) == nil { + responseModel = string(responsesResponse.Model) + inputTokens = responsesResponse.Usage.InputTokens + outputTokens = responsesResponse.Usage.OutputTokens + cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens + } + } + + if inputTokens > 0 || outputTokens > 0 { + if responseModel == "" { + responseModel = requestModel + } + if responseModel != "" { + s.usageTracker.AddUsage(responseModel, inputTokens, outputTokens, cachedTokens, username) + } + } + + _, _ = writer.Write(bodyBytes) + return + } + + flusher, ok := writer.(http.Flusher) + if !ok { + s.logger.Error("streaming not supported") + return + } + + var inputTokens, outputTokens, cachedTokens int64 + var responseModel string + buffer := make([]byte, buf.BufferSize) + var leftover []byte + + for { + n, err := response.Body.Read(buffer) + if n > 0 { + data := append(leftover, buffer[:n]...) + lines := bytes.Split(data, []byte("\n")) + + if err == nil { + leftover = lines[len(lines)-1] + lines = lines[:len(lines)-1] + } else { + leftover = nil + } + + for _, line := range lines { + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + + if bytes.HasPrefix(line, []byte("data: ")) { + eventData := bytes.TrimPrefix(line, []byte("data: ")) + if bytes.Equal(eventData, []byte("[DONE]")) { + continue + } + + if isChatCompletions { + var chatChunk openai.ChatCompletionChunk + if json.Unmarshal(eventData, &chatChunk) == nil { + if chatChunk.Model != "" { + responseModel = chatChunk.Model + } + if chatChunk.Usage.PromptTokens > 0 { + inputTokens = chatChunk.Usage.PromptTokens + cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens + } + if chatChunk.Usage.CompletionTokens > 0 { + outputTokens = chatChunk.Usage.CompletionTokens + } + } + } else { + var streamEvent responses.ResponseStreamEventUnion + if json.Unmarshal(eventData, &streamEvent) == nil { + if streamEvent.Type == "response.completed" { + completedEvent := streamEvent.AsResponseCompleted() + if string(completedEvent.Response.Model) != "" { + responseModel = string(completedEvent.Response.Model) + } + if completedEvent.Response.Usage.InputTokens > 0 { + inputTokens = completedEvent.Response.Usage.InputTokens + cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens + } + if completedEvent.Response.Usage.OutputTokens > 0 { + outputTokens = completedEvent.Response.Usage.OutputTokens + } + } + } + } + } + } + + _, writeError := writer.Write(buffer[:n]) + if writeError != nil { + s.logger.Error("write streaming response: ", writeError) + return + } + flusher.Flush() + } + + if err != nil { + if responseModel == "" { + responseModel = requestModel + } + + if inputTokens > 0 || outputTokens > 0 { + if responseModel != "" { + s.usageTracker.AddUsage(responseModel, inputTokens, outputTokens, cachedTokens, username) + } + } + return + } + } +} + +func (s *Service) Close() error { + err := common.Close( + common.PtrOrNil(s.httpServer), + common.PtrOrNil(s.listener), + s.tlsConfig, + ) + + if s.usageTracker != nil { + s.usageTracker.cancelPendingSave() + saveErr := s.usageTracker.Save() + if saveErr != nil { + s.logger.Error("save usage statistics: ", saveErr) + } + } + + return err +} diff --git a/service/ocm/service_usage.go b/service/ocm/service_usage.go new file mode 100644 index 000000000..7089f4d39 --- /dev/null +++ b/service/ocm/service_usage.go @@ -0,0 +1,445 @@ +package ocm + +import ( + "encoding/json" + "math" + "os" + "regexp" + "sync" + "time" + + "github.com/sagernet/sing-box/log" + E "github.com/sagernet/sing/common/exceptions" +) + +type UsageStats struct { + RequestCount int `json:"request_count"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CachedTokens int64 `json:"cached_tokens"` +} + +func (u *UsageStats) UnmarshalJSON(data []byte) error { + type Alias UsageStats + aux := &struct { + *Alias + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + }{ + Alias: (*Alias)(u), + } + err := json.Unmarshal(data, aux) + if err != nil { + return err + } + if u.InputTokens == 0 && aux.PromptTokens > 0 { + u.InputTokens = aux.PromptTokens + } + if u.OutputTokens == 0 && aux.CompletionTokens > 0 { + u.OutputTokens = aux.CompletionTokens + } + return nil +} + +type CostCombination struct { + Model string `json:"model"` + Total UsageStats `json:"total"` + ByUser map[string]UsageStats `json:"by_user"` +} + +type AggregatedUsage struct { + LastUpdated time.Time `json:"last_updated"` + Combinations []CostCombination `json:"combinations"` + mutex sync.Mutex + filePath string + logger log.ContextLogger + lastSaveTime time.Time + pendingSave bool + saveTimer *time.Timer + saveMutex sync.Mutex +} + +type UsageStatsJSON struct { + RequestCount int `json:"request_count"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CachedTokens int64 `json:"cached_tokens"` + CostUSD float64 `json:"cost_usd"` +} + +type CostCombinationJSON struct { + Model string `json:"model"` + Total UsageStatsJSON `json:"total"` + ByUser map[string]UsageStatsJSON `json:"by_user"` +} + +type CostsSummaryJSON struct { + TotalUSD float64 `json:"total_usd"` + ByUser map[string]float64 `json:"by_user"` +} + +type AggregatedUsageJSON struct { + LastUpdated time.Time `json:"last_updated"` + Costs CostsSummaryJSON `json:"costs"` + Combinations []CostCombinationJSON `json:"combinations"` +} + +type ModelPricing struct { + InputPrice float64 + OutputPrice float64 + CachedInputPrice float64 +} + +type modelFamily struct { + pattern *regexp.Regexp + pricing ModelPricing +} + +var ( + gpt4oPricing = ModelPricing{ + InputPrice: 2.5, + OutputPrice: 10.0, + CachedInputPrice: 1.25, + } + + gpt4oMiniPricing = ModelPricing{ + InputPrice: 0.15, + OutputPrice: 0.6, + CachedInputPrice: 0.075, + } + + gpt4oAudioPricing = ModelPricing{ + InputPrice: 2.5, + OutputPrice: 10.0, + CachedInputPrice: 1.25, + } + + o1Pricing = ModelPricing{ + InputPrice: 15.0, + OutputPrice: 60.0, + CachedInputPrice: 7.5, + } + + o1MiniPricing = ModelPricing{ + InputPrice: 1.1, + OutputPrice: 4.4, + CachedInputPrice: 0.55, + } + + o3MiniPricing = ModelPricing{ + InputPrice: 1.1, + OutputPrice: 4.4, + CachedInputPrice: 0.55, + } + + o3Pricing = ModelPricing{ + InputPrice: 2.0, + OutputPrice: 8.0, + CachedInputPrice: 1.0, + } + + o4MiniPricing = ModelPricing{ + InputPrice: 1.1, + OutputPrice: 4.4, + CachedInputPrice: 0.55, + } + + gpt41Pricing = ModelPricing{ + InputPrice: 2.0, + OutputPrice: 8.0, + CachedInputPrice: 0.5, + } + + gpt41MiniPricing = ModelPricing{ + InputPrice: 0.4, + OutputPrice: 1.6, + CachedInputPrice: 0.1, + } + + gpt41NanoPricing = ModelPricing{ + InputPrice: 0.1, + OutputPrice: 0.4, + CachedInputPrice: 0.025, + } + + modelFamilies = []modelFamily{ + { + pattern: regexp.MustCompile(`^gpt-4\.1-nano`), + pricing: gpt41NanoPricing, + }, + { + pattern: regexp.MustCompile(`^gpt-4\.1-mini`), + pricing: gpt41MiniPricing, + }, + { + pattern: regexp.MustCompile(`^gpt-4\.1`), + pricing: gpt41Pricing, + }, + { + pattern: regexp.MustCompile(`^o4-mini`), + pricing: o4MiniPricing, + }, + { + pattern: regexp.MustCompile(`^o3-mini`), + pricing: o3MiniPricing, + }, + { + pattern: regexp.MustCompile(`^o3`), + pricing: o3Pricing, + }, + { + pattern: regexp.MustCompile(`^o1-mini`), + pricing: o1MiniPricing, + }, + { + pattern: regexp.MustCompile(`^o1`), + pricing: o1Pricing, + }, + { + pattern: regexp.MustCompile(`^gpt-4o-audio`), + pricing: gpt4oAudioPricing, + }, + { + pattern: regexp.MustCompile(`^gpt-4o-mini`), + pricing: gpt4oMiniPricing, + }, + { + pattern: regexp.MustCompile(`^gpt-4o`), + pricing: gpt4oPricing, + }, + { + pattern: regexp.MustCompile(`^chatgpt-4o`), + pricing: gpt4oPricing, + }, + } +) + +func getPricing(model string) ModelPricing { + for _, family := range modelFamilies { + if family.pattern.MatchString(model) { + return family.pricing + } + } + return gpt4oPricing +} + +func calculateCost(stats UsageStats, model string) float64 { + pricing := getPricing(model) + + regularInputTokens := stats.InputTokens - stats.CachedTokens + if regularInputTokens < 0 { + regularInputTokens = 0 + } + + cost := (float64(regularInputTokens)*pricing.InputPrice + + float64(stats.OutputTokens)*pricing.OutputPrice + + float64(stats.CachedTokens)*pricing.CachedInputPrice) / 1_000_000 + + return math.Round(cost*100) / 100 +} + +func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON { + u.mutex.Lock() + defer u.mutex.Unlock() + + result := &AggregatedUsageJSON{ + LastUpdated: u.LastUpdated, + Combinations: make([]CostCombinationJSON, len(u.Combinations)), + Costs: CostsSummaryJSON{ + TotalUSD: 0, + ByUser: make(map[string]float64), + }, + } + + for i, combo := range u.Combinations { + totalCost := calculateCost(combo.Total, combo.Model) + + result.Costs.TotalUSD += totalCost + + comboJSON := CostCombinationJSON{ + Model: combo.Model, + Total: UsageStatsJSON{ + RequestCount: combo.Total.RequestCount, + InputTokens: combo.Total.InputTokens, + OutputTokens: combo.Total.OutputTokens, + CachedTokens: combo.Total.CachedTokens, + CostUSD: totalCost, + }, + ByUser: make(map[string]UsageStatsJSON), + } + + for user, userStats := range combo.ByUser { + userCost := calculateCost(userStats, combo.Model) + result.Costs.ByUser[user] += userCost + + comboJSON.ByUser[user] = UsageStatsJSON{ + RequestCount: userStats.RequestCount, + InputTokens: userStats.InputTokens, + OutputTokens: userStats.OutputTokens, + CachedTokens: userStats.CachedTokens, + CostUSD: userCost, + } + } + + result.Combinations[i] = comboJSON + } + + result.Costs.TotalUSD = math.Round(result.Costs.TotalUSD*100) / 100 + for user, cost := range result.Costs.ByUser { + result.Costs.ByUser[user] = math.Round(cost*100) / 100 + } + + return result +} + +func (u *AggregatedUsage) Load() error { + u.mutex.Lock() + defer u.mutex.Unlock() + + data, err := os.ReadFile(u.filePath) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + var temp struct { + LastUpdated time.Time `json:"last_updated"` + Combinations []CostCombination `json:"combinations"` + } + + err = json.Unmarshal(data, &temp) + if err != nil { + return err + } + + u.LastUpdated = temp.LastUpdated + u.Combinations = temp.Combinations + + for i := range u.Combinations { + if u.Combinations[i].ByUser == nil { + u.Combinations[i].ByUser = make(map[string]UsageStats) + } + } + + return nil +} + +func (u *AggregatedUsage) Save() error { + jsonData := u.ToJSON() + + data, err := json.MarshalIndent(jsonData, "", " ") + if err != nil { + return err + } + + tmpFile := u.filePath + ".tmp" + err = os.WriteFile(tmpFile, data, 0o644) + if err != nil { + return err + } + defer os.Remove(tmpFile) + err = os.Rename(tmpFile, u.filePath) + if err == nil { + u.saveMutex.Lock() + u.lastSaveTime = time.Now() + u.saveMutex.Unlock() + } + return err +} + +func (u *AggregatedUsage) AddUsage(model string, inputTokens, outputTokens, cachedTokens int64, user string) error { + if model == "" { + return E.New("model cannot be empty") + } + + u.mutex.Lock() + defer u.mutex.Unlock() + + u.LastUpdated = time.Now() + + var combo *CostCombination + for i := range u.Combinations { + if u.Combinations[i].Model == model { + combo = &u.Combinations[i] + break + } + } + + if combo == nil { + newCombo := CostCombination{ + Model: model, + Total: UsageStats{}, + ByUser: make(map[string]UsageStats), + } + u.Combinations = append(u.Combinations, newCombo) + combo = &u.Combinations[len(u.Combinations)-1] + } + + combo.Total.RequestCount++ + combo.Total.InputTokens += inputTokens + combo.Total.OutputTokens += outputTokens + combo.Total.CachedTokens += cachedTokens + + if user != "" { + userStats := combo.ByUser[user] + userStats.RequestCount++ + userStats.InputTokens += inputTokens + userStats.OutputTokens += outputTokens + userStats.CachedTokens += cachedTokens + combo.ByUser[user] = userStats + } + + go u.scheduleSave() + + return nil +} + +func (u *AggregatedUsage) scheduleSave() { + const saveInterval = time.Minute + + u.saveMutex.Lock() + defer u.saveMutex.Unlock() + + timeSinceLastSave := time.Since(u.lastSaveTime) + + if timeSinceLastSave >= saveInterval { + go u.saveAsync() + return + } + + if u.pendingSave { + return + } + + u.pendingSave = true + remainingTime := saveInterval - timeSinceLastSave + + u.saveTimer = time.AfterFunc(remainingTime, func() { + u.saveMutex.Lock() + u.pendingSave = false + u.saveMutex.Unlock() + u.saveAsync() + }) +} + +func (u *AggregatedUsage) saveAsync() { + err := u.Save() + if err != nil { + if u.logger != nil { + u.logger.Error("save usage statistics: ", err) + } + } +} + +func (u *AggregatedUsage) cancelPendingSave() { + u.saveMutex.Lock() + defer u.saveMutex.Unlock() + + if u.saveTimer != nil { + u.saveTimer.Stop() + u.saveTimer = nil + } + u.pendingSave = false +} diff --git a/service/ocm/service_user.go b/service/ocm/service_user.go new file mode 100644 index 000000000..494b981b9 --- /dev/null +++ b/service/ocm/service_user.go @@ -0,0 +1,29 @@ +package ocm + +import ( + "sync" + + "github.com/sagernet/sing-box/option" +) + +type UserManager struct { + accessMutex sync.RWMutex + tokenMap map[string]string +} + +func (m *UserManager) UpdateUsers(users []option.OCMUser) { + m.accessMutex.Lock() + defer m.accessMutex.Unlock() + tokenMap := make(map[string]string, len(users)) + for _, user := range users { + tokenMap[user.Token] = user.Name + } + m.tokenMap = tokenMap +} + +func (m *UserManager) Authenticate(token string) (string, bool) { + m.accessMutex.RLock() + username, found := m.tokenMap[token] + m.accessMutex.RUnlock() + return username, found +}