mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-14 12:48:28 +10:00
Guard updateStateFromHeaders emission with value-change detection to avoid unnecessary computeAggregatedUtilization scans on every proxied response. Replace statusAggregateStateLocked two-value return with comparable statusSnapshot struct. Define statusPayload type for the status wire format, replacing anonymous structs and map literals.
786 lines
22 KiB
Go
786 lines
22 KiB
Go
package ccm
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
stdTLS "crypto/tls"
|
|
"encoding/json"
|
|
"io"
|
|
"math"
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sagernet/fswatch"
|
|
"github.com/sagernet/sing-box/adapter"
|
|
"github.com/sagernet/sing-box/common/dialer"
|
|
"github.com/sagernet/sing-box/log"
|
|
"github.com/sagernet/sing-box/option"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
"github.com/sagernet/sing/common/ntp"
|
|
"github.com/sagernet/sing/common/observable"
|
|
)
|
|
|
|
type defaultCredential struct {
|
|
tag string
|
|
serviceContext context.Context
|
|
credentialPath string
|
|
credentialFilePath string
|
|
credentials *oauthCredentials
|
|
access sync.RWMutex
|
|
state credentialState
|
|
stateAccess sync.RWMutex
|
|
pollAccess sync.Mutex
|
|
reloadAccess sync.Mutex
|
|
watcherAccess sync.Mutex
|
|
cap5h float64
|
|
capWeekly float64
|
|
usageTracker *AggregatedUsage
|
|
forwardHTTPClient *http.Client
|
|
logger log.ContextLogger
|
|
watcher *fswatch.Watcher
|
|
watcherRetryAt time.Time
|
|
|
|
statusSubscriber *observable.Subscriber[struct{}]
|
|
|
|
// Connection interruption
|
|
interrupted bool
|
|
requestContext context.Context
|
|
cancelRequests context.CancelFunc
|
|
requestAccess sync.Mutex
|
|
}
|
|
|
|
func newDefaultCredential(ctx context.Context, tag string, options option.CCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) {
|
|
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
|
|
Context: ctx,
|
|
Options: option.DialerOptions{
|
|
Detour: options.Detour,
|
|
},
|
|
RemoteIsDomain: true,
|
|
})
|
|
if err != nil {
|
|
return nil, E.Cause(err, "create dialer for credential ", tag)
|
|
}
|
|
httpClient := &http.Client{
|
|
Transport: &http.Transport{
|
|
ForceAttemptHTTP2: true,
|
|
TLSClientConfig: &stdTLS.Config{
|
|
RootCAs: adapter.RootPoolFromContext(ctx),
|
|
Time: ntp.TimeFuncFromContext(ctx),
|
|
},
|
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
|
},
|
|
},
|
|
}
|
|
reserve5h := options.Reserve5h
|
|
if reserve5h == 0 {
|
|
reserve5h = 1
|
|
}
|
|
reserveWeekly := options.ReserveWeekly
|
|
if reserveWeekly == 0 {
|
|
reserveWeekly = 1
|
|
}
|
|
var cap5h float64
|
|
if options.Limit5h > 0 {
|
|
cap5h = float64(options.Limit5h)
|
|
} else {
|
|
cap5h = float64(100 - reserve5h)
|
|
}
|
|
var capWeekly float64
|
|
if options.LimitWeekly > 0 {
|
|
capWeekly = float64(options.LimitWeekly)
|
|
} else {
|
|
capWeekly = float64(100 - reserveWeekly)
|
|
}
|
|
requestContext, cancelRequests := context.WithCancel(context.Background())
|
|
credential := &defaultCredential{
|
|
tag: tag,
|
|
serviceContext: ctx,
|
|
credentialPath: options.CredentialPath,
|
|
cap5h: cap5h,
|
|
capWeekly: capWeekly,
|
|
forwardHTTPClient: httpClient,
|
|
logger: logger,
|
|
requestContext: requestContext,
|
|
cancelRequests: cancelRequests,
|
|
}
|
|
if options.UsagesPath != "" {
|
|
credential.usageTracker = &AggregatedUsage{
|
|
LastUpdated: time.Now(),
|
|
Combinations: make([]CostCombination, 0),
|
|
filePath: options.UsagesPath,
|
|
logger: logger,
|
|
}
|
|
}
|
|
return credential, nil
|
|
}
|
|
|
|
func (c *defaultCredential) start() error {
|
|
credentialFilePath, err := resolveCredentialFilePath(c.credentialPath)
|
|
if err != nil {
|
|
return E.Cause(err, "resolve credential path for ", c.tag)
|
|
}
|
|
c.credentialFilePath = credentialFilePath
|
|
err = c.ensureCredentialWatcher()
|
|
if err != nil {
|
|
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
|
|
}
|
|
err = c.reloadCredentials(true)
|
|
if err != nil {
|
|
c.logger.Warn("initial credential load for ", c.tag, ": ", err)
|
|
}
|
|
if c.usageTracker != nil {
|
|
err = c.usageTracker.Load()
|
|
if err != nil {
|
|
c.logger.Warn("load usage statistics for ", c.tag, ": ", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *defaultCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) {
|
|
c.statusSubscriber = subscriber
|
|
}
|
|
|
|
func (c *defaultCredential) emitStatusUpdate() {
|
|
if c.statusSubscriber != nil {
|
|
c.statusSubscriber.Emit(struct{}{})
|
|
}
|
|
}
|
|
|
|
type statusSnapshot struct {
|
|
available bool
|
|
weight float64
|
|
}
|
|
|
|
func (c *defaultCredential) statusSnapshotLocked() statusSnapshot {
|
|
if c.state.unavailable {
|
|
return statusSnapshot{}
|
|
}
|
|
return statusSnapshot{true, ccmPlanWeight(c.state.accountType, c.state.rateLimitTier)}
|
|
}
|
|
|
|
func (c *defaultCredential) getAccessToken() (string, error) {
|
|
c.retryCredentialReloadIfNeeded()
|
|
|
|
c.access.RLock()
|
|
if c.credentials != nil && !c.credentials.needsRefresh() {
|
|
token := c.credentials.AccessToken
|
|
c.access.RUnlock()
|
|
return token, nil
|
|
}
|
|
c.access.RUnlock()
|
|
|
|
err := c.reloadCredentials(true)
|
|
if err == nil {
|
|
c.access.RLock()
|
|
if c.credentials != nil && !c.credentials.needsRefresh() {
|
|
token := c.credentials.AccessToken
|
|
c.access.RUnlock()
|
|
return token, nil
|
|
}
|
|
c.access.RUnlock()
|
|
}
|
|
|
|
c.access.Lock()
|
|
defer c.access.Unlock()
|
|
|
|
if c.credentials == nil {
|
|
return "", c.unavailableError()
|
|
}
|
|
if !c.credentials.needsRefresh() {
|
|
return c.credentials.AccessToken, nil
|
|
}
|
|
|
|
err = platformCanWriteCredentials(c.credentialPath)
|
|
if err != nil {
|
|
return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation")
|
|
}
|
|
|
|
baseCredentials := cloneCredentials(c.credentials)
|
|
newCredentials, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
latestCredentials, latestErr := platformReadCredentials(c.credentialPath)
|
|
if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) {
|
|
c.credentials = latestCredentials
|
|
c.stateAccess.Lock()
|
|
before := c.statusSnapshotLocked()
|
|
c.state.unavailable = false
|
|
c.state.lastCredentialLoadAttempt = time.Now()
|
|
c.state.lastCredentialLoadError = ""
|
|
c.state.accountType = latestCredentials.SubscriptionType
|
|
c.state.rateLimitTier = latestCredentials.RateLimitTier
|
|
c.checkTransitionLocked()
|
|
shouldEmit := before != c.statusSnapshotLocked()
|
|
c.stateAccess.Unlock()
|
|
if shouldEmit {
|
|
c.emitStatusUpdate()
|
|
}
|
|
if !latestCredentials.needsRefresh() {
|
|
return latestCredentials.AccessToken, nil
|
|
}
|
|
return "", E.New("credential ", c.tag, " changed while refreshing")
|
|
}
|
|
|
|
c.credentials = newCredentials
|
|
c.stateAccess.Lock()
|
|
before := c.statusSnapshotLocked()
|
|
c.state.unavailable = false
|
|
c.state.lastCredentialLoadAttempt = time.Now()
|
|
c.state.lastCredentialLoadError = ""
|
|
c.state.accountType = newCredentials.SubscriptionType
|
|
c.state.rateLimitTier = newCredentials.RateLimitTier
|
|
c.checkTransitionLocked()
|
|
shouldEmit := before != c.statusSnapshotLocked()
|
|
c.stateAccess.Unlock()
|
|
if shouldEmit {
|
|
c.emitStatusUpdate()
|
|
}
|
|
|
|
err = platformWriteCredentials(newCredentials, c.credentialPath)
|
|
if err != nil {
|
|
c.logger.Error("persist refreshed token for ", c.tag, ": ", err)
|
|
}
|
|
|
|
return newCredentials.AccessToken, nil
|
|
}
|
|
|
|
func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
|
|
c.stateAccess.Lock()
|
|
isFirstUpdate := c.state.lastUpdated.IsZero()
|
|
oldFiveHour := c.state.fiveHourUtilization
|
|
oldWeekly := c.state.weeklyUtilization
|
|
hadData := false
|
|
|
|
fiveHourResetChanged := false
|
|
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists {
|
|
hadData = true
|
|
if value.After(c.state.fiveHourReset) {
|
|
fiveHourResetChanged = true
|
|
c.state.fiveHourReset = value
|
|
}
|
|
}
|
|
if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" {
|
|
value, err := strconv.ParseFloat(utilization, 64)
|
|
if err == nil {
|
|
hadData = true
|
|
newValue := math.Ceil(value * 100)
|
|
if newValue >= c.state.fiveHourUtilization || fiveHourResetChanged {
|
|
c.state.fiveHourUtilization = newValue
|
|
}
|
|
}
|
|
}
|
|
|
|
weeklyResetChanged := false
|
|
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists {
|
|
hadData = true
|
|
if value.After(c.state.weeklyReset) {
|
|
weeklyResetChanged = true
|
|
c.state.weeklyReset = value
|
|
}
|
|
}
|
|
if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" {
|
|
value, err := strconv.ParseFloat(utilization, 64)
|
|
if err == nil {
|
|
hadData = true
|
|
newValue := math.Ceil(value * 100)
|
|
if newValue >= c.state.weeklyUtilization || weeklyResetChanged {
|
|
c.state.weeklyUtilization = newValue
|
|
}
|
|
}
|
|
}
|
|
if hadData {
|
|
c.state.consecutivePollFailures = 0
|
|
c.state.lastUpdated = time.Now()
|
|
}
|
|
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
|
resetSuffix := ""
|
|
if !c.state.weeklyReset.IsZero() {
|
|
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
|
|
}
|
|
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
|
}
|
|
shouldEmit := hadData && (c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly)
|
|
shouldInterrupt := c.checkTransitionLocked()
|
|
c.stateAccess.Unlock()
|
|
if shouldInterrupt {
|
|
c.interruptConnections()
|
|
}
|
|
if shouldEmit {
|
|
c.emitStatusUpdate()
|
|
}
|
|
}
|
|
|
|
func (c *defaultCredential) markRateLimited(resetAt time.Time) {
|
|
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
|
|
c.stateAccess.Lock()
|
|
c.state.hardRateLimited = true
|
|
c.state.rateLimitResetAt = resetAt
|
|
shouldInterrupt := c.checkTransitionLocked()
|
|
c.stateAccess.Unlock()
|
|
if shouldInterrupt {
|
|
c.interruptConnections()
|
|
}
|
|
c.emitStatusUpdate()
|
|
}
|
|
|
|
func (c *defaultCredential) isUsable() bool {
|
|
c.retryCredentialReloadIfNeeded()
|
|
|
|
c.stateAccess.RLock()
|
|
if c.state.unavailable {
|
|
c.stateAccess.RUnlock()
|
|
return false
|
|
}
|
|
if c.state.consecutivePollFailures > 0 {
|
|
c.stateAccess.RUnlock()
|
|
return false
|
|
}
|
|
if c.state.hardRateLimited {
|
|
if time.Now().Before(c.state.rateLimitResetAt) {
|
|
c.stateAccess.RUnlock()
|
|
return false
|
|
}
|
|
c.stateAccess.RUnlock()
|
|
c.stateAccess.Lock()
|
|
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
|
|
c.state.hardRateLimited = false
|
|
}
|
|
usable := c.checkReservesLocked()
|
|
c.stateAccess.Unlock()
|
|
return usable
|
|
}
|
|
usable := c.checkReservesLocked()
|
|
c.stateAccess.RUnlock()
|
|
return usable
|
|
}
|
|
|
|
func (c *defaultCredential) checkReservesLocked() bool {
|
|
if c.state.fiveHourUtilization >= c.cap5h {
|
|
return false
|
|
}
|
|
if c.state.weeklyUtilization >= c.capWeekly {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
// checkTransitionLocked detects usable→unusable transition.
|
|
// Must be called with stateAccess write lock held.
|
|
func (c *defaultCredential) checkTransitionLocked() bool {
|
|
unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() || c.state.consecutivePollFailures > 0
|
|
if unusable && !c.interrupted {
|
|
c.interrupted = true
|
|
return true
|
|
}
|
|
if !unusable && c.interrupted {
|
|
c.interrupted = false
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (c *defaultCredential) interruptConnections() {
|
|
c.logger.Warn("interrupting connections for ", c.tag)
|
|
c.requestAccess.Lock()
|
|
c.cancelRequests()
|
|
c.requestContext, c.cancelRequests = context.WithCancel(context.Background())
|
|
c.requestAccess.Unlock()
|
|
}
|
|
|
|
func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext {
|
|
c.requestAccess.Lock()
|
|
credentialContext := c.requestContext
|
|
c.requestAccess.Unlock()
|
|
derived, cancel := context.WithCancel(parent)
|
|
stop := context.AfterFunc(credentialContext, func() {
|
|
cancel()
|
|
})
|
|
return &credentialRequestContext{
|
|
Context: derived,
|
|
releaseFuncs: []func() bool{stop},
|
|
cancelFunc: cancel,
|
|
}
|
|
}
|
|
|
|
func (c *defaultCredential) weeklyUtilization() float64 {
|
|
c.stateAccess.RLock()
|
|
defer c.stateAccess.RUnlock()
|
|
return c.state.weeklyUtilization
|
|
}
|
|
|
|
func (c *defaultCredential) planWeight() float64 {
|
|
c.stateAccess.RLock()
|
|
defer c.stateAccess.RUnlock()
|
|
return ccmPlanWeight(c.state.accountType, c.state.rateLimitTier)
|
|
}
|
|
|
|
func (c *defaultCredential) weeklyResetTime() time.Time {
|
|
c.stateAccess.RLock()
|
|
defer c.stateAccess.RUnlock()
|
|
return c.state.weeklyReset
|
|
}
|
|
|
|
func (c *defaultCredential) isAvailable() bool {
|
|
c.retryCredentialReloadIfNeeded()
|
|
|
|
c.stateAccess.RLock()
|
|
defer c.stateAccess.RUnlock()
|
|
return !c.state.unavailable
|
|
}
|
|
|
|
func (c *defaultCredential) unavailableError() error {
|
|
c.stateAccess.RLock()
|
|
defer c.stateAccess.RUnlock()
|
|
if !c.state.unavailable {
|
|
return nil
|
|
}
|
|
if c.state.lastCredentialLoadError == "" {
|
|
return E.New("credential ", c.tag, " is unavailable")
|
|
}
|
|
return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError)
|
|
}
|
|
|
|
func (c *defaultCredential) lastUpdatedTime() time.Time {
|
|
c.stateAccess.RLock()
|
|
defer c.stateAccess.RUnlock()
|
|
return c.state.lastUpdated
|
|
}
|
|
|
|
func (c *defaultCredential) markUsagePollAttempted() {
|
|
c.stateAccess.Lock()
|
|
defer c.stateAccess.Unlock()
|
|
c.state.lastUpdated = time.Now()
|
|
}
|
|
|
|
func (c *defaultCredential) incrementPollFailures() {
|
|
c.stateAccess.Lock()
|
|
c.state.consecutivePollFailures++
|
|
shouldInterrupt := c.checkTransitionLocked()
|
|
c.stateAccess.Unlock()
|
|
if shouldInterrupt {
|
|
c.interruptConnections()
|
|
}
|
|
}
|
|
|
|
func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration {
|
|
c.stateAccess.RLock()
|
|
failures := c.state.consecutivePollFailures
|
|
retryDelay := c.state.usageAPIRetryDelay
|
|
c.stateAccess.RUnlock()
|
|
if failures <= 0 {
|
|
if retryDelay > 0 {
|
|
return retryDelay
|
|
}
|
|
return baseInterval
|
|
}
|
|
backoff := failedPollRetryInterval * time.Duration(1<<(failures-1))
|
|
if backoff > httpRetryMaxBackoff {
|
|
return httpRetryMaxBackoff
|
|
}
|
|
return backoff
|
|
}
|
|
|
|
func (c *defaultCredential) isPollBackoffAtCap() bool {
|
|
c.stateAccess.RLock()
|
|
defer c.stateAccess.RUnlock()
|
|
failures := c.state.consecutivePollFailures
|
|
return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff
|
|
}
|
|
|
|
func (c *defaultCredential) earliestReset() time.Time {
|
|
c.stateAccess.RLock()
|
|
defer c.stateAccess.RUnlock()
|
|
if c.state.unavailable {
|
|
return time.Time{}
|
|
}
|
|
if c.state.hardRateLimited {
|
|
return c.state.rateLimitResetAt
|
|
}
|
|
earliest := c.state.fiveHourReset
|
|
if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) {
|
|
earliest = c.state.weeklyReset
|
|
}
|
|
return earliest
|
|
}
|
|
|
|
func (c *defaultCredential) pollUsage(ctx context.Context) {
|
|
if !c.pollAccess.TryLock() {
|
|
return
|
|
}
|
|
defer c.pollAccess.Unlock()
|
|
defer c.markUsagePollAttempted()
|
|
|
|
c.retryCredentialReloadIfNeeded()
|
|
if !c.isAvailable() {
|
|
return
|
|
}
|
|
|
|
accessToken, err := c.getAccessToken()
|
|
if err != nil {
|
|
if !c.isPollBackoffAtCap() {
|
|
c.logger.Error("poll usage for ", c.tag, ": get token: ", err)
|
|
}
|
|
c.incrementPollFailures()
|
|
return
|
|
}
|
|
|
|
httpClient := &http.Client{
|
|
Transport: c.forwardHTTPClient.Transport,
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
|
request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
request.Header.Set("Authorization", "Bearer "+accessToken)
|
|
request.Header.Set("Content-Type", "application/json")
|
|
request.Header.Set("User-Agent", ccmUserAgentValue)
|
|
request.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
|
|
return request, nil
|
|
})
|
|
if err != nil {
|
|
if !c.isPollBackoffAtCap() {
|
|
c.logger.Error("poll usage for ", c.tag, ": ", err)
|
|
}
|
|
c.incrementPollFailures()
|
|
return
|
|
}
|
|
defer response.Body.Close()
|
|
|
|
if response.StatusCode != http.StatusOK {
|
|
if response.StatusCode == http.StatusTooManyRequests {
|
|
retryDelay := time.Minute
|
|
if retryAfter := response.Header.Get("Retry-After"); retryAfter != "" {
|
|
seconds, err := strconv.ParseInt(retryAfter, 10, 64)
|
|
if err == nil && seconds > 0 {
|
|
retryDelay = time.Duration(seconds) * time.Second
|
|
}
|
|
}
|
|
c.logger.Warn("poll usage for ", c.tag, ": usage API rate limited, retry in ", log.FormatDuration(retryDelay))
|
|
c.stateAccess.Lock()
|
|
c.state.usageAPIRetryDelay = retryDelay
|
|
c.stateAccess.Unlock()
|
|
return
|
|
}
|
|
body, _ := io.ReadAll(response.Body)
|
|
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
|
|
c.incrementPollFailures()
|
|
return
|
|
}
|
|
|
|
var usageResponse struct {
|
|
FiveHour struct {
|
|
Utilization float64 `json:"utilization"`
|
|
ResetsAt time.Time `json:"resets_at"`
|
|
} `json:"five_hour"`
|
|
SevenDay struct {
|
|
Utilization float64 `json:"utilization"`
|
|
ResetsAt time.Time `json:"resets_at"`
|
|
} `json:"seven_day"`
|
|
}
|
|
err = json.NewDecoder(response.Body).Decode(&usageResponse)
|
|
if err != nil {
|
|
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
|
|
c.incrementPollFailures()
|
|
return
|
|
}
|
|
|
|
c.stateAccess.Lock()
|
|
isFirstUpdate := c.state.lastUpdated.IsZero()
|
|
oldFiveHour := c.state.fiveHourUtilization
|
|
oldWeekly := c.state.weeklyUtilization
|
|
c.state.consecutivePollFailures = 0
|
|
c.state.usageAPIRetryDelay = 0
|
|
c.state.fiveHourUtilization = usageResponse.FiveHour.Utilization
|
|
if !usageResponse.FiveHour.ResetsAt.IsZero() {
|
|
c.state.fiveHourReset = usageResponse.FiveHour.ResetsAt
|
|
}
|
|
c.state.weeklyUtilization = usageResponse.SevenDay.Utilization
|
|
if !usageResponse.SevenDay.ResetsAt.IsZero() {
|
|
c.state.weeklyReset = usageResponse.SevenDay.ResetsAt
|
|
}
|
|
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
|
|
c.state.hardRateLimited = false
|
|
}
|
|
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
|
resetSuffix := ""
|
|
if !c.state.weeklyReset.IsZero() {
|
|
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
|
|
}
|
|
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
|
}
|
|
needsProfileFetch := c.state.rateLimitTier == ""
|
|
shouldInterrupt := c.checkTransitionLocked()
|
|
c.stateAccess.Unlock()
|
|
if shouldInterrupt {
|
|
c.interruptConnections()
|
|
}
|
|
c.emitStatusUpdate()
|
|
|
|
if needsProfileFetch {
|
|
c.fetchProfile(ctx, httpClient, accessToken)
|
|
}
|
|
}
|
|
|
|
func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.Client, accessToken string) {
|
|
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
|
request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/profile", nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
request.Header.Set("Authorization", "Bearer "+accessToken)
|
|
request.Header.Set("Content-Type", "application/json")
|
|
request.Header.Set("User-Agent", ccmUserAgentValue)
|
|
return request, nil
|
|
})
|
|
if err != nil {
|
|
c.logger.Debug("fetch profile for ", c.tag, ": ", err)
|
|
return
|
|
}
|
|
defer response.Body.Close()
|
|
|
|
if response.StatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var profileResponse struct {
|
|
Organization *struct {
|
|
OrganizationType string `json:"organization_type"`
|
|
RateLimitTier string `json:"rate_limit_tier"`
|
|
} `json:"organization"`
|
|
}
|
|
err = json.NewDecoder(response.Body).Decode(&profileResponse)
|
|
if err != nil || profileResponse.Organization == nil {
|
|
return
|
|
}
|
|
|
|
accountType := ""
|
|
switch profileResponse.Organization.OrganizationType {
|
|
case "claude_pro":
|
|
accountType = "pro"
|
|
case "claude_max":
|
|
accountType = "max"
|
|
case "claude_team":
|
|
accountType = "team"
|
|
case "claude_enterprise":
|
|
accountType = "enterprise"
|
|
}
|
|
rateLimitTier := profileResponse.Organization.RateLimitTier
|
|
|
|
c.stateAccess.Lock()
|
|
before := c.statusSnapshotLocked()
|
|
if accountType != "" && c.state.accountType == "" {
|
|
c.state.accountType = accountType
|
|
}
|
|
if rateLimitTier != "" {
|
|
c.state.rateLimitTier = rateLimitTier
|
|
}
|
|
resolvedAccountType := c.state.accountType
|
|
shouldEmit := before != c.statusSnapshotLocked()
|
|
c.stateAccess.Unlock()
|
|
if shouldEmit {
|
|
c.emitStatusUpdate()
|
|
}
|
|
c.logger.Info("fetched profile for ", c.tag, ": type=", resolvedAccountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(resolvedAccountType, rateLimitTier))
|
|
}
|
|
|
|
func (c *defaultCredential) close() {
|
|
if c.watcher != nil {
|
|
err := c.watcher.Close()
|
|
if err != nil {
|
|
c.logger.Error("close credential watcher for ", c.tag, ": ", err)
|
|
}
|
|
}
|
|
if c.usageTracker != nil {
|
|
c.usageTracker.cancelPendingSave()
|
|
err := c.usageTracker.Save()
|
|
if err != nil {
|
|
c.logger.Error("save usage statistics for ", c.tag, ": ", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *defaultCredential) tagName() string {
|
|
return c.tag
|
|
}
|
|
|
|
func (c *defaultCredential) isExternal() bool {
|
|
return false
|
|
}
|
|
|
|
func (c *defaultCredential) fiveHourUtilization() float64 {
|
|
c.stateAccess.RLock()
|
|
defer c.stateAccess.RUnlock()
|
|
return c.state.fiveHourUtilization
|
|
}
|
|
|
|
func (c *defaultCredential) fiveHourCap() float64 {
|
|
return c.cap5h
|
|
}
|
|
|
|
func (c *defaultCredential) weeklyCap() float64 {
|
|
return c.capWeekly
|
|
}
|
|
|
|
func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage {
|
|
return c.usageTracker
|
|
}
|
|
|
|
func (c *defaultCredential) httpClient() *http.Client {
|
|
return c.forwardHTTPClient
|
|
}
|
|
|
|
func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) {
|
|
accessToken, err := c.getAccessToken()
|
|
if err != nil {
|
|
return nil, E.Cause(err, "get access token for ", c.tag)
|
|
}
|
|
|
|
proxyURL := claudeAPIBaseURL + original.URL.RequestURI()
|
|
var body io.Reader
|
|
if bodyBytes != nil {
|
|
body = bytes.NewReader(bodyBytes)
|
|
} else {
|
|
body = original.Body
|
|
}
|
|
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for key, values := range original.Header {
|
|
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && !isAPIKeyHeader(key) && key != "Authorization" {
|
|
proxyRequest.Header[key] = values
|
|
}
|
|
}
|
|
|
|
serviceOverridesAcceptEncoding := len(serviceHeaders.Values("Accept-Encoding")) > 0
|
|
if c.usageTracker != nil && !serviceOverridesAcceptEncoding {
|
|
proxyRequest.Header.Del("Accept-Encoding")
|
|
}
|
|
|
|
anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta")
|
|
if anthropicBetaHeader != "" {
|
|
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader)
|
|
} else {
|
|
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
|
|
}
|
|
|
|
for key, values := range serviceHeaders {
|
|
proxyRequest.Header.Del(key)
|
|
proxyRequest.Header[key] = values
|
|
}
|
|
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
|
|
|
|
return proxyRequest, nil
|
|
}
|