fix(ccm): mark credential unavailable on refresh failure, handle poll 401

tryRefreshCredentials now returns error and calls markCredentialsUnavailable
when lock acquisition or file write permission fails. getAccessToken propagates
the error instead of silently returning the expired token. pollUsage handles
401 by attempting auth recovery and marking unavailable on failure. All
credential error paths now use Error log level instead of Debug. Startup
checks expired tokens eagerly via tryRefreshCredentials.
This commit is contained in:
世界
2026-03-28 16:58:52 +08:00
parent cf11e0e74a
commit e7478ce947
4 changed files with 94 additions and 34 deletions

View File

@@ -158,11 +158,15 @@ func (c *defaultCredential) start() error {
}
err = c.ensureCredentialWatcher()
if err != nil {
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
c.logger.Error("start credential watcher for ", c.tag, ": ", err)
}
err = c.reloadCredentials(true)
if err != nil {
c.logger.Warn("initial credential load for ", c.tag, ": ", err)
c.logger.Error("initial credential load for ", c.tag, ": ", err)
}
if c.credentials != nil && c.credentials.needsRefresh() &&
slices.Contains(c.credentials.Scopes, "user:inference") {
c.tryRefreshCredentials(false)
}
if c.usageTracker != nil {
err = c.usageTracker.Load()
@@ -240,7 +244,10 @@ func (c *defaultCredential) getAccessToken() (string, error) {
if !currentCredentials.needsRefresh() || !slices.Contains(currentCredentials.Scopes, "user:inference") {
return currentCredentials.AccessToken, nil
}
c.tryRefreshCredentials(false)
refreshErr := c.tryRefreshCredentials(false)
if refreshErr != nil {
return "", refreshErr
}
c.access.RLock()
defer c.access.RUnlock()
if c.credentials != nil && c.credentials.AccessToken != "" {
@@ -354,14 +361,14 @@ func (c *defaultCredential) shouldAttemptRefresh(credentials *oauthCredentials,
return credentials.needsRefresh()
}
func (c *defaultCredential) tryRefreshCredentials(force bool) bool {
func (c *defaultCredential) tryRefreshCredentials(force bool) error {
latestCredentials, err := platformReadCredentials(c.credentialPath)
if err == nil && latestCredentials != nil {
c.absorbCredentials(latestCredentials)
}
currentCredentials := c.currentCredentials()
if !c.shouldAttemptRefresh(currentCredentials, force) {
return false
return nil
}
acquireLock := c.acquireLock
if acquireLock == nil {
@@ -369,8 +376,10 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) bool {
}
release, err := acquireLock(c.configDir)
if err != nil {
c.logger.Debug("acquire credential lock for ", c.tag, ": ", err)
return false
lockErr := E.Cause(err, "acquire credential lock for ", c.tag)
c.logger.Error(lockErr)
c.markCredentialsUnavailable(lockErr)
return lockErr
}
defer release()
@@ -382,30 +391,35 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) bool {
currentCredentials = c.currentCredentials()
}
if !c.shouldAttemptRefresh(currentCredentials, force) {
return false
return nil
}
if err := platformCanWriteCredentials(c.credentialPath); err != nil {
c.logger.Debug("credential file not writable for ", c.tag, ": ", err)
return false
err = platformCanWriteCredentials(c.credentialPath)
if err != nil {
writeErr := E.Cause(err, "credential file not writable for ", c.tag)
c.logger.Error(writeErr)
c.markCredentialsUnavailable(writeErr)
return writeErr
}
baseCredentials := cloneCredentials(currentCredentials)
refreshResult, retryDelay, err := refreshToken(c.serviceContext, c.forwardHTTPClient, currentCredentials)
if err != nil {
if retryDelay != 0 {
c.logger.Debug("refresh token for ", c.tag, ": retry delay=", retryDelay, ", error=", err)
c.logger.Error("refresh token for ", c.tag, ": retry delay=", retryDelay, ", error=", err)
} else {
c.logger.Debug("refresh token for ", c.tag, ": ", err)
c.logger.Error("refresh token for ", c.tag, ": ", err)
}
latestCredentials, readErr := platformReadCredentials(c.credentialPath)
if readErr == nil && latestCredentials != nil {
c.absorbCredentials(latestCredentials)
return latestCredentials.AccessToken != "" && (latestCredentials.AccessToken != baseCredentials.AccessToken || !latestCredentials.needsRefresh())
if latestCredentials.AccessToken != "" && (latestCredentials.AccessToken != baseCredentials.AccessToken || !latestCredentials.needsRefresh()) {
return nil
}
}
return false
return E.Cause(err, "refresh token for ", c.tag)
}
if refreshResult == nil || refreshResult.Credentials == nil {
return false
return E.New("refresh token for ", c.tag, ": empty result")
}
refreshedCredentials := cloneCredentials(refreshResult.Credentials)
@@ -419,7 +433,7 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) bool {
if c.needsProfileHydration() {
profileSnapshot, profileErr := c.fetchProfileSnapshot(c.forwardHTTPClient, refreshedCredentials.AccessToken)
if profileErr != nil {
c.logger.Debug("fetch profile for ", c.tag, ": ", profileErr)
c.logger.Error("fetch profile for ", c.tag, ": ", profileErr)
} else if profileSnapshot != nil {
credentialsChanged := c.applyProfileSnapshot(profileSnapshot)
c.persistOAuthAccount()
@@ -428,7 +442,7 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) bool {
}
}
}
return true
return nil
}
func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) bool {
@@ -439,7 +453,10 @@ func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) bool {
return true
}
}
c.tryRefreshCredentials(true)
err = c.tryRefreshCredentials(true)
if err != nil {
return false
}
currentCredentials := c.currentCredentials()
return currentCredentials != nil && currentCredentials.AccessToken != "" && currentCredentials.AccessToken != failedAccessToken
}
@@ -924,7 +941,16 @@ func (c *defaultCredential) pollUsage() {
return
}
body, _ := io.ReadAll(response.Body)
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
if response.StatusCode == http.StatusUnauthorized {
c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
if !c.recoverAuthFailure(accessToken) {
c.markCredentialsUnavailable(E.New("poll usage unauthorized for ", c.tag))
}
return
}
if !c.isPollBackoffAtCap() {
c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
}
c.incrementPollFailures()
return
}
@@ -941,7 +967,9 @@ func (c *defaultCredential) pollUsage() {
}
err = json.NewDecoder(response.Body).Decode(&usageResponse)
if err != nil {
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
if !c.isPollBackoffAtCap() {
c.logger.Error("poll usage for ", c.tag, ": decode: ", err)
}
c.incrementPollFailures()
return
}
@@ -982,7 +1010,7 @@ func (c *defaultCredential) pollUsage() {
if needsProfileFetch {
profileSnapshot, err := c.fetchProfileSnapshot(httpClient, accessToken)
if err != nil {
c.logger.Debug("fetch profile for ", c.tag, ": ", err)
c.logger.Error("fetch profile for ", c.tag, ": ", err)
return
}
if profileSnapshot != nil {

View File

@@ -9,7 +9,7 @@ import (
"time"
)
func TestGetAccessTokenReturnsExistingTokenWhenLockFails(t *testing.T) {
func TestGetAccessTokenMarksUnavailableWhenLockFails(t *testing.T) {
t.Parallel()
directory := t.TempDir()
@@ -32,15 +32,47 @@ func TestGetAccessTokenReturnsExistingTokenWhenLockFails(t *testing.T) {
}
credential.acquireLock = func(string) (func(), error) {
return nil, errors.New("locked")
return nil, errors.New("permission denied")
}
token, err := credential.getAccessToken()
if err != nil {
_, err := credential.getAccessToken()
if err == nil {
t.Fatal("expected error when lock acquisition fails, got nil")
}
if credential.isUsable() {
t.Fatal("credential should be marked unavailable after lock failure")
}
}
func TestGetAccessTokenMarksUnavailableOnUnwritableFile(t *testing.T) {
t.Parallel()
directory := t.TempDir()
credentialPath := filepath.Join(directory, ".credentials.json")
writeTestCredentials(t, credentialPath, &oauthCredentials{
AccessToken: "old-token",
RefreshToken: "refresh-token",
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
Scopes: []string{"user:profile", "user:inference"},
})
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
t.Fatal("refresh should not be attempted when file is not writable")
return nil, nil
}))
if err := credential.reloadCredentials(true); err != nil {
t.Fatal(err)
}
if token != "old-token" {
t.Fatalf("expected old token, got %q", token)
os.Chmod(credentialPath, 0o444)
t.Cleanup(func() { os.Chmod(credentialPath, 0o644) })
_, err := credential.getAccessToken()
if err == nil {
t.Fatal("expected error when credential file is not writable, got nil")
}
if credential.isUsable() {
t.Fatal("credential should be marked unavailable after write permission failure")
}
}

View File

@@ -598,7 +598,7 @@ func (c *externalCredential) pollUsage() {
ctx := c.getReverseContext()
response, err := c.doPollUsageRequest(ctx)
if err != nil {
c.logger.Debug("poll usage for ", c.tag, ": ", err)
c.logger.Error("poll usage for ", c.tag, ": ", err)
c.incrementPollFailures()
return
}
@@ -606,21 +606,21 @@ func (c *externalCredential) pollUsage() {
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
c.incrementPollFailures()
return
}
body, err := io.ReadAll(response.Body)
if err != nil {
c.logger.Debug("poll usage for ", c.tag, ": read body: ", err)
c.logger.Error("poll usage for ", c.tag, ": read body: ", err)
c.incrementPollFailures()
return
}
var rawFields map[string]json.RawMessage
err = json.Unmarshal(body, &rawFields)
if err != nil {
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
c.logger.Error("poll usage for ", c.tag, ": decode: ", err)
c.incrementPollFailures()
return
}
@@ -634,7 +634,7 @@ func (c *externalCredential) pollUsage() {
var statusResponse statusPayload
err = json.Unmarshal(body, &statusResponse)
if err != nil {
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
c.logger.Error("poll usage for ", c.tag, ": decode: ", err)
c.incrementPollFailures()
return
}

View File

@@ -75,7 +75,7 @@ func (c *defaultCredential) retryCredentialReloadIfNeeded() {
err := c.ensureCredentialWatcher()
if err != nil {
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
c.logger.Error("start credential watcher for ", c.tag, ": ", err)
}
_ = c.reloadCredentials(false)
}