pollUsage(ctx) accepted caller context, and service_status.go passed r.Context() which gets canceled on client disconnect or service shutdown. This caused incrementPollFailures → interruptConnections on transient cancellations. Each implementation now uses its own persistent context: defaultCredential uses serviceContext, externalCredential uses getReverseContext().
433 lines
12 KiB
Go
433 lines
12 KiB
Go
package ccm
|
|
|
|
import (
|
|
"context"
|
|
"math/rand/v2"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
C "github.com/sagernet/sing-box/constant"
|
|
"github.com/sagernet/sing-box/log"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
)
|
|
|
|
type credentialProvider interface {
|
|
selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error)
|
|
onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential
|
|
linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool
|
|
pollIfStale()
|
|
allCredentials() []Credential
|
|
close()
|
|
}
|
|
|
|
type singleCredentialProvider struct {
|
|
credential Credential
|
|
sessionAccess sync.RWMutex
|
|
sessions map[string]time.Time
|
|
}
|
|
|
|
func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) {
|
|
if !selection.allows(p.credential) {
|
|
return nil, false, E.New("credential ", p.credential.tagName(), " is filtered out")
|
|
}
|
|
if !p.credential.isAvailable() {
|
|
return nil, false, p.credential.unavailableError()
|
|
}
|
|
if !p.credential.isUsable() {
|
|
return nil, false, E.New("credential ", p.credential.tagName(), " is rate-limited")
|
|
}
|
|
var isNew bool
|
|
if sessionID != "" {
|
|
p.sessionAccess.Lock()
|
|
if p.sessions == nil {
|
|
p.sessions = make(map[string]time.Time)
|
|
}
|
|
_, exists := p.sessions[sessionID]
|
|
if !exists {
|
|
p.sessions[sessionID] = time.Now()
|
|
isNew = true
|
|
}
|
|
p.sessionAccess.Unlock()
|
|
}
|
|
return p.credential, isNew, nil
|
|
}
|
|
|
|
func (p *singleCredentialProvider) onRateLimited(_ string, credential Credential, resetAt time.Time, _ credentialSelection) Credential {
|
|
credential.markRateLimited(resetAt)
|
|
return nil
|
|
}
|
|
|
|
func (p *singleCredentialProvider) pollIfStale() {
|
|
now := time.Now()
|
|
p.sessionAccess.Lock()
|
|
for id, createdAt := range p.sessions {
|
|
if now.Sub(createdAt) > sessionExpiry {
|
|
delete(p.sessions, id)
|
|
}
|
|
}
|
|
p.sessionAccess.Unlock()
|
|
|
|
if time.Since(p.credential.lastUpdatedTime()) > p.credential.pollBackoff(defaultPollInterval) {
|
|
p.credential.pollUsage()
|
|
}
|
|
}
|
|
|
|
func (p *singleCredentialProvider) allCredentials() []Credential {
|
|
return []Credential{p.credential}
|
|
}
|
|
|
|
func (p *singleCredentialProvider) linkProviderInterrupt(_ Credential, _ credentialSelection, _ func()) func() bool {
|
|
return func() bool {
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (p *singleCredentialProvider) close() {}
|
|
|
|
type sessionEntry struct {
|
|
tag string
|
|
selectionScope credentialSelectionScope
|
|
createdAt time.Time
|
|
}
|
|
|
|
type credentialInterruptKey struct {
|
|
tag string
|
|
selectionScope credentialSelectionScope
|
|
}
|
|
|
|
type credentialInterruptEntry struct {
|
|
context context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
type balancerProvider struct {
|
|
credentials []Credential
|
|
strategy string
|
|
roundRobinIndex atomic.Uint64
|
|
pollInterval time.Duration
|
|
rebalanceThreshold float64
|
|
sessionAccess sync.RWMutex
|
|
sessions map[string]sessionEntry
|
|
interruptAccess sync.Mutex
|
|
credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry
|
|
logger log.ContextLogger
|
|
}
|
|
|
|
func newBalancerProvider(credentials []Credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider {
|
|
if pollInterval <= 0 {
|
|
pollInterval = defaultPollInterval
|
|
}
|
|
return &balancerProvider{
|
|
credentials: credentials,
|
|
strategy: strategy,
|
|
pollInterval: pollInterval,
|
|
rebalanceThreshold: rebalanceThreshold,
|
|
sessions: make(map[string]sessionEntry),
|
|
credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry),
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) {
|
|
selectionScope := selection.scopeOrDefault()
|
|
for {
|
|
if p.strategy == C.BalancerStrategyFallback {
|
|
best := p.pickCredential(selection.filter)
|
|
if best == nil {
|
|
return nil, false, allCredentialsUnavailableError(p.credentials)
|
|
}
|
|
return best, p.storeSessionIfAbsent(sessionID, sessionEntry{createdAt: time.Now()}), nil
|
|
}
|
|
|
|
if sessionID != "" {
|
|
p.sessionAccess.RLock()
|
|
entry, exists := p.sessions[sessionID]
|
|
p.sessionAccess.RUnlock()
|
|
if exists {
|
|
if entry.selectionScope == selectionScope {
|
|
for _, credential := range p.credentials {
|
|
if credential.tagName() == entry.tag && selection.allows(credential) && credential.isUsable() {
|
|
if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) {
|
|
better := p.pickLeastUsed(selection.filter)
|
|
if better != nil && better.tagName() != credential.tagName() {
|
|
effectiveThreshold := p.rebalanceThreshold / credential.planWeight()
|
|
delta := credential.weeklyUtilization() - better.weeklyUtilization()
|
|
if delta > effectiveThreshold {
|
|
p.logger.Info("rebalancing away from ", credential.tagName(),
|
|
": utilization delta ", delta, "% exceeds effective threshold ",
|
|
effectiveThreshold, "% (weight ", credential.planWeight(), ")")
|
|
p.rebalanceCredential(credential.tagName(), selectionScope)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return credential, false, nil
|
|
}
|
|
}
|
|
}
|
|
p.sessionAccess.Lock()
|
|
currentEntry, stillExists := p.sessions[sessionID]
|
|
if stillExists && currentEntry == entry {
|
|
delete(p.sessions, sessionID)
|
|
p.sessionAccess.Unlock()
|
|
} else {
|
|
p.sessionAccess.Unlock()
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
best := p.pickCredential(selection.filter)
|
|
if best == nil {
|
|
return nil, false, allCredentialsUnavailableError(p.credentials)
|
|
}
|
|
if p.storeSessionIfAbsent(sessionID, sessionEntry{
|
|
tag: best.tagName(),
|
|
selectionScope: selectionScope,
|
|
createdAt: time.Now(),
|
|
}) {
|
|
return best, true, nil
|
|
}
|
|
if sessionID == "" {
|
|
return best, false, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *balancerProvider) storeSessionIfAbsent(sessionID string, entry sessionEntry) bool {
|
|
if sessionID == "" {
|
|
return false
|
|
}
|
|
p.sessionAccess.Lock()
|
|
defer p.sessionAccess.Unlock()
|
|
if _, exists := p.sessions[sessionID]; exists {
|
|
return false
|
|
}
|
|
p.sessions[sessionID] = entry
|
|
return true
|
|
}
|
|
|
|
func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) {
|
|
key := credentialInterruptKey{tag: tag, selectionScope: selectionScope}
|
|
p.interruptAccess.Lock()
|
|
if entry, loaded := p.credentialInterrupts[key]; loaded {
|
|
entry.cancel()
|
|
}
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel}
|
|
p.interruptAccess.Unlock()
|
|
|
|
p.sessionAccess.Lock()
|
|
for id, entry := range p.sessions {
|
|
if entry.tag == tag && entry.selectionScope == selectionScope {
|
|
delete(p.sessions, id)
|
|
}
|
|
}
|
|
p.sessionAccess.Unlock()
|
|
}
|
|
|
|
func (p *balancerProvider) linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool {
|
|
if p.strategy == C.BalancerStrategyFallback {
|
|
return func() bool { return false }
|
|
}
|
|
key := credentialInterruptKey{
|
|
tag: credential.tagName(),
|
|
selectionScope: selection.scopeOrDefault(),
|
|
}
|
|
p.interruptAccess.Lock()
|
|
entry, loaded := p.credentialInterrupts[key]
|
|
if !loaded {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
entry = credentialInterruptEntry{context: ctx, cancel: cancel}
|
|
p.credentialInterrupts[key] = entry
|
|
}
|
|
p.interruptAccess.Unlock()
|
|
return context.AfterFunc(entry.context, onInterrupt)
|
|
}
|
|
|
|
func (p *balancerProvider) onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential {
|
|
credential.markRateLimited(resetAt)
|
|
if p.strategy == C.BalancerStrategyFallback {
|
|
return p.pickCredential(selection.filter)
|
|
}
|
|
if sessionID != "" {
|
|
p.sessionAccess.Lock()
|
|
delete(p.sessions, sessionID)
|
|
p.sessionAccess.Unlock()
|
|
}
|
|
|
|
best := p.pickCredential(selection.filter)
|
|
if best != nil && sessionID != "" {
|
|
p.sessionAccess.Lock()
|
|
p.sessions[sessionID] = sessionEntry{
|
|
tag: best.tagName(),
|
|
selectionScope: selection.scopeOrDefault(),
|
|
createdAt: time.Now(),
|
|
}
|
|
p.sessionAccess.Unlock()
|
|
}
|
|
return best
|
|
}
|
|
|
|
func (p *balancerProvider) pickCredential(filter func(Credential) bool) Credential {
|
|
switch p.strategy {
|
|
case C.BalancerStrategyRoundRobin:
|
|
return p.pickRoundRobin(filter)
|
|
case C.BalancerStrategyRandom:
|
|
return p.pickRandom(filter)
|
|
case C.BalancerStrategyFallback:
|
|
return p.pickFallback(filter)
|
|
default:
|
|
return p.pickLeastUsed(filter)
|
|
}
|
|
}
|
|
|
|
func (p *balancerProvider) pickFallback(filter func(Credential) bool) Credential {
|
|
for _, credential := range p.credentials {
|
|
if filter != nil && !filter(credential) {
|
|
continue
|
|
}
|
|
if credential.isUsable() {
|
|
return credential
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
const weeklyWindowHours = 7 * 24
|
|
|
|
func (p *balancerProvider) pickLeastUsed(filter func(Credential) bool) Credential {
|
|
var best Credential
|
|
bestScore := float64(-1)
|
|
now := time.Now()
|
|
for _, credential := range p.credentials {
|
|
if filter != nil && !filter(credential) {
|
|
continue
|
|
}
|
|
if !credential.isUsable() {
|
|
continue
|
|
}
|
|
remaining := credential.weeklyCap() - credential.weeklyUtilization()
|
|
score := remaining * credential.planWeight()
|
|
resetTime := credential.weeklyResetTime()
|
|
if !resetTime.IsZero() {
|
|
timeUntilReset := resetTime.Sub(now)
|
|
if timeUntilReset < time.Hour {
|
|
timeUntilReset = time.Hour
|
|
}
|
|
score *= weeklyWindowHours / timeUntilReset.Hours()
|
|
}
|
|
if score > bestScore {
|
|
bestScore = score
|
|
best = credential
|
|
}
|
|
}
|
|
return best
|
|
}
|
|
|
|
func (p *balancerProvider) pickRoundRobin(filter func(Credential) bool) Credential {
|
|
start := int(p.roundRobinIndex.Add(1) - 1)
|
|
count := len(p.credentials)
|
|
for offset := range count {
|
|
candidate := p.credentials[(start+offset)%count]
|
|
if filter != nil && !filter(candidate) {
|
|
continue
|
|
}
|
|
if candidate.isUsable() {
|
|
return candidate
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *balancerProvider) pickRandom(filter func(Credential) bool) Credential {
|
|
var usable []Credential
|
|
for _, candidate := range p.credentials {
|
|
if filter != nil && !filter(candidate) {
|
|
continue
|
|
}
|
|
if candidate.isUsable() {
|
|
usable = append(usable, candidate)
|
|
}
|
|
}
|
|
if len(usable) == 0 {
|
|
return nil
|
|
}
|
|
return usable[rand.IntN(len(usable))]
|
|
}
|
|
|
|
func (p *balancerProvider) pollIfStale() {
|
|
now := time.Now()
|
|
p.sessionAccess.Lock()
|
|
for id, entry := range p.sessions {
|
|
if now.Sub(entry.createdAt) > sessionExpiry {
|
|
delete(p.sessions, id)
|
|
}
|
|
}
|
|
p.sessionAccess.Unlock()
|
|
|
|
p.interruptAccess.Lock()
|
|
for key, entry := range p.credentialInterrupts {
|
|
if entry.context.Err() != nil {
|
|
delete(p.credentialInterrupts, key)
|
|
}
|
|
}
|
|
p.interruptAccess.Unlock()
|
|
|
|
for _, credential := range p.credentials {
|
|
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) {
|
|
credential.pollUsage()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *balancerProvider) allCredentials() []Credential {
|
|
return p.credentials
|
|
}
|
|
|
|
func (p *balancerProvider) close() {}
|
|
|
|
func ccmPlanWeight(accountType string, rateLimitTier string) float64 {
|
|
switch accountType {
|
|
case "max":
|
|
switch rateLimitTier {
|
|
case "default_claude_max_20x":
|
|
return 10
|
|
case "default_claude_max_5x":
|
|
return 5
|
|
default:
|
|
return 5
|
|
}
|
|
case "team":
|
|
if rateLimitTier == "default_claude_max_5x" {
|
|
return 5
|
|
}
|
|
return 1
|
|
default:
|
|
return 1
|
|
}
|
|
}
|
|
|
|
func allCredentialsUnavailableError(credentials []Credential) error {
|
|
var hasUnavailable bool
|
|
var earliest time.Time
|
|
for _, credential := range credentials {
|
|
if credential.unavailableError() != nil {
|
|
hasUnavailable = true
|
|
continue
|
|
}
|
|
resetAt := credential.earliestReset()
|
|
if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) {
|
|
earliest = resetAt
|
|
}
|
|
}
|
|
if hasUnavailable {
|
|
return E.New("all credentials unavailable")
|
|
}
|
|
if earliest.IsZero() {
|
|
return E.New("all credentials rate-limited")
|
|
}
|
|
return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest)))
|
|
}
|