ccm/ocm: Add external credential support for cross-instance usage sharing

Extract credential interface from *defaultCredential to support both
default (OAuth) and external (remote proxy) credential types. External
credentials proxy requests to a remote ccm/ocm instance with bearer
token auth, poll a /status endpoint for utilization, and parse
aggregated rate limit headers from responses.

Add allow_external_usage user flag to control whether balancer/fallback
providers may select external credentials. Add status endpoint
(/ccm/v1/status, /ocm/v1/status) returning averaged utilization across
eligible credentials. Rewrite response rate limit headers for external
users with aggregated values.
This commit is contained in:
世界
2026-03-12 23:17:47 +08:00
parent 2801bce815
commit da8ff6f578
9 changed files with 1829 additions and 526 deletions

View File

@@ -19,15 +19,18 @@ type CCMServiceOptions struct {
}
type CCMUser struct {
Name string `json:"name,omitempty"`
Token string `json:"token,omitempty"`
Credential string `json:"credential,omitempty"`
Name string `json:"name,omitempty"`
Token string `json:"token,omitempty"`
Credential string `json:"credential,omitempty"`
ExternalCredential string `json:"external_credential,omitempty"`
AllowExternalUsage bool `json:"allow_external_usage,omitempty"`
}
type _CCMCredential struct {
Type string `json:"type,omitempty"`
Tag string `json:"tag"`
DefaultOptions CCMDefaultCredentialOptions `json:"-"`
ExternalOptions CCMExternalCredentialOptions `json:"-"`
BalancerOptions CCMBalancerCredentialOptions `json:"-"`
FallbackOptions CCMFallbackCredentialOptions `json:"-"`
}
@@ -40,6 +43,8 @@ func (c CCMCredential) MarshalJSON() ([]byte, error) {
case "", "default":
c.Type = ""
v = c.DefaultOptions
case "external":
v = c.ExternalOptions
case "balancer":
v = c.BalancerOptions
case "fallback":
@@ -63,6 +68,8 @@ func (c *CCMCredential) UnmarshalJSON(bytes []byte) error {
case "", "default":
c.Type = "default"
v = &c.DefaultOptions
case "external":
v = &c.ExternalOptions
case "balancer":
v = &c.BalancerOptions
case "fallback":
@@ -87,6 +94,15 @@ type CCMBalancerCredentialOptions struct {
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
}
type CCMExternalCredentialOptions struct {
URL string `json:"url"`
ServerOptions
Token string `json:"token"`
Detour string `json:"detour,omitempty"`
UsagesPath string `json:"usages_path,omitempty"`
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
}
type CCMFallbackCredentialOptions struct {
Credentials badoption.Listable[string] `json:"credentials"`
PollInterval badoption.Duration `json:"poll_interval,omitempty"`

View File

@@ -19,15 +19,18 @@ type OCMServiceOptions struct {
}
type OCMUser struct {
Name string `json:"name,omitempty"`
Token string `json:"token,omitempty"`
Credential string `json:"credential,omitempty"`
Name string `json:"name,omitempty"`
Token string `json:"token,omitempty"`
Credential string `json:"credential,omitempty"`
ExternalCredential string `json:"external_credential,omitempty"`
AllowExternalUsage bool `json:"allow_external_usage,omitempty"`
}
type _OCMCredential struct {
Type string `json:"type,omitempty"`
Tag string `json:"tag"`
DefaultOptions OCMDefaultCredentialOptions `json:"-"`
ExternalOptions OCMExternalCredentialOptions `json:"-"`
BalancerOptions OCMBalancerCredentialOptions `json:"-"`
FallbackOptions OCMFallbackCredentialOptions `json:"-"`
}
@@ -40,6 +43,8 @@ func (c OCMCredential) MarshalJSON() ([]byte, error) {
case "", "default":
c.Type = ""
v = c.DefaultOptions
case "external":
v = c.ExternalOptions
case "balancer":
v = c.BalancerOptions
case "fallback":
@@ -63,6 +68,8 @@ func (c *OCMCredential) UnmarshalJSON(bytes []byte) error {
case "", "default":
c.Type = "default"
v = &c.DefaultOptions
case "external":
v = &c.ExternalOptions
case "balancer":
v = &c.BalancerOptions
case "fallback":
@@ -87,6 +94,15 @@ type OCMBalancerCredentialOptions struct {
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
}
type OCMExternalCredentialOptions struct {
URL string `json:"url"`
ServerOptions
Token string `json:"token"`
Detour string `json:"detour,omitempty"`
UsagesPath string `json:"usages_path,omitempty"`
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
}
type OCMFallbackCredentialOptions struct {
Credentials badoption.Listable[string] `json:"credentials"`
PollInterval badoption.Duration `json:"poll_interval,omitempty"`

View File

@@ -0,0 +1,428 @@
package ccm
import (
"bytes"
"context"
stdTLS "crypto/tls"
"encoding/json"
"io"
"net"
"net/http"
"net/url"
"strconv"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/ntp"
)
type externalCredential struct {
tag string
baseURL string
token string
httpClient *http.Client
state credentialState
stateMutex sync.RWMutex
pollAccess sync.Mutex
pollInterval time.Duration
usageTracker *AggregatedUsage
logger log.ContextLogger
onBecameUnusable func()
interrupted bool
requestContext context.Context
cancelRequests context.CancelFunc
requestAccess sync.Mutex
}
func newExternalCredential(ctx context.Context, tag string, options option.CCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) {
parsedURL, err := url.Parse(options.URL)
if err != nil {
return nil, E.Cause(err, "parse url for credential ", tag)
}
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: option.DialerOptions{
Detour: options.Detour,
},
RemoteIsDomain: true,
})
if err != nil {
return nil, E.Cause(err, "create dialer for credential ", tag)
}
transport := &http.Transport{
ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if options.Server != "" {
serverPort := options.ServerPort
if serverPort == 0 {
portStr := parsedURL.Port()
if portStr != "" {
port, parseErr := strconv.ParseUint(portStr, 10, 16)
if parseErr == nil {
serverPort = uint16(port)
}
}
if serverPort == 0 {
if parsedURL.Scheme == "https" {
serverPort = 443
} else {
serverPort = 80
}
}
}
destination := M.ParseSocksaddrHostPort(options.Server, serverPort)
return credentialDialer.DialContext(ctx, network, destination)
}
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
}
if parsedURL.Scheme == "https" {
transport.TLSClientConfig = &stdTLS.Config{
ServerName: parsedURL.Hostname(),
RootCAs: adapter.RootPoolFromContext(ctx),
Time: ntp.TimeFuncFromContext(ctx),
}
}
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
if parsedURL.Path != "" && parsedURL.Path != "/" {
baseURL += parsedURL.Path
}
// Strip trailing slash
if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
baseURL = baseURL[:len(baseURL)-1]
}
pollInterval := time.Duration(options.PollInterval)
if pollInterval <= 0 {
pollInterval = 30 * time.Minute
}
requestContext, cancelRequests := context.WithCancel(context.Background())
cred := &externalCredential{
tag: tag,
baseURL: baseURL,
token: options.Token,
httpClient: &http.Client{Transport: transport},
pollInterval: pollInterval,
logger: logger,
requestContext: requestContext,
cancelRequests: cancelRequests,
}
if options.UsagesPath != "" {
cred.usageTracker = &AggregatedUsage{
LastUpdated: time.Now(),
Combinations: make([]CostCombination, 0),
filePath: options.UsagesPath,
logger: logger,
}
}
return cred, nil
}
func (c *externalCredential) start() error {
if c.usageTracker != nil {
err := c.usageTracker.Load()
if err != nil {
c.logger.Warn("load usage statistics for ", c.tag, ": ", err)
}
}
return nil
}
func (c *externalCredential) tagName() string {
return c.tag
}
func (c *externalCredential) isExternal() bool {
return true
}
func (c *externalCredential) isUsable() bool {
c.stateMutex.RLock()
if c.state.hardRateLimited {
if time.Now().Before(c.state.rateLimitResetAt) {
c.stateMutex.RUnlock()
return false
}
c.stateMutex.RUnlock()
c.stateMutex.Lock()
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
// No reserve for external: only 100% is unusable
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
c.stateMutex.Unlock()
return usable
}
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
c.stateMutex.RUnlock()
return usable
}
func (c *externalCredential) fiveHourUtilization() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
return c.state.fiveHourUtilization
}
func (c *externalCredential) weeklyUtilization() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
return c.state.weeklyUtilization
}
func (c *externalCredential) markRateLimited(resetAt time.Time) {
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
c.stateMutex.Lock()
c.state.hardRateLimited = true
c.state.rateLimitResetAt = resetAt
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *externalCredential) earliestReset() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
if c.state.hardRateLimited {
return c.state.rateLimitResetAt
}
earliest := c.state.fiveHourReset
if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) {
earliest = c.state.weeklyReset
}
return earliest
}
func (c *externalCredential) getAccessToken() (string, error) {
return c.token, nil
}
func (c *externalCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, _ http.Header) (*http.Request, error) {
proxyURL := c.baseURL + original.URL.RequestURI()
var body io.Reader
if bodyBytes != nil {
body = bytes.NewReader(bodyBytes)
} else {
body = original.Body
}
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
if err != nil {
return nil, err
}
for key, values := range original.Header {
if !isHopByHopHeader(key) && key != "Authorization" {
proxyRequest.Header[key] = values
}
}
proxyRequest.Header.Set("Authorization", "Bearer "+c.token)
return proxyRequest, nil
}
func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
c.stateMutex.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" {
value, err := strconv.ParseFloat(utilization, 64)
if err == nil {
// Remote CCM writes aggregated utilization as 0.0-1.0; convert to percentage
c.state.fiveHourUtilization = value * 100
}
}
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists {
c.state.fiveHourReset = value
}
if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" {
value, err := strconv.ParseFloat(utilization, 64)
if err == nil {
c.state.weeklyUtilization = value * 100
}
}
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists {
c.state.weeklyReset = value
}
c.state.lastUpdated = time.Now()
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%")
}
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *externalCredential) checkTransitionLocked() bool {
unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100
if unusable && !c.interrupted {
c.interrupted = true
return true
}
if !unusable && c.interrupted {
c.interrupted = false
}
return false
}
func (c *externalCredential) wrapRequestContext(parent context.Context) *credentialRequestContext {
c.requestAccess.Lock()
credentialContext := c.requestContext
c.requestAccess.Unlock()
derived, cancel := context.WithCancel(parent)
stop := context.AfterFunc(credentialContext, func() {
cancel()
})
return &credentialRequestContext{
Context: derived,
releaseFunc: stop,
cancelFunc: cancel,
}
}
func (c *externalCredential) interruptConnections() {
c.logger.Warn("interrupting connections for ", c.tag)
c.requestAccess.Lock()
c.cancelRequests()
c.requestContext, c.cancelRequests = context.WithCancel(context.Background())
c.requestAccess.Unlock()
if c.onBecameUnusable != nil {
c.onBecameUnusable()
}
}
func (c *externalCredential) pollUsage(ctx context.Context) {
if !c.pollAccess.TryLock() {
return
}
defer c.pollAccess.Unlock()
defer c.markUsagePollAttempted()
statusURL := c.baseURL + "/ccm/v1/status"
httpClient := &http.Client{
Transport: c.httpClient.Transport,
Timeout: 5 * time.Second,
}
request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil)
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": create request: ", err)
return
}
request.Header.Set("Authorization", "Bearer "+c.token)
response, err := httpClient.Do(request)
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": ", err)
c.stateMutex.Lock()
c.state.consecutivePollFailures++
c.stateMutex.Unlock()
return
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
c.stateMutex.Lock()
c.state.consecutivePollFailures++
c.stateMutex.Unlock()
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
return
}
var statusResponse struct {
FiveHourUtilization float64 `json:"five_hour_utilization"`
WeeklyUtilization float64 `json:"weekly_utilization"`
}
err = json.NewDecoder(response.Body).Decode(&statusResponse)
if err != nil {
c.stateMutex.Lock()
c.state.consecutivePollFailures++
c.stateMutex.Unlock()
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
return
}
c.stateMutex.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
c.state.consecutivePollFailures = 0
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%")
}
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *externalCredential) lastUpdatedTime() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
return c.state.lastUpdated
}
func (c *externalCredential) markUsagePollAttempted() {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
c.state.lastUpdated = time.Now()
}
func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration {
c.stateMutex.RLock()
failures := c.state.consecutivePollFailures
c.stateMutex.RUnlock()
if failures <= 0 {
return baseInterval
}
if failures > 4 {
failures = 4
}
return baseInterval * time.Duration(1<<failures)
}
func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage {
return c.usageTracker
}
func (c *externalCredential) httpTransport() *http.Client {
return c.httpClient
}
func (c *externalCredential) close() {
if c.usageTracker != nil {
c.usageTracker.cancelPendingSave()
err := c.usageTracker.Save()
if err != nil {
c.logger.Error("save usage statistics for ", c.tag, ": ", err)
}
}
}

View File

@@ -81,6 +81,31 @@ func (c *credentialRequestContext) cancelRequest() {
c.cancelOnce.Do(c.cancelFunc)
}
type credential interface {
tagName() string
isUsable() bool
isExternal() bool
fiveHourUtilization() float64
weeklyUtilization() float64
markRateLimited(resetAt time.Time)
earliestReset() time.Time
getAccessToken() (string, error)
buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error)
updateStateFromHeaders(header http.Header)
wrapRequestContext(ctx context.Context) *credentialRequestContext
interruptConnections()
start() error
pollUsage(ctx context.Context)
lastUpdatedTime() time.Time
pollBackoff(base time.Duration) time.Duration
usageTrackerOrNil() *AggregatedUsage
httpTransport() *http.Client
close()
}
func newDefaultCredential(ctx context.Context, tag string, options option.CCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) {
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
@@ -188,15 +213,34 @@ func (c *defaultCredential) getAccessToken() (string, error) {
return newCredentials.AccessToken, nil
}
func parseResetTimestamp(value string) (time.Time, error) {
if value == "" {
return time.Time{}, nil
// Claude Code's unified rate-limit handling parses these reset headers with
// Number(...), compares them against Date.now()/1000, and renders them via
// new Date(seconds*1000), so keep the wire format pinned to Unix epoch seconds.
func parseAnthropicResetHeaderValue(headerName string, headerValue string) time.Time {
unixEpoch, err := strconv.ParseInt(headerValue, 10, 64)
if err != nil {
panic("invalid " + headerName + " header: expected Unix epoch seconds, got " + strconv.Quote(headerValue))
}
unixEpoch, err := strconv.ParseInt(value, 10, 64)
if err == nil {
return time.Unix(unixEpoch, 0), nil
if unixEpoch <= 0 {
panic("invalid " + headerName + " header: expected positive Unix epoch seconds, got " + strconv.Quote(headerValue))
}
return time.Parse(time.RFC3339Nano, value)
return time.Unix(unixEpoch, 0)
}
func parseOptionalAnthropicResetHeader(headers http.Header, headerName string) (time.Time, bool) {
headerValue := headers.Get(headerName)
if headerValue == "" {
return time.Time{}, false
}
return parseAnthropicResetHeaderValue(headerName, headerValue), true
}
func parseRequiredAnthropicResetHeader(headers http.Header, headerName string) time.Time {
headerValue := headers.Get(headerName)
if headerValue == "" {
panic("missing required " + headerName + " header")
}
return parseAnthropicResetHeaderValue(headerName, headerValue)
}
func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
@@ -215,11 +259,8 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
c.state.fiveHourUtilization = newValue
}
}
if resetAt := headers.Get("anthropic-ratelimit-unified-5h-reset"); resetAt != "" {
value, err := parseResetTimestamp(resetAt)
if err == nil {
c.state.fiveHourReset = value
}
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists {
c.state.fiveHourReset = value
}
if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" {
value, err := strconv.ParseFloat(utilization, 64)
@@ -231,11 +272,8 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
c.state.weeklyUtilization = newValue
}
}
if resetAt := headers.Get("anthropic-ratelimit-unified-7d-reset"); resetAt != "" {
value, err := parseResetTimestamp(resetAt)
if err == nil {
c.state.weeklyReset = value
}
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists {
c.state.weeklyReset = value
}
c.state.lastUpdated = time.Now()
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
@@ -499,40 +537,110 @@ func (c *defaultCredential) close() {
}
}
func (c *defaultCredential) tagName() string {
return c.tag
}
func (c *defaultCredential) isExternal() bool {
return false
}
func (c *defaultCredential) fiveHourUtilization() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
return c.state.fiveHourUtilization
}
func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage {
return c.usageTracker
}
func (c *defaultCredential) httpTransport() *http.Client {
return c.httpClient
}
func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) {
accessToken, err := c.getAccessToken()
if err != nil {
return nil, E.Cause(err, "get access token for ", c.tag)
}
proxyURL := claudeAPIBaseURL + original.URL.RequestURI()
var body io.Reader
if bodyBytes != nil {
body = bytes.NewReader(bodyBytes)
} else {
body = original.Body
}
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
if err != nil {
return nil, err
}
for key, values := range original.Header {
if !isHopByHopHeader(key) && key != "Authorization" {
proxyRequest.Header[key] = values
}
}
serviceOverridesAcceptEncoding := len(serviceHeaders.Values("Accept-Encoding")) > 0
if c.usageTracker != nil && !serviceOverridesAcceptEncoding {
proxyRequest.Header.Del("Accept-Encoding")
}
anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta")
if anthropicBetaHeader != "" {
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader)
} else {
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
}
for key, values := range serviceHeaders {
proxyRequest.Header.Del(key)
proxyRequest.Header[key] = values
}
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
return proxyRequest, nil
}
// credentialProvider is the interface for all credential types.
type credentialProvider interface {
selectCredential(sessionID string) (*defaultCredential, bool, error)
onRateLimited(sessionID string, credential *defaultCredential, resetAt time.Time) *defaultCredential
selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error)
onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential
pollIfStale(ctx context.Context)
allDefaults() []*defaultCredential
allCredentials() []credential
close()
}
// singleCredentialProvider wraps a single default credential (legacy or single default).
// singleCredentialProvider wraps a single credential (legacy or single default).
type singleCredentialProvider struct {
credential *defaultCredential
cred credential
}
func (p *singleCredentialProvider) selectCredential(_ string) (*defaultCredential, bool, error) {
if !p.credential.isUsable() {
return nil, false, E.New("credential ", p.credential.tag, " is rate-limited")
func (p *singleCredentialProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) {
if filter != nil && !filter(p.cred) {
return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out")
}
return p.credential, false, nil
if !p.cred.isUsable() {
return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited")
}
return p.cred, false, nil
}
func (p *singleCredentialProvider) onRateLimited(_ string, credential *defaultCredential, resetAt time.Time) *defaultCredential {
credential.markRateLimited(resetAt)
func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ func(credential) bool) credential {
cred.markRateLimited(resetAt)
return nil
}
func (p *singleCredentialProvider) pollIfStale(ctx context.Context) {
if time.Since(p.credential.lastUpdatedTime()) > p.credential.pollBackoff(defaultPollInterval) {
p.credential.pollUsage(ctx)
if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) {
p.cred.pollUsage(ctx)
}
}
func (p *singleCredentialProvider) allDefaults() []*defaultCredential {
return []*defaultCredential{p.credential}
func (p *singleCredentialProvider) allCredentials() []credential {
return []credential{p.cred}
}
func (p *singleCredentialProvider) close() {}
@@ -546,7 +654,7 @@ type sessionEntry struct {
// balancerProvider assigns sessions to credentials based on a configurable strategy.
type balancerProvider struct {
credentials []*defaultCredential
credentials []credential
strategy string
roundRobinIndex atomic.Uint64
pollInterval time.Duration
@@ -555,7 +663,7 @@ type balancerProvider struct {
logger log.ContextLogger
}
func newBalancerProvider(credentials []*defaultCredential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider {
func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider {
if pollInterval <= 0 {
pollInterval = defaultPollInterval
}
@@ -568,15 +676,15 @@ func newBalancerProvider(credentials []*defaultCredential, strategy string, poll
}
}
func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredential, bool, error) {
func (p *balancerProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) {
if sessionID != "" {
p.sessionMutex.RLock()
entry, exists := p.sessions[sessionID]
p.sessionMutex.RUnlock()
if exists {
for _, credential := range p.credentials {
if credential.tag == entry.tag && credential.isUsable() {
return credential, false, nil
for _, cred := range p.credentials {
if cred.tagName() == entry.tag && (filter == nil || filter(cred)) && cred.isUsable() {
return cred, false, nil
}
}
p.sessionMutex.Lock()
@@ -585,7 +693,7 @@ func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredentia
}
}
best := p.pickCredential()
best := p.pickCredential(filter)
if best == nil {
return nil, false, allCredentialsUnavailableError(p.credentials)
}
@@ -593,61 +701,67 @@ func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredentia
isNew := sessionID != ""
if isNew {
p.sessionMutex.Lock()
p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()}
p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()}
p.sessionMutex.Unlock()
}
return best, isNew, nil
}
func (p *balancerProvider) onRateLimited(sessionID string, credential *defaultCredential, resetAt time.Time) *defaultCredential {
credential.markRateLimited(resetAt)
func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential {
cred.markRateLimited(resetAt)
if sessionID != "" {
p.sessionMutex.Lock()
delete(p.sessions, sessionID)
p.sessionMutex.Unlock()
}
best := p.pickCredential()
best := p.pickCredential(filter)
if best != nil && sessionID != "" {
p.sessionMutex.Lock()
p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()}
p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()}
p.sessionMutex.Unlock()
}
return best
}
func (p *balancerProvider) pickCredential() *defaultCredential {
func (p *balancerProvider) pickCredential(filter func(credential) bool) credential {
switch p.strategy {
case "round_robin":
return p.pickRoundRobin()
return p.pickRoundRobin(filter)
case "random":
return p.pickRandom()
return p.pickRandom(filter)
default:
return p.pickLeastUsed()
return p.pickLeastUsed(filter)
}
}
func (p *balancerProvider) pickLeastUsed() *defaultCredential {
var best *defaultCredential
func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential {
var best credential
bestUtilization := float64(101)
for _, credential := range p.credentials {
if !credential.isUsable() {
for _, cred := range p.credentials {
if filter != nil && !filter(cred) {
continue
}
utilization := credential.weeklyUtilization()
if !cred.isUsable() {
continue
}
utilization := cred.weeklyUtilization()
if utilization < bestUtilization {
bestUtilization = utilization
best = credential
best = cred
}
}
return best
}
func (p *balancerProvider) pickRoundRobin() *defaultCredential {
func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential {
start := int(p.roundRobinIndex.Add(1) - 1)
count := len(p.credentials)
for offset := range count {
candidate := p.credentials[(start+offset)%count]
if filter != nil && !filter(candidate) {
continue
}
if candidate.isUsable() {
return candidate
}
@@ -655,9 +769,12 @@ func (p *balancerProvider) pickRoundRobin() *defaultCredential {
return nil
}
func (p *balancerProvider) pickRandom() *defaultCredential {
var usable []*defaultCredential
func (p *balancerProvider) pickRandom(filter func(credential) bool) credential {
var usable []credential
for _, candidate := range p.credentials {
if filter != nil && !filter(candidate) {
continue
}
if candidate.isUsable() {
usable = append(usable, candidate)
}
@@ -678,14 +795,14 @@ func (p *balancerProvider) pollIfStale(ctx context.Context) {
}
p.sessionMutex.Unlock()
for _, credential := range p.credentials {
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) {
credential.pollUsage(ctx)
for _, cred := range p.credentials {
if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) {
cred.pollUsage(ctx)
}
}
}
func (p *balancerProvider) allDefaults() []*defaultCredential {
func (p *balancerProvider) allCredentials() []credential {
return p.credentials
}
@@ -693,12 +810,12 @@ func (p *balancerProvider) close() {}
// fallbackProvider tries credentials in order.
type fallbackProvider struct {
credentials []*defaultCredential
credentials []credential
pollInterval time.Duration
logger log.ContextLogger
}
func newFallbackProvider(credentials []*defaultCredential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider {
func newFallbackProvider(credentials []credential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider {
if pollInterval <= 0 {
pollInterval = defaultPollInterval
}
@@ -709,18 +826,24 @@ func newFallbackProvider(credentials []*defaultCredential, pollInterval time.Dur
}
}
func (p *fallbackProvider) selectCredential(_ string) (*defaultCredential, bool, error) {
for _, credential := range p.credentials {
if credential.isUsable() {
return credential, false, nil
func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) {
for _, cred := range p.credentials {
if filter != nil && !filter(cred) {
continue
}
if cred.isUsable() {
return cred, false, nil
}
}
return nil, false, allCredentialsUnavailableError(p.credentials)
}
func (p *fallbackProvider) onRateLimited(_ string, credential *defaultCredential, resetAt time.Time) *defaultCredential {
credential.markRateLimited(resetAt)
func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, filter func(credential) bool) credential {
cred.markRateLimited(resetAt)
for _, candidate := range p.credentials {
if filter != nil && !filter(candidate) {
continue
}
if candidate.isUsable() {
return candidate
}
@@ -729,23 +852,23 @@ func (p *fallbackProvider) onRateLimited(_ string, credential *defaultCredential
}
func (p *fallbackProvider) pollIfStale(ctx context.Context) {
for _, credential := range p.credentials {
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) {
credential.pollUsage(ctx)
for _, cred := range p.credentials {
if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) {
cred.pollUsage(ctx)
}
}
}
func (p *fallbackProvider) allDefaults() []*defaultCredential {
func (p *fallbackProvider) allCredentials() []credential {
return p.credentials
}
func (p *fallbackProvider) close() {}
func allCredentialsUnavailableError(credentials []*defaultCredential) error {
func allCredentialsUnavailableError(credentials []credential) error {
var earliest time.Time
for _, credential := range credentials {
resetAt := credential.earliestReset()
for _, cred := range credentials {
resetAt := cred.earliestReset()
if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) {
earliest = resetAt
}
@@ -778,34 +901,44 @@ func buildCredentialProviders(
ctx context.Context,
options option.CCMServiceOptions,
logger log.ContextLogger,
) (map[string]credentialProvider, []*defaultCredential, error) {
defaultCredentials := make(map[string]*defaultCredential)
var allDefaults []*defaultCredential
) (map[string]credentialProvider, []credential, error) {
allCredentialMap := make(map[string]credential)
var allCreds []credential
providers := make(map[string]credentialProvider)
// Pass 1: create default and external credentials
for _, credOpt := range options.Credentials {
switch credOpt.Type {
case "default":
credential, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger)
cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger)
if err != nil {
return nil, nil, err
}
defaultCredentials[credOpt.Tag] = credential
allDefaults = append(allDefaults, credential)
providers[credOpt.Tag] = &singleCredentialProvider{credential: credential}
allCredentialMap[credOpt.Tag] = cred
allCreds = append(allCreds, cred)
providers[credOpt.Tag] = &singleCredentialProvider{cred: cred}
case "external":
cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger)
if err != nil {
return nil, nil, err
}
allCredentialMap[credOpt.Tag] = cred
allCreds = append(allCreds, cred)
providers[credOpt.Tag] = &singleCredentialProvider{cred: cred}
}
}
// Pass 2: create balancer and fallback providers
for _, credOpt := range options.Credentials {
switch credOpt.Type {
case "balancer":
subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, defaultCredentials, credOpt.Tag)
subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag)
if err != nil {
return nil, nil, err
}
providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), logger)
case "fallback":
subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, defaultCredentials, credOpt.Tag)
subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, allCredentialMap, credOpt.Tag)
if err != nil {
return nil, nil, err
}
@@ -813,17 +946,17 @@ func buildCredentialProviders(
}
}
return providers, allDefaults, nil
return providers, allCreds, nil
}
func resolveCredentialTags(tags []string, defaults map[string]*defaultCredential, parentTag string) ([]*defaultCredential, error) {
credentials := make([]*defaultCredential, 0, len(tags))
func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) {
credentials := make([]credential, 0, len(tags))
for _, tag := range tags {
credential, exists := defaults[tag]
cred, exists := allCredentials[tag]
if !exists {
return nil, E.New("credential ", parentTag, " references unknown default credential: ", tag)
return nil, E.New("credential ", parentTag, " references unknown credential: ", tag)
}
credentials = append(credentials, credential)
credentials = append(credentials, cred)
}
if len(credentials) == 0 {
return nil, E.New("credential ", parentTag, " has no sub-credentials")
@@ -835,27 +968,12 @@ func parseRateLimitResetFromHeaders(headers http.Header) time.Time {
claim := headers.Get("anthropic-ratelimit-unified-representative-claim")
switch claim {
case "5h":
if resetStr := headers.Get("anthropic-ratelimit-unified-5h-reset"); resetStr != "" {
value, err := strconv.ParseInt(resetStr, 10, 64)
if err == nil {
return time.Unix(value, 0)
}
}
return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset")
case "7d":
if resetStr := headers.Get("anthropic-ratelimit-unified-7d-reset"); resetStr != "" {
value, err := strconv.ParseInt(resetStr, 10, 64)
if err == nil {
return time.Unix(value, 0)
}
}
return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset")
default:
panic("invalid anthropic-ratelimit-unified-representative-claim header: " + strconv.Quote(claim))
}
if retryAfter := headers.Get("Retry-After"); retryAfter != "" {
seconds, err := strconv.ParseInt(retryAfter, 10, 64)
if err == nil {
return time.Now().Add(time.Duration(seconds) * time.Second)
}
}
return time.Now().Add(5 * time.Minute)
}
func validateCCMOptions(options option.CCMServiceOptions) error {
@@ -876,24 +994,34 @@ func validateCCMOptions(options option.CCMServiceOptions) error {
if hasCredentials {
tags := make(map[string]bool)
for _, credential := range options.Credentials {
if tags[credential.Tag] {
return E.New("duplicate credential tag: ", credential.Tag)
credentialTypes := make(map[string]string)
for _, cred := range options.Credentials {
if tags[cred.Tag] {
return E.New("duplicate credential tag: ", cred.Tag)
}
tags[credential.Tag] = true
if credential.Type == "default" || credential.Type == "" {
if credential.DefaultOptions.Reserve5h > 99 {
return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99")
tags[cred.Tag] = true
credentialTypes[cred.Tag] = cred.Type
if cred.Type == "default" || cred.Type == "" {
if cred.DefaultOptions.Reserve5h > 99 {
return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99")
}
if credential.DefaultOptions.ReserveWeekly > 99 {
return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99")
if cred.DefaultOptions.ReserveWeekly > 99 {
return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99")
}
}
if credential.Type == "balancer" {
switch credential.BalancerOptions.Strategy {
if cred.Type == "external" {
if cred.ExternalOptions.URL == "" {
return E.New("credential ", cred.Tag, ": external credential requires url")
}
if cred.ExternalOptions.Token == "" {
return E.New("credential ", cred.Tag, ": external credential requires token")
}
}
if cred.Type == "balancer" {
switch cred.BalancerOptions.Strategy {
case "", "least_used", "round_robin", "random":
default:
return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy)
return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy)
}
}
}
@@ -905,63 +1033,25 @@ func validateCCMOptions(options option.CCMServiceOptions) error {
if !tags[user.Credential] {
return E.New("user ", user.Name, " references unknown credential: ", user.Credential)
}
if user.ExternalCredential != "" {
if !tags[user.ExternalCredential] {
return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential)
}
if credentialTypes[user.ExternalCredential] != "external" {
return E.New("user ", user.Name, ": external_credential must reference an external type credential")
}
}
}
}
return nil
}
// retryRequestWithBody re-sends a buffered request body using a different credential.
func retryRequestWithBody(
ctx context.Context,
originalRequest *http.Request,
bodyBytes []byte,
credential *defaultCredential,
httpHeaders http.Header,
) (*http.Response, error) {
accessToken, err := credential.getAccessToken()
if err != nil {
return nil, E.Cause(err, "get access token for ", credential.tag)
}
proxyURL := claudeAPIBaseURL + originalRequest.URL.RequestURI()
retryRequest, err := http.NewRequestWithContext(ctx, originalRequest.Method, proxyURL, bytes.NewReader(bodyBytes))
if err != nil {
return nil, err
}
for key, values := range originalRequest.Header {
if !isHopByHopHeader(key) && key != "Authorization" {
retryRequest.Header[key] = values
}
}
serviceOverridesAcceptEncoding := len(httpHeaders.Values("Accept-Encoding")) > 0
if credential.usageTracker != nil && !serviceOverridesAcceptEncoding {
retryRequest.Header.Del("Accept-Encoding")
}
anthropicBetaHeader := retryRequest.Header.Get("anthropic-beta")
if anthropicBetaHeader != "" {
retryRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader)
} else {
retryRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
}
for key, values := range httpHeaders {
retryRequest.Header.Del(key)
retryRequest.Header[key] = values
}
retryRequest.Header.Set("Authorization", "Bearer "+accessToken)
return credential.httpClient.Do(retryRequest)
}
// credentialForUser finds the credential provider for a user.
// In legacy mode, returns the single provider.
// In multi-credential mode, returns the provider mapped to the user's credential tag.
func credentialForUser(
userCredentialMap map[string]string,
userConfigMap map[string]*option.CCMUser,
providers map[string]credentialProvider,
legacyProvider credentialProvider,
username string,
@@ -969,13 +1059,13 @@ func credentialForUser(
if legacyProvider != nil {
return legacyProvider, nil
}
tag, exists := userCredentialMap[username]
userConfig, exists := userConfigMap[username]
if !exists {
return nil, E.New("no credential mapping for user: ", username)
}
provider, exists := providers[tag]
provider, exists := providers[userConfig.Credential]
if !exists {
return nil, E.New("unknown credential: ", tag)
return nil, E.New("unknown credential: ", userConfig.Credential)
}
return provider, nil
}

View File

@@ -66,15 +66,18 @@ func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, erro
})
}
func hasAlternativeCredential(provider credentialProvider, currentCredential *defaultCredential) bool {
func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool {
if provider == nil || currentCredential == nil {
return false
}
for _, credential := range provider.allDefaults() {
if credential == currentCredential {
for _, cred := range provider.allCredentials() {
if cred == currentCredential {
continue
}
if credential.isUsable() {
if filter != nil && !filter(cred) {
continue
}
if cred.isUsable() {
return true
}
}
@@ -85,7 +88,7 @@ func unavailableCredentialMessage(provider credentialProvider, fallback string)
if provider == nil {
return fallback
}
return allCredentialsUnavailableError(provider.allDefaults()).Error()
return allCredentialsUnavailableError(provider.allCredentials()).Error()
}
func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) {
@@ -100,10 +103,11 @@ func writeCredentialUnavailableError(
w http.ResponseWriter,
r *http.Request,
provider credentialProvider,
currentCredential *defaultCredential,
currentCredential credential,
filter func(credential) bool,
fallback string,
) {
if hasAlternativeCredential(provider, currentCredential) {
if hasAlternativeCredential(provider, currentCredential, filter) {
writeRetryableUsageError(w, r)
return
}
@@ -124,27 +128,15 @@ const (
weeklyWindowMinutes = weeklyWindowSeconds / 60
)
func parseInt64Header(headers http.Header, headerName string) (int64, bool) {
headerValue := strings.TrimSpace(headers.Get(headerName))
if headerValue == "" {
return 0, false
}
parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64)
if parseError != nil {
return 0, false
}
return parsedValue, true
}
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
resetAtUnix, hasResetAt := parseInt64Header(headers, "anthropic-ratelimit-unified-7d-reset")
if !hasResetAt || resetAtUnix <= 0 {
resetAt, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset")
if !exists {
return nil
}
return &WeeklyCycleHint{
WindowMinutes: weeklyWindowMinutes,
ResetAt: time.Unix(resetAtUnix, 0).UTC(),
ResetAt: resetAt.UTC(),
}
}
@@ -166,9 +158,9 @@ type Service struct {
legacyProvider credentialProvider
// Multi-credential mode
providers map[string]credentialProvider
allDefaults []*defaultCredential
userCredentialMap map[string]string
providers map[string]credentialProvider
allCredentials []credential
userConfigMap map[string]*option.CCMUser
}
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) {
@@ -199,20 +191,20 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
}
if len(options.Credentials) > 0 {
providers, allDefaults, err := buildCredentialProviders(ctx, options, logger)
providers, allCredentials, err := buildCredentialProviders(ctx, options, logger)
if err != nil {
return nil, E.Cause(err, "build credential providers")
}
service.providers = providers
service.allDefaults = allDefaults
service.allCredentials = allCredentials
userCredentialMap := make(map[string]string)
for _, user := range options.Users {
userCredentialMap[user.Name] = user.Credential
userConfigMap := make(map[string]*option.CCMUser)
for i := range options.Users {
userConfigMap[options.Users[i].Name] = &options.Users[i]
}
service.userCredentialMap = userCredentialMap
service.userConfigMap = userConfigMap
} else {
credential, err := newDefaultCredential(ctx, "default", option.CCMDefaultCredentialOptions{
cred, err := newDefaultCredential(ctx, "default", option.CCMDefaultCredentialOptions{
CredentialPath: options.CredentialPath,
UsagesPath: options.UsagesPath,
Detour: options.Detour,
@@ -220,9 +212,9 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
if err != nil {
return nil, err
}
service.legacyCredential = credential
service.legacyProvider = &singleCredentialProvider{credential: credential}
service.allDefaults = []*defaultCredential{credential}
service.legacyCredential = cred
service.legacyProvider = &singleCredentialProvider{cred: cred}
service.allCredentials = []credential{cred}
}
if options.TLS != nil {
@@ -243,8 +235,8 @@ func (s *Service) Start(stage adapter.StartStage) error {
s.userManager.UpdateUsers(s.options.Users)
for _, credential := range s.allDefaults {
err := credential.start()
for _, cred := range s.allCredentials {
err := cred.start()
if err != nil {
return err
}
@@ -303,6 +295,11 @@ func detectContextWindow(betaHeader string, totalInputTokens int64) int {
}
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/ccm/v1/status" {
s.handleStatusEndpoint(w, r)
return
}
if !strings.HasPrefix(r.URL.Path, "/v1/") {
writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found")
return
@@ -360,11 +357,13 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
// Resolve credential provider
// Resolve credential provider and user config
var provider credentialProvider
var userConfig *option.CCMUser
if len(s.options.Users) > 0 {
userConfig = s.userConfigMap[username]
var err error
provider, err = credentialForUser(s.userCredentialMap, s.providers, s.legacyProvider, username)
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
s.logger.Error("resolve credential: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
@@ -389,70 +388,48 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
credential, isNew, err := provider.selectCredential(sessionID)
var credentialFilter func(credential) bool
if userConfig != nil && !userConfig.AllowExternalUsage {
credentialFilter = func(c credential) bool { return !c.isExternal() }
}
selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter)
if err != nil {
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error()))
return
}
if isNew {
if username != "" {
s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID, " by user ", username)
s.logger.Debug("assigned credential ", selectedCredential.tagName(), " for session ", sessionID, " by user ", username)
} else {
s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID)
s.logger.Debug("assigned credential ", selectedCredential.tagName(), " for session ", sessionID)
}
}
accessToken, err := credential.getAccessToken()
if err != nil {
s.logger.Error("get access token: ", err)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed")
if isExtendedContextRequest(anthropicBetaHeader) && selectedCredential.isExternal() {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"extended context (1m) requests cannot be proxied through external credentials")
return
}
proxyURL := claudeAPIBaseURL + r.URL.RequestURI()
requestContext := credential.wrapRequestContext(r.Context())
requestContext := selectedCredential.wrapRequestContext(r.Context())
defer func() {
requestContext.cancelRequest()
}()
proxyRequest, err := http.NewRequestWithContext(requestContext, r.Method, proxyURL, r.Body)
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if err != nil {
s.logger.Error("create proxy request: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
return
}
for key, values := range r.Header {
if !isHopByHopHeader(key) && key != "Authorization" {
proxyRequest.Header[key] = values
}
}
hasUsageTracker := credential.usageTracker != nil
serviceOverridesAcceptEncoding := len(s.httpHeaders.Values("Accept-Encoding")) > 0
if hasUsageTracker && !serviceOverridesAcceptEncoding {
proxyRequest.Header.Del("Accept-Encoding")
}
if anthropicBetaHeader != "" {
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader)
} else {
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
}
for key, values := range s.httpHeaders {
proxyRequest.Header.Del(key)
proxyRequest.Header[key] = values
}
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
response, err := credential.httpClient.Do(proxyRequest)
response, err := selectedCredential.httpTransport().Do(proxyRequest)
if err != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, credential, "credential became unavailable while processing the request")
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request")
return
}
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
@@ -463,24 +440,30 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Transparent 429 retry
for response.StatusCode == http.StatusTooManyRequests {
resetAt := parseRateLimitResetFromHeaders(response.Header)
nextCredential := provider.onRateLimited(sessionID, credential, resetAt)
credential.updateStateFromHeaders(response.Header)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
selectedCredential.updateStateFromHeaders(response.Header)
if bodyBytes == nil || nextCredential == nil {
response.Body.Close()
writeCredentialUnavailableError(w, r, provider, credential, "all credentials rate-limited")
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
return
}
response.Body.Close()
s.logger.Info("retrying with credential ", nextCredential.tag, " after 429 from ", credential.tag)
s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
requestContext.cancelRequest()
requestContext = nextCredential.wrapRequestContext(r.Context())
retryResponse, retryErr := retryRequestWithBody(requestContext, r, bodyBytes, nextCredential, s.httpHeaders)
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
s.logger.Error("retry request: ", buildErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
return
}
retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest)
if retryErr != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, nextCredential, "credential became unavailable while retrying the request")
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
return
}
s.logger.Error("retry request: ", retryErr)
@@ -489,21 +472,24 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
requestContext.releaseCredentialInterrupt()
response = retryResponse
credential = nextCredential
selectedCredential = nextCredential
}
defer response.Body.Close()
credential.updateStateFromHeaders(response.Header)
selectedCredential.updateStateFromHeaders(response.Header)
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
s.logger.Error("upstream error from ", credential.tag, ": status ", response.StatusCode, " ", string(body))
s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
return
}
hasUsageTracker = credential.usageTracker != nil
// Rewrite response headers for external users
if userConfig != nil && userConfig.ExternalCredential != "" {
s.rewriteResponseHeadersForExternalUser(response.Header, userConfig)
}
for key, values := range response.Header {
if !isHopByHopHeader(key) {
@@ -512,8 +498,9 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
w.WriteHeader(response.StatusCode)
if hasUsageTracker && response.StatusCode == http.StatusOK {
s.handleResponseWithTracking(w, response, credential.usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
usageTracker := selectedCredential.usageTrackerOrNil()
if usageTracker != nil && response.StatusCode == http.StatusOK {
s.handleResponseWithTracking(w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
} else {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
if err == nil && mediaType != "text/event-stream" {
@@ -693,6 +680,91 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
}
}
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
return
}
if len(s.options.Users) == 0 {
writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication")
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
username, ok := s.userManager.Authenticate(clientToken)
if !ok {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
userConfig := s.userConfigMap[username]
if userConfig == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found")
return
}
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
provider.pollIfStale(r.Context())
avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]float64{
"five_hour_utilization": avgFiveHour,
"weekly_utilization": avgWeekly,
})
}
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64) {
var totalFiveHour, totalWeekly float64
var count int
for _, cred := range provider.allCredentials() {
// Exclude the user's own external_credential (their contribution to us)
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
continue
}
// If user doesn't allow external usage, exclude all external credentials
if !userConfig.AllowExternalUsage && cred.isExternal() {
continue
}
totalFiveHour += cred.fiveHourUtilization()
totalWeekly += cred.weeklyUtilization()
count++
}
if count == 0 {
return 100, 100
}
return totalFiveHour / float64(count), totalWeekly / float64(count)
}
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.CCMUser) {
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name)
if err != nil {
return
}
avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig)
// Rewrite utilization headers to aggregated average (convert back to 0.0-1.0 range)
headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64))
headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64))
}
func (s *Service) Close() error {
err := common.Close(
common.PtrOrNil(s.httpServer),
@@ -700,8 +772,8 @@ func (s *Service) Close() error {
s.tlsConfig,
)
for _, credential := range s.allDefaults {
credential.close()
for _, cred := range s.allCredentials {
cred.close()
}
return err

View File

@@ -0,0 +1,463 @@
package ocm
import (
"bytes"
"context"
stdTLS "crypto/tls"
"encoding/json"
"io"
"net"
"net/http"
"net/url"
"strconv"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
)
type externalCredential struct {
tag string
baseURL string
token string
credDialer N.Dialer
httpClient *http.Client
state credentialState
stateMutex sync.RWMutex
pollAccess sync.Mutex
pollInterval time.Duration
usageTracker *AggregatedUsage
logger log.ContextLogger
onBecameUnusable func()
interrupted bool
requestContext context.Context
cancelRequests context.CancelFunc
requestAccess sync.Mutex
}
func newExternalCredential(ctx context.Context, tag string, options option.OCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) {
parsedURL, err := url.Parse(options.URL)
if err != nil {
return nil, E.Cause(err, "parse url for credential ", tag)
}
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: option.DialerOptions{
Detour: options.Detour,
},
RemoteIsDomain: true,
})
if err != nil {
return nil, E.Cause(err, "create dialer for credential ", tag)
}
transport := &http.Transport{
ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if options.Server != "" {
serverPort := options.ServerPort
if serverPort == 0 {
portStr := parsedURL.Port()
if portStr != "" {
port, parseErr := strconv.ParseUint(portStr, 10, 16)
if parseErr == nil {
serverPort = uint16(port)
}
}
if serverPort == 0 {
if parsedURL.Scheme == "https" {
serverPort = 443
} else {
serverPort = 80
}
}
}
destination := M.ParseSocksaddrHostPort(options.Server, serverPort)
return credentialDialer.DialContext(ctx, network, destination)
}
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
}
if parsedURL.Scheme == "https" {
transport.TLSClientConfig = &stdTLS.Config{
ServerName: parsedURL.Hostname(),
RootCAs: adapter.RootPoolFromContext(ctx),
Time: ntp.TimeFuncFromContext(ctx),
}
}
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
if parsedURL.Path != "" && parsedURL.Path != "/" {
baseURL += parsedURL.Path
}
if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
baseURL = baseURL[:len(baseURL)-1]
}
pollInterval := time.Duration(options.PollInterval)
if pollInterval <= 0 {
pollInterval = 30 * time.Minute
}
requestContext, cancelRequests := context.WithCancel(context.Background())
cred := &externalCredential{
tag: tag,
baseURL: baseURL,
token: options.Token,
credDialer: credentialDialer,
httpClient: &http.Client{Transport: transport},
pollInterval: pollInterval,
logger: logger,
requestContext: requestContext,
cancelRequests: cancelRequests,
}
if options.UsagesPath != "" {
cred.usageTracker = &AggregatedUsage{
LastUpdated: time.Now(),
Combinations: make([]CostCombination, 0),
filePath: options.UsagesPath,
logger: logger,
}
}
return cred, nil
}
func (c *externalCredential) start() error {
if c.usageTracker != nil {
err := c.usageTracker.Load()
if err != nil {
c.logger.Warn("load usage statistics for ", c.tag, ": ", err)
}
}
return nil
}
func (c *externalCredential) setOnBecameUnusable(fn func()) {
c.onBecameUnusable = fn
}
func (c *externalCredential) tagName() string {
return c.tag
}
func (c *externalCredential) isExternal() bool {
return true
}
func (c *externalCredential) isUsable() bool {
c.stateMutex.RLock()
if c.state.hardRateLimited {
if time.Now().Before(c.state.rateLimitResetAt) {
c.stateMutex.RUnlock()
return false
}
c.stateMutex.RUnlock()
c.stateMutex.Lock()
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
c.stateMutex.Unlock()
return usable
}
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
c.stateMutex.RUnlock()
return usable
}
func (c *externalCredential) fiveHourUtilization() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
return c.state.fiveHourUtilization
}
func (c *externalCredential) weeklyUtilization() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
return c.state.weeklyUtilization
}
func (c *externalCredential) markRateLimited(resetAt time.Time) {
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
c.stateMutex.Lock()
c.state.hardRateLimited = true
c.state.rateLimitResetAt = resetAt
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *externalCredential) earliestReset() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
if c.state.hardRateLimited {
return c.state.rateLimitResetAt
}
earliest := c.state.fiveHourReset
if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) {
earliest = c.state.weeklyReset
}
return earliest
}
func (c *externalCredential) getAccessToken() (string, error) {
return c.token, nil
}
func (c *externalCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, _ http.Header) (*http.Request, error) {
proxyURL := c.baseURL + original.URL.RequestURI()
var body io.Reader
if bodyBytes != nil {
body = bytes.NewReader(bodyBytes)
} else {
body = original.Body
}
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
if err != nil {
return nil, err
}
for key, values := range original.Header {
if !isHopByHopHeader(key) && key != "Authorization" {
proxyRequest.Header[key] = values
}
}
proxyRequest.Header.Set("Authorization", "Bearer "+c.token)
return proxyRequest, nil
}
func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
c.stateMutex.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
if activeLimitIdentifier == "" {
activeLimitIdentifier = "codex"
}
fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent")
if fiveHourPercent != "" {
value, err := strconv.ParseFloat(fiveHourPercent, 64)
if err == nil {
c.state.fiveHourUtilization = value
}
}
fiveHourResetAt := headers.Get("x-" + activeLimitIdentifier + "-primary-reset-at")
if fiveHourResetAt != "" {
value, err := strconv.ParseInt(fiveHourResetAt, 10, 64)
if err == nil {
c.state.fiveHourReset = time.Unix(value, 0)
}
}
weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent")
if weeklyPercent != "" {
value, err := strconv.ParseFloat(weeklyPercent, 64)
if err == nil {
c.state.weeklyUtilization = value
}
}
weeklyResetAt := headers.Get("x-" + activeLimitIdentifier + "-secondary-reset-at")
if weeklyResetAt != "" {
value, err := strconv.ParseInt(weeklyResetAt, 10, 64)
if err == nil {
c.state.weeklyReset = time.Unix(value, 0)
}
}
c.state.lastUpdated = time.Now()
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%")
}
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *externalCredential) checkTransitionLocked() bool {
unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100
if unusable && !c.interrupted {
c.interrupted = true
return true
}
if !unusable && c.interrupted {
c.interrupted = false
}
return false
}
func (c *externalCredential) wrapRequestContext(parent context.Context) *credentialRequestContext {
c.requestAccess.Lock()
credentialContext := c.requestContext
c.requestAccess.Unlock()
derived, cancel := context.WithCancel(parent)
stop := context.AfterFunc(credentialContext, func() {
cancel()
})
return &credentialRequestContext{
Context: derived,
releaseFunc: stop,
cancelFunc: cancel,
}
}
func (c *externalCredential) interruptConnections() {
c.logger.Warn("interrupting connections for ", c.tag)
c.requestAccess.Lock()
c.cancelRequests()
c.requestContext, c.cancelRequests = context.WithCancel(context.Background())
c.requestAccess.Unlock()
if c.onBecameUnusable != nil {
c.onBecameUnusable()
}
}
func (c *externalCredential) pollUsage(ctx context.Context) {
if !c.pollAccess.TryLock() {
return
}
defer c.pollAccess.Unlock()
defer c.markUsagePollAttempted()
statusURL := c.baseURL + "/ocm/v1/status"
httpClient := &http.Client{
Transport: c.httpClient.Transport,
Timeout: 5 * time.Second,
}
request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil)
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": create request: ", err)
return
}
request.Header.Set("Authorization", "Bearer "+c.token)
response, err := httpClient.Do(request)
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": ", err)
c.stateMutex.Lock()
c.state.consecutivePollFailures++
c.stateMutex.Unlock()
return
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
c.stateMutex.Lock()
c.state.consecutivePollFailures++
c.stateMutex.Unlock()
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
return
}
var statusResponse struct {
FiveHourUtilization float64 `json:"five_hour_utilization"`
WeeklyUtilization float64 `json:"weekly_utilization"`
}
err = json.NewDecoder(response.Body).Decode(&statusResponse)
if err != nil {
c.stateMutex.Lock()
c.state.consecutivePollFailures++
c.stateMutex.Unlock()
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
return
}
c.stateMutex.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()
oldFiveHour := c.state.fiveHourUtilization
oldWeekly := c.state.weeklyUtilization
c.state.consecutivePollFailures = 0
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%")
}
shouldInterrupt := c.checkTransitionLocked()
c.stateMutex.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
}
func (c *externalCredential) lastUpdatedTime() time.Time {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
return c.state.lastUpdated
}
func (c *externalCredential) markUsagePollAttempted() {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
c.state.lastUpdated = time.Now()
}
func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration {
c.stateMutex.RLock()
failures := c.state.consecutivePollFailures
c.stateMutex.RUnlock()
if failures <= 0 {
return baseInterval
}
if failures > 4 {
failures = 4
}
return baseInterval * time.Duration(1<<failures)
}
func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage {
return c.usageTracker
}
func (c *externalCredential) httpTransport() *http.Client {
return c.httpClient
}
func (c *externalCredential) ocmDialer() N.Dialer {
return c.credDialer
}
func (c *externalCredential) ocmIsAPIKeyMode() bool {
return false
}
func (c *externalCredential) ocmGetAccountID() string {
return ""
}
func (c *externalCredential) ocmGetBaseURL() string {
return c.baseURL
}
func (c *externalCredential) close() {
if c.usageTracker != nil {
c.usageTracker.cancelPendingSave()
err := c.usageTracker.Save()
if err != nil {
c.logger.Error("save usage statistics for ", c.tag, ": ", err)
}
}
}

View File

@@ -82,6 +82,38 @@ func (c *credentialRequestContext) cancelRequest() {
c.cancelOnce.Do(c.cancelFunc)
}
type credential interface {
tagName() string
isUsable() bool
isExternal() bool
fiveHourUtilization() float64
weeklyUtilization() float64
markRateLimited(resetAt time.Time)
earliestReset() time.Time
getAccessToken() (string, error)
buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error)
updateStateFromHeaders(header http.Header)
wrapRequestContext(ctx context.Context) *credentialRequestContext
interruptConnections()
setOnBecameUnusable(fn func())
start() error
pollUsage(ctx context.Context)
lastUpdatedTime() time.Time
pollBackoff(base time.Duration) time.Duration
usageTrackerOrNil() *AggregatedUsage
httpTransport() *http.Client
close()
// OCM-specific
ocmDialer() N.Dialer
ocmIsAPIKeyMode() bool
ocmGetAccountID() string
ocmGetBaseURL() string
}
func newDefaultCredential(ctx context.Context, tag string, options option.OCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) {
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
@@ -523,38 +555,132 @@ func (c *defaultCredential) close() {
}
}
func (c *defaultCredential) setOnBecameUnusable(fn func()) {
c.onBecameUnusable = fn
}
func (c *defaultCredential) tagName() string {
return c.tag
}
func (c *defaultCredential) isExternal() bool {
return false
}
func (c *defaultCredential) fiveHourUtilization() float64 {
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()
return c.state.fiveHourUtilization
}
func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage {
return c.usageTracker
}
func (c *defaultCredential) httpTransport() *http.Client {
return c.httpClient
}
func (c *defaultCredential) ocmDialer() N.Dialer {
return c.dialer
}
func (c *defaultCredential) ocmIsAPIKeyMode() bool {
return c.isAPIKeyMode()
}
func (c *defaultCredential) ocmGetAccountID() string {
return c.getAccountID()
}
func (c *defaultCredential) ocmGetBaseURL() string {
return c.getBaseURL()
}
func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) {
accessToken, err := c.getAccessToken()
if err != nil {
return nil, E.Cause(err, "get access token for ", c.tag)
}
path := original.URL.Path
var proxyPath string
if c.isAPIKeyMode() {
proxyPath = path
} else {
proxyPath = strings.TrimPrefix(path, "/v1")
}
proxyURL := c.getBaseURL() + proxyPath
if original.URL.RawQuery != "" {
proxyURL += "?" + original.URL.RawQuery
}
var body io.Reader
if bodyBytes != nil {
body = bytes.NewReader(bodyBytes)
} else {
body = original.Body
}
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
if err != nil {
return nil, err
}
for key, values := range original.Header {
if !isHopByHopHeader(key) && key != "Authorization" {
proxyRequest.Header[key] = values
}
}
for key, values := range serviceHeaders {
proxyRequest.Header.Del(key)
proxyRequest.Header[key] = values
}
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
if accountID := c.getAccountID(); accountID != "" {
proxyRequest.Header.Set("ChatGPT-Account-Id", accountID)
}
return proxyRequest, nil
}
type credentialProvider interface {
selectCredential(sessionID string) (*defaultCredential, bool, error)
onRateLimited(sessionID string, credential *defaultCredential, resetAt time.Time) *defaultCredential
selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error)
onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential
pollIfStale(ctx context.Context)
allDefaults() []*defaultCredential
allCredentials() []credential
close()
}
type singleCredentialProvider struct {
credential *defaultCredential
cred credential
}
func (p *singleCredentialProvider) selectCredential(_ string) (*defaultCredential, bool, error) {
if !p.credential.isUsable() {
return nil, false, E.New("credential ", p.credential.tag, " is rate-limited")
func (p *singleCredentialProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) {
if filter != nil && !filter(p.cred) {
return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out")
}
return p.credential, false, nil
if !p.cred.isUsable() {
return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited")
}
return p.cred, false, nil
}
func (p *singleCredentialProvider) onRateLimited(_ string, credential *defaultCredential, resetAt time.Time) *defaultCredential {
credential.markRateLimited(resetAt)
func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ func(credential) bool) credential {
cred.markRateLimited(resetAt)
return nil
}
func (p *singleCredentialProvider) pollIfStale(ctx context.Context) {
if time.Since(p.credential.lastUpdatedTime()) > p.credential.pollBackoff(defaultPollInterval) {
p.credential.pollUsage(ctx)
if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) {
p.cred.pollUsage(ctx)
}
}
func (p *singleCredentialProvider) allDefaults() []*defaultCredential {
return []*defaultCredential{p.credential}
func (p *singleCredentialProvider) allCredentials() []credential {
return []credential{p.cred}
}
func (p *singleCredentialProvider) close() {}
@@ -567,7 +693,7 @@ type sessionEntry struct {
}
type balancerProvider struct {
credentials []*defaultCredential
credentials []credential
strategy string
roundRobinIndex atomic.Uint64
pollInterval time.Duration
@@ -576,7 +702,7 @@ type balancerProvider struct {
logger log.ContextLogger
}
func newBalancerProvider(credentials []*defaultCredential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider {
func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider {
if pollInterval <= 0 {
pollInterval = defaultPollInterval
}
@@ -589,15 +715,15 @@ func newBalancerProvider(credentials []*defaultCredential, strategy string, poll
}
}
func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredential, bool, error) {
func (p *balancerProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) {
if sessionID != "" {
p.sessionMutex.RLock()
entry, exists := p.sessions[sessionID]
p.sessionMutex.RUnlock()
if exists {
for _, credential := range p.credentials {
if credential.tag == entry.tag && credential.isUsable() {
return credential, false, nil
for _, cred := range p.credentials {
if cred.tagName() == entry.tag && (filter == nil || filter(cred)) && cred.isUsable() {
return cred, false, nil
}
}
p.sessionMutex.Lock()
@@ -606,7 +732,7 @@ func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredentia
}
}
best := p.pickCredential()
best := p.pickCredential(filter)
if best == nil {
return nil, false, allRateLimitedError(p.credentials)
}
@@ -614,61 +740,67 @@ func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredentia
isNew := sessionID != ""
if isNew {
p.sessionMutex.Lock()
p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()}
p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()}
p.sessionMutex.Unlock()
}
return best, isNew, nil
}
func (p *balancerProvider) onRateLimited(sessionID string, credential *defaultCredential, resetAt time.Time) *defaultCredential {
credential.markRateLimited(resetAt)
func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential {
cred.markRateLimited(resetAt)
if sessionID != "" {
p.sessionMutex.Lock()
delete(p.sessions, sessionID)
p.sessionMutex.Unlock()
}
best := p.pickCredential()
best := p.pickCredential(filter)
if best != nil && sessionID != "" {
p.sessionMutex.Lock()
p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()}
p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()}
p.sessionMutex.Unlock()
}
return best
}
func (p *balancerProvider) pickCredential() *defaultCredential {
func (p *balancerProvider) pickCredential(filter func(credential) bool) credential {
switch p.strategy {
case "round_robin":
return p.pickRoundRobin()
return p.pickRoundRobin(filter)
case "random":
return p.pickRandom()
return p.pickRandom(filter)
default:
return p.pickLeastUsed()
return p.pickLeastUsed(filter)
}
}
func (p *balancerProvider) pickLeastUsed() *defaultCredential {
var best *defaultCredential
func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential {
var best credential
bestUtilization := float64(101)
for _, credential := range p.credentials {
if !credential.isUsable() {
for _, cred := range p.credentials {
if filter != nil && !filter(cred) {
continue
}
utilization := credential.weeklyUtilization()
if !cred.isUsable() {
continue
}
utilization := cred.weeklyUtilization()
if utilization < bestUtilization {
bestUtilization = utilization
best = credential
best = cred
}
}
return best
}
func (p *balancerProvider) pickRoundRobin() *defaultCredential {
func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential {
start := int(p.roundRobinIndex.Add(1) - 1)
count := len(p.credentials)
for offset := range count {
candidate := p.credentials[(start+offset)%count]
if filter != nil && !filter(candidate) {
continue
}
if candidate.isUsable() {
return candidate
}
@@ -676,9 +808,12 @@ func (p *balancerProvider) pickRoundRobin() *defaultCredential {
return nil
}
func (p *balancerProvider) pickRandom() *defaultCredential {
var usable []*defaultCredential
func (p *balancerProvider) pickRandom(filter func(credential) bool) credential {
var usable []credential
for _, candidate := range p.credentials {
if filter != nil && !filter(candidate) {
continue
}
if candidate.isUsable() {
usable = append(usable, candidate)
}
@@ -699,26 +834,26 @@ func (p *balancerProvider) pollIfStale(ctx context.Context) {
}
p.sessionMutex.Unlock()
for _, credential := range p.credentials {
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) {
credential.pollUsage(ctx)
for _, cred := range p.credentials {
if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) {
cred.pollUsage(ctx)
}
}
}
func (p *balancerProvider) allDefaults() []*defaultCredential {
func (p *balancerProvider) allCredentials() []credential {
return p.credentials
}
func (p *balancerProvider) close() {}
type fallbackProvider struct {
credentials []*defaultCredential
credentials []credential
pollInterval time.Duration
logger log.ContextLogger
}
func newFallbackProvider(credentials []*defaultCredential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider {
func newFallbackProvider(credentials []credential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider {
if pollInterval <= 0 {
pollInterval = defaultPollInterval
}
@@ -729,18 +864,24 @@ func newFallbackProvider(credentials []*defaultCredential, pollInterval time.Dur
}
}
func (p *fallbackProvider) selectCredential(_ string) (*defaultCredential, bool, error) {
for _, credential := range p.credentials {
if credential.isUsable() {
return credential, false, nil
func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) {
for _, cred := range p.credentials {
if filter != nil && !filter(cred) {
continue
}
if cred.isUsable() {
return cred, false, nil
}
}
return nil, false, allRateLimitedError(p.credentials)
}
func (p *fallbackProvider) onRateLimited(_ string, credential *defaultCredential, resetAt time.Time) *defaultCredential {
credential.markRateLimited(resetAt)
func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, filter func(credential) bool) credential {
cred.markRateLimited(resetAt)
for _, candidate := range p.credentials {
if filter != nil && !filter(candidate) {
continue
}
if candidate.isUsable() {
return candidate
}
@@ -749,23 +890,23 @@ func (p *fallbackProvider) onRateLimited(_ string, credential *defaultCredential
}
func (p *fallbackProvider) pollIfStale(ctx context.Context) {
for _, credential := range p.credentials {
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) {
credential.pollUsage(ctx)
for _, cred := range p.credentials {
if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) {
cred.pollUsage(ctx)
}
}
}
func (p *fallbackProvider) allDefaults() []*defaultCredential {
func (p *fallbackProvider) allCredentials() []credential {
return p.credentials
}
func (p *fallbackProvider) close() {}
func allRateLimitedError(credentials []*defaultCredential) error {
func allRateLimitedError(credentials []credential) error {
var earliest time.Time
for _, credential := range credentials {
resetAt := credential.earliestReset()
for _, cred := range credentials {
resetAt := cred.earliestReset()
if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) {
earliest = resetAt
}
@@ -780,34 +921,44 @@ func buildOCMCredentialProviders(
ctx context.Context,
options option.OCMServiceOptions,
logger log.ContextLogger,
) (map[string]credentialProvider, []*defaultCredential, error) {
defaultCredentials := make(map[string]*defaultCredential)
var allDefaults []*defaultCredential
) (map[string]credentialProvider, []credential, error) {
allCredentialMap := make(map[string]credential)
var allCreds []credential
providers := make(map[string]credentialProvider)
// Pass 1: create default and external credentials
for _, credOpt := range options.Credentials {
switch credOpt.Type {
case "default":
credential, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger)
cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger)
if err != nil {
return nil, nil, err
}
defaultCredentials[credOpt.Tag] = credential
allDefaults = append(allDefaults, credential)
providers[credOpt.Tag] = &singleCredentialProvider{credential: credential}
allCredentialMap[credOpt.Tag] = cred
allCreds = append(allCreds, cred)
providers[credOpt.Tag] = &singleCredentialProvider{cred: cred}
case "external":
cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger)
if err != nil {
return nil, nil, err
}
allCredentialMap[credOpt.Tag] = cred
allCreds = append(allCreds, cred)
providers[credOpt.Tag] = &singleCredentialProvider{cred: cred}
}
}
// Pass 2: create balancer and fallback providers
for _, credOpt := range options.Credentials {
switch credOpt.Type {
case "balancer":
subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, defaultCredentials, credOpt.Tag)
subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag)
if err != nil {
return nil, nil, err
}
providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), logger)
case "fallback":
subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, defaultCredentials, credOpt.Tag)
subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, allCredentialMap, credOpt.Tag)
if err != nil {
return nil, nil, err
}
@@ -815,17 +966,17 @@ func buildOCMCredentialProviders(
}
}
return providers, allDefaults, nil
return providers, allCreds, nil
}
func resolveCredentialTags(tags []string, defaults map[string]*defaultCredential, parentTag string) ([]*defaultCredential, error) {
credentials := make([]*defaultCredential, 0, len(tags))
func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) {
credentials := make([]credential, 0, len(tags))
for _, tag := range tags {
credential, exists := defaults[tag]
cred, exists := allCredentials[tag]
if !exists {
return nil, E.New("credential ", parentTag, " references unknown default credential: ", tag)
return nil, E.New("credential ", parentTag, " references unknown credential: ", tag)
}
credentials = append(credentials, credential)
credentials = append(credentials, cred)
}
if len(credentials) == 0 {
return nil, E.New("credential ", parentTag, " has no sub-credentials")
@@ -871,24 +1022,34 @@ func validateOCMOptions(options option.OCMServiceOptions) error {
if hasCredentials {
tags := make(map[string]bool)
for _, credential := range options.Credentials {
if tags[credential.Tag] {
return E.New("duplicate credential tag: ", credential.Tag)
credentialTypes := make(map[string]string)
for _, cred := range options.Credentials {
if tags[cred.Tag] {
return E.New("duplicate credential tag: ", cred.Tag)
}
tags[credential.Tag] = true
if credential.Type == "default" || credential.Type == "" {
if credential.DefaultOptions.Reserve5h > 99 {
return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99")
tags[cred.Tag] = true
credentialTypes[cred.Tag] = cred.Type
if cred.Type == "default" || cred.Type == "" {
if cred.DefaultOptions.Reserve5h > 99 {
return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99")
}
if credential.DefaultOptions.ReserveWeekly > 99 {
return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99")
if cred.DefaultOptions.ReserveWeekly > 99 {
return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99")
}
}
if credential.Type == "balancer" {
switch credential.BalancerOptions.Strategy {
if cred.Type == "external" {
if cred.ExternalOptions.URL == "" {
return E.New("credential ", cred.Tag, ": external credential requires url")
}
if cred.ExternalOptions.Token == "" {
return E.New("credential ", cred.Tag, ": external credential requires token")
}
}
if cred.Type == "balancer" {
switch cred.BalancerOptions.Strategy {
case "", "least_used", "round_robin", "random":
default:
return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy)
return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy)
}
}
}
@@ -900,6 +1061,14 @@ func validateOCMOptions(options option.OCMServiceOptions) error {
if !tags[user.Credential] {
return E.New("user ", user.Name, " references unknown credential: ", user.Credential)
}
if user.ExternalCredential != "" {
if !tags[user.ExternalCredential] {
return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential)
}
if credentialTypes[user.ExternalCredential] != "external" {
return E.New("user ", user.Name, ": external_credential must reference an external type credential")
}
}
}
}
@@ -910,21 +1079,21 @@ func validateOCMCompositeCredentialModes(
options option.OCMServiceOptions,
providers map[string]credentialProvider,
) error {
for _, credential := range options.Credentials {
if credential.Type != "balancer" && credential.Type != "fallback" {
for _, credOpt := range options.Credentials {
if credOpt.Type != "balancer" && credOpt.Type != "fallback" {
continue
}
provider, exists := providers[credential.Tag]
provider, exists := providers[credOpt.Tag]
if !exists {
return E.New("unknown credential: ", credential.Tag)
return E.New("unknown credential: ", credOpt.Tag)
}
for _, subCredential := range provider.allDefaults() {
if subCredential.isAPIKeyMode() {
for _, subCred := range provider.allCredentials() {
if subCred.ocmIsAPIKeyMode() {
return E.New(
"credential ", credential.Tag,
" references API key default credential ", subCredential.tag,
"credential ", credOpt.Tag,
" references API key default credential ", subCred.tagName(),
"; balancer and fallback only support OAuth default credentials",
)
}
@@ -934,60 +1103,8 @@ func validateOCMCompositeCredentialModes(
return nil
}
func retryOCMRequestWithBody(
ctx context.Context,
originalRequest *http.Request,
bodyBytes []byte,
credential *defaultCredential,
httpHeaders http.Header,
) (*http.Response, error) {
accessToken, err := credential.getAccessToken()
if err != nil {
return nil, E.Cause(err, "get access token for ", credential.tag)
}
baseURL := credential.getBaseURL()
path := originalRequest.URL.Path
var proxyPath string
if credential.isAPIKeyMode() {
proxyPath = path
} else {
proxyPath = strings.TrimPrefix(path, "/v1")
}
proxyURL := baseURL + proxyPath
if originalRequest.URL.RawQuery != "" {
proxyURL += "?" + originalRequest.URL.RawQuery
}
var body io.Reader
if bodyBytes != nil {
body = bytes.NewReader(bodyBytes)
}
retryRequest, err := http.NewRequestWithContext(ctx, originalRequest.Method, proxyURL, body)
if err != nil {
return nil, err
}
for key, values := range originalRequest.Header {
if !isHopByHopHeader(key) && key != "Authorization" {
retryRequest.Header[key] = values
}
}
for key, values := range httpHeaders {
retryRequest.Header.Del(key)
retryRequest.Header[key] = values
}
retryRequest.Header.Set("Authorization", "Bearer "+accessToken)
if accountID := credential.getAccountID(); accountID != "" {
retryRequest.Header.Set("ChatGPT-Account-Id", accountID)
}
return credential.httpClient.Do(retryRequest)
}
func credentialForUser(
userCredentialMap map[string]string,
userConfigMap map[string]*option.OCMUser,
providers map[string]credentialProvider,
legacyProvider credentialProvider,
username string,
@@ -995,13 +1112,13 @@ func credentialForUser(
if legacyProvider != nil {
return legacyProvider, nil
}
tag, exists := userCredentialMap[username]
userConfig, exists := userConfigMap[username]
if !exists {
return nil, E.New("no credential mapping for user: ", username)
}
provider, exists := providers[tag]
provider, exists := providers[userConfig.Credential]
if !exists {
return nil, E.New("unknown credential: ", tag)
return nil, E.New("unknown credential: ", userConfig.Credential)
}
return provider, nil
}

View File

@@ -74,15 +74,18 @@ const (
retryableUsageCode = "credential_usage_exhausted"
)
func hasAlternativeCredential(provider credentialProvider, currentCredential *defaultCredential) bool {
func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool {
if provider == nil || currentCredential == nil {
return false
}
for _, credential := range provider.allDefaults() {
if credential == currentCredential {
for _, cred := range provider.allCredentials() {
if cred == currentCredential {
continue
}
if credential.isUsable() {
if filter != nil && !filter(cred) {
continue
}
if cred.isUsable() {
return true
}
}
@@ -93,7 +96,7 @@ func unavailableCredentialMessage(provider credentialProvider, fallback string)
if provider == nil {
return fallback
}
return allRateLimitedError(provider.allDefaults()).Error()
return allRateLimitedError(provider.allCredentials()).Error()
}
func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) {
@@ -108,10 +111,11 @@ func writeCredentialUnavailableError(
w http.ResponseWriter,
r *http.Request,
provider credentialProvider,
currentCredential *defaultCredential,
currentCredential credential,
filter func(credential) bool,
fallback string,
) {
if hasAlternativeCredential(provider, currentCredential) {
if hasAlternativeCredential(provider, currentCredential, filter) {
writeRetryableUsageError(w, r)
return
}
@@ -198,9 +202,9 @@ type Service struct {
legacyProvider credentialProvider
// Multi-credential mode
providers map[string]credentialProvider
allDefaults []*defaultCredential
userCredentialMap map[string]string
providers map[string]credentialProvider
allCredentials []credential
userConfigMap map[string]*option.OCMUser
}
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) {
@@ -230,20 +234,20 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
}
if len(options.Credentials) > 0 {
providers, allDefaults, err := buildOCMCredentialProviders(ctx, options, logger)
providers, allCredentials, err := buildOCMCredentialProviders(ctx, options, logger)
if err != nil {
return nil, E.Cause(err, "build credential providers")
}
service.providers = providers
service.allDefaults = allDefaults
service.allCredentials = allCredentials
userCredentialMap := make(map[string]string)
for _, user := range options.Users {
userCredentialMap[user.Name] = user.Credential
userConfigMap := make(map[string]*option.OCMUser)
for i := range options.Users {
userConfigMap[options.Users[i].Name] = &options.Users[i]
}
service.userCredentialMap = userCredentialMap
service.userConfigMap = userConfigMap
} else {
credential, err := newDefaultCredential(ctx, "default", option.OCMDefaultCredentialOptions{
cred, err := newDefaultCredential(ctx, "default", option.OCMDefaultCredentialOptions{
CredentialPath: options.CredentialPath,
UsagesPath: options.UsagesPath,
Detour: options.Detour,
@@ -251,9 +255,9 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
if err != nil {
return nil, err
}
service.legacyCredential = credential
service.legacyProvider = &singleCredentialProvider{credential: credential}
service.allDefaults = []*defaultCredential{credential}
service.legacyCredential = cred
service.legacyProvider = &singleCredentialProvider{cred: cred}
service.allCredentials = []credential{cred}
}
if options.TLS != nil {
@@ -274,15 +278,15 @@ func (s *Service) Start(stage adapter.StartStage) error {
s.userManager.UpdateUsers(s.options.Users)
for _, credential := range s.allDefaults {
err := credential.start()
for _, cred := range s.allCredentials {
err := cred.start()
if err != nil {
return err
}
tag := credential.tag
credential.onBecameUnusable = func() {
tag := cred.tagName()
cred.setOnBecameUnusable(func() {
s.interruptWebSocketSessionsForCredential(tag)
}
})
}
if len(s.options.Credentials) > 0 {
err := validateOCMCompositeCredentialModes(s.options, s.providers)
@@ -327,7 +331,7 @@ func (s *Service) Start(stage adapter.StartStage) error {
func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) {
if len(s.options.Users) > 0 {
return credentialForUser(s.userCredentialMap, s.providers, s.legacyProvider, username)
return credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
}
provider := noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
if provider == nil {
@@ -337,6 +341,11 @@ func (s *Service) resolveCredentialProvider(username string) (credentialProvider
}
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/ocm/v1/status" {
s.handleStatusEndpoint(w, r)
return
}
path := r.URL.Path
if !strings.HasPrefix(path, "/v1/") {
writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/")
@@ -368,49 +377,64 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
sessionID := r.Header.Get("session_id")
// Resolve credential provider
provider, err := s.resolveCredentialProvider(username)
if err != nil {
s.logger.Error("resolve credential: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
// Resolve credential provider and user config
var provider credentialProvider
var userConfig *option.OCMUser
if len(s.options.Users) > 0 {
userConfig = s.userConfigMap[username]
var err error
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
s.logger.Error("resolve credential: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
} else {
provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
}
if provider == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
return
}
provider.pollIfStale(s.ctx)
credential, isNew, err := provider.selectCredential(sessionID)
var credentialFilter func(credential) bool
if userConfig != nil && !userConfig.AllowExternalUsage {
credentialFilter = func(c credential) bool { return !c.isExternal() }
}
selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter)
if err != nil {
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error()))
return
}
if isNew {
if username != "" {
s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID, " by user ", username)
s.logger.Debug("assigned credential ", selectedCredential.tagName(), " for session ", sessionID, " by user ", username)
} else {
s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID)
s.logger.Debug("assigned credential ", selectedCredential.tagName(), " for session ", sessionID)
}
}
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
s.handleWebSocket(w, r, path, username, sessionID, provider, credential)
s.handleWebSocket(w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter)
return
}
var proxyPath string
if credential.isAPIKeyMode() {
proxyPath = path
} else {
if !selectedCredential.isExternal() && selectedCredential.ocmIsAPIKeyMode() {
// API key mode path handling
} else if !selectedCredential.isExternal() {
if path == "/v1/chat/completions" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
"chat completions endpoint is only available in API key mode")
return
}
proxyPath = strings.TrimPrefix(path, "/v1")
}
shouldTrackUsage := credential.usageTracker != nil &&
shouldTrackUsage := selectedCredential.usageTrackerOrNil() != nil &&
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses"))
canRetryRequest := len(provider.allDefaults()) > 1
canRetryRequest := len(provider.allCredentials()) > 1
// Read body for model extraction and retry buffer when JSON replay is useful.
var bodyBytes []byte
@@ -435,52 +459,24 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
accessToken, err := credential.getAccessToken()
if err != nil {
s.logger.Error("get access token: ", err)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed")
return
}
proxyURL := credential.getBaseURL() + proxyPath
if r.URL.RawQuery != "" {
proxyURL += "?" + r.URL.RawQuery
}
requestContext := credential.wrapRequestContext(r.Context())
requestContext := selectedCredential.wrapRequestContext(r.Context())
defer func() {
requestContext.cancelRequest()
}()
proxyRequest, err := http.NewRequestWithContext(requestContext, r.Method, proxyURL, r.Body)
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if err != nil {
s.logger.Error("create proxy request: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
return
}
for key, values := range r.Header {
if !isHopByHopHeader(key) && key != "Authorization" {
proxyRequest.Header[key] = values
}
}
for key, values := range s.httpHeaders {
proxyRequest.Header.Del(key)
proxyRequest.Header[key] = values
}
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
if accountID := credential.getAccountID(); accountID != "" {
proxyRequest.Header.Set("ChatGPT-Account-Id", accountID)
}
response, err := credential.httpClient.Do(proxyRequest)
response, err := selectedCredential.httpTransport().Do(proxyRequest)
if err != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, credential, "credential became unavailable while processing the request")
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request")
return
}
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
@@ -491,25 +487,31 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Transparent 429 retry
for response.StatusCode == http.StatusTooManyRequests {
resetAt := parseOCMRateLimitResetFromHeaders(response.Header)
nextCredential := provider.onRateLimited(sessionID, credential, resetAt)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete
credential.updateStateFromHeaders(response.Header)
selectedCredential.updateStateFromHeaders(response.Header)
if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil {
response.Body.Close()
writeCredentialUnavailableError(w, r, provider, credential, "all credentials rate-limited")
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
return
}
response.Body.Close()
s.logger.Info("retrying with credential ", nextCredential.tag, " after 429 from ", credential.tag)
s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
requestContext.cancelRequest()
requestContext = nextCredential.wrapRequestContext(r.Context())
retryResponse, retryErr := retryOCMRequestWithBody(requestContext, r, bodyBytes, nextCredential, s.httpHeaders)
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
if buildErr != nil {
s.logger.Error("retry request: ", buildErr)
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
return
}
retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest)
if retryErr != nil {
if r.Context().Err() != nil {
return
}
if requestContext.Err() != nil {
writeCredentialUnavailableError(w, r, provider, nextCredential, "credential became unavailable while retrying the request")
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
return
}
s.logger.Error("retry request: ", retryErr)
@@ -518,20 +520,25 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
requestContext.releaseCredentialInterrupt()
response = retryResponse
credential = nextCredential
selectedCredential = nextCredential
}
defer response.Body.Close()
credential.updateStateFromHeaders(response.Header)
selectedCredential.updateStateFromHeaders(response.Header)
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
body, _ := io.ReadAll(response.Body)
s.logger.Error("upstream error from ", credential.tag, ": status ", response.StatusCode, " ", string(body))
s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
return
}
// Rewrite response headers for external users
if userConfig != nil && userConfig.ExternalCredential != "" {
s.rewriteResponseHeadersForExternalUser(response.Header, userConfig)
}
for key, values := range response.Header {
if !isHopByHopHeader(key) {
w.Header()[key] = values
@@ -539,10 +546,10 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
w.WriteHeader(response.StatusCode)
hasUsageTracker := credential.usageTracker != nil
if hasUsageTracker && response.StatusCode == http.StatusOK &&
usageTracker := selectedCredential.usageTrackerOrNil()
if usageTracker != nil && response.StatusCode == http.StatusOK &&
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) {
s.handleResponseWithTracking(w, response, credential.usageTracker, path, requestModel, username)
s.handleResponseWithTracking(w, response, usageTracker, path, requestModel, username)
} else {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
if err == nil && mediaType != "text/event-stream" {
@@ -745,6 +752,93 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
}
}
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
return
}
if len(s.options.Users) == 0 {
writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication")
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
username, ok := s.userManager.Authenticate(clientToken)
if !ok {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
userConfig := s.userConfigMap[username]
if userConfig == nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found")
return
}
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
if err != nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
return
}
provider.pollIfStale(r.Context())
avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]float64{
"five_hour_utilization": avgFiveHour,
"weekly_utilization": avgWeekly,
})
}
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64) {
var totalFiveHour, totalWeekly float64
var count int
for _, cred := range provider.allCredentials() {
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
continue
}
if !userConfig.AllowExternalUsage && cred.isExternal() {
continue
}
totalFiveHour += cred.fiveHourUtilization()
totalWeekly += cred.weeklyUtilization()
count++
}
if count == 0 {
return 100, 100
}
return totalFiveHour / float64(count), totalWeekly / float64(count)
}
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.OCMUser) {
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name)
if err != nil {
return
}
avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig)
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
if activeLimitIdentifier == "" {
activeLimitIdentifier = "codex"
}
headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64))
headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64))
}
func (s *Service) Close() error {
webSocketSessions := s.startWebSocketShutdown()
@@ -758,8 +852,8 @@ func (s *Service) Close() error {
}
s.webSocketGroup.Wait()
for _, credential := range s.allDefaults {
credential.close()
for _, cred := range s.allCredentials {
cred.close()
}
return err

View File

@@ -14,6 +14,7 @@ import (
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/ntp"
@@ -85,8 +86,10 @@ func (s *Service) handleWebSocket(
path string,
username string,
sessionID string,
userConfig *option.OCMUser,
provider credentialProvider,
credential *defaultCredential,
selectedCredential credential,
credentialFilter func(credential) bool,
) {
var (
err error
@@ -97,7 +100,7 @@ func (s *Service) handleWebSocket(
)
for {
accessToken, accessErr := credential.getAccessToken()
accessToken, accessErr := selectedCredential.getAccessToken()
if accessErr != nil {
s.logger.Error("get access token for websocket: ", accessErr)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed")
@@ -105,13 +108,13 @@ func (s *Service) handleWebSocket(
}
var proxyPath string
if credential.isAPIKeyMode() {
if selectedCredential.ocmIsAPIKeyMode() || selectedCredential.isExternal() {
proxyPath = path
} else {
proxyPath = strings.TrimPrefix(path, "/v1")
}
upstreamURL := buildUpstreamWebSocketURL(credential.getBaseURL(), proxyPath)
upstreamURL := buildUpstreamWebSocketURL(selectedCredential.ocmGetBaseURL(), proxyPath)
if r.URL.RawQuery != "" {
upstreamURL += "?" + r.URL.RawQuery
}
@@ -127,7 +130,7 @@ func (s *Service) handleWebSocket(
upstreamHeaders[key] = values
}
upstreamHeaders.Set("Authorization", "Bearer "+accessToken)
if accountID := credential.getAccountID(); accountID != "" {
if accountID := selectedCredential.ocmGetAccountID(); accountID != "" {
upstreamHeaders.Set("ChatGPT-Account-Id", accountID)
}
@@ -135,7 +138,7 @@ func (s *Service) handleWebSocket(
statusCode = 0
upstreamDialer := ws.Dialer{
NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) {
return credential.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
return selectedCredential.ocmDialer().DialContext(ctx, network, M.ParseSocksaddr(addr))
},
TLSConfig: &stdTLS.Config{
RootCAs: adapter.RootPoolFromContext(s.ctx),
@@ -170,14 +173,14 @@ func (s *Service) handleWebSocket(
}
if statusCode == http.StatusTooManyRequests {
resetAt := parseOCMRateLimitResetFromHeaders(upstreamResponseHeaders)
nextCredential := provider.onRateLimited(sessionID, credential, resetAt)
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
if nextCredential == nil {
credential.updateStateFromHeaders(upstreamResponseHeaders)
writeCredentialUnavailableError(w, r, provider, credential, "all credentials rate-limited")
selectedCredential.updateStateFromHeaders(upstreamResponseHeaders)
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
return
}
s.logger.Info("retrying websocket with credential ", nextCredential.tag, " after 429 from ", credential.tag)
credential = nextCredential
s.logger.Info("retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
selectedCredential = nextCredential
continue
}
s.logger.Error("dial upstream websocket: ", err)
@@ -185,15 +188,18 @@ func (s *Service) handleWebSocket(
return
}
credential.updateStateFromHeaders(upstreamResponseHeaders)
selectedCredential.updateStateFromHeaders(upstreamResponseHeaders)
weeklyCycleHint := extractWeeklyCycleHint(upstreamResponseHeaders)
clientResponseHeaders := make(http.Header)
for key, values := range upstreamResponseHeaders {
if isForwardableResponseHeader(key) {
clientResponseHeaders[key] = values
clientResponseHeaders[key] = append([]string(nil), values...)
}
}
if userConfig != nil && userConfig.ExternalCredential != "" {
s.rewriteResponseHeadersForExternalUser(clientResponseHeaders, userConfig)
}
clientUpgrader := ws.HTTPUpgrader{
Header: clientResponseHeaders,
@@ -212,7 +218,7 @@ func (s *Service) handleWebSocket(
session := &webSocketSession{
clientConn: clientConn,
upstreamConn: upstreamConn,
credentialTag: credential.tag,
credentialTag: selectedCredential.tagName(),
}
if !s.registerWebSocketSession(session) {
session.Close()
@@ -237,17 +243,17 @@ func (s *Service) handleWebSocket(
go func() {
defer waitGroup.Done()
defer session.Close()
s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, credential, modelChannel)
s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, selectedCredential, modelChannel)
}()
go func() {
defer waitGroup.Done()
defer session.Close()
s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, credential, modelChannel, username, weeklyCycleHint)
s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, selectedCredential, modelChannel, username, weeklyCycleHint)
}()
waitGroup.Wait()
}
func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, credential *defaultCredential, modelChannel chan<- string) {
func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string) {
for {
data, opCode, err := wsutil.ReadClientData(clientConn)
if err != nil {
@@ -257,7 +263,7 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo
return
}
if opCode == ws.OpText && credential.usageTracker != nil {
if opCode == ws.OpText && selectedCredential.usageTrackerOrNil() != nil {
var request struct {
Type string `json:"type"`
Model string `json:"model"`
@@ -280,7 +286,8 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo
}
}
func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, credential *defaultCredential, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
usageTracker := selectedCredential.usageTrackerOrNil()
var requestModel string
for {
data, opCode, err := wsutil.ReadServerData(upstreamReadWriter)
@@ -291,7 +298,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite
return
}
if opCode == ws.OpText && credential.usageTracker != nil {
if opCode == ws.OpText && usageTracker != nil {
select {
case model := <-modelChannel:
requestModel = model
@@ -317,7 +324,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite
}
if responseModel != "" {
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
credential.usageTracker.AddUsageWithCycleHint(
usageTracker.AddUsageWithCycleHint(
responseModel,
contextWindow,
inputTokens,