package ocm import ( "bytes" "context" "encoding/json" "io" "mime" "net/http" "strconv" "strings" "time" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/responses" ) func weeklyCycleHintForLimit(headers http.Header, limitIdentifier string) *WeeklyCycleHint { normalizedLimitIdentifier := normalizeRateLimitIdentifier(limitIdentifier) if normalizedLimitIdentifier == "" { return nil } windowHeader := "x-" + normalizedLimitIdentifier + "-secondary-window-minutes" resetHeader := "x-" + normalizedLimitIdentifier + "-secondary-reset-at" windowMinutes, hasWindowMinutes := parseInt64Header(headers, windowHeader) resetAtUnix, hasResetAt := parseInt64Header(headers, resetHeader) if !hasWindowMinutes || !hasResetAt || windowMinutes <= 0 || resetAtUnix <= 0 { return nil } return &WeeklyCycleHint{ WindowMinutes: windowMinutes, ResetAt: time.Unix(resetAtUnix, 0).UTC(), } } func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint { activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) if activeLimitIdentifier != "" { if activeHint := weeklyCycleHintForLimit(headers, activeLimitIdentifier); activeHint != nil { return activeHint } } return weeklyCycleHintForLimit(headers, "codex") } func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) { if len(s.options.Users) > 0 { return credentialForUser(s.userConfigMap, s.providers, username) } provider := s.providers[s.options.Credentials[0].Tag] if provider == nil { return nil, E.New("no credential available") } return provider, nil } func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := log.ContextWithNewID(r.Context()) if r.URL.Path == "/ocm/v1/status" { s.handleStatusEndpoint(w, r) return } if r.URL.Path == "/ocm/v1/reverse" { s.handleReverseConnect(ctx, w, r) return } path := r.URL.Path if !strings.HasPrefix(path, "/v1/") { writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/") return } if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" { writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "API key authentication is not supported; use Authorization: Bearer with an OCM user token") return } var username string if len(s.options.Users) > 0 { authHeader := r.Header.Get("Authorization") if authHeader == "" { s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") return } clientToken := strings.TrimPrefix(authHeader, "Bearer ") if clientToken == authHeader { s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") return } var ok bool username, ok = s.userManager.Authenticate(clientToken) if !ok { s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") return } } sessionID := r.Header.Get("session_id") // Resolve credential provider and user config var provider credentialProvider var userConfig *option.OCMUser if len(s.options.Users) > 0 { userConfig = s.userConfigMap[username] var err error provider, err = credentialForUser(s.userConfigMap, s.providers, username) if err != nil { s.logger.ErrorContext(ctx, "resolve credential: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) return } } else { provider = s.providers[s.options.Credentials[0].Tag] } if provider == nil { writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available") return } provider.pollIfStale() selection := credentialSelectionForUser(userConfig) selectedCredential, isNew, err := provider.selectCredential(sessionID, selection) if err != nil { writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error())) return } if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") { s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, selection, isNew) return } if !selectedCredential.isExternal() && selectedCredential.ocmIsAPIKeyMode() { // API key mode path handling } else if !selectedCredential.isExternal() { if path == "/v1/chat/completions" { writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "chat completions endpoint is only available in API key mode") return } } shouldTrackUsage := selectedCredential.usageTrackerOrNil() != nil && (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) canRetryRequest := len(provider.allCredentials()) > 1 // Read body for model extraction and retry buffer when JSON replay is useful. var bodyBytes []byte var requestModel string var requestServiceTier string if r.Body != nil && (shouldTrackUsage || canRetryRequest) { mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type")) isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json")) if isJSONRequest { bodyBytes, err = io.ReadAll(r.Body) if err != nil { s.logger.ErrorContext(ctx, "read request body: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") return } var request struct { Model string `json:"model"` ServiceTier string `json:"service_tier"` } if json.Unmarshal(bodyBytes, &request) == nil { requestModel = request.Model requestServiceTier = request.ServiceTier } r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) } } if isNew { logParts := []any{"assigned credential ", selectedCredential.tagName()} if sessionID != "" { logParts = append(logParts, " for session ", sessionID) } if username != "" { logParts = append(logParts, " by user ", username) } if requestModel != "" { logParts = append(logParts, ", model=", requestModel) } if requestServiceTier == "priority" { logParts = append(logParts, ", fast") } s.logger.DebugContext(ctx, logParts...) } requestContext := selectedCredential.wrapRequestContext(ctx) { currentRequestContext := requestContext requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() { currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) })) } defer func() { requestContext.cancelRequest() }() proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if err != nil { s.logger.ErrorContext(ctx, "create proxy request: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") return } response, err := selectedCredential.httpClient().Do(proxyRequest) if err != nil { if r.Context().Err() != nil { return } if requestContext.Err() != nil { writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request") return } writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) return } requestContext.releaseCredentialInterrupt() // Transparent 429 retry for response.StatusCode == http.StatusTooManyRequests { resetAt := parseOCMRateLimitResetFromHeaders(response.Header) nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection) needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete selectedCredential.updateStateFromHeaders(response.Header) if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil { response.Body.Close() writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited") return } response.Body.Close() s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) requestContext.cancelRequest() requestContext = nextCredential.wrapRequestContext(ctx) { currentRequestContext := requestContext requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() { currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) })) } retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if buildErr != nil { s.logger.ErrorContext(ctx, "retry request: ", buildErr) writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error()) return } retryResponse, retryErr := nextCredential.httpClient().Do(retryRequest) if retryErr != nil { if r.Context().Err() != nil { return } if requestContext.Err() != nil { writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request") return } s.logger.ErrorContext(ctx, "retry request: ", retryErr) writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) return } requestContext.releaseCredentialInterrupt() response = retryResponse selectedCredential = nextCredential } defer response.Body.Close() selectedCredential.updateStateFromHeaders(response.Header) if response.StatusCode == http.StatusBadRequest { if selectedCredential.isExternal() { selectedCredential.markUpstreamRejected() } else { provider.pollCredentialIfStale(selectedCredential) } s.logger.ErrorContext(ctx, "upstream rejected from ", selectedCredential.tagName(), ": status ", response.StatusCode) writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "upstream rejected credential") return } if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { body, _ := io.ReadAll(response.Body) s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) return } s.rewriteResponseHeaders(response.Header, provider, userConfig) for key, values := range response.Header { if !isHopByHopHeader(key) && !isReverseProxyHeader(key) { w.Header()[key] = values } } w.WriteHeader(response.StatusCode) usageTracker := selectedCredential.usageTrackerOrNil() if usageTracker != nil && response.StatusCode == http.StatusOK && (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) { s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username) } else { mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) if err == nil && mediaType != "text/event-stream" { _, _ = io.Copy(w, response.Body) return } flusher, ok := w.(http.Flusher) if !ok { s.logger.ErrorContext(ctx, "streaming not supported") return } buffer := make([]byte, buf.BufferSize) for { n, err := response.Body.Read(buffer) if n > 0 { _, writeError := w.Write(buffer[:n]) if writeError != nil { if E.IsClosedOrCanceled(writeError) { return } s.logger.ErrorContext(ctx, "write streaming response: ", writeError) return } flusher.Flush() } if err != nil { return } } } } func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) { isChatCompletions := path == "/v1/chat/completions" weeklyCycleHint := extractWeeklyCycleHint(response.Header) mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) isStreaming := err == nil && mediaType == "text/event-stream" if !isStreaming && !isChatCompletions && response.Header.Get("Content-Type") == "" { isStreaming = true } if !isStreaming { bodyBytes, err := io.ReadAll(response.Body) if err != nil { s.logger.ErrorContext(ctx, "read response body: ", err) return } var responseModel, serviceTier string var inputTokens, outputTokens, cachedTokens int64 if isChatCompletions { var chatCompletion openai.ChatCompletion if json.Unmarshal(bodyBytes, &chatCompletion) == nil { responseModel = chatCompletion.Model serviceTier = string(chatCompletion.ServiceTier) inputTokens = chatCompletion.Usage.PromptTokens outputTokens = chatCompletion.Usage.CompletionTokens cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens } } else { var responsesResponse responses.Response if json.Unmarshal(bodyBytes, &responsesResponse) == nil { responseModel = string(responsesResponse.Model) serviceTier = string(responsesResponse.ServiceTier) inputTokens = responsesResponse.Usage.InputTokens outputTokens = responsesResponse.Usage.OutputTokens cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens } } if inputTokens > 0 || outputTokens > 0 { if responseModel == "" { responseModel = requestModel } if responseModel != "" { contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) usageTracker.AddUsageWithCycleHint( responseModel, contextWindow, inputTokens, outputTokens, cachedTokens, serviceTier, username, time.Now(), weeklyCycleHint, ) } } _, _ = writer.Write(bodyBytes) return } flusher, ok := writer.(http.Flusher) if !ok { s.logger.ErrorContext(ctx, "streaming not supported") return } var inputTokens, outputTokens, cachedTokens int64 var responseModel, serviceTier string buffer := make([]byte, buf.BufferSize) var leftover []byte for { n, err := response.Body.Read(buffer) if n > 0 { data := append(leftover, buffer[:n]...) lines := bytes.Split(data, []byte("\n")) if err == nil { leftover = lines[len(lines)-1] lines = lines[:len(lines)-1] } else { leftover = nil } for _, line := range lines { line = bytes.TrimSpace(line) if len(line) == 0 { continue } if bytes.HasPrefix(line, []byte("data: ")) { eventData := bytes.TrimPrefix(line, []byte("data: ")) if bytes.Equal(eventData, []byte("[DONE]")) { continue } if isChatCompletions { var chatChunk openai.ChatCompletionChunk if json.Unmarshal(eventData, &chatChunk) == nil { if chatChunk.Model != "" { responseModel = chatChunk.Model } if chatChunk.ServiceTier != "" { serviceTier = string(chatChunk.ServiceTier) } if chatChunk.Usage.PromptTokens > 0 { inputTokens = chatChunk.Usage.PromptTokens cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens } if chatChunk.Usage.CompletionTokens > 0 { outputTokens = chatChunk.Usage.CompletionTokens } } } else { var streamEvent responses.ResponseStreamEventUnion if json.Unmarshal(eventData, &streamEvent) == nil { if streamEvent.Type == "response.completed" { completedEvent := streamEvent.AsResponseCompleted() if string(completedEvent.Response.Model) != "" { responseModel = string(completedEvent.Response.Model) } if completedEvent.Response.ServiceTier != "" { serviceTier = string(completedEvent.Response.ServiceTier) } if completedEvent.Response.Usage.InputTokens > 0 { inputTokens = completedEvent.Response.Usage.InputTokens cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens } if completedEvent.Response.Usage.OutputTokens > 0 { outputTokens = completedEvent.Response.Usage.OutputTokens } } } } } } _, writeError := writer.Write(buffer[:n]) if writeError != nil { if E.IsClosedOrCanceled(writeError) { return } s.logger.ErrorContext(ctx, "write streaming response: ", writeError) return } flusher.Flush() } if err != nil { if responseModel == "" { responseModel = requestModel } if inputTokens > 0 || outputTokens > 0 { if responseModel != "" { contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens) usageTracker.AddUsageWithCycleHint( responseModel, contextWindow, inputTokens, outputTokens, cachedTokens, serviceTier, username, time.Now(), weeklyCycleHint, ) } } return } } }