diff --git a/option/ccm.go b/option/ccm.go index dd55a4ba4..b4be72ea7 100644 --- a/option/ccm.go +++ b/option/ccm.go @@ -102,6 +102,7 @@ type CCMExternalCredentialOptions struct { Token string `json:"token"` Reverse bool `json:"reverse,omitempty"` Detour string `json:"detour,omitempty"` + PlanWeight float64 `json:"plan_weight,omitempty"` UsagesPath string `json:"usages_path,omitempty"` PollInterval badoption.Duration `json:"poll_interval,omitempty"` } diff --git a/option/ocm.go b/option/ocm.go index e508abae7..0f364821f 100644 --- a/option/ocm.go +++ b/option/ocm.go @@ -102,6 +102,7 @@ type OCMExternalCredentialOptions struct { Token string `json:"token"` Reverse bool `json:"reverse,omitempty"` Detour string `json:"detour,omitempty"` + PlanWeight float64 `json:"plan_weight,omitempty"` UsagesPath string `json:"usages_path,omitempty"` PollInterval badoption.Duration `json:"poll_interval,omitempty"` } diff --git a/service/ccm/credential.go b/service/ccm/credential.go index 75ae62f97..8bfd27c23 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -133,6 +133,7 @@ type oauthCredentials struct { ExpiresAt int64 `json:"expiresAt"` Scopes []string `json:"scopes,omitempty"` SubscriptionType string `json:"subscriptionType,omitempty"` + RateLimitTier string `json:"rateLimitTier,omitempty"` IsMax bool `json:"isMax,omitempty"` } @@ -219,5 +220,6 @@ func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool { left.ExpiresAt == right.ExpiresAt && slices.Equal(left.Scopes, right.Scopes) && left.SubscriptionType == right.SubscriptionType && + left.RateLimitTier == right.RateLimitTier && left.IsMax == right.IsMax } diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index d6eb4c102..807a06fe8 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -29,16 +29,17 @@ import ( const reverseProxyBaseURL = "http://reverse-proxy" type externalCredential struct { - tag string - baseURL string - token string - httpClient *http.Client - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - pollInterval time.Duration - usageTracker *AggregatedUsage - logger log.ContextLogger + tag string + baseURL string + token string + httpClient *http.Client + state credentialState + stateMutex sync.RWMutex + pollAccess sync.Mutex + pollInterval time.Duration + configuredPlanWeight float64 + usageTracker *AggregatedUsage + logger log.ContextLogger onBecameUnusable func() interrupted bool @@ -112,16 +113,22 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx requestContext, cancelRequests := context.WithCancel(context.Background()) reverseContext, reverseCancel := context.WithCancel(context.Background()) + configuredPlanWeight := options.PlanWeight + if configuredPlanWeight <= 0 { + configuredPlanWeight = 1 + } + cred := &externalCredential{ - tag: tag, - token: options.Token, - pollInterval: pollInterval, - logger: logger, - requestContext: requestContext, - cancelRequests: cancelRequests, - reverse: options.Reverse, - reverseContext: reverseContext, - reverseCancel: reverseCancel, + tag: tag, + token: options.Token, + pollInterval: pollInterval, + configuredPlanWeight: configuredPlanWeight, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + reverse: options.Reverse, + reverseContext: reverseContext, + reverseCancel: reverseCancel, } if options.URL == "" { @@ -283,6 +290,16 @@ func (c *externalCredential) weeklyCap() float64 { return 100 } +func (c *externalCredential) planWeight() float64 { + return c.configuredPlanWeight +} + +func (c *externalCredential) weeklyResetTime() time.Time { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.weeklyReset +} + func (c *externalCredential) markRateLimited(resetAt time.Time) { c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) c.stateMutex.Lock() diff --git a/service/ccm/credential_file.go b/service/ccm/credential_file.go index da13fae10..eba920726 100644 --- a/service/ccm/credential_file.go +++ b/service/ccm/credential_file.go @@ -114,6 +114,7 @@ func (c *defaultCredential) reloadCredentials(force bool) error { c.state.unavailable = false c.state.lastCredentialLoadError = "" c.state.accountType = credentials.SubscriptionType + c.state.rateLimitTier = credentials.RateLimitTier c.checkTransitionLocked() c.stateMutex.Unlock() @@ -130,6 +131,7 @@ func (c *defaultCredential) markCredentialsUnavailable(err error) error { c.state.unavailable = true c.state.lastCredentialLoadError = err.Error() c.state.accountType = "" + c.state.rateLimitTier = "" shouldInterrupt := c.checkTransitionLocked() c.stateMutex.Unlock() diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index 5eacd19ee..87c9afde2 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -71,6 +71,7 @@ type credentialState struct { hardRateLimited bool rateLimitResetAt time.Time accountType string + rateLimitTier string lastUpdated time.Time consecutivePollFailures int unavailable bool @@ -134,6 +135,8 @@ type credential interface { weeklyUtilization() float64 fiveHourCap() float64 weeklyCap() float64 + planWeight() float64 + weeklyResetTime() time.Time markRateLimited(resetAt time.Time) earliestReset() time.Time unavailableError() error @@ -294,6 +297,7 @@ func (c *defaultCredential) getAccessToken() (string, error) { c.state.lastCredentialLoadAttempt = time.Now() c.state.lastCredentialLoadError = "" c.state.accountType = latestCredentials.SubscriptionType + c.state.rateLimitTier = latestCredentials.RateLimitTier c.checkTransitionLocked() c.stateMutex.Unlock() if !latestCredentials.needsRefresh() { @@ -308,6 +312,7 @@ func (c *defaultCredential) getAccessToken() (string, error) { c.state.lastCredentialLoadAttempt = time.Now() c.state.lastCredentialLoadError = "" c.state.accountType = newCredentials.SubscriptionType + c.state.rateLimitTier = newCredentials.RateLimitTier c.checkTransitionLocked() c.stateMutex.Unlock() @@ -510,6 +515,18 @@ func (c *defaultCredential) weeklyUtilization() float64 { return c.state.weeklyUtilization } +func (c *defaultCredential) planWeight() float64 { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return ccmPlanWeight(c.state.accountType, c.state.rateLimitTier) +} + +func (c *defaultCredential) weeklyResetTime() time.Time { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.weeklyReset +} + func (c *defaultCredential) isAvailable() bool { c.retryCredentialReloadIfNeeded() @@ -670,11 +687,72 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { } c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix) } + needsProfileFetch := c.state.rateLimitTier == "" shouldInterrupt := c.checkTransitionLocked() c.stateMutex.Unlock() if shouldInterrupt { c.interruptConnections() } + + if needsProfileFetch { + c.fetchProfile(ctx, httpClient, accessToken) + } +} + +func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.Client, accessToken string) { + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/profile", nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+accessToken) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("User-Agent", ccmUserAgentValue) + return request, nil + }) + if err != nil { + c.logger.Debug("fetch profile for ", c.tag, ": ", err) + return + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return + } + + var profileResponse struct { + Organization *struct { + OrganizationType string `json:"organization_type"` + RateLimitTier string `json:"rate_limit_tier"` + } `json:"organization"` + } + err = json.NewDecoder(response.Body).Decode(&profileResponse) + if err != nil || profileResponse.Organization == nil { + return + } + + accountType := "" + switch profileResponse.Organization.OrganizationType { + case "claude_pro": + accountType = "pro" + case "claude_max": + accountType = "max" + case "claude_team": + accountType = "team" + case "claude_enterprise": + accountType = "enterprise" + } + rateLimitTier := profileResponse.Organization.RateLimitTier + + c.stateMutex.Lock() + if accountType != "" && c.state.accountType == "" { + c.state.accountType = accountType + } + if rateLimitTier != "" { + c.state.rateLimitTier = rateLimitTier + } + c.stateMutex.Unlock() + c.logger.Info("fetched profile for ", c.tag, ": type=", c.state.accountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(c.state.accountType, rateLimitTier)) } func (c *defaultCredential) close() { @@ -928,7 +1006,8 @@ func (p *balancerProvider) pickCredential(filter func(credential) bool) credenti func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { var best credential - bestRemaining := float64(-1) + bestScore := float64(-1) + now := time.Now() for _, cred := range p.credentials { if filter != nil && !filter(cred) { continue @@ -937,14 +1016,46 @@ func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credentia continue } remaining := cred.weeklyCap() - cred.weeklyUtilization() - if remaining > bestRemaining { - bestRemaining = remaining + score := remaining * cred.planWeight() + resetTime := cred.weeklyResetTime() + if !resetTime.IsZero() { + timeUntilReset := resetTime.Sub(now) + if timeUntilReset < time.Hour { + timeUntilReset = time.Hour + } + score *= weeklyWindowDuration / timeUntilReset.Hours() + } + if score > bestScore { + bestScore = score best = cred } } return best } +const weeklyWindowDuration = 7 * 24 // hours + +func ccmPlanWeight(accountType string, rateLimitTier string) float64 { + switch accountType { + case "max": + switch rateLimitTier { + case "default_claude_max_20x": + return 10 + case "default_claude_max_5x": + return 5 + default: + return 5 + } + case "team": + if rateLimitTier == "default_claude_max_5x" { + return 5 + } + return 1 + default: + return 1 + } +} + func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential { start := int(p.roundRobinIndex.Add(1) - 1) count := len(p.credentials) diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 52c1e7210..2c9dce46b 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -30,17 +30,18 @@ import ( const reverseProxyBaseURL = "http://reverse-proxy" type externalCredential struct { - tag string - baseURL string - token string - credDialer N.Dialer - httpClient *http.Client - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - pollInterval time.Duration - usageTracker *AggregatedUsage - logger log.ContextLogger + tag string + baseURL string + token string + credDialer N.Dialer + httpClient *http.Client + state credentialState + stateMutex sync.RWMutex + pollAccess sync.Mutex + pollInterval time.Duration + configuredPlanWeight float64 + usageTracker *AggregatedUsage + logger log.ContextLogger onBecameUnusable func() interrupted bool @@ -129,16 +130,22 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx requestContext, cancelRequests := context.WithCancel(context.Background()) reverseContext, reverseCancel := context.WithCancel(context.Background()) + configuredPlanWeight := options.PlanWeight + if configuredPlanWeight <= 0 { + configuredPlanWeight = 1 + } + cred := &externalCredential{ - tag: tag, - token: options.Token, - pollInterval: pollInterval, - logger: logger, - requestContext: requestContext, - cancelRequests: cancelRequests, - reverse: options.Reverse, - reverseContext: reverseContext, - reverseCancel: reverseCancel, + tag: tag, + token: options.Token, + pollInterval: pollInterval, + configuredPlanWeight: configuredPlanWeight, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + reverse: options.Reverse, + reverseContext: reverseContext, + reverseCancel: reverseCancel, } if options.URL == "" { @@ -305,6 +312,16 @@ func (c *externalCredential) weeklyCap() float64 { return 100 } +func (c *externalCredential) planWeight() float64 { + return c.configuredPlanWeight +} + +func (c *externalCredential) weeklyResetTime() time.Time { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.weeklyReset +} + func (c *externalCredential) markRateLimited(resetAt time.Time) { c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) c.stateMutex.Lock() diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index db06d05a9..3cb1f48b9 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -135,6 +135,8 @@ type credential interface { weeklyUtilization() float64 fiveHourCap() float64 weeklyCap() float64 + planWeight() float64 + weeklyResetTime() time.Time markRateLimited(resetAt time.Time) earliestReset() time.Time unavailableError() error @@ -527,6 +529,18 @@ func (c *defaultCredential) weeklyUtilization() float64 { return c.state.weeklyUtilization } +func (c *defaultCredential) planWeight() float64 { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return ocmPlanWeight(c.state.accountType) +} + +func (c *defaultCredential) weeklyResetTime() time.Time { + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + return c.state.weeklyReset +} + func (c *defaultCredential) isAvailable() bool { c.retryCredentialReloadIfNeeded() @@ -991,7 +1005,8 @@ func (p *balancerProvider) pickCredential(filter func(credential) bool) credenti func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential { var best credential - bestRemaining := float64(-1) + bestScore := float64(-1) + now := time.Now() for _, cred := range p.credentials { if filter != nil && !filter(cred) { continue @@ -1003,14 +1018,36 @@ func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credentia continue } remaining := cred.weeklyCap() - cred.weeklyUtilization() - if remaining > bestRemaining { - bestRemaining = remaining + score := remaining * cred.planWeight() + resetTime := cred.weeklyResetTime() + if !resetTime.IsZero() { + timeUntilReset := resetTime.Sub(now) + if timeUntilReset < time.Hour { + timeUntilReset = time.Hour + } + score *= weeklyWindowDuration / timeUntilReset.Hours() + } + if score > bestScore { + bestScore = score best = cred } } return best } +const weeklyWindowDuration = 7 * 24 // hours + +func ocmPlanWeight(accountType string) float64 { + switch accountType { + case "pro": + return 10 + case "plus": + return 1 + default: + return 1 + } +} + func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential { start := int(p.roundRobinIndex.Add(1) - 1) count := len(p.credentials)