mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-11 17:47:20 +10:00
ccm,ocm: watch credential_path and allow delayed credentials
This commit is contained in:
@@ -51,6 +51,10 @@ On macOS, credentials are read from the system keychain first, then fall back to
|
||||
|
||||
Refreshed tokens are automatically written back to the same location.
|
||||
|
||||
When `credential_path` points to a file, the service can start before the file exists. The credential becomes available automatically after the file is created or updated, and becomes unavailable immediately if the file is later removed or becomes invalid.
|
||||
|
||||
On macOS without an explicit `credential_path`, keychain changes are not watched. Automatic reload only applies to the credential file path.
|
||||
|
||||
Conflict with `credentials`.
|
||||
|
||||
#### credentials
|
||||
@@ -76,7 +80,7 @@ Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a
|
||||
}
|
||||
```
|
||||
|
||||
A single OAuth credential file. The `type` field can be omitted (defaults to `default`).
|
||||
A single OAuth credential file. The `type` field can be omitted (defaults to `default`). The service can start before the file exists, and reloads file updates automatically.
|
||||
|
||||
- `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`.
|
||||
- `usages_path`: Optional usage tracking file for this credential.
|
||||
|
||||
@@ -51,6 +51,10 @@ Claude Code OAuth 凭据文件的路径。
|
||||
|
||||
刷新的令牌会自动写回相同位置。
|
||||
|
||||
当 `credential_path` 指向文件时,即使文件尚不存在,服务也可以启动。文件被创建或更新后,凭据会自动变为可用;如果文件之后被删除或变为无效,该凭据会立即变为不可用。
|
||||
|
||||
在 macOS 上如果未显式设置 `credential_path`,不会监听钥匙串变化。自动重载只作用于凭据文件路径。
|
||||
|
||||
与 `credentials` 冲突。
|
||||
|
||||
#### credentials
|
||||
@@ -76,7 +80,7 @@ Claude Code OAuth 凭据文件的路径。
|
||||
}
|
||||
```
|
||||
|
||||
单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。
|
||||
单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。即使文件尚不存在,服务也可以启动,并会自动重载文件更新。
|
||||
|
||||
- `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。
|
||||
- `usages_path`:此凭据的可选使用跟踪文件。
|
||||
|
||||
@@ -49,6 +49,8 @@ If not specified, defaults to:
|
||||
|
||||
Refreshed tokens are automatically written back to the same location.
|
||||
|
||||
When `credential_path` points to a file, the service can start before the file exists. The credential becomes available automatically after the file is created or updated, and becomes unavailable immediately if the file is later removed or becomes invalid.
|
||||
|
||||
Conflict with `credentials`.
|
||||
|
||||
#### credentials
|
||||
@@ -74,7 +76,7 @@ Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a
|
||||
}
|
||||
```
|
||||
|
||||
A single OAuth credential file. The `type` field can be omitted (defaults to `default`).
|
||||
A single OAuth credential file. The `type` field can be omitted (defaults to `default`). The service can start before the file exists, and reloads file updates automatically.
|
||||
|
||||
- `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`.
|
||||
- `usages_path`: Optional usage tracking file for this credential.
|
||||
|
||||
@@ -49,6 +49,8 @@ OpenAI OAuth 凭据文件的路径。
|
||||
|
||||
刷新的令牌会自动写回相同位置。
|
||||
|
||||
当 `credential_path` 指向文件时,即使文件尚不存在,服务也可以启动。文件被创建或更新后,凭据会自动变为可用;如果文件之后被删除或变为无效,该凭据会立即变为不可用。
|
||||
|
||||
与 `credentials` 冲突。
|
||||
|
||||
#### credentials
|
||||
@@ -74,7 +76,7 @@ OpenAI OAuth 凭据文件的路径。
|
||||
}
|
||||
```
|
||||
|
||||
单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。
|
||||
单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。即使文件尚不存在,服务也可以启动,并会自动重载文件更新。
|
||||
|
||||
- `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。
|
||||
- `usages_path`:此凭据的可选使用跟踪文件。
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -189,3 +190,24 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut
|
||||
|
||||
return &newCredentials, nil
|
||||
}
|
||||
|
||||
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
|
||||
if credentials == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *credentials
|
||||
cloned.Scopes = append([]string(nil), credentials.Scopes...)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
|
||||
if left == nil || right == nil {
|
||||
return left == right
|
||||
}
|
||||
return left.AccessToken == right.AccessToken &&
|
||||
left.RefreshToken == right.RefreshToken &&
|
||||
left.ExpiresAt == right.ExpiresAt &&
|
||||
slices.Equal(left.Scopes, right.Scopes) &&
|
||||
left.SubscriptionType == right.SubscriptionType &&
|
||||
left.IsMax == right.IsMax
|
||||
}
|
||||
|
||||
@@ -151,6 +151,10 @@ func (c *externalCredential) isExternal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *externalCredential) isAvailable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *externalCredential) isUsable() bool {
|
||||
c.stateMutex.RLock()
|
||||
if c.state.hardRateLimited {
|
||||
@@ -210,6 +214,10 @@ func (c *externalCredential) earliestReset() time.Time {
|
||||
return earliest
|
||||
}
|
||||
|
||||
func (c *externalCredential) unavailableError() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) getAccessToken() (string, error) {
|
||||
return c.token, nil
|
||||
}
|
||||
|
||||
141
service/ccm/credential_file.go
Normal file
141
service/ccm/credential_file.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/fswatch"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
const credentialReloadRetryInterval = 2 * time.Second
|
||||
|
||||
func resolveCredentialFilePath(customPath string) (string, error) {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if filepath.IsAbs(customPath) {
|
||||
return customPath, nil
|
||||
}
|
||||
return filepath.Abs(customPath)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ensureCredentialWatcher() error {
|
||||
c.watcherAccess.Lock()
|
||||
defer c.watcherAccess.Unlock()
|
||||
|
||||
if c.watcher != nil || c.credentialFilePath == "" {
|
||||
return nil
|
||||
}
|
||||
if !c.watcherRetryAt.IsZero() && time.Now().Before(c.watcherRetryAt) {
|
||||
return nil
|
||||
}
|
||||
|
||||
watcher, err := fswatch.NewWatcher(fswatch.Options{
|
||||
Path: []string{c.credentialFilePath},
|
||||
Logger: c.logger,
|
||||
Callback: func(string) {
|
||||
err := c.reloadCredentials(true)
|
||||
if err != nil {
|
||||
c.logger.Warn("reload credentials for ", c.tag, ": ", err)
|
||||
}
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval)
|
||||
return err
|
||||
}
|
||||
|
||||
err = watcher.Start()
|
||||
if err != nil {
|
||||
c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval)
|
||||
return err
|
||||
}
|
||||
|
||||
c.watcher = watcher
|
||||
c.watcherRetryAt = time.Time{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) retryCredentialReloadIfNeeded() {
|
||||
c.stateMutex.RLock()
|
||||
unavailable := c.state.unavailable
|
||||
lastAttempt := c.state.lastCredentialLoadAttempt
|
||||
c.stateMutex.RUnlock()
|
||||
if !unavailable {
|
||||
return
|
||||
}
|
||||
if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval {
|
||||
return
|
||||
}
|
||||
|
||||
err := c.ensureCredentialWatcher()
|
||||
if err != nil {
|
||||
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
_ = c.reloadCredentials(false)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) reloadCredentials(force bool) error {
|
||||
c.reloadAccess.Lock()
|
||||
defer c.reloadAccess.Unlock()
|
||||
|
||||
c.stateMutex.RLock()
|
||||
unavailable := c.state.unavailable
|
||||
lastAttempt := c.state.lastCredentialLoadAttempt
|
||||
c.stateMutex.RUnlock()
|
||||
if !force {
|
||||
if !unavailable {
|
||||
return nil
|
||||
}
|
||||
if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval {
|
||||
return c.unavailableError()
|
||||
}
|
||||
}
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.stateMutex.Unlock()
|
||||
|
||||
credentials, err := platformReadCredentials(c.credentialPath)
|
||||
if err != nil {
|
||||
return c.markCredentialsUnavailable(E.Cause(err, "read credentials"))
|
||||
}
|
||||
|
||||
c.accessMutex.Lock()
|
||||
c.credentials = credentials
|
||||
c.accessMutex.Unlock()
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.state.accountType = credentials.SubscriptionType
|
||||
c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markCredentialsUnavailable(err error) error {
|
||||
c.accessMutex.Lock()
|
||||
hadCredentials := c.credentials != nil
|
||||
c.credentials = nil
|
||||
c.accessMutex.Unlock()
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.state.unavailable = true
|
||||
c.state.lastCredentialLoadError = err.Error()
|
||||
c.state.accountType = ""
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
|
||||
if shouldInterrupt && hadCredentials {
|
||||
c.interruptConnections()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/fswatch"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
@@ -29,30 +30,38 @@ import (
|
||||
const defaultPollInterval = 60 * time.Minute
|
||||
|
||||
type credentialState struct {
|
||||
fiveHourUtilization float64
|
||||
fiveHourReset time.Time
|
||||
weeklyUtilization float64
|
||||
weeklyReset time.Time
|
||||
hardRateLimited bool
|
||||
rateLimitResetAt time.Time
|
||||
accountType string
|
||||
lastUpdated time.Time
|
||||
consecutivePollFailures int
|
||||
fiveHourUtilization float64
|
||||
fiveHourReset time.Time
|
||||
weeklyUtilization float64
|
||||
weeklyReset time.Time
|
||||
hardRateLimited bool
|
||||
rateLimitResetAt time.Time
|
||||
accountType string
|
||||
lastUpdated time.Time
|
||||
consecutivePollFailures int
|
||||
unavailable bool
|
||||
lastCredentialLoadAttempt time.Time
|
||||
lastCredentialLoadError string
|
||||
}
|
||||
|
||||
type defaultCredential struct {
|
||||
tag string
|
||||
credentialPath string
|
||||
credentials *oauthCredentials
|
||||
accessMutex sync.RWMutex
|
||||
state credentialState
|
||||
stateMutex sync.RWMutex
|
||||
pollAccess sync.Mutex
|
||||
reserve5h uint8
|
||||
reserveWeekly uint8
|
||||
usageTracker *AggregatedUsage
|
||||
httpClient *http.Client
|
||||
logger log.ContextLogger
|
||||
tag string
|
||||
credentialPath string
|
||||
credentialFilePath string
|
||||
credentials *oauthCredentials
|
||||
accessMutex sync.RWMutex
|
||||
state credentialState
|
||||
stateMutex sync.RWMutex
|
||||
pollAccess sync.Mutex
|
||||
reloadAccess sync.Mutex
|
||||
watcherAccess sync.Mutex
|
||||
reserve5h uint8
|
||||
reserveWeekly uint8
|
||||
usageTracker *AggregatedUsage
|
||||
httpClient *http.Client
|
||||
logger log.ContextLogger
|
||||
watcher *fswatch.Watcher
|
||||
watcherRetryAt time.Time
|
||||
|
||||
// Connection interruption
|
||||
onBecameUnusable func()
|
||||
@@ -83,12 +92,14 @@ func (c *credentialRequestContext) cancelRequest() {
|
||||
|
||||
type credential interface {
|
||||
tagName() string
|
||||
isAvailable() bool
|
||||
isUsable() bool
|
||||
isExternal() bool
|
||||
fiveHourUtilization() float64
|
||||
weeklyUtilization() float64
|
||||
markRateLimited(resetAt time.Time)
|
||||
earliestReset() time.Time
|
||||
unavailableError() error
|
||||
|
||||
getAccessToken() (string, error)
|
||||
buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error)
|
||||
@@ -160,13 +171,18 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef
|
||||
}
|
||||
|
||||
func (c *defaultCredential) start() error {
|
||||
credentials, err := platformReadCredentials(c.credentialPath)
|
||||
credentialFilePath, err := resolveCredentialFilePath(c.credentialPath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read credentials for ", c.tag)
|
||||
return E.Cause(err, "resolve credential path for ", c.tag)
|
||||
}
|
||||
c.credentials = credentials
|
||||
if credentials.SubscriptionType != "" {
|
||||
c.state.accountType = credentials.SubscriptionType
|
||||
c.credentialFilePath = credentialFilePath
|
||||
err = c.ensureCredentialWatcher()
|
||||
if err != nil {
|
||||
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
err = c.reloadCredentials(true)
|
||||
if err != nil {
|
||||
c.logger.Warn("initial credential load for ", c.tag, ": ", err)
|
||||
}
|
||||
if c.usageTracker != nil {
|
||||
err = c.usageTracker.Load()
|
||||
@@ -178,33 +194,68 @@ func (c *defaultCredential) start() error {
|
||||
}
|
||||
|
||||
func (c *defaultCredential) getAccessToken() (string, error) {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
c.accessMutex.RLock()
|
||||
if !c.credentials.needsRefresh() {
|
||||
if c.credentials != nil && !c.credentials.needsRefresh() {
|
||||
token := c.credentials.AccessToken
|
||||
c.accessMutex.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
c.accessMutex.RUnlock()
|
||||
|
||||
err := c.reloadCredentials(true)
|
||||
if err == nil {
|
||||
c.accessMutex.RLock()
|
||||
if c.credentials != nil && !c.credentials.needsRefresh() {
|
||||
token := c.credentials.AccessToken
|
||||
c.accessMutex.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
c.accessMutex.RUnlock()
|
||||
}
|
||||
|
||||
c.accessMutex.Lock()
|
||||
defer c.accessMutex.Unlock()
|
||||
|
||||
if c.credentials == nil {
|
||||
return "", c.unavailableError()
|
||||
}
|
||||
if !c.credentials.needsRefresh() {
|
||||
return c.credentials.AccessToken, nil
|
||||
}
|
||||
|
||||
baseCredentials := cloneCredentials(c.credentials)
|
||||
newCredentials, err := refreshToken(c.httpClient, c.credentials)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
c.credentials = newCredentials
|
||||
if newCredentials.SubscriptionType != "" {
|
||||
latestCredentials, latestErr := platformReadCredentials(c.credentialPath)
|
||||
if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) {
|
||||
c.credentials = latestCredentials
|
||||
c.stateMutex.Lock()
|
||||
c.state.accountType = newCredentials.SubscriptionType
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.state.accountType = latestCredentials.SubscriptionType
|
||||
c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
if !latestCredentials.needsRefresh() {
|
||||
return latestCredentials.AccessToken, nil
|
||||
}
|
||||
return "", E.New("credential ", c.tag, " changed while refreshing")
|
||||
}
|
||||
|
||||
c.credentials = newCredentials
|
||||
c.stateMutex.Lock()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.state.accountType = newCredentials.SubscriptionType
|
||||
c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
|
||||
err = platformWriteCredentials(newCredentials, c.credentialPath)
|
||||
if err != nil {
|
||||
c.logger.Warn("persist refreshed token for ", c.tag, ": ", err)
|
||||
@@ -299,7 +350,13 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) {
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isUsable() bool {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
c.stateMutex.RLock()
|
||||
if c.state.unavailable {
|
||||
c.stateMutex.RUnlock()
|
||||
return false
|
||||
}
|
||||
if c.state.hardRateLimited {
|
||||
if time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.stateMutex.RUnlock()
|
||||
@@ -332,7 +389,7 @@ func (c *defaultCredential) checkReservesLocked() bool {
|
||||
// checkTransitionLocked detects usable→unusable transition.
|
||||
// Must be called with stateMutex write lock held.
|
||||
func (c *defaultCredential) checkTransitionLocked() bool {
|
||||
unusable := c.state.hardRateLimited || !c.checkReservesLocked()
|
||||
unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked()
|
||||
if unusable && !c.interrupted {
|
||||
c.interrupted = true
|
||||
return true
|
||||
@@ -375,6 +432,26 @@ func (c *defaultCredential) weeklyUtilization() float64 {
|
||||
return c.state.weeklyUtilization
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isAvailable() bool {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return !c.state.unavailable
|
||||
}
|
||||
|
||||
func (c *defaultCredential) unavailableError() error {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
if !c.state.unavailable {
|
||||
return nil
|
||||
}
|
||||
if c.state.lastCredentialLoadError == "" {
|
||||
return E.New("credential ", c.tag, " is unavailable")
|
||||
}
|
||||
return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) lastUpdatedTime() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
@@ -403,6 +480,9 @@ func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duratio
|
||||
func (c *defaultCredential) earliestReset() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
if c.state.unavailable {
|
||||
return time.Time{}
|
||||
}
|
||||
if c.state.hardRateLimited {
|
||||
return c.state.rateLimitResetAt
|
||||
}
|
||||
@@ -430,6 +510,11 @@ func (c *defaultCredential) pollUsage(ctx context.Context) {
|
||||
defer c.pollAccess.Unlock()
|
||||
defer c.markUsagePollAttempted()
|
||||
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
if !c.isAvailable() {
|
||||
return
|
||||
}
|
||||
|
||||
accessToken, err := c.getAccessToken()
|
||||
if err != nil {
|
||||
c.logger.Error("poll usage for ", c.tag, ": get token: ", err)
|
||||
@@ -528,6 +613,12 @@ func (c *defaultCredential) pollUsage(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (c *defaultCredential) close() {
|
||||
if c.watcher != nil {
|
||||
err := c.watcher.Close()
|
||||
if err != nil {
|
||||
c.logger.Error("close credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
if c.usageTracker != nil {
|
||||
c.usageTracker.cancelPendingSave()
|
||||
err := c.usageTracker.Save()
|
||||
@@ -622,6 +713,9 @@ func (p *singleCredentialProvider) selectCredential(_ string, filter func(creden
|
||||
if filter != nil && !filter(p.cred) {
|
||||
return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out")
|
||||
}
|
||||
if !p.cred.isAvailable() {
|
||||
return nil, false, p.cred.unavailableError()
|
||||
}
|
||||
if !p.cred.isUsable() {
|
||||
return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited")
|
||||
}
|
||||
@@ -866,13 +960,21 @@ func (p *fallbackProvider) allCredentials() []credential {
|
||||
func (p *fallbackProvider) close() {}
|
||||
|
||||
func allCredentialsUnavailableError(credentials []credential) error {
|
||||
var hasUnavailable bool
|
||||
var earliest time.Time
|
||||
for _, cred := range credentials {
|
||||
if cred.unavailableError() != nil {
|
||||
hasUnavailable = true
|
||||
continue
|
||||
}
|
||||
resetAt := cred.earliestReset()
|
||||
if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) {
|
||||
earliest = resetAt
|
||||
}
|
||||
}
|
||||
if hasUnavailable {
|
||||
return E.New("all credentials unavailable")
|
||||
}
|
||||
if earliest.IsZero() {
|
||||
return E.New("all credentials rate-limited")
|
||||
}
|
||||
|
||||
@@ -88,7 +88,11 @@ func unavailableCredentialMessage(provider credentialProvider, fallback string)
|
||||
if provider == nil {
|
||||
return fallback
|
||||
}
|
||||
return allCredentialsUnavailableError(provider.allCredentials()).Error()
|
||||
message := allCredentialsUnavailableError(provider.allCredentials()).Error()
|
||||
if message == "all credentials unavailable" && fallback != "" {
|
||||
return fallback
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -734,6 +738,9 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user
|
||||
var totalFiveHour, totalWeekly float64
|
||||
var count int
|
||||
for _, cred := range provider.allCredentials() {
|
||||
if !cred.isAvailable() {
|
||||
continue
|
||||
}
|
||||
// Exclude the user's own external_credential (their contribution to us)
|
||||
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
|
||||
continue
|
||||
|
||||
@@ -175,3 +175,41 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut
|
||||
|
||||
return &newCredentials, nil
|
||||
}
|
||||
|
||||
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
|
||||
if credentials == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *credentials
|
||||
if credentials.Tokens != nil {
|
||||
clonedTokens := *credentials.Tokens
|
||||
cloned.Tokens = &clonedTokens
|
||||
}
|
||||
if credentials.LastRefresh != nil {
|
||||
lastRefresh := *credentials.LastRefresh
|
||||
cloned.LastRefresh = &lastRefresh
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
|
||||
if left == nil || right == nil {
|
||||
return left == right
|
||||
}
|
||||
if left.APIKey != right.APIKey {
|
||||
return false
|
||||
}
|
||||
if (left.Tokens == nil) != (right.Tokens == nil) {
|
||||
return false
|
||||
}
|
||||
if left.Tokens != nil && *left.Tokens != *right.Tokens {
|
||||
return false
|
||||
}
|
||||
if (left.LastRefresh == nil) != (right.LastRefresh == nil) {
|
||||
return false
|
||||
}
|
||||
if left.LastRefresh != nil && !left.LastRefresh.Equal(*right.LastRefresh) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -157,6 +157,10 @@ func (c *externalCredential) isExternal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *externalCredential) isAvailable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *externalCredential) isUsable() bool {
|
||||
c.stateMutex.RLock()
|
||||
if c.state.hardRateLimited {
|
||||
@@ -215,6 +219,10 @@ func (c *externalCredential) earliestReset() time.Time {
|
||||
return earliest
|
||||
}
|
||||
|
||||
func (c *externalCredential) unavailableError() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) getAccessToken() (string, error) {
|
||||
return c.token, nil
|
||||
}
|
||||
|
||||
139
service/ocm/credential_file.go
Normal file
139
service/ocm/credential_file.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/fswatch"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
const credentialReloadRetryInterval = 2 * time.Second
|
||||
|
||||
func resolveCredentialFilePath(customPath string) (string, error) {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if filepath.IsAbs(customPath) {
|
||||
return customPath, nil
|
||||
}
|
||||
return filepath.Abs(customPath)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ensureCredentialWatcher() error {
|
||||
c.watcherAccess.Lock()
|
||||
defer c.watcherAccess.Unlock()
|
||||
|
||||
if c.watcher != nil || c.credentialFilePath == "" {
|
||||
return nil
|
||||
}
|
||||
if !c.watcherRetryAt.IsZero() && time.Now().Before(c.watcherRetryAt) {
|
||||
return nil
|
||||
}
|
||||
|
||||
watcher, err := fswatch.NewWatcher(fswatch.Options{
|
||||
Path: []string{c.credentialFilePath},
|
||||
Logger: c.logger,
|
||||
Callback: func(string) {
|
||||
err := c.reloadCredentials(true)
|
||||
if err != nil {
|
||||
c.logger.Warn("reload credentials for ", c.tag, ": ", err)
|
||||
}
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval)
|
||||
return err
|
||||
}
|
||||
|
||||
err = watcher.Start()
|
||||
if err != nil {
|
||||
c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval)
|
||||
return err
|
||||
}
|
||||
|
||||
c.watcher = watcher
|
||||
c.watcherRetryAt = time.Time{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) retryCredentialReloadIfNeeded() {
|
||||
c.stateMutex.RLock()
|
||||
unavailable := c.state.unavailable
|
||||
lastAttempt := c.state.lastCredentialLoadAttempt
|
||||
c.stateMutex.RUnlock()
|
||||
if !unavailable {
|
||||
return
|
||||
}
|
||||
if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval {
|
||||
return
|
||||
}
|
||||
|
||||
err := c.ensureCredentialWatcher()
|
||||
if err != nil {
|
||||
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
_ = c.reloadCredentials(false)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) reloadCredentials(force bool) error {
|
||||
c.reloadAccess.Lock()
|
||||
defer c.reloadAccess.Unlock()
|
||||
|
||||
c.stateMutex.RLock()
|
||||
unavailable := c.state.unavailable
|
||||
lastAttempt := c.state.lastCredentialLoadAttempt
|
||||
c.stateMutex.RUnlock()
|
||||
if !force {
|
||||
if !unavailable {
|
||||
return nil
|
||||
}
|
||||
if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval {
|
||||
return c.unavailableError()
|
||||
}
|
||||
}
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.stateMutex.Unlock()
|
||||
|
||||
credentials, err := platformReadCredentials(c.credentialPath)
|
||||
if err != nil {
|
||||
return c.markCredentialsUnavailable(E.Cause(err, "read credentials"))
|
||||
}
|
||||
|
||||
c.accessMutex.Lock()
|
||||
c.credentials = credentials
|
||||
c.accessMutex.Unlock()
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markCredentialsUnavailable(err error) error {
|
||||
c.accessMutex.Lock()
|
||||
hadCredentials := c.credentials != nil
|
||||
c.credentials = nil
|
||||
c.accessMutex.Unlock()
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.state.unavailable = true
|
||||
c.state.lastCredentialLoadError = err.Error()
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
|
||||
if shouldInterrupt && hadCredentials {
|
||||
c.interruptConnections()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/fswatch"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
@@ -29,31 +30,39 @@ import (
|
||||
const defaultPollInterval = 60 * time.Minute
|
||||
|
||||
type credentialState struct {
|
||||
fiveHourUtilization float64
|
||||
fiveHourReset time.Time
|
||||
weeklyUtilization float64
|
||||
weeklyReset time.Time
|
||||
hardRateLimited bool
|
||||
rateLimitResetAt time.Time
|
||||
accountType string
|
||||
lastUpdated time.Time
|
||||
consecutivePollFailures int
|
||||
fiveHourUtilization float64
|
||||
fiveHourReset time.Time
|
||||
weeklyUtilization float64
|
||||
weeklyReset time.Time
|
||||
hardRateLimited bool
|
||||
rateLimitResetAt time.Time
|
||||
accountType string
|
||||
lastUpdated time.Time
|
||||
consecutivePollFailures int
|
||||
unavailable bool
|
||||
lastCredentialLoadAttempt time.Time
|
||||
lastCredentialLoadError string
|
||||
}
|
||||
|
||||
type defaultCredential struct {
|
||||
tag string
|
||||
credentialPath string
|
||||
credentials *oauthCredentials
|
||||
accessMutex sync.RWMutex
|
||||
state credentialState
|
||||
stateMutex sync.RWMutex
|
||||
pollAccess sync.Mutex
|
||||
reserve5h uint8
|
||||
reserveWeekly uint8
|
||||
usageTracker *AggregatedUsage
|
||||
dialer N.Dialer
|
||||
httpClient *http.Client
|
||||
logger log.ContextLogger
|
||||
tag string
|
||||
credentialPath string
|
||||
credentialFilePath string
|
||||
credentials *oauthCredentials
|
||||
accessMutex sync.RWMutex
|
||||
state credentialState
|
||||
stateMutex sync.RWMutex
|
||||
pollAccess sync.Mutex
|
||||
reloadAccess sync.Mutex
|
||||
watcherAccess sync.Mutex
|
||||
reserve5h uint8
|
||||
reserveWeekly uint8
|
||||
usageTracker *AggregatedUsage
|
||||
dialer N.Dialer
|
||||
httpClient *http.Client
|
||||
logger log.ContextLogger
|
||||
watcher *fswatch.Watcher
|
||||
watcherRetryAt time.Time
|
||||
|
||||
// Connection interruption
|
||||
onBecameUnusable func()
|
||||
@@ -84,12 +93,14 @@ func (c *credentialRequestContext) cancelRequest() {
|
||||
|
||||
type credential interface {
|
||||
tagName() string
|
||||
isAvailable() bool
|
||||
isUsable() bool
|
||||
isExternal() bool
|
||||
fiveHourUtilization() float64
|
||||
weeklyUtilization() float64
|
||||
markRateLimited(resetAt time.Time)
|
||||
earliestReset() time.Time
|
||||
unavailableError() error
|
||||
|
||||
getAccessToken() (string, error)
|
||||
buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error)
|
||||
@@ -169,11 +180,19 @@ func newDefaultCredential(ctx context.Context, tag string, options option.OCMDef
|
||||
}
|
||||
|
||||
func (c *defaultCredential) start() error {
|
||||
credentials, err := platformReadCredentials(c.credentialPath)
|
||||
credentialFilePath, err := resolveCredentialFilePath(c.credentialPath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read credentials for ", c.tag)
|
||||
return E.Cause(err, "resolve credential path for ", c.tag)
|
||||
}
|
||||
c.credentialFilePath = credentialFilePath
|
||||
err = c.ensureCredentialWatcher()
|
||||
if err != nil {
|
||||
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
err = c.reloadCredentials(true)
|
||||
if err != nil {
|
||||
c.logger.Warn("initial credential load for ", c.tag, ": ", err)
|
||||
}
|
||||
c.credentials = credentials
|
||||
if c.usageTracker != nil {
|
||||
err = c.usageTracker.Load()
|
||||
if err != nil {
|
||||
@@ -184,27 +203,65 @@ func (c *defaultCredential) start() error {
|
||||
}
|
||||
|
||||
func (c *defaultCredential) getAccessToken() (string, error) {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
c.accessMutex.RLock()
|
||||
if !c.credentials.needsRefresh() {
|
||||
if c.credentials != nil && !c.credentials.needsRefresh() {
|
||||
token := c.credentials.getAccessToken()
|
||||
c.accessMutex.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
c.accessMutex.RUnlock()
|
||||
|
||||
err := c.reloadCredentials(true)
|
||||
if err == nil {
|
||||
c.accessMutex.RLock()
|
||||
if c.credentials != nil && !c.credentials.needsRefresh() {
|
||||
token := c.credentials.getAccessToken()
|
||||
c.accessMutex.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
c.accessMutex.RUnlock()
|
||||
}
|
||||
|
||||
c.accessMutex.Lock()
|
||||
defer c.accessMutex.Unlock()
|
||||
|
||||
if c.credentials == nil {
|
||||
return "", c.unavailableError()
|
||||
}
|
||||
if !c.credentials.needsRefresh() {
|
||||
return c.credentials.getAccessToken(), nil
|
||||
}
|
||||
|
||||
baseCredentials := cloneCredentials(c.credentials)
|
||||
newCredentials, err := refreshToken(c.httpClient, c.credentials)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
latestCredentials, latestErr := platformReadCredentials(c.credentialPath)
|
||||
if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) {
|
||||
c.credentials = latestCredentials
|
||||
c.stateMutex.Lock()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
if !latestCredentials.needsRefresh() {
|
||||
return latestCredentials.getAccessToken(), nil
|
||||
}
|
||||
return "", E.New("credential ", c.tag, " changed while refreshing")
|
||||
}
|
||||
|
||||
c.credentials = newCredentials
|
||||
c.stateMutex.Lock()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
|
||||
err = platformWriteCredentials(newCredentials, c.credentialPath)
|
||||
if err != nil {
|
||||
@@ -217,12 +274,18 @@ func (c *defaultCredential) getAccessToken() (string, error) {
|
||||
func (c *defaultCredential) getAccountID() string {
|
||||
c.accessMutex.RLock()
|
||||
defer c.accessMutex.RUnlock()
|
||||
if c.credentials == nil {
|
||||
return ""
|
||||
}
|
||||
return c.credentials.getAccountID()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isAPIKeyMode() bool {
|
||||
c.accessMutex.RLock()
|
||||
defer c.accessMutex.RUnlock()
|
||||
if c.credentials == nil {
|
||||
return false
|
||||
}
|
||||
return c.credentials.isAPIKeyMode()
|
||||
}
|
||||
|
||||
@@ -296,7 +359,13 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) {
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isUsable() bool {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
c.stateMutex.RLock()
|
||||
if c.state.unavailable {
|
||||
c.stateMutex.RUnlock()
|
||||
return false
|
||||
}
|
||||
if c.state.hardRateLimited {
|
||||
if time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.stateMutex.RUnlock()
|
||||
@@ -329,7 +398,7 @@ func (c *defaultCredential) checkReservesLocked() bool {
|
||||
// checkTransitionLocked detects usable→unusable transition.
|
||||
// Must be called with stateMutex write lock held.
|
||||
func (c *defaultCredential) checkTransitionLocked() bool {
|
||||
unusable := c.state.hardRateLimited || !c.checkReservesLocked()
|
||||
unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked()
|
||||
if unusable && !c.interrupted {
|
||||
c.interrupted = true
|
||||
return true
|
||||
@@ -372,6 +441,26 @@ func (c *defaultCredential) weeklyUtilization() float64 {
|
||||
return c.state.weeklyUtilization
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isAvailable() bool {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return !c.state.unavailable
|
||||
}
|
||||
|
||||
func (c *defaultCredential) unavailableError() error {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
if !c.state.unavailable {
|
||||
return nil
|
||||
}
|
||||
if c.state.lastCredentialLoadError == "" {
|
||||
return E.New("credential ", c.tag, " is unavailable")
|
||||
}
|
||||
return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) lastUpdatedTime() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
@@ -400,6 +489,9 @@ func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duratio
|
||||
func (c *defaultCredential) earliestReset() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
if c.state.unavailable {
|
||||
return time.Time{}
|
||||
}
|
||||
if c.state.hardRateLimited {
|
||||
return c.state.rateLimitResetAt
|
||||
}
|
||||
@@ -421,15 +513,20 @@ func isTimeoutError(err error) bool {
|
||||
}
|
||||
|
||||
func (c *defaultCredential) pollUsage(ctx context.Context) {
|
||||
if c.isAPIKeyMode() {
|
||||
return
|
||||
}
|
||||
if !c.pollAccess.TryLock() {
|
||||
return
|
||||
}
|
||||
defer c.pollAccess.Unlock()
|
||||
defer c.markUsagePollAttempted()
|
||||
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
if !c.isAvailable() {
|
||||
return
|
||||
}
|
||||
if c.isAPIKeyMode() {
|
||||
return
|
||||
}
|
||||
|
||||
accessToken, err := c.getAccessToken()
|
||||
if err != nil {
|
||||
c.logger.Error("poll usage for ", c.tag, ": get token: ", err)
|
||||
@@ -546,6 +643,12 @@ func (c *defaultCredential) pollUsage(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (c *defaultCredential) close() {
|
||||
if c.watcher != nil {
|
||||
err := c.watcher.Close()
|
||||
if err != nil {
|
||||
c.logger.Error("close credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
if c.usageTracker != nil {
|
||||
c.usageTracker.cancelPendingSave()
|
||||
err := c.usageTracker.Save()
|
||||
@@ -662,6 +765,9 @@ func (p *singleCredentialProvider) selectCredential(_ string, filter func(creden
|
||||
if filter != nil && !filter(p.cred) {
|
||||
return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out")
|
||||
}
|
||||
if !p.cred.isAvailable() {
|
||||
return nil, false, p.cred.unavailableError()
|
||||
}
|
||||
if !p.cred.isUsable() {
|
||||
return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited")
|
||||
}
|
||||
@@ -702,6 +808,10 @@ type balancerProvider struct {
|
||||
logger log.ContextLogger
|
||||
}
|
||||
|
||||
func compositeCredentialSelectable(cred credential) bool {
|
||||
return !cred.ocmIsAPIKeyMode()
|
||||
}
|
||||
|
||||
func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider {
|
||||
if pollInterval <= 0 {
|
||||
pollInterval = defaultPollInterval
|
||||
@@ -722,7 +832,7 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden
|
||||
p.sessionMutex.RUnlock()
|
||||
if exists {
|
||||
for _, cred := range p.credentials {
|
||||
if cred.tagName() == entry.tag && (filter == nil || filter(cred)) && cred.isUsable() {
|
||||
if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && (filter == nil || filter(cred)) && cred.isUsable() {
|
||||
return cred, false, nil
|
||||
}
|
||||
}
|
||||
@@ -781,6 +891,9 @@ func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credentia
|
||||
if filter != nil && !filter(cred) {
|
||||
continue
|
||||
}
|
||||
if !compositeCredentialSelectable(cred) {
|
||||
continue
|
||||
}
|
||||
if !cred.isUsable() {
|
||||
continue
|
||||
}
|
||||
@@ -801,6 +914,9 @@ func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credenti
|
||||
if filter != nil && !filter(candidate) {
|
||||
continue
|
||||
}
|
||||
if !compositeCredentialSelectable(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
return candidate
|
||||
}
|
||||
@@ -814,6 +930,9 @@ func (p *balancerProvider) pickRandom(filter func(credential) bool) credential {
|
||||
if filter != nil && !filter(candidate) {
|
||||
continue
|
||||
}
|
||||
if !compositeCredentialSelectable(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
usable = append(usable, candidate)
|
||||
}
|
||||
@@ -869,6 +988,9 @@ func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bo
|
||||
if filter != nil && !filter(cred) {
|
||||
continue
|
||||
}
|
||||
if !compositeCredentialSelectable(cred) {
|
||||
continue
|
||||
}
|
||||
if cred.isUsable() {
|
||||
return cred, false, nil
|
||||
}
|
||||
@@ -882,6 +1004,9 @@ func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time
|
||||
if filter != nil && !filter(candidate) {
|
||||
continue
|
||||
}
|
||||
if !compositeCredentialSelectable(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
return candidate
|
||||
}
|
||||
@@ -904,13 +1029,21 @@ func (p *fallbackProvider) allCredentials() []credential {
|
||||
func (p *fallbackProvider) close() {}
|
||||
|
||||
func allRateLimitedError(credentials []credential) error {
|
||||
var hasUnavailable bool
|
||||
var earliest time.Time
|
||||
for _, cred := range credentials {
|
||||
if cred.unavailableError() != nil {
|
||||
hasUnavailable = true
|
||||
continue
|
||||
}
|
||||
resetAt := cred.earliestReset()
|
||||
if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) {
|
||||
earliest = resetAt
|
||||
}
|
||||
}
|
||||
if hasUnavailable {
|
||||
return E.New("all credentials unavailable")
|
||||
}
|
||||
if earliest.IsZero() {
|
||||
return E.New("all credentials rate-limited")
|
||||
}
|
||||
@@ -1090,6 +1223,9 @@ func validateOCMCompositeCredentialModes(
|
||||
}
|
||||
|
||||
for _, subCred := range provider.allCredentials() {
|
||||
if !subCred.isAvailable() {
|
||||
continue
|
||||
}
|
||||
if subCred.ocmIsAPIKeyMode() {
|
||||
return E.New(
|
||||
"credential ", credOpt.Tag,
|
||||
|
||||
@@ -96,7 +96,11 @@ func unavailableCredentialMessage(provider credentialProvider, fallback string)
|
||||
if provider == nil {
|
||||
return fallback
|
||||
}
|
||||
return allRateLimitedError(provider.allCredentials()).Error()
|
||||
message := allRateLimitedError(provider.allCredentials()).Error()
|
||||
if message == "all credentials unavailable" && fallback != "" {
|
||||
return fallback
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -806,6 +810,9 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user
|
||||
var totalFiveHour, totalWeekly float64
|
||||
var count int
|
||||
for _, cred := range provider.allCredentials() {
|
||||
if !cred.isAvailable() {
|
||||
continue
|
||||
}
|
||||
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
|
||||
continue
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user