fix(ccm): align default credential with Claude Code
This commit is contained in:
@@ -62,6 +62,7 @@ type credentialState struct {
|
||||
accountUUID string
|
||||
accountType string
|
||||
rateLimitTier string
|
||||
oauthAccount *claudeOAuthAccount
|
||||
remotePlanWeight float64
|
||||
lastUpdated time.Time
|
||||
consecutivePollFailures int
|
||||
|
||||
@@ -19,7 +19,14 @@ type claudeCodeConfig struct {
|
||||
}
|
||||
|
||||
type claudeOAuthAccount struct {
|
||||
AccountUUID string `json:"accountUuid"`
|
||||
AccountUUID string `json:"accountUuid,omitempty"`
|
||||
EmailAddress string `json:"emailAddress,omitempty"`
|
||||
OrganizationUUID string `json:"organizationUuid,omitempty"`
|
||||
DisplayName *string `json:"displayName,omitempty"`
|
||||
HasExtraUsageEnabled *bool `json:"hasExtraUsageEnabled,omitempty"`
|
||||
BillingType *string `json:"billingType,omitempty"`
|
||||
AccountCreatedAt *string `json:"accountCreatedAt,omitempty"`
|
||||
SubscriptionCreatedAt *string `json:"subscriptionCreatedAt,omitempty"`
|
||||
}
|
||||
|
||||
// resolveClaudeConfigFile finds the Claude Code config file within the given directory.
|
||||
@@ -33,8 +40,8 @@ type claudeOAuthAccount struct {
|
||||
func resolveClaudeConfigFile(claudeDirectory string) string {
|
||||
candidates := []string{
|
||||
filepath.Join(claudeDirectory, ".config.json"),
|
||||
filepath.Join(claudeDirectory, ".claude.json"),
|
||||
filepath.Join(filepath.Dir(claudeDirectory), ".claude.json"),
|
||||
filepath.Join(claudeDirectory, claudeCodeLegacyConfigFileName()),
|
||||
filepath.Join(filepath.Dir(claudeDirectory), claudeCodeLegacyConfigFileName()),
|
||||
}
|
||||
for _, candidate := range candidates {
|
||||
_, err := os.Stat(candidate)
|
||||
@@ -57,3 +64,84 @@ func readClaudeCodeConfig(path string) (*claudeCodeConfig, error) {
|
||||
}
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func resolveClaudeConfigWritePath(claudeDirectory string) string {
|
||||
if claudeDirectory == "" {
|
||||
return ""
|
||||
}
|
||||
existingPath := resolveClaudeConfigFile(claudeDirectory)
|
||||
if existingPath != "" {
|
||||
return existingPath
|
||||
}
|
||||
if os.Getenv("CLAUDE_CONFIG_DIR") != "" {
|
||||
return filepath.Join(claudeDirectory, claudeCodeLegacyConfigFileName())
|
||||
}
|
||||
defaultClaudeDirectory := filepath.Join(filepath.Dir(claudeDirectory), ".claude")
|
||||
if claudeDirectory != defaultClaudeDirectory {
|
||||
return filepath.Join(claudeDirectory, claudeCodeLegacyConfigFileName())
|
||||
}
|
||||
return filepath.Join(filepath.Dir(claudeDirectory), claudeCodeLegacyConfigFileName())
|
||||
}
|
||||
|
||||
func writeClaudeCodeOAuthAccount(path string, account *claudeOAuthAccount) error {
|
||||
if path == "" || account == nil {
|
||||
return nil
|
||||
}
|
||||
storage := jsonFileStorage{path: path}
|
||||
return writeStorageValue(storage, "oauthAccount", account)
|
||||
}
|
||||
|
||||
func claudeCodeLegacyConfigFileName() string {
|
||||
if os.Getenv("CLAUDE_CODE_CUSTOM_OAUTH_URL") != "" {
|
||||
return ".claude-custom-oauth.json"
|
||||
}
|
||||
return ".claude.json"
|
||||
}
|
||||
|
||||
func cloneClaudeOAuthAccount(account *claudeOAuthAccount) *claudeOAuthAccount {
|
||||
if account == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *account
|
||||
cloned.DisplayName = cloneStringPointer(account.DisplayName)
|
||||
cloned.HasExtraUsageEnabled = cloneBoolPointer(account.HasExtraUsageEnabled)
|
||||
cloned.BillingType = cloneStringPointer(account.BillingType)
|
||||
cloned.AccountCreatedAt = cloneStringPointer(account.AccountCreatedAt)
|
||||
cloned.SubscriptionCreatedAt = cloneStringPointer(account.SubscriptionCreatedAt)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func mergeClaudeOAuthAccount(base *claudeOAuthAccount, update *claudeOAuthAccount) *claudeOAuthAccount {
|
||||
if update == nil {
|
||||
return cloneClaudeOAuthAccount(base)
|
||||
}
|
||||
if base == nil {
|
||||
return cloneClaudeOAuthAccount(update)
|
||||
}
|
||||
merged := cloneClaudeOAuthAccount(base)
|
||||
if update.AccountUUID != "" {
|
||||
merged.AccountUUID = update.AccountUUID
|
||||
}
|
||||
if update.EmailAddress != "" {
|
||||
merged.EmailAddress = update.EmailAddress
|
||||
}
|
||||
if update.OrganizationUUID != "" {
|
||||
merged.OrganizationUUID = update.OrganizationUUID
|
||||
}
|
||||
if update.DisplayName != nil {
|
||||
merged.DisplayName = cloneStringPointer(update.DisplayName)
|
||||
}
|
||||
if update.HasExtraUsageEnabled != nil {
|
||||
merged.HasExtraUsageEnabled = cloneBoolPointer(update.HasExtraUsageEnabled)
|
||||
}
|
||||
if update.BillingType != nil {
|
||||
merged.BillingType = cloneStringPointer(update.BillingType)
|
||||
}
|
||||
if update.AccountCreatedAt != nil {
|
||||
merged.AccountCreatedAt = cloneStringPointer(update.AccountCreatedAt)
|
||||
}
|
||||
if update.SubscriptionCreatedAt != nil {
|
||||
merged.SubscriptionCreatedAt = cloneStringPointer(update.SubscriptionCreatedAt)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
@@ -14,6 +14,11 @@ import (
|
||||
"github.com/keybase/go-keychain"
|
||||
)
|
||||
|
||||
type keychainStorage struct {
|
||||
service string
|
||||
account string
|
||||
}
|
||||
|
||||
func getKeychainServiceName() string {
|
||||
configDirectory := os.Getenv("CLAUDE_CONFIG_DIR")
|
||||
if configDirectory == "" {
|
||||
@@ -76,72 +81,90 @@ func platformCanWriteCredentials(customPath string) error {
|
||||
return checkCredentialFileWritable(customPath)
|
||||
}
|
||||
|
||||
// platformWriteCredentials performs a read-modify-write on the keychain entry,
|
||||
// preserving any fields or top-level keys not managed by CCM.
|
||||
//
|
||||
// ref (@anthropic-ai/claude-code @2.1.81): cli.js BP6 (line 179444-179454) — read-modify-write
|
||||
func platformWriteCredentials(credentials *oauthCredentials, customPath string) error {
|
||||
if customPath != "" {
|
||||
return writeCredentialsToFile(credentials, customPath)
|
||||
}
|
||||
|
||||
userInfo, err := getRealUser()
|
||||
if err == nil {
|
||||
serviceName := getKeychainServiceName()
|
||||
|
||||
existing := make(map[string]json.RawMessage)
|
||||
query := keychain.NewItem()
|
||||
query.SetSecClass(keychain.SecClassGenericPassword)
|
||||
query.SetService(serviceName)
|
||||
query.SetAccount(userInfo.Username)
|
||||
query.SetMatchLimit(keychain.MatchLimitOne)
|
||||
query.SetReturnData(true)
|
||||
results, queryErr := keychain.QueryItem(query)
|
||||
if queryErr == nil && len(results) == 1 {
|
||||
_ = json.Unmarshal(results[0].Data, &existing)
|
||||
}
|
||||
|
||||
credentialData, err := json.Marshal(credentials)
|
||||
if err != nil {
|
||||
return E.Cause(err, "marshal credentials")
|
||||
}
|
||||
existing["claudeAiOauth"] = credentialData
|
||||
data, err := json.Marshal(existing)
|
||||
if err != nil {
|
||||
return E.Cause(err, "marshal credential container")
|
||||
}
|
||||
|
||||
item := keychain.NewItem()
|
||||
item.SetSecClass(keychain.SecClassGenericPassword)
|
||||
item.SetService(serviceName)
|
||||
item.SetAccount(userInfo.Username)
|
||||
item.SetData(data)
|
||||
item.SetAccessible(keychain.AccessibleWhenUnlocked)
|
||||
|
||||
err = keychain.AddItem(item)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err == keychain.ErrorDuplicateItem {
|
||||
updateQuery := keychain.NewItem()
|
||||
updateQuery.SetSecClass(keychain.SecClassGenericPassword)
|
||||
updateQuery.SetService(serviceName)
|
||||
updateQuery.SetAccount(userInfo.Username)
|
||||
|
||||
updateItem := keychain.NewItem()
|
||||
updateItem.SetData(data)
|
||||
|
||||
updateErr := keychain.UpdateItem(updateQuery, updateItem)
|
||||
if updateErr == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
defaultPath, err := getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeCredentialsToFile(credentials, defaultPath)
|
||||
fileStorage := jsonFileStorage{path: defaultPath}
|
||||
|
||||
userInfo, err := getRealUser()
|
||||
if err != nil {
|
||||
return writeCredentialsToFile(credentials, defaultPath)
|
||||
}
|
||||
return persistStorageValue(keychainStorage{
|
||||
service: getKeychainServiceName(),
|
||||
account: userInfo.Username,
|
||||
}, fileStorage, "claudeAiOauth", credentials)
|
||||
}
|
||||
|
||||
func (s keychainStorage) readContainer() (map[string]json.RawMessage, bool, error) {
|
||||
query := keychain.NewItem()
|
||||
query.SetSecClass(keychain.SecClassGenericPassword)
|
||||
query.SetService(s.service)
|
||||
query.SetAccount(s.account)
|
||||
query.SetMatchLimit(keychain.MatchLimitOne)
|
||||
query.SetReturnData(true)
|
||||
|
||||
results, err := keychain.QueryItem(query)
|
||||
if err != nil {
|
||||
if err == keychain.ErrorItemNotFound {
|
||||
return make(map[string]json.RawMessage), false, nil
|
||||
}
|
||||
return nil, false, E.Cause(err, "query keychain")
|
||||
}
|
||||
if len(results) != 1 {
|
||||
return make(map[string]json.RawMessage), false, nil
|
||||
}
|
||||
|
||||
container := make(map[string]json.RawMessage)
|
||||
if len(results[0].Data) == 0 {
|
||||
return container, true, nil
|
||||
}
|
||||
if err := json.Unmarshal(results[0].Data, &container); err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
return container, true, nil
|
||||
}
|
||||
|
||||
func (s keychainStorage) writeContainer(container map[string]json.RawMessage) error {
|
||||
data, err := json.Marshal(container)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
item := keychain.NewItem()
|
||||
item.SetSecClass(keychain.SecClassGenericPassword)
|
||||
item.SetService(s.service)
|
||||
item.SetAccount(s.account)
|
||||
item.SetData(data)
|
||||
item.SetAccessible(keychain.AccessibleWhenUnlocked)
|
||||
err = keychain.AddItem(item)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if err != keychain.ErrorDuplicateItem {
|
||||
return err
|
||||
}
|
||||
|
||||
updateQuery := keychain.NewItem()
|
||||
updateQuery.SetSecClass(keychain.SecClassGenericPassword)
|
||||
updateQuery.SetService(s.service)
|
||||
updateQuery.SetAccount(s.account)
|
||||
|
||||
updateItem := keychain.NewItem()
|
||||
updateItem.SetData(data)
|
||||
return keychain.UpdateItem(updateQuery, updateItem)
|
||||
}
|
||||
|
||||
func (s keychainStorage) delete() error {
|
||||
err := keychain.DeleteGenericPasswordItem(s.service, s.account)
|
||||
if err != nil && err != keychain.ErrorItemNotFound {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -26,6 +26,15 @@ import (
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
)
|
||||
|
||||
var acquireCredentialLockFunc = acquireCredentialLock
|
||||
|
||||
type claudeProfileSnapshot struct {
|
||||
OAuthAccount *claudeOAuthAccount
|
||||
AccountType string
|
||||
RateLimitTier string
|
||||
SubscriptionType *string
|
||||
}
|
||||
|
||||
type defaultCredential struct {
|
||||
tag string
|
||||
serviceContext context.Context
|
||||
@@ -33,8 +42,9 @@ type defaultCredential struct {
|
||||
claudeDirectory string
|
||||
credentialFilePath string
|
||||
configDir string
|
||||
claudeConfigPath string
|
||||
syncClaudeConfig bool
|
||||
deviceID string
|
||||
configLoaded bool
|
||||
credentials *oauthCredentials
|
||||
access sync.RWMutex
|
||||
state credentialState
|
||||
@@ -52,11 +62,6 @@ type defaultCredential struct {
|
||||
|
||||
statusSubscriber *observable.Subscriber[struct{}]
|
||||
|
||||
// Refresh rate-limit cooldown (protected by access mutex)
|
||||
refreshRetryAt time.Time
|
||||
refreshRetryError error
|
||||
refreshBlocked bool
|
||||
|
||||
// Connection interruption
|
||||
interrupted bool
|
||||
requestContext context.Context
|
||||
@@ -113,6 +118,7 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef
|
||||
serviceContext: ctx,
|
||||
credentialPath: options.CredentialPath,
|
||||
claudeDirectory: options.ClaudeDirectory,
|
||||
syncClaudeConfig: options.ClaudeDirectory != "" || options.CredentialPath == "",
|
||||
cap5h: cap5h,
|
||||
capWeekly: capWeekly,
|
||||
forwardHTTPClient: httpClient,
|
||||
@@ -133,7 +139,6 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef
|
||||
|
||||
func (c *defaultCredential) start() error {
|
||||
if c.claudeDirectory != "" {
|
||||
c.loadClaudeCodeConfig()
|
||||
if c.credentialPath == "" {
|
||||
c.credentialPath = filepath.Join(c.claudeDirectory, ".credentials.json")
|
||||
}
|
||||
@@ -144,6 +149,13 @@ func (c *defaultCredential) start() error {
|
||||
}
|
||||
c.credentialFilePath = credentialFilePath
|
||||
c.configDir = resolveConfigDir(c.credentialPath, credentialFilePath)
|
||||
if c.syncClaudeConfig {
|
||||
if c.claudeDirectory == "" {
|
||||
c.claudeDirectory = c.configDir
|
||||
}
|
||||
c.claudeConfigPath = resolveClaudeConfigWritePath(c.claudeDirectory)
|
||||
c.loadClaudeCodeConfig()
|
||||
}
|
||||
err = c.ensureCredentialWatcher()
|
||||
if err != nil {
|
||||
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
|
||||
@@ -173,6 +185,7 @@ func (c *defaultCredential) loadClaudeCodeConfig() {
|
||||
return
|
||||
}
|
||||
c.stateAccess.Lock()
|
||||
c.state.oauthAccount = cloneClaudeOAuthAccount(config.OAuthAccount)
|
||||
if config.OAuthAccount != nil && config.OAuthAccount.AccountUUID != "" {
|
||||
c.state.accountUUID = config.OAuthAccount.AccountUUID
|
||||
}
|
||||
@@ -180,7 +193,7 @@ func (c *defaultCredential) loadClaudeCodeConfig() {
|
||||
if config.UserID != "" {
|
||||
c.deviceID = config.UserID
|
||||
}
|
||||
c.configLoaded = true
|
||||
c.claudeConfigPath = configFilePath
|
||||
c.logger.Debug("loaded claude code config for ", c.tag, ": account=", c.state.accountUUID, ", device=", c.deviceID)
|
||||
}
|
||||
|
||||
@@ -209,147 +222,358 @@ func (c *defaultCredential) statusSnapshotLocked() statusSnapshot {
|
||||
func (c *defaultCredential) getAccessToken() (string, error) {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
// Fast path: cached token is still valid
|
||||
c.access.RLock()
|
||||
if c.credentials != nil && !c.credentials.needsRefresh() {
|
||||
token := c.credentials.AccessToken
|
||||
c.access.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
currentCredentials := cloneCredentials(c.credentials)
|
||||
c.access.RUnlock()
|
||||
|
||||
// Reload from disk — Claude Code or another process may have refreshed
|
||||
err := c.reloadCredentials(true)
|
||||
if err == nil {
|
||||
c.access.RLock()
|
||||
if c.credentials != nil && !c.credentials.needsRefresh() {
|
||||
token := c.credentials.AccessToken
|
||||
c.access.RUnlock()
|
||||
return token, nil
|
||||
if currentCredentials == nil {
|
||||
err := c.reloadCredentials(true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
c.access.RLock()
|
||||
currentCredentials = cloneCredentials(c.credentials)
|
||||
c.access.RUnlock()
|
||||
}
|
||||
|
||||
// ref (@anthropic-ai/claude-code @2.1.81): cli.js _P1 line 179526
|
||||
// Claude Code skips refresh for tokens without user:inference scope.
|
||||
// Return existing token (may be expired); 401 recovery is the safety net.
|
||||
c.access.RLock()
|
||||
if c.credentials != nil && !slices.Contains(c.credentials.Scopes, "user:inference") {
|
||||
token := c.credentials.AccessToken
|
||||
c.access.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
c.access.RUnlock()
|
||||
|
||||
// Acquire cross-process lock before refresh (outside Go mutex to avoid holding mutex during sleep)
|
||||
// ref: cli.js _P1 (line 179534-179536) — proper-lockfile lock on config dir
|
||||
release, lockErr := acquireCredentialLock(c.configDir)
|
||||
if lockErr != nil {
|
||||
c.logger.Debug("acquire credential lock for ", c.tag, ": ", lockErr)
|
||||
release = func() {}
|
||||
}
|
||||
defer release()
|
||||
|
||||
// ref: cli.js _P1 (line 179559-179562) — re-read after lock, skip if race resolved
|
||||
_ = c.reloadCredentials(true)
|
||||
c.access.RLock()
|
||||
noRefreshToken := c.credentials == nil || c.credentials.RefreshToken == ""
|
||||
raceResolved := !noRefreshToken && !c.credentials.needsRefresh()
|
||||
var racedToken string
|
||||
if (noRefreshToken || raceResolved) && c.credentials != nil {
|
||||
racedToken = c.credentials.AccessToken
|
||||
}
|
||||
c.access.RUnlock()
|
||||
if noRefreshToken || raceResolved {
|
||||
return racedToken, nil
|
||||
}
|
||||
|
||||
// Slow path: acquire Go mutex and refresh
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
|
||||
if c.credentials == nil {
|
||||
if currentCredentials == nil {
|
||||
return "", c.unavailableError()
|
||||
}
|
||||
if !c.credentials.needsRefresh() {
|
||||
if !currentCredentials.needsRefresh() || !slices.Contains(currentCredentials.Scopes, "user:inference") {
|
||||
return currentCredentials.AccessToken, nil
|
||||
}
|
||||
c.tryRefreshCredentials(false)
|
||||
c.access.RLock()
|
||||
defer c.access.RUnlock()
|
||||
if c.credentials != nil && c.credentials.AccessToken != "" {
|
||||
return c.credentials.AccessToken, nil
|
||||
}
|
||||
return "", c.unavailableError()
|
||||
}
|
||||
|
||||
if c.refreshBlocked {
|
||||
return "", c.refreshRetryError
|
||||
}
|
||||
if !c.refreshRetryAt.IsZero() && time.Now().Before(c.refreshRetryAt) {
|
||||
return "", c.refreshRetryError
|
||||
}
|
||||
func (c *defaultCredential) shouldUseClaudeConfig() bool {
|
||||
return c.syncClaudeConfig && c.claudeConfigPath != ""
|
||||
}
|
||||
|
||||
err = platformCanWriteCredentials(c.credentialPath)
|
||||
if err != nil {
|
||||
return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation")
|
||||
}
|
||||
func (c *defaultCredential) absorbCredentials(credentials *oauthCredentials) {
|
||||
c.access.Lock()
|
||||
c.credentials = cloneCredentials(credentials)
|
||||
c.access.Unlock()
|
||||
|
||||
baseCredentials := cloneCredentials(c.credentials)
|
||||
newCredentials, retryDelay, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials)
|
||||
if err != nil {
|
||||
if retryDelay < 0 {
|
||||
c.refreshBlocked = true
|
||||
c.refreshRetryError = err
|
||||
} else if retryDelay > 0 {
|
||||
c.refreshRetryAt = time.Now().Add(retryDelay)
|
||||
c.refreshRetryError = err
|
||||
}
|
||||
// ref: cli.js _P1 (line 179568-179573) — post-failure recovery:
|
||||
// re-read from disk; if another process refreshed successfully, use that.
|
||||
// Cannot call reloadCredentials here (deadlock: already holding c.access).
|
||||
latestCredentials, readErr := platformReadCredentials(c.credentialPath)
|
||||
if readErr == nil && latestCredentials != nil && !latestCredentials.needsRefresh() {
|
||||
c.credentials = latestCredentials
|
||||
return latestCredentials.AccessToken, nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
c.refreshRetryAt = time.Time{}
|
||||
c.refreshRetryError = nil
|
||||
c.refreshBlocked = false
|
||||
|
||||
latestCredentials, latestErr := platformReadCredentials(c.credentialPath)
|
||||
if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) {
|
||||
c.credentials = latestCredentials
|
||||
c.stateAccess.Lock()
|
||||
before := c.statusSnapshotLocked()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.checkTransitionLocked()
|
||||
shouldEmit := before != c.statusSnapshotLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
if !latestCredentials.needsRefresh() {
|
||||
return latestCredentials.AccessToken, nil
|
||||
}
|
||||
return "", E.New("credential ", c.tag, " changed while refreshing")
|
||||
}
|
||||
|
||||
c.credentials = newCredentials
|
||||
c.stateAccess.Lock()
|
||||
before := c.statusSnapshotLocked()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.applyCredentialMetadataLocked(credentials)
|
||||
c.checkTransitionLocked()
|
||||
shouldEmit := before != c.statusSnapshotLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
err = platformWriteCredentials(newCredentials, c.credentialPath)
|
||||
if err != nil {
|
||||
func (c *defaultCredential) applyCredentialMetadataLocked(credentials *oauthCredentials) {
|
||||
if credentials == nil {
|
||||
return
|
||||
}
|
||||
if credentials.SubscriptionType != nil && *credentials.SubscriptionType != "" {
|
||||
c.state.accountType = *credentials.SubscriptionType
|
||||
}
|
||||
if credentials.RateLimitTier != nil && *credentials.RateLimitTier != "" {
|
||||
c.state.rateLimitTier = *credentials.RateLimitTier
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) absorbOAuthAccount(account *claudeOAuthAccount) {
|
||||
c.stateAccess.Lock()
|
||||
c.state.oauthAccount = mergeClaudeOAuthAccount(c.state.oauthAccount, account)
|
||||
if c.state.oauthAccount != nil && c.state.oauthAccount.AccountUUID != "" {
|
||||
c.state.accountUUID = c.state.oauthAccount.AccountUUID
|
||||
}
|
||||
c.stateAccess.Unlock()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) persistOAuthAccount() {
|
||||
if !c.shouldUseClaudeConfig() {
|
||||
return
|
||||
}
|
||||
c.stateAccess.RLock()
|
||||
account := cloneClaudeOAuthAccount(c.state.oauthAccount)
|
||||
c.stateAccess.RUnlock()
|
||||
if account == nil {
|
||||
return
|
||||
}
|
||||
if err := writeClaudeCodeOAuthAccount(c.claudeConfigPath, account); err != nil {
|
||||
c.logger.Debug("write claude code config for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) needsProfileHydration() bool {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.needsProfileHydrationLocked()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) needsProfileHydrationLocked() bool {
|
||||
if c.state.accountUUID == "" || c.state.accountType == "" || c.state.rateLimitTier == "" {
|
||||
return true
|
||||
}
|
||||
if c.state.oauthAccount == nil {
|
||||
return true
|
||||
}
|
||||
return c.state.oauthAccount.BillingType == nil ||
|
||||
c.state.oauthAccount.AccountCreatedAt == nil ||
|
||||
c.state.oauthAccount.SubscriptionCreatedAt == nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) currentCredentials() *oauthCredentials {
|
||||
c.access.RLock()
|
||||
defer c.access.RUnlock()
|
||||
return cloneCredentials(c.credentials)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) persistCredentials(credentials *oauthCredentials) {
|
||||
if credentials == nil {
|
||||
return
|
||||
}
|
||||
if err := platformWriteCredentials(credentials, c.credentialPath); err != nil {
|
||||
c.logger.Error("persist refreshed token for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
|
||||
return newCredentials.AccessToken, nil
|
||||
func (c *defaultCredential) shouldAttemptRefresh(credentials *oauthCredentials, force bool) bool {
|
||||
if credentials == nil || credentials.RefreshToken == "" {
|
||||
return false
|
||||
}
|
||||
if !slices.Contains(credentials.Scopes, "user:inference") {
|
||||
return false
|
||||
}
|
||||
if force {
|
||||
return true
|
||||
}
|
||||
return credentials.needsRefresh()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) tryRefreshCredentials(force bool) bool {
|
||||
latestCredentials, err := platformReadCredentials(c.credentialPath)
|
||||
if err == nil && latestCredentials != nil {
|
||||
c.absorbCredentials(latestCredentials)
|
||||
}
|
||||
currentCredentials := c.currentCredentials()
|
||||
if !c.shouldAttemptRefresh(currentCredentials, force) {
|
||||
return false
|
||||
}
|
||||
release, err := acquireCredentialLockFunc(c.configDir)
|
||||
if err != nil {
|
||||
c.logger.Debug("acquire credential lock for ", c.tag, ": ", err)
|
||||
return false
|
||||
}
|
||||
defer release()
|
||||
|
||||
latestCredentials, err = platformReadCredentials(c.credentialPath)
|
||||
if err == nil && latestCredentials != nil {
|
||||
c.absorbCredentials(latestCredentials)
|
||||
currentCredentials = latestCredentials
|
||||
} else {
|
||||
currentCredentials = c.currentCredentials()
|
||||
}
|
||||
if !c.shouldAttemptRefresh(currentCredentials, force) {
|
||||
return false
|
||||
}
|
||||
if err := platformCanWriteCredentials(c.credentialPath); err != nil {
|
||||
c.logger.Debug("credential file not writable for ", c.tag, ": ", err)
|
||||
return false
|
||||
}
|
||||
|
||||
baseCredentials := cloneCredentials(currentCredentials)
|
||||
refreshResult, retryDelay, err := refreshToken(c.serviceContext, c.forwardHTTPClient, currentCredentials)
|
||||
if err != nil {
|
||||
if retryDelay != 0 {
|
||||
c.logger.Debug("refresh token for ", c.tag, ": retry delay=", retryDelay, ", error=", err)
|
||||
} else {
|
||||
c.logger.Debug("refresh token for ", c.tag, ": ", err)
|
||||
}
|
||||
latestCredentials, readErr := platformReadCredentials(c.credentialPath)
|
||||
if readErr == nil && latestCredentials != nil {
|
||||
c.absorbCredentials(latestCredentials)
|
||||
return latestCredentials.AccessToken != "" && (latestCredentials.AccessToken != baseCredentials.AccessToken || !latestCredentials.needsRefresh())
|
||||
}
|
||||
return false
|
||||
}
|
||||
if refreshResult == nil || refreshResult.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
refreshedCredentials := cloneCredentials(refreshResult.Credentials)
|
||||
c.absorbCredentials(refreshedCredentials)
|
||||
c.persistCredentials(refreshedCredentials)
|
||||
|
||||
if refreshResult.TokenAccount != nil {
|
||||
c.absorbOAuthAccount(refreshResult.TokenAccount)
|
||||
c.persistOAuthAccount()
|
||||
}
|
||||
if c.needsProfileHydration() {
|
||||
profileSnapshot, profileErr := c.fetchProfileSnapshot(c.forwardHTTPClient, refreshedCredentials.AccessToken)
|
||||
if profileErr != nil {
|
||||
c.logger.Debug("fetch profile for ", c.tag, ": ", profileErr)
|
||||
} else if profileSnapshot != nil {
|
||||
credentialsChanged := c.applyProfileSnapshot(profileSnapshot)
|
||||
c.persistOAuthAccount()
|
||||
if credentialsChanged {
|
||||
c.persistCredentials(c.currentCredentials())
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) bool {
|
||||
latestCredentials, err := platformReadCredentials(c.credentialPath)
|
||||
if err == nil && latestCredentials != nil {
|
||||
c.absorbCredentials(latestCredentials)
|
||||
if latestCredentials.AccessToken != "" && latestCredentials.AccessToken != failedAccessToken {
|
||||
return true
|
||||
}
|
||||
}
|
||||
c.tryRefreshCredentials(true)
|
||||
currentCredentials := c.currentCredentials()
|
||||
return currentCredentials != nil && currentCredentials.AccessToken != "" && currentCredentials.AccessToken != failedAccessToken
|
||||
}
|
||||
|
||||
func (c *defaultCredential) applyProfileSnapshot(snapshot *claudeProfileSnapshot) bool {
|
||||
if snapshot == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
credentialsChanged := false
|
||||
c.access.Lock()
|
||||
if c.credentials != nil {
|
||||
updatedCredentials := cloneCredentials(c.credentials)
|
||||
if snapshot.SubscriptionType != nil {
|
||||
updatedCredentials.SubscriptionType = cloneStringPointer(snapshot.SubscriptionType)
|
||||
}
|
||||
if snapshot.RateLimitTier != "" {
|
||||
updatedCredentials.RateLimitTier = cloneStringPointer(&snapshot.RateLimitTier)
|
||||
}
|
||||
credentialsChanged = !credentialsEqual(c.credentials, updatedCredentials)
|
||||
c.credentials = updatedCredentials
|
||||
}
|
||||
c.access.Unlock()
|
||||
|
||||
c.stateAccess.Lock()
|
||||
before := c.statusSnapshotLocked()
|
||||
if snapshot.OAuthAccount != nil {
|
||||
c.state.oauthAccount = mergeClaudeOAuthAccount(c.state.oauthAccount, snapshot.OAuthAccount)
|
||||
if c.state.oauthAccount != nil && c.state.oauthAccount.AccountUUID != "" {
|
||||
c.state.accountUUID = c.state.oauthAccount.AccountUUID
|
||||
}
|
||||
}
|
||||
if snapshot.AccountType != "" {
|
||||
c.state.accountType = snapshot.AccountType
|
||||
}
|
||||
if snapshot.RateLimitTier != "" {
|
||||
c.state.rateLimitTier = snapshot.RateLimitTier
|
||||
}
|
||||
c.checkTransitionLocked()
|
||||
shouldEmit := before != c.statusSnapshotLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
return credentialsChanged
|
||||
}
|
||||
|
||||
func (c *defaultCredential) fetchProfileSnapshot(httpClient *http.Client, accessToken string) (*claudeProfileSnapshot, error) {
|
||||
ctx := c.serviceContext
|
||||
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/profile", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("User-Agent", ccmUserAgentValue)
|
||||
return request, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, E.New("status ", response.StatusCode, " ", string(body))
|
||||
}
|
||||
|
||||
var profileResponse struct {
|
||||
Account *struct {
|
||||
UUID string `json:"uuid"`
|
||||
Email string `json:"email"`
|
||||
DisplayName string `json:"display_name"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
} `json:"account"`
|
||||
Organization *struct {
|
||||
UUID string `json:"uuid"`
|
||||
OrganizationType string `json:"organization_type"`
|
||||
RateLimitTier string `json:"rate_limit_tier"`
|
||||
HasExtraUsageEnabled *bool `json:"has_extra_usage_enabled"`
|
||||
BillingType *string `json:"billing_type"`
|
||||
SubscriptionCreatedAt *string `json:"subscription_created_at"`
|
||||
} `json:"organization"`
|
||||
}
|
||||
if err := json.NewDecoder(response.Body).Decode(&profileResponse); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if profileResponse.Organization == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
accountType := normalizeClaudeOrganizationType(profileResponse.Organization.OrganizationType)
|
||||
snapshot := &claudeProfileSnapshot{
|
||||
AccountType: accountType,
|
||||
RateLimitTier: profileResponse.Organization.RateLimitTier,
|
||||
}
|
||||
if accountType != "" {
|
||||
snapshot.SubscriptionType = cloneStringPointer(&accountType)
|
||||
}
|
||||
account := &claudeOAuthAccount{}
|
||||
if profileResponse.Account != nil {
|
||||
account.AccountUUID = profileResponse.Account.UUID
|
||||
account.EmailAddress = profileResponse.Account.Email
|
||||
account.DisplayName = optionalStringPointer(profileResponse.Account.DisplayName)
|
||||
account.AccountCreatedAt = optionalStringPointer(profileResponse.Account.CreatedAt)
|
||||
}
|
||||
account.OrganizationUUID = profileResponse.Organization.UUID
|
||||
account.HasExtraUsageEnabled = cloneBoolPointer(profileResponse.Organization.HasExtraUsageEnabled)
|
||||
account.BillingType = cloneStringPointer(profileResponse.Organization.BillingType)
|
||||
account.SubscriptionCreatedAt = cloneStringPointer(profileResponse.Organization.SubscriptionCreatedAt)
|
||||
if account.AccountUUID != "" || account.EmailAddress != "" || account.OrganizationUUID != "" || account.DisplayName != nil ||
|
||||
account.HasExtraUsageEnabled != nil || account.BillingType != nil || account.AccountCreatedAt != nil || account.SubscriptionCreatedAt != nil {
|
||||
snapshot.OAuthAccount = account
|
||||
}
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
func normalizeClaudeOrganizationType(organizationType string) string {
|
||||
switch organizationType {
|
||||
case "claude_pro":
|
||||
return "pro"
|
||||
case "claude_max":
|
||||
return "max"
|
||||
case "claude_team":
|
||||
return "team"
|
||||
case "claude_enterprise":
|
||||
return "enterprise"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func optionalStringPointer(value string) *string {
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
return &value
|
||||
}
|
||||
|
||||
func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
|
||||
@@ -727,7 +951,7 @@ func (c *defaultCredential) pollUsage() {
|
||||
}
|
||||
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
needsProfileFetch := !c.configLoaded && c.state.rateLimitTier == ""
|
||||
needsProfileFetch := c.needsProfileHydrationLocked()
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
@@ -736,83 +960,19 @@ func (c *defaultCredential) pollUsage() {
|
||||
c.emitStatusUpdate()
|
||||
|
||||
if needsProfileFetch {
|
||||
c.fetchProfile(httpClient, accessToken)
|
||||
}
|
||||
}
|
||||
|
||||
// fetchProfile calls GET /api/oauth/profile to retrieve account and organization info.
|
||||
// Same endpoint used by Claude Code (@anthropic-ai/claude-code @2.1.81):
|
||||
//
|
||||
// ref: cli.js GB() — fetches profile
|
||||
// ref: cli.js AH8() / fetchProfileInfo — parses organization_type, rate_limit_tier
|
||||
// ref: cli.js EX1() / populateOAuthAccountInfoIfNeeded — stores account.uuid
|
||||
func (c *defaultCredential) fetchProfile(httpClient *http.Client, accessToken string) {
|
||||
ctx := c.serviceContext
|
||||
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/profile", nil)
|
||||
profileSnapshot, err := c.fetchProfileSnapshot(httpClient, accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
c.logger.Debug("fetch profile for ", c.tag, ": ", err)
|
||||
return
|
||||
}
|
||||
if profileSnapshot != nil {
|
||||
credentialsChanged := c.applyProfileSnapshot(profileSnapshot)
|
||||
c.persistOAuthAccount()
|
||||
if credentialsChanged {
|
||||
c.persistCredentials(c.currentCredentials())
|
||||
}
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("User-Agent", ccmUserAgentValue)
|
||||
return request, nil
|
||||
})
|
||||
if err != nil {
|
||||
c.logger.Debug("fetch profile for ", c.tag, ": ", err)
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
var profileResponse struct {
|
||||
Account *struct {
|
||||
UUID string `json:"uuid"`
|
||||
} `json:"account"`
|
||||
Organization *struct {
|
||||
OrganizationType string `json:"organization_type"`
|
||||
RateLimitTier string `json:"rate_limit_tier"`
|
||||
} `json:"organization"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&profileResponse)
|
||||
if err != nil || profileResponse.Organization == nil {
|
||||
return
|
||||
}
|
||||
|
||||
accountType := ""
|
||||
switch profileResponse.Organization.OrganizationType {
|
||||
case "claude_pro":
|
||||
accountType = "pro"
|
||||
case "claude_max":
|
||||
accountType = "max"
|
||||
case "claude_team":
|
||||
accountType = "team"
|
||||
case "claude_enterprise":
|
||||
accountType = "enterprise"
|
||||
}
|
||||
rateLimitTier := profileResponse.Organization.RateLimitTier
|
||||
|
||||
c.stateAccess.Lock()
|
||||
before := c.statusSnapshotLocked()
|
||||
if profileResponse.Account != nil && profileResponse.Account.UUID != "" {
|
||||
c.state.accountUUID = profileResponse.Account.UUID
|
||||
}
|
||||
if accountType != "" && c.state.accountType == "" {
|
||||
c.state.accountType = accountType
|
||||
}
|
||||
if rateLimitTier != "" {
|
||||
c.state.rateLimitTier = rateLimitTier
|
||||
}
|
||||
resolvedAccountType := c.state.accountType
|
||||
shouldEmit := before != c.statusSnapshotLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
c.logger.Info("fetched profile for ", c.tag, ": type=", resolvedAccountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(resolvedAccountType, rateLimitTier))
|
||||
}
|
||||
|
||||
func (c *defaultCredential) close() {
|
||||
|
||||
205
service/ccm/credential_default_test.go
Normal file
205
service/ccm/credential_default_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetAccessTokenReturnsExistingTokenWhenLockFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := t.TempDir()
|
||||
credentialPath := filepath.Join(directory, ".credentials.json")
|
||||
writeTestCredentials(t, credentialPath, &oauthCredentials{
|
||||
AccessToken: "old-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
|
||||
Scopes: []string{"user:profile", "user:inference"},
|
||||
SubscriptionType: optionalStringPointer("max"),
|
||||
RateLimitTier: optionalStringPointer("default_claude_max_20x"),
|
||||
})
|
||||
|
||||
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
t.Fatal("refresh should not be attempted when lock acquisition fails")
|
||||
return nil, nil
|
||||
}))
|
||||
if err := credential.reloadCredentials(true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
originalLockFunc := acquireCredentialLockFunc
|
||||
acquireCredentialLockFunc = func(string) (func(), error) {
|
||||
return nil, errors.New("locked")
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
acquireCredentialLockFunc = originalLockFunc
|
||||
})
|
||||
|
||||
token, err := credential.getAccessToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if token != "old-token" {
|
||||
t.Fatalf("expected old token, got %q", token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccessTokenAbsorbsRefreshDoneByAnotherProcess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := t.TempDir()
|
||||
credentialPath := filepath.Join(directory, ".credentials.json")
|
||||
oldCredentials := &oauthCredentials{
|
||||
AccessToken: "old-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
|
||||
Scopes: []string{"user:profile", "user:inference"},
|
||||
SubscriptionType: optionalStringPointer("max"),
|
||||
RateLimitTier: optionalStringPointer("default_claude_max_20x"),
|
||||
}
|
||||
writeTestCredentials(t, credentialPath, oldCredentials)
|
||||
|
||||
newCredentials := cloneCredentials(oldCredentials)
|
||||
newCredentials.AccessToken = "new-token"
|
||||
newCredentials.ExpiresAt = time.Now().Add(time.Hour).UnixMilli()
|
||||
transport := roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
if request.URL.Path == "/v1/oauth/token" {
|
||||
writeTestCredentials(t, credentialPath, newCredentials)
|
||||
return newJSONResponse(http.StatusInternalServerError, `{"error":"boom"}`), nil
|
||||
}
|
||||
t.Fatalf("unexpected path %s", request.URL.Path)
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
credential := newTestDefaultCredential(t, credentialPath, transport)
|
||||
if err := credential.reloadCredentials(true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
token, err := credential.getAccessToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if token != "new-token" {
|
||||
t.Fatalf("expected refreshed token from disk, got %q", token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomCredentialPathDoesNotEnableClaudeConfigSync(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := t.TempDir()
|
||||
credentialPath := filepath.Join(directory, ".credentials.json")
|
||||
writeTestCredentials(t, credentialPath, &oauthCredentials{
|
||||
AccessToken: "token",
|
||||
ExpiresAt: time.Now().Add(time.Hour).UnixMilli(),
|
||||
Scopes: []string{"user:profile"},
|
||||
})
|
||||
|
||||
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
t.Fatalf("unexpected request to %s", request.URL.Path)
|
||||
return nil, nil
|
||||
}))
|
||||
if err := credential.reloadCredentials(true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
token, err := credential.getAccessToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if token != "token" {
|
||||
t.Fatalf("expected token, got %q", token)
|
||||
}
|
||||
if credential.shouldUseClaudeConfig() {
|
||||
t.Fatal("custom credential path should not enable Claude config sync")
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(directory, ".claude.json")); !os.IsNotExist(err) {
|
||||
t.Fatalf("did not expect config file to be created, stat err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCredentialHydratesProfileAndWritesConfig(t *testing.T) {
|
||||
configDir := t.TempDir()
|
||||
credentialPath := filepath.Join(configDir, ".credentials.json")
|
||||
|
||||
writeTestCredentials(t, credentialPath, &oauthCredentials{
|
||||
AccessToken: "old-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
|
||||
Scopes: []string{"user:profile", "user:inference"},
|
||||
})
|
||||
|
||||
transport := roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
switch request.URL.Path {
|
||||
case "/v1/oauth/token":
|
||||
return newJSONResponse(http.StatusOK, `{
|
||||
"access_token":"new-token",
|
||||
"refresh_token":"new-refresh",
|
||||
"expires_in":3600,
|
||||
"account":{"uuid":"account","email_address":"user@example.com"},
|
||||
"organization":{"uuid":"org"}
|
||||
}`), nil
|
||||
case "/api/oauth/profile":
|
||||
return newJSONResponse(http.StatusOK, `{
|
||||
"account":{
|
||||
"uuid":"account",
|
||||
"email":"user@example.com",
|
||||
"display_name":"User",
|
||||
"created_at":"2024-01-01T00:00:00Z"
|
||||
},
|
||||
"organization":{
|
||||
"uuid":"org",
|
||||
"organization_type":"claude_max",
|
||||
"rate_limit_tier":"default_claude_max_20x",
|
||||
"has_extra_usage_enabled":true,
|
||||
"billing_type":"individual",
|
||||
"subscription_created_at":"2024-01-02T00:00:00Z"
|
||||
}
|
||||
}`), nil
|
||||
default:
|
||||
t.Fatalf("unexpected path %s", request.URL.Path)
|
||||
return nil, nil
|
||||
}
|
||||
})
|
||||
|
||||
credential := newTestDefaultCredential(t, credentialPath, transport)
|
||||
credential.syncClaudeConfig = true
|
||||
credential.claudeDirectory = configDir
|
||||
credential.claudeConfigPath = resolveClaudeConfigWritePath(configDir)
|
||||
if err := credential.reloadCredentials(true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
token, err := credential.getAccessToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if token != "new-token" {
|
||||
t.Fatalf("expected refreshed token, got %q", token)
|
||||
}
|
||||
|
||||
updatedCredentials := readTestCredentials(t, credentialPath)
|
||||
if updatedCredentials.SubscriptionType == nil || *updatedCredentials.SubscriptionType != "max" {
|
||||
t.Fatalf("expected subscription type to be persisted, got %#v", updatedCredentials.SubscriptionType)
|
||||
}
|
||||
if updatedCredentials.RateLimitTier == nil || *updatedCredentials.RateLimitTier != "default_claude_max_20x" {
|
||||
t.Fatalf("expected rate limit tier to be persisted, got %#v", updatedCredentials.RateLimitTier)
|
||||
}
|
||||
|
||||
configPath := tempConfigPath(t, configDir)
|
||||
config, err := readClaudeCodeConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.OAuthAccount == nil || config.OAuthAccount.AccountUUID != "account" || config.OAuthAccount.EmailAddress != "user@example.com" {
|
||||
t.Fatalf("unexpected oauth account: %#v", config.OAuthAccount)
|
||||
}
|
||||
if config.OAuthAccount.BillingType == nil || *config.OAuthAccount.BillingType != "individual" {
|
||||
t.Fatalf("expected billing type to be hydrated, got %#v", config.OAuthAccount.BillingType)
|
||||
}
|
||||
}
|
||||
@@ -106,24 +106,7 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
|
||||
return c.markCredentialsUnavailable(E.Cause(err, "read credentials"))
|
||||
}
|
||||
|
||||
c.access.Lock()
|
||||
c.credentials = credentials
|
||||
c.refreshRetryAt = time.Time{}
|
||||
c.refreshRetryError = nil
|
||||
c.refreshBlocked = false
|
||||
c.access.Unlock()
|
||||
|
||||
c.stateAccess.Lock()
|
||||
before := c.statusSnapshotLocked()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.checkTransitionLocked()
|
||||
shouldEmit := before != c.statusSnapshotLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
c.absorbCredentials(credentials)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -164,21 +164,7 @@ func checkCredentialFileWritable(path string) error {
|
||||
// ref (@anthropic-ai/claude-code @2.1.81): cli.js BP6 (line 179444-179454) — read-modify-write
|
||||
// ref: cli.js qD1.update (line 176156) — writeFileSync + chmod 0o600
|
||||
func writeCredentialsToFile(credentials *oauthCredentials, path string) error {
|
||||
existing := make(map[string]json.RawMessage)
|
||||
data, readErr := os.ReadFile(path)
|
||||
if readErr == nil {
|
||||
_ = json.Unmarshal(data, &existing)
|
||||
}
|
||||
credentialData, err := json.Marshal(credentials)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
existing["claudeAiOauth"] = credentialData
|
||||
data, err = json.MarshalIndent(existing, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0o600)
|
||||
return writeStorageValue(jsonFileStorage{path: path}, "claudeAiOauth", credentials)
|
||||
}
|
||||
|
||||
// oauthCredentials mirrors the claudeAiOauth object in Claude Code's
|
||||
@@ -194,6 +180,12 @@ type oauthCredentials struct {
|
||||
RateLimitTier *string `json:"rateLimitTier"` // ref: cli.js line 179452 (?? null)
|
||||
}
|
||||
|
||||
type oauthRefreshResult struct {
|
||||
Credentials *oauthCredentials
|
||||
TokenAccount *claudeOAuthAccount
|
||||
Profile *claudeProfileSnapshot
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) needsRefresh() bool {
|
||||
if c.ExpiresAt == 0 {
|
||||
return false
|
||||
@@ -201,7 +193,7 @@ func (c *oauthCredentials) needsRefresh() bool {
|
||||
return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs
|
||||
}
|
||||
|
||||
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, time.Duration, error) {
|
||||
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthRefreshResult, time.Duration, error) {
|
||||
if credentials.RefreshToken == "" {
|
||||
return nil, 0, E.New("refresh token is empty")
|
||||
}
|
||||
@@ -249,10 +241,17 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau
|
||||
|
||||
// ref (@anthropic-ai/claude-code @2.1.81): cli.js mB6 response (line 172769-172772)
|
||||
var tokenResponse struct {
|
||||
AccessToken string `json:"access_token"` // ref: cli.js line 172770 z
|
||||
RefreshToken string `json:"refresh_token"` // ref: cli.js line 172770 w (defaults to input)
|
||||
ExpiresIn int `json:"expires_in"` // ref: cli.js line 172770 O
|
||||
Scope string `json:"scope"` // ref: cli.js line 172772 uB6(Y.scope)
|
||||
AccessToken string `json:"access_token"` // ref: cli.js line 172770 z
|
||||
RefreshToken string `json:"refresh_token"` // ref: cli.js line 172770 w (defaults to input)
|
||||
ExpiresIn int `json:"expires_in"` // ref: cli.js line 172770 O
|
||||
Scope *string `json:"scope"` // ref: cli.js line 172772 uB6(Y.scope)
|
||||
Account *struct {
|
||||
UUID string `json:"uuid"`
|
||||
EmailAddress string `json:"email_address"`
|
||||
} `json:"account"`
|
||||
Organization *struct {
|
||||
UUID string `json:"uuid"`
|
||||
} `json:"organization"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
|
||||
if err != nil {
|
||||
@@ -267,11 +266,14 @@ func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oau
|
||||
newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000
|
||||
// ref: cli.js uB6 (line 172696-172697): A?.split(" ").filter(Boolean)
|
||||
// strings.Fields matches .filter(Boolean): splits on whitespace runs, removes empty strings
|
||||
if tokenResponse.Scope != "" {
|
||||
newCredentials.Scopes = strings.Fields(tokenResponse.Scope)
|
||||
if tokenResponse.Scope != nil {
|
||||
newCredentials.Scopes = strings.Fields(*tokenResponse.Scope)
|
||||
}
|
||||
|
||||
return &newCredentials, 0, nil
|
||||
return &oauthRefreshResult{
|
||||
Credentials: &newCredentials,
|
||||
TokenAccount: extractTokenAccount(tokenResponse.Account, tokenResponse.Organization),
|
||||
}, 0, nil
|
||||
}
|
||||
|
||||
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
|
||||
@@ -280,6 +282,8 @@ func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
|
||||
}
|
||||
cloned := *credentials
|
||||
cloned.Scopes = append([]string(nil), credentials.Scopes...)
|
||||
cloned.SubscriptionType = cloneStringPointer(credentials.SubscriptionType)
|
||||
cloned.RateLimitTier = cloneStringPointer(credentials.RateLimitTier)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
@@ -290,5 +294,31 @@ func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
|
||||
return left.AccessToken == right.AccessToken &&
|
||||
left.RefreshToken == right.RefreshToken &&
|
||||
left.ExpiresAt == right.ExpiresAt &&
|
||||
slices.Equal(left.Scopes, right.Scopes)
|
||||
slices.Equal(left.Scopes, right.Scopes) &&
|
||||
equalStringPointer(left.SubscriptionType, right.SubscriptionType) &&
|
||||
equalStringPointer(left.RateLimitTier, right.RateLimitTier)
|
||||
}
|
||||
|
||||
func extractTokenAccount(account *struct {
|
||||
UUID string `json:"uuid"`
|
||||
EmailAddress string `json:"email_address"`
|
||||
}, organization *struct {
|
||||
UUID string `json:"uuid"`
|
||||
},
|
||||
) *claudeOAuthAccount {
|
||||
if account == nil && organization == nil {
|
||||
return nil
|
||||
}
|
||||
tokenAccount := &claudeOAuthAccount{}
|
||||
if account != nil {
|
||||
tokenAccount.AccountUUID = account.UUID
|
||||
tokenAccount.EmailAddress = account.EmailAddress
|
||||
}
|
||||
if organization != nil {
|
||||
tokenAccount.OrganizationUUID = organization.UUID
|
||||
}
|
||||
if tokenAccount.AccountUUID == "" && tokenAccount.EmailAddress == "" && tokenAccount.OrganizationUUID == "" {
|
||||
return nil
|
||||
}
|
||||
return tokenAccount
|
||||
}
|
||||
|
||||
141
service/ccm/credential_oauth_test.go
Normal file
141
service/ccm/credential_oauth_test.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRefreshTokenScopeParsing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
storedScopes []string
|
||||
responseBody string
|
||||
expectedScope string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "missing scope preserves stored scopes",
|
||||
storedScopes: []string{"user:profile", "user:inference"},
|
||||
responseBody: `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`,
|
||||
expectedScope: strings.Join(defaultOAuthScopes, " "),
|
||||
expected: []string{"user:profile", "user:inference"},
|
||||
},
|
||||
{
|
||||
name: "empty scope clears stored scopes",
|
||||
storedScopes: []string{"user:profile", "user:inference"},
|
||||
responseBody: `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600,"scope":""}`,
|
||||
expectedScope: strings.Join(defaultOAuthScopes, " "),
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "stored non inference scopes are sent verbatim",
|
||||
storedScopes: []string{"user:profile"},
|
||||
responseBody: `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600,"scope":"user:profile user:file_upload"}`,
|
||||
expectedScope: "user:profile",
|
||||
expected: []string{"user:profile", "user:file_upload"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var seenScope string
|
||||
client := &http.Client{Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
body, err := io.ReadAll(request.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var payload map[string]string
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
seenScope = payload["scope"]
|
||||
return newJSONResponse(http.StatusOK, testCase.responseBody), nil
|
||||
})}
|
||||
|
||||
result, _, err := refreshToken(context.Background(), client, &oauthCredentials{
|
||||
AccessToken: "old-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
|
||||
Scopes: testCase.storedScopes,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if seenScope != testCase.expectedScope {
|
||||
t.Fatalf("expected request scope %q, got %q", testCase.expectedScope, seenScope)
|
||||
}
|
||||
if result == nil || result.Credentials == nil {
|
||||
t.Fatal("expected refresh result credentials")
|
||||
}
|
||||
if !slices.Equal(result.Credentials.Scopes, testCase.expected) {
|
||||
t.Fatalf("expected scopes %v, got %v", testCase.expected, result.Credentials.Scopes)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshTokenExtractsTokenAccount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := &http.Client{Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
return newJSONResponse(http.StatusOK, `{
|
||||
"access_token":"new-token",
|
||||
"refresh_token":"new-refresh",
|
||||
"expires_in":3600,
|
||||
"account":{"uuid":"account","email_address":"user@example.com"},
|
||||
"organization":{"uuid":"org"}
|
||||
}`), nil
|
||||
})}
|
||||
|
||||
result, _, err := refreshToken(context.Background(), client, &oauthCredentials{
|
||||
AccessToken: "old-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
|
||||
Scopes: []string{"user:profile", "user:inference"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result == nil || result.TokenAccount == nil {
|
||||
t.Fatal("expected token account")
|
||||
}
|
||||
if result.TokenAccount.AccountUUID != "account" || result.TokenAccount.EmailAddress != "user@example.com" || result.TokenAccount.OrganizationUUID != "org" {
|
||||
t.Fatalf("unexpected token account: %#v", result.TokenAccount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCredentialsEqualIncludesProfileFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
subscriptionType := "max"
|
||||
rateLimitTier := "default_claude_max_20x"
|
||||
left := &oauthCredentials{
|
||||
AccessToken: "token",
|
||||
RefreshToken: "refresh",
|
||||
ExpiresAt: 123,
|
||||
Scopes: []string{"user:inference"},
|
||||
SubscriptionType: &subscriptionType,
|
||||
RateLimitTier: &rateLimitTier,
|
||||
}
|
||||
right := cloneCredentials(left)
|
||||
if !credentialsEqual(left, right) {
|
||||
t.Fatal("expected cloned credentials to be equal")
|
||||
}
|
||||
|
||||
otherTier := "default_claude_max_5x"
|
||||
right.RateLimitTier = &otherTier
|
||||
if credentialsEqual(left, right) {
|
||||
t.Fatal("expected different rate limit tier to break equality")
|
||||
}
|
||||
}
|
||||
124
service/ccm/credential_storage.go
Normal file
124
service/ccm/credential_storage.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type jsonContainerStorage interface {
|
||||
readContainer() (map[string]json.RawMessage, bool, error)
|
||||
writeContainer(map[string]json.RawMessage) error
|
||||
delete() error
|
||||
}
|
||||
|
||||
type jsonFileStorage struct {
|
||||
path string
|
||||
}
|
||||
|
||||
func (s jsonFileStorage) readContainer() (map[string]json.RawMessage, bool, error) {
|
||||
data, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return make(map[string]json.RawMessage), false, nil
|
||||
}
|
||||
return nil, false, err
|
||||
}
|
||||
container := make(map[string]json.RawMessage)
|
||||
if len(data) == 0 {
|
||||
return container, true, nil
|
||||
}
|
||||
if err := json.Unmarshal(data, &container); err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
return container, true, nil
|
||||
}
|
||||
|
||||
func (s jsonFileStorage) writeContainer(container map[string]json.RawMessage) error {
|
||||
if err := os.MkdirAll(filepath.Dir(s.path), 0o700); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.MarshalIndent(container, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(s.path, data, 0o600)
|
||||
}
|
||||
|
||||
func (s jsonFileStorage) delete() error {
|
||||
err := os.Remove(s.path)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeStorageValue(storage jsonContainerStorage, key string, value any) error {
|
||||
container, _, err := storage.readContainer()
|
||||
if err != nil {
|
||||
var syntaxError *json.SyntaxError
|
||||
var typeError *json.UnmarshalTypeError
|
||||
if !errors.As(err, &syntaxError) && !errors.As(err, &typeError) {
|
||||
return err
|
||||
}
|
||||
container = make(map[string]json.RawMessage)
|
||||
}
|
||||
if container == nil {
|
||||
container = make(map[string]json.RawMessage)
|
||||
}
|
||||
encodedValue, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
container[key] = encodedValue
|
||||
return storage.writeContainer(container)
|
||||
}
|
||||
|
||||
func persistStorageValue(primary jsonContainerStorage, fallback jsonContainerStorage, key string, value any) error {
|
||||
primaryErr := writeStorageValue(primary, key, value)
|
||||
if primaryErr == nil {
|
||||
if fallback != nil {
|
||||
_ = fallback.delete()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if fallback == nil {
|
||||
return primaryErr
|
||||
}
|
||||
if err := writeStorageValue(fallback, key, value); err != nil {
|
||||
return err
|
||||
}
|
||||
_ = primary.delete()
|
||||
return nil
|
||||
}
|
||||
|
||||
func cloneStringPointer(value *string) *string {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *value
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func cloneBoolPointer(value *bool) *bool {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *value
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func equalStringPointer(left *string, right *string) bool {
|
||||
if left == nil || right == nil {
|
||||
return left == right
|
||||
}
|
||||
return *left == *right
|
||||
}
|
||||
|
||||
func equalBoolPointer(left *bool, right *bool) bool {
|
||||
if left == nil || right == nil {
|
||||
return left == right
|
||||
}
|
||||
return *left == *right
|
||||
}
|
||||
125
service/ccm/credential_storage_test.go
Normal file
125
service/ccm/credential_storage_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fakeJSONStorage struct {
|
||||
container map[string]json.RawMessage
|
||||
writeErr error
|
||||
deleted bool
|
||||
}
|
||||
|
||||
func (s *fakeJSONStorage) readContainer() (map[string]json.RawMessage, bool, error) {
|
||||
if s.container == nil {
|
||||
return make(map[string]json.RawMessage), false, nil
|
||||
}
|
||||
cloned := make(map[string]json.RawMessage, len(s.container))
|
||||
for key, value := range s.container {
|
||||
cloned[key] = value
|
||||
}
|
||||
return cloned, true, nil
|
||||
}
|
||||
|
||||
func (s *fakeJSONStorage) writeContainer(container map[string]json.RawMessage) error {
|
||||
if s.writeErr != nil {
|
||||
return s.writeErr
|
||||
}
|
||||
s.container = make(map[string]json.RawMessage, len(container))
|
||||
for key, value := range container {
|
||||
s.container[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeJSONStorage) delete() error {
|
||||
s.deleted = true
|
||||
s.container = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestPersistStorageValueDeletesFallbackOnPrimarySuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
primary := &fakeJSONStorage{}
|
||||
fallback := &fakeJSONStorage{container: map[string]json.RawMessage{"stale": json.RawMessage(`true`)}}
|
||||
if err := persistStorageValue(primary, fallback, "claudeAiOauth", &oauthCredentials{AccessToken: "token"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !fallback.deleted {
|
||||
t.Fatal("expected fallback storage to be deleted after primary write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistStorageValueDeletesPrimaryAfterFallbackSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
primary := &fakeJSONStorage{
|
||||
container: map[string]json.RawMessage{"claudeAiOauth": json.RawMessage(`{"accessToken":"old"}`)},
|
||||
writeErr: os.ErrPermission,
|
||||
}
|
||||
fallback := &fakeJSONStorage{}
|
||||
if err := persistStorageValue(primary, fallback, "claudeAiOauth", &oauthCredentials{AccessToken: "new"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !primary.deleted {
|
||||
t.Fatal("expected primary storage to be deleted after fallback write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteCredentialsToFilePreservesTopLevelKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := t.TempDir()
|
||||
path := filepath.Join(directory, ".credentials.json")
|
||||
initial := []byte(`{"keep":{"nested":true},"claudeAiOauth":{"accessToken":"old"}}`)
|
||||
if err := os.WriteFile(path, initial, 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := writeCredentialsToFile(&oauthCredentials{AccessToken: "new"}, path); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var container map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &container); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, exists := container["keep"]; !exists {
|
||||
t.Fatal("expected unknown top-level key to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteClaudeCodeOAuthAccountPreservesTopLevelKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := t.TempDir()
|
||||
path := filepath.Join(directory, ".claude.json")
|
||||
initial := []byte(`{"keep":{"nested":true},"oauthAccount":{"accountUuid":"old"}}`)
|
||||
if err := os.WriteFile(path, initial, 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := writeClaudeCodeOAuthAccount(path, &claudeOAuthAccount{AccountUUID: "new"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var container map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &container); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, exists := container["keep"]; !exists {
|
||||
t.Fatal("expected unknown config key to be preserved")
|
||||
}
|
||||
}
|
||||
@@ -377,8 +377,9 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if !selectedCredential.isExternal() && bodyBytes != nil &&
|
||||
(response.StatusCode == http.StatusUnauthorized || response.StatusCode == http.StatusForbidden) {
|
||||
shouldRetry := response.StatusCode == http.StatusUnauthorized
|
||||
var peekBody []byte
|
||||
if response.StatusCode == http.StatusForbidden {
|
||||
peekBody, _ := io.ReadAll(response.Body)
|
||||
peekBody, _ = io.ReadAll(response.Body)
|
||||
shouldRetry = strings.Contains(string(peekBody), "OAuth token has been revoked")
|
||||
if !shouldRetry {
|
||||
response.Body.Close()
|
||||
@@ -389,23 +390,33 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
if shouldRetry {
|
||||
response.Body.Close()
|
||||
s.logger.WarnContext(ctx, "upstream auth failure from ", selectedCredential.tagName(), ", reloading credentials and retrying")
|
||||
recovered := false
|
||||
if defaultCred, ok := selectedCredential.(*defaultCredential); ok {
|
||||
_ = defaultCred.reloadCredentials(true)
|
||||
failedAccessToken := ""
|
||||
currentCredentials := defaultCred.currentCredentials()
|
||||
if currentCredentials != nil {
|
||||
failedAccessToken = currentCredentials.AccessToken
|
||||
}
|
||||
s.logger.WarnContext(ctx, "upstream auth failure from ", selectedCredential.tagName(), ", reloading credentials and retrying")
|
||||
recovered = defaultCred.recoverAuthFailure(failedAccessToken)
|
||||
}
|
||||
retryRequest, buildErr := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if buildErr != nil {
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(buildErr, "rebuild request after auth recovery").Error())
|
||||
return
|
||||
if recovered {
|
||||
response.Body.Close()
|
||||
retryRequest, buildErr := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if buildErr != nil {
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(buildErr, "rebuild request after auth recovery").Error())
|
||||
return
|
||||
}
|
||||
retryResponse, retryErr := selectedCredential.httpClient().Do(retryRequest)
|
||||
if retryErr != nil {
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(retryErr, "retry request after auth recovery").Error())
|
||||
return
|
||||
}
|
||||
response = retryResponse
|
||||
defer retryResponse.Body.Close()
|
||||
} else if response.StatusCode == http.StatusForbidden {
|
||||
response.Body = io.NopCloser(bytes.NewReader(peekBody))
|
||||
}
|
||||
retryResponse, retryErr := selectedCredential.httpClient().Do(retryRequest)
|
||||
if retryErr != nil {
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(retryErr, "retry request after auth recovery").Error())
|
||||
return
|
||||
}
|
||||
response = retryResponse
|
||||
defer retryResponse.Body.Close()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
221
service/ccm/service_handler_test.go
Normal file
221
service/ccm/service_handler_test.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newHandlerCredential(t *testing.T, transport http.RoundTripper) (*defaultCredential, string) {
|
||||
t.Helper()
|
||||
directory := t.TempDir()
|
||||
credentialPath := filepath.Join(directory, ".credentials.json")
|
||||
writeTestCredentials(t, credentialPath, &oauthCredentials{
|
||||
AccessToken: "old-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(time.Hour).UnixMilli(),
|
||||
Scopes: []string{"user:profile", "user:inference"},
|
||||
SubscriptionType: optionalStringPointer("max"),
|
||||
RateLimitTier: optionalStringPointer("default_claude_max_20x"),
|
||||
})
|
||||
credential := newTestDefaultCredential(t, credentialPath, transport)
|
||||
if err := credential.reloadCredentials(true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
seedTestCredentialState(credential)
|
||||
return credential, credentialPath
|
||||
}
|
||||
|
||||
func TestServiceHandlerRecoversFrom401(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var messageRequests atomic.Int32
|
||||
var refreshRequests atomic.Int32
|
||||
credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
switch request.URL.Path {
|
||||
case "/v1/messages":
|
||||
call := messageRequests.Add(1)
|
||||
switch request.Header.Get("Authorization") {
|
||||
case "Bearer old-token":
|
||||
if call != 1 {
|
||||
t.Fatalf("unexpected old-token call count %d", call)
|
||||
}
|
||||
return newTextResponse(http.StatusUnauthorized, "unauthorized"), nil
|
||||
case "Bearer new-token":
|
||||
return newJSONResponse(http.StatusOK, `{}`), nil
|
||||
default:
|
||||
t.Fatalf("unexpected authorization header %q", request.Header.Get("Authorization"))
|
||||
}
|
||||
case "/v1/oauth/token":
|
||||
refreshRequests.Add(1)
|
||||
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
|
||||
default:
|
||||
t.Fatalf("unexpected path %s", request.URL.Path)
|
||||
}
|
||||
return nil, nil
|
||||
}))
|
||||
|
||||
service := newTestService(credential)
|
||||
recorder := httptest.NewRecorder()
|
||||
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
if messageRequests.Load() != 2 {
|
||||
t.Fatalf("expected two upstream message requests, got %d", messageRequests.Load())
|
||||
}
|
||||
if refreshRequests.Load() != 1 {
|
||||
t.Fatalf("expected one refresh request, got %d", refreshRequests.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceHandlerRecoversFromRevoked403(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var messageRequests atomic.Int32
|
||||
var refreshRequests atomic.Int32
|
||||
credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
switch request.URL.Path {
|
||||
case "/v1/messages":
|
||||
messageRequests.Add(1)
|
||||
if request.Header.Get("Authorization") == "Bearer old-token" {
|
||||
return newTextResponse(http.StatusForbidden, "OAuth token has been revoked"), nil
|
||||
}
|
||||
return newJSONResponse(http.StatusOK, `{}`), nil
|
||||
case "/v1/oauth/token":
|
||||
refreshRequests.Add(1)
|
||||
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
|
||||
default:
|
||||
t.Fatalf("unexpected path %s", request.URL.Path)
|
||||
}
|
||||
return nil, nil
|
||||
}))
|
||||
|
||||
service := newTestService(credential)
|
||||
recorder := httptest.NewRecorder()
|
||||
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
if refreshRequests.Load() != 1 {
|
||||
t.Fatalf("expected one refresh request, got %d", refreshRequests.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceHandlerDoesNotRecoverFromOrdinary403(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var refreshRequests atomic.Int32
|
||||
credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
switch request.URL.Path {
|
||||
case "/v1/messages":
|
||||
return newTextResponse(http.StatusForbidden, "forbidden"), nil
|
||||
case "/v1/oauth/token":
|
||||
refreshRequests.Add(1)
|
||||
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
|
||||
default:
|
||||
t.Fatalf("unexpected path %s", request.URL.Path)
|
||||
}
|
||||
return nil, nil
|
||||
}))
|
||||
|
||||
service := newTestService(credential)
|
||||
recorder := httptest.NewRecorder()
|
||||
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
|
||||
|
||||
if recorder.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d", recorder.Code)
|
||||
}
|
||||
if refreshRequests.Load() != 0 {
|
||||
t.Fatalf("expected no refresh request, got %d", refreshRequests.Load())
|
||||
}
|
||||
if !strings.Contains(recorder.Body.String(), "forbidden") {
|
||||
t.Fatalf("expected forbidden body, got %s", recorder.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceHandlerUsesReloadedTokenBeforeRefreshing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var messageRequests atomic.Int32
|
||||
var refreshRequests atomic.Int32
|
||||
var credentialPath string
|
||||
var credential *defaultCredential
|
||||
credential, credentialPath = newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
switch request.URL.Path {
|
||||
case "/v1/messages":
|
||||
call := messageRequests.Add(1)
|
||||
if request.Header.Get("Authorization") == "Bearer old-token" {
|
||||
updatedCredentials := readTestCredentials(t, credentialPath)
|
||||
updatedCredentials.AccessToken = "disk-token"
|
||||
updatedCredentials.ExpiresAt = time.Now().Add(time.Hour).UnixMilli()
|
||||
writeTestCredentials(t, credentialPath, updatedCredentials)
|
||||
if call != 1 {
|
||||
t.Fatalf("unexpected old-token call count %d", call)
|
||||
}
|
||||
return newTextResponse(http.StatusUnauthorized, "unauthorized"), nil
|
||||
}
|
||||
if request.Header.Get("Authorization") != "Bearer disk-token" {
|
||||
t.Fatalf("expected disk token retry, got %q", request.Header.Get("Authorization"))
|
||||
}
|
||||
return newJSONResponse(http.StatusOK, `{}`), nil
|
||||
case "/v1/oauth/token":
|
||||
refreshRequests.Add(1)
|
||||
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
|
||||
default:
|
||||
t.Fatalf("unexpected path %s", request.URL.Path)
|
||||
}
|
||||
return nil, nil
|
||||
}))
|
||||
|
||||
service := newTestService(credential)
|
||||
recorder := httptest.NewRecorder()
|
||||
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
if refreshRequests.Load() != 0 {
|
||||
t.Fatalf("expected zero refresh requests, got %d", refreshRequests.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceHandlerRetriesAuthRecoveryOnlyOnce(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var messageRequests atomic.Int32
|
||||
var refreshRequests atomic.Int32
|
||||
credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
switch request.URL.Path {
|
||||
case "/v1/messages":
|
||||
messageRequests.Add(1)
|
||||
return newTextResponse(http.StatusUnauthorized, "still unauthorized"), nil
|
||||
case "/v1/oauth/token":
|
||||
refreshRequests.Add(1)
|
||||
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
|
||||
default:
|
||||
t.Fatalf("unexpected path %s", request.URL.Path)
|
||||
}
|
||||
return nil, nil
|
||||
}))
|
||||
|
||||
service := newTestService(credential)
|
||||
recorder := httptest.NewRecorder()
|
||||
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
|
||||
|
||||
if recorder.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d", recorder.Code)
|
||||
}
|
||||
if messageRequests.Load() != 2 {
|
||||
t.Fatalf("expected exactly two upstream attempts, got %d", messageRequests.Load())
|
||||
}
|
||||
if refreshRequests.Load() != 1 {
|
||||
t.Fatalf("expected exactly one refresh request, got %d", refreshRequests.Load())
|
||||
}
|
||||
}
|
||||
138
service/ccm/test_helpers_test.go
Normal file
138
service/ccm/test_helpers_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
return f(request)
|
||||
}
|
||||
|
||||
func newJSONResponse(statusCode int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Status: http.StatusText(statusCode),
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func newTextResponse(statusCode int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Status: http.StatusText(statusCode),
|
||||
Header: http.Header{"Content-Type": []string{"text/plain"}},
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func writeTestCredentials(t *testing.T, path string, credentials *oauthCredentials) {
|
||||
t.Helper()
|
||||
if path == "" {
|
||||
var err error
|
||||
path, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if err := writeCredentialsToFile(credentials, path); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func readTestCredentials(t *testing.T, path string) *oauthCredentials {
|
||||
t.Helper()
|
||||
if path == "" {
|
||||
var err error
|
||||
path, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
credentials, err := readCredentialsFromFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return credentials
|
||||
}
|
||||
|
||||
func newTestDefaultCredential(t *testing.T, credentialPath string, transport http.RoundTripper) *defaultCredential {
|
||||
t.Helper()
|
||||
credentialFilePath, err := resolveCredentialFilePath(credentialPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
requestContext, cancelRequests := context.WithCancel(context.Background())
|
||||
credential := &defaultCredential{
|
||||
tag: "test",
|
||||
serviceContext: context.Background(),
|
||||
credentialPath: credentialPath,
|
||||
credentialFilePath: credentialFilePath,
|
||||
configDir: resolveConfigDir(credentialPath, credentialFilePath),
|
||||
syncClaudeConfig: credentialPath == "",
|
||||
cap5h: 99,
|
||||
capWeekly: 99,
|
||||
forwardHTTPClient: &http.Client{Transport: transport},
|
||||
logger: log.NewNOPFactory().Logger(),
|
||||
requestContext: requestContext,
|
||||
cancelRequests: cancelRequests,
|
||||
}
|
||||
if credential.syncClaudeConfig {
|
||||
credential.claudeDirectory = credential.configDir
|
||||
credential.claudeConfigPath = resolveClaudeConfigWritePath(credential.claudeDirectory)
|
||||
}
|
||||
credential.state.lastUpdated = time.Now()
|
||||
return credential
|
||||
}
|
||||
|
||||
func seedTestCredentialState(credential *defaultCredential) {
|
||||
billingType := "individual"
|
||||
accountCreatedAt := "2024-01-01T00:00:00Z"
|
||||
subscriptionCreatedAt := "2024-01-02T00:00:00Z"
|
||||
credential.stateAccess.Lock()
|
||||
credential.state.accountUUID = "account"
|
||||
credential.state.accountType = "max"
|
||||
credential.state.rateLimitTier = "default_claude_max_20x"
|
||||
credential.state.oauthAccount = &claudeOAuthAccount{
|
||||
AccountUUID: "account",
|
||||
EmailAddress: "user@example.com",
|
||||
OrganizationUUID: "org",
|
||||
BillingType: &billingType,
|
||||
AccountCreatedAt: &accountCreatedAt,
|
||||
SubscriptionCreatedAt: &subscriptionCreatedAt,
|
||||
}
|
||||
credential.stateAccess.Unlock()
|
||||
}
|
||||
|
||||
func newTestService(credential *defaultCredential) *Service {
|
||||
return &Service{
|
||||
logger: log.NewNOPFactory().Logger(),
|
||||
options: option.CCMServiceOptions{Credentials: []option.CCMCredential{{Tag: "default"}}},
|
||||
httpHeaders: make(http.Header),
|
||||
providers: map[string]credentialProvider{"default": &singleCredentialProvider{credential: credential}},
|
||||
sessionModels: make(map[sessionModelKey]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
func newMessageRequest(body string) *http.Request {
|
||||
request := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
return request
|
||||
}
|
||||
|
||||
func tempConfigPath(t *testing.T, dir string) string {
|
||||
t.Helper()
|
||||
return filepath.Join(dir, claudeCodeLegacyConfigFileName())
|
||||
}
|
||||
Reference in New Issue
Block a user