ccm,ocm: watch credential_path and allow delayed credentials

This commit is contained in:
世界
2026-03-13 02:21:34 +08:00
parent a8934be7cd
commit c8d593503f
14 changed files with 688 additions and 68 deletions

View File

@@ -51,6 +51,10 @@ On macOS, credentials are read from the system keychain first, then fall back to
Refreshed tokens are automatically written back to the same location.
When `credential_path` points to a file, the service can start before the file exists. The credential becomes available automatically after the file is created or updated, and becomes unavailable immediately if the file is later removed or becomes invalid.
On macOS without an explicit `credential_path`, keychain changes are not watched. Automatic reload only applies to the credential file path.
Conflict with `credentials`.
#### credentials
@@ -76,7 +80,7 @@ Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a
}
```
A single OAuth credential file. The `type` field can be omitted (defaults to `default`).
A single OAuth credential file. The `type` field can be omitted (defaults to `default`). The service can start before the file exists, and reloads file updates automatically.
- `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`.
- `usages_path`: Optional usage tracking file for this credential.

View File

@@ -51,6 +51,10 @@ Claude Code OAuth 凭据文件的路径。
刷新的令牌会自动写回相同位置。
`credential_path` 指向文件时,即使文件尚不存在,服务也可以启动。文件被创建或更新后,凭据会自动变为可用;如果文件之后被删除或变为无效,该凭据会立即变为不可用。
在 macOS 上如果未显式设置 `credential_path`,不会监听钥匙串变化。自动重载只作用于凭据文件路径。
`credentials` 冲突。
#### credentials
@@ -76,7 +80,7 @@ Claude Code OAuth 凭据文件的路径。
}
```
单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。
单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。即使文件尚不存在,服务也可以启动,并会自动重载文件更新。
- `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。
- `usages_path`:此凭据的可选使用跟踪文件。

View File

@@ -49,6 +49,8 @@ If not specified, defaults to:
Refreshed tokens are automatically written back to the same location.
When `credential_path` points to a file, the service can start before the file exists. The credential becomes available automatically after the file is created or updated, and becomes unavailable immediately if the file is later removed or becomes invalid.
Conflict with `credentials`.
#### credentials
@@ -74,7 +76,7 @@ Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a
}
```
A single OAuth credential file. The `type` field can be omitted (defaults to `default`).
A single OAuth credential file. The `type` field can be omitted (defaults to `default`). The service can start before the file exists, and reloads file updates automatically.
- `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`.
- `usages_path`: Optional usage tracking file for this credential.

View File

@@ -49,6 +49,8 @@ OpenAI OAuth 凭据文件的路径。
刷新的令牌会自动写回相同位置。
`credential_path` 指向文件时,即使文件尚不存在,服务也可以启动。文件被创建或更新后,凭据会自动变为可用;如果文件之后被删除或变为无效,该凭据会立即变为不可用。
`credentials` 冲突。
#### credentials
@@ -74,7 +76,7 @@ OpenAI OAuth 凭据文件的路径。
}
```
单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。
单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。即使文件尚不存在,服务也可以启动,并会自动重载文件更新。
- `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。
- `usages_path`:此凭据的可选使用跟踪文件。

View File

@@ -9,6 +9,7 @@ import (
"os/user"
"path/filepath"
"runtime"
"slices"
"sync"
"time"
@@ -189,3 +190,24 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut
return &newCredentials, nil
}
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
if credentials == nil {
return nil
}
cloned := *credentials
cloned.Scopes = append([]string(nil), credentials.Scopes...)
return &cloned
}
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
if left == nil || right == nil {
return left == right
}
return left.AccessToken == right.AccessToken &&
left.RefreshToken == right.RefreshToken &&
left.ExpiresAt == right.ExpiresAt &&
slices.Equal(left.Scopes, right.Scopes) &&
left.SubscriptionType == right.SubscriptionType &&
left.IsMax == right.IsMax
}

View File

@@ -151,6 +151,10 @@ func (c *externalCredential) isExternal() bool {
return true
}
func (c *externalCredential) isAvailable() bool {
return true
}
func (c *externalCredential) isUsable() bool {
c.stateMutex.RLock()
if c.state.hardRateLimited {
@@ -210,6 +214,10 @@ func (c *externalCredential) earliestReset() time.Time {
return earliest
}
func (c *externalCredential) unavailableError() error {
return nil
}
func (c *externalCredential) getAccessToken() (string, error) {
return c.token, nil
}

View File

@@ -0,0 +1,141 @@
package ccm
import (
"path/filepath"
"time"
"github.com/sagernet/fswatch"
E "github.com/sagernet/sing/common/exceptions"
)
const credentialReloadRetryInterval = 2 * time.Second
func resolveCredentialFilePath(customPath string) (string, error) {
if customPath == "" {
var err error
customPath, err = getDefaultCredentialsPath()
if err != nil {
return "", err
}
}
if filepath.IsAbs(customPath) {
return customPath, nil
}
return filepath.Abs(customPath)
}
func (c *defaultCredential) ensureCredentialWatcher() error {
c.watcherAccess.Lock()
defer c.watcherAccess.Unlock()
if c.watcher != nil || c.credentialFilePath == "" {
return nil
}
if !c.watcherRetryAt.IsZero() && time.Now().Before(c.watcherRetryAt) {
return nil
}
watcher, err := fswatch.NewWatcher(fswatch.Options{
Path: []string{c.credentialFilePath},
Logger: c.logger,
Callback: func(string) {
err := c.reloadCredentials(true)
if err != nil {
c.logger.Warn("reload credentials for ", c.tag, ": ", err)
}
},
})
if err != nil {
c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval)
return err
}
err = watcher.Start()
if err != nil {
c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval)
return err
}
c.watcher = watcher
c.watcherRetryAt = time.Time{}
return nil
}
func (c *defaultCredential) retryCredentialReloadIfNeeded() {
c.stateMutex.RLock()
unavailable := c.state.unavailable
lastAttempt := c.state.lastCredentialLoadAttempt
c.stateMutex.RUnlock()
if !unavailable {
return
}
if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval {
return
}
err := c.ensureCredentialWatcher()
if err != nil {
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
}
_ = c.reloadCredentials(false)
}
func (c *defaultCredential) reloadCredentials(force bool) error {
c.reloadAccess.Lock()
defer c.reloadAccess.Unlock()
c.stateMutex.RLock()
unavailable := c.state.unavailable
lastAttempt := c.state.lastCredentialLoadAttempt
c.stateMutex.RUnlock()
if !force {
if !unavailable {
return nil
}
if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval {
return c.unavailableError()
}
}
c.stateMutex.Lock()
c.state.lastCredentialLoadAttempt = time.Now()
c.stateMutex.Unlock()
credentials, err := platformReadCredentials(c.credentialPath)
if err != nil {
return c.markCredentialsUnavailable(E.Cause(err, "read credentials"))
}
c.accessMutex.Lock()
c.credentials = credentials
c.accessMutex.Unlock()
c.stateMutex.Lock()
c.state.unavailable = false
c.state.lastCredentialLoadError = ""
c.state.accountType = credentials.SubscriptionType
c.checkTransitionLocked()
c.stateMutex.Unlock()
return nil
}
func (c *defaultCredential) markCredentialsUnavailable(err error) error {
c.accessMutex.Lock()
hadCredentials := c.credentials != nil
c.credentials = nil
c.accessMutex.Unlock()
c.stateMutex.Lock()
c.state.unavailable = true
c.state.lastCredentialLoadError = err.Error()
c.state.accountType = ""
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
if shouldInterrupt && hadCredentials {
c.interruptConnections()
}
return err
}

View File

@@ -17,6 +17,7 @@ import (
"sync/atomic"
"time"
"github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/log"
@@ -29,30 +30,38 @@ import (
const defaultPollInterval = 60 * time.Minute
type credentialState struct {
fiveHourUtilization float64
fiveHourReset time.Time
weeklyUtilization float64
weeklyReset time.Time
hardRateLimited bool
rateLimitResetAt time.Time
accountType string
lastUpdated time.Time
consecutivePollFailures int
fiveHourUtilization float64
fiveHourReset time.Time
weeklyUtilization float64
weeklyReset time.Time
hardRateLimited bool
rateLimitResetAt time.Time
accountType string
lastUpdated time.Time
consecutivePollFailures int
unavailable bool
lastCredentialLoadAttempt time.Time
lastCredentialLoadError string
}
type defaultCredential struct {
tag string
credentialPath string
credentials *oauthCredentials
accessMutex sync.RWMutex
state credentialState
stateMutex sync.RWMutex
pollAccess sync.Mutex
reserve5h uint8
reserveWeekly uint8
usageTracker *AggregatedUsage
httpClient *http.Client
logger log.ContextLogger
tag string
credentialPath string
credentialFilePath string
credentials *oauthCredentials
accessMutex sync.RWMutex
state credentialState
stateMutex sync.RWMutex
pollAccess sync.Mutex
reloadAccess sync.Mutex
watcherAccess sync.Mutex
reserve5h uint8
reserveWeekly uint8
usageTracker *AggregatedUsage
httpClient *http.Client
logger log.ContextLogger
watcher *fswatch.Watcher
watcherRetryAt time.Time
// Connection interruption
onBecameUnusable func()
@@ -83,12 +92,14 @@ func (c *credentialRequestContext) cancelRequest() {
type credential interface {
tagName() string
isAvailable() bool
isUsable() bool
isExternal() bool
fiveHourUtilization() float64
weeklyUtilization() float64
markRateLimited(resetAt time.Time)
earliestReset() time.Time
unavailableError() error
getAccessToken() (string, error)
buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error)
@@ -160,13 +171,18 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef
}
func (c *defaultCredential) start() error {
credentials, err := platformReadCredentials(c.credentialPath)
credentialFilePath, err := resolveCredentialFilePath(c.credentialPath)
if err != nil {
return E.Cause(err, "read credentials for ", c.tag)
return E.Cause(err, "resolve credential path for ", c.tag)
}
c.credentials = credentials
if credentials.SubscriptionType != "" {
c.state.accountType = credentials.SubscriptionType
c.credentialFilePath = credentialFilePath
err = c.ensureCredentialWatcher()
if err != nil {
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
}
err = c.reloadCredentials(true)
if err != nil {
c.logger.Warn("initial credential load for ", c.tag, ": ", err)
}
if c.usageTracker != nil {
err = c.usageTracker.Load()
@@ -178,33 +194,68 @@ func (c *defaultCredential) start() error {
}
func (c *defaultCredential) getAccessToken() (string, error) {
c.retryCredentialReloadIfNeeded()
c.accessMutex.RLock()
if !c.credentials.needsRefresh() {
if c.credentials != nil && !c.credentials.needsRefresh() {
token := c.credentials.AccessToken
c.accessMutex.RUnlock()
return token, nil
}
c.accessMutex.RUnlock()
err := c.reloadCredentials(true)
if err == nil {
c.accessMutex.RLock()
if c.credentials != nil && !c.credentials.needsRefresh() {
token := c.credentials.AccessToken
c.accessMutex.RUnlock()
return token, nil
}
c.accessMutex.RUnlock()
}
c.accessMutex.Lock()
defer c.accessMutex.Unlock()
if c.credentials == nil {
return "", c.unavailableError()
}
if !c.credentials.needsRefresh() {
return c.credentials.AccessToken, nil
}
baseCredentials := cloneCredentials(c.credentials)
newCredentials, err := refreshToken(c.httpClient, c.credentials)
if err != nil {
return "", err
}
c.credentials = newCredentials
if newCredentials.SubscriptionType != "" {
latestCredentials, latestErr := platformReadCredentials(c.credentialPath)
if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) {
c.credentials = latestCredentials
c.stateMutex.Lock()
c.state.accountType = newCredentials.SubscriptionType
c.state.unavailable = false
c.state.lastCredentialLoadAttempt = time.Now()
c.state.lastCredentialLoadError = ""
c.state.accountType = latestCredentials.SubscriptionType
c.checkTransitionLocked()
c.stateMutex.Unlock()
if !latestCredentials.needsRefresh() {
return latestCredentials.AccessToken, nil
}
return "", E.New("credential ", c.tag, " changed while refreshing")
}
c.credentials = newCredentials
c.stateMutex.Lock()
c.state.unavailable = false
c.state.lastCredentialLoadAttempt = time.Now()
c.state.lastCredentialLoadError = ""
c.state.accountType = newCredentials.SubscriptionType
c.checkTransitionLocked()
c.stateMutex.Unlock()
err = platformWriteCredentials(newCredentials, c.credentialPath)
if err != nil {
c.logger.Warn("persist refreshed token for ", c.tag, ": ", err)
@@ -299,7 +350,13 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) {
}
func (c *defaultCredential) isUsable() bool {
c.retryCredentialReloadIfNeeded()
c.stateMutex.RLock()
if c.state.unavailable {
c.stateMutex.RUnlock()
return false
}
if c.state.hardRateLimited {
if time.Now().Before(c.state.rateLimitResetAt) {
c.stateMutex.RUnlock()
@@ -332,7 +389,7 @@ func (c *defaultCredential) checkReservesLocked() bool {
// checkTransitionLocked detects usable→unusable transition.
// Must be called with stateMutex write lock held.
func (c *defaultCredential) checkTransitionLocked() bool {
unusable := c.state.hardRateLimited || !c.checkReservesLocked()
unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked()
if unusable && !c.interrupted {
c.interrupted = true
return true
@@ -375,6 +432,26 @@ func (c *defaultCredential) weeklyUtilization() float64 {
return c.state.weeklyUtilization
}
func (c *defaultCredential) isAvailable() bool {
c.retryCredentialReloadIfNeeded()
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
return !c.state.unavailable
}
func (c *defaultCredential) unavailableError() error {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
if !c.state.unavailable {
return nil
}
if c.state.lastCredentialLoadError == "" {
return E.New("credential ", c.tag, " is unavailable")
}
return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError)
}
func (c *defaultCredential) lastUpdatedTime() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
@@ -403,6 +480,9 @@ func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duratio
func (c *defaultCredential) earliestReset() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
if c.state.unavailable {
return time.Time{}
}
if c.state.hardRateLimited {
return c.state.rateLimitResetAt
}
@@ -430,6 +510,11 @@ func (c *defaultCredential) pollUsage(ctx context.Context) {
defer c.pollAccess.Unlock()
defer c.markUsagePollAttempted()
c.retryCredentialReloadIfNeeded()
if !c.isAvailable() {
return
}
accessToken, err := c.getAccessToken()
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": get token: ", err)
@@ -528,6 +613,12 @@ func (c *defaultCredential) pollUsage(ctx context.Context) {
}
func (c *defaultCredential) close() {
if c.watcher != nil {
err := c.watcher.Close()
if err != nil {
c.logger.Error("close credential watcher for ", c.tag, ": ", err)
}
}
if c.usageTracker != nil {
c.usageTracker.cancelPendingSave()
err := c.usageTracker.Save()
@@ -622,6 +713,9 @@ func (p *singleCredentialProvider) selectCredential(_ string, filter func(creden
if filter != nil && !filter(p.cred) {
return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out")
}
if !p.cred.isAvailable() {
return nil, false, p.cred.unavailableError()
}
if !p.cred.isUsable() {
return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited")
}
@@ -866,13 +960,21 @@ func (p *fallbackProvider) allCredentials() []credential {
func (p *fallbackProvider) close() {}
func allCredentialsUnavailableError(credentials []credential) error {
var hasUnavailable bool
var earliest time.Time
for _, cred := range credentials {
if cred.unavailableError() != nil {
hasUnavailable = true
continue
}
resetAt := cred.earliestReset()
if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) {
earliest = resetAt
}
}
if hasUnavailable {
return E.New("all credentials unavailable")
}
if earliest.IsZero() {
return E.New("all credentials rate-limited")
}

View File

@@ -88,7 +88,11 @@ func unavailableCredentialMessage(provider credentialProvider, fallback string)
if provider == nil {
return fallback
}
return allCredentialsUnavailableError(provider.allCredentials()).Error()
message := allCredentialsUnavailableError(provider.allCredentials()).Error()
if message == "all credentials unavailable" && fallback != "" {
return fallback
}
return message
}
func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) {
@@ -734,6 +738,9 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user
var totalFiveHour, totalWeekly float64
var count int
for _, cred := range provider.allCredentials() {
if !cred.isAvailable() {
continue
}
// Exclude the user's own external_credential (their contribution to us)
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
continue

View File

@@ -175,3 +175,41 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut
return &newCredentials, nil
}
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
if credentials == nil {
return nil
}
cloned := *credentials
if credentials.Tokens != nil {
clonedTokens := *credentials.Tokens
cloned.Tokens = &clonedTokens
}
if credentials.LastRefresh != nil {
lastRefresh := *credentials.LastRefresh
cloned.LastRefresh = &lastRefresh
}
return &cloned
}
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
if left == nil || right == nil {
return left == right
}
if left.APIKey != right.APIKey {
return false
}
if (left.Tokens == nil) != (right.Tokens == nil) {
return false
}
if left.Tokens != nil && *left.Tokens != *right.Tokens {
return false
}
if (left.LastRefresh == nil) != (right.LastRefresh == nil) {
return false
}
if left.LastRefresh != nil && !left.LastRefresh.Equal(*right.LastRefresh) {
return false
}
return true
}

View File

@@ -157,6 +157,10 @@ func (c *externalCredential) isExternal() bool {
return true
}
func (c *externalCredential) isAvailable() bool {
return true
}
func (c *externalCredential) isUsable() bool {
c.stateMutex.RLock()
if c.state.hardRateLimited {
@@ -215,6 +219,10 @@ func (c *externalCredential) earliestReset() time.Time {
return earliest
}
func (c *externalCredential) unavailableError() error {
return nil
}
func (c *externalCredential) getAccessToken() (string, error) {
return c.token, nil
}

View File

@@ -0,0 +1,139 @@
package ocm
import (
"path/filepath"
"time"
"github.com/sagernet/fswatch"
E "github.com/sagernet/sing/common/exceptions"
)
const credentialReloadRetryInterval = 2 * time.Second
func resolveCredentialFilePath(customPath string) (string, error) {
if customPath == "" {
var err error
customPath, err = getDefaultCredentialsPath()
if err != nil {
return "", err
}
}
if filepath.IsAbs(customPath) {
return customPath, nil
}
return filepath.Abs(customPath)
}
func (c *defaultCredential) ensureCredentialWatcher() error {
c.watcherAccess.Lock()
defer c.watcherAccess.Unlock()
if c.watcher != nil || c.credentialFilePath == "" {
return nil
}
if !c.watcherRetryAt.IsZero() && time.Now().Before(c.watcherRetryAt) {
return nil
}
watcher, err := fswatch.NewWatcher(fswatch.Options{
Path: []string{c.credentialFilePath},
Logger: c.logger,
Callback: func(string) {
err := c.reloadCredentials(true)
if err != nil {
c.logger.Warn("reload credentials for ", c.tag, ": ", err)
}
},
})
if err != nil {
c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval)
return err
}
err = watcher.Start()
if err != nil {
c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval)
return err
}
c.watcher = watcher
c.watcherRetryAt = time.Time{}
return nil
}
func (c *defaultCredential) retryCredentialReloadIfNeeded() {
c.stateMutex.RLock()
unavailable := c.state.unavailable
lastAttempt := c.state.lastCredentialLoadAttempt
c.stateMutex.RUnlock()
if !unavailable {
return
}
if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval {
return
}
err := c.ensureCredentialWatcher()
if err != nil {
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
}
_ = c.reloadCredentials(false)
}
func (c *defaultCredential) reloadCredentials(force bool) error {
c.reloadAccess.Lock()
defer c.reloadAccess.Unlock()
c.stateMutex.RLock()
unavailable := c.state.unavailable
lastAttempt := c.state.lastCredentialLoadAttempt
c.stateMutex.RUnlock()
if !force {
if !unavailable {
return nil
}
if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval {
return c.unavailableError()
}
}
c.stateMutex.Lock()
c.state.lastCredentialLoadAttempt = time.Now()
c.stateMutex.Unlock()
credentials, err := platformReadCredentials(c.credentialPath)
if err != nil {
return c.markCredentialsUnavailable(E.Cause(err, "read credentials"))
}
c.accessMutex.Lock()
c.credentials = credentials
c.accessMutex.Unlock()
c.stateMutex.Lock()
c.state.unavailable = false
c.state.lastCredentialLoadError = ""
c.checkTransitionLocked()
c.stateMutex.Unlock()
return nil
}
func (c *defaultCredential) markCredentialsUnavailable(err error) error {
c.accessMutex.Lock()
hadCredentials := c.credentials != nil
c.credentials = nil
c.accessMutex.Unlock()
c.stateMutex.Lock()
c.state.unavailable = true
c.state.lastCredentialLoadError = err.Error()
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
if shouldInterrupt && hadCredentials {
c.interruptConnections()
}
return err
}

View File

@@ -16,6 +16,7 @@ import (
"sync/atomic"
"time"
"github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/log"
@@ -29,31 +30,39 @@ import (
const defaultPollInterval = 60 * time.Minute
type credentialState struct {
fiveHourUtilization float64
fiveHourReset time.Time
weeklyUtilization float64
weeklyReset time.Time
hardRateLimited bool
rateLimitResetAt time.Time
accountType string
lastUpdated time.Time
consecutivePollFailures int
fiveHourUtilization float64
fiveHourReset time.Time
weeklyUtilization float64
weeklyReset time.Time
hardRateLimited bool
rateLimitResetAt time.Time
accountType string
lastUpdated time.Time
consecutivePollFailures int
unavailable bool
lastCredentialLoadAttempt time.Time
lastCredentialLoadError string
}
type defaultCredential struct {
tag string
credentialPath string
credentials *oauthCredentials
accessMutex sync.RWMutex
state credentialState
stateMutex sync.RWMutex
pollAccess sync.Mutex
reserve5h uint8
reserveWeekly uint8
usageTracker *AggregatedUsage
dialer N.Dialer
httpClient *http.Client
logger log.ContextLogger
tag string
credentialPath string
credentialFilePath string
credentials *oauthCredentials
accessMutex sync.RWMutex
state credentialState
stateMutex sync.RWMutex
pollAccess sync.Mutex
reloadAccess sync.Mutex
watcherAccess sync.Mutex
reserve5h uint8
reserveWeekly uint8
usageTracker *AggregatedUsage
dialer N.Dialer
httpClient *http.Client
logger log.ContextLogger
watcher *fswatch.Watcher
watcherRetryAt time.Time
// Connection interruption
onBecameUnusable func()
@@ -84,12 +93,14 @@ func (c *credentialRequestContext) cancelRequest() {
type credential interface {
tagName() string
isAvailable() bool
isUsable() bool
isExternal() bool
fiveHourUtilization() float64
weeklyUtilization() float64
markRateLimited(resetAt time.Time)
earliestReset() time.Time
unavailableError() error
getAccessToken() (string, error)
buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error)
@@ -169,11 +180,19 @@ func newDefaultCredential(ctx context.Context, tag string, options option.OCMDef
}
func (c *defaultCredential) start() error {
credentials, err := platformReadCredentials(c.credentialPath)
credentialFilePath, err := resolveCredentialFilePath(c.credentialPath)
if err != nil {
return E.Cause(err, "read credentials for ", c.tag)
return E.Cause(err, "resolve credential path for ", c.tag)
}
c.credentialFilePath = credentialFilePath
err = c.ensureCredentialWatcher()
if err != nil {
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
}
err = c.reloadCredentials(true)
if err != nil {
c.logger.Warn("initial credential load for ", c.tag, ": ", err)
}
c.credentials = credentials
if c.usageTracker != nil {
err = c.usageTracker.Load()
if err != nil {
@@ -184,27 +203,65 @@ func (c *defaultCredential) start() error {
}
func (c *defaultCredential) getAccessToken() (string, error) {
c.retryCredentialReloadIfNeeded()
c.accessMutex.RLock()
if !c.credentials.needsRefresh() {
if c.credentials != nil && !c.credentials.needsRefresh() {
token := c.credentials.getAccessToken()
c.accessMutex.RUnlock()
return token, nil
}
c.accessMutex.RUnlock()
err := c.reloadCredentials(true)
if err == nil {
c.accessMutex.RLock()
if c.credentials != nil && !c.credentials.needsRefresh() {
token := c.credentials.getAccessToken()
c.accessMutex.RUnlock()
return token, nil
}
c.accessMutex.RUnlock()
}
c.accessMutex.Lock()
defer c.accessMutex.Unlock()
if c.credentials == nil {
return "", c.unavailableError()
}
if !c.credentials.needsRefresh() {
return c.credentials.getAccessToken(), nil
}
baseCredentials := cloneCredentials(c.credentials)
newCredentials, err := refreshToken(c.httpClient, c.credentials)
if err != nil {
return "", err
}
latestCredentials, latestErr := platformReadCredentials(c.credentialPath)
if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) {
c.credentials = latestCredentials
c.stateMutex.Lock()
c.state.unavailable = false
c.state.lastCredentialLoadAttempt = time.Now()
c.state.lastCredentialLoadError = ""
c.checkTransitionLocked()
c.stateMutex.Unlock()
if !latestCredentials.needsRefresh() {
return latestCredentials.getAccessToken(), nil
}
return "", E.New("credential ", c.tag, " changed while refreshing")
}
c.credentials = newCredentials
c.stateMutex.Lock()
c.state.unavailable = false
c.state.lastCredentialLoadAttempt = time.Now()
c.state.lastCredentialLoadError = ""
c.checkTransitionLocked()
c.stateMutex.Unlock()
err = platformWriteCredentials(newCredentials, c.credentialPath)
if err != nil {
@@ -217,12 +274,18 @@ func (c *defaultCredential) getAccessToken() (string, error) {
func (c *defaultCredential) getAccountID() string {
c.accessMutex.RLock()
defer c.accessMutex.RUnlock()
if c.credentials == nil {
return ""
}
return c.credentials.getAccountID()
}
func (c *defaultCredential) isAPIKeyMode() bool {
c.accessMutex.RLock()
defer c.accessMutex.RUnlock()
if c.credentials == nil {
return false
}
return c.credentials.isAPIKeyMode()
}
@@ -296,7 +359,13 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) {
}
func (c *defaultCredential) isUsable() bool {
c.retryCredentialReloadIfNeeded()
c.stateMutex.RLock()
if c.state.unavailable {
c.stateMutex.RUnlock()
return false
}
if c.state.hardRateLimited {
if time.Now().Before(c.state.rateLimitResetAt) {
c.stateMutex.RUnlock()
@@ -329,7 +398,7 @@ func (c *defaultCredential) checkReservesLocked() bool {
// checkTransitionLocked detects usable→unusable transition.
// Must be called with stateMutex write lock held.
func (c *defaultCredential) checkTransitionLocked() bool {
unusable := c.state.hardRateLimited || !c.checkReservesLocked()
unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked()
if unusable && !c.interrupted {
c.interrupted = true
return true
@@ -372,6 +441,26 @@ func (c *defaultCredential) weeklyUtilization() float64 {
return c.state.weeklyUtilization
}
func (c *defaultCredential) isAvailable() bool {
c.retryCredentialReloadIfNeeded()
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
return !c.state.unavailable
}
func (c *defaultCredential) unavailableError() error {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
if !c.state.unavailable {
return nil
}
if c.state.lastCredentialLoadError == "" {
return E.New("credential ", c.tag, " is unavailable")
}
return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError)
}
func (c *defaultCredential) lastUpdatedTime() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
@@ -400,6 +489,9 @@ func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duratio
func (c *defaultCredential) earliestReset() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
if c.state.unavailable {
return time.Time{}
}
if c.state.hardRateLimited {
return c.state.rateLimitResetAt
}
@@ -421,15 +513,20 @@ func isTimeoutError(err error) bool {
}
func (c *defaultCredential) pollUsage(ctx context.Context) {
if c.isAPIKeyMode() {
return
}
if !c.pollAccess.TryLock() {
return
}
defer c.pollAccess.Unlock()
defer c.markUsagePollAttempted()
c.retryCredentialReloadIfNeeded()
if !c.isAvailable() {
return
}
if c.isAPIKeyMode() {
return
}
accessToken, err := c.getAccessToken()
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": get token: ", err)
@@ -546,6 +643,12 @@ func (c *defaultCredential) pollUsage(ctx context.Context) {
}
func (c *defaultCredential) close() {
if c.watcher != nil {
err := c.watcher.Close()
if err != nil {
c.logger.Error("close credential watcher for ", c.tag, ": ", err)
}
}
if c.usageTracker != nil {
c.usageTracker.cancelPendingSave()
err := c.usageTracker.Save()
@@ -662,6 +765,9 @@ func (p *singleCredentialProvider) selectCredential(_ string, filter func(creden
if filter != nil && !filter(p.cred) {
return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out")
}
if !p.cred.isAvailable() {
return nil, false, p.cred.unavailableError()
}
if !p.cred.isUsable() {
return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited")
}
@@ -702,6 +808,10 @@ type balancerProvider struct {
logger log.ContextLogger
}
func compositeCredentialSelectable(cred credential) bool {
return !cred.ocmIsAPIKeyMode()
}
func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider {
if pollInterval <= 0 {
pollInterval = defaultPollInterval
@@ -722,7 +832,7 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden
p.sessionMutex.RUnlock()
if exists {
for _, cred := range p.credentials {
if cred.tagName() == entry.tag && (filter == nil || filter(cred)) && cred.isUsable() {
if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && (filter == nil || filter(cred)) && cred.isUsable() {
return cred, false, nil
}
}
@@ -781,6 +891,9 @@ func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credentia
if filter != nil && !filter(cred) {
continue
}
if !compositeCredentialSelectable(cred) {
continue
}
if !cred.isUsable() {
continue
}
@@ -801,6 +914,9 @@ func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credenti
if filter != nil && !filter(candidate) {
continue
}
if !compositeCredentialSelectable(candidate) {
continue
}
if candidate.isUsable() {
return candidate
}
@@ -814,6 +930,9 @@ func (p *balancerProvider) pickRandom(filter func(credential) bool) credential {
if filter != nil && !filter(candidate) {
continue
}
if !compositeCredentialSelectable(candidate) {
continue
}
if candidate.isUsable() {
usable = append(usable, candidate)
}
@@ -869,6 +988,9 @@ func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bo
if filter != nil && !filter(cred) {
continue
}
if !compositeCredentialSelectable(cred) {
continue
}
if cred.isUsable() {
return cred, false, nil
}
@@ -882,6 +1004,9 @@ func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time
if filter != nil && !filter(candidate) {
continue
}
if !compositeCredentialSelectable(candidate) {
continue
}
if candidate.isUsable() {
return candidate
}
@@ -904,13 +1029,21 @@ func (p *fallbackProvider) allCredentials() []credential {
func (p *fallbackProvider) close() {}
func allRateLimitedError(credentials []credential) error {
var hasUnavailable bool
var earliest time.Time
for _, cred := range credentials {
if cred.unavailableError() != nil {
hasUnavailable = true
continue
}
resetAt := cred.earliestReset()
if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) {
earliest = resetAt
}
}
if hasUnavailable {
return E.New("all credentials unavailable")
}
if earliest.IsZero() {
return E.New("all credentials rate-limited")
}
@@ -1090,6 +1223,9 @@ func validateOCMCompositeCredentialModes(
}
for _, subCred := range provider.allCredentials() {
if !subCred.isAvailable() {
continue
}
if subCred.ocmIsAPIKeyMode() {
return E.New(
"credential ", credOpt.Tag,

View File

@@ -96,7 +96,11 @@ func unavailableCredentialMessage(provider credentialProvider, fallback string)
if provider == nil {
return fallback
}
return allRateLimitedError(provider.allCredentials()).Error()
message := allRateLimitedError(provider.allCredentials()).Error()
if message == "all credentials unavailable" && fallback != "" {
return fallback
}
return message
}
func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) {
@@ -806,6 +810,9 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user
var totalFiveHour, totalWeekly float64
var count int
for _, cred := range provider.allCredentials() {
if !cred.isAvailable() {
continue
}
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
continue
}