mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
Fix scoped rebalance interrupts
This commit is contained in:
@@ -855,14 +855,37 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt
|
||||
|
||||
// credentialProvider is the interface for all credential types.
|
||||
type credentialProvider interface {
|
||||
selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error)
|
||||
onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential
|
||||
wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext)
|
||||
selectCredential(sessionID string, selection credentialSelection) (credential, bool, error)
|
||||
onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential
|
||||
linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool
|
||||
pollIfStale(ctx context.Context)
|
||||
allCredentials() []credential
|
||||
close()
|
||||
}
|
||||
|
||||
type credentialSelectionScope string
|
||||
|
||||
const (
|
||||
credentialSelectionScopeAll credentialSelectionScope = "all"
|
||||
credentialSelectionScopeNonExternal credentialSelectionScope = "non_external"
|
||||
)
|
||||
|
||||
type credentialSelection struct {
|
||||
scope credentialSelectionScope
|
||||
filter func(credential) bool
|
||||
}
|
||||
|
||||
func (s credentialSelection) allows(cred credential) bool {
|
||||
return s.filter == nil || s.filter(cred)
|
||||
}
|
||||
|
||||
func (s credentialSelection) scopeOrDefault() credentialSelectionScope {
|
||||
if s.scope == "" {
|
||||
return credentialSelectionScopeAll
|
||||
}
|
||||
return s.scope
|
||||
}
|
||||
|
||||
// singleCredentialProvider wraps a single credential (legacy or single default).
|
||||
type singleCredentialProvider struct {
|
||||
cred credential
|
||||
@@ -870,8 +893,8 @@ type singleCredentialProvider struct {
|
||||
sessions map[string]time.Time
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) {
|
||||
if filter != nil && !filter(p.cred) {
|
||||
func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) {
|
||||
if !selection.allows(p.cred) {
|
||||
return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out")
|
||||
}
|
||||
if !p.cred.isAvailable() {
|
||||
@@ -896,7 +919,7 @@ func (p *singleCredentialProvider) selectCredential(sessionID string, filter fun
|
||||
return p.cred, isNew, nil
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ func(credential) bool) credential {
|
||||
func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential {
|
||||
cred.markRateLimited(resetAt)
|
||||
return nil
|
||||
}
|
||||
@@ -920,15 +943,25 @@ func (p *singleCredentialProvider) allCredentials() []credential {
|
||||
return []credential{p.cred}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {}
|
||||
func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool {
|
||||
return func() bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) close() {}
|
||||
|
||||
const sessionExpiry = 24 * time.Hour
|
||||
|
||||
type sessionEntry struct {
|
||||
tag string
|
||||
createdAt time.Time
|
||||
tag string
|
||||
selectionScope credentialSelectionScope
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
type credentialInterruptKey struct {
|
||||
tag string
|
||||
selectionScope credentialSelectionScope
|
||||
}
|
||||
|
||||
type credentialInterruptEntry struct {
|
||||
@@ -946,7 +979,7 @@ type balancerProvider struct {
|
||||
sessionMutex sync.RWMutex
|
||||
sessions map[string]sessionEntry
|
||||
interruptAccess sync.Mutex
|
||||
credentialInterrupts map[string]credentialInterruptEntry
|
||||
credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry
|
||||
logger log.ContextLogger
|
||||
}
|
||||
|
||||
@@ -960,34 +993,37 @@ func newBalancerProvider(credentials []credential, strategy string, pollInterval
|
||||
pollInterval: pollInterval,
|
||||
rebalanceThreshold: rebalanceThreshold,
|
||||
sessions: make(map[string]sessionEntry),
|
||||
credentialInterrupts: make(map[string]credentialInterruptEntry),
|
||||
credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) {
|
||||
func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) {
|
||||
selectionScope := selection.scopeOrDefault()
|
||||
if sessionID != "" {
|
||||
p.sessionMutex.RLock()
|
||||
entry, exists := p.sessions[sessionID]
|
||||
p.sessionMutex.RUnlock()
|
||||
if exists {
|
||||
for _, cred := range p.credentials {
|
||||
if cred.tagName() == entry.tag && (filter == nil || filter(cred)) && cred.isUsable() {
|
||||
if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") {
|
||||
better := p.pickLeastUsed(filter)
|
||||
if better != nil && better.tagName() != cred.tagName() {
|
||||
effectiveThreshold := p.rebalanceThreshold / cred.planWeight()
|
||||
delta := cred.weeklyUtilization() - better.weeklyUtilization()
|
||||
if delta > effectiveThreshold {
|
||||
p.logger.Info("rebalancing away from ", cred.tagName(),
|
||||
": utilization delta ", delta, "% exceeds effective threshold ",
|
||||
effectiveThreshold, "% (weight ", cred.planWeight(), ")")
|
||||
p.rebalanceCredential(cred.tagName())
|
||||
break
|
||||
if entry.selectionScope == selectionScope {
|
||||
for _, cred := range p.credentials {
|
||||
if cred.tagName() == entry.tag && selection.allows(cred) && cred.isUsable() {
|
||||
if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") {
|
||||
better := p.pickLeastUsed(selection.filter)
|
||||
if better != nil && better.tagName() != cred.tagName() {
|
||||
effectiveThreshold := p.rebalanceThreshold / cred.planWeight()
|
||||
delta := cred.weeklyUtilization() - better.weeklyUtilization()
|
||||
if delta > effectiveThreshold {
|
||||
p.logger.Info("rebalancing away from ", cred.tagName(),
|
||||
": utilization delta ", delta, "% exceeds effective threshold ",
|
||||
effectiveThreshold, "% (weight ", cred.planWeight(), ")")
|
||||
p.rebalanceCredential(cred.tagName(), selectionScope)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return cred, false, nil
|
||||
}
|
||||
return cred, false, nil
|
||||
}
|
||||
}
|
||||
p.sessionMutex.Lock()
|
||||
@@ -996,7 +1032,7 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden
|
||||
}
|
||||
}
|
||||
|
||||
best := p.pickCredential(filter)
|
||||
best := p.pickCredential(selection.filter)
|
||||
if best == nil {
|
||||
return nil, false, allCredentialsUnavailableError(p.credentials)
|
||||
}
|
||||
@@ -1004,47 +1040,52 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden
|
||||
isNew := sessionID != ""
|
||||
if isNew {
|
||||
p.sessionMutex.Lock()
|
||||
p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()}
|
||||
p.sessions[sessionID] = sessionEntry{
|
||||
tag: best.tagName(),
|
||||
selectionScope: selectionScope,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
p.sessionMutex.Unlock()
|
||||
}
|
||||
return best, isNew, nil
|
||||
}
|
||||
|
||||
func (p *balancerProvider) rebalanceCredential(tag string) {
|
||||
func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) {
|
||||
key := credentialInterruptKey{tag: tag, selectionScope: selectionScope}
|
||||
p.interruptAccess.Lock()
|
||||
if entry, loaded := p.credentialInterrupts[tag]; loaded {
|
||||
if entry, loaded := p.credentialInterrupts[key]; loaded {
|
||||
entry.cancel()
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
p.credentialInterrupts[tag] = credentialInterruptEntry{context: ctx, cancel: cancel}
|
||||
p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel}
|
||||
p.interruptAccess.Unlock()
|
||||
|
||||
p.sessionMutex.Lock()
|
||||
for id, entry := range p.sessions {
|
||||
if entry.tag == tag {
|
||||
if entry.tag == tag && entry.selectionScope == selectionScope {
|
||||
delete(p.sessions, id)
|
||||
}
|
||||
}
|
||||
p.sessionMutex.Unlock()
|
||||
}
|
||||
|
||||
func (p *balancerProvider) wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext) {
|
||||
tag := cred.tagName()
|
||||
func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool {
|
||||
key := credentialInterruptKey{
|
||||
tag: cred.tagName(),
|
||||
selectionScope: selection.scopeOrDefault(),
|
||||
}
|
||||
p.interruptAccess.Lock()
|
||||
entry, loaded := p.credentialInterrupts[tag]
|
||||
entry, loaded := p.credentialInterrupts[key]
|
||||
if !loaded {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
entry = credentialInterruptEntry{context: ctx, cancel: cancel}
|
||||
p.credentialInterrupts[tag] = entry
|
||||
p.credentialInterrupts[key] = entry
|
||||
}
|
||||
p.interruptAccess.Unlock()
|
||||
stop := context.AfterFunc(entry.context, func() {
|
||||
requestContext.cancelOnce.Do(requestContext.cancelFunc)
|
||||
})
|
||||
requestContext.addInterruptLink(stop)
|
||||
return context.AfterFunc(entry.context, onInterrupt)
|
||||
}
|
||||
|
||||
func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential {
|
||||
func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential {
|
||||
cred.markRateLimited(resetAt)
|
||||
if sessionID != "" {
|
||||
p.sessionMutex.Lock()
|
||||
@@ -1052,10 +1093,14 @@ func (p *balancerProvider) onRateLimited(sessionID string, cred credential, rese
|
||||
p.sessionMutex.Unlock()
|
||||
}
|
||||
|
||||
best := p.pickCredential(filter)
|
||||
best := p.pickCredential(selection.filter)
|
||||
if best != nil && sessionID != "" {
|
||||
p.sessionMutex.Lock()
|
||||
p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()}
|
||||
p.sessions[sessionID] = sessionEntry{
|
||||
tag: best.tagName(),
|
||||
selectionScope: selection.scopeOrDefault(),
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
p.sessionMutex.Unlock()
|
||||
}
|
||||
return best
|
||||
@@ -1196,9 +1241,9 @@ func newFallbackProvider(credentials []credential, pollInterval time.Duration, l
|
||||
}
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) {
|
||||
func (p *fallbackProvider) selectCredential(_ string, selection credentialSelection) (credential, bool, error) {
|
||||
for _, cred := range p.credentials {
|
||||
if filter != nil && !filter(cred) {
|
||||
if !selection.allows(cred) {
|
||||
continue
|
||||
}
|
||||
if cred.isUsable() {
|
||||
@@ -1208,10 +1253,10 @@ func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bo
|
||||
return nil, false, allCredentialsUnavailableError(p.credentials)
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, filter func(credential) bool) credential {
|
||||
func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, selection credentialSelection) credential {
|
||||
cred.markRateLimited(resetAt)
|
||||
for _, candidate := range p.credentials {
|
||||
if filter != nil && !filter(candidate) {
|
||||
if !selection.allows(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
@@ -1233,7 +1278,11 @@ func (p *fallbackProvider) allCredentials() []credential {
|
||||
return p.credentials
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {}
|
||||
func (p *fallbackProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool {
|
||||
return func() bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) close() {}
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, erro
|
||||
})
|
||||
}
|
||||
|
||||
func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool {
|
||||
func hasAlternativeCredential(provider credentialProvider, currentCredential credential, selection credentialSelection) bool {
|
||||
if provider == nil || currentCredential == nil {
|
||||
return false
|
||||
}
|
||||
@@ -75,7 +75,7 @@ func hasAlternativeCredential(provider credentialProvider, currentCredential cre
|
||||
if cred == currentCredential {
|
||||
continue
|
||||
}
|
||||
if filter != nil && !filter(cred) {
|
||||
if !selection.allows(cred) {
|
||||
continue
|
||||
}
|
||||
if cred.isUsable() {
|
||||
@@ -109,16 +109,27 @@ func writeCredentialUnavailableError(
|
||||
r *http.Request,
|
||||
provider credentialProvider,
|
||||
currentCredential credential,
|
||||
filter func(credential) bool,
|
||||
selection credentialSelection,
|
||||
fallback string,
|
||||
) {
|
||||
if hasAlternativeCredential(provider, currentCredential, filter) {
|
||||
if hasAlternativeCredential(provider, currentCredential, selection) {
|
||||
writeRetryableUsageError(w, r)
|
||||
return
|
||||
}
|
||||
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, fallback))
|
||||
}
|
||||
|
||||
func credentialSelectionForUser(userConfig *option.CCMUser) credentialSelection {
|
||||
selection := credentialSelection{scope: credentialSelectionScopeAll}
|
||||
if userConfig != nil && !userConfig.AllowExternalUsage {
|
||||
selection.scope = credentialSelectionScopeNonExternal
|
||||
selection.filter = func(cred credential) bool {
|
||||
return !cred.isExternal()
|
||||
}
|
||||
}
|
||||
return selection
|
||||
}
|
||||
|
||||
func isHopByHopHeader(header string) bool {
|
||||
switch strings.ToLower(header) {
|
||||
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
|
||||
@@ -424,12 +435,9 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
var credentialFilter func(credential) bool
|
||||
if userConfig != nil && !userConfig.AllowExternalUsage {
|
||||
credentialFilter = func(c credential) bool { return !c.isExternal() }
|
||||
}
|
||||
selection := credentialSelectionForUser(userConfig)
|
||||
|
||||
selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter)
|
||||
selectedCredential, isNew, err := provider.selectCredential(sessionID, selection)
|
||||
if err != nil {
|
||||
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error()))
|
||||
return
|
||||
@@ -459,7 +467,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
requestContext := selectedCredential.wrapRequestContext(ctx)
|
||||
provider.wrapProviderInterrupt(selectedCredential, requestContext)
|
||||
{
|
||||
currentRequestContext := requestContext
|
||||
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
|
||||
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
||||
}))
|
||||
}
|
||||
defer func() {
|
||||
requestContext.cancelRequest()
|
||||
}()
|
||||
@@ -476,7 +489,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request")
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request")
|
||||
return
|
||||
}
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
|
||||
@@ -487,18 +500,23 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Transparent 429 retry
|
||||
for response.StatusCode == http.StatusTooManyRequests {
|
||||
resetAt := parseRateLimitResetFromHeaders(response.Header)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
if bodyBytes == nil || nextCredential == nil {
|
||||
response.Body.Close()
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
|
||||
return
|
||||
}
|
||||
response.Body.Close()
|
||||
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
requestContext.cancelRequest()
|
||||
requestContext = nextCredential.wrapRequestContext(ctx)
|
||||
provider.wrapProviderInterrupt(nextCredential, requestContext)
|
||||
{
|
||||
currentRequestContext := requestContext
|
||||
requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() {
|
||||
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
||||
}))
|
||||
}
|
||||
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if buildErr != nil {
|
||||
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
|
||||
@@ -511,7 +529,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request")
|
||||
return
|
||||
}
|
||||
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
|
||||
|
||||
@@ -852,22 +852,45 @@ func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *htt
|
||||
}
|
||||
|
||||
type credentialProvider interface {
|
||||
selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error)
|
||||
onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential
|
||||
wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext)
|
||||
selectCredential(sessionID string, selection credentialSelection) (credential, bool, error)
|
||||
onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential
|
||||
linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool
|
||||
pollIfStale(ctx context.Context)
|
||||
allCredentials() []credential
|
||||
close()
|
||||
}
|
||||
|
||||
type credentialSelectionScope string
|
||||
|
||||
const (
|
||||
credentialSelectionScopeAll credentialSelectionScope = "all"
|
||||
credentialSelectionScopeNonExternal credentialSelectionScope = "non_external"
|
||||
)
|
||||
|
||||
type credentialSelection struct {
|
||||
scope credentialSelectionScope
|
||||
filter func(credential) bool
|
||||
}
|
||||
|
||||
func (s credentialSelection) allows(cred credential) bool {
|
||||
return s.filter == nil || s.filter(cred)
|
||||
}
|
||||
|
||||
func (s credentialSelection) scopeOrDefault() credentialSelectionScope {
|
||||
if s.scope == "" {
|
||||
return credentialSelectionScopeAll
|
||||
}
|
||||
return s.scope
|
||||
}
|
||||
|
||||
type singleCredentialProvider struct {
|
||||
cred credential
|
||||
sessionAccess sync.RWMutex
|
||||
sessions map[string]time.Time
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) {
|
||||
if filter != nil && !filter(p.cred) {
|
||||
func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) {
|
||||
if !selection.allows(p.cred) {
|
||||
return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out")
|
||||
}
|
||||
if !p.cred.isAvailable() {
|
||||
@@ -892,7 +915,7 @@ func (p *singleCredentialProvider) selectCredential(sessionID string, filter fun
|
||||
return p.cred, isNew, nil
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ func(credential) bool) credential {
|
||||
func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential {
|
||||
cred.markRateLimited(resetAt)
|
||||
return nil
|
||||
}
|
||||
@@ -916,15 +939,25 @@ func (p *singleCredentialProvider) allCredentials() []credential {
|
||||
return []credential{p.cred}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {}
|
||||
func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool {
|
||||
return func() bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) close() {}
|
||||
|
||||
const sessionExpiry = 24 * time.Hour
|
||||
|
||||
type sessionEntry struct {
|
||||
tag string
|
||||
createdAt time.Time
|
||||
tag string
|
||||
selectionScope credentialSelectionScope
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
type credentialInterruptKey struct {
|
||||
tag string
|
||||
selectionScope credentialSelectionScope
|
||||
}
|
||||
|
||||
type credentialInterruptEntry struct {
|
||||
@@ -941,7 +974,7 @@ type balancerProvider struct {
|
||||
sessionMutex sync.RWMutex
|
||||
sessions map[string]sessionEntry
|
||||
interruptAccess sync.Mutex
|
||||
credentialInterrupts map[string]credentialInterruptEntry
|
||||
credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry
|
||||
logger log.ContextLogger
|
||||
}
|
||||
|
||||
@@ -959,34 +992,37 @@ func newBalancerProvider(credentials []credential, strategy string, pollInterval
|
||||
pollInterval: pollInterval,
|
||||
rebalanceThreshold: rebalanceThreshold,
|
||||
sessions: make(map[string]sessionEntry),
|
||||
credentialInterrupts: make(map[string]credentialInterruptEntry),
|
||||
credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) {
|
||||
func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) {
|
||||
selectionScope := selection.scopeOrDefault()
|
||||
if sessionID != "" {
|
||||
p.sessionMutex.RLock()
|
||||
entry, exists := p.sessions[sessionID]
|
||||
p.sessionMutex.RUnlock()
|
||||
if exists {
|
||||
for _, cred := range p.credentials {
|
||||
if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && (filter == nil || filter(cred)) && cred.isUsable() {
|
||||
if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") {
|
||||
better := p.pickLeastUsed(filter)
|
||||
if better != nil && better.tagName() != cred.tagName() {
|
||||
effectiveThreshold := p.rebalanceThreshold / cred.planWeight()
|
||||
delta := cred.weeklyUtilization() - better.weeklyUtilization()
|
||||
if delta > effectiveThreshold {
|
||||
p.logger.Info("rebalancing away from ", cred.tagName(),
|
||||
": utilization delta ", delta, "% exceeds effective threshold ",
|
||||
effectiveThreshold, "% (weight ", cred.planWeight(), ")")
|
||||
p.rebalanceCredential(cred.tagName())
|
||||
break
|
||||
if entry.selectionScope == selectionScope {
|
||||
for _, cred := range p.credentials {
|
||||
if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && selection.allows(cred) && cred.isUsable() {
|
||||
if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == "least_used") {
|
||||
better := p.pickLeastUsed(selection.filter)
|
||||
if better != nil && better.tagName() != cred.tagName() {
|
||||
effectiveThreshold := p.rebalanceThreshold / cred.planWeight()
|
||||
delta := cred.weeklyUtilization() - better.weeklyUtilization()
|
||||
if delta > effectiveThreshold {
|
||||
p.logger.Info("rebalancing away from ", cred.tagName(),
|
||||
": utilization delta ", delta, "% exceeds effective threshold ",
|
||||
effectiveThreshold, "% (weight ", cred.planWeight(), ")")
|
||||
p.rebalanceCredential(cred.tagName(), selectionScope)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return cred, false, nil
|
||||
}
|
||||
return cred, false, nil
|
||||
}
|
||||
}
|
||||
p.sessionMutex.Lock()
|
||||
@@ -995,7 +1031,7 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden
|
||||
}
|
||||
}
|
||||
|
||||
best := p.pickCredential(filter)
|
||||
best := p.pickCredential(selection.filter)
|
||||
if best == nil {
|
||||
return nil, false, allRateLimitedError(p.credentials)
|
||||
}
|
||||
@@ -1003,47 +1039,52 @@ func (p *balancerProvider) selectCredential(sessionID string, filter func(creden
|
||||
isNew := sessionID != ""
|
||||
if isNew {
|
||||
p.sessionMutex.Lock()
|
||||
p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()}
|
||||
p.sessions[sessionID] = sessionEntry{
|
||||
tag: best.tagName(),
|
||||
selectionScope: selectionScope,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
p.sessionMutex.Unlock()
|
||||
}
|
||||
return best, isNew, nil
|
||||
}
|
||||
|
||||
func (p *balancerProvider) rebalanceCredential(tag string) {
|
||||
func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) {
|
||||
key := credentialInterruptKey{tag: tag, selectionScope: selectionScope}
|
||||
p.interruptAccess.Lock()
|
||||
if entry, loaded := p.credentialInterrupts[tag]; loaded {
|
||||
if entry, loaded := p.credentialInterrupts[key]; loaded {
|
||||
entry.cancel()
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
p.credentialInterrupts[tag] = credentialInterruptEntry{context: ctx, cancel: cancel}
|
||||
p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel}
|
||||
p.interruptAccess.Unlock()
|
||||
|
||||
p.sessionMutex.Lock()
|
||||
for id, entry := range p.sessions {
|
||||
if entry.tag == tag {
|
||||
if entry.tag == tag && entry.selectionScope == selectionScope {
|
||||
delete(p.sessions, id)
|
||||
}
|
||||
}
|
||||
p.sessionMutex.Unlock()
|
||||
}
|
||||
|
||||
func (p *balancerProvider) wrapProviderInterrupt(cred credential, requestContext *credentialRequestContext) {
|
||||
tag := cred.tagName()
|
||||
func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool {
|
||||
key := credentialInterruptKey{
|
||||
tag: cred.tagName(),
|
||||
selectionScope: selection.scopeOrDefault(),
|
||||
}
|
||||
p.interruptAccess.Lock()
|
||||
entry, loaded := p.credentialInterrupts[tag]
|
||||
entry, loaded := p.credentialInterrupts[key]
|
||||
if !loaded {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
entry = credentialInterruptEntry{context: ctx, cancel: cancel}
|
||||
p.credentialInterrupts[tag] = entry
|
||||
p.credentialInterrupts[key] = entry
|
||||
}
|
||||
p.interruptAccess.Unlock()
|
||||
stop := context.AfterFunc(entry.context, func() {
|
||||
requestContext.cancelOnce.Do(requestContext.cancelFunc)
|
||||
})
|
||||
requestContext.addInterruptLink(stop)
|
||||
return context.AfterFunc(entry.context, onInterrupt)
|
||||
}
|
||||
|
||||
func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential {
|
||||
func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential {
|
||||
cred.markRateLimited(resetAt)
|
||||
if sessionID != "" {
|
||||
p.sessionMutex.Lock()
|
||||
@@ -1051,10 +1092,14 @@ func (p *balancerProvider) onRateLimited(sessionID string, cred credential, rese
|
||||
p.sessionMutex.Unlock()
|
||||
}
|
||||
|
||||
best := p.pickCredential(filter)
|
||||
best := p.pickCredential(selection.filter)
|
||||
if best != nil && sessionID != "" {
|
||||
p.sessionMutex.Lock()
|
||||
p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()}
|
||||
p.sessions[sessionID] = sessionEntry{
|
||||
tag: best.tagName(),
|
||||
selectionScope: selection.scopeOrDefault(),
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
p.sessionMutex.Unlock()
|
||||
}
|
||||
return best
|
||||
@@ -1193,9 +1238,9 @@ func newFallbackProvider(credentials []credential, pollInterval time.Duration, l
|
||||
}
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) {
|
||||
func (p *fallbackProvider) selectCredential(_ string, selection credentialSelection) (credential, bool, error) {
|
||||
for _, cred := range p.credentials {
|
||||
if filter != nil && !filter(cred) {
|
||||
if !selection.allows(cred) {
|
||||
continue
|
||||
}
|
||||
if !compositeCredentialSelectable(cred) {
|
||||
@@ -1208,10 +1253,10 @@ func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bo
|
||||
return nil, false, allRateLimitedError(p.credentials)
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, filter func(credential) bool) credential {
|
||||
func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, selection credentialSelection) credential {
|
||||
cred.markRateLimited(resetAt)
|
||||
for _, candidate := range p.credentials {
|
||||
if filter != nil && !filter(candidate) {
|
||||
if !selection.allows(candidate) {
|
||||
continue
|
||||
}
|
||||
if !compositeCredentialSelectable(candidate) {
|
||||
@@ -1236,7 +1281,11 @@ func (p *fallbackProvider) allCredentials() []credential {
|
||||
return p.credentials
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) wrapProviderInterrupt(_ credential, _ *credentialRequestContext) {}
|
||||
func (p *fallbackProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool {
|
||||
return func() bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) close() {}
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ const (
|
||||
retryableUsageCode = "credential_usage_exhausted"
|
||||
)
|
||||
|
||||
func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool {
|
||||
func hasAlternativeCredential(provider credentialProvider, currentCredential credential, selection credentialSelection) bool {
|
||||
if provider == nil || currentCredential == nil {
|
||||
return false
|
||||
}
|
||||
@@ -83,7 +83,7 @@ func hasAlternativeCredential(provider credentialProvider, currentCredential cre
|
||||
if cred == currentCredential {
|
||||
continue
|
||||
}
|
||||
if filter != nil && !filter(cred) {
|
||||
if !selection.allows(cred) {
|
||||
continue
|
||||
}
|
||||
if cred.isUsable() {
|
||||
@@ -117,16 +117,27 @@ func writeCredentialUnavailableError(
|
||||
r *http.Request,
|
||||
provider credentialProvider,
|
||||
currentCredential credential,
|
||||
filter func(credential) bool,
|
||||
selection credentialSelection,
|
||||
fallback string,
|
||||
) {
|
||||
if hasAlternativeCredential(provider, currentCredential, filter) {
|
||||
if hasAlternativeCredential(provider, currentCredential, selection) {
|
||||
writeRetryableUsageError(w, r)
|
||||
return
|
||||
}
|
||||
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, fallback))
|
||||
}
|
||||
|
||||
func credentialSelectionForUser(userConfig *option.OCMUser) credentialSelection {
|
||||
selection := credentialSelection{scope: credentialSelectionScopeAll}
|
||||
if userConfig != nil && !userConfig.AllowExternalUsage {
|
||||
selection.scope = credentialSelectionScopeNonExternal
|
||||
selection.filter = func(cred credential) bool {
|
||||
return !cred.isExternal()
|
||||
}
|
||||
}
|
||||
return selection
|
||||
}
|
||||
|
||||
func isHopByHopHeader(header string) bool {
|
||||
switch strings.ToLower(header) {
|
||||
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
|
||||
@@ -426,19 +437,16 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
provider.pollIfStale(s.ctx)
|
||||
|
||||
var credentialFilter func(credential) bool
|
||||
if userConfig != nil && !userConfig.AllowExternalUsage {
|
||||
credentialFilter = func(c credential) bool { return !c.isExternal() }
|
||||
}
|
||||
selection := credentialSelectionForUser(userConfig)
|
||||
|
||||
selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter)
|
||||
selectedCredential, isNew, err := provider.selectCredential(sessionID, selection)
|
||||
if err != nil {
|
||||
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
|
||||
s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew)
|
||||
s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, selection, isNew)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -500,7 +508,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
requestContext := selectedCredential.wrapRequestContext(ctx)
|
||||
provider.wrapProviderInterrupt(selectedCredential, requestContext)
|
||||
{
|
||||
currentRequestContext := requestContext
|
||||
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
|
||||
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
||||
}))
|
||||
}
|
||||
defer func() {
|
||||
requestContext.cancelRequest()
|
||||
}()
|
||||
@@ -517,7 +530,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request")
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request")
|
||||
return
|
||||
}
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
|
||||
@@ -528,19 +541,24 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Transparent 429 retry
|
||||
for response.StatusCode == http.StatusTooManyRequests {
|
||||
resetAt := parseOCMRateLimitResetFromHeaders(response.Header)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
|
||||
needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil {
|
||||
response.Body.Close()
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
|
||||
return
|
||||
}
|
||||
response.Body.Close()
|
||||
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
requestContext.cancelRequest()
|
||||
requestContext = nextCredential.wrapRequestContext(ctx)
|
||||
provider.wrapProviderInterrupt(nextCredential, requestContext)
|
||||
{
|
||||
currentRequestContext := requestContext
|
||||
requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() {
|
||||
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
||||
}))
|
||||
}
|
||||
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if buildErr != nil {
|
||||
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
|
||||
@@ -553,7 +571,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request")
|
||||
return
|
||||
}
|
||||
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
|
||||
|
||||
@@ -26,16 +26,24 @@ import (
|
||||
)
|
||||
|
||||
type webSocketSession struct {
|
||||
clientConn net.Conn
|
||||
upstreamConn net.Conn
|
||||
credentialTag string
|
||||
closeOnce sync.Once
|
||||
clientConn net.Conn
|
||||
upstreamConn net.Conn
|
||||
credentialTag string
|
||||
releaseProviderInterrupt func()
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (s *webSocketSession) Close() {
|
||||
s.closeOnce.Do(func() {
|
||||
s.clientConn.Close()
|
||||
s.upstreamConn.Close()
|
||||
if s.releaseProviderInterrupt != nil {
|
||||
s.releaseProviderInterrupt()
|
||||
}
|
||||
if s.clientConn != nil {
|
||||
s.clientConn.Close()
|
||||
}
|
||||
if s.upstreamConn != nil {
|
||||
s.upstreamConn.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -91,12 +99,14 @@ func (s *Service) handleWebSocket(
|
||||
userConfig *option.OCMUser,
|
||||
provider credentialProvider,
|
||||
selectedCredential credential,
|
||||
credentialFilter func(credential) bool,
|
||||
selection credentialSelection,
|
||||
isNew bool,
|
||||
) {
|
||||
var (
|
||||
err error
|
||||
requestContext *credentialRequestContext
|
||||
clientConn net.Conn
|
||||
session *webSocketSession
|
||||
upstreamConn net.Conn
|
||||
upstreamBufferedReader *bufio.Reader
|
||||
upstreamResponseHeaders http.Header
|
||||
@@ -186,20 +196,36 @@ func (s *Service) handleWebSocket(
|
||||
}
|
||||
|
||||
requestContext = selectedCredential.wrapRequestContext(ctx)
|
||||
provider.wrapProviderInterrupt(selectedCredential, requestContext)
|
||||
{
|
||||
currentRequestContext := requestContext
|
||||
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
|
||||
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
||||
if session != nil {
|
||||
session.Close()
|
||||
return
|
||||
}
|
||||
if clientConn != nil {
|
||||
clientConn.Close()
|
||||
}
|
||||
if upstreamConn != nil {
|
||||
upstreamConn.Close()
|
||||
}
|
||||
}))
|
||||
}
|
||||
upstreamConn, upstreamBufferedReader, _, err = upstreamDialer.Dial(requestContext, upstreamURL)
|
||||
if err == nil {
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
break
|
||||
}
|
||||
requestContext.cancelRequest()
|
||||
requestContext = nil
|
||||
upstreamConn = nil
|
||||
clientConn = nil
|
||||
if statusCode == http.StatusTooManyRequests {
|
||||
resetAt := parseOCMRateLimitResetFromHeaders(upstreamResponseHeaders)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
|
||||
selectedCredential.updateStateFromHeaders(upstreamResponseHeaders)
|
||||
if nextCredential == nil {
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
|
||||
return
|
||||
}
|
||||
s.logger.InfoContext(ctx, "retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
@@ -236,16 +262,17 @@ func (s *Service) handleWebSocket(
|
||||
writeJSONError(w, r, http.StatusServiceUnavailable, "api_error", "service is shutting down")
|
||||
return
|
||||
}
|
||||
clientConn, _, _, err := clientUpgrader.Upgrade(r, w)
|
||||
clientConn, _, _, err = clientUpgrader.Upgrade(r, w)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "upgrade client websocket: ", err)
|
||||
upstreamConn.Close()
|
||||
return
|
||||
}
|
||||
session := &webSocketSession{
|
||||
clientConn: clientConn,
|
||||
upstreamConn: upstreamConn,
|
||||
credentialTag: selectedCredential.tagName(),
|
||||
session = &webSocketSession{
|
||||
clientConn: clientConn,
|
||||
upstreamConn: upstreamConn,
|
||||
credentialTag: selectedCredential.tagName(),
|
||||
releaseProviderInterrupt: requestContext.releaseCredentialInterrupt,
|
||||
}
|
||||
if !s.registerWebSocketSession(session) {
|
||||
session.Close()
|
||||
|
||||
Reference in New Issue
Block a user