Align CCM and OCM rate limits

This commit is contained in:
世界
2026-03-24 22:06:10 +08:00
parent 92c8f4c5c8
commit 4592164a7a
15 changed files with 2097 additions and 196 deletions

View File

@@ -59,6 +59,17 @@ type credentialState struct {
weeklyReset time.Time
hardRateLimited bool
rateLimitResetAt time.Time
availabilityState availabilityState
availabilityReason availabilityReason
availabilityResetAt time.Time
lastKnownDataAt time.Time
unifiedStatus unifiedRateLimitStatus
unifiedResetAt time.Time
representativeClaim string
unifiedFallbackAvailable bool
overageStatus string
overageResetAt time.Time
overageDisabledReason string
accountUUID string
accountType string
rateLimitTier string
@@ -103,6 +114,7 @@ type Credential interface {
isAvailable() bool
isUsable() bool
isExternal() bool
hasSnapshotData() bool
fiveHourUtilization() float64
weeklyUtilization() float64
fiveHourCap() float64
@@ -112,6 +124,8 @@ type Credential interface {
weeklyResetTime() time.Time
markRateLimited(resetAt time.Time)
markUpstreamRejected()
availabilityStatus() availabilityStatus
unifiedRateLimitState() unifiedRateLimitInfo
earliestReset() time.Time
unavailableError() error
@@ -185,6 +199,71 @@ func parseRequiredAnthropicResetHeader(headers http.Header, headerName string) t
return parseAnthropicResetHeaderValue(headerName, headerValue)
}
func (s *credentialState) noteSnapshotData() {
s.lastKnownDataAt = time.Now()
}
func (s credentialState) hasSnapshotData() bool {
return !s.lastKnownDataAt.IsZero() ||
s.fiveHourUtilization > 0 ||
s.weeklyUtilization > 0 ||
!s.fiveHourReset.IsZero() ||
!s.weeklyReset.IsZero()
}
func (s *credentialState) setAvailability(state availabilityState, reason availabilityReason, resetAt time.Time) {
s.availabilityState = state
s.availabilityReason = reason
s.availabilityResetAt = resetAt
}
func (s credentialState) currentAvailability() availabilityStatus {
now := time.Now()
switch {
case s.unavailable:
return availabilityStatus{
State: availabilityStateUnavailable,
Reason: availabilityReasonUnknown,
ResetAt: s.availabilityResetAt,
}
case s.hardRateLimited && (s.rateLimitResetAt.IsZero() || now.Before(s.rateLimitResetAt)):
reason := s.availabilityReason
if reason == "" {
reason = availabilityReasonHardRateLimit
}
return availabilityStatus{
State: availabilityStateRateLimited,
Reason: reason,
ResetAt: s.rateLimitResetAt,
}
case !s.upstreamRejectedUntil.IsZero() && now.Before(s.upstreamRejectedUntil):
return availabilityStatus{
State: availabilityStateTemporarilyBlocked,
Reason: availabilityReasonUpstreamRejected,
ResetAt: s.upstreamRejectedUntil,
}
case s.consecutivePollFailures > 0:
return availabilityStatus{
State: availabilityStateTemporarilyBlocked,
Reason: availabilityReasonPollFailed,
}
default:
return availabilityStatus{State: availabilityStateUsable}
}
}
func (s credentialState) currentUnifiedRateLimit() unifiedRateLimitInfo {
return unifiedRateLimitInfo{
Status: s.unifiedStatus,
ResetAt: s.unifiedResetAt,
RepresentativeClaim: s.representativeClaim,
FallbackAvailable: s.unifiedFallbackAvailable,
OverageStatus: s.overageStatus,
OverageResetAt: s.overageResetAt,
OverageDisabledReason: s.overageDisabledReason,
}.normalized()
}
func parseRateLimitResetFromHeaders(headers http.Header) time.Time {
claim := headers.Get("anthropic-ratelimit-unified-representative-claim")
switch claim {

View File

@@ -623,7 +623,21 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
if hadData {
c.state.consecutivePollFailures = 0
c.state.lastUpdated = time.Now()
c.state.noteSnapshotData()
}
if unifiedStatus := unifiedRateLimitStatus(headers.Get("anthropic-ratelimit-unified-status")); unifiedStatus != "" {
c.state.unifiedStatus = unifiedStatus
}
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-reset"); exists {
c.state.unifiedResetAt = value
}
c.state.representativeClaim = headers.Get("anthropic-ratelimit-unified-representative-claim")
c.state.unifiedFallbackAvailable = headers.Get("anthropic-ratelimit-unified-fallback") == "available"
c.state.overageStatus = headers.Get("anthropic-ratelimit-unified-overage-status")
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-overage-reset"); exists {
c.state.overageResetAt = value
}
c.state.overageDisabledReason = headers.Get("anthropic-ratelimit-unified-overage-disabled-reason")
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
resetSuffix := ""
if !c.state.weeklyReset.IsZero() {
@@ -647,6 +661,9 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) {
c.stateAccess.Lock()
c.state.hardRateLimited = true
c.state.rateLimitResetAt = resetAt
c.state.setAvailability(availabilityStateRateLimited, availabilityReasonHardRateLimit, resetAt)
c.state.unifiedStatus = unifiedRateLimitStatusRejected
c.state.unifiedResetAt = resetAt
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
@@ -741,6 +758,12 @@ func (c *defaultCredential) weeklyUtilization() float64 {
return c.state.weeklyUtilization
}
func (c *defaultCredential) hasSnapshotData() bool {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.hasSnapshotData()
}
func (c *defaultCredential) planWeight() float64 {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
@@ -767,6 +790,18 @@ func (c *defaultCredential) isAvailable() bool {
return !c.state.unavailable
}
func (c *defaultCredential) availabilityStatus() availabilityStatus {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.currentAvailability()
}
func (c *defaultCredential) unifiedRateLimitState() unifiedRateLimitInfo {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.currentUnifiedRateLimit()
}
func (c *defaultCredential) unavailableError() error {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
@@ -794,6 +829,7 @@ func (c *defaultCredential) markUsagePollAttempted() {
func (c *defaultCredential) incrementPollFailures() {
c.stateAccess.Lock()
c.state.consecutivePollFailures++
c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonPollFailed, time.Time{})
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
@@ -944,6 +980,7 @@ func (c *defaultCredential) pollUsage() {
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
c.state.noteSnapshotData()
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
resetSuffix := ""
if !c.state.weeklyReset.IsZero() {

View File

@@ -343,6 +343,9 @@ func (c *externalCredential) markRateLimited(resetAt time.Time) {
c.stateAccess.Lock()
c.state.hardRateLimited = true
c.state.rateLimitResetAt = resetAt
c.state.setAvailability(availabilityStateRateLimited, availabilityReasonHardRateLimit, resetAt)
c.state.unifiedStatus = unifiedRateLimitStatusRejected
c.state.unifiedResetAt = resetAt
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
@@ -355,6 +358,7 @@ func (c *externalCredential) markUpstreamRejected() {
c.logger.Warn("upstream rejected credential ", c.tag, ", marking unavailable for ", log.FormatDuration(c.pollInterval))
c.stateAccess.Lock()
c.state.upstreamRejectedUntil = time.Now().Add(c.pollInterval)
c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonUpstreamRejected, c.state.upstreamRejectedUntil)
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
@@ -493,7 +497,21 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
c.state.consecutivePollFailures = 0
c.state.upstreamRejectedUntil = time.Time{}
c.state.lastUpdated = time.Now()
c.state.noteSnapshotData()
}
if unifiedStatus := unifiedRateLimitStatus(headers.Get("anthropic-ratelimit-unified-status")); unifiedStatus != "" {
c.state.unifiedStatus = unifiedStatus
}
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-reset"); exists {
c.state.unifiedResetAt = value
}
c.state.representativeClaim = headers.Get("anthropic-ratelimit-unified-representative-claim")
c.state.unifiedFallbackAvailable = headers.Get("anthropic-ratelimit-unified-fallback") == "available"
c.state.overageStatus = headers.Get("anthropic-ratelimit-unified-overage-status")
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-overage-reset"); exists {
c.state.overageResetAt = value
}
c.state.overageDisabledReason = headers.Get("anthropic-ratelimit-unified-overage-disabled-reason")
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
resetSuffix := ""
if !c.state.weeklyReset.IsZero() {
@@ -635,13 +653,7 @@ func (c *externalCredential) pollUsage() {
c.clearPollFailures()
return
}
var statusResponse struct {
FiveHourUtilization float64 `json:"five_hour_utilization"`
FiveHourReset int64 `json:"five_hour_reset"`
WeeklyUtilization float64 `json:"weekly_utilization"`
WeeklyReset int64 `json:"weekly_reset"`
PlanWeight float64 `json:"plan_weight"`
}
var statusResponse statusPayload
err = json.Unmarshal(body, &statusResponse)
if err != nil {
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
@@ -657,6 +669,11 @@ func (c *externalCredential) pollUsage() {
c.state.upstreamRejectedUntil = time.Time{}
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
c.state.unifiedStatus = unifiedRateLimitStatus(statusResponse.UnifiedStatus)
c.state.representativeClaim = statusResponse.RepresentativeClaim
c.state.unifiedFallbackAvailable = statusResponse.FallbackAvailable
c.state.overageStatus = statusResponse.OverageStatus
c.state.overageDisabledReason = statusResponse.OverageDisabledReason
if statusResponse.PlanWeight > 0 {
c.state.remotePlanWeight = statusResponse.PlanWeight
}
@@ -666,6 +683,30 @@ func (c *externalCredential) pollUsage() {
if statusResponse.WeeklyReset > 0 {
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
}
if statusResponse.UnifiedReset > 0 {
c.state.unifiedResetAt = time.Unix(statusResponse.UnifiedReset, 0)
}
if statusResponse.OverageReset > 0 {
c.state.overageResetAt = time.Unix(statusResponse.OverageReset, 0)
}
if statusResponse.Availability != nil {
switch availabilityState(statusResponse.Availability.State) {
case availabilityStateRateLimited:
c.state.hardRateLimited = true
if statusResponse.Availability.ResetAt > 0 {
c.state.rateLimitResetAt = time.Unix(statusResponse.Availability.ResetAt, 0)
}
case availabilityStateTemporarilyBlocked:
resetAt := time.Time{}
if statusResponse.Availability.ResetAt > 0 {
resetAt = time.Unix(statusResponse.Availability.ResetAt, 0)
}
c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReason(statusResponse.Availability.Reason), resetAt)
if availabilityReason(statusResponse.Availability.Reason) == availabilityReasonUpstreamRejected && !resetAt.IsZero() {
c.state.upstreamRejectedUntil = resetAt
}
}
}
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
@@ -766,6 +807,11 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr
c.state.upstreamRejectedUntil = time.Time{}
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
c.state.unifiedStatus = unifiedRateLimitStatus(statusResponse.UnifiedStatus)
c.state.representativeClaim = statusResponse.RepresentativeClaim
c.state.unifiedFallbackAvailable = statusResponse.FallbackAvailable
c.state.overageStatus = statusResponse.OverageStatus
c.state.overageDisabledReason = statusResponse.OverageDisabledReason
if statusResponse.PlanWeight > 0 {
c.state.remotePlanWeight = statusResponse.PlanWeight
}
@@ -775,6 +821,30 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr
if statusResponse.WeeklyReset > 0 {
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
}
if statusResponse.UnifiedReset > 0 {
c.state.unifiedResetAt = time.Unix(statusResponse.UnifiedReset, 0)
}
if statusResponse.OverageReset > 0 {
c.state.overageResetAt = time.Unix(statusResponse.OverageReset, 0)
}
if statusResponse.Availability != nil {
switch availabilityState(statusResponse.Availability.State) {
case availabilityStateRateLimited:
c.state.hardRateLimited = true
if statusResponse.Availability.ResetAt > 0 {
c.state.rateLimitResetAt = time.Unix(statusResponse.Availability.ResetAt, 0)
}
case availabilityStateTemporarilyBlocked:
resetAt := time.Time{}
if statusResponse.Availability.ResetAt > 0 {
resetAt = time.Unix(statusResponse.Availability.ResetAt, 0)
}
c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReason(statusResponse.Availability.Reason), resetAt)
if availabilityReason(statusResponse.Availability.Reason) == availabilityReasonUpstreamRejected && !resetAt.IsZero() {
c.state.upstreamRejectedUntil = resetAt
}
}
}
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
@@ -846,6 +916,24 @@ func (c *externalCredential) lastUpdatedTime() time.Time {
return c.state.lastUpdated
}
func (c *externalCredential) hasSnapshotData() bool {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.hasSnapshotData()
}
func (c *externalCredential) availabilityStatus() availabilityStatus {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.currentAvailability()
}
func (c *externalCredential) unifiedRateLimitState() unifiedRateLimitInfo {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.currentUnifiedRateLimit()
}
func (c *externalCredential) markUsageStreamUpdated() {
c.stateAccess.Lock()
defer c.stateAccess.Unlock()

View File

@@ -0,0 +1,124 @@
package ccm
import "time"
type availabilityState string
const (
availabilityStateUsable availabilityState = "usable"
availabilityStateRateLimited availabilityState = "rate_limited"
availabilityStateTemporarilyBlocked availabilityState = "temporarily_blocked"
availabilityStateUnavailable availabilityState = "unavailable"
availabilityStateUnknown availabilityState = "unknown"
)
type availabilityReason string
const (
availabilityReasonHardRateLimit availabilityReason = "hard_rate_limit"
availabilityReasonConnectionLimit availabilityReason = "connection_limit"
availabilityReasonPollFailed availabilityReason = "poll_failed"
availabilityReasonUpstreamRejected availabilityReason = "upstream_rejected"
availabilityReasonNoCredentials availabilityReason = "no_credentials"
availabilityReasonUnknown availabilityReason = "unknown"
)
type availabilityStatus struct {
State availabilityState
Reason availabilityReason
ResetAt time.Time
}
type availabilityPayload struct {
State string `json:"state"`
Reason string `json:"reason,omitempty"`
ResetAt int64 `json:"reset_at,omitempty"`
}
func (s availabilityStatus) normalized() availabilityStatus {
if s.State == "" {
s.State = availabilityStateUnknown
}
if s.Reason == "" && s.State != availabilityStateUsable {
s.Reason = availabilityReasonUnknown
}
return s
}
func (s availabilityStatus) toPayload() *availabilityPayload {
s = s.normalized()
if s.State == "" {
return nil
}
payload := &availabilityPayload{
State: string(s.State),
}
if s.Reason != "" && s.Reason != availabilityReasonUnknown {
payload.Reason = string(s.Reason)
}
if !s.ResetAt.IsZero() {
payload.ResetAt = s.ResetAt.Unix()
}
return payload
}
type unifiedRateLimitStatus string
const (
unifiedRateLimitStatusAllowed unifiedRateLimitStatus = "allowed"
unifiedRateLimitStatusAllowedWarning unifiedRateLimitStatus = "allowed_warning"
unifiedRateLimitStatusRejected unifiedRateLimitStatus = "rejected"
)
type unifiedRateLimitInfo struct {
Status unifiedRateLimitStatus
ResetAt time.Time
RepresentativeClaim string
FallbackAvailable bool
OverageStatus string
OverageResetAt time.Time
OverageDisabledReason string
}
func (s unifiedRateLimitInfo) normalized() unifiedRateLimitInfo {
if s.Status == "" {
s.Status = unifiedRateLimitStatusAllowed
}
return s
}
func claudeWindowProgress(resetAt time.Time, windowSeconds float64, now time.Time) float64 {
if resetAt.IsZero() || windowSeconds <= 0 {
return 0
}
windowStart := resetAt.Add(-time.Duration(windowSeconds * float64(time.Second)))
if now.Before(windowStart) {
return 0
}
progress := now.Sub(windowStart).Seconds() / windowSeconds
if progress < 0 {
return 0
}
if progress > 1 {
return 1
}
return progress
}
func claudeFiveHourWarning(utilizationPercent float64, resetAt time.Time, now time.Time) bool {
return utilizationPercent >= 90 && claudeWindowProgress(resetAt, 5*60*60, now) >= 0.72
}
func claudeWeeklyWarning(utilizationPercent float64, resetAt time.Time, now time.Time) bool {
progress := claudeWindowProgress(resetAt, 7*24*60*60, now)
switch {
case utilizationPercent >= 75:
return progress >= 0.60
case utilizationPercent >= 50:
return progress >= 0.35
case utilizationPercent >= 25:
return progress >= 0.15
default:
return false
}
}

View File

@@ -105,6 +105,10 @@ func writeCredentialUnavailableError(
writeRetryableUsageError(w, r)
return
}
if provider != nil && strings.HasPrefix(allCredentialsUnavailableError(provider.allCredentials()).Error(), "all credentials rate-limited") {
writeRetryableUsageError(w, r)
return
}
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, fallback))
}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"net/http"
"reflect"
"strconv"
"strings"
"time"
@@ -12,11 +13,19 @@ import (
)
type statusPayload struct {
FiveHourUtilization float64 `json:"five_hour_utilization"`
FiveHourReset int64 `json:"five_hour_reset"`
WeeklyUtilization float64 `json:"weekly_utilization"`
WeeklyReset int64 `json:"weekly_reset"`
PlanWeight float64 `json:"plan_weight"`
FiveHourUtilization float64 `json:"five_hour_utilization"`
FiveHourReset int64 `json:"five_hour_reset"`
WeeklyUtilization float64 `json:"weekly_utilization"`
WeeklyReset int64 `json:"weekly_reset"`
PlanWeight float64 `json:"plan_weight"`
UnifiedStatus string `json:"unified_status,omitempty"`
UnifiedReset int64 `json:"unified_reset,omitempty"`
RepresentativeClaim string `json:"representative_claim,omitempty"`
FallbackAvailable bool `json:"fallback_available,omitempty"`
OverageStatus string `json:"overage_status,omitempty"`
OverageReset int64 `json:"overage_reset,omitempty"`
OverageDisabledReason string `json:"overage_disabled_reason,omitempty"`
Availability *availabilityPayload `json:"availability,omitempty"`
}
type aggregatedStatus struct {
@@ -25,6 +34,8 @@ type aggregatedStatus struct {
totalWeight float64
fiveHourReset time.Time
weeklyReset time.Time
unifiedRateLimit unifiedRateLimitInfo
availability availabilityStatus
}
func resetToEpoch(t time.Time) int64 {
@@ -35,23 +46,176 @@ func resetToEpoch(t time.Time) int64 {
}
func (s aggregatedStatus) equal(other aggregatedStatus) bool {
return s.fiveHourUtilization == other.fiveHourUtilization &&
s.weeklyUtilization == other.weeklyUtilization &&
s.totalWeight == other.totalWeight &&
resetToEpoch(s.fiveHourReset) == resetToEpoch(other.fiveHourReset) &&
resetToEpoch(s.weeklyReset) == resetToEpoch(other.weeklyReset)
return reflect.DeepEqual(s.toPayload(), other.toPayload())
}
func (s aggregatedStatus) toPayload() statusPayload {
unified := s.unifiedRateLimit.normalized()
return statusPayload{
FiveHourUtilization: s.fiveHourUtilization,
FiveHourReset: resetToEpoch(s.fiveHourReset),
WeeklyUtilization: s.weeklyUtilization,
WeeklyReset: resetToEpoch(s.weeklyReset),
PlanWeight: s.totalWeight,
FiveHourUtilization: s.fiveHourUtilization,
FiveHourReset: resetToEpoch(s.fiveHourReset),
WeeklyUtilization: s.weeklyUtilization,
WeeklyReset: resetToEpoch(s.weeklyReset),
PlanWeight: s.totalWeight,
UnifiedStatus: string(unified.Status),
UnifiedReset: resetToEpoch(unified.ResetAt),
RepresentativeClaim: unified.RepresentativeClaim,
FallbackAvailable: unified.FallbackAvailable,
OverageStatus: unified.OverageStatus,
OverageReset: resetToEpoch(unified.OverageResetAt),
OverageDisabledReason: unified.OverageDisabledReason,
Availability: s.availability.toPayload(),
}
}
type aggregateInput struct {
availability availabilityStatus
unified unifiedRateLimitInfo
}
func aggregateAvailability(inputs []aggregateInput) availabilityStatus {
if len(inputs) == 0 {
return availabilityStatus{
State: availabilityStateUnavailable,
Reason: availabilityReasonNoCredentials,
}
}
var earliestRateLimit time.Time
var hasRateLimited bool
var blocked availabilityStatus
var hasBlocked bool
var hasUnavailable bool
for _, input := range inputs {
availability := input.availability.normalized()
switch availability.State {
case availabilityStateUsable:
return availabilityStatus{State: availabilityStateUsable}
case availabilityStateRateLimited:
hasRateLimited = true
if !availability.ResetAt.IsZero() && (earliestRateLimit.IsZero() || availability.ResetAt.Before(earliestRateLimit)) {
earliestRateLimit = availability.ResetAt
}
if blocked.State == "" {
blocked = availabilityStatus{
State: availabilityStateRateLimited,
Reason: availabilityReasonHardRateLimit,
ResetAt: earliestRateLimit,
}
}
case availabilityStateTemporarilyBlocked:
if !hasBlocked {
blocked = availability
hasBlocked = true
}
if !availability.ResetAt.IsZero() && (blocked.ResetAt.IsZero() || availability.ResetAt.Before(blocked.ResetAt)) {
blocked.ResetAt = availability.ResetAt
}
case availabilityStateUnavailable:
hasUnavailable = true
}
}
if hasRateLimited {
blocked.ResetAt = earliestRateLimit
return blocked
}
if hasBlocked {
return blocked
}
if hasUnavailable {
return availabilityStatus{
State: availabilityStateUnavailable,
Reason: availabilityReasonUnknown,
}
}
return availabilityStatus{
State: availabilityStateUnknown,
Reason: availabilityReasonUnknown,
}
}
func chooseRepresentativeClaim(status unifiedRateLimitStatus, fiveHourUtilization float64, fiveHourReset time.Time, weeklyUtilization float64, weeklyReset time.Time, now time.Time) string {
type claimCandidate struct {
name string
priority int
utilization float64
}
candidateFor := func(name string, utilization float64, warning bool) claimCandidate {
priority := 0
switch {
case status == unifiedRateLimitStatusRejected && utilization >= 100:
priority = 2
case warning:
priority = 1
}
return claimCandidate{name: name, priority: priority, utilization: utilization}
}
five := candidateFor("5h", fiveHourUtilization, claudeFiveHourWarning(fiveHourUtilization, fiveHourReset, now))
weekly := candidateFor("7d", weeklyUtilization, claudeWeeklyWarning(weeklyUtilization, weeklyReset, now))
switch {
case five.priority > weekly.priority:
return five.name
case weekly.priority > five.priority:
return weekly.name
case five.utilization > weekly.utilization:
return five.name
case weekly.utilization > five.utilization:
return weekly.name
case !fiveHourReset.IsZero():
return five.name
case !weeklyReset.IsZero():
return weekly.name
default:
return "5h"
}
}
func aggregateUnifiedRateLimit(inputs []aggregateInput, fiveHourUtilization float64, fiveHourReset time.Time, weeklyUtilization float64, weeklyReset time.Time, availability availabilityStatus) unifiedRateLimitInfo {
now := time.Now()
info := unifiedRateLimitInfo{}
usableCount := 0
for _, input := range inputs {
if input.availability.State == availabilityStateUsable {
usableCount++
}
if input.unified.OverageStatus != "" && info.OverageStatus == "" {
info.OverageStatus = input.unified.OverageStatus
info.OverageResetAt = input.unified.OverageResetAt
info.OverageDisabledReason = input.unified.OverageDisabledReason
}
if input.unified.Status == unifiedRateLimitStatusRejected {
info.Status = unifiedRateLimitStatusRejected
if !input.unified.ResetAt.IsZero() && (info.ResetAt.IsZero() || input.unified.ResetAt.Before(info.ResetAt)) {
info.ResetAt = input.unified.ResetAt
info.RepresentativeClaim = input.unified.RepresentativeClaim
}
}
}
if info.Status == "" {
switch {
case availability.State == availabilityStateRateLimited || fiveHourUtilization >= 100 || weeklyUtilization >= 100:
info.Status = unifiedRateLimitStatusRejected
info.ResetAt = availability.ResetAt
case claudeFiveHourWarning(fiveHourUtilization, fiveHourReset, now) || claudeWeeklyWarning(weeklyUtilization, weeklyReset, now):
info.Status = unifiedRateLimitStatusAllowedWarning
default:
info.Status = unifiedRateLimitStatusAllowed
}
}
info.FallbackAvailable = usableCount > 0 && len(inputs) > 1
if info.RepresentativeClaim == "" {
info.RepresentativeClaim = chooseRepresentativeClaim(info.Status, fiveHourUtilization, fiveHourReset, weeklyUtilization, weeklyReset, now)
}
if info.ResetAt.IsZero() {
switch info.RepresentativeClaim {
case "7d":
info.ResetAt = weeklyReset
default:
info.ResetAt = fiveHourReset
}
}
return info.normalized()
}
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
@@ -171,20 +335,27 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro
}
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) aggregatedStatus {
visibleInputs := make([]aggregateInput, 0, len(provider.allCredentials()))
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
now := time.Now()
var totalWeightedHoursUntil5hReset, total5hResetWeight float64
var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64
var hasSnapshotData bool
for _, credential := range provider.allCredentials() {
if !credential.isUsable() {
continue
}
if userConfig != nil && userConfig.ExternalCredential != "" && credential.tagName() == userConfig.ExternalCredential {
continue
}
if userConfig != nil && !userConfig.AllowExternalUsage && credential.isExternal() {
continue
}
visibleInputs = append(visibleInputs, aggregateInput{
availability: credential.availabilityStatus(),
unified: credential.unifiedRateLimitState(),
})
if !credential.hasSnapshotData() {
continue
}
hasSnapshotData = true
weight := credential.planWeight()
remaining5h := credential.fiveHourCap() - credential.fiveHourUtilization()
if remaining5h < 0 {
@@ -215,16 +386,21 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user
}
}
}
availability := aggregateAvailability(visibleInputs)
if totalWeight == 0 {
return aggregatedStatus{
fiveHourUtilization: 100,
weeklyUtilization: 100,
result := aggregatedStatus{availability: availability}
if !hasSnapshotData {
result.fiveHourUtilization = 100
result.weeklyUtilization = 100
}
result.unifiedRateLimit = aggregateUnifiedRateLimit(visibleInputs, result.fiveHourUtilization, result.fiveHourReset, result.weeklyUtilization, result.weeklyReset, availability)
return result
}
result := aggregatedStatus{
fiveHourUtilization: 100 - totalWeightedRemaining5h/totalWeight,
weeklyUtilization: 100 - totalWeightedRemainingWeekly/totalWeight,
totalWeight: totalWeight,
availability: availability,
}
if total5hResetWeight > 0 {
avgHours := totalWeightedHoursUntil5hReset / total5hResetWeight
@@ -234,6 +410,7 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user
avgHours := totalWeightedHoursUntilWeeklyReset / totalWeeklyResetWeight
result.weeklyReset = now.Add(time.Duration(avgHours * float64(time.Hour)))
}
result.unifiedRateLimit = aggregateUnifiedRateLimit(visibleInputs, result.fiveHourUtilization, result.fiveHourReset, result.weeklyUtilization, result.weeklyReset, availability)
return result
}
@@ -254,4 +431,45 @@ func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentia
if status.totalWeight > 0 {
headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64))
}
headers.Set("anthropic-ratelimit-unified-status", string(status.unifiedRateLimit.normalized().Status))
if !status.unifiedRateLimit.ResetAt.IsZero() {
headers.Set("anthropic-ratelimit-unified-reset", strconv.FormatInt(status.unifiedRateLimit.ResetAt.Unix(), 10))
} else {
headers.Del("anthropic-ratelimit-unified-reset")
}
if status.unifiedRateLimit.RepresentativeClaim != "" {
headers.Set("anthropic-ratelimit-unified-representative-claim", status.unifiedRateLimit.RepresentativeClaim)
} else {
headers.Del("anthropic-ratelimit-unified-representative-claim")
}
if status.unifiedRateLimit.FallbackAvailable {
headers.Set("anthropic-ratelimit-unified-fallback", "available")
} else {
headers.Del("anthropic-ratelimit-unified-fallback")
}
if status.unifiedRateLimit.OverageStatus != "" {
headers.Set("anthropic-ratelimit-unified-overage-status", status.unifiedRateLimit.OverageStatus)
} else {
headers.Del("anthropic-ratelimit-unified-overage-status")
}
if !status.unifiedRateLimit.OverageResetAt.IsZero() {
headers.Set("anthropic-ratelimit-unified-overage-reset", strconv.FormatInt(status.unifiedRateLimit.OverageResetAt.Unix(), 10))
} else {
headers.Del("anthropic-ratelimit-unified-overage-reset")
}
if status.unifiedRateLimit.OverageDisabledReason != "" {
headers.Set("anthropic-ratelimit-unified-overage-disabled-reason", status.unifiedRateLimit.OverageDisabledReason)
} else {
headers.Del("anthropic-ratelimit-unified-overage-disabled-reason")
}
if claudeFiveHourWarning(status.fiveHourUtilization, status.fiveHourReset, time.Now()) || status.fiveHourUtilization >= 100 {
headers.Set("anthropic-ratelimit-unified-5h-surpassed-threshold", "true")
} else {
headers.Del("anthropic-ratelimit-unified-5h-surpassed-threshold")
}
if claudeWeeklyWarning(status.weeklyUtilization, status.weeklyReset, time.Now()) || status.weeklyUtilization >= 100 {
headers.Set("anthropic-ratelimit-unified-7d-surpassed-threshold", "true")
} else {
headers.Del("anthropic-ratelimit-unified-7d-surpassed-threshold")
}
}

View File

@@ -0,0 +1,173 @@
package ccm
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/sagernet/sing/common/observable"
)
type testCredential struct {
tag string
external bool
available bool
usable bool
hasData bool
fiveHour float64
weekly float64
fiveHourCapV float64
weeklyCapV float64
weight float64
fiveReset time.Time
weeklyReset time.Time
availability availabilityStatus
unified unifiedRateLimitInfo
}
func (c *testCredential) tagName() string { return c.tag }
func (c *testCredential) isAvailable() bool { return c.available }
func (c *testCredential) isUsable() bool { return c.usable }
func (c *testCredential) isExternal() bool { return c.external }
func (c *testCredential) hasSnapshotData() bool { return c.hasData }
func (c *testCredential) fiveHourUtilization() float64 { return c.fiveHour }
func (c *testCredential) weeklyUtilization() float64 { return c.weekly }
func (c *testCredential) fiveHourCap() float64 { return c.fiveHourCapV }
func (c *testCredential) weeklyCap() float64 { return c.weeklyCapV }
func (c *testCredential) planWeight() float64 { return c.weight }
func (c *testCredential) fiveHourResetTime() time.Time { return c.fiveReset }
func (c *testCredential) weeklyResetTime() time.Time { return c.weeklyReset }
func (c *testCredential) markRateLimited(time.Time) {}
func (c *testCredential) markUpstreamRejected() {}
func (c *testCredential) availabilityStatus() availabilityStatus { return c.availability }
func (c *testCredential) unifiedRateLimitState() unifiedRateLimitInfo { return c.unified }
func (c *testCredential) earliestReset() time.Time { return c.fiveReset }
func (c *testCredential) unavailableError() error { return nil }
func (c *testCredential) getAccessToken() (string, error) { return "", nil }
func (c *testCredential) buildProxyRequest(context.Context, *http.Request, []byte, http.Header) (*http.Request, error) {
return nil, nil
}
func (c *testCredential) updateStateFromHeaders(http.Header) {}
func (c *testCredential) wrapRequestContext(context.Context) *credentialRequestContext { return nil }
func (c *testCredential) interruptConnections() {}
func (c *testCredential) setStatusSubscriber(*observable.Subscriber[struct{}]) {}
func (c *testCredential) start() error { return nil }
func (c *testCredential) pollUsage() {}
func (c *testCredential) lastUpdatedTime() time.Time { return time.Now() }
func (c *testCredential) pollBackoff(time.Duration) time.Duration { return 0 }
func (c *testCredential) usageTrackerOrNil() *AggregatedUsage { return nil }
func (c *testCredential) httpClient() *http.Client { return nil }
func (c *testCredential) close() {}
type testProvider struct {
credentials []Credential
}
func (p *testProvider) selectCredential(string, credentialSelection) (Credential, bool, error) {
return nil, false, nil
}
func (p *testProvider) onRateLimited(string, Credential, time.Time, credentialSelection) Credential {
return nil
}
func (p *testProvider) linkProviderInterrupt(Credential, credentialSelection, func()) func() bool {
return func() bool { return true }
}
func (p *testProvider) pollIfStale() {}
func (p *testProvider) pollCredentialIfStale(Credential) {}
func (p *testProvider) allCredentials() []Credential { return p.credentials }
func (p *testProvider) close() {}
func TestComputeAggregatedUtilizationPreservesSnapshotForRateLimitedCredential(t *testing.T) {
t.Parallel()
reset := time.Now().Add(15 * time.Minute)
service := &Service{}
status := service.computeAggregatedUtilization(&testProvider{credentials: []Credential{
&testCredential{
tag: "a",
available: true,
usable: false,
hasData: true,
fiveHour: 42,
weekly: 18,
fiveHourCapV: 100,
weeklyCapV: 100,
weight: 1,
fiveReset: reset,
weeklyReset: reset.Add(2 * time.Hour),
availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: reset},
unified: unifiedRateLimitInfo{Status: unifiedRateLimitStatusRejected, ResetAt: reset, RepresentativeClaim: "5h"},
},
}}, nil)
if status.fiveHourUtilization != 42 || status.weeklyUtilization != 18 {
t.Fatalf("expected preserved utilization, got 5h=%v weekly=%v", status.fiveHourUtilization, status.weeklyUtilization)
}
if status.unifiedRateLimit.Status != unifiedRateLimitStatusRejected {
t.Fatalf("expected rejected unified status, got %q", status.unifiedRateLimit.Status)
}
if status.availability.State != availabilityStateRateLimited {
t.Fatalf("expected rate-limited availability, got %#v", status.availability)
}
}
func TestRewriteResponseHeadersIncludesUnifiedHeaders(t *testing.T) {
t.Parallel()
reset := time.Now().Add(80 * time.Minute)
service := &Service{}
headers := make(http.Header)
service.rewriteResponseHeaders(headers, &testProvider{credentials: []Credential{
&testCredential{
tag: "a",
available: true,
usable: true,
hasData: true,
fiveHour: 92,
weekly: 30,
fiveHourCapV: 100,
weeklyCapV: 100,
weight: 1,
fiveReset: reset,
weeklyReset: time.Now().Add(4 * 24 * time.Hour),
availability: availabilityStatus{State: availabilityStateUsable},
},
}}, nil)
if headers.Get("anthropic-ratelimit-unified-status") != "allowed_warning" {
t.Fatalf("expected allowed_warning, got %q", headers.Get("anthropic-ratelimit-unified-status"))
}
if headers.Get("anthropic-ratelimit-unified-representative-claim") != "5h" {
t.Fatalf("expected 5h representative claim, got %q", headers.Get("anthropic-ratelimit-unified-representative-claim"))
}
if headers.Get("anthropic-ratelimit-unified-5h-surpassed-threshold") != "true" {
t.Fatalf("expected 5h threshold header")
}
}
func TestWriteCredentialUnavailableErrorReturns429ForRateLimitedCredentials(t *testing.T) {
t.Parallel()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/v1/messages", nil)
provider := &testProvider{credentials: []Credential{
&testCredential{
tag: "a",
available: true,
usable: false,
hasData: true,
fiveHourCapV: 100,
weeklyCapV: 100,
weight: 1,
availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: time.Now().Add(time.Minute)},
},
}}
writeCredentialUnavailableError(recorder, request, provider, provider.credentials[0], credentialSelection{}, "all credentials rate-limited")
if recorder.Code != http.StatusTooManyRequests {
t.Fatalf("expected 429, got %d", recorder.Code)
}
}

View File

@@ -61,8 +61,14 @@ type credentialState struct {
weeklyReset time.Time
hardRateLimited bool
rateLimitResetAt time.Time
availabilityState availabilityState
availabilityReason availabilityReason
availabilityResetAt time.Time
lastKnownDataAt time.Time
accountType string
remotePlanWeight float64
activeLimitID string
rateLimitSnapshots map[string]rateLimitSnapshot
lastUpdated time.Time
consecutivePollFailures int
usageAPIRetryDelay time.Duration
@@ -102,6 +108,7 @@ type Credential interface {
isAvailable() bool
isUsable() bool
isExternal() bool
hasSnapshotData() bool
fiveHourUtilization() float64
weeklyUtilization() float64
fiveHourCap() float64
@@ -111,6 +118,10 @@ type Credential interface {
fiveHourResetTime() time.Time
markRateLimited(resetAt time.Time)
markUpstreamRejected()
markTemporarilyBlocked(reason availabilityReason, resetAt time.Time)
availabilityStatus() availabilityStatus
rateLimitSnapshots() []rateLimitSnapshot
activeLimitID() string
earliestReset() time.Time
unavailableError() error
@@ -200,3 +211,67 @@ func parseOCMRateLimitResetFromHeaders(headers http.Header) time.Time {
}
return time.Now().Add(5 * time.Minute)
}
func (s *credentialState) noteSnapshotData() {
s.lastKnownDataAt = time.Now()
}
func (s credentialState) hasSnapshotData() bool {
return !s.lastKnownDataAt.IsZero() ||
s.fiveHourUtilization > 0 ||
s.weeklyUtilization > 0 ||
!s.fiveHourReset.IsZero() ||
!s.weeklyReset.IsZero() ||
len(s.rateLimitSnapshots) > 0
}
func (s *credentialState) setAvailability(state availabilityState, reason availabilityReason, resetAt time.Time) {
s.availabilityState = state
s.availabilityReason = reason
s.availabilityResetAt = resetAt
}
func (s credentialState) currentAvailability() availabilityStatus {
now := time.Now()
switch {
case s.unavailable:
return availabilityStatus{
State: availabilityStateUnavailable,
Reason: availabilityReasonUnknown,
}
case s.availabilityState == availabilityStateTemporarilyBlocked &&
(s.availabilityResetAt.IsZero() || now.Before(s.availabilityResetAt)):
reason := s.availabilityReason
if reason == "" {
reason = availabilityReasonUnknown
}
return availabilityStatus{
State: availabilityStateTemporarilyBlocked,
Reason: reason,
ResetAt: s.availabilityResetAt,
}
case s.hardRateLimited && (s.rateLimitResetAt.IsZero() || now.Before(s.rateLimitResetAt)):
reason := s.availabilityReason
if reason == "" {
reason = availabilityReasonHardRateLimit
}
return availabilityStatus{
State: availabilityStateRateLimited,
Reason: reason,
ResetAt: s.rateLimitResetAt,
}
case !s.upstreamRejectedUntil.IsZero() && now.Before(s.upstreamRejectedUntil):
return availabilityStatus{
State: availabilityStateTemporarilyBlocked,
Reason: availabilityReasonUpstreamRejected,
ResetAt: s.upstreamRejectedUntil,
}
case s.consecutivePollFailures > 0:
return availabilityStatus{
State: availabilityStateTemporarilyBlocked,
Reason: availabilityReasonPollFailed,
}
default:
return availabilityStatus{State: availabilityStateUsable}
}
}

View File

@@ -359,9 +359,14 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
}
}
}
if snapshots := parseRateLimitSnapshotsFromHeaders(headers); len(snapshots) > 0 {
hadData = true
applyRateLimitSnapshotsLocked(&c.state, snapshots, headers.Get("x-codex-active-limit"), c.state.remotePlanWeight, c.state.accountType)
}
if hadData {
c.state.consecutivePollFailures = 0
c.state.lastUpdated = time.Now()
c.state.noteSnapshotData()
}
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
resetSuffix := ""
@@ -386,6 +391,7 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) {
c.stateAccess.Lock()
c.state.hardRateLimited = true
c.state.rateLimitResetAt = resetAt
c.state.setAvailability(availabilityStateRateLimited, availabilityReasonHardRateLimit, resetAt)
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
@@ -396,6 +402,17 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) {
func (c *defaultCredential) markUpstreamRejected() {}
func (c *defaultCredential) markTemporarilyBlocked(reason availabilityReason, resetAt time.Time) {
c.stateAccess.Lock()
c.state.setAvailability(availabilityStateTemporarilyBlocked, reason, resetAt)
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *defaultCredential) isUsable() bool {
c.retryCredentialReloadIfNeeded()
@@ -483,6 +500,12 @@ func (c *defaultCredential) fiveHourUtilization() float64 {
return c.state.fiveHourUtilization
}
func (c *defaultCredential) hasSnapshotData() bool {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.hasSnapshotData()
}
func (c *defaultCredential) weeklyUtilization() float64 {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
@@ -515,6 +538,32 @@ func (c *defaultCredential) isAvailable() bool {
return !c.state.unavailable
}
func (c *defaultCredential) availabilityStatus() availabilityStatus {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.currentAvailability()
}
func (c *defaultCredential) rateLimitSnapshots() []rateLimitSnapshot {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
if len(c.state.rateLimitSnapshots) == 0 {
return nil
}
snapshots := make([]rateLimitSnapshot, 0, len(c.state.rateLimitSnapshots))
for _, snapshot := range c.state.rateLimitSnapshots {
snapshots = append(snapshots, cloneRateLimitSnapshot(snapshot))
}
sortRateLimitSnapshots(snapshots)
return snapshots
}
func (c *defaultCredential) activeLimitID() string {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.activeLimitID
}
func (c *defaultCredential) unavailableError() error {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
@@ -542,6 +591,7 @@ func (c *defaultCredential) markUsagePollAttempted() {
func (c *defaultCredential) incrementPollFailures() {
c.stateAccess.Lock()
c.state.consecutivePollFailures++
c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonPollFailed, time.Time{})
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
@@ -696,17 +746,7 @@ func (c *defaultCredential) pollUsage() {
return
}
type usageWindow struct {
UsedPercent float64 `json:"used_percent"`
ResetAt int64 `json:"reset_at"`
}
var usageResponse struct {
PlanType string `json:"plan_type"`
RateLimit *struct {
PrimaryWindow *usageWindow `json:"primary_window"`
SecondaryWindow *usageWindow `json:"secondary_window"`
} `json:"rate_limit"`
}
var usageResponse usageRateLimitStatusPayload
err = json.NewDecoder(response.Body).Decode(&usageResponse)
if err != nil {
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
@@ -720,26 +760,11 @@ func (c *defaultCredential) pollUsage() {
oldWeekly := c.state.weeklyUtilization
c.state.consecutivePollFailures = 0
c.state.usageAPIRetryDelay = 0
if usageResponse.RateLimit != nil {
if w := usageResponse.RateLimit.PrimaryWindow; w != nil {
c.state.fiveHourUtilization = w.UsedPercent
if w.ResetAt > 0 {
c.state.fiveHourReset = time.Unix(w.ResetAt, 0)
}
}
if w := usageResponse.RateLimit.SecondaryWindow; w != nil {
c.state.weeklyUtilization = w.UsedPercent
if w.ResetAt > 0 {
c.state.weeklyReset = time.Unix(w.ResetAt, 0)
}
}
}
if usageResponse.PlanType != "" {
c.state.accountType = usageResponse.PlanType
}
applyRateLimitSnapshotsLocked(&c.state, snapshotsFromUsagePayload(usageResponse), c.state.activeLimitID, c.state.remotePlanWeight, usageResponse.PlanType)
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
c.state.noteSnapshotData()
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
resetSuffix := ""
if !c.state.weeklyReset.IsZero() {

View File

@@ -367,6 +367,7 @@ func (c *externalCredential) markRateLimited(resetAt time.Time) {
c.stateAccess.Lock()
c.state.hardRateLimited = true
c.state.rateLimitResetAt = resetAt
c.state.setAvailability(availabilityStateRateLimited, availabilityReasonHardRateLimit, resetAt)
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
@@ -379,6 +380,18 @@ func (c *externalCredential) markUpstreamRejected() {
c.logger.Warn("upstream rejected credential ", c.tag, ", marking unavailable for ", log.FormatDuration(c.pollInterval))
c.stateAccess.Lock()
c.state.upstreamRejectedUntil = time.Now().Add(c.pollInterval)
c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonUpstreamRejected, c.state.upstreamRejectedUntil)
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *externalCredential) markTemporarilyBlocked(reason availabilityReason, resetAt time.Time) {
c.stateAccess.Lock()
c.state.setAvailability(availabilityStateTemporarilyBlocked, reason, resetAt)
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
@@ -528,10 +541,15 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
c.state.remotePlanWeight = value
}
}
if snapshots := parseRateLimitSnapshotsFromHeaders(headers); len(snapshots) > 0 {
hadData = true
applyRateLimitSnapshotsLocked(&c.state, snapshots, headers.Get("x-codex-active-limit"), c.state.remotePlanWeight, c.state.accountType)
}
if hadData {
c.state.consecutivePollFailures = 0
c.state.upstreamRejectedUntil = time.Time{}
c.state.lastUpdated = time.Now()
c.state.noteSnapshotData()
}
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
resetSuffix := ""
@@ -670,20 +688,14 @@ func (c *externalCredential) pollUsage() {
c.clearPollFailures()
return
}
if rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil ||
if rawFields["limits"] == nil && (rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil ||
rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil ||
rawFields["plan_weight"] == nil {
rawFields["plan_weight"] == nil) {
c.logger.Error("poll usage for ", c.tag, ": invalid response")
c.clearPollFailures()
return
}
var statusResponse struct {
FiveHourUtilization float64 `json:"five_hour_utilization"`
FiveHourReset int64 `json:"five_hour_reset"`
WeeklyUtilization float64 `json:"weekly_utilization"`
WeeklyReset int64 `json:"weekly_reset"`
PlanWeight float64 `json:"plan_weight"`
}
var statusResponse statusPayload
err = json.Unmarshal(body, &statusResponse)
if err != nil {
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
@@ -697,16 +709,38 @@ func (c *externalCredential) pollUsage() {
oldWeekly := c.state.weeklyUtilization
c.state.consecutivePollFailures = 0
c.state.upstreamRejectedUntil = time.Time{}
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
if statusResponse.FiveHourReset > 0 {
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
if len(statusResponse.Limits) > 0 {
applyRateLimitSnapshotsLocked(&c.state, statusResponse.Limits, statusResponse.ActiveLimit, statusResponse.PlanWeight, c.state.accountType)
} else {
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
if statusResponse.FiveHourReset > 0 {
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
}
if statusResponse.WeeklyReset > 0 {
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
}
if statusResponse.PlanWeight > 0 {
c.state.remotePlanWeight = statusResponse.PlanWeight
}
}
if statusResponse.WeeklyReset > 0 {
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
}
if statusResponse.PlanWeight > 0 {
c.state.remotePlanWeight = statusResponse.PlanWeight
if statusResponse.Availability != nil {
switch availabilityState(statusResponse.Availability.State) {
case availabilityStateRateLimited:
c.state.hardRateLimited = true
if statusResponse.Availability.ResetAt > 0 {
c.state.rateLimitResetAt = time.Unix(statusResponse.Availability.ResetAt, 0)
}
case availabilityStateTemporarilyBlocked:
resetAt := time.Time{}
if statusResponse.Availability.ResetAt > 0 {
resetAt = time.Unix(statusResponse.Availability.ResetAt, 0)
}
c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReason(statusResponse.Availability.Reason), resetAt)
if availabilityReason(statusResponse.Availability.Reason) == availabilityReasonUpstreamRejected && !resetAt.IsZero() {
c.state.upstreamRejectedUntil = resetAt
}
}
}
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
@@ -787,9 +821,9 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr
result.duration = time.Since(startTime)
return result, E.Cause(err, "decode status frame")
}
if rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil ||
if rawFields["limits"] == nil && (rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil ||
rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil ||
rawFields["plan_weight"] == nil {
rawFields["plan_weight"] == nil) {
result.duration = time.Since(startTime)
return result, E.New("invalid response")
}
@@ -806,16 +840,38 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr
oldWeekly := c.state.weeklyUtilization
c.state.consecutivePollFailures = 0
c.state.upstreamRejectedUntil = time.Time{}
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
if statusResponse.FiveHourReset > 0 {
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
if len(statusResponse.Limits) > 0 {
applyRateLimitSnapshotsLocked(&c.state, statusResponse.Limits, statusResponse.ActiveLimit, statusResponse.PlanWeight, c.state.accountType)
} else {
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
if statusResponse.FiveHourReset > 0 {
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
}
if statusResponse.WeeklyReset > 0 {
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
}
if statusResponse.PlanWeight > 0 {
c.state.remotePlanWeight = statusResponse.PlanWeight
}
}
if statusResponse.WeeklyReset > 0 {
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
}
if statusResponse.PlanWeight > 0 {
c.state.remotePlanWeight = statusResponse.PlanWeight
if statusResponse.Availability != nil {
switch availabilityState(statusResponse.Availability.State) {
case availabilityStateRateLimited:
c.state.hardRateLimited = true
if statusResponse.Availability.ResetAt > 0 {
c.state.rateLimitResetAt = time.Unix(statusResponse.Availability.ResetAt, 0)
}
case availabilityStateTemporarilyBlocked:
resetAt := time.Time{}
if statusResponse.Availability.ResetAt > 0 {
resetAt = time.Unix(statusResponse.Availability.ResetAt, 0)
}
c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReason(statusResponse.Availability.Reason), resetAt)
if availabilityReason(statusResponse.Availability.Reason) == availabilityReasonUpstreamRejected && !resetAt.IsZero() {
c.state.upstreamRejectedUntil = resetAt
}
}
}
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
@@ -888,6 +944,38 @@ func (c *externalCredential) lastUpdatedTime() time.Time {
return c.state.lastUpdated
}
func (c *externalCredential) hasSnapshotData() bool {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.hasSnapshotData()
}
func (c *externalCredential) availabilityStatus() availabilityStatus {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.currentAvailability()
}
func (c *externalCredential) rateLimitSnapshots() []rateLimitSnapshot {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
if len(c.state.rateLimitSnapshots) == 0 {
return nil
}
snapshots := make([]rateLimitSnapshot, 0, len(c.state.rateLimitSnapshots))
for _, snapshot := range c.state.rateLimitSnapshots {
snapshots = append(snapshots, cloneRateLimitSnapshot(snapshot))
}
sortRateLimitSnapshots(snapshots)
return snapshots
}
func (c *externalCredential) activeLimitID() string {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.activeLimitID
}
func (c *externalCredential) markUsageStreamUpdated() {
c.stateAccess.Lock()
defer c.stateAccess.Unlock()

View File

@@ -0,0 +1,384 @@
package ocm
import (
"net/http"
"slices"
"strconv"
"strings"
"time"
)
type availabilityState string
const (
availabilityStateUsable availabilityState = "usable"
availabilityStateRateLimited availabilityState = "rate_limited"
availabilityStateTemporarilyBlocked availabilityState = "temporarily_blocked"
availabilityStateUnavailable availabilityState = "unavailable"
availabilityStateUnknown availabilityState = "unknown"
)
type availabilityReason string
const (
availabilityReasonHardRateLimit availabilityReason = "hard_rate_limit"
availabilityReasonConnectionLimit availabilityReason = "connection_limit"
availabilityReasonPollFailed availabilityReason = "poll_failed"
availabilityReasonUpstreamRejected availabilityReason = "upstream_rejected"
availabilityReasonNoCredentials availabilityReason = "no_credentials"
availabilityReasonUnknown availabilityReason = "unknown"
)
type availabilityStatus struct {
State availabilityState
Reason availabilityReason
ResetAt time.Time
}
type availabilityPayload struct {
State string `json:"state"`
Reason string `json:"reason,omitempty"`
ResetAt int64 `json:"reset_at,omitempty"`
}
func (s availabilityStatus) normalized() availabilityStatus {
if s.State == "" {
s.State = availabilityStateUnknown
}
if s.Reason == "" && s.State != availabilityStateUsable {
s.Reason = availabilityReasonUnknown
}
return s
}
func (s availabilityStatus) toPayload() *availabilityPayload {
s = s.normalized()
payload := &availabilityPayload{
State: string(s.State),
}
if s.Reason != "" && s.Reason != availabilityReasonUnknown {
payload.Reason = string(s.Reason)
}
if !s.ResetAt.IsZero() {
payload.ResetAt = s.ResetAt.Unix()
}
return payload
}
type creditsSnapshot struct {
HasCredits bool `json:"has_credits"`
Unlimited bool `json:"unlimited"`
Balance string `json:"balance,omitempty"`
}
type rateLimitWindow struct {
UsedPercent float64 `json:"used_percent"`
WindowMinutes int64 `json:"window_minutes,omitempty"`
ResetAt int64 `json:"reset_at,omitempty"`
}
type rateLimitSnapshot struct {
LimitID string `json:"limit_id,omitempty"`
LimitName string `json:"limit_name,omitempty"`
Primary *rateLimitWindow `json:"primary,omitempty"`
Secondary *rateLimitWindow `json:"secondary,omitempty"`
Credits *creditsSnapshot `json:"credits,omitempty"`
PlanType string `json:"plan_type,omitempty"`
}
func normalizeStoredLimitID(limitID string) string {
normalized := normalizeRateLimitIdentifier(limitID)
if normalized == "" {
return ""
}
return strings.ReplaceAll(normalized, "-", "_")
}
func headerLimitID(limitID string) string {
if limitID == "" {
return "codex"
}
return strings.ReplaceAll(normalizeStoredLimitID(limitID), "_", "-")
}
func defaultRateLimitSnapshot(limitID string) rateLimitSnapshot {
if limitID == "" {
limitID = "codex"
}
return rateLimitSnapshot{LimitID: normalizeStoredLimitID(limitID)}
}
func cloneCreditsSnapshot(snapshot *creditsSnapshot) *creditsSnapshot {
if snapshot == nil {
return nil
}
cloned := *snapshot
return &cloned
}
func cloneRateLimitWindow(window *rateLimitWindow) *rateLimitWindow {
if window == nil {
return nil
}
cloned := *window
return &cloned
}
func cloneRateLimitSnapshot(snapshot rateLimitSnapshot) rateLimitSnapshot {
snapshot.Primary = cloneRateLimitWindow(snapshot.Primary)
snapshot.Secondary = cloneRateLimitWindow(snapshot.Secondary)
snapshot.Credits = cloneCreditsSnapshot(snapshot.Credits)
return snapshot
}
func sortRateLimitSnapshots(snapshots []rateLimitSnapshot) {
slices.SortFunc(snapshots, func(a, b rateLimitSnapshot) int {
return strings.Compare(a.LimitID, b.LimitID)
})
}
func parseHeaderFloat(headers http.Header, name string) (float64, bool) {
value := strings.TrimSpace(headers.Get(name))
if value == "" {
return 0, false
}
parsed, err := strconv.ParseFloat(value, 64)
if err != nil {
return 0, false
}
if !isFinite(parsed) {
return 0, false
}
return parsed, true
}
func isFinite(value float64) bool {
return !((value != value) || value > 1e308 || value < -1e308)
}
func parseCreditsSnapshotFromHeaders(headers http.Header) *creditsSnapshot {
hasCreditsValue := strings.TrimSpace(headers.Get("x-codex-credits-has-credits"))
unlimitedValue := strings.TrimSpace(headers.Get("x-codex-credits-unlimited"))
if hasCreditsValue == "" || unlimitedValue == "" {
return nil
}
hasCredits := strings.EqualFold(hasCreditsValue, "true") || hasCreditsValue == "1"
unlimited := strings.EqualFold(unlimitedValue, "true") || unlimitedValue == "1"
return &creditsSnapshot{
HasCredits: hasCredits,
Unlimited: unlimited,
Balance: strings.TrimSpace(headers.Get("x-codex-credits-balance")),
}
}
func parseRateLimitWindowFromHeaders(headers http.Header, prefix string, windowName string) *rateLimitWindow {
usedPercent, hasPercent := parseHeaderFloat(headers, prefix+"-"+windowName+"-used-percent")
windowMinutes, hasWindow := parseInt64Header(headers, prefix+"-"+windowName+"-window-minutes")
resetAt, hasReset := parseInt64Header(headers, prefix+"-"+windowName+"-reset-at")
if !hasPercent && !hasWindow && !hasReset {
return nil
}
window := &rateLimitWindow{}
if hasPercent {
window.UsedPercent = usedPercent
}
if hasWindow {
window.WindowMinutes = windowMinutes
}
if hasReset {
window.ResetAt = resetAt
}
return window
}
func parseRateLimitSnapshotsFromHeaders(headers http.Header) []rateLimitSnapshot {
limitIDs := map[string]struct{}{}
for key := range headers {
lowerKey := strings.ToLower(key)
if strings.HasPrefix(lowerKey, "x-") && strings.Contains(lowerKey, "-primary-") {
limitID := strings.TrimPrefix(lowerKey, "x-")
if suffix := strings.Index(limitID, "-primary-"); suffix > 0 {
limitIDs[normalizeStoredLimitID(limitID[:suffix])] = struct{}{}
}
}
if strings.HasPrefix(lowerKey, "x-") && strings.Contains(lowerKey, "-secondary-") {
limitID := strings.TrimPrefix(lowerKey, "x-")
if suffix := strings.Index(limitID, "-secondary-"); suffix > 0 {
limitIDs[normalizeStoredLimitID(limitID[:suffix])] = struct{}{}
}
}
}
if activeLimit := normalizeStoredLimitID(headers.Get("x-codex-active-limit")); activeLimit != "" {
limitIDs[activeLimit] = struct{}{}
}
if credits := parseCreditsSnapshotFromHeaders(headers); credits != nil {
_ = credits
limitIDs["codex"] = struct{}{}
}
if len(limitIDs) == 0 {
return nil
}
snapshots := make([]rateLimitSnapshot, 0, len(limitIDs))
for limitID := range limitIDs {
prefix := "x-" + headerLimitID(limitID)
snapshot := defaultRateLimitSnapshot(limitID)
snapshot.LimitName = strings.TrimSpace(headers.Get(prefix + "-limit-name"))
snapshot.Primary = parseRateLimitWindowFromHeaders(headers, prefix, "primary")
snapshot.Secondary = parseRateLimitWindowFromHeaders(headers, prefix, "secondary")
if limitID == "codex" {
snapshot.Credits = parseCreditsSnapshotFromHeaders(headers)
}
if snapshot.Primary == nil && snapshot.Secondary == nil && snapshot.Credits == nil {
continue
}
snapshots = append(snapshots, snapshot)
}
sortRateLimitSnapshots(snapshots)
return snapshots
}
type usageRateLimitWindowPayload struct {
UsedPercent float64 `json:"used_percent"`
LimitWindowSeconds int64 `json:"limit_window_seconds"`
ResetAt int64 `json:"reset_at"`
}
type usageRateLimitDetailsPayload struct {
PrimaryWindow *usageRateLimitWindowPayload `json:"primary_window"`
SecondaryWindow *usageRateLimitWindowPayload `json:"secondary_window"`
}
type usageCreditsPayload struct {
HasCredits bool `json:"has_credits"`
Unlimited bool `json:"unlimited"`
Balance *string `json:"balance"`
}
type additionalRateLimitPayload struct {
LimitName string `json:"limit_name"`
MeteredFeature string `json:"metered_feature"`
RateLimit *usageRateLimitDetailsPayload `json:"rate_limit"`
}
type usageRateLimitStatusPayload struct {
PlanType string `json:"plan_type"`
RateLimit *usageRateLimitDetailsPayload `json:"rate_limit"`
Credits *usageCreditsPayload `json:"credits"`
AdditionalRateLimits []additionalRateLimitPayload `json:"additional_rate_limits"`
}
func windowFromUsagePayload(window *usageRateLimitWindowPayload) *rateLimitWindow {
if window == nil {
return nil
}
result := &rateLimitWindow{
UsedPercent: window.UsedPercent,
}
if window.LimitWindowSeconds > 0 {
result.WindowMinutes = (window.LimitWindowSeconds + 59) / 60
}
if window.ResetAt > 0 {
result.ResetAt = window.ResetAt
}
return result
}
func snapshotsFromUsagePayload(payload usageRateLimitStatusPayload) []rateLimitSnapshot {
snapshots := make([]rateLimitSnapshot, 0, 1+len(payload.AdditionalRateLimits))
codex := defaultRateLimitSnapshot("codex")
codex.PlanType = payload.PlanType
if payload.RateLimit != nil {
codex.Primary = windowFromUsagePayload(payload.RateLimit.PrimaryWindow)
codex.Secondary = windowFromUsagePayload(payload.RateLimit.SecondaryWindow)
}
if payload.Credits != nil {
codex.Credits = &creditsSnapshot{
HasCredits: payload.Credits.HasCredits,
Unlimited: payload.Credits.Unlimited,
}
if payload.Credits.Balance != nil {
codex.Credits.Balance = *payload.Credits.Balance
}
}
if codex.Primary != nil || codex.Secondary != nil || codex.Credits != nil || codex.PlanType != "" {
snapshots = append(snapshots, codex)
}
for _, additional := range payload.AdditionalRateLimits {
snapshot := defaultRateLimitSnapshot(additional.MeteredFeature)
snapshot.LimitName = additional.LimitName
snapshot.PlanType = payload.PlanType
if additional.RateLimit != nil {
snapshot.Primary = windowFromUsagePayload(additional.RateLimit.PrimaryWindow)
snapshot.Secondary = windowFromUsagePayload(additional.RateLimit.SecondaryWindow)
}
if snapshot.Primary == nil && snapshot.Secondary == nil {
continue
}
snapshots = append(snapshots, snapshot)
}
sortRateLimitSnapshots(snapshots)
return snapshots
}
func applyRateLimitSnapshotsLocked(state *credentialState, snapshots []rateLimitSnapshot, activeLimitID string, planWeight float64, planType string) {
if len(snapshots) == 0 {
return
}
if state.rateLimitSnapshots == nil {
state.rateLimitSnapshots = make(map[string]rateLimitSnapshot, len(snapshots))
} else {
clear(state.rateLimitSnapshots)
}
for _, snapshot := range snapshots {
snapshot = cloneRateLimitSnapshot(snapshot)
if snapshot.LimitID == "" {
snapshot.LimitID = "codex"
}
if snapshot.LimitName == "" && snapshot.LimitID != "codex" {
snapshot.LimitName = strings.ReplaceAll(snapshot.LimitID, "_", "-")
}
if snapshot.PlanType == "" {
snapshot.PlanType = planType
}
state.rateLimitSnapshots[snapshot.LimitID] = snapshot
}
if planWeight > 0 {
state.remotePlanWeight = planWeight
}
if planType != "" {
state.accountType = planType
}
if normalizedActive := normalizeStoredLimitID(activeLimitID); normalizedActive != "" {
state.activeLimitID = normalizedActive
} else if state.activeLimitID == "" {
if _, exists := state.rateLimitSnapshots["codex"]; exists {
state.activeLimitID = "codex"
} else {
for limitID := range state.rateLimitSnapshots {
state.activeLimitID = limitID
break
}
}
}
legacy := state.rateLimitSnapshots["codex"]
if legacy.LimitID == "" && state.activeLimitID != "" {
legacy = state.rateLimitSnapshots[state.activeLimitID]
}
state.fiveHourUtilization = 0
state.fiveHourReset = time.Time{}
state.weeklyUtilization = 0
state.weeklyReset = time.Time{}
if legacy.Primary != nil {
state.fiveHourUtilization = legacy.Primary.UsedPercent
if legacy.Primary.ResetAt > 0 {
state.fiveHourReset = time.Unix(legacy.Primary.ResetAt, 0)
}
}
if legacy.Secondary != nil {
state.weeklyUtilization = legacy.Secondary.UsedPercent
if legacy.Secondary.ResetAt > 0 {
state.weeklyReset = time.Unix(legacy.Secondary.ResetAt, 0)
}
}
state.noteSnapshotData()
}

View File

@@ -65,7 +65,6 @@ func writePlainTextError(w http.ResponseWriter, statusCode int, message string)
const (
retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential"
retryableUsageCode = "credential_usage_exhausted"
)
func hasAlternativeCredential(provider credentialProvider, currentCredential Credential, selection credentialSelection) bool {
@@ -98,7 +97,7 @@ func unavailableCredentialMessage(provider credentialProvider, fallback string)
}
func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) {
writeJSONErrorWithCode(w, r, http.StatusServiceUnavailable, "server_error", retryableUsageCode, retryableUsageMessage)
writeJSONErrorWithCode(w, r, http.StatusTooManyRequests, "usage_limit_reached", "", retryableUsageMessage)
}
func writeNonRetryableCredentialError(w http.ResponseWriter, message string) {
@@ -117,6 +116,10 @@ func writeCredentialUnavailableError(
writeRetryableUsageError(w, r)
return
}
if provider != nil && strings.HasPrefix(allRateLimitedError(provider.allCredentials()).Error(), "all credentials rate-limited") {
writeRetryableUsageError(w, r)
return
}
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, fallback))
}

View File

@@ -4,6 +4,8 @@ import (
"bytes"
"encoding/json"
"net/http"
"reflect"
"slices"
"strconv"
"strings"
"time"
@@ -12,11 +14,14 @@ import (
)
type statusPayload struct {
FiveHourUtilization float64 `json:"five_hour_utilization"`
FiveHourReset int64 `json:"five_hour_reset"`
WeeklyUtilization float64 `json:"weekly_utilization"`
WeeklyReset int64 `json:"weekly_reset"`
PlanWeight float64 `json:"plan_weight"`
FiveHourUtilization float64 `json:"five_hour_utilization"`
FiveHourReset int64 `json:"five_hour_reset"`
WeeklyUtilization float64 `json:"weekly_utilization"`
WeeklyReset int64 `json:"weekly_reset"`
PlanWeight float64 `json:"plan_weight"`
ActiveLimit string `json:"active_limit,omitempty"`
Limits []rateLimitSnapshot `json:"limits,omitempty"`
Availability *availabilityPayload `json:"availability,omitempty"`
}
type aggregatedStatus struct {
@@ -25,6 +30,9 @@ type aggregatedStatus struct {
totalWeight float64
fiveHourReset time.Time
weeklyReset time.Time
activeLimitID string
limits []rateLimitSnapshot
availability availabilityStatus
}
func resetToEpoch(t time.Time) int64 {
@@ -35,11 +43,7 @@ func resetToEpoch(t time.Time) int64 {
}
func (s aggregatedStatus) equal(other aggregatedStatus) bool {
return s.fiveHourUtilization == other.fiveHourUtilization &&
s.weeklyUtilization == other.weeklyUtilization &&
s.totalWeight == other.totalWeight &&
resetToEpoch(s.fiveHourReset) == resetToEpoch(other.fiveHourReset) &&
resetToEpoch(s.weeklyReset) == resetToEpoch(other.weeklyReset)
return reflect.DeepEqual(s.toPayload(), other.toPayload())
}
func (s aggregatedStatus) toPayload() statusPayload {
@@ -49,9 +53,253 @@ func (s aggregatedStatus) toPayload() statusPayload {
WeeklyUtilization: s.weeklyUtilization,
WeeklyReset: resetToEpoch(s.weeklyReset),
PlanWeight: s.totalWeight,
ActiveLimit: s.activeLimitID,
Limits: slices.Clone(s.limits),
Availability: s.availability.toPayload(),
}
}
type aggregateInput struct {
weight float64
snapshots []rateLimitSnapshot
activeLimit string
availability availabilityStatus
}
type snapshotContribution struct {
weight float64
snapshot rateLimitSnapshot
}
func aggregateAvailability(inputs []aggregateInput) availabilityStatus {
if len(inputs) == 0 {
return availabilityStatus{
State: availabilityStateUnavailable,
Reason: availabilityReasonNoCredentials,
}
}
var earliestRateLimited time.Time
var hasRateLimited bool
var bestBlocked availabilityStatus
var hasBlocked bool
var hasUnavailable bool
blockedPriority := func(reason availabilityReason) int {
switch reason {
case availabilityReasonConnectionLimit:
return 3
case availabilityReasonPollFailed:
return 2
case availabilityReasonUpstreamRejected:
return 1
default:
return 0
}
}
for _, input := range inputs {
availability := input.availability.normalized()
switch availability.State {
case availabilityStateUsable:
return availabilityStatus{State: availabilityStateUsable}
case availabilityStateRateLimited:
hasRateLimited = true
if !availability.ResetAt.IsZero() && (earliestRateLimited.IsZero() || availability.ResetAt.Before(earliestRateLimited)) {
earliestRateLimited = availability.ResetAt
}
case availabilityStateTemporarilyBlocked:
if !hasBlocked || blockedPriority(availability.Reason) > blockedPriority(bestBlocked.Reason) {
bestBlocked = availability
hasBlocked = true
}
if hasBlocked && !availability.ResetAt.IsZero() && (bestBlocked.ResetAt.IsZero() || availability.ResetAt.Before(bestBlocked.ResetAt)) {
bestBlocked.ResetAt = availability.ResetAt
}
case availabilityStateUnavailable:
hasUnavailable = true
}
}
if hasRateLimited {
return availabilityStatus{
State: availabilityStateRateLimited,
Reason: availabilityReasonHardRateLimit,
ResetAt: earliestRateLimited,
}
}
if hasBlocked {
return bestBlocked
}
if hasUnavailable {
return availabilityStatus{
State: availabilityStateUnavailable,
Reason: availabilityReasonUnknown,
}
}
return availabilityStatus{
State: availabilityStateUnknown,
Reason: availabilityReasonUnknown,
}
}
func aggregateRateLimitWindow(contributions []snapshotContribution, selector func(rateLimitSnapshot) *rateLimitWindow) *rateLimitWindow {
var totalWeight float64
var totalRemaining float64
var totalWindowMinutes float64
var totalResetHours float64
var resetWeight float64
now := time.Now()
for _, contribution := range contributions {
window := selector(contribution.snapshot)
if window == nil {
continue
}
totalWeight += contribution.weight
totalRemaining += (100 - window.UsedPercent) * contribution.weight
if window.WindowMinutes > 0 {
totalWindowMinutes += float64(window.WindowMinutes) * contribution.weight
}
if window.ResetAt > 0 {
resetTime := time.Unix(window.ResetAt, 0)
hours := resetTime.Sub(now).Hours()
if hours > 0 {
totalResetHours += hours * contribution.weight
resetWeight += contribution.weight
}
}
}
if totalWeight == 0 {
return nil
}
window := &rateLimitWindow{
UsedPercent: 100 - totalRemaining/totalWeight,
}
if totalWindowMinutes > 0 {
window.WindowMinutes = int64(totalWindowMinutes / totalWeight)
}
if resetWeight > 0 {
window.ResetAt = now.Add(time.Duration(totalResetHours / resetWeight * float64(time.Hour))).Unix()
}
return window
}
func aggregateCredits(contributions []snapshotContribution) *creditsSnapshot {
var hasCredits bool
var unlimited bool
var balanceTotal float64
var hasBalance bool
for _, contribution := range contributions {
if contribution.snapshot.Credits == nil {
continue
}
hasCredits = hasCredits || contribution.snapshot.Credits.HasCredits
unlimited = unlimited || contribution.snapshot.Credits.Unlimited
if balance := strings.TrimSpace(contribution.snapshot.Credits.Balance); balance != "" {
value, err := strconv.ParseFloat(balance, 64)
if err == nil {
balanceTotal += value
hasBalance = true
}
}
}
if !hasCredits && !unlimited && !hasBalance {
return nil
}
credits := &creditsSnapshot{
HasCredits: hasCredits,
Unlimited: unlimited,
}
if hasBalance && !unlimited {
credits.Balance = strconv.FormatFloat(balanceTotal, 'f', -1, 64)
}
return credits
}
func aggregateSnapshots(inputs []aggregateInput) []rateLimitSnapshot {
grouped := make(map[string][]snapshotContribution)
for _, input := range inputs {
for _, snapshot := range input.snapshots {
limitID := snapshot.LimitID
if limitID == "" {
limitID = "codex"
}
grouped[limitID] = append(grouped[limitID], snapshotContribution{
weight: input.weight,
snapshot: snapshot,
})
}
}
if len(grouped) == 0 {
return nil
}
aggregated := make([]rateLimitSnapshot, 0, len(grouped))
for limitID, contributions := range grouped {
snapshot := defaultRateLimitSnapshot(limitID)
var bestPlanWeight float64
for _, contribution := range contributions {
if contribution.snapshot.LimitName != "" && snapshot.LimitName == "" {
snapshot.LimitName = contribution.snapshot.LimitName
}
if contribution.snapshot.PlanType != "" && contribution.weight >= bestPlanWeight {
bestPlanWeight = contribution.weight
snapshot.PlanType = contribution.snapshot.PlanType
}
}
snapshot.Primary = aggregateRateLimitWindow(contributions, func(snapshot rateLimitSnapshot) *rateLimitWindow {
return snapshot.Primary
})
snapshot.Secondary = aggregateRateLimitWindow(contributions, func(snapshot rateLimitSnapshot) *rateLimitWindow {
return snapshot.Secondary
})
snapshot.Credits = aggregateCredits(contributions)
if snapshot.Primary == nil && snapshot.Secondary == nil && snapshot.Credits == nil {
continue
}
aggregated = append(aggregated, snapshot)
}
sortRateLimitSnapshots(aggregated)
return aggregated
}
func selectActiveLimitID(inputs []aggregateInput, snapshots []rateLimitSnapshot) string {
if len(snapshots) == 0 {
return ""
}
weights := make(map[string]float64)
for _, input := range inputs {
if input.activeLimit == "" {
continue
}
weights[normalizeStoredLimitID(input.activeLimit)] += input.weight
}
var (
bestID string
bestWeight float64
)
for limitID, weight := range weights {
if weight > bestWeight {
bestID = limitID
bestWeight = weight
}
}
if bestID != "" {
return bestID
}
for _, snapshot := range snapshots {
if snapshot.LimitID == "codex" {
return "codex"
}
}
return snapshots[0].LimitID
}
func findSnapshotByLimitID(snapshots []rateLimitSnapshot, limitID string) *rateLimitSnapshot {
for _, snapshot := range snapshots {
if snapshot.LimitID == limitID {
snapshotCopy := snapshot
return &snapshotCopy
}
}
return nil
}
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
@@ -171,74 +419,86 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro
}
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) aggregatedStatus {
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
now := time.Now()
var totalWeightedHoursUntil5hReset, total5hResetWeight float64
var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64
inputs := make([]aggregateInput, 0, len(provider.allCredentials()))
var totalWeight float64
var hasSnapshotData bool
for _, credential := range provider.allCredentials() {
if !credential.isUsable() {
continue
}
if userConfig != nil && userConfig.ExternalCredential != "" && credential.tagName() == userConfig.ExternalCredential {
continue
}
if userConfig != nil && !userConfig.AllowExternalUsage && credential.isExternal() {
continue
}
weight := credential.planWeight()
remaining5h := credential.fiveHourCap() - credential.fiveHourUtilization()
if remaining5h < 0 {
remaining5h = 0
input := aggregateInput{
weight: credential.planWeight(),
snapshots: credential.rateLimitSnapshots(),
activeLimit: credential.activeLimitID(),
availability: credential.availabilityStatus(),
}
remainingWeekly := credential.weeklyCap() - credential.weeklyUtilization()
if remainingWeekly < 0 {
remainingWeekly = 0
}
totalWeightedRemaining5h += remaining5h * weight
totalWeightedRemainingWeekly += remainingWeekly * weight
totalWeight += weight
fiveHourReset := credential.fiveHourResetTime()
if !fiveHourReset.IsZero() {
hours := fiveHourReset.Sub(now).Hours()
if hours > 0 {
totalWeightedHoursUntil5hReset += hours * weight
total5hResetWeight += weight
}
}
weeklyReset := credential.weeklyResetTime()
if !weeklyReset.IsZero() {
hours := weeklyReset.Sub(now).Hours()
if hours > 0 {
totalWeightedHoursUntilWeeklyReset += hours * weight
totalWeeklyResetWeight += weight
}
}
}
if totalWeight == 0 {
return aggregatedStatus{
fiveHourUtilization: 100,
weeklyUtilization: 100,
inputs = append(inputs, input)
if credential.hasSnapshotData() {
hasSnapshotData = true
}
totalWeight += input.weight
}
limits := aggregateSnapshots(inputs)
result := aggregatedStatus{
fiveHourUtilization: 100 - totalWeightedRemaining5h/totalWeight,
weeklyUtilization: 100 - totalWeightedRemainingWeekly/totalWeight,
totalWeight: totalWeight,
totalWeight: totalWeight,
availability: aggregateAvailability(inputs),
limits: limits,
activeLimitID: selectActiveLimitID(inputs, limits),
}
if total5hResetWeight > 0 {
avgHours := totalWeightedHoursUntil5hReset / total5hResetWeight
result.fiveHourReset = now.Add(time.Duration(avgHours * float64(time.Hour)))
if legacy := findSnapshotByLimitID(result.limits, "codex"); legacy != nil {
if legacy.Primary != nil {
result.fiveHourUtilization = legacy.Primary.UsedPercent
if legacy.Primary.ResetAt > 0 {
result.fiveHourReset = time.Unix(legacy.Primary.ResetAt, 0)
}
}
if legacy.Secondary != nil {
result.weeklyUtilization = legacy.Secondary.UsedPercent
if legacy.Secondary.ResetAt > 0 {
result.weeklyReset = time.Unix(legacy.Secondary.ResetAt, 0)
}
}
} else if legacy := findSnapshotByLimitID(result.limits, result.activeLimitID); legacy != nil {
if legacy.Primary != nil {
result.fiveHourUtilization = legacy.Primary.UsedPercent
if legacy.Primary.ResetAt > 0 {
result.fiveHourReset = time.Unix(legacy.Primary.ResetAt, 0)
}
}
if legacy.Secondary != nil {
result.weeklyUtilization = legacy.Secondary.UsedPercent
if legacy.Secondary.ResetAt > 0 {
result.weeklyReset = time.Unix(legacy.Secondary.ResetAt, 0)
}
}
}
if totalWeeklyResetWeight > 0 {
avgHours := totalWeightedHoursUntilWeeklyReset / totalWeeklyResetWeight
result.weeklyReset = now.Add(time.Duration(avgHours * float64(time.Hour)))
if len(result.limits) == 0 && !hasSnapshotData {
result.fiveHourUtilization = 100
result.weeklyUtilization = 100
}
return result
}
func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentialProvider, userConfig *option.OCMUser) {
status := s.computeAggregatedUtilization(provider, userConfig)
for key := range headers {
lowerKey := strings.ToLower(key)
if lowerKey == "x-codex-active-limit" ||
strings.HasSuffix(lowerKey, "-primary-used-percent") ||
strings.HasSuffix(lowerKey, "-primary-window-minutes") ||
strings.HasSuffix(lowerKey, "-primary-reset-at") ||
strings.HasSuffix(lowerKey, "-secondary-used-percent") ||
strings.HasSuffix(lowerKey, "-secondary-window-minutes") ||
strings.HasSuffix(lowerKey, "-secondary-reset-at") ||
strings.HasSuffix(lowerKey, "-limit-name") ||
strings.HasPrefix(lowerKey, "x-codex-credits-") {
headers.Del(key)
}
}
headers.Set("x-codex-active-limit", headerLimitID(status.activeLimitID))
headers.Set("x-codex-primary-used-percent", strconv.FormatFloat(status.fiveHourUtilization, 'f', 2, 64))
headers.Set("x-codex-secondary-used-percent", strconv.FormatFloat(status.weeklyUtilization, 'f', 2, 64))
if !status.fiveHourReset.IsZero() {
@@ -254,25 +514,34 @@ func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentia
if status.totalWeight > 0 {
headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64))
}
rateLimitSuffixes := [...]string{
"-primary-used-percent",
"-primary-reset-at",
"-secondary-used-percent",
"-secondary-reset-at",
"-secondary-window-minutes",
"-limit-name",
}
for key := range headers {
lowerKey := strings.ToLower(key)
if !strings.HasPrefix(lowerKey, "x-") {
continue
for _, snapshot := range status.limits {
prefix := "x-" + headerLimitID(snapshot.LimitID)
if snapshot.Primary != nil {
headers.Set(prefix+"-primary-used-percent", strconv.FormatFloat(snapshot.Primary.UsedPercent, 'f', 2, 64))
if snapshot.Primary.WindowMinutes > 0 {
headers.Set(prefix+"-primary-window-minutes", strconv.FormatInt(snapshot.Primary.WindowMinutes, 10))
}
if snapshot.Primary.ResetAt > 0 {
headers.Set(prefix+"-primary-reset-at", strconv.FormatInt(snapshot.Primary.ResetAt, 10))
}
}
for _, suffix := range rateLimitSuffixes {
if strings.HasSuffix(lowerKey, suffix) {
if strings.TrimSuffix(lowerKey, suffix) != "x-codex" {
headers.Del(key)
}
break
if snapshot.Secondary != nil {
headers.Set(prefix+"-secondary-used-percent", strconv.FormatFloat(snapshot.Secondary.UsedPercent, 'f', 2, 64))
if snapshot.Secondary.WindowMinutes > 0 {
headers.Set(prefix+"-secondary-window-minutes", strconv.FormatInt(snapshot.Secondary.WindowMinutes, 10))
}
if snapshot.Secondary.ResetAt > 0 {
headers.Set(prefix+"-secondary-reset-at", strconv.FormatInt(snapshot.Secondary.ResetAt, 10))
}
}
if snapshot.LimitName != "" {
headers.Set(prefix+"-limit-name", snapshot.LimitName)
}
if snapshot.LimitID == "codex" && snapshot.Credits != nil {
headers.Set("x-codex-credits-has-credits", strconv.FormatBool(snapshot.Credits.HasCredits))
headers.Set("x-codex-credits-unlimited", strconv.FormatBool(snapshot.Credits.Unlimited))
if snapshot.Credits.Balance != "" {
headers.Set("x-codex-credits-balance", snapshot.Credits.Balance)
}
}
}

View File

@@ -0,0 +1,220 @@
package ocm
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/observable"
)
type testCredential struct {
tag string
external bool
available bool
usable bool
hasData bool
fiveHour float64
weekly float64
fiveHourCapV float64
weeklyCapV float64
weight float64
fiveReset time.Time
weeklyReset time.Time
availability availabilityStatus
activeLimit string
snapshots []rateLimitSnapshot
}
func (c *testCredential) tagName() string { return c.tag }
func (c *testCredential) isAvailable() bool { return c.available }
func (c *testCredential) isUsable() bool { return c.usable }
func (c *testCredential) isExternal() bool { return c.external }
func (c *testCredential) hasSnapshotData() bool { return c.hasData }
func (c *testCredential) fiveHourUtilization() float64 { return c.fiveHour }
func (c *testCredential) weeklyUtilization() float64 { return c.weekly }
func (c *testCredential) fiveHourCap() float64 { return c.fiveHourCapV }
func (c *testCredential) weeklyCap() float64 { return c.weeklyCapV }
func (c *testCredential) planWeight() float64 { return c.weight }
func (c *testCredential) weeklyResetTime() time.Time { return c.weeklyReset }
func (c *testCredential) fiveHourResetTime() time.Time { return c.fiveReset }
func (c *testCredential) markRateLimited(time.Time) {}
func (c *testCredential) markUpstreamRejected() {}
func (c *testCredential) markTemporarilyBlocked(reason availabilityReason, resetAt time.Time) {
c.availability = availabilityStatus{State: availabilityStateTemporarilyBlocked, Reason: reason, ResetAt: resetAt}
}
func (c *testCredential) availabilityStatus() availabilityStatus { return c.availability }
func (c *testCredential) rateLimitSnapshots() []rateLimitSnapshot {
return slicesCloneSnapshots(c.snapshots)
}
func (c *testCredential) activeLimitID() string { return c.activeLimit }
func (c *testCredential) earliestReset() time.Time { return c.fiveReset }
func (c *testCredential) unavailableError() error { return nil }
func (c *testCredential) getAccessToken() (string, error) { return "", nil }
func (c *testCredential) buildProxyRequest(context.Context, *http.Request, []byte, http.Header) (*http.Request, error) {
return nil, nil
}
func (c *testCredential) updateStateFromHeaders(http.Header) {}
func (c *testCredential) wrapRequestContext(context.Context) *credentialRequestContext { return nil }
func (c *testCredential) interruptConnections() {}
func (c *testCredential) setOnBecameUnusable(func()) {}
func (c *testCredential) setStatusSubscriber(*observable.Subscriber[struct{}]) {}
func (c *testCredential) start() error { return nil }
func (c *testCredential) pollUsage() {}
func (c *testCredential) lastUpdatedTime() time.Time { return time.Now() }
func (c *testCredential) pollBackoff(time.Duration) time.Duration { return 0 }
func (c *testCredential) usageTrackerOrNil() *AggregatedUsage { return nil }
func (c *testCredential) httpClient() *http.Client { return nil }
func (c *testCredential) close() {}
func (c *testCredential) ocmDialer() N.Dialer { return nil }
func (c *testCredential) ocmIsAPIKeyMode() bool { return false }
func (c *testCredential) ocmGetAccountID() string { return "" }
func (c *testCredential) ocmGetBaseURL() string { return "" }
func slicesCloneSnapshots(snapshots []rateLimitSnapshot) []rateLimitSnapshot {
if len(snapshots) == 0 {
return nil
}
cloned := make([]rateLimitSnapshot, 0, len(snapshots))
for _, snapshot := range snapshots {
cloned = append(cloned, cloneRateLimitSnapshot(snapshot))
}
return cloned
}
type testProvider struct {
credentials []Credential
}
func (p *testProvider) selectCredential(string, credentialSelection) (Credential, bool, error) {
return nil, false, nil
}
func (p *testProvider) onRateLimited(string, Credential, time.Time, credentialSelection) Credential {
return nil
}
func (p *testProvider) linkProviderInterrupt(Credential, credentialSelection, func()) func() bool {
return func() bool { return true }
}
func (p *testProvider) pollIfStale() {}
func (p *testProvider) pollCredentialIfStale(Credential) {}
func (p *testProvider) allCredentials() []Credential { return p.credentials }
func (p *testProvider) close() {}
func TestComputeAggregatedUtilizationPreservesStoredSnapshots(t *testing.T) {
t.Parallel()
service := &Service{}
status := service.computeAggregatedUtilization(&testProvider{credentials: []Credential{
&testCredential{
tag: "a",
available: true,
usable: false,
hasData: true,
weight: 1,
activeLimit: "codex",
availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: time.Now().Add(time.Minute)},
snapshots: []rateLimitSnapshot{
{
LimitID: "codex",
Primary: &rateLimitWindow{UsedPercent: 44, WindowMinutes: 300, ResetAt: time.Now().Add(time.Hour).Unix()},
Secondary: &rateLimitWindow{UsedPercent: 12, WindowMinutes: 10080, ResetAt: time.Now().Add(24 * time.Hour).Unix()},
},
},
},
}}, nil)
if status.fiveHourUtilization != 44 || status.weeklyUtilization != 12 {
t.Fatalf("expected stored snapshot utilization, got 5h=%v weekly=%v", status.fiveHourUtilization, status.weeklyUtilization)
}
if status.availability.State != availabilityStateRateLimited {
t.Fatalf("expected rate-limited availability, got %#v", status.availability)
}
}
func TestRewriteResponseHeadersIncludesAdditionalLimitFamiliesAndCredits(t *testing.T) {
t.Parallel()
service := &Service{}
headers := make(http.Header)
service.rewriteResponseHeaders(headers, &testProvider{credentials: []Credential{
&testCredential{
tag: "a",
available: true,
usable: true,
hasData: true,
weight: 1,
activeLimit: "codex_other",
availability: availabilityStatus{State: availabilityStateUsable},
snapshots: []rateLimitSnapshot{
{
LimitID: "codex",
Primary: &rateLimitWindow{UsedPercent: 20, WindowMinutes: 300, ResetAt: time.Now().Add(time.Hour).Unix()},
Secondary: &rateLimitWindow{UsedPercent: 40, WindowMinutes: 10080, ResetAt: time.Now().Add(24 * time.Hour).Unix()},
Credits: &creditsSnapshot{HasCredits: true, Unlimited: false, Balance: "12"},
},
{
LimitID: "codex_other",
LimitName: "codex-other",
Primary: &rateLimitWindow{UsedPercent: 60, WindowMinutes: 60, ResetAt: time.Now().Add(30 * time.Minute).Unix()},
},
},
},
}}, nil)
if headers.Get("x-codex-active-limit") != "codex-other" {
t.Fatalf("expected active limit header, got %q", headers.Get("x-codex-active-limit"))
}
if headers.Get("x-codex-other-primary-used-percent") == "" {
t.Fatal("expected additional rate-limit family header")
}
if headers.Get("x-codex-credits-balance") != "12" {
t.Fatalf("expected credits balance header, got %q", headers.Get("x-codex-credits-balance"))
}
}
func TestHandleWebSocketErrorEventConnectionLimitDoesNotUseRateLimitPath(t *testing.T) {
t.Parallel()
credential := &testCredential{availability: availabilityStatus{State: availabilityStateUsable}}
service := &Service{}
service.handleWebSocketErrorEvent([]byte(`{"type":"error","status_code":400,"error":{"code":"websocket_connection_limit_reached"}}`), credential)
if credential.availability.State != availabilityStateTemporarilyBlocked || credential.availability.Reason != availabilityReasonConnectionLimit {
t.Fatalf("expected temporary connection limit block, got %#v", credential.availability)
}
}
func TestWriteCredentialUnavailableErrorReturns429ForRateLimitedCredentials(t *testing.T) {
t.Parallel()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/v1/responses", nil)
provider := &testProvider{credentials: []Credential{
&testCredential{
tag: "a",
available: true,
usable: false,
hasData: true,
weight: 1,
availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: time.Now().Add(time.Minute)},
snapshots: []rateLimitSnapshot{{LimitID: "codex", Primary: &rateLimitWindow{UsedPercent: 80}}},
},
}}
writeCredentialUnavailableError(recorder, request, provider, provider.credentials[0], credentialSelection{}, "all credentials rate-limited")
if recorder.Code != http.StatusTooManyRequests {
t.Fatalf("expected 429, got %d", recorder.Code)
}
var body map[string]map[string]string
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
t.Fatal(err)
}
if body["error"]["type"] != "usage_limit_reached" {
t.Fatalf("expected usage_limit_reached type, got %#v", body)
}
}

View File

@@ -430,9 +430,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe
s.handleWebSocketRateLimitsEvent(data, selectedCredential)
continue
case "error":
if event.StatusCode == http.StatusTooManyRequests {
s.handleWebSocketErrorRateLimited(data, selectedCredential)
}
s.handleWebSocketErrorEvent(data, selectedCredential)
case "response.completed":
if usageTracker != nil {
select {
@@ -460,17 +458,22 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe
func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential Credential) {
var rateLimitsEvent struct {
RateLimits struct {
MeteredLimitName string `json:"metered_limit_name"`
LimitName string `json:"limit_name"`
RateLimits struct {
Primary *struct {
UsedPercent float64 `json:"used_percent"`
ResetAt int64 `json:"reset_at"`
UsedPercent float64 `json:"used_percent"`
WindowMinutes int64 `json:"window_minutes"`
ResetAt int64 `json:"reset_at"`
} `json:"primary"`
Secondary *struct {
UsedPercent float64 `json:"used_percent"`
ResetAt int64 `json:"reset_at"`
UsedPercent float64 `json:"used_percent"`
WindowMinutes int64 `json:"window_minutes"`
ResetAt int64 `json:"reset_at"`
} `json:"secondary"`
} `json:"rate_limits"`
PlanWeight float64 `json:"plan_weight"`
Credits *creditsSnapshot `json:"credits"`
PlanWeight float64 `json:"plan_weight"`
}
err := json.Unmarshal(data, &rateLimitsEvent)
if err != nil {
@@ -478,17 +481,41 @@ func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential
}
headers := make(http.Header)
headers.Set("x-codex-active-limit", "codex")
limitID := rateLimitsEvent.MeteredLimitName
if limitID == "" {
limitID = rateLimitsEvent.LimitName
}
if limitID == "" {
limitID = "codex"
}
headerLimit := headerLimitID(limitID)
headers.Set("x-codex-active-limit", headerLimit)
if w := rateLimitsEvent.RateLimits.Primary; w != nil {
headers.Set("x-codex-primary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64))
headers.Set("x-"+headerLimit+"-primary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64))
if w.WindowMinutes > 0 {
headers.Set("x-"+headerLimit+"-primary-window-minutes", strconv.FormatInt(w.WindowMinutes, 10))
}
if w.ResetAt > 0 {
headers.Set("x-codex-primary-reset-at", strconv.FormatInt(w.ResetAt, 10))
headers.Set("x-"+headerLimit+"-primary-reset-at", strconv.FormatInt(w.ResetAt, 10))
}
}
if w := rateLimitsEvent.RateLimits.Secondary; w != nil {
headers.Set("x-codex-secondary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64))
headers.Set("x-"+headerLimit+"-secondary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64))
if w.WindowMinutes > 0 {
headers.Set("x-"+headerLimit+"-secondary-window-minutes", strconv.FormatInt(w.WindowMinutes, 10))
}
if w.ResetAt > 0 {
headers.Set("x-codex-secondary-reset-at", strconv.FormatInt(w.ResetAt, 10))
headers.Set("x-"+headerLimit+"-secondary-reset-at", strconv.FormatInt(w.ResetAt, 10))
}
}
if rateLimitsEvent.LimitName != "" {
headers.Set("x-"+headerLimit+"-limit-name", rateLimitsEvent.LimitName)
}
if rateLimitsEvent.Credits != nil && normalizeStoredLimitID(limitID) == "codex" {
headers.Set("x-codex-credits-has-credits", strconv.FormatBool(rateLimitsEvent.Credits.HasCredits))
headers.Set("x-codex-credits-unlimited", strconv.FormatBool(rateLimitsEvent.Credits.Unlimited))
if rateLimitsEvent.Credits.Balance != "" {
headers.Set("x-codex-credits-balance", rateLimitsEvent.Credits.Balance)
}
}
if rateLimitsEvent.PlanWeight > 0 {
@@ -497,14 +524,25 @@ func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential
selectedCredential.updateStateFromHeaders(headers)
}
func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredential Credential) {
func (s *Service) handleWebSocketErrorEvent(data []byte, selectedCredential Credential) {
var errorEvent struct {
Headers map[string]string `json:"headers"`
StatusCode int `json:"status_code"`
Headers map[string]string `json:"headers"`
Error struct {
Code string `json:"code"`
} `json:"error"`
}
err := json.Unmarshal(data, &errorEvent)
if err != nil {
return
}
if errorEvent.StatusCode == http.StatusBadRequest && errorEvent.Error.Code == "websocket_connection_limit_reached" {
selectedCredential.markTemporarilyBlocked(availabilityReasonConnectionLimit, time.Now().Add(time.Minute))
return
}
if errorEvent.StatusCode != http.StatusTooManyRequests {
return
}
headers := make(http.Header)
for key, value := range errorEvent.Headers {
headers.Set(key, value)
@@ -515,10 +553,14 @@ func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredentia
}
func writeWebSocketAggregatedStatus(clientConn net.Conn, clientWriteAccess *sync.Mutex, status aggregatedStatus) error {
data := buildSyntheticRateLimitsEvent(status)
clientWriteAccess.Lock()
defer clientWriteAccess.Unlock()
return wsutil.WriteServerMessage(clientConn, ws.OpText, data)
for _, data := range buildSyntheticRateLimitsEvents(status) {
if err := wsutil.WriteServerMessage(clientConn, ws.OpText, data); err != nil {
return err
}
}
return nil
}
func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn net.Conn, clientWriteAccess *sync.Mutex, sessionClosed <-chan struct{}, firstRealRequest <-chan struct{}, provider credentialProvider, userConfig *option.OCMUser) {
@@ -573,34 +615,106 @@ func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn
}
}
func buildSyntheticRateLimitsEvent(status aggregatedStatus) []byte {
func buildSyntheticRateLimitsEvents(status aggregatedStatus) [][]byte {
type rateLimitWindow struct {
UsedPercent float64 `json:"used_percent"`
ResetAt int64 `json:"reset_at,omitempty"`
UsedPercent float64 `json:"used_percent"`
WindowMinutes int64 `json:"window_minutes,omitempty"`
ResetAt int64 `json:"reset_at,omitempty"`
}
event := struct {
type creditsEvent struct {
HasCredits bool `json:"has_credits"`
Unlimited bool `json:"unlimited"`
Balance string `json:"balance,omitempty"`
}
type eventPayload struct {
Type string `json:"type"`
RateLimits struct {
Primary *rateLimitWindow `json:"primary,omitempty"`
Secondary *rateLimitWindow `json:"secondary,omitempty"`
} `json:"rate_limits"`
LimitName string `json:"limit_name"`
PlanWeight float64 `json:"plan_weight,omitempty"`
}{
Type: "codex.rate_limits",
LimitName: "codex",
PlanWeight: status.totalWeight,
MeteredLimitName string `json:"metered_limit_name,omitempty"`
LimitName string `json:"limit_name,omitempty"`
Credits *creditsEvent `json:"credits,omitempty"`
PlanWeight float64 `json:"plan_weight,omitempty"`
}
event.RateLimits.Primary = &rateLimitWindow{
buildEvent := func(snapshot rateLimitSnapshot, primary *rateLimitWindow, secondary *rateLimitWindow) []byte {
event := eventPayload{
Type: "codex.rate_limits",
MeteredLimitName: snapshot.LimitID,
LimitName: snapshot.LimitName,
PlanWeight: status.totalWeight,
}
if event.MeteredLimitName == "" {
event.MeteredLimitName = "codex"
}
if event.LimitName == "" {
event.LimitName = strings.ReplaceAll(event.MeteredLimitName, "_", "-")
}
event.RateLimits.Primary = primary
event.RateLimits.Secondary = secondary
if snapshot.Credits != nil {
event.Credits = &creditsEvent{
HasCredits: snapshot.Credits.HasCredits,
Unlimited: snapshot.Credits.Unlimited,
Balance: snapshot.Credits.Balance,
}
}
data, _ := json.Marshal(event)
return data
}
defaultPrimary := &rateLimitWindow{
UsedPercent: status.fiveHourUtilization,
ResetAt: resetToEpoch(status.fiveHourReset),
}
event.RateLimits.Secondary = &rateLimitWindow{
defaultSecondary := &rateLimitWindow{
UsedPercent: status.weeklyUtilization,
ResetAt: resetToEpoch(status.weeklyReset),
}
data, _ := json.Marshal(event)
return data
events := make([][]byte, 0, 1+len(status.limits))
if snapshot := findSnapshotByLimitID(status.limits, "codex"); snapshot != nil {
primary := defaultPrimary
if snapshot.Primary != nil {
primary = &rateLimitWindow{
UsedPercent: snapshot.Primary.UsedPercent,
WindowMinutes: snapshot.Primary.WindowMinutes,
ResetAt: snapshot.Primary.ResetAt,
}
}
secondary := defaultSecondary
if snapshot.Secondary != nil {
secondary = &rateLimitWindow{
UsedPercent: snapshot.Secondary.UsedPercent,
WindowMinutes: snapshot.Secondary.WindowMinutes,
ResetAt: snapshot.Secondary.ResetAt,
}
}
events = append(events, buildEvent(*snapshot, primary, secondary))
} else {
events = append(events, buildEvent(rateLimitSnapshot{LimitID: "codex", LimitName: "codex"}, defaultPrimary, defaultSecondary))
}
for _, snapshot := range status.limits {
if snapshot.LimitID == "codex" {
continue
}
var primary *rateLimitWindow
if snapshot.Primary != nil {
primary = &rateLimitWindow{
UsedPercent: snapshot.Primary.UsedPercent,
WindowMinutes: snapshot.Primary.WindowMinutes,
ResetAt: snapshot.Primary.ResetAt,
}
}
var secondary *rateLimitWindow
if snapshot.Secondary != nil {
secondary = &rateLimitWindow{
UsedPercent: snapshot.Secondary.UsedPercent,
WindowMinutes: snapshot.Secondary.WindowMinutes,
ResetAt: snapshot.Secondary.ResetAt,
}
}
events = append(events, buildEvent(snapshot, primary, secondary))
}
return events
}
func (s *Service) handleWebSocketResponseCompleted(data []byte, usageTracker *AggregatedUsage, requestModel string, username string, weeklyCycleHint *WeeklyCycleHint) {