From 2c907bef2ca46a0dc0059f6e67204d4a5a5718c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Mar 2026 17:38:40 +0800 Subject: [PATCH] Fix scoped rebalance interrupts --- service/ccm/credential_state.go | 145 +++++++++++++++++++++---------- service/ccm/service.go | 48 ++++++---- service/ocm/credential_state.go | 145 +++++++++++++++++++++---------- service/ocm/service.go | 50 +++++++---- service/ocm/service_websocket.go | 59 +++++++++---- 5 files changed, 304 insertions(+), 143 deletions(-) diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index 490c6148f..6b1a766f2 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -855,14 +855,37 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt // credentialProvider is the interface for all credential types. type credentialProvider interface { - selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) - onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential - wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext) + selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) + onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential + linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool pollIfStale(ctx context.Context) allCredentials() []credential close() } +type credentialSelectionScope string + +const ( + credentialSelectionScopeAll credentialSelectionScope = "all" + credentialSelectionScopeNonExternal credentialSelectionScope = "non_external" +) + +type credentialSelection struct { + scope credentialSelectionScope + filter func(credential) bool +} + +func (s credentialSelection) allows(cred credential) bool { + return s.filter == nil || s.filter(cred) +} + +func (s credentialSelection) scopeOrDefault() credentialSelectionScope { + if s.scope == "" { + return credentialSelectionScopeAll + } + return s.scope +} + // singleCredentialProvider wraps a single credential (legacy or single default). type singleCredentialProvider struct { cred credential @@ -870,8 +893,8 @@ type singleCredentialProvider struct { sessions map[string]time.Time } -func (p *singleCredentialProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) { - if filter != nil && !filter(p.cred) { +func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + if !selection.allows(p.cred) { return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") } if !p.cred.isAvailable() { @@ -896,7 +919,7 @@ func (p *singleCredentialProvider) selectCredential(sessionID string, filter fun return p.cred, isNew, nil } -func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ func(credential) bool) credential { +func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential { cred.markRateLimited(resetAt) return nil } @@ -920,15 +943,25 @@ func (p *singleCredentialProvider) allCredentials() []credential { return []credential{p.cred} } -func (p *singleCredentialProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {} +func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { + return func() bool { + return false + } +} func (p *singleCredentialProvider) close() {} const sessionExpiry = 24 * time.Hour type sessionEntry struct { - tag string - createdAt time.Time + tag string + selectionScope credentialSelectionScope + createdAt time.Time +} + +type credentialInterruptKey struct { + tag string + selectionScope credentialSelectionScope } type credentialInterruptEntry struct { @@ -946,7 +979,7 @@ type balancerProvider struct { sessionMutex sync.RWMutex sessions map[string]sessionEntry interruptAccess sync.Mutex - credentialInterrupts map[string]credentialInterruptEntry + credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry logger log.ContextLogger } @@ -960,34 +993,37 @@ func newBalancerProvider(credentials []credential, strategy string, pollInterval pollInterval: pollInterval, rebalanceThreshold: rebalanceThreshold, sessions: make(map[string]sessionEntry), - credentialInterrupts: make(map[string]credentialInterruptEntry), + credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry), logger: logger, } } -func (p *balancerProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) { +func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + selectionScope := selection.scopeOrDefault() if sessionID != "" { p.sessionMutex.RLock() entry, exists := p.sessions[sessionID] p.sessionMutex.RUnlock() if exists { - for _, cred := range p.credentials { - if cred.tagName() == entry.tag && (filter == nil || filter(cred)) && cred.isUsable() { - if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") { - better := p.pickLeastUsed(filter) - if better != nil && better.tagName() != cred.tagName() { - effectiveThreshold := p.rebalanceThreshold / cred.planWeight() - delta := cred.weeklyUtilization() - better.weeklyUtilization() - if delta > effectiveThreshold { - p.logger.Info("rebalancing away from ", cred.tagName(), - ": utilization delta ", delta, "% exceeds effective threshold ", - effectiveThreshold, "% (weight ", cred.planWeight(), ")") - p.rebalanceCredential(cred.tagName()) - break + if entry.selectionScope == selectionScope { + for _, cred := range p.credentials { + if cred.tagName() == entry.tag && selection.allows(cred) && cred.isUsable() { + if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") { + better := p.pickLeastUsed(selection.filter) + if better != nil && better.tagName() != cred.tagName() { + effectiveThreshold := p.rebalanceThreshold / cred.planWeight() + delta := cred.weeklyUtilization() - better.weeklyUtilization() + if delta > effectiveThreshold { + p.logger.Info("rebalancing away from ", cred.tagName(), + ": utilization delta ", delta, "% exceeds effective threshold ", + effectiveThreshold, "% (weight ", cred.planWeight(), ")") + p.rebalanceCredential(cred.tagName(), selectionScope) + break + } } } + return cred, false, nil } - return cred, false, nil } } p.sessionMutex.Lock() @@ -996,7 +1032,7 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden } } - best := p.pickCredential(filter) + best := p.pickCredential(selection.filter) if best == nil { return nil, false, allCredentialsUnavailableError(p.credentials) } @@ -1004,47 +1040,52 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden isNew := sessionID != "" if isNew { p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()} + p.sessions[sessionID] = sessionEntry{ + tag: best.tagName(), + selectionScope: selectionScope, + createdAt: time.Now(), + } p.sessionMutex.Unlock() } return best, isNew, nil } -func (p *balancerProvider) rebalanceCredential(tag string) { +func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) { + key := credentialInterruptKey{tag: tag, selectionScope: selectionScope} p.interruptAccess.Lock() - if entry, loaded := p.credentialInterrupts[tag]; loaded { + if entry, loaded := p.credentialInterrupts[key]; loaded { entry.cancel() } ctx, cancel := context.WithCancel(context.Background()) - p.credentialInterrupts[tag] = credentialInterruptEntry{context: ctx, cancel: cancel} + p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel} p.interruptAccess.Unlock() p.sessionMutex.Lock() for id, entry := range p.sessions { - if entry.tag == tag { + if entry.tag == tag && entry.selectionScope == selectionScope { delete(p.sessions, id) } } p.sessionMutex.Unlock() } -func (p *balancerProvider) wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext) { - tag := cred.tagName() +func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool { + key := credentialInterruptKey{ + tag: cred.tagName(), + selectionScope: selection.scopeOrDefault(), + } p.interruptAccess.Lock() - entry, loaded := p.credentialInterrupts[tag] + entry, loaded := p.credentialInterrupts[key] if !loaded { ctx, cancel := context.WithCancel(context.Background()) entry = credentialInterruptEntry{context: ctx, cancel: cancel} - p.credentialInterrupts[tag] = entry + p.credentialInterrupts[key] = entry } p.interruptAccess.Unlock() - stop := context.AfterFunc(entry.context, func() { - requestContext.cancelOnce.Do(requestContext.cancelFunc) - }) - requestContext.addInterruptLink(stop) + return context.AfterFunc(entry.context, onInterrupt) } -func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential { +func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential { cred.markRateLimited(resetAt) if sessionID != "" { p.sessionMutex.Lock() @@ -1052,10 +1093,14 @@ func (p *balancerProvider) onRateLimited(sessionID string, cred credential, rese p.sessionMutex.Unlock() } - best := p.pickCredential(filter) + best := p.pickCredential(selection.filter) if best != nil && sessionID != "" { p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()} + p.sessions[sessionID] = sessionEntry{ + tag: best.tagName(), + selectionScope: selection.scopeOrDefault(), + createdAt: time.Now(), + } p.sessionMutex.Unlock() } return best @@ -1196,9 +1241,9 @@ func newFallbackProvider(credentials []credential, pollInterval time.Duration, l } } -func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) { +func (p *fallbackProvider) selectCredential(_ string, selection credentialSelection) (credential, bool, error) { for _, cred := range p.credentials { - if filter != nil && !filter(cred) { + if !selection.allows(cred) { continue } if cred.isUsable() { @@ -1208,10 +1253,10 @@ func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bo return nil, false, allCredentialsUnavailableError(p.credentials) } -func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, filter func(credential) bool) credential { +func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, selection credentialSelection) credential { cred.markRateLimited(resetAt) for _, candidate := range p.credentials { - if filter != nil && !filter(candidate) { + if !selection.allows(candidate) { continue } if candidate.isUsable() { @@ -1233,7 +1278,11 @@ func (p *fallbackProvider) allCredentials() []credential { return p.credentials } -func (p *fallbackProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {} +func (p *fallbackProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { + return func() bool { + return false + } +} func (p *fallbackProvider) close() {} diff --git a/service/ccm/service.go b/service/ccm/service.go index 58bdd3787..6a2aa2b74 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -67,7 +67,7 @@ func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, erro }) } -func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool { +func hasAlternativeCredential(provider credentialProvider, currentCredential credential, selection credentialSelection) bool { if provider == nil || currentCredential == nil { return false } @@ -75,7 +75,7 @@ func hasAlternativeCredential(provider credentialProvider, currentCredential cre if cred == currentCredential { continue } - if filter != nil && !filter(cred) { + if !selection.allows(cred) { continue } if cred.isUsable() { @@ -109,16 +109,27 @@ func writeCredentialUnavailableError( r *http.Request, provider credentialProvider, currentCredential credential, - filter func(credential) bool, + selection credentialSelection, fallback string, ) { - if hasAlternativeCredential(provider, currentCredential, filter) { + if hasAlternativeCredential(provider, currentCredential, selection) { writeRetryableUsageError(w, r) return } writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, fallback)) } +func credentialSelectionForUser(userConfig *option.CCMUser) credentialSelection { + selection := credentialSelection{scope: credentialSelectionScopeAll} + if userConfig != nil && !userConfig.AllowExternalUsage { + selection.scope = credentialSelectionScopeNonExternal + selection.filter = func(cred credential) bool { + return !cred.isExternal() + } + } + return selection +} + func isHopByHopHeader(header string) bool { switch strings.ToLower(header) { case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host": @@ -424,12 +435,9 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - var credentialFilter func(credential) bool - if userConfig != nil && !userConfig.AllowExternalUsage { - credentialFilter = func(c credential) bool { return !c.isExternal() } - } + selection := credentialSelectionForUser(userConfig) - selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter) + selectedCredential, isNew, err := provider.selectCredential(sessionID, selection) if err != nil { writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error())) return @@ -459,7 +467,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } requestContext := selectedCredential.wrapRequestContext(ctx) - provider.wrapProviderInterrupt(selectedCredential, requestContext) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + })) + } defer func() { requestContext.cancelRequest() }() @@ -476,7 +489,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request") + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request") return } writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) @@ -487,18 +500,23 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Transparent 429 retry for response.StatusCode == http.StatusTooManyRequests { resetAt := parseRateLimitResetFromHeaders(response.Header) - nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter) + nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection) selectedCredential.updateStateFromHeaders(response.Header) if bodyBytes == nil || nextCredential == nil { response.Body.Close() - writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited") + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited") return } response.Body.Close() s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) requestContext.cancelRequest() requestContext = nextCredential.wrapRequestContext(ctx) - provider.wrapProviderInterrupt(nextCredential, requestContext) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + })) + } retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if buildErr != nil { s.logger.ErrorContext(ctx, "retry request: ", buildErr) @@ -511,7 +529,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request") + writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request") return } s.logger.ErrorContext(ctx, "retry request: ", retryErr) diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index 7a6c8ef5e..d8f2e826a 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -852,22 +852,45 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt } type credentialProvider interface { - selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) - onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential - wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext) + selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) + onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential + linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool pollIfStale(ctx context.Context) allCredentials() []credential close() } +type credentialSelectionScope string + +const ( + credentialSelectionScopeAll credentialSelectionScope = "all" + credentialSelectionScopeNonExternal credentialSelectionScope = "non_external" +) + +type credentialSelection struct { + scope credentialSelectionScope + filter func(credential) bool +} + +func (s credentialSelection) allows(cred credential) bool { + return s.filter == nil || s.filter(cred) +} + +func (s credentialSelection) scopeOrDefault() credentialSelectionScope { + if s.scope == "" { + return credentialSelectionScopeAll + } + return s.scope +} + type singleCredentialProvider struct { cred credential sessionAccess sync.RWMutex sessions map[string]time.Time } -func (p *singleCredentialProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) { - if filter != nil && !filter(p.cred) { +func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + if !selection.allows(p.cred) { return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out") } if !p.cred.isAvailable() { @@ -892,7 +915,7 @@ func (p *singleCredentialProvider) selectCredential(sessionID string, filter fun return p.cred, isNew, nil } -func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ func(credential) bool) credential { +func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential { cred.markRateLimited(resetAt) return nil } @@ -916,15 +939,25 @@ func (p *singleCredentialProvider) allCredentials() []credential { return []credential{p.cred} } -func (p *singleCredentialProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {} +func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { + return func() bool { + return false + } +} func (p *singleCredentialProvider) close() {} const sessionExpiry = 24 * time.Hour type sessionEntry struct { - tag string - createdAt time.Time + tag string + selectionScope credentialSelectionScope + createdAt time.Time +} + +type credentialInterruptKey struct { + tag string + selectionScope credentialSelectionScope } type credentialInterruptEntry struct { @@ -941,7 +974,7 @@ type balancerProvider struct { sessionMutex sync.RWMutex sessions map[string]sessionEntry interruptAccess sync.Mutex - credentialInterrupts map[string]credentialInterruptEntry + credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry logger log.ContextLogger } @@ -959,34 +992,37 @@ func newBalancerProvider(credentials []credential, strategy string, pollInterval pollInterval: pollInterval, rebalanceThreshold: rebalanceThreshold, sessions: make(map[string]sessionEntry), - credentialInterrupts: make(map[string]credentialInterruptEntry), + credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry), logger: logger, } } -func (p *balancerProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) { +func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) { + selectionScope := selection.scopeOrDefault() if sessionID != "" { p.sessionMutex.RLock() entry, exists := p.sessions[sessionID] p.sessionMutex.RUnlock() if exists { - for _, cred := range p.credentials { - if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && (filter == nil || filter(cred)) && cred.isUsable() { - if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") { - better := p.pickLeastUsed(filter) - if better != nil && better.tagName() != cred.tagName() { - effectiveThreshold := p.rebalanceThreshold / cred.planWeight() - delta := cred.weeklyUtilization() - better.weeklyUtilization() - if delta > effectiveThreshold { - p.logger.Info("rebalancing away from ", cred.tagName(), - ": utilization delta ", delta, "% exceeds effective threshold ", - effectiveThreshold, "% (weight ", cred.planWeight(), ")") - p.rebalanceCredential(cred.tagName()) - break + if entry.selectionScope == selectionScope { + for _, cred := range p.credentials { + if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && selection.allows(cred) && cred.isUsable() { + if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") { + better := p.pickLeastUsed(selection.filter) + if better != nil && better.tagName() != cred.tagName() { + effectiveThreshold := p.rebalanceThreshold / cred.planWeight() + delta := cred.weeklyUtilization() - better.weeklyUtilization() + if delta > effectiveThreshold { + p.logger.Info("rebalancing away from ", cred.tagName(), + ": utilization delta ", delta, "% exceeds effective threshold ", + effectiveThreshold, "% (weight ", cred.planWeight(), ")") + p.rebalanceCredential(cred.tagName(), selectionScope) + break + } } } + return cred, false, nil } - return cred, false, nil } } p.sessionMutex.Lock() @@ -995,7 +1031,7 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden } } - best := p.pickCredential(filter) + best := p.pickCredential(selection.filter) if best == nil { return nil, false, allRateLimitedError(p.credentials) } @@ -1003,47 +1039,52 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden isNew := sessionID != "" if isNew { p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()} + p.sessions[sessionID] = sessionEntry{ + tag: best.tagName(), + selectionScope: selectionScope, + createdAt: time.Now(), + } p.sessionMutex.Unlock() } return best, isNew, nil } -func (p *balancerProvider) rebalanceCredential(tag string) { +func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) { + key := credentialInterruptKey{tag: tag, selectionScope: selectionScope} p.interruptAccess.Lock() - if entry, loaded := p.credentialInterrupts[tag]; loaded { + if entry, loaded := p.credentialInterrupts[key]; loaded { entry.cancel() } ctx, cancel := context.WithCancel(context.Background()) - p.credentialInterrupts[tag] = credentialInterruptEntry{context: ctx, cancel: cancel} + p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel} p.interruptAccess.Unlock() p.sessionMutex.Lock() for id, entry := range p.sessions { - if entry.tag == tag { + if entry.tag == tag && entry.selectionScope == selectionScope { delete(p.sessions, id) } } p.sessionMutex.Unlock() } -func (p *balancerProvider) wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext) { - tag := cred.tagName() +func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool { + key := credentialInterruptKey{ + tag: cred.tagName(), + selectionScope: selection.scopeOrDefault(), + } p.interruptAccess.Lock() - entry, loaded := p.credentialInterrupts[tag] + entry, loaded := p.credentialInterrupts[key] if !loaded { ctx, cancel := context.WithCancel(context.Background()) entry = credentialInterruptEntry{context: ctx, cancel: cancel} - p.credentialInterrupts[tag] = entry + p.credentialInterrupts[key] = entry } p.interruptAccess.Unlock() - stop := context.AfterFunc(entry.context, func() { - requestContext.cancelOnce.Do(requestContext.cancelFunc) - }) - requestContext.addInterruptLink(stop) + return context.AfterFunc(entry.context, onInterrupt) } -func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential { +func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential { cred.markRateLimited(resetAt) if sessionID != "" { p.sessionMutex.Lock() @@ -1051,10 +1092,14 @@ func (p *balancerProvider) onRateLimited(sessionID string, cred credential, rese p.sessionMutex.Unlock() } - best := p.pickCredential(filter) + best := p.pickCredential(selection.filter) if best != nil && sessionID != "" { p.sessionMutex.Lock() - p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()} + p.sessions[sessionID] = sessionEntry{ + tag: best.tagName(), + selectionScope: selection.scopeOrDefault(), + createdAt: time.Now(), + } p.sessionMutex.Unlock() } return best @@ -1193,9 +1238,9 @@ func newFallbackProvider(credentials []credential, pollInterval time.Duration, l } } -func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) { +func (p *fallbackProvider) selectCredential(_ string, selection credentialSelection) (credential, bool, error) { for _, cred := range p.credentials { - if filter != nil && !filter(cred) { + if !selection.allows(cred) { continue } if !compositeCredentialSelectable(cred) { @@ -1208,10 +1253,10 @@ func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bo return nil, false, allRateLimitedError(p.credentials) } -func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, filter func(credential) bool) credential { +func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, selection credentialSelection) credential { cred.markRateLimited(resetAt) for _, candidate := range p.credentials { - if filter != nil && !filter(candidate) { + if !selection.allows(candidate) { continue } if !compositeCredentialSelectable(candidate) { @@ -1236,7 +1281,11 @@ func (p *fallbackProvider) allCredentials() []credential { return p.credentials } -func (p *fallbackProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {} +func (p *fallbackProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool { + return func() bool { + return false + } +} func (p *fallbackProvider) close() {} diff --git a/service/ocm/service.go b/service/ocm/service.go index cd7909dd4..071cec8cc 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -75,7 +75,7 @@ const ( retryableUsageCode = "credential_usage_exhausted" ) -func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool { +func hasAlternativeCredential(provider credentialProvider, currentCredential credential, selection credentialSelection) bool { if provider == nil || currentCredential == nil { return false } @@ -83,7 +83,7 @@ func hasAlternativeCredential(provider credentialProvider, currentCredential cre if cred == currentCredential { continue } - if filter != nil && !filter(cred) { + if !selection.allows(cred) { continue } if cred.isUsable() { @@ -117,16 +117,27 @@ func writeCredentialUnavailableError( r *http.Request, provider credentialProvider, currentCredential credential, - filter func(credential) bool, + selection credentialSelection, fallback string, ) { - if hasAlternativeCredential(provider, currentCredential, filter) { + if hasAlternativeCredential(provider, currentCredential, selection) { writeRetryableUsageError(w, r) return } writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, fallback)) } +func credentialSelectionForUser(userConfig *option.OCMUser) credentialSelection { + selection := credentialSelection{scope: credentialSelectionScopeAll} + if userConfig != nil && !userConfig.AllowExternalUsage { + selection.scope = credentialSelectionScopeNonExternal + selection.filter = func(cred credential) bool { + return !cred.isExternal() + } + } + return selection +} + func isHopByHopHeader(header string) bool { switch strings.ToLower(header) { case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host": @@ -426,19 +437,16 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { provider.pollIfStale(s.ctx) - var credentialFilter func(credential) bool - if userConfig != nil && !userConfig.AllowExternalUsage { - credentialFilter = func(c credential) bool { return !c.isExternal() } - } + selection := credentialSelectionForUser(userConfig) - selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter) + selectedCredential, isNew, err := provider.selectCredential(sessionID, selection) if err != nil { writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error())) return } if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") { - s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew) + s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, selection, isNew) return } @@ -500,7 +508,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } requestContext := selectedCredential.wrapRequestContext(ctx) - provider.wrapProviderInterrupt(selectedCredential, requestContext) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + })) + } defer func() { requestContext.cancelRequest() }() @@ -517,7 +530,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request") + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request") return } writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error()) @@ -528,19 +541,24 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Transparent 429 retry for response.StatusCode == http.StatusTooManyRequests { resetAt := parseOCMRateLimitResetFromHeaders(response.Header) - nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter) + nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection) needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete selectedCredential.updateStateFromHeaders(response.Header) if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil { response.Body.Close() - writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited") + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited") return } response.Body.Close() s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) requestContext.cancelRequest() requestContext = nextCredential.wrapRequestContext(ctx) - provider.wrapProviderInterrupt(nextCredential, requestContext) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + })) + } retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders) if buildErr != nil { s.logger.ErrorContext(ctx, "retry request: ", buildErr) @@ -553,7 +571,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if requestContext.Err() != nil { - writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request") + writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request") return } s.logger.ErrorContext(ctx, "retry request: ", retryErr) diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index fcffaae96..4b640d9c5 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -26,16 +26,24 @@ import ( ) type webSocketSession struct { - clientConn net.Conn - upstreamConn net.Conn - credentialTag string - closeOnce sync.Once + clientConn net.Conn + upstreamConn net.Conn + credentialTag string + releaseProviderInterrupt func() + closeOnce sync.Once } func (s *webSocketSession) Close() { s.closeOnce.Do(func() { - s.clientConn.Close() - s.upstreamConn.Close() + if s.releaseProviderInterrupt != nil { + s.releaseProviderInterrupt() + } + if s.clientConn != nil { + s.clientConn.Close() + } + if s.upstreamConn != nil { + s.upstreamConn.Close() + } }) } @@ -91,12 +99,14 @@ func (s *Service) handleWebSocket( userConfig *option.OCMUser, provider credentialProvider, selectedCredential credential, - credentialFilter func(credential) bool, + selection credentialSelection, isNew bool, ) { var ( err error requestContext *credentialRequestContext + clientConn net.Conn + session *webSocketSession upstreamConn net.Conn upstreamBufferedReader *bufio.Reader upstreamResponseHeaders http.Header @@ -186,20 +196,36 @@ func (s *Service) handleWebSocket( } requestContext = selectedCredential.wrapRequestContext(ctx) - provider.wrapProviderInterrupt(selectedCredential, requestContext) + { + currentRequestContext := requestContext + requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() { + currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc) + if session != nil { + session.Close() + return + } + if clientConn != nil { + clientConn.Close() + } + if upstreamConn != nil { + upstreamConn.Close() + } + })) + } upstreamConn, upstreamBufferedReader, _, err = upstreamDialer.Dial(requestContext, upstreamURL) if err == nil { - requestContext.releaseCredentialInterrupt() break } requestContext.cancelRequest() requestContext = nil + upstreamConn = nil + clientConn = nil if statusCode == http.StatusTooManyRequests { resetAt := parseOCMRateLimitResetFromHeaders(upstreamResponseHeaders) - nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter) + nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection) selectedCredential.updateStateFromHeaders(upstreamResponseHeaders) if nextCredential == nil { - writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited") + writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited") return } s.logger.InfoContext(ctx, "retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName()) @@ -236,16 +262,17 @@ func (s *Service) handleWebSocket( writeJSONError(w, r, http.StatusServiceUnavailable, "api_error", "service is shutting down") return } - clientConn, _, _, err := clientUpgrader.Upgrade(r, w) + clientConn, _, _, err = clientUpgrader.Upgrade(r, w) if err != nil { s.logger.ErrorContext(ctx, "upgrade client websocket: ", err) upstreamConn.Close() return } - session := &webSocketSession{ - clientConn: clientConn, - upstreamConn: upstreamConn, - credentialTag: selectedCredential.tagName(), + session = &webSocketSession{ + clientConn: clientConn, + upstreamConn: upstreamConn, + credentialTag: selectedCredential.tagName(), + releaseProviderInterrupt: requestContext.releaseCredentialInterrupt, } if !s.registerWebSocketSession(session) { session.Close()