mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-11 17:47:20 +10:00
ccm,ocm: add request ID context to HTTP request logging
This commit is contained in:
@@ -52,7 +52,7 @@ func (l *yamuxNetListener) Addr() net.Addr {
|
|||||||
return l.session.Addr()
|
return l.session.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
|
func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Header.Get("Upgrade") != "reverse-proxy" {
|
if r.Header.Get("Upgrade") != "reverse-proxy" {
|
||||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header")
|
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header")
|
||||||
return
|
return
|
||||||
@@ -71,21 +71,21 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
receiverCredential := s.findReceiverCredential(clientToken)
|
receiverCredential := s.findReceiverCredential(clientToken)
|
||||||
if receiverCredential == nil {
|
if receiverCredential == nil {
|
||||||
s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
|
s.logger.WarnContext(ctx, "reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
|
||||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token")
|
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hijacker, ok := w.(http.Hijacker)
|
hijacker, ok := w.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
s.logger.Error("reverse connect: hijack not supported")
|
s.logger.ErrorContext(ctx, "reverse connect: hijack not supported")
|
||||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported")
|
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, bufferedReadWriter, err := hijacker.Hijack()
|
conn, bufferedReadWriter, err := hijacker.Hijack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("reverse connect: hijack: ", err)
|
s.logger.ErrorContext(ctx, "reverse connect: hijack: ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,20 +93,20 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
|
|||||||
_, err = bufferedReadWriter.WriteString(response)
|
_, err = bufferedReadWriter.WriteString(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
s.logger.Error("reverse connect: write upgrade response: ", err)
|
s.logger.ErrorContext(ctx, "reverse connect: write upgrade response: ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = bufferedReadWriter.Flush()
|
err = bufferedReadWriter.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
s.logger.Error("reverse connect: flush upgrade response: ", err)
|
s.logger.ErrorContext(ctx, "reverse connect: flush upgrade response: ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := yamux.Client(conn, reverseYamuxConfig())
|
session, err := yamux.Client(conn, reverseYamuxConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
|
s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,12 +114,12 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
|
|||||||
session.Close()
|
session.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
|
s.logger.InfoContext(ctx, "reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-session.CloseChan()
|
<-session.CloseChan()
|
||||||
receiverCredential.clearReverseSession(session)
|
receiverCredential.clearReverseSession(session)
|
||||||
s.logger.Warn("reverse connection lost for ", receiverCredential.tagName())
|
s.logger.WarnContext(ctx, "reverse connection lost for ", receiverCredential.tagName())
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -325,13 +325,14 @@ func detectContextWindow(betaHeader string, totalInputTokens int64) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := log.ContextWithNewID(r.Context())
|
||||||
if r.URL.Path == "/ccm/v1/status" {
|
if r.URL.Path == "/ccm/v1/status" {
|
||||||
s.handleStatusEndpoint(w, r)
|
s.handleStatusEndpoint(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.URL.Path == "/ccm/v1/reverse" {
|
if r.URL.Path == "/ccm/v1/reverse" {
|
||||||
s.handleReverseConnect(w, r)
|
s.handleReverseConnect(ctx, w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -344,20 +345,20 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
if len(s.options.Users) > 0 {
|
if len(s.options.Users) > 0 {
|
||||||
authHeader := r.Header.Get("Authorization")
|
authHeader := r.Header.Get("Authorization")
|
||||||
if authHeader == "" {
|
if authHeader == "" {
|
||||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
if clientToken == authHeader {
|
if clientToken == authHeader {
|
||||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var ok bool
|
var ok bool
|
||||||
username, ok = s.userManager.Authenticate(clientToken)
|
username, ok = s.userManager.Authenticate(clientToken)
|
||||||
if !ok {
|
if !ok {
|
||||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -373,7 +374,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
var err error
|
var err error
|
||||||
bodyBytes, err = io.ReadAll(r.Body)
|
bodyBytes, err = io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("read request body: ", err)
|
s.logger.ErrorContext(ctx, "read request body: ", err)
|
||||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
|
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -400,7 +401,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
var err error
|
var err error
|
||||||
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("resolve credential: ", err)
|
s.logger.ErrorContext(ctx, "resolve credential: ", err)
|
||||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -448,7 +449,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
logParts = append(logParts, ", model=", modelDisplay)
|
logParts = append(logParts, ", model=", modelDisplay)
|
||||||
}
|
}
|
||||||
s.logger.Debug(logParts...)
|
s.logger.DebugContext(ctx, logParts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() {
|
if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() {
|
||||||
@@ -463,7 +464,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}()
|
}()
|
||||||
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("create proxy request: ", err)
|
s.logger.ErrorContext(ctx, "create proxy request: ", err)
|
||||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -493,12 +494,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
response.Body.Close()
|
response.Body.Close()
|
||||||
s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||||
requestContext.cancelRequest()
|
requestContext.cancelRequest()
|
||||||
requestContext = nextCredential.wrapRequestContext(r.Context())
|
requestContext = nextCredential.wrapRequestContext(r.Context())
|
||||||
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||||
if buildErr != nil {
|
if buildErr != nil {
|
||||||
s.logger.Error("retry request: ", buildErr)
|
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
|
||||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
|
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -511,7 +512,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
|
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.logger.Error("retry request: ", retryErr)
|
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
|
||||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
|
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -525,7 +526,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
|
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
|
||||||
body, _ := io.ReadAll(response.Body)
|
body, _ := io.ReadAll(response.Body)
|
||||||
s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
||||||
go selectedCredential.pollUsage(s.ctx)
|
go selectedCredential.pollUsage(s.ctx)
|
||||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
|
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
|
||||||
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
|
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
|
||||||
@@ -546,7 +547,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||||
if usageTracker != nil && response.StatusCode == http.StatusOK {
|
if usageTracker != nil && response.StatusCode == http.StatusOK {
|
||||||
s.handleResponseWithTracking(w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
|
s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
|
||||||
} else {
|
} else {
|
||||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||||
if err == nil && mediaType != "text/event-stream" {
|
if err == nil && mediaType != "text/event-stream" {
|
||||||
@@ -555,7 +556,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
flusher, ok := w.(http.Flusher)
|
flusher, ok := w.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
s.logger.Error("streaming not supported")
|
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
buffer := make([]byte, buf.BufferSize)
|
buffer := make([]byte, buf.BufferSize)
|
||||||
@@ -564,7 +565,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
if n > 0 {
|
if n > 0 {
|
||||||
_, writeError := w.Write(buffer[:n])
|
_, writeError := w.Write(buffer[:n])
|
||||||
if writeError != nil {
|
if writeError != nil {
|
||||||
s.logger.Error("write streaming response: ", writeError)
|
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
@@ -576,7 +577,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
|
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
|
||||||
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
||||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||||
isStreaming := err == nil && mediaType == "text/event-stream"
|
isStreaming := err == nil && mediaType == "text/event-stream"
|
||||||
@@ -584,7 +585,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
|||||||
if !isStreaming {
|
if !isStreaming {
|
||||||
bodyBytes, err := io.ReadAll(response.Body)
|
bodyBytes, err := io.ReadAll(response.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("read response body: ", err)
|
s.logger.ErrorContext(ctx, "read response body: ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -627,7 +628,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
|||||||
|
|
||||||
flusher, ok := writer.(http.Flusher)
|
flusher, ok := writer.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
s.logger.Error("streaming not supported")
|
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -690,7 +691,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
|||||||
|
|
||||||
_, writeError := writer.Write(buffer[:n])
|
_, writeError := writer.Write(buffer[:n])
|
||||||
if writeError != nil {
|
if writeError != nil {
|
||||||
s.logger.Error("write streaming response: ", writeError)
|
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ func (l *yamuxNetListener) Addr() net.Addr {
|
|||||||
return l.session.Addr()
|
return l.session.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
|
func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Header.Get("Upgrade") != "reverse-proxy" {
|
if r.Header.Get("Upgrade") != "reverse-proxy" {
|
||||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header")
|
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header")
|
||||||
return
|
return
|
||||||
@@ -71,21 +71,21 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
receiverCredential := s.findReceiverCredential(clientToken)
|
receiverCredential := s.findReceiverCredential(clientToken)
|
||||||
if receiverCredential == nil {
|
if receiverCredential == nil {
|
||||||
s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
|
s.logger.WarnContext(ctx, "reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
|
||||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token")
|
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hijacker, ok := w.(http.Hijacker)
|
hijacker, ok := w.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
s.logger.Error("reverse connect: hijack not supported")
|
s.logger.ErrorContext(ctx, "reverse connect: hijack not supported")
|
||||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported")
|
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, bufferedReadWriter, err := hijacker.Hijack()
|
conn, bufferedReadWriter, err := hijacker.Hijack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("reverse connect: hijack: ", err)
|
s.logger.ErrorContext(ctx, "reverse connect: hijack: ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,20 +93,20 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
|
|||||||
_, err = bufferedReadWriter.WriteString(response)
|
_, err = bufferedReadWriter.WriteString(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
s.logger.Error("reverse connect: write upgrade response: ", err)
|
s.logger.ErrorContext(ctx, "reverse connect: write upgrade response: ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = bufferedReadWriter.Flush()
|
err = bufferedReadWriter.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
s.logger.Error("reverse connect: flush upgrade response: ", err)
|
s.logger.ErrorContext(ctx, "reverse connect: flush upgrade response: ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := yamux.Client(conn, reverseYamuxConfig())
|
session, err := yamux.Client(conn, reverseYamuxConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
|
s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,12 +114,12 @@ func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
|
|||||||
session.Close()
|
session.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
|
s.logger.InfoContext(ctx, "reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-session.CloseChan()
|
<-session.CloseChan()
|
||||||
receiverCredential.clearReverseSession(session)
|
receiverCredential.clearReverseSession(session)
|
||||||
s.logger.Warn("reverse connection lost for ", receiverCredential.tagName())
|
s.logger.WarnContext(ctx, "reverse connection lost for ", receiverCredential.tagName())
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -362,13 +362,14 @@ func (s *Service) resolveCredentialProvider(username string) (credentialProvider
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := log.ContextWithNewID(r.Context())
|
||||||
if r.URL.Path == "/ocm/v1/status" {
|
if r.URL.Path == "/ocm/v1/status" {
|
||||||
s.handleStatusEndpoint(w, r)
|
s.handleStatusEndpoint(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.URL.Path == "/ocm/v1/reverse" {
|
if r.URL.Path == "/ocm/v1/reverse" {
|
||||||
s.handleReverseConnect(w, r)
|
s.handleReverseConnect(ctx, w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -382,20 +383,20 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
if len(s.options.Users) > 0 {
|
if len(s.options.Users) > 0 {
|
||||||
authHeader := r.Header.Get("Authorization")
|
authHeader := r.Header.Get("Authorization")
|
||||||
if authHeader == "" {
|
if authHeader == "" {
|
||||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
if clientToken == authHeader {
|
if clientToken == authHeader {
|
||||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var ok bool
|
var ok bool
|
||||||
username, ok = s.userManager.Authenticate(clientToken)
|
username, ok = s.userManager.Authenticate(clientToken)
|
||||||
if !ok {
|
if !ok {
|
||||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -411,7 +412,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
var err error
|
var err error
|
||||||
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("resolve credential: ", err)
|
s.logger.ErrorContext(ctx, "resolve credential: ", err)
|
||||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -437,7 +438,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
|
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
|
||||||
s.handleWebSocket(w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew)
|
s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -465,7 +466,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
if isJSONRequest {
|
if isJSONRequest {
|
||||||
bodyBytes, err = io.ReadAll(r.Body)
|
bodyBytes, err = io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("read request body: ", err)
|
s.logger.ErrorContext(ctx, "read request body: ", err)
|
||||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
|
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -495,7 +496,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
if requestServiceTier == "priority" {
|
if requestServiceTier == "priority" {
|
||||||
logParts = append(logParts, ", fast")
|
logParts = append(logParts, ", fast")
|
||||||
}
|
}
|
||||||
s.logger.Debug(logParts...)
|
s.logger.DebugContext(ctx, logParts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
requestContext := selectedCredential.wrapRequestContext(r.Context())
|
requestContext := selectedCredential.wrapRequestContext(r.Context())
|
||||||
@@ -504,7 +505,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}()
|
}()
|
||||||
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("create proxy request: ", err)
|
s.logger.ErrorContext(ctx, "create proxy request: ", err)
|
||||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -535,12 +536,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
response.Body.Close()
|
response.Body.Close()
|
||||||
s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||||
requestContext.cancelRequest()
|
requestContext.cancelRequest()
|
||||||
requestContext = nextCredential.wrapRequestContext(r.Context())
|
requestContext = nextCredential.wrapRequestContext(r.Context())
|
||||||
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||||
if buildErr != nil {
|
if buildErr != nil {
|
||||||
s.logger.Error("retry request: ", buildErr)
|
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
|
||||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
|
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -553,7 +554,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
|
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.logger.Error("retry request: ", retryErr)
|
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
|
||||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
|
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -567,7 +568,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
|
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
|
||||||
body, _ := io.ReadAll(response.Body)
|
body, _ := io.ReadAll(response.Body)
|
||||||
s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
||||||
go selectedCredential.pollUsage(s.ctx)
|
go selectedCredential.pollUsage(s.ctx)
|
||||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
|
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
|
||||||
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
|
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
|
||||||
@@ -589,7 +590,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||||
if usageTracker != nil && response.StatusCode == http.StatusOK &&
|
if usageTracker != nil && response.StatusCode == http.StatusOK &&
|
||||||
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) {
|
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) {
|
||||||
s.handleResponseWithTracking(w, response, usageTracker, path, requestModel, username)
|
s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username)
|
||||||
} else {
|
} else {
|
||||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||||
if err == nil && mediaType != "text/event-stream" {
|
if err == nil && mediaType != "text/event-stream" {
|
||||||
@@ -598,7 +599,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
flusher, ok := w.(http.Flusher)
|
flusher, ok := w.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
s.logger.Error("streaming not supported")
|
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
buffer := make([]byte, buf.BufferSize)
|
buffer := make([]byte, buf.BufferSize)
|
||||||
@@ -607,7 +608,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
if n > 0 {
|
if n > 0 {
|
||||||
_, writeError := w.Write(buffer[:n])
|
_, writeError := w.Write(buffer[:n])
|
||||||
if writeError != nil {
|
if writeError != nil {
|
||||||
s.logger.Error("write streaming response: ", writeError)
|
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
@@ -619,7 +620,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) {
|
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) {
|
||||||
isChatCompletions := path == "/v1/chat/completions"
|
isChatCompletions := path == "/v1/chat/completions"
|
||||||
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
||||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||||
@@ -630,7 +631,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
|||||||
if !isStreaming {
|
if !isStreaming {
|
||||||
bodyBytes, err := io.ReadAll(response.Body)
|
bodyBytes, err := io.ReadAll(response.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("read response body: ", err)
|
s.logger.ErrorContext(ctx, "read response body: ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -683,7 +684,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
|||||||
|
|
||||||
flusher, ok := writer.(http.Flusher)
|
flusher, ok := writer.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
s.logger.Error("streaming not supported")
|
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -760,7 +761,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
|||||||
|
|
||||||
_, writeError := writer.Write(buffer[:n])
|
_, writeError := writer.Write(buffer[:n])
|
||||||
if writeError != nil {
|
if writeError != nil {
|
||||||
s.logger.Error("write streaming response: ", writeError)
|
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ func isForwardableWebSocketRequestHeader(key string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) handleWebSocket(
|
func (s *Service) handleWebSocket(
|
||||||
|
ctx context.Context,
|
||||||
w http.ResponseWriter,
|
w http.ResponseWriter,
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
path string,
|
path string,
|
||||||
@@ -105,7 +106,7 @@ func (s *Service) handleWebSocket(
|
|||||||
for {
|
for {
|
||||||
accessToken, accessErr := selectedCredential.getAccessToken()
|
accessToken, accessErr := selectedCredential.getAccessToken()
|
||||||
if accessErr != nil {
|
if accessErr != nil {
|
||||||
s.logger.Error("get access token for websocket: ", accessErr)
|
s.logger.ErrorContext(ctx, "get access token for websocket: ", accessErr)
|
||||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed")
|
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -190,14 +191,14 @@ func (s *Service) handleWebSocket(
|
|||||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
|
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.logger.Info("retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
s.logger.InfoContext(ctx, "retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||||
selectedCredential = nextCredential
|
selectedCredential = nextCredential
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if statusCode > 0 && statusResponseBody != "" {
|
if statusCode > 0 && statusResponseBody != "" {
|
||||||
s.logger.Error("dial upstream websocket: status ", statusCode, " body: ", statusResponseBody)
|
s.logger.ErrorContext(ctx, "dial upstream websocket: status ", statusCode, " body: ", statusResponseBody)
|
||||||
} else {
|
} else {
|
||||||
s.logger.Error("dial upstream websocket: ", err)
|
s.logger.ErrorContext(ctx, "dial upstream websocket: ", err)
|
||||||
}
|
}
|
||||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", "upstream websocket connection failed")
|
writeJSONError(w, r, http.StatusBadGateway, "api_error", "upstream websocket connection failed")
|
||||||
return
|
return
|
||||||
@@ -226,7 +227,7 @@ func (s *Service) handleWebSocket(
|
|||||||
}
|
}
|
||||||
clientConn, _, _, err := clientUpgrader.Upgrade(r, w)
|
clientConn, _, _, err := clientUpgrader.Upgrade(r, w)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("upgrade client websocket: ", err)
|
s.logger.ErrorContext(ctx, "upgrade client websocket: ", err)
|
||||||
upstreamConn.Close()
|
upstreamConn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -258,23 +259,23 @@ func (s *Service) handleWebSocket(
|
|||||||
go func() {
|
go func() {
|
||||||
defer waitGroup.Done()
|
defer waitGroup.Done()
|
||||||
defer session.Close()
|
defer session.Close()
|
||||||
s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, selectedCredential, modelChannel, isNew, username, sessionID)
|
s.proxyWebSocketClientToUpstream(ctx, clientConn, upstreamConn, selectedCredential, modelChannel, isNew, username, sessionID)
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
defer waitGroup.Done()
|
defer waitGroup.Done()
|
||||||
defer session.Close()
|
defer session.Close()
|
||||||
s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint)
|
s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint)
|
||||||
}()
|
}()
|
||||||
waitGroup.Wait()
|
waitGroup.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string, isNew bool, username string, sessionID string) {
|
func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string, isNew bool, username string, sessionID string) {
|
||||||
logged := false
|
logged := false
|
||||||
for {
|
for {
|
||||||
data, opCode, err := wsutil.ReadClientData(clientConn)
|
data, opCode, err := wsutil.ReadClientData(clientConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !E.IsClosedOrCanceled(err) {
|
if !E.IsClosedOrCanceled(err) {
|
||||||
s.logger.Debug("read client websocket: ", err)
|
s.logger.DebugContext(ctx, "read client websocket: ", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -299,7 +300,7 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo
|
|||||||
if request.ServiceTier == "priority" {
|
if request.ServiceTier == "priority" {
|
||||||
logParts = append(logParts, ", fast")
|
logParts = append(logParts, ", fast")
|
||||||
}
|
}
|
||||||
s.logger.Debug(logParts...)
|
s.logger.DebugContext(ctx, logParts...)
|
||||||
}
|
}
|
||||||
if selectedCredential.usageTrackerOrNil() != nil {
|
if selectedCredential.usageTrackerOrNil() != nil {
|
||||||
select {
|
select {
|
||||||
@@ -313,21 +314,21 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo
|
|||||||
err = wsutil.WriteClientMessage(upstreamConn, opCode, data)
|
err = wsutil.WriteClientMessage(upstreamConn, opCode, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !E.IsClosedOrCanceled(err) {
|
if !E.IsClosedOrCanceled(err) {
|
||||||
s.logger.Debug("write upstream websocket: ", err)
|
s.logger.DebugContext(ctx, "write upstream websocket: ", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
|
func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
|
||||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||||
var requestModel string
|
var requestModel string
|
||||||
for {
|
for {
|
||||||
data, opCode, err := wsutil.ReadServerData(upstreamReadWriter)
|
data, opCode, err := wsutil.ReadServerData(upstreamReadWriter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !E.IsClosedOrCanceled(err) {
|
if !E.IsClosedOrCanceled(err) {
|
||||||
s.logger.Debug("read upstream websocket: ", err)
|
s.logger.DebugContext(ctx, "read upstream websocket: ", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -367,7 +368,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite
|
|||||||
err = wsutil.WriteServerMessage(clientConn, opCode, data)
|
err = wsutil.WriteServerMessage(clientConn, opCode, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !E.IsClosedOrCanceled(err) {
|
if !E.IsClosedOrCanceled(err) {
|
||||||
s.logger.Debug("write client websocket: ", err)
|
s.logger.DebugContext(ctx, "write client websocket: ", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user