mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-17 13:23:06 +10:00
752 lines
23 KiB
Go
752 lines
23 KiB
Go
package ocm
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
stdTLS "crypto/tls"
|
|
"encoding/json"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/textproto"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sagernet/sing-box/adapter"
|
|
"github.com/sagernet/sing-box/option"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
"github.com/sagernet/sing/common/ntp"
|
|
"github.com/sagernet/ws"
|
|
"github.com/sagernet/ws/wsutil"
|
|
|
|
"github.com/openai/openai-go/v3/responses"
|
|
)
|
|
|
|
type webSocketSession struct {
|
|
clientConn net.Conn
|
|
upstreamConn net.Conn
|
|
credentialTag string
|
|
releaseProviderInterrupt func()
|
|
closeOnce sync.Once
|
|
closed chan struct{}
|
|
}
|
|
|
|
func (s *webSocketSession) Close() {
|
|
s.closeOnce.Do(func() {
|
|
close(s.closed)
|
|
if s.releaseProviderInterrupt != nil {
|
|
s.releaseProviderInterrupt()
|
|
}
|
|
if s.clientConn != nil {
|
|
s.clientConn.Close()
|
|
}
|
|
if s.upstreamConn != nil {
|
|
s.upstreamConn.Close()
|
|
}
|
|
})
|
|
}
|
|
|
|
type webSocketResponseCreateRequest struct {
|
|
Type string `json:"type"`
|
|
Model string `json:"model"`
|
|
ServiceTier string `json:"service_tier"`
|
|
Generate *bool `json:"generate"`
|
|
}
|
|
|
|
func parseWebSocketResponseCreateRequest(data []byte) (webSocketResponseCreateRequest, bool) {
|
|
var request webSocketResponseCreateRequest
|
|
if json.Unmarshal(data, &request) != nil {
|
|
return webSocketResponseCreateRequest{}, false
|
|
}
|
|
if request.Type != "response.create" || request.Model == "" {
|
|
return webSocketResponseCreateRequest{}, false
|
|
}
|
|
return request, true
|
|
}
|
|
|
|
func (r webSocketResponseCreateRequest) isWarmup() bool {
|
|
return r.Generate != nil && !*r.Generate
|
|
}
|
|
|
|
func signalWebSocketReady(channel chan struct{}, once *sync.Once) {
|
|
once.Do(func() {
|
|
close(channel)
|
|
})
|
|
}
|
|
|
|
func buildUpstreamWebSocketURL(baseURL string, proxyPath string) string {
|
|
upstreamURL := baseURL
|
|
if strings.HasPrefix(upstreamURL, "https://") {
|
|
upstreamURL = "wss://" + upstreamURL[len("https://"):]
|
|
} else if strings.HasPrefix(upstreamURL, "http://") {
|
|
upstreamURL = "ws://" + upstreamURL[len("http://"):]
|
|
}
|
|
return upstreamURL + proxyPath
|
|
}
|
|
|
|
func isForwardableResponseHeader(key string) bool {
|
|
lowerKey := strings.ToLower(key)
|
|
switch {
|
|
case strings.HasPrefix(lowerKey, "x-codex-"):
|
|
return true
|
|
case strings.HasPrefix(lowerKey, "x-reasoning"):
|
|
return true
|
|
case lowerKey == "openai-model":
|
|
return true
|
|
case strings.Contains(lowerKey, "-secondary-"):
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func isForwardableWebSocketRequestHeader(key string) bool {
|
|
if isHopByHopHeader(key) || isReverseProxyHeader(key) {
|
|
return false
|
|
}
|
|
|
|
lowerKey := strings.ToLower(key)
|
|
switch {
|
|
case lowerKey == "authorization":
|
|
return false
|
|
case lowerKey == "x-api-key" || lowerKey == "api-key":
|
|
return false
|
|
case strings.HasPrefix(lowerKey, "sec-websocket-"):
|
|
return false
|
|
default:
|
|
return true
|
|
}
|
|
}
|
|
|
|
func (s *Service) handleWebSocket(
|
|
ctx context.Context,
|
|
w http.ResponseWriter,
|
|
r *http.Request,
|
|
path string,
|
|
username string,
|
|
sessionID string,
|
|
userConfig *option.OCMUser,
|
|
provider credentialProvider,
|
|
selectedCredential Credential,
|
|
selection credentialSelection,
|
|
isNew bool,
|
|
) {
|
|
var (
|
|
err error
|
|
requestContext *credentialRequestContext
|
|
clientConn net.Conn
|
|
session *webSocketSession
|
|
upstreamConn net.Conn
|
|
upstreamBufferedReader *bufio.Reader
|
|
upstreamResponseHeaders http.Header
|
|
statusCode int
|
|
statusResponseBody string
|
|
)
|
|
defer func() {
|
|
if requestContext != nil {
|
|
requestContext.cancelRequest()
|
|
}
|
|
}()
|
|
|
|
for {
|
|
accessToken, accessErr := selectedCredential.getAccessToken()
|
|
if accessErr != nil {
|
|
s.logger.ErrorContext(ctx, "get access token for websocket: ", accessErr)
|
|
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed")
|
|
return
|
|
}
|
|
|
|
var proxyPath string
|
|
if selectedCredential.ocmIsAPIKeyMode() || selectedCredential.isExternal() {
|
|
proxyPath = path
|
|
} else {
|
|
proxyPath = strings.TrimPrefix(path, "/v1")
|
|
}
|
|
|
|
upstreamURL := buildUpstreamWebSocketURL(selectedCredential.ocmGetBaseURL(), proxyPath)
|
|
if r.URL.RawQuery != "" {
|
|
upstreamURL += "?" + r.URL.RawQuery
|
|
}
|
|
|
|
upstreamHeaders := make(http.Header)
|
|
for key, values := range r.Header {
|
|
if isForwardableWebSocketRequestHeader(key) {
|
|
upstreamHeaders[key] = values
|
|
}
|
|
}
|
|
for key, values := range s.httpHeaders {
|
|
upstreamHeaders.Del(key)
|
|
upstreamHeaders[key] = values
|
|
}
|
|
upstreamHeaders.Set("Authorization", "Bearer "+accessToken)
|
|
if accountID := selectedCredential.ocmGetAccountID(); accountID != "" {
|
|
upstreamHeaders.Set("ChatGPT-Account-Id", accountID)
|
|
}
|
|
if upstreamHeaders.Get("OpenAI-Beta") == "" {
|
|
upstreamHeaders.Set("OpenAI-Beta", "responses_websockets=2026-02-06")
|
|
}
|
|
|
|
upstreamResponseHeaders = make(http.Header)
|
|
statusCode = 0
|
|
statusResponseBody = ""
|
|
upstreamDialer := ws.Dialer{
|
|
NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
return selectedCredential.ocmDialer().DialContext(ctx, network, M.ParseSocksaddr(addr))
|
|
},
|
|
TLSConfig: &stdTLS.Config{
|
|
RootCAs: adapter.RootPoolFromContext(s.ctx),
|
|
Time: ntp.TimeFuncFromContext(s.ctx),
|
|
},
|
|
Header: ws.HandshakeHeaderHTTP(upstreamHeaders),
|
|
// gobwas/ws@v1.4.0: the response io.Reader is
|
|
// MultiReader(statusLine_without_CRLF, "\r\n", bufferedConn).
|
|
// ReadString('\n') consumes the status line, then ReadMIMEHeader
|
|
// parses the remaining headers.
|
|
OnStatusError: func(status int, reason []byte, response io.Reader) {
|
|
statusCode = status
|
|
bufferedResponse := bufio.NewReader(response)
|
|
_, readErr := bufferedResponse.ReadString('\n')
|
|
if readErr != nil {
|
|
return
|
|
}
|
|
mimeHeader, readErr := textproto.NewReader(bufferedResponse).ReadMIMEHeader()
|
|
if readErr == nil {
|
|
upstreamResponseHeaders = http.Header(mimeHeader)
|
|
}
|
|
body, readErr := io.ReadAll(io.LimitReader(bufferedResponse, 4096))
|
|
if readErr == nil && len(body) > 0 {
|
|
statusResponseBody = string(body)
|
|
}
|
|
},
|
|
OnHeader: func(key, value []byte) error {
|
|
upstreamResponseHeaders.Add(string(key), string(value))
|
|
return nil
|
|
},
|
|
}
|
|
|
|
requestContext = selectedCredential.wrapRequestContext(ctx)
|
|
{
|
|
currentRequestContext := requestContext
|
|
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
|
|
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
|
if session != nil {
|
|
session.Close()
|
|
return
|
|
}
|
|
if clientConn != nil {
|
|
clientConn.Close()
|
|
}
|
|
if upstreamConn != nil {
|
|
upstreamConn.Close()
|
|
}
|
|
}))
|
|
}
|
|
upstreamConn, upstreamBufferedReader, _, err = upstreamDialer.Dial(requestContext, upstreamURL)
|
|
if err == nil {
|
|
break
|
|
}
|
|
requestContext.cancelRequest()
|
|
requestContext = nil
|
|
upstreamConn = nil
|
|
clientConn = nil
|
|
if statusCode == http.StatusTooManyRequests {
|
|
resetAt := parseOCMRateLimitResetFromHeaders(upstreamResponseHeaders)
|
|
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
|
|
selectedCredential.updateStateFromHeaders(upstreamResponseHeaders)
|
|
if nextCredential == nil {
|
|
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
|
|
return
|
|
}
|
|
s.logger.InfoContext(ctx, "retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
|
selectedCredential = nextCredential
|
|
continue
|
|
}
|
|
if statusCode == http.StatusBadRequest && selectedCredential.isExternal() {
|
|
selectedCredential.markUpstreamRejected()
|
|
s.logger.ErrorContext(ctx, "upstream rejected websocket from ", selectedCredential.tagName(), ": status ", statusCode)
|
|
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "upstream rejected credential")
|
|
return
|
|
}
|
|
if statusCode > 0 && statusResponseBody != "" {
|
|
s.logger.ErrorContext(ctx, "dial upstream websocket: status ", statusCode, " body: ", statusResponseBody)
|
|
} else {
|
|
s.logger.ErrorContext(ctx, "dial upstream websocket: ", err)
|
|
}
|
|
writeJSONError(w, r, http.StatusBadGateway, "api_error", "upstream websocket connection failed")
|
|
return
|
|
}
|
|
|
|
selectedCredential.updateStateFromHeaders(upstreamResponseHeaders)
|
|
weeklyCycleHint := extractWeeklyCycleHint(upstreamResponseHeaders)
|
|
|
|
clientResponseHeaders := make(http.Header)
|
|
for key, values := range upstreamResponseHeaders {
|
|
if isForwardableResponseHeader(key) {
|
|
clientResponseHeaders[key] = append([]string(nil), values...)
|
|
}
|
|
}
|
|
s.rewriteResponseHeaders(clientResponseHeaders, provider, userConfig)
|
|
|
|
clientUpgrader := ws.HTTPUpgrader{
|
|
Header: clientResponseHeaders,
|
|
}
|
|
if s.isShuttingDown() {
|
|
upstreamConn.Close()
|
|
writeJSONError(w, r, http.StatusServiceUnavailable, "api_error", "service is shutting down")
|
|
return
|
|
}
|
|
clientConn, _, _, err = clientUpgrader.Upgrade(r, w)
|
|
if err != nil {
|
|
s.logger.ErrorContext(ctx, "upgrade client websocket: ", err)
|
|
upstreamConn.Close()
|
|
return
|
|
}
|
|
session = &webSocketSession{
|
|
clientConn: clientConn,
|
|
upstreamConn: upstreamConn,
|
|
credentialTag: selectedCredential.tagName(),
|
|
releaseProviderInterrupt: requestContext.releaseCredentialInterrupt,
|
|
closed: make(chan struct{}),
|
|
}
|
|
if !s.registerWebSocketSession(session) {
|
|
session.Close()
|
|
return
|
|
}
|
|
defer s.unregisterWebSocketSession(session)
|
|
|
|
var upstreamReadWriter io.ReadWriter
|
|
if upstreamBufferedReader != nil {
|
|
upstreamReadWriter = struct {
|
|
io.Reader
|
|
io.Writer
|
|
}{upstreamBufferedReader, upstreamConn}
|
|
} else {
|
|
upstreamReadWriter = upstreamConn
|
|
}
|
|
|
|
var clientWriteAccess sync.Mutex
|
|
modelChannel := make(chan string, 1)
|
|
firstRealRequest := make(chan struct{})
|
|
var firstRealRequestOnce sync.Once
|
|
var waitGroup sync.WaitGroup
|
|
|
|
waitGroup.Add(3)
|
|
go func() {
|
|
defer waitGroup.Done()
|
|
defer session.Close()
|
|
s.proxyWebSocketClientToUpstream(ctx, clientConn, upstreamConn, selectedCredential, modelChannel, firstRealRequest, &firstRealRequestOnce, isNew, username, sessionID)
|
|
}()
|
|
go func() {
|
|
defer waitGroup.Done()
|
|
defer session.Close()
|
|
s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, &clientWriteAccess, selectedCredential, modelChannel, username, weeklyCycleHint)
|
|
}()
|
|
go func() {
|
|
defer waitGroup.Done()
|
|
defer session.Close()
|
|
s.pushWebSocketAggregatedStatus(ctx, clientConn, &clientWriteAccess, session.closed, firstRealRequest, provider, userConfig)
|
|
}()
|
|
waitGroup.Wait()
|
|
}
|
|
|
|
func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn net.Conn, upstreamConn net.Conn, selectedCredential Credential, modelChannel chan<- string, firstRealRequest chan struct{}, firstRealRequestOnce *sync.Once, isNew bool, username string, sessionID string) {
|
|
logged := false
|
|
for {
|
|
data, opCode, err := wsutil.ReadClientData(clientConn)
|
|
if err != nil {
|
|
if !E.IsClosedOrCanceled(err) {
|
|
s.logger.DebugContext(ctx, "read client websocket: ", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
shouldSignalFirstRealRequest := false
|
|
if opCode == ws.OpText {
|
|
if request, ok := parseWebSocketResponseCreateRequest(data); ok {
|
|
isWarmup := request.isWarmup()
|
|
if !isWarmup && isNew && !logged {
|
|
logged = true
|
|
logParts := []any{"assigned credential ", selectedCredential.tagName()}
|
|
if sessionID != "" {
|
|
logParts = append(logParts, " for session ", sessionID)
|
|
}
|
|
if username != "" {
|
|
logParts = append(logParts, " by user ", username)
|
|
}
|
|
logParts = append(logParts, ", model=", request.Model)
|
|
if request.ServiceTier == "priority" {
|
|
logParts = append(logParts, ", fast")
|
|
}
|
|
s.logger.DebugContext(ctx, logParts...)
|
|
}
|
|
if !isWarmup && selectedCredential.usageTrackerOrNil() != nil {
|
|
select {
|
|
case modelChannel <- request.Model:
|
|
default:
|
|
}
|
|
}
|
|
if !isWarmup {
|
|
shouldSignalFirstRealRequest = true
|
|
}
|
|
}
|
|
}
|
|
|
|
err = wsutil.WriteClientMessage(upstreamConn, opCode, data)
|
|
if err != nil {
|
|
if !E.IsClosedOrCanceled(err) {
|
|
s.logger.DebugContext(ctx, "write upstream websocket: ", err)
|
|
}
|
|
return
|
|
}
|
|
if shouldSignalFirstRealRequest {
|
|
signalWebSocketReady(firstRealRequest, firstRealRequestOnce)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, clientWriteAccess *sync.Mutex, selectedCredential Credential, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
|
|
usageTracker := selectedCredential.usageTrackerOrNil()
|
|
var requestModel string
|
|
for {
|
|
data, opCode, err := wsutil.ReadServerData(upstreamReadWriter)
|
|
if err != nil {
|
|
if !E.IsClosedOrCanceled(err) {
|
|
s.logger.DebugContext(ctx, "read upstream websocket: ", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
if opCode == ws.OpText {
|
|
var event struct {
|
|
Type string `json:"type"`
|
|
StatusCode int `json:"status_code"`
|
|
}
|
|
if json.Unmarshal(data, &event) == nil {
|
|
switch event.Type {
|
|
case "codex.rate_limits":
|
|
s.handleWebSocketRateLimitsEvent(data, selectedCredential)
|
|
continue
|
|
case "error":
|
|
s.handleWebSocketErrorEvent(data, selectedCredential)
|
|
case "response.completed":
|
|
if usageTracker != nil {
|
|
select {
|
|
case model := <-modelChannel:
|
|
requestModel = model
|
|
default:
|
|
}
|
|
s.handleWebSocketResponseCompleted(data, usageTracker, requestModel, username, weeklyCycleHint)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
clientWriteAccess.Lock()
|
|
err = wsutil.WriteServerMessage(clientConn, opCode, data)
|
|
clientWriteAccess.Unlock()
|
|
if err != nil {
|
|
if !E.IsClosedOrCanceled(err) {
|
|
s.logger.DebugContext(ctx, "write client websocket: ", err)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential Credential) {
|
|
var rateLimitsEvent struct {
|
|
MeteredLimitName string `json:"metered_limit_name"`
|
|
LimitName string `json:"limit_name"`
|
|
RateLimits struct {
|
|
Primary *struct {
|
|
UsedPercent float64 `json:"used_percent"`
|
|
WindowMinutes int64 `json:"window_minutes"`
|
|
ResetAt int64 `json:"reset_at"`
|
|
} `json:"primary"`
|
|
Secondary *struct {
|
|
UsedPercent float64 `json:"used_percent"`
|
|
WindowMinutes int64 `json:"window_minutes"`
|
|
ResetAt int64 `json:"reset_at"`
|
|
} `json:"secondary"`
|
|
} `json:"rate_limits"`
|
|
Credits *creditsSnapshot `json:"credits"`
|
|
PlanWeight float64 `json:"plan_weight"`
|
|
}
|
|
err := json.Unmarshal(data, &rateLimitsEvent)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
headers := make(http.Header)
|
|
limitID := rateLimitsEvent.MeteredLimitName
|
|
if limitID == "" {
|
|
limitID = rateLimitsEvent.LimitName
|
|
}
|
|
if limitID == "" {
|
|
limitID = "codex"
|
|
}
|
|
headerLimit := headerLimitID(limitID)
|
|
headers.Set("x-codex-active-limit", headerLimit)
|
|
if w := rateLimitsEvent.RateLimits.Primary; w != nil {
|
|
headers.Set("x-"+headerLimit+"-primary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64))
|
|
if w.WindowMinutes > 0 {
|
|
headers.Set("x-"+headerLimit+"-primary-window-minutes", strconv.FormatInt(w.WindowMinutes, 10))
|
|
}
|
|
if w.ResetAt > 0 {
|
|
headers.Set("x-"+headerLimit+"-primary-reset-at", strconv.FormatInt(w.ResetAt, 10))
|
|
}
|
|
}
|
|
if w := rateLimitsEvent.RateLimits.Secondary; w != nil {
|
|
headers.Set("x-"+headerLimit+"-secondary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64))
|
|
if w.WindowMinutes > 0 {
|
|
headers.Set("x-"+headerLimit+"-secondary-window-minutes", strconv.FormatInt(w.WindowMinutes, 10))
|
|
}
|
|
if w.ResetAt > 0 {
|
|
headers.Set("x-"+headerLimit+"-secondary-reset-at", strconv.FormatInt(w.ResetAt, 10))
|
|
}
|
|
}
|
|
if rateLimitsEvent.LimitName != "" {
|
|
headers.Set("x-"+headerLimit+"-limit-name", rateLimitsEvent.LimitName)
|
|
}
|
|
if rateLimitsEvent.Credits != nil && normalizeStoredLimitID(limitID) == "codex" {
|
|
headers.Set("x-codex-credits-has-credits", strconv.FormatBool(rateLimitsEvent.Credits.HasCredits))
|
|
headers.Set("x-codex-credits-unlimited", strconv.FormatBool(rateLimitsEvent.Credits.Unlimited))
|
|
if rateLimitsEvent.Credits.Balance != "" {
|
|
headers.Set("x-codex-credits-balance", rateLimitsEvent.Credits.Balance)
|
|
}
|
|
}
|
|
if rateLimitsEvent.PlanWeight > 0 {
|
|
headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(rateLimitsEvent.PlanWeight, 'f', -1, 64))
|
|
}
|
|
selectedCredential.updateStateFromHeaders(headers)
|
|
}
|
|
|
|
func (s *Service) handleWebSocketErrorEvent(data []byte, selectedCredential Credential) {
|
|
var errorEvent struct {
|
|
StatusCode int `json:"status_code"`
|
|
Headers map[string]string `json:"headers"`
|
|
Error struct {
|
|
Code string `json:"code"`
|
|
} `json:"error"`
|
|
}
|
|
err := json.Unmarshal(data, &errorEvent)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if errorEvent.StatusCode == http.StatusBadRequest && errorEvent.Error.Code == "websocket_connection_limit_reached" {
|
|
selectedCredential.markTemporarilyBlocked(availabilityReasonConnectionLimit, time.Now().Add(time.Minute))
|
|
return
|
|
}
|
|
if errorEvent.StatusCode != http.StatusTooManyRequests {
|
|
return
|
|
}
|
|
headers := make(http.Header)
|
|
for key, value := range errorEvent.Headers {
|
|
headers.Set(key, value)
|
|
}
|
|
selectedCredential.updateStateFromHeaders(headers)
|
|
resetAt := parseOCMRateLimitResetFromHeaders(headers)
|
|
selectedCredential.markRateLimited(resetAt)
|
|
}
|
|
|
|
func writeWebSocketAggregatedStatus(clientConn net.Conn, clientWriteAccess *sync.Mutex, status aggregatedStatus) error {
|
|
clientWriteAccess.Lock()
|
|
defer clientWriteAccess.Unlock()
|
|
for _, data := range buildSyntheticRateLimitsEvents(status) {
|
|
if err := wsutil.WriteServerMessage(clientConn, ws.OpText, data); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn net.Conn, clientWriteAccess *sync.Mutex, sessionClosed <-chan struct{}, firstRealRequest <-chan struct{}, provider credentialProvider, userConfig *option.OCMUser) {
|
|
subscription, done, err := s.statusObserver.Subscribe()
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer s.statusObserver.UnSubscribe(subscription)
|
|
|
|
var last aggregatedStatus
|
|
hasLast := false
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-done:
|
|
return
|
|
case <-sessionClosed:
|
|
return
|
|
case <-firstRealRequest:
|
|
current := s.computeAggregatedUtilization(provider, userConfig)
|
|
err = writeWebSocketAggregatedStatus(clientConn, clientWriteAccess, current)
|
|
if err != nil {
|
|
return
|
|
}
|
|
last = current
|
|
hasLast = true
|
|
firstRealRequest = nil
|
|
case <-subscription:
|
|
for {
|
|
select {
|
|
case <-subscription:
|
|
default:
|
|
goto drained
|
|
}
|
|
}
|
|
drained:
|
|
if !hasLast {
|
|
continue
|
|
}
|
|
current := s.computeAggregatedUtilization(provider, userConfig)
|
|
if current.equal(last) {
|
|
continue
|
|
}
|
|
last = current
|
|
err = writeWebSocketAggregatedStatus(clientConn, clientWriteAccess, current)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func buildSyntheticRateLimitsEvents(status aggregatedStatus) [][]byte {
|
|
type rateLimitWindow struct {
|
|
UsedPercent float64 `json:"used_percent"`
|
|
WindowMinutes int64 `json:"window_minutes,omitempty"`
|
|
ResetAt int64 `json:"reset_at,omitempty"`
|
|
}
|
|
type creditsEvent struct {
|
|
HasCredits bool `json:"has_credits"`
|
|
Unlimited bool `json:"unlimited"`
|
|
Balance string `json:"balance,omitempty"`
|
|
}
|
|
type eventPayload struct {
|
|
Type string `json:"type"`
|
|
RateLimits struct {
|
|
Primary *rateLimitWindow `json:"primary,omitempty"`
|
|
Secondary *rateLimitWindow `json:"secondary,omitempty"`
|
|
} `json:"rate_limits"`
|
|
MeteredLimitName string `json:"metered_limit_name,omitempty"`
|
|
LimitName string `json:"limit_name,omitempty"`
|
|
Credits *creditsEvent `json:"credits,omitempty"`
|
|
PlanWeight float64 `json:"plan_weight,omitempty"`
|
|
}
|
|
buildEvent := func(snapshot rateLimitSnapshot, primary *rateLimitWindow, secondary *rateLimitWindow) []byte {
|
|
event := eventPayload{
|
|
Type: "codex.rate_limits",
|
|
MeteredLimitName: snapshot.LimitID,
|
|
LimitName: snapshot.LimitName,
|
|
PlanWeight: status.totalWeight,
|
|
}
|
|
if event.MeteredLimitName == "" {
|
|
event.MeteredLimitName = "codex"
|
|
}
|
|
if event.LimitName == "" {
|
|
event.LimitName = strings.ReplaceAll(event.MeteredLimitName, "_", "-")
|
|
}
|
|
event.RateLimits.Primary = primary
|
|
event.RateLimits.Secondary = secondary
|
|
if snapshot.Credits != nil {
|
|
event.Credits = &creditsEvent{
|
|
HasCredits: snapshot.Credits.HasCredits,
|
|
Unlimited: snapshot.Credits.Unlimited,
|
|
Balance: snapshot.Credits.Balance,
|
|
}
|
|
}
|
|
data, _ := json.Marshal(event)
|
|
return data
|
|
}
|
|
defaultPrimary := &rateLimitWindow{
|
|
UsedPercent: status.fiveHourUtilization,
|
|
ResetAt: resetToEpoch(status.fiveHourReset),
|
|
}
|
|
defaultSecondary := &rateLimitWindow{
|
|
UsedPercent: status.weeklyUtilization,
|
|
ResetAt: resetToEpoch(status.weeklyReset),
|
|
}
|
|
events := make([][]byte, 0, 1+len(status.limits))
|
|
if snapshot := findSnapshotByLimitID(status.limits, "codex"); snapshot != nil {
|
|
primary := defaultPrimary
|
|
if snapshot.Primary != nil {
|
|
primary = &rateLimitWindow{
|
|
UsedPercent: snapshot.Primary.UsedPercent,
|
|
WindowMinutes: snapshot.Primary.WindowMinutes,
|
|
ResetAt: snapshot.Primary.ResetAt,
|
|
}
|
|
}
|
|
secondary := defaultSecondary
|
|
if snapshot.Secondary != nil {
|
|
secondary = &rateLimitWindow{
|
|
UsedPercent: snapshot.Secondary.UsedPercent,
|
|
WindowMinutes: snapshot.Secondary.WindowMinutes,
|
|
ResetAt: snapshot.Secondary.ResetAt,
|
|
}
|
|
}
|
|
events = append(events, buildEvent(*snapshot, primary, secondary))
|
|
} else {
|
|
events = append(events, buildEvent(rateLimitSnapshot{LimitID: "codex", LimitName: "codex"}, defaultPrimary, defaultSecondary))
|
|
}
|
|
for _, snapshot := range status.limits {
|
|
if snapshot.LimitID == "codex" {
|
|
continue
|
|
}
|
|
var primary *rateLimitWindow
|
|
if snapshot.Primary != nil {
|
|
primary = &rateLimitWindow{
|
|
UsedPercent: snapshot.Primary.UsedPercent,
|
|
WindowMinutes: snapshot.Primary.WindowMinutes,
|
|
ResetAt: snapshot.Primary.ResetAt,
|
|
}
|
|
}
|
|
var secondary *rateLimitWindow
|
|
if snapshot.Secondary != nil {
|
|
secondary = &rateLimitWindow{
|
|
UsedPercent: snapshot.Secondary.UsedPercent,
|
|
WindowMinutes: snapshot.Secondary.WindowMinutes,
|
|
ResetAt: snapshot.Secondary.ResetAt,
|
|
}
|
|
}
|
|
events = append(events, buildEvent(snapshot, primary, secondary))
|
|
}
|
|
return events
|
|
}
|
|
|
|
func (s *Service) handleWebSocketResponseCompleted(data []byte, usageTracker *AggregatedUsage, requestModel string, username string, weeklyCycleHint *WeeklyCycleHint) {
|
|
var streamEvent responses.ResponseStreamEventUnion
|
|
if json.Unmarshal(data, &streamEvent) != nil {
|
|
return
|
|
}
|
|
completedEvent := streamEvent.AsResponseCompleted()
|
|
responseModel := string(completedEvent.Response.Model)
|
|
serviceTier := string(completedEvent.Response.ServiceTier)
|
|
inputTokens := completedEvent.Response.Usage.InputTokens
|
|
outputTokens := completedEvent.Response.Usage.OutputTokens
|
|
cachedTokens := completedEvent.Response.Usage.InputTokensDetails.CachedTokens
|
|
|
|
if inputTokens > 0 || outputTokens > 0 {
|
|
if responseModel == "" {
|
|
responseModel = requestModel
|
|
}
|
|
if responseModel != "" {
|
|
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
|
|
usageTracker.AddUsageWithCycleHint(
|
|
responseModel,
|
|
contextWindow,
|
|
inputTokens,
|
|
outputTokens,
|
|
cachedTokens,
|
|
serviceTier,
|
|
username,
|
|
time.Now(),
|
|
weeklyCycleHint,
|
|
)
|
|
}
|
|
}
|
|
}
|