fix(ccm): make refresh failure fail fast

This commit is contained in:
世界
2026-03-28 18:07:31 +08:00
parent e7478ce947
commit 471c9c3b47
4 changed files with 159 additions and 56 deletions

View File

@@ -164,10 +164,6 @@ func (c *defaultCredential) start() error {
if err != nil {
c.logger.Error("initial credential load for ", c.tag, ": ", err)
}
if c.credentials != nil && c.credentials.needsRefresh() &&
slices.Contains(c.credentials.Scopes, "user:inference") {
c.tryRefreshCredentials(false)
}
if c.usageTracker != nil {
err = c.usageTracker.Load()
if err != nil {
@@ -216,6 +212,31 @@ type statusSnapshot struct {
weight float64
}
type refreshFailureError struct {
err error
hard bool
}
func (e *refreshFailureError) Error() string {
return e.err.Error()
}
func (e *refreshFailureError) Unwrap() error {
return e.err
}
func newRefreshFailure(err error, hard bool) error {
if err == nil {
return nil
}
return &refreshFailureError{err: err, hard: hard}
}
func isHardRefreshFailure(err error) bool {
refreshErr, ok := err.(*refreshFailureError)
return ok && refreshErr.hard
}
func (c *defaultCredential) statusSnapshotLocked() statusSnapshot {
if c.state.unavailable {
return statusSnapshot{}
@@ -339,13 +360,11 @@ func (c *defaultCredential) currentCredentials() *oauthCredentials {
return cloneCredentials(c.credentials)
}
func (c *defaultCredential) persistCredentials(credentials *oauthCredentials) {
func (c *defaultCredential) persistCredentials(credentials *oauthCredentials) error {
if credentials == nil {
return
}
if err := platformWriteCredentials(credentials, c.credentialPath); err != nil {
c.logger.Error("persist refreshed token for ", c.tag, ": ", err)
return nil
}
return platformWriteCredentials(credentials, c.credentialPath)
}
func (c *defaultCredential) shouldAttemptRefresh(credentials *oauthCredentials, force bool) bool {
@@ -361,6 +380,18 @@ func (c *defaultCredential) shouldAttemptRefresh(credentials *oauthCredentials,
return credentials.needsRefresh()
}
func (c *defaultCredential) markRefreshUnavailable(err error) error {
return newRefreshFailure(c.markCredentialsUnavailable(err), true)
}
func (c *defaultCredential) refreshCredentialsIfNeeded(force bool) error {
currentCredentials := c.currentCredentials()
if !c.shouldAttemptRefresh(currentCredentials, force) {
return nil
}
return c.tryRefreshCredentials(force)
}
func (c *defaultCredential) tryRefreshCredentials(force bool) error {
latestCredentials, err := platformReadCredentials(c.credentialPath)
if err == nil && latestCredentials != nil {
@@ -378,8 +409,7 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) error {
if err != nil {
lockErr := E.Cause(err, "acquire credential lock for ", c.tag)
c.logger.Error(lockErr)
c.markCredentialsUnavailable(lockErr)
return lockErr
return c.markRefreshUnavailable(lockErr)
}
defer release()
@@ -397,8 +427,7 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) error {
if err != nil {
writeErr := E.Cause(err, "credential file not writable for ", c.tag)
c.logger.Error(writeErr)
c.markCredentialsUnavailable(writeErr)
return writeErr
return c.markRefreshUnavailable(writeErr)
}
baseCredentials := cloneCredentials(currentCredentials)
@@ -416,15 +445,20 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) error {
return nil
}
}
return E.Cause(err, "refresh token for ", c.tag)
return newRefreshFailure(E.Cause(err, "refresh token for ", c.tag), false)
}
if refreshResult == nil || refreshResult.Credentials == nil {
return E.New("refresh token for ", c.tag, ": empty result")
return newRefreshFailure(E.New("refresh token for ", c.tag, ": empty result"), false)
}
refreshedCredentials := cloneCredentials(refreshResult.Credentials)
err = c.persistCredentials(refreshedCredentials)
if err != nil {
persistErr := E.Cause(err, "persist refreshed token for ", c.tag)
c.logger.Error(persistErr)
return c.markRefreshUnavailable(persistErr)
}
c.absorbCredentials(refreshedCredentials)
c.persistCredentials(refreshedCredentials)
if refreshResult.TokenAccount != nil {
c.absorbOAuthAccount(refreshResult.TokenAccount)
@@ -438,27 +472,30 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) error {
credentialsChanged := c.applyProfileSnapshot(profileSnapshot)
c.persistOAuthAccount()
if credentialsChanged {
c.persistCredentials(c.currentCredentials())
err = c.persistCredentials(c.currentCredentials())
if err != nil {
c.logger.Error("persist credential metadata for ", c.tag, ": ", err)
}
}
}
}
return nil
}
func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) bool {
func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) (bool, error) {
latestCredentials, err := platformReadCredentials(c.credentialPath)
if err == nil && latestCredentials != nil {
c.absorbCredentials(latestCredentials)
if latestCredentials.AccessToken != "" && latestCredentials.AccessToken != failedAccessToken {
return true
return true, nil
}
}
err = c.tryRefreshCredentials(true)
if err != nil {
return false
return false, err
}
currentCredentials := c.currentCredentials()
return currentCredentials != nil && currentCredentials.AccessToken != "" && currentCredentials.AccessToken != failedAccessToken
return currentCredentials != nil && currentCredentials.AccessToken != "" && currentCredentials.AccessToken != failedAccessToken, nil
}
func (c *defaultCredential) applyProfileSnapshot(snapshot *claudeProfileSnapshot) bool {
@@ -895,7 +932,9 @@ func (c *defaultCredential) pollUsage() {
if !c.isPollBackoffAtCap() {
c.logger.Error("poll usage for ", c.tag, ": get token: ", err)
}
c.incrementPollFailures()
if !isHardRefreshFailure(err) {
c.incrementPollFailures()
}
return
}
@@ -905,55 +944,97 @@ func (c *defaultCredential) pollUsage() {
Timeout: 5 * time.Second,
}
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+accessToken)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("User-Agent", ccmUserAgentValue)
request.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
return request, nil
})
if err != nil {
if !c.isPollBackoffAtCap() {
c.logger.Error("poll usage for ", c.tag, ": ", err)
}
c.incrementPollFailures()
return
doUsageRequest := func(token string) (*http.Response, error) {
return doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+token)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("User-Agent", ccmUserAgentValue)
request.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
return request, nil
})
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
var response *http.Response
attemptedAuthRecovery := false
for {
response, err = doUsageRequest(accessToken)
if err != nil {
if !c.isPollBackoffAtCap() {
c.logger.Error("poll usage for ", c.tag, ": ", err)
}
c.incrementPollFailures()
return
}
if response.StatusCode == http.StatusOK {
break
}
if response.StatusCode == http.StatusTooManyRequests {
retryDelay := time.Minute
if retryAfter := response.Header.Get("Retry-After"); retryAfter != "" {
seconds, err := strconv.ParseInt(retryAfter, 10, 64)
if err == nil && seconds > 0 {
seconds, parseErr := strconv.ParseInt(retryAfter, 10, 64)
if parseErr == nil && seconds > 0 {
retryDelay = time.Duration(seconds) * time.Second
}
}
response.Body.Close()
c.logger.Warn("poll usage for ", c.tag, ": usage API rate limited, retry in ", log.FormatDuration(retryDelay))
c.stateAccess.Lock()
c.state.usageAPIRetryDelay = retryDelay
c.stateAccess.Unlock()
return
}
body, _ := io.ReadAll(response.Body)
if response.StatusCode == http.StatusUnauthorized {
c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
if !c.recoverAuthFailure(accessToken) {
c.markCredentialsUnavailable(E.New("poll usage unauthorized for ", c.tag))
response.Body.Close()
recoverableAuthFailure := !attemptedAuthRecovery &&
(response.StatusCode == http.StatusUnauthorized ||
(response.StatusCode == http.StatusForbidden && bytes.Contains(body, []byte("OAuth token has been revoked"))))
if recoverableAuthFailure {
if !c.isPollBackoffAtCap() {
c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
}
return
attemptedAuthRecovery = true
recovered, recoverErr := c.recoverAuthFailure(accessToken)
if recoverErr != nil {
if !isHardRefreshFailure(recoverErr) {
if !c.isPollBackoffAtCap() {
c.logger.Error("poll usage for ", c.tag, ": auth recovery: ", recoverErr)
}
c.incrementPollFailures()
}
return
}
if !recovered {
if !c.isPollBackoffAtCap() {
c.logger.Error("poll usage for ", c.tag, ": auth recovery did not produce a new token")
}
c.incrementPollFailures()
return
}
accessToken, err = c.getAccessToken()
if err != nil {
if !c.isPollBackoffAtCap() {
c.logger.Error("poll usage for ", c.tag, ": get token after auth recovery: ", err)
}
if !isHardRefreshFailure(err) {
c.incrementPollFailures()
}
return
}
continue
}
if !c.isPollBackoffAtCap() {
c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
}
c.incrementPollFailures()
return
}
defer response.Body.Close()
var usageResponse struct {
FiveHour struct {

View File

@@ -14,14 +14,15 @@ func TestGetAccessTokenMarksUnavailableWhenLockFails(t *testing.T) {
directory := t.TempDir()
credentialPath := filepath.Join(directory, ".credentials.json")
writeTestCredentials(t, credentialPath, &oauthCredentials{
credentials := &oauthCredentials{
AccessToken: "old-token",
RefreshToken: "refresh-token",
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
ExpiresAt: time.Now().Add(time.Hour).UnixMilli(),
Scopes: []string{"user:profile", "user:inference"},
SubscriptionType: optionalStringPointer("max"),
RateLimitTier: optionalStringPointer("default_claude_max_20x"),
})
}
writeTestCredentials(t, credentialPath, credentials)
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
t.Fatal("refresh should not be attempted when lock acquisition fails")
@@ -31,6 +32,11 @@ func TestGetAccessTokenMarksUnavailableWhenLockFails(t *testing.T) {
t.Fatal(err)
}
expiredCredentials := cloneCredentials(credentials)
expiredCredentials.ExpiresAt = time.Now().Add(-time.Minute).UnixMilli()
writeTestCredentials(t, credentialPath, expiredCredentials)
credential.absorbCredentials(expiredCredentials)
credential.acquireLock = func(string) (func(), error) {
return nil, errors.New("permission denied")
}
@@ -49,12 +55,13 @@ func TestGetAccessTokenMarksUnavailableOnUnwritableFile(t *testing.T) {
directory := t.TempDir()
credentialPath := filepath.Join(directory, ".credentials.json")
writeTestCredentials(t, credentialPath, &oauthCredentials{
credentials := &oauthCredentials{
AccessToken: "old-token",
RefreshToken: "refresh-token",
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
ExpiresAt: time.Now().Add(time.Hour).UnixMilli(),
Scopes: []string{"user:profile", "user:inference"},
})
}
writeTestCredentials(t, credentialPath, credentials)
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
t.Fatal("refresh should not be attempted when file is not writable")
@@ -64,6 +71,11 @@ func TestGetAccessTokenMarksUnavailableOnUnwritableFile(t *testing.T) {
t.Fatal(err)
}
expiredCredentials := cloneCredentials(credentials)
expiredCredentials.ExpiresAt = time.Now().Add(-time.Minute).UnixMilli()
writeTestCredentials(t, credentialPath, expiredCredentials)
credential.absorbCredentials(expiredCredentials)
os.Chmod(credentialPath, 0o444)
t.Cleanup(func() { os.Chmod(credentialPath, 0o644) })

View File

@@ -107,7 +107,7 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
}
c.absorbCredentials(credentials)
return nil
return c.refreshCredentialsIfNeeded(false)
}
func (c *defaultCredential) markCredentialsUnavailable(err error) error {

View File

@@ -422,6 +422,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
if shouldRetry {
recovered := false
var recoverErr error
if defaultCred, ok := selectedCredential.(*defaultCredential); ok {
failedAccessToken := ""
currentCredentials := defaultCred.currentCredentials()
@@ -429,7 +430,16 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
failedAccessToken = currentCredentials.AccessToken
}
s.logger.WarnContext(ctx, "upstream auth failure from ", selectedCredential.tagName(), ", reloading credentials and retrying")
recovered = defaultCred.recoverAuthFailure(failedAccessToken)
recovered, recoverErr = defaultCred.recoverAuthFailure(failedAccessToken)
}
if recoverErr != nil {
response.Body.Close()
if isHardRefreshFailure(recoverErr) || selectedCredential.unavailableError() != nil {
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable during auth recovery")
return
}
writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(recoverErr, "auth recovery").Error())
return
}
if recovered {
response.Body.Close()