Files
sing-box/service/ocm/service_handler.go
世界 cd5007ffbb fix(ccm,ocm): track external credential poll failures and re-poll on user connect
External credentials now properly increment consecutivePollFailures on
poll errors (matching defaultCredential behavior), marking the credential
as temporarily blocked. When a user with external_credential connects
and the credential is not usable, a forced poll is triggered to check
recovery.
2026-03-26 22:16:02 +08:00

532 lines
17 KiB
Go

package ocm
import (
"bytes"
"context"
"encoding/json"
"io"
"mime"
"net/http"
"strconv"
"strings"
"time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/responses"
)
func weeklyCycleHintForLimit(headers http.Header, limitIdentifier string) *WeeklyCycleHint {
normalizedLimitIdentifier := normalizeRateLimitIdentifier(limitIdentifier)
if normalizedLimitIdentifier == "" {
return nil
}
windowHeader := "x-" + normalizedLimitIdentifier + "-secondary-window-minutes"
resetHeader := "x-" + normalizedLimitIdentifier + "-secondary-reset-at"
windowMinutes, hasWindowMinutes := parseInt64Header(headers, windowHeader)
resetAtUnix, hasResetAt := parseInt64Header(headers, resetHeader)
if !hasWindowMinutes || !hasResetAt || windowMinutes <= 0 || resetAtUnix <= 0 {
return nil
}
return &WeeklyCycleHint{
WindowMinutes: windowMinutes,
ResetAt: time.Unix(resetAtUnix, 0).UTC(),
}
}
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
if activeLimitIdentifier != "" {
if activeHint := weeklyCycleHintForLimit(headers, activeLimitIdentifier); activeHint != nil {
return activeHint
}
}
return weeklyCycleHintForLimit(headers, "codex")
}
func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) {
if len(s.options.Users) > 0 {
return credentialForUser(s.userConfigMap, s.providers, username)
}
provider := s.providers[s.options.Credentials[0].Tag]
if provider == nil {
return nil, E.New("no credential available")
}
return provider, nil
}
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := log.ContextWithNewID(r.Context())
if r.URL.Path == "/ocm/v1/status" {
s.handleStatusEndpoint(w, r)
return
}
if r.URL.Path == "/ocm/v1/reverse" {
s.handleReverseConnect(ctx, w, r)
return
}
path := r.URL.Path
if !strings.HasPrefix(path, "/v1/") {
writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/")
return
}
if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"API key authentication is not supported; use Authorization: Bearer with an OCM user token")
return
}
var username string
if len(s.options.Users) > 0 {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
var ok bool
username, ok = s.userManager.Authenticate(clientToken)
if !ok {
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
}
sessionID := r.Header.Get("session_id")
// Resolve credential provider and user config
var provider credentialProvider
var userConfig *option.OCMUser
if len(s.options.Users) > 0 {
userConfig = s.userConfigMap[username]
var err error
provider, err = credentialForUser(s.userConfigMap, s.providers, username)
if err != nil {
s.logger.ErrorContext(ctx, "resolve credential: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
} else {
provider = s.providers[s.options.Credentials[0].Tag]
}
if provider == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
return
}
provider.pollIfStale()
if userConfig != nil && userConfig.ExternalCredential != "" {
for _, credential := range s.allCredentials {
if credential.tagName() == userConfig.ExternalCredential && !credential.isUsable() {
credential.pollUsage()
break
}
}
}
selection := credentialSelectionForUser(userConfig)
selectedCredential, isNew, err := provider.selectCredential(sessionID, selection)
if err != nil {
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error()))
return
}
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, selection, isNew)
return
}
if !selectedCredential.isExternal() && selectedCredential.ocmIsAPIKeyMode() {
// API key mode path handling
} else if !selectedCredential.isExternal() {
if path == "/v1/chat/completions" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"chat completions endpoint is only available in API key mode")
return
}
}
shouldTrackUsage := selectedCredential.usageTrackerOrNil() != nil &&
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses"))
canRetryRequest := len(provider.allCredentials()) > 1
// Read body for model extraction and retry buffer when JSON replay is useful.
var bodyBytes []byte
var requestModel string
var requestServiceTier string
if r.Body != nil && (shouldTrackUsage || canRetryRequest) {
mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type"))
isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json"))
if isJSONRequest {
bodyBytes, err = io.ReadAll(r.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read request body: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
return
}
var request struct {
Model string `json:"model"`
ServiceTier string `json:"service_tier"`
}
if json.Unmarshal(bodyBytes, &request) == nil {
requestModel = request.Model
requestServiceTier = request.ServiceTier
}
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
}
if isNew {
logParts := []any{"assigned credential ", selectedCredential.tagName()}
if sessionID != "" {
logParts = append(logParts, " for session ", sessionID)
}
if username != "" {
logParts = append(logParts, " by user ", username)
}
if requestModel != "" {
logParts = append(logParts, ", model=", requestModel)
}
if requestServiceTier == "priority" {
logParts = append(logParts, ", fast")
}
s.logger.DebugContext(ctx, logParts...)
}
requestContext := selectedCredential.wrapRequestContext(ctx)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
defer func() {
requestContext.cancelRequest()
}()
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if err != nil {
s.logger.ErrorContext(ctx, "create proxy request: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
return
}
response, err := selectedCredential.httpClient().Do(proxyRequest)
if err != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request")
return
}
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
return
}
requestContext.releaseCredentialInterrupt()
// Transparent 429 retry
for response.StatusCode == http.StatusTooManyRequests {
resetAt := parseOCMRateLimitResetFromHeaders(response.Header)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete
selectedCredential.updateStateFromHeaders(response.Header)
if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil {
response.Body.Close()
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
return
}
response.Body.Close()
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
requestContext.cancelRequest()
requestContext = nextCredential.wrapRequestContext(ctx)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
return
}
retryResponse, retryErr := nextCredential.httpClient().Do(retryRequest)
if retryErr != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request")
return
}
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
return
}
requestContext.releaseCredentialInterrupt()
response = retryResponse
selectedCredential = nextCredential
}
defer response.Body.Close()
selectedCredential.updateStateFromHeaders(response.Header)
if response.StatusCode == http.StatusBadRequest {
if selectedCredential.isExternal() {
selectedCredential.markUpstreamRejected()
} else {
provider.pollCredentialIfStale(selectedCredential)
}
s.logger.ErrorContext(ctx, "upstream rejected from ", selectedCredential.tagName(), ": status ", response.StatusCode)
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "upstream rejected credential")
return
}
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
return
}
s.rewriteResponseHeaders(response.Header, provider, userConfig)
for key, values := range response.Header {
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
w.Header()[key] = values
}
}
w.WriteHeader(response.StatusCode)
usageTracker := selectedCredential.usageTrackerOrNil()
if usageTracker != nil && response.StatusCode == http.StatusOK &&
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) {
s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username)
} else {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
if err == nil && mediaType != "text/event-stream" {
_, _ = io.Copy(w, response.Body)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
buffer := make([]byte, buf.BufferSize)
for {
n, err := response.Body.Read(buffer)
if n > 0 {
_, writeError := w.Write(buffer[:n])
if writeError != nil {
if E.IsClosedOrCanceled(writeError) {
return
}
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
return
}
}
}
}
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) {
isChatCompletions := path == "/v1/chat/completions"
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
isStreaming := err == nil && mediaType == "text/event-stream"
if !isStreaming && !isChatCompletions && response.Header.Get("Content-Type") == "" {
isStreaming = true
}
if !isStreaming {
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
s.logger.ErrorContext(ctx, "read response body: ", err)
return
}
var responseModel, serviceTier string
var inputTokens, outputTokens, cachedTokens int64
if isChatCompletions {
var chatCompletion openai.ChatCompletion
if json.Unmarshal(bodyBytes, &chatCompletion) == nil {
responseModel = chatCompletion.Model
serviceTier = string(chatCompletion.ServiceTier)
inputTokens = chatCompletion.Usage.PromptTokens
outputTokens = chatCompletion.Usage.CompletionTokens
cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens
}
} else {
var responsesResponse responses.Response
if json.Unmarshal(bodyBytes, &responsesResponse) == nil {
responseModel = string(responsesResponse.Model)
serviceTier = string(responsesResponse.ServiceTier)
inputTokens = responsesResponse.Usage.InputTokens
outputTokens = responsesResponse.Usage.OutputTokens
cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
}
}
if inputTokens > 0 || outputTokens > 0 {
if responseModel == "" {
responseModel = requestModel
}
if responseModel != "" {
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
inputTokens,
outputTokens,
cachedTokens,
serviceTier,
username,
time.Now(),
weeklyCycleHint,
)
}
}
_, _ = writer.Write(bodyBytes)
return
}
flusher, ok := writer.(http.Flusher)
if !ok {
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
var inputTokens, outputTokens, cachedTokens int64
var responseModel, serviceTier string
buffer := make([]byte, buf.BufferSize)
var leftover []byte
for {
n, err := response.Body.Read(buffer)
if n > 0 {
data := append(leftover, buffer[:n]...)
lines := bytes.Split(data, []byte("\n"))
if err == nil {
leftover = lines[len(lines)-1]
lines = lines[:len(lines)-1]
} else {
leftover = nil
}
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if bytes.HasPrefix(line, []byte("data: ")) {
eventData := bytes.TrimPrefix(line, []byte("data: "))
if bytes.Equal(eventData, []byte("[DONE]")) {
continue
}
if isChatCompletions {
var chatChunk openai.ChatCompletionChunk
if json.Unmarshal(eventData, &chatChunk) == nil {
if chatChunk.Model != "" {
responseModel = chatChunk.Model
}
if chatChunk.ServiceTier != "" {
serviceTier = string(chatChunk.ServiceTier)
}
if chatChunk.Usage.PromptTokens > 0 {
inputTokens = chatChunk.Usage.PromptTokens
cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens
}
if chatChunk.Usage.CompletionTokens > 0 {
outputTokens = chatChunk.Usage.CompletionTokens
}
}
} else {
var streamEvent responses.ResponseStreamEventUnion
if json.Unmarshal(eventData, &streamEvent) == nil {
if streamEvent.Type == "response.completed" {
completedEvent := streamEvent.AsResponseCompleted()
if string(completedEvent.Response.Model) != "" {
responseModel = string(completedEvent.Response.Model)
}
if completedEvent.Response.ServiceTier != "" {
serviceTier = string(completedEvent.Response.ServiceTier)
}
if completedEvent.Response.Usage.InputTokens > 0 {
inputTokens = completedEvent.Response.Usage.InputTokens
cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens
}
if completedEvent.Response.Usage.OutputTokens > 0 {
outputTokens = completedEvent.Response.Usage.OutputTokens
}
}
}
}
}
}
_, writeError := writer.Write(buffer[:n])
if writeError != nil {
if E.IsClosedOrCanceled(writeError) {
return
}
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
if responseModel == "" {
responseModel = requestModel
}
if inputTokens > 0 || outputTokens > 0 {
if responseModel != "" {
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
inputTokens,
outputTokens,
cachedTokens,
serviceTier,
username,
time.Now(),
weeklyCycleHint,
)
}
}
return
}
}
}