mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-11 17:47:20 +10:00
ccm,ocm: add request ID context to HTTP request logging
This commit is contained in:
@@ -52,7 +52,7 @@ func (l *yamuxNetListener) Addr() net.Addr {
|
||||
return l.session.Addr()
|
||||
}
|
||||
|
||||
func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Upgrade") != "reverse-proxy" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header")
|
||||
return
|
||||
@@ -71,21 +71,21 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
receiverCredential := s.findReceiverCredential(clientToken)
|
||||
if receiverCredential == nil {
|
||||
s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
|
||||
s.logger.WarnContext(ctx, "reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token")
|
||||
return
|
||||
}
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
s.logger.Error("reverse connect: hijack not supported")
|
||||
s.logger.ErrorContext(ctx, "reverse connect: hijack not supported")
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported")
|
||||
return
|
||||
}
|
||||
|
||||
conn, bufferedReadWriter, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
s.logger.Error("reverse connect: hijack: ", err)
|
||||
s.logger.ErrorContext(ctx, "reverse connect: hijack: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -93,20 +93,20 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
|
||||
_, err = bufferedReadWriter.WriteString(response)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.Error("reverse connect: write upgrade response: ", err)
|
||||
s.logger.ErrorContext(ctx, "reverse connect: write upgrade response: ", err)
|
||||
return
|
||||
}
|
||||
err = bufferedReadWriter.Flush()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.Error("reverse connect: flush upgrade response: ", err)
|
||||
s.logger.ErrorContext(ctx, "reverse connect: flush upgrade response: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := yamux.Client(conn, reverseYamuxConfig())
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
|
||||
s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -114,12 +114,12 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
|
||||
session.Close()
|
||||
return
|
||||
}
|
||||
s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
|
||||
s.logger.InfoContext(ctx, "reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
|
||||
|
||||
go func() {
|
||||
<-session.CloseChan()
|
||||
receiverCredential.clearReverseSession(session)
|
||||
s.logger.Warn("reverse connection lost for ", receiverCredential.tagName())
|
||||
s.logger.WarnContext(ctx, "reverse connection lost for ", receiverCredential.tagName())
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -325,13 +325,14 @@ func detectContextWindow(betaHeader string, totalInputTokens int64) int {
|
||||
}
|
||||
|
||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := log.ContextWithNewID(r.Context())
|
||||
if r.URL.Path == "/ccm/v1/status" {
|
||||
s.handleStatusEndpoint(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/ccm/v1/reverse" {
|
||||
s.handleReverseConnect(w, r)
|
||||
s.handleReverseConnect(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -344,20 +345,20 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if len(s.options.Users) > 0 {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
var ok bool
|
||||
username, ok = s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
@@ -373,7 +374,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
bodyBytes, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
s.logger.Error("read request body: ", err)
|
||||
s.logger.ErrorContext(ctx, "read request body: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
|
||||
return
|
||||
}
|
||||
@@ -400,7 +401,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
s.logger.Error("resolve credential: ", err)
|
||||
s.logger.ErrorContext(ctx, "resolve credential: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
@@ -448,7 +449,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
logParts = append(logParts, ", model=", modelDisplay)
|
||||
}
|
||||
s.logger.Debug(logParts...)
|
||||
s.logger.DebugContext(ctx, logParts...)
|
||||
}
|
||||
|
||||
if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() {
|
||||
@@ -463,7 +464,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}()
|
||||
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if err != nil {
|
||||
s.logger.Error("create proxy request: ", err)
|
||||
s.logger.ErrorContext(ctx, "create proxy request: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
return
|
||||
}
|
||||
@@ -493,12 +494,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
response.Body.Close()
|
||||
s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
requestContext.cancelRequest()
|
||||
requestContext = nextCredential.wrapRequestContext(r.Context())
|
||||
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if buildErr != nil {
|
||||
s.logger.Error("retry request: ", buildErr)
|
||||
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
|
||||
return
|
||||
}
|
||||
@@ -511,7 +512,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
|
||||
return
|
||||
}
|
||||
s.logger.Error("retry request: ", retryErr)
|
||||
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
|
||||
return
|
||||
}
|
||||
@@ -525,7 +526,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
||||
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
||||
go selectedCredential.pollUsage(s.ctx)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
|
||||
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
|
||||
@@ -546,7 +547,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||
if usageTracker != nil && response.StatusCode == http.StatusOK {
|
||||
s.handleResponseWithTracking(w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
|
||||
s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
|
||||
} else {
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
if err == nil && mediaType != "text/event-stream" {
|
||||
@@ -555,7 +556,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error("streaming not supported")
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
@@ -564,7 +565,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if n > 0 {
|
||||
_, writeError := w.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.Error("write streaming response: ", writeError)
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
@@ -576,7 +577,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
|
||||
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
|
||||
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
isStreaming := err == nil && mediaType == "text/event-stream"
|
||||
@@ -584,7 +585,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
if !isStreaming {
|
||||
bodyBytes, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
s.logger.Error("read response body: ", err)
|
||||
s.logger.ErrorContext(ctx, "read response body: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -627,7 +628,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error("streaming not supported")
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -690,7 +691,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
|
||||
_, writeError := writer.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.Error("write streaming response: ", writeError)
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
@@ -52,7 +52,7 @@ func (l *yamuxNetListener) Addr() net.Addr {
|
||||
return l.session.Addr()
|
||||
}
|
||||
|
||||
func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Upgrade") != "reverse-proxy" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header")
|
||||
return
|
||||
@@ -71,21 +71,21 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
receiverCredential := s.findReceiverCredential(clientToken)
|
||||
if receiverCredential == nil {
|
||||
s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
|
||||
s.logger.WarnContext(ctx, "reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token")
|
||||
return
|
||||
}
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
s.logger.Error("reverse connect: hijack not supported")
|
||||
s.logger.ErrorContext(ctx, "reverse connect: hijack not supported")
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported")
|
||||
return
|
||||
}
|
||||
|
||||
conn, bufferedReadWriter, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
s.logger.Error("reverse connect: hijack: ", err)
|
||||
s.logger.ErrorContext(ctx, "reverse connect: hijack: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -93,20 +93,20 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
|
||||
_, err = bufferedReadWriter.WriteString(response)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.Error("reverse connect: write upgrade response: ", err)
|
||||
s.logger.ErrorContext(ctx, "reverse connect: write upgrade response: ", err)
|
||||
return
|
||||
}
|
||||
err = bufferedReadWriter.Flush()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.Error("reverse connect: flush upgrade response: ", err)
|
||||
s.logger.ErrorContext(ctx, "reverse connect: flush upgrade response: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := yamux.Client(conn, reverseYamuxConfig())
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
|
||||
s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -114,12 +114,12 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
|
||||
session.Close()
|
||||
return
|
||||
}
|
||||
s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
|
||||
s.logger.InfoContext(ctx, "reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
|
||||
|
||||
go func() {
|
||||
<-session.CloseChan()
|
||||
receiverCredential.clearReverseSession(session)
|
||||
s.logger.Warn("reverse connection lost for ", receiverCredential.tagName())
|
||||
s.logger.WarnContext(ctx, "reverse connection lost for ", receiverCredential.tagName())
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -362,13 +362,14 @@ func (s *Service) resolveCredentialProvider(username string) (credentialProvider
|
||||
}
|
||||
|
||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := log.ContextWithNewID(r.Context())
|
||||
if r.URL.Path == "/ocm/v1/status" {
|
||||
s.handleStatusEndpoint(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/ocm/v1/reverse" {
|
||||
s.handleReverseConnect(w, r)
|
||||
s.handleReverseConnect(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -382,20 +383,20 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if len(s.options.Users) > 0 {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
var ok bool
|
||||
username, ok = s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
@@ -411,7 +412,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
s.logger.Error("resolve credential: ", err)
|
||||
s.logger.ErrorContext(ctx, "resolve credential: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
@@ -437,7 +438,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
|
||||
s.handleWebSocket(w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew)
|
||||
s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -465,7 +466,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if isJSONRequest {
|
||||
bodyBytes, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
s.logger.Error("read request body: ", err)
|
||||
s.logger.ErrorContext(ctx, "read request body: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
|
||||
return
|
||||
}
|
||||
@@ -495,7 +496,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if requestServiceTier == "priority" {
|
||||
logParts = append(logParts, ", fast")
|
||||
}
|
||||
s.logger.Debug(logParts...)
|
||||
s.logger.DebugContext(ctx, logParts...)
|
||||
}
|
||||
|
||||
requestContext := selectedCredential.wrapRequestContext(r.Context())
|
||||
@@ -504,7 +505,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}()
|
||||
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if err != nil {
|
||||
s.logger.Error("create proxy request: ", err)
|
||||
s.logger.ErrorContext(ctx, "create proxy request: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
return
|
||||
}
|
||||
@@ -535,12 +536,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
response.Body.Close()
|
||||
s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
requestContext.cancelRequest()
|
||||
requestContext = nextCredential.wrapRequestContext(r.Context())
|
||||
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if buildErr != nil {
|
||||
s.logger.Error("retry request: ", buildErr)
|
||||
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
|
||||
return
|
||||
}
|
||||
@@ -553,7 +554,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
|
||||
return
|
||||
}
|
||||
s.logger.Error("retry request: ", retryErr)
|
||||
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
|
||||
return
|
||||
}
|
||||
@@ -567,7 +568,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
||||
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
||||
go selectedCredential.pollUsage(s.ctx)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
|
||||
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
|
||||
@@ -589,7 +590,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||
if usageTracker != nil && response.StatusCode == http.StatusOK &&
|
||||
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) {
|
||||
s.handleResponseWithTracking(w, response, usageTracker, path, requestModel, username)
|
||||
s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username)
|
||||
} else {
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
if err == nil && mediaType != "text/event-stream" {
|
||||
@@ -598,7 +599,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error("streaming not supported")
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
@@ -607,7 +608,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if n > 0 {
|
||||
_, writeError := w.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.Error("write streaming response: ", writeError)
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
@@ -619,7 +620,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) {
|
||||
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) {
|
||||
isChatCompletions := path == "/v1/chat/completions"
|
||||
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
@@ -630,7 +631,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
if !isStreaming {
|
||||
bodyBytes, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
s.logger.Error("read response body: ", err)
|
||||
s.logger.ErrorContext(ctx, "read response body: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -683,7 +684,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error("streaming not supported")
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -760,7 +761,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
|
||||
_, writeError := writer.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.Error("write streaming response: ", writeError)
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
@@ -82,6 +82,7 @@ func isForwardableWebSocketRequestHeader(key string) bool {
|
||||
}
|
||||
|
||||
func (s *Service) handleWebSocket(
|
||||
ctx context.Context,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
path string,
|
||||
@@ -105,7 +106,7 @@ func (s *Service) handleWebSocket(
|
||||
for {
|
||||
accessToken, accessErr := selectedCredential.getAccessToken()
|
||||
if accessErr != nil {
|
||||
s.logger.Error("get access token for websocket: ", accessErr)
|
||||
s.logger.ErrorContext(ctx, "get access token for websocket: ", accessErr)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed")
|
||||
return
|
||||
}
|
||||
@@ -190,14 +191,14 @@ func (s *Service) handleWebSocket(
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
|
||||
return
|
||||
}
|
||||
s.logger.Info("retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
s.logger.InfoContext(ctx, "retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
selectedCredential = nextCredential
|
||||
continue
|
||||
}
|
||||
if statusCode > 0 && statusResponseBody != "" {
|
||||
s.logger.Error("dial upstream websocket: status ", statusCode, " body: ", statusResponseBody)
|
||||
s.logger.ErrorContext(ctx, "dial upstream websocket: status ", statusCode, " body: ", statusResponseBody)
|
||||
} else {
|
||||
s.logger.Error("dial upstream websocket: ", err)
|
||||
s.logger.ErrorContext(ctx, "dial upstream websocket: ", err)
|
||||
}
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", "upstream websocket connection failed")
|
||||
return
|
||||
@@ -226,7 +227,7 @@ func (s *Service) handleWebSocket(
|
||||
}
|
||||
clientConn, _, _, err := clientUpgrader.Upgrade(r, w)
|
||||
if err != nil {
|
||||
s.logger.Error("upgrade client websocket: ", err)
|
||||
s.logger.ErrorContext(ctx, "upgrade client websocket: ", err)
|
||||
upstreamConn.Close()
|
||||
return
|
||||
}
|
||||
@@ -258,23 +259,23 @@ func (s *Service) handleWebSocket(
|
||||
go func() {
|
||||
defer waitGroup.Done()
|
||||
defer session.Close()
|
||||
s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, selectedCredential, modelChannel, isNew, username, sessionID)
|
||||
s.proxyWebSocketClientToUpstream(ctx, clientConn, upstreamConn, selectedCredential, modelChannel, isNew, username, sessionID)
|
||||
}()
|
||||
go func() {
|
||||
defer waitGroup.Done()
|
||||
defer session.Close()
|
||||
s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint)
|
||||
s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint)
|
||||
}()
|
||||
waitGroup.Wait()
|
||||
}
|
||||
|
||||
func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string, isNew bool, username string, sessionID string) {
|
||||
func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string, isNew bool, username string, sessionID string) {
|
||||
logged := false
|
||||
for {
|
||||
data, opCode, err := wsutil.ReadClientData(clientConn)
|
||||
if err != nil {
|
||||
if !E.IsClosedOrCanceled(err) {
|
||||
s.logger.Debug("read client websocket: ", err)
|
||||
s.logger.DebugContext(ctx, "read client websocket: ", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -299,7 +300,7 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo
|
||||
if request.ServiceTier == "priority" {
|
||||
logParts = append(logParts, ", fast")
|
||||
}
|
||||
s.logger.Debug(logParts...)
|
||||
s.logger.DebugContext(ctx, logParts...)
|
||||
}
|
||||
if selectedCredential.usageTrackerOrNil() != nil {
|
||||
select {
|
||||
@@ -313,21 +314,21 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo
|
||||
err = wsutil.WriteClientMessage(upstreamConn, opCode, data)
|
||||
if err != nil {
|
||||
if !E.IsClosedOrCanceled(err) {
|
||||
s.logger.Debug("write upstream websocket: ", err)
|
||||
s.logger.DebugContext(ctx, "write upstream websocket: ", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
|
||||
func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
|
||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||
var requestModel string
|
||||
for {
|
||||
data, opCode, err := wsutil.ReadServerData(upstreamReadWriter)
|
||||
if err != nil {
|
||||
if !E.IsClosedOrCanceled(err) {
|
||||
s.logger.Debug("read upstream websocket: ", err)
|
||||
s.logger.DebugContext(ctx, "read upstream websocket: ", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -367,7 +368,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite
|
||||
err = wsutil.WriteServerMessage(clientConn, opCode, data)
|
||||
if err != nil {
|
||||
if !E.IsClosedOrCanceled(err) {
|
||||
s.logger.Debug("write client websocket: ", err)
|
||||
s.logger.DebugContext(ctx, "write client websocket: ", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user