From 0a054b9aa4fd3c5f60e47c3c139295166cb66227 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 17 Mar 2026 18:13:54 +0800 Subject: [PATCH] ccm,ocm: propagate reset times, rewrite headers for all users, add WS status push - Add fiveHourReset/weeklyReset to statusPayload and aggregatedStatus with weight-averaged reset time aggregation across credential pools - Rewrite response headers (utilization + reset times) for all users, not just external credential users - Rewrite WebSocket rate_limits events for all users with aggregated values - Add proactive WebSocket status push: synthetic codex.rate_limits events sent on connection start and on status changes via statusObserver - Remove one-shot stream forward compatibility (statusStreamHeader, restoreLastUpdatedIfUnchanged, oneShot detection) --- service/ccm/credential.go | 1 + service/ccm/credential_default.go | 6 ++ service/ccm/credential_external.go | 71 ++++++-------- service/ccm/credential_status_test.go | 54 +---------- service/ccm/service_handler.go | 5 +- service/ccm/service_status.go | 132 ++++++++++++++++++------- service/ocm/credential.go | 1 + service/ocm/credential_default.go | 6 ++ service/ocm/credential_external.go | 71 ++++++-------- service/ocm/credential_status_test.go | 54 +---------- service/ocm/service_handler.go | 5 +- service/ocm/service_status.go | 133 +++++++++++++++++++------- service/ocm/service_websocket.go | 129 +++++++++++++++++++++---- 13 files changed, 381 insertions(+), 287 deletions(-) diff --git a/service/ccm/credential.go b/service/ccm/credential.go index 9defed434..6f41ba128 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -105,6 +105,7 @@ type Credential interface { fiveHourCap() float64 weeklyCap() float64 planWeight() float64 + fiveHourResetTime() time.Time weeklyResetTime() time.Time markRateLimited(resetAt time.Time) earliestReset() time.Time diff --git a/service/ccm/credential_default.go b/service/ccm/credential_default.go index f23ba3f65..bf88fc836 100644 --- a/service/ccm/credential_default.go +++ b/service/ccm/credential_default.go @@ -421,6 +421,12 @@ func (c *defaultCredential) planWeight() float64 { return ccmPlanWeight(c.state.accountType, c.state.rateLimitTier) } +func (c *defaultCredential) fiveHourResetTime() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.fiveHourReset +} + func (c *defaultCredential) weeklyResetTime() time.Time { c.stateAccess.RLock() defer c.stateAccess.RUnlock() diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 3d39a88ca..cdbbc4844 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -27,10 +27,7 @@ import ( "github.com/hashicorp/yamux" ) -const ( - reverseProxyBaseURL = "http://reverse-proxy" - statusStreamHeader = "X-CCM-Status-Stream" -) +const reverseProxyBaseURL = "http://reverse-proxy" type externalCredential struct { tag string @@ -70,7 +67,6 @@ type externalCredential struct { type statusStreamResult struct { duration time.Duration frames int - oneShot bool } func externalCredentialURLPort(parsedURL *url.URL) uint16 { @@ -325,6 +321,12 @@ func (c *externalCredential) planWeight() float64 { return 10 } +func (c *externalCredential) fiveHourResetTime() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.fiveHourReset +} + func (c *externalCredential) weeklyResetTime() time.Time { c.stateAccess.RLock() defer c.stateAccess.RUnlock() @@ -592,7 +594,9 @@ func (c *externalCredential) pollUsage(ctx context.Context) { var statusResponse struct { FiveHourUtilization float64 `json:"five_hour_utilization"` + FiveHourReset int64 `json:"five_hour_reset"` WeeklyUtilization float64 `json:"weekly_utilization"` + WeeklyReset int64 `json:"weekly_reset"` PlanWeight float64 `json:"plan_weight"` } err = json.NewDecoder(response.Body).Decode(&statusResponse) @@ -612,6 +616,12 @@ func (c *externalCredential) pollUsage(ctx context.Context) { if statusResponse.PlanWeight > 0 { c.state.remotePlanWeight = statusResponse.PlanWeight } + if statusResponse.FiveHourReset > 0 { + c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) + } + if statusResponse.WeeklyReset > 0 { + c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) + } if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } @@ -645,13 +655,8 @@ func (c *externalCredential) statusStreamLoop() { return } var backoff time.Duration - var oneShot bool - consecutiveFailures, backoff, oneShot = c.nextStatusStreamBackoff(result, consecutiveFailures) - if oneShot { - c.logger.Debug("status stream for ", c.tag, " returned a single-frame response, retrying in ", backoff) - } else { - c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff) - } + consecutiveFailures, backoff = c.nextStatusStreamBackoff(result, consecutiveFailures) + c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff) timer := time.NewTimer(backoff) select { case <-timer.C: @@ -679,18 +684,11 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr } decoder := json.NewDecoder(response.Body) - isStatusStream := response.Header.Get(statusStreamHeader) == "true" - previousLastUpdated := c.lastUpdatedTime() - var firstFrameUpdatedAt time.Time for { var statusResponse statusPayload err = decoder.Decode(&statusResponse) if err != nil { result.duration = time.Since(startTime) - if result.frames == 1 && err == io.EOF && !isStatusStream { - result.oneShot = true - c.restoreLastUpdatedIfUnchanged(firstFrameUpdatedAt, previousLastUpdated) - } return result, err } @@ -701,6 +699,12 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr if statusResponse.PlanWeight > 0 { c.state.remotePlanWeight = statusResponse.PlanWeight } + if statusResponse.FiveHourReset > 0 { + c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) + } + if statusResponse.WeeklyReset > 0 { + c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) + } if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } @@ -710,23 +714,17 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr c.interruptConnections() } result.frames++ - updatedAt := c.markUsageStreamUpdated() - if result.frames == 1 { - firstFrameUpdatedAt = updatedAt - } + c.markUsageStreamUpdated() c.emitStatusUpdate() } } -func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration, bool) { - if result.oneShot { - return 0, c.pollInterval, true - } +func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration) { if result.duration >= connectorBackoffResetThreshold { consecutiveFailures = 0 } consecutiveFailures++ - return consecutiveFailures, connectorBackoff(consecutiveFailures), false + return consecutiveFailures, connectorBackoff(consecutiveFailures) } func (c *externalCredential) doStreamStatusRequest(ctx context.Context) (*http.Response, error) { @@ -767,23 +765,10 @@ func (c *externalCredential) lastUpdatedTime() time.Time { return c.state.lastUpdated } -func (c *externalCredential) markUsageStreamUpdated() time.Time { +func (c *externalCredential) markUsageStreamUpdated() { c.stateAccess.Lock() defer c.stateAccess.Unlock() - now := time.Now() - c.state.lastUpdated = now - return now -} - -func (c *externalCredential) restoreLastUpdatedIfUnchanged(expectedCurrent time.Time, previous time.Time) { - if expectedCurrent.IsZero() { - return - } - c.stateAccess.Lock() - defer c.stateAccess.Unlock() - if c.state.lastUpdated.Equal(expectedCurrent) { - c.state.lastUpdated = previous - } + c.state.lastUpdated = time.Now() } func (c *externalCredential) markUsagePollAttempted() { diff --git a/service/ccm/credential_status_test.go b/service/ccm/credential_status_test.go index f92b27e85..9353f1d83 100644 --- a/service/ccm/credential_status_test.go +++ b/service/ccm/credential_status_test.go @@ -84,7 +84,7 @@ func newTestYamuxSessionPair(t *testing.T) (*yamux.Session, *yamux.Session) { return clientSession, serverSession } -func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *testing.T) { +func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) { credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil) oldTime := time.Unix(123, 0) credential.stateAccess.Lock() @@ -95,50 +95,6 @@ func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *test if err != io.EOF { t.Fatalf("expected EOF, got %v", err) } - if !result.oneShot { - t.Fatal("expected one-shot result") - } - if result.frames != 1 { - t.Fatalf("expected 1 frame, got %d", result.frames) - } - if !credential.lastUpdatedTime().Equal(oldTime) { - t.Fatalf("expected lastUpdated restored to %v, got %v", oldTime, credential.lastUpdatedTime()) - } - if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 { - t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization()) - } - if count := drainStatusEvents(subscription); count != 1 { - t.Fatalf("expected 1 status event, got %d", count) - } - - failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3) - if !oneShot { - t.Fatal("expected one-shot backoff branch") - } - if failures != 0 { - t.Fatalf("expected failures reset, got %d", failures) - } - if backoff != credential.pollInterval { - t.Fatalf("expected poll interval backoff %v, got %v", credential.pollInterval, backoff) - } -} - -func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) { - headers := make(http.Header) - headers.Set(statusStreamHeader, "true") - credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", headers) - oldTime := time.Unix(123, 0) - credential.stateAccess.Lock() - credential.state.lastUpdated = oldTime - credential.stateAccess.Unlock() - - result, err := credential.connectStatusStream(context.Background()) - if err != io.EOF { - t.Fatalf("expected EOF, got %v", err) - } - if result.oneShot { - t.Fatal("did not expect one-shot result") - } if result.frames != 1 { t.Fatalf("expected 1 frame, got %d", result.frames) } @@ -152,10 +108,7 @@ func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *tes t.Fatalf("expected 1 status event, got %d", count) } - failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3) - if oneShot { - t.Fatal("did not expect one-shot backoff branch") - } + failures, backoff := credential.nextStatusStreamBackoff(result, 3) if failures != 4 { t.Fatalf("expected failures incremented to 4, got %d", failures) } @@ -178,9 +131,6 @@ func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *test if err != io.EOF { t.Fatalf("expected EOF, got %v", err) } - if result.oneShot { - t.Fatal("did not expect one-shot result") - } if result.frames != 2 { t.Fatalf("expected 2 frames, got %d", result.frames) } diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index 33d6317de..1ccbd83ff 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -311,10 +311,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Rewrite response headers for external users - if userConfig != nil && userConfig.ExternalCredential != "" { - s.rewriteResponseHeadersForExternalUser(response.Header, provider, userConfig) - } + s.rewriteResponseHeaders(response.Header, provider, userConfig) for key, values := range response.Header { if !isHopByHopHeader(key) && !isReverseProxyHeader(key) { diff --git a/service/ccm/service_status.go b/service/ccm/service_status.go index b5771d4b8..e3aa43e30 100644 --- a/service/ccm/service_status.go +++ b/service/ccm/service_status.go @@ -6,16 +6,52 @@ import ( "net/http" "strconv" "strings" + "time" "github.com/sagernet/sing-box/option" ) type statusPayload struct { FiveHourUtilization float64 `json:"five_hour_utilization"` + FiveHourReset int64 `json:"five_hour_reset"` WeeklyUtilization float64 `json:"weekly_utilization"` + WeeklyReset int64 `json:"weekly_reset"` PlanWeight float64 `json:"plan_weight"` } +type aggregatedStatus struct { + fiveHourUtilization float64 + weeklyUtilization float64 + totalWeight float64 + fiveHourReset time.Time + weeklyReset time.Time +} + +func resetToEpoch(t time.Time) int64 { + if t.IsZero() { + return 0 + } + return t.Unix() +} + +func (s aggregatedStatus) equal(other aggregatedStatus) bool { + return s.fiveHourUtilization == other.fiveHourUtilization && + s.weeklyUtilization == other.weeklyUtilization && + s.totalWeight == other.totalWeight && + resetToEpoch(s.fiveHourReset) == resetToEpoch(other.fiveHourReset) && + resetToEpoch(s.weeklyReset) == resetToEpoch(other.weeklyReset) +} + +func (s aggregatedStatus) toPayload() statusPayload { + return statusPayload{ + FiveHourUtilization: s.fiveHourUtilization, + FiveHourReset: resetToEpoch(s.fiveHourReset), + WeeklyUtilization: s.weeklyUtilization, + WeeklyReset: resetToEpoch(s.weeklyReset), + PlanWeight: s.totalWeight, + } +} + func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") @@ -68,15 +104,11 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { } provider.pollIfStale(r.Context()) - avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) + status := s.computeAggregatedUtilization(provider, userConfig) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(statusPayload{ - FiveHourUtilization: avgFiveHour, - WeeklyUtilization: avgWeekly, - PlanWeight: totalWeight, - }) + json.NewEncoder(w).Encode(status.toPayload()) } func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, provider credentialProvider, userConfig *option.CCMUser) { @@ -96,16 +128,11 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro provider.pollIfStale(r.Context()) w.Header().Set("Content-Type", "application/json") - w.Header().Set(statusStreamHeader, "true") w.WriteHeader(http.StatusOK) - lastFiveHour, lastWeekly, lastWeight := s.computeAggregatedUtilization(provider, userConfig) + last := s.computeAggregatedUtilization(provider, userConfig) buf := &bytes.Buffer{} - json.NewEncoder(buf).Encode(statusPayload{ - FiveHourUtilization: lastFiveHour, - WeeklyUtilization: lastWeekly, - PlanWeight: lastWeight, - }) + json.NewEncoder(buf).Encode(last.toPayload()) _, writeErr := w.Write(buf.Bytes()) if writeErr != nil { return @@ -127,19 +154,13 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro } } drained: - fiveHour, weekly, weight := s.computeAggregatedUtilization(provider, userConfig) - if fiveHour == lastFiveHour && weekly == lastWeekly && weight == lastWeight { + current := s.computeAggregatedUtilization(provider, userConfig) + if current.equal(last) { continue } - lastFiveHour = fiveHour - lastWeekly = weekly - lastWeight = weight + last = current buf.Reset() - json.NewEncoder(buf).Encode(statusPayload{ - FiveHourUtilization: fiveHour, - WeeklyUtilization: weekly, - PlanWeight: weight, - }) + json.NewEncoder(buf).Encode(current.toPayload()) _, writeErr = w.Write(buf.Bytes()) if writeErr != nil { return @@ -149,8 +170,11 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro } } -func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) { +func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) aggregatedStatus { var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 + now := time.Now() + var totalWeightedHoursUntil5hReset, total5hResetWeight float64 + var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64 for _, credential := range provider.allCredentials() { if !credential.isAvailable() { continue @@ -173,21 +197,59 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user totalWeightedRemaining5h += remaining5h * weight totalWeightedRemainingWeekly += remainingWeekly * weight totalWeight += weight + + fiveHourReset := credential.fiveHourResetTime() + if !fiveHourReset.IsZero() { + hours := fiveHourReset.Sub(now).Hours() + if hours < 0 { + hours = 0 + } + totalWeightedHoursUntil5hReset += hours * weight + total5hResetWeight += weight + } + weeklyReset := credential.weeklyResetTime() + if !weeklyReset.IsZero() { + hours := weeklyReset.Sub(now).Hours() + if hours < 0 { + hours = 0 + } + totalWeightedHoursUntilWeeklyReset += hours * weight + totalWeeklyResetWeight += weight + } } if totalWeight == 0 { - return 100, 100, 0 + return aggregatedStatus{ + fiveHourUtilization: 100, + weeklyUtilization: 100, + } } - return 100 - totalWeightedRemaining5h/totalWeight, - 100 - totalWeightedRemainingWeekly/totalWeight, - totalWeight + result := aggregatedStatus{ + fiveHourUtilization: 100 - totalWeightedRemaining5h/totalWeight, + weeklyUtilization: 100 - totalWeightedRemainingWeekly/totalWeight, + totalWeight: totalWeight, + } + if total5hResetWeight > 0 { + avgHours := totalWeightedHoursUntil5hReset / total5hResetWeight + result.fiveHourReset = now.Add(time.Duration(avgHours * float64(time.Hour))) + } + if totalWeeklyResetWeight > 0 { + avgHours := totalWeightedHoursUntilWeeklyReset / totalWeeklyResetWeight + result.weeklyReset = now.Add(time.Duration(avgHours * float64(time.Hour))) + } + return result } -func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, provider credentialProvider, userConfig *option.CCMUser) { - avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) - - headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64)) - headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64)) - if totalWeight > 0 { - headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64)) +func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentialProvider, userConfig *option.CCMUser) { + status := s.computeAggregatedUtilization(provider, userConfig) + headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(status.fiveHourUtilization/100, 'f', 6, 64)) + headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(status.weeklyUtilization/100, 'f', 6, 64)) + if !status.fiveHourReset.IsZero() { + headers.Set("anthropic-ratelimit-unified-5h-reset", strconv.FormatInt(status.fiveHourReset.Unix(), 10)) + } + if !status.weeklyReset.IsZero() { + headers.Set("anthropic-ratelimit-unified-7d-reset", strconv.FormatInt(status.weeklyReset.Unix(), 10)) + } + if status.totalWeight > 0 { + headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64)) } } diff --git a/service/ocm/credential.go b/service/ocm/credential.go index c777ea5c9..1478f5f19 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -107,6 +107,7 @@ type Credential interface { weeklyCap() float64 planWeight() float64 weeklyResetTime() time.Time + fiveHourResetTime() time.Time markRateLimited(resetAt time.Time) earliestReset() time.Time unavailableError() error diff --git a/service/ocm/credential_default.go b/service/ocm/credential_default.go index 89daf56e7..1e0af9847 100644 --- a/service/ocm/credential_default.go +++ b/service/ocm/credential_default.go @@ -476,6 +476,12 @@ func (c *defaultCredential) weeklyResetTime() time.Time { return c.state.weeklyReset } +func (c *defaultCredential) fiveHourResetTime() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.fiveHourReset +} + func (c *defaultCredential) isAvailable() bool { c.retryCredentialReloadIfNeeded() diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index dd13aca60..b796ff0bb 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -28,10 +28,7 @@ import ( "github.com/hashicorp/yamux" ) -const ( - reverseProxyBaseURL = "http://reverse-proxy" - statusStreamHeader = "X-OCM-Status-Stream" -) +const reverseProxyBaseURL = "http://reverse-proxy" type externalCredential struct { tag string @@ -77,7 +74,6 @@ type reverseSessionDialer struct { type statusStreamResult struct { duration time.Duration frames int - oneShot bool } func (d reverseSessionDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { @@ -355,6 +351,12 @@ func (c *externalCredential) weeklyResetTime() time.Time { return c.state.weeklyReset } +func (c *externalCredential) fiveHourResetTime() time.Time { + c.stateAccess.RLock() + defer c.stateAccess.RUnlock() + return c.state.fiveHourReset +} + func (c *externalCredential) markRateLimited(resetAt time.Time) { c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt))) c.stateAccess.Lock() @@ -634,7 +636,9 @@ func (c *externalCredential) pollUsage(ctx context.Context) { var statusResponse struct { FiveHourUtilization float64 `json:"five_hour_utilization"` + FiveHourReset int64 `json:"five_hour_reset"` WeeklyUtilization float64 `json:"weekly_utilization"` + WeeklyReset int64 `json:"weekly_reset"` PlanWeight float64 `json:"plan_weight"` } err = json.NewDecoder(response.Body).Decode(&statusResponse) @@ -651,6 +655,12 @@ func (c *externalCredential) pollUsage(ctx context.Context) { c.state.consecutivePollFailures = 0 c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization + if statusResponse.FiveHourReset > 0 { + c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) + } + if statusResponse.WeeklyReset > 0 { + c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) + } if statusResponse.PlanWeight > 0 { c.state.remotePlanWeight = statusResponse.PlanWeight } @@ -687,13 +697,8 @@ func (c *externalCredential) statusStreamLoop() { return } var backoff time.Duration - var oneShot bool - consecutiveFailures, backoff, oneShot = c.nextStatusStreamBackoff(result, consecutiveFailures) - if oneShot { - c.logger.Debug("status stream for ", c.tag, " returned a single-frame response, retrying in ", backoff) - } else { - c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff) - } + consecutiveFailures, backoff = c.nextStatusStreamBackoff(result, consecutiveFailures) + c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff) timer := time.NewTimer(backoff) select { case <-timer.C: @@ -721,18 +726,11 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr } decoder := json.NewDecoder(response.Body) - isStatusStream := response.Header.Get(statusStreamHeader) == "true" - previousLastUpdated := c.lastUpdatedTime() - var firstFrameUpdatedAt time.Time for { var statusResponse statusPayload err = decoder.Decode(&statusResponse) if err != nil { result.duration = time.Since(startTime) - if result.frames == 1 && err == io.EOF && !isStatusStream { - result.oneShot = true - c.restoreLastUpdatedIfUnchanged(firstFrameUpdatedAt, previousLastUpdated) - } return result, err } @@ -740,6 +738,12 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr c.state.consecutivePollFailures = 0 c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization + if statusResponse.FiveHourReset > 0 { + c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0) + } + if statusResponse.WeeklyReset > 0 { + c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0) + } if statusResponse.PlanWeight > 0 { c.state.remotePlanWeight = statusResponse.PlanWeight } @@ -752,23 +756,17 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr c.interruptConnections() } result.frames++ - updatedAt := c.markUsageStreamUpdated() - if result.frames == 1 { - firstFrameUpdatedAt = updatedAt - } + c.markUsageStreamUpdated() c.emitStatusUpdate() } } -func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration, bool) { - if result.oneShot { - return 0, c.pollInterval, true - } +func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration) { if result.duration >= connectorBackoffResetThreshold { consecutiveFailures = 0 } consecutiveFailures++ - return consecutiveFailures, connectorBackoff(consecutiveFailures), false + return consecutiveFailures, connectorBackoff(consecutiveFailures) } func (c *externalCredential) doStreamStatusRequest(ctx context.Context) (*http.Response, error) { @@ -809,23 +807,10 @@ func (c *externalCredential) lastUpdatedTime() time.Time { return c.state.lastUpdated } -func (c *externalCredential) markUsageStreamUpdated() time.Time { +func (c *externalCredential) markUsageStreamUpdated() { c.stateAccess.Lock() defer c.stateAccess.Unlock() - now := time.Now() - c.state.lastUpdated = now - return now -} - -func (c *externalCredential) restoreLastUpdatedIfUnchanged(expectedCurrent time.Time, previous time.Time) { - if expectedCurrent.IsZero() { - return - } - c.stateAccess.Lock() - defer c.stateAccess.Unlock() - if c.state.lastUpdated.Equal(expectedCurrent) { - c.state.lastUpdated = previous - } + c.state.lastUpdated = time.Now() } func (c *externalCredential) markUsagePollAttempted() { diff --git a/service/ocm/credential_status_test.go b/service/ocm/credential_status_test.go index 955338fce..2865a2380 100644 --- a/service/ocm/credential_status_test.go +++ b/service/ocm/credential_status_test.go @@ -84,7 +84,7 @@ func newTestYamuxSessionPair(t *testing.T) (*yamux.Session, *yamux.Session) { return clientSession, serverSession } -func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *testing.T) { +func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) { credential, subscription := newTestOCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil) oldTime := time.Unix(123, 0) credential.stateAccess.Lock() @@ -95,50 +95,6 @@ func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *test if err != io.EOF { t.Fatalf("expected EOF, got %v", err) } - if !result.oneShot { - t.Fatal("expected one-shot result") - } - if result.frames != 1 { - t.Fatalf("expected 1 frame, got %d", result.frames) - } - if !credential.lastUpdatedTime().Equal(oldTime) { - t.Fatalf("expected lastUpdated restored to %v, got %v", oldTime, credential.lastUpdatedTime()) - } - if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 { - t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization()) - } - if count := drainStatusEvents(subscription); count != 1 { - t.Fatalf("expected 1 status event, got %d", count) - } - - failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3) - if !oneShot { - t.Fatal("expected one-shot backoff branch") - } - if failures != 0 { - t.Fatalf("expected failures reset, got %d", failures) - } - if backoff != credential.pollInterval { - t.Fatalf("expected poll interval backoff %v, got %v", credential.pollInterval, backoff) - } -} - -func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) { - headers := make(http.Header) - headers.Set(statusStreamHeader, "true") - credential, subscription := newTestOCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", headers) - oldTime := time.Unix(123, 0) - credential.stateAccess.Lock() - credential.state.lastUpdated = oldTime - credential.stateAccess.Unlock() - - result, err := credential.connectStatusStream(context.Background()) - if err != io.EOF { - t.Fatalf("expected EOF, got %v", err) - } - if result.oneShot { - t.Fatal("did not expect one-shot result") - } if result.frames != 1 { t.Fatalf("expected 1 frame, got %d", result.frames) } @@ -152,10 +108,7 @@ func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *tes t.Fatalf("expected 1 status event, got %d", count) } - failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3) - if oneShot { - t.Fatal("did not expect one-shot backoff branch") - } + failures, backoff := credential.nextStatusStreamBackoff(result, 3) if failures != 4 { t.Fatalf("expected failures incremented to 4, got %d", failures) } @@ -178,9 +131,6 @@ func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *test if err != io.EOF { t.Fatalf("expected EOF, got %v", err) } - if result.oneShot { - t.Fatal("did not expect one-shot result") - } if result.frames != 2 { t.Fatalf("expected 2 frames, got %d", result.frames) } diff --git a/service/ocm/service_handler.go b/service/ocm/service_handler.go index 0a7698d00..52e35f39b 100644 --- a/service/ocm/service_handler.go +++ b/service/ocm/service_handler.go @@ -291,10 +291,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Rewrite response headers for external users - if userConfig != nil && userConfig.ExternalCredential != "" { - s.rewriteResponseHeadersForExternalUser(response.Header, provider, userConfig) - } + s.rewriteResponseHeaders(response.Header, provider, userConfig) for key, values := range response.Header { if !isHopByHopHeader(key) && !isReverseProxyHeader(key) { diff --git a/service/ocm/service_status.go b/service/ocm/service_status.go index bc167a65c..f66b469c2 100644 --- a/service/ocm/service_status.go +++ b/service/ocm/service_status.go @@ -6,16 +6,52 @@ import ( "net/http" "strconv" "strings" + "time" "github.com/sagernet/sing-box/option" ) type statusPayload struct { FiveHourUtilization float64 `json:"five_hour_utilization"` + FiveHourReset int64 `json:"five_hour_reset"` WeeklyUtilization float64 `json:"weekly_utilization"` + WeeklyReset int64 `json:"weekly_reset"` PlanWeight float64 `json:"plan_weight"` } +type aggregatedStatus struct { + fiveHourUtilization float64 + weeklyUtilization float64 + totalWeight float64 + fiveHourReset time.Time + weeklyReset time.Time +} + +func resetToEpoch(t time.Time) int64 { + if t.IsZero() { + return 0 + } + return t.Unix() +} + +func (s aggregatedStatus) equal(other aggregatedStatus) bool { + return s.fiveHourUtilization == other.fiveHourUtilization && + s.weeklyUtilization == other.weeklyUtilization && + s.totalWeight == other.totalWeight && + resetToEpoch(s.fiveHourReset) == resetToEpoch(other.fiveHourReset) && + resetToEpoch(s.weeklyReset) == resetToEpoch(other.weeklyReset) +} + +func (s aggregatedStatus) toPayload() statusPayload { + return statusPayload{ + FiveHourUtilization: s.fiveHourUtilization, + FiveHourReset: resetToEpoch(s.fiveHourReset), + WeeklyUtilization: s.weeklyUtilization, + WeeklyReset: resetToEpoch(s.weeklyReset), + PlanWeight: s.totalWeight, + } +} + func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed") @@ -68,15 +104,11 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { } provider.pollIfStale(r.Context()) - avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) + status := s.computeAggregatedUtilization(provider, userConfig) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(statusPayload{ - FiveHourUtilization: avgFiveHour, - WeeklyUtilization: avgWeekly, - PlanWeight: totalWeight, - }) + json.NewEncoder(w).Encode(status.toPayload()) } func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, provider credentialProvider, userConfig *option.OCMUser) { @@ -96,16 +128,11 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro provider.pollIfStale(r.Context()) w.Header().Set("Content-Type", "application/json") - w.Header().Set(statusStreamHeader, "true") w.WriteHeader(http.StatusOK) - lastFiveHour, lastWeekly, lastWeight := s.computeAggregatedUtilization(provider, userConfig) + last := s.computeAggregatedUtilization(provider, userConfig) buf := &bytes.Buffer{} - json.NewEncoder(buf).Encode(statusPayload{ - FiveHourUtilization: lastFiveHour, - WeeklyUtilization: lastWeekly, - PlanWeight: lastWeight, - }) + json.NewEncoder(buf).Encode(last.toPayload()) _, writeErr := w.Write(buf.Bytes()) if writeErr != nil { return @@ -127,19 +154,13 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro } } drained: - fiveHour, weekly, weight := s.computeAggregatedUtilization(provider, userConfig) - if fiveHour == lastFiveHour && weekly == lastWeekly && weight == lastWeight { + current := s.computeAggregatedUtilization(provider, userConfig) + if current.equal(last) { continue } - lastFiveHour = fiveHour - lastWeekly = weekly - lastWeight = weight + last = current buf.Reset() - json.NewEncoder(buf).Encode(statusPayload{ - FiveHourUtilization: fiveHour, - WeeklyUtilization: weekly, - PlanWeight: weight, - }) + json.NewEncoder(buf).Encode(current.toPayload()) _, writeErr = w.Write(buf.Bytes()) if writeErr != nil { return @@ -149,8 +170,11 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro } } -func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) { +func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) aggregatedStatus { var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 + now := time.Now() + var totalWeightedHoursUntil5hReset, total5hResetWeight float64 + var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64 for _, credential := range provider.allCredentials() { if !credential.isAvailable() { continue @@ -173,26 +197,63 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user totalWeightedRemaining5h += remaining5h * weight totalWeightedRemainingWeekly += remainingWeekly * weight totalWeight += weight + + fiveHourReset := credential.fiveHourResetTime() + if !fiveHourReset.IsZero() { + hours := fiveHourReset.Sub(now).Hours() + if hours < 0 { + hours = 0 + } + totalWeightedHoursUntil5hReset += hours * weight + total5hResetWeight += weight + } + weeklyReset := credential.weeklyResetTime() + if !weeklyReset.IsZero() { + hours := weeklyReset.Sub(now).Hours() + if hours < 0 { + hours = 0 + } + totalWeightedHoursUntilWeeklyReset += hours * weight + totalWeeklyResetWeight += weight + } } if totalWeight == 0 { - return 100, 100, 0 + return aggregatedStatus{ + fiveHourUtilization: 100, + weeklyUtilization: 100, + } } - return 100 - totalWeightedRemaining5h/totalWeight, - 100 - totalWeightedRemainingWeekly/totalWeight, - totalWeight + result := aggregatedStatus{ + fiveHourUtilization: 100 - totalWeightedRemaining5h/totalWeight, + weeklyUtilization: 100 - totalWeightedRemainingWeekly/totalWeight, + totalWeight: totalWeight, + } + if total5hResetWeight > 0 { + avgHours := totalWeightedHoursUntil5hReset / total5hResetWeight + result.fiveHourReset = now.Add(time.Duration(avgHours * float64(time.Hour))) + } + if totalWeeklyResetWeight > 0 { + avgHours := totalWeightedHoursUntilWeeklyReset / totalWeeklyResetWeight + result.weeklyReset = now.Add(time.Duration(avgHours * float64(time.Hour))) + } + return result } -func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, provider credentialProvider, userConfig *option.OCMUser) { - avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) - +func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentialProvider, userConfig *option.OCMUser) { + status := s.computeAggregatedUtilization(provider, userConfig) activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) if activeLimitIdentifier == "" { activeLimitIdentifier = "codex" } - - headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64)) - headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64)) - if totalWeight > 0 { - headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64)) + headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(status.fiveHourUtilization, 'f', 2, 64)) + headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(status.weeklyUtilization, 'f', 2, 64)) + if !status.fiveHourReset.IsZero() { + headers.Set("x-"+activeLimitIdentifier+"-primary-reset-at", strconv.FormatInt(status.fiveHourReset.Unix(), 10)) + } + if !status.weeklyReset.IsZero() { + headers.Set("x-"+activeLimitIdentifier+"-secondary-reset-at", strconv.FormatInt(status.weeklyReset.Unix(), 10)) + } + if status.totalWeight > 0 { + headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64)) } } diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 066692ace..bb9640d54 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -252,9 +252,7 @@ func (s *Service) handleWebSocket( clientResponseHeaders[key] = append([]string(nil), values...) } } - if userConfig != nil && userConfig.ExternalCredential != "" { - s.rewriteResponseHeadersForExternalUser(clientResponseHeaders, provider, userConfig) - } + s.rewriteResponseHeaders(clientResponseHeaders, provider, userConfig) clientUpgrader := ws.HTTPUpgrader{ Header: clientResponseHeaders, @@ -292,10 +290,16 @@ func (s *Service) handleWebSocket( upstreamReadWriter = upstreamConn } + rateLimitIdentifier := normalizeRateLimitIdentifier(upstreamResponseHeaders.Get("x-codex-active-limit")) + if rateLimitIdentifier == "" { + rateLimitIdentifier = "codex" + } + + var clientWriteAccess sync.Mutex modelChannel := make(chan string, 1) var waitGroup sync.WaitGroup - waitGroup.Add(2) + waitGroup.Add(3) go func() { defer waitGroup.Done() defer session.Close() @@ -304,7 +308,12 @@ func (s *Service) handleWebSocket( go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint) + s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, &clientWriteAccess, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint) + }() + go func() { + defer waitGroup.Done() + defer session.Close() + s.pushWebSocketAggregatedStatus(ctx, clientConn, &clientWriteAccess, provider, userConfig, rateLimitIdentifier) }() waitGroup.Wait() } @@ -363,7 +372,7 @@ func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn } } -func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential Credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { +func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, clientWriteAccess *sync.Mutex, selectedCredential Credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { usageTracker := selectedCredential.usageTrackerOrNil() var requestModel string for { @@ -384,11 +393,9 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe switch event.Type { case "codex.rate_limits": s.handleWebSocketRateLimitsEvent(data, selectedCredential) - if userConfig != nil && userConfig.ExternalCredential != "" { - rewritten, rewriteErr := s.rewriteWebSocketRateLimitsForExternalUser(data, provider, userConfig) - if rewriteErr == nil { - data = rewritten - } + rewritten, rewriteErr := s.rewriteWebSocketRateLimits(data, provider, userConfig) + if rewriteErr == nil { + data = rewritten } case "error": if event.StatusCode == http.StatusTooManyRequests { @@ -407,7 +414,9 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe } } + clientWriteAccess.Lock() err = wsutil.WriteServerMessage(clientConn, opCode, data) + clientWriteAccess.Unlock() if err != nil { if !E.IsClosedOrCanceled(err) { s.logger.DebugContext(ctx, "write client websocket: ", err) @@ -483,7 +492,7 @@ func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredentia selectedCredential.markRateLimited(resetAt) } -func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provider credentialProvider, userConfig *option.OCMUser) ([]byte, error) { +func (s *Service) rewriteWebSocketRateLimits(data []byte, provider credentialProvider, userConfig *option.OCMUser) ([]byte, error) { var event map[string]json.RawMessage err := json.Unmarshal(data, &event) if err != nil { @@ -501,13 +510,13 @@ func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provide return nil, err } - averageFiveHour, averageWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) + status := s.computeAggregatedUtilization(provider, userConfig) - if totalWeight > 0 { - event["plan_weight"], _ = json.Marshal(totalWeight) + if status.totalWeight > 0 { + event["plan_weight"], _ = json.Marshal(status.totalWeight) } - primaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["primary"], averageFiveHour) + primaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["primary"], status.fiveHourUtilization, resetToEpoch(status.fiveHourReset)) if err != nil { return nil, err } @@ -515,7 +524,7 @@ func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provide rateLimits["primary"] = primaryData } - secondaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["secondary"], averageWeekly) + secondaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["secondary"], status.weeklyUtilization, resetToEpoch(status.weeklyReset)) if err != nil { return nil, err } @@ -531,7 +540,7 @@ func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provide return json.Marshal(event) } -func rewriteWebSocketRateLimitWindow(data json.RawMessage, usedPercent float64) (json.RawMessage, error) { +func rewriteWebSocketRateLimitWindow(data json.RawMessage, usedPercent float64, resetAt int64) (json.RawMessage, error) { if len(data) == 0 || string(data) == "null" { return nil, nil } @@ -547,9 +556,93 @@ func rewriteWebSocketRateLimitWindow(data json.RawMessage, usedPercent float64) return nil, err } + if resetAt > 0 { + window["reset_at"], err = json.Marshal(resetAt) + if err != nil { + return nil, err + } + } + return json.Marshal(window) } +func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn net.Conn, clientWriteAccess *sync.Mutex, provider credentialProvider, userConfig *option.OCMUser, rateLimitIdentifier string) { + subscription, done, err := s.statusObserver.Subscribe() + if err != nil { + return + } + defer s.statusObserver.UnSubscribe(subscription) + + last := s.computeAggregatedUtilization(provider, userConfig) + data := buildSyntheticRateLimitsEvent(rateLimitIdentifier, last) + clientWriteAccess.Lock() + err = wsutil.WriteServerMessage(clientConn, ws.OpText, data) + clientWriteAccess.Unlock() + if err != nil { + return + } + + for { + select { + case <-ctx.Done(): + return + case <-done: + return + case <-subscription: + for { + select { + case <-subscription: + default: + goto drained + } + } + drained: + current := s.computeAggregatedUtilization(provider, userConfig) + if current.equal(last) { + continue + } + last = current + data = buildSyntheticRateLimitsEvent(rateLimitIdentifier, current) + clientWriteAccess.Lock() + err = wsutil.WriteServerMessage(clientConn, ws.OpText, data) + clientWriteAccess.Unlock() + if err != nil { + return + } + } + } +} + +func buildSyntheticRateLimitsEvent(identifier string, status aggregatedStatus) []byte { + type rateLimitWindow struct { + UsedPercent float64 `json:"used_percent"` + ResetAt int64 `json:"reset_at,omitempty"` + } + event := struct { + Type string `json:"type"` + RateLimits struct { + Primary *rateLimitWindow `json:"primary,omitempty"` + Secondary *rateLimitWindow `json:"secondary,omitempty"` + } `json:"rate_limits"` + LimitName string `json:"limit_name"` + PlanWeight float64 `json:"plan_weight,omitempty"` + }{ + Type: "codex.rate_limits", + LimitName: identifier, + PlanWeight: status.totalWeight, + } + event.RateLimits.Primary = &rateLimitWindow{ + UsedPercent: status.fiveHourUtilization, + ResetAt: resetToEpoch(status.fiveHourReset), + } + event.RateLimits.Secondary = &rateLimitWindow{ + UsedPercent: status.weeklyUtilization, + ResetAt: resetToEpoch(status.weeklyReset), + } + data, _ := json.Marshal(event) + return data +} + func (s *Service) handleWebSocketResponseCompleted(data []byte, usageTracker *AggregatedUsage, requestModel string, username string, weeklyCycleHint *WeeklyCycleHint) { var streamEvent responses.ResponseStreamEventUnion if json.Unmarshal(data, &streamEvent) != nil {