ccm,ocm: propagate reset times, rewrite headers for all users, add WS status push

- Add fiveHourReset/weeklyReset to statusPayload and aggregatedStatus
  with weight-averaged reset time aggregation across credential pools
- Rewrite response headers (utilization + reset times) for all users,
  not just external credential users
- Rewrite WebSocket rate_limits events for all users with aggregated values
- Add proactive WebSocket status push: synthetic codex.rate_limits events
  sent on connection start and on status changes via statusObserver
- Remove one-shot stream forward compatibility (statusStreamHeader,
  restoreLastUpdatedIfUnchanged, oneShot detection)
This commit is contained in:
世界
2026-03-17 18:13:54 +08:00
parent 7d15d9d282
commit 0a054b9aa4
13 changed files with 381 additions and 287 deletions

View File

@@ -105,6 +105,7 @@ type Credential interface {
fiveHourCap() float64
weeklyCap() float64
planWeight() float64
fiveHourResetTime() time.Time
weeklyResetTime() time.Time
markRateLimited(resetAt time.Time)
earliestReset() time.Time

View File

@@ -421,6 +421,12 @@ func (c *defaultCredential) planWeight() float64 {
return ccmPlanWeight(c.state.accountType, c.state.rateLimitTier)
}
func (c *defaultCredential) fiveHourResetTime() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.fiveHourReset
}
func (c *defaultCredential) weeklyResetTime() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()

View File

@@ -27,10 +27,7 @@ import (
"github.com/hashicorp/yamux"
)
const (
reverseProxyBaseURL = "http://reverse-proxy"
statusStreamHeader = "X-CCM-Status-Stream"
)
const reverseProxyBaseURL = "http://reverse-proxy"
type externalCredential struct {
tag string
@@ -70,7 +67,6 @@ type externalCredential struct {
type statusStreamResult struct {
duration time.Duration
frames int
oneShot bool
}
func externalCredentialURLPort(parsedURL *url.URL) uint16 {
@@ -325,6 +321,12 @@ func (c *externalCredential) planWeight() float64 {
return 10
}
func (c *externalCredential) fiveHourResetTime() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.fiveHourReset
}
func (c *externalCredential) weeklyResetTime() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
@@ -592,7 +594,9 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
var statusResponse struct {
FiveHourUtilization float64 `json:"five_hour_utilization"`
FiveHourReset int64 `json:"five_hour_reset"`
WeeklyUtilization float64 `json:"weekly_utilization"`
WeeklyReset int64 `json:"weekly_reset"`
PlanWeight float64 `json:"plan_weight"`
}
err = json.NewDecoder(response.Body).Decode(&statusResponse)
@@ -612,6 +616,12 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
if statusResponse.PlanWeight > 0 {
c.state.remotePlanWeight = statusResponse.PlanWeight
}
if statusResponse.FiveHourReset > 0 {
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
}
if statusResponse.WeeklyReset > 0 {
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
}
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
@@ -645,13 +655,8 @@ func (c *externalCredential) statusStreamLoop() {
return
}
var backoff time.Duration
var oneShot bool
consecutiveFailures, backoff, oneShot = c.nextStatusStreamBackoff(result, consecutiveFailures)
if oneShot {
c.logger.Debug("status stream for ", c.tag, " returned a single-frame response, retrying in ", backoff)
} else {
c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff)
}
consecutiveFailures, backoff = c.nextStatusStreamBackoff(result, consecutiveFailures)
c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff)
timer := time.NewTimer(backoff)
select {
case <-timer.C:
@@ -679,18 +684,11 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr
}
decoder := json.NewDecoder(response.Body)
isStatusStream := response.Header.Get(statusStreamHeader) == "true"
previousLastUpdated := c.lastUpdatedTime()
var firstFrameUpdatedAt time.Time
for {
var statusResponse statusPayload
err = decoder.Decode(&statusResponse)
if err != nil {
result.duration = time.Since(startTime)
if result.frames == 1 && err == io.EOF && !isStatusStream {
result.oneShot = true
c.restoreLastUpdatedIfUnchanged(firstFrameUpdatedAt, previousLastUpdated)
}
return result, err
}
@@ -701,6 +699,12 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr
if statusResponse.PlanWeight > 0 {
c.state.remotePlanWeight = statusResponse.PlanWeight
}
if statusResponse.FiveHourReset > 0 {
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
}
if statusResponse.WeeklyReset > 0 {
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
}
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
@@ -710,23 +714,17 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr
c.interruptConnections()
}
result.frames++
updatedAt := c.markUsageStreamUpdated()
if result.frames == 1 {
firstFrameUpdatedAt = updatedAt
}
c.markUsageStreamUpdated()
c.emitStatusUpdate()
}
}
func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration, bool) {
if result.oneShot {
return 0, c.pollInterval, true
}
func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration) {
if result.duration >= connectorBackoffResetThreshold {
consecutiveFailures = 0
}
consecutiveFailures++
return consecutiveFailures, connectorBackoff(consecutiveFailures), false
return consecutiveFailures, connectorBackoff(consecutiveFailures)
}
func (c *externalCredential) doStreamStatusRequest(ctx context.Context) (*http.Response, error) {
@@ -767,23 +765,10 @@ func (c *externalCredential) lastUpdatedTime() time.Time {
return c.state.lastUpdated
}
func (c *externalCredential) markUsageStreamUpdated() time.Time {
func (c *externalCredential) markUsageStreamUpdated() {
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
now := time.Now()
c.state.lastUpdated = now
return now
}
func (c *externalCredential) restoreLastUpdatedIfUnchanged(expectedCurrent time.Time, previous time.Time) {
if expectedCurrent.IsZero() {
return
}
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
if c.state.lastUpdated.Equal(expectedCurrent) {
c.state.lastUpdated = previous
}
c.state.lastUpdated = time.Now()
}
func (c *externalCredential) markUsagePollAttempted() {

View File

@@ -84,7 +84,7 @@ func newTestYamuxSessionPair(t *testing.T) (*yamux.Session, *yamux.Session) {
return clientSession, serverSession
}
func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *testing.T) {
func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) {
credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil)
oldTime := time.Unix(123, 0)
credential.stateAccess.Lock()
@@ -95,50 +95,6 @@ func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *test
if err != io.EOF {
t.Fatalf("expected EOF, got %v", err)
}
if !result.oneShot {
t.Fatal("expected one-shot result")
}
if result.frames != 1 {
t.Fatalf("expected 1 frame, got %d", result.frames)
}
if !credential.lastUpdatedTime().Equal(oldTime) {
t.Fatalf("expected lastUpdated restored to %v, got %v", oldTime, credential.lastUpdatedTime())
}
if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 {
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
}
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event, got %d", count)
}
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
if !oneShot {
t.Fatal("expected one-shot backoff branch")
}
if failures != 0 {
t.Fatalf("expected failures reset, got %d", failures)
}
if backoff != credential.pollInterval {
t.Fatalf("expected poll interval backoff %v, got %v", credential.pollInterval, backoff)
}
}
func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) {
headers := make(http.Header)
headers.Set(statusStreamHeader, "true")
credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", headers)
oldTime := time.Unix(123, 0)
credential.stateAccess.Lock()
credential.state.lastUpdated = oldTime
credential.stateAccess.Unlock()
result, err := credential.connectStatusStream(context.Background())
if err != io.EOF {
t.Fatalf("expected EOF, got %v", err)
}
if result.oneShot {
t.Fatal("did not expect one-shot result")
}
if result.frames != 1 {
t.Fatalf("expected 1 frame, got %d", result.frames)
}
@@ -152,10 +108,7 @@ func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *tes
t.Fatalf("expected 1 status event, got %d", count)
}
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
if oneShot {
t.Fatal("did not expect one-shot backoff branch")
}
failures, backoff := credential.nextStatusStreamBackoff(result, 3)
if failures != 4 {
t.Fatalf("expected failures incremented to 4, got %d", failures)
}
@@ -178,9 +131,6 @@ func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *test
if err != io.EOF {
t.Fatalf("expected EOF, got %v", err)
}
if result.oneShot {
t.Fatal("did not expect one-shot result")
}
if result.frames != 2 {
t.Fatalf("expected 2 frames, got %d", result.frames)
}

View File

@@ -311,10 +311,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// Rewrite response headers for external users
if userConfig != nil && userConfig.ExternalCredential != "" {
s.rewriteResponseHeadersForExternalUser(response.Header, provider, userConfig)
}
s.rewriteResponseHeaders(response.Header, provider, userConfig)
for key, values := range response.Header {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {

View File

@@ -6,16 +6,52 @@ import (
"net/http"
"strconv"
"strings"
"time"
"github.com/sagernet/sing-box/option"
)
type statusPayload struct {
FiveHourUtilization float64 `json:"five_hour_utilization"`
FiveHourReset int64 `json:"five_hour_reset"`
WeeklyUtilization float64 `json:"weekly_utilization"`
WeeklyReset int64 `json:"weekly_reset"`
PlanWeight float64 `json:"plan_weight"`
}
type aggregatedStatus struct {
fiveHourUtilization float64
weeklyUtilization float64
totalWeight float64
fiveHourReset time.Time
weeklyReset time.Time
}
func resetToEpoch(t time.Time) int64 {
if t.IsZero() {
return 0
}
return t.Unix()
}
func (s aggregatedStatus) equal(other aggregatedStatus) bool {
return s.fiveHourUtilization == other.fiveHourUtilization &&
s.weeklyUtilization == other.weeklyUtilization &&
s.totalWeight == other.totalWeight &&
resetToEpoch(s.fiveHourReset) == resetToEpoch(other.fiveHourReset) &&
resetToEpoch(s.weeklyReset) == resetToEpoch(other.weeklyReset)
}
func (s aggregatedStatus) toPayload() statusPayload {
return statusPayload{
FiveHourUtilization: s.fiveHourUtilization,
FiveHourReset: resetToEpoch(s.fiveHourReset),
WeeklyUtilization: s.weeklyUtilization,
WeeklyReset: resetToEpoch(s.weeklyReset),
PlanWeight: s.totalWeight,
}
}
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
@@ -68,15 +104,11 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
}
provider.pollIfStale(r.Context())
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
status := s.computeAggregatedUtilization(provider, userConfig)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(statusPayload{
FiveHourUtilization: avgFiveHour,
WeeklyUtilization: avgWeekly,
PlanWeight: totalWeight,
})
json.NewEncoder(w).Encode(status.toPayload())
}
func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, provider credentialProvider, userConfig *option.CCMUser) {
@@ -96,16 +128,11 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro
provider.pollIfStale(r.Context())
w.Header().Set("Content-Type", "application/json")
w.Header().Set(statusStreamHeader, "true")
w.WriteHeader(http.StatusOK)
lastFiveHour, lastWeekly, lastWeight := s.computeAggregatedUtilization(provider, userConfig)
last := s.computeAggregatedUtilization(provider, userConfig)
buf := &bytes.Buffer{}
json.NewEncoder(buf).Encode(statusPayload{
FiveHourUtilization: lastFiveHour,
WeeklyUtilization: lastWeekly,
PlanWeight: lastWeight,
})
json.NewEncoder(buf).Encode(last.toPayload())
_, writeErr := w.Write(buf.Bytes())
if writeErr != nil {
return
@@ -127,19 +154,13 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro
}
}
drained:
fiveHour, weekly, weight := s.computeAggregatedUtilization(provider, userConfig)
if fiveHour == lastFiveHour && weekly == lastWeekly && weight == lastWeight {
current := s.computeAggregatedUtilization(provider, userConfig)
if current.equal(last) {
continue
}
lastFiveHour = fiveHour
lastWeekly = weekly
lastWeight = weight
last = current
buf.Reset()
json.NewEncoder(buf).Encode(statusPayload{
FiveHourUtilization: fiveHour,
WeeklyUtilization: weekly,
PlanWeight: weight,
})
json.NewEncoder(buf).Encode(current.toPayload())
_, writeErr = w.Write(buf.Bytes())
if writeErr != nil {
return
@@ -149,8 +170,11 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro
}
}
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) {
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) aggregatedStatus {
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
now := time.Now()
var totalWeightedHoursUntil5hReset, total5hResetWeight float64
var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64
for _, credential := range provider.allCredentials() {
if !credential.isAvailable() {
continue
@@ -173,21 +197,59 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user
totalWeightedRemaining5h += remaining5h * weight
totalWeightedRemainingWeekly += remainingWeekly * weight
totalWeight += weight
fiveHourReset := credential.fiveHourResetTime()
if !fiveHourReset.IsZero() {
hours := fiveHourReset.Sub(now).Hours()
if hours < 0 {
hours = 0
}
totalWeightedHoursUntil5hReset += hours * weight
total5hResetWeight += weight
}
weeklyReset := credential.weeklyResetTime()
if !weeklyReset.IsZero() {
hours := weeklyReset.Sub(now).Hours()
if hours < 0 {
hours = 0
}
totalWeightedHoursUntilWeeklyReset += hours * weight
totalWeeklyResetWeight += weight
}
}
if totalWeight == 0 {
return 100, 100, 0
return aggregatedStatus{
fiveHourUtilization: 100,
weeklyUtilization: 100,
}
}
return 100 - totalWeightedRemaining5h/totalWeight,
100 - totalWeightedRemainingWeekly/totalWeight,
totalWeight
result := aggregatedStatus{
fiveHourUtilization: 100 - totalWeightedRemaining5h/totalWeight,
weeklyUtilization: 100 - totalWeightedRemainingWeekly/totalWeight,
totalWeight: totalWeight,
}
if total5hResetWeight > 0 {
avgHours := totalWeightedHoursUntil5hReset / total5hResetWeight
result.fiveHourReset = now.Add(time.Duration(avgHours * float64(time.Hour)))
}
if totalWeeklyResetWeight > 0 {
avgHours := totalWeightedHoursUntilWeeklyReset / totalWeeklyResetWeight
result.weeklyReset = now.Add(time.Duration(avgHours * float64(time.Hour)))
}
return result
}
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, provider credentialProvider, userConfig *option.CCMUser) {
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64))
headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64))
if totalWeight > 0 {
headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64))
func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentialProvider, userConfig *option.CCMUser) {
status := s.computeAggregatedUtilization(provider, userConfig)
headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(status.fiveHourUtilization/100, 'f', 6, 64))
headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(status.weeklyUtilization/100, 'f', 6, 64))
if !status.fiveHourReset.IsZero() {
headers.Set("anthropic-ratelimit-unified-5h-reset", strconv.FormatInt(status.fiveHourReset.Unix(), 10))
}
if !status.weeklyReset.IsZero() {
headers.Set("anthropic-ratelimit-unified-7d-reset", strconv.FormatInt(status.weeklyReset.Unix(), 10))
}
if status.totalWeight > 0 {
headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64))
}
}

View File

@@ -107,6 +107,7 @@ type Credential interface {
weeklyCap() float64
planWeight() float64
weeklyResetTime() time.Time
fiveHourResetTime() time.Time
markRateLimited(resetAt time.Time)
earliestReset() time.Time
unavailableError() error

View File

@@ -476,6 +476,12 @@ func (c *defaultCredential) weeklyResetTime() time.Time {
return c.state.weeklyReset
}
func (c *defaultCredential) fiveHourResetTime() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.fiveHourReset
}
func (c *defaultCredential) isAvailable() bool {
c.retryCredentialReloadIfNeeded()

View File

@@ -28,10 +28,7 @@ import (
"github.com/hashicorp/yamux"
)
const (
reverseProxyBaseURL = "http://reverse-proxy"
statusStreamHeader = "X-OCM-Status-Stream"
)
const reverseProxyBaseURL = "http://reverse-proxy"
type externalCredential struct {
tag string
@@ -77,7 +74,6 @@ type reverseSessionDialer struct {
type statusStreamResult struct {
duration time.Duration
frames int
oneShot bool
}
func (d reverseSessionDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
@@ -355,6 +351,12 @@ func (c *externalCredential) weeklyResetTime() time.Time {
return c.state.weeklyReset
}
func (c *externalCredential) fiveHourResetTime() time.Time {
c.stateAccess.RLock()
defer c.stateAccess.RUnlock()
return c.state.fiveHourReset
}
func (c *externalCredential) markRateLimited(resetAt time.Time) {
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
c.stateAccess.Lock()
@@ -634,7 +636,9 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
var statusResponse struct {
FiveHourUtilization float64 `json:"five_hour_utilization"`
FiveHourReset int64 `json:"five_hour_reset"`
WeeklyUtilization float64 `json:"weekly_utilization"`
WeeklyReset int64 `json:"weekly_reset"`
PlanWeight float64 `json:"plan_weight"`
}
err = json.NewDecoder(response.Body).Decode(&statusResponse)
@@ -651,6 +655,12 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
c.state.consecutivePollFailures = 0
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
if statusResponse.FiveHourReset > 0 {
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
}
if statusResponse.WeeklyReset > 0 {
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
}
if statusResponse.PlanWeight > 0 {
c.state.remotePlanWeight = statusResponse.PlanWeight
}
@@ -687,13 +697,8 @@ func (c *externalCredential) statusStreamLoop() {
return
}
var backoff time.Duration
var oneShot bool
consecutiveFailures, backoff, oneShot = c.nextStatusStreamBackoff(result, consecutiveFailures)
if oneShot {
c.logger.Debug("status stream for ", c.tag, " returned a single-frame response, retrying in ", backoff)
} else {
c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff)
}
consecutiveFailures, backoff = c.nextStatusStreamBackoff(result, consecutiveFailures)
c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff)
timer := time.NewTimer(backoff)
select {
case <-timer.C:
@@ -721,18 +726,11 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr
}
decoder := json.NewDecoder(response.Body)
isStatusStream := response.Header.Get(statusStreamHeader) == "true"
previousLastUpdated := c.lastUpdatedTime()
var firstFrameUpdatedAt time.Time
for {
var statusResponse statusPayload
err = decoder.Decode(&statusResponse)
if err != nil {
result.duration = time.Since(startTime)
if result.frames == 1 && err == io.EOF && !isStatusStream {
result.oneShot = true
c.restoreLastUpdatedIfUnchanged(firstFrameUpdatedAt, previousLastUpdated)
}
return result, err
}
@@ -740,6 +738,12 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr
c.state.consecutivePollFailures = 0
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
if statusResponse.FiveHourReset > 0 {
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
}
if statusResponse.WeeklyReset > 0 {
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
}
if statusResponse.PlanWeight > 0 {
c.state.remotePlanWeight = statusResponse.PlanWeight
}
@@ -752,23 +756,17 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr
c.interruptConnections()
}
result.frames++
updatedAt := c.markUsageStreamUpdated()
if result.frames == 1 {
firstFrameUpdatedAt = updatedAt
}
c.markUsageStreamUpdated()
c.emitStatusUpdate()
}
}
func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration, bool) {
if result.oneShot {
return 0, c.pollInterval, true
}
func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration) {
if result.duration >= connectorBackoffResetThreshold {
consecutiveFailures = 0
}
consecutiveFailures++
return consecutiveFailures, connectorBackoff(consecutiveFailures), false
return consecutiveFailures, connectorBackoff(consecutiveFailures)
}
func (c *externalCredential) doStreamStatusRequest(ctx context.Context) (*http.Response, error) {
@@ -809,23 +807,10 @@ func (c *externalCredential) lastUpdatedTime() time.Time {
return c.state.lastUpdated
}
func (c *externalCredential) markUsageStreamUpdated() time.Time {
func (c *externalCredential) markUsageStreamUpdated() {
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
now := time.Now()
c.state.lastUpdated = now
return now
}
func (c *externalCredential) restoreLastUpdatedIfUnchanged(expectedCurrent time.Time, previous time.Time) {
if expectedCurrent.IsZero() {
return
}
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
if c.state.lastUpdated.Equal(expectedCurrent) {
c.state.lastUpdated = previous
}
c.state.lastUpdated = time.Now()
}
func (c *externalCredential) markUsagePollAttempted() {

View File

@@ -84,7 +84,7 @@ func newTestYamuxSessionPair(t *testing.T) (*yamux.Session, *yamux.Session) {
return clientSession, serverSession
}
func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *testing.T) {
func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) {
credential, subscription := newTestOCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil)
oldTime := time.Unix(123, 0)
credential.stateAccess.Lock()
@@ -95,50 +95,6 @@ func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *test
if err != io.EOF {
t.Fatalf("expected EOF, got %v", err)
}
if !result.oneShot {
t.Fatal("expected one-shot result")
}
if result.frames != 1 {
t.Fatalf("expected 1 frame, got %d", result.frames)
}
if !credential.lastUpdatedTime().Equal(oldTime) {
t.Fatalf("expected lastUpdated restored to %v, got %v", oldTime, credential.lastUpdatedTime())
}
if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 {
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
}
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event, got %d", count)
}
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
if !oneShot {
t.Fatal("expected one-shot backoff branch")
}
if failures != 0 {
t.Fatalf("expected failures reset, got %d", failures)
}
if backoff != credential.pollInterval {
t.Fatalf("expected poll interval backoff %v, got %v", credential.pollInterval, backoff)
}
}
func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) {
headers := make(http.Header)
headers.Set(statusStreamHeader, "true")
credential, subscription := newTestOCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", headers)
oldTime := time.Unix(123, 0)
credential.stateAccess.Lock()
credential.state.lastUpdated = oldTime
credential.stateAccess.Unlock()
result, err := credential.connectStatusStream(context.Background())
if err != io.EOF {
t.Fatalf("expected EOF, got %v", err)
}
if result.oneShot {
t.Fatal("did not expect one-shot result")
}
if result.frames != 1 {
t.Fatalf("expected 1 frame, got %d", result.frames)
}
@@ -152,10 +108,7 @@ func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *tes
t.Fatalf("expected 1 status event, got %d", count)
}
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
if oneShot {
t.Fatal("did not expect one-shot backoff branch")
}
failures, backoff := credential.nextStatusStreamBackoff(result, 3)
if failures != 4 {
t.Fatalf("expected failures incremented to 4, got %d", failures)
}
@@ -178,9 +131,6 @@ func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *test
if err != io.EOF {
t.Fatalf("expected EOF, got %v", err)
}
if result.oneShot {
t.Fatal("did not expect one-shot result")
}
if result.frames != 2 {
t.Fatalf("expected 2 frames, got %d", result.frames)
}

View File

@@ -291,10 +291,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// Rewrite response headers for external users
if userConfig != nil && userConfig.ExternalCredential != "" {
s.rewriteResponseHeadersForExternalUser(response.Header, provider, userConfig)
}
s.rewriteResponseHeaders(response.Header, provider, userConfig)
for key, values := range response.Header {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {

View File

@@ -6,16 +6,52 @@ import (
"net/http"
"strconv"
"strings"
"time"
"github.com/sagernet/sing-box/option"
)
type statusPayload struct {
FiveHourUtilization float64 `json:"five_hour_utilization"`
FiveHourReset int64 `json:"five_hour_reset"`
WeeklyUtilization float64 `json:"weekly_utilization"`
WeeklyReset int64 `json:"weekly_reset"`
PlanWeight float64 `json:"plan_weight"`
}
type aggregatedStatus struct {
fiveHourUtilization float64
weeklyUtilization float64
totalWeight float64
fiveHourReset time.Time
weeklyReset time.Time
}
func resetToEpoch(t time.Time) int64 {
if t.IsZero() {
return 0
}
return t.Unix()
}
func (s aggregatedStatus) equal(other aggregatedStatus) bool {
return s.fiveHourUtilization == other.fiveHourUtilization &&
s.weeklyUtilization == other.weeklyUtilization &&
s.totalWeight == other.totalWeight &&
resetToEpoch(s.fiveHourReset) == resetToEpoch(other.fiveHourReset) &&
resetToEpoch(s.weeklyReset) == resetToEpoch(other.weeklyReset)
}
func (s aggregatedStatus) toPayload() statusPayload {
return statusPayload{
FiveHourUtilization: s.fiveHourUtilization,
FiveHourReset: resetToEpoch(s.fiveHourReset),
WeeklyUtilization: s.weeklyUtilization,
WeeklyReset: resetToEpoch(s.weeklyReset),
PlanWeight: s.totalWeight,
}
}
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
@@ -68,15 +104,11 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
}
provider.pollIfStale(r.Context())
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
status := s.computeAggregatedUtilization(provider, userConfig)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(statusPayload{
FiveHourUtilization: avgFiveHour,
WeeklyUtilization: avgWeekly,
PlanWeight: totalWeight,
})
json.NewEncoder(w).Encode(status.toPayload())
}
func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, provider credentialProvider, userConfig *option.OCMUser) {
@@ -96,16 +128,11 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro
provider.pollIfStale(r.Context())
w.Header().Set("Content-Type", "application/json")
w.Header().Set(statusStreamHeader, "true")
w.WriteHeader(http.StatusOK)
lastFiveHour, lastWeekly, lastWeight := s.computeAggregatedUtilization(provider, userConfig)
last := s.computeAggregatedUtilization(provider, userConfig)
buf := &bytes.Buffer{}
json.NewEncoder(buf).Encode(statusPayload{
FiveHourUtilization: lastFiveHour,
WeeklyUtilization: lastWeekly,
PlanWeight: lastWeight,
})
json.NewEncoder(buf).Encode(last.toPayload())
_, writeErr := w.Write(buf.Bytes())
if writeErr != nil {
return
@@ -127,19 +154,13 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro
}
}
drained:
fiveHour, weekly, weight := s.computeAggregatedUtilization(provider, userConfig)
if fiveHour == lastFiveHour && weekly == lastWeekly && weight == lastWeight {
current := s.computeAggregatedUtilization(provider, userConfig)
if current.equal(last) {
continue
}
lastFiveHour = fiveHour
lastWeekly = weekly
lastWeight = weight
last = current
buf.Reset()
json.NewEncoder(buf).Encode(statusPayload{
FiveHourUtilization: fiveHour,
WeeklyUtilization: weekly,
PlanWeight: weight,
})
json.NewEncoder(buf).Encode(current.toPayload())
_, writeErr = w.Write(buf.Bytes())
if writeErr != nil {
return
@@ -149,8 +170,11 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro
}
}
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) {
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) aggregatedStatus {
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
now := time.Now()
var totalWeightedHoursUntil5hReset, total5hResetWeight float64
var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64
for _, credential := range provider.allCredentials() {
if !credential.isAvailable() {
continue
@@ -173,26 +197,63 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user
totalWeightedRemaining5h += remaining5h * weight
totalWeightedRemainingWeekly += remainingWeekly * weight
totalWeight += weight
fiveHourReset := credential.fiveHourResetTime()
if !fiveHourReset.IsZero() {
hours := fiveHourReset.Sub(now).Hours()
if hours < 0 {
hours = 0
}
totalWeightedHoursUntil5hReset += hours * weight
total5hResetWeight += weight
}
weeklyReset := credential.weeklyResetTime()
if !weeklyReset.IsZero() {
hours := weeklyReset.Sub(now).Hours()
if hours < 0 {
hours = 0
}
totalWeightedHoursUntilWeeklyReset += hours * weight
totalWeeklyResetWeight += weight
}
}
if totalWeight == 0 {
return 100, 100, 0
return aggregatedStatus{
fiveHourUtilization: 100,
weeklyUtilization: 100,
}
}
return 100 - totalWeightedRemaining5h/totalWeight,
100 - totalWeightedRemainingWeekly/totalWeight,
totalWeight
result := aggregatedStatus{
fiveHourUtilization: 100 - totalWeightedRemaining5h/totalWeight,
weeklyUtilization: 100 - totalWeightedRemainingWeekly/totalWeight,
totalWeight: totalWeight,
}
if total5hResetWeight > 0 {
avgHours := totalWeightedHoursUntil5hReset / total5hResetWeight
result.fiveHourReset = now.Add(time.Duration(avgHours * float64(time.Hour)))
}
if totalWeeklyResetWeight > 0 {
avgHours := totalWeightedHoursUntilWeeklyReset / totalWeeklyResetWeight
result.weeklyReset = now.Add(time.Duration(avgHours * float64(time.Hour)))
}
return result
}
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, provider credentialProvider, userConfig *option.OCMUser) {
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentialProvider, userConfig *option.OCMUser) {
status := s.computeAggregatedUtilization(provider, userConfig)
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
if activeLimitIdentifier == "" {
activeLimitIdentifier = "codex"
}
headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64))
headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64))
if totalWeight > 0 {
headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64))
headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(status.fiveHourUtilization, 'f', 2, 64))
headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(status.weeklyUtilization, 'f', 2, 64))
if !status.fiveHourReset.IsZero() {
headers.Set("x-"+activeLimitIdentifier+"-primary-reset-at", strconv.FormatInt(status.fiveHourReset.Unix(), 10))
}
if !status.weeklyReset.IsZero() {
headers.Set("x-"+activeLimitIdentifier+"-secondary-reset-at", strconv.FormatInt(status.weeklyReset.Unix(), 10))
}
if status.totalWeight > 0 {
headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64))
}
}

View File

@@ -252,9 +252,7 @@ func (s *Service) handleWebSocket(
clientResponseHeaders[key] = append([]string(nil), values...)
}
}
if userConfig != nil && userConfig.ExternalCredential != "" {
s.rewriteResponseHeadersForExternalUser(clientResponseHeaders, provider, userConfig)
}
s.rewriteResponseHeaders(clientResponseHeaders, provider, userConfig)
clientUpgrader := ws.HTTPUpgrader{
Header: clientResponseHeaders,
@@ -292,10 +290,16 @@ func (s *Service) handleWebSocket(
upstreamReadWriter = upstreamConn
}
rateLimitIdentifier := normalizeRateLimitIdentifier(upstreamResponseHeaders.Get("x-codex-active-limit"))
if rateLimitIdentifier == "" {
rateLimitIdentifier = "codex"
}
var clientWriteAccess sync.Mutex
modelChannel := make(chan string, 1)
var waitGroup sync.WaitGroup
waitGroup.Add(2)
waitGroup.Add(3)
go func() {
defer waitGroup.Done()
defer session.Close()
@@ -304,7 +308,12 @@ func (s *Service) handleWebSocket(
go func() {
defer waitGroup.Done()
defer session.Close()
s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint)
s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, &clientWriteAccess, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint)
}()
go func() {
defer waitGroup.Done()
defer session.Close()
s.pushWebSocketAggregatedStatus(ctx, clientConn, &clientWriteAccess, provider, userConfig, rateLimitIdentifier)
}()
waitGroup.Wait()
}
@@ -363,7 +372,7 @@ func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn
}
}
func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential Credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, clientWriteAccess *sync.Mutex, selectedCredential Credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
usageTracker := selectedCredential.usageTrackerOrNil()
var requestModel string
for {
@@ -384,11 +393,9 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe
switch event.Type {
case "codex.rate_limits":
s.handleWebSocketRateLimitsEvent(data, selectedCredential)
if userConfig != nil && userConfig.ExternalCredential != "" {
rewritten, rewriteErr := s.rewriteWebSocketRateLimitsForExternalUser(data, provider, userConfig)
if rewriteErr == nil {
data = rewritten
}
rewritten, rewriteErr := s.rewriteWebSocketRateLimits(data, provider, userConfig)
if rewriteErr == nil {
data = rewritten
}
case "error":
if event.StatusCode == http.StatusTooManyRequests {
@@ -407,7 +414,9 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe
}
}
clientWriteAccess.Lock()
err = wsutil.WriteServerMessage(clientConn, opCode, data)
clientWriteAccess.Unlock()
if err != nil {
if !E.IsClosedOrCanceled(err) {
s.logger.DebugContext(ctx, "write client websocket: ", err)
@@ -483,7 +492,7 @@ func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredentia
selectedCredential.markRateLimited(resetAt)
}
func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provider credentialProvider, userConfig *option.OCMUser) ([]byte, error) {
func (s *Service) rewriteWebSocketRateLimits(data []byte, provider credentialProvider, userConfig *option.OCMUser) ([]byte, error) {
var event map[string]json.RawMessage
err := json.Unmarshal(data, &event)
if err != nil {
@@ -501,13 +510,13 @@ func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provide
return nil, err
}
averageFiveHour, averageWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
status := s.computeAggregatedUtilization(provider, userConfig)
if totalWeight > 0 {
event["plan_weight"], _ = json.Marshal(totalWeight)
if status.totalWeight > 0 {
event["plan_weight"], _ = json.Marshal(status.totalWeight)
}
primaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["primary"], averageFiveHour)
primaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["primary"], status.fiveHourUtilization, resetToEpoch(status.fiveHourReset))
if err != nil {
return nil, err
}
@@ -515,7 +524,7 @@ func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provide
rateLimits["primary"] = primaryData
}
secondaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["secondary"], averageWeekly)
secondaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["secondary"], status.weeklyUtilization, resetToEpoch(status.weeklyReset))
if err != nil {
return nil, err
}
@@ -531,7 +540,7 @@ func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provide
return json.Marshal(event)
}
func rewriteWebSocketRateLimitWindow(data json.RawMessage, usedPercent float64) (json.RawMessage, error) {
func rewriteWebSocketRateLimitWindow(data json.RawMessage, usedPercent float64, resetAt int64) (json.RawMessage, error) {
if len(data) == 0 || string(data) == "null" {
return nil, nil
}
@@ -547,9 +556,93 @@ func rewriteWebSocketRateLimitWindow(data json.RawMessage, usedPercent float64)
return nil, err
}
if resetAt > 0 {
window["reset_at"], err = json.Marshal(resetAt)
if err != nil {
return nil, err
}
}
return json.Marshal(window)
}
func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn net.Conn, clientWriteAccess *sync.Mutex, provider credentialProvider, userConfig *option.OCMUser, rateLimitIdentifier string) {
subscription, done, err := s.statusObserver.Subscribe()
if err != nil {
return
}
defer s.statusObserver.UnSubscribe(subscription)
last := s.computeAggregatedUtilization(provider, userConfig)
data := buildSyntheticRateLimitsEvent(rateLimitIdentifier, last)
clientWriteAccess.Lock()
err = wsutil.WriteServerMessage(clientConn, ws.OpText, data)
clientWriteAccess.Unlock()
if err != nil {
return
}
for {
select {
case <-ctx.Done():
return
case <-done:
return
case <-subscription:
for {
select {
case <-subscription:
default:
goto drained
}
}
drained:
current := s.computeAggregatedUtilization(provider, userConfig)
if current.equal(last) {
continue
}
last = current
data = buildSyntheticRateLimitsEvent(rateLimitIdentifier, current)
clientWriteAccess.Lock()
err = wsutil.WriteServerMessage(clientConn, ws.OpText, data)
clientWriteAccess.Unlock()
if err != nil {
return
}
}
}
}
func buildSyntheticRateLimitsEvent(identifier string, status aggregatedStatus) []byte {
type rateLimitWindow struct {
UsedPercent float64 `json:"used_percent"`
ResetAt int64 `json:"reset_at,omitempty"`
}
event := struct {
Type string `json:"type"`
RateLimits struct {
Primary *rateLimitWindow `json:"primary,omitempty"`
Secondary *rateLimitWindow `json:"secondary,omitempty"`
} `json:"rate_limits"`
LimitName string `json:"limit_name"`
PlanWeight float64 `json:"plan_weight,omitempty"`
}{
Type: "codex.rate_limits",
LimitName: identifier,
PlanWeight: status.totalWeight,
}
event.RateLimits.Primary = &rateLimitWindow{
UsedPercent: status.fiveHourUtilization,
ResetAt: resetToEpoch(status.fiveHourReset),
}
event.RateLimits.Secondary = &rateLimitWindow{
UsedPercent: status.weeklyUtilization,
ResetAt: resetToEpoch(status.weeklyReset),
}
data, _ := json.Marshal(event)
return data
}
func (s *Service) handleWebSocketResponseCompleted(data []byte, usageTracker *AggregatedUsage, requestModel string, username string, weeklyCycleHint *WeeklyCycleHint) {
var streamEvent responses.ResponseStreamEventUnion
if json.Unmarshal(data, &streamEvent) != nil {