Files
sing-box/service/ocm/service_websocket.go
世界 99d9e06dd0 fix(ccm,ocm): handle upstream 400 by marking external credentials rejected and polling default credentials
External credentials returning 400 are marked unavailable for pollInterval
duration; status stream/poll success clears the rejection early. Default
credentials trigger a stale poll to let the usage API detect account issues
without causing 429 storms.
2026-03-21 10:31:17 +08:00

638 lines
19 KiB
Go

package ocm
import (
"bufio"
"context"
stdTLS "crypto/tls"
"encoding/json"
"io"
"net"
"net/http"
"net/textproto"
"strconv"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/ws"
"github.com/sagernet/ws/wsutil"
"github.com/openai/openai-go/v3/responses"
)
type webSocketSession struct {
clientConn net.Conn
upstreamConn net.Conn
credentialTag string
releaseProviderInterrupt func()
closeOnce sync.Once
closed chan struct{}
}
func (s *webSocketSession) Close() {
s.closeOnce.Do(func() {
close(s.closed)
if s.releaseProviderInterrupt != nil {
s.releaseProviderInterrupt()
}
if s.clientConn != nil {
s.clientConn.Close()
}
if s.upstreamConn != nil {
s.upstreamConn.Close()
}
})
}
type webSocketResponseCreateRequest struct {
Type string `json:"type"`
Model string `json:"model"`
ServiceTier string `json:"service_tier"`
Generate *bool `json:"generate"`
}
func parseWebSocketResponseCreateRequest(data []byte) (webSocketResponseCreateRequest, bool) {
var request webSocketResponseCreateRequest
if json.Unmarshal(data, &request) != nil {
return webSocketResponseCreateRequest{}, false
}
if request.Type != "response.create" || request.Model == "" {
return webSocketResponseCreateRequest{}, false
}
return request, true
}
func (r webSocketResponseCreateRequest) isWarmup() bool {
return r.Generate != nil && !*r.Generate
}
func signalWebSocketReady(channel chan struct{}, once *sync.Once) {
once.Do(func() {
close(channel)
})
}
func buildUpstreamWebSocketURL(baseURL string, proxyPath string) string {
upstreamURL := baseURL
if strings.HasPrefix(upstreamURL, "https://") {
upstreamURL = "wss://" + upstreamURL[len("https://"):]
} else if strings.HasPrefix(upstreamURL, "http://") {
upstreamURL = "ws://" + upstreamURL[len("http://"):]
}
return upstreamURL + proxyPath
}
func isForwardableResponseHeader(key string) bool {
lowerKey := strings.ToLower(key)
switch {
case strings.HasPrefix(lowerKey, "x-codex-"):
return true
case strings.HasPrefix(lowerKey, "x-reasoning"):
return true
case lowerKey == "openai-model":
return true
case strings.Contains(lowerKey, "-secondary-"):
return true
default:
return false
}
}
func isForwardableWebSocketRequestHeader(key string) bool {
if isHopByHopHeader(key) || isReverseProxyHeader(key) {
return false
}
lowerKey := strings.ToLower(key)
switch {
case lowerKey == "authorization":
return false
case lowerKey == "x-api-key" || lowerKey == "api-key":
return false
case strings.HasPrefix(lowerKey, "sec-websocket-"):
return false
default:
return true
}
}
func (s *Service) handleWebSocket(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
path string,
username string,
sessionID string,
userConfig *option.OCMUser,
provider credentialProvider,
selectedCredential Credential,
selection credentialSelection,
isNew bool,
) {
var (
err error
requestContext *credentialRequestContext
clientConn net.Conn
session *webSocketSession
upstreamConn net.Conn
upstreamBufferedReader *bufio.Reader
upstreamResponseHeaders http.Header
statusCode int
statusResponseBody string
)
defer func() {
if requestContext != nil {
requestContext.cancelRequest()
}
}()
for {
accessToken, accessErr := selectedCredential.getAccessToken()
if accessErr != nil {
s.logger.ErrorContext(ctx, "get access token for websocket: ", accessErr)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed")
return
}
var proxyPath string
if selectedCredential.ocmIsAPIKeyMode() || selectedCredential.isExternal() {
proxyPath = path
} else {
proxyPath = strings.TrimPrefix(path, "/v1")
}
upstreamURL := buildUpstreamWebSocketURL(selectedCredential.ocmGetBaseURL(), proxyPath)
if r.URL.RawQuery != "" {
upstreamURL += "?" + r.URL.RawQuery
}
upstreamHeaders := make(http.Header)
for key, values := range r.Header {
if isForwardableWebSocketRequestHeader(key) {
upstreamHeaders[key] = values
}
}
for key, values := range s.httpHeaders {
upstreamHeaders.Del(key)
upstreamHeaders[key] = values
}
upstreamHeaders.Set("Authorization", "Bearer "+accessToken)
if accountID := selectedCredential.ocmGetAccountID(); accountID != "" {
upstreamHeaders.Set("ChatGPT-Account-Id", accountID)
}
if upstreamHeaders.Get("OpenAI-Beta") == "" {
upstreamHeaders.Set("OpenAI-Beta", "responses_websockets=2026-02-06")
}
upstreamResponseHeaders = make(http.Header)
statusCode = 0
statusResponseBody = ""
upstreamDialer := ws.Dialer{
NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) {
return selectedCredential.ocmDialer().DialContext(ctx, network, M.ParseSocksaddr(addr))
},
TLSConfig: &stdTLS.Config{
RootCAs: adapter.RootPoolFromContext(s.ctx),
Time: ntp.TimeFuncFromContext(s.ctx),
},
Header: ws.HandshakeHeaderHTTP(upstreamHeaders),
// gobwas/ws@v1.4.0: the response io.Reader is
// MultiReader(statusLine_without_CRLF, "\r\n", bufferedConn).
// ReadString('\n') consumes the status line, then ReadMIMEHeader
// parses the remaining headers.
OnStatusError: func(status int, reason []byte, response io.Reader) {
statusCode = status
bufferedResponse := bufio.NewReader(response)
_, readErr := bufferedResponse.ReadString('\n')
if readErr != nil {
return
}
mimeHeader, readErr := textproto.NewReader(bufferedResponse).ReadMIMEHeader()
if readErr == nil {
upstreamResponseHeaders = http.Header(mimeHeader)
}
body, readErr := io.ReadAll(io.LimitReader(bufferedResponse, 4096))
if readErr == nil && len(body) > 0 {
statusResponseBody = string(body)
}
},
OnHeader: func(key, value []byte) error {
upstreamResponseHeaders.Add(string(key), string(value))
return nil
},
}
requestContext = selectedCredential.wrapRequestContext(ctx)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
if session != nil {
session.Close()
return
}
if clientConn != nil {
clientConn.Close()
}
if upstreamConn != nil {
upstreamConn.Close()
}
}))
}
upstreamConn, upstreamBufferedReader, _, err = upstreamDialer.Dial(requestContext, upstreamURL)
if err == nil {
break
}
requestContext.cancelRequest()
requestContext = nil
upstreamConn = nil
clientConn = nil
if statusCode == http.StatusTooManyRequests {
resetAt := parseOCMRateLimitResetFromHeaders(upstreamResponseHeaders)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
selectedCredential.updateStateFromHeaders(upstreamResponseHeaders)
if nextCredential == nil {
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
return
}
s.logger.InfoContext(ctx, "retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
selectedCredential = nextCredential
continue
}
if statusCode == http.StatusBadRequest && selectedCredential.isExternal() {
selectedCredential.markUpstreamRejected()
s.logger.ErrorContext(ctx, "upstream rejected websocket from ", selectedCredential.tagName(), ": status ", statusCode)
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "upstream rejected credential")
return
}
if statusCode > 0 && statusResponseBody != "" {
s.logger.ErrorContext(ctx, "dial upstream websocket: status ", statusCode, " body: ", statusResponseBody)
} else {
s.logger.ErrorContext(ctx, "dial upstream websocket: ", err)
}
writeJSONError(w, r, http.StatusBadGateway, "api_error", "upstream websocket connection failed")
return
}
selectedCredential.updateStateFromHeaders(upstreamResponseHeaders)
weeklyCycleHint := extractWeeklyCycleHint(upstreamResponseHeaders)
clientResponseHeaders := make(http.Header)
for key, values := range upstreamResponseHeaders {
if isForwardableResponseHeader(key) {
clientResponseHeaders[key] = append([]string(nil), values...)
}
}
s.rewriteResponseHeaders(clientResponseHeaders, provider, userConfig)
clientUpgrader := ws.HTTPUpgrader{
Header: clientResponseHeaders,
}
if s.isShuttingDown() {
upstreamConn.Close()
writeJSONError(w, r, http.StatusServiceUnavailable, "api_error", "service is shutting down")
return
}
clientConn, _, _, err = clientUpgrader.Upgrade(r, w)
if err != nil {
s.logger.ErrorContext(ctx, "upgrade client websocket: ", err)
upstreamConn.Close()
return
}
session = &webSocketSession{
clientConn: clientConn,
upstreamConn: upstreamConn,
credentialTag: selectedCredential.tagName(),
releaseProviderInterrupt: requestContext.releaseCredentialInterrupt,
closed: make(chan struct{}),
}
if !s.registerWebSocketSession(session) {
session.Close()
return
}
defer s.unregisterWebSocketSession(session)
var upstreamReadWriter io.ReadWriter
if upstreamBufferedReader != nil {
upstreamReadWriter = struct {
io.Reader
io.Writer
}{upstreamBufferedReader, upstreamConn}
} else {
upstreamReadWriter = upstreamConn
}
var clientWriteAccess sync.Mutex
modelChannel := make(chan string, 1)
firstRealRequest := make(chan struct{})
var firstRealRequestOnce sync.Once
var waitGroup sync.WaitGroup
waitGroup.Add(3)
go func() {
defer waitGroup.Done()
defer session.Close()
s.proxyWebSocketClientToUpstream(ctx, clientConn, upstreamConn, selectedCredential, modelChannel, firstRealRequest, &firstRealRequestOnce, isNew, username, sessionID)
}()
go func() {
defer waitGroup.Done()
defer session.Close()
s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, &clientWriteAccess, selectedCredential, modelChannel, username, weeklyCycleHint)
}()
go func() {
defer waitGroup.Done()
defer session.Close()
s.pushWebSocketAggregatedStatus(ctx, clientConn, &clientWriteAccess, session.closed, firstRealRequest, provider, userConfig)
}()
waitGroup.Wait()
}
func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn net.Conn, upstreamConn net.Conn, selectedCredential Credential, modelChannel chan<- string, firstRealRequest chan struct{}, firstRealRequestOnce *sync.Once, isNew bool, username string, sessionID string) {
logged := false
for {
data, opCode, err := wsutil.ReadClientData(clientConn)
if err != nil {
if !E.IsClosedOrCanceled(err) {
s.logger.DebugContext(ctx, "read client websocket: ", err)
}
return
}
shouldSignalFirstRealRequest := false
if opCode == ws.OpText {
if request, ok := parseWebSocketResponseCreateRequest(data); ok {
isWarmup := request.isWarmup()
if !isWarmup && isNew && !logged {
logged = true
logParts := []any{"assigned credential ", selectedCredential.tagName()}
if sessionID != "" {
logParts = append(logParts, " for session ", sessionID)
}
if username != "" {
logParts = append(logParts, " by user ", username)
}
logParts = append(logParts, ", model=", request.Model)
if request.ServiceTier == "priority" {
logParts = append(logParts, ", fast")
}
s.logger.DebugContext(ctx, logParts...)
}
if !isWarmup && selectedCredential.usageTrackerOrNil() != nil {
select {
case modelChannel <- request.Model:
default:
}
}
if !isWarmup {
shouldSignalFirstRealRequest = true
}
}
}
err = wsutil.WriteClientMessage(upstreamConn, opCode, data)
if err != nil {
if !E.IsClosedOrCanceled(err) {
s.logger.DebugContext(ctx, "write upstream websocket: ", err)
}
return
}
if shouldSignalFirstRealRequest {
signalWebSocketReady(firstRealRequest, firstRealRequestOnce)
}
}
}
func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, clientWriteAccess *sync.Mutex, selectedCredential Credential, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
usageTracker := selectedCredential.usageTrackerOrNil()
var requestModel string
for {
data, opCode, err := wsutil.ReadServerData(upstreamReadWriter)
if err != nil {
if !E.IsClosedOrCanceled(err) {
s.logger.DebugContext(ctx, "read upstream websocket: ", err)
}
return
}
if opCode == ws.OpText {
var event struct {
Type string `json:"type"`
StatusCode int `json:"status_code"`
}
if json.Unmarshal(data, &event) == nil {
switch event.Type {
case "codex.rate_limits":
s.handleWebSocketRateLimitsEvent(data, selectedCredential)
continue
case "error":
if event.StatusCode == http.StatusTooManyRequests {
s.handleWebSocketErrorRateLimited(data, selectedCredential)
}
case "response.completed":
if usageTracker != nil {
select {
case model := <-modelChannel:
requestModel = model
default:
}
s.handleWebSocketResponseCompleted(data, usageTracker, requestModel, username, weeklyCycleHint)
}
}
}
}
clientWriteAccess.Lock()
err = wsutil.WriteServerMessage(clientConn, opCode, data)
clientWriteAccess.Unlock()
if err != nil {
if !E.IsClosedOrCanceled(err) {
s.logger.DebugContext(ctx, "write client websocket: ", err)
}
return
}
}
}
func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential Credential) {
var rateLimitsEvent struct {
RateLimits struct {
Primary *struct {
UsedPercent float64 `json:"used_percent"`
ResetAt int64 `json:"reset_at"`
} `json:"primary"`
Secondary *struct {
UsedPercent float64 `json:"used_percent"`
ResetAt int64 `json:"reset_at"`
} `json:"secondary"`
} `json:"rate_limits"`
PlanWeight float64 `json:"plan_weight"`
}
err := json.Unmarshal(data, &rateLimitsEvent)
if err != nil {
return
}
headers := make(http.Header)
headers.Set("x-codex-active-limit", "codex")
if w := rateLimitsEvent.RateLimits.Primary; w != nil {
headers.Set("x-codex-primary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64))
if w.ResetAt > 0 {
headers.Set("x-codex-primary-reset-at", strconv.FormatInt(w.ResetAt, 10))
}
}
if w := rateLimitsEvent.RateLimits.Secondary; w != nil {
headers.Set("x-codex-secondary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64))
if w.ResetAt > 0 {
headers.Set("x-codex-secondary-reset-at", strconv.FormatInt(w.ResetAt, 10))
}
}
if rateLimitsEvent.PlanWeight > 0 {
headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(rateLimitsEvent.PlanWeight, 'f', -1, 64))
}
selectedCredential.updateStateFromHeaders(headers)
}
func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredential Credential) {
var errorEvent struct {
Headers map[string]string `json:"headers"`
}
err := json.Unmarshal(data, &errorEvent)
if err != nil {
return
}
headers := make(http.Header)
for key, value := range errorEvent.Headers {
headers.Set(key, value)
}
selectedCredential.updateStateFromHeaders(headers)
resetAt := parseOCMRateLimitResetFromHeaders(headers)
selectedCredential.markRateLimited(resetAt)
}
func writeWebSocketAggregatedStatus(clientConn net.Conn, clientWriteAccess *sync.Mutex, status aggregatedStatus) error {
data := buildSyntheticRateLimitsEvent(status)
clientWriteAccess.Lock()
defer clientWriteAccess.Unlock()
return wsutil.WriteServerMessage(clientConn, ws.OpText, data)
}
func (s *Service) pushWebSocketAggregatedStatus(ctx context.Context, clientConn net.Conn, clientWriteAccess *sync.Mutex, sessionClosed <-chan struct{}, firstRealRequest <-chan struct{}, provider credentialProvider, userConfig *option.OCMUser) {
subscription, done, err := s.statusObserver.Subscribe()
if err != nil {
return
}
defer s.statusObserver.UnSubscribe(subscription)
var last aggregatedStatus
hasLast := false
for {
select {
case <-ctx.Done():
return
case <-done:
return
case <-sessionClosed:
return
case <-firstRealRequest:
current := s.computeAggregatedUtilization(provider, userConfig)
err = writeWebSocketAggregatedStatus(clientConn, clientWriteAccess, current)
if err != nil {
return
}
last = current
hasLast = true
firstRealRequest = nil
case <-subscription:
for {
select {
case <-subscription:
default:
goto drained
}
}
drained:
if !hasLast {
continue
}
current := s.computeAggregatedUtilization(provider, userConfig)
if current.equal(last) {
continue
}
last = current
err = writeWebSocketAggregatedStatus(clientConn, clientWriteAccess, current)
if err != nil {
return
}
}
}
}
func buildSyntheticRateLimitsEvent(status aggregatedStatus) []byte {
type rateLimitWindow struct {
UsedPercent float64 `json:"used_percent"`
ResetAt int64 `json:"reset_at,omitempty"`
}
event := struct {
Type string `json:"type"`
RateLimits struct {
Primary *rateLimitWindow `json:"primary,omitempty"`
Secondary *rateLimitWindow `json:"secondary,omitempty"`
} `json:"rate_limits"`
LimitName string `json:"limit_name"`
PlanWeight float64 `json:"plan_weight,omitempty"`
}{
Type: "codex.rate_limits",
LimitName: "codex",
PlanWeight: status.totalWeight,
}
event.RateLimits.Primary = &rateLimitWindow{
UsedPercent: status.fiveHourUtilization,
ResetAt: resetToEpoch(status.fiveHourReset),
}
event.RateLimits.Secondary = &rateLimitWindow{
UsedPercent: status.weeklyUtilization,
ResetAt: resetToEpoch(status.weeklyReset),
}
data, _ := json.Marshal(event)
return data
}
func (s *Service) handleWebSocketResponseCompleted(data []byte, usageTracker *AggregatedUsage, requestModel string, username string, weeklyCycleHint *WeeklyCycleHint) {
var streamEvent responses.ResponseStreamEventUnion
if json.Unmarshal(data, &streamEvent) != nil {
return
}
completedEvent := streamEvent.AsResponseCompleted()
responseModel := string(completedEvent.Response.Model)
serviceTier := string(completedEvent.Response.ServiceTier)
inputTokens := completedEvent.Response.Usage.InputTokens
outputTokens := completedEvent.Response.Usage.OutputTokens
cachedTokens := completedEvent.Response.Usage.InputTokensDetails.CachedTokens
if inputTokens > 0 || outputTokens > 0 {
if responseModel == "" {
responseModel = requestModel
}
if responseModel != "" {
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
inputTokens,
outputTokens,
cachedTokens,
serviceTier,
username,
time.Now(),
weeklyCycleHint,
)
}
}
}