From 04bd63b45573eb4f81153a3158dc9f6de8f50155 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 20:17:23 +0800 Subject: [PATCH] ccm,ocm: reorganize files and improve naming conventions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Split credential_state.go (1500+ lines) into credential.go, credential_default.go, credential_provider.go, credential_builder.go. Split service.go (900+ lines) into service.go, service_handler.go, service_status.go. Rename credential.go to credential_oauth.go to avoid name conflict with the credential interface. Apply naming fixes: accessMutex→access, stateMutex→stateAccess, sessionMutex→sessionAccess, webSocketMutex→webSocketAccess, httpTransport()→httpClient(), httpClient field→forwardHTTPClient, weeklyWindowDuration→weeklyWindowHours. --- service/ccm/credential.go | 353 +++---- service/ccm/credential_builder.go | 192 ++++ service/ccm/credential_default.go | 726 +++++++++++++ service/ccm/credential_external.go | 104 +- service/ccm/credential_file.go | 28 +- service/ccm/credential_oauth.go | 224 ++++ service/ccm/credential_provider.go | 405 ++++++++ service/ccm/credential_state.go | 1506 --------------------------- service/ccm/service.go | 570 +---------- service/ccm/service_handler.go | 499 +++++++++ service/ccm/service_status.go | 109 ++ service/ccm/service_user.go | 10 +- service/ocm/credential.go | 375 ++++--- service/ocm/credential_builder.go | 223 ++++ service/ocm/credential_default.go | 749 ++++++++++++++ service/ocm/credential_external.go | 106 +- service/ocm/credential_file.go | 28 +- service/ocm/credential_oauth.go | 225 ++++ service/ocm/credential_provider.go | 411 ++++++++ service/ocm/credential_state.go | 1524 ---------------------------- service/ocm/service.go | 658 +----------- service/ocm/service_handler.go | 504 +++++++++ service/ocm/service_status.go | 114 +++ service/ocm/service_user.go | 10 +- 24 files changed, 4877 insertions(+), 4776 deletions(-) create mode 100644 service/ccm/credential_builder.go create mode 100644 service/ccm/credential_default.go create mode 100644 service/ccm/credential_oauth.go create mode 100644 service/ccm/credential_provider.go delete mode 100644 service/ccm/credential_state.go create mode 100644 service/ccm/service_handler.go create mode 100644 service/ccm/service_status.go create mode 100644 service/ocm/credential_builder.go create mode 100644 service/ocm/credential_default.go create mode 100644 service/ocm/credential_oauth.go create mode 100644 service/ocm/credential_provider.go delete mode 100644 service/ocm/credential_state.go create mode 100644 service/ocm/service_handler.go create mode 100644 service/ocm/service_status.go diff --git a/service/ccm/credential.go b/service/ccm/credential.go index da559c173..8589676a8 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -1,224 +1,187 @@ package ccm import ( - "bytes" "context" - "encoding/json" - "io" "net/http" - "os" - "os/user" - "path/filepath" - "runtime" - "slices" + "strconv" "sync" "time" - - "github.com/sagernet/sing-box/log" - E "github.com/sagernet/sing/common/exceptions" ) const ( - oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" - oauth2TokenURL = "https://platform.claude.com/v1/oauth/token" - claudeAPIBaseURL = "https://api.anthropic.com" - tokenRefreshBufferMs = 60000 - anthropicBetaOAuthValue = "oauth-2025-04-20" + defaultPollInterval = 60 * time.Minute + failedPollRetryInterval = time.Minute + httpRetryMaxBackoff = 5 * time.Minute ) -const ccmUserAgentFallback = "claude-code/2.1.72" - -var ( - ccmUserAgentOnce sync.Once - ccmUserAgentValue string +const ( + httpRetryMaxAttempts = 3 + httpRetryInitialDelay = 200 * time.Millisecond ) -func initCCMUserAgent(logger log.ContextLogger) { - ccmUserAgentOnce.Do(func() { - version, err := detectClaudeCodeVersion() - if err != nil { - logger.Error("detect Claude Code version: ", err) - ccmUserAgentValue = ccmUserAgentFallback - return +const sessionExpiry = 24 * time.Hour + +func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) { + var lastError error + for attempt := range httpRetryMaxAttempts { + if attempt > 0 { + delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1)) + select { + case <-ctx.Done(): + return nil, lastError + case <-time.After(delay): + } } - logger.Debug("detected Claude Code version: ", version) - ccmUserAgentValue = "claude-code/" + version - }) -} - -func detectClaudeCodeVersion() (string, error) { - userInfo, err := getRealUser() - if err != nil { - return "", E.Cause(err, "get user") - } - binaryName := "claude" - if runtime.GOOS == "windows" { - binaryName = "claude.exe" - } - linkPath := filepath.Join(userInfo.HomeDir, ".local", "bin", binaryName) - target, err := os.Readlink(linkPath) - if err != nil { - return "", E.Cause(err, "readlink ", linkPath) - } - if !filepath.IsAbs(target) { - target = filepath.Join(filepath.Dir(linkPath), target) - } - parent := filepath.Base(filepath.Dir(target)) - if parent != "versions" { - return "", E.New("unexpected symlink target: ", target) - } - return filepath.Base(target), nil -} - -func getRealUser() (*user.User, error) { - if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { - sudoUserInfo, err := user.Lookup(sudoUser) - if err == nil { - return sudoUserInfo, nil - } - } - return user.Current() -} - -func getDefaultCredentialsPath() (string, error) { - if configDir := os.Getenv("CLAUDE_CONFIG_DIR"); configDir != "" { - return filepath.Join(configDir, ".credentials.json"), nil - } - userInfo, err := getRealUser() - if err != nil { - return "", err - } - return filepath.Join(userInfo.HomeDir, ".claude", ".credentials.json"), nil -} - -func readCredentialsFromFile(path string) (*oauthCredentials, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, err - } - var credentialsContainer struct { - ClaudeAIAuth *oauthCredentials `json:"claudeAiOauth,omitempty"` - } - err = json.Unmarshal(data, &credentialsContainer) - if err != nil { - return nil, err - } - if credentialsContainer.ClaudeAIAuth == nil { - return nil, E.New("claudeAiOauth field not found in credentials") - } - return credentialsContainer.ClaudeAIAuth, nil -} - -func checkCredentialFileWritable(path string) error { - file, err := os.OpenFile(path, os.O_WRONLY, 0) - if err != nil { - return err - } - return file.Close() -} - -func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) error { - data, err := json.MarshalIndent(map[string]any{ - "claudeAiOauth": oauthCredentials, - }, "", " ") - if err != nil { - return err - } - return os.WriteFile(path, data, 0o600) -} - -type oauthCredentials struct { - AccessToken string `json:"accessToken"` - RefreshToken string `json:"refreshToken"` - ExpiresAt int64 `json:"expiresAt"` - Scopes []string `json:"scopes,omitempty"` - SubscriptionType string `json:"subscriptionType,omitempty"` - RateLimitTier string `json:"rateLimitTier,omitempty"` - IsMax bool `json:"isMax,omitempty"` -} - -func (c *oauthCredentials) needsRefresh() bool { - if c.ExpiresAt == 0 { - return false - } - return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs -} - -func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { - if credentials.RefreshToken == "" { - return nil, E.New("refresh token is empty") - } - - requestBody, err := json.Marshal(map[string]string{ - "grant_type": "refresh_token", - "refresh_token": credentials.RefreshToken, - "client_id": oauth2ClientID, - }) - if err != nil { - return nil, E.Cause(err, "marshal request") - } - - response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { - request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody)) + request, err := buildRequest() if err != nil { return nil, err } - request.Header.Set("Content-Type", "application/json") - request.Header.Set("User-Agent", ccmUserAgentValue) - return request, nil + response, err := client.Do(request) + if err == nil { + return response, nil + } + lastError = err + if ctx.Err() != nil { + return nil, lastError + } + } + return nil, lastError +} + +type credentialState struct { + fiveHourUtilization float64 + fiveHourReset time.Time + weeklyUtilization float64 + weeklyReset time.Time + hardRateLimited bool + rateLimitResetAt time.Time + accountType string + rateLimitTier string + remotePlanWeight float64 + lastUpdated time.Time + consecutivePollFailures int + unavailable bool + lastCredentialLoadAttempt time.Time + lastCredentialLoadError string +} + +type credentialRequestContext struct { + context.Context + releaseOnce sync.Once + cancelOnce sync.Once + releaseFuncs []func() bool + cancelFunc context.CancelFunc +} + +func (c *credentialRequestContext) addInterruptLink(stop func() bool) { + c.releaseFuncs = append(c.releaseFuncs, stop) +} + +func (c *credentialRequestContext) releaseCredentialInterrupt() { + c.releaseOnce.Do(func() { + for _, f := range c.releaseFuncs { + f() + } }) +} + +func (c *credentialRequestContext) cancelRequest() { + c.releaseCredentialInterrupt() + c.cancelOnce.Do(c.cancelFunc) +} + +type credential interface { + tagName() string + isAvailable() bool + isUsable() bool + isExternal() bool + fiveHourUtilization() float64 + weeklyUtilization() float64 + fiveHourCap() float64 + weeklyCap() float64 + planWeight() float64 + weeklyResetTime() time.Time + markRateLimited(resetAt time.Time) + earliestReset() time.Time + unavailableError() error + + getAccessToken() (string, error) + buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) + updateStateFromHeaders(header http.Header) + + wrapRequestContext(ctx context.Context) *credentialRequestContext + interruptConnections() + + start() error + pollUsage(ctx context.Context) + lastUpdatedTime() time.Time + pollBackoff(base time.Duration) time.Duration + usageTrackerOrNil() *AggregatedUsage + httpClient() *http.Client + close() +} + +type credentialSelectionScope string + +const ( + credentialSelectionScopeAll credentialSelectionScope = "all" + credentialSelectionScopeNonExternal credentialSelectionScope = "non_external" +) + +type credentialSelection struct { + scope credentialSelectionScope + filter func(credential) bool +} + +func (s credentialSelection) allows(cred credential) bool { + return s.filter == nil || s.filter(cred) +} + +func (s credentialSelection) scopeOrDefault() credentialSelectionScope { + if s.scope == "" { + return credentialSelectionScopeAll + } + return s.scope +} + +// Claude Code's unified rate-limit handling parses these reset headers with +// Number(...), compares them against Date.now()/1000, and renders them via +// new Date(seconds*1000), so keep the wire format pinned to Unix epoch seconds. +func parseAnthropicResetHeaderValue(headerName string, headerValue string) time.Time { + unixEpoch, err := strconv.ParseInt(headerValue, 10, 64) if err != nil { - return nil, err + panic("invalid " + headerName + " header: expected Unix epoch seconds, got " + strconv.Quote(headerValue)) } - defer response.Body.Close() - - if response.StatusCode == http.StatusTooManyRequests { - body, _ := io.ReadAll(response.Body) - return nil, E.New("refresh rate limited: ", response.Status, " ", string(body)) + if unixEpoch <= 0 { + panic("invalid " + headerName + " header: expected positive Unix epoch seconds, got " + strconv.Quote(headerValue)) } - if response.StatusCode != http.StatusOK { - body, _ := io.ReadAll(response.Body) - return nil, E.New("refresh failed: ", response.Status, " ", string(body)) - } - - var tokenResponse struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int `json:"expires_in"` - } - err = json.NewDecoder(response.Body).Decode(&tokenResponse) - if err != nil { - return nil, E.Cause(err, "decode response") - } - - newCredentials := *credentials - newCredentials.AccessToken = tokenResponse.AccessToken - if tokenResponse.RefreshToken != "" { - newCredentials.RefreshToken = tokenResponse.RefreshToken - } - newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000 - - return &newCredentials, nil + return time.Unix(unixEpoch, 0) } -func cloneCredentials(credentials *oauthCredentials) *oauthCredentials { - if credentials == nil { - return nil +func parseOptionalAnthropicResetHeader(headers http.Header, headerName string) (time.Time, bool) { + headerValue := headers.Get(headerName) + if headerValue == "" { + return time.Time{}, false } - cloned := *credentials - cloned.Scopes = append([]string(nil), credentials.Scopes...) - return &cloned + return parseAnthropicResetHeaderValue(headerName, headerValue), true } -func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool { - if left == nil || right == nil { - return left == right +func parseRequiredAnthropicResetHeader(headers http.Header, headerName string) time.Time { + headerValue := headers.Get(headerName) + if headerValue == "" { + panic("missing required " + headerName + " header") + } + return parseAnthropicResetHeaderValue(headerName, headerValue) +} + +func parseRateLimitResetFromHeaders(headers http.Header) time.Time { + claim := headers.Get("anthropic-ratelimit-unified-representative-claim") + switch claim { + case "5h": + return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset") + case "7d": + return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset") + default: + panic("invalid anthropic-ratelimit-unified-representative-claim header: " + strconv.Quote(claim)) } - return left.AccessToken == right.AccessToken && - left.RefreshToken == right.RefreshToken && - left.ExpiresAt == right.ExpiresAt && - slices.Equal(left.Scopes, right.Scopes) && - left.SubscriptionType == right.SubscriptionType && - left.RateLimitTier == right.RateLimitTier && - left.IsMax == right.IsMax } diff --git a/service/ccm/credential_builder.go b/service/ccm/credential_builder.go new file mode 100644 index 000000000..c49a20195 --- /dev/null +++ b/service/ccm/credential_builder.go @@ -0,0 +1,192 @@ +package ccm + +import ( + "context" + "time" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +func buildCredentialProviders( + ctx context.Context, + options option.CCMServiceOptions, + logger log.ContextLogger, +) (map[string]credentialProvider, []credential, error) { + allCredentialMap := make(map[string]credential) + var allCreds []credential + providers := make(map[string]credentialProvider) + + // Pass 1: create default and external credentials + for _, credOpt := range options.Credentials { + switch credOpt.Type { + case "default": + cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) + if err != nil { + return nil, nil, err + } + allCredentialMap[credOpt.Tag] = cred + allCreds = append(allCreds, cred) + providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} + case "external": + cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger) + if err != nil { + return nil, nil, err + } + allCredentialMap[credOpt.Tag] = cred + allCreds = append(allCreds, cred) + providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} + } + } + + // Pass 2: create balancer providers + for _, credOpt := range options.Credentials { + if credOpt.Type == "balancer" { + subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag) + if err != nil { + return nil, nil, err + } + providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger) + } + } + + return providers, allCreds, nil +} + +func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) { + credentials := make([]credential, 0, len(tags)) + for _, tag := range tags { + cred, exists := allCredentials[tag] + if !exists { + return nil, E.New("credential ", parentTag, " references unknown credential: ", tag) + } + credentials = append(credentials, cred) + } + if len(credentials) == 0 { + return nil, E.New("credential ", parentTag, " has no sub-credentials") + } + return credentials, nil +} + +func validateCCMOptions(options option.CCMServiceOptions) error { + hasCredentials := len(options.Credentials) > 0 + hasLegacyPath := options.CredentialPath != "" + hasLegacyUsages := options.UsagesPath != "" + hasLegacyDetour := options.Detour != "" + + if hasCredentials && hasLegacyPath { + return E.New("credential_path and credentials are mutually exclusive") + } + if hasCredentials && hasLegacyUsages { + return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials") + } + if hasCredentials && hasLegacyDetour { + return E.New("detour and credentials are mutually exclusive; use detour on individual credentials") + } + + if hasCredentials { + tags := make(map[string]bool) + credentialTypes := make(map[string]string) + for _, cred := range options.Credentials { + if tags[cred.Tag] { + return E.New("duplicate credential tag: ", cred.Tag) + } + tags[cred.Tag] = true + credentialTypes[cred.Tag] = cred.Type + if cred.Type == "default" || cred.Type == "" { + if cred.DefaultOptions.Reserve5h > 99 { + return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99") + } + if cred.DefaultOptions.ReserveWeekly > 99 { + return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99") + } + if cred.DefaultOptions.Limit5h > 100 { + return E.New("credential ", cred.Tag, ": limit_5h must be at most 100") + } + if cred.DefaultOptions.LimitWeekly > 100 { + return E.New("credential ", cred.Tag, ": limit_weekly must be at most 100") + } + if cred.DefaultOptions.Reserve5h > 0 && cred.DefaultOptions.Limit5h > 0 { + return E.New("credential ", cred.Tag, ": reserve_5h and limit_5h are mutually exclusive") + } + if cred.DefaultOptions.ReserveWeekly > 0 && cred.DefaultOptions.LimitWeekly > 0 { + return E.New("credential ", cred.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") + } + } + if cred.Type == "external" { + if cred.ExternalOptions.Token == "" { + return E.New("credential ", cred.Tag, ": external credential requires token") + } + if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" { + return E.New("credential ", cred.Tag, ": reverse external credential requires url") + } + } + if cred.Type == "balancer" { + switch cred.BalancerOptions.Strategy { + case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback: + default: + return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy) + } + if cred.BalancerOptions.RebalanceThreshold < 0 { + return E.New("credential ", cred.Tag, ": rebalance_threshold must not be negative") + } + } + } + + for _, user := range options.Users { + if user.Credential == "" { + return E.New("user ", user.Name, " must specify credential in multi-credential mode") + } + if !tags[user.Credential] { + return E.New("user ", user.Name, " references unknown credential: ", user.Credential) + } + if user.ExternalCredential != "" { + if !tags[user.ExternalCredential] { + return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential) + } + if credentialTypes[user.ExternalCredential] != "external" { + return E.New("user ", user.Name, ": external_credential must reference an external type credential") + } + } + } + } + + return nil +} + +func credentialForUser( + userConfigMap map[string]*option.CCMUser, + providers map[string]credentialProvider, + legacyProvider credentialProvider, + username string, +) (credentialProvider, error) { + if legacyProvider != nil { + return legacyProvider, nil + } + userConfig, exists := userConfigMap[username] + if !exists { + return nil, E.New("no credential mapping for user: ", username) + } + provider, exists := providers[userConfig.Credential] + if !exists { + return nil, E.New("unknown credential: ", userConfig.Credential) + } + return provider, nil +} + +func noUserCredentialProvider( + providers map[string]credentialProvider, + legacyProvider credentialProvider, + options option.CCMServiceOptions, +) credentialProvider { + if legacyProvider != nil { + return legacyProvider + } + if len(options.Credentials) > 0 { + tag := options.Credentials[0].Tag + return providers[tag] + } + return nil +} diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go new file mode 100644 index 000000000..c44ec4103 --- /dev/null +++ b/service/ccm/credential_default.go @@ -0,0 +1,726 @@ +package ccm + +import ( + "bytes" + "context" + stdTLS "crypto/tls" + "encoding/json" + "io" + "math" + "net" + "net/http" + "strconv" + "sync" + "time" + + "github.com/sagernet/fswatch" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/ntp" +) + +type defaultCredential struct { + tag string + serviceContext context.Context + credentialPath string + credentialFilePath string + credentials *oauthCredentials + access sync.RWMutex + state credentialState + stateAccess sync.RWMutex + pollAccess sync.Mutex + reloadAccess sync.Mutex + watcherAccess sync.Mutex + cap5h float64 + capWeekly float64 + usageTracker *AggregatedUsage + forwardHTTPClient *http.Client + logger log.ContextLogger + watcher *fswatch.Watcher + watcherRetryAt time.Time + + // Connection interruption + onBecameUnusable func() + interrupted bool + requestContext context.Context + cancelRequests context.CancelFunc + requestAccess sync.Mutex +} + +func newDefaultCredential(ctx context.Context, tag string, options option.CCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) { + credentialDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{ + Detour: options.Detour, + }, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "create dialer for credential ", tag) + } + httpClient := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSClientConfig: &stdTLS.Config{ + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + }, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + }, + } + reserve5h := options.Reserve5h + if reserve5h == 0 { + reserve5h = 1 + } + reserveWeekly := options.ReserveWeekly + if reserveWeekly == 0 { + reserveWeekly = 1 + } + var cap5h float64 + if options.Limit5h > 0 { + cap5h = float64(options.Limit5h) + } else { + cap5h = float64(100 - reserve5h) + } + var capWeekly float64 + if options.LimitWeekly > 0 { + capWeekly = float64(options.LimitWeekly) + } else { + capWeekly = float64(100 - reserveWeekly) + } + requestContext, cancelRequests := context.WithCancel(context.Background()) + credential := &defaultCredential{ + tag: tag, + serviceContext: ctx, + credentialPath: options.CredentialPath, + cap5h: cap5h, + capWeekly: capWeekly, + forwardHTTPClient: httpClient, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + } + if options.UsagesPath != "" { + credential.usageTracker = &AggregatedUsage{ + LastUpdated: time.Now(), + Combinations: make([]CostCombination, 0), + filePath: options.UsagesPath, + logger: logger, + } + } + return credential, nil +} + +func (c *defaultCredential) start() error { + credentialFilePath, err := resolveCredentialFilePath(c.credentialPath) + if err != nil { + return E.Cause(err, "resolve credential path for ", c.tag) + } + c.credentialFilePath = credentialFilePath + err = c.ensureCredentialWatcher() + if err != nil { + c.logger.Debug("start credential watcher for ", c.tag, ": ", err) + } + err = c.reloadCredentials(true) + if err != nil { + c.logger.Warn("initial credential load for ", c.tag, ": ", err) + } + if c.usageTracker != nil { + err = c.usageTracker.Load() + if err != nil { + c.logger.Warn("load usage statistics for ", c.tag, ": ", err) + } + } + return nil +} + +func (c *defaultCredential) getAccessToken() (string, error) { + c.retryCredentialReloadIfNeeded() + + c.access.RLock() + if c.credentials != nil && !c.credentials.needsRefresh() { + token := c.credentials.AccessToken + c.access.RUnlock() + return token, nil + } + c.access.RUnlock() + + err := c.reloadCredentials(true) + if err == nil { + c.access.RLock() + if c.credentials != nil && !c.credentials.needsRefresh() { + token := c.credentials.AccessToken + c.access.RUnlock() + return token, nil + } + c.access.RUnlock() + } + + c.access.Lock() + defer c.access.Unlock() + + if c.credentials == nil { + return "", c.unavailableError() + } + if !c.credentials.needsRefresh() { + return c.credentials.AccessToken, nil + } + + err = platformCanWriteCredentials(c.credentialPath) + if err != nil { + return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation") + } + + baseCredentials := cloneCredentials(c.credentials) + newCredentials, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials) + if err != nil { + return "", err + } + + latestCredentials, latestErr := platformReadCredentials(c.credentialPath) + if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) { + c.credentials = latestCredentials + c.stateAccess.Lock() + c.state.unavailable = false + c.state.lastCredentialLoadAttempt = time.Now() + c.state.lastCredentialLoadError = "" + c.state.accountType = latestCredentials.SubscriptionType + c.state.rateLimitTier = latestCredentials.RateLimitTier + c.checkTransitionLocked() + c.stateAccess.Unlock() + if !latestCredentials.needsRefresh() { + return latestCredentials.AccessToken, nil + } + return "", E.New("credential ", c.tag, " changed while refreshing") + } + + c.credentials = newCredentials + c.stateAccess.Lock() + c.state.unavailable = false + c.state.lastCredentialLoadAttempt = time.Now() + c.state.lastCredentialLoadError = "" + c.state.accountType = newCredentials.SubscriptionType + c.state.rateLimitTier = newCredentials.RateLimitTier + c.checkTransitionLocked() + c.stateAccess.Unlock() + + err = platformWriteCredentials(newCredentials, c.credentialPath) + if err != nil { + c.logger.Error("persist refreshed token for ", c.tag, ": ", err) + } + + return newCredentials.AccessToken, nil +} + +func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { + c.stateAccess.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization + hadData := false + + fiveHourResetChanged := false + if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists { + hadData = true + if value.After(c.state.fiveHourReset) { + fiveHourResetChanged = true + c.state.fiveHourReset = value + } + } + if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" { + value, err := strconv.ParseFloat(utilization, 64) + if err == nil { + hadData = true + newValue := math.Ceil(value * 100) + if newValue >= c.state.fiveHourUtilization || fiveHourResetChanged { + c.state.fiveHourUtilization = newValue + } + } + } + + weeklyResetChanged := false + if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists { + hadData = true + if value.After(c.state.weeklyReset) { + weeklyResetChanged = true + c.state.weeklyReset = value + } + } + if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" { + value, err := strconv.ParseFloat(utilization, 64) + if err == nil { + hadData = true + newValue := math.Ceil(value * 100) + if newValue >= c.state.weeklyUtilization || weeklyResetChanged { + c.state.weeklyUtilization = newValue + } + } + } + if hadData { + c.state.consecutivePollFailures = 0 + c.state.lastUpdated = time.Now() + } + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) + } + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) markRateLimited(resetAt time.Time) { + c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) + c.stateAccess.Lock() + c.state.hardRateLimited = true + c.state.rateLimitResetAt = resetAt + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) isUsable() bool { + c.retryCredentialReloadIfNeeded() + + c.stateAccess.RLock() + if c.state.unavailable { + c.stateAccess.RUnlock() + return false + } + if c.state.consecutivePollFailures > 0 { + c.stateAccess.RUnlock() + return false + } + if c.state.hardRateLimited { + if time.Now().Before(c.state.rateLimitResetAt) { + c.stateAccess.RUnlock() + return false + } + c.stateAccess.RUnlock() + c.stateAccess.Lock() + if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + usable := c.checkReservesLocked() + c.stateAccess.Unlock() + return usable + } + usable := c.checkReservesLocked() + c.stateAccess.RUnlock() + return usable +} + +func (c *defaultCredential) checkReservesLocked() bool { + if c.state.fiveHourUtilization >= c.cap5h { + return false + } + if c.state.weeklyUtilization >= c.capWeekly { + return false + } + return true +} + +// checkTransitionLocked detects usable→unusable transition. +// Must be called with stateAccess write lock held. +func (c *defaultCredential) checkTransitionLocked() bool { + unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() || c.state.consecutivePollFailures > 0 + if unusable && !c.interrupted { + c.interrupted = true + return true + } + if !unusable && c.interrupted { + c.interrupted = false + } + return false +} + +func (c *defaultCredential) interruptConnections() { + c.logger.Warn("interrupting connections for ", c.tag) + c.requestAccess.Lock() + c.cancelRequests() + c.requestContext, c.cancelRequests = context.WithCancel(context.Background()) + c.requestAccess.Unlock() + if c.onBecameUnusable != nil { + c.onBecameUnusable() + } +} + +func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext { + c.requestAccess.Lock() + credentialContext := c.requestContext + c.requestAccess.Unlock() + derived, cancel := context.WithCancel(parent) + stop := context.AfterFunc(credentialContext, func() { + cancel() + }) + return &credentialRequestContext{ + Context: derived, + releaseFuncs: []func() bool{stop}, + cancelFunc: cancel, + } +} + +func (c *defaultCredential) weeklyUtilization() float64 { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.weeklyUtilization +} + +func (c *defaultCredential) planWeight() float64 { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return ccmPlanWeight(c.state.accountType, c.state.rateLimitTier) +} + +func (c *defaultCredential) weeklyResetTime() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.weeklyReset +} + +func (c *defaultCredential) isAvailable() bool { + c.retryCredentialReloadIfNeeded() + + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return !c.state.unavailable +} + +func (c *defaultCredential) unavailableError() error { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + if !c.state.unavailable { + return nil + } + if c.state.lastCredentialLoadError == "" { + return E.New("credential ", c.tag, " is unavailable") + } + return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError) +} + +func (c *defaultCredential) lastUpdatedTime() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.lastUpdated +} + +func (c *defaultCredential) markUsagePollAttempted() { + c.stateAccess.Lock() + defer c.stateAccess.Unlock() + c.state.lastUpdated = time.Now() +} + +func (c *defaultCredential) incrementPollFailures() { + c.stateAccess.Lock() + c.state.consecutivePollFailures++ + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration { + c.stateAccess.RLock() + failures := c.state.consecutivePollFailures + c.stateAccess.RUnlock() + if failures <= 0 { + return baseInterval + } + backoff := failedPollRetryInterval * time.Duration(1<<(failures-1)) + if backoff > httpRetryMaxBackoff { + return httpRetryMaxBackoff + } + return backoff +} + +func (c *defaultCredential) isPollBackoffAtCap() bool { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + failures := c.state.consecutivePollFailures + return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff +} + +func (c *defaultCredential) earliestReset() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + if c.state.unavailable { + return time.Time{} + } + if c.state.hardRateLimited { + return c.state.rateLimitResetAt + } + earliest := c.state.fiveHourReset + if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) { + earliest = c.state.weeklyReset + } + return earliest +} + +func (c *defaultCredential) pollUsage(ctx context.Context) { + if !c.pollAccess.TryLock() { + return + } + defer c.pollAccess.Unlock() + defer c.markUsagePollAttempted() + + c.retryCredentialReloadIfNeeded() + if !c.isAvailable() { + return + } + + accessToken, err := c.getAccessToken() + if err != nil { + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": get token: ", err) + } + c.incrementPollFailures() + return + } + + httpClient := &http.Client{ + Transport: c.forwardHTTPClient.Transport, + Timeout: 5 * time.Second, + } + + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+accessToken) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("User-Agent", ccmUserAgentValue) + request.Header.Set("anthropic-beta", anthropicBetaOAuthValue) + return request, nil + }) + if err != nil { + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": ", err) + } + c.incrementPollFailures() + return + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + if response.StatusCode == http.StatusTooManyRequests { + c.logger.Warn("poll usage for ", c.tag, ": rate limited") + } + body, _ := io.ReadAll(response.Body) + c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) + c.incrementPollFailures() + return + } + + var usageResponse struct { + FiveHour struct { + Utilization float64 `json:"utilization"` + ResetsAt time.Time `json:"resets_at"` + } `json:"five_hour"` + SevenDay struct { + Utilization float64 `json:"utilization"` + ResetsAt time.Time `json:"resets_at"` + } `json:"seven_day"` + } + err = json.NewDecoder(response.Body).Decode(&usageResponse) + if err != nil { + c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) + c.incrementPollFailures() + return + } + + c.stateAccess.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization + c.state.consecutivePollFailures = 0 + c.state.fiveHourUtilization = usageResponse.FiveHour.Utilization + if !usageResponse.FiveHour.ResetsAt.IsZero() { + c.state.fiveHourReset = usageResponse.FiveHour.ResetsAt + } + c.state.weeklyUtilization = usageResponse.SevenDay.Utilization + if !usageResponse.SevenDay.ResetsAt.IsZero() { + c.state.weeklyReset = usageResponse.SevenDay.ResetsAt + } + if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) + } + needsProfileFetch := c.state.rateLimitTier == "" + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } + + if needsProfileFetch { + c.fetchProfile(ctx, httpClient, accessToken) + } +} + +func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.Client, accessToken string) { + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/profile", nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+accessToken) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("User-Agent", ccmUserAgentValue) + return request, nil + }) + if err != nil { + c.logger.Debug("fetch profile for ", c.tag, ": ", err) + return + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return + } + + var profileResponse struct { + Organization *struct { + OrganizationType string `json:"organization_type"` + RateLimitTier string `json:"rate_limit_tier"` + } `json:"organization"` + } + err = json.NewDecoder(response.Body).Decode(&profileResponse) + if err != nil || profileResponse.Organization == nil { + return + } + + accountType := "" + switch profileResponse.Organization.OrganizationType { + case "claude_pro": + accountType = "pro" + case "claude_max": + accountType = "max" + case "claude_team": + accountType = "team" + case "claude_enterprise": + accountType = "enterprise" + } + rateLimitTier := profileResponse.Organization.RateLimitTier + + c.stateAccess.Lock() + if accountType != "" && c.state.accountType == "" { + c.state.accountType = accountType + } + if rateLimitTier != "" { + c.state.rateLimitTier = rateLimitTier + } + c.stateAccess.Unlock() + c.logger.Info("fetched profile for ", c.tag, ": type=", c.state.accountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(c.state.accountType, rateLimitTier)) +} + +func (c *defaultCredential) close() { + if c.watcher != nil { + err := c.watcher.Close() + if err != nil { + c.logger.Error("close credential watcher for ", c.tag, ": ", err) + } + } + if c.usageTracker != nil { + c.usageTracker.cancelPendingSave() + err := c.usageTracker.Save() + if err != nil { + c.logger.Error("save usage statistics for ", c.tag, ": ", err) + } + } +} + +func (c *defaultCredential) tagName() string { + return c.tag +} + +func (c *defaultCredential) isExternal() bool { + return false +} + +func (c *defaultCredential) fiveHourUtilization() float64 { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.fiveHourUtilization +} + +func (c *defaultCredential) fiveHourCap() float64 { + return c.cap5h +} + +func (c *defaultCredential) weeklyCap() float64 { + return c.capWeekly +} + +func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage { + return c.usageTracker +} + +func (c *defaultCredential) httpClient() *http.Client { + return c.forwardHTTPClient +} + +func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) { + accessToken, err := c.getAccessToken() + if err != nil { + return nil, E.Cause(err, "get access token for ", c.tag) + } + + proxyURL := claudeAPIBaseURL + original.URL.RequestURI() + var body io.Reader + if bodyBytes != nil { + body = bytes.NewReader(bodyBytes) + } else { + body = original.Body + } + proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body) + if err != nil { + return nil, err + } + + for key, values := range original.Header { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" { + proxyRequest.Header[key] = values + } + } + + serviceOverridesAcceptEncoding := len(serviceHeaders.Values("Accept-Encoding")) > 0 + if c.usageTracker != nil && !serviceOverridesAcceptEncoding { + proxyRequest.Header.Del("Accept-Encoding") + } + + anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta") + if anthropicBetaHeader != "" { + proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader) + } else { + proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue) + } + + for key, values := range serviceHeaders { + proxyRequest.Header.Del(key) + proxyRequest.Header[key] = values + } + proxyRequest.Header.Set("Authorization", "Bearer "+accessToken) + + return proxyRequest, nil +} diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index b7e04bad3..24ddf6c4a 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -29,16 +29,16 @@ import ( const reverseProxyBaseURL = "http://reverse-proxy" type externalCredential struct { - tag string - baseURL string - token string - httpClient *http.Client - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - pollInterval time.Duration - usageTracker *AggregatedUsage - logger log.ContextLogger + tag string + baseURL string + token string + forwardHTTPClient *http.Client + state credentialState + stateAccess sync.RWMutex + pollAccess sync.Mutex + pollInterval time.Duration + usageTracker *AggregatedUsage + logger log.ContextLogger onBecameUnusable func() interrupted bool @@ -128,7 +128,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx if options.URL == "" { // Receiver mode: no URL, wait for reverse connection cred.baseURL = reverseProxyBaseURL - cred.httpClient = &http.Client{ + cred.forwardHTTPClient = &http.Client{ Transport: &http.Transport{ ForceAttemptHTTP2: false, DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { @@ -192,10 +192,10 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx Time: ntp.TimeFuncFromContext(ctx), } } - cred.httpClient = &http.Client{Transport: transport} + cred.forwardHTTPClient = &http.Client{Transport: transport} } else { // Normal mode: standard HTTP client for proxying - cred.httpClient = &http.Client{Transport: transport} + cred.forwardHTTPClient = &http.Client{Transport: transport} cred.reverseHttpClient = &http.Client{ Transport: &http.Transport{ ForceAttemptHTTP2: false, @@ -248,40 +248,40 @@ func (c *externalCredential) isUsable() bool { if !c.isAvailable() { return false } - c.stateMutex.RLock() + c.stateAccess.RLock() if c.state.consecutivePollFailures > 0 { - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() return false } if c.state.hardRateLimited { if time.Now().Before(c.state.rateLimitResetAt) { - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() return false } - c.stateMutex.RUnlock() - c.stateMutex.Lock() + c.stateAccess.RUnlock() + c.stateAccess.Lock() if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } // No reserve for external: only 100% is unusable usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100 - c.stateMutex.Unlock() + c.stateAccess.Unlock() return usable } usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100 - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() return usable } func (c *externalCredential) fiveHourUtilization() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() return c.state.fiveHourUtilization } func (c *externalCredential) weeklyUtilization() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() return c.state.weeklyUtilization } @@ -294,8 +294,8 @@ func (c *externalCredential) weeklyCap() float64 { } func (c *externalCredential) planWeight() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() if c.state.remotePlanWeight > 0 { return c.state.remotePlanWeight } @@ -303,26 +303,26 @@ func (c *externalCredential) planWeight() float64 { } func (c *externalCredential) weeklyResetTime() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() return c.state.weeklyReset } func (c *externalCredential) markRateLimited(resetAt time.Time) { c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.hardRateLimited = true c.state.rateLimitResetAt = resetAt shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } } func (c *externalCredential) earliestReset() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() if c.state.hardRateLimited { return c.state.rateLimitResetAt } @@ -408,7 +408,7 @@ func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Con } func (c *externalCredential) updateStateFromHeaders(headers http.Header) { - c.stateMutex.Lock() + c.stateAccess.Lock() isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization @@ -455,7 +455,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } @@ -530,9 +530,9 @@ func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Resp } } // Forward transport with retries - if c.httpClient != nil { + if c.forwardHTTPClient != nil { forwardClient := &http.Client{ - Transport: c.httpClient.Transport, + Transport: c.forwardHTTPClient.Transport, Timeout: 5 * time.Second, } return doHTTPWithRetry(ctx, forwardClient, buildRequest(c.baseURL)) @@ -563,10 +563,10 @@ func (c *externalCredential) pollUsage(ctx context.Context) { // 404 means the remote does not have a status endpoint yet; // usage will be updated passively from response headers. if response.StatusCode == http.StatusNotFound { - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.consecutivePollFailures = 0 c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() } else { c.incrementPollFailures() } @@ -585,7 +585,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) { return } - c.stateMutex.Lock() + c.stateAccess.Lock() isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization @@ -606,28 +606,28 @@ func (c *externalCredential) pollUsage(ctx context.Context) { c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } } func (c *externalCredential) lastUpdatedTime() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() return c.state.lastUpdated } func (c *externalCredential) markUsagePollAttempted() { - c.stateMutex.Lock() - defer c.stateMutex.Unlock() + c.stateAccess.Lock() + defer c.stateAccess.Unlock() c.state.lastUpdated = time.Now() } func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration { - c.stateMutex.RLock() + c.stateAccess.RLock() failures := c.state.consecutivePollFailures - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() if failures <= 0 { return baseInterval } @@ -639,17 +639,17 @@ func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Durati } func (c *externalCredential) isPollBackoffAtCap() bool { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() failures := c.state.consecutivePollFailures return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff } func (c *externalCredential) incrementPollFailures() { - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.consecutivePollFailures++ shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } @@ -659,14 +659,14 @@ func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage { return c.usageTracker } -func (c *externalCredential) httpTransport() *http.Client { +func (c *externalCredential) httpClient() *http.Client { if c.reverseHttpClient != nil { session := c.getReverseSession() if session != nil && !session.IsClosed() { return c.reverseHttpClient } } - return c.httpClient + return c.forwardHTTPClient } func (c *externalCredential) close() { diff --git a/service/ccm/credential_file.go b/service/ccm/credential_file.go index eba920726..72d9da010 100644 --- a/service/ccm/credential_file.go +++ b/service/ccm/credential_file.go @@ -62,10 +62,10 @@ func (c *defaultCredential) ensureCredentialWatcher() error { } func (c *defaultCredential) retryCredentialReloadIfNeeded() { - c.stateMutex.RLock() + c.stateAccess.RLock() unavailable := c.state.unavailable lastAttempt := c.state.lastCredentialLoadAttempt - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() if !unavailable { return } @@ -84,10 +84,10 @@ func (c *defaultCredential) reloadCredentials(force bool) error { c.reloadAccess.Lock() defer c.reloadAccess.Unlock() - c.stateMutex.RLock() + c.stateAccess.RLock() unavailable := c.state.unavailable lastAttempt := c.state.lastCredentialLoadAttempt - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() if !force { if !unavailable { return nil @@ -97,43 +97,43 @@ func (c *defaultCredential) reloadCredentials(force bool) error { } } - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.lastCredentialLoadAttempt = time.Now() - c.stateMutex.Unlock() + c.stateAccess.Unlock() credentials, err := platformReadCredentials(c.credentialPath) if err != nil { return c.markCredentialsUnavailable(E.Cause(err, "read credentials")) } - c.accessMutex.Lock() + c.access.Lock() c.credentials = credentials - c.accessMutex.Unlock() + c.access.Unlock() - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.unavailable = false c.state.lastCredentialLoadError = "" c.state.accountType = credentials.SubscriptionType c.state.rateLimitTier = credentials.RateLimitTier c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() return nil } func (c *defaultCredential) markCredentialsUnavailable(err error) error { - c.accessMutex.Lock() + c.access.Lock() hadCredentials := c.credentials != nil c.credentials = nil - c.accessMutex.Unlock() + c.access.Unlock() - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.unavailable = true c.state.lastCredentialLoadError = err.Error() c.state.accountType = "" c.state.rateLimitTier = "" shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt && hadCredentials { c.interruptConnections() diff --git a/service/ccm/credential_oauth.go b/service/ccm/credential_oauth.go new file mode 100644 index 000000000..da559c173 --- /dev/null +++ b/service/ccm/credential_oauth.go @@ -0,0 +1,224 @@ +package ccm + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "os" + "os/user" + "path/filepath" + "runtime" + "slices" + "sync" + "time" + + "github.com/sagernet/sing-box/log" + E "github.com/sagernet/sing/common/exceptions" +) + +const ( + oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" + oauth2TokenURL = "https://platform.claude.com/v1/oauth/token" + claudeAPIBaseURL = "https://api.anthropic.com" + tokenRefreshBufferMs = 60000 + anthropicBetaOAuthValue = "oauth-2025-04-20" +) + +const ccmUserAgentFallback = "claude-code/2.1.72" + +var ( + ccmUserAgentOnce sync.Once + ccmUserAgentValue string +) + +func initCCMUserAgent(logger log.ContextLogger) { + ccmUserAgentOnce.Do(func() { + version, err := detectClaudeCodeVersion() + if err != nil { + logger.Error("detect Claude Code version: ", err) + ccmUserAgentValue = ccmUserAgentFallback + return + } + logger.Debug("detected Claude Code version: ", version) + ccmUserAgentValue = "claude-code/" + version + }) +} + +func detectClaudeCodeVersion() (string, error) { + userInfo, err := getRealUser() + if err != nil { + return "", E.Cause(err, "get user") + } + binaryName := "claude" + if runtime.GOOS == "windows" { + binaryName = "claude.exe" + } + linkPath := filepath.Join(userInfo.HomeDir, ".local", "bin", binaryName) + target, err := os.Readlink(linkPath) + if err != nil { + return "", E.Cause(err, "readlink ", linkPath) + } + if !filepath.IsAbs(target) { + target = filepath.Join(filepath.Dir(linkPath), target) + } + parent := filepath.Base(filepath.Dir(target)) + if parent != "versions" { + return "", E.New("unexpected symlink target: ", target) + } + return filepath.Base(target), nil +} + +func getRealUser() (*user.User, error) { + if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { + sudoUserInfo, err := user.Lookup(sudoUser) + if err == nil { + return sudoUserInfo, nil + } + } + return user.Current() +} + +func getDefaultCredentialsPath() (string, error) { + if configDir := os.Getenv("CLAUDE_CONFIG_DIR"); configDir != "" { + return filepath.Join(configDir, ".credentials.json"), nil + } + userInfo, err := getRealUser() + if err != nil { + return "", err + } + return filepath.Join(userInfo.HomeDir, ".claude", ".credentials.json"), nil +} + +func readCredentialsFromFile(path string) (*oauthCredentials, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var credentialsContainer struct { + ClaudeAIAuth *oauthCredentials `json:"claudeAiOauth,omitempty"` + } + err = json.Unmarshal(data, &credentialsContainer) + if err != nil { + return nil, err + } + if credentialsContainer.ClaudeAIAuth == nil { + return nil, E.New("claudeAiOauth field not found in credentials") + } + return credentialsContainer.ClaudeAIAuth, nil +} + +func checkCredentialFileWritable(path string) error { + file, err := os.OpenFile(path, os.O_WRONLY, 0) + if err != nil { + return err + } + return file.Close() +} + +func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) error { + data, err := json.MarshalIndent(map[string]any{ + "claudeAiOauth": oauthCredentials, + }, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0o600) +} + +type oauthCredentials struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ExpiresAt int64 `json:"expiresAt"` + Scopes []string `json:"scopes,omitempty"` + SubscriptionType string `json:"subscriptionType,omitempty"` + RateLimitTier string `json:"rateLimitTier,omitempty"` + IsMax bool `json:"isMax,omitempty"` +} + +func (c *oauthCredentials) needsRefresh() bool { + if c.ExpiresAt == 0 { + return false + } + return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs +} + +func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { + if credentials.RefreshToken == "" { + return nil, E.New("refresh token is empty") + } + + requestBody, err := json.Marshal(map[string]string{ + "grant_type": "refresh_token", + "refresh_token": credentials.RefreshToken, + "client_id": oauth2ClientID, + }) + if err != nil { + return nil, E.Cause(err, "marshal request") + } + + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody)) + if err != nil { + return nil, err + } + request.Header.Set("Content-Type", "application/json") + request.Header.Set("User-Agent", ccmUserAgentValue) + return request, nil + }) + if err != nil { + return nil, err + } + defer response.Body.Close() + + if response.StatusCode == http.StatusTooManyRequests { + body, _ := io.ReadAll(response.Body) + return nil, E.New("refresh rate limited: ", response.Status, " ", string(body)) + } + if response.StatusCode != http.StatusOK { + body, _ := io.ReadAll(response.Body) + return nil, E.New("refresh failed: ", response.Status, " ", string(body)) + } + + var tokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + } + err = json.NewDecoder(response.Body).Decode(&tokenResponse) + if err != nil { + return nil, E.Cause(err, "decode response") + } + + newCredentials := *credentials + newCredentials.AccessToken = tokenResponse.AccessToken + if tokenResponse.RefreshToken != "" { + newCredentials.RefreshToken = tokenResponse.RefreshToken + } + newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000 + + return &newCredentials, nil +} + +func cloneCredentials(credentials *oauthCredentials) *oauthCredentials { + if credentials == nil { + return nil + } + cloned := *credentials + cloned.Scopes = append([]string(nil), credentials.Scopes...) + return &cloned +} + +func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool { + if left == nil || right == nil { + return left == right + } + return left.AccessToken == right.AccessToken && + left.RefreshToken == right.RefreshToken && + left.ExpiresAt == right.ExpiresAt && + slices.Equal(left.Scopes, right.Scopes) && + left.SubscriptionType == right.SubscriptionType && + left.RateLimitTier == right.RateLimitTier && + left.IsMax == right.IsMax +} diff --git a/service/ccm/credential_provider.go b/service/ccm/credential_provider.go new file mode 100644 index 000000000..cd77bfcdc --- /dev/null +++ b/service/ccm/credential_provider.go @@ -0,0 +1,405 @@ +package ccm + +import ( + "context" + "math/rand/v2" + "sync" + "sync/atomic" + "time" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + E "github.com/sagernet/sing/common/exceptions" +) + +type credentialProvider interface { + selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) + onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential + linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool + pollIfStale(ctx context.Context) + allCredentials() []credential + close() +} + +type singleCredentialProvider struct { + cred credential + sessionAccess sync.RWMutex + sessions map[string]time.Time +} + +func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + if !selection.allows(p.cred) { + return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") + } + if !p.cred.isAvailable() { + return nil, false, p.cred.unavailableError() + } + if !p.cred.isUsable() { + return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited") + } + var isNew bool + if sessionID != "" { + p.sessionAccess.Lock() + if p.sessions == nil { + p.sessions = make(map[string]time.Time) + } + _, exists := p.sessions[sessionID] + if !exists { + p.sessions[sessionID] = time.Now() + isNew = true + } + p.sessionAccess.Unlock() + } + return p.cred, isNew, nil +} + +func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential { + cred.markRateLimited(resetAt) + return nil +} + +func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { + now := time.Now() + p.sessionAccess.Lock() + for id, createdAt := range p.sessions { + if now.Sub(createdAt) > sessionExpiry { + delete(p.sessions, id) + } + } + p.sessionAccess.Unlock() + + if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) { + p.cred.pollUsage(ctx) + } +} + +func (p *singleCredentialProvider) allCredentials() []credential { + return []credential{p.cred} +} + +func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { + return func() bool { + return false + } +} + +func (p *singleCredentialProvider) close() {} + +type sessionEntry struct { + tag string + selectionScope credentialSelectionScope + createdAt time.Time +} + +type credentialInterruptKey struct { + tag string + selectionScope credentialSelectionScope +} + +type credentialInterruptEntry struct { + context context.Context + cancel context.CancelFunc +} + +type balancerProvider struct { + credentials []credential + strategy string + roundRobinIndex atomic.Uint64 + pollInterval time.Duration + rebalanceThreshold float64 + sessionAccess sync.RWMutex + sessions map[string]sessionEntry + interruptAccess sync.Mutex + credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry + logger log.ContextLogger +} + +func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { + if pollInterval <= 0 { + pollInterval = defaultPollInterval + } + return &balancerProvider{ + credentials: credentials, + strategy: strategy, + pollInterval: pollInterval, + rebalanceThreshold: rebalanceThreshold, + sessions: make(map[string]sessionEntry), + credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry), + logger: logger, + } +} + +func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + if p.strategy == C.BalancerStrategyFallback { + best := p.pickCredential(selection.filter) + if best == nil { + return nil, false, allCredentialsUnavailableError(p.credentials) + } + return best, false, nil + } + + selectionScope := selection.scopeOrDefault() + if sessionID != "" { + p.sessionAccess.RLock() + entry, exists := p.sessions[sessionID] + p.sessionAccess.RUnlock() + if exists { + if entry.selectionScope == selectionScope { + for _, cred := range p.credentials { + if cred.tagName() == entry.tag && selection.allows(cred) && cred.isUsable() { + if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) { + better := p.pickLeastUsed(selection.filter) + if better != nil && better.tagName() != cred.tagName() { + effectiveThreshold := p.rebalanceThreshold / cred.planWeight() + delta := cred.weeklyUtilization() - better.weeklyUtilization() + if delta > effectiveThreshold { + p.logger.Info("rebalancing away from ", cred.tagName(), + ": utilization delta ", delta, "% exceeds effective threshold ", + effectiveThreshold, "% (weight ", cred.planWeight(), ")") + p.rebalanceCredential(cred.tagName(), selectionScope) + break + } + } + } + return cred, false, nil + } + } + } + p.sessionAccess.Lock() + delete(p.sessions, sessionID) + p.sessionAccess.Unlock() + } + } + + best := p.pickCredential(selection.filter) + if best == nil { + return nil, false, allCredentialsUnavailableError(p.credentials) + } + + isNew := sessionID != "" + if isNew { + p.sessionAccess.Lock() + p.sessions[sessionID] = sessionEntry{ + tag: best.tagName(), + selectionScope: selectionScope, + createdAt: time.Now(), + } + p.sessionAccess.Unlock() + } + return best, isNew, nil +} + +func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) { + key := credentialInterruptKey{tag: tag, selectionScope: selectionScope} + p.interruptAccess.Lock() + if entry, loaded := p.credentialInterrupts[key]; loaded { + entry.cancel() + } + ctx, cancel := context.WithCancel(context.Background()) + p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel} + p.interruptAccess.Unlock() + + p.sessionAccess.Lock() + for id, entry := range p.sessions { + if entry.tag == tag && entry.selectionScope == selectionScope { + delete(p.sessions, id) + } + } + p.sessionAccess.Unlock() +} + +func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool { + if p.strategy == C.BalancerStrategyFallback { + return func() bool { return false } + } + key := credentialInterruptKey{ + tag: cred.tagName(), + selectionScope: selection.scopeOrDefault(), + } + p.interruptAccess.Lock() + entry, loaded := p.credentialInterrupts[key] + if !loaded { + ctx, cancel := context.WithCancel(context.Background()) + entry = credentialInterruptEntry{context: ctx, cancel: cancel} + p.credentialInterrupts[key] = entry + } + p.interruptAccess.Unlock() + return context.AfterFunc(entry.context, onInterrupt) +} + +func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential { + cred.markRateLimited(resetAt) + if p.strategy == C.BalancerStrategyFallback { + return p.pickCredential(selection.filter) + } + if sessionID != "" { + p.sessionAccess.Lock() + delete(p.sessions, sessionID) + p.sessionAccess.Unlock() + } + + best := p.pickCredential(selection.filter) + if best != nil && sessionID != "" { + p.sessionAccess.Lock() + p.sessions[sessionID] = sessionEntry{ + tag: best.tagName(), + selectionScope: selection.scopeOrDefault(), + createdAt: time.Now(), + } + p.sessionAccess.Unlock() + } + return best +} + +func (p *balancerProvider) pickCredential(filter func(credential) bool) credential { + switch p.strategy { + case C.BalancerStrategyRoundRobin: + return p.pickRoundRobin(filter) + case C.BalancerStrategyRandom: + return p.pickRandom(filter) + case C.BalancerStrategyFallback: + return p.pickFallback(filter) + default: + return p.pickLeastUsed(filter) + } +} + +func (p *balancerProvider) pickFallback(filter func(credential) bool) credential { + for _, cred := range p.credentials { + if filter != nil && !filter(cred) { + continue + } + if cred.isUsable() { + return cred + } + } + return nil +} + +const weeklyWindowHours = 7 * 24 + +func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { + var best credential + bestScore := float64(-1) + now := time.Now() + for _, cred := range p.credentials { + if filter != nil && !filter(cred) { + continue + } + if !cred.isUsable() { + continue + } + remaining := cred.weeklyCap() - cred.weeklyUtilization() + score := remaining * cred.planWeight() + resetTime := cred.weeklyResetTime() + if !resetTime.IsZero() { + timeUntilReset := resetTime.Sub(now) + if timeUntilReset < time.Hour { + timeUntilReset = time.Hour + } + score *= weeklyWindowHours / timeUntilReset.Hours() + } + if score > bestScore { + bestScore = score + best = cred + } + } + return best +} + +func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential { + start := int(p.roundRobinIndex.Add(1) - 1) + count := len(p.credentials) + for offset := range count { + candidate := p.credentials[(start+offset)%count] + if filter != nil && !filter(candidate) { + continue + } + if candidate.isUsable() { + return candidate + } + } + return nil +} + +func (p *balancerProvider) pickRandom(filter func(credential) bool) credential { + var usable []credential + for _, candidate := range p.credentials { + if filter != nil && !filter(candidate) { + continue + } + if candidate.isUsable() { + usable = append(usable, candidate) + } + } + if len(usable) == 0 { + return nil + } + return usable[rand.IntN(len(usable))] +} + +func (p *balancerProvider) pollIfStale(ctx context.Context) { + now := time.Now() + p.sessionAccess.Lock() + for id, entry := range p.sessions { + if now.Sub(entry.createdAt) > sessionExpiry { + delete(p.sessions, id) + } + } + p.sessionAccess.Unlock() + + for _, cred := range p.credentials { + if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) { + cred.pollUsage(ctx) + } + } +} + +func (p *balancerProvider) allCredentials() []credential { + return p.credentials +} + +func (p *balancerProvider) close() {} + +func ccmPlanWeight(accountType string, rateLimitTier string) float64 { + switch accountType { + case "max": + switch rateLimitTier { + case "default_claude_max_20x": + return 10 + case "default_claude_max_5x": + return 5 + default: + return 5 + } + case "team": + if rateLimitTier == "default_claude_max_5x" { + return 5 + } + return 1 + default: + return 1 + } +} + +func allCredentialsUnavailableError(credentials []credential) error { + var hasUnavailable bool + var earliest time.Time + for _, cred := range credentials { + if cred.unavailableError() != nil { + hasUnavailable = true + continue + } + resetAt := cred.earliestReset() + if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { + earliest = resetAt + } + } + if hasUnavailable { + return E.New("all credentials unavailable") + } + if earliest.IsZero() { + return E.New("all credentials rate-limited") + } + return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest))) +} diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go deleted file mode 100644 index b07529eb0..000000000 --- a/service/ccm/credential_state.go +++ /dev/null @@ -1,1506 +0,0 @@ -package ccm - -import ( - "bytes" - "context" - stdTLS "crypto/tls" - "encoding/json" - "io" - "math" - "math/rand/v2" - "net" - "net/http" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/sagernet/fswatch" - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/dialer" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing-box/option" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/common/ntp" -) - -const ( - defaultPollInterval = 60 * time.Minute - failedPollRetryInterval = time.Minute - httpRetryMaxBackoff = 5 * time.Minute -) - -const ( - httpRetryMaxAttempts = 3 - httpRetryInitialDelay = 200 * time.Millisecond -) - -func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) { - var lastError error - for attempt := range httpRetryMaxAttempts { - if attempt > 0 { - delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1)) - select { - case <-ctx.Done(): - return nil, lastError - case <-time.After(delay): - } - } - request, err := buildRequest() - if err != nil { - return nil, err - } - response, err := client.Do(request) - if err == nil { - return response, nil - } - lastError = err - if ctx.Err() != nil { - return nil, lastError - } - } - return nil, lastError -} - -type credentialState struct { - fiveHourUtilization float64 - fiveHourReset time.Time - weeklyUtilization float64 - weeklyReset time.Time - hardRateLimited bool - rateLimitResetAt time.Time - accountType string - rateLimitTier string - remotePlanWeight float64 - lastUpdated time.Time - consecutivePollFailures int - unavailable bool - lastCredentialLoadAttempt time.Time - lastCredentialLoadError string -} - -type defaultCredential struct { - tag string - serviceContext context.Context - credentialPath string - credentialFilePath string - credentials *oauthCredentials - accessMutex sync.RWMutex - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - reloadAccess sync.Mutex - watcherAccess sync.Mutex - cap5h float64 - capWeekly float64 - usageTracker *AggregatedUsage - httpClient *http.Client - logger log.ContextLogger - watcher *fswatch.Watcher - watcherRetryAt time.Time - - // Connection interruption - onBecameUnusable func() - interrupted bool - requestContext context.Context - cancelRequests context.CancelFunc - requestAccess sync.Mutex -} - -type credentialRequestContext struct { - context.Context - releaseOnce sync.Once - cancelOnce sync.Once - releaseFuncs []func() bool - cancelFunc context.CancelFunc -} - -func (c *credentialRequestContext) addInterruptLink(stop func() bool) { - c.releaseFuncs = append(c.releaseFuncs, stop) -} - -func (c *credentialRequestContext) releaseCredentialInterrupt() { - c.releaseOnce.Do(func() { - for _, f := range c.releaseFuncs { - f() - } - }) -} - -func (c *credentialRequestContext) cancelRequest() { - c.releaseCredentialInterrupt() - c.cancelOnce.Do(c.cancelFunc) -} - -type credential interface { - tagName() string - isAvailable() bool - isUsable() bool - isExternal() bool - fiveHourUtilization() float64 - weeklyUtilization() float64 - fiveHourCap() float64 - weeklyCap() float64 - planWeight() float64 - weeklyResetTime() time.Time - markRateLimited(resetAt time.Time) - earliestReset() time.Time - unavailableError() error - - getAccessToken() (string, error) - buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) - updateStateFromHeaders(header http.Header) - - wrapRequestContext(ctx context.Context) *credentialRequestContext - interruptConnections() - - start() error - pollUsage(ctx context.Context) - lastUpdatedTime() time.Time - pollBackoff(base time.Duration) time.Duration - usageTrackerOrNil() *AggregatedUsage - httpTransport() *http.Client - close() -} - -func newDefaultCredential(ctx context.Context, tag string, options option.CCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) { - credentialDialer, err := dialer.NewWithOptions(dialer.Options{ - Context: ctx, - Options: option.DialerOptions{ - Detour: options.Detour, - }, - RemoteIsDomain: true, - }) - if err != nil { - return nil, E.Cause(err, "create dialer for credential ", tag) - } - httpClient := &http.Client{ - Transport: &http.Transport{ - ForceAttemptHTTP2: true, - TLSClientConfig: &stdTLS.Config{ - RootCAs: adapter.RootPoolFromContext(ctx), - Time: ntp.TimeFuncFromContext(ctx), - }, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - }, - } - reserve5h := options.Reserve5h - if reserve5h == 0 { - reserve5h = 1 - } - reserveWeekly := options.ReserveWeekly - if reserveWeekly == 0 { - reserveWeekly = 1 - } - var cap5h float64 - if options.Limit5h > 0 { - cap5h = float64(options.Limit5h) - } else { - cap5h = float64(100 - reserve5h) - } - var capWeekly float64 - if options.LimitWeekly > 0 { - capWeekly = float64(options.LimitWeekly) - } else { - capWeekly = float64(100 - reserveWeekly) - } - requestContext, cancelRequests := context.WithCancel(context.Background()) - credential := &defaultCredential{ - tag: tag, - serviceContext: ctx, - credentialPath: options.CredentialPath, - cap5h: cap5h, - capWeekly: capWeekly, - httpClient: httpClient, - logger: logger, - requestContext: requestContext, - cancelRequests: cancelRequests, - } - if options.UsagesPath != "" { - credential.usageTracker = &AggregatedUsage{ - LastUpdated: time.Now(), - Combinations: make([]CostCombination, 0), - filePath: options.UsagesPath, - logger: logger, - } - } - return credential, nil -} - -func (c *defaultCredential) start() error { - credentialFilePath, err := resolveCredentialFilePath(c.credentialPath) - if err != nil { - return E.Cause(err, "resolve credential path for ", c.tag) - } - c.credentialFilePath = credentialFilePath - err = c.ensureCredentialWatcher() - if err != nil { - c.logger.Debug("start credential watcher for ", c.tag, ": ", err) - } - err = c.reloadCredentials(true) - if err != nil { - c.logger.Warn("initial credential load for ", c.tag, ": ", err) - } - if c.usageTracker != nil { - err = c.usageTracker.Load() - if err != nil { - c.logger.Warn("load usage statistics for ", c.tag, ": ", err) - } - } - return nil -} - -func (c *defaultCredential) getAccessToken() (string, error) { - c.retryCredentialReloadIfNeeded() - - c.accessMutex.RLock() - if c.credentials != nil && !c.credentials.needsRefresh() { - token := c.credentials.AccessToken - c.accessMutex.RUnlock() - return token, nil - } - c.accessMutex.RUnlock() - - err := c.reloadCredentials(true) - if err == nil { - c.accessMutex.RLock() - if c.credentials != nil && !c.credentials.needsRefresh() { - token := c.credentials.AccessToken - c.accessMutex.RUnlock() - return token, nil - } - c.accessMutex.RUnlock() - } - - c.accessMutex.Lock() - defer c.accessMutex.Unlock() - - if c.credentials == nil { - return "", c.unavailableError() - } - if !c.credentials.needsRefresh() { - return c.credentials.AccessToken, nil - } - - err = platformCanWriteCredentials(c.credentialPath) - if err != nil { - return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation") - } - - baseCredentials := cloneCredentials(c.credentials) - newCredentials, err := refreshToken(c.serviceContext, c.httpClient, c.credentials) - if err != nil { - return "", err - } - - latestCredentials, latestErr := platformReadCredentials(c.credentialPath) - if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) { - c.credentials = latestCredentials - c.stateMutex.Lock() - c.state.unavailable = false - c.state.lastCredentialLoadAttempt = time.Now() - c.state.lastCredentialLoadError = "" - c.state.accountType = latestCredentials.SubscriptionType - c.state.rateLimitTier = latestCredentials.RateLimitTier - c.checkTransitionLocked() - c.stateMutex.Unlock() - if !latestCredentials.needsRefresh() { - return latestCredentials.AccessToken, nil - } - return "", E.New("credential ", c.tag, " changed while refreshing") - } - - c.credentials = newCredentials - c.stateMutex.Lock() - c.state.unavailable = false - c.state.lastCredentialLoadAttempt = time.Now() - c.state.lastCredentialLoadError = "" - c.state.accountType = newCredentials.SubscriptionType - c.state.rateLimitTier = newCredentials.RateLimitTier - c.checkTransitionLocked() - c.stateMutex.Unlock() - - err = platformWriteCredentials(newCredentials, c.credentialPath) - if err != nil { - c.logger.Error("persist refreshed token for ", c.tag, ": ", err) - } - - return newCredentials.AccessToken, nil -} - -// Claude Code's unified rate-limit handling parses these reset headers with -// Number(...), compares them against Date.now()/1000, and renders them via -// new Date(seconds*1000), so keep the wire format pinned to Unix epoch seconds. -func parseAnthropicResetHeaderValue(headerName string, headerValue string) time.Time { - unixEpoch, err := strconv.ParseInt(headerValue, 10, 64) - if err != nil { - panic("invalid " + headerName + " header: expected Unix epoch seconds, got " + strconv.Quote(headerValue)) - } - if unixEpoch <= 0 { - panic("invalid " + headerName + " header: expected positive Unix epoch seconds, got " + strconv.Quote(headerValue)) - } - return time.Unix(unixEpoch, 0) -} - -func parseOptionalAnthropicResetHeader(headers http.Header, headerName string) (time.Time, bool) { - headerValue := headers.Get(headerName) - if headerValue == "" { - return time.Time{}, false - } - return parseAnthropicResetHeaderValue(headerName, headerValue), true -} - -func parseRequiredAnthropicResetHeader(headers http.Header, headerName string) time.Time { - headerValue := headers.Get(headerName) - if headerValue == "" { - panic("missing required " + headerName + " header") - } - return parseAnthropicResetHeaderValue(headerName, headerValue) -} - -func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { - c.stateMutex.Lock() - isFirstUpdate := c.state.lastUpdated.IsZero() - oldFiveHour := c.state.fiveHourUtilization - oldWeekly := c.state.weeklyUtilization - hadData := false - - fiveHourResetChanged := false - if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists { - hadData = true - if value.After(c.state.fiveHourReset) { - fiveHourResetChanged = true - c.state.fiveHourReset = value - } - } - if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" { - value, err := strconv.ParseFloat(utilization, 64) - if err == nil { - hadData = true - newValue := math.Ceil(value * 100) - if newValue >= c.state.fiveHourUtilization || fiveHourResetChanged { - c.state.fiveHourUtilization = newValue - } - } - } - - weeklyResetChanged := false - if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists { - hadData = true - if value.After(c.state.weeklyReset) { - weeklyResetChanged = true - c.state.weeklyReset = value - } - } - if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" { - value, err := strconv.ParseFloat(utilization, 64) - if err == nil { - hadData = true - newValue := math.Ceil(value * 100) - if newValue >= c.state.weeklyUtilization || weeklyResetChanged { - c.state.weeklyUtilization = newValue - } - } - } - if hadData { - c.state.consecutivePollFailures = 0 - c.state.lastUpdated = time.Now() - } - if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { - resetSuffix := "" - if !c.state.weeklyReset.IsZero() { - resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) - } - c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) - } - shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() - if shouldInterrupt { - c.interruptConnections() - } -} - -func (c *defaultCredential) markRateLimited(resetAt time.Time) { - c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) - c.stateMutex.Lock() - c.state.hardRateLimited = true - c.state.rateLimitResetAt = resetAt - shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() - if shouldInterrupt { - c.interruptConnections() - } -} - -func (c *defaultCredential) isUsable() bool { - c.retryCredentialReloadIfNeeded() - - c.stateMutex.RLock() - if c.state.unavailable { - c.stateMutex.RUnlock() - return false - } - if c.state.consecutivePollFailures > 0 { - c.stateMutex.RUnlock() - return false - } - if c.state.hardRateLimited { - if time.Now().Before(c.state.rateLimitResetAt) { - c.stateMutex.RUnlock() - return false - } - c.stateMutex.RUnlock() - c.stateMutex.Lock() - if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { - c.state.hardRateLimited = false - } - usable := c.checkReservesLocked() - c.stateMutex.Unlock() - return usable - } - usable := c.checkReservesLocked() - c.stateMutex.RUnlock() - return usable -} - -func (c *defaultCredential) checkReservesLocked() bool { - if c.state.fiveHourUtilization >= c.cap5h { - return false - } - if c.state.weeklyUtilization >= c.capWeekly { - return false - } - return true -} - -// checkTransitionLocked detects usable→unusable transition. -// Must be called with stateMutex write lock held. -func (c *defaultCredential) checkTransitionLocked() bool { - unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() || c.state.consecutivePollFailures > 0 - if unusable && !c.interrupted { - c.interrupted = true - return true - } - if !unusable && c.interrupted { - c.interrupted = false - } - return false -} - -func (c *defaultCredential) interruptConnections() { - c.logger.Warn("interrupting connections for ", c.tag) - c.requestAccess.Lock() - c.cancelRequests() - c.requestContext, c.cancelRequests = context.WithCancel(context.Background()) - c.requestAccess.Unlock() - if c.onBecameUnusable != nil { - c.onBecameUnusable() - } -} - -func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext { - c.requestAccess.Lock() - credentialContext := c.requestContext - c.requestAccess.Unlock() - derived, cancel := context.WithCancel(parent) - stop := context.AfterFunc(credentialContext, func() { - cancel() - }) - return &credentialRequestContext{ - Context: derived, - releaseFuncs: []func() bool{stop}, - cancelFunc: cancel, - } -} - -func (c *defaultCredential) weeklyUtilization() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return c.state.weeklyUtilization -} - -func (c *defaultCredential) planWeight() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return ccmPlanWeight(c.state.accountType, c.state.rateLimitTier) -} - -func (c *defaultCredential) weeklyResetTime() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return c.state.weeklyReset -} - -func (c *defaultCredential) isAvailable() bool { - c.retryCredentialReloadIfNeeded() - - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return !c.state.unavailable -} - -func (c *defaultCredential) unavailableError() error { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - if !c.state.unavailable { - return nil - } - if c.state.lastCredentialLoadError == "" { - return E.New("credential ", c.tag, " is unavailable") - } - return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError) -} - -func (c *defaultCredential) lastUpdatedTime() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return c.state.lastUpdated -} - -func (c *defaultCredential) markUsagePollAttempted() { - c.stateMutex.Lock() - defer c.stateMutex.Unlock() - c.state.lastUpdated = time.Now() -} - -func (c *defaultCredential) incrementPollFailures() { - c.stateMutex.Lock() - c.state.consecutivePollFailures++ - shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() - if shouldInterrupt { - c.interruptConnections() - } -} - -func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration { - c.stateMutex.RLock() - failures := c.state.consecutivePollFailures - c.stateMutex.RUnlock() - if failures <= 0 { - return baseInterval - } - backoff := failedPollRetryInterval * time.Duration(1<<(failures-1)) - if backoff > httpRetryMaxBackoff { - return httpRetryMaxBackoff - } - return backoff -} - -func (c *defaultCredential) isPollBackoffAtCap() bool { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - failures := c.state.consecutivePollFailures - return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff -} - -func (c *defaultCredential) earliestReset() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - if c.state.unavailable { - return time.Time{} - } - if c.state.hardRateLimited { - return c.state.rateLimitResetAt - } - earliest := c.state.fiveHourReset - if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) { - earliest = c.state.weeklyReset - } - return earliest -} - -func (c *defaultCredential) pollUsage(ctx context.Context) { - if !c.pollAccess.TryLock() { - return - } - defer c.pollAccess.Unlock() - defer c.markUsagePollAttempted() - - c.retryCredentialReloadIfNeeded() - if !c.isAvailable() { - return - } - - accessToken, err := c.getAccessToken() - if err != nil { - if !c.isPollBackoffAtCap() { - c.logger.Error("poll usage for ", c.tag, ": get token: ", err) - } - c.incrementPollFailures() - return - } - - httpClient := &http.Client{ - Transport: c.httpClient.Transport, - Timeout: 5 * time.Second, - } - - response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { - request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil) - if err != nil { - return nil, err - } - request.Header.Set("Authorization", "Bearer "+accessToken) - request.Header.Set("Content-Type", "application/json") - request.Header.Set("User-Agent", ccmUserAgentValue) - request.Header.Set("anthropic-beta", anthropicBetaOAuthValue) - return request, nil - }) - if err != nil { - if !c.isPollBackoffAtCap() { - c.logger.Error("poll usage for ", c.tag, ": ", err) - } - c.incrementPollFailures() - return - } - defer response.Body.Close() - - if response.StatusCode != http.StatusOK { - if response.StatusCode == http.StatusTooManyRequests { - c.logger.Warn("poll usage for ", c.tag, ": rate limited") - } - body, _ := io.ReadAll(response.Body) - c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) - c.incrementPollFailures() - return - } - - var usageResponse struct { - FiveHour struct { - Utilization float64 `json:"utilization"` - ResetsAt time.Time `json:"resets_at"` - } `json:"five_hour"` - SevenDay struct { - Utilization float64 `json:"utilization"` - ResetsAt time.Time `json:"resets_at"` - } `json:"seven_day"` - } - err = json.NewDecoder(response.Body).Decode(&usageResponse) - if err != nil { - c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) - c.incrementPollFailures() - return - } - - c.stateMutex.Lock() - isFirstUpdate := c.state.lastUpdated.IsZero() - oldFiveHour := c.state.fiveHourUtilization - oldWeekly := c.state.weeklyUtilization - c.state.consecutivePollFailures = 0 - c.state.fiveHourUtilization = usageResponse.FiveHour.Utilization - if !usageResponse.FiveHour.ResetsAt.IsZero() { - c.state.fiveHourReset = usageResponse.FiveHour.ResetsAt - } - c.state.weeklyUtilization = usageResponse.SevenDay.Utilization - if !usageResponse.SevenDay.ResetsAt.IsZero() { - c.state.weeklyReset = usageResponse.SevenDay.ResetsAt - } - if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { - c.state.hardRateLimited = false - } - if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { - resetSuffix := "" - if !c.state.weeklyReset.IsZero() { - resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) - } - c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) - } - needsProfileFetch := c.state.rateLimitTier == "" - shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() - if shouldInterrupt { - c.interruptConnections() - } - - if needsProfileFetch { - c.fetchProfile(ctx, httpClient, accessToken) - } -} - -func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.Client, accessToken string) { - response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { - request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/profile", nil) - if err != nil { - return nil, err - } - request.Header.Set("Authorization", "Bearer "+accessToken) - request.Header.Set("Content-Type", "application/json") - request.Header.Set("User-Agent", ccmUserAgentValue) - return request, nil - }) - if err != nil { - c.logger.Debug("fetch profile for ", c.tag, ": ", err) - return - } - defer response.Body.Close() - - if response.StatusCode != http.StatusOK { - return - } - - var profileResponse struct { - Organization *struct { - OrganizationType string `json:"organization_type"` - RateLimitTier string `json:"rate_limit_tier"` - } `json:"organization"` - } - err = json.NewDecoder(response.Body).Decode(&profileResponse) - if err != nil || profileResponse.Organization == nil { - return - } - - accountType := "" - switch profileResponse.Organization.OrganizationType { - case "claude_pro": - accountType = "pro" - case "claude_max": - accountType = "max" - case "claude_team": - accountType = "team" - case "claude_enterprise": - accountType = "enterprise" - } - rateLimitTier := profileResponse.Organization.RateLimitTier - - c.stateMutex.Lock() - if accountType != "" && c.state.accountType == "" { - c.state.accountType = accountType - } - if rateLimitTier != "" { - c.state.rateLimitTier = rateLimitTier - } - c.stateMutex.Unlock() - c.logger.Info("fetched profile for ", c.tag, ": type=", c.state.accountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(c.state.accountType, rateLimitTier)) -} - -func (c *defaultCredential) close() { - if c.watcher != nil { - err := c.watcher.Close() - if err != nil { - c.logger.Error("close credential watcher for ", c.tag, ": ", err) - } - } - if c.usageTracker != nil { - c.usageTracker.cancelPendingSave() - err := c.usageTracker.Save() - if err != nil { - c.logger.Error("save usage statistics for ", c.tag, ": ", err) - } - } -} - -func (c *defaultCredential) tagName() string { - return c.tag -} - -func (c *defaultCredential) isExternal() bool { - return false -} - -func (c *defaultCredential) fiveHourUtilization() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return c.state.fiveHourUtilization -} - -func (c *defaultCredential) fiveHourCap() float64 { - return c.cap5h -} - -func (c *defaultCredential) weeklyCap() float64 { - return c.capWeekly -} - -func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage { - return c.usageTracker -} - -func (c *defaultCredential) httpTransport() *http.Client { - return c.httpClient -} - -func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) { - accessToken, err := c.getAccessToken() - if err != nil { - return nil, E.Cause(err, "get access token for ", c.tag) - } - - proxyURL := claudeAPIBaseURL + original.URL.RequestURI() - var body io.Reader - if bodyBytes != nil { - body = bytes.NewReader(bodyBytes) - } else { - body = original.Body - } - proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body) - if err != nil { - return nil, err - } - - for key, values := range original.Header { - if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" { - proxyRequest.Header[key] = values - } - } - - serviceOverridesAcceptEncoding := len(serviceHeaders.Values("Accept-Encoding")) > 0 - if c.usageTracker != nil && !serviceOverridesAcceptEncoding { - proxyRequest.Header.Del("Accept-Encoding") - } - - anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta") - if anthropicBetaHeader != "" { - proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader) - } else { - proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue) - } - - for key, values := range serviceHeaders { - proxyRequest.Header.Del(key) - proxyRequest.Header[key] = values - } - proxyRequest.Header.Set("Authorization", "Bearer "+accessToken) - - return proxyRequest, nil -} - -// credentialProvider is the interface for all credential types. -type credentialProvider interface { - selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) - onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential - linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool - pollIfStale(ctx context.Context) - allCredentials() []credential - close() -} - -type credentialSelectionScope string - -const ( - credentialSelectionScopeAll credentialSelectionScope = "all" - credentialSelectionScopeNonExternal credentialSelectionScope = "non_external" -) - -type credentialSelection struct { - scope credentialSelectionScope - filter func(credential) bool -} - -func (s credentialSelection) allows(cred credential) bool { - return s.filter == nil || s.filter(cred) -} - -func (s credentialSelection) scopeOrDefault() credentialSelectionScope { - if s.scope == "" { - return credentialSelectionScopeAll - } - return s.scope -} - -// singleCredentialProvider wraps a single credential (legacy or single default). -type singleCredentialProvider struct { - cred credential - sessionAccess sync.RWMutex - sessions map[string]time.Time -} - -func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { - if !selection.allows(p.cred) { - return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") - } - if !p.cred.isAvailable() { - return nil, false, p.cred.unavailableError() - } - if !p.cred.isUsable() { - return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited") - } - var isNew bool - if sessionID != "" { - p.sessionAccess.Lock() - if p.sessions == nil { - p.sessions = make(map[string]time.Time) - } - _, exists := p.sessions[sessionID] - if !exists { - p.sessions[sessionID] = time.Now() - isNew = true - } - p.sessionAccess.Unlock() - } - return p.cred, isNew, nil -} - -func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential { - cred.markRateLimited(resetAt) - return nil -} - -func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { - now := time.Now() - p.sessionAccess.Lock() - for id, createdAt := range p.sessions { - if now.Sub(createdAt) > sessionExpiry { - delete(p.sessions, id) - } - } - p.sessionAccess.Unlock() - - if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) { - p.cred.pollUsage(ctx) - } -} - -func (p *singleCredentialProvider) allCredentials() []credential { - return []credential{p.cred} -} - -func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { - return func() bool { - return false - } -} - -func (p *singleCredentialProvider) close() {} - -const sessionExpiry = 24 * time.Hour - -type sessionEntry struct { - tag string - selectionScope credentialSelectionScope - createdAt time.Time -} - -type credentialInterruptKey struct { - tag string - selectionScope credentialSelectionScope -} - -type credentialInterruptEntry struct { - context context.Context - cancel context.CancelFunc -} - -// balancerProvider assigns sessions to credentials based on a configurable strategy. -type balancerProvider struct { - credentials []credential - strategy string - roundRobinIndex atomic.Uint64 - pollInterval time.Duration - rebalanceThreshold float64 - sessionMutex sync.RWMutex - sessions map[string]sessionEntry - interruptAccess sync.Mutex - credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry - logger log.ContextLogger -} - -func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { - if pollInterval <= 0 { - pollInterval = defaultPollInterval - } - return &balancerProvider{ - credentials: credentials, - strategy: strategy, - pollInterval: pollInterval, - rebalanceThreshold: rebalanceThreshold, - sessions: make(map[string]sessionEntry), - credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry), - logger: logger, - } -} - -func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { - if p.strategy == C.BalancerStrategyFallback { - best := p.pickCredential(selection.filter) - if best == nil { - return nil, false, allCredentialsUnavailableError(p.credentials) - } - return best, false, nil - } - - selectionScope := selection.scopeOrDefault() - if sessionID != "" { - p.sessionMutex.RLock() - entry, exists := p.sessions[sessionID] - p.sessionMutex.RUnlock() - if exists { - if entry.selectionScope == selectionScope { - for _, cred := range p.credentials { - if cred.tagName() == entry.tag && selection.allows(cred) && cred.isUsable() { - if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) { - better := p.pickLeastUsed(selection.filter) - if better != nil && better.tagName() != cred.tagName() { - effectiveThreshold := p.rebalanceThreshold / cred.planWeight() - delta := cred.weeklyUtilization() - better.weeklyUtilization() - if delta > effectiveThreshold { - p.logger.Info("rebalancing away from ", cred.tagName(), - ": utilization delta ", delta, "% exceeds effective threshold ", - effectiveThreshold, "% (weight ", cred.planWeight(), ")") - p.rebalanceCredential(cred.tagName(), selectionScope) - break - } - } - } - return cred, false, nil - } - } - } - p.sessionMutex.Lock() - delete(p.sessions, sessionID) - p.sessionMutex.Unlock() - } - } - - best := p.pickCredential(selection.filter) - if best == nil { - return nil, false, allCredentialsUnavailableError(p.credentials) - } - - isNew := sessionID != "" - if isNew { - p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{ - tag: best.tagName(), - selectionScope: selectionScope, - createdAt: time.Now(), - } - p.sessionMutex.Unlock() - } - return best, isNew, nil -} - -func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) { - key := credentialInterruptKey{tag: tag, selectionScope: selectionScope} - p.interruptAccess.Lock() - if entry, loaded := p.credentialInterrupts[key]; loaded { - entry.cancel() - } - ctx, cancel := context.WithCancel(context.Background()) - p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel} - p.interruptAccess.Unlock() - - p.sessionMutex.Lock() - for id, entry := range p.sessions { - if entry.tag == tag && entry.selectionScope == selectionScope { - delete(p.sessions, id) - } - } - p.sessionMutex.Unlock() -} - -func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool { - if p.strategy == C.BalancerStrategyFallback { - return func() bool { return false } - } - key := credentialInterruptKey{ - tag: cred.tagName(), - selectionScope: selection.scopeOrDefault(), - } - p.interruptAccess.Lock() - entry, loaded := p.credentialInterrupts[key] - if !loaded { - ctx, cancel := context.WithCancel(context.Background()) - entry = credentialInterruptEntry{context: ctx, cancel: cancel} - p.credentialInterrupts[key] = entry - } - p.interruptAccess.Unlock() - return context.AfterFunc(entry.context, onInterrupt) -} - -func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential { - cred.markRateLimited(resetAt) - if p.strategy == C.BalancerStrategyFallback { - return p.pickCredential(selection.filter) - } - if sessionID != "" { - p.sessionMutex.Lock() - delete(p.sessions, sessionID) - p.sessionMutex.Unlock() - } - - best := p.pickCredential(selection.filter) - if best != nil && sessionID != "" { - p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{ - tag: best.tagName(), - selectionScope: selection.scopeOrDefault(), - createdAt: time.Now(), - } - p.sessionMutex.Unlock() - } - return best -} - -func (p *balancerProvider) pickCredential(filter func(credential) bool) credential { - switch p.strategy { - case C.BalancerStrategyRoundRobin: - return p.pickRoundRobin(filter) - case C.BalancerStrategyRandom: - return p.pickRandom(filter) - case C.BalancerStrategyFallback: - return p.pickFallback(filter) - default: - return p.pickLeastUsed(filter) - } -} - -func (p *balancerProvider) pickFallback(filter func(credential) bool) credential { - for _, cred := range p.credentials { - if filter != nil && !filter(cred) { - continue - } - if cred.isUsable() { - return cred - } - } - return nil -} - -func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { - var best credential - bestScore := float64(-1) - now := time.Now() - for _, cred := range p.credentials { - if filter != nil && !filter(cred) { - continue - } - if !cred.isUsable() { - continue - } - remaining := cred.weeklyCap() - cred.weeklyUtilization() - score := remaining * cred.planWeight() - resetTime := cred.weeklyResetTime() - if !resetTime.IsZero() { - timeUntilReset := resetTime.Sub(now) - if timeUntilReset < time.Hour { - timeUntilReset = time.Hour - } - score *= weeklyWindowDuration / timeUntilReset.Hours() - } - if score > bestScore { - bestScore = score - best = cred - } - } - return best -} - -const weeklyWindowDuration = 7 * 24 // hours - -func ccmPlanWeight(accountType string, rateLimitTier string) float64 { - switch accountType { - case "max": - switch rateLimitTier { - case "default_claude_max_20x": - return 10 - case "default_claude_max_5x": - return 5 - default: - return 5 - } - case "team": - if rateLimitTier == "default_claude_max_5x" { - return 5 - } - return 1 - default: - return 1 - } -} - -func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential { - start := int(p.roundRobinIndex.Add(1) - 1) - count := len(p.credentials) - for offset := range count { - candidate := p.credentials[(start+offset)%count] - if filter != nil && !filter(candidate) { - continue - } - if candidate.isUsable() { - return candidate - } - } - return nil -} - -func (p *balancerProvider) pickRandom(filter func(credential) bool) credential { - var usable []credential - for _, candidate := range p.credentials { - if filter != nil && !filter(candidate) { - continue - } - if candidate.isUsable() { - usable = append(usable, candidate) - } - } - if len(usable) == 0 { - return nil - } - return usable[rand.IntN(len(usable))] -} - -func (p *balancerProvider) pollIfStale(ctx context.Context) { - now := time.Now() - p.sessionMutex.Lock() - for id, entry := range p.sessions { - if now.Sub(entry.createdAt) > sessionExpiry { - delete(p.sessions, id) - } - } - p.sessionMutex.Unlock() - - for _, cred := range p.credentials { - if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) { - cred.pollUsage(ctx) - } - } -} - -func (p *balancerProvider) allCredentials() []credential { - return p.credentials -} - -func (p *balancerProvider) close() {} - -func allCredentialsUnavailableError(credentials []credential) error { - var hasUnavailable bool - var earliest time.Time - for _, cred := range credentials { - if cred.unavailableError() != nil { - hasUnavailable = true - continue - } - resetAt := cred.earliestReset() - if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { - earliest = resetAt - } - } - if hasUnavailable { - return E.New("all credentials unavailable") - } - if earliest.IsZero() { - return E.New("all credentials rate-limited") - } - return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest))) -} - -func extractCCMSessionID(bodyBytes []byte) string { - var body struct { - Metadata struct { - UserID string `json:"user_id"` - } `json:"metadata"` - } - err := json.Unmarshal(bodyBytes, &body) - if err != nil { - return "" - } - userID := body.Metadata.UserID - sessionIndex := strings.LastIndex(userID, "_session_") - if sessionIndex < 0 { - return "" - } - return userID[sessionIndex+len("_session_"):] -} - -func buildCredentialProviders( - ctx context.Context, - options option.CCMServiceOptions, - logger log.ContextLogger, -) (map[string]credentialProvider, []credential, error) { - allCredentialMap := make(map[string]credential) - var allCreds []credential - providers := make(map[string]credentialProvider) - - // Pass 1: create default and external credentials - for _, credOpt := range options.Credentials { - switch credOpt.Type { - case "default": - cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) - if err != nil { - return nil, nil, err - } - allCredentialMap[credOpt.Tag] = cred - allCreds = append(allCreds, cred) - providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} - case "external": - cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger) - if err != nil { - return nil, nil, err - } - allCredentialMap[credOpt.Tag] = cred - allCreds = append(allCreds, cred) - providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} - } - } - - // Pass 2: create balancer providers - for _, credOpt := range options.Credentials { - if credOpt.Type == "balancer" { - subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag) - if err != nil { - return nil, nil, err - } - providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger) - } - } - - return providers, allCreds, nil -} - -func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) { - credentials := make([]credential, 0, len(tags)) - for _, tag := range tags { - cred, exists := allCredentials[tag] - if !exists { - return nil, E.New("credential ", parentTag, " references unknown credential: ", tag) - } - credentials = append(credentials, cred) - } - if len(credentials) == 0 { - return nil, E.New("credential ", parentTag, " has no sub-credentials") - } - return credentials, nil -} - -func parseRateLimitResetFromHeaders(headers http.Header) time.Time { - claim := headers.Get("anthropic-ratelimit-unified-representative-claim") - switch claim { - case "5h": - return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset") - case "7d": - return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset") - default: - panic("invalid anthropic-ratelimit-unified-representative-claim header: " + strconv.Quote(claim)) - } -} - -func validateCCMOptions(options option.CCMServiceOptions) error { - hasCredentials := len(options.Credentials) > 0 - hasLegacyPath := options.CredentialPath != "" - hasLegacyUsages := options.UsagesPath != "" - hasLegacyDetour := options.Detour != "" - - if hasCredentials && hasLegacyPath { - return E.New("credential_path and credentials are mutually exclusive") - } - if hasCredentials && hasLegacyUsages { - return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials") - } - if hasCredentials && hasLegacyDetour { - return E.New("detour and credentials are mutually exclusive; use detour on individual credentials") - } - - if hasCredentials { - tags := make(map[string]bool) - credentialTypes := make(map[string]string) - for _, cred := range options.Credentials { - if tags[cred.Tag] { - return E.New("duplicate credential tag: ", cred.Tag) - } - tags[cred.Tag] = true - credentialTypes[cred.Tag] = cred.Type - if cred.Type == "default" || cred.Type == "" { - if cred.DefaultOptions.Reserve5h > 99 { - return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99") - } - if cred.DefaultOptions.ReserveWeekly > 99 { - return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99") - } - if cred.DefaultOptions.Limit5h > 100 { - return E.New("credential ", cred.Tag, ": limit_5h must be at most 100") - } - if cred.DefaultOptions.LimitWeekly > 100 { - return E.New("credential ", cred.Tag, ": limit_weekly must be at most 100") - } - if cred.DefaultOptions.Reserve5h > 0 && cred.DefaultOptions.Limit5h > 0 { - return E.New("credential ", cred.Tag, ": reserve_5h and limit_5h are mutually exclusive") - } - if cred.DefaultOptions.ReserveWeekly > 0 && cred.DefaultOptions.LimitWeekly > 0 { - return E.New("credential ", cred.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") - } - } - if cred.Type == "external" { - if cred.ExternalOptions.Token == "" { - return E.New("credential ", cred.Tag, ": external credential requires token") - } - if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" { - return E.New("credential ", cred.Tag, ": reverse external credential requires url") - } - } - if cred.Type == "balancer" { - switch cred.BalancerOptions.Strategy { - case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback: - default: - return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy) - } - if cred.BalancerOptions.RebalanceThreshold < 0 { - return E.New("credential ", cred.Tag, ": rebalance_threshold must not be negative") - } - } - } - - for _, user := range options.Users { - if user.Credential == "" { - return E.New("user ", user.Name, " must specify credential in multi-credential mode") - } - if !tags[user.Credential] { - return E.New("user ", user.Name, " references unknown credential: ", user.Credential) - } - if user.ExternalCredential != "" { - if !tags[user.ExternalCredential] { - return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential) - } - if credentialTypes[user.ExternalCredential] != "external" { - return E.New("user ", user.Name, ": external_credential must reference an external type credential") - } - } - } - } - - return nil -} - -// credentialForUser finds the credential provider for a user. -// In legacy mode, returns the single provider. -// In multi-credential mode, returns the provider mapped to the user's credential tag. -func credentialForUser( - userConfigMap map[string]*option.CCMUser, - providers map[string]credentialProvider, - legacyProvider credentialProvider, - username string, -) (credentialProvider, error) { - if legacyProvider != nil { - return legacyProvider, nil - } - userConfig, exists := userConfigMap[username] - if !exists { - return nil, E.New("no credential mapping for user: ", username) - } - provider, exists := providers[userConfig.Credential] - if !exists { - return nil, E.New("unknown credential: ", userConfig.Credential) - } - return provider, nil -} - -// noUserCredentialProvider returns the single provider for legacy mode or the first credential in multi-credential mode (no auth). -func noUserCredentialProvider( - providers map[string]credentialProvider, - legacyProvider credentialProvider, - options option.CCMServiceOptions, -) credentialProvider { - if legacyProvider != nil { - return legacyProvider - } - if len(options.Credentials) > 0 { - tag := options.Credentials[0].Tag - return providers[tag] - } - return nil -} diff --git a/service/ccm/service.go b/service/ccm/service.go index 6a2aa2b74..6dce1931b 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -1,17 +1,12 @@ package ccm import ( - "bytes" "context" "encoding/json" "errors" - "io" - "mime" "net/http" - "strconv" "strings" "sync" - "time" "github.com/sagernet/sing-box/adapter" boxService "github.com/sagernet/sing-box/adapter/service" @@ -21,23 +16,16 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" aTLS "github.com/sagernet/sing/common/tls" - "github.com/anthropics/anthropic-sdk-go" "github.com/go-chi/chi/v5" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) -const ( - contextWindowStandard = 200000 - contextWindowPremium = 1000000 - premiumContextThreshold = 200000 - retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential" -) +const retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential" func RegisterService(registry *boxService.Registry) { boxService.Register[option.CCMServiceOptions](registry, C.TypeCCM, NewService) @@ -152,23 +140,6 @@ func isReverseProxyHeader(header string) bool { } } -const ( - weeklyWindowSeconds = 604800 - weeklyWindowMinutes = weeklyWindowSeconds / 60 -) - -func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint { - resetAt, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset") - if !exists { - return nil - } - - return &WeeklyCycleHint{ - WindowMinutes: weeklyWindowMinutes, - ResetAt: resetAt.UTC(), - } -} - type Service struct { boxService.Adapter ctx context.Context @@ -308,545 +279,6 @@ func (s *Service) Start(stage adapter.StartStage) error { return nil } -func isExtendedContextRequest(betaHeader string) bool { - for _, feature := range strings.Split(betaHeader, ",") { - if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") { - return true - } - } - return false -} - -func isFastModeRequest(betaHeader string) bool { - for _, feature := range strings.Split(betaHeader, ",") { - if strings.HasPrefix(strings.TrimSpace(feature), "fast-mode") { - return true - } - } - return false -} - -func detectContextWindow(betaHeader string, totalInputTokens int64) int { - if totalInputTokens > premiumContextThreshold { - if isExtendedContextRequest(betaHeader) { - return contextWindowPremium - } - } - return contextWindowStandard -} - -func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := log.ContextWithNewID(r.Context()) - if r.URL.Path == "/ccm/v1/status" { - s.handleStatusEndpoint(w, r) - return - } - - if r.URL.Path == "/ccm/v1/reverse" { - s.handleReverseConnect(ctx, w, r) - return - } - - if !strings.HasPrefix(r.URL.Path, "/v1/") { - writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found") - return - } - - var username string - if len(s.options.Users) > 0 { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") - return - } - clientToken := strings.TrimPrefix(authHeader, "Bearer ") - if clientToken == authHeader { - s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") - return - } - var ok bool - username, ok = s.userManager.Authenticate(clientToken) - if !ok { - s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") - return - } - } - - // Always read body to extract model and session ID - var bodyBytes []byte - var requestModel string - var messagesCount int - var sessionID string - - if r.Body != nil { - var err error - bodyBytes, err = io.ReadAll(r.Body) - if err != nil { - s.logger.ErrorContext(ctx, "read request body: ", err) - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") - return - } - - var request struct { - Model string `json:"model"` - Messages []anthropic.MessageParam `json:"messages"` - } - err = json.Unmarshal(bodyBytes, &request) - if err == nil { - requestModel = request.Model - messagesCount = len(request.Messages) - } - - sessionID = extractCCMSessionID(bodyBytes) - r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - } - - // Resolve credential provider and user config - var provider credentialProvider - var userConfig *option.CCMUser - if len(s.options.Users) > 0 { - userConfig = s.userConfigMap[username] - var err error - provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) - if err != nil { - s.logger.ErrorContext(ctx, "resolve credential: ", err) - writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) - return - } - } else { - provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options) - } - if provider == nil { - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available") - return - } - - provider.pollIfStale(s.ctx) - - anthropicBetaHeader := r.Header.Get("anthropic-beta") - if isFastModeRequest(anthropicBetaHeader) { - if _, isSingle := provider.(*singleCredentialProvider); !isSingle { - writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", - "fast mode requests will consume Extra usage, please use a default credential directly") - return - } - } - - selection := credentialSelectionForUser(userConfig) - - selectedCredential, isNew, err := provider.selectCredential(sessionID, selection) - if err != nil { - writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error())) - return - } - if isNew { - logParts := []any{"assigned credential ", selectedCredential.tagName()} - if sessionID != "" { - logParts = append(logParts, " for session ", sessionID) - } - if username != "" { - logParts = append(logParts, " by user ", username) - } - if requestModel != "" { - modelDisplay := requestModel - if isExtendedContextRequest(anthropicBetaHeader) { - modelDisplay += "[1m]" - } - logParts = append(logParts, ", model=", modelDisplay) - } - s.logger.DebugContext(ctx, logParts...) - } - - if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() { - writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", - "fast mode requests cannot be proxied through external credentials") - return - } - - requestContext := selectedCredential.wrapRequestContext(ctx) - { - currentRequestContext := requestContext - requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() { - currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) - })) - } - defer func() { - requestContext.cancelRequest() - }() - proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) - if err != nil { - s.logger.ErrorContext(ctx, "create proxy request: ", err) - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") - return - } - - response, err := selectedCredential.httpTransport().Do(proxyRequest) - if err != nil { - if r.Context().Err() != nil { - return - } - if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request") - return - } - writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) - return - } - requestContext.releaseCredentialInterrupt() - - // Transparent 429 retry - for response.StatusCode == http.StatusTooManyRequests { - resetAt := parseRateLimitResetFromHeaders(response.Header) - nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection) - selectedCredential.updateStateFromHeaders(response.Header) - if bodyBytes == nil || nextCredential == nil { - response.Body.Close() - writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited") - return - } - response.Body.Close() - s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) - requestContext.cancelRequest() - requestContext = nextCredential.wrapRequestContext(ctx) - { - currentRequestContext := requestContext - requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() { - currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) - })) - } - retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) - if buildErr != nil { - s.logger.ErrorContext(ctx, "retry request: ", buildErr) - writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error()) - return - } - retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest) - if retryErr != nil { - if r.Context().Err() != nil { - return - } - if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request") - return - } - s.logger.ErrorContext(ctx, "retry request: ", retryErr) - writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) - return - } - requestContext.releaseCredentialInterrupt() - response = retryResponse - selectedCredential = nextCredential - } - defer response.Body.Close() - - selectedCredential.updateStateFromHeaders(response.Header) - - if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { - body, _ := io.ReadAll(response.Body) - s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) - go selectedCredential.pollUsage(s.ctx) - writeJSONError(w, r, http.StatusInternalServerError, "api_error", - "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) - return - } - - // Rewrite response headers for external users - if userConfig != nil && userConfig.ExternalCredential != "" { - s.rewriteResponseHeadersForExternalUser(response.Header, userConfig) - } - - for key, values := range response.Header { - if !isHopByHopHeader(key) && !isReverseProxyHeader(key) { - w.Header()[key] = values - } - } - w.WriteHeader(response.StatusCode) - - usageTracker := selectedCredential.usageTrackerOrNil() - if usageTracker != nil && response.StatusCode == http.StatusOK { - s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username) - } else { - mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) - if err == nil && mediaType != "text/event-stream" { - _, _ = io.Copy(w, response.Body) - return - } - flusher, ok := w.(http.Flusher) - if !ok { - s.logger.ErrorContext(ctx, "streaming not supported") - return - } - buffer := make([]byte, buf.BufferSize) - for { - n, err := response.Body.Read(buffer) - if n > 0 { - _, writeError := w.Write(buffer[:n]) - if writeError != nil { - s.logger.ErrorContext(ctx, "write streaming response: ", writeError) - return - } - flusher.Flush() - } - if err != nil { - return - } - } - } -} - -func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) { - weeklyCycleHint := extractWeeklyCycleHint(response.Header) - mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) - isStreaming := err == nil && mediaType == "text/event-stream" - - if !isStreaming { - bodyBytes, err := io.ReadAll(response.Body) - if err != nil { - s.logger.ErrorContext(ctx, "read response body: ", err) - return - } - - var message anthropic.Message - var usage anthropic.Usage - var responseModel string - err = json.Unmarshal(bodyBytes, &message) - if err == nil { - responseModel = string(message.Model) - usage = message.Usage - } - if responseModel == "" { - responseModel = requestModel - } - - if usage.InputTokens > 0 || usage.OutputTokens > 0 { - if responseModel != "" { - totalInputTokens := usage.InputTokens + usage.CacheCreationInputTokens + usage.CacheReadInputTokens - contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens) - usageTracker.AddUsageWithCycleHint( - responseModel, - contextWindow, - messagesCount, - usage.InputTokens, - usage.OutputTokens, - usage.CacheReadInputTokens, - usage.CacheCreationInputTokens, - usage.CacheCreation.Ephemeral5mInputTokens, - usage.CacheCreation.Ephemeral1hInputTokens, - username, - time.Now(), - weeklyCycleHint, - ) - } - } - - _, _ = writer.Write(bodyBytes) - return - } - - flusher, ok := writer.(http.Flusher) - if !ok { - s.logger.ErrorContext(ctx, "streaming not supported") - return - } - - var accumulatedUsage anthropic.Usage - var responseModel string - buffer := make([]byte, buf.BufferSize) - var leftover []byte - - for { - n, err := response.Body.Read(buffer) - if n > 0 { - data := append(leftover, buffer[:n]...) - lines := bytes.Split(data, []byte("\n")) - - if err == nil { - leftover = lines[len(lines)-1] - lines = lines[:len(lines)-1] - } else { - leftover = nil - } - - for _, line := range lines { - line = bytes.TrimSpace(line) - if len(line) == 0 { - continue - } - - if bytes.HasPrefix(line, []byte("data: ")) { - eventData := bytes.TrimPrefix(line, []byte("data: ")) - if bytes.Equal(eventData, []byte("[DONE]")) { - continue - } - - var event anthropic.MessageStreamEventUnion - err := json.Unmarshal(eventData, &event) - if err != nil { - continue - } - switch event.Type { - case "message_start": - messageStart := event.AsMessageStart() - if messageStart.Message.Model != "" { - responseModel = string(messageStart.Message.Model) - } - if messageStart.Message.Usage.InputTokens > 0 { - accumulatedUsage.InputTokens = messageStart.Message.Usage.InputTokens - accumulatedUsage.CacheReadInputTokens = messageStart.Message.Usage.CacheReadInputTokens - accumulatedUsage.CacheCreationInputTokens = messageStart.Message.Usage.CacheCreationInputTokens - accumulatedUsage.CacheCreation.Ephemeral5mInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral5mInputTokens - accumulatedUsage.CacheCreation.Ephemeral1hInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral1hInputTokens - } - case "message_delta": - messageDelta := event.AsMessageDelta() - if messageDelta.Usage.OutputTokens > 0 { - accumulatedUsage.OutputTokens = messageDelta.Usage.OutputTokens - } - } - } - } - - _, writeError := writer.Write(buffer[:n]) - if writeError != nil { - s.logger.ErrorContext(ctx, "write streaming response: ", writeError) - return - } - flusher.Flush() - } - - if err != nil { - if responseModel == "" { - responseModel = requestModel - } - - if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 { - if responseModel != "" { - totalInputTokens := accumulatedUsage.InputTokens + accumulatedUsage.CacheCreationInputTokens + accumulatedUsage.CacheReadInputTokens - contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens) - usageTracker.AddUsageWithCycleHint( - responseModel, - contextWindow, - messagesCount, - accumulatedUsage.InputTokens, - accumulatedUsage.OutputTokens, - accumulatedUsage.CacheReadInputTokens, - accumulatedUsage.CacheCreationInputTokens, - accumulatedUsage.CacheCreation.Ephemeral5mInputTokens, - accumulatedUsage.CacheCreation.Ephemeral1hInputTokens, - username, - time.Now(), - weeklyCycleHint, - ) - } - } - return - } - } -} - -func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") - return - } - - if len(s.options.Users) == 0 { - writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication") - return - } - - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") - return - } - clientToken := strings.TrimPrefix(authHeader, "Bearer ") - if clientToken == authHeader { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") - return - } - username, ok := s.userManager.Authenticate(clientToken) - if !ok { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") - return - } - - userConfig := s.userConfigMap[username] - if userConfig == nil { - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found") - return - } - - provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) - if err != nil { - writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) - return - } - - provider.pollIfStale(r.Context()) - avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]float64{ - "five_hour_utilization": avgFiveHour, - "weekly_utilization": avgWeekly, - "plan_weight": totalWeight, - }) -} - -func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) { - var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 - for _, cred := range provider.allCredentials() { - if !cred.isAvailable() { - continue - } - if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential { - continue - } - if !userConfig.AllowExternalUsage && cred.isExternal() { - continue - } - weight := cred.planWeight() - remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization() - if remaining5h < 0 { - remaining5h = 0 - } - remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization() - if remainingWeekly < 0 { - remainingWeekly = 0 - } - totalWeightedRemaining5h += remaining5h * weight - totalWeightedRemainingWeekly += remainingWeekly * weight - totalWeight += weight - } - if totalWeight == 0 { - return 100, 100, 0 - } - return 100 - totalWeightedRemaining5h/totalWeight, - 100 - totalWeightedRemainingWeekly/totalWeight, - totalWeight -} - -func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.CCMUser) { - provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name) - if err != nil { - return - } - - avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) - - // Rewrite utilization headers to aggregated average (convert back to 0.0-1.0 range) - headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64)) - headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64)) - if totalWeight > 0 { - headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64)) - } -} - func (s *Service) InterfaceUpdated() { for _, cred := range s.allCredentials { extCred, ok := cred.(*externalCredential) diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go new file mode 100644 index 000000000..7dd0c6411 --- /dev/null +++ b/service/ccm/service_handler.go @@ -0,0 +1,499 @@ +package ccm + +import ( + "bytes" + "context" + "encoding/json" + "io" + "mime" + "net/http" + "strconv" + "strings" + "time" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/buf" + + "github.com/anthropics/anthropic-sdk-go" +) + +const ( + contextWindowStandard = 200000 + contextWindowPremium = 1000000 + premiumContextThreshold = 200000 +) + +const ( + weeklyWindowSeconds = 604800 + weeklyWindowMinutes = weeklyWindowSeconds / 60 +) + +func isExtendedContextRequest(betaHeader string) bool { + for _, feature := range strings.Split(betaHeader, ",") { + if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") { + return true + } + } + return false +} + +func isFastModeRequest(betaHeader string) bool { + for _, feature := range strings.Split(betaHeader, ",") { + if strings.HasPrefix(strings.TrimSpace(feature), "fast-mode") { + return true + } + } + return false +} + +func detectContextWindow(betaHeader string, totalInputTokens int64) int { + if totalInputTokens > premiumContextThreshold { + if isExtendedContextRequest(betaHeader) { + return contextWindowPremium + } + } + return contextWindowStandard +} + +func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint { + resetAt, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset") + if !exists { + return nil + } + + return &WeeklyCycleHint{ + WindowMinutes: weeklyWindowMinutes, + ResetAt: resetAt.UTC(), + } +} + +func extractCCMSessionID(bodyBytes []byte) string { + var body struct { + Metadata struct { + UserID string `json:"user_id"` + } `json:"metadata"` + } + err := json.Unmarshal(bodyBytes, &body) + if err != nil { + return "" + } + userID := body.Metadata.UserID + sessionIndex := strings.LastIndex(userID, "_session_") + if sessionIndex < 0 { + return "" + } + return userID[sessionIndex+len("_session_"):] +} + +func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := log.ContextWithNewID(r.Context()) + if r.URL.Path == "/ccm/v1/status" { + s.handleStatusEndpoint(w, r) + return + } + + if r.URL.Path == "/ccm/v1/reverse" { + s.handleReverseConnect(ctx, w, r) + return + } + + if !strings.HasPrefix(r.URL.Path, "/v1/") { + writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found") + return + } + + var username string + if len(s.options.Users) > 0 { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") + return + } + clientToken := strings.TrimPrefix(authHeader, "Bearer ") + if clientToken == authHeader { + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") + return + } + var ok bool + username, ok = s.userManager.Authenticate(clientToken) + if !ok { + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") + return + } + } + + // Always read body to extract model and session ID + var bodyBytes []byte + var requestModel string + var messagesCount int + var sessionID string + + if r.Body != nil { + var err error + bodyBytes, err = io.ReadAll(r.Body) + if err != nil { + s.logger.ErrorContext(ctx, "read request body: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") + return + } + + var request struct { + Model string `json:"model"` + Messages []anthropic.MessageParam `json:"messages"` + } + err = json.Unmarshal(bodyBytes, &request) + if err == nil { + requestModel = request.Model + messagesCount = len(request.Messages) + } + + sessionID = extractCCMSessionID(bodyBytes) + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + // Resolve credential provider and user config + var provider credentialProvider + var userConfig *option.CCMUser + if len(s.options.Users) > 0 { + userConfig = s.userConfigMap[username] + var err error + provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) + if err != nil { + s.logger.ErrorContext(ctx, "resolve credential: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + return + } + } else { + provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options) + } + if provider == nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available") + return + } + + provider.pollIfStale(s.ctx) + + anthropicBetaHeader := r.Header.Get("anthropic-beta") + if isFastModeRequest(anthropicBetaHeader) { + if _, isSingle := provider.(*singleCredentialProvider); !isSingle { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "fast mode requests will consume Extra usage, please use a default credential directly") + return + } + } + + selection := credentialSelectionForUser(userConfig) + + selectedCredential, isNew, err := provider.selectCredential(sessionID, selection) + if err != nil { + writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error())) + return + } + if isNew { + logParts := []any{"assigned credential ", selectedCredential.tagName()} + if sessionID != "" { + logParts = append(logParts, " for session ", sessionID) + } + if username != "" { + logParts = append(logParts, " by user ", username) + } + if requestModel != "" { + modelDisplay := requestModel + if isExtendedContextRequest(anthropicBetaHeader) { + modelDisplay += "[1m]" + } + logParts = append(logParts, ", model=", modelDisplay) + } + s.logger.DebugContext(ctx, logParts...) + } + + if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "fast mode requests cannot be proxied through external credentials") + return + } + + requestContext := selectedCredential.wrapRequestContext(ctx) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + })) + } + defer func() { + requestContext.cancelRequest() + }() + proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) + if err != nil { + s.logger.ErrorContext(ctx, "create proxy request: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") + return + } + + response, err := selectedCredential.httpClient().Do(proxyRequest) + if err != nil { + if r.Context().Err() != nil { + return + } + if requestContext.Err() != nil { + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request") + return + } + writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) + return + } + requestContext.releaseCredentialInterrupt() + + // Transparent 429 retry + for response.StatusCode == http.StatusTooManyRequests { + resetAt := parseRateLimitResetFromHeaders(response.Header) + nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection) + selectedCredential.updateStateFromHeaders(response.Header) + if bodyBytes == nil || nextCredential == nil { + response.Body.Close() + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited") + return + } + response.Body.Close() + s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) + requestContext.cancelRequest() + requestContext = nextCredential.wrapRequestContext(ctx) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + })) + } + retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) + if buildErr != nil { + s.logger.ErrorContext(ctx, "retry request: ", buildErr) + writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error()) + return + } + retryResponse, retryErr := nextCredential.httpClient().Do(retryRequest) + if retryErr != nil { + if r.Context().Err() != nil { + return + } + if requestContext.Err() != nil { + writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request") + return + } + s.logger.ErrorContext(ctx, "retry request: ", retryErr) + writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) + return + } + requestContext.releaseCredentialInterrupt() + response = retryResponse + selectedCredential = nextCredential + } + defer response.Body.Close() + + selectedCredential.updateStateFromHeaders(response.Header) + + if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { + body, _ := io.ReadAll(response.Body) + s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) + go selectedCredential.pollUsage(s.ctx) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", + "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) + return + } + + // Rewrite response headers for external users + if userConfig != nil && userConfig.ExternalCredential != "" { + s.rewriteResponseHeadersForExternalUser(response.Header, userConfig) + } + + for key, values := range response.Header { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) { + w.Header()[key] = values + } + } + w.WriteHeader(response.StatusCode) + + usageTracker := selectedCredential.usageTrackerOrNil() + if usageTracker != nil && response.StatusCode == http.StatusOK { + s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username) + } else { + mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) + if err == nil && mediaType != "text/event-stream" { + _, _ = io.Copy(w, response.Body) + return + } + flusher, ok := w.(http.Flusher) + if !ok { + s.logger.ErrorContext(ctx, "streaming not supported") + return + } + buffer := make([]byte, buf.BufferSize) + for { + n, err := response.Body.Read(buffer) + if n > 0 { + _, writeError := w.Write(buffer[:n]) + if writeError != nil { + s.logger.ErrorContext(ctx, "write streaming response: ", writeError) + return + } + flusher.Flush() + } + if err != nil { + return + } + } + } +} + +func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) { + weeklyCycleHint := extractWeeklyCycleHint(response.Header) + mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) + isStreaming := err == nil && mediaType == "text/event-stream" + + if !isStreaming { + bodyBytes, err := io.ReadAll(response.Body) + if err != nil { + s.logger.ErrorContext(ctx, "read response body: ", err) + return + } + + var message anthropic.Message + var usage anthropic.Usage + var responseModel string + err = json.Unmarshal(bodyBytes, &message) + if err == nil { + responseModel = string(message.Model) + usage = message.Usage + } + if responseModel == "" { + responseModel = requestModel + } + + if usage.InputTokens > 0 || usage.OutputTokens > 0 { + if responseModel != "" { + totalInputTokens := usage.InputTokens + usage.CacheCreationInputTokens + usage.CacheReadInputTokens + contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens) + usageTracker.AddUsageWithCycleHint( + responseModel, + contextWindow, + messagesCount, + usage.InputTokens, + usage.OutputTokens, + usage.CacheReadInputTokens, + usage.CacheCreationInputTokens, + usage.CacheCreation.Ephemeral5mInputTokens, + usage.CacheCreation.Ephemeral1hInputTokens, + username, + time.Now(), + weeklyCycleHint, + ) + } + } + + _, _ = writer.Write(bodyBytes) + return + } + + flusher, ok := writer.(http.Flusher) + if !ok { + s.logger.ErrorContext(ctx, "streaming not supported") + return + } + + var accumulatedUsage anthropic.Usage + var responseModel string + buffer := make([]byte, buf.BufferSize) + var leftover []byte + + for { + n, err := response.Body.Read(buffer) + if n > 0 { + data := append(leftover, buffer[:n]...) + lines := bytes.Split(data, []byte("\n")) + + if err == nil { + leftover = lines[len(lines)-1] + lines = lines[:len(lines)-1] + } else { + leftover = nil + } + + for _, line := range lines { + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + + if bytes.HasPrefix(line, []byte("data: ")) { + eventData := bytes.TrimPrefix(line, []byte("data: ")) + if bytes.Equal(eventData, []byte("[DONE]")) { + continue + } + + var event anthropic.MessageStreamEventUnion + err := json.Unmarshal(eventData, &event) + if err != nil { + continue + } + switch event.Type { + case "message_start": + messageStart := event.AsMessageStart() + if messageStart.Message.Model != "" { + responseModel = string(messageStart.Message.Model) + } + if messageStart.Message.Usage.InputTokens > 0 { + accumulatedUsage.InputTokens = messageStart.Message.Usage.InputTokens + accumulatedUsage.CacheReadInputTokens = messageStart.Message.Usage.CacheReadInputTokens + accumulatedUsage.CacheCreationInputTokens = messageStart.Message.Usage.CacheCreationInputTokens + accumulatedUsage.CacheCreation.Ephemeral5mInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral5mInputTokens + accumulatedUsage.CacheCreation.Ephemeral1hInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral1hInputTokens + } + case "message_delta": + messageDelta := event.AsMessageDelta() + if messageDelta.Usage.OutputTokens > 0 { + accumulatedUsage.OutputTokens = messageDelta.Usage.OutputTokens + } + } + } + } + + _, writeError := writer.Write(buffer[:n]) + if writeError != nil { + s.logger.ErrorContext(ctx, "write streaming response: ", writeError) + return + } + flusher.Flush() + } + + if err != nil { + if responseModel == "" { + responseModel = requestModel + } + + if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 { + if responseModel != "" { + totalInputTokens := accumulatedUsage.InputTokens + accumulatedUsage.CacheCreationInputTokens + accumulatedUsage.CacheReadInputTokens + contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens) + usageTracker.AddUsageWithCycleHint( + responseModel, + contextWindow, + messagesCount, + accumulatedUsage.InputTokens, + accumulatedUsage.OutputTokens, + accumulatedUsage.CacheReadInputTokens, + accumulatedUsage.CacheCreationInputTokens, + accumulatedUsage.CacheCreation.Ephemeral5mInputTokens, + accumulatedUsage.CacheCreation.Ephemeral1hInputTokens, + username, + time.Now(), + weeklyCycleHint, + ) + } + } + return + } + } +} diff --git a/service/ccm/service_status.go b/service/ccm/service_status.go new file mode 100644 index 000000000..3f91b4614 --- /dev/null +++ b/service/ccm/service_status.go @@ -0,0 +1,109 @@ +package ccm + +import ( + "encoding/json" + "net/http" + "strconv" + "strings" + + "github.com/sagernet/sing-box/option" +) + +func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") + return + } + + if len(s.options.Users) == 0 { + writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication") + return + } + + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") + return + } + clientToken := strings.TrimPrefix(authHeader, "Bearer ") + if clientToken == authHeader { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") + return + } + username, ok := s.userManager.Authenticate(clientToken) + if !ok { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") + return + } + + userConfig := s.userConfigMap[username] + if userConfig == nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found") + return + } + + provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) + if err != nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + return + } + + provider.pollIfStale(r.Context()) + avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]float64{ + "five_hour_utilization": avgFiveHour, + "weekly_utilization": avgWeekly, + "plan_weight": totalWeight, + }) +} + +func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) { + var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 + for _, cred := range provider.allCredentials() { + if !cred.isAvailable() { + continue + } + if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential { + continue + } + if !userConfig.AllowExternalUsage && cred.isExternal() { + continue + } + weight := cred.planWeight() + remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization() + if remaining5h < 0 { + remaining5h = 0 + } + remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization() + if remainingWeekly < 0 { + remainingWeekly = 0 + } + totalWeightedRemaining5h += remaining5h * weight + totalWeightedRemainingWeekly += remainingWeekly * weight + totalWeight += weight + } + if totalWeight == 0 { + return 100, 100, 0 + } + return 100 - totalWeightedRemaining5h/totalWeight, + 100 - totalWeightedRemainingWeekly/totalWeight, + totalWeight +} + +func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.CCMUser) { + provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name) + if err != nil { + return + } + + avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) + + headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64)) + headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64)) + if totalWeight > 0 { + headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64)) + } +} diff --git a/service/ccm/service_user.go b/service/ccm/service_user.go index 94637ed81..149894c04 100644 --- a/service/ccm/service_user.go +++ b/service/ccm/service_user.go @@ -7,13 +7,13 @@ import ( ) type UserManager struct { - accessMutex sync.RWMutex + access sync.RWMutex tokenMap map[string]string } func (m *UserManager) UpdateUsers(users []option.CCMUser) { - m.accessMutex.Lock() - defer m.accessMutex.Unlock() + m.access.Lock() + defer m.access.Unlock() tokenMap := make(map[string]string, len(users)) for _, user := range users { tokenMap[user.Token] = user.Name @@ -22,8 +22,8 @@ func (m *UserManager) UpdateUsers(users []option.CCMUser) { } func (m *UserManager) Authenticate(token string) (string, bool) { - m.accessMutex.RLock() + m.access.RLock() username, found := m.tokenMap[token] - m.accessMutex.RUnlock() + m.access.RUnlock() return username, found } diff --git a/service/ocm/credential.go b/service/ocm/credential.go index bb240b5ab..27a889470 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -1,225 +1,194 @@ package ocm import ( - "bytes" "context" - "encoding/json" - "io" "net/http" - "os" - "os/user" - "path/filepath" + "strconv" + "strings" + "sync" "time" - E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" ) const ( - oauth2ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" - oauth2TokenURL = "https://auth.openai.com/oauth/token" - openaiAPIBaseURL = "https://api.openai.com" - chatGPTBackendURL = "https://chatgpt.com/backend-api/codex" - tokenRefreshIntervalDays = 8 + defaultPollInterval = 60 * time.Minute + failedPollRetryInterval = time.Minute + httpRetryMaxBackoff = 5 * time.Minute ) -func getRealUser() (*user.User, error) { - if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { - sudoUserInfo, err := user.Lookup(sudoUser) - if err == nil { - return sudoUserInfo, nil +const ( + httpRetryMaxAttempts = 3 + httpRetryInitialDelay = 200 * time.Millisecond +) + +const sessionExpiry = 24 * time.Hour + +func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) { + var lastError error + for attempt := range httpRetryMaxAttempts { + if attempt > 0 { + delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1)) + select { + case <-ctx.Done(): + return nil, lastError + case <-time.After(delay): + } } - } - return user.Current() -} - -func getDefaultCredentialsPath() (string, error) { - if codexHome := os.Getenv("CODEX_HOME"); codexHome != "" { - return filepath.Join(codexHome, "auth.json"), nil - } - userInfo, err := getRealUser() - if err != nil { - return "", err - } - return filepath.Join(userInfo.HomeDir, ".codex", "auth.json"), nil -} - -func readCredentialsFromFile(path string) (*oauthCredentials, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, err - } - var credentials oauthCredentials - err = json.Unmarshal(data, &credentials) - if err != nil { - return nil, err - } - return &credentials, nil -} - -func checkCredentialFileWritable(path string) error { - file, err := os.OpenFile(path, os.O_WRONLY, 0) - if err != nil { - return err - } - return file.Close() -} - -func writeCredentialsToFile(credentials *oauthCredentials, path string) error { - data, err := json.MarshalIndent(credentials, "", " ") - if err != nil { - return err - } - return os.WriteFile(path, data, 0o600) -} - -type oauthCredentials struct { - APIKey string `json:"OPENAI_API_KEY,omitempty"` - Tokens *tokenData `json:"tokens,omitempty"` - LastRefresh *time.Time `json:"last_refresh,omitempty"` -} - -type tokenData struct { - IDToken string `json:"id_token,omitempty"` - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - AccountID string `json:"account_id,omitempty"` -} - -func (c *oauthCredentials) isAPIKeyMode() bool { - return c.APIKey != "" -} - -func (c *oauthCredentials) getAccessToken() string { - if c.APIKey != "" { - return c.APIKey - } - if c.Tokens != nil { - return c.Tokens.AccessToken - } - return "" -} - -func (c *oauthCredentials) getAccountID() string { - if c.Tokens != nil { - return c.Tokens.AccountID - } - return "" -} - -func (c *oauthCredentials) needsRefresh() bool { - if c.APIKey != "" { - return false - } - if c.Tokens == nil || c.Tokens.RefreshToken == "" { - return false - } - if c.LastRefresh == nil { - return true - } - return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour -} - -func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { - if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" { - return nil, E.New("refresh token is empty") - } - - requestBody, err := json.Marshal(map[string]string{ - "grant_type": "refresh_token", - "refresh_token": credentials.Tokens.RefreshToken, - "client_id": oauth2ClientID, - "scope": "openid profile email", - }) - if err != nil { - return nil, E.Cause(err, "marshal request") - } - - response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { - request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody)) + request, err := buildRequest() if err != nil { return nil, err } - request.Header.Set("Content-Type", "application/json") - request.Header.Set("Accept", "application/json") - return request, nil + response, err := client.Do(request) + if err == nil { + return response, nil + } + lastError = err + if ctx.Err() != nil { + return nil, lastError + } + } + return nil, lastError +} + +type credentialState struct { + fiveHourUtilization float64 + fiveHourReset time.Time + weeklyUtilization float64 + weeklyReset time.Time + hardRateLimited bool + rateLimitResetAt time.Time + accountType string + remotePlanWeight float64 + lastUpdated time.Time + consecutivePollFailures int + unavailable bool + lastCredentialLoadAttempt time.Time + lastCredentialLoadError string +} + +type credentialRequestContext struct { + context.Context + releaseOnce sync.Once + cancelOnce sync.Once + releaseFuncs []func() bool + cancelFunc context.CancelFunc +} + +func (c *credentialRequestContext) addInterruptLink(stop func() bool) { + c.releaseFuncs = append(c.releaseFuncs, stop) +} + +func (c *credentialRequestContext) releaseCredentialInterrupt() { + c.releaseOnce.Do(func() { + for _, f := range c.releaseFuncs { + f() + } }) - if err != nil { - return nil, err - } - defer response.Body.Close() - - if response.StatusCode == http.StatusTooManyRequests { - body, _ := io.ReadAll(response.Body) - return nil, E.New("refresh rate limited: ", response.Status, " ", string(body)) - } - if response.StatusCode != http.StatusOK { - body, _ := io.ReadAll(response.Body) - return nil, E.New("refresh failed: ", response.Status, " ", string(body)) - } - - var tokenResponse struct { - IDToken string `json:"id_token"` - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - } - err = json.NewDecoder(response.Body).Decode(&tokenResponse) - if err != nil { - return nil, E.Cause(err, "decode response") - } - - newCredentials := *credentials - if newCredentials.Tokens == nil { - newCredentials.Tokens = &tokenData{} - } - if tokenResponse.IDToken != "" { - newCredentials.Tokens.IDToken = tokenResponse.IDToken - } - if tokenResponse.AccessToken != "" { - newCredentials.Tokens.AccessToken = tokenResponse.AccessToken - } - if tokenResponse.RefreshToken != "" { - newCredentials.Tokens.RefreshToken = tokenResponse.RefreshToken - } - now := time.Now() - newCredentials.LastRefresh = &now - - return &newCredentials, nil } -func cloneCredentials(credentials *oauthCredentials) *oauthCredentials { - if credentials == nil { - return nil - } - cloned := *credentials - if credentials.Tokens != nil { - clonedTokens := *credentials.Tokens - cloned.Tokens = &clonedTokens - } - if credentials.LastRefresh != nil { - lastRefresh := *credentials.LastRefresh - cloned.LastRefresh = &lastRefresh - } - return &cloned +func (c *credentialRequestContext) cancelRequest() { + c.releaseCredentialInterrupt() + c.cancelOnce.Do(c.cancelFunc) } -func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool { - if left == nil || right == nil { - return left == right - } - if left.APIKey != right.APIKey { - return false - } - if (left.Tokens == nil) != (right.Tokens == nil) { - return false - } - if left.Tokens != nil && *left.Tokens != *right.Tokens { - return false - } - if (left.LastRefresh == nil) != (right.LastRefresh == nil) { - return false - } - if left.LastRefresh != nil && !left.LastRefresh.Equal(*right.LastRefresh) { - return false - } - return true +type credential interface { + tagName() string + isAvailable() bool + isUsable() bool + isExternal() bool + fiveHourUtilization() float64 + weeklyUtilization() float64 + fiveHourCap() float64 + weeklyCap() float64 + planWeight() float64 + weeklyResetTime() time.Time + markRateLimited(resetAt time.Time) + earliestReset() time.Time + unavailableError() error + + getAccessToken() (string, error) + buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) + updateStateFromHeaders(header http.Header) + + wrapRequestContext(ctx context.Context) *credentialRequestContext + interruptConnections() + + setOnBecameUnusable(fn func()) + start() error + pollUsage(ctx context.Context) + lastUpdatedTime() time.Time + pollBackoff(base time.Duration) time.Duration + usageTrackerOrNil() *AggregatedUsage + httpClient() *http.Client + close() + + // OCM-specific + ocmDialer() N.Dialer + ocmIsAPIKeyMode() bool + ocmGetAccountID() string + ocmGetBaseURL() string +} + +type credentialSelectionScope string + +const ( + credentialSelectionScopeAll credentialSelectionScope = "all" + credentialSelectionScopeNonExternal credentialSelectionScope = "non_external" +) + +type credentialSelection struct { + scope credentialSelectionScope + filter func(credential) bool +} + +func (s credentialSelection) allows(cred credential) bool { + return s.filter == nil || s.filter(cred) +} + +func (s credentialSelection) scopeOrDefault() credentialSelectionScope { + if s.scope == "" { + return credentialSelectionScopeAll + } + return s.scope +} + +func normalizeRateLimitIdentifier(limitIdentifier string) string { + trimmedIdentifier := strings.TrimSpace(strings.ToLower(limitIdentifier)) + if trimmedIdentifier == "" { + return "" + } + return strings.ReplaceAll(trimmedIdentifier, "_", "-") +} + +func parseInt64Header(headers http.Header, headerName string) (int64, bool) { + headerValue := strings.TrimSpace(headers.Get(headerName)) + if headerValue == "" { + return 0, false + } + parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64) + if parseError != nil { + return 0, false + } + return parsedValue, true +} + +func parseOCMRateLimitResetFromHeaders(headers http.Header) time.Time { + activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) + if activeLimitIdentifier != "" { + resetHeader := "x-" + activeLimitIdentifier + "-primary-reset-at" + if resetStr := headers.Get(resetHeader); resetStr != "" { + value, err := strconv.ParseInt(resetStr, 10, 64) + if err == nil { + return time.Unix(value, 0) + } + } + } + if retryAfter := headers.Get("Retry-After"); retryAfter != "" { + seconds, err := strconv.ParseInt(retryAfter, 10, 64) + if err == nil { + return time.Now().Add(time.Duration(seconds) * time.Second) + } + } + return time.Now().Add(5 * time.Minute) } diff --git a/service/ocm/credential_builder.go b/service/ocm/credential_builder.go new file mode 100644 index 000000000..5faaf67c6 --- /dev/null +++ b/service/ocm/credential_builder.go @@ -0,0 +1,223 @@ +package ocm + +import ( + "context" + "time" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +func buildOCMCredentialProviders( + ctx context.Context, + options option.OCMServiceOptions, + logger log.ContextLogger, +) (map[string]credentialProvider, []credential, error) { + allCredentialMap := make(map[string]credential) + var allCreds []credential + providers := make(map[string]credentialProvider) + + // Pass 1: create default and external credentials + for _, credOpt := range options.Credentials { + switch credOpt.Type { + case "default": + cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) + if err != nil { + return nil, nil, err + } + allCredentialMap[credOpt.Tag] = cred + allCreds = append(allCreds, cred) + providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} + case "external": + cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger) + if err != nil { + return nil, nil, err + } + allCredentialMap[credOpt.Tag] = cred + allCreds = append(allCreds, cred) + providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} + } + } + + // Pass 2: create balancer providers + for _, credOpt := range options.Credentials { + if credOpt.Type == "balancer" { + subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag) + if err != nil { + return nil, nil, err + } + providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger) + } + } + + return providers, allCreds, nil +} + +func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) { + credentials := make([]credential, 0, len(tags)) + for _, tag := range tags { + cred, exists := allCredentials[tag] + if !exists { + return nil, E.New("credential ", parentTag, " references unknown credential: ", tag) + } + credentials = append(credentials, cred) + } + if len(credentials) == 0 { + return nil, E.New("credential ", parentTag, " has no sub-credentials") + } + return credentials, nil +} + +func validateOCMOptions(options option.OCMServiceOptions) error { + hasCredentials := len(options.Credentials) > 0 + hasLegacyPath := options.CredentialPath != "" + hasLegacyUsages := options.UsagesPath != "" + hasLegacyDetour := options.Detour != "" + + if hasCredentials && hasLegacyPath { + return E.New("credential_path and credentials are mutually exclusive") + } + if hasCredentials && hasLegacyUsages { + return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials") + } + if hasCredentials && hasLegacyDetour { + return E.New("detour and credentials are mutually exclusive; use detour on individual credentials") + } + + if hasCredentials { + tags := make(map[string]bool) + credentialTypes := make(map[string]string) + for _, cred := range options.Credentials { + if tags[cred.Tag] { + return E.New("duplicate credential tag: ", cred.Tag) + } + tags[cred.Tag] = true + credentialTypes[cred.Tag] = cred.Type + if cred.Type == "default" || cred.Type == "" { + if cred.DefaultOptions.Reserve5h > 99 { + return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99") + } + if cred.DefaultOptions.ReserveWeekly > 99 { + return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99") + } + if cred.DefaultOptions.Limit5h > 100 { + return E.New("credential ", cred.Tag, ": limit_5h must be at most 100") + } + if cred.DefaultOptions.LimitWeekly > 100 { + return E.New("credential ", cred.Tag, ": limit_weekly must be at most 100") + } + if cred.DefaultOptions.Reserve5h > 0 && cred.DefaultOptions.Limit5h > 0 { + return E.New("credential ", cred.Tag, ": reserve_5h and limit_5h are mutually exclusive") + } + if cred.DefaultOptions.ReserveWeekly > 0 && cred.DefaultOptions.LimitWeekly > 0 { + return E.New("credential ", cred.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") + } + } + if cred.Type == "external" { + if cred.ExternalOptions.Token == "" { + return E.New("credential ", cred.Tag, ": external credential requires token") + } + if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" { + return E.New("credential ", cred.Tag, ": reverse external credential requires url") + } + } + if cred.Type == "balancer" { + switch cred.BalancerOptions.Strategy { + case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback: + default: + return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy) + } + if cred.BalancerOptions.RebalanceThreshold < 0 { + return E.New("credential ", cred.Tag, ": rebalance_threshold must not be negative") + } + } + } + + for _, user := range options.Users { + if user.Credential == "" { + return E.New("user ", user.Name, " must specify credential in multi-credential mode") + } + if !tags[user.Credential] { + return E.New("user ", user.Name, " references unknown credential: ", user.Credential) + } + if user.ExternalCredential != "" { + if !tags[user.ExternalCredential] { + return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential) + } + if credentialTypes[user.ExternalCredential] != "external" { + return E.New("user ", user.Name, ": external_credential must reference an external type credential") + } + } + } + } + + return nil +} + +func validateOCMCompositeCredentialModes( + options option.OCMServiceOptions, + providers map[string]credentialProvider, +) error { + for _, credOpt := range options.Credentials { + if credOpt.Type != "balancer" { + continue + } + + provider, exists := providers[credOpt.Tag] + if !exists { + return E.New("unknown credential: ", credOpt.Tag) + } + + for _, subCred := range provider.allCredentials() { + if !subCred.isAvailable() { + continue + } + if subCred.ocmIsAPIKeyMode() { + return E.New( + "credential ", credOpt.Tag, + " references API key default credential ", subCred.tagName(), + "; balancer and fallback only support OAuth default credentials", + ) + } + } + } + + return nil +} + +func credentialForUser( + userConfigMap map[string]*option.OCMUser, + providers map[string]credentialProvider, + legacyProvider credentialProvider, + username string, +) (credentialProvider, error) { + if legacyProvider != nil { + return legacyProvider, nil + } + userConfig, exists := userConfigMap[username] + if !exists { + return nil, E.New("no credential mapping for user: ", username) + } + provider, exists := providers[userConfig.Credential] + if !exists { + return nil, E.New("unknown credential: ", userConfig.Credential) + } + return provider, nil +} + +func noUserCredentialProvider( + providers map[string]credentialProvider, + legacyProvider credentialProvider, + options option.OCMServiceOptions, +) credentialProvider { + if legacyProvider != nil { + return legacyProvider + } + if len(options.Credentials) > 0 { + tag := options.Credentials[0].Tag + return providers[tag] + } + return nil +} diff --git a/service/ocm/credential_default.go b/service/ocm/credential_default.go new file mode 100644 index 000000000..b82af9d20 --- /dev/null +++ b/service/ocm/credential_default.go @@ -0,0 +1,749 @@ +package ocm + +import ( + "bytes" + "context" + stdTLS "crypto/tls" + "encoding/json" + "io" + "net" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/sagernet/fswatch" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/ntp" +) + +type defaultCredential struct { + tag string + serviceContext context.Context + credentialPath string + credentialFilePath string + credentials *oauthCredentials + access sync.RWMutex + state credentialState + stateAccess sync.RWMutex + pollAccess sync.Mutex + reloadAccess sync.Mutex + watcherAccess sync.Mutex + cap5h float64 + capWeekly float64 + usageTracker *AggregatedUsage + dialer N.Dialer + forwardHTTPClient *http.Client + logger log.ContextLogger + watcher *fswatch.Watcher + watcherRetryAt time.Time + + // Connection interruption + onBecameUnusable func() + interrupted bool + requestContext context.Context + cancelRequests context.CancelFunc + requestAccess sync.Mutex +} + +func newDefaultCredential(ctx context.Context, tag string, options option.OCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) { + credentialDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{ + Detour: options.Detour, + }, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "create dialer for credential ", tag) + } + httpClient := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSClientConfig: &stdTLS.Config{ + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + }, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + }, + } + reserve5h := options.Reserve5h + if reserve5h == 0 { + reserve5h = 1 + } + reserveWeekly := options.ReserveWeekly + if reserveWeekly == 0 { + reserveWeekly = 1 + } + var cap5h float64 + if options.Limit5h > 0 { + cap5h = float64(options.Limit5h) + } else { + cap5h = float64(100 - reserve5h) + } + var capWeekly float64 + if options.LimitWeekly > 0 { + capWeekly = float64(options.LimitWeekly) + } else { + capWeekly = float64(100 - reserveWeekly) + } + requestContext, cancelRequests := context.WithCancel(context.Background()) + credential := &defaultCredential{ + tag: tag, + serviceContext: ctx, + credentialPath: options.CredentialPath, + cap5h: cap5h, + capWeekly: capWeekly, + dialer: credentialDialer, + forwardHTTPClient: httpClient, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + } + if options.UsagesPath != "" { + credential.usageTracker = &AggregatedUsage{ + LastUpdated: time.Now(), + Combinations: make([]CostCombination, 0), + filePath: options.UsagesPath, + logger: logger, + } + } + return credential, nil +} + +func (c *defaultCredential) start() error { + credentialFilePath, err := resolveCredentialFilePath(c.credentialPath) + if err != nil { + return E.Cause(err, "resolve credential path for ", c.tag) + } + c.credentialFilePath = credentialFilePath + err = c.ensureCredentialWatcher() + if err != nil { + c.logger.Debug("start credential watcher for ", c.tag, ": ", err) + } + err = c.reloadCredentials(true) + if err != nil { + c.logger.Warn("initial credential load for ", c.tag, ": ", err) + } + if c.usageTracker != nil { + err = c.usageTracker.Load() + if err != nil { + c.logger.Warn("load usage statistics for ", c.tag, ": ", err) + } + } + return nil +} + +func (c *defaultCredential) setOnBecameUnusable(fn func()) { + c.onBecameUnusable = fn +} + +func (c *defaultCredential) tagName() string { + return c.tag +} + +func (c *defaultCredential) isExternal() bool { + return false +} + +func (c *defaultCredential) getAccessToken() (string, error) { + c.retryCredentialReloadIfNeeded() + + c.access.RLock() + if c.credentials != nil && !c.credentials.needsRefresh() { + token := c.credentials.getAccessToken() + c.access.RUnlock() + return token, nil + } + c.access.RUnlock() + + err := c.reloadCredentials(true) + if err == nil { + c.access.RLock() + if c.credentials != nil && !c.credentials.needsRefresh() { + token := c.credentials.getAccessToken() + c.access.RUnlock() + return token, nil + } + c.access.RUnlock() + } + + c.access.Lock() + defer c.access.Unlock() + + if c.credentials == nil { + return "", c.unavailableError() + } + if !c.credentials.needsRefresh() { + return c.credentials.getAccessToken(), nil + } + + err = platformCanWriteCredentials(c.credentialPath) + if err != nil { + return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation") + } + + baseCredentials := cloneCredentials(c.credentials) + newCredentials, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials) + if err != nil { + return "", err + } + + latestCredentials, latestErr := platformReadCredentials(c.credentialPath) + if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) { + c.credentials = latestCredentials + c.stateAccess.Lock() + c.state.unavailable = false + c.state.lastCredentialLoadAttempt = time.Now() + c.state.lastCredentialLoadError = "" + c.checkTransitionLocked() + c.stateAccess.Unlock() + if !latestCredentials.needsRefresh() { + return latestCredentials.getAccessToken(), nil + } + return "", E.New("credential ", c.tag, " changed while refreshing") + } + + c.credentials = newCredentials + c.stateAccess.Lock() + c.state.unavailable = false + c.state.lastCredentialLoadAttempt = time.Now() + c.state.lastCredentialLoadError = "" + c.checkTransitionLocked() + c.stateAccess.Unlock() + + err = platformWriteCredentials(newCredentials, c.credentialPath) + if err != nil { + c.logger.Error("persist refreshed token for ", c.tag, ": ", err) + } + + return newCredentials.getAccessToken(), nil +} + +func (c *defaultCredential) getAccountID() string { + c.access.RLock() + defer c.access.RUnlock() + if c.credentials == nil { + return "" + } + return c.credentials.getAccountID() +} + +func (c *defaultCredential) isAPIKeyMode() bool { + c.access.RLock() + defer c.access.RUnlock() + if c.credentials == nil { + return false + } + return c.credentials.isAPIKeyMode() +} + +func (c *defaultCredential) getBaseURL() string { + if c.isAPIKeyMode() { + return openaiAPIBaseURL + } + return chatGPTBackendURL +} + +func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { + c.stateAccess.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization + hadData := false + + activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) + if activeLimitIdentifier == "" { + activeLimitIdentifier = "codex" + } + + fiveHourResetChanged := false + fiveHourResetAt := headers.Get("x-" + activeLimitIdentifier + "-primary-reset-at") + if fiveHourResetAt != "" { + value, err := strconv.ParseInt(fiveHourResetAt, 10, 64) + if err == nil { + hadData = true + newReset := time.Unix(value, 0) + if newReset.After(c.state.fiveHourReset) { + fiveHourResetChanged = true + c.state.fiveHourReset = newReset + } + } + } + fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent") + if fiveHourPercent != "" { + value, err := strconv.ParseFloat(fiveHourPercent, 64) + if err == nil { + hadData = true + if value >= c.state.fiveHourUtilization || fiveHourResetChanged { + c.state.fiveHourUtilization = value + } + } + } + + weeklyResetChanged := false + weeklyResetAt := headers.Get("x-" + activeLimitIdentifier + "-secondary-reset-at") + if weeklyResetAt != "" { + value, err := strconv.ParseInt(weeklyResetAt, 10, 64) + if err == nil { + hadData = true + newReset := time.Unix(value, 0) + if newReset.After(c.state.weeklyReset) { + weeklyResetChanged = true + c.state.weeklyReset = newReset + } + } + } + weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent") + if weeklyPercent != "" { + value, err := strconv.ParseFloat(weeklyPercent, 64) + if err == nil { + hadData = true + if value >= c.state.weeklyUtilization || weeklyResetChanged { + c.state.weeklyUtilization = value + } + } + } + if hadData { + c.state.consecutivePollFailures = 0 + c.state.lastUpdated = time.Now() + } + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) + } + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) markRateLimited(resetAt time.Time) { + c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) + c.stateAccess.Lock() + c.state.hardRateLimited = true + c.state.rateLimitResetAt = resetAt + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) isUsable() bool { + c.retryCredentialReloadIfNeeded() + + c.stateAccess.RLock() + if c.state.unavailable { + c.stateAccess.RUnlock() + return false + } + if c.state.consecutivePollFailures > 0 { + c.stateAccess.RUnlock() + return false + } + if c.state.hardRateLimited { + if time.Now().Before(c.state.rateLimitResetAt) { + c.stateAccess.RUnlock() + return false + } + c.stateAccess.RUnlock() + c.stateAccess.Lock() + if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + usable := c.checkReservesLocked() + c.stateAccess.Unlock() + return usable + } + usable := c.checkReservesLocked() + c.stateAccess.RUnlock() + return usable +} + +func (c *defaultCredential) checkReservesLocked() bool { + if c.state.fiveHourUtilization >= c.cap5h { + return false + } + if c.state.weeklyUtilization >= c.capWeekly { + return false + } + return true +} + +// checkTransitionLocked detects usable->unusable transition. +// Must be called with stateAccess write lock held. +func (c *defaultCredential) checkTransitionLocked() bool { + unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() || c.state.consecutivePollFailures > 0 + if unusable && !c.interrupted { + c.interrupted = true + return true + } + if !unusable && c.interrupted { + c.interrupted = false + } + return false +} + +func (c *defaultCredential) interruptConnections() { + c.logger.Warn("interrupting connections for ", c.tag) + c.requestAccess.Lock() + c.cancelRequests() + c.requestContext, c.cancelRequests = context.WithCancel(context.Background()) + c.requestAccess.Unlock() + if c.onBecameUnusable != nil { + c.onBecameUnusable() + } +} + +func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext { + c.requestAccess.Lock() + credentialContext := c.requestContext + c.requestAccess.Unlock() + derived, cancel := context.WithCancel(parent) + stop := context.AfterFunc(credentialContext, func() { + cancel() + }) + return &credentialRequestContext{ + Context: derived, + releaseFuncs: []func() bool{stop}, + cancelFunc: cancel, + } +} + +func (c *defaultCredential) fiveHourUtilization() float64 { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.fiveHourUtilization +} + +func (c *defaultCredential) weeklyUtilization() float64 { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.weeklyUtilization +} + +func (c *defaultCredential) planWeight() float64 { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return ocmPlanWeight(c.state.accountType) +} + +func (c *defaultCredential) weeklyResetTime() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.weeklyReset +} + +func (c *defaultCredential) isAvailable() bool { + c.retryCredentialReloadIfNeeded() + + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return !c.state.unavailable +} + +func (c *defaultCredential) unavailableError() error { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + if !c.state.unavailable { + return nil + } + if c.state.lastCredentialLoadError == "" { + return E.New("credential ", c.tag, " is unavailable") + } + return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError) +} + +func (c *defaultCredential) lastUpdatedTime() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.lastUpdated +} + +func (c *defaultCredential) markUsagePollAttempted() { + c.stateAccess.Lock() + defer c.stateAccess.Unlock() + c.state.lastUpdated = time.Now() +} + +func (c *defaultCredential) incrementPollFailures() { + c.stateAccess.Lock() + c.state.consecutivePollFailures++ + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration { + c.stateAccess.RLock() + failures := c.state.consecutivePollFailures + c.stateAccess.RUnlock() + if failures <= 0 { + return baseInterval + } + backoff := failedPollRetryInterval * time.Duration(1<<(failures-1)) + if backoff > httpRetryMaxBackoff { + return httpRetryMaxBackoff + } + return backoff +} + +func (c *defaultCredential) isPollBackoffAtCap() bool { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + failures := c.state.consecutivePollFailures + return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff +} + +func (c *defaultCredential) earliestReset() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + if c.state.unavailable { + return time.Time{} + } + if c.state.hardRateLimited { + return c.state.rateLimitResetAt + } + earliest := c.state.fiveHourReset + if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) { + earliest = c.state.weeklyReset + } + return earliest +} + +func (c *defaultCredential) fiveHourCap() float64 { + return c.cap5h +} + +func (c *defaultCredential) weeklyCap() float64 { + return c.capWeekly +} + +func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage { + return c.usageTracker +} + +func (c *defaultCredential) httpClient() *http.Client { + return c.forwardHTTPClient +} + +func (c *defaultCredential) ocmDialer() N.Dialer { + return c.dialer +} + +func (c *defaultCredential) ocmIsAPIKeyMode() bool { + return c.isAPIKeyMode() +} + +func (c *defaultCredential) ocmGetAccountID() string { + return c.getAccountID() +} + +func (c *defaultCredential) ocmGetBaseURL() string { + return c.getBaseURL() +} + +func (c *defaultCredential) pollUsage(ctx context.Context) { + if !c.pollAccess.TryLock() { + return + } + defer c.pollAccess.Unlock() + defer c.markUsagePollAttempted() + + c.retryCredentialReloadIfNeeded() + if !c.isAvailable() { + return + } + if c.isAPIKeyMode() { + return + } + + accessToken, err := c.getAccessToken() + if err != nil { + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": get token: ", err) + } + c.incrementPollFailures() + return + } + + var usageURL string + if c.isAPIKeyMode() { + usageURL = openaiAPIBaseURL + "/api/codex/usage" + } else { + usageURL = strings.TrimSuffix(chatGPTBackendURL, "/codex") + "/wham/usage" + } + + accountID := c.getAccountID() + pollClient := &http.Client{ + Transport: c.forwardHTTPClient.Transport, + Timeout: 5 * time.Second, + } + + response, err := doHTTPWithRetry(ctx, pollClient, func() (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, usageURL, nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+accessToken) + if accountID != "" { + request.Header.Set("ChatGPT-Account-Id", accountID) + } + return request, nil + }) + if err != nil { + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": ", err) + } + c.incrementPollFailures() + return + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + if response.StatusCode == http.StatusTooManyRequests { + c.logger.Warn("poll usage for ", c.tag, ": rate limited") + } + body, _ := io.ReadAll(response.Body) + c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) + c.incrementPollFailures() + return + } + + type usageWindow struct { + UsedPercent float64 `json:"used_percent"` + ResetAt int64 `json:"reset_at"` + } + var usageResponse struct { + PlanType string `json:"plan_type"` + RateLimit *struct { + PrimaryWindow *usageWindow `json:"primary_window"` + SecondaryWindow *usageWindow `json:"secondary_window"` + } `json:"rate_limit"` + } + err = json.NewDecoder(response.Body).Decode(&usageResponse) + if err != nil { + c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) + c.incrementPollFailures() + return + } + + c.stateAccess.Lock() + isFirstUpdate := c.state.lastUpdated.IsZero() + oldFiveHour := c.state.fiveHourUtilization + oldWeekly := c.state.weeklyUtilization + c.state.consecutivePollFailures = 0 + if usageResponse.RateLimit != nil { + if w := usageResponse.RateLimit.PrimaryWindow; w != nil { + c.state.fiveHourUtilization = w.UsedPercent + if w.ResetAt > 0 { + c.state.fiveHourReset = time.Unix(w.ResetAt, 0) + } + } + if w := usageResponse.RateLimit.SecondaryWindow; w != nil { + c.state.weeklyUtilization = w.UsedPercent + if w.ResetAt > 0 { + c.state.weeklyReset = time.Unix(w.ResetAt, 0) + } + } + } + if usageResponse.PlanType != "" { + c.state.accountType = usageResponse.PlanType + } + if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { + c.state.hardRateLimited = false + } + if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { + resetSuffix := "" + if !c.state.weeklyReset.IsZero() { + resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) + } + c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) + } + shouldInterrupt := c.checkTransitionLocked() + c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } +} + +func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) { + accessToken, err := c.getAccessToken() + if err != nil { + return nil, E.Cause(err, "get access token for ", c.tag) + } + + path := original.URL.Path + var proxyPath string + if c.isAPIKeyMode() { + proxyPath = path + } else { + proxyPath = strings.TrimPrefix(path, "/v1") + } + + proxyURL := c.getBaseURL() + proxyPath + if original.URL.RawQuery != "" { + proxyURL += "?" + original.URL.RawQuery + } + + var body io.Reader + if bodyBytes != nil { + body = bytes.NewReader(bodyBytes) + } else { + body = original.Body + } + proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body) + if err != nil { + return nil, err + } + + for key, values := range original.Header { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" { + proxyRequest.Header[key] = values + } + } + + for key, values := range serviceHeaders { + proxyRequest.Header.Del(key) + proxyRequest.Header[key] = values + } + proxyRequest.Header.Set("Authorization", "Bearer "+accessToken) + + if accountID := c.getAccountID(); accountID != "" { + proxyRequest.Header.Set("ChatGPT-Account-Id", accountID) + } + + return proxyRequest, nil +} + +func (c *defaultCredential) close() { + if c.watcher != nil { + err := c.watcher.Close() + if err != nil { + c.logger.Error("close credential watcher for ", c.tag, ": ", err) + } + } + if c.usageTracker != nil { + c.usageTracker.cancelPendingSave() + err := c.usageTracker.Save() + if err != nil { + c.logger.Error("save usage statistics for ", c.tag, ": ", err) + } + } +} diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 0e0556be7..968bf904d 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -30,17 +30,17 @@ import ( const reverseProxyBaseURL = "http://reverse-proxy" type externalCredential struct { - tag string - baseURL string - token string - credDialer N.Dialer - httpClient *http.Client - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - pollInterval time.Duration - usageTracker *AggregatedUsage - logger log.ContextLogger + tag string + baseURL string + token string + credDialer N.Dialer + forwardHTTPClient *http.Client + state credentialState + stateAccess sync.RWMutex + pollAccess sync.Mutex + pollInterval time.Duration + usageTracker *AggregatedUsage + logger log.ContextLogger onBecameUnusable func() interrupted bool @@ -147,7 +147,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx // Receiver mode: no URL, wait for reverse connection cred.baseURL = reverseProxyBaseURL cred.credDialer = reverseSessionDialer{credential: cred} - cred.httpClient = &http.Client{ + cred.forwardHTTPClient = &http.Client{ Transport: &http.Transport{ ForceAttemptHTTP2: false, DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { @@ -211,11 +211,11 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx Time: ntp.TimeFuncFromContext(ctx), } } - cred.httpClient = &http.Client{Transport: transport} + cred.forwardHTTPClient = &http.Client{Transport: transport} } else { // Normal mode: standard HTTP client for proxying cred.credDialer = credentialDialer - cred.httpClient = &http.Client{Transport: transport} + cred.forwardHTTPClient = &http.Client{Transport: transport} cred.reverseCredDialer = reverseSessionDialer{credential: cred} cred.reverseHttpClient = &http.Client{ Transport: &http.Transport{ @@ -273,39 +273,39 @@ func (c *externalCredential) isUsable() bool { if !c.isAvailable() { return false } - c.stateMutex.RLock() + c.stateAccess.RLock() if c.state.consecutivePollFailures > 0 { - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() return false } if c.state.hardRateLimited { if time.Now().Before(c.state.rateLimitResetAt) { - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() return false } - c.stateMutex.RUnlock() - c.stateMutex.Lock() + c.stateAccess.RUnlock() + c.stateAccess.Lock() if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100 - c.stateMutex.Unlock() + c.stateAccess.Unlock() return usable } usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100 - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() return usable } func (c *externalCredential) fiveHourUtilization() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() return c.state.fiveHourUtilization } func (c *externalCredential) weeklyUtilization() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() return c.state.weeklyUtilization } @@ -318,8 +318,8 @@ func (c *externalCredential) weeklyCap() float64 { } func (c *externalCredential) planWeight() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() if c.state.remotePlanWeight > 0 { return c.state.remotePlanWeight } @@ -327,26 +327,26 @@ func (c *externalCredential) planWeight() float64 { } func (c *externalCredential) weeklyResetTime() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() return c.state.weeklyReset } func (c *externalCredential) markRateLimited(resetAt time.Time) { c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.hardRateLimited = true c.state.rateLimitResetAt = resetAt shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } } func (c *externalCredential) earliestReset() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() if c.state.hardRateLimited { return c.state.rateLimitResetAt } @@ -432,7 +432,7 @@ func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Con } func (c *externalCredential) updateStateFromHeaders(headers http.Header) { - c.stateMutex.Lock() + c.stateAccess.Lock() isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization @@ -494,7 +494,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } @@ -569,9 +569,9 @@ func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Resp } } // Forward transport with retries - if c.httpClient != nil { + if c.forwardHTTPClient != nil { forwardClient := &http.Client{ - Transport: c.httpClient.Transport, + Transport: c.forwardHTTPClient.Transport, Timeout: 5 * time.Second, } return doHTTPWithRetry(ctx, forwardClient, buildRequest(c.baseURL)) @@ -602,10 +602,10 @@ func (c *externalCredential) pollUsage(ctx context.Context) { // 404 means the remote does not have a status endpoint yet; // usage will be updated passively from response headers. if response.StatusCode == http.StatusNotFound { - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.consecutivePollFailures = 0 c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() } else { c.incrementPollFailures() } @@ -624,7 +624,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) { return } - c.stateMutex.Lock() + c.stateAccess.Lock() isFirstUpdate := c.state.lastUpdated.IsZero() oldFiveHour := c.state.fiveHourUtilization oldWeekly := c.state.weeklyUtilization @@ -645,28 +645,28 @@ func (c *externalCredential) pollUsage(ctx context.Context) { c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } } func (c *externalCredential) lastUpdatedTime() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() return c.state.lastUpdated } func (c *externalCredential) markUsagePollAttempted() { - c.stateMutex.Lock() - defer c.stateMutex.Unlock() + c.stateAccess.Lock() + defer c.stateAccess.Unlock() c.state.lastUpdated = time.Now() } func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration { - c.stateMutex.RLock() + c.stateAccess.RLock() failures := c.state.consecutivePollFailures - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() if failures <= 0 { return baseInterval } @@ -678,17 +678,17 @@ func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Durati } func (c *externalCredential) isPollBackoffAtCap() bool { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() failures := c.state.consecutivePollFailures return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff } func (c *externalCredential) incrementPollFailures() { - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.consecutivePollFailures++ shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt { c.interruptConnections() } @@ -698,14 +698,14 @@ func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage { return c.usageTracker } -func (c *externalCredential) httpTransport() *http.Client { +func (c *externalCredential) httpClient() *http.Client { if c.reverseHttpClient != nil { session := c.getReverseSession() if session != nil && !session.IsClosed() { return c.reverseHttpClient } } - return c.httpClient + return c.forwardHTTPClient } func (c *externalCredential) ocmDialer() N.Dialer { diff --git a/service/ocm/credential_file.go b/service/ocm/credential_file.go index b8252904e..861dbdb86 100644 --- a/service/ocm/credential_file.go +++ b/service/ocm/credential_file.go @@ -62,10 +62,10 @@ func (c *defaultCredential) ensureCredentialWatcher() error { } func (c *defaultCredential) retryCredentialReloadIfNeeded() { - c.stateMutex.RLock() + c.stateAccess.RLock() unavailable := c.state.unavailable lastAttempt := c.state.lastCredentialLoadAttempt - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() if !unavailable { return } @@ -84,10 +84,10 @@ func (c *defaultCredential) reloadCredentials(force bool) error { c.reloadAccess.Lock() defer c.reloadAccess.Unlock() - c.stateMutex.RLock() + c.stateAccess.RLock() unavailable := c.state.unavailable lastAttempt := c.state.lastCredentialLoadAttempt - c.stateMutex.RUnlock() + c.stateAccess.RUnlock() if !force { if !unavailable { return nil @@ -97,39 +97,39 @@ func (c *defaultCredential) reloadCredentials(force bool) error { } } - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.lastCredentialLoadAttempt = time.Now() - c.stateMutex.Unlock() + c.stateAccess.Unlock() credentials, err := platformReadCredentials(c.credentialPath) if err != nil { return c.markCredentialsUnavailable(E.Cause(err, "read credentials")) } - c.accessMutex.Lock() + c.access.Lock() c.credentials = credentials - c.accessMutex.Unlock() + c.access.Unlock() - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.unavailable = false c.state.lastCredentialLoadError = "" c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() return nil } func (c *defaultCredential) markCredentialsUnavailable(err error) error { - c.accessMutex.Lock() + c.access.Lock() hadCredentials := c.credentials != nil c.credentials = nil - c.accessMutex.Unlock() + c.access.Unlock() - c.stateMutex.Lock() + c.stateAccess.Lock() c.state.unavailable = true c.state.lastCredentialLoadError = err.Error() shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() + c.stateAccess.Unlock() if shouldInterrupt && hadCredentials { c.interruptConnections() diff --git a/service/ocm/credential_oauth.go b/service/ocm/credential_oauth.go new file mode 100644 index 000000000..bb240b5ab --- /dev/null +++ b/service/ocm/credential_oauth.go @@ -0,0 +1,225 @@ +package ocm + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "os" + "os/user" + "path/filepath" + "time" + + E "github.com/sagernet/sing/common/exceptions" +) + +const ( + oauth2ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + oauth2TokenURL = "https://auth.openai.com/oauth/token" + openaiAPIBaseURL = "https://api.openai.com" + chatGPTBackendURL = "https://chatgpt.com/backend-api/codex" + tokenRefreshIntervalDays = 8 +) + +func getRealUser() (*user.User, error) { + if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { + sudoUserInfo, err := user.Lookup(sudoUser) + if err == nil { + return sudoUserInfo, nil + } + } + return user.Current() +} + +func getDefaultCredentialsPath() (string, error) { + if codexHome := os.Getenv("CODEX_HOME"); codexHome != "" { + return filepath.Join(codexHome, "auth.json"), nil + } + userInfo, err := getRealUser() + if err != nil { + return "", err + } + return filepath.Join(userInfo.HomeDir, ".codex", "auth.json"), nil +} + +func readCredentialsFromFile(path string) (*oauthCredentials, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var credentials oauthCredentials + err = json.Unmarshal(data, &credentials) + if err != nil { + return nil, err + } + return &credentials, nil +} + +func checkCredentialFileWritable(path string) error { + file, err := os.OpenFile(path, os.O_WRONLY, 0) + if err != nil { + return err + } + return file.Close() +} + +func writeCredentialsToFile(credentials *oauthCredentials, path string) error { + data, err := json.MarshalIndent(credentials, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0o600) +} + +type oauthCredentials struct { + APIKey string `json:"OPENAI_API_KEY,omitempty"` + Tokens *tokenData `json:"tokens,omitempty"` + LastRefresh *time.Time `json:"last_refresh,omitempty"` +} + +type tokenData struct { + IDToken string `json:"id_token,omitempty"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + AccountID string `json:"account_id,omitempty"` +} + +func (c *oauthCredentials) isAPIKeyMode() bool { + return c.APIKey != "" +} + +func (c *oauthCredentials) getAccessToken() string { + if c.APIKey != "" { + return c.APIKey + } + if c.Tokens != nil { + return c.Tokens.AccessToken + } + return "" +} + +func (c *oauthCredentials) getAccountID() string { + if c.Tokens != nil { + return c.Tokens.AccountID + } + return "" +} + +func (c *oauthCredentials) needsRefresh() bool { + if c.APIKey != "" { + return false + } + if c.Tokens == nil || c.Tokens.RefreshToken == "" { + return false + } + if c.LastRefresh == nil { + return true + } + return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour +} + +func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { + if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" { + return nil, E.New("refresh token is empty") + } + + requestBody, err := json.Marshal(map[string]string{ + "grant_type": "refresh_token", + "refresh_token": credentials.Tokens.RefreshToken, + "client_id": oauth2ClientID, + "scope": "openid profile email", + }) + if err != nil { + return nil, E.Cause(err, "marshal request") + } + + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody)) + if err != nil { + return nil, err + } + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Accept", "application/json") + return request, nil + }) + if err != nil { + return nil, err + } + defer response.Body.Close() + + if response.StatusCode == http.StatusTooManyRequests { + body, _ := io.ReadAll(response.Body) + return nil, E.New("refresh rate limited: ", response.Status, " ", string(body)) + } + if response.StatusCode != http.StatusOK { + body, _ := io.ReadAll(response.Body) + return nil, E.New("refresh failed: ", response.Status, " ", string(body)) + } + + var tokenResponse struct { + IDToken string `json:"id_token"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + } + err = json.NewDecoder(response.Body).Decode(&tokenResponse) + if err != nil { + return nil, E.Cause(err, "decode response") + } + + newCredentials := *credentials + if newCredentials.Tokens == nil { + newCredentials.Tokens = &tokenData{} + } + if tokenResponse.IDToken != "" { + newCredentials.Tokens.IDToken = tokenResponse.IDToken + } + if tokenResponse.AccessToken != "" { + newCredentials.Tokens.AccessToken = tokenResponse.AccessToken + } + if tokenResponse.RefreshToken != "" { + newCredentials.Tokens.RefreshToken = tokenResponse.RefreshToken + } + now := time.Now() + newCredentials.LastRefresh = &now + + return &newCredentials, nil +} + +func cloneCredentials(credentials *oauthCredentials) *oauthCredentials { + if credentials == nil { + return nil + } + cloned := *credentials + if credentials.Tokens != nil { + clonedTokens := *credentials.Tokens + cloned.Tokens = &clonedTokens + } + if credentials.LastRefresh != nil { + lastRefresh := *credentials.LastRefresh + cloned.LastRefresh = &lastRefresh + } + return &cloned +} + +func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool { + if left == nil || right == nil { + return left == right + } + if left.APIKey != right.APIKey { + return false + } + if (left.Tokens == nil) != (right.Tokens == nil) { + return false + } + if left.Tokens != nil && *left.Tokens != *right.Tokens { + return false + } + if (left.LastRefresh == nil) != (right.LastRefresh == nil) { + return false + } + if left.LastRefresh != nil && !left.LastRefresh.Equal(*right.LastRefresh) { + return false + } + return true +} diff --git a/service/ocm/credential_provider.go b/service/ocm/credential_provider.go new file mode 100644 index 000000000..53383e368 --- /dev/null +++ b/service/ocm/credential_provider.go @@ -0,0 +1,411 @@ +package ocm + +import ( + "context" + "math/rand/v2" + "sync" + "sync/atomic" + "time" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + E "github.com/sagernet/sing/common/exceptions" +) + +type credentialProvider interface { + selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) + onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential + linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool + pollIfStale(ctx context.Context) + allCredentials() []credential + close() +} + +type singleCredentialProvider struct { + cred credential + sessionAccess sync.RWMutex + sessions map[string]time.Time +} + +func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + if !selection.allows(p.cred) { + return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") + } + if !p.cred.isAvailable() { + return nil, false, p.cred.unavailableError() + } + if !p.cred.isUsable() { + return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited") + } + var isNew bool + if sessionID != "" { + p.sessionAccess.Lock() + if p.sessions == nil { + p.sessions = make(map[string]time.Time) + } + _, exists := p.sessions[sessionID] + if !exists { + p.sessions[sessionID] = time.Now() + isNew = true + } + p.sessionAccess.Unlock() + } + return p.cred, isNew, nil +} + +func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential { + cred.markRateLimited(resetAt) + return nil +} + +func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { + now := time.Now() + p.sessionAccess.Lock() + for id, createdAt := range p.sessions { + if now.Sub(createdAt) > sessionExpiry { + delete(p.sessions, id) + } + } + p.sessionAccess.Unlock() + + if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) { + p.cred.pollUsage(ctx) + } +} + +func (p *singleCredentialProvider) allCredentials() []credential { + return []credential{p.cred} +} + +func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { + return func() bool { + return false + } +} + +func (p *singleCredentialProvider) close() {} + +type sessionEntry struct { + tag string + selectionScope credentialSelectionScope + createdAt time.Time +} + +type credentialInterruptKey struct { + tag string + selectionScope credentialSelectionScope +} + +type credentialInterruptEntry struct { + context context.Context + cancel context.CancelFunc +} + +type balancerProvider struct { + credentials []credential + strategy string + roundRobinIndex atomic.Uint64 + pollInterval time.Duration + rebalanceThreshold float64 + sessionAccess sync.RWMutex + sessions map[string]sessionEntry + interruptAccess sync.Mutex + credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry + logger log.ContextLogger +} + +func compositeCredentialSelectable(cred credential) bool { + return !cred.ocmIsAPIKeyMode() +} + +func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { + if pollInterval <= 0 { + pollInterval = defaultPollInterval + } + return &balancerProvider{ + credentials: credentials, + strategy: strategy, + pollInterval: pollInterval, + rebalanceThreshold: rebalanceThreshold, + sessions: make(map[string]sessionEntry), + credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry), + logger: logger, + } +} + +func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + if p.strategy == C.BalancerStrategyFallback { + best := p.pickCredential(selection.filter) + if best == nil { + return nil, false, allRateLimitedError(p.credentials) + } + return best, false, nil + } + + selectionScope := selection.scopeOrDefault() + if sessionID != "" { + p.sessionAccess.RLock() + entry, exists := p.sessions[sessionID] + p.sessionAccess.RUnlock() + if exists { + if entry.selectionScope == selectionScope { + for _, cred := range p.credentials { + if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && selection.allows(cred) && cred.isUsable() { + if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) { + better := p.pickLeastUsed(selection.filter) + if better != nil && better.tagName() != cred.tagName() { + effectiveThreshold := p.rebalanceThreshold / cred.planWeight() + delta := cred.weeklyUtilization() - better.weeklyUtilization() + if delta > effectiveThreshold { + p.logger.Info("rebalancing away from ", cred.tagName(), + ": utilization delta ", delta, "% exceeds effective threshold ", + effectiveThreshold, "% (weight ", cred.planWeight(), ")") + p.rebalanceCredential(cred.tagName(), selectionScope) + break + } + } + } + return cred, false, nil + } + } + } + p.sessionAccess.Lock() + delete(p.sessions, sessionID) + p.sessionAccess.Unlock() + } + } + + best := p.pickCredential(selection.filter) + if best == nil { + return nil, false, allRateLimitedError(p.credentials) + } + + isNew := sessionID != "" + if isNew { + p.sessionAccess.Lock() + p.sessions[sessionID] = sessionEntry{ + tag: best.tagName(), + selectionScope: selectionScope, + createdAt: time.Now(), + } + p.sessionAccess.Unlock() + } + return best, isNew, nil +} + +func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) { + key := credentialInterruptKey{tag: tag, selectionScope: selectionScope} + p.interruptAccess.Lock() + if entry, loaded := p.credentialInterrupts[key]; loaded { + entry.cancel() + } + ctx, cancel := context.WithCancel(context.Background()) + p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel} + p.interruptAccess.Unlock() + + p.sessionAccess.Lock() + for id, entry := range p.sessions { + if entry.tag == tag && entry.selectionScope == selectionScope { + delete(p.sessions, id) + } + } + p.sessionAccess.Unlock() +} + +func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool { + if p.strategy == C.BalancerStrategyFallback { + return func() bool { return false } + } + key := credentialInterruptKey{ + tag: cred.tagName(), + selectionScope: selection.scopeOrDefault(), + } + p.interruptAccess.Lock() + entry, loaded := p.credentialInterrupts[key] + if !loaded { + ctx, cancel := context.WithCancel(context.Background()) + entry = credentialInterruptEntry{context: ctx, cancel: cancel} + p.credentialInterrupts[key] = entry + } + p.interruptAccess.Unlock() + return context.AfterFunc(entry.context, onInterrupt) +} + +func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential { + cred.markRateLimited(resetAt) + if p.strategy == C.BalancerStrategyFallback { + return p.pickCredential(selection.filter) + } + if sessionID != "" { + p.sessionAccess.Lock() + delete(p.sessions, sessionID) + p.sessionAccess.Unlock() + } + + best := p.pickCredential(selection.filter) + if best != nil && sessionID != "" { + p.sessionAccess.Lock() + p.sessions[sessionID] = sessionEntry{ + tag: best.tagName(), + selectionScope: selection.scopeOrDefault(), + createdAt: time.Now(), + } + p.sessionAccess.Unlock() + } + return best +} + +func (p *balancerProvider) pickCredential(filter func(credential) bool) credential { + switch p.strategy { + case C.BalancerStrategyRoundRobin: + return p.pickRoundRobin(filter) + case C.BalancerStrategyRandom: + return p.pickRandom(filter) + case C.BalancerStrategyFallback: + return p.pickFallback(filter) + default: + return p.pickLeastUsed(filter) + } +} + +func (p *balancerProvider) pickFallback(filter func(credential) bool) credential { + for _, cred := range p.credentials { + if filter != nil && !filter(cred) { + continue + } + if !compositeCredentialSelectable(cred) { + continue + } + if cred.isUsable() { + return cred + } + } + return nil +} + +const weeklyWindowHours = 7 * 24 + +func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { + var best credential + bestScore := float64(-1) + now := time.Now() + for _, cred := range p.credentials { + if filter != nil && !filter(cred) { + continue + } + if !compositeCredentialSelectable(cred) { + continue + } + if !cred.isUsable() { + continue + } + remaining := cred.weeklyCap() - cred.weeklyUtilization() + score := remaining * cred.planWeight() + resetTime := cred.weeklyResetTime() + if !resetTime.IsZero() { + timeUntilReset := resetTime.Sub(now) + if timeUntilReset < time.Hour { + timeUntilReset = time.Hour + } + score *= weeklyWindowHours / timeUntilReset.Hours() + } + if score > bestScore { + bestScore = score + best = cred + } + } + return best +} + +func ocmPlanWeight(accountType string) float64 { + switch accountType { + case "pro": + return 10 + case "plus": + return 1 + default: + return 1 + } +} + +func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential { + start := int(p.roundRobinIndex.Add(1) - 1) + count := len(p.credentials) + for offset := range count { + candidate := p.credentials[(start+offset)%count] + if filter != nil && !filter(candidate) { + continue + } + if !compositeCredentialSelectable(candidate) { + continue + } + if candidate.isUsable() { + return candidate + } + } + return nil +} + +func (p *balancerProvider) pickRandom(filter func(credential) bool) credential { + var usable []credential + for _, candidate := range p.credentials { + if filter != nil && !filter(candidate) { + continue + } + if !compositeCredentialSelectable(candidate) { + continue + } + if candidate.isUsable() { + usable = append(usable, candidate) + } + } + if len(usable) == 0 { + return nil + } + return usable[rand.IntN(len(usable))] +} + +func (p *balancerProvider) pollIfStale(ctx context.Context) { + now := time.Now() + p.sessionAccess.Lock() + for id, entry := range p.sessions { + if now.Sub(entry.createdAt) > sessionExpiry { + delete(p.sessions, id) + } + } + p.sessionAccess.Unlock() + + for _, cred := range p.credentials { + if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) { + cred.pollUsage(ctx) + } + } +} + +func (p *balancerProvider) allCredentials() []credential { + return p.credentials +} + +func (p *balancerProvider) close() {} + +func allRateLimitedError(credentials []credential) error { + var hasUnavailable bool + var earliest time.Time + for _, cred := range credentials { + if cred.unavailableError() != nil { + hasUnavailable = true + continue + } + resetAt := cred.earliestReset() + if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { + earliest = resetAt + } + } + if hasUnavailable { + return E.New("all credentials unavailable") + } + if earliest.IsZero() { + return E.New("all credentials rate-limited") + } + return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest))) +} diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go deleted file mode 100644 index 181132f09..000000000 --- a/service/ocm/credential_state.go +++ /dev/null @@ -1,1524 +0,0 @@ -package ocm - -import ( - "bytes" - "context" - stdTLS "crypto/tls" - "encoding/json" - "io" - "math/rand/v2" - "net" - "net/http" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/sagernet/fswatch" - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/dialer" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing-box/option" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/ntp" -) - -const ( - defaultPollInterval = 60 * time.Minute - failedPollRetryInterval = time.Minute - httpRetryMaxBackoff = 5 * time.Minute -) - -const ( - httpRetryMaxAttempts = 3 - httpRetryInitialDelay = 200 * time.Millisecond -) - -func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) { - var lastError error - for attempt := range httpRetryMaxAttempts { - if attempt > 0 { - delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1)) - select { - case <-ctx.Done(): - return nil, lastError - case <-time.After(delay): - } - } - request, err := buildRequest() - if err != nil { - return nil, err - } - response, err := client.Do(request) - if err == nil { - return response, nil - } - lastError = err - if ctx.Err() != nil { - return nil, lastError - } - } - return nil, lastError -} - -type credentialState struct { - fiveHourUtilization float64 - fiveHourReset time.Time - weeklyUtilization float64 - weeklyReset time.Time - hardRateLimited bool - rateLimitResetAt time.Time - accountType string - remotePlanWeight float64 - lastUpdated time.Time - consecutivePollFailures int - unavailable bool - lastCredentialLoadAttempt time.Time - lastCredentialLoadError string -} - -type defaultCredential struct { - tag string - serviceContext context.Context - credentialPath string - credentialFilePath string - credentials *oauthCredentials - accessMutex sync.RWMutex - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - reloadAccess sync.Mutex - watcherAccess sync.Mutex - cap5h float64 - capWeekly float64 - usageTracker *AggregatedUsage - dialer N.Dialer - httpClient *http.Client - logger log.ContextLogger - watcher *fswatch.Watcher - watcherRetryAt time.Time - - // Connection interruption - onBecameUnusable func() - interrupted bool - requestContext context.Context - cancelRequests context.CancelFunc - requestAccess sync.Mutex -} - -type credentialRequestContext struct { - context.Context - releaseOnce sync.Once - cancelOnce sync.Once - releaseFuncs []func() bool - cancelFunc context.CancelFunc -} - -func (c *credentialRequestContext) addInterruptLink(stop func() bool) { - c.releaseFuncs = append(c.releaseFuncs, stop) -} - -func (c *credentialRequestContext) releaseCredentialInterrupt() { - c.releaseOnce.Do(func() { - for _, f := range c.releaseFuncs { - f() - } - }) -} - -func (c *credentialRequestContext) cancelRequest() { - c.releaseCredentialInterrupt() - c.cancelOnce.Do(c.cancelFunc) -} - -type credential interface { - tagName() string - isAvailable() bool - isUsable() bool - isExternal() bool - fiveHourUtilization() float64 - weeklyUtilization() float64 - fiveHourCap() float64 - weeklyCap() float64 - planWeight() float64 - weeklyResetTime() time.Time - markRateLimited(resetAt time.Time) - earliestReset() time.Time - unavailableError() error - - getAccessToken() (string, error) - buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) - updateStateFromHeaders(header http.Header) - - wrapRequestContext(ctx context.Context) *credentialRequestContext - interruptConnections() - - setOnBecameUnusable(fn func()) - start() error - pollUsage(ctx context.Context) - lastUpdatedTime() time.Time - pollBackoff(base time.Duration) time.Duration - usageTrackerOrNil() *AggregatedUsage - httpTransport() *http.Client - close() - - // OCM-specific - ocmDialer() N.Dialer - ocmIsAPIKeyMode() bool - ocmGetAccountID() string - ocmGetBaseURL() string -} - -func newDefaultCredential(ctx context.Context, tag string, options option.OCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) { - credentialDialer, err := dialer.NewWithOptions(dialer.Options{ - Context: ctx, - Options: option.DialerOptions{ - Detour: options.Detour, - }, - RemoteIsDomain: true, - }) - if err != nil { - return nil, E.Cause(err, "create dialer for credential ", tag) - } - httpClient := &http.Client{ - Transport: &http.Transport{ - ForceAttemptHTTP2: true, - TLSClientConfig: &stdTLS.Config{ - RootCAs: adapter.RootPoolFromContext(ctx), - Time: ntp.TimeFuncFromContext(ctx), - }, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - }, - } - reserve5h := options.Reserve5h - if reserve5h == 0 { - reserve5h = 1 - } - reserveWeekly := options.ReserveWeekly - if reserveWeekly == 0 { - reserveWeekly = 1 - } - var cap5h float64 - if options.Limit5h > 0 { - cap5h = float64(options.Limit5h) - } else { - cap5h = float64(100 - reserve5h) - } - var capWeekly float64 - if options.LimitWeekly > 0 { - capWeekly = float64(options.LimitWeekly) - } else { - capWeekly = float64(100 - reserveWeekly) - } - requestContext, cancelRequests := context.WithCancel(context.Background()) - credential := &defaultCredential{ - tag: tag, - serviceContext: ctx, - credentialPath: options.CredentialPath, - cap5h: cap5h, - capWeekly: capWeekly, - dialer: credentialDialer, - httpClient: httpClient, - logger: logger, - requestContext: requestContext, - cancelRequests: cancelRequests, - } - if options.UsagesPath != "" { - credential.usageTracker = &AggregatedUsage{ - LastUpdated: time.Now(), - Combinations: make([]CostCombination, 0), - filePath: options.UsagesPath, - logger: logger, - } - } - return credential, nil -} - -func (c *defaultCredential) start() error { - credentialFilePath, err := resolveCredentialFilePath(c.credentialPath) - if err != nil { - return E.Cause(err, "resolve credential path for ", c.tag) - } - c.credentialFilePath = credentialFilePath - err = c.ensureCredentialWatcher() - if err != nil { - c.logger.Debug("start credential watcher for ", c.tag, ": ", err) - } - err = c.reloadCredentials(true) - if err != nil { - c.logger.Warn("initial credential load for ", c.tag, ": ", err) - } - if c.usageTracker != nil { - err = c.usageTracker.Load() - if err != nil { - c.logger.Warn("load usage statistics for ", c.tag, ": ", err) - } - } - return nil -} - -func (c *defaultCredential) getAccessToken() (string, error) { - c.retryCredentialReloadIfNeeded() - - c.accessMutex.RLock() - if c.credentials != nil && !c.credentials.needsRefresh() { - token := c.credentials.getAccessToken() - c.accessMutex.RUnlock() - return token, nil - } - c.accessMutex.RUnlock() - - err := c.reloadCredentials(true) - if err == nil { - c.accessMutex.RLock() - if c.credentials != nil && !c.credentials.needsRefresh() { - token := c.credentials.getAccessToken() - c.accessMutex.RUnlock() - return token, nil - } - c.accessMutex.RUnlock() - } - - c.accessMutex.Lock() - defer c.accessMutex.Unlock() - - if c.credentials == nil { - return "", c.unavailableError() - } - if !c.credentials.needsRefresh() { - return c.credentials.getAccessToken(), nil - } - - err = platformCanWriteCredentials(c.credentialPath) - if err != nil { - return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation") - } - - baseCredentials := cloneCredentials(c.credentials) - newCredentials, err := refreshToken(c.serviceContext, c.httpClient, c.credentials) - if err != nil { - return "", err - } - - latestCredentials, latestErr := platformReadCredentials(c.credentialPath) - if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) { - c.credentials = latestCredentials - c.stateMutex.Lock() - c.state.unavailable = false - c.state.lastCredentialLoadAttempt = time.Now() - c.state.lastCredentialLoadError = "" - c.checkTransitionLocked() - c.stateMutex.Unlock() - if !latestCredentials.needsRefresh() { - return latestCredentials.getAccessToken(), nil - } - return "", E.New("credential ", c.tag, " changed while refreshing") - } - - c.credentials = newCredentials - c.stateMutex.Lock() - c.state.unavailable = false - c.state.lastCredentialLoadAttempt = time.Now() - c.state.lastCredentialLoadError = "" - c.checkTransitionLocked() - c.stateMutex.Unlock() - - err = platformWriteCredentials(newCredentials, c.credentialPath) - if err != nil { - c.logger.Error("persist refreshed token for ", c.tag, ": ", err) - } - - return newCredentials.getAccessToken(), nil -} - -func (c *defaultCredential) getAccountID() string { - c.accessMutex.RLock() - defer c.accessMutex.RUnlock() - if c.credentials == nil { - return "" - } - return c.credentials.getAccountID() -} - -func (c *defaultCredential) isAPIKeyMode() bool { - c.accessMutex.RLock() - defer c.accessMutex.RUnlock() - if c.credentials == nil { - return false - } - return c.credentials.isAPIKeyMode() -} - -func (c *defaultCredential) getBaseURL() string { - if c.isAPIKeyMode() { - return openaiAPIBaseURL - } - return chatGPTBackendURL -} - -func (c *defaultCredential) updateStateFromHeaders(headers http.Header) { - c.stateMutex.Lock() - isFirstUpdate := c.state.lastUpdated.IsZero() - oldFiveHour := c.state.fiveHourUtilization - oldWeekly := c.state.weeklyUtilization - hadData := false - - activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) - if activeLimitIdentifier == "" { - activeLimitIdentifier = "codex" - } - - fiveHourResetChanged := false - fiveHourResetAt := headers.Get("x-" + activeLimitIdentifier + "-primary-reset-at") - if fiveHourResetAt != "" { - value, err := strconv.ParseInt(fiveHourResetAt, 10, 64) - if err == nil { - hadData = true - newReset := time.Unix(value, 0) - if newReset.After(c.state.fiveHourReset) { - fiveHourResetChanged = true - c.state.fiveHourReset = newReset - } - } - } - fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent") - if fiveHourPercent != "" { - value, err := strconv.ParseFloat(fiveHourPercent, 64) - if err == nil { - hadData = true - if value >= c.state.fiveHourUtilization || fiveHourResetChanged { - c.state.fiveHourUtilization = value - } - } - } - - weeklyResetChanged := false - weeklyResetAt := headers.Get("x-" + activeLimitIdentifier + "-secondary-reset-at") - if weeklyResetAt != "" { - value, err := strconv.ParseInt(weeklyResetAt, 10, 64) - if err == nil { - hadData = true - newReset := time.Unix(value, 0) - if newReset.After(c.state.weeklyReset) { - weeklyResetChanged = true - c.state.weeklyReset = newReset - } - } - } - weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent") - if weeklyPercent != "" { - value, err := strconv.ParseFloat(weeklyPercent, 64) - if err == nil { - hadData = true - if value >= c.state.weeklyUtilization || weeklyResetChanged { - c.state.weeklyUtilization = value - } - } - } - if hadData { - c.state.consecutivePollFailures = 0 - c.state.lastUpdated = time.Now() - } - if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { - resetSuffix := "" - if !c.state.weeklyReset.IsZero() { - resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) - } - c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) - } - shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() - if shouldInterrupt { - c.interruptConnections() - } -} - -func (c *defaultCredential) markRateLimited(resetAt time.Time) { - c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) - c.stateMutex.Lock() - c.state.hardRateLimited = true - c.state.rateLimitResetAt = resetAt - shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() - if shouldInterrupt { - c.interruptConnections() - } -} - -func (c *defaultCredential) isUsable() bool { - c.retryCredentialReloadIfNeeded() - - c.stateMutex.RLock() - if c.state.unavailable { - c.stateMutex.RUnlock() - return false - } - if c.state.consecutivePollFailures > 0 { - c.stateMutex.RUnlock() - return false - } - if c.state.hardRateLimited { - if time.Now().Before(c.state.rateLimitResetAt) { - c.stateMutex.RUnlock() - return false - } - c.stateMutex.RUnlock() - c.stateMutex.Lock() - if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) { - c.state.hardRateLimited = false - } - usable := c.checkReservesLocked() - c.stateMutex.Unlock() - return usable - } - usable := c.checkReservesLocked() - c.stateMutex.RUnlock() - return usable -} - -func (c *defaultCredential) checkReservesLocked() bool { - if c.state.fiveHourUtilization >= c.cap5h { - return false - } - if c.state.weeklyUtilization >= c.capWeekly { - return false - } - return true -} - -// checkTransitionLocked detects usable→unusable transition. -// Must be called with stateMutex write lock held. -func (c *defaultCredential) checkTransitionLocked() bool { - unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() || c.state.consecutivePollFailures > 0 - if unusable && !c.interrupted { - c.interrupted = true - return true - } - if !unusable && c.interrupted { - c.interrupted = false - } - return false -} - -func (c *defaultCredential) interruptConnections() { - c.logger.Warn("interrupting connections for ", c.tag) - c.requestAccess.Lock() - c.cancelRequests() - c.requestContext, c.cancelRequests = context.WithCancel(context.Background()) - c.requestAccess.Unlock() - if c.onBecameUnusable != nil { - c.onBecameUnusable() - } -} - -func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext { - c.requestAccess.Lock() - credentialContext := c.requestContext - c.requestAccess.Unlock() - derived, cancel := context.WithCancel(parent) - stop := context.AfterFunc(credentialContext, func() { - cancel() - }) - return &credentialRequestContext{ - Context: derived, - releaseFuncs: []func() bool{stop}, - cancelFunc: cancel, - } -} - -func (c *defaultCredential) weeklyUtilization() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return c.state.weeklyUtilization -} - -func (c *defaultCredential) planWeight() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return ocmPlanWeight(c.state.accountType) -} - -func (c *defaultCredential) weeklyResetTime() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return c.state.weeklyReset -} - -func (c *defaultCredential) isAvailable() bool { - c.retryCredentialReloadIfNeeded() - - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return !c.state.unavailable -} - -func (c *defaultCredential) unavailableError() error { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - if !c.state.unavailable { - return nil - } - if c.state.lastCredentialLoadError == "" { - return E.New("credential ", c.tag, " is unavailable") - } - return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError) -} - -func (c *defaultCredential) lastUpdatedTime() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return c.state.lastUpdated -} - -func (c *defaultCredential) markUsagePollAttempted() { - c.stateMutex.Lock() - defer c.stateMutex.Unlock() - c.state.lastUpdated = time.Now() -} - -func (c *defaultCredential) incrementPollFailures() { - c.stateMutex.Lock() - c.state.consecutivePollFailures++ - shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() - if shouldInterrupt { - c.interruptConnections() - } -} - -func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration { - c.stateMutex.RLock() - failures := c.state.consecutivePollFailures - c.stateMutex.RUnlock() - if failures <= 0 { - return baseInterval - } - backoff := failedPollRetryInterval * time.Duration(1<<(failures-1)) - if backoff > httpRetryMaxBackoff { - return httpRetryMaxBackoff - } - return backoff -} - -func (c *defaultCredential) isPollBackoffAtCap() bool { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - failures := c.state.consecutivePollFailures - return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff -} - -func (c *defaultCredential) earliestReset() time.Time { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - if c.state.unavailable { - return time.Time{} - } - if c.state.hardRateLimited { - return c.state.rateLimitResetAt - } - earliest := c.state.fiveHourReset - if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) { - earliest = c.state.weeklyReset - } - return earliest -} - -func (c *defaultCredential) pollUsage(ctx context.Context) { - if !c.pollAccess.TryLock() { - return - } - defer c.pollAccess.Unlock() - defer c.markUsagePollAttempted() - - c.retryCredentialReloadIfNeeded() - if !c.isAvailable() { - return - } - if c.isAPIKeyMode() { - return - } - - accessToken, err := c.getAccessToken() - if err != nil { - if !c.isPollBackoffAtCap() { - c.logger.Error("poll usage for ", c.tag, ": get token: ", err) - } - c.incrementPollFailures() - return - } - - var usageURL string - if c.isAPIKeyMode() { - usageURL = openaiAPIBaseURL + "/api/codex/usage" - } else { - usageURL = strings.TrimSuffix(chatGPTBackendURL, "/codex") + "/wham/usage" - } - - accountID := c.getAccountID() - httpClient := &http.Client{ - Transport: c.httpClient.Transport, - Timeout: 5 * time.Second, - } - - response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { - request, err := http.NewRequestWithContext(ctx, http.MethodGet, usageURL, nil) - if err != nil { - return nil, err - } - request.Header.Set("Authorization", "Bearer "+accessToken) - if accountID != "" { - request.Header.Set("ChatGPT-Account-Id", accountID) - } - return request, nil - }) - if err != nil { - if !c.isPollBackoffAtCap() { - c.logger.Error("poll usage for ", c.tag, ": ", err) - } - c.incrementPollFailures() - return - } - defer response.Body.Close() - - if response.StatusCode != http.StatusOK { - if response.StatusCode == http.StatusTooManyRequests { - c.logger.Warn("poll usage for ", c.tag, ": rate limited") - } - body, _ := io.ReadAll(response.Body) - c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) - c.incrementPollFailures() - return - } - - type usageWindow struct { - UsedPercent float64 `json:"used_percent"` - ResetAt int64 `json:"reset_at"` - } - var usageResponse struct { - PlanType string `json:"plan_type"` - RateLimit *struct { - PrimaryWindow *usageWindow `json:"primary_window"` - SecondaryWindow *usageWindow `json:"secondary_window"` - } `json:"rate_limit"` - } - err = json.NewDecoder(response.Body).Decode(&usageResponse) - if err != nil { - c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) - c.incrementPollFailures() - return - } - - c.stateMutex.Lock() - isFirstUpdate := c.state.lastUpdated.IsZero() - oldFiveHour := c.state.fiveHourUtilization - oldWeekly := c.state.weeklyUtilization - c.state.consecutivePollFailures = 0 - if usageResponse.RateLimit != nil { - if w := usageResponse.RateLimit.PrimaryWindow; w != nil { - c.state.fiveHourUtilization = w.UsedPercent - if w.ResetAt > 0 { - c.state.fiveHourReset = time.Unix(w.ResetAt, 0) - } - } - if w := usageResponse.RateLimit.SecondaryWindow; w != nil { - c.state.weeklyUtilization = w.UsedPercent - if w.ResetAt > 0 { - c.state.weeklyReset = time.Unix(w.ResetAt, 0) - } - } - } - if usageResponse.PlanType != "" { - c.state.accountType = usageResponse.PlanType - } - if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { - c.state.hardRateLimited = false - } - if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) { - resetSuffix := "" - if !c.state.weeklyReset.IsZero() { - resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset)) - } - c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) - } - shouldInterrupt := c.checkTransitionLocked() - c.stateMutex.Unlock() - if shouldInterrupt { - c.interruptConnections() - } -} - -func (c *defaultCredential) close() { - if c.watcher != nil { - err := c.watcher.Close() - if err != nil { - c.logger.Error("close credential watcher for ", c.tag, ": ", err) - } - } - if c.usageTracker != nil { - c.usageTracker.cancelPendingSave() - err := c.usageTracker.Save() - if err != nil { - c.logger.Error("save usage statistics for ", c.tag, ": ", err) - } - } -} - -func (c *defaultCredential) setOnBecameUnusable(fn func()) { - c.onBecameUnusable = fn -} - -func (c *defaultCredential) tagName() string { - return c.tag -} - -func (c *defaultCredential) isExternal() bool { - return false -} - -func (c *defaultCredential) fiveHourUtilization() float64 { - c.stateMutex.RLock() - defer c.stateMutex.RUnlock() - return c.state.fiveHourUtilization -} - -func (c *defaultCredential) fiveHourCap() float64 { - return c.cap5h -} - -func (c *defaultCredential) weeklyCap() float64 { - return c.capWeekly -} - -func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage { - return c.usageTracker -} - -func (c *defaultCredential) httpTransport() *http.Client { - return c.httpClient -} - -func (c *defaultCredential) ocmDialer() N.Dialer { - return c.dialer -} - -func (c *defaultCredential) ocmIsAPIKeyMode() bool { - return c.isAPIKeyMode() -} - -func (c *defaultCredential) ocmGetAccountID() string { - return c.getAccountID() -} - -func (c *defaultCredential) ocmGetBaseURL() string { - return c.getBaseURL() -} - -func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) { - accessToken, err := c.getAccessToken() - if err != nil { - return nil, E.Cause(err, "get access token for ", c.tag) - } - - path := original.URL.Path - var proxyPath string - if c.isAPIKeyMode() { - proxyPath = path - } else { - proxyPath = strings.TrimPrefix(path, "/v1") - } - - proxyURL := c.getBaseURL() + proxyPath - if original.URL.RawQuery != "" { - proxyURL += "?" + original.URL.RawQuery - } - - var body io.Reader - if bodyBytes != nil { - body = bytes.NewReader(bodyBytes) - } else { - body = original.Body - } - proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body) - if err != nil { - return nil, err - } - - for key, values := range original.Header { - if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" { - proxyRequest.Header[key] = values - } - } - - for key, values := range serviceHeaders { - proxyRequest.Header.Del(key) - proxyRequest.Header[key] = values - } - proxyRequest.Header.Set("Authorization", "Bearer "+accessToken) - - if accountID := c.getAccountID(); accountID != "" { - proxyRequest.Header.Set("ChatGPT-Account-Id", accountID) - } - - return proxyRequest, nil -} - -type credentialProvider interface { - selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) - onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential - linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool - pollIfStale(ctx context.Context) - allCredentials() []credential - close() -} - -type credentialSelectionScope string - -const ( - credentialSelectionScopeAll credentialSelectionScope = "all" - credentialSelectionScopeNonExternal credentialSelectionScope = "non_external" -) - -type credentialSelection struct { - scope credentialSelectionScope - filter func(credential) bool -} - -func (s credentialSelection) allows(cred credential) bool { - return s.filter == nil || s.filter(cred) -} - -func (s credentialSelection) scopeOrDefault() credentialSelectionScope { - if s.scope == "" { - return credentialSelectionScopeAll - } - return s.scope -} - -type singleCredentialProvider struct { - cred credential - sessionAccess sync.RWMutex - sessions map[string]time.Time -} - -func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { - if !selection.allows(p.cred) { - return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") - } - if !p.cred.isAvailable() { - return nil, false, p.cred.unavailableError() - } - if !p.cred.isUsable() { - return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited") - } - var isNew bool - if sessionID != "" { - p.sessionAccess.Lock() - if p.sessions == nil { - p.sessions = make(map[string]time.Time) - } - _, exists := p.sessions[sessionID] - if !exists { - p.sessions[sessionID] = time.Now() - isNew = true - } - p.sessionAccess.Unlock() - } - return p.cred, isNew, nil -} - -func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential { - cred.markRateLimited(resetAt) - return nil -} - -func (p *singleCredentialProvider) pollIfStale(ctx context.Context) { - now := time.Now() - p.sessionAccess.Lock() - for id, createdAt := range p.sessions { - if now.Sub(createdAt) > sessionExpiry { - delete(p.sessions, id) - } - } - p.sessionAccess.Unlock() - - if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) { - p.cred.pollUsage(ctx) - } -} - -func (p *singleCredentialProvider) allCredentials() []credential { - return []credential{p.cred} -} - -func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { - return func() bool { - return false - } -} - -func (p *singleCredentialProvider) close() {} - -const sessionExpiry = 24 * time.Hour - -type sessionEntry struct { - tag string - selectionScope credentialSelectionScope - createdAt time.Time -} - -type credentialInterruptKey struct { - tag string - selectionScope credentialSelectionScope -} - -type credentialInterruptEntry struct { - context context.Context - cancel context.CancelFunc -} - -type balancerProvider struct { - credentials []credential - strategy string - roundRobinIndex atomic.Uint64 - pollInterval time.Duration - rebalanceThreshold float64 - sessionMutex sync.RWMutex - sessions map[string]sessionEntry - interruptAccess sync.Mutex - credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry - logger log.ContextLogger -} - -func compositeCredentialSelectable(cred credential) bool { - return !cred.ocmIsAPIKeyMode() -} - -func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider { - if pollInterval <= 0 { - pollInterval = defaultPollInterval - } - return &balancerProvider{ - credentials: credentials, - strategy: strategy, - pollInterval: pollInterval, - rebalanceThreshold: rebalanceThreshold, - sessions: make(map[string]sessionEntry), - credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry), - logger: logger, - } -} - -func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { - if p.strategy == C.BalancerStrategyFallback { - best := p.pickCredential(selection.filter) - if best == nil { - return nil, false, allRateLimitedError(p.credentials) - } - return best, false, nil - } - - selectionScope := selection.scopeOrDefault() - if sessionID != "" { - p.sessionMutex.RLock() - entry, exists := p.sessions[sessionID] - p.sessionMutex.RUnlock() - if exists { - if entry.selectionScope == selectionScope { - for _, cred := range p.credentials { - if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && selection.allows(cred) && cred.isUsable() { - if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) { - better := p.pickLeastUsed(selection.filter) - if better != nil && better.tagName() != cred.tagName() { - effectiveThreshold := p.rebalanceThreshold / cred.planWeight() - delta := cred.weeklyUtilization() - better.weeklyUtilization() - if delta > effectiveThreshold { - p.logger.Info("rebalancing away from ", cred.tagName(), - ": utilization delta ", delta, "% exceeds effective threshold ", - effectiveThreshold, "% (weight ", cred.planWeight(), ")") - p.rebalanceCredential(cred.tagName(), selectionScope) - break - } - } - } - return cred, false, nil - } - } - } - p.sessionMutex.Lock() - delete(p.sessions, sessionID) - p.sessionMutex.Unlock() - } - } - - best := p.pickCredential(selection.filter) - if best == nil { - return nil, false, allRateLimitedError(p.credentials) - } - - isNew := sessionID != "" - if isNew { - p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{ - tag: best.tagName(), - selectionScope: selectionScope, - createdAt: time.Now(), - } - p.sessionMutex.Unlock() - } - return best, isNew, nil -} - -func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) { - key := credentialInterruptKey{tag: tag, selectionScope: selectionScope} - p.interruptAccess.Lock() - if entry, loaded := p.credentialInterrupts[key]; loaded { - entry.cancel() - } - ctx, cancel := context.WithCancel(context.Background()) - p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel} - p.interruptAccess.Unlock() - - p.sessionMutex.Lock() - for id, entry := range p.sessions { - if entry.tag == tag && entry.selectionScope == selectionScope { - delete(p.sessions, id) - } - } - p.sessionMutex.Unlock() -} - -func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool { - if p.strategy == C.BalancerStrategyFallback { - return func() bool { return false } - } - key := credentialInterruptKey{ - tag: cred.tagName(), - selectionScope: selection.scopeOrDefault(), - } - p.interruptAccess.Lock() - entry, loaded := p.credentialInterrupts[key] - if !loaded { - ctx, cancel := context.WithCancel(context.Background()) - entry = credentialInterruptEntry{context: ctx, cancel: cancel} - p.credentialInterrupts[key] = entry - } - p.interruptAccess.Unlock() - return context.AfterFunc(entry.context, onInterrupt) -} - -func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential { - cred.markRateLimited(resetAt) - if p.strategy == C.BalancerStrategyFallback { - return p.pickCredential(selection.filter) - } - if sessionID != "" { - p.sessionMutex.Lock() - delete(p.sessions, sessionID) - p.sessionMutex.Unlock() - } - - best := p.pickCredential(selection.filter) - if best != nil && sessionID != "" { - p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{ - tag: best.tagName(), - selectionScope: selection.scopeOrDefault(), - createdAt: time.Now(), - } - p.sessionMutex.Unlock() - } - return best -} - -func (p *balancerProvider) pickCredential(filter func(credential) bool) credential { - switch p.strategy { - case C.BalancerStrategyRoundRobin: - return p.pickRoundRobin(filter) - case C.BalancerStrategyRandom: - return p.pickRandom(filter) - case C.BalancerStrategyFallback: - return p.pickFallback(filter) - default: - return p.pickLeastUsed(filter) - } -} - -func (p *balancerProvider) pickFallback(filter func(credential) bool) credential { - for _, cred := range p.credentials { - if filter != nil && !filter(cred) { - continue - } - if !compositeCredentialSelectable(cred) { - continue - } - if cred.isUsable() { - return cred - } - } - return nil -} - -func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { - var best credential - bestScore := float64(-1) - now := time.Now() - for _, cred := range p.credentials { - if filter != nil && !filter(cred) { - continue - } - if !compositeCredentialSelectable(cred) { - continue - } - if !cred.isUsable() { - continue - } - remaining := cred.weeklyCap() - cred.weeklyUtilization() - score := remaining * cred.planWeight() - resetTime := cred.weeklyResetTime() - if !resetTime.IsZero() { - timeUntilReset := resetTime.Sub(now) - if timeUntilReset < time.Hour { - timeUntilReset = time.Hour - } - score *= weeklyWindowDuration / timeUntilReset.Hours() - } - if score > bestScore { - bestScore = score - best = cred - } - } - return best -} - -const weeklyWindowDuration = 7 * 24 // hours - -func ocmPlanWeight(accountType string) float64 { - switch accountType { - case "pro": - return 10 - case "plus": - return 1 - default: - return 1 - } -} - -func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential { - start := int(p.roundRobinIndex.Add(1) - 1) - count := len(p.credentials) - for offset := range count { - candidate := p.credentials[(start+offset)%count] - if filter != nil && !filter(candidate) { - continue - } - if !compositeCredentialSelectable(candidate) { - continue - } - if candidate.isUsable() { - return candidate - } - } - return nil -} - -func (p *balancerProvider) pickRandom(filter func(credential) bool) credential { - var usable []credential - for _, candidate := range p.credentials { - if filter != nil && !filter(candidate) { - continue - } - if !compositeCredentialSelectable(candidate) { - continue - } - if candidate.isUsable() { - usable = append(usable, candidate) - } - } - if len(usable) == 0 { - return nil - } - return usable[rand.IntN(len(usable))] -} - -func (p *balancerProvider) pollIfStale(ctx context.Context) { - now := time.Now() - p.sessionMutex.Lock() - for id, entry := range p.sessions { - if now.Sub(entry.createdAt) > sessionExpiry { - delete(p.sessions, id) - } - } - p.sessionMutex.Unlock() - - for _, cred := range p.credentials { - if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) { - cred.pollUsage(ctx) - } - } -} - -func (p *balancerProvider) allCredentials() []credential { - return p.credentials -} - -func (p *balancerProvider) close() {} - -func allRateLimitedError(credentials []credential) error { - var hasUnavailable bool - var earliest time.Time - for _, cred := range credentials { - if cred.unavailableError() != nil { - hasUnavailable = true - continue - } - resetAt := cred.earliestReset() - if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) { - earliest = resetAt - } - } - if hasUnavailable { - return E.New("all credentials unavailable") - } - if earliest.IsZero() { - return E.New("all credentials rate-limited") - } - return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest))) -} - -func buildOCMCredentialProviders( - ctx context.Context, - options option.OCMServiceOptions, - logger log.ContextLogger, -) (map[string]credentialProvider, []credential, error) { - allCredentialMap := make(map[string]credential) - var allCreds []credential - providers := make(map[string]credentialProvider) - - // Pass 1: create default and external credentials - for _, credOpt := range options.Credentials { - switch credOpt.Type { - case "default": - cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger) - if err != nil { - return nil, nil, err - } - allCredentialMap[credOpt.Tag] = cred - allCreds = append(allCreds, cred) - providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} - case "external": - cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger) - if err != nil { - return nil, nil, err - } - allCredentialMap[credOpt.Tag] = cred - allCreds = append(allCreds, cred) - providers[credOpt.Tag] = &singleCredentialProvider{cred: cred} - } - } - - // Pass 2: create balancer providers - for _, credOpt := range options.Credentials { - if credOpt.Type == "balancer" { - subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag) - if err != nil { - return nil, nil, err - } - providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger) - } - } - - return providers, allCreds, nil -} - -func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) { - credentials := make([]credential, 0, len(tags)) - for _, tag := range tags { - cred, exists := allCredentials[tag] - if !exists { - return nil, E.New("credential ", parentTag, " references unknown credential: ", tag) - } - credentials = append(credentials, cred) - } - if len(credentials) == 0 { - return nil, E.New("credential ", parentTag, " has no sub-credentials") - } - return credentials, nil -} - -func parseOCMRateLimitResetFromHeaders(headers http.Header) time.Time { - activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) - if activeLimitIdentifier != "" { - resetHeader := "x-" + activeLimitIdentifier + "-primary-reset-at" - if resetStr := headers.Get(resetHeader); resetStr != "" { - value, err := strconv.ParseInt(resetStr, 10, 64) - if err == nil { - return time.Unix(value, 0) - } - } - } - if retryAfter := headers.Get("Retry-After"); retryAfter != "" { - seconds, err := strconv.ParseInt(retryAfter, 10, 64) - if err == nil { - return time.Now().Add(time.Duration(seconds) * time.Second) - } - } - return time.Now().Add(5 * time.Minute) -} - -func validateOCMOptions(options option.OCMServiceOptions) error { - hasCredentials := len(options.Credentials) > 0 - hasLegacyPath := options.CredentialPath != "" - hasLegacyUsages := options.UsagesPath != "" - hasLegacyDetour := options.Detour != "" - - if hasCredentials && hasLegacyPath { - return E.New("credential_path and credentials are mutually exclusive") - } - if hasCredentials && hasLegacyUsages { - return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials") - } - if hasCredentials && hasLegacyDetour { - return E.New("detour and credentials are mutually exclusive; use detour on individual credentials") - } - - if hasCredentials { - tags := make(map[string]bool) - credentialTypes := make(map[string]string) - for _, cred := range options.Credentials { - if tags[cred.Tag] { - return E.New("duplicate credential tag: ", cred.Tag) - } - tags[cred.Tag] = true - credentialTypes[cred.Tag] = cred.Type - if cred.Type == "default" || cred.Type == "" { - if cred.DefaultOptions.Reserve5h > 99 { - return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99") - } - if cred.DefaultOptions.ReserveWeekly > 99 { - return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99") - } - if cred.DefaultOptions.Limit5h > 100 { - return E.New("credential ", cred.Tag, ": limit_5h must be at most 100") - } - if cred.DefaultOptions.LimitWeekly > 100 { - return E.New("credential ", cred.Tag, ": limit_weekly must be at most 100") - } - if cred.DefaultOptions.Reserve5h > 0 && cred.DefaultOptions.Limit5h > 0 { - return E.New("credential ", cred.Tag, ": reserve_5h and limit_5h are mutually exclusive") - } - if cred.DefaultOptions.ReserveWeekly > 0 && cred.DefaultOptions.LimitWeekly > 0 { - return E.New("credential ", cred.Tag, ": reserve_weekly and limit_weekly are mutually exclusive") - } - } - if cred.Type == "external" { - if cred.ExternalOptions.Token == "" { - return E.New("credential ", cred.Tag, ": external credential requires token") - } - if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" { - return E.New("credential ", cred.Tag, ": reverse external credential requires url") - } - } - if cred.Type == "balancer" { - switch cred.BalancerOptions.Strategy { - case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback: - default: - return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy) - } - if cred.BalancerOptions.RebalanceThreshold < 0 { - return E.New("credential ", cred.Tag, ": rebalance_threshold must not be negative") - } - } - } - - for _, user := range options.Users { - if user.Credential == "" { - return E.New("user ", user.Name, " must specify credential in multi-credential mode") - } - if !tags[user.Credential] { - return E.New("user ", user.Name, " references unknown credential: ", user.Credential) - } - if user.ExternalCredential != "" { - if !tags[user.ExternalCredential] { - return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential) - } - if credentialTypes[user.ExternalCredential] != "external" { - return E.New("user ", user.Name, ": external_credential must reference an external type credential") - } - } - } - } - - return nil -} - -func validateOCMCompositeCredentialModes( - options option.OCMServiceOptions, - providers map[string]credentialProvider, -) error { - for _, credOpt := range options.Credentials { - if credOpt.Type != "balancer" { - continue - } - - provider, exists := providers[credOpt.Tag] - if !exists { - return E.New("unknown credential: ", credOpt.Tag) - } - - for _, subCred := range provider.allCredentials() { - if !subCred.isAvailable() { - continue - } - if subCred.ocmIsAPIKeyMode() { - return E.New( - "credential ", credOpt.Tag, - " references API key default credential ", subCred.tagName(), - "; balancer and fallback only support OAuth default credentials", - ) - } - } - } - - return nil -} - -func credentialForUser( - userConfigMap map[string]*option.OCMUser, - providers map[string]credentialProvider, - legacyProvider credentialProvider, - username string, -) (credentialProvider, error) { - if legacyProvider != nil { - return legacyProvider, nil - } - userConfig, exists := userConfigMap[username] - if !exists { - return nil, E.New("no credential mapping for user: ", username) - } - provider, exists := providers[userConfig.Credential] - if !exists { - return nil, E.New("unknown credential: ", userConfig.Credential) - } - return provider, nil -} - -func noUserCredentialProvider( - providers map[string]credentialProvider, - legacyProvider credentialProvider, - options option.OCMServiceOptions, -) credentialProvider { - if legacyProvider != nil { - return legacyProvider - } - if len(options.Credentials) > 0 { - tag := options.Credentials[0].Tag - return providers[tag] - } - return nil -} diff --git a/service/ocm/service.go b/service/ocm/service.go index 071cec8cc..101f90492 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -1,17 +1,13 @@ package ocm import ( - "bytes" "context" "encoding/json" "errors" "io" - "mime" "net/http" - "strconv" "strings" "sync" - "time" "github.com/sagernet/sing-box/adapter" boxService "github.com/sagernet/sing-box/adapter/service" @@ -21,14 +17,11 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" aTLS "github.com/sagernet/sing/common/tls" "github.com/go-chi/chi/v5" - "github.com/openai/openai-go/v3" - "github.com/openai/openai-go/v3/responses" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) @@ -160,71 +153,20 @@ func isReverseProxyHeader(header string) bool { } } -func normalizeRateLimitIdentifier(limitIdentifier string) string { - trimmedIdentifier := strings.TrimSpace(strings.ToLower(limitIdentifier)) - if trimmedIdentifier == "" { - return "" - } - return strings.ReplaceAll(trimmedIdentifier, "_", "-") -} - -func parseInt64Header(headers http.Header, headerName string) (int64, bool) { - headerValue := strings.TrimSpace(headers.Get(headerName)) - if headerValue == "" { - return 0, false - } - parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64) - if parseError != nil { - return 0, false - } - return parsedValue, true -} - -func weeklyCycleHintForLimit(headers http.Header, limitIdentifier string) *WeeklyCycleHint { - normalizedLimitIdentifier := normalizeRateLimitIdentifier(limitIdentifier) - if normalizedLimitIdentifier == "" { - return nil - } - - windowHeader := "x-" + normalizedLimitIdentifier + "-secondary-window-minutes" - resetHeader := "x-" + normalizedLimitIdentifier + "-secondary-reset-at" - - windowMinutes, hasWindowMinutes := parseInt64Header(headers, windowHeader) - resetAtUnix, hasResetAt := parseInt64Header(headers, resetHeader) - if !hasWindowMinutes || !hasResetAt || windowMinutes <= 0 || resetAtUnix <= 0 { - return nil - } - - return &WeeklyCycleHint{ - WindowMinutes: windowMinutes, - ResetAt: time.Unix(resetAtUnix, 0).UTC(), - } -} - -func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint { - activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) - if activeLimitIdentifier != "" { - if activeHint := weeklyCycleHintForLimit(headers, activeLimitIdentifier); activeHint != nil { - return activeHint - } - } - return weeklyCycleHintForLimit(headers, "codex") -} - type Service struct { boxService.Adapter - ctx context.Context - logger log.ContextLogger - options option.OCMServiceOptions - httpHeaders http.Header - listener *listener.Listener - tlsConfig tls.ServerConfig - httpServer *http.Server - userManager *UserManager - webSocketMutex sync.Mutex - webSocketGroup sync.WaitGroup - webSocketConns map[*webSocketSession]struct{} - shuttingDown bool + ctx context.Context + logger log.ContextLogger + options option.OCMServiceOptions + httpHeaders http.Header + listener *listener.Listener + tlsConfig tls.ServerConfig + httpServer *http.Server + userManager *UserManager + webSocketAccess sync.Mutex + webSocketGroup sync.WaitGroup + webSocketConns map[*webSocketSession]struct{} + shuttingDown bool // Legacy mode legacyCredential *defaultCredential @@ -361,562 +303,6 @@ func (s *Service) Start(stage adapter.StartStage) error { return nil } -func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) { - if len(s.options.Users) > 0 { - return credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) - } - provider := noUserCredentialProvider(s.providers, s.legacyProvider, s.options) - if provider == nil { - return nil, E.New("no credential available") - } - return provider, nil -} - -func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := log.ContextWithNewID(r.Context()) - if r.URL.Path == "/ocm/v1/status" { - s.handleStatusEndpoint(w, r) - return - } - - if r.URL.Path == "/ocm/v1/reverse" { - s.handleReverseConnect(ctx, w, r) - return - } - - path := r.URL.Path - if !strings.HasPrefix(path, "/v1/") { - writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/") - return - } - - var username string - if len(s.options.Users) > 0 { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") - return - } - clientToken := strings.TrimPrefix(authHeader, "Bearer ") - if clientToken == authHeader { - s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") - return - } - var ok bool - username, ok = s.userManager.Authenticate(clientToken) - if !ok { - s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") - return - } - } - - sessionID := r.Header.Get("session_id") - - // Resolve credential provider and user config - var provider credentialProvider - var userConfig *option.OCMUser - if len(s.options.Users) > 0 { - userConfig = s.userConfigMap[username] - var err error - provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) - if err != nil { - s.logger.ErrorContext(ctx, "resolve credential: ", err) - writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) - return - } - } else { - provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options) - } - if provider == nil { - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available") - return - } - - provider.pollIfStale(s.ctx) - - selection := credentialSelectionForUser(userConfig) - - selectedCredential, isNew, err := provider.selectCredential(sessionID, selection) - if err != nil { - writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error())) - return - } - - if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") { - s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, selection, isNew) - return - } - - if !selectedCredential.isExternal() && selectedCredential.ocmIsAPIKeyMode() { - // API key mode path handling - } else if !selectedCredential.isExternal() { - if path == "/v1/chat/completions" { - writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", - "chat completions endpoint is only available in API key mode") - return - } - } - - shouldTrackUsage := selectedCredential.usageTrackerOrNil() != nil && - (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) - canRetryRequest := len(provider.allCredentials()) > 1 - - // Read body for model extraction and retry buffer when JSON replay is useful. - var bodyBytes []byte - var requestModel string - var requestServiceTier string - if r.Body != nil && (shouldTrackUsage || canRetryRequest) { - mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type")) - isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json")) - if isJSONRequest { - bodyBytes, err = io.ReadAll(r.Body) - if err != nil { - s.logger.ErrorContext(ctx, "read request body: ", err) - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") - return - } - var request struct { - Model string `json:"model"` - ServiceTier string `json:"service_tier"` - } - if json.Unmarshal(bodyBytes, &request) == nil { - requestModel = request.Model - requestServiceTier = request.ServiceTier - } - r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - } - } - - if isNew { - logParts := []any{"assigned credential ", selectedCredential.tagName()} - if sessionID != "" { - logParts = append(logParts, " for session ", sessionID) - } - if username != "" { - logParts = append(logParts, " by user ", username) - } - if requestModel != "" { - logParts = append(logParts, ", model=", requestModel) - } - if requestServiceTier == "priority" { - logParts = append(logParts, ", fast") - } - s.logger.DebugContext(ctx, logParts...) - } - - requestContext := selectedCredential.wrapRequestContext(ctx) - { - currentRequestContext := requestContext - requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() { - currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) - })) - } - defer func() { - requestContext.cancelRequest() - }() - proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) - if err != nil { - s.logger.ErrorContext(ctx, "create proxy request: ", err) - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") - return - } - - response, err := selectedCredential.httpTransport().Do(proxyRequest) - if err != nil { - if r.Context().Err() != nil { - return - } - if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request") - return - } - writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) - return - } - requestContext.releaseCredentialInterrupt() - - // Transparent 429 retry - for response.StatusCode == http.StatusTooManyRequests { - resetAt := parseOCMRateLimitResetFromHeaders(response.Header) - nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection) - needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete - selectedCredential.updateStateFromHeaders(response.Header) - if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil { - response.Body.Close() - writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited") - return - } - response.Body.Close() - s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) - requestContext.cancelRequest() - requestContext = nextCredential.wrapRequestContext(ctx) - { - currentRequestContext := requestContext - requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() { - currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) - })) - } - retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) - if buildErr != nil { - s.logger.ErrorContext(ctx, "retry request: ", buildErr) - writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error()) - return - } - retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest) - if retryErr != nil { - if r.Context().Err() != nil { - return - } - if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request") - return - } - s.logger.ErrorContext(ctx, "retry request: ", retryErr) - writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) - return - } - requestContext.releaseCredentialInterrupt() - response = retryResponse - selectedCredential = nextCredential - } - defer response.Body.Close() - - selectedCredential.updateStateFromHeaders(response.Header) - - if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { - body, _ := io.ReadAll(response.Body) - s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) - go selectedCredential.pollUsage(s.ctx) - writeJSONError(w, r, http.StatusInternalServerError, "api_error", - "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) - return - } - - // Rewrite response headers for external users - if userConfig != nil && userConfig.ExternalCredential != "" { - s.rewriteResponseHeadersForExternalUser(response.Header, userConfig) - } - - for key, values := range response.Header { - if !isHopByHopHeader(key) && !isReverseProxyHeader(key) { - w.Header()[key] = values - } - } - w.WriteHeader(response.StatusCode) - - usageTracker := selectedCredential.usageTrackerOrNil() - if usageTracker != nil && response.StatusCode == http.StatusOK && - (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) { - s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username) - } else { - mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) - if err == nil && mediaType != "text/event-stream" { - _, _ = io.Copy(w, response.Body) - return - } - flusher, ok := w.(http.Flusher) - if !ok { - s.logger.ErrorContext(ctx, "streaming not supported") - return - } - buffer := make([]byte, buf.BufferSize) - for { - n, err := response.Body.Read(buffer) - if n > 0 { - _, writeError := w.Write(buffer[:n]) - if writeError != nil { - s.logger.ErrorContext(ctx, "write streaming response: ", writeError) - return - } - flusher.Flush() - } - if err != nil { - return - } - } - } -} - -func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) { - isChatCompletions := path == "/v1/chat/completions" - weeklyCycleHint := extractWeeklyCycleHint(response.Header) - mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) - isStreaming := err == nil && mediaType == "text/event-stream" - if !isStreaming && !isChatCompletions && response.Header.Get("Content-Type") == "" { - isStreaming = true - } - if !isStreaming { - bodyBytes, err := io.ReadAll(response.Body) - if err != nil { - s.logger.ErrorContext(ctx, "read response body: ", err) - return - } - - var responseModel, serviceTier string - var inputTokens, outputTokens, cachedTokens int64 - - if isChatCompletions { - var chatCompletion openai.ChatCompletion - if json.Unmarshal(bodyBytes, &chatCompletion) == nil { - responseModel = chatCompletion.Model - serviceTier = string(chatCompletion.ServiceTier) - inputTokens = chatCompletion.Usage.PromptTokens - outputTokens = chatCompletion.Usage.CompletionTokens - cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens - } - } else { - var responsesResponse responses.Response - if json.Unmarshal(bodyBytes, &responsesResponse) == nil { - responseModel = string(responsesResponse.Model) - serviceTier = string(responsesResponse.ServiceTier) - inputTokens = responsesResponse.Usage.InputTokens - outputTokens = responsesResponse.Usage.OutputTokens - cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens - } - } - - if inputTokens > 0 || outputTokens > 0 { - if responseModel == "" { - responseModel = requestModel - } - if responseModel != "" { - contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) - usageTracker.AddUsageWithCycleHint( - responseModel, - contextWindow, - inputTokens, - outputTokens, - cachedTokens, - serviceTier, - username, - time.Now(), - weeklyCycleHint, - ) - } - } - - _, _ = writer.Write(bodyBytes) - return - } - - flusher, ok := writer.(http.Flusher) - if !ok { - s.logger.ErrorContext(ctx, "streaming not supported") - return - } - - var inputTokens, outputTokens, cachedTokens int64 - var responseModel, serviceTier string - buffer := make([]byte, buf.BufferSize) - var leftover []byte - - for { - n, err := response.Body.Read(buffer) - if n > 0 { - data := append(leftover, buffer[:n]...) - lines := bytes.Split(data, []byte("\n")) - - if err == nil { - leftover = lines[len(lines)-1] - lines = lines[:len(lines)-1] - } else { - leftover = nil - } - - for _, line := range lines { - line = bytes.TrimSpace(line) - if len(line) == 0 { - continue - } - - if bytes.HasPrefix(line, []byte("data: ")) { - eventData := bytes.TrimPrefix(line, []byte("data: ")) - if bytes.Equal(eventData, []byte("[DONE]")) { - continue - } - - if isChatCompletions { - var chatChunk openai.ChatCompletionChunk - if json.Unmarshal(eventData, &chatChunk) == nil { - if chatChunk.Model != "" { - responseModel = chatChunk.Model - } - if chatChunk.ServiceTier != "" { - serviceTier = string(chatChunk.ServiceTier) - } - if chatChunk.Usage.PromptTokens > 0 { - inputTokens = chatChunk.Usage.PromptTokens - cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens - } - if chatChunk.Usage.CompletionTokens > 0 { - outputTokens = chatChunk.Usage.CompletionTokens - } - } - } else { - var streamEvent responses.ResponseStreamEventUnion - if json.Unmarshal(eventData, &streamEvent) == nil { - if streamEvent.Type == "response.completed" { - completedEvent := streamEvent.AsResponseCompleted() - if string(completedEvent.Response.Model) != "" { - responseModel = string(completedEvent.Response.Model) - } - if completedEvent.Response.ServiceTier != "" { - serviceTier = string(completedEvent.Response.ServiceTier) - } - if completedEvent.Response.Usage.InputTokens > 0 { - inputTokens = completedEvent.Response.Usage.InputTokens - cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens - } - if completedEvent.Response.Usage.OutputTokens > 0 { - outputTokens = completedEvent.Response.Usage.OutputTokens - } - } - } - } - } - } - - _, writeError := writer.Write(buffer[:n]) - if writeError != nil { - s.logger.ErrorContext(ctx, "write streaming response: ", writeError) - return - } - flusher.Flush() - } - - if err != nil { - if responseModel == "" { - responseModel = requestModel - } - - if inputTokens > 0 || outputTokens > 0 { - if responseModel != "" { - contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) - usageTracker.AddUsageWithCycleHint( - responseModel, - contextWindow, - inputTokens, - outputTokens, - cachedTokens, - serviceTier, - username, - time.Now(), - weeklyCycleHint, - ) - } - } - return - } - } -} - -func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") - return - } - - if len(s.options.Users) == 0 { - writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication") - return - } - - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") - return - } - clientToken := strings.TrimPrefix(authHeader, "Bearer ") - if clientToken == authHeader { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") - return - } - username, ok := s.userManager.Authenticate(clientToken) - if !ok { - writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") - return - } - - userConfig := s.userConfigMap[username] - if userConfig == nil { - writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found") - return - } - - provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) - if err != nil { - writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) - return - } - - provider.pollIfStale(r.Context()) - avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]float64{ - "five_hour_utilization": avgFiveHour, - "weekly_utilization": avgWeekly, - "plan_weight": totalWeight, - }) -} - -func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) { - var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 - for _, cred := range provider.allCredentials() { - if !cred.isAvailable() { - continue - } - if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential { - continue - } - if !userConfig.AllowExternalUsage && cred.isExternal() { - continue - } - weight := cred.planWeight() - remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization() - if remaining5h < 0 { - remaining5h = 0 - } - remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization() - if remainingWeekly < 0 { - remainingWeekly = 0 - } - totalWeightedRemaining5h += remaining5h * weight - totalWeightedRemainingWeekly += remainingWeekly * weight - totalWeight += weight - } - if totalWeight == 0 { - return 100, 100, 0 - } - return 100 - totalWeightedRemaining5h/totalWeight, - 100 - totalWeightedRemainingWeekly/totalWeight, - totalWeight -} - -func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.OCMUser) { - provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name) - if err != nil { - return - } - - avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) - - activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) - if activeLimitIdentifier == "" { - activeLimitIdentifier = "codex" - } - - headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64)) - headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64)) - if totalWeight > 0 { - headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64)) - } -} - func (s *Service) InterfaceUpdated() { for _, cred := range s.allCredentials { extCred, ok := cred.(*externalCredential) @@ -952,8 +338,8 @@ func (s *Service) Close() error { } func (s *Service) registerWebSocketSession(session *webSocketSession) bool { - s.webSocketMutex.Lock() - defer s.webSocketMutex.Unlock() + s.webSocketAccess.Lock() + defer s.webSocketAccess.Unlock() if s.shuttingDown { return false @@ -965,12 +351,12 @@ func (s *Service) registerWebSocketSession(session *webSocketSession) bool { } func (s *Service) unregisterWebSocketSession(session *webSocketSession) { - s.webSocketMutex.Lock() + s.webSocketAccess.Lock() _, loaded := s.webSocketConns[session] if loaded { delete(s.webSocketConns, session) } - s.webSocketMutex.Unlock() + s.webSocketAccess.Unlock() if loaded { s.webSocketGroup.Done() @@ -978,28 +364,28 @@ func (s *Service) unregisterWebSocketSession(session *webSocketSession) { } func (s *Service) isShuttingDown() bool { - s.webSocketMutex.Lock() - defer s.webSocketMutex.Unlock() + s.webSocketAccess.Lock() + defer s.webSocketAccess.Unlock() return s.shuttingDown } func (s *Service) interruptWebSocketSessionsForCredential(tag string) { - s.webSocketMutex.Lock() + s.webSocketAccess.Lock() var toClose []*webSocketSession for session := range s.webSocketConns { if session.credentialTag == tag { toClose = append(toClose, session) } } - s.webSocketMutex.Unlock() + s.webSocketAccess.Unlock() for _, session := range toClose { session.Close() } } func (s *Service) startWebSocketShutdown() []*webSocketSession { - s.webSocketMutex.Lock() - defer s.webSocketMutex.Unlock() + s.webSocketAccess.Lock() + defer s.webSocketAccess.Unlock() s.shuttingDown = true diff --git a/service/ocm/service_handler.go b/service/ocm/service_handler.go new file mode 100644 index 000000000..9fb9c96d7 --- /dev/null +++ b/service/ocm/service_handler.go @@ -0,0 +1,504 @@ +package ocm + +import ( + "bytes" + "context" + "encoding/json" + "io" + "mime" + "net/http" + "strconv" + "strings" + "time" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/responses" +) + +func weeklyCycleHintForLimit(headers http.Header, limitIdentifier string) *WeeklyCycleHint { + normalizedLimitIdentifier := normalizeRateLimitIdentifier(limitIdentifier) + if normalizedLimitIdentifier == "" { + return nil + } + + windowHeader := "x-" + normalizedLimitIdentifier + "-secondary-window-minutes" + resetHeader := "x-" + normalizedLimitIdentifier + "-secondary-reset-at" + + windowMinutes, hasWindowMinutes := parseInt64Header(headers, windowHeader) + resetAtUnix, hasResetAt := parseInt64Header(headers, resetHeader) + if !hasWindowMinutes || !hasResetAt || windowMinutes <= 0 || resetAtUnix <= 0 { + return nil + } + + return &WeeklyCycleHint{ + WindowMinutes: windowMinutes, + ResetAt: time.Unix(resetAtUnix, 0).UTC(), + } +} + +func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint { + activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) + if activeLimitIdentifier != "" { + if activeHint := weeklyCycleHintForLimit(headers, activeLimitIdentifier); activeHint != nil { + return activeHint + } + } + return weeklyCycleHintForLimit(headers, "codex") +} + +func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) { + if len(s.options.Users) > 0 { + return credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) + } + provider := noUserCredentialProvider(s.providers, s.legacyProvider, s.options) + if provider == nil { + return nil, E.New("no credential available") + } + return provider, nil +} + +func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := log.ContextWithNewID(r.Context()) + if r.URL.Path == "/ocm/v1/status" { + s.handleStatusEndpoint(w, r) + return + } + + if r.URL.Path == "/ocm/v1/reverse" { + s.handleReverseConnect(ctx, w, r) + return + } + + path := r.URL.Path + if !strings.HasPrefix(path, "/v1/") { + writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/") + return + } + + var username string + if len(s.options.Users) > 0 { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") + return + } + clientToken := strings.TrimPrefix(authHeader, "Bearer ") + if clientToken == authHeader { + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") + return + } + var ok bool + username, ok = s.userManager.Authenticate(clientToken) + if !ok { + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") + return + } + } + + sessionID := r.Header.Get("session_id") + + // Resolve credential provider and user config + var provider credentialProvider + var userConfig *option.OCMUser + if len(s.options.Users) > 0 { + userConfig = s.userConfigMap[username] + var err error + provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) + if err != nil { + s.logger.ErrorContext(ctx, "resolve credential: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + return + } + } else { + provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options) + } + if provider == nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available") + return + } + + provider.pollIfStale(s.ctx) + + selection := credentialSelectionForUser(userConfig) + + selectedCredential, isNew, err := provider.selectCredential(sessionID, selection) + if err != nil { + writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error())) + return + } + + if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") { + s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, selection, isNew) + return + } + + if !selectedCredential.isExternal() && selectedCredential.ocmIsAPIKeyMode() { + // API key mode path handling + } else if !selectedCredential.isExternal() { + if path == "/v1/chat/completions" { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", + "chat completions endpoint is only available in API key mode") + return + } + } + + shouldTrackUsage := selectedCredential.usageTrackerOrNil() != nil && + (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) + canRetryRequest := len(provider.allCredentials()) > 1 + + // Read body for model extraction and retry buffer when JSON replay is useful. + var bodyBytes []byte + var requestModel string + var requestServiceTier string + if r.Body != nil && (shouldTrackUsage || canRetryRequest) { + mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type")) + isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json")) + if isJSONRequest { + bodyBytes, err = io.ReadAll(r.Body) + if err != nil { + s.logger.ErrorContext(ctx, "read request body: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") + return + } + var request struct { + Model string `json:"model"` + ServiceTier string `json:"service_tier"` + } + if json.Unmarshal(bodyBytes, &request) == nil { + requestModel = request.Model + requestServiceTier = request.ServiceTier + } + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + } + + if isNew { + logParts := []any{"assigned credential ", selectedCredential.tagName()} + if sessionID != "" { + logParts = append(logParts, " for session ", sessionID) + } + if username != "" { + logParts = append(logParts, " by user ", username) + } + if requestModel != "" { + logParts = append(logParts, ", model=", requestModel) + } + if requestServiceTier == "priority" { + logParts = append(logParts, ", fast") + } + s.logger.DebugContext(ctx, logParts...) + } + + requestContext := selectedCredential.wrapRequestContext(ctx) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + })) + } + defer func() { + requestContext.cancelRequest() + }() + proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) + if err != nil { + s.logger.ErrorContext(ctx, "create proxy request: ", err) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") + return + } + + response, err := selectedCredential.httpClient().Do(proxyRequest) + if err != nil { + if r.Context().Err() != nil { + return + } + if requestContext.Err() != nil { + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request") + return + } + writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) + return + } + requestContext.releaseCredentialInterrupt() + + // Transparent 429 retry + for response.StatusCode == http.StatusTooManyRequests { + resetAt := parseOCMRateLimitResetFromHeaders(response.Header) + nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection) + needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete + selectedCredential.updateStateFromHeaders(response.Header) + if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil { + response.Body.Close() + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited") + return + } + response.Body.Close() + s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) + requestContext.cancelRequest() + requestContext = nextCredential.wrapRequestContext(ctx) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + })) + } + retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) + if buildErr != nil { + s.logger.ErrorContext(ctx, "retry request: ", buildErr) + writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error()) + return + } + retryResponse, retryErr := nextCredential.httpClient().Do(retryRequest) + if retryErr != nil { + if r.Context().Err() != nil { + return + } + if requestContext.Err() != nil { + writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request") + return + } + s.logger.ErrorContext(ctx, "retry request: ", retryErr) + writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) + return + } + requestContext.releaseCredentialInterrupt() + response = retryResponse + selectedCredential = nextCredential + } + defer response.Body.Close() + + selectedCredential.updateStateFromHeaders(response.Header) + + if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { + body, _ := io.ReadAll(response.Body) + s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) + go selectedCredential.pollUsage(s.ctx) + writeJSONError(w, r, http.StatusInternalServerError, "api_error", + "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) + return + } + + // Rewrite response headers for external users + if userConfig != nil && userConfig.ExternalCredential != "" { + s.rewriteResponseHeadersForExternalUser(response.Header, userConfig) + } + + for key, values := range response.Header { + if !isHopByHopHeader(key) && !isReverseProxyHeader(key) { + w.Header()[key] = values + } + } + w.WriteHeader(response.StatusCode) + + usageTracker := selectedCredential.usageTrackerOrNil() + if usageTracker != nil && response.StatusCode == http.StatusOK && + (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) { + s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username) + } else { + mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) + if err == nil && mediaType != "text/event-stream" { + _, _ = io.Copy(w, response.Body) + return + } + flusher, ok := w.(http.Flusher) + if !ok { + s.logger.ErrorContext(ctx, "streaming not supported") + return + } + buffer := make([]byte, buf.BufferSize) + for { + n, err := response.Body.Read(buffer) + if n > 0 { + _, writeError := w.Write(buffer[:n]) + if writeError != nil { + s.logger.ErrorContext(ctx, "write streaming response: ", writeError) + return + } + flusher.Flush() + } + if err != nil { + return + } + } + } +} + +func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) { + isChatCompletions := path == "/v1/chat/completions" + weeklyCycleHint := extractWeeklyCycleHint(response.Header) + mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) + isStreaming := err == nil && mediaType == "text/event-stream" + if !isStreaming && !isChatCompletions && response.Header.Get("Content-Type") == "" { + isStreaming = true + } + if !isStreaming { + bodyBytes, err := io.ReadAll(response.Body) + if err != nil { + s.logger.ErrorContext(ctx, "read response body: ", err) + return + } + + var responseModel, serviceTier string + var inputTokens, outputTokens, cachedTokens int64 + + if isChatCompletions { + var chatCompletion openai.ChatCompletion + if json.Unmarshal(bodyBytes, &chatCompletion) == nil { + responseModel = chatCompletion.Model + serviceTier = string(chatCompletion.ServiceTier) + inputTokens = chatCompletion.Usage.PromptTokens + outputTokens = chatCompletion.Usage.CompletionTokens + cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens + } + } else { + var responsesResponse responses.Response + if json.Unmarshal(bodyBytes, &responsesResponse) == nil { + responseModel = string(responsesResponse.Model) + serviceTier = string(responsesResponse.ServiceTier) + inputTokens = responsesResponse.Usage.InputTokens + outputTokens = responsesResponse.Usage.OutputTokens + cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens + } + } + + if inputTokens > 0 || outputTokens > 0 { + if responseModel == "" { + responseModel = requestModel + } + if responseModel != "" { + contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) + usageTracker.AddUsageWithCycleHint( + responseModel, + contextWindow, + inputTokens, + outputTokens, + cachedTokens, + serviceTier, + username, + time.Now(), + weeklyCycleHint, + ) + } + } + + _, _ = writer.Write(bodyBytes) + return + } + + flusher, ok := writer.(http.Flusher) + if !ok { + s.logger.ErrorContext(ctx, "streaming not supported") + return + } + + var inputTokens, outputTokens, cachedTokens int64 + var responseModel, serviceTier string + buffer := make([]byte, buf.BufferSize) + var leftover []byte + + for { + n, err := response.Body.Read(buffer) + if n > 0 { + data := append(leftover, buffer[:n]...) + lines := bytes.Split(data, []byte("\n")) + + if err == nil { + leftover = lines[len(lines)-1] + lines = lines[:len(lines)-1] + } else { + leftover = nil + } + + for _, line := range lines { + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + + if bytes.HasPrefix(line, []byte("data: ")) { + eventData := bytes.TrimPrefix(line, []byte("data: ")) + if bytes.Equal(eventData, []byte("[DONE]")) { + continue + } + + if isChatCompletions { + var chatChunk openai.ChatCompletionChunk + if json.Unmarshal(eventData, &chatChunk) == nil { + if chatChunk.Model != "" { + responseModel = chatChunk.Model + } + if chatChunk.ServiceTier != "" { + serviceTier = string(chatChunk.ServiceTier) + } + if chatChunk.Usage.PromptTokens > 0 { + inputTokens = chatChunk.Usage.PromptTokens + cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens + } + if chatChunk.Usage.CompletionTokens > 0 { + outputTokens = chatChunk.Usage.CompletionTokens + } + } + } else { + var streamEvent responses.ResponseStreamEventUnion + if json.Unmarshal(eventData, &streamEvent) == nil { + if streamEvent.Type == "response.completed" { + completedEvent := streamEvent.AsResponseCompleted() + if string(completedEvent.Response.Model) != "" { + responseModel = string(completedEvent.Response.Model) + } + if completedEvent.Response.ServiceTier != "" { + serviceTier = string(completedEvent.Response.ServiceTier) + } + if completedEvent.Response.Usage.InputTokens > 0 { + inputTokens = completedEvent.Response.Usage.InputTokens + cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens + } + if completedEvent.Response.Usage.OutputTokens > 0 { + outputTokens = completedEvent.Response.Usage.OutputTokens + } + } + } + } + } + } + + _, writeError := writer.Write(buffer[:n]) + if writeError != nil { + s.logger.ErrorContext(ctx, "write streaming response: ", writeError) + return + } + flusher.Flush() + } + + if err != nil { + if responseModel == "" { + responseModel = requestModel + } + + if inputTokens > 0 || outputTokens > 0 { + if responseModel != "" { + contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) + usageTracker.AddUsageWithCycleHint( + responseModel, + contextWindow, + inputTokens, + outputTokens, + cachedTokens, + serviceTier, + username, + time.Now(), + weeklyCycleHint, + ) + } + } + return + } + } +} diff --git a/service/ocm/service_status.go b/service/ocm/service_status.go new file mode 100644 index 000000000..29b95d063 --- /dev/null +++ b/service/ocm/service_status.go @@ -0,0 +1,114 @@ +package ocm + +import ( + "encoding/json" + "net/http" + "strconv" + "strings" + + "github.com/sagernet/sing-box/option" +) + +func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") + return + } + + if len(s.options.Users) == 0 { + writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication") + return + } + + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") + return + } + clientToken := strings.TrimPrefix(authHeader, "Bearer ") + if clientToken == authHeader { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") + return + } + username, ok := s.userManager.Authenticate(clientToken) + if !ok { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") + return + } + + userConfig := s.userConfigMap[username] + if userConfig == nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found") + return + } + + provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) + if err != nil { + writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) + return + } + + provider.pollIfStale(r.Context()) + avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]float64{ + "five_hour_utilization": avgFiveHour, + "weekly_utilization": avgWeekly, + "plan_weight": totalWeight, + }) +} + +func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) { + var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 + for _, cred := range provider.allCredentials() { + if !cred.isAvailable() { + continue + } + if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential { + continue + } + if !userConfig.AllowExternalUsage && cred.isExternal() { + continue + } + weight := cred.planWeight() + remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization() + if remaining5h < 0 { + remaining5h = 0 + } + remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization() + if remainingWeekly < 0 { + remainingWeekly = 0 + } + totalWeightedRemaining5h += remaining5h * weight + totalWeightedRemainingWeekly += remainingWeekly * weight + totalWeight += weight + } + if totalWeight == 0 { + return 100, 100, 0 + } + return 100 - totalWeightedRemaining5h/totalWeight, + 100 - totalWeightedRemainingWeekly/totalWeight, + totalWeight +} + +func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.OCMUser) { + provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name) + if err != nil { + return + } + + avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) + + activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) + if activeLimitIdentifier == "" { + activeLimitIdentifier = "codex" + } + + headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64)) + headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64)) + if totalWeight > 0 { + headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64)) + } +} diff --git a/service/ocm/service_user.go b/service/ocm/service_user.go index 494b981b9..b69655e9a 100644 --- a/service/ocm/service_user.go +++ b/service/ocm/service_user.go @@ -7,13 +7,13 @@ import ( ) type UserManager struct { - accessMutex sync.RWMutex + access sync.RWMutex tokenMap map[string]string } func (m *UserManager) UpdateUsers(users []option.OCMUser) { - m.accessMutex.Lock() - defer m.accessMutex.Unlock() + m.access.Lock() + defer m.access.Unlock() tokenMap := make(map[string]string, len(users)) for _, user := range users { tokenMap[user.Token] = user.Name @@ -22,8 +22,8 @@ func (m *UserManager) UpdateUsers(users []option.OCMUser) { } func (m *UserManager) Authenticate(token string) (string, bool) { - m.accessMutex.RLock() + m.access.RLock() username, found := m.tokenMap[token] - m.accessMutex.RUnlock() + m.access.RUnlock() return username, found }