mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-11 17:47:20 +10:00
fix(ccm): make refresh failure fail fast
This commit is contained in:
@@ -164,10 +164,6 @@ func (c *defaultCredential) start() error {
|
||||
if err != nil {
|
||||
c.logger.Error("initial credential load for ", c.tag, ": ", err)
|
||||
}
|
||||
if c.credentials != nil && c.credentials.needsRefresh() &&
|
||||
slices.Contains(c.credentials.Scopes, "user:inference") {
|
||||
c.tryRefreshCredentials(false)
|
||||
}
|
||||
if c.usageTracker != nil {
|
||||
err = c.usageTracker.Load()
|
||||
if err != nil {
|
||||
@@ -216,6 +212,31 @@ type statusSnapshot struct {
|
||||
weight float64
|
||||
}
|
||||
|
||||
type refreshFailureError struct {
|
||||
err error
|
||||
hard bool
|
||||
}
|
||||
|
||||
func (e *refreshFailureError) Error() string {
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func (e *refreshFailureError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
func newRefreshFailure(err error, hard bool) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return &refreshFailureError{err: err, hard: hard}
|
||||
}
|
||||
|
||||
func isHardRefreshFailure(err error) bool {
|
||||
refreshErr, ok := err.(*refreshFailureError)
|
||||
return ok && refreshErr.hard
|
||||
}
|
||||
|
||||
func (c *defaultCredential) statusSnapshotLocked() statusSnapshot {
|
||||
if c.state.unavailable {
|
||||
return statusSnapshot{}
|
||||
@@ -339,13 +360,11 @@ func (c *defaultCredential) currentCredentials() *oauthCredentials {
|
||||
return cloneCredentials(c.credentials)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) persistCredentials(credentials *oauthCredentials) {
|
||||
func (c *defaultCredential) persistCredentials(credentials *oauthCredentials) error {
|
||||
if credentials == nil {
|
||||
return
|
||||
}
|
||||
if err := platformWriteCredentials(credentials, c.credentialPath); err != nil {
|
||||
c.logger.Error("persist refreshed token for ", c.tag, ": ", err)
|
||||
return nil
|
||||
}
|
||||
return platformWriteCredentials(credentials, c.credentialPath)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) shouldAttemptRefresh(credentials *oauthCredentials, force bool) bool {
|
||||
@@ -361,6 +380,18 @@ func (c *defaultCredential) shouldAttemptRefresh(credentials *oauthCredentials,
|
||||
return credentials.needsRefresh()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markRefreshUnavailable(err error) error {
|
||||
return newRefreshFailure(c.markCredentialsUnavailable(err), true)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) refreshCredentialsIfNeeded(force bool) error {
|
||||
currentCredentials := c.currentCredentials()
|
||||
if !c.shouldAttemptRefresh(currentCredentials, force) {
|
||||
return nil
|
||||
}
|
||||
return c.tryRefreshCredentials(force)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) tryRefreshCredentials(force bool) error {
|
||||
latestCredentials, err := platformReadCredentials(c.credentialPath)
|
||||
if err == nil && latestCredentials != nil {
|
||||
@@ -378,8 +409,7 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) error {
|
||||
if err != nil {
|
||||
lockErr := E.Cause(err, "acquire credential lock for ", c.tag)
|
||||
c.logger.Error(lockErr)
|
||||
c.markCredentialsUnavailable(lockErr)
|
||||
return lockErr
|
||||
return c.markRefreshUnavailable(lockErr)
|
||||
}
|
||||
defer release()
|
||||
|
||||
@@ -397,8 +427,7 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) error {
|
||||
if err != nil {
|
||||
writeErr := E.Cause(err, "credential file not writable for ", c.tag)
|
||||
c.logger.Error(writeErr)
|
||||
c.markCredentialsUnavailable(writeErr)
|
||||
return writeErr
|
||||
return c.markRefreshUnavailable(writeErr)
|
||||
}
|
||||
|
||||
baseCredentials := cloneCredentials(currentCredentials)
|
||||
@@ -416,15 +445,20 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) error {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return E.Cause(err, "refresh token for ", c.tag)
|
||||
return newRefreshFailure(E.Cause(err, "refresh token for ", c.tag), false)
|
||||
}
|
||||
if refreshResult == nil || refreshResult.Credentials == nil {
|
||||
return E.New("refresh token for ", c.tag, ": empty result")
|
||||
return newRefreshFailure(E.New("refresh token for ", c.tag, ": empty result"), false)
|
||||
}
|
||||
|
||||
refreshedCredentials := cloneCredentials(refreshResult.Credentials)
|
||||
err = c.persistCredentials(refreshedCredentials)
|
||||
if err != nil {
|
||||
persistErr := E.Cause(err, "persist refreshed token for ", c.tag)
|
||||
c.logger.Error(persistErr)
|
||||
return c.markRefreshUnavailable(persistErr)
|
||||
}
|
||||
c.absorbCredentials(refreshedCredentials)
|
||||
c.persistCredentials(refreshedCredentials)
|
||||
|
||||
if refreshResult.TokenAccount != nil {
|
||||
c.absorbOAuthAccount(refreshResult.TokenAccount)
|
||||
@@ -438,27 +472,30 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) error {
|
||||
credentialsChanged := c.applyProfileSnapshot(profileSnapshot)
|
||||
c.persistOAuthAccount()
|
||||
if credentialsChanged {
|
||||
c.persistCredentials(c.currentCredentials())
|
||||
err = c.persistCredentials(c.currentCredentials())
|
||||
if err != nil {
|
||||
c.logger.Error("persist credential metadata for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) bool {
|
||||
func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) (bool, error) {
|
||||
latestCredentials, err := platformReadCredentials(c.credentialPath)
|
||||
if err == nil && latestCredentials != nil {
|
||||
c.absorbCredentials(latestCredentials)
|
||||
if latestCredentials.AccessToken != "" && latestCredentials.AccessToken != failedAccessToken {
|
||||
return true
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
err = c.tryRefreshCredentials(true)
|
||||
if err != nil {
|
||||
return false
|
||||
return false, err
|
||||
}
|
||||
currentCredentials := c.currentCredentials()
|
||||
return currentCredentials != nil && currentCredentials.AccessToken != "" && currentCredentials.AccessToken != failedAccessToken
|
||||
return currentCredentials != nil && currentCredentials.AccessToken != "" && currentCredentials.AccessToken != failedAccessToken, nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) applyProfileSnapshot(snapshot *claudeProfileSnapshot) bool {
|
||||
@@ -895,7 +932,9 @@ func (c *defaultCredential) pollUsage() {
|
||||
if !c.isPollBackoffAtCap() {
|
||||
c.logger.Error("poll usage for ", c.tag, ": get token: ", err)
|
||||
}
|
||||
c.incrementPollFailures()
|
||||
if !isHardRefreshFailure(err) {
|
||||
c.incrementPollFailures()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -905,55 +944,97 @@ func (c *defaultCredential) pollUsage() {
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("User-Agent", ccmUserAgentValue)
|
||||
request.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
|
||||
return request, nil
|
||||
})
|
||||
if err != nil {
|
||||
if !c.isPollBackoffAtCap() {
|
||||
c.logger.Error("poll usage for ", c.tag, ": ", err)
|
||||
}
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
doUsageRequest := func(token string) (*http.Response, error) {
|
||||
return doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer "+token)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("User-Agent", ccmUserAgentValue)
|
||||
request.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
|
||||
return request, nil
|
||||
})
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
var response *http.Response
|
||||
attemptedAuthRecovery := false
|
||||
for {
|
||||
response, err = doUsageRequest(accessToken)
|
||||
if err != nil {
|
||||
if !c.isPollBackoffAtCap() {
|
||||
c.logger.Error("poll usage for ", c.tag, ": ", err)
|
||||
}
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
if response.StatusCode == http.StatusOK {
|
||||
break
|
||||
}
|
||||
if response.StatusCode == http.StatusTooManyRequests {
|
||||
retryDelay := time.Minute
|
||||
if retryAfter := response.Header.Get("Retry-After"); retryAfter != "" {
|
||||
seconds, err := strconv.ParseInt(retryAfter, 10, 64)
|
||||
if err == nil && seconds > 0 {
|
||||
seconds, parseErr := strconv.ParseInt(retryAfter, 10, 64)
|
||||
if parseErr == nil && seconds > 0 {
|
||||
retryDelay = time.Duration(seconds) * time.Second
|
||||
}
|
||||
}
|
||||
response.Body.Close()
|
||||
c.logger.Warn("poll usage for ", c.tag, ": usage API rate limited, retry in ", log.FormatDuration(retryDelay))
|
||||
c.stateAccess.Lock()
|
||||
c.state.usageAPIRetryDelay = retryDelay
|
||||
c.stateAccess.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
if response.StatusCode == http.StatusUnauthorized {
|
||||
c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
if !c.recoverAuthFailure(accessToken) {
|
||||
c.markCredentialsUnavailable(E.New("poll usage unauthorized for ", c.tag))
|
||||
response.Body.Close()
|
||||
recoverableAuthFailure := !attemptedAuthRecovery &&
|
||||
(response.StatusCode == http.StatusUnauthorized ||
|
||||
(response.StatusCode == http.StatusForbidden && bytes.Contains(body, []byte("OAuth token has been revoked"))))
|
||||
if recoverableAuthFailure {
|
||||
if !c.isPollBackoffAtCap() {
|
||||
c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
}
|
||||
return
|
||||
attemptedAuthRecovery = true
|
||||
recovered, recoverErr := c.recoverAuthFailure(accessToken)
|
||||
if recoverErr != nil {
|
||||
if !isHardRefreshFailure(recoverErr) {
|
||||
if !c.isPollBackoffAtCap() {
|
||||
c.logger.Error("poll usage for ", c.tag, ": auth recovery: ", recoverErr)
|
||||
}
|
||||
c.incrementPollFailures()
|
||||
}
|
||||
return
|
||||
}
|
||||
if !recovered {
|
||||
if !c.isPollBackoffAtCap() {
|
||||
c.logger.Error("poll usage for ", c.tag, ": auth recovery did not produce a new token")
|
||||
}
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
accessToken, err = c.getAccessToken()
|
||||
if err != nil {
|
||||
if !c.isPollBackoffAtCap() {
|
||||
c.logger.Error("poll usage for ", c.tag, ": get token after auth recovery: ", err)
|
||||
}
|
||||
if !isHardRefreshFailure(err) {
|
||||
c.incrementPollFailures()
|
||||
}
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if !c.isPollBackoffAtCap() {
|
||||
c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
}
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
var usageResponse struct {
|
||||
FiveHour struct {
|
||||
|
||||
@@ -14,14 +14,15 @@ func TestGetAccessTokenMarksUnavailableWhenLockFails(t *testing.T) {
|
||||
|
||||
directory := t.TempDir()
|
||||
credentialPath := filepath.Join(directory, ".credentials.json")
|
||||
writeTestCredentials(t, credentialPath, &oauthCredentials{
|
||||
credentials := &oauthCredentials{
|
||||
AccessToken: "old-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
|
||||
ExpiresAt: time.Now().Add(time.Hour).UnixMilli(),
|
||||
Scopes: []string{"user:profile", "user:inference"},
|
||||
SubscriptionType: optionalStringPointer("max"),
|
||||
RateLimitTier: optionalStringPointer("default_claude_max_20x"),
|
||||
})
|
||||
}
|
||||
writeTestCredentials(t, credentialPath, credentials)
|
||||
|
||||
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
t.Fatal("refresh should not be attempted when lock acquisition fails")
|
||||
@@ -31,6 +32,11 @@ func TestGetAccessTokenMarksUnavailableWhenLockFails(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expiredCredentials := cloneCredentials(credentials)
|
||||
expiredCredentials.ExpiresAt = time.Now().Add(-time.Minute).UnixMilli()
|
||||
writeTestCredentials(t, credentialPath, expiredCredentials)
|
||||
credential.absorbCredentials(expiredCredentials)
|
||||
|
||||
credential.acquireLock = func(string) (func(), error) {
|
||||
return nil, errors.New("permission denied")
|
||||
}
|
||||
@@ -49,12 +55,13 @@ func TestGetAccessTokenMarksUnavailableOnUnwritableFile(t *testing.T) {
|
||||
|
||||
directory := t.TempDir()
|
||||
credentialPath := filepath.Join(directory, ".credentials.json")
|
||||
writeTestCredentials(t, credentialPath, &oauthCredentials{
|
||||
credentials := &oauthCredentials{
|
||||
AccessToken: "old-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
|
||||
ExpiresAt: time.Now().Add(time.Hour).UnixMilli(),
|
||||
Scopes: []string{"user:profile", "user:inference"},
|
||||
})
|
||||
}
|
||||
writeTestCredentials(t, credentialPath, credentials)
|
||||
|
||||
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
t.Fatal("refresh should not be attempted when file is not writable")
|
||||
@@ -64,6 +71,11 @@ func TestGetAccessTokenMarksUnavailableOnUnwritableFile(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expiredCredentials := cloneCredentials(credentials)
|
||||
expiredCredentials.ExpiresAt = time.Now().Add(-time.Minute).UnixMilli()
|
||||
writeTestCredentials(t, credentialPath, expiredCredentials)
|
||||
credential.absorbCredentials(expiredCredentials)
|
||||
|
||||
os.Chmod(credentialPath, 0o444)
|
||||
t.Cleanup(func() { os.Chmod(credentialPath, 0o644) })
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
|
||||
}
|
||||
|
||||
c.absorbCredentials(credentials)
|
||||
return nil
|
||||
return c.refreshCredentialsIfNeeded(false)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markCredentialsUnavailable(err error) error {
|
||||
|
||||
@@ -422,6 +422,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
if shouldRetry {
|
||||
recovered := false
|
||||
var recoverErr error
|
||||
if defaultCred, ok := selectedCredential.(*defaultCredential); ok {
|
||||
failedAccessToken := ""
|
||||
currentCredentials := defaultCred.currentCredentials()
|
||||
@@ -429,7 +430,16 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
failedAccessToken = currentCredentials.AccessToken
|
||||
}
|
||||
s.logger.WarnContext(ctx, "upstream auth failure from ", selectedCredential.tagName(), ", reloading credentials and retrying")
|
||||
recovered = defaultCred.recoverAuthFailure(failedAccessToken)
|
||||
recovered, recoverErr = defaultCred.recoverAuthFailure(failedAccessToken)
|
||||
}
|
||||
if recoverErr != nil {
|
||||
response.Body.Close()
|
||||
if isHardRefreshFailure(recoverErr) || selectedCredential.unavailableError() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable during auth recovery")
|
||||
return
|
||||
}
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(recoverErr, "auth recovery").Error())
|
||||
return
|
||||
}
|
||||
if recovered {
|
||||
response.Body.Close()
|
||||
|
||||
Reference in New Issue
Block a user