Fix scoped rebalance interrupts

This commit is contained in:
世界
2026-03-14 17:38:40 +08:00
parent d2300353fd
commit 2c907bef2c
5 changed files with 304 additions and 143 deletions

View File

@@ -855,14 +855,37 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt
// credentialProvider is the interface for all credential types.
type credentialProvider interface {
selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error)
onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential
wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext)
selectCredential(sessionID string, selection credentialSelection) (credential, bool, error)
onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential
linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool
pollIfStale(ctx context.Context)
allCredentials() []credential
close()
}
type credentialSelectionScope string
const (
credentialSelectionScopeAll credentialSelectionScope = "all"
credentialSelectionScopeNonExternal credentialSelectionScope = "non_external"
)
type credentialSelection struct {
scope credentialSelectionScope
filter func(credential) bool
}
func (s credentialSelection) allows(cred credential) bool {
return s.filter == nil || s.filter(cred)
}
func (s credentialSelection) scopeOrDefault() credentialSelectionScope {
if s.scope == "" {
return credentialSelectionScopeAll
}
return s.scope
}
// singleCredentialProvider wraps a single credential (legacy or single default).
type singleCredentialProvider struct {
cred credential
@@ -870,8 +893,8 @@ type singleCredentialProvider struct {
sessions map[string]time.Time
}
func (p *singleCredentialProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) {
if filter != nil && !filter(p.cred) {
func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) {
if !selection.allows(p.cred) {
return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out")
}
if !p.cred.isAvailable() {
@@ -896,7 +919,7 @@ func (p *singleCredentialProvider) selectCredential(sessionID string, filter fun
return p.cred, isNew, nil
}
func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ func(credential) bool) credential {
func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential {
cred.markRateLimited(resetAt)
return nil
}
@@ -920,15 +943,25 @@ func (p *singleCredentialProvider) allCredentials() []credential {
return []credential{p.cred}
}
func (p *singleCredentialProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {}
func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool {
return func() bool {
return false
}
}
func (p *singleCredentialProvider) close() {}
const sessionExpiry = 24 * time.Hour
type sessionEntry struct {
tag string
createdAt time.Time
tag string
selectionScope credentialSelectionScope
createdAt time.Time
}
type credentialInterruptKey struct {
tag string
selectionScope credentialSelectionScope
}
type credentialInterruptEntry struct {
@@ -946,7 +979,7 @@ type balancerProvider struct {
sessionMutex sync.RWMutex
sessions map[string]sessionEntry
interruptAccess sync.Mutex
credentialInterrupts map[string]credentialInterruptEntry
credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry
logger log.ContextLogger
}
@@ -960,34 +993,37 @@ func newBalancerProvider(credentials []credential, strategy string, pollInterval
pollInterval: pollInterval,
rebalanceThreshold: rebalanceThreshold,
sessions: make(map[string]sessionEntry),
credentialInterrupts: make(map[string]credentialInterruptEntry),
credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry),
logger: logger,
}
}
func (p *balancerProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) {
func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) {
selectionScope := selection.scopeOrDefault()
if sessionID != "" {
p.sessionMutex.RLock()
entry, exists := p.sessions[sessionID]
p.sessionMutex.RUnlock()
if exists {
for _, cred := range p.credentials {
if cred.tagName() == entry.tag && (filter == nil || filter(cred)) && cred.isUsable() {
if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") {
better := p.pickLeastUsed(filter)
if better != nil && better.tagName() != cred.tagName() {
effectiveThreshold := p.rebalanceThreshold / cred.planWeight()
delta := cred.weeklyUtilization() - better.weeklyUtilization()
if delta > effectiveThreshold {
p.logger.Info("rebalancing away from ", cred.tagName(),
": utilization delta ", delta, "% exceeds effective threshold ",
effectiveThreshold, "% (weight ", cred.planWeight(), ")")
p.rebalanceCredential(cred.tagName())
break
if entry.selectionScope == selectionScope {
for _, cred := range p.credentials {
if cred.tagName() == entry.tag && selection.allows(cred) && cred.isUsable() {
if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") {
better := p.pickLeastUsed(selection.filter)
if better != nil && better.tagName() != cred.tagName() {
effectiveThreshold := p.rebalanceThreshold / cred.planWeight()
delta := cred.weeklyUtilization() - better.weeklyUtilization()
if delta > effectiveThreshold {
p.logger.Info("rebalancing away from ", cred.tagName(),
": utilization delta ", delta, "% exceeds effective threshold ",
effectiveThreshold, "% (weight ", cred.planWeight(), ")")
p.rebalanceCredential(cred.tagName(), selectionScope)
break
}
}
}
return cred, false, nil
}
return cred, false, nil
}
}
p.sessionMutex.Lock()
@@ -996,7 +1032,7 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden
}
}
best := p.pickCredential(filter)
best := p.pickCredential(selection.filter)
if best == nil {
return nil, false, allCredentialsUnavailableError(p.credentials)
}
@@ -1004,47 +1040,52 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden
isNew := sessionID != ""
if isNew {
p.sessionMutex.Lock()
p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()}
p.sessions[sessionID] = sessionEntry{
tag: best.tagName(),
selectionScope: selectionScope,
createdAt: time.Now(),
}
p.sessionMutex.Unlock()
}
return best, isNew, nil
}
func (p *balancerProvider) rebalanceCredential(tag string) {
func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) {
key := credentialInterruptKey{tag: tag, selectionScope: selectionScope}
p.interruptAccess.Lock()
if entry, loaded := p.credentialInterrupts[tag]; loaded {
if entry, loaded := p.credentialInterrupts[key]; loaded {
entry.cancel()
}
ctx, cancel := context.WithCancel(context.Background())
p.credentialInterrupts[tag] = credentialInterruptEntry{context: ctx, cancel: cancel}
p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel}
p.interruptAccess.Unlock()
p.sessionMutex.Lock()
for id, entry := range p.sessions {
if entry.tag == tag {
if entry.tag == tag && entry.selectionScope == selectionScope {
delete(p.sessions, id)
}
}
p.sessionMutex.Unlock()
}
func (p *balancerProvider) wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext) {
tag := cred.tagName()
func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool {
key := credentialInterruptKey{
tag: cred.tagName(),
selectionScope: selection.scopeOrDefault(),
}
p.interruptAccess.Lock()
entry, loaded := p.credentialInterrupts[tag]
entry, loaded := p.credentialInterrupts[key]
if !loaded {
ctx, cancel := context.WithCancel(context.Background())
entry = credentialInterruptEntry{context: ctx, cancel: cancel}
p.credentialInterrupts[tag] = entry
p.credentialInterrupts[key] = entry
}
p.interruptAccess.Unlock()
stop := context.AfterFunc(entry.context, func() {
requestContext.cancelOnce.Do(requestContext.cancelFunc)
})
requestContext.addInterruptLink(stop)
return context.AfterFunc(entry.context, onInterrupt)
}
func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential {
func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential {
cred.markRateLimited(resetAt)
if sessionID != "" {
p.sessionMutex.Lock()
@@ -1052,10 +1093,14 @@ func (p *balancerProvider) onRateLimited(sessionID string, cred credential, rese
p.sessionMutex.Unlock()
}
best := p.pickCredential(filter)
best := p.pickCredential(selection.filter)
if best != nil && sessionID != "" {
p.sessionMutex.Lock()
p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()}
p.sessions[sessionID] = sessionEntry{
tag: best.tagName(),
selectionScope: selection.scopeOrDefault(),
createdAt: time.Now(),
}
p.sessionMutex.Unlock()
}
return best
@@ -1196,9 +1241,9 @@ func newFallbackProvider(credentials []credential, pollInterval time.Duration, l
}
}
func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) {
func (p *fallbackProvider) selectCredential(_ string, selection credentialSelection) (credential, bool, error) {
for _, cred := range p.credentials {
if filter != nil && !filter(cred) {
if !selection.allows(cred) {
continue
}
if cred.isUsable() {
@@ -1208,10 +1253,10 @@ func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bo
return nil, false, allCredentialsUnavailableError(p.credentials)
}
func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, filter func(credential) bool) credential {
func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, selection credentialSelection) credential {
cred.markRateLimited(resetAt)
for _, candidate := range p.credentials {
if filter != nil && !filter(candidate) {
if !selection.allows(candidate) {
continue
}
if candidate.isUsable() {
@@ -1233,7 +1278,11 @@ func (p *fallbackProvider) allCredentials() []credential {
return p.credentials
}
func (p *fallbackProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {}
func (p *fallbackProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool {
return func() bool {
return false
}
}
func (p *fallbackProvider) close() {}

View File

@@ -67,7 +67,7 @@ func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, erro
})
}
func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool {
func hasAlternativeCredential(provider credentialProvider, currentCredential credential, selection credentialSelection) bool {
if provider == nil || currentCredential == nil {
return false
}
@@ -75,7 +75,7 @@ func hasAlternativeCredential(provider credentialProvider, currentCredential cre
if cred == currentCredential {
continue
}
if filter != nil && !filter(cred) {
if !selection.allows(cred) {
continue
}
if cred.isUsable() {
@@ -109,16 +109,27 @@ func writeCredentialUnavailableError(
r *http.Request,
provider credentialProvider,
currentCredential credential,
filter func(credential) bool,
selection credentialSelection,
fallback string,
) {
if hasAlternativeCredential(provider, currentCredential, filter) {
if hasAlternativeCredential(provider, currentCredential, selection) {
writeRetryableUsageError(w, r)
return
}
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, fallback))
}
func credentialSelectionForUser(userConfig *option.CCMUser) credentialSelection {
selection := credentialSelection{scope: credentialSelectionScopeAll}
if userConfig != nil && !userConfig.AllowExternalUsage {
selection.scope = credentialSelectionScopeNonExternal
selection.filter = func(cred credential) bool {
return !cred.isExternal()
}
}
return selection
}
func isHopByHopHeader(header string) bool {
switch strings.ToLower(header) {
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
@@ -424,12 +435,9 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
var credentialFilter func(credential) bool
if userConfig != nil && !userConfig.AllowExternalUsage {
credentialFilter = func(c credential) bool { return !c.isExternal() }
}
selection := credentialSelectionForUser(userConfig)
selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter)
selectedCredential, isNew, err := provider.selectCredential(sessionID, selection)
if err != nil {
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error()))
return
@@ -459,7 +467,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
requestContext := selectedCredential.wrapRequestContext(ctx)
provider.wrapProviderInterrupt(selectedCredential, requestContext)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
defer func() {
requestContext.cancelRequest()
}()
@@ -476,7 +489,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request")
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request")
return
}
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
@@ -487,18 +500,23 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Transparent 429 retry
for response.StatusCode == http.StatusTooManyRequests {
resetAt := parseRateLimitResetFromHeaders(response.Header)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
selectedCredential.updateStateFromHeaders(response.Header)
if bodyBytes == nil || nextCredential == nil {
response.Body.Close()
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
return
}
response.Body.Close()
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
requestContext.cancelRequest()
requestContext = nextCredential.wrapRequestContext(ctx)
provider.wrapProviderInterrupt(nextCredential, requestContext)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
@@ -511,7 +529,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request")
return
}
s.logger.ErrorContext(ctx, "retry request: ", retryErr)

View File

@@ -852,22 +852,45 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt
}
type credentialProvider interface {
selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error)
onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential
wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext)
selectCredential(sessionID string, selection credentialSelection) (credential, bool, error)
onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential
linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool
pollIfStale(ctx context.Context)
allCredentials() []credential
close()
}
type credentialSelectionScope string
const (
credentialSelectionScopeAll credentialSelectionScope = "all"
credentialSelectionScopeNonExternal credentialSelectionScope = "non_external"
)
type credentialSelection struct {
scope credentialSelectionScope
filter func(credential) bool
}
func (s credentialSelection) allows(cred credential) bool {
return s.filter == nil || s.filter(cred)
}
func (s credentialSelection) scopeOrDefault() credentialSelectionScope {
if s.scope == "" {
return credentialSelectionScopeAll
}
return s.scope
}
type singleCredentialProvider struct {
cred credential
sessionAccess sync.RWMutex
sessions map[string]time.Time
}
func (p *singleCredentialProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) {
if filter != nil && !filter(p.cred) {
func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) {
if !selection.allows(p.cred) {
return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out")
}
if !p.cred.isAvailable() {
@@ -892,7 +915,7 @@ func (p *singleCredentialProvider) selectCredential(sessionID string, filter fun
return p.cred, isNew, nil
}
func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ func(credential) bool) credential {
func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential {
cred.markRateLimited(resetAt)
return nil
}
@@ -916,15 +939,25 @@ func (p *singleCredentialProvider) allCredentials() []credential {
return []credential{p.cred}
}
func (p *singleCredentialProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {}
func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool {
return func() bool {
return false
}
}
func (p *singleCredentialProvider) close() {}
const sessionExpiry = 24 * time.Hour
type sessionEntry struct {
tag string
createdAt time.Time
tag string
selectionScope credentialSelectionScope
createdAt time.Time
}
type credentialInterruptKey struct {
tag string
selectionScope credentialSelectionScope
}
type credentialInterruptEntry struct {
@@ -941,7 +974,7 @@ type balancerProvider struct {
sessionMutex sync.RWMutex
sessions map[string]sessionEntry
interruptAccess sync.Mutex
credentialInterrupts map[string]credentialInterruptEntry
credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry
logger log.ContextLogger
}
@@ -959,34 +992,37 @@ func newBalancerProvider(credentials []credential, strategy string, pollInterval
pollInterval: pollInterval,
rebalanceThreshold: rebalanceThreshold,
sessions: make(map[string]sessionEntry),
credentialInterrupts: make(map[string]credentialInterruptEntry),
credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry),
logger: logger,
}
}
func (p *balancerProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) {
func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) {
selectionScope := selection.scopeOrDefault()
if sessionID != "" {
p.sessionMutex.RLock()
entry, exists := p.sessions[sessionID]
p.sessionMutex.RUnlock()
if exists {
for _, cred := range p.credentials {
if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && (filter == nil || filter(cred)) && cred.isUsable() {
if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") {
better := p.pickLeastUsed(filter)
if better != nil && better.tagName() != cred.tagName() {
effectiveThreshold := p.rebalanceThreshold / cred.planWeight()
delta := cred.weeklyUtilization() - better.weeklyUtilization()
if delta > effectiveThreshold {
p.logger.Info("rebalancing away from ", cred.tagName(),
": utilization delta ", delta, "% exceeds effective threshold ",
effectiveThreshold, "% (weight ", cred.planWeight(), ")")
p.rebalanceCredential(cred.tagName())
break
if entry.selectionScope == selectionScope {
for _, cred := range p.credentials {
if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && selection.allows(cred) && cred.isUsable() {
if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") {
better := p.pickLeastUsed(selection.filter)
if better != nil && better.tagName() != cred.tagName() {
effectiveThreshold := p.rebalanceThreshold / cred.planWeight()
delta := cred.weeklyUtilization() - better.weeklyUtilization()
if delta > effectiveThreshold {
p.logger.Info("rebalancing away from ", cred.tagName(),
": utilization delta ", delta, "% exceeds effective threshold ",
effectiveThreshold, "% (weight ", cred.planWeight(), ")")
p.rebalanceCredential(cred.tagName(), selectionScope)
break
}
}
}
return cred, false, nil
}
return cred, false, nil
}
}
p.sessionMutex.Lock()
@@ -995,7 +1031,7 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden
}
}
best := p.pickCredential(filter)
best := p.pickCredential(selection.filter)
if best == nil {
return nil, false, allRateLimitedError(p.credentials)
}
@@ -1003,47 +1039,52 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden
isNew := sessionID != ""
if isNew {
p.sessionMutex.Lock()
p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()}
p.sessions[sessionID] = sessionEntry{
tag: best.tagName(),
selectionScope: selectionScope,
createdAt: time.Now(),
}
p.sessionMutex.Unlock()
}
return best, isNew, nil
}
func (p *balancerProvider) rebalanceCredential(tag string) {
func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) {
key := credentialInterruptKey{tag: tag, selectionScope: selectionScope}
p.interruptAccess.Lock()
if entry, loaded := p.credentialInterrupts[tag]; loaded {
if entry, loaded := p.credentialInterrupts[key]; loaded {
entry.cancel()
}
ctx, cancel := context.WithCancel(context.Background())
p.credentialInterrupts[tag] = credentialInterruptEntry{context: ctx, cancel: cancel}
p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel}
p.interruptAccess.Unlock()
p.sessionMutex.Lock()
for id, entry := range p.sessions {
if entry.tag == tag {
if entry.tag == tag && entry.selectionScope == selectionScope {
delete(p.sessions, id)
}
}
p.sessionMutex.Unlock()
}
func (p *balancerProvider) wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext) {
tag := cred.tagName()
func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool {
key := credentialInterruptKey{
tag: cred.tagName(),
selectionScope: selection.scopeOrDefault(),
}
p.interruptAccess.Lock()
entry, loaded := p.credentialInterrupts[tag]
entry, loaded := p.credentialInterrupts[key]
if !loaded {
ctx, cancel := context.WithCancel(context.Background())
entry = credentialInterruptEntry{context: ctx, cancel: cancel}
p.credentialInterrupts[tag] = entry
p.credentialInterrupts[key] = entry
}
p.interruptAccess.Unlock()
stop := context.AfterFunc(entry.context, func() {
requestContext.cancelOnce.Do(requestContext.cancelFunc)
})
requestContext.addInterruptLink(stop)
return context.AfterFunc(entry.context, onInterrupt)
}
func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential {
func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential {
cred.markRateLimited(resetAt)
if sessionID != "" {
p.sessionMutex.Lock()
@@ -1051,10 +1092,14 @@ func (p *balancerProvider) onRateLimited(sessionID string, cred credential, rese
p.sessionMutex.Unlock()
}
best := p.pickCredential(filter)
best := p.pickCredential(selection.filter)
if best != nil && sessionID != "" {
p.sessionMutex.Lock()
p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()}
p.sessions[sessionID] = sessionEntry{
tag: best.tagName(),
selectionScope: selection.scopeOrDefault(),
createdAt: time.Now(),
}
p.sessionMutex.Unlock()
}
return best
@@ -1193,9 +1238,9 @@ func newFallbackProvider(credentials []credential, pollInterval time.Duration, l
}
}
func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) {
func (p *fallbackProvider) selectCredential(_ string, selection credentialSelection) (credential, bool, error) {
for _, cred := range p.credentials {
if filter != nil && !filter(cred) {
if !selection.allows(cred) {
continue
}
if !compositeCredentialSelectable(cred) {
@@ -1208,10 +1253,10 @@ func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bo
return nil, false, allRateLimitedError(p.credentials)
}
func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, filter func(credential) bool) credential {
func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, selection credentialSelection) credential {
cred.markRateLimited(resetAt)
for _, candidate := range p.credentials {
if filter != nil && !filter(candidate) {
if !selection.allows(candidate) {
continue
}
if !compositeCredentialSelectable(candidate) {
@@ -1236,7 +1281,11 @@ func (p *fallbackProvider) allCredentials() []credential {
return p.credentials
}
func (p *fallbackProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {}
func (p *fallbackProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool {
return func() bool {
return false
}
}
func (p *fallbackProvider) close() {}

View File

@@ -75,7 +75,7 @@ const (
retryableUsageCode = "credential_usage_exhausted"
)
func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool {
func hasAlternativeCredential(provider credentialProvider, currentCredential credential, selection credentialSelection) bool {
if provider == nil || currentCredential == nil {
return false
}
@@ -83,7 +83,7 @@ func hasAlternativeCredential(provider credentialProvider, currentCredential cre
if cred == currentCredential {
continue
}
if filter != nil && !filter(cred) {
if !selection.allows(cred) {
continue
}
if cred.isUsable() {
@@ -117,16 +117,27 @@ func writeCredentialUnavailableError(
r *http.Request,
provider credentialProvider,
currentCredential credential,
filter func(credential) bool,
selection credentialSelection,
fallback string,
) {
if hasAlternativeCredential(provider, currentCredential, filter) {
if hasAlternativeCredential(provider, currentCredential, selection) {
writeRetryableUsageError(w, r)
return
}
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, fallback))
}
func credentialSelectionForUser(userConfig *option.OCMUser) credentialSelection {
selection := credentialSelection{scope: credentialSelectionScopeAll}
if userConfig != nil && !userConfig.AllowExternalUsage {
selection.scope = credentialSelectionScopeNonExternal
selection.filter = func(cred credential) bool {
return !cred.isExternal()
}
}
return selection
}
func isHopByHopHeader(header string) bool {
switch strings.ToLower(header) {
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
@@ -426,19 +437,16 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
provider.pollIfStale(s.ctx)
var credentialFilter func(credential) bool
if userConfig != nil && !userConfig.AllowExternalUsage {
credentialFilter = func(c credential) bool { return !c.isExternal() }
}
selection := credentialSelectionForUser(userConfig)
selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter)
selectedCredential, isNew, err := provider.selectCredential(sessionID, selection)
if err != nil {
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error()))
return
}
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew)
s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, selection, isNew)
return
}
@@ -500,7 +508,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
requestContext := selectedCredential.wrapRequestContext(ctx)
provider.wrapProviderInterrupt(selectedCredential, requestContext)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
defer func() {
requestContext.cancelRequest()
}()
@@ -517,7 +530,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request")
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request")
return
}
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
@@ -528,19 +541,24 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Transparent 429 retry
for response.StatusCode == http.StatusTooManyRequests {
resetAt := parseOCMRateLimitResetFromHeaders(response.Header)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete
selectedCredential.updateStateFromHeaders(response.Header)
if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil {
response.Body.Close()
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
return
}
response.Body.Close()
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
requestContext.cancelRequest()
requestContext = nextCredential.wrapRequestContext(ctx)
provider.wrapProviderInterrupt(nextCredential, requestContext)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
}))
}
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
@@ -553,7 +571,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request")
return
}
s.logger.ErrorContext(ctx, "retry request: ", retryErr)

View File

@@ -26,16 +26,24 @@ import (
)
type webSocketSession struct {
clientConn net.Conn
upstreamConn net.Conn
credentialTag string
closeOnce sync.Once
clientConn net.Conn
upstreamConn net.Conn
credentialTag string
releaseProviderInterrupt func()
closeOnce sync.Once
}
func (s *webSocketSession) Close() {
s.closeOnce.Do(func() {
s.clientConn.Close()
s.upstreamConn.Close()
if s.releaseProviderInterrupt != nil {
s.releaseProviderInterrupt()
}
if s.clientConn != nil {
s.clientConn.Close()
}
if s.upstreamConn != nil {
s.upstreamConn.Close()
}
})
}
@@ -91,12 +99,14 @@ func (s *Service) handleWebSocket(
userConfig *option.OCMUser,
provider credentialProvider,
selectedCredential credential,
credentialFilter func(credential) bool,
selection credentialSelection,
isNew bool,
) {
var (
err error
requestContext *credentialRequestContext
clientConn net.Conn
session *webSocketSession
upstreamConn net.Conn
upstreamBufferedReader *bufio.Reader
upstreamResponseHeaders http.Header
@@ -186,20 +196,36 @@ func (s *Service) handleWebSocket(
}
requestContext = selectedCredential.wrapRequestContext(ctx)
provider.wrapProviderInterrupt(selectedCredential, requestContext)
{
currentRequestContext := requestContext
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
if session != nil {
session.Close()
return
}
if clientConn != nil {
clientConn.Close()
}
if upstreamConn != nil {
upstreamConn.Close()
}
}))
}
upstreamConn, upstreamBufferedReader, _, err = upstreamDialer.Dial(requestContext, upstreamURL)
if err == nil {
requestContext.releaseCredentialInterrupt()
break
}
requestContext.cancelRequest()
requestContext = nil
upstreamConn = nil
clientConn = nil
if statusCode == http.StatusTooManyRequests {
resetAt := parseOCMRateLimitResetFromHeaders(upstreamResponseHeaders)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
selectedCredential.updateStateFromHeaders(upstreamResponseHeaders)
if nextCredential == nil {
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
return
}
s.logger.InfoContext(ctx, "retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
@@ -236,16 +262,17 @@ func (s *Service) handleWebSocket(
writeJSONError(w, r, http.StatusServiceUnavailable, "api_error", "service is shutting down")
return
}
clientConn, _, _, err := clientUpgrader.Upgrade(r, w)
clientConn, _, _, err = clientUpgrader.Upgrade(r, w)
if err != nil {
s.logger.ErrorContext(ctx, "upgrade client websocket: ", err)
upstreamConn.Close()
return
}
session := &webSocketSession{
clientConn: clientConn,
upstreamConn: upstreamConn,
credentialTag: selectedCredential.tagName(),
session = &webSocketSession{
clientConn: clientConn,
upstreamConn: upstreamConn,
credentialTag: selectedCredential.tagName(),
releaseProviderInterrupt: requestContext.releaseCredentialInterrupt,
}
if !s.registerWebSocketSession(session) {
session.Close()