ccm,ocm: add request ID context to HTTP request logging

This commit is contained in:
世界
2026-03-14 16:14:24 +08:00
parent 016e5e1b12
commit 5b29fd3be4
5 changed files with 74 additions and 71 deletions

View File

@@ -52,7 +52,7 @@ func (l *yamuxNetListener) Addr() net.Addr {
return l.session.Addr() return l.session.Addr()
} }
func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") != "reverse-proxy" { if r.Header.Get("Upgrade") != "reverse-proxy" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header") writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header")
return return
@@ -71,21 +71,21 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
receiverCredential := s.findReceiverCredential(clientToken) receiverCredential := s.findReceiverCredential(clientToken)
if receiverCredential == nil { if receiverCredential == nil {
s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential") s.logger.WarnContext(ctx, "reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token")
return return
} }
hijacker, ok := w.(http.Hijacker) hijacker, ok := w.(http.Hijacker)
if !ok { if !ok {
s.logger.Error("reverse connect: hijack not supported") s.logger.ErrorContext(ctx, "reverse connect: hijack not supported")
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported") writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported")
return return
} }
conn, bufferedReadWriter, err := hijacker.Hijack() conn, bufferedReadWriter, err := hijacker.Hijack()
if err != nil { if err != nil {
s.logger.Error("reverse connect: hijack: ", err) s.logger.ErrorContext(ctx, "reverse connect: hijack: ", err)
return return
} }
@@ -93,20 +93,20 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
_, err = bufferedReadWriter.WriteString(response) _, err = bufferedReadWriter.WriteString(response)
if err != nil { if err != nil {
conn.Close() conn.Close()
s.logger.Error("reverse connect: write upgrade response: ", err) s.logger.ErrorContext(ctx, "reverse connect: write upgrade response: ", err)
return return
} }
err = bufferedReadWriter.Flush() err = bufferedReadWriter.Flush()
if err != nil { if err != nil {
conn.Close() conn.Close()
s.logger.Error("reverse connect: flush upgrade response: ", err) s.logger.ErrorContext(ctx, "reverse connect: flush upgrade response: ", err)
return return
} }
session, err := yamux.Client(conn, reverseYamuxConfig()) session, err := yamux.Client(conn, reverseYamuxConfig())
if err != nil { if err != nil {
conn.Close() conn.Close()
s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err) s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
return return
} }
@@ -114,12 +114,12 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
session.Close() session.Close()
return return
} }
s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr) s.logger.InfoContext(ctx, "reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
go func() { go func() {
<-session.CloseChan() <-session.CloseChan()
receiverCredential.clearReverseSession(session) receiverCredential.clearReverseSession(session)
s.logger.Warn("reverse connection lost for ", receiverCredential.tagName()) s.logger.WarnContext(ctx, "reverse connection lost for ", receiverCredential.tagName())
}() }()
} }

View File

@@ -325,13 +325,14 @@ func detectContextWindow(betaHeader string, totalInputTokens int64) int {
} }
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := log.ContextWithNewID(r.Context())
if r.URL.Path == "/ccm/v1/status" { if r.URL.Path == "/ccm/v1/status" {
s.handleStatusEndpoint(w, r) s.handleStatusEndpoint(w, r)
return return
} }
if r.URL.Path == "/ccm/v1/reverse" { if r.URL.Path == "/ccm/v1/reverse" {
s.handleReverseConnect(w, r) s.handleReverseConnect(ctx, w, r)
return return
} }
@@ -344,20 +345,20 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if len(s.options.Users) > 0 { if len(s.options.Users) > 0 {
authHeader := r.Header.Get("Authorization") authHeader := r.Header.Get("Authorization")
if authHeader == "" { if authHeader == "" {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return return
} }
clientToken := strings.TrimPrefix(authHeader, "Bearer ") clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader { if clientToken == authHeader {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return return
} }
var ok bool var ok bool
username, ok = s.userManager.Authenticate(clientToken) username, ok = s.userManager.Authenticate(clientToken)
if !ok { if !ok {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return return
} }
@@ -373,7 +374,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var err error var err error
bodyBytes, err = io.ReadAll(r.Body) bodyBytes, err = io.ReadAll(r.Body)
if err != nil { if err != nil {
s.logger.Error("read request body: ", err) s.logger.ErrorContext(ctx, "read request body: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
return return
} }
@@ -400,7 +401,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var err error var err error
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil { if err != nil {
s.logger.Error("resolve credential: ", err) s.logger.ErrorContext(ctx, "resolve credential: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return return
} }
@@ -448,7 +449,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
logParts = append(logParts, ", model=", modelDisplay) logParts = append(logParts, ", model=", modelDisplay)
} }
s.logger.Debug(logParts...) s.logger.DebugContext(ctx, logParts...)
} }
if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() { if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() {
@@ -463,7 +464,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}() }()
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if err != nil { if err != nil {
s.logger.Error("create proxy request: ", err) s.logger.ErrorContext(ctx, "create proxy request: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
return return
} }
@@ -493,12 +494,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
response.Body.Close() response.Body.Close()
s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
requestContext.cancelRequest() requestContext.cancelRequest()
requestContext = nextCredential.wrapRequestContext(r.Context()) requestContext = nextCredential.wrapRequestContext(r.Context())
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil { if buildErr != nil {
s.logger.Error("retry request: ", buildErr) s.logger.ErrorContext(ctx, "retry request: ", buildErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error()) writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
return return
} }
@@ -511,7 +512,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request") writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
return return
} }
s.logger.Error("retry request: ", retryErr) s.logger.ErrorContext(ctx, "retry request: ", retryErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
return return
} }
@@ -525,7 +526,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body) body, _ := io.ReadAll(response.Body)
s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
go selectedCredential.pollUsage(s.ctx) go selectedCredential.pollUsage(s.ctx)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", writeJSONError(w, r, http.StatusInternalServerError, "api_error",
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
@@ -546,7 +547,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
usageTracker := selectedCredential.usageTrackerOrNil() usageTracker := selectedCredential.usageTrackerOrNil()
if usageTracker != nil && response.StatusCode == http.StatusOK { if usageTracker != nil && response.StatusCode == http.StatusOK {
s.handleResponseWithTracking(w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username) s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
} else { } else {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
if err == nil && mediaType != "text/event-stream" { if err == nil && mediaType != "text/event-stream" {
@@ -555,7 +556,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
if !ok { if !ok {
s.logger.Error("streaming not supported") s.logger.ErrorContext(ctx, "streaming not supported")
return return
} }
buffer := make([]byte, buf.BufferSize) buffer := make([]byte, buf.BufferSize)
@@ -564,7 +565,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if n > 0 { if n > 0 {
_, writeError := w.Write(buffer[:n]) _, writeError := w.Write(buffer[:n])
if writeError != nil { if writeError != nil {
s.logger.Error("write streaming response: ", writeError) s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return return
} }
flusher.Flush() flusher.Flush()
@@ -576,7 +577,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) { func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
weeklyCycleHint := extractWeeklyCycleHint(response.Header) weeklyCycleHint := extractWeeklyCycleHint(response.Header)
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
isStreaming := err == nil && mediaType == "text/event-stream" isStreaming := err == nil && mediaType == "text/event-stream"
@@ -584,7 +585,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
if !isStreaming { if !isStreaming {
bodyBytes, err := io.ReadAll(response.Body) bodyBytes, err := io.ReadAll(response.Body)
if err != nil { if err != nil {
s.logger.Error("read response body: ", err) s.logger.ErrorContext(ctx, "read response body: ", err)
return return
} }
@@ -627,7 +628,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
flusher, ok := writer.(http.Flusher) flusher, ok := writer.(http.Flusher)
if !ok { if !ok {
s.logger.Error("streaming not supported") s.logger.ErrorContext(ctx, "streaming not supported")
return return
} }
@@ -690,7 +691,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
_, writeError := writer.Write(buffer[:n]) _, writeError := writer.Write(buffer[:n])
if writeError != nil { if writeError != nil {
s.logger.Error("write streaming response: ", writeError) s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return return
} }
flusher.Flush() flusher.Flush()

View File

@@ -52,7 +52,7 @@ func (l *yamuxNetListener) Addr() net.Addr {
return l.session.Addr() return l.session.Addr()
} }
func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") != "reverse-proxy" { if r.Header.Get("Upgrade") != "reverse-proxy" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header") writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header")
return return
@@ -71,21 +71,21 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
receiverCredential := s.findReceiverCredential(clientToken) receiverCredential := s.findReceiverCredential(clientToken)
if receiverCredential == nil { if receiverCredential == nil {
s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential") s.logger.WarnContext(ctx, "reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token")
return return
} }
hijacker, ok := w.(http.Hijacker) hijacker, ok := w.(http.Hijacker)
if !ok { if !ok {
s.logger.Error("reverse connect: hijack not supported") s.logger.ErrorContext(ctx, "reverse connect: hijack not supported")
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported") writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported")
return return
} }
conn, bufferedReadWriter, err := hijacker.Hijack() conn, bufferedReadWriter, err := hijacker.Hijack()
if err != nil { if err != nil {
s.logger.Error("reverse connect: hijack: ", err) s.logger.ErrorContext(ctx, "reverse connect: hijack: ", err)
return return
} }
@@ -93,20 +93,20 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
_, err = bufferedReadWriter.WriteString(response) _, err = bufferedReadWriter.WriteString(response)
if err != nil { if err != nil {
conn.Close() conn.Close()
s.logger.Error("reverse connect: write upgrade response: ", err) s.logger.ErrorContext(ctx, "reverse connect: write upgrade response: ", err)
return return
} }
err = bufferedReadWriter.Flush() err = bufferedReadWriter.Flush()
if err != nil { if err != nil {
conn.Close() conn.Close()
s.logger.Error("reverse connect: flush upgrade response: ", err) s.logger.ErrorContext(ctx, "reverse connect: flush upgrade response: ", err)
return return
} }
session, err := yamux.Client(conn, reverseYamuxConfig()) session, err := yamux.Client(conn, reverseYamuxConfig())
if err != nil { if err != nil {
conn.Close() conn.Close()
s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err) s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
return return
} }
@@ -114,12 +114,12 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
session.Close() session.Close()
return return
} }
s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr) s.logger.InfoContext(ctx, "reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
go func() { go func() {
<-session.CloseChan() <-session.CloseChan()
receiverCredential.clearReverseSession(session) receiverCredential.clearReverseSession(session)
s.logger.Warn("reverse connection lost for ", receiverCredential.tagName()) s.logger.WarnContext(ctx, "reverse connection lost for ", receiverCredential.tagName())
}() }()
} }

View File

@@ -362,13 +362,14 @@ func (s *Service) resolveCredentialProvider(username string) (credentialProvider
} }
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := log.ContextWithNewID(r.Context())
if r.URL.Path == "/ocm/v1/status" { if r.URL.Path == "/ocm/v1/status" {
s.handleStatusEndpoint(w, r) s.handleStatusEndpoint(w, r)
return return
} }
if r.URL.Path == "/ocm/v1/reverse" { if r.URL.Path == "/ocm/v1/reverse" {
s.handleReverseConnect(w, r) s.handleReverseConnect(ctx, w, r)
return return
} }
@@ -382,20 +383,20 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if len(s.options.Users) > 0 { if len(s.options.Users) > 0 {
authHeader := r.Header.Get("Authorization") authHeader := r.Header.Get("Authorization")
if authHeader == "" { if authHeader == "" {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return return
} }
clientToken := strings.TrimPrefix(authHeader, "Bearer ") clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader { if clientToken == authHeader {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return return
} }
var ok bool var ok bool
username, ok = s.userManager.Authenticate(clientToken) username, ok = s.userManager.Authenticate(clientToken)
if !ok { if !ok {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return return
} }
@@ -411,7 +412,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var err error var err error
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil { if err != nil {
s.logger.Error("resolve credential: ", err) s.logger.ErrorContext(ctx, "resolve credential: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return return
} }
@@ -437,7 +438,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") { if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
s.handleWebSocket(w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew) s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew)
return return
} }
@@ -465,7 +466,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if isJSONRequest { if isJSONRequest {
bodyBytes, err = io.ReadAll(r.Body) bodyBytes, err = io.ReadAll(r.Body)
if err != nil { if err != nil {
s.logger.Error("read request body: ", err) s.logger.ErrorContext(ctx, "read request body: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
return return
} }
@@ -495,7 +496,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if requestServiceTier == "priority" { if requestServiceTier == "priority" {
logParts = append(logParts, ", fast") logParts = append(logParts, ", fast")
} }
s.logger.Debug(logParts...) s.logger.DebugContext(ctx, logParts...)
} }
requestContext := selectedCredential.wrapRequestContext(r.Context()) requestContext := selectedCredential.wrapRequestContext(r.Context())
@@ -504,7 +505,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}() }()
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if err != nil { if err != nil {
s.logger.Error("create proxy request: ", err) s.logger.ErrorContext(ctx, "create proxy request: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
return return
} }
@@ -535,12 +536,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
response.Body.Close() response.Body.Close()
s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
requestContext.cancelRequest() requestContext.cancelRequest()
requestContext = nextCredential.wrapRequestContext(r.Context()) requestContext = nextCredential.wrapRequestContext(r.Context())
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil { if buildErr != nil {
s.logger.Error("retry request: ", buildErr) s.logger.ErrorContext(ctx, "retry request: ", buildErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error()) writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
return return
} }
@@ -553,7 +554,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request") writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
return return
} }
s.logger.Error("retry request: ", retryErr) s.logger.ErrorContext(ctx, "retry request: ", retryErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
return return
} }
@@ -567,7 +568,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body) body, _ := io.ReadAll(response.Body)
s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
go selectedCredential.pollUsage(s.ctx) go selectedCredential.pollUsage(s.ctx)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", writeJSONError(w, r, http.StatusInternalServerError, "api_error",
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
@@ -589,7 +590,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
usageTracker := selectedCredential.usageTrackerOrNil() usageTracker := selectedCredential.usageTrackerOrNil()
if usageTracker != nil && response.StatusCode == http.StatusOK && if usageTracker != nil && response.StatusCode == http.StatusOK &&
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) { (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) {
s.handleResponseWithTracking(w, response, usageTracker, path, requestModel, username) s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username)
} else { } else {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
if err == nil && mediaType != "text/event-stream" { if err == nil && mediaType != "text/event-stream" {
@@ -598,7 +599,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
if !ok { if !ok {
s.logger.Error("streaming not supported") s.logger.ErrorContext(ctx, "streaming not supported")
return return
} }
buffer := make([]byte, buf.BufferSize) buffer := make([]byte, buf.BufferSize)
@@ -607,7 +608,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if n > 0 { if n > 0 {
_, writeError := w.Write(buffer[:n]) _, writeError := w.Write(buffer[:n])
if writeError != nil { if writeError != nil {
s.logger.Error("write streaming response: ", writeError) s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return return
} }
flusher.Flush() flusher.Flush()
@@ -619,7 +620,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) { func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) {
isChatCompletions := path == "/v1/chat/completions" isChatCompletions := path == "/v1/chat/completions"
weeklyCycleHint := extractWeeklyCycleHint(response.Header) weeklyCycleHint := extractWeeklyCycleHint(response.Header)
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
@@ -630,7 +631,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
if !isStreaming { if !isStreaming {
bodyBytes, err := io.ReadAll(response.Body) bodyBytes, err := io.ReadAll(response.Body)
if err != nil { if err != nil {
s.logger.Error("read response body: ", err) s.logger.ErrorContext(ctx, "read response body: ", err)
return return
} }
@@ -683,7 +684,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
flusher, ok := writer.(http.Flusher) flusher, ok := writer.(http.Flusher)
if !ok { if !ok {
s.logger.Error("streaming not supported") s.logger.ErrorContext(ctx, "streaming not supported")
return return
} }
@@ -760,7 +761,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
_, writeError := writer.Write(buffer[:n]) _, writeError := writer.Write(buffer[:n])
if writeError != nil { if writeError != nil {
s.logger.Error("write streaming response: ", writeError) s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return return
} }
flusher.Flush() flusher.Flush()

View File

@@ -82,6 +82,7 @@ func isForwardableWebSocketRequestHeader(key string) bool {
} }
func (s *Service) handleWebSocket( func (s *Service) handleWebSocket(
ctx context.Context,
w http.ResponseWriter, w http.ResponseWriter,
r *http.Request, r *http.Request,
path string, path string,
@@ -105,7 +106,7 @@ func (s *Service) handleWebSocket(
for { for {
accessToken, accessErr := selectedCredential.getAccessToken() accessToken, accessErr := selectedCredential.getAccessToken()
if accessErr != nil { if accessErr != nil {
s.logger.Error("get access token for websocket: ", accessErr) s.logger.ErrorContext(ctx, "get access token for websocket: ", accessErr)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed")
return return
} }
@@ -190,14 +191,14 @@ func (s *Service) handleWebSocket(
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited") writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
return return
} }
s.logger.Info("retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) s.logger.InfoContext(ctx, "retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
selectedCredential = nextCredential selectedCredential = nextCredential
continue continue
} }
if statusCode > 0 && statusResponseBody != "" { if statusCode > 0 && statusResponseBody != "" {
s.logger.Error("dial upstream websocket: status ", statusCode, " body: ", statusResponseBody) s.logger.ErrorContext(ctx, "dial upstream websocket: status ", statusCode, " body: ", statusResponseBody)
} else { } else {
s.logger.Error("dial upstream websocket: ", err) s.logger.ErrorContext(ctx, "dial upstream websocket: ", err)
} }
writeJSONError(w, r, http.StatusBadGateway, "api_error", "upstream websocket connection failed") writeJSONError(w, r, http.StatusBadGateway, "api_error", "upstream websocket connection failed")
return return
@@ -226,7 +227,7 @@ func (s *Service) handleWebSocket(
} }
clientConn, _, _, err := clientUpgrader.Upgrade(r, w) clientConn, _, _, err := clientUpgrader.Upgrade(r, w)
if err != nil { if err != nil {
s.logger.Error("upgrade client websocket: ", err) s.logger.ErrorContext(ctx, "upgrade client websocket: ", err)
upstreamConn.Close() upstreamConn.Close()
return return
} }
@@ -258,23 +259,23 @@ func (s *Service) handleWebSocket(
go func() { go func() {
defer waitGroup.Done() defer waitGroup.Done()
defer session.Close() defer session.Close()
s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, selectedCredential, modelChannel, isNew, username, sessionID) s.proxyWebSocketClientToUpstream(ctx, clientConn, upstreamConn, selectedCredential, modelChannel, isNew, username, sessionID)
}() }()
go func() { go func() {
defer waitGroup.Done() defer waitGroup.Done()
defer session.Close() defer session.Close()
s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint) s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint)
}() }()
waitGroup.Wait() waitGroup.Wait()
} }
func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string, isNew bool, username string, sessionID string) { func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string, isNew bool, username string, sessionID string) {
logged := false logged := false
for { for {
data, opCode, err := wsutil.ReadClientData(clientConn) data, opCode, err := wsutil.ReadClientData(clientConn)
if err != nil { if err != nil {
if !E.IsClosedOrCanceled(err) { if !E.IsClosedOrCanceled(err) {
s.logger.Debug("read client websocket: ", err) s.logger.DebugContext(ctx, "read client websocket: ", err)
} }
return return
} }
@@ -299,7 +300,7 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo
if request.ServiceTier == "priority" { if request.ServiceTier == "priority" {
logParts = append(logParts, ", fast") logParts = append(logParts, ", fast")
} }
s.logger.Debug(logParts...) s.logger.DebugContext(ctx, logParts...)
} }
if selectedCredential.usageTrackerOrNil() != nil { if selectedCredential.usageTrackerOrNil() != nil {
select { select {
@@ -313,21 +314,21 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo
err = wsutil.WriteClientMessage(upstreamConn, opCode, data) err = wsutil.WriteClientMessage(upstreamConn, opCode, data)
if err != nil { if err != nil {
if !E.IsClosedOrCanceled(err) { if !E.IsClosedOrCanceled(err) {
s.logger.Debug("write upstream websocket: ", err) s.logger.DebugContext(ctx, "write upstream websocket: ", err)
} }
return return
} }
} }
} }
func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
usageTracker := selectedCredential.usageTrackerOrNil() usageTracker := selectedCredential.usageTrackerOrNil()
var requestModel string var requestModel string
for { for {
data, opCode, err := wsutil.ReadServerData(upstreamReadWriter) data, opCode, err := wsutil.ReadServerData(upstreamReadWriter)
if err != nil { if err != nil {
if !E.IsClosedOrCanceled(err) { if !E.IsClosedOrCanceled(err) {
s.logger.Debug("read upstream websocket: ", err) s.logger.DebugContext(ctx, "read upstream websocket: ", err)
} }
return return
} }
@@ -367,7 +368,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite
err = wsutil.WriteServerMessage(clientConn, opCode, data) err = wsutil.WriteServerMessage(clientConn, opCode, data)
if err != nil { if err != nil {
if !E.IsClosedOrCanceled(err) { if !E.IsClosedOrCanceled(err) {
s.logger.Debug("write client websocket: ", err) s.logger.DebugContext(ctx, "write client websocket: ", err)
} }
return return
} }