diff --git a/service/ccm/reverse.go b/service/ccm/reverse.go index 625e55a9d..62a101117 100644 --- a/service/ccm/reverse.go +++ b/service/ccm/reverse.go @@ -52,7 +52,7 @@ func (l *yamuxNetListener) Addr() net.Addr { return l.session.Addr() } -func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { +func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWriter, r *http.Request) { if r.Header.Get("Upgrade") != "reverse-proxy" { writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header") return @@ -71,21 +71,21 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { receiverCredential := s.findReceiverCredential(clientToken) if receiverCredential == nil { - s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential") + s.logger.WarnContext(ctx, "reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token") return } hijacker, ok := w.(http.Hijacker) if !ok { - s.logger.Error("reverse connect: hijack not supported") + s.logger.ErrorContext(ctx, "reverse connect: hijack not supported") writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported") return } conn, bufferedReadWriter, err := hijacker.Hijack() if err != nil { - s.logger.Error("reverse connect: hijack: ", err) + s.logger.ErrorContext(ctx, "reverse connect: hijack: ", err) return } @@ -93,20 +93,20 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { _, err = bufferedReadWriter.WriteString(response) if err != nil { conn.Close() - s.logger.Error("reverse connect: write upgrade response: ", err) + s.logger.ErrorContext(ctx, "reverse connect: write upgrade response: ", err) return } err = bufferedReadWriter.Flush() if err != nil { conn.Close() - s.logger.Error("reverse connect: flush upgrade response: ", err) + s.logger.ErrorContext(ctx, "reverse connect: flush upgrade response: ", err) return } session, err := yamux.Client(conn, reverseYamuxConfig()) if err != nil { conn.Close() - s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err) + s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err) return } @@ -114,12 +114,12 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { session.Close() return } - s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr) + s.logger.InfoContext(ctx, "reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr) go func() { <-session.CloseChan() receiverCredential.clearReverseSession(session) - s.logger.Warn("reverse connection lost for ", receiverCredential.tagName()) + s.logger.WarnContext(ctx, "reverse connection lost for ", receiverCredential.tagName()) }() } diff --git a/service/ccm/service.go b/service/ccm/service.go index e8aaebdc4..81e3b38a5 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -325,13 +325,14 @@ func detectContextWindow(betaHeader string, totalInputTokens int64) int { } func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := log.ContextWithNewID(r.Context()) if r.URL.Path == "/ccm/v1/status" { s.handleStatusEndpoint(w, r) return } if r.URL.Path == "/ccm/v1/reverse" { - s.handleReverseConnect(w, r) + s.handleReverseConnect(ctx, w, r) return } @@ -344,20 +345,20 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if len(s.options.Users) > 0 { authHeader := r.Header.Get("Authorization") if authHeader == "" { - s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") return } clientToken := strings.TrimPrefix(authHeader, "Bearer ") if clientToken == authHeader { - s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") return } var ok bool username, ok = s.userManager.Authenticate(clientToken) if !ok { - s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") return } @@ -373,7 +374,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { var err error bodyBytes, err = io.ReadAll(r.Body) if err != nil { - s.logger.Error("read request body: ", err) + s.logger.ErrorContext(ctx, "read request body: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") return } @@ -400,7 +401,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { var err error provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) if err != nil { - s.logger.Error("resolve credential: ", err) + s.logger.ErrorContext(ctx, "resolve credential: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) return } @@ -448,7 +449,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } logParts = append(logParts, ", model=", modelDisplay) } - s.logger.Debug(logParts...) + s.logger.DebugContext(ctx, logParts...) } if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() { @@ -463,7 +464,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { }() proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if err != nil { - s.logger.Error("create proxy request: ", err) + s.logger.ErrorContext(ctx, "create proxy request: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") return } @@ -493,12 +494,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } response.Body.Close() - s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) + s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) requestContext.cancelRequest() requestContext = nextCredential.wrapRequestContext(r.Context()) retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if buildErr != nil { - s.logger.Error("retry request: ", buildErr) + s.logger.ErrorContext(ctx, "retry request: ", buildErr) writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error()) return } @@ -511,7 +512,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request") return } - s.logger.Error("retry request: ", retryErr) + s.logger.ErrorContext(ctx, "retry request: ", retryErr) writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) return } @@ -525,7 +526,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { body, _ := io.ReadAll(response.Body) - s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) + s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) go selectedCredential.pollUsage(s.ctx) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) @@ -546,7 +547,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { usageTracker := selectedCredential.usageTrackerOrNil() if usageTracker != nil && response.StatusCode == http.StatusOK { - s.handleResponseWithTracking(w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username) + s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username) } else { mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) if err == nil && mediaType != "text/event-stream" { @@ -555,7 +556,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } flusher, ok := w.(http.Flusher) if !ok { - s.logger.Error("streaming not supported") + s.logger.ErrorContext(ctx, "streaming not supported") return } buffer := make([]byte, buf.BufferSize) @@ -564,7 +565,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if n > 0 { _, writeError := w.Write(buffer[:n]) if writeError != nil { - s.logger.Error("write streaming response: ", writeError) + s.logger.ErrorContext(ctx, "write streaming response: ", writeError) return } flusher.Flush() @@ -576,7 +577,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) { +func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) { weeklyCycleHint := extractWeeklyCycleHint(response.Header) mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) isStreaming := err == nil && mediaType == "text/event-stream" @@ -584,7 +585,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons if !isStreaming { bodyBytes, err := io.ReadAll(response.Body) if err != nil { - s.logger.Error("read response body: ", err) + s.logger.ErrorContext(ctx, "read response body: ", err) return } @@ -627,7 +628,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons flusher, ok := writer.(http.Flusher) if !ok { - s.logger.Error("streaming not supported") + s.logger.ErrorContext(ctx, "streaming not supported") return } @@ -690,7 +691,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons _, writeError := writer.Write(buffer[:n]) if writeError != nil { - s.logger.Error("write streaming response: ", writeError) + s.logger.ErrorContext(ctx, "write streaming response: ", writeError) return } flusher.Flush() diff --git a/service/ocm/reverse.go b/service/ocm/reverse.go index 906778df5..1ed274f6d 100644 --- a/service/ocm/reverse.go +++ b/service/ocm/reverse.go @@ -52,7 +52,7 @@ func (l *yamuxNetListener) Addr() net.Addr { return l.session.Addr() } -func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { +func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWriter, r *http.Request) { if r.Header.Get("Upgrade") != "reverse-proxy" { writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header") return @@ -71,21 +71,21 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { receiverCredential := s.findReceiverCredential(clientToken) if receiverCredential == nil { - s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential") + s.logger.WarnContext(ctx, "reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token") return } hijacker, ok := w.(http.Hijacker) if !ok { - s.logger.Error("reverse connect: hijack not supported") + s.logger.ErrorContext(ctx, "reverse connect: hijack not supported") writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported") return } conn, bufferedReadWriter, err := hijacker.Hijack() if err != nil { - s.logger.Error("reverse connect: hijack: ", err) + s.logger.ErrorContext(ctx, "reverse connect: hijack: ", err) return } @@ -93,20 +93,20 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { _, err = bufferedReadWriter.WriteString(response) if err != nil { conn.Close() - s.logger.Error("reverse connect: write upgrade response: ", err) + s.logger.ErrorContext(ctx, "reverse connect: write upgrade response: ", err) return } err = bufferedReadWriter.Flush() if err != nil { conn.Close() - s.logger.Error("reverse connect: flush upgrade response: ", err) + s.logger.ErrorContext(ctx, "reverse connect: flush upgrade response: ", err) return } session, err := yamux.Client(conn, reverseYamuxConfig()) if err != nil { conn.Close() - s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err) + s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err) return } @@ -114,12 +114,12 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { session.Close() return } - s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr) + s.logger.InfoContext(ctx, "reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr) go func() { <-session.CloseChan() receiverCredential.clearReverseSession(session) - s.logger.Warn("reverse connection lost for ", receiverCredential.tagName()) + s.logger.WarnContext(ctx, "reverse connection lost for ", receiverCredential.tagName()) }() } diff --git a/service/ocm/service.go b/service/ocm/service.go index 858595b1f..94a98c665 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -362,13 +362,14 @@ func (s *Service) resolveCredentialProvider(username string) (credentialProvider } func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := log.ContextWithNewID(r.Context()) if r.URL.Path == "/ocm/v1/status" { s.handleStatusEndpoint(w, r) return } if r.URL.Path == "/ocm/v1/reverse" { - s.handleReverseConnect(w, r) + s.handleReverseConnect(ctx, w, r) return } @@ -382,20 +383,20 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if len(s.options.Users) > 0 { authHeader := r.Header.Get("Authorization") if authHeader == "" { - s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") return } clientToken := strings.TrimPrefix(authHeader, "Bearer ") if clientToken == authHeader { - s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format") writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") return } var ok bool username, ok = s.userManager.Authenticate(clientToken) if !ok { - s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) + s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken) writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key") return } @@ -411,7 +412,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { var err error provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username) if err != nil { - s.logger.Error("resolve credential: ", err) + s.logger.ErrorContext(ctx, "resolve credential: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error()) return } @@ -437,7 +438,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") { - s.handleWebSocket(w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew) + s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew) return } @@ -465,7 +466,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if isJSONRequest { bodyBytes, err = io.ReadAll(r.Body) if err != nil { - s.logger.Error("read request body: ", err) + s.logger.ErrorContext(ctx, "read request body: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body") return } @@ -495,7 +496,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if requestServiceTier == "priority" { logParts = append(logParts, ", fast") } - s.logger.Debug(logParts...) + s.logger.DebugContext(ctx, logParts...) } requestContext := selectedCredential.wrapRequestContext(r.Context()) @@ -504,7 +505,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { }() proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if err != nil { - s.logger.Error("create proxy request: ", err) + s.logger.ErrorContext(ctx, "create proxy request: ", err) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error") return } @@ -535,12 +536,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } response.Body.Close() - s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) + s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) requestContext.cancelRequest() requestContext = nextCredential.wrapRequestContext(r.Context()) retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if buildErr != nil { - s.logger.Error("retry request: ", buildErr) + s.logger.ErrorContext(ctx, "retry request: ", buildErr) writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error()) return } @@ -553,7 +554,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request") return } - s.logger.Error("retry request: ", retryErr) + s.logger.ErrorContext(ctx, "retry request: ", retryErr) writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error()) return } @@ -567,7 +568,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests { body, _ := io.ReadAll(response.Body) - s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) + s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body)) go selectedCredential.pollUsage(s.ctx) writeJSONError(w, r, http.StatusInternalServerError, "api_error", "proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body)) @@ -589,7 +590,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { usageTracker := selectedCredential.usageTrackerOrNil() if usageTracker != nil && response.StatusCode == http.StatusOK && (path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) { - s.handleResponseWithTracking(w, response, usageTracker, path, requestModel, username) + s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username) } else { mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) if err == nil && mediaType != "text/event-stream" { @@ -598,7 +599,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } flusher, ok := w.(http.Flusher) if !ok { - s.logger.Error("streaming not supported") + s.logger.ErrorContext(ctx, "streaming not supported") return } buffer := make([]byte, buf.BufferSize) @@ -607,7 +608,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { if n > 0 { _, writeError := w.Write(buffer[:n]) if writeError != nil { - s.logger.Error("write streaming response: ", writeError) + s.logger.ErrorContext(ctx, "write streaming response: ", writeError) return } flusher.Flush() @@ -619,7 +620,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) { +func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) { isChatCompletions := path == "/v1/chat/completions" weeklyCycleHint := extractWeeklyCycleHint(response.Header) mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type")) @@ -630,7 +631,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons if !isStreaming { bodyBytes, err := io.ReadAll(response.Body) if err != nil { - s.logger.Error("read response body: ", err) + s.logger.ErrorContext(ctx, "read response body: ", err) return } @@ -683,7 +684,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons flusher, ok := writer.(http.Flusher) if !ok { - s.logger.Error("streaming not supported") + s.logger.ErrorContext(ctx, "streaming not supported") return } @@ -760,7 +761,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons _, writeError := writer.Write(buffer[:n]) if writeError != nil { - s.logger.Error("write streaming response: ", writeError) + s.logger.ErrorContext(ctx, "write streaming response: ", writeError) return } flusher.Flush() diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 21b25bafc..17178e8c2 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -82,6 +82,7 @@ func isForwardableWebSocketRequestHeader(key string) bool { } func (s *Service) handleWebSocket( + ctx context.Context, w http.ResponseWriter, r *http.Request, path string, @@ -105,7 +106,7 @@ func (s *Service) handleWebSocket( for { accessToken, accessErr := selectedCredential.getAccessToken() if accessErr != nil { - s.logger.Error("get access token for websocket: ", accessErr) + s.logger.ErrorContext(ctx, "get access token for websocket: ", accessErr) writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed") return } @@ -190,14 +191,14 @@ func (s *Service) handleWebSocket( writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited") return } - s.logger.Info("retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) + s.logger.InfoContext(ctx, "retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) selectedCredential = nextCredential continue } if statusCode > 0 && statusResponseBody != "" { - s.logger.Error("dial upstream websocket: status ", statusCode, " body: ", statusResponseBody) + s.logger.ErrorContext(ctx, "dial upstream websocket: status ", statusCode, " body: ", statusResponseBody) } else { - s.logger.Error("dial upstream websocket: ", err) + s.logger.ErrorContext(ctx, "dial upstream websocket: ", err) } writeJSONError(w, r, http.StatusBadGateway, "api_error", "upstream websocket connection failed") return @@ -226,7 +227,7 @@ func (s *Service) handleWebSocket( } clientConn, _, _, err := clientUpgrader.Upgrade(r, w) if err != nil { - s.logger.Error("upgrade client websocket: ", err) + s.logger.ErrorContext(ctx, "upgrade client websocket: ", err) upstreamConn.Close() return } @@ -258,23 +259,23 @@ func (s *Service) handleWebSocket( go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, selectedCredential, modelChannel, isNew, username, sessionID) + s.proxyWebSocketClientToUpstream(ctx, clientConn, upstreamConn, selectedCredential, modelChannel, isNew, username, sessionID) }() go func() { defer waitGroup.Done() defer session.Close() - s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint) + s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint) }() waitGroup.Wait() } -func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string, isNew bool, username string, sessionID string) { +func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string, isNew bool, username string, sessionID string) { logged := false for { data, opCode, err := wsutil.ReadClientData(clientConn) if err != nil { if !E.IsClosedOrCanceled(err) { - s.logger.Debug("read client websocket: ", err) + s.logger.DebugContext(ctx, "read client websocket: ", err) } return } @@ -299,7 +300,7 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo if request.ServiceTier == "priority" { logParts = append(logParts, ", fast") } - s.logger.Debug(logParts...) + s.logger.DebugContext(ctx, logParts...) } if selectedCredential.usageTrackerOrNil() != nil { select { @@ -313,21 +314,21 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo err = wsutil.WriteClientMessage(upstreamConn, opCode, data) if err != nil { if !E.IsClosedOrCanceled(err) { - s.logger.Debug("write upstream websocket: ", err) + s.logger.DebugContext(ctx, "write upstream websocket: ", err) } return } } } -func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { +func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) { usageTracker := selectedCredential.usageTrackerOrNil() var requestModel string for { data, opCode, err := wsutil.ReadServerData(upstreamReadWriter) if err != nil { if !E.IsClosedOrCanceled(err) { - s.logger.Debug("read upstream websocket: ", err) + s.logger.DebugContext(ctx, "read upstream websocket: ", err) } return } @@ -367,7 +368,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite err = wsutil.WriteServerMessage(clientConn, opCode, data) if err != nil { if !E.IsClosedOrCanceled(err) { - s.logger.Debug("write client websocket: ", err) + s.logger.DebugContext(ctx, "write client websocket: ", err) } return }