mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-11 17:47:20 +10:00
fix(ccm): mark credential unavailable on refresh failure, handle poll 401
tryRefreshCredentials now returns error and calls markCredentialsUnavailable when lock acquisition or file write permission fails. getAccessToken propagates the error instead of silently returning the expired token. pollUsage handles 401 by attempting auth recovery and marking unavailable on failure. All credential error paths now use Error log level instead of Debug. Startup checks expired tokens eagerly via tryRefreshCredentials.
This commit is contained in:
@@ -158,11 +158,15 @@ func (c *defaultCredential) start() error {
|
||||
}
|
||||
err = c.ensureCredentialWatcher()
|
||||
if err != nil {
|
||||
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
|
||||
c.logger.Error("start credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
err = c.reloadCredentials(true)
|
||||
if err != nil {
|
||||
c.logger.Warn("initial credential load for ", c.tag, ": ", err)
|
||||
c.logger.Error("initial credential load for ", c.tag, ": ", err)
|
||||
}
|
||||
if c.credentials != nil && c.credentials.needsRefresh() &&
|
||||
slices.Contains(c.credentials.Scopes, "user:inference") {
|
||||
c.tryRefreshCredentials(false)
|
||||
}
|
||||
if c.usageTracker != nil {
|
||||
err = c.usageTracker.Load()
|
||||
@@ -240,7 +244,10 @@ func (c *defaultCredential) getAccessToken() (string, error) {
|
||||
if !currentCredentials.needsRefresh() || !slices.Contains(currentCredentials.Scopes, "user:inference") {
|
||||
return currentCredentials.AccessToken, nil
|
||||
}
|
||||
c.tryRefreshCredentials(false)
|
||||
refreshErr := c.tryRefreshCredentials(false)
|
||||
if refreshErr != nil {
|
||||
return "", refreshErr
|
||||
}
|
||||
c.access.RLock()
|
||||
defer c.access.RUnlock()
|
||||
if c.credentials != nil && c.credentials.AccessToken != "" {
|
||||
@@ -354,14 +361,14 @@ func (c *defaultCredential) shouldAttemptRefresh(credentials *oauthCredentials,
|
||||
return credentials.needsRefresh()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) tryRefreshCredentials(force bool) bool {
|
||||
func (c *defaultCredential) tryRefreshCredentials(force bool) error {
|
||||
latestCredentials, err := platformReadCredentials(c.credentialPath)
|
||||
if err == nil && latestCredentials != nil {
|
||||
c.absorbCredentials(latestCredentials)
|
||||
}
|
||||
currentCredentials := c.currentCredentials()
|
||||
if !c.shouldAttemptRefresh(currentCredentials, force) {
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
acquireLock := c.acquireLock
|
||||
if acquireLock == nil {
|
||||
@@ -369,8 +376,10 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) bool {
|
||||
}
|
||||
release, err := acquireLock(c.configDir)
|
||||
if err != nil {
|
||||
c.logger.Debug("acquire credential lock for ", c.tag, ": ", err)
|
||||
return false
|
||||
lockErr := E.Cause(err, "acquire credential lock for ", c.tag)
|
||||
c.logger.Error(lockErr)
|
||||
c.markCredentialsUnavailable(lockErr)
|
||||
return lockErr
|
||||
}
|
||||
defer release()
|
||||
|
||||
@@ -382,30 +391,35 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) bool {
|
||||
currentCredentials = c.currentCredentials()
|
||||
}
|
||||
if !c.shouldAttemptRefresh(currentCredentials, force) {
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
if err := platformCanWriteCredentials(c.credentialPath); err != nil {
|
||||
c.logger.Debug("credential file not writable for ", c.tag, ": ", err)
|
||||
return false
|
||||
err = platformCanWriteCredentials(c.credentialPath)
|
||||
if err != nil {
|
||||
writeErr := E.Cause(err, "credential file not writable for ", c.tag)
|
||||
c.logger.Error(writeErr)
|
||||
c.markCredentialsUnavailable(writeErr)
|
||||
return writeErr
|
||||
}
|
||||
|
||||
baseCredentials := cloneCredentials(currentCredentials)
|
||||
refreshResult, retryDelay, err := refreshToken(c.serviceContext, c.forwardHTTPClient, currentCredentials)
|
||||
if err != nil {
|
||||
if retryDelay != 0 {
|
||||
c.logger.Debug("refresh token for ", c.tag, ": retry delay=", retryDelay, ", error=", err)
|
||||
c.logger.Error("refresh token for ", c.tag, ": retry delay=", retryDelay, ", error=", err)
|
||||
} else {
|
||||
c.logger.Debug("refresh token for ", c.tag, ": ", err)
|
||||
c.logger.Error("refresh token for ", c.tag, ": ", err)
|
||||
}
|
||||
latestCredentials, readErr := platformReadCredentials(c.credentialPath)
|
||||
if readErr == nil && latestCredentials != nil {
|
||||
c.absorbCredentials(latestCredentials)
|
||||
return latestCredentials.AccessToken != "" && (latestCredentials.AccessToken != baseCredentials.AccessToken || !latestCredentials.needsRefresh())
|
||||
if latestCredentials.AccessToken != "" && (latestCredentials.AccessToken != baseCredentials.AccessToken || !latestCredentials.needsRefresh()) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return false
|
||||
return E.Cause(err, "refresh token for ", c.tag)
|
||||
}
|
||||
if refreshResult == nil || refreshResult.Credentials == nil {
|
||||
return false
|
||||
return E.New("refresh token for ", c.tag, ": empty result")
|
||||
}
|
||||
|
||||
refreshedCredentials := cloneCredentials(refreshResult.Credentials)
|
||||
@@ -419,7 +433,7 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) bool {
|
||||
if c.needsProfileHydration() {
|
||||
profileSnapshot, profileErr := c.fetchProfileSnapshot(c.forwardHTTPClient, refreshedCredentials.AccessToken)
|
||||
if profileErr != nil {
|
||||
c.logger.Debug("fetch profile for ", c.tag, ": ", profileErr)
|
||||
c.logger.Error("fetch profile for ", c.tag, ": ", profileErr)
|
||||
} else if profileSnapshot != nil {
|
||||
credentialsChanged := c.applyProfileSnapshot(profileSnapshot)
|
||||
c.persistOAuthAccount()
|
||||
@@ -428,7 +442,7 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) bool {
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) bool {
|
||||
@@ -439,7 +453,10 @@ func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
c.tryRefreshCredentials(true)
|
||||
err = c.tryRefreshCredentials(true)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
currentCredentials := c.currentCredentials()
|
||||
return currentCredentials != nil && currentCredentials.AccessToken != "" && currentCredentials.AccessToken != failedAccessToken
|
||||
}
|
||||
@@ -924,7 +941,16 @@ func (c *defaultCredential) pollUsage() {
|
||||
return
|
||||
}
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
if response.StatusCode == http.StatusUnauthorized {
|
||||
c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
if !c.recoverAuthFailure(accessToken) {
|
||||
c.markCredentialsUnavailable(E.New("poll usage unauthorized for ", c.tag))
|
||||
}
|
||||
return
|
||||
}
|
||||
if !c.isPollBackoffAtCap() {
|
||||
c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
}
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
@@ -941,7 +967,9 @@ func (c *defaultCredential) pollUsage() {
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&usageResponse)
|
||||
if err != nil {
|
||||
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
|
||||
if !c.isPollBackoffAtCap() {
|
||||
c.logger.Error("poll usage for ", c.tag, ": decode: ", err)
|
||||
}
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
@@ -982,7 +1010,7 @@ func (c *defaultCredential) pollUsage() {
|
||||
if needsProfileFetch {
|
||||
profileSnapshot, err := c.fetchProfileSnapshot(httpClient, accessToken)
|
||||
if err != nil {
|
||||
c.logger.Debug("fetch profile for ", c.tag, ": ", err)
|
||||
c.logger.Error("fetch profile for ", c.tag, ": ", err)
|
||||
return
|
||||
}
|
||||
if profileSnapshot != nil {
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetAccessTokenReturnsExistingTokenWhenLockFails(t *testing.T) {
|
||||
func TestGetAccessTokenMarksUnavailableWhenLockFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := t.TempDir()
|
||||
@@ -32,15 +32,47 @@ func TestGetAccessTokenReturnsExistingTokenWhenLockFails(t *testing.T) {
|
||||
}
|
||||
|
||||
credential.acquireLock = func(string) (func(), error) {
|
||||
return nil, errors.New("locked")
|
||||
return nil, errors.New("permission denied")
|
||||
}
|
||||
|
||||
token, err := credential.getAccessToken()
|
||||
if err != nil {
|
||||
_, err := credential.getAccessToken()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when lock acquisition fails, got nil")
|
||||
}
|
||||
if credential.isUsable() {
|
||||
t.Fatal("credential should be marked unavailable after lock failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccessTokenMarksUnavailableOnUnwritableFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := t.TempDir()
|
||||
credentialPath := filepath.Join(directory, ".credentials.json")
|
||||
writeTestCredentials(t, credentialPath, &oauthCredentials{
|
||||
AccessToken: "old-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
|
||||
Scopes: []string{"user:profile", "user:inference"},
|
||||
})
|
||||
|
||||
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
t.Fatal("refresh should not be attempted when file is not writable")
|
||||
return nil, nil
|
||||
}))
|
||||
if err := credential.reloadCredentials(true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if token != "old-token" {
|
||||
t.Fatalf("expected old token, got %q", token)
|
||||
|
||||
os.Chmod(credentialPath, 0o444)
|
||||
t.Cleanup(func() { os.Chmod(credentialPath, 0o644) })
|
||||
|
||||
_, err := credential.getAccessToken()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when credential file is not writable, got nil")
|
||||
}
|
||||
if credential.isUsable() {
|
||||
t.Fatal("credential should be marked unavailable after write permission failure")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -598,7 +598,7 @@ func (c *externalCredential) pollUsage() {
|
||||
ctx := c.getReverseContext()
|
||||
response, err := c.doPollUsageRequest(ctx)
|
||||
if err != nil {
|
||||
c.logger.Debug("poll usage for ", c.tag, ": ", err)
|
||||
c.logger.Error("poll usage for ", c.tag, ": ", err)
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
@@ -606,21 +606,21 @@ func (c *externalCredential) pollUsage() {
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
c.logger.Debug("poll usage for ", c.tag, ": read body: ", err)
|
||||
c.logger.Error("poll usage for ", c.tag, ": read body: ", err)
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
var rawFields map[string]json.RawMessage
|
||||
err = json.Unmarshal(body, &rawFields)
|
||||
if err != nil {
|
||||
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
|
||||
c.logger.Error("poll usage for ", c.tag, ": decode: ", err)
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
@@ -634,7 +634,7 @@ func (c *externalCredential) pollUsage() {
|
||||
var statusResponse statusPayload
|
||||
err = json.Unmarshal(body, &statusResponse)
|
||||
if err != nil {
|
||||
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
|
||||
c.logger.Error("poll usage for ", c.tag, ": decode: ", err)
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -75,7 +75,7 @@ func (c *defaultCredential) retryCredentialReloadIfNeeded() {
|
||||
|
||||
err := c.ensureCredentialWatcher()
|
||||
if err != nil {
|
||||
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
|
||||
c.logger.Error("start credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
_ = c.reloadCredentials(false)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user