Add OpenAI Codex Multiplexer service

This commit is contained in:
世界
2025-12-16 22:09:02 +08:00
parent a89680fa2d
commit e8620587dd
24 changed files with 1673 additions and 8 deletions

View File

@@ -190,7 +190,7 @@ jobs:
- name: Set build tags - name: Set build tags
run: | run: |
set -xeuo pipefail set -xeuo pipefail
TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,badlinkname,tfogo_checklinkname0' TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,badlinkname,tfogo_checklinkname0'
if [[ "${{ matrix.naive }}" == "true" ]]; then if [[ "${{ matrix.naive }}" == "true" ]]; then
TAGS="${TAGS},with_naive_outbound" TAGS="${TAGS},with_naive_outbound"
fi fi
@@ -427,7 +427,7 @@ jobs:
- name: Set build tags - name: Set build tags
run: | run: |
set -xeuo pipefail set -xeuo pipefail
TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,badlinkname,tfogo_checklinkname0' TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,badlinkname,tfogo_checklinkname0'
if [[ "${{ matrix.legacy_go124 }}" != "true" ]]; then if [[ "${{ matrix.legacy_go124 }}" != "true" ]]; then
TAGS="${TAGS},with_naive_outbound" TAGS="${TAGS},with_naive_outbound"
fi fi
@@ -495,7 +495,7 @@ jobs:
- name: Build - name: Build
run: | run: |
mkdir -p dist mkdir -p dist
go build -v -trimpath -o dist/sing-box.exe -tags "with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_naive_outbound,with_purego,badlinkname,tfogo_checklinkname0" ` go build -v -trimpath -o dist/sing-box.exe -tags "with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,with_naive_outbound,with_purego,badlinkname,tfogo_checklinkname0" `
-ldflags "-s -buildid= -X github.com/sagernet/sing-box/constant.Version=${{ needs.calculate_version.outputs.version }} -checklinkname=0" ` -ldflags "-s -buildid= -X github.com/sagernet/sing-box/constant.Version=${{ needs.calculate_version.outputs.version }} -checklinkname=0" `
./cmd/sing-box ./cmd/sing-box
env: env:
@@ -885,6 +885,16 @@ jobs:
with: with:
path: dist path: dist
merge-multiple: true merge-multiple: true
- name: Generate SFA version metadata
run: |-
VERSION_CODE=$(grep VERSION_CODE clients/android/version.properties | cut -d= -f2)
cat > dist/SFA-version-metadata.json << EOF
{
"version_code": ${VERSION_CODE},
"version_name": "${VERSION}"
}
EOF
cat dist/SFA-version-metadata.json
- name: Upload builds - name: Upload builds
if: ${{ env.PUBLISHED == 'false' }} if: ${{ env.PUBLISHED == 'false' }}
run: |- run: |-

View File

@@ -93,7 +93,7 @@ jobs:
- name: Set build tags - name: Set build tags
run: | run: |
set -xeuo pipefail set -xeuo pipefail
TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,badlinkname,tfogo_checklinkname0' TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,badlinkname,tfogo_checklinkname0'
if [[ "${{ matrix.naive }}" == "true" ]]; then if [[ "${{ matrix.naive }}" == "true" ]]; then
TAGS="${TAGS},with_naive_outbound,with_musl" TAGS="${TAGS},with_naive_outbound,with_musl"
fi fi

View File

@@ -116,7 +116,7 @@ jobs:
- name: Set build tags - name: Set build tags
run: | run: |
set -xeuo pipefail set -xeuo pipefail
TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,badlinkname,tfogo_checklinkname0' TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,badlinkname,tfogo_checklinkname0'
if [[ "${{ matrix.naive }}" == "true" ]]; then if [[ "${{ matrix.naive }}" == "true" ]]; then
TAGS="${TAGS},with_naive_outbound,with_musl" TAGS="${TAGS},with_naive_outbound,with_musl"
fi fi

View File

@@ -13,7 +13,7 @@ RUN set -ex \
&& export COMMIT=$(git rev-parse --short HEAD) \ && export COMMIT=$(git rev-parse --short HEAD) \
&& export VERSION=$(go run ./cmd/internal/read_tag) \ && export VERSION=$(go run ./cmd/internal/read_tag) \
&& go build -v -trimpath -tags \ && go build -v -trimpath -tags \
"with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,badlinkname,tfogo_checklinkname0" \ "with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,badlinkname,tfogo_checklinkname0" \
-o /go/bin/sing-box \ -o /go/bin/sing-box \
-ldflags "-X \"github.com/sagernet/sing-box/constant.Version=$VERSION\" -s -w -buildid= -checklinkname=0" \ -ldflags "-X \"github.com/sagernet/sing-box/constant.Version=$VERSION\" -s -w -buildid= -checklinkname=0" \
./cmd/sing-box ./cmd/sing-box

View File

@@ -1,6 +1,6 @@
NAME = sing-box NAME = sing-box
COMMIT = $(shell git rev-parse --short HEAD) COMMIT = $(shell git rev-parse --short HEAD)
TAGS ?= with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,badlinkname,tfogo_checklinkname0 TAGS ?= with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,badlinkname,tfogo_checklinkname0
GOHOSTOS = $(shell go env GOHOSTOS) GOHOSTOS = $(shell go env GOHOSTOS)
GOHOSTARCH = $(shell go env GOHOSTARCH) GOHOSTARCH = $(shell go env GOHOSTARCH)

View File

@@ -29,6 +29,7 @@ const (
TypeResolved = "resolved" TypeResolved = "resolved"
TypeSSMAPI = "ssm-api" TypeSSMAPI = "ssm-api"
TypeCCM = "ccm" TypeCCM = "ccm"
TypeOCM = "ocm"
) )
const ( const (

View File

@@ -25,6 +25,7 @@ icon: material/new-box
|------------|------------------------| |------------|------------------------|
| `ccm` | [CCM](./ccm) | | `ccm` | [CCM](./ccm) |
| `derp` | [DERP](./derp) | | `derp` | [DERP](./derp) |
| `ocm` | [OCM](./ocm) |
| `resolved` | [Resolved](./resolved) | | `resolved` | [Resolved](./resolved) |
| `ssm-api` | [SSM API](./ssm-api) | | `ssm-api` | [SSM API](./ssm-api) |

View File

@@ -25,6 +25,7 @@ icon: material/new-box
|-----------|------------------------| |-----------|------------------------|
| `ccm` | [CCM](./ccm) | | `ccm` | [CCM](./ccm) |
| `derp` | [DERP](./derp) | | `derp` | [DERP](./derp) |
| `ocm` | [OCM](./ocm) |
| `resolved`| [Resolved](./resolved) | | `resolved`| [Resolved](./resolved) |
| `ssm-api` | [SSM API](./ssm-api) | | `ssm-api` | [SSM API](./ssm-api) |

View File

@@ -0,0 +1,171 @@
---
icon: material/new-box
---
!!! question "Since sing-box 1.13.0"
# OCM
OCM (OpenAI Codex Multiplexer) service is a multiplexing service that allows you to access your local OpenAI Codex subscription remotely through custom tokens.
It handles OAuth authentication with OpenAI's API on your local machine while allowing remote clients to authenticate using custom tokens.
### Structure
```json
{
"type": "ocm",
... // Listen Fields
"credential_path": "",
"usages_path": "",
"users": [],
"headers": {},
"detour": "",
"tls": {}
}
```
### Listen Fields
See [Listen Fields](/configuration/shared/listen/) for details.
### Fields
#### credential_path
Path to the OpenAI OAuth credentials file.
If not specified, defaults to `~/.codex/auth.json`.
Refreshed tokens are automatically written back to the same location.
#### usages_path
Path to the file for storing aggregated API usage statistics.
Usage tracking is disabled if not specified.
When enabled, the service tracks and saves comprehensive statistics including:
- Request counts
- Token usage (input, output, cached)
- Calculated costs in USD based on OpenAI API pricing
Statistics are organized by model and optionally by user when authentication is enabled.
The statistics file is automatically saved every minute and upon service shutdown.
#### users
List of authorized users for token authentication.
If empty, no authentication is required.
Object format:
```json
{
"name": "",
"token": ""
}
```
Object fields:
- `name`: Username identifier for tracking purposes.
- `token`: Bearer token for authentication. Clients authenticate by setting the `Authorization: Bearer <token>` header.
#### headers
Custom HTTP headers to send to the OpenAI API.
These headers will override any existing headers with the same name.
#### detour
Outbound tag for connecting to the OpenAI API.
#### tls
TLS configuration, see [TLS](/configuration/shared/tls/#inbound).
### Example
#### Server
```json
{
"services": [
{
"type": "ocm",
"listen": "127.0.0.1",
"listen_port": 8080
}
]
}
```
#### Client
Add to `~/.codex/config.toml`:
```toml
[model_providers.ocm]
name = "OCM Proxy"
base_url = "http://127.0.0.1:8080/v1"
wire_api = "responses"
requires_openai_auth = false
```
Then run:
```bash
codex --model-provider ocm
```
### Example with Authentication
#### Server
```json
{
"services": [
{
"type": "ocm",
"listen": "0.0.0.0",
"listen_port": 8080,
"usages_path": "./codex-usages.json",
"users": [
{
"name": "alice",
"token": "sk-alice-secret-token"
},
{
"name": "bob",
"token": "sk-bob-secret-token"
}
]
}
]
}
```
#### Client
Add to `~/.codex/config.toml`:
```toml
[model_providers.ocm]
name = "OCM Proxy"
base_url = "http://127.0.0.1:8080/v1"
wire_api = "responses"
requires_openai_auth = false
experimental_bearer_token = "sk-alice-secret-token"
```
Then run:
```bash
codex --model-provider ocm
```

View File

@@ -0,0 +1,171 @@
---
icon: material/new-box
---
!!! question "自 sing-box 1.13.0 起"
# OCM
OCMOpenAI Codex 多路复用器)服务是一个多路复用服务,允许您通过自定义令牌远程访问本地的 OpenAI Codex 订阅。
它在本地机器上处理与 OpenAI API 的 OAuth 身份验证,同时允许远程客户端使用自定义令牌进行身份验证。
### 结构
```json
{
"type": "ocm",
... // 监听字段
"credential_path": "",
"usages_path": "",
"users": [],
"headers": {},
"detour": "",
"tls": {}
}
```
### 监听字段
参阅 [监听字段](/zh/configuration/shared/listen/) 了解详情。
### 字段
#### credential_path
OpenAI OAuth 凭据文件的路径。
如果未指定,默认值为 `~/.codex/auth.json`
刷新的令牌会自动写回相同位置。
#### usages_path
用于存储聚合 API 使用统计信息的文件路径。
如果未指定,使用跟踪将被禁用。
启用后,服务会跟踪并保存全面的统计信息,包括:
- 请求计数
- 令牌使用量(输入、输出、缓存)
- 基于 OpenAI API 定价计算的美元成本
统计信息按模型以及可选的用户(启用身份验证时)进行组织。
统计文件每分钟自动保存一次,并在服务关闭时保存。
#### users
用于令牌身份验证的授权用户列表。
如果为空,则不需要身份验证。
对象格式:
```json
{
"name": "",
"token": ""
}
```
对象字段:
- `name`:用于跟踪的用户名标识符。
- `token`:用于身份验证的 Bearer 令牌。客户端通过设置 `Authorization: Bearer <token>` 头进行身份验证。
#### headers
发送到 OpenAI API 的自定义 HTTP 头。
这些头会覆盖同名的现有头。
#### detour
用于连接 OpenAI API 的出站标签。
#### tls
TLS 配置,参阅 [TLS](/zh/configuration/shared/tls/#inbound)。
### 示例
#### 服务端
```json
{
"services": [
{
"type": "ocm",
"listen": "127.0.0.1",
"listen_port": 8080
}
]
}
```
#### 客户端
`~/.codex/config.toml` 中添加:
```toml
[model_providers.ocm]
name = "OCM Proxy"
base_url = "http://127.0.0.1:8080/v1"
wire_api = "responses"
requires_openai_auth = false
```
然后运行:
```bash
codex --model-provider ocm
```
### 带身份验证的示例
#### 服务端
```json
{
"services": [
{
"type": "ocm",
"listen": "0.0.0.0",
"listen_port": 8080,
"usages_path": "./codex-usages.json",
"users": [
{
"name": "alice",
"token": "sk-alice-secret-token"
},
{
"name": "bob",
"token": "sk-bob-secret-token"
}
]
}
]
}
```
#### 客户端
`~/.codex/config.toml` 中添加:
```toml
[model_providers.ocm]
name = "OCM Proxy"
base_url = "http://127.0.0.1:8080/v1"
wire_api = "responses"
requires_openai_auth = false
experimental_bearer_token = "sk-alice-secret-token"
```
然后运行:
```bash
codex --model-provider ocm
```

1
go.mod
View File

@@ -21,6 +21,7 @@ require (
github.com/metacubex/utls v1.8.4 github.com/metacubex/utls v1.8.4
github.com/mholt/acmez/v3 v3.1.2 github.com/mholt/acmez/v3 v3.1.2
github.com/miekg/dns v1.1.67 github.com/miekg/dns v1.1.67
github.com/openai/openai-go/v3 v3.13.0
github.com/oschwald/maxminddb-golang v1.13.1 github.com/oschwald/maxminddb-golang v1.13.1
github.com/sagernet/asc-go v0.0.0-20241217030726-d563060fe4e1 github.com/sagernet/asc-go v0.0.0-20241217030726-d563060fe4e1
github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a

2
go.sum
View File

@@ -131,6 +131,8 @@ github.com/mitchellh/go-ps v1.0.0 h1:i6ampVEEF4wQFF+bkYfwYgY+F/uYJDktmvLPf7qIgjc
github.com/mitchellh/go-ps v1.0.0/go.mod h1:J4lOc8z8yJs6vUwklHw2XEIiT4z4C40KtWVN3nvg8Pg= github.com/mitchellh/go-ps v1.0.0/go.mod h1:J4lOc8z8yJs6vUwklHw2XEIiT4z4C40KtWVN3nvg8Pg=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/openai/openai-go/v3 v3.13.0 h1:arSFmVHcBHNVYG5iqspPJrLoin0Qqn2JcCLWWcTcM1Q=
github.com/openai/openai-go/v3 v3.13.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE= github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE=
github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8= github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8=
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=

12
include/ocm.go Normal file
View File

@@ -0,0 +1,12 @@
//go:build with_ocm
package include
import (
"github.com/sagernet/sing-box/adapter/service"
"github.com/sagernet/sing-box/service/ocm"
)
func registerOCMService(registry *service.Registry) {
ocm.RegisterService(registry)
}

20
include/ocm_stub.go Normal file
View File

@@ -0,0 +1,20 @@
//go:build !with_ocm
package include
import (
"context"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/service"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
func registerOCMService(registry *service.Registry) {
service.Register[option.OCMServiceOptions](registry, C.TypeOCM, func(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) {
return nil, E.New(`OCM is not included in this build, rebuild with -tags with_ocm`)
})
}

View File

@@ -136,6 +136,7 @@ func ServiceRegistry() *service.Registry {
registerDERPService(registry) registerDERPService(registry)
registerCCMService(registry) registerCCMService(registry)
registerOCMService(registry)
return registry return registry
} }

View File

@@ -176,6 +176,8 @@ nav:
- DERP: configuration/service/derp.md - DERP: configuration/service/derp.md
- Resolved: configuration/service/resolved.md - Resolved: configuration/service/resolved.md
- SSM API: configuration/service/ssm-api.md - SSM API: configuration/service/ssm-api.md
- CCM: configuration/service/ccm.md
- OCM: configuration/service/ocm.md
markdown_extensions: markdown_extensions:
- pymdownx.inlinehilite - pymdownx.inlinehilite
- pymdownx.snippets - pymdownx.snippets

20
option/ocm.go Normal file
View File

@@ -0,0 +1,20 @@
package option
import (
"github.com/sagernet/sing/common/json/badoption"
)
type OCMServiceOptions struct {
ListenOptions
InboundTLSOptionsContainer
CredentialPath string `json:"credential_path,omitempty"`
Users []OCMUser `json:"users,omitempty"`
Headers badoption.HTTPHeader `json:"headers,omitempty"`
Detour string `json:"detour,omitempty"`
UsagesPath string `json:"usages_path,omitempty"`
}
type OCMUser struct {
Name string `json:"name,omitempty"`
Token string `json:"token,omitempty"`
}

View File

@@ -11,7 +11,7 @@ INSTALL_CONFIG_PATH="/usr/local/etc/sing-box"
INSTALL_DATA_PATH="/var/lib/sing-box" INSTALL_DATA_PATH="/var/lib/sing-box"
SYSTEMD_SERVICE_PATH="/etc/systemd/system" SYSTEMD_SERVICE_PATH="/etc/systemd/system"
DEFAULT_BUILD_TAGS="with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,badlinkname,tfogo_checklinkname0" DEFAULT_BUILD_TAGS="with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,badlinkname,tfogo_checklinkname0"
setup_environment() { setup_environment() {
if [ -d /usr/local/go ]; then if [ -d /usr/local/go ]; then

173
service/ocm/credential.go Normal file
View File

@@ -0,0 +1,173 @@
package ocm
import (
"bytes"
"encoding/json"
"io"
"net/http"
"os"
"os/user"
"path/filepath"
"time"
E "github.com/sagernet/sing/common/exceptions"
)
const (
oauth2ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
oauth2TokenURL = "https://auth.openai.com/oauth/token"
openaiAPIBaseURL = "https://api.openai.com"
chatGPTBackendURL = "https://chatgpt.com/backend-api/codex"
tokenRefreshIntervalDays = 8
)
func getRealUser() (*user.User, error) {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
sudoUserInfo, err := user.Lookup(sudoUser)
if err == nil {
return sudoUserInfo, nil
}
}
return user.Current()
}
func getDefaultCredentialsPath() (string, error) {
if codexHome := os.Getenv("CODEX_HOME"); codexHome != "" {
return filepath.Join(codexHome, "auth.json"), nil
}
userInfo, err := getRealUser()
if err != nil {
return "", err
}
return filepath.Join(userInfo.HomeDir, ".codex", "auth.json"), nil
}
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var credentials oauthCredentials
err = json.Unmarshal(data, &credentials)
if err != nil {
return nil, err
}
return &credentials, nil
}
func writeCredentialsToFile(credentials *oauthCredentials, path string) error {
data, err := json.MarshalIndent(credentials, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o600)
}
type oauthCredentials struct {
APIKey string `json:"OPENAI_API_KEY,omitempty"`
Tokens *tokenData `json:"tokens,omitempty"`
LastRefresh *time.Time `json:"last_refresh,omitempty"`
}
type tokenData struct {
IDToken string `json:"id_token,omitempty"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
AccountID string `json:"account_id,omitempty"`
}
func (c *oauthCredentials) isAPIKeyMode() bool {
return c.APIKey != ""
}
func (c *oauthCredentials) getAccessToken() string {
if c.APIKey != "" {
return c.APIKey
}
if c.Tokens != nil {
return c.Tokens.AccessToken
}
return ""
}
func (c *oauthCredentials) getAccountID() string {
if c.Tokens != nil {
return c.Tokens.AccountID
}
return ""
}
func (c *oauthCredentials) needsRefresh() bool {
if c.APIKey != "" {
return false
}
if c.Tokens == nil || c.Tokens.RefreshToken == "" {
return false
}
if c.LastRefresh == nil {
return true
}
return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour
}
func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" {
return nil, E.New("refresh token is empty")
}
requestBody, err := json.Marshal(map[string]string{
"grant_type": "refresh_token",
"refresh_token": credentials.Tokens.RefreshToken,
"client_id": oauth2ClientID,
"scope": "openid profile email",
})
if err != nil {
return nil, E.Cause(err, "marshal request")
}
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
response, err := httpClient.Do(request)
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
}
var tokenResponse struct {
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
if err != nil {
return nil, E.Cause(err, "decode response")
}
newCredentials := *credentials
if newCredentials.Tokens == nil {
newCredentials.Tokens = &tokenData{}
}
if tokenResponse.IDToken != "" {
newCredentials.Tokens.IDToken = tokenResponse.IDToken
}
if tokenResponse.AccessToken != "" {
newCredentials.Tokens.AccessToken = tokenResponse.AccessToken
}
if tokenResponse.RefreshToken != "" {
newCredentials.Tokens.RefreshToken = tokenResponse.RefreshToken
}
now := time.Now()
newCredentials.LastRefresh = &now
return &newCredentials, nil
}

View File

@@ -0,0 +1,25 @@
//go:build darwin
package ocm
func platformReadCredentials(customPath string) (*oauthCredentials, error) {
if customPath == "" {
var err error
customPath, err = getDefaultCredentialsPath()
if err != nil {
return nil, err
}
}
return readCredentialsFromFile(customPath)
}
func platformWriteCredentials(credentials *oauthCredentials, customPath string) error {
if customPath == "" {
var err error
customPath, err = getDefaultCredentialsPath()
if err != nil {
return err
}
}
return writeCredentialsToFile(credentials, customPath)
}

View File

@@ -0,0 +1,25 @@
//go:build !darwin
package ocm
func platformReadCredentials(customPath string) (*oauthCredentials, error) {
if customPath == "" {
var err error
customPath, err = getDefaultCredentialsPath()
if err != nil {
return nil, err
}
}
return readCredentialsFromFile(customPath)
}
func platformWriteCredentials(credentials *oauthCredentials, customPath string) error {
if customPath == "" {
var err error
customPath, err = getDefaultCredentialsPath()
if err != nil {
return err
}
}
return writeCredentialsToFile(credentials, customPath)
}

555
service/ocm/service.go Normal file
View File

@@ -0,0 +1,555 @@
package ocm
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"mime"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
boxService "github.com/sagernet/sing-box/adapter/service"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/listener"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
"github.com/go-chi/chi/v5"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/responses"
"golang.org/x/net/http2"
)
func RegisterService(registry *boxService.Registry) {
boxService.Register[option.OCMServiceOptions](registry, C.TypeOCM, NewService)
}
type errorResponse struct {
Error errorDetails `json:"error"`
}
type errorDetails struct {
Type string `json:"type"`
Code string `json:"code,omitempty"`
Message string `json:"message"`
}
func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(errorResponse{
Error: errorDetails{
Type: errorType,
Message: message,
},
})
}
func isHopByHopHeader(header string) bool {
switch strings.ToLower(header) {
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
return true
default:
return false
}
}
type Service struct {
boxService.Adapter
ctx context.Context
logger log.ContextLogger
credentialPath string
credentials *oauthCredentials
users []option.OCMUser
httpClient *http.Client
httpHeaders http.Header
listener *listener.Listener
tlsConfig tls.ServerConfig
httpServer *http.Server
userManager *UserManager
accessMutex sync.RWMutex
usageTracker *AggregatedUsage
trackingGroup sync.WaitGroup
shuttingDown bool
}
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) {
serviceDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: option.DialerOptions{
Detour: options.Detour,
},
RemoteIsDomain: true,
})
if err != nil {
return nil, E.Cause(err, "create dialer")
}
httpClient := &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
},
}
userManager := &UserManager{
tokenMap: make(map[string]string),
}
var usageTracker *AggregatedUsage
if options.UsagesPath != "" {
usageTracker = &AggregatedUsage{
LastUpdated: time.Now(),
Combinations: make([]CostCombination, 0),
filePath: options.UsagesPath,
logger: logger,
}
}
service := &Service{
Adapter: boxService.NewAdapter(C.TypeOCM, tag),
ctx: ctx,
logger: logger,
credentialPath: options.CredentialPath,
users: options.Users,
httpClient: httpClient,
httpHeaders: options.Headers.Build(),
listener: listener.New(listener.Options{
Context: ctx,
Logger: logger,
Network: []string{N.NetworkTCP},
Listen: options.ListenOptions,
}),
userManager: userManager,
usageTracker: usageTracker,
}
if options.TLS != nil {
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
service.tlsConfig = tlsConfig
}
return service, nil
}
func (s *Service) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
s.userManager.UpdateUsers(s.users)
credentials, err := platformReadCredentials(s.credentialPath)
if err != nil {
return E.Cause(err, "read credentials")
}
s.credentials = credentials
if s.usageTracker != nil {
err = s.usageTracker.Load()
if err != nil {
s.logger.Warn("load usage statistics: ", err)
}
}
router := chi.NewRouter()
router.Mount("/", s)
s.httpServer = &http.Server{Handler: router}
if s.tlsConfig != nil {
err = s.tlsConfig.Start()
if err != nil {
return E.Cause(err, "create TLS config")
}
}
tcpListener, err := s.listener.ListenTCP()
if err != nil {
return err
}
if s.tlsConfig != nil {
if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) {
s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...))
}
tcpListener = aTLS.NewListener(tcpListener, s.tlsConfig)
}
go func() {
serveErr := s.httpServer.Serve(tcpListener)
if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
s.logger.Error("serve error: ", serveErr)
}
}()
return nil
}
func (s *Service) getAccessToken() (string, error) {
s.accessMutex.RLock()
if !s.credentials.needsRefresh() {
token := s.credentials.getAccessToken()
s.accessMutex.RUnlock()
return token, nil
}
s.accessMutex.RUnlock()
s.accessMutex.Lock()
defer s.accessMutex.Unlock()
if !s.credentials.needsRefresh() {
return s.credentials.getAccessToken(), nil
}
newCredentials, err := refreshToken(s.httpClient, s.credentials)
if err != nil {
return "", err
}
s.credentials = newCredentials
err = platformWriteCredentials(newCredentials, s.credentialPath)
if err != nil {
s.logger.Warn("persist refreshed token: ", err)
}
return newCredentials.getAccessToken(), nil
}
func (s *Service) getAccountID() string {
s.accessMutex.RLock()
defer s.accessMutex.RUnlock()
return s.credentials.getAccountID()
}
func (s *Service) isAPIKeyMode() bool {
s.accessMutex.RLock()
defer s.accessMutex.RUnlock()
return s.credentials.isAPIKeyMode()
}
func (s *Service) getBaseURL() string {
if s.isAPIKeyMode() {
return openaiAPIBaseURL
}
return chatGPTBackendURL
}
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
if !strings.HasPrefix(path, "/v1/") {
writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/")
return
}
var proxyPath string
if s.isAPIKeyMode() {
proxyPath = path
} else {
if path == "/v1/chat/completions" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"chat completions endpoint is only available in API key mode")
return
}
proxyPath = strings.TrimPrefix(path, "/v1")
}
var username string
if len(s.users) > 0 {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
var ok bool
username, ok = s.userManager.Authenticate(clientToken)
if !ok {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
}
var requestModel string
if s.usageTracker != nil && r.Body != nil {
bodyBytes, err := io.ReadAll(r.Body)
if err == nil {
var request struct {
Model string `json:"model"`
}
err := json.Unmarshal(bodyBytes, &request)
if err == nil {
requestModel = request.Model
}
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
}
accessToken, err := s.getAccessToken()
if err != nil {
s.logger.Error("get access token: ", err)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed")
return
}
proxyURL := s.getBaseURL() + proxyPath
if r.URL.RawQuery != "" {
proxyURL += "?" + r.URL.RawQuery
}
proxyRequest, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, r.Body)
if err != nil {
s.logger.Error("create proxy request: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
return
}
for key, values := range r.Header {
if !isHopByHopHeader(key) && key != "Authorization" {
proxyRequest.Header[key] = values
}
}
for key, values := range s.httpHeaders {
proxyRequest.Header.Del(key)
proxyRequest.Header[key] = values
}
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
if accountID := s.getAccountID(); accountID != "" {
proxyRequest.Header.Set("ChatGPT-Account-Id", accountID)
}
response, err := s.httpClient.Do(proxyRequest)
if err != nil {
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
return
}
defer response.Body.Close()
for key, values := range response.Header {
if !isHopByHopHeader(key) {
w.Header()[key] = values
}
}
w.WriteHeader(response.StatusCode)
trackUsage := s.usageTracker != nil && response.StatusCode == http.StatusOK &&
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses"))
if trackUsage {
s.handleResponseWithTracking(w, response, path, requestModel, username)
} else {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
if err == nil && mediaType != "text/event-stream" {
_, _ = io.Copy(w, response.Body)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
s.logger.Error("streaming not supported")
return
}
buffer := make([]byte, buf.BufferSize)
for {
n, err := response.Body.Read(buffer)
if n > 0 {
_, writeError := w.Write(buffer[:n])
if writeError != nil {
s.logger.Error("write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
return
}
}
}
}
func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, path string, requestModel string, username string) {
isChatCompletions := path == "/v1/chat/completions"
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
isStreaming := err == nil && mediaType == "text/event-stream"
if !isStreaming {
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
s.logger.Error("read response body: ", err)
return
}
var responseModel string
var inputTokens, outputTokens, cachedTokens int64
if isChatCompletions {
var chatCompletion openai.ChatCompletion
if json.Unmarshal(bodyBytes, &chatCompletion) == nil {
responseModel = chatCompletion.Model
inputTokens = chatCompletion.Usage.PromptTokens
outputTokens = chatCompletion.Usage.CompletionTokens
cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens
}
} else {
var responsesResponse responses.Response
if json.Unmarshal(bodyBytes, &responsesResponse) == nil {
responseModel = string(responsesResponse.Model)
inputTokens = responsesResponse.Usage.InputTokens
outputTokens = responsesResponse.Usage.OutputTokens
cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
}
}
if inputTokens > 0 || outputTokens > 0 {
if responseModel == "" {
responseModel = requestModel
}
if responseModel != "" {
s.usageTracker.AddUsage(responseModel, inputTokens, outputTokens, cachedTokens, username)
}
}
_, _ = writer.Write(bodyBytes)
return
}
flusher, ok := writer.(http.Flusher)
if !ok {
s.logger.Error("streaming not supported")
return
}
var inputTokens, outputTokens, cachedTokens int64
var responseModel string
buffer := make([]byte, buf.BufferSize)
var leftover []byte
for {
n, err := response.Body.Read(buffer)
if n > 0 {
data := append(leftover, buffer[:n]...)
lines := bytes.Split(data, []byte("\n"))
if err == nil {
leftover = lines[len(lines)-1]
lines = lines[:len(lines)-1]
} else {
leftover = nil
}
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if bytes.HasPrefix(line, []byte("data: ")) {
eventData := bytes.TrimPrefix(line, []byte("data: "))
if bytes.Equal(eventData, []byte("[DONE]")) {
continue
}
if isChatCompletions {
var chatChunk openai.ChatCompletionChunk
if json.Unmarshal(eventData, &chatChunk) == nil {
if chatChunk.Model != "" {
responseModel = chatChunk.Model
}
if chatChunk.Usage.PromptTokens > 0 {
inputTokens = chatChunk.Usage.PromptTokens
cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens
}
if chatChunk.Usage.CompletionTokens > 0 {
outputTokens = chatChunk.Usage.CompletionTokens
}
}
} else {
var streamEvent responses.ResponseStreamEventUnion
if json.Unmarshal(eventData, &streamEvent) == nil {
if streamEvent.Type == "response.completed" {
completedEvent := streamEvent.AsResponseCompleted()
if string(completedEvent.Response.Model) != "" {
responseModel = string(completedEvent.Response.Model)
}
if completedEvent.Response.Usage.InputTokens > 0 {
inputTokens = completedEvent.Response.Usage.InputTokens
cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens
}
if completedEvent.Response.Usage.OutputTokens > 0 {
outputTokens = completedEvent.Response.Usage.OutputTokens
}
}
}
}
}
}
_, writeError := writer.Write(buffer[:n])
if writeError != nil {
s.logger.Error("write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
if responseModel == "" {
responseModel = requestModel
}
if inputTokens > 0 || outputTokens > 0 {
if responseModel != "" {
s.usageTracker.AddUsage(responseModel, inputTokens, outputTokens, cachedTokens, username)
}
}
return
}
}
}
func (s *Service) Close() error {
err := common.Close(
common.PtrOrNil(s.httpServer),
common.PtrOrNil(s.listener),
s.tlsConfig,
)
if s.usageTracker != nil {
s.usageTracker.cancelPendingSave()
saveErr := s.usageTracker.Save()
if saveErr != nil {
s.logger.Error("save usage statistics: ", saveErr)
}
}
return err
}

View File

@@ -0,0 +1,445 @@
package ocm
import (
"encoding/json"
"math"
"os"
"regexp"
"sync"
"time"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
)
type UsageStats struct {
RequestCount int `json:"request_count"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CachedTokens int64 `json:"cached_tokens"`
}
func (u *UsageStats) UnmarshalJSON(data []byte) error {
type Alias UsageStats
aux := &struct {
*Alias
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
}{
Alias: (*Alias)(u),
}
err := json.Unmarshal(data, aux)
if err != nil {
return err
}
if u.InputTokens == 0 && aux.PromptTokens > 0 {
u.InputTokens = aux.PromptTokens
}
if u.OutputTokens == 0 && aux.CompletionTokens > 0 {
u.OutputTokens = aux.CompletionTokens
}
return nil
}
type CostCombination struct {
Model string `json:"model"`
Total UsageStats `json:"total"`
ByUser map[string]UsageStats `json:"by_user"`
}
type AggregatedUsage struct {
LastUpdated time.Time `json:"last_updated"`
Combinations []CostCombination `json:"combinations"`
mutex sync.Mutex
filePath string
logger log.ContextLogger
lastSaveTime time.Time
pendingSave bool
saveTimer *time.Timer
saveMutex sync.Mutex
}
type UsageStatsJSON struct {
RequestCount int `json:"request_count"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CachedTokens int64 `json:"cached_tokens"`
CostUSD float64 `json:"cost_usd"`
}
type CostCombinationJSON struct {
Model string `json:"model"`
Total UsageStatsJSON `json:"total"`
ByUser map[string]UsageStatsJSON `json:"by_user"`
}
type CostsSummaryJSON struct {
TotalUSD float64 `json:"total_usd"`
ByUser map[string]float64 `json:"by_user"`
}
type AggregatedUsageJSON struct {
LastUpdated time.Time `json:"last_updated"`
Costs CostsSummaryJSON `json:"costs"`
Combinations []CostCombinationJSON `json:"combinations"`
}
type ModelPricing struct {
InputPrice float64
OutputPrice float64
CachedInputPrice float64
}
type modelFamily struct {
pattern *regexp.Regexp
pricing ModelPricing
}
var (
gpt4oPricing = ModelPricing{
InputPrice: 2.5,
OutputPrice: 10.0,
CachedInputPrice: 1.25,
}
gpt4oMiniPricing = ModelPricing{
InputPrice: 0.15,
OutputPrice: 0.6,
CachedInputPrice: 0.075,
}
gpt4oAudioPricing = ModelPricing{
InputPrice: 2.5,
OutputPrice: 10.0,
CachedInputPrice: 1.25,
}
o1Pricing = ModelPricing{
InputPrice: 15.0,
OutputPrice: 60.0,
CachedInputPrice: 7.5,
}
o1MiniPricing = ModelPricing{
InputPrice: 1.1,
OutputPrice: 4.4,
CachedInputPrice: 0.55,
}
o3MiniPricing = ModelPricing{
InputPrice: 1.1,
OutputPrice: 4.4,
CachedInputPrice: 0.55,
}
o3Pricing = ModelPricing{
InputPrice: 2.0,
OutputPrice: 8.0,
CachedInputPrice: 1.0,
}
o4MiniPricing = ModelPricing{
InputPrice: 1.1,
OutputPrice: 4.4,
CachedInputPrice: 0.55,
}
gpt41Pricing = ModelPricing{
InputPrice: 2.0,
OutputPrice: 8.0,
CachedInputPrice: 0.5,
}
gpt41MiniPricing = ModelPricing{
InputPrice: 0.4,
OutputPrice: 1.6,
CachedInputPrice: 0.1,
}
gpt41NanoPricing = ModelPricing{
InputPrice: 0.1,
OutputPrice: 0.4,
CachedInputPrice: 0.025,
}
modelFamilies = []modelFamily{
{
pattern: regexp.MustCompile(`^gpt-4\.1-nano`),
pricing: gpt41NanoPricing,
},
{
pattern: regexp.MustCompile(`^gpt-4\.1-mini`),
pricing: gpt41MiniPricing,
},
{
pattern: regexp.MustCompile(`^gpt-4\.1`),
pricing: gpt41Pricing,
},
{
pattern: regexp.MustCompile(`^o4-mini`),
pricing: o4MiniPricing,
},
{
pattern: regexp.MustCompile(`^o3-mini`),
pricing: o3MiniPricing,
},
{
pattern: regexp.MustCompile(`^o3`),
pricing: o3Pricing,
},
{
pattern: regexp.MustCompile(`^o1-mini`),
pricing: o1MiniPricing,
},
{
pattern: regexp.MustCompile(`^o1`),
pricing: o1Pricing,
},
{
pattern: regexp.MustCompile(`^gpt-4o-audio`),
pricing: gpt4oAudioPricing,
},
{
pattern: regexp.MustCompile(`^gpt-4o-mini`),
pricing: gpt4oMiniPricing,
},
{
pattern: regexp.MustCompile(`^gpt-4o`),
pricing: gpt4oPricing,
},
{
pattern: regexp.MustCompile(`^chatgpt-4o`),
pricing: gpt4oPricing,
},
}
)
func getPricing(model string) ModelPricing {
for _, family := range modelFamilies {
if family.pattern.MatchString(model) {
return family.pricing
}
}
return gpt4oPricing
}
func calculateCost(stats UsageStats, model string) float64 {
pricing := getPricing(model)
regularInputTokens := stats.InputTokens - stats.CachedTokens
if regularInputTokens < 0 {
regularInputTokens = 0
}
cost := (float64(regularInputTokens)*pricing.InputPrice +
float64(stats.OutputTokens)*pricing.OutputPrice +
float64(stats.CachedTokens)*pricing.CachedInputPrice) / 1_000_000
return math.Round(cost*100) / 100
}
func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON {
u.mutex.Lock()
defer u.mutex.Unlock()
result := &AggregatedUsageJSON{
LastUpdated: u.LastUpdated,
Combinations: make([]CostCombinationJSON, len(u.Combinations)),
Costs: CostsSummaryJSON{
TotalUSD: 0,
ByUser: make(map[string]float64),
},
}
for i, combo := range u.Combinations {
totalCost := calculateCost(combo.Total, combo.Model)
result.Costs.TotalUSD += totalCost
comboJSON := CostCombinationJSON{
Model: combo.Model,
Total: UsageStatsJSON{
RequestCount: combo.Total.RequestCount,
InputTokens: combo.Total.InputTokens,
OutputTokens: combo.Total.OutputTokens,
CachedTokens: combo.Total.CachedTokens,
CostUSD: totalCost,
},
ByUser: make(map[string]UsageStatsJSON),
}
for user, userStats := range combo.ByUser {
userCost := calculateCost(userStats, combo.Model)
result.Costs.ByUser[user] += userCost
comboJSON.ByUser[user] = UsageStatsJSON{
RequestCount: userStats.RequestCount,
InputTokens: userStats.InputTokens,
OutputTokens: userStats.OutputTokens,
CachedTokens: userStats.CachedTokens,
CostUSD: userCost,
}
}
result.Combinations[i] = comboJSON
}
result.Costs.TotalUSD = math.Round(result.Costs.TotalUSD*100) / 100
for user, cost := range result.Costs.ByUser {
result.Costs.ByUser[user] = math.Round(cost*100) / 100
}
return result
}
func (u *AggregatedUsage) Load() error {
u.mutex.Lock()
defer u.mutex.Unlock()
data, err := os.ReadFile(u.filePath)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
var temp struct {
LastUpdated time.Time `json:"last_updated"`
Combinations []CostCombination `json:"combinations"`
}
err = json.Unmarshal(data, &temp)
if err != nil {
return err
}
u.LastUpdated = temp.LastUpdated
u.Combinations = temp.Combinations
for i := range u.Combinations {
if u.Combinations[i].ByUser == nil {
u.Combinations[i].ByUser = make(map[string]UsageStats)
}
}
return nil
}
func (u *AggregatedUsage) Save() error {
jsonData := u.ToJSON()
data, err := json.MarshalIndent(jsonData, "", " ")
if err != nil {
return err
}
tmpFile := u.filePath + ".tmp"
err = os.WriteFile(tmpFile, data, 0o644)
if err != nil {
return err
}
defer os.Remove(tmpFile)
err = os.Rename(tmpFile, u.filePath)
if err == nil {
u.saveMutex.Lock()
u.lastSaveTime = time.Now()
u.saveMutex.Unlock()
}
return err
}
func (u *AggregatedUsage) AddUsage(model string, inputTokens, outputTokens, cachedTokens int64, user string) error {
if model == "" {
return E.New("model cannot be empty")
}
u.mutex.Lock()
defer u.mutex.Unlock()
u.LastUpdated = time.Now()
var combo *CostCombination
for i := range u.Combinations {
if u.Combinations[i].Model == model {
combo = &u.Combinations[i]
break
}
}
if combo == nil {
newCombo := CostCombination{
Model: model,
Total: UsageStats{},
ByUser: make(map[string]UsageStats),
}
u.Combinations = append(u.Combinations, newCombo)
combo = &u.Combinations[len(u.Combinations)-1]
}
combo.Total.RequestCount++
combo.Total.InputTokens += inputTokens
combo.Total.OutputTokens += outputTokens
combo.Total.CachedTokens += cachedTokens
if user != "" {
userStats := combo.ByUser[user]
userStats.RequestCount++
userStats.InputTokens += inputTokens
userStats.OutputTokens += outputTokens
userStats.CachedTokens += cachedTokens
combo.ByUser[user] = userStats
}
go u.scheduleSave()
return nil
}
func (u *AggregatedUsage) scheduleSave() {
const saveInterval = time.Minute
u.saveMutex.Lock()
defer u.saveMutex.Unlock()
timeSinceLastSave := time.Since(u.lastSaveTime)
if timeSinceLastSave >= saveInterval {
go u.saveAsync()
return
}
if u.pendingSave {
return
}
u.pendingSave = true
remainingTime := saveInterval - timeSinceLastSave
u.saveTimer = time.AfterFunc(remainingTime, func() {
u.saveMutex.Lock()
u.pendingSave = false
u.saveMutex.Unlock()
u.saveAsync()
})
}
func (u *AggregatedUsage) saveAsync() {
err := u.Save()
if err != nil {
if u.logger != nil {
u.logger.Error("save usage statistics: ", err)
}
}
}
func (u *AggregatedUsage) cancelPendingSave() {
u.saveMutex.Lock()
defer u.saveMutex.Unlock()
if u.saveTimer != nil {
u.saveTimer.Stop()
u.saveTimer = nil
}
u.pendingSave = false
}

View File

@@ -0,0 +1,29 @@
package ocm
import (
"sync"
"github.com/sagernet/sing-box/option"
)
type UserManager struct {
accessMutex sync.RWMutex
tokenMap map[string]string
}
func (m *UserManager) UpdateUsers(users []option.OCMUser) {
m.accessMutex.Lock()
defer m.accessMutex.Unlock()
tokenMap := make(map[string]string, len(users))
for _, user := range users {
tokenMap[user.Token] = user.Name
}
m.tokenMap = tokenMap
}
func (m *UserManager) Authenticate(token string) (string, bool) {
m.accessMutex.RLock()
username, found := m.tokenMap[token]
m.accessMutex.RUnlock()
return username, found
}