Fix OCM websocket proxy lifecycle and headers

This commit is contained in:
世界
2026-03-10 21:17:35 +08:00
parent a09ffe6a0f
commit 67621ee6ba
2 changed files with 105 additions and 23 deletions

View File

@@ -139,7 +139,9 @@ type Service struct {
userManager *UserManager
accessMutex sync.RWMutex
usageTracker *AggregatedUsage
trackingGroup sync.WaitGroup
webSocketMutex sync.Mutex
webSocketGroup sync.WaitGroup
webSocketConns map[*webSocketSession]struct{}
shuttingDown bool
}
@@ -197,8 +199,9 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
Network: []string{N.NetworkTCP},
Listen: options.ListenOptions,
}),
userManager: userManager,
usageTracker: usageTracker,
userManager: userManager,
usageTracker: usageTracker,
webSocketConns: make(map[*webSocketSession]struct{}),
}
if options.TLS != nil {
@@ -631,11 +634,17 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
}
func (s *Service) Close() error {
webSocketSessions := s.startWebSocketShutdown()
err := common.Close(
common.PtrOrNil(s.httpServer),
common.PtrOrNil(s.listener),
s.tlsConfig,
)
for _, session := range webSocketSessions {
session.Close()
}
s.webSocketGroup.Wait()
if s.usageTracker != nil {
s.usageTracker.cancelPendingSave()
@@ -647,3 +656,48 @@ func (s *Service) Close() error {
return err
}
func (s *Service) registerWebSocketSession(session *webSocketSession) bool {
s.webSocketMutex.Lock()
defer s.webSocketMutex.Unlock()
if s.shuttingDown {
return false
}
s.webSocketConns[session] = struct{}{}
s.webSocketGroup.Add(1)
return true
}
func (s *Service) unregisterWebSocketSession(session *webSocketSession) {
s.webSocketMutex.Lock()
_, loaded := s.webSocketConns[session]
if loaded {
delete(s.webSocketConns, session)
}
s.webSocketMutex.Unlock()
if loaded {
s.webSocketGroup.Done()
}
}
func (s *Service) isShuttingDown() bool {
s.webSocketMutex.Lock()
defer s.webSocketMutex.Unlock()
return s.shuttingDown
}
func (s *Service) startWebSocketShutdown() []*webSocketSession {
s.webSocketMutex.Lock()
defer s.webSocketMutex.Unlock()
s.shuttingDown = true
webSocketSessions := make([]*webSocketSession, 0, len(s.webSocketConns))
for session := range s.webSocketConns {
webSocketSessions = append(webSocketSessions, session)
}
return webSocketSessions
}

View File

@@ -21,6 +21,19 @@ import (
"github.com/openai/openai-go/v3/responses"
)
type webSocketSession struct {
clientConn net.Conn
upstreamConn net.Conn
closeOnce sync.Once
}
func (s *webSocketSession) Close() {
s.closeOnce.Do(func() {
s.clientConn.Close()
s.upstreamConn.Close()
})
}
func buildUpstreamWebSocketURL(baseURL string, proxyPath string) string {
upstreamURL := baseURL
if strings.HasPrefix(upstreamURL, "https://") {
@@ -47,6 +60,22 @@ func isForwardableResponseHeader(key string) bool {
}
}
func isForwardableWebSocketRequestHeader(key string) bool {
if isHopByHopHeader(key) {
return false
}
lowerKey := strings.ToLower(key)
switch {
case lowerKey == "authorization":
return false
case strings.HasPrefix(lowerKey, "sec-websocket-"):
return false
default:
return true
}
}
func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyPath string, username string) {
accessToken, err := s.getAccessToken()
if err != nil {
@@ -61,18 +90,8 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP
}
upstreamHeaders := make(http.Header)
forwardHeaders := []string{
"OpenAI-Beta",
"X-Conversation-ID",
}
for _, headerKey := range forwardHeaders {
if value := r.Header.Get(headerKey); value != "" {
upstreamHeaders.Set(headerKey, value)
}
}
for key, values := range r.Header {
lowerKey := strings.ToLower(key)
if strings.HasPrefix(lowerKey, "x-codex-") || strings.HasPrefix(lowerKey, "x-responsesapi-") {
if isForwardableWebSocketRequestHeader(key) {
upstreamHeaders[key] = values
}
}
@@ -87,8 +106,8 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP
upstreamResponseHeaders := make(http.Header)
upstreamDialer := ws.Dialer{
NetDial: func(_ context.Context, network, addr string) (net.Conn, error) {
return s.dialer.DialContext(s.ctx, network, M.ParseSocksaddr(addr))
NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) {
return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
TLSConfig: &stdTLS.Config{
RootCAs: adapter.RootPoolFromContext(s.ctx),
@@ -120,12 +139,26 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP
clientUpgrader := ws.HTTPUpgrader{
Header: clientResponseHeaders,
}
if s.isShuttingDown() {
upstreamConn.Close()
writeJSONError(w, r, http.StatusServiceUnavailable, "api_error", "service is shutting down")
return
}
clientConn, _, _, err := clientUpgrader.Upgrade(r, w)
if err != nil {
s.logger.Error("upgrade client websocket: ", err)
upstreamConn.Close()
return
}
session := &webSocketSession{
clientConn: clientConn,
upstreamConn: upstreamConn,
}
if !s.registerWebSocketSession(session) {
session.Close()
return
}
defer s.unregisterWebSocketSession(session)
var upstreamReadWriter io.ReadWriter
if upstreamBufferedReader != nil {
@@ -139,21 +172,16 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP
modelChannel := make(chan string, 1)
var waitGroup sync.WaitGroup
var once sync.Once
closeAll := func() {
clientConn.Close()
upstreamConn.Close()
}
waitGroup.Add(2)
go func() {
defer waitGroup.Done()
defer once.Do(closeAll)
defer session.Close()
s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, modelChannel)
}()
go func() {
defer waitGroup.Done()
defer once.Do(closeAll)
defer session.Close()
s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, modelChannel, username, weeklyCycleHint)
}()
waitGroup.Wait()