fix(ccm): align default credential with Claude Code

This commit is contained in:
世界
2026-03-24 21:15:46 +08:00
parent 441c98890d
commit 92c8f4c5c8
13 changed files with 1566 additions and 316 deletions

View File

@@ -62,6 +62,7 @@ type credentialState struct {
accountUUID string
accountType string
rateLimitTier string
oauthAccount *claudeOAuthAccount
remotePlanWeight float64
lastUpdated time.Time
consecutivePollFailures int

View File

@@ -19,7 +19,14 @@ type claudeCodeConfig struct {
}
type claudeOAuthAccount struct {
AccountUUID string `json:"accountUuid"`
AccountUUID string `json:"accountUuid,omitempty"`
EmailAddress string `json:"emailAddress,omitempty"`
OrganizationUUID string `json:"organizationUuid,omitempty"`
DisplayName *string `json:"displayName,omitempty"`
HasExtraUsageEnabled *bool `json:"hasExtraUsageEnabled,omitempty"`
BillingType *string `json:"billingType,omitempty"`
AccountCreatedAt *string `json:"accountCreatedAt,omitempty"`
SubscriptionCreatedAt *string `json:"subscriptionCreatedAt,omitempty"`
}
// resolveClaudeConfigFile finds the Claude Code config file within the given directory.
@@ -33,8 +40,8 @@ type claudeOAuthAccount struct {
func resolveClaudeConfigFile(claudeDirectory string) string {
candidates := []string{
filepath.Join(claudeDirectory, ".config.json"),
filepath.Join(claudeDirectory, ".claude.json"),
filepath.Join(filepath.Dir(claudeDirectory), ".claude.json"),
filepath.Join(claudeDirectory, claudeCodeLegacyConfigFileName()),
filepath.Join(filepath.Dir(claudeDirectory), claudeCodeLegacyConfigFileName()),
}
for _, candidate := range candidates {
_, err := os.Stat(candidate)
@@ -57,3 +64,84 @@ func readClaudeCodeConfig(path string) (*claudeCodeConfig, error) {
}
return &config, nil
}
func resolveClaudeConfigWritePath(claudeDirectory string) string {
if claudeDirectory == "" {
return ""
}
existingPath := resolveClaudeConfigFile(claudeDirectory)
if existingPath != "" {
return existingPath
}
if os.Getenv("CLAUDE_CONFIG_DIR") != "" {
return filepath.Join(claudeDirectory, claudeCodeLegacyConfigFileName())
}
defaultClaudeDirectory := filepath.Join(filepath.Dir(claudeDirectory), ".claude")
if claudeDirectory != defaultClaudeDirectory {
return filepath.Join(claudeDirectory, claudeCodeLegacyConfigFileName())
}
return filepath.Join(filepath.Dir(claudeDirectory), claudeCodeLegacyConfigFileName())
}
func writeClaudeCodeOAuthAccount(path string, account *claudeOAuthAccount) error {
if path == "" || account == nil {
return nil
}
storage := jsonFileStorage{path: path}
return writeStorageValue(storage, "oauthAccount", account)
}
func claudeCodeLegacyConfigFileName() string {
if os.Getenv("CLAUDE_CODE_CUSTOM_OAUTH_URL") != "" {
return ".claude-custom-oauth.json"
}
return ".claude.json"
}
func cloneClaudeOAuthAccount(account *claudeOAuthAccount) *claudeOAuthAccount {
if account == nil {
return nil
}
cloned := *account
cloned.DisplayName = cloneStringPointer(account.DisplayName)
cloned.HasExtraUsageEnabled = cloneBoolPointer(account.HasExtraUsageEnabled)
cloned.BillingType = cloneStringPointer(account.BillingType)
cloned.AccountCreatedAt = cloneStringPointer(account.AccountCreatedAt)
cloned.SubscriptionCreatedAt = cloneStringPointer(account.SubscriptionCreatedAt)
return &cloned
}
func mergeClaudeOAuthAccount(base *claudeOAuthAccount, update *claudeOAuthAccount) *claudeOAuthAccount {
if update == nil {
return cloneClaudeOAuthAccount(base)
}
if base == nil {
return cloneClaudeOAuthAccount(update)
}
merged := cloneClaudeOAuthAccount(base)
if update.AccountUUID != "" {
merged.AccountUUID = update.AccountUUID
}
if update.EmailAddress != "" {
merged.EmailAddress = update.EmailAddress
}
if update.OrganizationUUID != "" {
merged.OrganizationUUID = update.OrganizationUUID
}
if update.DisplayName != nil {
merged.DisplayName = cloneStringPointer(update.DisplayName)
}
if update.HasExtraUsageEnabled != nil {
merged.HasExtraUsageEnabled = cloneBoolPointer(update.HasExtraUsageEnabled)
}
if update.BillingType != nil {
merged.BillingType = cloneStringPointer(update.BillingType)
}
if update.AccountCreatedAt != nil {
merged.AccountCreatedAt = cloneStringPointer(update.AccountCreatedAt)
}
if update.SubscriptionCreatedAt != nil {
merged.SubscriptionCreatedAt = cloneStringPointer(update.SubscriptionCreatedAt)
}
return merged
}

View File

@@ -14,6 +14,11 @@ import (
"github.com/keybase/go-keychain"
)
type keychainStorage struct {
service string
account string
}
func getKeychainServiceName() string {
configDirectory := os.Getenv("CLAUDE_CONFIG_DIR")
if configDirectory == "" {
@@ -76,72 +81,90 @@ func platformCanWriteCredentials(customPath string) error {
return checkCredentialFileWritable(customPath)
}
// platformWriteCredentials performs a read-modify-write on the keychain entry,
// preserving any fields or top-level keys not managed by CCM.
//
// ref (@anthropic-ai/claude-code @2.1.81): cli.js BP6 (line 179444-179454) — read-modify-write
func platformWriteCredentials(credentials *oauthCredentials, customPath string) error {
if customPath != "" {
return writeCredentialsToFile(credentials, customPath)
}
userInfo, err := getRealUser()
if err == nil {
serviceName := getKeychainServiceName()
existing := make(map[string]json.RawMessage)
query := keychain.NewItem()
query.SetSecClass(keychain.SecClassGenericPassword)
query.SetService(serviceName)
query.SetAccount(userInfo.Username)
query.SetMatchLimit(keychain.MatchLimitOne)
query.SetReturnData(true)
results, queryErr := keychain.QueryItem(query)
if queryErr == nil && len(results) == 1 {
_ = json.Unmarshal(results[0].Data, &existing)
}
credentialData, err := json.Marshal(credentials)
if err != nil {
return E.Cause(err, "marshal credentials")
}
existing["claudeAiOauth"] = credentialData
data, err := json.Marshal(existing)
if err != nil {
return E.Cause(err, "marshal credential container")
}
item := keychain.NewItem()
item.SetSecClass(keychain.SecClassGenericPassword)
item.SetService(serviceName)
item.SetAccount(userInfo.Username)
item.SetData(data)
item.SetAccessible(keychain.AccessibleWhenUnlocked)
err = keychain.AddItem(item)
if err == nil {
return nil
}
if err == keychain.ErrorDuplicateItem {
updateQuery := keychain.NewItem()
updateQuery.SetSecClass(keychain.SecClassGenericPassword)
updateQuery.SetService(serviceName)
updateQuery.SetAccount(userInfo.Username)
updateItem := keychain.NewItem()
updateItem.SetData(data)
updateErr := keychain.UpdateItem(updateQuery, updateItem)
if updateErr == nil {
return nil
}
}
}
defaultPath, err := getDefaultCredentialsPath()
if err != nil {
return err
}
return writeCredentialsToFile(credentials, defaultPath)
fileStorage := jsonFileStorage{path: defaultPath}
userInfo, err := getRealUser()
if err != nil {
return writeCredentialsToFile(credentials, defaultPath)
}
return persistStorageValue(keychainStorage{
service: getKeychainServiceName(),
account: userInfo.Username,
}, fileStorage, "claudeAiOauth", credentials)
}
func (s keychainStorage) readContainer() (map[string]json.RawMessage, bool, error) {
query := keychain.NewItem()
query.SetSecClass(keychain.SecClassGenericPassword)
query.SetService(s.service)
query.SetAccount(s.account)
query.SetMatchLimit(keychain.MatchLimitOne)
query.SetReturnData(true)
results, err := keychain.QueryItem(query)
if err != nil {
if err == keychain.ErrorItemNotFound {
return make(map[string]json.RawMessage), false, nil
}
return nil, false, E.Cause(err, "query keychain")
}
if len(results) != 1 {
return make(map[string]json.RawMessage), false, nil
}
container := make(map[string]json.RawMessage)
if len(results[0].Data) == 0 {
return container, true, nil
}
if err := json.Unmarshal(results[0].Data, &container); err != nil {
return nil, true, err
}
return container, true, nil
}
func (s keychainStorage) writeContainer(container map[string]json.RawMessage) error {
data, err := json.Marshal(container)
if err != nil {
return err
}
item := keychain.NewItem()
item.SetSecClass(keychain.SecClassGenericPassword)
item.SetService(s.service)
item.SetAccount(s.account)
item.SetData(data)
item.SetAccessible(keychain.AccessibleWhenUnlocked)
err = keychain.AddItem(item)
if err == nil {
return nil
}
if err != keychain.ErrorDuplicateItem {
return err
}
updateQuery := keychain.NewItem()
updateQuery.SetSecClass(keychain.SecClassGenericPassword)
updateQuery.SetService(s.service)
updateQuery.SetAccount(s.account)
updateItem := keychain.NewItem()
updateItem.SetData(data)
return keychain.UpdateItem(updateQuery, updateItem)
}
func (s keychainStorage) delete() error {
err := keychain.DeleteGenericPasswordItem(s.service, s.account)
if err != nil && err != keychain.ErrorItemNotFound {
return err
}
return nil
}

View File

@@ -26,6 +26,15 @@ import (
"github.com/sagernet/sing/common/observable"
)
var acquireCredentialLockFunc = acquireCredentialLock
type claudeProfileSnapshot struct {
OAuthAccount *claudeOAuthAccount
AccountType string
RateLimitTier string
SubscriptionType *string
}
type defaultCredential struct {
tag string
serviceContext context.Context
@@ -33,8 +42,9 @@ type defaultCredential struct {
claudeDirectory string
credentialFilePath string
configDir string
claudeConfigPath string
syncClaudeConfig bool
deviceID string
configLoaded bool
credentials *oauthCredentials
access sync.RWMutex
state credentialState
@@ -52,11 +62,6 @@ type defaultCredential struct {
statusSubscriber *observable.Subscriber[struct{}]
// Refresh rate-limit cooldown (protected by access mutex)
refreshRetryAt time.Time
refreshRetryError error
refreshBlocked bool
// Connection interruption
interrupted bool
requestContext context.Context
@@ -113,6 +118,7 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef
serviceContext: ctx,
credentialPath: options.CredentialPath,
claudeDirectory: options.ClaudeDirectory,
syncClaudeConfig: options.ClaudeDirectory != "" || options.CredentialPath == "",
cap5h: cap5h,
capWeekly: capWeekly,
forwardHTTPClient: httpClient,
@@ -133,7 +139,6 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef
func (c *defaultCredential) start() error {
if c.claudeDirectory != "" {
c.loadClaudeCodeConfig()
if c.credentialPath == "" {
c.credentialPath = filepath.Join(c.claudeDirectory, ".credentials.json")
}
@@ -144,6 +149,13 @@ func (c *defaultCredential) start() error {
}
c.credentialFilePath = credentialFilePath
c.configDir = resolveConfigDir(c.credentialPath, credentialFilePath)
if c.syncClaudeConfig {
if c.claudeDirectory == "" {
c.claudeDirectory = c.configDir
}
c.claudeConfigPath = resolveClaudeConfigWritePath(c.claudeDirectory)
c.loadClaudeCodeConfig()
}
err = c.ensureCredentialWatcher()
if err != nil {
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
@@ -173,6 +185,7 @@ func (c *defaultCredential) loadClaudeCodeConfig() {
return
}
c.stateAccess.Lock()
c.state.oauthAccount = cloneClaudeOAuthAccount(config.OAuthAccount)
if config.OAuthAccount != nil && config.OAuthAccount.AccountUUID != "" {
c.state.accountUUID = config.OAuthAccount.AccountUUID
}
@@ -180,7 +193,7 @@ func (c *defaultCredential) loadClaudeCodeConfig() {
if config.UserID != "" {
c.deviceID = config.UserID
}
c.configLoaded = true
c.claudeConfigPath = configFilePath
c.logger.Debug("loaded claude code config for ", c.tag, ": account=", c.state.accountUUID, ", device=", c.deviceID)
}
@@ -209,147 +222,358 @@ func (c *defaultCredential) statusSnapshotLocked() statusSnapshot {
func (c *defaultCredential) getAccessToken() (string, error) {
c.retryCredentialReloadIfNeeded()
// Fast path: cached token is still valid
c.access.RLock()
if c.credentials != nil && !c.credentials.needsRefresh() {
token := c.credentials.AccessToken
c.access.RUnlock()
return token, nil
}
currentCredentials := cloneCredentials(c.credentials)
c.access.RUnlock()
// Reload from disk — Claude Code or another process may have refreshed
err := c.reloadCredentials(true)
if err == nil {
c.access.RLock()
if c.credentials != nil && !c.credentials.needsRefresh() {
token := c.credentials.AccessToken
c.access.RUnlock()
return token, nil
if currentCredentials == nil {
err := c.reloadCredentials(true)
if err != nil {
return "", err
}
c.access.RLock()
currentCredentials = cloneCredentials(c.credentials)
c.access.RUnlock()
}
// ref (@anthropic-ai/claude-code @2.1.81): cli.js _P1 line 179526
// Claude Code skips refresh for tokens without user:inference scope.
// Return existing token (may be expired); 401 recovery is the safety net.
c.access.RLock()
if c.credentials != nil && !slices.Contains(c.credentials.Scopes, "user:inference") {
token := c.credentials.AccessToken
c.access.RUnlock()
return token, nil
}
c.access.RUnlock()
// Acquire cross-process lock before refresh (outside Go mutex to avoid holding mutex during sleep)
// ref: cli.js _P1 (line 179534-179536) — proper-lockfile lock on config dir
release, lockErr := acquireCredentialLock(c.configDir)
if lockErr != nil {
c.logger.Debug("acquire credential lock for ", c.tag, ": ", lockErr)
release = func() {}
}
defer release()
// ref: cli.js _P1 (line 179559-179562) — re-read after lock, skip if race resolved
_ = c.reloadCredentials(true)
c.access.RLock()
noRefreshToken := c.credentials == nil || c.credentials.RefreshToken == ""
raceResolved := !noRefreshToken && !c.credentials.needsRefresh()
var racedToken string
if (noRefreshToken || raceResolved) && c.credentials != nil {
racedToken = c.credentials.AccessToken
}
c.access.RUnlock()
if noRefreshToken || raceResolved {
return racedToken, nil
}
// Slow path: acquire Go mutex and refresh
c.access.Lock()
defer c.access.Unlock()
if c.credentials == nil {
if currentCredentials == nil {
return "", c.unavailableError()
}
if !c.credentials.needsRefresh() {
if !currentCredentials.needsRefresh() || !slices.Contains(currentCredentials.Scopes, "user:inference") {
return currentCredentials.AccessToken, nil
}
c.tryRefreshCredentials(false)
c.access.RLock()
defer c.access.RUnlock()
if c.credentials != nil && c.credentials.AccessToken != "" {
return c.credentials.AccessToken, nil
}
return "", c.unavailableError()
}
if c.refreshBlocked {
return "", c.refreshRetryError
}
if !c.refreshRetryAt.IsZero() && time.Now().Before(c.refreshRetryAt) {
return "", c.refreshRetryError
}
func (c *defaultCredential) shouldUseClaudeConfig() bool {
return c.syncClaudeConfig && c.claudeConfigPath != ""
}
err = platformCanWriteCredentials(c.credentialPath)
if err != nil {
return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation")
}
func (c *defaultCredential) absorbCredentials(credentials *oauthCredentials) {
c.access.Lock()
c.credentials = cloneCredentials(credentials)
c.access.Unlock()
baseCredentials := cloneCredentials(c.credentials)
newCredentials, retryDelay, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials)
if err != nil {
if retryDelay < 0 {
c.refreshBlocked = true
c.refreshRetryError = err
} else if retryDelay > 0 {
c.refreshRetryAt = time.Now().Add(retryDelay)
c.refreshRetryError = err
}
// ref: cli.js _P1 (line 179568-179573) — post-failure recovery:
// re-read from disk; if another process refreshed successfully, use that.
// Cannot call reloadCredentials here (deadlock: already holding c.access).
latestCredentials, readErr := platformReadCredentials(c.credentialPath)
if readErr == nil && latestCredentials != nil && !latestCredentials.needsRefresh() {
c.credentials = latestCredentials
return latestCredentials.AccessToken, nil
}
return "", err
}
c.refreshRetryAt = time.Time{}
c.refreshRetryError = nil
c.refreshBlocked = false
latestCredentials, latestErr := platformReadCredentials(c.credentialPath)
if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) {
c.credentials = latestCredentials
c.stateAccess.Lock()
before := c.statusSnapshotLocked()
c.state.unavailable = false
c.state.lastCredentialLoadAttempt = time.Now()
c.state.lastCredentialLoadError = ""
c.checkTransitionLocked()
shouldEmit := before != c.statusSnapshotLocked()
c.stateAccess.Unlock()
if shouldEmit {
c.emitStatusUpdate()
}
if !latestCredentials.needsRefresh() {
return latestCredentials.AccessToken, nil
}
return "", E.New("credential ", c.tag, " changed while refreshing")
}
c.credentials = newCredentials
c.stateAccess.Lock()
before := c.statusSnapshotLocked()
c.state.unavailable = false
c.state.lastCredentialLoadAttempt = time.Now()
c.state.lastCredentialLoadError = ""
c.applyCredentialMetadataLocked(credentials)
c.checkTransitionLocked()
shouldEmit := before != c.statusSnapshotLocked()
c.stateAccess.Unlock()
if shouldEmit {
c.emitStatusUpdate()
}
}
err = platformWriteCredentials(newCredentials, c.credentialPath)
if err != nil {
func (c *defaultCredential) applyCredentialMetadataLocked(credentials *oauthCredentials) {
if credentials == nil {
return
}
if credentials.SubscriptionType != nil && *credentials.SubscriptionType != "" {
c.state.accountType = *credentials.SubscriptionType
}
if credentials.RateLimitTier != nil && *credentials.RateLimitTier != "" {
c.state.rateLimitTier = *credentials.RateLimitTier
}
}
func (c *defaultCredential) absorbOAuthAccount(account *claudeOAuthAccount) {
c.stateAccess.Lock()
c.state.oauthAccount = mergeClaudeOAuthAccount(c.state.oauthAccount, account)
if c.state.oauthAccount != nil && c.state.oauthAccount.AccountUUID != "" {
c.state.accountUUID = c.state.oauthAccount.AccountUUID
}
c.stateAccess.Unlock()
}
func (c *defaultCredential) persistOAuthAccount() {
if !c.shouldUseClaudeConfig() {
return
}
c.stateAccess.RLock()
account := cloneClaudeOAuthAccount(c.state.oauthAccount)
c.stateAccess.RUnlock()
if account == nil {
return
}
if err := writeClaudeCodeOAuthAccount(c.claudeConfigPath, account); err != nil {
c.logger.Debug("write claude code config for ", c.tag, ": ", err)
}
}
func (c *defaultCredential) needsProfileHydration() bool {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.needsProfileHydrationLocked()
}
func (c *defaultCredential) needsProfileHydrationLocked() bool {
if c.state.accountUUID == "" || c.state.accountType == "" || c.state.rateLimitTier == "" {
return true
}
if c.state.oauthAccount == nil {
return true
}
return c.state.oauthAccount.BillingType == nil ||
c.state.oauthAccount.AccountCreatedAt == nil ||
c.state.oauthAccount.SubscriptionCreatedAt == nil
}
func (c *defaultCredential) currentCredentials() *oauthCredentials {
c.access.RLock()
defer c.access.RUnlock()
return cloneCredentials(c.credentials)
}
func (c *defaultCredential) persistCredentials(credentials *oauthCredentials) {
if credentials == nil {
return
}
if err := platformWriteCredentials(credentials, c.credentialPath); err != nil {
c.logger.Error("persist refreshed token for ", c.tag, ": ", err)
}
}
return newCredentials.AccessToken, nil
func (c *defaultCredential) shouldAttemptRefresh(credentials *oauthCredentials, force bool) bool {
if credentials == nil || credentials.RefreshToken == "" {
return false
}
if !slices.Contains(credentials.Scopes, "user:inference") {
return false
}
if force {
return true
}
return credentials.needsRefresh()
}
func (c *defaultCredential) tryRefreshCredentials(force bool) bool {
latestCredentials, err := platformReadCredentials(c.credentialPath)
if err == nil && latestCredentials != nil {
c.absorbCredentials(latestCredentials)
}
currentCredentials := c.currentCredentials()
if !c.shouldAttemptRefresh(currentCredentials, force) {
return false
}
release, err := acquireCredentialLockFunc(c.configDir)
if err != nil {
c.logger.Debug("acquire credential lock for ", c.tag, ": ", err)
return false
}
defer release()
latestCredentials, err = platformReadCredentials(c.credentialPath)
if err == nil && latestCredentials != nil {
c.absorbCredentials(latestCredentials)
currentCredentials = latestCredentials
} else {
currentCredentials = c.currentCredentials()
}
if !c.shouldAttemptRefresh(currentCredentials, force) {
return false
}
if err := platformCanWriteCredentials(c.credentialPath); err != nil {
c.logger.Debug("credential file not writable for ", c.tag, ": ", err)
return false
}
baseCredentials := cloneCredentials(currentCredentials)
refreshResult, retryDelay, err := refreshToken(c.serviceContext, c.forwardHTTPClient, currentCredentials)
if err != nil {
if retryDelay != 0 {
c.logger.Debug("refresh token for ", c.tag, ": retry delay=", retryDelay, ", error=", err)
} else {
c.logger.Debug("refresh token for ", c.tag, ": ", err)
}
latestCredentials, readErr := platformReadCredentials(c.credentialPath)
if readErr == nil && latestCredentials != nil {
c.absorbCredentials(latestCredentials)
return latestCredentials.AccessToken != "" && (latestCredentials.AccessToken != baseCredentials.AccessToken || !latestCredentials.needsRefresh())
}
return false
}
if refreshResult == nil || refreshResult.Credentials == nil {
return false
}
refreshedCredentials := cloneCredentials(refreshResult.Credentials)
c.absorbCredentials(refreshedCredentials)
c.persistCredentials(refreshedCredentials)
if refreshResult.TokenAccount != nil {
c.absorbOAuthAccount(refreshResult.TokenAccount)
c.persistOAuthAccount()
}
if c.needsProfileHydration() {
profileSnapshot, profileErr := c.fetchProfileSnapshot(c.forwardHTTPClient, refreshedCredentials.AccessToken)
if profileErr != nil {
c.logger.Debug("fetch profile for ", c.tag, ": ", profileErr)
} else if profileSnapshot != nil {
credentialsChanged := c.applyProfileSnapshot(profileSnapshot)
c.persistOAuthAccount()
if credentialsChanged {
c.persistCredentials(c.currentCredentials())
}
}
}
return true
}
func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) bool {
latestCredentials, err := platformReadCredentials(c.credentialPath)
if err == nil && latestCredentials != nil {
c.absorbCredentials(latestCredentials)
if latestCredentials.AccessToken != "" && latestCredentials.AccessToken != failedAccessToken {
return true
}
}
c.tryRefreshCredentials(true)
currentCredentials := c.currentCredentials()
return currentCredentials != nil && currentCredentials.AccessToken != "" && currentCredentials.AccessToken != failedAccessToken
}
func (c *defaultCredential) applyProfileSnapshot(snapshot *claudeProfileSnapshot) bool {
if snapshot == nil {
return false
}
credentialsChanged := false
c.access.Lock()
if c.credentials != nil {
updatedCredentials := cloneCredentials(c.credentials)
if snapshot.SubscriptionType != nil {
updatedCredentials.SubscriptionType = cloneStringPointer(snapshot.SubscriptionType)
}
if snapshot.RateLimitTier != "" {
updatedCredentials.RateLimitTier = cloneStringPointer(&snapshot.RateLimitTier)
}
credentialsChanged = !credentialsEqual(c.credentials, updatedCredentials)
c.credentials = updatedCredentials
}
c.access.Unlock()
c.stateAccess.Lock()
before := c.statusSnapshotLocked()
if snapshot.OAuthAccount != nil {
c.state.oauthAccount = mergeClaudeOAuthAccount(c.state.oauthAccount, snapshot.OAuthAccount)
if c.state.oauthAccount != nil && c.state.oauthAccount.AccountUUID != "" {
c.state.accountUUID = c.state.oauthAccount.AccountUUID
}
}
if snapshot.AccountType != "" {
c.state.accountType = snapshot.AccountType
}
if snapshot.RateLimitTier != "" {
c.state.rateLimitTier = snapshot.RateLimitTier
}
c.checkTransitionLocked()
shouldEmit := before != c.statusSnapshotLocked()
c.stateAccess.Unlock()
if shouldEmit {
c.emitStatusUpdate()
}
return credentialsChanged
}
func (c *defaultCredential) fetchProfileSnapshot(httpClient *http.Client, accessToken string) (*claudeProfileSnapshot, error) {
ctx := c.serviceContext
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/profile", nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+accessToken)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("User-Agent", ccmUserAgentValue)
return request, nil
})
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
return nil, E.New("status ", response.StatusCode, " ", string(body))
}
var profileResponse struct {
Account *struct {
UUID string `json:"uuid"`
Email string `json:"email"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
} `json:"account"`
Organization *struct {
UUID string `json:"uuid"`
OrganizationType string `json:"organization_type"`
RateLimitTier string `json:"rate_limit_tier"`
HasExtraUsageEnabled *bool `json:"has_extra_usage_enabled"`
BillingType *string `json:"billing_type"`
SubscriptionCreatedAt *string `json:"subscription_created_at"`
} `json:"organization"`
}
if err := json.NewDecoder(response.Body).Decode(&profileResponse); err != nil {
return nil, err
}
if profileResponse.Organization == nil {
return nil, nil
}
accountType := normalizeClaudeOrganizationType(profileResponse.Organization.OrganizationType)
snapshot := &claudeProfileSnapshot{
AccountType: accountType,
RateLimitTier: profileResponse.Organization.RateLimitTier,
}
if accountType != "" {
snapshot.SubscriptionType = cloneStringPointer(&accountType)
}
account := &claudeOAuthAccount{}
if profileResponse.Account != nil {
account.AccountUUID = profileResponse.Account.UUID
account.EmailAddress = profileResponse.Account.Email
account.DisplayName = optionalStringPointer(profileResponse.Account.DisplayName)
account.AccountCreatedAt = optionalStringPointer(profileResponse.Account.CreatedAt)
}
account.OrganizationUUID = profileResponse.Organization.UUID
account.HasExtraUsageEnabled = cloneBoolPointer(profileResponse.Organization.HasExtraUsageEnabled)
account.BillingType = cloneStringPointer(profileResponse.Organization.BillingType)
account.SubscriptionCreatedAt = cloneStringPointer(profileResponse.Organization.SubscriptionCreatedAt)
if account.AccountUUID != "" || account.EmailAddress != "" || account.OrganizationUUID != "" || account.DisplayName != nil ||
account.HasExtraUsageEnabled != nil || account.BillingType != nil || account.AccountCreatedAt != nil || account.SubscriptionCreatedAt != nil {
snapshot.OAuthAccount = account
}
return snapshot, nil
}
func normalizeClaudeOrganizationType(organizationType string) string {
switch organizationType {
case "claude_pro":
return "pro"
case "claude_max":
return "max"
case "claude_team":
return "team"
case "claude_enterprise":
return "enterprise"
default:
return ""
}
}
func optionalStringPointer(value string) *string {
if value == "" {
return nil
}
return &value
}
func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
@@ -727,7 +951,7 @@ func (c *defaultCredential) pollUsage() {
}
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
}
needsProfileFetch := !c.configLoaded && c.state.rateLimitTier == ""
needsProfileFetch := c.needsProfileHydrationLocked()
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
@@ -736,83 +960,19 @@ func (c *defaultCredential) pollUsage() {
c.emitStatusUpdate()
if needsProfileFetch {
c.fetchProfile(httpClient, accessToken)
}
}
// fetchProfile calls GET /api/oauth/profile to retrieve account and organization info.
// Same endpoint used by Claude Code (@anthropic-ai/claude-code @2.1.81):
//
// ref: cli.js GB() — fetches profile
// ref: cli.js AH8() / fetchProfileInfo — parses organization_type, rate_limit_tier
// ref: cli.js EX1() / populateOAuthAccountInfoIfNeeded — stores account.uuid
func (c *defaultCredential) fetchProfile(httpClient *http.Client, accessToken string) {
ctx := c.serviceContext
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/profile", nil)
profileSnapshot, err := c.fetchProfileSnapshot(httpClient, accessToken)
if err != nil {
return nil, err
c.logger.Debug("fetch profile for ", c.tag, ": ", err)
return
}
if profileSnapshot != nil {
credentialsChanged := c.applyProfileSnapshot(profileSnapshot)
c.persistOAuthAccount()
if credentialsChanged {
c.persistCredentials(c.currentCredentials())
}
}
request.Header.Set("Authorization", "Bearer "+accessToken)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("User-Agent", ccmUserAgentValue)
return request, nil
})
if err != nil {
c.logger.Debug("fetch profile for ", c.tag, ": ", err)
return
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return
}
var profileResponse struct {
Account *struct {
UUID string `json:"uuid"`
} `json:"account"`
Organization *struct {
OrganizationType string `json:"organization_type"`
RateLimitTier string `json:"rate_limit_tier"`
} `json:"organization"`
}
err = json.NewDecoder(response.Body).Decode(&profileResponse)
if err != nil || profileResponse.Organization == nil {
return
}
accountType := ""
switch profileResponse.Organization.OrganizationType {
case "claude_pro":
accountType = "pro"
case "claude_max":
accountType = "max"
case "claude_team":
accountType = "team"
case "claude_enterprise":
accountType = "enterprise"
}
rateLimitTier := profileResponse.Organization.RateLimitTier
c.stateAccess.Lock()
before := c.statusSnapshotLocked()
if profileResponse.Account != nil && profileResponse.Account.UUID != "" {
c.state.accountUUID = profileResponse.Account.UUID
}
if accountType != "" && c.state.accountType == "" {
c.state.accountType = accountType
}
if rateLimitTier != "" {
c.state.rateLimitTier = rateLimitTier
}
resolvedAccountType := c.state.accountType
shouldEmit := before != c.statusSnapshotLocked()
c.stateAccess.Unlock()
if shouldEmit {
c.emitStatusUpdate()
}
c.logger.Info("fetched profile for ", c.tag, ": type=", resolvedAccountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(resolvedAccountType, rateLimitTier))
}
func (c *defaultCredential) close() {

View File

@@ -0,0 +1,205 @@
package ccm
import (
"errors"
"net/http"
"os"
"path/filepath"
"testing"
"time"
)
func TestGetAccessTokenReturnsExistingTokenWhenLockFails(t *testing.T) {
t.Parallel()
directory := t.TempDir()
credentialPath := filepath.Join(directory, ".credentials.json")
writeTestCredentials(t, credentialPath, &oauthCredentials{
AccessToken: "old-token",
RefreshToken: "refresh-token",
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
Scopes: []string{"user:profile", "user:inference"},
SubscriptionType: optionalStringPointer("max"),
RateLimitTier: optionalStringPointer("default_claude_max_20x"),
})
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
t.Fatal("refresh should not be attempted when lock acquisition fails")
return nil, nil
}))
if err := credential.reloadCredentials(true); err != nil {
t.Fatal(err)
}
originalLockFunc := acquireCredentialLockFunc
acquireCredentialLockFunc = func(string) (func(), error) {
return nil, errors.New("locked")
}
t.Cleanup(func() {
acquireCredentialLockFunc = originalLockFunc
})
token, err := credential.getAccessToken()
if err != nil {
t.Fatal(err)
}
if token != "old-token" {
t.Fatalf("expected old token, got %q", token)
}
}
func TestGetAccessTokenAbsorbsRefreshDoneByAnotherProcess(t *testing.T) {
t.Parallel()
directory := t.TempDir()
credentialPath := filepath.Join(directory, ".credentials.json")
oldCredentials := &oauthCredentials{
AccessToken: "old-token",
RefreshToken: "refresh-token",
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
Scopes: []string{"user:profile", "user:inference"},
SubscriptionType: optionalStringPointer("max"),
RateLimitTier: optionalStringPointer("default_claude_max_20x"),
}
writeTestCredentials(t, credentialPath, oldCredentials)
newCredentials := cloneCredentials(oldCredentials)
newCredentials.AccessToken = "new-token"
newCredentials.ExpiresAt = time.Now().Add(time.Hour).UnixMilli()
transport := roundTripFunc(func(request *http.Request) (*http.Response, error) {
if request.URL.Path == "/v1/oauth/token" {
writeTestCredentials(t, credentialPath, newCredentials)
return newJSONResponse(http.StatusInternalServerError, `{"error":"boom"}`), nil
}
t.Fatalf("unexpected path %s", request.URL.Path)
return nil, nil
})
credential := newTestDefaultCredential(t, credentialPath, transport)
if err := credential.reloadCredentials(true); err != nil {
t.Fatal(err)
}
token, err := credential.getAccessToken()
if err != nil {
t.Fatal(err)
}
if token != "new-token" {
t.Fatalf("expected refreshed token from disk, got %q", token)
}
}
func TestCustomCredentialPathDoesNotEnableClaudeConfigSync(t *testing.T) {
t.Parallel()
directory := t.TempDir()
credentialPath := filepath.Join(directory, ".credentials.json")
writeTestCredentials(t, credentialPath, &oauthCredentials{
AccessToken: "token",
ExpiresAt: time.Now().Add(time.Hour).UnixMilli(),
Scopes: []string{"user:profile"},
})
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
t.Fatalf("unexpected request to %s", request.URL.Path)
return nil, nil
}))
if err := credential.reloadCredentials(true); err != nil {
t.Fatal(err)
}
token, err := credential.getAccessToken()
if err != nil {
t.Fatal(err)
}
if token != "token" {
t.Fatalf("expected token, got %q", token)
}
if credential.shouldUseClaudeConfig() {
t.Fatal("custom credential path should not enable Claude config sync")
}
if _, err := os.Stat(filepath.Join(directory, ".claude.json")); !os.IsNotExist(err) {
t.Fatalf("did not expect config file to be created, stat err=%v", err)
}
}
func TestDefaultCredentialHydratesProfileAndWritesConfig(t *testing.T) {
configDir := t.TempDir()
credentialPath := filepath.Join(configDir, ".credentials.json")
writeTestCredentials(t, credentialPath, &oauthCredentials{
AccessToken: "old-token",
RefreshToken: "refresh-token",
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
Scopes: []string{"user:profile", "user:inference"},
})
transport := roundTripFunc(func(request *http.Request) (*http.Response, error) {
switch request.URL.Path {
case "/v1/oauth/token":
return newJSONResponse(http.StatusOK, `{
"access_token":"new-token",
"refresh_token":"new-refresh",
"expires_in":3600,
"account":{"uuid":"account","email_address":"user@example.com"},
"organization":{"uuid":"org"}
}`), nil
case "/api/oauth/profile":
return newJSONResponse(http.StatusOK, `{
"account":{
"uuid":"account",
"email":"user@example.com",
"display_name":"User",
"created_at":"2024-01-01T00:00:00Z"
},
"organization":{
"uuid":"org",
"organization_type":"claude_max",
"rate_limit_tier":"default_claude_max_20x",
"has_extra_usage_enabled":true,
"billing_type":"individual",
"subscription_created_at":"2024-01-02T00:00:00Z"
}
}`), nil
default:
t.Fatalf("unexpected path %s", request.URL.Path)
return nil, nil
}
})
credential := newTestDefaultCredential(t, credentialPath, transport)
credential.syncClaudeConfig = true
credential.claudeDirectory = configDir
credential.claudeConfigPath = resolveClaudeConfigWritePath(configDir)
if err := credential.reloadCredentials(true); err != nil {
t.Fatal(err)
}
token, err := credential.getAccessToken()
if err != nil {
t.Fatal(err)
}
if token != "new-token" {
t.Fatalf("expected refreshed token, got %q", token)
}
updatedCredentials := readTestCredentials(t, credentialPath)
if updatedCredentials.SubscriptionType == nil || *updatedCredentials.SubscriptionType != "max" {
t.Fatalf("expected subscription type to be persisted, got %#v", updatedCredentials.SubscriptionType)
}
if updatedCredentials.RateLimitTier == nil || *updatedCredentials.RateLimitTier != "default_claude_max_20x" {
t.Fatalf("expected rate limit tier to be persisted, got %#v", updatedCredentials.RateLimitTier)
}
configPath := tempConfigPath(t, configDir)
config, err := readClaudeCodeConfig(configPath)
if err != nil {
t.Fatal(err)
}
if config.OAuthAccount == nil || config.OAuthAccount.AccountUUID != "account" || config.OAuthAccount.EmailAddress != "user@example.com" {
t.Fatalf("unexpected oauth account: %#v", config.OAuthAccount)
}
if config.OAuthAccount.BillingType == nil || *config.OAuthAccount.BillingType != "individual" {
t.Fatalf("expected billing type to be hydrated, got %#v", config.OAuthAccount.BillingType)
}
}

View File

@@ -106,24 +106,7 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
return c.markCredentialsUnavailable(E.Cause(err, "read credentials"))
}
c.access.Lock()
c.credentials = credentials
c.refreshRetryAt = time.Time{}
c.refreshRetryError = nil
c.refreshBlocked = false
c.access.Unlock()
c.stateAccess.Lock()
before := c.statusSnapshotLocked()
c.state.unavailable = false
c.state.lastCredentialLoadError = ""
c.checkTransitionLocked()
shouldEmit := before != c.statusSnapshotLocked()
c.stateAccess.Unlock()
if shouldEmit {
c.emitStatusUpdate()
}
c.absorbCredentials(credentials)
return nil
}

View File

@@ -164,21 +164,7 @@ func checkCredentialFileWritable(path string) error {
// ref (@anthropic-ai/claude-code @2.1.81): cli.js BP6 (line 179444-179454) — read-modify-write
// ref: cli.js qD1.update (line 176156) — writeFileSync + chmod 0o600
func writeCredentialsToFile(credentials *oauthCredentials, path string) error {
existing := make(map[string]json.RawMessage)
data, readErr := os.ReadFile(path)
if readErr == nil {
_ = json.Unmarshal(data, &existing)
}
credentialData, err := json.Marshal(credentials)
if err != nil {
return err
}
existing["claudeAiOauth"] = credentialData
data, err = json.MarshalIndent(existing, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o600)
return writeStorageValue(jsonFileStorage{path: path}, "claudeAiOauth", credentials)
}
// oauthCredentials mirrors the claudeAiOauth object in Claude Code's
@@ -194,6 +180,12 @@ type oauthCredentials struct {
RateLimitTier *string `json:"rateLimitTier"` // ref: cli.js line 179452 (?? null)
}
type oauthRefreshResult struct {
Credentials *oauthCredentials
TokenAccount *claudeOAuthAccount
Profile *claudeProfileSnapshot
}
func (c *oauthCredentials) needsRefresh() bool {
if c.ExpiresAt == 0 {
return false
@@ -201,7 +193,7 @@ func (c *oauthCredentials) needsRefresh() bool {
return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs
}
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, time.Duration, error) {
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthRefreshResult, time.Duration, error) {
if credentials.RefreshToken == "" {
return nil, 0, E.New("refresh token is empty")
}
@@ -249,10 +241,17 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau
// ref (@anthropic-ai/claude-code @2.1.81): cli.js mB6 response (line 172769-172772)
var tokenResponse struct {
AccessToken string `json:"access_token"` // ref: cli.js line 172770 z
RefreshToken string `json:"refresh_token"` // ref: cli.js line 172770 w (defaults to input)
ExpiresIn int `json:"expires_in"` // ref: cli.js line 172770 O
Scope string `json:"scope"` // ref: cli.js line 172772 uB6(Y.scope)
AccessToken string `json:"access_token"` // ref: cli.js line 172770 z
RefreshToken string `json:"refresh_token"` // ref: cli.js line 172770 w (defaults to input)
ExpiresIn int `json:"expires_in"` // ref: cli.js line 172770 O
Scope *string `json:"scope"` // ref: cli.js line 172772 uB6(Y.scope)
Account *struct {
UUID string `json:"uuid"`
EmailAddress string `json:"email_address"`
} `json:"account"`
Organization *struct {
UUID string `json:"uuid"`
} `json:"organization"`
}
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
if err != nil {
@@ -267,11 +266,14 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau
newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000
// ref: cli.js uB6 (line 172696-172697): A?.split(" ").filter(Boolean)
// strings.Fields matches .filter(Boolean): splits on whitespace runs, removes empty strings
if tokenResponse.Scope != "" {
newCredentials.Scopes = strings.Fields(tokenResponse.Scope)
if tokenResponse.Scope != nil {
newCredentials.Scopes = strings.Fields(*tokenResponse.Scope)
}
return &newCredentials, 0, nil
return &oauthRefreshResult{
Credentials: &newCredentials,
TokenAccount: extractTokenAccount(tokenResponse.Account, tokenResponse.Organization),
}, 0, nil
}
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
@@ -280,6 +282,8 @@ func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
}
cloned := *credentials
cloned.Scopes = append([]string(nil), credentials.Scopes...)
cloned.SubscriptionType = cloneStringPointer(credentials.SubscriptionType)
cloned.RateLimitTier = cloneStringPointer(credentials.RateLimitTier)
return &cloned
}
@@ -290,5 +294,31 @@ func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
return left.AccessToken == right.AccessToken &&
left.RefreshToken == right.RefreshToken &&
left.ExpiresAt == right.ExpiresAt &&
slices.Equal(left.Scopes, right.Scopes)
slices.Equal(left.Scopes, right.Scopes) &&
equalStringPointer(left.SubscriptionType, right.SubscriptionType) &&
equalStringPointer(left.RateLimitTier, right.RateLimitTier)
}
func extractTokenAccount(account *struct {
UUID string `json:"uuid"`
EmailAddress string `json:"email_address"`
}, organization *struct {
UUID string `json:"uuid"`
},
) *claudeOAuthAccount {
if account == nil && organization == nil {
return nil
}
tokenAccount := &claudeOAuthAccount{}
if account != nil {
tokenAccount.AccountUUID = account.UUID
tokenAccount.EmailAddress = account.EmailAddress
}
if organization != nil {
tokenAccount.OrganizationUUID = organization.UUID
}
if tokenAccount.AccountUUID == "" && tokenAccount.EmailAddress == "" && tokenAccount.OrganizationUUID == "" {
return nil
}
return tokenAccount
}

View File

@@ -0,0 +1,141 @@
package ccm
import (
"context"
"encoding/json"
"io"
"net/http"
"slices"
"strings"
"testing"
"time"
)
func TestRefreshTokenScopeParsing(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
storedScopes []string
responseBody string
expectedScope string
expected []string
}{
{
name: "missing scope preserves stored scopes",
storedScopes: []string{"user:profile", "user:inference"},
responseBody: `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`,
expectedScope: strings.Join(defaultOAuthScopes, " "),
expected: []string{"user:profile", "user:inference"},
},
{
name: "empty scope clears stored scopes",
storedScopes: []string{"user:profile", "user:inference"},
responseBody: `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600,"scope":""}`,
expectedScope: strings.Join(defaultOAuthScopes, " "),
expected: []string{},
},
{
name: "stored non inference scopes are sent verbatim",
storedScopes: []string{"user:profile"},
responseBody: `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600,"scope":"user:profile user:file_upload"}`,
expectedScope: "user:profile",
expected: []string{"user:profile", "user:file_upload"},
},
}
for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
var seenScope string
client := &http.Client{Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) {
body, err := io.ReadAll(request.Body)
if err != nil {
t.Fatal(err)
}
var payload map[string]string
if err := json.Unmarshal(body, &payload); err != nil {
t.Fatal(err)
}
seenScope = payload["scope"]
return newJSONResponse(http.StatusOK, testCase.responseBody), nil
})}
result, _, err := refreshToken(context.Background(), client, &oauthCredentials{
AccessToken: "old-token",
RefreshToken: "refresh-token",
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
Scopes: testCase.storedScopes,
})
if err != nil {
t.Fatal(err)
}
if seenScope != testCase.expectedScope {
t.Fatalf("expected request scope %q, got %q", testCase.expectedScope, seenScope)
}
if result == nil || result.Credentials == nil {
t.Fatal("expected refresh result credentials")
}
if !slices.Equal(result.Credentials.Scopes, testCase.expected) {
t.Fatalf("expected scopes %v, got %v", testCase.expected, result.Credentials.Scopes)
}
})
}
}
func TestRefreshTokenExtractsTokenAccount(t *testing.T) {
t.Parallel()
client := &http.Client{Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) {
return newJSONResponse(http.StatusOK, `{
"access_token":"new-token",
"refresh_token":"new-refresh",
"expires_in":3600,
"account":{"uuid":"account","email_address":"user@example.com"},
"organization":{"uuid":"org"}
}`), nil
})}
result, _, err := refreshToken(context.Background(), client, &oauthCredentials{
AccessToken: "old-token",
RefreshToken: "refresh-token",
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
Scopes: []string{"user:profile", "user:inference"},
})
if err != nil {
t.Fatal(err)
}
if result == nil || result.TokenAccount == nil {
t.Fatal("expected token account")
}
if result.TokenAccount.AccountUUID != "account" || result.TokenAccount.EmailAddress != "user@example.com" || result.TokenAccount.OrganizationUUID != "org" {
t.Fatalf("unexpected token account: %#v", result.TokenAccount)
}
}
func TestCredentialsEqualIncludesProfileFields(t *testing.T) {
t.Parallel()
subscriptionType := "max"
rateLimitTier := "default_claude_max_20x"
left := &oauthCredentials{
AccessToken: "token",
RefreshToken: "refresh",
ExpiresAt: 123,
Scopes: []string{"user:inference"},
SubscriptionType: &subscriptionType,
RateLimitTier: &rateLimitTier,
}
right := cloneCredentials(left)
if !credentialsEqual(left, right) {
t.Fatal("expected cloned credentials to be equal")
}
otherTier := "default_claude_max_5x"
right.RateLimitTier = &otherTier
if credentialsEqual(left, right) {
t.Fatal("expected different rate limit tier to break equality")
}
}

View File

@@ -0,0 +1,124 @@
package ccm
import (
"encoding/json"
"errors"
"os"
"path/filepath"
)
type jsonContainerStorage interface {
readContainer() (map[string]json.RawMessage, bool, error)
writeContainer(map[string]json.RawMessage) error
delete() error
}
type jsonFileStorage struct {
path string
}
func (s jsonFileStorage) readContainer() (map[string]json.RawMessage, bool, error) {
data, err := os.ReadFile(s.path)
if err != nil {
if os.IsNotExist(err) {
return make(map[string]json.RawMessage), false, nil
}
return nil, false, err
}
container := make(map[string]json.RawMessage)
if len(data) == 0 {
return container, true, nil
}
if err := json.Unmarshal(data, &container); err != nil {
return nil, true, err
}
return container, true, nil
}
func (s jsonFileStorage) writeContainer(container map[string]json.RawMessage) error {
if err := os.MkdirAll(filepath.Dir(s.path), 0o700); err != nil {
return err
}
data, err := json.MarshalIndent(container, "", " ")
if err != nil {
return err
}
return os.WriteFile(s.path, data, 0o600)
}
func (s jsonFileStorage) delete() error {
err := os.Remove(s.path)
if err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
func writeStorageValue(storage jsonContainerStorage, key string, value any) error {
container, _, err := storage.readContainer()
if err != nil {
var syntaxError *json.SyntaxError
var typeError *json.UnmarshalTypeError
if !errors.As(err, &syntaxError) && !errors.As(err, &typeError) {
return err
}
container = make(map[string]json.RawMessage)
}
if container == nil {
container = make(map[string]json.RawMessage)
}
encodedValue, err := json.Marshal(value)
if err != nil {
return err
}
container[key] = encodedValue
return storage.writeContainer(container)
}
func persistStorageValue(primary jsonContainerStorage, fallback jsonContainerStorage, key string, value any) error {
primaryErr := writeStorageValue(primary, key, value)
if primaryErr == nil {
if fallback != nil {
_ = fallback.delete()
}
return nil
}
if fallback == nil {
return primaryErr
}
if err := writeStorageValue(fallback, key, value); err != nil {
return err
}
_ = primary.delete()
return nil
}
func cloneStringPointer(value *string) *string {
if value == nil {
return nil
}
cloned := *value
return &cloned
}
func cloneBoolPointer(value *bool) *bool {
if value == nil {
return nil
}
cloned := *value
return &cloned
}
func equalStringPointer(left *string, right *string) bool {
if left == nil || right == nil {
return left == right
}
return *left == *right
}
func equalBoolPointer(left *bool, right *bool) bool {
if left == nil || right == nil {
return left == right
}
return *left == *right
}

View File

@@ -0,0 +1,125 @@
package ccm
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
type fakeJSONStorage struct {
container map[string]json.RawMessage
writeErr error
deleted bool
}
func (s *fakeJSONStorage) readContainer() (map[string]json.RawMessage, bool, error) {
if s.container == nil {
return make(map[string]json.RawMessage), false, nil
}
cloned := make(map[string]json.RawMessage, len(s.container))
for key, value := range s.container {
cloned[key] = value
}
return cloned, true, nil
}
func (s *fakeJSONStorage) writeContainer(container map[string]json.RawMessage) error {
if s.writeErr != nil {
return s.writeErr
}
s.container = make(map[string]json.RawMessage, len(container))
for key, value := range container {
s.container[key] = value
}
return nil
}
func (s *fakeJSONStorage) delete() error {
s.deleted = true
s.container = nil
return nil
}
func TestPersistStorageValueDeletesFallbackOnPrimarySuccess(t *testing.T) {
t.Parallel()
primary := &fakeJSONStorage{}
fallback := &fakeJSONStorage{container: map[string]json.RawMessage{"stale": json.RawMessage(`true`)}}
if err := persistStorageValue(primary, fallback, "claudeAiOauth", &oauthCredentials{AccessToken: "token"}); err != nil {
t.Fatal(err)
}
if !fallback.deleted {
t.Fatal("expected fallback storage to be deleted after primary write")
}
}
func TestPersistStorageValueDeletesPrimaryAfterFallbackSuccess(t *testing.T) {
t.Parallel()
primary := &fakeJSONStorage{
container: map[string]json.RawMessage{"claudeAiOauth": json.RawMessage(`{"accessToken":"old"}`)},
writeErr: os.ErrPermission,
}
fallback := &fakeJSONStorage{}
if err := persistStorageValue(primary, fallback, "claudeAiOauth", &oauthCredentials{AccessToken: "new"}); err != nil {
t.Fatal(err)
}
if !primary.deleted {
t.Fatal("expected primary storage to be deleted after fallback write")
}
}
func TestWriteCredentialsToFilePreservesTopLevelKeys(t *testing.T) {
t.Parallel()
directory := t.TempDir()
path := filepath.Join(directory, ".credentials.json")
initial := []byte(`{"keep":{"nested":true},"claudeAiOauth":{"accessToken":"old"}}`)
if err := os.WriteFile(path, initial, 0o600); err != nil {
t.Fatal(err)
}
if err := writeCredentialsToFile(&oauthCredentials{AccessToken: "new"}, path); err != nil {
t.Fatal(err)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var container map[string]json.RawMessage
if err := json.Unmarshal(data, &container); err != nil {
t.Fatal(err)
}
if _, exists := container["keep"]; !exists {
t.Fatal("expected unknown top-level key to be preserved")
}
}
func TestWriteClaudeCodeOAuthAccountPreservesTopLevelKeys(t *testing.T) {
t.Parallel()
directory := t.TempDir()
path := filepath.Join(directory, ".claude.json")
initial := []byte(`{"keep":{"nested":true},"oauthAccount":{"accountUuid":"old"}}`)
if err := os.WriteFile(path, initial, 0o600); err != nil {
t.Fatal(err)
}
if err := writeClaudeCodeOAuthAccount(path, &claudeOAuthAccount{AccountUUID: "new"}); err != nil {
t.Fatal(err)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var container map[string]json.RawMessage
if err := json.Unmarshal(data, &container); err != nil {
t.Fatal(err)
}
if _, exists := container["keep"]; !exists {
t.Fatal("expected unknown config key to be preserved")
}
}

View File

@@ -377,8 +377,9 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !selectedCredential.isExternal() && bodyBytes != nil &&
(response.StatusCode == http.StatusUnauthorized || response.StatusCode == http.StatusForbidden) {
shouldRetry := response.StatusCode == http.StatusUnauthorized
var peekBody []byte
if response.StatusCode == http.StatusForbidden {
peekBody, _ := io.ReadAll(response.Body)
peekBody, _ = io.ReadAll(response.Body)
shouldRetry = strings.Contains(string(peekBody), "OAuth token has been revoked")
if !shouldRetry {
response.Body.Close()
@@ -389,23 +390,33 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
if shouldRetry {
response.Body.Close()
s.logger.WarnContext(ctx, "upstream auth failure from ", selectedCredential.tagName(), ", reloading credentials and retrying")
recovered := false
if defaultCred, ok := selectedCredential.(*defaultCredential); ok {
_ = defaultCred.reloadCredentials(true)
failedAccessToken := ""
currentCredentials := defaultCred.currentCredentials()
if currentCredentials != nil {
failedAccessToken = currentCredentials.AccessToken
}
s.logger.WarnContext(ctx, "upstream auth failure from ", selectedCredential.tagName(), ", reloading credentials and retrying")
recovered = defaultCred.recoverAuthFailure(failedAccessToken)
}
retryRequest, buildErr := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(buildErr, "rebuild request after auth recovery").Error())
return
if recovered {
response.Body.Close()
retryRequest, buildErr := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(buildErr, "rebuild request after auth recovery").Error())
return
}
retryResponse, retryErr := selectedCredential.httpClient().Do(retryRequest)
if retryErr != nil {
writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(retryErr, "retry request after auth recovery").Error())
return
}
response = retryResponse
defer retryResponse.Body.Close()
} else if response.StatusCode == http.StatusForbidden {
response.Body = io.NopCloser(bytes.NewReader(peekBody))
}
retryResponse, retryErr := selectedCredential.httpClient().Do(retryRequest)
if retryErr != nil {
writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(retryErr, "retry request after auth recovery").Error())
return
}
response = retryResponse
defer retryResponse.Body.Close()
}
}

View File

@@ -0,0 +1,221 @@
package ccm
import (
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"sync/atomic"
"testing"
"time"
)
func newHandlerCredential(t *testing.T, transport http.RoundTripper) (*defaultCredential, string) {
t.Helper()
directory := t.TempDir()
credentialPath := filepath.Join(directory, ".credentials.json")
writeTestCredentials(t, credentialPath, &oauthCredentials{
AccessToken: "old-token",
RefreshToken: "refresh-token",
ExpiresAt: time.Now().Add(time.Hour).UnixMilli(),
Scopes: []string{"user:profile", "user:inference"},
SubscriptionType: optionalStringPointer("max"),
RateLimitTier: optionalStringPointer("default_claude_max_20x"),
})
credential := newTestDefaultCredential(t, credentialPath, transport)
if err := credential.reloadCredentials(true); err != nil {
t.Fatal(err)
}
seedTestCredentialState(credential)
return credential, credentialPath
}
func TestServiceHandlerRecoversFrom401(t *testing.T) {
t.Parallel()
var messageRequests atomic.Int32
var refreshRequests atomic.Int32
credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
switch request.URL.Path {
case "/v1/messages":
call := messageRequests.Add(1)
switch request.Header.Get("Authorization") {
case "Bearer old-token":
if call != 1 {
t.Fatalf("unexpected old-token call count %d", call)
}
return newTextResponse(http.StatusUnauthorized, "unauthorized"), nil
case "Bearer new-token":
return newJSONResponse(http.StatusOK, `{}`), nil
default:
t.Fatalf("unexpected authorization header %q", request.Header.Get("Authorization"))
}
case "/v1/oauth/token":
refreshRequests.Add(1)
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
default:
t.Fatalf("unexpected path %s", request.URL.Path)
}
return nil, nil
}))
service := newTestService(credential)
recorder := httptest.NewRecorder()
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", recorder.Code, recorder.Body.String())
}
if messageRequests.Load() != 2 {
t.Fatalf("expected two upstream message requests, got %d", messageRequests.Load())
}
if refreshRequests.Load() != 1 {
t.Fatalf("expected one refresh request, got %d", refreshRequests.Load())
}
}
func TestServiceHandlerRecoversFromRevoked403(t *testing.T) {
t.Parallel()
var messageRequests atomic.Int32
var refreshRequests atomic.Int32
credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
switch request.URL.Path {
case "/v1/messages":
messageRequests.Add(1)
if request.Header.Get("Authorization") == "Bearer old-token" {
return newTextResponse(http.StatusForbidden, "OAuth token has been revoked"), nil
}
return newJSONResponse(http.StatusOK, `{}`), nil
case "/v1/oauth/token":
refreshRequests.Add(1)
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
default:
t.Fatalf("unexpected path %s", request.URL.Path)
}
return nil, nil
}))
service := newTestService(credential)
recorder := httptest.NewRecorder()
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", recorder.Code, recorder.Body.String())
}
if refreshRequests.Load() != 1 {
t.Fatalf("expected one refresh request, got %d", refreshRequests.Load())
}
}
func TestServiceHandlerDoesNotRecoverFromOrdinary403(t *testing.T) {
t.Parallel()
var refreshRequests atomic.Int32
credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
switch request.URL.Path {
case "/v1/messages":
return newTextResponse(http.StatusForbidden, "forbidden"), nil
case "/v1/oauth/token":
refreshRequests.Add(1)
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
default:
t.Fatalf("unexpected path %s", request.URL.Path)
}
return nil, nil
}))
service := newTestService(credential)
recorder := httptest.NewRecorder()
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
if recorder.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", recorder.Code)
}
if refreshRequests.Load() != 0 {
t.Fatalf("expected no refresh request, got %d", refreshRequests.Load())
}
if !strings.Contains(recorder.Body.String(), "forbidden") {
t.Fatalf("expected forbidden body, got %s", recorder.Body.String())
}
}
func TestServiceHandlerUsesReloadedTokenBeforeRefreshing(t *testing.T) {
t.Parallel()
var messageRequests atomic.Int32
var refreshRequests atomic.Int32
var credentialPath string
var credential *defaultCredential
credential, credentialPath = newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
switch request.URL.Path {
case "/v1/messages":
call := messageRequests.Add(1)
if request.Header.Get("Authorization") == "Bearer old-token" {
updatedCredentials := readTestCredentials(t, credentialPath)
updatedCredentials.AccessToken = "disk-token"
updatedCredentials.ExpiresAt = time.Now().Add(time.Hour).UnixMilli()
writeTestCredentials(t, credentialPath, updatedCredentials)
if call != 1 {
t.Fatalf("unexpected old-token call count %d", call)
}
return newTextResponse(http.StatusUnauthorized, "unauthorized"), nil
}
if request.Header.Get("Authorization") != "Bearer disk-token" {
t.Fatalf("expected disk token retry, got %q", request.Header.Get("Authorization"))
}
return newJSONResponse(http.StatusOK, `{}`), nil
case "/v1/oauth/token":
refreshRequests.Add(1)
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
default:
t.Fatalf("unexpected path %s", request.URL.Path)
}
return nil, nil
}))
service := newTestService(credential)
recorder := httptest.NewRecorder()
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", recorder.Code, recorder.Body.String())
}
if refreshRequests.Load() != 0 {
t.Fatalf("expected zero refresh requests, got %d", refreshRequests.Load())
}
}
func TestServiceHandlerRetriesAuthRecoveryOnlyOnce(t *testing.T) {
t.Parallel()
var messageRequests atomic.Int32
var refreshRequests atomic.Int32
credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
switch request.URL.Path {
case "/v1/messages":
messageRequests.Add(1)
return newTextResponse(http.StatusUnauthorized, "still unauthorized"), nil
case "/v1/oauth/token":
refreshRequests.Add(1)
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
default:
t.Fatalf("unexpected path %s", request.URL.Path)
}
return nil, nil
}))
service := newTestService(credential)
recorder := httptest.NewRecorder()
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
if recorder.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", recorder.Code)
}
if messageRequests.Load() != 2 {
t.Fatalf("expected exactly two upstream attempts, got %d", messageRequests.Load())
}
if refreshRequests.Load() != 1 {
t.Fatalf("expected exactly one refresh request, got %d", refreshRequests.Load())
}
}

View File

@@ -0,0 +1,138 @@
package ccm
import (
"context"
"io"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(request *http.Request) (*http.Response, error) {
return f(request)
}
func newJSONResponse(statusCode int, body string) *http.Response {
return &http.Response{
StatusCode: statusCode,
Status: http.StatusText(statusCode),
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(body)),
}
}
func newTextResponse(statusCode int, body string) *http.Response {
return &http.Response{
StatusCode: statusCode,
Status: http.StatusText(statusCode),
Header: http.Header{"Content-Type": []string{"text/plain"}},
Body: io.NopCloser(strings.NewReader(body)),
}
}
func writeTestCredentials(t *testing.T, path string, credentials *oauthCredentials) {
t.Helper()
if path == "" {
var err error
path, err = getDefaultCredentialsPath()
if err != nil {
t.Fatal(err)
}
}
if err := writeCredentialsToFile(credentials, path); err != nil {
t.Fatal(err)
}
}
func readTestCredentials(t *testing.T, path string) *oauthCredentials {
t.Helper()
if path == "" {
var err error
path, err = getDefaultCredentialsPath()
if err != nil {
t.Fatal(err)
}
}
credentials, err := readCredentialsFromFile(path)
if err != nil {
t.Fatal(err)
}
return credentials
}
func newTestDefaultCredential(t *testing.T, credentialPath string, transport http.RoundTripper) *defaultCredential {
t.Helper()
credentialFilePath, err := resolveCredentialFilePath(credentialPath)
if err != nil {
t.Fatal(err)
}
requestContext, cancelRequests := context.WithCancel(context.Background())
credential := &defaultCredential{
tag: "test",
serviceContext: context.Background(),
credentialPath: credentialPath,
credentialFilePath: credentialFilePath,
configDir: resolveConfigDir(credentialPath, credentialFilePath),
syncClaudeConfig: credentialPath == "",
cap5h: 99,
capWeekly: 99,
forwardHTTPClient: &http.Client{Transport: transport},
logger: log.NewNOPFactory().Logger(),
requestContext: requestContext,
cancelRequests: cancelRequests,
}
if credential.syncClaudeConfig {
credential.claudeDirectory = credential.configDir
credential.claudeConfigPath = resolveClaudeConfigWritePath(credential.claudeDirectory)
}
credential.state.lastUpdated = time.Now()
return credential
}
func seedTestCredentialState(credential *defaultCredential) {
billingType := "individual"
accountCreatedAt := "2024-01-01T00:00:00Z"
subscriptionCreatedAt := "2024-01-02T00:00:00Z"
credential.stateAccess.Lock()
credential.state.accountUUID = "account"
credential.state.accountType = "max"
credential.state.rateLimitTier = "default_claude_max_20x"
credential.state.oauthAccount = &claudeOAuthAccount{
AccountUUID: "account",
EmailAddress: "user@example.com",
OrganizationUUID: "org",
BillingType: &billingType,
AccountCreatedAt: &accountCreatedAt,
SubscriptionCreatedAt: &subscriptionCreatedAt,
}
credential.stateAccess.Unlock()
}
func newTestService(credential *defaultCredential) *Service {
return &Service{
logger: log.NewNOPFactory().Logger(),
options: option.CCMServiceOptions{Credentials: []option.CCMCredential{{Tag: "default"}}},
httpHeaders: make(http.Header),
providers: map[string]credentialProvider{"default": &singleCredentialProvider{credential: credential}},
sessionModels: make(map[sessionModelKey]time.Time),
}
}
func newMessageRequest(body string) *http.Request {
request := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
request.Header.Set("Content-Type", "application/json")
return request
}
func tempConfigPath(t *testing.T, dir string) string {
t.Helper()
return filepath.Join(dir, claudeCodeLegacyConfigFileName())
}