diff --git a/service/ocm/service.go b/service/ocm/service.go index e05c95477..0c2e34301 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -139,7 +139,9 @@ type Service struct { userManager *UserManager accessMutex sync.RWMutex usageTracker *AggregatedUsage - trackingGroup sync.WaitGroup + webSocketMutex sync.Mutex + webSocketGroup sync.WaitGroup + webSocketConns map[*webSocketSession]struct{} shuttingDown bool } @@ -197,8 +199,9 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio Network: []string{N.NetworkTCP}, Listen: options.ListenOptions, }), - userManager: userManager, - usageTracker: usageTracker, + userManager: userManager, + usageTracker: usageTracker, + webSocketConns: make(map[*webSocketSession]struct{}), } if options.TLS != nil { @@ -631,11 +634,17 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons } func (s *Service) Close() error { + webSocketSessions := s.startWebSocketShutdown() + err := common.Close( common.PtrOrNil(s.httpServer), common.PtrOrNil(s.listener), s.tlsConfig, ) + for _, session := range webSocketSessions { + session.Close() + } + s.webSocketGroup.Wait() if s.usageTracker != nil { s.usageTracker.cancelPendingSave() @@ -647,3 +656,48 @@ func (s *Service) Close() error { return err } + +func (s *Service) registerWebSocketSession(session *webSocketSession) bool { + s.webSocketMutex.Lock() + defer s.webSocketMutex.Unlock() + + if s.shuttingDown { + return false + } + + s.webSocketConns[session] = struct{}{} + s.webSocketGroup.Add(1) + return true +} + +func (s *Service) unregisterWebSocketSession(session *webSocketSession) { + s.webSocketMutex.Lock() + _, loaded := s.webSocketConns[session] + if loaded { + delete(s.webSocketConns, session) + } + s.webSocketMutex.Unlock() + + if loaded { + s.webSocketGroup.Done() + } +} + +func (s *Service) isShuttingDown() bool { + s.webSocketMutex.Lock() + defer s.webSocketMutex.Unlock() + return s.shuttingDown +} + +func (s *Service) startWebSocketShutdown() []*webSocketSession { + s.webSocketMutex.Lock() + defer s.webSocketMutex.Unlock() + + s.shuttingDown = true + + webSocketSessions := make([]*webSocketSession, 0, len(s.webSocketConns)) + for session := range s.webSocketConns { + webSocketSessions = append(webSocketSessions, session) + } + return webSocketSessions +} diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index 5e8cb8bbd..c2e6148d2 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -21,6 +21,19 @@ import ( "github.com/openai/openai-go/v3/responses" ) +type webSocketSession struct { + clientConn net.Conn + upstreamConn net.Conn + closeOnce sync.Once +} + +func (s *webSocketSession) Close() { + s.closeOnce.Do(func() { + s.clientConn.Close() + s.upstreamConn.Close() + }) +} + func buildUpstreamWebSocketURL(baseURL string, proxyPath string) string { upstreamURL := baseURL if strings.HasPrefix(upstreamURL, "https://") { @@ -47,6 +60,22 @@ func isForwardableResponseHeader(key string) bool { } } +func isForwardableWebSocketRequestHeader(key string) bool { + if isHopByHopHeader(key) { + return false + } + + lowerKey := strings.ToLower(key) + switch { + case lowerKey == "authorization": + return false + case strings.HasPrefix(lowerKey, "sec-websocket-"): + return false + default: + return true + } +} + func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyPath string, username string) { accessToken, err := s.getAccessToken() if err != nil { @@ -61,18 +90,8 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP } upstreamHeaders := make(http.Header) - forwardHeaders := []string{ - "OpenAI-Beta", - "X-Conversation-ID", - } - for _, headerKey := range forwardHeaders { - if value := r.Header.Get(headerKey); value != "" { - upstreamHeaders.Set(headerKey, value) - } - } for key, values := range r.Header { - lowerKey := strings.ToLower(key) - if strings.HasPrefix(lowerKey, "x-codex-") || strings.HasPrefix(lowerKey, "x-responsesapi-") { + if isForwardableWebSocketRequestHeader(key) { upstreamHeaders[key] = values } } @@ -87,8 +106,8 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP upstreamResponseHeaders := make(http.Header) upstreamDialer := ws.Dialer{ - NetDial: func(_ context.Context, network, addr string) (net.Conn, error) { - return s.dialer.DialContext(s.ctx, network, M.ParseSocksaddr(addr)) + NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) { + return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) }, TLSConfig: &stdTLS.Config{ RootCAs: adapter.RootPoolFromContext(s.ctx), @@ -120,12 +139,26 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP clientUpgrader := ws.HTTPUpgrader{ Header: clientResponseHeaders, } + if s.isShuttingDown() { + upstreamConn.Close() + writeJSONError(w, r, http.StatusServiceUnavailable, "api_error", "service is shutting down") + return + } clientConn, _, _, err := clientUpgrader.Upgrade(r, w) if err != nil { s.logger.Error("upgrade client websocket: ", err) upstreamConn.Close() return } + session := &webSocketSession{ + clientConn: clientConn, + upstreamConn: upstreamConn, + } + if !s.registerWebSocketSession(session) { + session.Close() + return + } + defer s.unregisterWebSocketSession(session) var upstreamReadWriter io.ReadWriter if upstreamBufferedReader != nil { @@ -139,21 +172,16 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP modelChannel := make(chan string, 1) var waitGroup sync.WaitGroup - var once sync.Once - closeAll := func() { - clientConn.Close() - upstreamConn.Close() - } waitGroup.Add(2) go func() { defer waitGroup.Done() - defer once.Do(closeAll) + defer session.Close() s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, modelChannel) }() go func() { defer waitGroup.Done() - defer once.Do(closeAll) + defer session.Close() s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, modelChannel, username, weeklyCycleHint) }() waitGroup.Wait()