mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
ccm,ocm: propagate reset times, rewrite headers for all users, add WS status push
- Add fiveHourReset/weeklyReset to statusPayload and aggregatedStatus with weight-averaged reset time aggregation across credential pools - Rewrite response headers (utilization + reset times) for all users, not just external credential users - Rewrite WebSocket rate_limits events for all users with aggregated values - Add proactive WebSocket status push: synthetic codex.rate_limits events sent on connection start and on status changes via statusObserver - Remove one-shot stream forward compatibility (statusStreamHeader, restoreLastUpdatedIfUnchanged, oneShot detection)
This commit is contained in:
@@ -105,6 +105,7 @@ type Credential interface {
|
||||
fiveHourCap() float64
|
||||
weeklyCap() float64
|
||||
planWeight() float64
|
||||
fiveHourResetTime() time.Time
|
||||
weeklyResetTime() time.Time
|
||||
markRateLimited(resetAt time.Time)
|
||||
earliestReset() time.Time
|
||||
|
||||
@@ -421,6 +421,12 @@ func (c *defaultCredential) planWeight() float64 {
|
||||
return ccmPlanWeight(c.state.accountType, c.state.rateLimitTier)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) fiveHourResetTime() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.fiveHourReset
|
||||
}
|
||||
|
||||
func (c *defaultCredential) weeklyResetTime() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
|
||||
@@ -27,10 +27,7 @@ import (
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
const (
|
||||
reverseProxyBaseURL = "http://reverse-proxy"
|
||||
statusStreamHeader = "X-CCM-Status-Stream"
|
||||
)
|
||||
const reverseProxyBaseURL = "http://reverse-proxy"
|
||||
|
||||
type externalCredential struct {
|
||||
tag string
|
||||
@@ -70,7 +67,6 @@ type externalCredential struct {
|
||||
type statusStreamResult struct {
|
||||
duration time.Duration
|
||||
frames int
|
||||
oneShot bool
|
||||
}
|
||||
|
||||
func externalCredentialURLPort(parsedURL *url.URL) uint16 {
|
||||
@@ -325,6 +321,12 @@ func (c *externalCredential) planWeight() float64 {
|
||||
return 10
|
||||
}
|
||||
|
||||
func (c *externalCredential) fiveHourResetTime() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.fiveHourReset
|
||||
}
|
||||
|
||||
func (c *externalCredential) weeklyResetTime() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
@@ -592,7 +594,9 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
|
||||
|
||||
var statusResponse struct {
|
||||
FiveHourUtilization float64 `json:"five_hour_utilization"`
|
||||
FiveHourReset int64 `json:"five_hour_reset"`
|
||||
WeeklyUtilization float64 `json:"weekly_utilization"`
|
||||
WeeklyReset int64 `json:"weekly_reset"`
|
||||
PlanWeight float64 `json:"plan_weight"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&statusResponse)
|
||||
@@ -612,6 +616,12 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
|
||||
if statusResponse.PlanWeight > 0 {
|
||||
c.state.remotePlanWeight = statusResponse.PlanWeight
|
||||
}
|
||||
if statusResponse.FiveHourReset > 0 {
|
||||
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
|
||||
}
|
||||
if statusResponse.WeeklyReset > 0 {
|
||||
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
|
||||
}
|
||||
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
@@ -645,13 +655,8 @@ func (c *externalCredential) statusStreamLoop() {
|
||||
return
|
||||
}
|
||||
var backoff time.Duration
|
||||
var oneShot bool
|
||||
consecutiveFailures, backoff, oneShot = c.nextStatusStreamBackoff(result, consecutiveFailures)
|
||||
if oneShot {
|
||||
c.logger.Debug("status stream for ", c.tag, " returned a single-frame response, retrying in ", backoff)
|
||||
} else {
|
||||
c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff)
|
||||
}
|
||||
consecutiveFailures, backoff = c.nextStatusStreamBackoff(result, consecutiveFailures)
|
||||
c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff)
|
||||
timer := time.NewTimer(backoff)
|
||||
select {
|
||||
case <-timer.C:
|
||||
@@ -679,18 +684,11 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(response.Body)
|
||||
isStatusStream := response.Header.Get(statusStreamHeader) == "true"
|
||||
previousLastUpdated := c.lastUpdatedTime()
|
||||
var firstFrameUpdatedAt time.Time
|
||||
for {
|
||||
var statusResponse statusPayload
|
||||
err = decoder.Decode(&statusResponse)
|
||||
if err != nil {
|
||||
result.duration = time.Since(startTime)
|
||||
if result.frames == 1 && err == io.EOF && !isStatusStream {
|
||||
result.oneShot = true
|
||||
c.restoreLastUpdatedIfUnchanged(firstFrameUpdatedAt, previousLastUpdated)
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
@@ -701,6 +699,12 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr
|
||||
if statusResponse.PlanWeight > 0 {
|
||||
c.state.remotePlanWeight = statusResponse.PlanWeight
|
||||
}
|
||||
if statusResponse.FiveHourReset > 0 {
|
||||
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
|
||||
}
|
||||
if statusResponse.WeeklyReset > 0 {
|
||||
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
|
||||
}
|
||||
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
@@ -710,23 +714,17 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr
|
||||
c.interruptConnections()
|
||||
}
|
||||
result.frames++
|
||||
updatedAt := c.markUsageStreamUpdated()
|
||||
if result.frames == 1 {
|
||||
firstFrameUpdatedAt = updatedAt
|
||||
}
|
||||
c.markUsageStreamUpdated()
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration, bool) {
|
||||
if result.oneShot {
|
||||
return 0, c.pollInterval, true
|
||||
}
|
||||
func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration) {
|
||||
if result.duration >= connectorBackoffResetThreshold {
|
||||
consecutiveFailures = 0
|
||||
}
|
||||
consecutiveFailures++
|
||||
return consecutiveFailures, connectorBackoff(consecutiveFailures), false
|
||||
return consecutiveFailures, connectorBackoff(consecutiveFailures)
|
||||
}
|
||||
|
||||
func (c *externalCredential) doStreamStatusRequest(ctx context.Context) (*http.Response, error) {
|
||||
@@ -767,23 +765,10 @@ func (c *externalCredential) lastUpdatedTime() time.Time {
|
||||
return c.state.lastUpdated
|
||||
}
|
||||
|
||||
func (c *externalCredential) markUsageStreamUpdated() time.Time {
|
||||
func (c *externalCredential) markUsageStreamUpdated() {
|
||||
c.stateAccess.Lock()
|
||||
defer c.stateAccess.Unlock()
|
||||
now := time.Now()
|
||||
c.state.lastUpdated = now
|
||||
return now
|
||||
}
|
||||
|
||||
func (c *externalCredential) restoreLastUpdatedIfUnchanged(expectedCurrent time.Time, previous time.Time) {
|
||||
if expectedCurrent.IsZero() {
|
||||
return
|
||||
}
|
||||
c.stateAccess.Lock()
|
||||
defer c.stateAccess.Unlock()
|
||||
if c.state.lastUpdated.Equal(expectedCurrent) {
|
||||
c.state.lastUpdated = previous
|
||||
}
|
||||
c.state.lastUpdated = time.Now()
|
||||
}
|
||||
|
||||
func (c *externalCredential) markUsagePollAttempted() {
|
||||
|
||||
@@ -84,7 +84,7 @@ func newTestYamuxSessionPair(t *testing.T) (*yamux.Session, *yamux.Session) {
|
||||
return clientSession, serverSession
|
||||
}
|
||||
|
||||
func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *testing.T) {
|
||||
func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) {
|
||||
credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil)
|
||||
oldTime := time.Unix(123, 0)
|
||||
credential.stateAccess.Lock()
|
||||
@@ -95,50 +95,6 @@ func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *test
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %v", err)
|
||||
}
|
||||
if !result.oneShot {
|
||||
t.Fatal("expected one-shot result")
|
||||
}
|
||||
if result.frames != 1 {
|
||||
t.Fatalf("expected 1 frame, got %d", result.frames)
|
||||
}
|
||||
if !credential.lastUpdatedTime().Equal(oldTime) {
|
||||
t.Fatalf("expected lastUpdated restored to %v, got %v", oldTime, credential.lastUpdatedTime())
|
||||
}
|
||||
if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 {
|
||||
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
|
||||
}
|
||||
if count := drainStatusEvents(subscription); count != 1 {
|
||||
t.Fatalf("expected 1 status event, got %d", count)
|
||||
}
|
||||
|
||||
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
|
||||
if !oneShot {
|
||||
t.Fatal("expected one-shot backoff branch")
|
||||
}
|
||||
if failures != 0 {
|
||||
t.Fatalf("expected failures reset, got %d", failures)
|
||||
}
|
||||
if backoff != credential.pollInterval {
|
||||
t.Fatalf("expected poll interval backoff %v, got %v", credential.pollInterval, backoff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) {
|
||||
headers := make(http.Header)
|
||||
headers.Set(statusStreamHeader, "true")
|
||||
credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", headers)
|
||||
oldTime := time.Unix(123, 0)
|
||||
credential.stateAccess.Lock()
|
||||
credential.state.lastUpdated = oldTime
|
||||
credential.stateAccess.Unlock()
|
||||
|
||||
result, err := credential.connectStatusStream(context.Background())
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %v", err)
|
||||
}
|
||||
if result.oneShot {
|
||||
t.Fatal("did not expect one-shot result")
|
||||
}
|
||||
if result.frames != 1 {
|
||||
t.Fatalf("expected 1 frame, got %d", result.frames)
|
||||
}
|
||||
@@ -152,10 +108,7 @@ func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *tes
|
||||
t.Fatalf("expected 1 status event, got %d", count)
|
||||
}
|
||||
|
||||
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
|
||||
if oneShot {
|
||||
t.Fatal("did not expect one-shot backoff branch")
|
||||
}
|
||||
failures, backoff := credential.nextStatusStreamBackoff(result, 3)
|
||||
if failures != 4 {
|
||||
t.Fatalf("expected failures incremented to 4, got %d", failures)
|
||||
}
|
||||
@@ -178,9 +131,6 @@ func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *test
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %v", err)
|
||||
}
|
||||
if result.oneShot {
|
||||
t.Fatal("did not expect one-shot result")
|
||||
}
|
||||
if result.frames != 2 {
|
||||
t.Fatalf("expected 2 frames, got %d", result.frames)
|
||||
}
|
||||
|
||||
@@ -311,10 +311,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite response headers for external users
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" {
|
||||
s.rewriteResponseHeadersForExternalUser(response.Header, provider, userConfig)
|
||||
}
|
||||
s.rewriteResponseHeaders(response.Header, provider, userConfig)
|
||||
|
||||
for key, values := range response.Header {
|
||||
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
|
||||
|
||||
@@ -6,16 +6,52 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
type statusPayload struct {
|
||||
FiveHourUtilization float64 `json:"five_hour_utilization"`
|
||||
FiveHourReset int64 `json:"five_hour_reset"`
|
||||
WeeklyUtilization float64 `json:"weekly_utilization"`
|
||||
WeeklyReset int64 `json:"weekly_reset"`
|
||||
PlanWeight float64 `json:"plan_weight"`
|
||||
}
|
||||
|
||||
type aggregatedStatus struct {
|
||||
fiveHourUtilization float64
|
||||
weeklyUtilization float64
|
||||
totalWeight float64
|
||||
fiveHourReset time.Time
|
||||
weeklyReset time.Time
|
||||
}
|
||||
|
||||
func resetToEpoch(t time.Time) int64 {
|
||||
if t.IsZero() {
|
||||
return 0
|
||||
}
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
func (s aggregatedStatus) equal(other aggregatedStatus) bool {
|
||||
return s.fiveHourUtilization == other.fiveHourUtilization &&
|
||||
s.weeklyUtilization == other.weeklyUtilization &&
|
||||
s.totalWeight == other.totalWeight &&
|
||||
resetToEpoch(s.fiveHourReset) == resetToEpoch(other.fiveHourReset) &&
|
||||
resetToEpoch(s.weeklyReset) == resetToEpoch(other.weeklyReset)
|
||||
}
|
||||
|
||||
func (s aggregatedStatus) toPayload() statusPayload {
|
||||
return statusPayload{
|
||||
FiveHourUtilization: s.fiveHourUtilization,
|
||||
FiveHourReset: resetToEpoch(s.fiveHourReset),
|
||||
WeeklyUtilization: s.weeklyUtilization,
|
||||
WeeklyReset: resetToEpoch(s.weeklyReset),
|
||||
PlanWeight: s.totalWeight,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
|
||||
@@ -68,15 +104,11 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
provider.pollIfStale(r.Context())
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
status := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(statusPayload{
|
||||
FiveHourUtilization: avgFiveHour,
|
||||
WeeklyUtilization: avgWeekly,
|
||||
PlanWeight: totalWeight,
|
||||
})
|
||||
json.NewEncoder(w).Encode(status.toPayload())
|
||||
}
|
||||
|
||||
func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, provider credentialProvider, userConfig *option.CCMUser) {
|
||||
@@ -96,16 +128,11 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro
|
||||
provider.pollIfStale(r.Context())
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set(statusStreamHeader, "true")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
lastFiveHour, lastWeekly, lastWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
last := s.computeAggregatedUtilization(provider, userConfig)
|
||||
buf := &bytes.Buffer{}
|
||||
json.NewEncoder(buf).Encode(statusPayload{
|
||||
FiveHourUtilization: lastFiveHour,
|
||||
WeeklyUtilization: lastWeekly,
|
||||
PlanWeight: lastWeight,
|
||||
})
|
||||
json.NewEncoder(buf).Encode(last.toPayload())
|
||||
_, writeErr := w.Write(buf.Bytes())
|
||||
if writeErr != nil {
|
||||
return
|
||||
@@ -127,19 +154,13 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro
|
||||
}
|
||||
}
|
||||
drained:
|
||||
fiveHour, weekly, weight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
if fiveHour == lastFiveHour && weekly == lastWeekly && weight == lastWeight {
|
||||
current := s.computeAggregatedUtilization(provider, userConfig)
|
||||
if current.equal(last) {
|
||||
continue
|
||||
}
|
||||
lastFiveHour = fiveHour
|
||||
lastWeekly = weekly
|
||||
lastWeight = weight
|
||||
last = current
|
||||
buf.Reset()
|
||||
json.NewEncoder(buf).Encode(statusPayload{
|
||||
FiveHourUtilization: fiveHour,
|
||||
WeeklyUtilization: weekly,
|
||||
PlanWeight: weight,
|
||||
})
|
||||
json.NewEncoder(buf).Encode(current.toPayload())
|
||||
_, writeErr = w.Write(buf.Bytes())
|
||||
if writeErr != nil {
|
||||
return
|
||||
@@ -149,8 +170,11 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) {
|
||||
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) aggregatedStatus {
|
||||
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
|
||||
now := time.Now()
|
||||
var totalWeightedHoursUntil5hReset, total5hResetWeight float64
|
||||
var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64
|
||||
for _, credential := range provider.allCredentials() {
|
||||
if !credential.isAvailable() {
|
||||
continue
|
||||
@@ -173,21 +197,59 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user
|
||||
totalWeightedRemaining5h += remaining5h * weight
|
||||
totalWeightedRemainingWeekly += remainingWeekly * weight
|
||||
totalWeight += weight
|
||||
|
||||
fiveHourReset := credential.fiveHourResetTime()
|
||||
if !fiveHourReset.IsZero() {
|
||||
hours := fiveHourReset.Sub(now).Hours()
|
||||
if hours < 0 {
|
||||
hours = 0
|
||||
}
|
||||
totalWeightedHoursUntil5hReset += hours * weight
|
||||
total5hResetWeight += weight
|
||||
}
|
||||
weeklyReset := credential.weeklyResetTime()
|
||||
if !weeklyReset.IsZero() {
|
||||
hours := weeklyReset.Sub(now).Hours()
|
||||
if hours < 0 {
|
||||
hours = 0
|
||||
}
|
||||
totalWeightedHoursUntilWeeklyReset += hours * weight
|
||||
totalWeeklyResetWeight += weight
|
||||
}
|
||||
}
|
||||
if totalWeight == 0 {
|
||||
return 100, 100, 0
|
||||
return aggregatedStatus{
|
||||
fiveHourUtilization: 100,
|
||||
weeklyUtilization: 100,
|
||||
}
|
||||
}
|
||||
return 100 - totalWeightedRemaining5h/totalWeight,
|
||||
100 - totalWeightedRemainingWeekly/totalWeight,
|
||||
totalWeight
|
||||
result := aggregatedStatus{
|
||||
fiveHourUtilization: 100 - totalWeightedRemaining5h/totalWeight,
|
||||
weeklyUtilization: 100 - totalWeightedRemainingWeekly/totalWeight,
|
||||
totalWeight: totalWeight,
|
||||
}
|
||||
if total5hResetWeight > 0 {
|
||||
avgHours := totalWeightedHoursUntil5hReset / total5hResetWeight
|
||||
result.fiveHourReset = now.Add(time.Duration(avgHours * float64(time.Hour)))
|
||||
}
|
||||
if totalWeeklyResetWeight > 0 {
|
||||
avgHours := totalWeightedHoursUntilWeeklyReset / totalWeeklyResetWeight
|
||||
result.weeklyReset = now.Add(time.Duration(avgHours * float64(time.Hour)))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, provider credentialProvider, userConfig *option.CCMUser) {
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64))
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64))
|
||||
if totalWeight > 0 {
|
||||
headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64))
|
||||
func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentialProvider, userConfig *option.CCMUser) {
|
||||
status := s.computeAggregatedUtilization(provider, userConfig)
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(status.fiveHourUtilization/100, 'f', 6, 64))
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(status.weeklyUtilization/100, 'f', 6, 64))
|
||||
if !status.fiveHourReset.IsZero() {
|
||||
headers.Set("anthropic-ratelimit-unified-5h-reset", strconv.FormatInt(status.fiveHourReset.Unix(), 10))
|
||||
}
|
||||
if !status.weeklyReset.IsZero() {
|
||||
headers.Set("anthropic-ratelimit-unified-7d-reset", strconv.FormatInt(status.weeklyReset.Unix(), 10))
|
||||
}
|
||||
if status.totalWeight > 0 {
|
||||
headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,6 +107,7 @@ type Credential interface {
|
||||
weeklyCap() float64
|
||||
planWeight() float64
|
||||
weeklyResetTime() time.Time
|
||||
fiveHourResetTime() time.Time
|
||||
markRateLimited(resetAt time.Time)
|
||||
earliestReset() time.Time
|
||||
unavailableError() error
|
||||
|
||||
@@ -476,6 +476,12 @@ func (c *defaultCredential) weeklyResetTime() time.Time {
|
||||
return c.state.weeklyReset
|
||||
}
|
||||
|
||||
func (c *defaultCredential) fiveHourResetTime() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.fiveHourReset
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isAvailable() bool {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
|
||||
@@ -28,10 +28,7 @@ import (
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
const (
|
||||
reverseProxyBaseURL = "http://reverse-proxy"
|
||||
statusStreamHeader = "X-OCM-Status-Stream"
|
||||
)
|
||||
const reverseProxyBaseURL = "http://reverse-proxy"
|
||||
|
||||
type externalCredential struct {
|
||||
tag string
|
||||
@@ -77,7 +74,6 @@ type reverseSessionDialer struct {
|
||||
type statusStreamResult struct {
|
||||
duration time.Duration
|
||||
frames int
|
||||
oneShot bool
|
||||
}
|
||||
|
||||
func (d reverseSessionDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
@@ -355,6 +351,12 @@ func (c *externalCredential) weeklyResetTime() time.Time {
|
||||
return c.state.weeklyReset
|
||||
}
|
||||
|
||||
func (c *externalCredential) fiveHourResetTime() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.fiveHourReset
|
||||
}
|
||||
|
||||
func (c *externalCredential) markRateLimited(resetAt time.Time) {
|
||||
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
|
||||
c.stateAccess.Lock()
|
||||
@@ -634,7 +636,9 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
|
||||
|
||||
var statusResponse struct {
|
||||
FiveHourUtilization float64 `json:"five_hour_utilization"`
|
||||
FiveHourReset int64 `json:"five_hour_reset"`
|
||||
WeeklyUtilization float64 `json:"weekly_utilization"`
|
||||
WeeklyReset int64 `json:"weekly_reset"`
|
||||
PlanWeight float64 `json:"plan_weight"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&statusResponse)
|
||||
@@ -651,6 +655,12 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
|
||||
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
|
||||
if statusResponse.FiveHourReset > 0 {
|
||||
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
|
||||
}
|
||||
if statusResponse.WeeklyReset > 0 {
|
||||
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
|
||||
}
|
||||
if statusResponse.PlanWeight > 0 {
|
||||
c.state.remotePlanWeight = statusResponse.PlanWeight
|
||||
}
|
||||
@@ -687,13 +697,8 @@ func (c *externalCredential) statusStreamLoop() {
|
||||
return
|
||||
}
|
||||
var backoff time.Duration
|
||||
var oneShot bool
|
||||
consecutiveFailures, backoff, oneShot = c.nextStatusStreamBackoff(result, consecutiveFailures)
|
||||
if oneShot {
|
||||
c.logger.Debug("status stream for ", c.tag, " returned a single-frame response, retrying in ", backoff)
|
||||
} else {
|
||||
c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff)
|
||||
}
|
||||
consecutiveFailures, backoff = c.nextStatusStreamBackoff(result, consecutiveFailures)
|
||||
c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff)
|
||||
timer := time.NewTimer(backoff)
|
||||
select {
|
||||
case <-timer.C:
|
||||
@@ -721,18 +726,11 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(response.Body)
|
||||
isStatusStream := response.Header.Get(statusStreamHeader) == "true"
|
||||
previousLastUpdated := c.lastUpdatedTime()
|
||||
var firstFrameUpdatedAt time.Time
|
||||
for {
|
||||
var statusResponse statusPayload
|
||||
err = decoder.Decode(&statusResponse)
|
||||
if err != nil {
|
||||
result.duration = time.Since(startTime)
|
||||
if result.frames == 1 && err == io.EOF && !isStatusStream {
|
||||
result.oneShot = true
|
||||
c.restoreLastUpdatedIfUnchanged(firstFrameUpdatedAt, previousLastUpdated)
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
@@ -740,6 +738,12 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
|
||||
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
|
||||
if statusResponse.FiveHourReset > 0 {
|
||||
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
|
||||
}
|
||||
if statusResponse.WeeklyReset > 0 {
|
||||
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
|
||||
}
|
||||
if statusResponse.PlanWeight > 0 {
|
||||
c.state.remotePlanWeight = statusResponse.PlanWeight
|
||||
}
|
||||
@@ -752,23 +756,17 @@ func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStr
|
||||
c.interruptConnections()
|
||||
}
|
||||
result.frames++
|
||||
updatedAt := c.markUsageStreamUpdated()
|
||||
if result.frames == 1 {
|
||||
firstFrameUpdatedAt = updatedAt
|
||||
}
|
||||
c.markUsageStreamUpdated()
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration, bool) {
|
||||
if result.oneShot {
|
||||
return 0, c.pollInterval, true
|
||||
}
|
||||
func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration) {
|
||||
if result.duration >= connectorBackoffResetThreshold {
|
||||
consecutiveFailures = 0
|
||||
}
|
||||
consecutiveFailures++
|
||||
return consecutiveFailures, connectorBackoff(consecutiveFailures), false
|
||||
return consecutiveFailures, connectorBackoff(consecutiveFailures)
|
||||
}
|
||||
|
||||
func (c *externalCredential) doStreamStatusRequest(ctx context.Context) (*http.Response, error) {
|
||||
@@ -809,23 +807,10 @@ func (c *externalCredential) lastUpdatedTime() time.Time {
|
||||
return c.state.lastUpdated
|
||||
}
|
||||
|
||||
func (c *externalCredential) markUsageStreamUpdated() time.Time {
|
||||
func (c *externalCredential) markUsageStreamUpdated() {
|
||||
c.stateAccess.Lock()
|
||||
defer c.stateAccess.Unlock()
|
||||
now := time.Now()
|
||||
c.state.lastUpdated = now
|
||||
return now
|
||||
}
|
||||
|
||||
func (c *externalCredential) restoreLastUpdatedIfUnchanged(expectedCurrent time.Time, previous time.Time) {
|
||||
if expectedCurrent.IsZero() {
|
||||
return
|
||||
}
|
||||
c.stateAccess.Lock()
|
||||
defer c.stateAccess.Unlock()
|
||||
if c.state.lastUpdated.Equal(expectedCurrent) {
|
||||
c.state.lastUpdated = previous
|
||||
}
|
||||
c.state.lastUpdated = time.Now()
|
||||
}
|
||||
|
||||
func (c *externalCredential) markUsagePollAttempted() {
|
||||
|
||||
@@ -84,7 +84,7 @@ func newTestYamuxSessionPair(t *testing.T) (*yamux.Session, *yamux.Session) {
|
||||
return clientSession, serverSession
|
||||
}
|
||||
|
||||
func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *testing.T) {
|
||||
func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) {
|
||||
credential, subscription := newTestOCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil)
|
||||
oldTime := time.Unix(123, 0)
|
||||
credential.stateAccess.Lock()
|
||||
@@ -95,50 +95,6 @@ func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *test
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %v", err)
|
||||
}
|
||||
if !result.oneShot {
|
||||
t.Fatal("expected one-shot result")
|
||||
}
|
||||
if result.frames != 1 {
|
||||
t.Fatalf("expected 1 frame, got %d", result.frames)
|
||||
}
|
||||
if !credential.lastUpdatedTime().Equal(oldTime) {
|
||||
t.Fatalf("expected lastUpdated restored to %v, got %v", oldTime, credential.lastUpdatedTime())
|
||||
}
|
||||
if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 {
|
||||
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
|
||||
}
|
||||
if count := drainStatusEvents(subscription); count != 1 {
|
||||
t.Fatalf("expected 1 status event, got %d", count)
|
||||
}
|
||||
|
||||
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
|
||||
if !oneShot {
|
||||
t.Fatal("expected one-shot backoff branch")
|
||||
}
|
||||
if failures != 0 {
|
||||
t.Fatalf("expected failures reset, got %d", failures)
|
||||
}
|
||||
if backoff != credential.pollInterval {
|
||||
t.Fatalf("expected poll interval backoff %v, got %v", credential.pollInterval, backoff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) {
|
||||
headers := make(http.Header)
|
||||
headers.Set(statusStreamHeader, "true")
|
||||
credential, subscription := newTestOCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", headers)
|
||||
oldTime := time.Unix(123, 0)
|
||||
credential.stateAccess.Lock()
|
||||
credential.state.lastUpdated = oldTime
|
||||
credential.stateAccess.Unlock()
|
||||
|
||||
result, err := credential.connectStatusStream(context.Background())
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %v", err)
|
||||
}
|
||||
if result.oneShot {
|
||||
t.Fatal("did not expect one-shot result")
|
||||
}
|
||||
if result.frames != 1 {
|
||||
t.Fatalf("expected 1 frame, got %d", result.frames)
|
||||
}
|
||||
@@ -152,10 +108,7 @@ func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *tes
|
||||
t.Fatalf("expected 1 status event, got %d", count)
|
||||
}
|
||||
|
||||
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
|
||||
if oneShot {
|
||||
t.Fatal("did not expect one-shot backoff branch")
|
||||
}
|
||||
failures, backoff := credential.nextStatusStreamBackoff(result, 3)
|
||||
if failures != 4 {
|
||||
t.Fatalf("expected failures incremented to 4, got %d", failures)
|
||||
}
|
||||
@@ -178,9 +131,6 @@ func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *test
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %v", err)
|
||||
}
|
||||
if result.oneShot {
|
||||
t.Fatal("did not expect one-shot result")
|
||||
}
|
||||
if result.frames != 2 {
|
||||
t.Fatalf("expected 2 frames, got %d", result.frames)
|
||||
}
|
||||
|
||||
@@ -291,10 +291,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite response headers for external users
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" {
|
||||
s.rewriteResponseHeadersForExternalUser(response.Header, provider, userConfig)
|
||||
}
|
||||
s.rewriteResponseHeaders(response.Header, provider, userConfig)
|
||||
|
||||
for key, values := range response.Header {
|
||||
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
|
||||
|
||||
@@ -6,16 +6,52 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
type statusPayload struct {
|
||||
FiveHourUtilization float64 `json:"five_hour_utilization"`
|
||||
FiveHourReset int64 `json:"five_hour_reset"`
|
||||
WeeklyUtilization float64 `json:"weekly_utilization"`
|
||||
WeeklyReset int64 `json:"weekly_reset"`
|
||||
PlanWeight float64 `json:"plan_weight"`
|
||||
}
|
||||
|
||||
type aggregatedStatus struct {
|
||||
fiveHourUtilization float64
|
||||
weeklyUtilization float64
|
||||
totalWeight float64
|
||||
fiveHourReset time.Time
|
||||
weeklyReset time.Time
|
||||
}
|
||||
|
||||
func resetToEpoch(t time.Time) int64 {
|
||||
if t.IsZero() {
|
||||
return 0
|
||||
}
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
func (s aggregatedStatus) equal(other aggregatedStatus) bool {
|
||||
return s.fiveHourUtilization == other.fiveHourUtilization &&
|
||||
s.weeklyUtilization == other.weeklyUtilization &&
|
||||
s.totalWeight == other.totalWeight &&
|
||||
resetToEpoch(s.fiveHourReset) == resetToEpoch(other.fiveHourReset) &&
|
||||
resetToEpoch(s.weeklyReset) == resetToEpoch(other.weeklyReset)
|
||||
}
|
||||
|
||||
func (s aggregatedStatus) toPayload() statusPayload {
|
||||
return statusPayload{
|
||||
FiveHourUtilization: s.fiveHourUtilization,
|
||||
FiveHourReset: resetToEpoch(s.fiveHourReset),
|
||||
WeeklyUtilization: s.weeklyUtilization,
|
||||
WeeklyReset: resetToEpoch(s.weeklyReset),
|
||||
PlanWeight: s.totalWeight,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
|
||||
@@ -68,15 +104,11 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
provider.pollIfStale(r.Context())
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
status := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(statusPayload{
|
||||
FiveHourUtilization: avgFiveHour,
|
||||
WeeklyUtilization: avgWeekly,
|
||||
PlanWeight: totalWeight,
|
||||
})
|
||||
json.NewEncoder(w).Encode(status.toPayload())
|
||||
}
|
||||
|
||||
func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, provider credentialProvider, userConfig *option.OCMUser) {
|
||||
@@ -96,16 +128,11 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro
|
||||
provider.pollIfStale(r.Context())
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set(statusStreamHeader, "true")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
lastFiveHour, lastWeekly, lastWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
last := s.computeAggregatedUtilization(provider, userConfig)
|
||||
buf := &bytes.Buffer{}
|
||||
json.NewEncoder(buf).Encode(statusPayload{
|
||||
FiveHourUtilization: lastFiveHour,
|
||||
WeeklyUtilization: lastWeekly,
|
||||
PlanWeight: lastWeight,
|
||||
})
|
||||
json.NewEncoder(buf).Encode(last.toPayload())
|
||||
_, writeErr := w.Write(buf.Bytes())
|
||||
if writeErr != nil {
|
||||
return
|
||||
@@ -127,19 +154,13 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro
|
||||
}
|
||||
}
|
||||
drained:
|
||||
fiveHour, weekly, weight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
if fiveHour == lastFiveHour && weekly == lastWeekly && weight == lastWeight {
|
||||
current := s.computeAggregatedUtilization(provider, userConfig)
|
||||
if current.equal(last) {
|
||||
continue
|
||||
}
|
||||
lastFiveHour = fiveHour
|
||||
lastWeekly = weekly
|
||||
lastWeight = weight
|
||||
last = current
|
||||
buf.Reset()
|
||||
json.NewEncoder(buf).Encode(statusPayload{
|
||||
FiveHourUtilization: fiveHour,
|
||||
WeeklyUtilization: weekly,
|
||||
PlanWeight: weight,
|
||||
})
|
||||
json.NewEncoder(buf).Encode(current.toPayload())
|
||||
_, writeErr = w.Write(buf.Bytes())
|
||||
if writeErr != nil {
|
||||
return
|
||||
@@ -149,8 +170,11 @@ func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, pro
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) {
|
||||
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) aggregatedStatus {
|
||||
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
|
||||
now := time.Now()
|
||||
var totalWeightedHoursUntil5hReset, total5hResetWeight float64
|
||||
var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64
|
||||
for _, credential := range provider.allCredentials() {
|
||||
if !credential.isAvailable() {
|
||||
continue
|
||||
@@ -173,26 +197,63 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user
|
||||
totalWeightedRemaining5h += remaining5h * weight
|
||||
totalWeightedRemainingWeekly += remainingWeekly * weight
|
||||
totalWeight += weight
|
||||
|
||||
fiveHourReset := credential.fiveHourResetTime()
|
||||
if !fiveHourReset.IsZero() {
|
||||
hours := fiveHourReset.Sub(now).Hours()
|
||||
if hours < 0 {
|
||||
hours = 0
|
||||
}
|
||||
totalWeightedHoursUntil5hReset += hours * weight
|
||||
total5hResetWeight += weight
|
||||
}
|
||||
weeklyReset := credential.weeklyResetTime()
|
||||
if !weeklyReset.IsZero() {
|
||||
hours := weeklyReset.Sub(now).Hours()
|
||||
if hours < 0 {
|
||||
hours = 0
|
||||
}
|
||||
totalWeightedHoursUntilWeeklyReset += hours * weight
|
||||
totalWeeklyResetWeight += weight
|
||||
}
|
||||
}
|
||||
if totalWeight == 0 {
|
||||
return 100, 100, 0
|
||||
return aggregatedStatus{
|
||||
fiveHourUtilization: 100,
|
||||
weeklyUtilization: 100,
|
||||
}
|
||||
}
|
||||
return 100 - totalWeightedRemaining5h/totalWeight,
|
||||
100 - totalWeightedRemainingWeekly/totalWeight,
|
||||
totalWeight
|
||||
result := aggregatedStatus{
|
||||
fiveHourUtilization: 100 - totalWeightedRemaining5h/totalWeight,
|
||||
weeklyUtilization: 100 - totalWeightedRemainingWeekly/totalWeight,
|
||||
totalWeight: totalWeight,
|
||||
}
|
||||
if total5hResetWeight > 0 {
|
||||
avgHours := totalWeightedHoursUntil5hReset / total5hResetWeight
|
||||
result.fiveHourReset = now.Add(time.Duration(avgHours * float64(time.Hour)))
|
||||
}
|
||||
if totalWeeklyResetWeight > 0 {
|
||||
avgHours := totalWeightedHoursUntilWeeklyReset / totalWeeklyResetWeight
|
||||
result.weeklyReset = now.Add(time.Duration(avgHours * float64(time.Hour)))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, provider credentialProvider, userConfig *option.OCMUser) {
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentialProvider, userConfig *option.OCMUser) {
|
||||
status := s.computeAggregatedUtilization(provider, userConfig)
|
||||
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
|
||||
if activeLimitIdentifier == "" {
|
||||
activeLimitIdentifier = "codex"
|
||||
}
|
||||
|
||||
headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64))
|
||||
headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64))
|
||||
if totalWeight > 0 {
|
||||
headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64))
|
||||
headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(status.fiveHourUtilization, 'f', 2, 64))
|
||||
headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(status.weeklyUtilization, 'f', 2, 64))
|
||||
if !status.fiveHourReset.IsZero() {
|
||||
headers.Set("x-"+activeLimitIdentifier+"-primary-reset-at", strconv.FormatInt(status.fiveHourReset.Unix(), 10))
|
||||
}
|
||||
if !status.weeklyReset.IsZero() {
|
||||
headers.Set("x-"+activeLimitIdentifier+"-secondary-reset-at", strconv.FormatInt(status.weeklyReset.Unix(), 10))
|
||||
}
|
||||
if status.totalWeight > 0 {
|
||||
headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -252,9 +252,7 @@ func (s *Service) handleWebSocket(
|
||||
clientResponseHeaders[key] = append([]string(nil), values...)
|
||||
}
|
||||
}
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" {
|
||||
s.rewriteResponseHeadersForExternalUser(clientResponseHeaders, provider, userConfig)
|
||||
}
|
||||
s.rewriteResponseHeaders(clientResponseHeaders, provider, userConfig)
|
||||
|
||||
clientUpgrader := ws.HTTPUpgrader{
|
||||
Header: clientResponseHeaders,
|
||||
@@ -292,10 +290,16 @@ func (s *Service) handleWebSocket(
|
||||
upstreamReadWriter = upstreamConn
|
||||
}
|
||||
|
||||
rateLimitIdentifier := normalizeRateLimitIdentifier(upstreamResponseHeaders.Get("x-codex-active-limit"))
|
||||
if rateLimitIdentifier == "" {
|
||||
rateLimitIdentifier = "codex"
|
||||
}
|
||||
|
||||
var clientWriteAccess sync.Mutex
|
||||
modelChannel := make(chan string, 1)
|
||||
var waitGroup sync.WaitGroup
|
||||
|
||||
waitGroup.Add(2)
|
||||
waitGroup.Add(3)
|
||||
go func() {
|
||||
defer waitGroup.Done()
|
||||
defer session.Close()
|
||||
@@ -304,7 +308,12 @@ func (s *Service) handleWebSocket(
|
||||
go func() {
|
||||
defer waitGroup.Done()
|
||||
defer session.Close()
|
||||
s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint)
|
||||
s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, &clientWriteAccess, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint)
|
||||
}()
|
||||
go func() {
|
||||
defer waitGroup.Done()
|
||||
defer session.Close()
|
||||
s.pushWebSocketAggregatedStatus(ctx, clientConn, &clientWriteAccess, provider, userConfig, rateLimitIdentifier)
|
||||
}()
|
||||
waitGroup.Wait()
|
||||
}
|
||||
@@ -363,7 +372,7 @@ func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential Credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
|
||||
func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, clientWriteAccess *sync.Mutex, selectedCredential Credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
|
||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||
var requestModel string
|
||||
for {
|
||||
@@ -384,11 +393,9 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe
|
||||
switch event.Type {
|
||||
case "codex.rate_limits":
|
||||
s.handleWebSocketRateLimitsEvent(data, selectedCredential)
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" {
|
||||
rewritten, rewriteErr := s.rewriteWebSocketRateLimitsForExternalUser(data, provider, userConfig)
|
||||
if rewriteErr == nil {
|
||||
data = rewritten
|
||||
}
|
||||
rewritten, rewriteErr := s.rewriteWebSocketRateLimits(data, provider, userConfig)
|
||||
if rewriteErr == nil {
|
||||
data = rewritten
|
||||
}
|
||||
case "error":
|
||||
if event.StatusCode == http.StatusTooManyRequests {
|
||||
@@ -407,7 +414,9 @@ func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamRe
|
||||
}
|
||||
}
|
||||
|
||||
clientWriteAccess.Lock()
|
||||
err = wsutil.WriteServerMessage(clientConn, opCode, data)
|
||||
clientWriteAccess.Unlock()
|
||||
if err != nil {
|
||||
if !E.IsClosedOrCanceled(err) {
|
||||
s.logger.DebugContext(ctx, "write client websocket: ", err)
|
||||
@@ -483,7 +492,7 @@ func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredentia
|
||||
selectedCredential.markRateLimited(resetAt)
|
||||
}
|
||||
|
||||
func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provider credentialProvider, userConfig *option.OCMUser) ([]byte, error) {
|
||||
func (s *Service) rewriteWebSocketRateLimits(data []byte, provider credentialProvider, userConfig *option.OCMUser) ([]byte, error) {
|
||||
var event map[string]json.RawMessage
|
||||
err := json.Unmarshal(data, &event)
|
||||
if err != nil {
|
||||
@@ -501,13 +510,13 @@ func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provide
|
||||
return nil, err
|
||||
}
|
||||
|
||||
averageFiveHour, averageWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
status := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
if totalWeight > 0 {
|
||||
event["plan_weight"], _ = json.Marshal(totalWeight)
|
||||
if status.totalWeight > 0 {
|
||||
event["plan_weight"], _ = json.Marshal(status.totalWeight)
|
||||
}
|
||||
|
||||
primaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["primary"], averageFiveHour)
|
||||
primaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["primary"], status.fiveHourUtilization, resetToEpoch(status.fiveHourReset))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -515,7 +524,7 @@ func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provide
|
||||
rateLimits["primary"] = primaryData
|
||||
}
|
||||
|
||||
secondaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["secondary"], averageWeekly)
|
||||
secondaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["secondary"], status.weeklyUtilization, resetToEpoch(status.weeklyReset))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -531,7 +540,7 @@ func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provide
|
||||
return json.Marshal(event)
|
||||
}
|
||||
|
||||
func rewriteWebSocketRateLimitWindow(data json.RawMessage, usedPercent float64) (json.RawMessage, error) {
|
||||
func rewriteWebSocketRateLimitWindow(data json.RawMessage, usedPercent float64, resetAt int64) (json.RawMessage, error) {
|
||||
if len(data) == 0 || string(data) == "null" {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -547,9 +556,93 @@ func rewriteWebSocketRateLimitWindow(data json.RawMessage, usedPercent float64)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resetAt > 0 {
|
||||
window["reset_at"], err = json.Marshal(resetAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(window)
|
||||
}
|
||||
|
||||
func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn net.Conn, clientWriteAccess *sync.Mutex, provider credentialProvider, userConfig *option.OCMUser, rateLimitIdentifier string) {
|
||||
subscription, done, err := s.statusObserver.Subscribe()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer s.statusObserver.UnSubscribe(subscription)
|
||||
|
||||
last := s.computeAggregatedUtilization(provider, userConfig)
|
||||
data := buildSyntheticRateLimitsEvent(rateLimitIdentifier, last)
|
||||
clientWriteAccess.Lock()
|
||||
err = wsutil.WriteServerMessage(clientConn, ws.OpText, data)
|
||||
clientWriteAccess.Unlock()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-done:
|
||||
return
|
||||
case <-subscription:
|
||||
for {
|
||||
select {
|
||||
case <-subscription:
|
||||
default:
|
||||
goto drained
|
||||
}
|
||||
}
|
||||
drained:
|
||||
current := s.computeAggregatedUtilization(provider, userConfig)
|
||||
if current.equal(last) {
|
||||
continue
|
||||
}
|
||||
last = current
|
||||
data = buildSyntheticRateLimitsEvent(rateLimitIdentifier, current)
|
||||
clientWriteAccess.Lock()
|
||||
err = wsutil.WriteServerMessage(clientConn, ws.OpText, data)
|
||||
clientWriteAccess.Unlock()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildSyntheticRateLimitsEvent(identifier string, status aggregatedStatus) []byte {
|
||||
type rateLimitWindow struct {
|
||||
UsedPercent float64 `json:"used_percent"`
|
||||
ResetAt int64 `json:"reset_at,omitempty"`
|
||||
}
|
||||
event := struct {
|
||||
Type string `json:"type"`
|
||||
RateLimits struct {
|
||||
Primary *rateLimitWindow `json:"primary,omitempty"`
|
||||
Secondary *rateLimitWindow `json:"secondary,omitempty"`
|
||||
} `json:"rate_limits"`
|
||||
LimitName string `json:"limit_name"`
|
||||
PlanWeight float64 `json:"plan_weight,omitempty"`
|
||||
}{
|
||||
Type: "codex.rate_limits",
|
||||
LimitName: identifier,
|
||||
PlanWeight: status.totalWeight,
|
||||
}
|
||||
event.RateLimits.Primary = &rateLimitWindow{
|
||||
UsedPercent: status.fiveHourUtilization,
|
||||
ResetAt: resetToEpoch(status.fiveHourReset),
|
||||
}
|
||||
event.RateLimits.Secondary = &rateLimitWindow{
|
||||
UsedPercent: status.weeklyUtilization,
|
||||
ResetAt: resetToEpoch(status.weeklyReset),
|
||||
}
|
||||
data, _ := json.Marshal(event)
|
||||
return data
|
||||
}
|
||||
|
||||
func (s *Service) handleWebSocketResponseCompleted(data []byte, usageTracker *AggregatedUsage, requestModel string, username string, weeklyCycleHint *WeeklyCycleHint) {
|
||||
var streamEvent responses.ResponseStreamEventUnion
|
||||
if json.Unmarshal(data, &streamEvent) != nil {
|
||||
|
||||
Reference in New Issue
Block a user