ccm,ocm: reorganize files and improve naming conventions

Split credential_state.go (1500+ lines) into credential.go,
credential_default.go, credential_provider.go, credential_builder.go.

Split service.go (900+ lines) into service.go, service_handler.go,
service_status.go.

Rename credential.go to credential_oauth.go to avoid name conflict
with the credential interface.

Apply naming fixes: accessMutex→access, stateMutex→stateAccess,
sessionMutex→sessionAccess, webSocketMutex→webSocketAccess,
httpTransport()→httpClient(), httpClient field→forwardHTTPClient,
weeklyWindowDuration→weeklyWindowHours.
This commit is contained in:
世界
2026-03-14 20:17:23 +08:00
parent 51d564c9ff
commit 04bd63b455
24 changed files with 4877 additions and 4776 deletions

View File

@@ -1,224 +1,187 @@
package ccm
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"os"
"os/user"
"path/filepath"
"runtime"
"slices"
"strconv"
"sync"
"time"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
)
const (
oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
oauth2TokenURL = "https://platform.claude.com/v1/oauth/token"
claudeAPIBaseURL = "https://api.anthropic.com"
tokenRefreshBufferMs = 60000
anthropicBetaOAuthValue = "oauth-2025-04-20"
defaultPollInterval = 60 * time.Minute
failedPollRetryInterval = time.Minute
httpRetryMaxBackoff = 5 * time.Minute
)
const ccmUserAgentFallback = "claude-code/2.1.72"
var (
ccmUserAgentOnce sync.Once
ccmUserAgentValue string
const (
httpRetryMaxAttempts = 3
httpRetryInitialDelay = 200 * time.Millisecond
)
func initCCMUserAgent(logger log.ContextLogger) {
ccmUserAgentOnce.Do(func() {
version, err := detectClaudeCodeVersion()
if err != nil {
logger.Error("detect Claude Code version: ", err)
ccmUserAgentValue = ccmUserAgentFallback
return
const sessionExpiry = 24 * time.Hour
func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) {
var lastError error
for attempt := range httpRetryMaxAttempts {
if attempt > 0 {
delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1))
select {
case <-ctx.Done():
return nil, lastError
case <-time.After(delay):
}
}
logger.Debug("detected Claude Code version: ", version)
ccmUserAgentValue = "claude-code/" + version
})
}
func detectClaudeCodeVersion() (string, error) {
userInfo, err := getRealUser()
if err != nil {
return "", E.Cause(err, "get user")
}
binaryName := "claude"
if runtime.GOOS == "windows" {
binaryName = "claude.exe"
}
linkPath := filepath.Join(userInfo.HomeDir, ".local", "bin", binaryName)
target, err := os.Readlink(linkPath)
if err != nil {
return "", E.Cause(err, "readlink ", linkPath)
}
if !filepath.IsAbs(target) {
target = filepath.Join(filepath.Dir(linkPath), target)
}
parent := filepath.Base(filepath.Dir(target))
if parent != "versions" {
return "", E.New("unexpected symlink target: ", target)
}
return filepath.Base(target), nil
}
func getRealUser() (*user.User, error) {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
sudoUserInfo, err := user.Lookup(sudoUser)
if err == nil {
return sudoUserInfo, nil
}
}
return user.Current()
}
func getDefaultCredentialsPath() (string, error) {
if configDir := os.Getenv("CLAUDE_CONFIG_DIR"); configDir != "" {
return filepath.Join(configDir, ".credentials.json"), nil
}
userInfo, err := getRealUser()
if err != nil {
return "", err
}
return filepath.Join(userInfo.HomeDir, ".claude", ".credentials.json"), nil
}
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var credentialsContainer struct {
ClaudeAIAuth *oauthCredentials `json:"claudeAiOauth,omitempty"`
}
err = json.Unmarshal(data, &credentialsContainer)
if err != nil {
return nil, err
}
if credentialsContainer.ClaudeAIAuth == nil {
return nil, E.New("claudeAiOauth field not found in credentials")
}
return credentialsContainer.ClaudeAIAuth, nil
}
func checkCredentialFileWritable(path string) error {
file, err := os.OpenFile(path, os.O_WRONLY, 0)
if err != nil {
return err
}
return file.Close()
}
func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) error {
data, err := json.MarshalIndent(map[string]any{
"claudeAiOauth": oauthCredentials,
}, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o600)
}
type oauthCredentials struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ExpiresAt int64 `json:"expiresAt"`
Scopes []string `json:"scopes,omitempty"`
SubscriptionType string `json:"subscriptionType,omitempty"`
RateLimitTier string `json:"rateLimitTier,omitempty"`
IsMax bool `json:"isMax,omitempty"`
}
func (c *oauthCredentials) needsRefresh() bool {
if c.ExpiresAt == 0 {
return false
}
return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs
}
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
if credentials.RefreshToken == "" {
return nil, E.New("refresh token is empty")
}
requestBody, err := json.Marshal(map[string]string{
"grant_type": "refresh_token",
"refresh_token": credentials.RefreshToken,
"client_id": oauth2ClientID,
})
if err != nil {
return nil, E.Cause(err, "marshal request")
}
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
request, err := buildRequest()
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("User-Agent", ccmUserAgentValue)
return request, nil
response, err := client.Do(request)
if err == nil {
return response, nil
}
lastError = err
if ctx.Err() != nil {
return nil, lastError
}
}
return nil, lastError
}
type credentialState struct {
fiveHourUtilization float64
fiveHourReset time.Time
weeklyUtilization float64
weeklyReset time.Time
hardRateLimited bool
rateLimitResetAt time.Time
accountType string
rateLimitTier string
remotePlanWeight float64
lastUpdated time.Time
consecutivePollFailures int
unavailable bool
lastCredentialLoadAttempt time.Time
lastCredentialLoadError string
}
type credentialRequestContext struct {
context.Context
releaseOnce sync.Once
cancelOnce sync.Once
releaseFuncs []func() bool
cancelFunc context.CancelFunc
}
func (c *credentialRequestContext) addInterruptLink(stop func() bool) {
c.releaseFuncs = append(c.releaseFuncs, stop)
}
func (c *credentialRequestContext) releaseCredentialInterrupt() {
c.releaseOnce.Do(func() {
for _, f := range c.releaseFuncs {
f()
}
})
}
func (c *credentialRequestContext) cancelRequest() {
c.releaseCredentialInterrupt()
c.cancelOnce.Do(c.cancelFunc)
}
type credential interface {
tagName() string
isAvailable() bool
isUsable() bool
isExternal() bool
fiveHourUtilization() float64
weeklyUtilization() float64
fiveHourCap() float64
weeklyCap() float64
planWeight() float64
weeklyResetTime() time.Time
markRateLimited(resetAt time.Time)
earliestReset() time.Time
unavailableError() error
getAccessToken() (string, error)
buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error)
updateStateFromHeaders(header http.Header)
wrapRequestContext(ctx context.Context) *credentialRequestContext
interruptConnections()
start() error
pollUsage(ctx context.Context)
lastUpdatedTime() time.Time
pollBackoff(base time.Duration) time.Duration
usageTrackerOrNil() *AggregatedUsage
httpClient() *http.Client
close()
}
type credentialSelectionScope string
const (
credentialSelectionScopeAll credentialSelectionScope = "all"
credentialSelectionScopeNonExternal credentialSelectionScope = "non_external"
)
type credentialSelection struct {
scope credentialSelectionScope
filter func(credential) bool
}
func (s credentialSelection) allows(cred credential) bool {
return s.filter == nil || s.filter(cred)
}
func (s credentialSelection) scopeOrDefault() credentialSelectionScope {
if s.scope == "" {
return credentialSelectionScopeAll
}
return s.scope
}
// Claude Code's unified rate-limit handling parses these reset headers with
// Number(...), compares them against Date.now()/1000, and renders them via
// new Date(seconds*1000), so keep the wire format pinned to Unix epoch seconds.
func parseAnthropicResetHeaderValue(headerName string, headerValue string) time.Time {
unixEpoch, err := strconv.ParseInt(headerValue, 10, 64)
if err != nil {
return nil, err
panic("invalid " + headerName + " header: expected Unix epoch seconds, got " + strconv.Quote(headerValue))
}
defer response.Body.Close()
if response.StatusCode == http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
return nil, E.New("refresh rate limited: ", response.Status, " ", string(body))
if unixEpoch <= 0 {
panic("invalid " + headerName + " header: expected positive Unix epoch seconds, got " + strconv.Quote(headerValue))
}
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
}
var tokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
}
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
if err != nil {
return nil, E.Cause(err, "decode response")
}
newCredentials := *credentials
newCredentials.AccessToken = tokenResponse.AccessToken
if tokenResponse.RefreshToken != "" {
newCredentials.RefreshToken = tokenResponse.RefreshToken
}
newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000
return &newCredentials, nil
return time.Unix(unixEpoch, 0)
}
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
if credentials == nil {
return nil
func parseOptionalAnthropicResetHeader(headers http.Header, headerName string) (time.Time, bool) {
headerValue := headers.Get(headerName)
if headerValue == "" {
return time.Time{}, false
}
cloned := *credentials
cloned.Scopes = append([]string(nil), credentials.Scopes...)
return &cloned
return parseAnthropicResetHeaderValue(headerName, headerValue), true
}
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
if left == nil || right == nil {
return left == right
func parseRequiredAnthropicResetHeader(headers http.Header, headerName string) time.Time {
headerValue := headers.Get(headerName)
if headerValue == "" {
panic("missing required " + headerName + " header")
}
return parseAnthropicResetHeaderValue(headerName, headerValue)
}
func parseRateLimitResetFromHeaders(headers http.Header) time.Time {
claim := headers.Get("anthropic-ratelimit-unified-representative-claim")
switch claim {
case "5h":
return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset")
case "7d":
return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset")
default:
panic("invalid anthropic-ratelimit-unified-representative-claim header: " + strconv.Quote(claim))
}
return left.AccessToken == right.AccessToken &&
left.RefreshToken == right.RefreshToken &&
left.ExpiresAt == right.ExpiresAt &&
slices.Equal(left.Scopes, right.Scopes) &&
left.SubscriptionType == right.SubscriptionType &&
left.RateLimitTier == right.RateLimitTier &&
left.IsMax == right.IsMax
}

View File

@@ -0,0 +1,192 @@
package ccm
import (
"context"
"time"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
func buildCredentialProviders(
ctx context.Context,
options option.CCMServiceOptions,
logger log.ContextLogger,
) (map[string]credentialProvider, []credential, error) {
allCredentialMap := make(map[string]credential)
var allCreds []credential
providers := make(map[string]credentialProvider)
// Pass 1: create default and external credentials
for _, credOpt := range options.Credentials {
switch credOpt.Type {
case "default":
cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger)
if err != nil {
return nil, nil, err
}
allCredentialMap[credOpt.Tag] = cred
allCreds = append(allCreds, cred)
providers[credOpt.Tag] = &singleCredentialProvider{cred: cred}
case "external":
cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger)
if err != nil {
return nil, nil, err
}
allCredentialMap[credOpt.Tag] = cred
allCreds = append(allCreds, cred)
providers[credOpt.Tag] = &singleCredentialProvider{cred: cred}
}
}
// Pass 2: create balancer providers
for _, credOpt := range options.Credentials {
if credOpt.Type == "balancer" {
subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag)
if err != nil {
return nil, nil, err
}
providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger)
}
}
return providers, allCreds, nil
}
func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) {
credentials := make([]credential, 0, len(tags))
for _, tag := range tags {
cred, exists := allCredentials[tag]
if !exists {
return nil, E.New("credential ", parentTag, " references unknown credential: ", tag)
}
credentials = append(credentials, cred)
}
if len(credentials) == 0 {
return nil, E.New("credential ", parentTag, " has no sub-credentials")
}
return credentials, nil
}
func validateCCMOptions(options option.CCMServiceOptions) error {
hasCredentials := len(options.Credentials) > 0
hasLegacyPath := options.CredentialPath != ""
hasLegacyUsages := options.UsagesPath != ""
hasLegacyDetour := options.Detour != ""
if hasCredentials && hasLegacyPath {
return E.New("credential_path and credentials are mutually exclusive")
}
if hasCredentials && hasLegacyUsages {
return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials")
}
if hasCredentials && hasLegacyDetour {
return E.New("detour and credentials are mutually exclusive; use detour on individual credentials")
}
if hasCredentials {
tags := make(map[string]bool)
credentialTypes := make(map[string]string)
for _, cred := range options.Credentials {
if tags[cred.Tag] {
return E.New("duplicate credential tag: ", cred.Tag)
}
tags[cred.Tag] = true
credentialTypes[cred.Tag] = cred.Type
if cred.Type == "default" || cred.Type == "" {
if cred.DefaultOptions.Reserve5h > 99 {
return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99")
}
if cred.DefaultOptions.ReserveWeekly > 99 {
return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99")
}
if cred.DefaultOptions.Limit5h > 100 {
return E.New("credential ", cred.Tag, ": limit_5h must be at most 100")
}
if cred.DefaultOptions.LimitWeekly > 100 {
return E.New("credential ", cred.Tag, ": limit_weekly must be at most 100")
}
if cred.DefaultOptions.Reserve5h > 0 && cred.DefaultOptions.Limit5h > 0 {
return E.New("credential ", cred.Tag, ": reserve_5h and limit_5h are mutually exclusive")
}
if cred.DefaultOptions.ReserveWeekly > 0 && cred.DefaultOptions.LimitWeekly > 0 {
return E.New("credential ", cred.Tag, ": reserve_weekly and limit_weekly are mutually exclusive")
}
}
if cred.Type == "external" {
if cred.ExternalOptions.Token == "" {
return E.New("credential ", cred.Tag, ": external credential requires token")
}
if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" {
return E.New("credential ", cred.Tag, ": reverse external credential requires url")
}
}
if cred.Type == "balancer" {
switch cred.BalancerOptions.Strategy {
case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback:
default:
return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy)
}
if cred.BalancerOptions.RebalanceThreshold < 0 {
return E.New("credential ", cred.Tag, ": rebalance_threshold must not be negative")
}
}
}
for _, user := range options.Users {
if user.Credential == "" {
return E.New("user ", user.Name, " must specify credential in multi-credential mode")
}
if !tags[user.Credential] {
return E.New("user ", user.Name, " references unknown credential: ", user.Credential)
}
if user.ExternalCredential != "" {
if !tags[user.ExternalCredential] {
return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential)
}
if credentialTypes[user.ExternalCredential] != "external" {
return E.New("user ", user.Name, ": external_credential must reference an external type credential")
}
}
}
}
return nil
}
func credentialForUser(
userConfigMap map[string]*option.CCMUser,
providers map[string]credentialProvider,
legacyProvider credentialProvider,
username string,
) (credentialProvider, error) {
if legacyProvider != nil {
return legacyProvider, nil
}
userConfig, exists := userConfigMap[username]
if !exists {
return nil, E.New("no credential mapping for user: ", username)
}
provider, exists := providers[userConfig.Credential]
if !exists {
return nil, E.New("unknown credential: ", userConfig.Credential)
}
return provider, nil
}
func noUserCredentialProvider(
providers map[string]credentialProvider,
legacyProvider credentialProvider,
options option.CCMServiceOptions,
) credentialProvider {
if legacyProvider != nil {
return legacyProvider
}
if len(options.Credentials) > 0 {
tag := options.Credentials[0].Tag
return providers[tag]
}
return nil
}

View File

@@ -0,0 +1,726 @@
package ccm
import (
"bytes"
"context"
stdTLS "crypto/tls"
"encoding/json"
"io"
"math"
"net"
"net/http"
"strconv"
"sync"
"time"
"github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/ntp"
)
type defaultCredential struct {
tag string
serviceContext context.Context
credentialPath string
credentialFilePath string
credentials *oauthCredentials
access sync.RWMutex
state credentialState
stateAccess sync.RWMutex
pollAccess sync.Mutex
reloadAccess sync.Mutex
watcherAccess sync.Mutex
cap5h float64
capWeekly float64
usageTracker *AggregatedUsage
forwardHTTPClient *http.Client
logger log.ContextLogger
watcher *fswatch.Watcher
watcherRetryAt time.Time
// Connection interruption
onBecameUnusable func()
interrupted bool
requestContext context.Context
cancelRequests context.CancelFunc
requestAccess sync.Mutex
}
func newDefaultCredential(ctx context.Context, tag string, options option.CCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) {
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: option.DialerOptions{
Detour: options.Detour,
},
RemoteIsDomain: true,
})
if err != nil {
return nil, E.Cause(err, "create dialer for credential ", tag)
}
httpClient := &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: true,
TLSClientConfig: &stdTLS.Config{
RootCAs: adapter.RootPoolFromContext(ctx),
Time: ntp.TimeFuncFromContext(ctx),
},
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
},
}
reserve5h := options.Reserve5h
if reserve5h == 0 {
reserve5h = 1
}
reserveWeekly := options.ReserveWeekly
if reserveWeekly == 0 {
reserveWeekly = 1
}
var cap5h float64
if options.Limit5h > 0 {
cap5h = float64(options.Limit5h)
} else {
cap5h = float64(100 - reserve5h)
}
var capWeekly float64
if options.LimitWeekly > 0 {
capWeekly = float64(options.LimitWeekly)
} else {
capWeekly = float64(100 - reserveWeekly)
}
requestContext, cancelRequests := context.WithCancel(context.Background())
credential := &defaultCredential{
tag: tag,
serviceContext: ctx,
credentialPath: options.CredentialPath,
cap5h: cap5h,
capWeekly: capWeekly,
forwardHTTPClient: httpClient,
logger: logger,
requestContext: requestContext,
cancelRequests: cancelRequests,
}
if options.UsagesPath != "" {
credential.usageTracker = &AggregatedUsage{
LastUpdated: time.Now(),
Combinations: make([]CostCombination, 0),
filePath: options.UsagesPath,
logger: logger,
}
}
return credential, nil
}
func (c *defaultCredential) start() error {
credentialFilePath, err := resolveCredentialFilePath(c.credentialPath)
if err != nil {
return E.Cause(err, "resolve credential path for ", c.tag)
}
c.credentialFilePath = credentialFilePath
err = c.ensureCredentialWatcher()
if err != nil {
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
}
err = c.reloadCredentials(true)
if err != nil {
c.logger.Warn("initial credential load for ", c.tag, ": ", err)
}
if c.usageTracker != nil {
err = c.usageTracker.Load()
if err != nil {
c.logger.Warn("load usage statistics for ", c.tag, ": ", err)
}
}
return nil
}
func (c *defaultCredential) getAccessToken() (string, error) {
c.retryCredentialReloadIfNeeded()
c.access.RLock()
if c.credentials != nil && !c.credentials.needsRefresh() {
token := c.credentials.AccessToken
c.access.RUnlock()
return token, nil
}
c.access.RUnlock()
err := c.reloadCredentials(true)
if err == nil {
c.access.RLock()
if c.credentials != nil && !c.credentials.needsRefresh() {
token := c.credentials.AccessToken
c.access.RUnlock()
return token, nil
}
c.access.RUnlock()
}
c.access.Lock()
defer c.access.Unlock()
if c.credentials == nil {
return "", c.unavailableError()
}
if !c.credentials.needsRefresh() {
return c.credentials.AccessToken, nil
}
err = platformCanWriteCredentials(c.credentialPath)
if err != nil {
return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation")
}
baseCredentials := cloneCredentials(c.credentials)
newCredentials, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials)
if err != nil {
return "", err
}
latestCredentials, latestErr := platformReadCredentials(c.credentialPath)
if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) {
c.credentials = latestCredentials
c.stateAccess.Lock()
c.state.unavailable = false
c.state.lastCredentialLoadAttempt = time.Now()
c.state.lastCredentialLoadError = ""
c.state.accountType = latestCredentials.SubscriptionType
c.state.rateLimitTier = latestCredentials.RateLimitTier
c.checkTransitionLocked()
c.stateAccess.Unlock()
if !latestCredentials.needsRefresh() {
return latestCredentials.AccessToken, nil
}
return "", E.New("credential ", c.tag, " changed while refreshing")
}
c.credentials = newCredentials
c.stateAccess.Lock()
c.state.unavailable = false
c.state.lastCredentialLoadAttempt = time.Now()
c.state.lastCredentialLoadError = ""
c.state.accountType = newCredentials.SubscriptionType
c.state.rateLimitTier = newCredentials.RateLimitTier
c.checkTransitionLocked()
c.stateAccess.Unlock()
err = platformWriteCredentials(newCredentials, c.credentialPath)
if err != nil {
c.logger.Error("persist refreshed token for ", c.tag, ": ", err)
}
return newCredentials.AccessToken, nil
}
func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
c.stateAccess.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
hadData := false
fiveHourResetChanged := false
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists {
hadData = true
if value.After(c.state.fiveHourReset) {
fiveHourResetChanged = true
c.state.fiveHourReset = value
}
}
if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" {
value, err := strconv.ParseFloat(utilization, 64)
if err == nil {
hadData = true
newValue := math.Ceil(value * 100)
if newValue >= c.state.fiveHourUtilization || fiveHourResetChanged {
c.state.fiveHourUtilization = newValue
}
}
}
weeklyResetChanged := false
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists {
hadData = true
if value.After(c.state.weeklyReset) {
weeklyResetChanged = true
c.state.weeklyReset = value
}
}
if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" {
value, err := strconv.ParseFloat(utilization, 64)
if err == nil {
hadData = true
newValue := math.Ceil(value * 100)
if newValue >= c.state.weeklyUtilization || weeklyResetChanged {
c.state.weeklyUtilization = newValue
}
}
}
if hadData {
c.state.consecutivePollFailures = 0
c.state.lastUpdated = time.Now()
}
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
resetSuffix := ""
if !c.state.weeklyReset.IsZero() {
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
}
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
}
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *defaultCredential) markRateLimited(resetAt time.Time) {
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
c.stateAccess.Lock()
c.state.hardRateLimited = true
c.state.rateLimitResetAt = resetAt
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *defaultCredential) isUsable() bool {
c.retryCredentialReloadIfNeeded()
c.stateAccess.RLock()
if c.state.unavailable {
c.stateAccess.RUnlock()
return false
}
if c.state.consecutivePollFailures > 0 {
c.stateAccess.RUnlock()
return false
}
if c.state.hardRateLimited {
if time.Now().Before(c.state.rateLimitResetAt) {
c.stateAccess.RUnlock()
return false
}
c.stateAccess.RUnlock()
c.stateAccess.Lock()
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
usable := c.checkReservesLocked()
c.stateAccess.Unlock()
return usable
}
usable := c.checkReservesLocked()
c.stateAccess.RUnlock()
return usable
}
func (c *defaultCredential) checkReservesLocked() bool {
if c.state.fiveHourUtilization >= c.cap5h {
return false
}
if c.state.weeklyUtilization >= c.capWeekly {
return false
}
return true
}
// checkTransitionLocked detects usable→unusable transition.
// Must be called with stateAccess write lock held.
func (c *defaultCredential) checkTransitionLocked() bool {
unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() || c.state.consecutivePollFailures > 0
if unusable && !c.interrupted {
c.interrupted = true
return true
}
if !unusable && c.interrupted {
c.interrupted = false
}
return false
}
func (c *defaultCredential) interruptConnections() {
c.logger.Warn("interrupting connections for ", c.tag)
c.requestAccess.Lock()
c.cancelRequests()
c.requestContext, c.cancelRequests = context.WithCancel(context.Background())
c.requestAccess.Unlock()
if c.onBecameUnusable != nil {
c.onBecameUnusable()
}
}
func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext {
c.requestAccess.Lock()
credentialContext := c.requestContext
c.requestAccess.Unlock()
derived, cancel := context.WithCancel(parent)
stop := context.AfterFunc(credentialContext, func() {
cancel()
})
return &credentialRequestContext{
Context: derived,
releaseFuncs: []func() bool{stop},
cancelFunc: cancel,
}
}
func (c *defaultCredential) weeklyUtilization() float64 {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.weeklyUtilization
}
func (c *defaultCredential) planWeight() float64 {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return ccmPlanWeight(c.state.accountType, c.state.rateLimitTier)
}
func (c *defaultCredential) weeklyResetTime() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.weeklyReset
}
func (c *defaultCredential) isAvailable() bool {
c.retryCredentialReloadIfNeeded()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return !c.state.unavailable
}
func (c *defaultCredential) unavailableError() error {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
if !c.state.unavailable {
return nil
}
if c.state.lastCredentialLoadError == "" {
return E.New("credential ", c.tag, " is unavailable")
}
return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError)
}
func (c *defaultCredential) lastUpdatedTime() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.lastUpdated
}
func (c *defaultCredential) markUsagePollAttempted() {
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
c.state.lastUpdated = time.Now()
}
func (c *defaultCredential) incrementPollFailures() {
c.stateAccess.Lock()
c.state.consecutivePollFailures++
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration {
c.stateAccess.RLock()
failures := c.state.consecutivePollFailures
c.stateAccess.RUnlock()
if failures <= 0 {
return baseInterval
}
backoff := failedPollRetryInterval * time.Duration(1<<(failures-1))
if backoff > httpRetryMaxBackoff {
return httpRetryMaxBackoff
}
return backoff
}
func (c *defaultCredential) isPollBackoffAtCap() bool {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
failures := c.state.consecutivePollFailures
return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff
}
func (c *defaultCredential) earliestReset() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
if c.state.unavailable {
return time.Time{}
}
if c.state.hardRateLimited {
return c.state.rateLimitResetAt
}
earliest := c.state.fiveHourReset
if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) {
earliest = c.state.weeklyReset
}
return earliest
}
func (c *defaultCredential) pollUsage(ctx context.Context) {
if !c.pollAccess.TryLock() {
return
}
defer c.pollAccess.Unlock()
defer c.markUsagePollAttempted()
c.retryCredentialReloadIfNeeded()
if !c.isAvailable() {
return
}
accessToken, err := c.getAccessToken()
if err != nil {
if !c.isPollBackoffAtCap() {
c.logger.Error("poll usage for ", c.tag, ": get token: ", err)
}
c.incrementPollFailures()
return
}
httpClient := &http.Client{
Transport: c.forwardHTTPClient.Transport,
Timeout: 5 * time.Second,
}
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+accessToken)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("User-Agent", ccmUserAgentValue)
request.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
return request, nil
})
if err != nil {
if !c.isPollBackoffAtCap() {
c.logger.Error("poll usage for ", c.tag, ": ", err)
}
c.incrementPollFailures()
return
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
if response.StatusCode == http.StatusTooManyRequests {
c.logger.Warn("poll usage for ", c.tag, ": rate limited")
}
body, _ := io.ReadAll(response.Body)
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
c.incrementPollFailures()
return
}
var usageResponse struct {
FiveHour struct {
Utilization float64 `json:"utilization"`
ResetsAt time.Time `json:"resets_at"`
} `json:"five_hour"`
SevenDay struct {
Utilization float64 `json:"utilization"`
ResetsAt time.Time `json:"resets_at"`
} `json:"seven_day"`
}
err = json.NewDecoder(response.Body).Decode(&usageResponse)
if err != nil {
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
c.incrementPollFailures()
return
}
c.stateAccess.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
c.state.consecutivePollFailures = 0
c.state.fiveHourUtilization = usageResponse.FiveHour.Utilization
if !usageResponse.FiveHour.ResetsAt.IsZero() {
c.state.fiveHourReset = usageResponse.FiveHour.ResetsAt
}
c.state.weeklyUtilization = usageResponse.SevenDay.Utilization
if !usageResponse.SevenDay.ResetsAt.IsZero() {
c.state.weeklyReset = usageResponse.SevenDay.ResetsAt
}
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
resetSuffix := ""
if !c.state.weeklyReset.IsZero() {
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
}
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
}
needsProfileFetch := c.state.rateLimitTier == ""
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
if needsProfileFetch {
c.fetchProfile(ctx, httpClient, accessToken)
}
}
func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.Client, accessToken string) {
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/profile", nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+accessToken)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("User-Agent", ccmUserAgentValue)
return request, nil
})
if err != nil {
c.logger.Debug("fetch profile for ", c.tag, ": ", err)
return
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return
}
var profileResponse struct {
Organization *struct {
OrganizationType string `json:"organization_type"`
RateLimitTier string `json:"rate_limit_tier"`
} `json:"organization"`
}
err = json.NewDecoder(response.Body).Decode(&profileResponse)
if err != nil || profileResponse.Organization == nil {
return
}
accountType := ""
switch profileResponse.Organization.OrganizationType {
case "claude_pro":
accountType = "pro"
case "claude_max":
accountType = "max"
case "claude_team":
accountType = "team"
case "claude_enterprise":
accountType = "enterprise"
}
rateLimitTier := profileResponse.Organization.RateLimitTier
c.stateAccess.Lock()
if accountType != "" && c.state.accountType == "" {
c.state.accountType = accountType
}
if rateLimitTier != "" {
c.state.rateLimitTier = rateLimitTier
}
c.stateAccess.Unlock()
c.logger.Info("fetched profile for ", c.tag, ": type=", c.state.accountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(c.state.accountType, rateLimitTier))
}
func (c *defaultCredential) close() {
if c.watcher != nil {
err := c.watcher.Close()
if err != nil {
c.logger.Error("close credential watcher for ", c.tag, ": ", err)
}
}
if c.usageTracker != nil {
c.usageTracker.cancelPendingSave()
err := c.usageTracker.Save()
if err != nil {
c.logger.Error("save usage statistics for ", c.tag, ": ", err)
}
}
}
func (c *defaultCredential) tagName() string {
return c.tag
}
func (c *defaultCredential) isExternal() bool {
return false
}
func (c *defaultCredential) fiveHourUtilization() float64 {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.fiveHourUtilization
}
func (c *defaultCredential) fiveHourCap() float64 {
return c.cap5h
}
func (c *defaultCredential) weeklyCap() float64 {
return c.capWeekly
}
func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage {
return c.usageTracker
}
func (c *defaultCredential) httpClient() *http.Client {
return c.forwardHTTPClient
}
func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) {
accessToken, err := c.getAccessToken()
if err != nil {
return nil, E.Cause(err, "get access token for ", c.tag)
}
proxyURL := claudeAPIBaseURL + original.URL.RequestURI()
var body io.Reader
if bodyBytes != nil {
body = bytes.NewReader(bodyBytes)
} else {
body = original.Body
}
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
if err != nil {
return nil, err
}
for key, values := range original.Header {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" {
proxyRequest.Header[key] = values
}
}
serviceOverridesAcceptEncoding := len(serviceHeaders.Values("Accept-Encoding")) > 0
if c.usageTracker != nil && !serviceOverridesAcceptEncoding {
proxyRequest.Header.Del("Accept-Encoding")
}
anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta")
if anthropicBetaHeader != "" {
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader)
} else {
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
}
for key, values := range serviceHeaders {
proxyRequest.Header.Del(key)
proxyRequest.Header[key] = values
}
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
return proxyRequest, nil
}

View File

@@ -29,16 +29,16 @@ import (
const reverseProxyBaseURL = "http://reverse-proxy"
type externalCredential struct {
tag string
baseURL string
token string
httpClient *http.Client
state credentialState
stateMutex sync.RWMutex
pollAccess sync.Mutex
pollInterval time.Duration
usageTracker *AggregatedUsage
logger log.ContextLogger
tag string
baseURL string
token string
forwardHTTPClient *http.Client
state credentialState
stateAccess sync.RWMutex
pollAccess sync.Mutex
pollInterval time.Duration
usageTracker *AggregatedUsage
logger log.ContextLogger
onBecameUnusable func()
interrupted bool
@@ -128,7 +128,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx
if options.URL == "" {
// Receiver mode: no URL, wait for reverse connection
cred.baseURL = reverseProxyBaseURL
cred.httpClient = &http.Client{
cred.forwardHTTPClient = &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: false,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
@@ -192,10 +192,10 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx
Time: ntp.TimeFuncFromContext(ctx),
}
}
cred.httpClient = &http.Client{Transport: transport}
cred.forwardHTTPClient = &http.Client{Transport: transport}
} else {
// Normal mode: standard HTTP client for proxying
cred.httpClient = &http.Client{Transport: transport}
cred.forwardHTTPClient = &http.Client{Transport: transport}
cred.reverseHttpClient = &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: false,
@@ -248,40 +248,40 @@ func (c *externalCredential) isUsable() bool {
if !c.isAvailable() {
return false
}
c.stateMutex.RLock()
c.stateAccess.RLock()
if c.state.consecutivePollFailures > 0 {
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
return false
}
if c.state.hardRateLimited {
if time.Now().Before(c.state.rateLimitResetAt) {
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
return false
}
c.stateMutex.RUnlock()
c.stateMutex.Lock()
c.stateAccess.RUnlock()
c.stateAccess.Lock()
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
// No reserve for external: only 100% is unusable
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
c.stateMutex.Unlock()
c.stateAccess.Unlock()
return usable
}
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
return usable
}
func (c *externalCredential) fiveHourUtilization() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.fiveHourUtilization
}
func (c *externalCredential) weeklyUtilization() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.weeklyUtilization
}
@@ -294,8 +294,8 @@ func (c *externalCredential) weeklyCap() float64 {
}
func (c *externalCredential) planWeight() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
if c.state.remotePlanWeight > 0 {
return c.state.remotePlanWeight
}
@@ -303,26 +303,26 @@ func (c *externalCredential) planWeight() float64 {
}
func (c *externalCredential) weeklyResetTime() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.weeklyReset
}
func (c *externalCredential) markRateLimited(resetAt time.Time) {
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.hardRateLimited = true
c.state.rateLimitResetAt = resetAt
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *externalCredential) earliestReset() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
if c.state.hardRateLimited {
return c.state.rateLimitResetAt
}
@@ -408,7 +408,7 @@ func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Con
}
func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
c.stateMutex.Lock()
c.stateAccess.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
@@ -455,7 +455,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
}
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
@@ -530,9 +530,9 @@ func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Resp
}
}
// Forward transport with retries
if c.httpClient != nil {
if c.forwardHTTPClient != nil {
forwardClient := &http.Client{
Transport: c.httpClient.Transport,
Transport: c.forwardHTTPClient.Transport,
Timeout: 5 * time.Second,
}
return doHTTPWithRetry(ctx, forwardClient, buildRequest(c.baseURL))
@@ -563,10 +563,10 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
// 404 means the remote does not have a status endpoint yet;
// usage will be updated passively from response headers.
if response.StatusCode == http.StatusNotFound {
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.consecutivePollFailures = 0
c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
} else {
c.incrementPollFailures()
}
@@ -585,7 +585,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
return
}
c.stateMutex.Lock()
c.stateAccess.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
@@ -606,28 +606,28 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
}
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *externalCredential) lastUpdatedTime() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.lastUpdated
}
func (c *externalCredential) markUsagePollAttempted() {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
c.state.lastUpdated = time.Now()
}
func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration {
c.stateMutex.RLock()
c.stateAccess.RLock()
failures := c.state.consecutivePollFailures
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
if failures <= 0 {
return baseInterval
}
@@ -639,17 +639,17 @@ func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Durati
}
func (c *externalCredential) isPollBackoffAtCap() bool {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
failures := c.state.consecutivePollFailures
return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff
}
func (c *externalCredential) incrementPollFailures() {
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.consecutivePollFailures++
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
@@ -659,14 +659,14 @@ func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage {
return c.usageTracker
}
func (c *externalCredential) httpTransport() *http.Client {
func (c *externalCredential) httpClient() *http.Client {
if c.reverseHttpClient != nil {
session := c.getReverseSession()
if session != nil && !session.IsClosed() {
return c.reverseHttpClient
}
}
return c.httpClient
return c.forwardHTTPClient
}
func (c *externalCredential) close() {

View File

@@ -62,10 +62,10 @@ func (c *defaultCredential) ensureCredentialWatcher() error {
}
func (c *defaultCredential) retryCredentialReloadIfNeeded() {
c.stateMutex.RLock()
c.stateAccess.RLock()
unavailable := c.state.unavailable
lastAttempt := c.state.lastCredentialLoadAttempt
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
if !unavailable {
return
}
@@ -84,10 +84,10 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
c.reloadAccess.Lock()
defer c.reloadAccess.Unlock()
c.stateMutex.RLock()
c.stateAccess.RLock()
unavailable := c.state.unavailable
lastAttempt := c.state.lastCredentialLoadAttempt
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
if !force {
if !unavailable {
return nil
@@ -97,43 +97,43 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
}
}
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.lastCredentialLoadAttempt = time.Now()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
credentials, err := platformReadCredentials(c.credentialPath)
if err != nil {
return c.markCredentialsUnavailable(E.Cause(err, "read credentials"))
}
c.accessMutex.Lock()
c.access.Lock()
c.credentials = credentials
c.accessMutex.Unlock()
c.access.Unlock()
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.unavailable = false
c.state.lastCredentialLoadError = ""
c.state.accountType = credentials.SubscriptionType
c.state.rateLimitTier = credentials.RateLimitTier
c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
return nil
}
func (c *defaultCredential) markCredentialsUnavailable(err error) error {
c.accessMutex.Lock()
c.access.Lock()
hadCredentials := c.credentials != nil
c.credentials = nil
c.accessMutex.Unlock()
c.access.Unlock()
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.unavailable = true
c.state.lastCredentialLoadError = err.Error()
c.state.accountType = ""
c.state.rateLimitTier = ""
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt && hadCredentials {
c.interruptConnections()

View File

@@ -0,0 +1,224 @@
package ccm
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"os"
"os/user"
"path/filepath"
"runtime"
"slices"
"sync"
"time"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
)
const (
oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
oauth2TokenURL = "https://platform.claude.com/v1/oauth/token"
claudeAPIBaseURL = "https://api.anthropic.com"
tokenRefreshBufferMs = 60000
anthropicBetaOAuthValue = "oauth-2025-04-20"
)
const ccmUserAgentFallback = "claude-code/2.1.72"
var (
ccmUserAgentOnce sync.Once
ccmUserAgentValue string
)
func initCCMUserAgent(logger log.ContextLogger) {
ccmUserAgentOnce.Do(func() {
version, err := detectClaudeCodeVersion()
if err != nil {
logger.Error("detect Claude Code version: ", err)
ccmUserAgentValue = ccmUserAgentFallback
return
}
logger.Debug("detected Claude Code version: ", version)
ccmUserAgentValue = "claude-code/" + version
})
}
func detectClaudeCodeVersion() (string, error) {
userInfo, err := getRealUser()
if err != nil {
return "", E.Cause(err, "get user")
}
binaryName := "claude"
if runtime.GOOS == "windows" {
binaryName = "claude.exe"
}
linkPath := filepath.Join(userInfo.HomeDir, ".local", "bin", binaryName)
target, err := os.Readlink(linkPath)
if err != nil {
return "", E.Cause(err, "readlink ", linkPath)
}
if !filepath.IsAbs(target) {
target = filepath.Join(filepath.Dir(linkPath), target)
}
parent := filepath.Base(filepath.Dir(target))
if parent != "versions" {
return "", E.New("unexpected symlink target: ", target)
}
return filepath.Base(target), nil
}
func getRealUser() (*user.User, error) {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
sudoUserInfo, err := user.Lookup(sudoUser)
if err == nil {
return sudoUserInfo, nil
}
}
return user.Current()
}
func getDefaultCredentialsPath() (string, error) {
if configDir := os.Getenv("CLAUDE_CONFIG_DIR"); configDir != "" {
return filepath.Join(configDir, ".credentials.json"), nil
}
userInfo, err := getRealUser()
if err != nil {
return "", err
}
return filepath.Join(userInfo.HomeDir, ".claude", ".credentials.json"), nil
}
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var credentialsContainer struct {
ClaudeAIAuth *oauthCredentials `json:"claudeAiOauth,omitempty"`
}
err = json.Unmarshal(data, &credentialsContainer)
if err != nil {
return nil, err
}
if credentialsContainer.ClaudeAIAuth == nil {
return nil, E.New("claudeAiOauth field not found in credentials")
}
return credentialsContainer.ClaudeAIAuth, nil
}
func checkCredentialFileWritable(path string) error {
file, err := os.OpenFile(path, os.O_WRONLY, 0)
if err != nil {
return err
}
return file.Close()
}
func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) error {
data, err := json.MarshalIndent(map[string]any{
"claudeAiOauth": oauthCredentials,
}, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o600)
}
type oauthCredentials struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ExpiresAt int64 `json:"expiresAt"`
Scopes []string `json:"scopes,omitempty"`
SubscriptionType string `json:"subscriptionType,omitempty"`
RateLimitTier string `json:"rateLimitTier,omitempty"`
IsMax bool `json:"isMax,omitempty"`
}
func (c *oauthCredentials) needsRefresh() bool {
if c.ExpiresAt == 0 {
return false
}
return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs
}
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
if credentials.RefreshToken == "" {
return nil, E.New("refresh token is empty")
}
requestBody, err := json.Marshal(map[string]string{
"grant_type": "refresh_token",
"refresh_token": credentials.RefreshToken,
"client_id": oauth2ClientID,
})
if err != nil {
return nil, E.Cause(err, "marshal request")
}
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("User-Agent", ccmUserAgentValue)
return request, nil
})
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode == http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
return nil, E.New("refresh rate limited: ", response.Status, " ", string(body))
}
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
}
var tokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
}
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
if err != nil {
return nil, E.Cause(err, "decode response")
}
newCredentials := *credentials
newCredentials.AccessToken = tokenResponse.AccessToken
if tokenResponse.RefreshToken != "" {
newCredentials.RefreshToken = tokenResponse.RefreshToken
}
newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000
return &newCredentials, nil
}
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
if credentials == nil {
return nil
}
cloned := *credentials
cloned.Scopes = append([]string(nil), credentials.Scopes...)
return &cloned
}
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
if left == nil || right == nil {
return left == right
}
return left.AccessToken == right.AccessToken &&
left.RefreshToken == right.RefreshToken &&
left.ExpiresAt == right.ExpiresAt &&
slices.Equal(left.Scopes, right.Scopes) &&
left.SubscriptionType == right.SubscriptionType &&
left.RateLimitTier == right.RateLimitTier &&
left.IsMax == right.IsMax
}

View File

@@ -0,0 +1,405 @@
package ccm
import (
"context"
"math/rand/v2"
"sync"
"sync/atomic"
"time"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
)
type credentialProvider interface {
selectCredential(sessionID string, selection credentialSelection) (credential, bool, error)
onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential
linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool
pollIfStale(ctx context.Context)
allCredentials() []credential
close()
}
type singleCredentialProvider struct {
cred credential
sessionAccess sync.RWMutex
sessions map[string]time.Time
}
func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) {
if !selection.allows(p.cred) {
return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out")
}
if !p.cred.isAvailable() {
return nil, false, p.cred.unavailableError()
}
if !p.cred.isUsable() {
return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited")
}
var isNew bool
if sessionID != "" {
p.sessionAccess.Lock()
if p.sessions == nil {
p.sessions = make(map[string]time.Time)
}
_, exists := p.sessions[sessionID]
if !exists {
p.sessions[sessionID] = time.Now()
isNew = true
}
p.sessionAccess.Unlock()
}
return p.cred, isNew, nil
}
func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential {
cred.markRateLimited(resetAt)
return nil
}
func (p *singleCredentialProvider) pollIfStale(ctx context.Context) {
now := time.Now()
p.sessionAccess.Lock()
for id, createdAt := range p.sessions {
if now.Sub(createdAt) > sessionExpiry {
delete(p.sessions, id)
}
}
p.sessionAccess.Unlock()
if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) {
p.cred.pollUsage(ctx)
}
}
func (p *singleCredentialProvider) allCredentials() []credential {
return []credential{p.cred}
}
func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool {
return func() bool {
return false
}
}
func (p *singleCredentialProvider) close() {}
type sessionEntry struct {
tag string
selectionScope credentialSelectionScope
createdAt time.Time
}
type credentialInterruptKey struct {
tag string
selectionScope credentialSelectionScope
}
type credentialInterruptEntry struct {
context context.Context
cancel context.CancelFunc
}
type balancerProvider struct {
credentials []credential
strategy string
roundRobinIndex atomic.Uint64
pollInterval time.Duration
rebalanceThreshold float64
sessionAccess sync.RWMutex
sessions map[string]sessionEntry
interruptAccess sync.Mutex
credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry
logger log.ContextLogger
}
func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider {
if pollInterval <= 0 {
pollInterval = defaultPollInterval
}
return &balancerProvider{
credentials: credentials,
strategy: strategy,
pollInterval: pollInterval,
rebalanceThreshold: rebalanceThreshold,
sessions: make(map[string]sessionEntry),
credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry),
logger: logger,
}
}
func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) {
if p.strategy == C.BalancerStrategyFallback {
best := p.pickCredential(selection.filter)
if best == nil {
return nil, false, allCredentialsUnavailableError(p.credentials)
}
return best, false, nil
}
selectionScope := selection.scopeOrDefault()
if sessionID != "" {
p.sessionAccess.RLock()
entry, exists := p.sessions[sessionID]
p.sessionAccess.RUnlock()
if exists {
if entry.selectionScope == selectionScope {
for _, cred := range p.credentials {
if cred.tagName() == entry.tag && selection.allows(cred) && cred.isUsable() {
if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) {
better := p.pickLeastUsed(selection.filter)
if better != nil && better.tagName() != cred.tagName() {
effectiveThreshold := p.rebalanceThreshold / cred.planWeight()
delta := cred.weeklyUtilization() - better.weeklyUtilization()
if delta > effectiveThreshold {
p.logger.Info("rebalancing away from ", cred.tagName(),
": utilization delta ", delta, "% exceeds effective threshold ",
effectiveThreshold, "% (weight ", cred.planWeight(), ")")
p.rebalanceCredential(cred.tagName(), selectionScope)
break
}
}
}
return cred, false, nil
}
}
}
p.sessionAccess.Lock()
delete(p.sessions, sessionID)
p.sessionAccess.Unlock()
}
}
best := p.pickCredential(selection.filter)
if best == nil {
return nil, false, allCredentialsUnavailableError(p.credentials)
}
isNew := sessionID != ""
if isNew {
p.sessionAccess.Lock()
p.sessions[sessionID] = sessionEntry{
tag: best.tagName(),
selectionScope: selectionScope,
createdAt: time.Now(),
}
p.sessionAccess.Unlock()
}
return best, isNew, nil
}
func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) {
key := credentialInterruptKey{tag: tag, selectionScope: selectionScope}
p.interruptAccess.Lock()
if entry, loaded := p.credentialInterrupts[key]; loaded {
entry.cancel()
}
ctx, cancel := context.WithCancel(context.Background())
p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel}
p.interruptAccess.Unlock()
p.sessionAccess.Lock()
for id, entry := range p.sessions {
if entry.tag == tag && entry.selectionScope == selectionScope {
delete(p.sessions, id)
}
}
p.sessionAccess.Unlock()
}
func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool {
if p.strategy == C.BalancerStrategyFallback {
return func() bool { return false }
}
key := credentialInterruptKey{
tag: cred.tagName(),
selectionScope: selection.scopeOrDefault(),
}
p.interruptAccess.Lock()
entry, loaded := p.credentialInterrupts[key]
if !loaded {
ctx, cancel := context.WithCancel(context.Background())
entry = credentialInterruptEntry{context: ctx, cancel: cancel}
p.credentialInterrupts[key] = entry
}
p.interruptAccess.Unlock()
return context.AfterFunc(entry.context, onInterrupt)
}
func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential {
cred.markRateLimited(resetAt)
if p.strategy == C.BalancerStrategyFallback {
return p.pickCredential(selection.filter)
}
if sessionID != "" {
p.sessionAccess.Lock()
delete(p.sessions, sessionID)
p.sessionAccess.Unlock()
}
best := p.pickCredential(selection.filter)
if best != nil && sessionID != "" {
p.sessionAccess.Lock()
p.sessions[sessionID] = sessionEntry{
tag: best.tagName(),
selectionScope: selection.scopeOrDefault(),
createdAt: time.Now(),
}
p.sessionAccess.Unlock()
}
return best
}
func (p *balancerProvider) pickCredential(filter func(credential) bool) credential {
switch p.strategy {
case C.BalancerStrategyRoundRobin:
return p.pickRoundRobin(filter)
case C.BalancerStrategyRandom:
return p.pickRandom(filter)
case C.BalancerStrategyFallback:
return p.pickFallback(filter)
default:
return p.pickLeastUsed(filter)
}
}
func (p *balancerProvider) pickFallback(filter func(credential) bool) credential {
for _, cred := range p.credentials {
if filter != nil && !filter(cred) {
continue
}
if cred.isUsable() {
return cred
}
}
return nil
}
const weeklyWindowHours = 7 * 24
func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential {
var best credential
bestScore := float64(-1)
now := time.Now()
for _, cred := range p.credentials {
if filter != nil && !filter(cred) {
continue
}
if !cred.isUsable() {
continue
}
remaining := cred.weeklyCap() - cred.weeklyUtilization()
score := remaining * cred.planWeight()
resetTime := cred.weeklyResetTime()
if !resetTime.IsZero() {
timeUntilReset := resetTime.Sub(now)
if timeUntilReset < time.Hour {
timeUntilReset = time.Hour
}
score *= weeklyWindowHours / timeUntilReset.Hours()
}
if score > bestScore {
bestScore = score
best = cred
}
}
return best
}
func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential {
start := int(p.roundRobinIndex.Add(1) - 1)
count := len(p.credentials)
for offset := range count {
candidate := p.credentials[(start+offset)%count]
if filter != nil && !filter(candidate) {
continue
}
if candidate.isUsable() {
return candidate
}
}
return nil
}
func (p *balancerProvider) pickRandom(filter func(credential) bool) credential {
var usable []credential
for _, candidate := range p.credentials {
if filter != nil && !filter(candidate) {
continue
}
if candidate.isUsable() {
usable = append(usable, candidate)
}
}
if len(usable) == 0 {
return nil
}
return usable[rand.IntN(len(usable))]
}
func (p *balancerProvider) pollIfStale(ctx context.Context) {
now := time.Now()
p.sessionAccess.Lock()
for id, entry := range p.sessions {
if now.Sub(entry.createdAt) > sessionExpiry {
delete(p.sessions, id)
}
}
p.sessionAccess.Unlock()
for _, cred := range p.credentials {
if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) {
cred.pollUsage(ctx)
}
}
}
func (p *balancerProvider) allCredentials() []credential {
return p.credentials
}
func (p *balancerProvider) close() {}
func ccmPlanWeight(accountType string, rateLimitTier string) float64 {
switch accountType {
case "max":
switch rateLimitTier {
case "default_claude_max_20x":
return 10
case "default_claude_max_5x":
return 5
default:
return 5
}
case "team":
if rateLimitTier == "default_claude_max_5x" {
return 5
}
return 1
default:
return 1
}
}
func allCredentialsUnavailableError(credentials []credential) error {
var hasUnavailable bool
var earliest time.Time
for _, cred := range credentials {
if cred.unavailableError() != nil {
hasUnavailable = true
continue
}
resetAt := cred.earliestReset()
if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) {
earliest = resetAt
}
}
if hasUnavailable {
return E.New("all credentials unavailable")
}
if earliest.IsZero() {
return E.New("all credentials rate-limited")
}
return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest)))
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,17 +1,12 @@
package ccm
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"mime"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
boxService "github.com/sagernet/sing-box/adapter/service"
@@ -21,23 +16,16 @@ import (
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
"github.com/anthropics/anthropic-sdk-go"
"github.com/go-chi/chi/v5"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
const (
contextWindowStandard = 200000
contextWindowPremium = 1000000
premiumContextThreshold = 200000
retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential"
)
const retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential"
func RegisterService(registry *boxService.Registry) {
boxService.Register[option.CCMServiceOptions](registry, C.TypeCCM, NewService)
@@ -152,23 +140,6 @@ func isReverseProxyHeader(header string) bool {
}
}
const (
weeklyWindowSeconds = 604800
weeklyWindowMinutes = weeklyWindowSeconds / 60
)
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
resetAt, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset")
if !exists {
return nil
}
return &WeeklyCycleHint{
WindowMinutes: weeklyWindowMinutes,
ResetAt: resetAt.UTC(),
}
}
type Service struct {
boxService.Adapter
ctx context.Context
@@ -308,545 +279,6 @@ func (s *Service) Start(stage adapter.StartStage) error {
return nil
}
func isExtendedContextRequest(betaHeader string) bool {
for _, feature := range strings.Split(betaHeader, ",") {
if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") {
return true
}
}
return false
}
func isFastModeRequest(betaHeader string) bool {
for _, feature := range strings.Split(betaHeader, ",") {
if strings.HasPrefix(strings.TrimSpace(feature), "fast-mode") {
return true
}
}
return false
}
func detectContextWindow(betaHeader string, totalInputTokens int64) int {
if totalInputTokens > premiumContextThreshold {
if isExtendedContextRequest(betaHeader) {
return contextWindowPremium
}
}
return contextWindowStandard
}
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := log.ContextWithNewID(r.Context())
if r.URL.Path == "/ccm/v1/status" {
s.handleStatusEndpoint(w, r)
return
}
if r.URL.Path == "/ccm/v1/reverse" {
s.handleReverseConnect(ctx, w, r)
return
}
if !strings.HasPrefix(r.URL.Path, "/v1/") {
writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found")
return
}
var username string
if len(s.options.Users) > 0 {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
var ok bool
username, ok = s.userManager.Authenticate(clientToken)
if !ok {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
}
// Always read body to extract model and session ID
var bodyBytes []byte
var requestModel string
var messagesCount int
var sessionID string
if r.Body != nil {
var err error
bodyBytes, err = io.ReadAll(r.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read request body: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
return
}
var request struct {
Model string `json:"model"`
Messages []anthropic.MessageParam `json:"messages"`
}
err = json.Unmarshal(bodyBytes, &request)
if err == nil {
requestModel = request.Model
messagesCount = len(request.Messages)
}
sessionID = extractCCMSessionID(bodyBytes)
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
// Resolve credential provider and user config
var provider credentialProvider
var userConfig *option.CCMUser
if len(s.options.Users) > 0 {
userConfig = s.userConfigMap[username]
var err error
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
s.logger.ErrorContext(ctx, "resolve credential: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
} else {
provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
}
if provider == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
return
}
provider.pollIfStale(s.ctx)
anthropicBetaHeader := r.Header.Get("anthropic-beta")
if isFastModeRequest(anthropicBetaHeader) {
if _, isSingle := provider.(*singleCredentialProvider); !isSingle {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"fast mode requests will consume Extra usage, please use a default credential directly")
return
}
}
selection := credentialSelectionForUser(userConfig)
selectedCredential, isNew, err := provider.selectCredential(sessionID, selection)
if err != nil {
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error()))
return
}
if isNew {
logParts := []any{"assigned credential ", selectedCredential.tagName()}
if sessionID != "" {
logParts = append(logParts, " for session ", sessionID)
}
if username != "" {
logParts = append(logParts, " by user ", username)
}
if requestModel != "" {
modelDisplay := requestModel
if isExtendedContextRequest(anthropicBetaHeader) {
modelDisplay += "[1m]"
}
logParts = append(logParts, ", model=", modelDisplay)
}
s.logger.DebugContext(ctx, logParts...)
}
if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"fast mode requests cannot be proxied through external credentials")
return
}
requestContext := selectedCredential.wrapRequestContext(ctx)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
defer func() {
requestContext.cancelRequest()
}()
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if err != nil {
s.logger.ErrorContext(ctx, "create proxy request: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
return
}
response, err := selectedCredential.httpTransport().Do(proxyRequest)
if err != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request")
return
}
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
return
}
requestContext.releaseCredentialInterrupt()
// Transparent 429 retry
for response.StatusCode == http.StatusTooManyRequests {
resetAt := parseRateLimitResetFromHeaders(response.Header)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
selectedCredential.updateStateFromHeaders(response.Header)
if bodyBytes == nil || nextCredential == nil {
response.Body.Close()
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
return
}
response.Body.Close()
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
requestContext.cancelRequest()
requestContext = nextCredential.wrapRequestContext(ctx)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
return
}
retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest)
if retryErr != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request")
return
}
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
return
}
requestContext.releaseCredentialInterrupt()
response = retryResponse
selectedCredential = nextCredential
}
defer response.Body.Close()
selectedCredential.updateStateFromHeaders(response.Header)
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
go selectedCredential.pollUsage(s.ctx)
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
return
}
// Rewrite response headers for external users
if userConfig != nil && userConfig.ExternalCredential != "" {
s.rewriteResponseHeadersForExternalUser(response.Header, userConfig)
}
for key, values := range response.Header {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
w.Header()[key] = values
}
}
w.WriteHeader(response.StatusCode)
usageTracker := selectedCredential.usageTrackerOrNil()
if usageTracker != nil && response.StatusCode == http.StatusOK {
s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
} else {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
if err == nil && mediaType != "text/event-stream" {
_, _ = io.Copy(w, response.Body)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
buffer := make([]byte, buf.BufferSize)
for {
n, err := response.Body.Read(buffer)
if n > 0 {
_, writeError := w.Write(buffer[:n])
if writeError != nil {
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
return
}
}
}
}
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
isStreaming := err == nil && mediaType == "text/event-stream"
if !isStreaming {
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read response body: ", err)
return
}
var message anthropic.Message
var usage anthropic.Usage
var responseModel string
err = json.Unmarshal(bodyBytes, &message)
if err == nil {
responseModel = string(message.Model)
usage = message.Usage
}
if responseModel == "" {
responseModel = requestModel
}
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
if responseModel != "" {
totalInputTokens := usage.InputTokens + usage.CacheCreationInputTokens + usage.CacheReadInputTokens
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
messagesCount,
usage.InputTokens,
usage.OutputTokens,
usage.CacheReadInputTokens,
usage.CacheCreationInputTokens,
usage.CacheCreation.Ephemeral5mInputTokens,
usage.CacheCreation.Ephemeral1hInputTokens,
username,
time.Now(),
weeklyCycleHint,
)
}
}
_, _ = writer.Write(bodyBytes)
return
}
flusher, ok := writer.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
var accumulatedUsage anthropic.Usage
var responseModel string
buffer := make([]byte, buf.BufferSize)
var leftover []byte
for {
n, err := response.Body.Read(buffer)
if n > 0 {
data := append(leftover, buffer[:n]...)
lines := bytes.Split(data, []byte("\n"))
if err == nil {
leftover = lines[len(lines)-1]
lines = lines[:len(lines)-1]
} else {
leftover = nil
}
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if bytes.HasPrefix(line, []byte("data: ")) {
eventData := bytes.TrimPrefix(line, []byte("data: "))
if bytes.Equal(eventData, []byte("[DONE]")) {
continue
}
var event anthropic.MessageStreamEventUnion
err := json.Unmarshal(eventData, &event)
if err != nil {
continue
}
switch event.Type {
case "message_start":
messageStart := event.AsMessageStart()
if messageStart.Message.Model != "" {
responseModel = string(messageStart.Message.Model)
}
if messageStart.Message.Usage.InputTokens > 0 {
accumulatedUsage.InputTokens = messageStart.Message.Usage.InputTokens
accumulatedUsage.CacheReadInputTokens = messageStart.Message.Usage.CacheReadInputTokens
accumulatedUsage.CacheCreationInputTokens = messageStart.Message.Usage.CacheCreationInputTokens
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral5mInputTokens
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral1hInputTokens
}
case "message_delta":
messageDelta := event.AsMessageDelta()
if messageDelta.Usage.OutputTokens > 0 {
accumulatedUsage.OutputTokens = messageDelta.Usage.OutputTokens
}
}
}
}
_, writeError := writer.Write(buffer[:n])
if writeError != nil {
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
if responseModel == "" {
responseModel = requestModel
}
if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 {
if responseModel != "" {
totalInputTokens := accumulatedUsage.InputTokens + accumulatedUsage.CacheCreationInputTokens + accumulatedUsage.CacheReadInputTokens
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
messagesCount,
accumulatedUsage.InputTokens,
accumulatedUsage.OutputTokens,
accumulatedUsage.CacheReadInputTokens,
accumulatedUsage.CacheCreationInputTokens,
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens,
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens,
username,
time.Now(),
weeklyCycleHint,
)
}
}
return
}
}
}
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
return
}
if len(s.options.Users) == 0 {
writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication")
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
username, ok := s.userManager.Authenticate(clientToken)
if !ok {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
userConfig := s.userConfigMap[username]
if userConfig == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found")
return
}
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
provider.pollIfStale(r.Context())
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]float64{
"five_hour_utilization": avgFiveHour,
"weekly_utilization": avgWeekly,
"plan_weight": totalWeight,
})
}
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) {
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
for _, cred := range provider.allCredentials() {
if !cred.isAvailable() {
continue
}
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
continue
}
if !userConfig.AllowExternalUsage && cred.isExternal() {
continue
}
weight := cred.planWeight()
remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization()
if remaining5h < 0 {
remaining5h = 0
}
remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization()
if remainingWeekly < 0 {
remainingWeekly = 0
}
totalWeightedRemaining5h += remaining5h * weight
totalWeightedRemainingWeekly += remainingWeekly * weight
totalWeight += weight
}
if totalWeight == 0 {
return 100, 100, 0
}
return 100 - totalWeightedRemaining5h/totalWeight,
100 - totalWeightedRemainingWeekly/totalWeight,
totalWeight
}
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.CCMUser) {
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name)
if err != nil {
return
}
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
// Rewrite utilization headers to aggregated average (convert back to 0.0-1.0 range)
headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64))
headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64))
if totalWeight > 0 {
headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64))
}
}
func (s *Service) InterfaceUpdated() {
for _, cred := range s.allCredentials {
extCred, ok := cred.(*externalCredential)

View File

@@ -0,0 +1,499 @@
package ccm
import (
"bytes"
"context"
"encoding/json"
"io"
"mime"
"net/http"
"strconv"
"strings"
"time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/buf"
"github.com/anthropics/anthropic-sdk-go"
)
const (
contextWindowStandard = 200000
contextWindowPremium = 1000000
premiumContextThreshold = 200000
)
const (
weeklyWindowSeconds = 604800
weeklyWindowMinutes = weeklyWindowSeconds / 60
)
func isExtendedContextRequest(betaHeader string) bool {
for _, feature := range strings.Split(betaHeader, ",") {
if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") {
return true
}
}
return false
}
func isFastModeRequest(betaHeader string) bool {
for _, feature := range strings.Split(betaHeader, ",") {
if strings.HasPrefix(strings.TrimSpace(feature), "fast-mode") {
return true
}
}
return false
}
func detectContextWindow(betaHeader string, totalInputTokens int64) int {
if totalInputTokens > premiumContextThreshold {
if isExtendedContextRequest(betaHeader) {
return contextWindowPremium
}
}
return contextWindowStandard
}
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
resetAt, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset")
if !exists {
return nil
}
return &WeeklyCycleHint{
WindowMinutes: weeklyWindowMinutes,
ResetAt: resetAt.UTC(),
}
}
func extractCCMSessionID(bodyBytes []byte) string {
var body struct {
Metadata struct {
UserID string `json:"user_id"`
} `json:"metadata"`
}
err := json.Unmarshal(bodyBytes, &body)
if err != nil {
return ""
}
userID := body.Metadata.UserID
sessionIndex := strings.LastIndex(userID, "_session_")
if sessionIndex < 0 {
return ""
}
return userID[sessionIndex+len("_session_"):]
}
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := log.ContextWithNewID(r.Context())
if r.URL.Path == "/ccm/v1/status" {
s.handleStatusEndpoint(w, r)
return
}
if r.URL.Path == "/ccm/v1/reverse" {
s.handleReverseConnect(ctx, w, r)
return
}
if !strings.HasPrefix(r.URL.Path, "/v1/") {
writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found")
return
}
var username string
if len(s.options.Users) > 0 {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
var ok bool
username, ok = s.userManager.Authenticate(clientToken)
if !ok {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
}
// Always read body to extract model and session ID
var bodyBytes []byte
var requestModel string
var messagesCount int
var sessionID string
if r.Body != nil {
var err error
bodyBytes, err = io.ReadAll(r.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read request body: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
return
}
var request struct {
Model string `json:"model"`
Messages []anthropic.MessageParam `json:"messages"`
}
err = json.Unmarshal(bodyBytes, &request)
if err == nil {
requestModel = request.Model
messagesCount = len(request.Messages)
}
sessionID = extractCCMSessionID(bodyBytes)
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
// Resolve credential provider and user config
var provider credentialProvider
var userConfig *option.CCMUser
if len(s.options.Users) > 0 {
userConfig = s.userConfigMap[username]
var err error
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
s.logger.ErrorContext(ctx, "resolve credential: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
} else {
provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
}
if provider == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
return
}
provider.pollIfStale(s.ctx)
anthropicBetaHeader := r.Header.Get("anthropic-beta")
if isFastModeRequest(anthropicBetaHeader) {
if _, isSingle := provider.(*singleCredentialProvider); !isSingle {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"fast mode requests will consume Extra usage, please use a default credential directly")
return
}
}
selection := credentialSelectionForUser(userConfig)
selectedCredential, isNew, err := provider.selectCredential(sessionID, selection)
if err != nil {
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error()))
return
}
if isNew {
logParts := []any{"assigned credential ", selectedCredential.tagName()}
if sessionID != "" {
logParts = append(logParts, " for session ", sessionID)
}
if username != "" {
logParts = append(logParts, " by user ", username)
}
if requestModel != "" {
modelDisplay := requestModel
if isExtendedContextRequest(anthropicBetaHeader) {
modelDisplay += "[1m]"
}
logParts = append(logParts, ", model=", modelDisplay)
}
s.logger.DebugContext(ctx, logParts...)
}
if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"fast mode requests cannot be proxied through external credentials")
return
}
requestContext := selectedCredential.wrapRequestContext(ctx)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
defer func() {
requestContext.cancelRequest()
}()
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if err != nil {
s.logger.ErrorContext(ctx, "create proxy request: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
return
}
response, err := selectedCredential.httpClient().Do(proxyRequest)
if err != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request")
return
}
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
return
}
requestContext.releaseCredentialInterrupt()
// Transparent 429 retry
for response.StatusCode == http.StatusTooManyRequests {
resetAt := parseRateLimitResetFromHeaders(response.Header)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
selectedCredential.updateStateFromHeaders(response.Header)
if bodyBytes == nil || nextCredential == nil {
response.Body.Close()
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
return
}
response.Body.Close()
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
requestContext.cancelRequest()
requestContext = nextCredential.wrapRequestContext(ctx)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
return
}
retryResponse, retryErr := nextCredential.httpClient().Do(retryRequest)
if retryErr != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request")
return
}
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
return
}
requestContext.releaseCredentialInterrupt()
response = retryResponse
selectedCredential = nextCredential
}
defer response.Body.Close()
selectedCredential.updateStateFromHeaders(response.Header)
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
go selectedCredential.pollUsage(s.ctx)
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
return
}
// Rewrite response headers for external users
if userConfig != nil && userConfig.ExternalCredential != "" {
s.rewriteResponseHeadersForExternalUser(response.Header, userConfig)
}
for key, values := range response.Header {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
w.Header()[key] = values
}
}
w.WriteHeader(response.StatusCode)
usageTracker := selectedCredential.usageTrackerOrNil()
if usageTracker != nil && response.StatusCode == http.StatusOK {
s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
} else {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
if err == nil && mediaType != "text/event-stream" {
_, _ = io.Copy(w, response.Body)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
buffer := make([]byte, buf.BufferSize)
for {
n, err := response.Body.Read(buffer)
if n > 0 {
_, writeError := w.Write(buffer[:n])
if writeError != nil {
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
return
}
}
}
}
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
isStreaming := err == nil && mediaType == "text/event-stream"
if !isStreaming {
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read response body: ", err)
return
}
var message anthropic.Message
var usage anthropic.Usage
var responseModel string
err = json.Unmarshal(bodyBytes, &message)
if err == nil {
responseModel = string(message.Model)
usage = message.Usage
}
if responseModel == "" {
responseModel = requestModel
}
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
if responseModel != "" {
totalInputTokens := usage.InputTokens + usage.CacheCreationInputTokens + usage.CacheReadInputTokens
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
messagesCount,
usage.InputTokens,
usage.OutputTokens,
usage.CacheReadInputTokens,
usage.CacheCreationInputTokens,
usage.CacheCreation.Ephemeral5mInputTokens,
usage.CacheCreation.Ephemeral1hInputTokens,
username,
time.Now(),
weeklyCycleHint,
)
}
}
_, _ = writer.Write(bodyBytes)
return
}
flusher, ok := writer.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
var accumulatedUsage anthropic.Usage
var responseModel string
buffer := make([]byte, buf.BufferSize)
var leftover []byte
for {
n, err := response.Body.Read(buffer)
if n > 0 {
data := append(leftover, buffer[:n]...)
lines := bytes.Split(data, []byte("\n"))
if err == nil {
leftover = lines[len(lines)-1]
lines = lines[:len(lines)-1]
} else {
leftover = nil
}
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if bytes.HasPrefix(line, []byte("data: ")) {
eventData := bytes.TrimPrefix(line, []byte("data: "))
if bytes.Equal(eventData, []byte("[DONE]")) {
continue
}
var event anthropic.MessageStreamEventUnion
err := json.Unmarshal(eventData, &event)
if err != nil {
continue
}
switch event.Type {
case "message_start":
messageStart := event.AsMessageStart()
if messageStart.Message.Model != "" {
responseModel = string(messageStart.Message.Model)
}
if messageStart.Message.Usage.InputTokens > 0 {
accumulatedUsage.InputTokens = messageStart.Message.Usage.InputTokens
accumulatedUsage.CacheReadInputTokens = messageStart.Message.Usage.CacheReadInputTokens
accumulatedUsage.CacheCreationInputTokens = messageStart.Message.Usage.CacheCreationInputTokens
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral5mInputTokens
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral1hInputTokens
}
case "message_delta":
messageDelta := event.AsMessageDelta()
if messageDelta.Usage.OutputTokens > 0 {
accumulatedUsage.OutputTokens = messageDelta.Usage.OutputTokens
}
}
}
}
_, writeError := writer.Write(buffer[:n])
if writeError != nil {
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
if responseModel == "" {
responseModel = requestModel
}
if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 {
if responseModel != "" {
totalInputTokens := accumulatedUsage.InputTokens + accumulatedUsage.CacheCreationInputTokens + accumulatedUsage.CacheReadInputTokens
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
messagesCount,
accumulatedUsage.InputTokens,
accumulatedUsage.OutputTokens,
accumulatedUsage.CacheReadInputTokens,
accumulatedUsage.CacheCreationInputTokens,
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens,
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens,
username,
time.Now(),
weeklyCycleHint,
)
}
}
return
}
}
}

View File

@@ -0,0 +1,109 @@
package ccm
import (
"encoding/json"
"net/http"
"strconv"
"strings"
"github.com/sagernet/sing-box/option"
)
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
return
}
if len(s.options.Users) == 0 {
writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication")
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
username, ok := s.userManager.Authenticate(clientToken)
if !ok {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
userConfig := s.userConfigMap[username]
if userConfig == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found")
return
}
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
provider.pollIfStale(r.Context())
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]float64{
"five_hour_utilization": avgFiveHour,
"weekly_utilization": avgWeekly,
"plan_weight": totalWeight,
})
}
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) {
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
for _, cred := range provider.allCredentials() {
if !cred.isAvailable() {
continue
}
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
continue
}
if !userConfig.AllowExternalUsage && cred.isExternal() {
continue
}
weight := cred.planWeight()
remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization()
if remaining5h < 0 {
remaining5h = 0
}
remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization()
if remainingWeekly < 0 {
remainingWeekly = 0
}
totalWeightedRemaining5h += remaining5h * weight
totalWeightedRemainingWeekly += remainingWeekly * weight
totalWeight += weight
}
if totalWeight == 0 {
return 100, 100, 0
}
return 100 - totalWeightedRemaining5h/totalWeight,
100 - totalWeightedRemainingWeekly/totalWeight,
totalWeight
}
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.CCMUser) {
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name)
if err != nil {
return
}
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64))
headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64))
if totalWeight > 0 {
headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64))
}
}

View File

@@ -7,13 +7,13 @@ import (
)
type UserManager struct {
accessMutex sync.RWMutex
access sync.RWMutex
tokenMap map[string]string
}
func (m *UserManager) UpdateUsers(users []option.CCMUser) {
m.accessMutex.Lock()
defer m.accessMutex.Unlock()
m.access.Lock()
defer m.access.Unlock()
tokenMap := make(map[string]string, len(users))
for _, user := range users {
tokenMap[user.Token] = user.Name
@@ -22,8 +22,8 @@ func (m *UserManager) UpdateUsers(users []option.CCMUser) {
}
func (m *UserManager) Authenticate(token string) (string, bool) {
m.accessMutex.RLock()
m.access.RLock()
username, found := m.tokenMap[token]
m.accessMutex.RUnlock()
m.access.RUnlock()
return username, found
}

View File

@@ -1,225 +1,194 @@
package ocm
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"os"
"os/user"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
)
const (
oauth2ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
oauth2TokenURL = "https://auth.openai.com/oauth/token"
openaiAPIBaseURL = "https://api.openai.com"
chatGPTBackendURL = "https://chatgpt.com/backend-api/codex"
tokenRefreshIntervalDays = 8
defaultPollInterval = 60 * time.Minute
failedPollRetryInterval = time.Minute
httpRetryMaxBackoff = 5 * time.Minute
)
func getRealUser() (*user.User, error) {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
sudoUserInfo, err := user.Lookup(sudoUser)
if err == nil {
return sudoUserInfo, nil
const (
httpRetryMaxAttempts = 3
httpRetryInitialDelay = 200 * time.Millisecond
)
const sessionExpiry = 24 * time.Hour
func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) {
var lastError error
for attempt := range httpRetryMaxAttempts {
if attempt > 0 {
delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1))
select {
case <-ctx.Done():
return nil, lastError
case <-time.After(delay):
}
}
}
return user.Current()
}
func getDefaultCredentialsPath() (string, error) {
if codexHome := os.Getenv("CODEX_HOME"); codexHome != "" {
return filepath.Join(codexHome, "auth.json"), nil
}
userInfo, err := getRealUser()
if err != nil {
return "", err
}
return filepath.Join(userInfo.HomeDir, ".codex", "auth.json"), nil
}
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var credentials oauthCredentials
err = json.Unmarshal(data, &credentials)
if err != nil {
return nil, err
}
return &credentials, nil
}
func checkCredentialFileWritable(path string) error {
file, err := os.OpenFile(path, os.O_WRONLY, 0)
if err != nil {
return err
}
return file.Close()
}
func writeCredentialsToFile(credentials *oauthCredentials, path string) error {
data, err := json.MarshalIndent(credentials, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o600)
}
type oauthCredentials struct {
APIKey string `json:"OPENAI_API_KEY,omitempty"`
Tokens *tokenData `json:"tokens,omitempty"`
LastRefresh *time.Time `json:"last_refresh,omitempty"`
}
type tokenData struct {
IDToken string `json:"id_token,omitempty"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
AccountID string `json:"account_id,omitempty"`
}
func (c *oauthCredentials) isAPIKeyMode() bool {
return c.APIKey != ""
}
func (c *oauthCredentials) getAccessToken() string {
if c.APIKey != "" {
return c.APIKey
}
if c.Tokens != nil {
return c.Tokens.AccessToken
}
return ""
}
func (c *oauthCredentials) getAccountID() string {
if c.Tokens != nil {
return c.Tokens.AccountID
}
return ""
}
func (c *oauthCredentials) needsRefresh() bool {
if c.APIKey != "" {
return false
}
if c.Tokens == nil || c.Tokens.RefreshToken == "" {
return false
}
if c.LastRefresh == nil {
return true
}
return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour
}
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" {
return nil, E.New("refresh token is empty")
}
requestBody, err := json.Marshal(map[string]string{
"grant_type": "refresh_token",
"refresh_token": credentials.Tokens.RefreshToken,
"client_id": oauth2ClientID,
"scope": "openid profile email",
})
if err != nil {
return nil, E.Cause(err, "marshal request")
}
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
request, err := buildRequest()
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
return request, nil
response, err := client.Do(request)
if err == nil {
return response, nil
}
lastError = err
if ctx.Err() != nil {
return nil, lastError
}
}
return nil, lastError
}
type credentialState struct {
fiveHourUtilization float64
fiveHourReset time.Time
weeklyUtilization float64
weeklyReset time.Time
hardRateLimited bool
rateLimitResetAt time.Time
accountType string
remotePlanWeight float64
lastUpdated time.Time
consecutivePollFailures int
unavailable bool
lastCredentialLoadAttempt time.Time
lastCredentialLoadError string
}
type credentialRequestContext struct {
context.Context
releaseOnce sync.Once
cancelOnce sync.Once
releaseFuncs []func() bool
cancelFunc context.CancelFunc
}
func (c *credentialRequestContext) addInterruptLink(stop func() bool) {
c.releaseFuncs = append(c.releaseFuncs, stop)
}
func (c *credentialRequestContext) releaseCredentialInterrupt() {
c.releaseOnce.Do(func() {
for _, f := range c.releaseFuncs {
f()
}
})
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode == http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
return nil, E.New("refresh rate limited: ", response.Status, " ", string(body))
}
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
}
var tokenResponse struct {
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
if err != nil {
return nil, E.Cause(err, "decode response")
}
newCredentials := *credentials
if newCredentials.Tokens == nil {
newCredentials.Tokens = &tokenData{}
}
if tokenResponse.IDToken != "" {
newCredentials.Tokens.IDToken = tokenResponse.IDToken
}
if tokenResponse.AccessToken != "" {
newCredentials.Tokens.AccessToken = tokenResponse.AccessToken
}
if tokenResponse.RefreshToken != "" {
newCredentials.Tokens.RefreshToken = tokenResponse.RefreshToken
}
now := time.Now()
newCredentials.LastRefresh = &now
return &newCredentials, nil
}
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
if credentials == nil {
return nil
}
cloned := *credentials
if credentials.Tokens != nil {
clonedTokens := *credentials.Tokens
cloned.Tokens = &clonedTokens
}
if credentials.LastRefresh != nil {
lastRefresh := *credentials.LastRefresh
cloned.LastRefresh = &lastRefresh
}
return &cloned
func (c *credentialRequestContext) cancelRequest() {
c.releaseCredentialInterrupt()
c.cancelOnce.Do(c.cancelFunc)
}
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
if left == nil || right == nil {
return left == right
}
if left.APIKey != right.APIKey {
return false
}
if (left.Tokens == nil) != (right.Tokens == nil) {
return false
}
if left.Tokens != nil && *left.Tokens != *right.Tokens {
return false
}
if (left.LastRefresh == nil) != (right.LastRefresh == nil) {
return false
}
if left.LastRefresh != nil && !left.LastRefresh.Equal(*right.LastRefresh) {
return false
}
return true
type credential interface {
tagName() string
isAvailable() bool
isUsable() bool
isExternal() bool
fiveHourUtilization() float64
weeklyUtilization() float64
fiveHourCap() float64
weeklyCap() float64
planWeight() float64
weeklyResetTime() time.Time
markRateLimited(resetAt time.Time)
earliestReset() time.Time
unavailableError() error
getAccessToken() (string, error)
buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error)
updateStateFromHeaders(header http.Header)
wrapRequestContext(ctx context.Context) *credentialRequestContext
interruptConnections()
setOnBecameUnusable(fn func())
start() error
pollUsage(ctx context.Context)
lastUpdatedTime() time.Time
pollBackoff(base time.Duration) time.Duration
usageTrackerOrNil() *AggregatedUsage
httpClient() *http.Client
close()
// OCM-specific
ocmDialer() N.Dialer
ocmIsAPIKeyMode() bool
ocmGetAccountID() string
ocmGetBaseURL() string
}
type credentialSelectionScope string
const (
credentialSelectionScopeAll credentialSelectionScope = "all"
credentialSelectionScopeNonExternal credentialSelectionScope = "non_external"
)
type credentialSelection struct {
scope credentialSelectionScope
filter func(credential) bool
}
func (s credentialSelection) allows(cred credential) bool {
return s.filter == nil || s.filter(cred)
}
func (s credentialSelection) scopeOrDefault() credentialSelectionScope {
if s.scope == "" {
return credentialSelectionScopeAll
}
return s.scope
}
func normalizeRateLimitIdentifier(limitIdentifier string) string {
trimmedIdentifier := strings.TrimSpace(strings.ToLower(limitIdentifier))
if trimmedIdentifier == "" {
return ""
}
return strings.ReplaceAll(trimmedIdentifier, "_", "-")
}
func parseInt64Header(headers http.Header, headerName string) (int64, bool) {
headerValue := strings.TrimSpace(headers.Get(headerName))
if headerValue == "" {
return 0, false
}
parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64)
if parseError != nil {
return 0, false
}
return parsedValue, true
}
func parseOCMRateLimitResetFromHeaders(headers http.Header) time.Time {
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
if activeLimitIdentifier != "" {
resetHeader := "x-" + activeLimitIdentifier + "-primary-reset-at"
if resetStr := headers.Get(resetHeader); resetStr != "" {
value, err := strconv.ParseInt(resetStr, 10, 64)
if err == nil {
return time.Unix(value, 0)
}
}
}
if retryAfter := headers.Get("Retry-After"); retryAfter != "" {
seconds, err := strconv.ParseInt(retryAfter, 10, 64)
if err == nil {
return time.Now().Add(time.Duration(seconds) * time.Second)
}
}
return time.Now().Add(5 * time.Minute)
}

View File

@@ -0,0 +1,223 @@
package ocm
import (
"context"
"time"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
func buildOCMCredentialProviders(
ctx context.Context,
options option.OCMServiceOptions,
logger log.ContextLogger,
) (map[string]credentialProvider, []credential, error) {
allCredentialMap := make(map[string]credential)
var allCreds []credential
providers := make(map[string]credentialProvider)
// Pass 1: create default and external credentials
for _, credOpt := range options.Credentials {
switch credOpt.Type {
case "default":
cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger)
if err != nil {
return nil, nil, err
}
allCredentialMap[credOpt.Tag] = cred
allCreds = append(allCreds, cred)
providers[credOpt.Tag] = &singleCredentialProvider{cred: cred}
case "external":
cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger)
if err != nil {
return nil, nil, err
}
allCredentialMap[credOpt.Tag] = cred
allCreds = append(allCreds, cred)
providers[credOpt.Tag] = &singleCredentialProvider{cred: cred}
}
}
// Pass 2: create balancer providers
for _, credOpt := range options.Credentials {
if credOpt.Type == "balancer" {
subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag)
if err != nil {
return nil, nil, err
}
providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger)
}
}
return providers, allCreds, nil
}
func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) {
credentials := make([]credential, 0, len(tags))
for _, tag := range tags {
cred, exists := allCredentials[tag]
if !exists {
return nil, E.New("credential ", parentTag, " references unknown credential: ", tag)
}
credentials = append(credentials, cred)
}
if len(credentials) == 0 {
return nil, E.New("credential ", parentTag, " has no sub-credentials")
}
return credentials, nil
}
func validateOCMOptions(options option.OCMServiceOptions) error {
hasCredentials := len(options.Credentials) > 0
hasLegacyPath := options.CredentialPath != ""
hasLegacyUsages := options.UsagesPath != ""
hasLegacyDetour := options.Detour != ""
if hasCredentials && hasLegacyPath {
return E.New("credential_path and credentials are mutually exclusive")
}
if hasCredentials && hasLegacyUsages {
return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials")
}
if hasCredentials && hasLegacyDetour {
return E.New("detour and credentials are mutually exclusive; use detour on individual credentials")
}
if hasCredentials {
tags := make(map[string]bool)
credentialTypes := make(map[string]string)
for _, cred := range options.Credentials {
if tags[cred.Tag] {
return E.New("duplicate credential tag: ", cred.Tag)
}
tags[cred.Tag] = true
credentialTypes[cred.Tag] = cred.Type
if cred.Type == "default" || cred.Type == "" {
if cred.DefaultOptions.Reserve5h > 99 {
return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99")
}
if cred.DefaultOptions.ReserveWeekly > 99 {
return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99")
}
if cred.DefaultOptions.Limit5h > 100 {
return E.New("credential ", cred.Tag, ": limit_5h must be at most 100")
}
if cred.DefaultOptions.LimitWeekly > 100 {
return E.New("credential ", cred.Tag, ": limit_weekly must be at most 100")
}
if cred.DefaultOptions.Reserve5h > 0 && cred.DefaultOptions.Limit5h > 0 {
return E.New("credential ", cred.Tag, ": reserve_5h and limit_5h are mutually exclusive")
}
if cred.DefaultOptions.ReserveWeekly > 0 && cred.DefaultOptions.LimitWeekly > 0 {
return E.New("credential ", cred.Tag, ": reserve_weekly and limit_weekly are mutually exclusive")
}
}
if cred.Type == "external" {
if cred.ExternalOptions.Token == "" {
return E.New("credential ", cred.Tag, ": external credential requires token")
}
if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" {
return E.New("credential ", cred.Tag, ": reverse external credential requires url")
}
}
if cred.Type == "balancer" {
switch cred.BalancerOptions.Strategy {
case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback:
default:
return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy)
}
if cred.BalancerOptions.RebalanceThreshold < 0 {
return E.New("credential ", cred.Tag, ": rebalance_threshold must not be negative")
}
}
}
for _, user := range options.Users {
if user.Credential == "" {
return E.New("user ", user.Name, " must specify credential in multi-credential mode")
}
if !tags[user.Credential] {
return E.New("user ", user.Name, " references unknown credential: ", user.Credential)
}
if user.ExternalCredential != "" {
if !tags[user.ExternalCredential] {
return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential)
}
if credentialTypes[user.ExternalCredential] != "external" {
return E.New("user ", user.Name, ": external_credential must reference an external type credential")
}
}
}
}
return nil
}
func validateOCMCompositeCredentialModes(
options option.OCMServiceOptions,
providers map[string]credentialProvider,
) error {
for _, credOpt := range options.Credentials {
if credOpt.Type != "balancer" {
continue
}
provider, exists := providers[credOpt.Tag]
if !exists {
return E.New("unknown credential: ", credOpt.Tag)
}
for _, subCred := range provider.allCredentials() {
if !subCred.isAvailable() {
continue
}
if subCred.ocmIsAPIKeyMode() {
return E.New(
"credential ", credOpt.Tag,
" references API key default credential ", subCred.tagName(),
"; balancer and fallback only support OAuth default credentials",
)
}
}
}
return nil
}
func credentialForUser(
userConfigMap map[string]*option.OCMUser,
providers map[string]credentialProvider,
legacyProvider credentialProvider,
username string,
) (credentialProvider, error) {
if legacyProvider != nil {
return legacyProvider, nil
}
userConfig, exists := userConfigMap[username]
if !exists {
return nil, E.New("no credential mapping for user: ", username)
}
provider, exists := providers[userConfig.Credential]
if !exists {
return nil, E.New("unknown credential: ", userConfig.Credential)
}
return provider, nil
}
func noUserCredentialProvider(
providers map[string]credentialProvider,
legacyProvider credentialProvider,
options option.OCMServiceOptions,
) credentialProvider {
if legacyProvider != nil {
return legacyProvider
}
if len(options.Credentials) > 0 {
tag := options.Credentials[0].Tag
return providers[tag]
}
return nil
}

View File

@@ -0,0 +1,749 @@
package ocm
import (
"bytes"
"context"
stdTLS "crypto/tls"
"encoding/json"
"io"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
)
type defaultCredential struct {
tag string
serviceContext context.Context
credentialPath string
credentialFilePath string
credentials *oauthCredentials
access sync.RWMutex
state credentialState
stateAccess sync.RWMutex
pollAccess sync.Mutex
reloadAccess sync.Mutex
watcherAccess sync.Mutex
cap5h float64
capWeekly float64
usageTracker *AggregatedUsage
dialer N.Dialer
forwardHTTPClient *http.Client
logger log.ContextLogger
watcher *fswatch.Watcher
watcherRetryAt time.Time
// Connection interruption
onBecameUnusable func()
interrupted bool
requestContext context.Context
cancelRequests context.CancelFunc
requestAccess sync.Mutex
}
func newDefaultCredential(ctx context.Context, tag string, options option.OCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) {
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: option.DialerOptions{
Detour: options.Detour,
},
RemoteIsDomain: true,
})
if err != nil {
return nil, E.Cause(err, "create dialer for credential ", tag)
}
httpClient := &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: true,
TLSClientConfig: &stdTLS.Config{
RootCAs: adapter.RootPoolFromContext(ctx),
Time: ntp.TimeFuncFromContext(ctx),
},
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
},
}
reserve5h := options.Reserve5h
if reserve5h == 0 {
reserve5h = 1
}
reserveWeekly := options.ReserveWeekly
if reserveWeekly == 0 {
reserveWeekly = 1
}
var cap5h float64
if options.Limit5h > 0 {
cap5h = float64(options.Limit5h)
} else {
cap5h = float64(100 - reserve5h)
}
var capWeekly float64
if options.LimitWeekly > 0 {
capWeekly = float64(options.LimitWeekly)
} else {
capWeekly = float64(100 - reserveWeekly)
}
requestContext, cancelRequests := context.WithCancel(context.Background())
credential := &defaultCredential{
tag: tag,
serviceContext: ctx,
credentialPath: options.CredentialPath,
cap5h: cap5h,
capWeekly: capWeekly,
dialer: credentialDialer,
forwardHTTPClient: httpClient,
logger: logger,
requestContext: requestContext,
cancelRequests: cancelRequests,
}
if options.UsagesPath != "" {
credential.usageTracker = &AggregatedUsage{
LastUpdated: time.Now(),
Combinations: make([]CostCombination, 0),
filePath: options.UsagesPath,
logger: logger,
}
}
return credential, nil
}
func (c *defaultCredential) start() error {
credentialFilePath, err := resolveCredentialFilePath(c.credentialPath)
if err != nil {
return E.Cause(err, "resolve credential path for ", c.tag)
}
c.credentialFilePath = credentialFilePath
err = c.ensureCredentialWatcher()
if err != nil {
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
}
err = c.reloadCredentials(true)
if err != nil {
c.logger.Warn("initial credential load for ", c.tag, ": ", err)
}
if c.usageTracker != nil {
err = c.usageTracker.Load()
if err != nil {
c.logger.Warn("load usage statistics for ", c.tag, ": ", err)
}
}
return nil
}
func (c *defaultCredential) setOnBecameUnusable(fn func()) {
c.onBecameUnusable = fn
}
func (c *defaultCredential) tagName() string {
return c.tag
}
func (c *defaultCredential) isExternal() bool {
return false
}
func (c *defaultCredential) getAccessToken() (string, error) {
c.retryCredentialReloadIfNeeded()
c.access.RLock()
if c.credentials != nil && !c.credentials.needsRefresh() {
token := c.credentials.getAccessToken()
c.access.RUnlock()
return token, nil
}
c.access.RUnlock()
err := c.reloadCredentials(true)
if err == nil {
c.access.RLock()
if c.credentials != nil && !c.credentials.needsRefresh() {
token := c.credentials.getAccessToken()
c.access.RUnlock()
return token, nil
}
c.access.RUnlock()
}
c.access.Lock()
defer c.access.Unlock()
if c.credentials == nil {
return "", c.unavailableError()
}
if !c.credentials.needsRefresh() {
return c.credentials.getAccessToken(), nil
}
err = platformCanWriteCredentials(c.credentialPath)
if err != nil {
return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation")
}
baseCredentials := cloneCredentials(c.credentials)
newCredentials, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials)
if err != nil {
return "", err
}
latestCredentials, latestErr := platformReadCredentials(c.credentialPath)
if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) {
c.credentials = latestCredentials
c.stateAccess.Lock()
c.state.unavailable = false
c.state.lastCredentialLoadAttempt = time.Now()
c.state.lastCredentialLoadError = ""
c.checkTransitionLocked()
c.stateAccess.Unlock()
if !latestCredentials.needsRefresh() {
return latestCredentials.getAccessToken(), nil
}
return "", E.New("credential ", c.tag, " changed while refreshing")
}
c.credentials = newCredentials
c.stateAccess.Lock()
c.state.unavailable = false
c.state.lastCredentialLoadAttempt = time.Now()
c.state.lastCredentialLoadError = ""
c.checkTransitionLocked()
c.stateAccess.Unlock()
err = platformWriteCredentials(newCredentials, c.credentialPath)
if err != nil {
c.logger.Error("persist refreshed token for ", c.tag, ": ", err)
}
return newCredentials.getAccessToken(), nil
}
func (c *defaultCredential) getAccountID() string {
c.access.RLock()
defer c.access.RUnlock()
if c.credentials == nil {
return ""
}
return c.credentials.getAccountID()
}
func (c *defaultCredential) isAPIKeyMode() bool {
c.access.RLock()
defer c.access.RUnlock()
if c.credentials == nil {
return false
}
return c.credentials.isAPIKeyMode()
}
func (c *defaultCredential) getBaseURL() string {
if c.isAPIKeyMode() {
return openaiAPIBaseURL
}
return chatGPTBackendURL
}
func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
c.stateAccess.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
hadData := false
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
if activeLimitIdentifier == "" {
activeLimitIdentifier = "codex"
}
fiveHourResetChanged := false
fiveHourResetAt := headers.Get("x-" + activeLimitIdentifier + "-primary-reset-at")
if fiveHourResetAt != "" {
value, err := strconv.ParseInt(fiveHourResetAt, 10, 64)
if err == nil {
hadData = true
newReset := time.Unix(value, 0)
if newReset.After(c.state.fiveHourReset) {
fiveHourResetChanged = true
c.state.fiveHourReset = newReset
}
}
}
fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent")
if fiveHourPercent != "" {
value, err := strconv.ParseFloat(fiveHourPercent, 64)
if err == nil {
hadData = true
if value >= c.state.fiveHourUtilization || fiveHourResetChanged {
c.state.fiveHourUtilization = value
}
}
}
weeklyResetChanged := false
weeklyResetAt := headers.Get("x-" + activeLimitIdentifier + "-secondary-reset-at")
if weeklyResetAt != "" {
value, err := strconv.ParseInt(weeklyResetAt, 10, 64)
if err == nil {
hadData = true
newReset := time.Unix(value, 0)
if newReset.After(c.state.weeklyReset) {
weeklyResetChanged = true
c.state.weeklyReset = newReset
}
}
}
weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent")
if weeklyPercent != "" {
value, err := strconv.ParseFloat(weeklyPercent, 64)
if err == nil {
hadData = true
if value >= c.state.weeklyUtilization || weeklyResetChanged {
c.state.weeklyUtilization = value
}
}
}
if hadData {
c.state.consecutivePollFailures = 0
c.state.lastUpdated = time.Now()
}
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
resetSuffix := ""
if !c.state.weeklyReset.IsZero() {
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
}
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
}
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *defaultCredential) markRateLimited(resetAt time.Time) {
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
c.stateAccess.Lock()
c.state.hardRateLimited = true
c.state.rateLimitResetAt = resetAt
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *defaultCredential) isUsable() bool {
c.retryCredentialReloadIfNeeded()
c.stateAccess.RLock()
if c.state.unavailable {
c.stateAccess.RUnlock()
return false
}
if c.state.consecutivePollFailures > 0 {
c.stateAccess.RUnlock()
return false
}
if c.state.hardRateLimited {
if time.Now().Before(c.state.rateLimitResetAt) {
c.stateAccess.RUnlock()
return false
}
c.stateAccess.RUnlock()
c.stateAccess.Lock()
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
usable := c.checkReservesLocked()
c.stateAccess.Unlock()
return usable
}
usable := c.checkReservesLocked()
c.stateAccess.RUnlock()
return usable
}
func (c *defaultCredential) checkReservesLocked() bool {
if c.state.fiveHourUtilization >= c.cap5h {
return false
}
if c.state.weeklyUtilization >= c.capWeekly {
return false
}
return true
}
// checkTransitionLocked detects usable->unusable transition.
// Must be called with stateAccess write lock held.
func (c *defaultCredential) checkTransitionLocked() bool {
unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() || c.state.consecutivePollFailures > 0
if unusable && !c.interrupted {
c.interrupted = true
return true
}
if !unusable && c.interrupted {
c.interrupted = false
}
return false
}
func (c *defaultCredential) interruptConnections() {
c.logger.Warn("interrupting connections for ", c.tag)
c.requestAccess.Lock()
c.cancelRequests()
c.requestContext, c.cancelRequests = context.WithCancel(context.Background())
c.requestAccess.Unlock()
if c.onBecameUnusable != nil {
c.onBecameUnusable()
}
}
func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext {
c.requestAccess.Lock()
credentialContext := c.requestContext
c.requestAccess.Unlock()
derived, cancel := context.WithCancel(parent)
stop := context.AfterFunc(credentialContext, func() {
cancel()
})
return &credentialRequestContext{
Context: derived,
releaseFuncs: []func() bool{stop},
cancelFunc: cancel,
}
}
func (c *defaultCredential) fiveHourUtilization() float64 {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.fiveHourUtilization
}
func (c *defaultCredential) weeklyUtilization() float64 {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.weeklyUtilization
}
func (c *defaultCredential) planWeight() float64 {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return ocmPlanWeight(c.state.accountType)
}
func (c *defaultCredential) weeklyResetTime() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.weeklyReset
}
func (c *defaultCredential) isAvailable() bool {
c.retryCredentialReloadIfNeeded()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return !c.state.unavailable
}
func (c *defaultCredential) unavailableError() error {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
if !c.state.unavailable {
return nil
}
if c.state.lastCredentialLoadError == "" {
return E.New("credential ", c.tag, " is unavailable")
}
return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError)
}
func (c *defaultCredential) lastUpdatedTime() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.lastUpdated
}
func (c *defaultCredential) markUsagePollAttempted() {
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
c.state.lastUpdated = time.Now()
}
func (c *defaultCredential) incrementPollFailures() {
c.stateAccess.Lock()
c.state.consecutivePollFailures++
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration {
c.stateAccess.RLock()
failures := c.state.consecutivePollFailures
c.stateAccess.RUnlock()
if failures <= 0 {
return baseInterval
}
backoff := failedPollRetryInterval * time.Duration(1<<(failures-1))
if backoff > httpRetryMaxBackoff {
return httpRetryMaxBackoff
}
return backoff
}
func (c *defaultCredential) isPollBackoffAtCap() bool {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
failures := c.state.consecutivePollFailures
return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff
}
func (c *defaultCredential) earliestReset() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
if c.state.unavailable {
return time.Time{}
}
if c.state.hardRateLimited {
return c.state.rateLimitResetAt
}
earliest := c.state.fiveHourReset
if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) {
earliest = c.state.weeklyReset
}
return earliest
}
func (c *defaultCredential) fiveHourCap() float64 {
return c.cap5h
}
func (c *defaultCredential) weeklyCap() float64 {
return c.capWeekly
}
func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage {
return c.usageTracker
}
func (c *defaultCredential) httpClient() *http.Client {
return c.forwardHTTPClient
}
func (c *defaultCredential) ocmDialer() N.Dialer {
return c.dialer
}
func (c *defaultCredential) ocmIsAPIKeyMode() bool {
return c.isAPIKeyMode()
}
func (c *defaultCredential) ocmGetAccountID() string {
return c.getAccountID()
}
func (c *defaultCredential) ocmGetBaseURL() string {
return c.getBaseURL()
}
func (c *defaultCredential) pollUsage(ctx context.Context) {
if !c.pollAccess.TryLock() {
return
}
defer c.pollAccess.Unlock()
defer c.markUsagePollAttempted()
c.retryCredentialReloadIfNeeded()
if !c.isAvailable() {
return
}
if c.isAPIKeyMode() {
return
}
accessToken, err := c.getAccessToken()
if err != nil {
if !c.isPollBackoffAtCap() {
c.logger.Error("poll usage for ", c.tag, ": get token: ", err)
}
c.incrementPollFailures()
return
}
var usageURL string
if c.isAPIKeyMode() {
usageURL = openaiAPIBaseURL + "/api/codex/usage"
} else {
usageURL = strings.TrimSuffix(chatGPTBackendURL, "/codex") + "/wham/usage"
}
accountID := c.getAccountID()
pollClient := &http.Client{
Transport: c.forwardHTTPClient.Transport,
Timeout: 5 * time.Second,
}
response, err := doHTTPWithRetry(ctx, pollClient, func() (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, usageURL, nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+accessToken)
if accountID != "" {
request.Header.Set("ChatGPT-Account-Id", accountID)
}
return request, nil
})
if err != nil {
if !c.isPollBackoffAtCap() {
c.logger.Error("poll usage for ", c.tag, ": ", err)
}
c.incrementPollFailures()
return
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
if response.StatusCode == http.StatusTooManyRequests {
c.logger.Warn("poll usage for ", c.tag, ": rate limited")
}
body, _ := io.ReadAll(response.Body)
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
c.incrementPollFailures()
return
}
type usageWindow struct {
UsedPercent float64 `json:"used_percent"`
ResetAt int64 `json:"reset_at"`
}
var usageResponse struct {
PlanType string `json:"plan_type"`
RateLimit *struct {
PrimaryWindow *usageWindow `json:"primary_window"`
SecondaryWindow *usageWindow `json:"secondary_window"`
} `json:"rate_limit"`
}
err = json.NewDecoder(response.Body).Decode(&usageResponse)
if err != nil {
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
c.incrementPollFailures()
return
}
c.stateAccess.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
c.state.consecutivePollFailures = 0
if usageResponse.RateLimit != nil {
if w := usageResponse.RateLimit.PrimaryWindow; w != nil {
c.state.fiveHourUtilization = w.UsedPercent
if w.ResetAt > 0 {
c.state.fiveHourReset = time.Unix(w.ResetAt, 0)
}
}
if w := usageResponse.RateLimit.SecondaryWindow; w != nil {
c.state.weeklyUtilization = w.UsedPercent
if w.ResetAt > 0 {
c.state.weeklyReset = time.Unix(w.ResetAt, 0)
}
}
}
if usageResponse.PlanType != "" {
c.state.accountType = usageResponse.PlanType
}
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
resetSuffix := ""
if !c.state.weeklyReset.IsZero() {
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
}
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
}
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) {
accessToken, err := c.getAccessToken()
if err != nil {
return nil, E.Cause(err, "get access token for ", c.tag)
}
path := original.URL.Path
var proxyPath string
if c.isAPIKeyMode() {
proxyPath = path
} else {
proxyPath = strings.TrimPrefix(path, "/v1")
}
proxyURL := c.getBaseURL() + proxyPath
if original.URL.RawQuery != "" {
proxyURL += "?" + original.URL.RawQuery
}
var body io.Reader
if bodyBytes != nil {
body = bytes.NewReader(bodyBytes)
} else {
body = original.Body
}
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
if err != nil {
return nil, err
}
for key, values := range original.Header {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" {
proxyRequest.Header[key] = values
}
}
for key, values := range serviceHeaders {
proxyRequest.Header.Del(key)
proxyRequest.Header[key] = values
}
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
if accountID := c.getAccountID(); accountID != "" {
proxyRequest.Header.Set("ChatGPT-Account-Id", accountID)
}
return proxyRequest, nil
}
func (c *defaultCredential) close() {
if c.watcher != nil {
err := c.watcher.Close()
if err != nil {
c.logger.Error("close credential watcher for ", c.tag, ": ", err)
}
}
if c.usageTracker != nil {
c.usageTracker.cancelPendingSave()
err := c.usageTracker.Save()
if err != nil {
c.logger.Error("save usage statistics for ", c.tag, ": ", err)
}
}
}

View File

@@ -30,17 +30,17 @@ import (
const reverseProxyBaseURL = "http://reverse-proxy"
type externalCredential struct {
tag string
baseURL string
token string
credDialer N.Dialer
httpClient *http.Client
state credentialState
stateMutex sync.RWMutex
pollAccess sync.Mutex
pollInterval time.Duration
usageTracker *AggregatedUsage
logger log.ContextLogger
tag string
baseURL string
token string
credDialer N.Dialer
forwardHTTPClient *http.Client
state credentialState
stateAccess sync.RWMutex
pollAccess sync.Mutex
pollInterval time.Duration
usageTracker *AggregatedUsage
logger log.ContextLogger
onBecameUnusable func()
interrupted bool
@@ -147,7 +147,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx
// Receiver mode: no URL, wait for reverse connection
cred.baseURL = reverseProxyBaseURL
cred.credDialer = reverseSessionDialer{credential: cred}
cred.httpClient = &http.Client{
cred.forwardHTTPClient = &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: false,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
@@ -211,11 +211,11 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx
Time: ntp.TimeFuncFromContext(ctx),
}
}
cred.httpClient = &http.Client{Transport: transport}
cred.forwardHTTPClient = &http.Client{Transport: transport}
} else {
// Normal mode: standard HTTP client for proxying
cred.credDialer = credentialDialer
cred.httpClient = &http.Client{Transport: transport}
cred.forwardHTTPClient = &http.Client{Transport: transport}
cred.reverseCredDialer = reverseSessionDialer{credential: cred}
cred.reverseHttpClient = &http.Client{
Transport: &http.Transport{
@@ -273,39 +273,39 @@ func (c *externalCredential) isUsable() bool {
if !c.isAvailable() {
return false
}
c.stateMutex.RLock()
c.stateAccess.RLock()
if c.state.consecutivePollFailures > 0 {
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
return false
}
if c.state.hardRateLimited {
if time.Now().Before(c.state.rateLimitResetAt) {
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
return false
}
c.stateMutex.RUnlock()
c.stateMutex.Lock()
c.stateAccess.RUnlock()
c.stateAccess.Lock()
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
c.stateMutex.Unlock()
c.stateAccess.Unlock()
return usable
}
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
return usable
}
func (c *externalCredential) fiveHourUtilization() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.fiveHourUtilization
}
func (c *externalCredential) weeklyUtilization() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.weeklyUtilization
}
@@ -318,8 +318,8 @@ func (c *externalCredential) weeklyCap() float64 {
}
func (c *externalCredential) planWeight() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
if c.state.remotePlanWeight > 0 {
return c.state.remotePlanWeight
}
@@ -327,26 +327,26 @@ func (c *externalCredential) planWeight() float64 {
}
func (c *externalCredential) weeklyResetTime() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.weeklyReset
}
func (c *externalCredential) markRateLimited(resetAt time.Time) {
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.hardRateLimited = true
c.state.rateLimitResetAt = resetAt
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *externalCredential) earliestReset() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
if c.state.hardRateLimited {
return c.state.rateLimitResetAt
}
@@ -432,7 +432,7 @@ func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Con
}
func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
c.stateMutex.Lock()
c.stateAccess.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
@@ -494,7 +494,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
}
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
@@ -569,9 +569,9 @@ func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Resp
}
}
// Forward transport with retries
if c.httpClient != nil {
if c.forwardHTTPClient != nil {
forwardClient := &http.Client{
Transport: c.httpClient.Transport,
Transport: c.forwardHTTPClient.Transport,
Timeout: 5 * time.Second,
}
return doHTTPWithRetry(ctx, forwardClient, buildRequest(c.baseURL))
@@ -602,10 +602,10 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
// 404 means the remote does not have a status endpoint yet;
// usage will be updated passively from response headers.
if response.StatusCode == http.StatusNotFound {
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.consecutivePollFailures = 0
c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
} else {
c.incrementPollFailures()
}
@@ -624,7 +624,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
return
}
c.stateMutex.Lock()
c.stateAccess.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
@@ -645,28 +645,28 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
}
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *externalCredential) lastUpdatedTime() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.lastUpdated
}
func (c *externalCredential) markUsagePollAttempted() {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
c.state.lastUpdated = time.Now()
}
func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration {
c.stateMutex.RLock()
c.stateAccess.RLock()
failures := c.state.consecutivePollFailures
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
if failures <= 0 {
return baseInterval
}
@@ -678,17 +678,17 @@ func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Durati
}
func (c *externalCredential) isPollBackoffAtCap() bool {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
failures := c.state.consecutivePollFailures
return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff
}
func (c *externalCredential) incrementPollFailures() {
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.consecutivePollFailures++
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
@@ -698,14 +698,14 @@ func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage {
return c.usageTracker
}
func (c *externalCredential) httpTransport() *http.Client {
func (c *externalCredential) httpClient() *http.Client {
if c.reverseHttpClient != nil {
session := c.getReverseSession()
if session != nil && !session.IsClosed() {
return c.reverseHttpClient
}
}
return c.httpClient
return c.forwardHTTPClient
}
func (c *externalCredential) ocmDialer() N.Dialer {

View File

@@ -62,10 +62,10 @@ func (c *defaultCredential) ensureCredentialWatcher() error {
}
func (c *defaultCredential) retryCredentialReloadIfNeeded() {
c.stateMutex.RLock()
c.stateAccess.RLock()
unavailable := c.state.unavailable
lastAttempt := c.state.lastCredentialLoadAttempt
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
if !unavailable {
return
}
@@ -84,10 +84,10 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
c.reloadAccess.Lock()
defer c.reloadAccess.Unlock()
c.stateMutex.RLock()
c.stateAccess.RLock()
unavailable := c.state.unavailable
lastAttempt := c.state.lastCredentialLoadAttempt
c.stateMutex.RUnlock()
c.stateAccess.RUnlock()
if !force {
if !unavailable {
return nil
@@ -97,39 +97,39 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
}
}
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.lastCredentialLoadAttempt = time.Now()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
credentials, err := platformReadCredentials(c.credentialPath)
if err != nil {
return c.markCredentialsUnavailable(E.Cause(err, "read credentials"))
}
c.accessMutex.Lock()
c.access.Lock()
c.credentials = credentials
c.accessMutex.Unlock()
c.access.Unlock()
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.unavailable = false
c.state.lastCredentialLoadError = ""
c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
return nil
}
func (c *defaultCredential) markCredentialsUnavailable(err error) error {
c.accessMutex.Lock()
c.access.Lock()
hadCredentials := c.credentials != nil
c.credentials = nil
c.accessMutex.Unlock()
c.access.Unlock()
c.stateMutex.Lock()
c.stateAccess.Lock()
c.state.unavailable = true
c.state.lastCredentialLoadError = err.Error()
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
c.stateAccess.Unlock()
if shouldInterrupt && hadCredentials {
c.interruptConnections()

View File

@@ -0,0 +1,225 @@
package ocm
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"os"
"os/user"
"path/filepath"
"time"
E "github.com/sagernet/sing/common/exceptions"
)
const (
oauth2ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
oauth2TokenURL = "https://auth.openai.com/oauth/token"
openaiAPIBaseURL = "https://api.openai.com"
chatGPTBackendURL = "https://chatgpt.com/backend-api/codex"
tokenRefreshIntervalDays = 8
)
func getRealUser() (*user.User, error) {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
sudoUserInfo, err := user.Lookup(sudoUser)
if err == nil {
return sudoUserInfo, nil
}
}
return user.Current()
}
func getDefaultCredentialsPath() (string, error) {
if codexHome := os.Getenv("CODEX_HOME"); codexHome != "" {
return filepath.Join(codexHome, "auth.json"), nil
}
userInfo, err := getRealUser()
if err != nil {
return "", err
}
return filepath.Join(userInfo.HomeDir, ".codex", "auth.json"), nil
}
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var credentials oauthCredentials
err = json.Unmarshal(data, &credentials)
if err != nil {
return nil, err
}
return &credentials, nil
}
func checkCredentialFileWritable(path string) error {
file, err := os.OpenFile(path, os.O_WRONLY, 0)
if err != nil {
return err
}
return file.Close()
}
func writeCredentialsToFile(credentials *oauthCredentials, path string) error {
data, err := json.MarshalIndent(credentials, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o600)
}
type oauthCredentials struct {
APIKey string `json:"OPENAI_API_KEY,omitempty"`
Tokens *tokenData `json:"tokens,omitempty"`
LastRefresh *time.Time `json:"last_refresh,omitempty"`
}
type tokenData struct {
IDToken string `json:"id_token,omitempty"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
AccountID string `json:"account_id,omitempty"`
}
func (c *oauthCredentials) isAPIKeyMode() bool {
return c.APIKey != ""
}
func (c *oauthCredentials) getAccessToken() string {
if c.APIKey != "" {
return c.APIKey
}
if c.Tokens != nil {
return c.Tokens.AccessToken
}
return ""
}
func (c *oauthCredentials) getAccountID() string {
if c.Tokens != nil {
return c.Tokens.AccountID
}
return ""
}
func (c *oauthCredentials) needsRefresh() bool {
if c.APIKey != "" {
return false
}
if c.Tokens == nil || c.Tokens.RefreshToken == "" {
return false
}
if c.LastRefresh == nil {
return true
}
return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour
}
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" {
return nil, E.New("refresh token is empty")
}
requestBody, err := json.Marshal(map[string]string{
"grant_type": "refresh_token",
"refresh_token": credentials.Tokens.RefreshToken,
"client_id": oauth2ClientID,
"scope": "openid profile email",
})
if err != nil {
return nil, E.Cause(err, "marshal request")
}
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
return request, nil
})
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode == http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
return nil, E.New("refresh rate limited: ", response.Status, " ", string(body))
}
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
}
var tokenResponse struct {
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
if err != nil {
return nil, E.Cause(err, "decode response")
}
newCredentials := *credentials
if newCredentials.Tokens == nil {
newCredentials.Tokens = &tokenData{}
}
if tokenResponse.IDToken != "" {
newCredentials.Tokens.IDToken = tokenResponse.IDToken
}
if tokenResponse.AccessToken != "" {
newCredentials.Tokens.AccessToken = tokenResponse.AccessToken
}
if tokenResponse.RefreshToken != "" {
newCredentials.Tokens.RefreshToken = tokenResponse.RefreshToken
}
now := time.Now()
newCredentials.LastRefresh = &now
return &newCredentials, nil
}
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
if credentials == nil {
return nil
}
cloned := *credentials
if credentials.Tokens != nil {
clonedTokens := *credentials.Tokens
cloned.Tokens = &clonedTokens
}
if credentials.LastRefresh != nil {
lastRefresh := *credentials.LastRefresh
cloned.LastRefresh = &lastRefresh
}
return &cloned
}
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
if left == nil || right == nil {
return left == right
}
if left.APIKey != right.APIKey {
return false
}
if (left.Tokens == nil) != (right.Tokens == nil) {
return false
}
if left.Tokens != nil && *left.Tokens != *right.Tokens {
return false
}
if (left.LastRefresh == nil) != (right.LastRefresh == nil) {
return false
}
if left.LastRefresh != nil && !left.LastRefresh.Equal(*right.LastRefresh) {
return false
}
return true
}

View File

@@ -0,0 +1,411 @@
package ocm
import (
"context"
"math/rand/v2"
"sync"
"sync/atomic"
"time"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
)
type credentialProvider interface {
selectCredential(sessionID string, selection credentialSelection) (credential, bool, error)
onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential
linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool
pollIfStale(ctx context.Context)
allCredentials() []credential
close()
}
type singleCredentialProvider struct {
cred credential
sessionAccess sync.RWMutex
sessions map[string]time.Time
}
func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) {
if !selection.allows(p.cred) {
return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out")
}
if !p.cred.isAvailable() {
return nil, false, p.cred.unavailableError()
}
if !p.cred.isUsable() {
return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited")
}
var isNew bool
if sessionID != "" {
p.sessionAccess.Lock()
if p.sessions == nil {
p.sessions = make(map[string]time.Time)
}
_, exists := p.sessions[sessionID]
if !exists {
p.sessions[sessionID] = time.Now()
isNew = true
}
p.sessionAccess.Unlock()
}
return p.cred, isNew, nil
}
func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential {
cred.markRateLimited(resetAt)
return nil
}
func (p *singleCredentialProvider) pollIfStale(ctx context.Context) {
now := time.Now()
p.sessionAccess.Lock()
for id, createdAt := range p.sessions {
if now.Sub(createdAt) > sessionExpiry {
delete(p.sessions, id)
}
}
p.sessionAccess.Unlock()
if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) {
p.cred.pollUsage(ctx)
}
}
func (p *singleCredentialProvider) allCredentials() []credential {
return []credential{p.cred}
}
func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool {
return func() bool {
return false
}
}
func (p *singleCredentialProvider) close() {}
type sessionEntry struct {
tag string
selectionScope credentialSelectionScope
createdAt time.Time
}
type credentialInterruptKey struct {
tag string
selectionScope credentialSelectionScope
}
type credentialInterruptEntry struct {
context context.Context
cancel context.CancelFunc
}
type balancerProvider struct {
credentials []credential
strategy string
roundRobinIndex atomic.Uint64
pollInterval time.Duration
rebalanceThreshold float64
sessionAccess sync.RWMutex
sessions map[string]sessionEntry
interruptAccess sync.Mutex
credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry
logger log.ContextLogger
}
func compositeCredentialSelectable(cred credential) bool {
return !cred.ocmIsAPIKeyMode()
}
func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider {
if pollInterval <= 0 {
pollInterval = defaultPollInterval
}
return &balancerProvider{
credentials: credentials,
strategy: strategy,
pollInterval: pollInterval,
rebalanceThreshold: rebalanceThreshold,
sessions: make(map[string]sessionEntry),
credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry),
logger: logger,
}
}
func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) {
if p.strategy == C.BalancerStrategyFallback {
best := p.pickCredential(selection.filter)
if best == nil {
return nil, false, allRateLimitedError(p.credentials)
}
return best, false, nil
}
selectionScope := selection.scopeOrDefault()
if sessionID != "" {
p.sessionAccess.RLock()
entry, exists := p.sessions[sessionID]
p.sessionAccess.RUnlock()
if exists {
if entry.selectionScope == selectionScope {
for _, cred := range p.credentials {
if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && selection.allows(cred) && cred.isUsable() {
if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) {
better := p.pickLeastUsed(selection.filter)
if better != nil && better.tagName() != cred.tagName() {
effectiveThreshold := p.rebalanceThreshold / cred.planWeight()
delta := cred.weeklyUtilization() - better.weeklyUtilization()
if delta > effectiveThreshold {
p.logger.Info("rebalancing away from ", cred.tagName(),
": utilization delta ", delta, "% exceeds effective threshold ",
effectiveThreshold, "% (weight ", cred.planWeight(), ")")
p.rebalanceCredential(cred.tagName(), selectionScope)
break
}
}
}
return cred, false, nil
}
}
}
p.sessionAccess.Lock()
delete(p.sessions, sessionID)
p.sessionAccess.Unlock()
}
}
best := p.pickCredential(selection.filter)
if best == nil {
return nil, false, allRateLimitedError(p.credentials)
}
isNew := sessionID != ""
if isNew {
p.sessionAccess.Lock()
p.sessions[sessionID] = sessionEntry{
tag: best.tagName(),
selectionScope: selectionScope,
createdAt: time.Now(),
}
p.sessionAccess.Unlock()
}
return best, isNew, nil
}
func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) {
key := credentialInterruptKey{tag: tag, selectionScope: selectionScope}
p.interruptAccess.Lock()
if entry, loaded := p.credentialInterrupts[key]; loaded {
entry.cancel()
}
ctx, cancel := context.WithCancel(context.Background())
p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel}
p.interruptAccess.Unlock()
p.sessionAccess.Lock()
for id, entry := range p.sessions {
if entry.tag == tag && entry.selectionScope == selectionScope {
delete(p.sessions, id)
}
}
p.sessionAccess.Unlock()
}
func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool {
if p.strategy == C.BalancerStrategyFallback {
return func() bool { return false }
}
key := credentialInterruptKey{
tag: cred.tagName(),
selectionScope: selection.scopeOrDefault(),
}
p.interruptAccess.Lock()
entry, loaded := p.credentialInterrupts[key]
if !loaded {
ctx, cancel := context.WithCancel(context.Background())
entry = credentialInterruptEntry{context: ctx, cancel: cancel}
p.credentialInterrupts[key] = entry
}
p.interruptAccess.Unlock()
return context.AfterFunc(entry.context, onInterrupt)
}
func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential {
cred.markRateLimited(resetAt)
if p.strategy == C.BalancerStrategyFallback {
return p.pickCredential(selection.filter)
}
if sessionID != "" {
p.sessionAccess.Lock()
delete(p.sessions, sessionID)
p.sessionAccess.Unlock()
}
best := p.pickCredential(selection.filter)
if best != nil && sessionID != "" {
p.sessionAccess.Lock()
p.sessions[sessionID] = sessionEntry{
tag: best.tagName(),
selectionScope: selection.scopeOrDefault(),
createdAt: time.Now(),
}
p.sessionAccess.Unlock()
}
return best
}
func (p *balancerProvider) pickCredential(filter func(credential) bool) credential {
switch p.strategy {
case C.BalancerStrategyRoundRobin:
return p.pickRoundRobin(filter)
case C.BalancerStrategyRandom:
return p.pickRandom(filter)
case C.BalancerStrategyFallback:
return p.pickFallback(filter)
default:
return p.pickLeastUsed(filter)
}
}
func (p *balancerProvider) pickFallback(filter func(credential) bool) credential {
for _, cred := range p.credentials {
if filter != nil && !filter(cred) {
continue
}
if !compositeCredentialSelectable(cred) {
continue
}
if cred.isUsable() {
return cred
}
}
return nil
}
const weeklyWindowHours = 7 * 24
func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential {
var best credential
bestScore := float64(-1)
now := time.Now()
for _, cred := range p.credentials {
if filter != nil && !filter(cred) {
continue
}
if !compositeCredentialSelectable(cred) {
continue
}
if !cred.isUsable() {
continue
}
remaining := cred.weeklyCap() - cred.weeklyUtilization()
score := remaining * cred.planWeight()
resetTime := cred.weeklyResetTime()
if !resetTime.IsZero() {
timeUntilReset := resetTime.Sub(now)
if timeUntilReset < time.Hour {
timeUntilReset = time.Hour
}
score *= weeklyWindowHours / timeUntilReset.Hours()
}
if score > bestScore {
bestScore = score
best = cred
}
}
return best
}
func ocmPlanWeight(accountType string) float64 {
switch accountType {
case "pro":
return 10
case "plus":
return 1
default:
return 1
}
}
func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential {
start := int(p.roundRobinIndex.Add(1) - 1)
count := len(p.credentials)
for offset := range count {
candidate := p.credentials[(start+offset)%count]
if filter != nil && !filter(candidate) {
continue
}
if !compositeCredentialSelectable(candidate) {
continue
}
if candidate.isUsable() {
return candidate
}
}
return nil
}
func (p *balancerProvider) pickRandom(filter func(credential) bool) credential {
var usable []credential
for _, candidate := range p.credentials {
if filter != nil && !filter(candidate) {
continue
}
if !compositeCredentialSelectable(candidate) {
continue
}
if candidate.isUsable() {
usable = append(usable, candidate)
}
}
if len(usable) == 0 {
return nil
}
return usable[rand.IntN(len(usable))]
}
func (p *balancerProvider) pollIfStale(ctx context.Context) {
now := time.Now()
p.sessionAccess.Lock()
for id, entry := range p.sessions {
if now.Sub(entry.createdAt) > sessionExpiry {
delete(p.sessions, id)
}
}
p.sessionAccess.Unlock()
for _, cred := range p.credentials {
if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) {
cred.pollUsage(ctx)
}
}
}
func (p *balancerProvider) allCredentials() []credential {
return p.credentials
}
func (p *balancerProvider) close() {}
func allRateLimitedError(credentials []credential) error {
var hasUnavailable bool
var earliest time.Time
for _, cred := range credentials {
if cred.unavailableError() != nil {
hasUnavailable = true
continue
}
resetAt := cred.earliestReset()
if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) {
earliest = resetAt
}
}
if hasUnavailable {
return E.New("all credentials unavailable")
}
if earliest.IsZero() {
return E.New("all credentials rate-limited")
}
return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest)))
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,17 +1,13 @@
package ocm
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"mime"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
boxService "github.com/sagernet/sing-box/adapter/service"
@@ -21,14 +17,11 @@ import (
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
"github.com/go-chi/chi/v5"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/responses"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
@@ -160,71 +153,20 @@ func isReverseProxyHeader(header string) bool {
}
}
func normalizeRateLimitIdentifier(limitIdentifier string) string {
trimmedIdentifier := strings.TrimSpace(strings.ToLower(limitIdentifier))
if trimmedIdentifier == "" {
return ""
}
return strings.ReplaceAll(trimmedIdentifier, "_", "-")
}
func parseInt64Header(headers http.Header, headerName string) (int64, bool) {
headerValue := strings.TrimSpace(headers.Get(headerName))
if headerValue == "" {
return 0, false
}
parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64)
if parseError != nil {
return 0, false
}
return parsedValue, true
}
func weeklyCycleHintForLimit(headers http.Header, limitIdentifier string) *WeeklyCycleHint {
normalizedLimitIdentifier := normalizeRateLimitIdentifier(limitIdentifier)
if normalizedLimitIdentifier == "" {
return nil
}
windowHeader := "x-" + normalizedLimitIdentifier + "-secondary-window-minutes"
resetHeader := "x-" + normalizedLimitIdentifier + "-secondary-reset-at"
windowMinutes, hasWindowMinutes := parseInt64Header(headers, windowHeader)
resetAtUnix, hasResetAt := parseInt64Header(headers, resetHeader)
if !hasWindowMinutes || !hasResetAt || windowMinutes <= 0 || resetAtUnix <= 0 {
return nil
}
return &WeeklyCycleHint{
WindowMinutes: windowMinutes,
ResetAt: time.Unix(resetAtUnix, 0).UTC(),
}
}
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
if activeLimitIdentifier != "" {
if activeHint := weeklyCycleHintForLimit(headers, activeLimitIdentifier); activeHint != nil {
return activeHint
}
}
return weeklyCycleHintForLimit(headers, "codex")
}
type Service struct {
boxService.Adapter
ctx context.Context
logger log.ContextLogger
options option.OCMServiceOptions
httpHeaders http.Header
listener *listener.Listener
tlsConfig tls.ServerConfig
httpServer *http.Server
userManager *UserManager
webSocketMutex sync.Mutex
webSocketGroup sync.WaitGroup
webSocketConns map[*webSocketSession]struct{}
shuttingDown bool
ctx context.Context
logger log.ContextLogger
options option.OCMServiceOptions
httpHeaders http.Header
listener *listener.Listener
tlsConfig tls.ServerConfig
httpServer *http.Server
userManager *UserManager
webSocketAccess sync.Mutex
webSocketGroup sync.WaitGroup
webSocketConns map[*webSocketSession]struct{}
shuttingDown bool
// Legacy mode
legacyCredential *defaultCredential
@@ -361,562 +303,6 @@ func (s *Service) Start(stage adapter.StartStage) error {
return nil
}
func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) {
if len(s.options.Users) > 0 {
return credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
}
provider := noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
if provider == nil {
return nil, E.New("no credential available")
}
return provider, nil
}
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := log.ContextWithNewID(r.Context())
if r.URL.Path == "/ocm/v1/status" {
s.handleStatusEndpoint(w, r)
return
}
if r.URL.Path == "/ocm/v1/reverse" {
s.handleReverseConnect(ctx, w, r)
return
}
path := r.URL.Path
if !strings.HasPrefix(path, "/v1/") {
writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/")
return
}
var username string
if len(s.options.Users) > 0 {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
var ok bool
username, ok = s.userManager.Authenticate(clientToken)
if !ok {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
}
sessionID := r.Header.Get("session_id")
// Resolve credential provider and user config
var provider credentialProvider
var userConfig *option.OCMUser
if len(s.options.Users) > 0 {
userConfig = s.userConfigMap[username]
var err error
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
s.logger.ErrorContext(ctx, "resolve credential: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
} else {
provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
}
if provider == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
return
}
provider.pollIfStale(s.ctx)
selection := credentialSelectionForUser(userConfig)
selectedCredential, isNew, err := provider.selectCredential(sessionID, selection)
if err != nil {
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error()))
return
}
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, selection, isNew)
return
}
if !selectedCredential.isExternal() && selectedCredential.ocmIsAPIKeyMode() {
// API key mode path handling
} else if !selectedCredential.isExternal() {
if path == "/v1/chat/completions" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"chat completions endpoint is only available in API key mode")
return
}
}
shouldTrackUsage := selectedCredential.usageTrackerOrNil() != nil &&
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses"))
canRetryRequest := len(provider.allCredentials()) > 1
// Read body for model extraction and retry buffer when JSON replay is useful.
var bodyBytes []byte
var requestModel string
var requestServiceTier string
if r.Body != nil && (shouldTrackUsage || canRetryRequest) {
mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type"))
isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json"))
if isJSONRequest {
bodyBytes, err = io.ReadAll(r.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read request body: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
return
}
var request struct {
Model string `json:"model"`
ServiceTier string `json:"service_tier"`
}
if json.Unmarshal(bodyBytes, &request) == nil {
requestModel = request.Model
requestServiceTier = request.ServiceTier
}
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
}
if isNew {
logParts := []any{"assigned credential ", selectedCredential.tagName()}
if sessionID != "" {
logParts = append(logParts, " for session ", sessionID)
}
if username != "" {
logParts = append(logParts, " by user ", username)
}
if requestModel != "" {
logParts = append(logParts, ", model=", requestModel)
}
if requestServiceTier == "priority" {
logParts = append(logParts, ", fast")
}
s.logger.DebugContext(ctx, logParts...)
}
requestContext := selectedCredential.wrapRequestContext(ctx)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
defer func() {
requestContext.cancelRequest()
}()
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if err != nil {
s.logger.ErrorContext(ctx, "create proxy request: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
return
}
response, err := selectedCredential.httpTransport().Do(proxyRequest)
if err != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request")
return
}
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
return
}
requestContext.releaseCredentialInterrupt()
// Transparent 429 retry
for response.StatusCode == http.StatusTooManyRequests {
resetAt := parseOCMRateLimitResetFromHeaders(response.Header)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete
selectedCredential.updateStateFromHeaders(response.Header)
if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil {
response.Body.Close()
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
return
}
response.Body.Close()
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
requestContext.cancelRequest()
requestContext = nextCredential.wrapRequestContext(ctx)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
return
}
retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest)
if retryErr != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request")
return
}
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
return
}
requestContext.releaseCredentialInterrupt()
response = retryResponse
selectedCredential = nextCredential
}
defer response.Body.Close()
selectedCredential.updateStateFromHeaders(response.Header)
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
go selectedCredential.pollUsage(s.ctx)
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
return
}
// Rewrite response headers for external users
if userConfig != nil && userConfig.ExternalCredential != "" {
s.rewriteResponseHeadersForExternalUser(response.Header, userConfig)
}
for key, values := range response.Header {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
w.Header()[key] = values
}
}
w.WriteHeader(response.StatusCode)
usageTracker := selectedCredential.usageTrackerOrNil()
if usageTracker != nil && response.StatusCode == http.StatusOK &&
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) {
s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username)
} else {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
if err == nil && mediaType != "text/event-stream" {
_, _ = io.Copy(w, response.Body)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
buffer := make([]byte, buf.BufferSize)
for {
n, err := response.Body.Read(buffer)
if n > 0 {
_, writeError := w.Write(buffer[:n])
if writeError != nil {
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
return
}
}
}
}
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) {
isChatCompletions := path == "/v1/chat/completions"
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
isStreaming := err == nil && mediaType == "text/event-stream"
if !isStreaming && !isChatCompletions && response.Header.Get("Content-Type") == "" {
isStreaming = true
}
if !isStreaming {
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read response body: ", err)
return
}
var responseModel, serviceTier string
var inputTokens, outputTokens, cachedTokens int64
if isChatCompletions {
var chatCompletion openai.ChatCompletion
if json.Unmarshal(bodyBytes, &chatCompletion) == nil {
responseModel = chatCompletion.Model
serviceTier = string(chatCompletion.ServiceTier)
inputTokens = chatCompletion.Usage.PromptTokens
outputTokens = chatCompletion.Usage.CompletionTokens
cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens
}
} else {
var responsesResponse responses.Response
if json.Unmarshal(bodyBytes, &responsesResponse) == nil {
responseModel = string(responsesResponse.Model)
serviceTier = string(responsesResponse.ServiceTier)
inputTokens = responsesResponse.Usage.InputTokens
outputTokens = responsesResponse.Usage.OutputTokens
cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
}
}
if inputTokens > 0 || outputTokens > 0 {
if responseModel == "" {
responseModel = requestModel
}
if responseModel != "" {
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
inputTokens,
outputTokens,
cachedTokens,
serviceTier,
username,
time.Now(),
weeklyCycleHint,
)
}
}
_, _ = writer.Write(bodyBytes)
return
}
flusher, ok := writer.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
var inputTokens, outputTokens, cachedTokens int64
var responseModel, serviceTier string
buffer := make([]byte, buf.BufferSize)
var leftover []byte
for {
n, err := response.Body.Read(buffer)
if n > 0 {
data := append(leftover, buffer[:n]...)
lines := bytes.Split(data, []byte("\n"))
if err == nil {
leftover = lines[len(lines)-1]
lines = lines[:len(lines)-1]
} else {
leftover = nil
}
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if bytes.HasPrefix(line, []byte("data: ")) {
eventData := bytes.TrimPrefix(line, []byte("data: "))
if bytes.Equal(eventData, []byte("[DONE]")) {
continue
}
if isChatCompletions {
var chatChunk openai.ChatCompletionChunk
if json.Unmarshal(eventData, &chatChunk) == nil {
if chatChunk.Model != "" {
responseModel = chatChunk.Model
}
if chatChunk.ServiceTier != "" {
serviceTier = string(chatChunk.ServiceTier)
}
if chatChunk.Usage.PromptTokens > 0 {
inputTokens = chatChunk.Usage.PromptTokens
cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens
}
if chatChunk.Usage.CompletionTokens > 0 {
outputTokens = chatChunk.Usage.CompletionTokens
}
}
} else {
var streamEvent responses.ResponseStreamEventUnion
if json.Unmarshal(eventData, &streamEvent) == nil {
if streamEvent.Type == "response.completed" {
completedEvent := streamEvent.AsResponseCompleted()
if string(completedEvent.Response.Model) != "" {
responseModel = string(completedEvent.Response.Model)
}
if completedEvent.Response.ServiceTier != "" {
serviceTier = string(completedEvent.Response.ServiceTier)
}
if completedEvent.Response.Usage.InputTokens > 0 {
inputTokens = completedEvent.Response.Usage.InputTokens
cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens
}
if completedEvent.Response.Usage.OutputTokens > 0 {
outputTokens = completedEvent.Response.Usage.OutputTokens
}
}
}
}
}
}
_, writeError := writer.Write(buffer[:n])
if writeError != nil {
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
if responseModel == "" {
responseModel = requestModel
}
if inputTokens > 0 || outputTokens > 0 {
if responseModel != "" {
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
inputTokens,
outputTokens,
cachedTokens,
serviceTier,
username,
time.Now(),
weeklyCycleHint,
)
}
}
return
}
}
}
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
return
}
if len(s.options.Users) == 0 {
writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication")
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
username, ok := s.userManager.Authenticate(clientToken)
if !ok {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
userConfig := s.userConfigMap[username]
if userConfig == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found")
return
}
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
provider.pollIfStale(r.Context())
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]float64{
"five_hour_utilization": avgFiveHour,
"weekly_utilization": avgWeekly,
"plan_weight": totalWeight,
})
}
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) {
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
for _, cred := range provider.allCredentials() {
if !cred.isAvailable() {
continue
}
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
continue
}
if !userConfig.AllowExternalUsage && cred.isExternal() {
continue
}
weight := cred.planWeight()
remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization()
if remaining5h < 0 {
remaining5h = 0
}
remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization()
if remainingWeekly < 0 {
remainingWeekly = 0
}
totalWeightedRemaining5h += remaining5h * weight
totalWeightedRemainingWeekly += remainingWeekly * weight
totalWeight += weight
}
if totalWeight == 0 {
return 100, 100, 0
}
return 100 - totalWeightedRemaining5h/totalWeight,
100 - totalWeightedRemainingWeekly/totalWeight,
totalWeight
}
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.OCMUser) {
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name)
if err != nil {
return
}
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
if activeLimitIdentifier == "" {
activeLimitIdentifier = "codex"
}
headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64))
headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64))
if totalWeight > 0 {
headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64))
}
}
func (s *Service) InterfaceUpdated() {
for _, cred := range s.allCredentials {
extCred, ok := cred.(*externalCredential)
@@ -952,8 +338,8 @@ func (s *Service) Close() error {
}
func (s *Service) registerWebSocketSession(session *webSocketSession) bool {
s.webSocketMutex.Lock()
defer s.webSocketMutex.Unlock()
s.webSocketAccess.Lock()
defer s.webSocketAccess.Unlock()
if s.shuttingDown {
return false
@@ -965,12 +351,12 @@ func (s *Service) registerWebSocketSession(session *webSocketSession) bool {
}
func (s *Service) unregisterWebSocketSession(session *webSocketSession) {
s.webSocketMutex.Lock()
s.webSocketAccess.Lock()
_, loaded := s.webSocketConns[session]
if loaded {
delete(s.webSocketConns, session)
}
s.webSocketMutex.Unlock()
s.webSocketAccess.Unlock()
if loaded {
s.webSocketGroup.Done()
@@ -978,28 +364,28 @@ func (s *Service) unregisterWebSocketSession(session *webSocketSession) {
}
func (s *Service) isShuttingDown() bool {
s.webSocketMutex.Lock()
defer s.webSocketMutex.Unlock()
s.webSocketAccess.Lock()
defer s.webSocketAccess.Unlock()
return s.shuttingDown
}
func (s *Service) interruptWebSocketSessionsForCredential(tag string) {
s.webSocketMutex.Lock()
s.webSocketAccess.Lock()
var toClose []*webSocketSession
for session := range s.webSocketConns {
if session.credentialTag == tag {
toClose = append(toClose, session)
}
}
s.webSocketMutex.Unlock()
s.webSocketAccess.Unlock()
for _, session := range toClose {
session.Close()
}
}
func (s *Service) startWebSocketShutdown() []*webSocketSession {
s.webSocketMutex.Lock()
defer s.webSocketMutex.Unlock()
s.webSocketAccess.Lock()
defer s.webSocketAccess.Unlock()
s.shuttingDown = true

View File

@@ -0,0 +1,504 @@
package ocm
import (
"bytes"
"context"
"encoding/json"
"io"
"mime"
"net/http"
"strconv"
"strings"
"time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/responses"
)
func weeklyCycleHintForLimit(headers http.Header, limitIdentifier string) *WeeklyCycleHint {
normalizedLimitIdentifier := normalizeRateLimitIdentifier(limitIdentifier)
if normalizedLimitIdentifier == "" {
return nil
}
windowHeader := "x-" + normalizedLimitIdentifier + "-secondary-window-minutes"
resetHeader := "x-" + normalizedLimitIdentifier + "-secondary-reset-at"
windowMinutes, hasWindowMinutes := parseInt64Header(headers, windowHeader)
resetAtUnix, hasResetAt := parseInt64Header(headers, resetHeader)
if !hasWindowMinutes || !hasResetAt || windowMinutes <= 0 || resetAtUnix <= 0 {
return nil
}
return &WeeklyCycleHint{
WindowMinutes: windowMinutes,
ResetAt: time.Unix(resetAtUnix, 0).UTC(),
}
}
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
if activeLimitIdentifier != "" {
if activeHint := weeklyCycleHintForLimit(headers, activeLimitIdentifier); activeHint != nil {
return activeHint
}
}
return weeklyCycleHintForLimit(headers, "codex")
}
func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) {
if len(s.options.Users) > 0 {
return credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
}
provider := noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
if provider == nil {
return nil, E.New("no credential available")
}
return provider, nil
}
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := log.ContextWithNewID(r.Context())
if r.URL.Path == "/ocm/v1/status" {
s.handleStatusEndpoint(w, r)
return
}
if r.URL.Path == "/ocm/v1/reverse" {
s.handleReverseConnect(ctx, w, r)
return
}
path := r.URL.Path
if !strings.HasPrefix(path, "/v1/") {
writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/")
return
}
var username string
if len(s.options.Users) > 0 {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
var ok bool
username, ok = s.userManager.Authenticate(clientToken)
if !ok {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
}
sessionID := r.Header.Get("session_id")
// Resolve credential provider and user config
var provider credentialProvider
var userConfig *option.OCMUser
if len(s.options.Users) > 0 {
userConfig = s.userConfigMap[username]
var err error
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
s.logger.ErrorContext(ctx, "resolve credential: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
} else {
provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
}
if provider == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
return
}
provider.pollIfStale(s.ctx)
selection := credentialSelectionForUser(userConfig)
selectedCredential, isNew, err := provider.selectCredential(sessionID, selection)
if err != nil {
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error()))
return
}
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, selection, isNew)
return
}
if !selectedCredential.isExternal() && selectedCredential.ocmIsAPIKeyMode() {
// API key mode path handling
} else if !selectedCredential.isExternal() {
if path == "/v1/chat/completions" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"chat completions endpoint is only available in API key mode")
return
}
}
shouldTrackUsage := selectedCredential.usageTrackerOrNil() != nil &&
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses"))
canRetryRequest := len(provider.allCredentials()) > 1
// Read body for model extraction and retry buffer when JSON replay is useful.
var bodyBytes []byte
var requestModel string
var requestServiceTier string
if r.Body != nil && (shouldTrackUsage || canRetryRequest) {
mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type"))
isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json"))
if isJSONRequest {
bodyBytes, err = io.ReadAll(r.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read request body: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
return
}
var request struct {
Model string `json:"model"`
ServiceTier string `json:"service_tier"`
}
if json.Unmarshal(bodyBytes, &request) == nil {
requestModel = request.Model
requestServiceTier = request.ServiceTier
}
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
}
if isNew {
logParts := []any{"assigned credential ", selectedCredential.tagName()}
if sessionID != "" {
logParts = append(logParts, " for session ", sessionID)
}
if username != "" {
logParts = append(logParts, " by user ", username)
}
if requestModel != "" {
logParts = append(logParts, ", model=", requestModel)
}
if requestServiceTier == "priority" {
logParts = append(logParts, ", fast")
}
s.logger.DebugContext(ctx, logParts...)
}
requestContext := selectedCredential.wrapRequestContext(ctx)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
defer func() {
requestContext.cancelRequest()
}()
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if err != nil {
s.logger.ErrorContext(ctx, "create proxy request: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
return
}
response, err := selectedCredential.httpClient().Do(proxyRequest)
if err != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request")
return
}
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
return
}
requestContext.releaseCredentialInterrupt()
// Transparent 429 retry
for response.StatusCode == http.StatusTooManyRequests {
resetAt := parseOCMRateLimitResetFromHeaders(response.Header)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete
selectedCredential.updateStateFromHeaders(response.Header)
if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil {
response.Body.Close()
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
return
}
response.Body.Close()
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
requestContext.cancelRequest()
requestContext = nextCredential.wrapRequestContext(ctx)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
return
}
retryResponse, retryErr := nextCredential.httpClient().Do(retryRequest)
if retryErr != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request")
return
}
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
return
}
requestContext.releaseCredentialInterrupt()
response = retryResponse
selectedCredential = nextCredential
}
defer response.Body.Close()
selectedCredential.updateStateFromHeaders(response.Header)
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
go selectedCredential.pollUsage(s.ctx)
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
return
}
// Rewrite response headers for external users
if userConfig != nil && userConfig.ExternalCredential != "" {
s.rewriteResponseHeadersForExternalUser(response.Header, userConfig)
}
for key, values := range response.Header {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
w.Header()[key] = values
}
}
w.WriteHeader(response.StatusCode)
usageTracker := selectedCredential.usageTrackerOrNil()
if usageTracker != nil && response.StatusCode == http.StatusOK &&
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) {
s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username)
} else {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
if err == nil && mediaType != "text/event-stream" {
_, _ = io.Copy(w, response.Body)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
buffer := make([]byte, buf.BufferSize)
for {
n, err := response.Body.Read(buffer)
if n > 0 {
_, writeError := w.Write(buffer[:n])
if writeError != nil {
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
return
}
}
}
}
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) {
isChatCompletions := path == "/v1/chat/completions"
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
isStreaming := err == nil && mediaType == "text/event-stream"
if !isStreaming && !isChatCompletions && response.Header.Get("Content-Type") == "" {
isStreaming = true
}
if !isStreaming {
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read response body: ", err)
return
}
var responseModel, serviceTier string
var inputTokens, outputTokens, cachedTokens int64
if isChatCompletions {
var chatCompletion openai.ChatCompletion
if json.Unmarshal(bodyBytes, &chatCompletion) == nil {
responseModel = chatCompletion.Model
serviceTier = string(chatCompletion.ServiceTier)
inputTokens = chatCompletion.Usage.PromptTokens
outputTokens = chatCompletion.Usage.CompletionTokens
cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens
}
} else {
var responsesResponse responses.Response
if json.Unmarshal(bodyBytes, &responsesResponse) == nil {
responseModel = string(responsesResponse.Model)
serviceTier = string(responsesResponse.ServiceTier)
inputTokens = responsesResponse.Usage.InputTokens
outputTokens = responsesResponse.Usage.OutputTokens
cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
}
}
if inputTokens > 0 || outputTokens > 0 {
if responseModel == "" {
responseModel = requestModel
}
if responseModel != "" {
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
inputTokens,
outputTokens,
cachedTokens,
serviceTier,
username,
time.Now(),
weeklyCycleHint,
)
}
}
_, _ = writer.Write(bodyBytes)
return
}
flusher, ok := writer.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
var inputTokens, outputTokens, cachedTokens int64
var responseModel, serviceTier string
buffer := make([]byte, buf.BufferSize)
var leftover []byte
for {
n, err := response.Body.Read(buffer)
if n > 0 {
data := append(leftover, buffer[:n]...)
lines := bytes.Split(data, []byte("\n"))
if err == nil {
leftover = lines[len(lines)-1]
lines = lines[:len(lines)-1]
} else {
leftover = nil
}
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if bytes.HasPrefix(line, []byte("data: ")) {
eventData := bytes.TrimPrefix(line, []byte("data: "))
if bytes.Equal(eventData, []byte("[DONE]")) {
continue
}
if isChatCompletions {
var chatChunk openai.ChatCompletionChunk
if json.Unmarshal(eventData, &chatChunk) == nil {
if chatChunk.Model != "" {
responseModel = chatChunk.Model
}
if chatChunk.ServiceTier != "" {
serviceTier = string(chatChunk.ServiceTier)
}
if chatChunk.Usage.PromptTokens > 0 {
inputTokens = chatChunk.Usage.PromptTokens
cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens
}
if chatChunk.Usage.CompletionTokens > 0 {
outputTokens = chatChunk.Usage.CompletionTokens
}
}
} else {
var streamEvent responses.ResponseStreamEventUnion
if json.Unmarshal(eventData, &streamEvent) == nil {
if streamEvent.Type == "response.completed" {
completedEvent := streamEvent.AsResponseCompleted()
if string(completedEvent.Response.Model) != "" {
responseModel = string(completedEvent.Response.Model)
}
if completedEvent.Response.ServiceTier != "" {
serviceTier = string(completedEvent.Response.ServiceTier)
}
if completedEvent.Response.Usage.InputTokens > 0 {
inputTokens = completedEvent.Response.Usage.InputTokens
cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens
}
if completedEvent.Response.Usage.OutputTokens > 0 {
outputTokens = completedEvent.Response.Usage.OutputTokens
}
}
}
}
}
}
_, writeError := writer.Write(buffer[:n])
if writeError != nil {
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
if responseModel == "" {
responseModel = requestModel
}
if inputTokens > 0 || outputTokens > 0 {
if responseModel != "" {
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
inputTokens,
outputTokens,
cachedTokens,
serviceTier,
username,
time.Now(),
weeklyCycleHint,
)
}
}
return
}
}
}

View File

@@ -0,0 +1,114 @@
package ocm
import (
"encoding/json"
"net/http"
"strconv"
"strings"
"github.com/sagernet/sing-box/option"
)
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
return
}
if len(s.options.Users) == 0 {
writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication")
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
username, ok := s.userManager.Authenticate(clientToken)
if !ok {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
userConfig := s.userConfigMap[username]
if userConfig == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found")
return
}
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
provider.pollIfStale(r.Context())
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]float64{
"five_hour_utilization": avgFiveHour,
"weekly_utilization": avgWeekly,
"plan_weight": totalWeight,
})
}
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) {
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
for _, cred := range provider.allCredentials() {
if !cred.isAvailable() {
continue
}
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
continue
}
if !userConfig.AllowExternalUsage && cred.isExternal() {
continue
}
weight := cred.planWeight()
remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization()
if remaining5h < 0 {
remaining5h = 0
}
remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization()
if remainingWeekly < 0 {
remainingWeekly = 0
}
totalWeightedRemaining5h += remaining5h * weight
totalWeightedRemainingWeekly += remainingWeekly * weight
totalWeight += weight
}
if totalWeight == 0 {
return 100, 100, 0
}
return 100 - totalWeightedRemaining5h/totalWeight,
100 - totalWeightedRemainingWeekly/totalWeight,
totalWeight
}
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.OCMUser) {
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name)
if err != nil {
return
}
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
if activeLimitIdentifier == "" {
activeLimitIdentifier = "codex"
}
headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64))
headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64))
if totalWeight > 0 {
headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64))
}
}

View File

@@ -7,13 +7,13 @@ import (
)
type UserManager struct {
accessMutex sync.RWMutex
access sync.RWMutex
tokenMap map[string]string
}
func (m *UserManager) UpdateUsers(users []option.OCMUser) {
m.accessMutex.Lock()
defer m.accessMutex.Unlock()
m.access.Lock()
defer m.access.Unlock()
tokenMap := make(map[string]string, len(users))
for _, user := range users {
tokenMap[user.Token] = user.Name
@@ -22,8 +22,8 @@ func (m *UserManager) UpdateUsers(users []option.OCMUser) {
}
func (m *UserManager) Authenticate(token string) (string, bool) {
m.accessMutex.RLock()
m.access.RLock()
username, found := m.tokenMap[token]
m.accessMutex.RUnlock()
m.access.RUnlock()
return username, found
}