From 471c9c3b470ce492b0496e0e7b6895f8c3f81117 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 28 Mar 2026 18:07:31 +0800 Subject: [PATCH] fix(ccm): make refresh failure fail fast --- service/ccm/credential_default.go | 177 ++++++++++++++++++------- service/ccm/credential_default_test.go | 24 +++- service/ccm/credential_file.go | 2 +- service/ccm/service_handler.go | 12 +- 4 files changed, 159 insertions(+), 56 deletions(-) diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index 4f5ecefd1..86c86879c 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -164,10 +164,6 @@ func (c *defaultCredential) start() error { if err != nil { c.logger.Error("initial credential load for ", c.tag, ": ", err) } - if c.credentials != nil && c.credentials.needsRefresh() && - slices.Contains(c.credentials.Scopes, "user:inference") { - c.tryRefreshCredentials(false) - } if c.usageTracker != nil { err = c.usageTracker.Load() if err != nil { @@ -216,6 +212,31 @@ type statusSnapshot struct { weight float64 } +type refreshFailureError struct { + err error + hard bool +} + +func (e *refreshFailureError) Error() string { + return e.err.Error() +} + +func (e *refreshFailureError) Unwrap() error { + return e.err +} + +func newRefreshFailure(err error, hard bool) error { + if err == nil { + return nil + } + return &refreshFailureError{err: err, hard: hard} +} + +func isHardRefreshFailure(err error) bool { + refreshErr, ok := err.(*refreshFailureError) + return ok && refreshErr.hard +} + func (c *defaultCredential) statusSnapshotLocked() statusSnapshot { if c.state.unavailable { return statusSnapshot{} @@ -339,13 +360,11 @@ func (c *defaultCredential) currentCredentials() *oauthCredentials { return cloneCredentials(c.credentials) } -func (c *defaultCredential) persistCredentials(credentials *oauthCredentials) { +func (c *defaultCredential) persistCredentials(credentials *oauthCredentials) error { if credentials == nil { - return - } - if err := platformWriteCredentials(credentials, c.credentialPath); err != nil { - c.logger.Error("persist refreshed token for ", c.tag, ": ", err) + return nil } + return platformWriteCredentials(credentials, c.credentialPath) } func (c *defaultCredential) shouldAttemptRefresh(credentials *oauthCredentials, force bool) bool { @@ -361,6 +380,18 @@ func (c *defaultCredential) shouldAttemptRefresh(credentials *oauthCredentials, return credentials.needsRefresh() } +func (c *defaultCredential) markRefreshUnavailable(err error) error { + return newRefreshFailure(c.markCredentialsUnavailable(err), true) +} + +func (c *defaultCredential) refreshCredentialsIfNeeded(force bool) error { + currentCredentials := c.currentCredentials() + if !c.shouldAttemptRefresh(currentCredentials, force) { + return nil + } + return c.tryRefreshCredentials(force) +} + func (c *defaultCredential) tryRefreshCredentials(force bool) error { latestCredentials, err := platformReadCredentials(c.credentialPath) if err == nil && latestCredentials != nil { @@ -378,8 +409,7 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) error { if err != nil { lockErr := E.Cause(err, "acquire credential lock for ", c.tag) c.logger.Error(lockErr) - c.markCredentialsUnavailable(lockErr) - return lockErr + return c.markRefreshUnavailable(lockErr) } defer release() @@ -397,8 +427,7 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) error { if err != nil { writeErr := E.Cause(err, "credential file not writable for ", c.tag) c.logger.Error(writeErr) - c.markCredentialsUnavailable(writeErr) - return writeErr + return c.markRefreshUnavailable(writeErr) } baseCredentials := cloneCredentials(currentCredentials) @@ -416,15 +445,20 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) error { return nil } } - return E.Cause(err, "refresh token for ", c.tag) + return newRefreshFailure(E.Cause(err, "refresh token for ", c.tag), false) } if refreshResult == nil || refreshResult.Credentials == nil { - return E.New("refresh token for ", c.tag, ": empty result") + return newRefreshFailure(E.New("refresh token for ", c.tag, ": empty result"), false) } refreshedCredentials := cloneCredentials(refreshResult.Credentials) + err = c.persistCredentials(refreshedCredentials) + if err != nil { + persistErr := E.Cause(err, "persist refreshed token for ", c.tag) + c.logger.Error(persistErr) + return c.markRefreshUnavailable(persistErr) + } c.absorbCredentials(refreshedCredentials) - c.persistCredentials(refreshedCredentials) if refreshResult.TokenAccount != nil { c.absorbOAuthAccount(refreshResult.TokenAccount) @@ -438,27 +472,30 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) error { credentialsChanged := c.applyProfileSnapshot(profileSnapshot) c.persistOAuthAccount() if credentialsChanged { - c.persistCredentials(c.currentCredentials()) + err = c.persistCredentials(c.currentCredentials()) + if err != nil { + c.logger.Error("persist credential metadata for ", c.tag, ": ", err) + } } } } return nil } -func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) bool { +func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) (bool, error) { latestCredentials, err := platformReadCredentials(c.credentialPath) if err == nil && latestCredentials != nil { c.absorbCredentials(latestCredentials) if latestCredentials.AccessToken != "" && latestCredentials.AccessToken != failedAccessToken { - return true + return true, nil } } err = c.tryRefreshCredentials(true) if err != nil { - return false + return false, err } currentCredentials := c.currentCredentials() - return currentCredentials != nil && currentCredentials.AccessToken != "" && currentCredentials.AccessToken != failedAccessToken + return currentCredentials != nil && currentCredentials.AccessToken != "" && currentCredentials.AccessToken != failedAccessToken, nil } func (c *defaultCredential) applyProfileSnapshot(snapshot *claudeProfileSnapshot) bool { @@ -895,7 +932,9 @@ func (c *defaultCredential) pollUsage() { if !c.isPollBackoffAtCap() { c.logger.Error("poll usage for ", c.tag, ": get token: ", err) } - c.incrementPollFailures() + if !isHardRefreshFailure(err) { + c.incrementPollFailures() + } return } @@ -905,55 +944,97 @@ func (c *defaultCredential) pollUsage() { Timeout: 5 * time.Second, } - response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { - request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil) - if err != nil { - return nil, err - } - request.Header.Set("Authorization", "Bearer "+accessToken) - request.Header.Set("Content-Type", "application/json") - request.Header.Set("User-Agent", ccmUserAgentValue) - request.Header.Set("anthropic-beta", anthropicBetaOAuthValue) - return request, nil - }) - if err != nil { - if !c.isPollBackoffAtCap() { - c.logger.Error("poll usage for ", c.tag, ": ", err) - } - c.incrementPollFailures() - return + doUsageRequest := func(token string) (*http.Response, error) { + return doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+token) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("User-Agent", ccmUserAgentValue) + request.Header.Set("anthropic-beta", anthropicBetaOAuthValue) + return request, nil + }) } - defer response.Body.Close() - if response.StatusCode != http.StatusOK { + var response *http.Response + attemptedAuthRecovery := false + for { + response, err = doUsageRequest(accessToken) + if err != nil { + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": ", err) + } + c.incrementPollFailures() + return + } + if response.StatusCode == http.StatusOK { + break + } if response.StatusCode == http.StatusTooManyRequests { retryDelay := time.Minute if retryAfter := response.Header.Get("Retry-After"); retryAfter != "" { - seconds, err := strconv.ParseInt(retryAfter, 10, 64) - if err == nil && seconds > 0 { + seconds, parseErr := strconv.ParseInt(retryAfter, 10, 64) + if parseErr == nil && seconds > 0 { retryDelay = time.Duration(seconds) * time.Second } } + response.Body.Close() c.logger.Warn("poll usage for ", c.tag, ": usage API rate limited, retry in ", log.FormatDuration(retryDelay)) c.stateAccess.Lock() c.state.usageAPIRetryDelay = retryDelay c.stateAccess.Unlock() return } + body, _ := io.ReadAll(response.Body) - if response.StatusCode == http.StatusUnauthorized { - c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) - if !c.recoverAuthFailure(accessToken) { - c.markCredentialsUnavailable(E.New("poll usage unauthorized for ", c.tag)) + response.Body.Close() + recoverableAuthFailure := !attemptedAuthRecovery && + (response.StatusCode == http.StatusUnauthorized || + (response.StatusCode == http.StatusForbidden && bytes.Contains(body, []byte("OAuth token has been revoked")))) + if recoverableAuthFailure { + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) } - return + attemptedAuthRecovery = true + recovered, recoverErr := c.recoverAuthFailure(accessToken) + if recoverErr != nil { + if !isHardRefreshFailure(recoverErr) { + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": auth recovery: ", recoverErr) + } + c.incrementPollFailures() + } + return + } + if !recovered { + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": auth recovery did not produce a new token") + } + c.incrementPollFailures() + return + } + accessToken, err = c.getAccessToken() + if err != nil { + if !c.isPollBackoffAtCap() { + c.logger.Error("poll usage for ", c.tag, ": get token after auth recovery: ", err) + } + if !isHardRefreshFailure(err) { + c.incrementPollFailures() + } + return + } + continue } + if !c.isPollBackoffAtCap() { c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) } c.incrementPollFailures() return } + defer response.Body.Close() var usageResponse struct { FiveHour struct { diff --git a/service/ccm/credential_default_test.go b/service/ccm/credential_default_test.go index 8da97dbea..90158afe0 100644 --- a/service/ccm/credential_default_test.go +++ b/service/ccm/credential_default_test.go @@ -14,14 +14,15 @@ func TestGetAccessTokenMarksUnavailableWhenLockFails(t *testing.T) { directory := t.TempDir() credentialPath := filepath.Join(directory, ".credentials.json") - writeTestCredentials(t, credentialPath, &oauthCredentials{ + credentials := &oauthCredentials{ AccessToken: "old-token", RefreshToken: "refresh-token", - ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(), + ExpiresAt: time.Now().Add(time.Hour).UnixMilli(), Scopes: []string{"user:profile", "user:inference"}, SubscriptionType: optionalStringPointer("max"), RateLimitTier: optionalStringPointer("default_claude_max_20x"), - }) + } + writeTestCredentials(t, credentialPath, credentials) credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) { t.Fatal("refresh should not be attempted when lock acquisition fails") @@ -31,6 +32,11 @@ func TestGetAccessTokenMarksUnavailableWhenLockFails(t *testing.T) { t.Fatal(err) } + expiredCredentials := cloneCredentials(credentials) + expiredCredentials.ExpiresAt = time.Now().Add(-time.Minute).UnixMilli() + writeTestCredentials(t, credentialPath, expiredCredentials) + credential.absorbCredentials(expiredCredentials) + credential.acquireLock = func(string) (func(), error) { return nil, errors.New("permission denied") } @@ -49,12 +55,13 @@ func TestGetAccessTokenMarksUnavailableOnUnwritableFile(t *testing.T) { directory := t.TempDir() credentialPath := filepath.Join(directory, ".credentials.json") - writeTestCredentials(t, credentialPath, &oauthCredentials{ + credentials := &oauthCredentials{ AccessToken: "old-token", RefreshToken: "refresh-token", - ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(), + ExpiresAt: time.Now().Add(time.Hour).UnixMilli(), Scopes: []string{"user:profile", "user:inference"}, - }) + } + writeTestCredentials(t, credentialPath, credentials) credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) { t.Fatal("refresh should not be attempted when file is not writable") @@ -64,6 +71,11 @@ func TestGetAccessTokenMarksUnavailableOnUnwritableFile(t *testing.T) { t.Fatal(err) } + expiredCredentials := cloneCredentials(credentials) + expiredCredentials.ExpiresAt = time.Now().Add(-time.Minute).UnixMilli() + writeTestCredentials(t, credentialPath, expiredCredentials) + credential.absorbCredentials(expiredCredentials) + os.Chmod(credentialPath, 0o444) t.Cleanup(func() { os.Chmod(credentialPath, 0o644) }) diff --git a/service/ccm/credential_file.go b/service/ccm/credential_file.go index 7258dd4e0..afff53d15 100644 --- a/service/ccm/credential_file.go +++ b/service/ccm/credential_file.go @@ -107,7 +107,7 @@ func (c *defaultCredential) reloadCredentials(force bool) error { } c.absorbCredentials(credentials) - return nil + return c.refreshCredentialsIfNeeded(false) } func (c *defaultCredential) markCredentialsUnavailable(err error) error { diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index aca9dc647..796d38a06 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -422,6 +422,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if shouldRetry { recovered := false + var recoverErr error if defaultCred, ok := selectedCredential.(*defaultCredential); ok { failedAccessToken := "" currentCredentials := defaultCred.currentCredentials() @@ -429,7 +430,16 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { failedAccessToken = currentCredentials.AccessToken } s.logger.WarnContext(ctx, "upstream auth failure from ", selectedCredential.tagName(), ", reloading credentials and retrying") - recovered = defaultCred.recoverAuthFailure(failedAccessToken) + recovered, recoverErr = defaultCred.recoverAuthFailure(failedAccessToken) + } + if recoverErr != nil { + response.Body.Close() + if isHardRefreshFailure(recoverErr) || selectedCredential.unavailableError() != nil { + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable during auth recovery") + return + } + writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(recoverErr, "auth recovery").Error()) + return } if recovered { response.Body.Close()