ccm,ocm: add request ID context to HTTP request logging

This commit is contained in:
世界
2026-03-14 16:14:24 +08:00
parent 016e5e1b12
commit 5b29fd3be4
5 changed files with 74 additions and 71 deletions

View File

@@ -52,7 +52,7 @@ func (l *yamuxNetListener) Addr() net.Addr {
return l.session.Addr()
}
func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") != "reverse-proxy" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header")
return
@@ -71,21 +71,21 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
receiverCredential := s.findReceiverCredential(clientToken)
if receiverCredential == nil {
s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
s.logger.WarnContext(ctx, "reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token")
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
s.logger.Error("reverse connect: hijack not supported")
s.logger.ErrorContext(ctx, "reverse connect: hijack not supported")
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported")
return
}
conn, bufferedReadWriter, err := hijacker.Hijack()
if err != nil {
s.logger.Error("reverse connect: hijack: ", err)
s.logger.ErrorContext(ctx, "reverse connect: hijack: ", err)
return
}
@@ -93,20 +93,20 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
_, err = bufferedReadWriter.WriteString(response)
if err != nil {
conn.Close()
s.logger.Error("reverse connect: write upgrade response: ", err)
s.logger.ErrorContext(ctx, "reverse connect: write upgrade response: ", err)
return
}
err = bufferedReadWriter.Flush()
if err != nil {
conn.Close()
s.logger.Error("reverse connect: flush upgrade response: ", err)
s.logger.ErrorContext(ctx, "reverse connect: flush upgrade response: ", err)
return
}
session, err := yamux.Client(conn, reverseYamuxConfig())
if err != nil {
conn.Close()
s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
return
}
@@ -114,12 +114,12 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
session.Close()
return
}
s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
s.logger.InfoContext(ctx, "reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
go func() {
<-session.CloseChan()
receiverCredential.clearReverseSession(session)
s.logger.Warn("reverse connection lost for ", receiverCredential.tagName())
s.logger.WarnContext(ctx, "reverse connection lost for ", receiverCredential.tagName())
}()
}

View File

@@ -325,13 +325,14 @@ func detectContextWindow(betaHeader string, totalInputTokens int64) int {
}
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := log.ContextWithNewID(r.Context())
if r.URL.Path == "/ccm/v1/status" {
s.handleStatusEndpoint(w, r)
return
}
if r.URL.Path == "/ccm/v1/reverse" {
s.handleReverseConnect(w, r)
s.handleReverseConnect(ctx, w, r)
return
}
@@ -344,20 +345,20 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if len(s.options.Users) > 0 {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
var ok bool
username, ok = s.userManager.Authenticate(clientToken)
if !ok {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
@@ -373,7 +374,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var err error
bodyBytes, err = io.ReadAll(r.Body)
if err != nil {
s.logger.Error("read request body: ", err)
s.logger.ErrorContext(ctx, "read request body: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
return
}
@@ -400,7 +401,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var err error
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
s.logger.Error("resolve credential: ", err)
s.logger.ErrorContext(ctx, "resolve credential: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
@@ -448,7 +449,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
logParts = append(logParts, ", model=", modelDisplay)
}
s.logger.Debug(logParts...)
s.logger.DebugContext(ctx, logParts...)
}
if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() {
@@ -463,7 +464,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}()
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if err != nil {
s.logger.Error("create proxy request: ", err)
s.logger.ErrorContext(ctx, "create proxy request: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
return
}
@@ -493,12 +494,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
response.Body.Close()
s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
requestContext.cancelRequest()
requestContext = nextCredential.wrapRequestContext(r.Context())
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
s.logger.Error("retry request: ", buildErr)
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
return
}
@@ -511,7 +512,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
return
}
s.logger.Error("retry request: ", retryErr)
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
return
}
@@ -525,7 +526,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
go selectedCredential.pollUsage(s.ctx)
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
@@ -546,7 +547,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
usageTracker := selectedCredential.usageTrackerOrNil()
if usageTracker != nil && response.StatusCode == http.StatusOK {
s.handleResponseWithTracking(w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
} else {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
if err == nil && mediaType != "text/event-stream" {
@@ -555,7 +556,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
flusher, ok := w.(http.Flusher)
if !ok {
s.logger.Error("streaming not supported")
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
buffer := make([]byte, buf.BufferSize)
@@ -564,7 +565,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if n > 0 {
_, writeError := w.Write(buffer[:n])
if writeError != nil {
s.logger.Error("write streaming response: ", writeError)
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
@@ -576,7 +577,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
isStreaming := err == nil && mediaType == "text/event-stream"
@@ -584,7 +585,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
if !isStreaming {
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
s.logger.Error("read response body: ", err)
s.logger.ErrorContext(ctx, "read response body: ", err)
return
}
@@ -627,7 +628,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
flusher, ok := writer.(http.Flusher)
if !ok {
s.logger.Error("streaming not supported")
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
@@ -690,7 +691,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
_, writeError := writer.Write(buffer[:n])
if writeError != nil {
s.logger.Error("write streaming response: ", writeError)
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()

View File

@@ -52,7 +52,7 @@ func (l *yamuxNetListener) Addr() net.Addr {
return l.session.Addr()
}
func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") != "reverse-proxy" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header")
return
@@ -71,21 +71,21 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
receiverCredential := s.findReceiverCredential(clientToken)
if receiverCredential == nil {
s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
s.logger.WarnContext(ctx, "reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token")
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
s.logger.Error("reverse connect: hijack not supported")
s.logger.ErrorContext(ctx, "reverse connect: hijack not supported")
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported")
return
}
conn, bufferedReadWriter, err := hijacker.Hijack()
if err != nil {
s.logger.Error("reverse connect: hijack: ", err)
s.logger.ErrorContext(ctx, "reverse connect: hijack: ", err)
return
}
@@ -93,20 +93,20 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
_, err = bufferedReadWriter.WriteString(response)
if err != nil {
conn.Close()
s.logger.Error("reverse connect: write upgrade response: ", err)
s.logger.ErrorContext(ctx, "reverse connect: write upgrade response: ", err)
return
}
err = bufferedReadWriter.Flush()
if err != nil {
conn.Close()
s.logger.Error("reverse connect: flush upgrade response: ", err)
s.logger.ErrorContext(ctx, "reverse connect: flush upgrade response: ", err)
return
}
session, err := yamux.Client(conn, reverseYamuxConfig())
if err != nil {
conn.Close()
s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
return
}
@@ -114,12 +114,12 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
session.Close()
return
}
s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
s.logger.InfoContext(ctx, "reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
go func() {
<-session.CloseChan()
receiverCredential.clearReverseSession(session)
s.logger.Warn("reverse connection lost for ", receiverCredential.tagName())
s.logger.WarnContext(ctx, "reverse connection lost for ", receiverCredential.tagName())
}()
}

View File

@@ -362,13 +362,14 @@ func (s *Service) resolveCredentialProvider(username string) (credentialProvider
}
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := log.ContextWithNewID(r.Context())
if r.URL.Path == "/ocm/v1/status" {
s.handleStatusEndpoint(w, r)
return
}
if r.URL.Path == "/ocm/v1/reverse" {
s.handleReverseConnect(w, r)
s.handleReverseConnect(ctx, w, r)
return
}
@@ -382,20 +383,20 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if len(s.options.Users) > 0 {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
var ok bool
username, ok = s.userManager.Authenticate(clientToken)
if !ok {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
@@ -411,7 +412,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var err error
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
s.logger.Error("resolve credential: ", err)
s.logger.ErrorContext(ctx, "resolve credential: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
@@ -437,7 +438,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
s.handleWebSocket(w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew)
s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew)
return
}
@@ -465,7 +466,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if isJSONRequest {
bodyBytes, err = io.ReadAll(r.Body)
if err != nil {
s.logger.Error("read request body: ", err)
s.logger.ErrorContext(ctx, "read request body: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
return
}
@@ -495,7 +496,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if requestServiceTier == "priority" {
logParts = append(logParts, ", fast")
}
s.logger.Debug(logParts...)
s.logger.DebugContext(ctx, logParts...)
}
requestContext := selectedCredential.wrapRequestContext(r.Context())
@@ -504,7 +505,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}()
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if err != nil {
s.logger.Error("create proxy request: ", err)
s.logger.ErrorContext(ctx, "create proxy request: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
return
}
@@ -535,12 +536,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
response.Body.Close()
s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
requestContext.cancelRequest()
requestContext = nextCredential.wrapRequestContext(r.Context())
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
s.logger.Error("retry request: ", buildErr)
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
return
}
@@ -553,7 +554,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
return
}
s.logger.Error("retry request: ", retryErr)
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
return
}
@@ -567,7 +568,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
go selectedCredential.pollUsage(s.ctx)
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
@@ -589,7 +590,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
usageTracker := selectedCredential.usageTrackerOrNil()
if usageTracker != nil && response.StatusCode == http.StatusOK &&
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) {
s.handleResponseWithTracking(w, response, usageTracker, path, requestModel, username)
s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username)
} else {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
if err == nil && mediaType != "text/event-stream" {
@@ -598,7 +599,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
flusher, ok := w.(http.Flusher)
if !ok {
s.logger.Error("streaming not supported")
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
buffer := make([]byte, buf.BufferSize)
@@ -607,7 +608,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if n > 0 {
_, writeError := w.Write(buffer[:n])
if writeError != nil {
s.logger.Error("write streaming response: ", writeError)
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()
@@ -619,7 +620,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) {
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) {
isChatCompletions := path == "/v1/chat/completions"
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
@@ -630,7 +631,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
if !isStreaming {
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
s.logger.Error("read response body: ", err)
s.logger.ErrorContext(ctx, "read response body: ", err)
return
}
@@ -683,7 +684,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
flusher, ok := writer.(http.Flusher)
if !ok {
s.logger.Error("streaming not supported")
s.logger.ErrorContext(ctx, "streaming not supported")
return
}
@@ -760,7 +761,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
_, writeError := writer.Write(buffer[:n])
if writeError != nil {
s.logger.Error("write streaming response: ", writeError)
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
return
}
flusher.Flush()

View File

@@ -82,6 +82,7 @@ func isForwardableWebSocketRequestHeader(key string) bool {
}
func (s *Service) handleWebSocket(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
path string,
@@ -105,7 +106,7 @@ func (s *Service) handleWebSocket(
for {
accessToken, accessErr := selectedCredential.getAccessToken()
if accessErr != nil {
s.logger.Error("get access token for websocket: ", accessErr)
s.logger.ErrorContext(ctx, "get access token for websocket: ", accessErr)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed")
return
}
@@ -190,14 +191,14 @@ func (s *Service) handleWebSocket(
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
return
}
s.logger.Info("retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
s.logger.InfoContext(ctx, "retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
selectedCredential = nextCredential
continue
}
if statusCode > 0 && statusResponseBody != "" {
s.logger.Error("dial upstream websocket: status ", statusCode, " body: ", statusResponseBody)
s.logger.ErrorContext(ctx, "dial upstream websocket: status ", statusCode, " body: ", statusResponseBody)
} else {
s.logger.Error("dial upstream websocket: ", err)
s.logger.ErrorContext(ctx, "dial upstream websocket: ", err)
}
writeJSONError(w, r, http.StatusBadGateway, "api_error", "upstream websocket connection failed")
return
@@ -226,7 +227,7 @@ func (s *Service) handleWebSocket(
}
clientConn, _, _, err := clientUpgrader.Upgrade(r, w)
if err != nil {
s.logger.Error("upgrade client websocket: ", err)
s.logger.ErrorContext(ctx, "upgrade client websocket: ", err)
upstreamConn.Close()
return
}
@@ -258,23 +259,23 @@ func (s *Service) handleWebSocket(
go func() {
defer waitGroup.Done()
defer session.Close()
s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, selectedCredential, modelChannel, isNew, username, sessionID)
s.proxyWebSocketClientToUpstream(ctx, clientConn, upstreamConn, selectedCredential, modelChannel, isNew, username, sessionID)
}()
go func() {
defer waitGroup.Done()
defer session.Close()
s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint)
s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint)
}()
waitGroup.Wait()
}
func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string, isNew bool, username string, sessionID string) {
func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string, isNew bool, username string, sessionID string) {
logged := false
for {
data, opCode, err := wsutil.ReadClientData(clientConn)
if err != nil {
if !E.IsClosedOrCanceled(err) {
s.logger.Debug("read client websocket: ", err)
s.logger.DebugContext(ctx, "read client websocket: ", err)
}
return
}
@@ -299,7 +300,7 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo
if request.ServiceTier == "priority" {
logParts = append(logParts, ", fast")
}
s.logger.Debug(logParts...)
s.logger.DebugContext(ctx, logParts...)
}
if selectedCredential.usageTrackerOrNil() != nil {
select {
@@ -313,21 +314,21 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo
err = wsutil.WriteClientMessage(upstreamConn, opCode, data)
if err != nil {
if !E.IsClosedOrCanceled(err) {
s.logger.Debug("write upstream websocket: ", err)
s.logger.DebugContext(ctx, "write upstream websocket: ", err)
}
return
}
}
}
func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
usageTracker := selectedCredential.usageTrackerOrNil()
var requestModel string
for {
data, opCode, err := wsutil.ReadServerData(upstreamReadWriter)
if err != nil {
if !E.IsClosedOrCanceled(err) {
s.logger.Debug("read upstream websocket: ", err)
s.logger.DebugContext(ctx, "read upstream websocket: ", err)
}
return
}
@@ -367,7 +368,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite
err = wsutil.WriteServerMessage(clientConn, opCode, data)
if err != nil {
if !E.IsClosedOrCanceled(err) {
s.logger.Debug("write client websocket: ", err)
s.logger.DebugContext(ctx, "write client websocket: ", err)
}
return
}