mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-14 04:38:28 +10:00
ccm/ocm: Add external credential support for cross-instance usage sharing
Extract credential interface from *defaultCredential to support both default (OAuth) and external (remote proxy) credential types. External credentials proxy requests to a remote ccm/ocm instance with bearer token auth, poll a /status endpoint for utilization, and parse aggregated rate limit headers from responses. Add allow_external_usage user flag to control whether balancer/fallback providers may select external credentials. Add status endpoint (/ccm/v1/status, /ocm/v1/status) returning averaged utilization across eligible credentials. Rewrite response rate limit headers for external users with aggregated values.
This commit is contained in:
@@ -19,15 +19,18 @@ type CCMServiceOptions struct {
|
||||
}
|
||||
|
||||
type CCMUser struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Credential string `json:"credential,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Credential string `json:"credential,omitempty"`
|
||||
ExternalCredential string `json:"external_credential,omitempty"`
|
||||
AllowExternalUsage bool `json:"allow_external_usage,omitempty"`
|
||||
}
|
||||
|
||||
type _CCMCredential struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Tag string `json:"tag"`
|
||||
DefaultOptions CCMDefaultCredentialOptions `json:"-"`
|
||||
ExternalOptions CCMExternalCredentialOptions `json:"-"`
|
||||
BalancerOptions CCMBalancerCredentialOptions `json:"-"`
|
||||
FallbackOptions CCMFallbackCredentialOptions `json:"-"`
|
||||
}
|
||||
@@ -40,6 +43,8 @@ func (c CCMCredential) MarshalJSON() ([]byte, error) {
|
||||
case "", "default":
|
||||
c.Type = ""
|
||||
v = c.DefaultOptions
|
||||
case "external":
|
||||
v = c.ExternalOptions
|
||||
case "balancer":
|
||||
v = c.BalancerOptions
|
||||
case "fallback":
|
||||
@@ -63,6 +68,8 @@ func (c *CCMCredential) UnmarshalJSON(bytes []byte) error {
|
||||
case "", "default":
|
||||
c.Type = "default"
|
||||
v = &c.DefaultOptions
|
||||
case "external":
|
||||
v = &c.ExternalOptions
|
||||
case "balancer":
|
||||
v = &c.BalancerOptions
|
||||
case "fallback":
|
||||
@@ -87,6 +94,15 @@ type CCMBalancerCredentialOptions struct {
|
||||
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
|
||||
}
|
||||
|
||||
type CCMExternalCredentialOptions struct {
|
||||
URL string `json:"url"`
|
||||
ServerOptions
|
||||
Token string `json:"token"`
|
||||
Detour string `json:"detour,omitempty"`
|
||||
UsagesPath string `json:"usages_path,omitempty"`
|
||||
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
|
||||
}
|
||||
|
||||
type CCMFallbackCredentialOptions struct {
|
||||
Credentials badoption.Listable[string] `json:"credentials"`
|
||||
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
|
||||
|
||||
@@ -19,15 +19,18 @@ type OCMServiceOptions struct {
|
||||
}
|
||||
|
||||
type OCMUser struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Credential string `json:"credential,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Credential string `json:"credential,omitempty"`
|
||||
ExternalCredential string `json:"external_credential,omitempty"`
|
||||
AllowExternalUsage bool `json:"allow_external_usage,omitempty"`
|
||||
}
|
||||
|
||||
type _OCMCredential struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Tag string `json:"tag"`
|
||||
DefaultOptions OCMDefaultCredentialOptions `json:"-"`
|
||||
ExternalOptions OCMExternalCredentialOptions `json:"-"`
|
||||
BalancerOptions OCMBalancerCredentialOptions `json:"-"`
|
||||
FallbackOptions OCMFallbackCredentialOptions `json:"-"`
|
||||
}
|
||||
@@ -40,6 +43,8 @@ func (c OCMCredential) MarshalJSON() ([]byte, error) {
|
||||
case "", "default":
|
||||
c.Type = ""
|
||||
v = c.DefaultOptions
|
||||
case "external":
|
||||
v = c.ExternalOptions
|
||||
case "balancer":
|
||||
v = c.BalancerOptions
|
||||
case "fallback":
|
||||
@@ -63,6 +68,8 @@ func (c *OCMCredential) UnmarshalJSON(bytes []byte) error {
|
||||
case "", "default":
|
||||
c.Type = "default"
|
||||
v = &c.DefaultOptions
|
||||
case "external":
|
||||
v = &c.ExternalOptions
|
||||
case "balancer":
|
||||
v = &c.BalancerOptions
|
||||
case "fallback":
|
||||
@@ -87,6 +94,15 @@ type OCMBalancerCredentialOptions struct {
|
||||
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
|
||||
}
|
||||
|
||||
type OCMExternalCredentialOptions struct {
|
||||
URL string `json:"url"`
|
||||
ServerOptions
|
||||
Token string `json:"token"`
|
||||
Detour string `json:"detour,omitempty"`
|
||||
UsagesPath string `json:"usages_path,omitempty"`
|
||||
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
|
||||
}
|
||||
|
||||
type OCMFallbackCredentialOptions struct {
|
||||
Credentials badoption.Listable[string] `json:"credentials"`
|
||||
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
|
||||
|
||||
428
service/ccm/credential_external.go
Normal file
428
service/ccm/credential_external.go
Normal file
@@ -0,0 +1,428 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
)
|
||||
|
||||
type externalCredential struct {
|
||||
tag string
|
||||
baseURL string
|
||||
token string
|
||||
httpClient *http.Client
|
||||
state credentialState
|
||||
stateMutex sync.RWMutex
|
||||
pollAccess sync.Mutex
|
||||
pollInterval time.Duration
|
||||
usageTracker *AggregatedUsage
|
||||
logger log.ContextLogger
|
||||
|
||||
onBecameUnusable func()
|
||||
interrupted bool
|
||||
requestContext context.Context
|
||||
cancelRequests context.CancelFunc
|
||||
requestAccess sync.Mutex
|
||||
}
|
||||
|
||||
func newExternalCredential(ctx context.Context, tag string, options option.CCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) {
|
||||
parsedURL, err := url.Parse(options.URL)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "parse url for credential ", tag)
|
||||
}
|
||||
|
||||
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
Options: option.DialerOptions{
|
||||
Detour: options.Detour,
|
||||
},
|
||||
RemoteIsDomain: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create dialer for credential ", tag)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if options.Server != "" {
|
||||
serverPort := options.ServerPort
|
||||
if serverPort == 0 {
|
||||
portStr := parsedURL.Port()
|
||||
if portStr != "" {
|
||||
port, parseErr := strconv.ParseUint(portStr, 10, 16)
|
||||
if parseErr == nil {
|
||||
serverPort = uint16(port)
|
||||
}
|
||||
}
|
||||
if serverPort == 0 {
|
||||
if parsedURL.Scheme == "https" {
|
||||
serverPort = 443
|
||||
} else {
|
||||
serverPort = 80
|
||||
}
|
||||
}
|
||||
}
|
||||
destination := M.ParseSocksaddrHostPort(options.Server, serverPort)
|
||||
return credentialDialer.DialContext(ctx, network, destination)
|
||||
}
|
||||
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
}
|
||||
|
||||
if parsedURL.Scheme == "https" {
|
||||
transport.TLSClientConfig = &stdTLS.Config{
|
||||
ServerName: parsedURL.Hostname(),
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
|
||||
if parsedURL.Path != "" && parsedURL.Path != "/" {
|
||||
baseURL += parsedURL.Path
|
||||
}
|
||||
// Strip trailing slash
|
||||
if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
|
||||
baseURL = baseURL[:len(baseURL)-1]
|
||||
}
|
||||
|
||||
pollInterval := time.Duration(options.PollInterval)
|
||||
if pollInterval <= 0 {
|
||||
pollInterval = 30 * time.Minute
|
||||
}
|
||||
|
||||
requestContext, cancelRequests := context.WithCancel(context.Background())
|
||||
|
||||
cred := &externalCredential{
|
||||
tag: tag,
|
||||
baseURL: baseURL,
|
||||
token: options.Token,
|
||||
httpClient: &http.Client{Transport: transport},
|
||||
pollInterval: pollInterval,
|
||||
logger: logger,
|
||||
requestContext: requestContext,
|
||||
cancelRequests: cancelRequests,
|
||||
}
|
||||
|
||||
if options.UsagesPath != "" {
|
||||
cred.usageTracker = &AggregatedUsage{
|
||||
LastUpdated: time.Now(),
|
||||
Combinations: make([]CostCombination, 0),
|
||||
filePath: options.UsagesPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) start() error {
|
||||
if c.usageTracker != nil {
|
||||
err := c.usageTracker.Load()
|
||||
if err != nil {
|
||||
c.logger.Warn("load usage statistics for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) tagName() string {
|
||||
return c.tag
|
||||
}
|
||||
|
||||
func (c *externalCredential) isExternal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *externalCredential) isUsable() bool {
|
||||
c.stateMutex.RLock()
|
||||
if c.state.hardRateLimited {
|
||||
if time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.stateMutex.RUnlock()
|
||||
return false
|
||||
}
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateMutex.Lock()
|
||||
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
// No reserve for external: only 100% is unusable
|
||||
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
|
||||
c.stateMutex.Unlock()
|
||||
return usable
|
||||
}
|
||||
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
|
||||
c.stateMutex.RUnlock()
|
||||
return usable
|
||||
}
|
||||
|
||||
func (c *externalCredential) fiveHourUtilization() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return c.state.fiveHourUtilization
|
||||
}
|
||||
|
||||
func (c *externalCredential) weeklyUtilization() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return c.state.weeklyUtilization
|
||||
}
|
||||
|
||||
func (c *externalCredential) markRateLimited(resetAt time.Time) {
|
||||
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
|
||||
c.stateMutex.Lock()
|
||||
c.state.hardRateLimited = true
|
||||
c.state.rateLimitResetAt = resetAt
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) earliestReset() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
if c.state.hardRateLimited {
|
||||
return c.state.rateLimitResetAt
|
||||
}
|
||||
earliest := c.state.fiveHourReset
|
||||
if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) {
|
||||
earliest = c.state.weeklyReset
|
||||
}
|
||||
return earliest
|
||||
}
|
||||
|
||||
func (c *externalCredential) getAccessToken() (string, error) {
|
||||
return c.token, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, _ http.Header) (*http.Request, error) {
|
||||
proxyURL := c.baseURL + original.URL.RequestURI()
|
||||
var body io.Reader
|
||||
if bodyBytes != nil {
|
||||
body = bytes.NewReader(bodyBytes)
|
||||
} else {
|
||||
body = original.Body
|
||||
}
|
||||
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, values := range original.Header {
|
||||
if !isHopByHopHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+c.token)
|
||||
|
||||
return proxyRequest, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
|
||||
c.stateMutex.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
|
||||
if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" {
|
||||
value, err := strconv.ParseFloat(utilization, 64)
|
||||
if err == nil {
|
||||
// Remote CCM writes aggregated utilization as 0.0-1.0; convert to percentage
|
||||
c.state.fiveHourUtilization = value * 100
|
||||
}
|
||||
}
|
||||
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists {
|
||||
c.state.fiveHourReset = value
|
||||
}
|
||||
if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" {
|
||||
value, err := strconv.ParseFloat(utilization, 64)
|
||||
if err == nil {
|
||||
c.state.weeklyUtilization = value * 100
|
||||
}
|
||||
}
|
||||
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists {
|
||||
c.state.weeklyReset = value
|
||||
}
|
||||
c.state.lastUpdated = time.Now()
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%")
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) checkTransitionLocked() bool {
|
||||
unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100
|
||||
if unusable && !c.interrupted {
|
||||
c.interrupted = true
|
||||
return true
|
||||
}
|
||||
if !unusable && c.interrupted {
|
||||
c.interrupted = false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *externalCredential) wrapRequestContext(parent context.Context) *credentialRequestContext {
|
||||
c.requestAccess.Lock()
|
||||
credentialContext := c.requestContext
|
||||
c.requestAccess.Unlock()
|
||||
derived, cancel := context.WithCancel(parent)
|
||||
stop := context.AfterFunc(credentialContext, func() {
|
||||
cancel()
|
||||
})
|
||||
return &credentialRequestContext{
|
||||
Context: derived,
|
||||
releaseFunc: stop,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) interruptConnections() {
|
||||
c.logger.Warn("interrupting connections for ", c.tag)
|
||||
c.requestAccess.Lock()
|
||||
c.cancelRequests()
|
||||
c.requestContext, c.cancelRequests = context.WithCancel(context.Background())
|
||||
c.requestAccess.Unlock()
|
||||
if c.onBecameUnusable != nil {
|
||||
c.onBecameUnusable()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) pollUsage(ctx context.Context) {
|
||||
if !c.pollAccess.TryLock() {
|
||||
return
|
||||
}
|
||||
defer c.pollAccess.Unlock()
|
||||
defer c.markUsagePollAttempted()
|
||||
|
||||
statusURL := c.baseURL + "/ccm/v1/status"
|
||||
httpClient := &http.Client{
|
||||
Transport: c.httpClient.Transport,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("poll usage for ", c.tag, ": create request: ", err)
|
||||
return
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer "+c.token)
|
||||
|
||||
response, err := httpClient.Do(request)
|
||||
if err != nil {
|
||||
c.logger.Error("poll usage for ", c.tag, ": ", err)
|
||||
c.stateMutex.Lock()
|
||||
c.state.consecutivePollFailures++
|
||||
c.stateMutex.Unlock()
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
c.stateMutex.Lock()
|
||||
c.state.consecutivePollFailures++
|
||||
c.stateMutex.Unlock()
|
||||
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
return
|
||||
}
|
||||
|
||||
var statusResponse struct {
|
||||
FiveHourUtilization float64 `json:"five_hour_utilization"`
|
||||
WeeklyUtilization float64 `json:"weekly_utilization"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&statusResponse)
|
||||
if err != nil {
|
||||
c.stateMutex.Lock()
|
||||
c.state.consecutivePollFailures++
|
||||
c.stateMutex.Unlock()
|
||||
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.stateMutex.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
|
||||
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
|
||||
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%")
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) lastUpdatedTime() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return c.state.lastUpdated
|
||||
}
|
||||
|
||||
func (c *externalCredential) markUsagePollAttempted() {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
c.state.lastUpdated = time.Now()
|
||||
}
|
||||
|
||||
func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration {
|
||||
c.stateMutex.RLock()
|
||||
failures := c.state.consecutivePollFailures
|
||||
c.stateMutex.RUnlock()
|
||||
if failures <= 0 {
|
||||
return baseInterval
|
||||
}
|
||||
if failures > 4 {
|
||||
failures = 4
|
||||
}
|
||||
return baseInterval * time.Duration(1<<failures)
|
||||
}
|
||||
|
||||
func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage {
|
||||
return c.usageTracker
|
||||
}
|
||||
|
||||
func (c *externalCredential) httpTransport() *http.Client {
|
||||
return c.httpClient
|
||||
}
|
||||
|
||||
func (c *externalCredential) close() {
|
||||
if c.usageTracker != nil {
|
||||
c.usageTracker.cancelPendingSave()
|
||||
err := c.usageTracker.Save()
|
||||
if err != nil {
|
||||
c.logger.Error("save usage statistics for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -81,6 +81,31 @@ func (c *credentialRequestContext) cancelRequest() {
|
||||
c.cancelOnce.Do(c.cancelFunc)
|
||||
}
|
||||
|
||||
type credential interface {
|
||||
tagName() string
|
||||
isUsable() bool
|
||||
isExternal() bool
|
||||
fiveHourUtilization() float64
|
||||
weeklyUtilization() float64
|
||||
markRateLimited(resetAt time.Time)
|
||||
earliestReset() time.Time
|
||||
|
||||
getAccessToken() (string, error)
|
||||
buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error)
|
||||
updateStateFromHeaders(header http.Header)
|
||||
|
||||
wrapRequestContext(ctx context.Context) *credentialRequestContext
|
||||
interruptConnections()
|
||||
|
||||
start() error
|
||||
pollUsage(ctx context.Context)
|
||||
lastUpdatedTime() time.Time
|
||||
pollBackoff(base time.Duration) time.Duration
|
||||
usageTrackerOrNil() *AggregatedUsage
|
||||
httpTransport() *http.Client
|
||||
close()
|
||||
}
|
||||
|
||||
func newDefaultCredential(ctx context.Context, tag string, options option.CCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) {
|
||||
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
@@ -188,15 +213,34 @@ func (c *defaultCredential) getAccessToken() (string, error) {
|
||||
return newCredentials.AccessToken, nil
|
||||
}
|
||||
|
||||
func parseResetTimestamp(value string) (time.Time, error) {
|
||||
if value == "" {
|
||||
return time.Time{}, nil
|
||||
// Claude Code's unified rate-limit handling parses these reset headers with
|
||||
// Number(...), compares them against Date.now()/1000, and renders them via
|
||||
// new Date(seconds*1000), so keep the wire format pinned to Unix epoch seconds.
|
||||
func parseAnthropicResetHeaderValue(headerName string, headerValue string) time.Time {
|
||||
unixEpoch, err := strconv.ParseInt(headerValue, 10, 64)
|
||||
if err != nil {
|
||||
panic("invalid " + headerName + " header: expected Unix epoch seconds, got " + strconv.Quote(headerValue))
|
||||
}
|
||||
unixEpoch, err := strconv.ParseInt(value, 10, 64)
|
||||
if err == nil {
|
||||
return time.Unix(unixEpoch, 0), nil
|
||||
if unixEpoch <= 0 {
|
||||
panic("invalid " + headerName + " header: expected positive Unix epoch seconds, got " + strconv.Quote(headerValue))
|
||||
}
|
||||
return time.Parse(time.RFC3339Nano, value)
|
||||
return time.Unix(unixEpoch, 0)
|
||||
}
|
||||
|
||||
func parseOptionalAnthropicResetHeader(headers http.Header, headerName string) (time.Time, bool) {
|
||||
headerValue := headers.Get(headerName)
|
||||
if headerValue == "" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return parseAnthropicResetHeaderValue(headerName, headerValue), true
|
||||
}
|
||||
|
||||
func parseRequiredAnthropicResetHeader(headers http.Header, headerName string) time.Time {
|
||||
headerValue := headers.Get(headerName)
|
||||
if headerValue == "" {
|
||||
panic("missing required " + headerName + " header")
|
||||
}
|
||||
return parseAnthropicResetHeaderValue(headerName, headerValue)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
|
||||
@@ -215,11 +259,8 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
|
||||
c.state.fiveHourUtilization = newValue
|
||||
}
|
||||
}
|
||||
if resetAt := headers.Get("anthropic-ratelimit-unified-5h-reset"); resetAt != "" {
|
||||
value, err := parseResetTimestamp(resetAt)
|
||||
if err == nil {
|
||||
c.state.fiveHourReset = value
|
||||
}
|
||||
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists {
|
||||
c.state.fiveHourReset = value
|
||||
}
|
||||
if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" {
|
||||
value, err := strconv.ParseFloat(utilization, 64)
|
||||
@@ -231,11 +272,8 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
|
||||
c.state.weeklyUtilization = newValue
|
||||
}
|
||||
}
|
||||
if resetAt := headers.Get("anthropic-ratelimit-unified-7d-reset"); resetAt != "" {
|
||||
value, err := parseResetTimestamp(resetAt)
|
||||
if err == nil {
|
||||
c.state.weeklyReset = value
|
||||
}
|
||||
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists {
|
||||
c.state.weeklyReset = value
|
||||
}
|
||||
c.state.lastUpdated = time.Now()
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
@@ -499,40 +537,110 @@ func (c *defaultCredential) close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) tagName() string {
|
||||
return c.tag
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isExternal() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *defaultCredential) fiveHourUtilization() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return c.state.fiveHourUtilization
|
||||
}
|
||||
|
||||
func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage {
|
||||
return c.usageTracker
|
||||
}
|
||||
|
||||
func (c *defaultCredential) httpTransport() *http.Client {
|
||||
return c.httpClient
|
||||
}
|
||||
|
||||
func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) {
|
||||
accessToken, err := c.getAccessToken()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "get access token for ", c.tag)
|
||||
}
|
||||
|
||||
proxyURL := claudeAPIBaseURL + original.URL.RequestURI()
|
||||
var body io.Reader
|
||||
if bodyBytes != nil {
|
||||
body = bytes.NewReader(bodyBytes)
|
||||
} else {
|
||||
body = original.Body
|
||||
}
|
||||
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, values := range original.Header {
|
||||
if !isHopByHopHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
serviceOverridesAcceptEncoding := len(serviceHeaders.Values("Accept-Encoding")) > 0
|
||||
if c.usageTracker != nil && !serviceOverridesAcceptEncoding {
|
||||
proxyRequest.Header.Del("Accept-Encoding")
|
||||
}
|
||||
|
||||
anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta")
|
||||
if anthropicBetaHeader != "" {
|
||||
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader)
|
||||
} else {
|
||||
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
|
||||
}
|
||||
|
||||
for key, values := range serviceHeaders {
|
||||
proxyRequest.Header.Del(key)
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
return proxyRequest, nil
|
||||
}
|
||||
|
||||
// credentialProvider is the interface for all credential types.
|
||||
type credentialProvider interface {
|
||||
selectCredential(sessionID string) (*defaultCredential, bool, error)
|
||||
onRateLimited(sessionID string, credential *defaultCredential, resetAt time.Time) *defaultCredential
|
||||
selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error)
|
||||
onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential
|
||||
pollIfStale(ctx context.Context)
|
||||
allDefaults() []*defaultCredential
|
||||
allCredentials() []credential
|
||||
close()
|
||||
}
|
||||
|
||||
// singleCredentialProvider wraps a single default credential (legacy or single default).
|
||||
// singleCredentialProvider wraps a single credential (legacy or single default).
|
||||
type singleCredentialProvider struct {
|
||||
credential *defaultCredential
|
||||
cred credential
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) selectCredential(_ string) (*defaultCredential, bool, error) {
|
||||
if !p.credential.isUsable() {
|
||||
return nil, false, E.New("credential ", p.credential.tag, " is rate-limited")
|
||||
func (p *singleCredentialProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) {
|
||||
if filter != nil && !filter(p.cred) {
|
||||
return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out")
|
||||
}
|
||||
return p.credential, false, nil
|
||||
if !p.cred.isUsable() {
|
||||
return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited")
|
||||
}
|
||||
return p.cred, false, nil
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) onRateLimited(_ string, credential *defaultCredential, resetAt time.Time) *defaultCredential {
|
||||
credential.markRateLimited(resetAt)
|
||||
func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ func(credential) bool) credential {
|
||||
cred.markRateLimited(resetAt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) pollIfStale(ctx context.Context) {
|
||||
if time.Since(p.credential.lastUpdatedTime()) > p.credential.pollBackoff(defaultPollInterval) {
|
||||
p.credential.pollUsage(ctx)
|
||||
if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) {
|
||||
p.cred.pollUsage(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) allDefaults() []*defaultCredential {
|
||||
return []*defaultCredential{p.credential}
|
||||
func (p *singleCredentialProvider) allCredentials() []credential {
|
||||
return []credential{p.cred}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) close() {}
|
||||
@@ -546,7 +654,7 @@ type sessionEntry struct {
|
||||
|
||||
// balancerProvider assigns sessions to credentials based on a configurable strategy.
|
||||
type balancerProvider struct {
|
||||
credentials []*defaultCredential
|
||||
credentials []credential
|
||||
strategy string
|
||||
roundRobinIndex atomic.Uint64
|
||||
pollInterval time.Duration
|
||||
@@ -555,7 +663,7 @@ type balancerProvider struct {
|
||||
logger log.ContextLogger
|
||||
}
|
||||
|
||||
func newBalancerProvider(credentials []*defaultCredential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider {
|
||||
func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider {
|
||||
if pollInterval <= 0 {
|
||||
pollInterval = defaultPollInterval
|
||||
}
|
||||
@@ -568,15 +676,15 @@ func newBalancerProvider(credentials []*defaultCredential, strategy string, poll
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredential, bool, error) {
|
||||
func (p *balancerProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) {
|
||||
if sessionID != "" {
|
||||
p.sessionMutex.RLock()
|
||||
entry, exists := p.sessions[sessionID]
|
||||
p.sessionMutex.RUnlock()
|
||||
if exists {
|
||||
for _, credential := range p.credentials {
|
||||
if credential.tag == entry.tag && credential.isUsable() {
|
||||
return credential, false, nil
|
||||
for _, cred := range p.credentials {
|
||||
if cred.tagName() == entry.tag && (filter == nil || filter(cred)) && cred.isUsable() {
|
||||
return cred, false, nil
|
||||
}
|
||||
}
|
||||
p.sessionMutex.Lock()
|
||||
@@ -585,7 +693,7 @@ func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredentia
|
||||
}
|
||||
}
|
||||
|
||||
best := p.pickCredential()
|
||||
best := p.pickCredential(filter)
|
||||
if best == nil {
|
||||
return nil, false, allCredentialsUnavailableError(p.credentials)
|
||||
}
|
||||
@@ -593,61 +701,67 @@ func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredentia
|
||||
isNew := sessionID != ""
|
||||
if isNew {
|
||||
p.sessionMutex.Lock()
|
||||
p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()}
|
||||
p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()}
|
||||
p.sessionMutex.Unlock()
|
||||
}
|
||||
return best, isNew, nil
|
||||
}
|
||||
|
||||
func (p *balancerProvider) onRateLimited(sessionID string, credential *defaultCredential, resetAt time.Time) *defaultCredential {
|
||||
credential.markRateLimited(resetAt)
|
||||
func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential {
|
||||
cred.markRateLimited(resetAt)
|
||||
if sessionID != "" {
|
||||
p.sessionMutex.Lock()
|
||||
delete(p.sessions, sessionID)
|
||||
p.sessionMutex.Unlock()
|
||||
}
|
||||
|
||||
best := p.pickCredential()
|
||||
best := p.pickCredential(filter)
|
||||
if best != nil && sessionID != "" {
|
||||
p.sessionMutex.Lock()
|
||||
p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()}
|
||||
p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()}
|
||||
p.sessionMutex.Unlock()
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickCredential() *defaultCredential {
|
||||
func (p *balancerProvider) pickCredential(filter func(credential) bool) credential {
|
||||
switch p.strategy {
|
||||
case "round_robin":
|
||||
return p.pickRoundRobin()
|
||||
return p.pickRoundRobin(filter)
|
||||
case "random":
|
||||
return p.pickRandom()
|
||||
return p.pickRandom(filter)
|
||||
default:
|
||||
return p.pickLeastUsed()
|
||||
return p.pickLeastUsed(filter)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickLeastUsed() *defaultCredential {
|
||||
var best *defaultCredential
|
||||
func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential {
|
||||
var best credential
|
||||
bestUtilization := float64(101)
|
||||
for _, credential := range p.credentials {
|
||||
if !credential.isUsable() {
|
||||
for _, cred := range p.credentials {
|
||||
if filter != nil && !filter(cred) {
|
||||
continue
|
||||
}
|
||||
utilization := credential.weeklyUtilization()
|
||||
if !cred.isUsable() {
|
||||
continue
|
||||
}
|
||||
utilization := cred.weeklyUtilization()
|
||||
if utilization < bestUtilization {
|
||||
bestUtilization = utilization
|
||||
best = credential
|
||||
best = cred
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickRoundRobin() *defaultCredential {
|
||||
func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential {
|
||||
start := int(p.roundRobinIndex.Add(1) - 1)
|
||||
count := len(p.credentials)
|
||||
for offset := range count {
|
||||
candidate := p.credentials[(start+offset)%count]
|
||||
if filter != nil && !filter(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
return candidate
|
||||
}
|
||||
@@ -655,9 +769,12 @@ func (p *balancerProvider) pickRoundRobin() *defaultCredential {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickRandom() *defaultCredential {
|
||||
var usable []*defaultCredential
|
||||
func (p *balancerProvider) pickRandom(filter func(credential) bool) credential {
|
||||
var usable []credential
|
||||
for _, candidate := range p.credentials {
|
||||
if filter != nil && !filter(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
usable = append(usable, candidate)
|
||||
}
|
||||
@@ -678,14 +795,14 @@ func (p *balancerProvider) pollIfStale(ctx context.Context) {
|
||||
}
|
||||
p.sessionMutex.Unlock()
|
||||
|
||||
for _, credential := range p.credentials {
|
||||
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) {
|
||||
credential.pollUsage(ctx)
|
||||
for _, cred := range p.credentials {
|
||||
if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) {
|
||||
cred.pollUsage(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) allDefaults() []*defaultCredential {
|
||||
func (p *balancerProvider) allCredentials() []credential {
|
||||
return p.credentials
|
||||
}
|
||||
|
||||
@@ -693,12 +810,12 @@ func (p *balancerProvider) close() {}
|
||||
|
||||
// fallbackProvider tries credentials in order.
|
||||
type fallbackProvider struct {
|
||||
credentials []*defaultCredential
|
||||
credentials []credential
|
||||
pollInterval time.Duration
|
||||
logger log.ContextLogger
|
||||
}
|
||||
|
||||
func newFallbackProvider(credentials []*defaultCredential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider {
|
||||
func newFallbackProvider(credentials []credential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider {
|
||||
if pollInterval <= 0 {
|
||||
pollInterval = defaultPollInterval
|
||||
}
|
||||
@@ -709,18 +826,24 @@ func newFallbackProvider(credentials []*defaultCredential, pollInterval time.Dur
|
||||
}
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) selectCredential(_ string) (*defaultCredential, bool, error) {
|
||||
for _, credential := range p.credentials {
|
||||
if credential.isUsable() {
|
||||
return credential, false, nil
|
||||
func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) {
|
||||
for _, cred := range p.credentials {
|
||||
if filter != nil && !filter(cred) {
|
||||
continue
|
||||
}
|
||||
if cred.isUsable() {
|
||||
return cred, false, nil
|
||||
}
|
||||
}
|
||||
return nil, false, allCredentialsUnavailableError(p.credentials)
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) onRateLimited(_ string, credential *defaultCredential, resetAt time.Time) *defaultCredential {
|
||||
credential.markRateLimited(resetAt)
|
||||
func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, filter func(credential) bool) credential {
|
||||
cred.markRateLimited(resetAt)
|
||||
for _, candidate := range p.credentials {
|
||||
if filter != nil && !filter(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
return candidate
|
||||
}
|
||||
@@ -729,23 +852,23 @@ func (p *fallbackProvider) onRateLimited(_ string, credential *defaultCredential
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) pollIfStale(ctx context.Context) {
|
||||
for _, credential := range p.credentials {
|
||||
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) {
|
||||
credential.pollUsage(ctx)
|
||||
for _, cred := range p.credentials {
|
||||
if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) {
|
||||
cred.pollUsage(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) allDefaults() []*defaultCredential {
|
||||
func (p *fallbackProvider) allCredentials() []credential {
|
||||
return p.credentials
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) close() {}
|
||||
|
||||
func allCredentialsUnavailableError(credentials []*defaultCredential) error {
|
||||
func allCredentialsUnavailableError(credentials []credential) error {
|
||||
var earliest time.Time
|
||||
for _, credential := range credentials {
|
||||
resetAt := credential.earliestReset()
|
||||
for _, cred := range credentials {
|
||||
resetAt := cred.earliestReset()
|
||||
if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) {
|
||||
earliest = resetAt
|
||||
}
|
||||
@@ -778,34 +901,44 @@ func buildCredentialProviders(
|
||||
ctx context.Context,
|
||||
options option.CCMServiceOptions,
|
||||
logger log.ContextLogger,
|
||||
) (map[string]credentialProvider, []*defaultCredential, error) {
|
||||
defaultCredentials := make(map[string]*defaultCredential)
|
||||
var allDefaults []*defaultCredential
|
||||
) (map[string]credentialProvider, []credential, error) {
|
||||
allCredentialMap := make(map[string]credential)
|
||||
var allCreds []credential
|
||||
providers := make(map[string]credentialProvider)
|
||||
|
||||
// Pass 1: create default and external credentials
|
||||
for _, credOpt := range options.Credentials {
|
||||
switch credOpt.Type {
|
||||
case "default":
|
||||
credential, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger)
|
||||
cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defaultCredentials[credOpt.Tag] = credential
|
||||
allDefaults = append(allDefaults, credential)
|
||||
providers[credOpt.Tag] = &singleCredentialProvider{credential: credential}
|
||||
allCredentialMap[credOpt.Tag] = cred
|
||||
allCreds = append(allCreds, cred)
|
||||
providers[credOpt.Tag] = &singleCredentialProvider{cred: cred}
|
||||
case "external":
|
||||
cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
allCredentialMap[credOpt.Tag] = cred
|
||||
allCreds = append(allCreds, cred)
|
||||
providers[credOpt.Tag] = &singleCredentialProvider{cred: cred}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2: create balancer and fallback providers
|
||||
for _, credOpt := range options.Credentials {
|
||||
switch credOpt.Type {
|
||||
case "balancer":
|
||||
subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, defaultCredentials, credOpt.Tag)
|
||||
subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), logger)
|
||||
case "fallback":
|
||||
subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, defaultCredentials, credOpt.Tag)
|
||||
subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, allCredentialMap, credOpt.Tag)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -813,17 +946,17 @@ func buildCredentialProviders(
|
||||
}
|
||||
}
|
||||
|
||||
return providers, allDefaults, nil
|
||||
return providers, allCreds, nil
|
||||
}
|
||||
|
||||
func resolveCredentialTags(tags []string, defaults map[string]*defaultCredential, parentTag string) ([]*defaultCredential, error) {
|
||||
credentials := make([]*defaultCredential, 0, len(tags))
|
||||
func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) {
|
||||
credentials := make([]credential, 0, len(tags))
|
||||
for _, tag := range tags {
|
||||
credential, exists := defaults[tag]
|
||||
cred, exists := allCredentials[tag]
|
||||
if !exists {
|
||||
return nil, E.New("credential ", parentTag, " references unknown default credential: ", tag)
|
||||
return nil, E.New("credential ", parentTag, " references unknown credential: ", tag)
|
||||
}
|
||||
credentials = append(credentials, credential)
|
||||
credentials = append(credentials, cred)
|
||||
}
|
||||
if len(credentials) == 0 {
|
||||
return nil, E.New("credential ", parentTag, " has no sub-credentials")
|
||||
@@ -835,27 +968,12 @@ func parseRateLimitResetFromHeaders(headers http.Header) time.Time {
|
||||
claim := headers.Get("anthropic-ratelimit-unified-representative-claim")
|
||||
switch claim {
|
||||
case "5h":
|
||||
if resetStr := headers.Get("anthropic-ratelimit-unified-5h-reset"); resetStr != "" {
|
||||
value, err := strconv.ParseInt(resetStr, 10, 64)
|
||||
if err == nil {
|
||||
return time.Unix(value, 0)
|
||||
}
|
||||
}
|
||||
return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset")
|
||||
case "7d":
|
||||
if resetStr := headers.Get("anthropic-ratelimit-unified-7d-reset"); resetStr != "" {
|
||||
value, err := strconv.ParseInt(resetStr, 10, 64)
|
||||
if err == nil {
|
||||
return time.Unix(value, 0)
|
||||
}
|
||||
}
|
||||
return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset")
|
||||
default:
|
||||
panic("invalid anthropic-ratelimit-unified-representative-claim header: " + strconv.Quote(claim))
|
||||
}
|
||||
if retryAfter := headers.Get("Retry-After"); retryAfter != "" {
|
||||
seconds, err := strconv.ParseInt(retryAfter, 10, 64)
|
||||
if err == nil {
|
||||
return time.Now().Add(time.Duration(seconds) * time.Second)
|
||||
}
|
||||
}
|
||||
return time.Now().Add(5 * time.Minute)
|
||||
}
|
||||
|
||||
func validateCCMOptions(options option.CCMServiceOptions) error {
|
||||
@@ -876,24 +994,34 @@ func validateCCMOptions(options option.CCMServiceOptions) error {
|
||||
|
||||
if hasCredentials {
|
||||
tags := make(map[string]bool)
|
||||
for _, credential := range options.Credentials {
|
||||
if tags[credential.Tag] {
|
||||
return E.New("duplicate credential tag: ", credential.Tag)
|
||||
credentialTypes := make(map[string]string)
|
||||
for _, cred := range options.Credentials {
|
||||
if tags[cred.Tag] {
|
||||
return E.New("duplicate credential tag: ", cred.Tag)
|
||||
}
|
||||
tags[credential.Tag] = true
|
||||
if credential.Type == "default" || credential.Type == "" {
|
||||
if credential.DefaultOptions.Reserve5h > 99 {
|
||||
return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99")
|
||||
tags[cred.Tag] = true
|
||||
credentialTypes[cred.Tag] = cred.Type
|
||||
if cred.Type == "default" || cred.Type == "" {
|
||||
if cred.DefaultOptions.Reserve5h > 99 {
|
||||
return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99")
|
||||
}
|
||||
if credential.DefaultOptions.ReserveWeekly > 99 {
|
||||
return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99")
|
||||
if cred.DefaultOptions.ReserveWeekly > 99 {
|
||||
return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99")
|
||||
}
|
||||
}
|
||||
if credential.Type == "balancer" {
|
||||
switch credential.BalancerOptions.Strategy {
|
||||
if cred.Type == "external" {
|
||||
if cred.ExternalOptions.URL == "" {
|
||||
return E.New("credential ", cred.Tag, ": external credential requires url")
|
||||
}
|
||||
if cred.ExternalOptions.Token == "" {
|
||||
return E.New("credential ", cred.Tag, ": external credential requires token")
|
||||
}
|
||||
}
|
||||
if cred.Type == "balancer" {
|
||||
switch cred.BalancerOptions.Strategy {
|
||||
case "", "least_used", "round_robin", "random":
|
||||
default:
|
||||
return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy)
|
||||
return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -905,63 +1033,25 @@ func validateCCMOptions(options option.CCMServiceOptions) error {
|
||||
if !tags[user.Credential] {
|
||||
return E.New("user ", user.Name, " references unknown credential: ", user.Credential)
|
||||
}
|
||||
if user.ExternalCredential != "" {
|
||||
if !tags[user.ExternalCredential] {
|
||||
return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential)
|
||||
}
|
||||
if credentialTypes[user.ExternalCredential] != "external" {
|
||||
return E.New("user ", user.Name, ": external_credential must reference an external type credential")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// retryRequestWithBody re-sends a buffered request body using a different credential.
|
||||
func retryRequestWithBody(
|
||||
ctx context.Context,
|
||||
originalRequest *http.Request,
|
||||
bodyBytes []byte,
|
||||
credential *defaultCredential,
|
||||
httpHeaders http.Header,
|
||||
) (*http.Response, error) {
|
||||
accessToken, err := credential.getAccessToken()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "get access token for ", credential.tag)
|
||||
}
|
||||
|
||||
proxyURL := claudeAPIBaseURL + originalRequest.URL.RequestURI()
|
||||
retryRequest, err := http.NewRequestWithContext(ctx, originalRequest.Method, proxyURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, values := range originalRequest.Header {
|
||||
if !isHopByHopHeader(key) && key != "Authorization" {
|
||||
retryRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
serviceOverridesAcceptEncoding := len(httpHeaders.Values("Accept-Encoding")) > 0
|
||||
if credential.usageTracker != nil && !serviceOverridesAcceptEncoding {
|
||||
retryRequest.Header.Del("Accept-Encoding")
|
||||
}
|
||||
|
||||
anthropicBetaHeader := retryRequest.Header.Get("anthropic-beta")
|
||||
if anthropicBetaHeader != "" {
|
||||
retryRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader)
|
||||
} else {
|
||||
retryRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
|
||||
}
|
||||
|
||||
for key, values := range httpHeaders {
|
||||
retryRequest.Header.Del(key)
|
||||
retryRequest.Header[key] = values
|
||||
}
|
||||
retryRequest.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
return credential.httpClient.Do(retryRequest)
|
||||
}
|
||||
|
||||
// credentialForUser finds the credential provider for a user.
|
||||
// In legacy mode, returns the single provider.
|
||||
// In multi-credential mode, returns the provider mapped to the user's credential tag.
|
||||
func credentialForUser(
|
||||
userCredentialMap map[string]string,
|
||||
userConfigMap map[string]*option.CCMUser,
|
||||
providers map[string]credentialProvider,
|
||||
legacyProvider credentialProvider,
|
||||
username string,
|
||||
@@ -969,13 +1059,13 @@ func credentialForUser(
|
||||
if legacyProvider != nil {
|
||||
return legacyProvider, nil
|
||||
}
|
||||
tag, exists := userCredentialMap[username]
|
||||
userConfig, exists := userConfigMap[username]
|
||||
if !exists {
|
||||
return nil, E.New("no credential mapping for user: ", username)
|
||||
}
|
||||
provider, exists := providers[tag]
|
||||
provider, exists := providers[userConfig.Credential]
|
||||
if !exists {
|
||||
return nil, E.New("unknown credential: ", tag)
|
||||
return nil, E.New("unknown credential: ", userConfig.Credential)
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
@@ -66,15 +66,18 @@ func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, erro
|
||||
})
|
||||
}
|
||||
|
||||
func hasAlternativeCredential(provider credentialProvider, currentCredential *defaultCredential) bool {
|
||||
func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool {
|
||||
if provider == nil || currentCredential == nil {
|
||||
return false
|
||||
}
|
||||
for _, credential := range provider.allDefaults() {
|
||||
if credential == currentCredential {
|
||||
for _, cred := range provider.allCredentials() {
|
||||
if cred == currentCredential {
|
||||
continue
|
||||
}
|
||||
if credential.isUsable() {
|
||||
if filter != nil && !filter(cred) {
|
||||
continue
|
||||
}
|
||||
if cred.isUsable() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -85,7 +88,7 @@ func unavailableCredentialMessage(provider credentialProvider, fallback string)
|
||||
if provider == nil {
|
||||
return fallback
|
||||
}
|
||||
return allCredentialsUnavailableError(provider.allDefaults()).Error()
|
||||
return allCredentialsUnavailableError(provider.allCredentials()).Error()
|
||||
}
|
||||
|
||||
func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -100,10 +103,11 @@ func writeCredentialUnavailableError(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
provider credentialProvider,
|
||||
currentCredential *defaultCredential,
|
||||
currentCredential credential,
|
||||
filter func(credential) bool,
|
||||
fallback string,
|
||||
) {
|
||||
if hasAlternativeCredential(provider, currentCredential) {
|
||||
if hasAlternativeCredential(provider, currentCredential, filter) {
|
||||
writeRetryableUsageError(w, r)
|
||||
return
|
||||
}
|
||||
@@ -124,27 +128,15 @@ const (
|
||||
weeklyWindowMinutes = weeklyWindowSeconds / 60
|
||||
)
|
||||
|
||||
func parseInt64Header(headers http.Header, headerName string) (int64, bool) {
|
||||
headerValue := strings.TrimSpace(headers.Get(headerName))
|
||||
if headerValue == "" {
|
||||
return 0, false
|
||||
}
|
||||
parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64)
|
||||
if parseError != nil {
|
||||
return 0, false
|
||||
}
|
||||
return parsedValue, true
|
||||
}
|
||||
|
||||
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
|
||||
resetAtUnix, hasResetAt := parseInt64Header(headers, "anthropic-ratelimit-unified-7d-reset")
|
||||
if !hasResetAt || resetAtUnix <= 0 {
|
||||
resetAt, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset")
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &WeeklyCycleHint{
|
||||
WindowMinutes: weeklyWindowMinutes,
|
||||
ResetAt: time.Unix(resetAtUnix, 0).UTC(),
|
||||
ResetAt: resetAt.UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,9 +158,9 @@ type Service struct {
|
||||
legacyProvider credentialProvider
|
||||
|
||||
// Multi-credential mode
|
||||
providers map[string]credentialProvider
|
||||
allDefaults []*defaultCredential
|
||||
userCredentialMap map[string]string
|
||||
providers map[string]credentialProvider
|
||||
allCredentials []credential
|
||||
userConfigMap map[string]*option.CCMUser
|
||||
}
|
||||
|
||||
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) {
|
||||
@@ -199,20 +191,20 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
|
||||
}
|
||||
|
||||
if len(options.Credentials) > 0 {
|
||||
providers, allDefaults, err := buildCredentialProviders(ctx, options, logger)
|
||||
providers, allCredentials, err := buildCredentialProviders(ctx, options, logger)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "build credential providers")
|
||||
}
|
||||
service.providers = providers
|
||||
service.allDefaults = allDefaults
|
||||
service.allCredentials = allCredentials
|
||||
|
||||
userCredentialMap := make(map[string]string)
|
||||
for _, user := range options.Users {
|
||||
userCredentialMap[user.Name] = user.Credential
|
||||
userConfigMap := make(map[string]*option.CCMUser)
|
||||
for i := range options.Users {
|
||||
userConfigMap[options.Users[i].Name] = &options.Users[i]
|
||||
}
|
||||
service.userCredentialMap = userCredentialMap
|
||||
service.userConfigMap = userConfigMap
|
||||
} else {
|
||||
credential, err := newDefaultCredential(ctx, "default", option.CCMDefaultCredentialOptions{
|
||||
cred, err := newDefaultCredential(ctx, "default", option.CCMDefaultCredentialOptions{
|
||||
CredentialPath: options.CredentialPath,
|
||||
UsagesPath: options.UsagesPath,
|
||||
Detour: options.Detour,
|
||||
@@ -220,9 +212,9 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
service.legacyCredential = credential
|
||||
service.legacyProvider = &singleCredentialProvider{credential: credential}
|
||||
service.allDefaults = []*defaultCredential{credential}
|
||||
service.legacyCredential = cred
|
||||
service.legacyProvider = &singleCredentialProvider{cred: cred}
|
||||
service.allCredentials = []credential{cred}
|
||||
}
|
||||
|
||||
if options.TLS != nil {
|
||||
@@ -243,8 +235,8 @@ func (s *Service) Start(stage adapter.StartStage) error {
|
||||
|
||||
s.userManager.UpdateUsers(s.options.Users)
|
||||
|
||||
for _, credential := range s.allDefaults {
|
||||
err := credential.start()
|
||||
for _, cred := range s.allCredentials {
|
||||
err := cred.start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -303,6 +295,11 @@ func detectContextWindow(betaHeader string, totalInputTokens int64) int {
|
||||
}
|
||||
|
||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/ccm/v1/status" {
|
||||
s.handleStatusEndpoint(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(r.URL.Path, "/v1/") {
|
||||
writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found")
|
||||
return
|
||||
@@ -360,11 +357,13 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}
|
||||
|
||||
// Resolve credential provider
|
||||
// Resolve credential provider and user config
|
||||
var provider credentialProvider
|
||||
var userConfig *option.CCMUser
|
||||
if len(s.options.Users) > 0 {
|
||||
userConfig = s.userConfigMap[username]
|
||||
var err error
|
||||
provider, err = credentialForUser(s.userCredentialMap, s.providers, s.legacyProvider, username)
|
||||
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
s.logger.Error("resolve credential: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
@@ -389,70 +388,48 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
credential, isNew, err := provider.selectCredential(sessionID)
|
||||
var credentialFilter func(credential) bool
|
||||
if userConfig != nil && !userConfig.AllowExternalUsage {
|
||||
credentialFilter = func(c credential) bool { return !c.isExternal() }
|
||||
}
|
||||
|
||||
selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter)
|
||||
if err != nil {
|
||||
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error()))
|
||||
return
|
||||
}
|
||||
if isNew {
|
||||
if username != "" {
|
||||
s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID, " by user ", username)
|
||||
s.logger.Debug("assigned credential ", selectedCredential.tagName(), " for session ", sessionID, " by user ", username)
|
||||
} else {
|
||||
s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID)
|
||||
s.logger.Debug("assigned credential ", selectedCredential.tagName(), " for session ", sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, err := credential.getAccessToken()
|
||||
if err != nil {
|
||||
s.logger.Error("get access token: ", err)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed")
|
||||
if isExtendedContextRequest(anthropicBetaHeader) && selectedCredential.isExternal() {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"extended context (1m) requests cannot be proxied through external credentials")
|
||||
return
|
||||
}
|
||||
|
||||
proxyURL := claudeAPIBaseURL + r.URL.RequestURI()
|
||||
requestContext := credential.wrapRequestContext(r.Context())
|
||||
requestContext := selectedCredential.wrapRequestContext(r.Context())
|
||||
defer func() {
|
||||
requestContext.cancelRequest()
|
||||
}()
|
||||
proxyRequest, err := http.NewRequestWithContext(requestContext, r.Method, proxyURL, r.Body)
|
||||
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if err != nil {
|
||||
s.logger.Error("create proxy request: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
for key, values := range r.Header {
|
||||
if !isHopByHopHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
hasUsageTracker := credential.usageTracker != nil
|
||||
serviceOverridesAcceptEncoding := len(s.httpHeaders.Values("Accept-Encoding")) > 0
|
||||
if hasUsageTracker && !serviceOverridesAcceptEncoding {
|
||||
proxyRequest.Header.Del("Accept-Encoding")
|
||||
}
|
||||
|
||||
if anthropicBetaHeader != "" {
|
||||
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader)
|
||||
} else {
|
||||
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
|
||||
}
|
||||
|
||||
for key, values := range s.httpHeaders {
|
||||
proxyRequest.Header.Del(key)
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
response, err := credential.httpClient.Do(proxyRequest)
|
||||
response, err := selectedCredential.httpTransport().Do(proxyRequest)
|
||||
if err != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, credential, "credential became unavailable while processing the request")
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request")
|
||||
return
|
||||
}
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
|
||||
@@ -463,24 +440,30 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Transparent 429 retry
|
||||
for response.StatusCode == http.StatusTooManyRequests {
|
||||
resetAt := parseRateLimitResetFromHeaders(response.Header)
|
||||
nextCredential := provider.onRateLimited(sessionID, credential, resetAt)
|
||||
credential.updateStateFromHeaders(response.Header)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
if bodyBytes == nil || nextCredential == nil {
|
||||
response.Body.Close()
|
||||
writeCredentialUnavailableError(w, r, provider, credential, "all credentials rate-limited")
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
|
||||
return
|
||||
}
|
||||
response.Body.Close()
|
||||
s.logger.Info("retrying with credential ", nextCredential.tag, " after 429 from ", credential.tag)
|
||||
s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
requestContext.cancelRequest()
|
||||
requestContext = nextCredential.wrapRequestContext(r.Context())
|
||||
retryResponse, retryErr := retryRequestWithBody(requestContext, r, bodyBytes, nextCredential, s.httpHeaders)
|
||||
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if buildErr != nil {
|
||||
s.logger.Error("retry request: ", buildErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
|
||||
return
|
||||
}
|
||||
retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest)
|
||||
if retryErr != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, "credential became unavailable while retrying the request")
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
|
||||
return
|
||||
}
|
||||
s.logger.Error("retry request: ", retryErr)
|
||||
@@ -489,21 +472,24 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
response = retryResponse
|
||||
credential = nextCredential
|
||||
selectedCredential = nextCredential
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
credential.updateStateFromHeaders(response.Header)
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
|
||||
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
s.logger.Error("upstream error from ", credential.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
|
||||
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
|
||||
return
|
||||
}
|
||||
|
||||
hasUsageTracker = credential.usageTracker != nil
|
||||
// Rewrite response headers for external users
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" {
|
||||
s.rewriteResponseHeadersForExternalUser(response.Header, userConfig)
|
||||
}
|
||||
|
||||
for key, values := range response.Header {
|
||||
if !isHopByHopHeader(key) {
|
||||
@@ -512,8 +498,9 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
w.WriteHeader(response.StatusCode)
|
||||
|
||||
if hasUsageTracker && response.StatusCode == http.StatusOK {
|
||||
s.handleResponseWithTracking(w, response, credential.usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
|
||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||
if usageTracker != nil && response.StatusCode == http.StatusOK {
|
||||
s.handleResponseWithTracking(w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
|
||||
} else {
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
if err == nil && mediaType != "text/event-stream" {
|
||||
@@ -693,6 +680,91 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
if len(s.options.Users) == 0 {
|
||||
writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication")
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
username, ok := s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
|
||||
userConfig := s.userConfigMap[username]
|
||||
if userConfig == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found")
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale(r.Context())
|
||||
avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]float64{
|
||||
"five_hour_utilization": avgFiveHour,
|
||||
"weekly_utilization": avgWeekly,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64) {
|
||||
var totalFiveHour, totalWeekly float64
|
||||
var count int
|
||||
for _, cred := range provider.allCredentials() {
|
||||
// Exclude the user's own external_credential (their contribution to us)
|
||||
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
|
||||
continue
|
||||
}
|
||||
// If user doesn't allow external usage, exclude all external credentials
|
||||
if !userConfig.AllowExternalUsage && cred.isExternal() {
|
||||
continue
|
||||
}
|
||||
totalFiveHour += cred.fiveHourUtilization()
|
||||
totalWeekly += cred.weeklyUtilization()
|
||||
count++
|
||||
}
|
||||
if count == 0 {
|
||||
return 100, 100
|
||||
}
|
||||
return totalFiveHour / float64(count), totalWeekly / float64(count)
|
||||
}
|
||||
|
||||
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.CCMUser) {
|
||||
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
// Rewrite utilization headers to aggregated average (convert back to 0.0-1.0 range)
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64))
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64))
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
err := common.Close(
|
||||
common.PtrOrNil(s.httpServer),
|
||||
@@ -700,8 +772,8 @@ func (s *Service) Close() error {
|
||||
s.tlsConfig,
|
||||
)
|
||||
|
||||
for _, credential := range s.allDefaults {
|
||||
credential.close()
|
||||
for _, cred := range s.allCredentials {
|
||||
cred.close()
|
||||
}
|
||||
|
||||
return err
|
||||
|
||||
463
service/ocm/credential_external.go
Normal file
463
service/ocm/credential_external.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
)
|
||||
|
||||
type externalCredential struct {
|
||||
tag string
|
||||
baseURL string
|
||||
token string
|
||||
credDialer N.Dialer
|
||||
httpClient *http.Client
|
||||
state credentialState
|
||||
stateMutex sync.RWMutex
|
||||
pollAccess sync.Mutex
|
||||
pollInterval time.Duration
|
||||
usageTracker *AggregatedUsage
|
||||
logger log.ContextLogger
|
||||
|
||||
onBecameUnusable func()
|
||||
interrupted bool
|
||||
requestContext context.Context
|
||||
cancelRequests context.CancelFunc
|
||||
requestAccess sync.Mutex
|
||||
}
|
||||
|
||||
func newExternalCredential(ctx context.Context, tag string, options option.OCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) {
|
||||
parsedURL, err := url.Parse(options.URL)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "parse url for credential ", tag)
|
||||
}
|
||||
|
||||
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
Options: option.DialerOptions{
|
||||
Detour: options.Detour,
|
||||
},
|
||||
RemoteIsDomain: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create dialer for credential ", tag)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if options.Server != "" {
|
||||
serverPort := options.ServerPort
|
||||
if serverPort == 0 {
|
||||
portStr := parsedURL.Port()
|
||||
if portStr != "" {
|
||||
port, parseErr := strconv.ParseUint(portStr, 10, 16)
|
||||
if parseErr == nil {
|
||||
serverPort = uint16(port)
|
||||
}
|
||||
}
|
||||
if serverPort == 0 {
|
||||
if parsedURL.Scheme == "https" {
|
||||
serverPort = 443
|
||||
} else {
|
||||
serverPort = 80
|
||||
}
|
||||
}
|
||||
}
|
||||
destination := M.ParseSocksaddrHostPort(options.Server, serverPort)
|
||||
return credentialDialer.DialContext(ctx, network, destination)
|
||||
}
|
||||
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
}
|
||||
|
||||
if parsedURL.Scheme == "https" {
|
||||
transport.TLSClientConfig = &stdTLS.Config{
|
||||
ServerName: parsedURL.Hostname(),
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
|
||||
if parsedURL.Path != "" && parsedURL.Path != "/" {
|
||||
baseURL += parsedURL.Path
|
||||
}
|
||||
if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
|
||||
baseURL = baseURL[:len(baseURL)-1]
|
||||
}
|
||||
|
||||
pollInterval := time.Duration(options.PollInterval)
|
||||
if pollInterval <= 0 {
|
||||
pollInterval = 30 * time.Minute
|
||||
}
|
||||
|
||||
requestContext, cancelRequests := context.WithCancel(context.Background())
|
||||
|
||||
cred := &externalCredential{
|
||||
tag: tag,
|
||||
baseURL: baseURL,
|
||||
token: options.Token,
|
||||
credDialer: credentialDialer,
|
||||
httpClient: &http.Client{Transport: transport},
|
||||
pollInterval: pollInterval,
|
||||
logger: logger,
|
||||
requestContext: requestContext,
|
||||
cancelRequests: cancelRequests,
|
||||
}
|
||||
|
||||
if options.UsagesPath != "" {
|
||||
cred.usageTracker = &AggregatedUsage{
|
||||
LastUpdated: time.Now(),
|
||||
Combinations: make([]CostCombination, 0),
|
||||
filePath: options.UsagesPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) start() error {
|
||||
if c.usageTracker != nil {
|
||||
err := c.usageTracker.Load()
|
||||
if err != nil {
|
||||
c.logger.Warn("load usage statistics for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) setOnBecameUnusable(fn func()) {
|
||||
c.onBecameUnusable = fn
|
||||
}
|
||||
|
||||
func (c *externalCredential) tagName() string {
|
||||
return c.tag
|
||||
}
|
||||
|
||||
func (c *externalCredential) isExternal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *externalCredential) isUsable() bool {
|
||||
c.stateMutex.RLock()
|
||||
if c.state.hardRateLimited {
|
||||
if time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.stateMutex.RUnlock()
|
||||
return false
|
||||
}
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateMutex.Lock()
|
||||
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
|
||||
c.stateMutex.Unlock()
|
||||
return usable
|
||||
}
|
||||
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
|
||||
c.stateMutex.RUnlock()
|
||||
return usable
|
||||
}
|
||||
|
||||
func (c *externalCredential) fiveHourUtilization() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return c.state.fiveHourUtilization
|
||||
}
|
||||
|
||||
func (c *externalCredential) weeklyUtilization() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return c.state.weeklyUtilization
|
||||
}
|
||||
|
||||
func (c *externalCredential) markRateLimited(resetAt time.Time) {
|
||||
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
|
||||
c.stateMutex.Lock()
|
||||
c.state.hardRateLimited = true
|
||||
c.state.rateLimitResetAt = resetAt
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) earliestReset() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
if c.state.hardRateLimited {
|
||||
return c.state.rateLimitResetAt
|
||||
}
|
||||
earliest := c.state.fiveHourReset
|
||||
if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) {
|
||||
earliest = c.state.weeklyReset
|
||||
}
|
||||
return earliest
|
||||
}
|
||||
|
||||
func (c *externalCredential) getAccessToken() (string, error) {
|
||||
return c.token, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, _ http.Header) (*http.Request, error) {
|
||||
proxyURL := c.baseURL + original.URL.RequestURI()
|
||||
var body io.Reader
|
||||
if bodyBytes != nil {
|
||||
body = bytes.NewReader(bodyBytes)
|
||||
} else {
|
||||
body = original.Body
|
||||
}
|
||||
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, values := range original.Header {
|
||||
if !isHopByHopHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+c.token)
|
||||
|
||||
return proxyRequest, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
|
||||
c.stateMutex.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
|
||||
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
|
||||
if activeLimitIdentifier == "" {
|
||||
activeLimitIdentifier = "codex"
|
||||
}
|
||||
|
||||
fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent")
|
||||
if fiveHourPercent != "" {
|
||||
value, err := strconv.ParseFloat(fiveHourPercent, 64)
|
||||
if err == nil {
|
||||
c.state.fiveHourUtilization = value
|
||||
}
|
||||
}
|
||||
fiveHourResetAt := headers.Get("x-" + activeLimitIdentifier + "-primary-reset-at")
|
||||
if fiveHourResetAt != "" {
|
||||
value, err := strconv.ParseInt(fiveHourResetAt, 10, 64)
|
||||
if err == nil {
|
||||
c.state.fiveHourReset = time.Unix(value, 0)
|
||||
}
|
||||
}
|
||||
weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent")
|
||||
if weeklyPercent != "" {
|
||||
value, err := strconv.ParseFloat(weeklyPercent, 64)
|
||||
if err == nil {
|
||||
c.state.weeklyUtilization = value
|
||||
}
|
||||
}
|
||||
weeklyResetAt := headers.Get("x-" + activeLimitIdentifier + "-secondary-reset-at")
|
||||
if weeklyResetAt != "" {
|
||||
value, err := strconv.ParseInt(weeklyResetAt, 10, 64)
|
||||
if err == nil {
|
||||
c.state.weeklyReset = time.Unix(value, 0)
|
||||
}
|
||||
}
|
||||
c.state.lastUpdated = time.Now()
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%")
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) checkTransitionLocked() bool {
|
||||
unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100
|
||||
if unusable && !c.interrupted {
|
||||
c.interrupted = true
|
||||
return true
|
||||
}
|
||||
if !unusable && c.interrupted {
|
||||
c.interrupted = false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *externalCredential) wrapRequestContext(parent context.Context) *credentialRequestContext {
|
||||
c.requestAccess.Lock()
|
||||
credentialContext := c.requestContext
|
||||
c.requestAccess.Unlock()
|
||||
derived, cancel := context.WithCancel(parent)
|
||||
stop := context.AfterFunc(credentialContext, func() {
|
||||
cancel()
|
||||
})
|
||||
return &credentialRequestContext{
|
||||
Context: derived,
|
||||
releaseFunc: stop,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) interruptConnections() {
|
||||
c.logger.Warn("interrupting connections for ", c.tag)
|
||||
c.requestAccess.Lock()
|
||||
c.cancelRequests()
|
||||
c.requestContext, c.cancelRequests = context.WithCancel(context.Background())
|
||||
c.requestAccess.Unlock()
|
||||
if c.onBecameUnusable != nil {
|
||||
c.onBecameUnusable()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) pollUsage(ctx context.Context) {
|
||||
if !c.pollAccess.TryLock() {
|
||||
return
|
||||
}
|
||||
defer c.pollAccess.Unlock()
|
||||
defer c.markUsagePollAttempted()
|
||||
|
||||
statusURL := c.baseURL + "/ocm/v1/status"
|
||||
httpClient := &http.Client{
|
||||
Transport: c.httpClient.Transport,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("poll usage for ", c.tag, ": create request: ", err)
|
||||
return
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer "+c.token)
|
||||
|
||||
response, err := httpClient.Do(request)
|
||||
if err != nil {
|
||||
c.logger.Error("poll usage for ", c.tag, ": ", err)
|
||||
c.stateMutex.Lock()
|
||||
c.state.consecutivePollFailures++
|
||||
c.stateMutex.Unlock()
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
c.stateMutex.Lock()
|
||||
c.state.consecutivePollFailures++
|
||||
c.stateMutex.Unlock()
|
||||
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
return
|
||||
}
|
||||
|
||||
var statusResponse struct {
|
||||
FiveHourUtilization float64 `json:"five_hour_utilization"`
|
||||
WeeklyUtilization float64 `json:"weekly_utilization"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&statusResponse)
|
||||
if err != nil {
|
||||
c.stateMutex.Lock()
|
||||
c.state.consecutivePollFailures++
|
||||
c.stateMutex.Unlock()
|
||||
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.stateMutex.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
|
||||
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
|
||||
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%")
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) lastUpdatedTime() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return c.state.lastUpdated
|
||||
}
|
||||
|
||||
func (c *externalCredential) markUsagePollAttempted() {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
c.state.lastUpdated = time.Now()
|
||||
}
|
||||
|
||||
func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration {
|
||||
c.stateMutex.RLock()
|
||||
failures := c.state.consecutivePollFailures
|
||||
c.stateMutex.RUnlock()
|
||||
if failures <= 0 {
|
||||
return baseInterval
|
||||
}
|
||||
if failures > 4 {
|
||||
failures = 4
|
||||
}
|
||||
return baseInterval * time.Duration(1<<failures)
|
||||
}
|
||||
|
||||
func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage {
|
||||
return c.usageTracker
|
||||
}
|
||||
|
||||
func (c *externalCredential) httpTransport() *http.Client {
|
||||
return c.httpClient
|
||||
}
|
||||
|
||||
func (c *externalCredential) ocmDialer() N.Dialer {
|
||||
return c.credDialer
|
||||
}
|
||||
|
||||
func (c *externalCredential) ocmIsAPIKeyMode() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *externalCredential) ocmGetAccountID() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *externalCredential) ocmGetBaseURL() string {
|
||||
return c.baseURL
|
||||
}
|
||||
|
||||
func (c *externalCredential) close() {
|
||||
if c.usageTracker != nil {
|
||||
c.usageTracker.cancelPendingSave()
|
||||
err := c.usageTracker.Save()
|
||||
if err != nil {
|
||||
c.logger.Error("save usage statistics for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -82,6 +82,38 @@ func (c *credentialRequestContext) cancelRequest() {
|
||||
c.cancelOnce.Do(c.cancelFunc)
|
||||
}
|
||||
|
||||
type credential interface {
|
||||
tagName() string
|
||||
isUsable() bool
|
||||
isExternal() bool
|
||||
fiveHourUtilization() float64
|
||||
weeklyUtilization() float64
|
||||
markRateLimited(resetAt time.Time)
|
||||
earliestReset() time.Time
|
||||
|
||||
getAccessToken() (string, error)
|
||||
buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error)
|
||||
updateStateFromHeaders(header http.Header)
|
||||
|
||||
wrapRequestContext(ctx context.Context) *credentialRequestContext
|
||||
interruptConnections()
|
||||
|
||||
setOnBecameUnusable(fn func())
|
||||
start() error
|
||||
pollUsage(ctx context.Context)
|
||||
lastUpdatedTime() time.Time
|
||||
pollBackoff(base time.Duration) time.Duration
|
||||
usageTrackerOrNil() *AggregatedUsage
|
||||
httpTransport() *http.Client
|
||||
close()
|
||||
|
||||
// OCM-specific
|
||||
ocmDialer() N.Dialer
|
||||
ocmIsAPIKeyMode() bool
|
||||
ocmGetAccountID() string
|
||||
ocmGetBaseURL() string
|
||||
}
|
||||
|
||||
func newDefaultCredential(ctx context.Context, tag string, options option.OCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) {
|
||||
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
@@ -523,38 +555,132 @@ func (c *defaultCredential) close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) setOnBecameUnusable(fn func()) {
|
||||
c.onBecameUnusable = fn
|
||||
}
|
||||
|
||||
func (c *defaultCredential) tagName() string {
|
||||
return c.tag
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isExternal() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *defaultCredential) fiveHourUtilization() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return c.state.fiveHourUtilization
|
||||
}
|
||||
|
||||
func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage {
|
||||
return c.usageTracker
|
||||
}
|
||||
|
||||
func (c *defaultCredential) httpTransport() *http.Client {
|
||||
return c.httpClient
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ocmDialer() N.Dialer {
|
||||
return c.dialer
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ocmIsAPIKeyMode() bool {
|
||||
return c.isAPIKeyMode()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ocmGetAccountID() string {
|
||||
return c.getAccountID()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ocmGetBaseURL() string {
|
||||
return c.getBaseURL()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) {
|
||||
accessToken, err := c.getAccessToken()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "get access token for ", c.tag)
|
||||
}
|
||||
|
||||
path := original.URL.Path
|
||||
var proxyPath string
|
||||
if c.isAPIKeyMode() {
|
||||
proxyPath = path
|
||||
} else {
|
||||
proxyPath = strings.TrimPrefix(path, "/v1")
|
||||
}
|
||||
|
||||
proxyURL := c.getBaseURL() + proxyPath
|
||||
if original.URL.RawQuery != "" {
|
||||
proxyURL += "?" + original.URL.RawQuery
|
||||
}
|
||||
|
||||
var body io.Reader
|
||||
if bodyBytes != nil {
|
||||
body = bytes.NewReader(bodyBytes)
|
||||
} else {
|
||||
body = original.Body
|
||||
}
|
||||
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, values := range original.Header {
|
||||
if !isHopByHopHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
for key, values := range serviceHeaders {
|
||||
proxyRequest.Header.Del(key)
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
if accountID := c.getAccountID(); accountID != "" {
|
||||
proxyRequest.Header.Set("ChatGPT-Account-Id", accountID)
|
||||
}
|
||||
|
||||
return proxyRequest, nil
|
||||
}
|
||||
|
||||
type credentialProvider interface {
|
||||
selectCredential(sessionID string) (*defaultCredential, bool, error)
|
||||
onRateLimited(sessionID string, credential *defaultCredential, resetAt time.Time) *defaultCredential
|
||||
selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error)
|
||||
onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential
|
||||
pollIfStale(ctx context.Context)
|
||||
allDefaults() []*defaultCredential
|
||||
allCredentials() []credential
|
||||
close()
|
||||
}
|
||||
|
||||
type singleCredentialProvider struct {
|
||||
credential *defaultCredential
|
||||
cred credential
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) selectCredential(_ string) (*defaultCredential, bool, error) {
|
||||
if !p.credential.isUsable() {
|
||||
return nil, false, E.New("credential ", p.credential.tag, " is rate-limited")
|
||||
func (p *singleCredentialProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) {
|
||||
if filter != nil && !filter(p.cred) {
|
||||
return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out")
|
||||
}
|
||||
return p.credential, false, nil
|
||||
if !p.cred.isUsable() {
|
||||
return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited")
|
||||
}
|
||||
return p.cred, false, nil
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) onRateLimited(_ string, credential *defaultCredential, resetAt time.Time) *defaultCredential {
|
||||
credential.markRateLimited(resetAt)
|
||||
func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ func(credential) bool) credential {
|
||||
cred.markRateLimited(resetAt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) pollIfStale(ctx context.Context) {
|
||||
if time.Since(p.credential.lastUpdatedTime()) > p.credential.pollBackoff(defaultPollInterval) {
|
||||
p.credential.pollUsage(ctx)
|
||||
if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) {
|
||||
p.cred.pollUsage(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) allDefaults() []*defaultCredential {
|
||||
return []*defaultCredential{p.credential}
|
||||
func (p *singleCredentialProvider) allCredentials() []credential {
|
||||
return []credential{p.cred}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) close() {}
|
||||
@@ -567,7 +693,7 @@ type sessionEntry struct {
|
||||
}
|
||||
|
||||
type balancerProvider struct {
|
||||
credentials []*defaultCredential
|
||||
credentials []credential
|
||||
strategy string
|
||||
roundRobinIndex atomic.Uint64
|
||||
pollInterval time.Duration
|
||||
@@ -576,7 +702,7 @@ type balancerProvider struct {
|
||||
logger log.ContextLogger
|
||||
}
|
||||
|
||||
func newBalancerProvider(credentials []*defaultCredential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider {
|
||||
func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, logger log.ContextLogger) *balancerProvider {
|
||||
if pollInterval <= 0 {
|
||||
pollInterval = defaultPollInterval
|
||||
}
|
||||
@@ -589,15 +715,15 @@ func newBalancerProvider(credentials []*defaultCredential, strategy string, poll
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredential, bool, error) {
|
||||
func (p *balancerProvider) selectCredential(sessionID string, filter func(credential) bool) (credential, bool, error) {
|
||||
if sessionID != "" {
|
||||
p.sessionMutex.RLock()
|
||||
entry, exists := p.sessions[sessionID]
|
||||
p.sessionMutex.RUnlock()
|
||||
if exists {
|
||||
for _, credential := range p.credentials {
|
||||
if credential.tag == entry.tag && credential.isUsable() {
|
||||
return credential, false, nil
|
||||
for _, cred := range p.credentials {
|
||||
if cred.tagName() == entry.tag && (filter == nil || filter(cred)) && cred.isUsable() {
|
||||
return cred, false, nil
|
||||
}
|
||||
}
|
||||
p.sessionMutex.Lock()
|
||||
@@ -606,7 +732,7 @@ func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredentia
|
||||
}
|
||||
}
|
||||
|
||||
best := p.pickCredential()
|
||||
best := p.pickCredential(filter)
|
||||
if best == nil {
|
||||
return nil, false, allRateLimitedError(p.credentials)
|
||||
}
|
||||
@@ -614,61 +740,67 @@ func (p *balancerProvider) selectCredential(sessionID string) (*defaultCredentia
|
||||
isNew := sessionID != ""
|
||||
if isNew {
|
||||
p.sessionMutex.Lock()
|
||||
p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()}
|
||||
p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()}
|
||||
p.sessionMutex.Unlock()
|
||||
}
|
||||
return best, isNew, nil
|
||||
}
|
||||
|
||||
func (p *balancerProvider) onRateLimited(sessionID string, credential *defaultCredential, resetAt time.Time) *defaultCredential {
|
||||
credential.markRateLimited(resetAt)
|
||||
func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, filter func(credential) bool) credential {
|
||||
cred.markRateLimited(resetAt)
|
||||
if sessionID != "" {
|
||||
p.sessionMutex.Lock()
|
||||
delete(p.sessions, sessionID)
|
||||
p.sessionMutex.Unlock()
|
||||
}
|
||||
|
||||
best := p.pickCredential()
|
||||
best := p.pickCredential(filter)
|
||||
if best != nil && sessionID != "" {
|
||||
p.sessionMutex.Lock()
|
||||
p.sessions[sessionID] = sessionEntry{tag: best.tag, createdAt: time.Now()}
|
||||
p.sessions[sessionID] = sessionEntry{tag: best.tagName(), createdAt: time.Now()}
|
||||
p.sessionMutex.Unlock()
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickCredential() *defaultCredential {
|
||||
func (p *balancerProvider) pickCredential(filter func(credential) bool) credential {
|
||||
switch p.strategy {
|
||||
case "round_robin":
|
||||
return p.pickRoundRobin()
|
||||
return p.pickRoundRobin(filter)
|
||||
case "random":
|
||||
return p.pickRandom()
|
||||
return p.pickRandom(filter)
|
||||
default:
|
||||
return p.pickLeastUsed()
|
||||
return p.pickLeastUsed(filter)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickLeastUsed() *defaultCredential {
|
||||
var best *defaultCredential
|
||||
func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential {
|
||||
var best credential
|
||||
bestUtilization := float64(101)
|
||||
for _, credential := range p.credentials {
|
||||
if !credential.isUsable() {
|
||||
for _, cred := range p.credentials {
|
||||
if filter != nil && !filter(cred) {
|
||||
continue
|
||||
}
|
||||
utilization := credential.weeklyUtilization()
|
||||
if !cred.isUsable() {
|
||||
continue
|
||||
}
|
||||
utilization := cred.weeklyUtilization()
|
||||
if utilization < bestUtilization {
|
||||
bestUtilization = utilization
|
||||
best = credential
|
||||
best = cred
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickRoundRobin() *defaultCredential {
|
||||
func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential {
|
||||
start := int(p.roundRobinIndex.Add(1) - 1)
|
||||
count := len(p.credentials)
|
||||
for offset := range count {
|
||||
candidate := p.credentials[(start+offset)%count]
|
||||
if filter != nil && !filter(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
return candidate
|
||||
}
|
||||
@@ -676,9 +808,12 @@ func (p *balancerProvider) pickRoundRobin() *defaultCredential {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickRandom() *defaultCredential {
|
||||
var usable []*defaultCredential
|
||||
func (p *balancerProvider) pickRandom(filter func(credential) bool) credential {
|
||||
var usable []credential
|
||||
for _, candidate := range p.credentials {
|
||||
if filter != nil && !filter(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
usable = append(usable, candidate)
|
||||
}
|
||||
@@ -699,26 +834,26 @@ func (p *balancerProvider) pollIfStale(ctx context.Context) {
|
||||
}
|
||||
p.sessionMutex.Unlock()
|
||||
|
||||
for _, credential := range p.credentials {
|
||||
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) {
|
||||
credential.pollUsage(ctx)
|
||||
for _, cred := range p.credentials {
|
||||
if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) {
|
||||
cred.pollUsage(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) allDefaults() []*defaultCredential {
|
||||
func (p *balancerProvider) allCredentials() []credential {
|
||||
return p.credentials
|
||||
}
|
||||
|
||||
func (p *balancerProvider) close() {}
|
||||
|
||||
type fallbackProvider struct {
|
||||
credentials []*defaultCredential
|
||||
credentials []credential
|
||||
pollInterval time.Duration
|
||||
logger log.ContextLogger
|
||||
}
|
||||
|
||||
func newFallbackProvider(credentials []*defaultCredential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider {
|
||||
func newFallbackProvider(credentials []credential, pollInterval time.Duration, logger log.ContextLogger) *fallbackProvider {
|
||||
if pollInterval <= 0 {
|
||||
pollInterval = defaultPollInterval
|
||||
}
|
||||
@@ -729,18 +864,24 @@ func newFallbackProvider(credentials []*defaultCredential, pollInterval time.Dur
|
||||
}
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) selectCredential(_ string) (*defaultCredential, bool, error) {
|
||||
for _, credential := range p.credentials {
|
||||
if credential.isUsable() {
|
||||
return credential, false, nil
|
||||
func (p *fallbackProvider) selectCredential(_ string, filter func(credential) bool) (credential, bool, error) {
|
||||
for _, cred := range p.credentials {
|
||||
if filter != nil && !filter(cred) {
|
||||
continue
|
||||
}
|
||||
if cred.isUsable() {
|
||||
return cred, false, nil
|
||||
}
|
||||
}
|
||||
return nil, false, allRateLimitedError(p.credentials)
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) onRateLimited(_ string, credential *defaultCredential, resetAt time.Time) *defaultCredential {
|
||||
credential.markRateLimited(resetAt)
|
||||
func (p *fallbackProvider) onRateLimited(_ string, cred credential, resetAt time.Time, filter func(credential) bool) credential {
|
||||
cred.markRateLimited(resetAt)
|
||||
for _, candidate := range p.credentials {
|
||||
if filter != nil && !filter(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
return candidate
|
||||
}
|
||||
@@ -749,23 +890,23 @@ func (p *fallbackProvider) onRateLimited(_ string, credential *defaultCredential
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) pollIfStale(ctx context.Context) {
|
||||
for _, credential := range p.credentials {
|
||||
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(p.pollInterval) {
|
||||
credential.pollUsage(ctx)
|
||||
for _, cred := range p.credentials {
|
||||
if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) {
|
||||
cred.pollUsage(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) allDefaults() []*defaultCredential {
|
||||
func (p *fallbackProvider) allCredentials() []credential {
|
||||
return p.credentials
|
||||
}
|
||||
|
||||
func (p *fallbackProvider) close() {}
|
||||
|
||||
func allRateLimitedError(credentials []*defaultCredential) error {
|
||||
func allRateLimitedError(credentials []credential) error {
|
||||
var earliest time.Time
|
||||
for _, credential := range credentials {
|
||||
resetAt := credential.earliestReset()
|
||||
for _, cred := range credentials {
|
||||
resetAt := cred.earliestReset()
|
||||
if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) {
|
||||
earliest = resetAt
|
||||
}
|
||||
@@ -780,34 +921,44 @@ func buildOCMCredentialProviders(
|
||||
ctx context.Context,
|
||||
options option.OCMServiceOptions,
|
||||
logger log.ContextLogger,
|
||||
) (map[string]credentialProvider, []*defaultCredential, error) {
|
||||
defaultCredentials := make(map[string]*defaultCredential)
|
||||
var allDefaults []*defaultCredential
|
||||
) (map[string]credentialProvider, []credential, error) {
|
||||
allCredentialMap := make(map[string]credential)
|
||||
var allCreds []credential
|
||||
providers := make(map[string]credentialProvider)
|
||||
|
||||
// Pass 1: create default and external credentials
|
||||
for _, credOpt := range options.Credentials {
|
||||
switch credOpt.Type {
|
||||
case "default":
|
||||
credential, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger)
|
||||
cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defaultCredentials[credOpt.Tag] = credential
|
||||
allDefaults = append(allDefaults, credential)
|
||||
providers[credOpt.Tag] = &singleCredentialProvider{credential: credential}
|
||||
allCredentialMap[credOpt.Tag] = cred
|
||||
allCreds = append(allCreds, cred)
|
||||
providers[credOpt.Tag] = &singleCredentialProvider{cred: cred}
|
||||
case "external":
|
||||
cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
allCredentialMap[credOpt.Tag] = cred
|
||||
allCreds = append(allCreds, cred)
|
||||
providers[credOpt.Tag] = &singleCredentialProvider{cred: cred}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2: create balancer and fallback providers
|
||||
for _, credOpt := range options.Credentials {
|
||||
switch credOpt.Type {
|
||||
case "balancer":
|
||||
subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, defaultCredentials, credOpt.Tag)
|
||||
subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), logger)
|
||||
case "fallback":
|
||||
subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, defaultCredentials, credOpt.Tag)
|
||||
subCredentials, err := resolveCredentialTags(credOpt.FallbackOptions.Credentials, allCredentialMap, credOpt.Tag)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -815,17 +966,17 @@ func buildOCMCredentialProviders(
|
||||
}
|
||||
}
|
||||
|
||||
return providers, allDefaults, nil
|
||||
return providers, allCreds, nil
|
||||
}
|
||||
|
||||
func resolveCredentialTags(tags []string, defaults map[string]*defaultCredential, parentTag string) ([]*defaultCredential, error) {
|
||||
credentials := make([]*defaultCredential, 0, len(tags))
|
||||
func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) {
|
||||
credentials := make([]credential, 0, len(tags))
|
||||
for _, tag := range tags {
|
||||
credential, exists := defaults[tag]
|
||||
cred, exists := allCredentials[tag]
|
||||
if !exists {
|
||||
return nil, E.New("credential ", parentTag, " references unknown default credential: ", tag)
|
||||
return nil, E.New("credential ", parentTag, " references unknown credential: ", tag)
|
||||
}
|
||||
credentials = append(credentials, credential)
|
||||
credentials = append(credentials, cred)
|
||||
}
|
||||
if len(credentials) == 0 {
|
||||
return nil, E.New("credential ", parentTag, " has no sub-credentials")
|
||||
@@ -871,24 +1022,34 @@ func validateOCMOptions(options option.OCMServiceOptions) error {
|
||||
|
||||
if hasCredentials {
|
||||
tags := make(map[string]bool)
|
||||
for _, credential := range options.Credentials {
|
||||
if tags[credential.Tag] {
|
||||
return E.New("duplicate credential tag: ", credential.Tag)
|
||||
credentialTypes := make(map[string]string)
|
||||
for _, cred := range options.Credentials {
|
||||
if tags[cred.Tag] {
|
||||
return E.New("duplicate credential tag: ", cred.Tag)
|
||||
}
|
||||
tags[credential.Tag] = true
|
||||
if credential.Type == "default" || credential.Type == "" {
|
||||
if credential.DefaultOptions.Reserve5h > 99 {
|
||||
return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99")
|
||||
tags[cred.Tag] = true
|
||||
credentialTypes[cred.Tag] = cred.Type
|
||||
if cred.Type == "default" || cred.Type == "" {
|
||||
if cred.DefaultOptions.Reserve5h > 99 {
|
||||
return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99")
|
||||
}
|
||||
if credential.DefaultOptions.ReserveWeekly > 99 {
|
||||
return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99")
|
||||
if cred.DefaultOptions.ReserveWeekly > 99 {
|
||||
return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99")
|
||||
}
|
||||
}
|
||||
if credential.Type == "balancer" {
|
||||
switch credential.BalancerOptions.Strategy {
|
||||
if cred.Type == "external" {
|
||||
if cred.ExternalOptions.URL == "" {
|
||||
return E.New("credential ", cred.Tag, ": external credential requires url")
|
||||
}
|
||||
if cred.ExternalOptions.Token == "" {
|
||||
return E.New("credential ", cred.Tag, ": external credential requires token")
|
||||
}
|
||||
}
|
||||
if cred.Type == "balancer" {
|
||||
switch cred.BalancerOptions.Strategy {
|
||||
case "", "least_used", "round_robin", "random":
|
||||
default:
|
||||
return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy)
|
||||
return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -900,6 +1061,14 @@ func validateOCMOptions(options option.OCMServiceOptions) error {
|
||||
if !tags[user.Credential] {
|
||||
return E.New("user ", user.Name, " references unknown credential: ", user.Credential)
|
||||
}
|
||||
if user.ExternalCredential != "" {
|
||||
if !tags[user.ExternalCredential] {
|
||||
return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential)
|
||||
}
|
||||
if credentialTypes[user.ExternalCredential] != "external" {
|
||||
return E.New("user ", user.Name, ": external_credential must reference an external type credential")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -910,21 +1079,21 @@ func validateOCMCompositeCredentialModes(
|
||||
options option.OCMServiceOptions,
|
||||
providers map[string]credentialProvider,
|
||||
) error {
|
||||
for _, credential := range options.Credentials {
|
||||
if credential.Type != "balancer" && credential.Type != "fallback" {
|
||||
for _, credOpt := range options.Credentials {
|
||||
if credOpt.Type != "balancer" && credOpt.Type != "fallback" {
|
||||
continue
|
||||
}
|
||||
|
||||
provider, exists := providers[credential.Tag]
|
||||
provider, exists := providers[credOpt.Tag]
|
||||
if !exists {
|
||||
return E.New("unknown credential: ", credential.Tag)
|
||||
return E.New("unknown credential: ", credOpt.Tag)
|
||||
}
|
||||
|
||||
for _, subCredential := range provider.allDefaults() {
|
||||
if subCredential.isAPIKeyMode() {
|
||||
for _, subCred := range provider.allCredentials() {
|
||||
if subCred.ocmIsAPIKeyMode() {
|
||||
return E.New(
|
||||
"credential ", credential.Tag,
|
||||
" references API key default credential ", subCredential.tag,
|
||||
"credential ", credOpt.Tag,
|
||||
" references API key default credential ", subCred.tagName(),
|
||||
"; balancer and fallback only support OAuth default credentials",
|
||||
)
|
||||
}
|
||||
@@ -934,60 +1103,8 @@ func validateOCMCompositeCredentialModes(
|
||||
return nil
|
||||
}
|
||||
|
||||
func retryOCMRequestWithBody(
|
||||
ctx context.Context,
|
||||
originalRequest *http.Request,
|
||||
bodyBytes []byte,
|
||||
credential *defaultCredential,
|
||||
httpHeaders http.Header,
|
||||
) (*http.Response, error) {
|
||||
accessToken, err := credential.getAccessToken()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "get access token for ", credential.tag)
|
||||
}
|
||||
|
||||
baseURL := credential.getBaseURL()
|
||||
path := originalRequest.URL.Path
|
||||
var proxyPath string
|
||||
if credential.isAPIKeyMode() {
|
||||
proxyPath = path
|
||||
} else {
|
||||
proxyPath = strings.TrimPrefix(path, "/v1")
|
||||
}
|
||||
|
||||
proxyURL := baseURL + proxyPath
|
||||
if originalRequest.URL.RawQuery != "" {
|
||||
proxyURL += "?" + originalRequest.URL.RawQuery
|
||||
}
|
||||
|
||||
var body io.Reader
|
||||
if bodyBytes != nil {
|
||||
body = bytes.NewReader(bodyBytes)
|
||||
}
|
||||
retryRequest, err := http.NewRequestWithContext(ctx, originalRequest.Method, proxyURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, values := range originalRequest.Header {
|
||||
if !isHopByHopHeader(key) && key != "Authorization" {
|
||||
retryRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
for key, values := range httpHeaders {
|
||||
retryRequest.Header.Del(key)
|
||||
retryRequest.Header[key] = values
|
||||
}
|
||||
retryRequest.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
if accountID := credential.getAccountID(); accountID != "" {
|
||||
retryRequest.Header.Set("ChatGPT-Account-Id", accountID)
|
||||
}
|
||||
|
||||
return credential.httpClient.Do(retryRequest)
|
||||
}
|
||||
|
||||
func credentialForUser(
|
||||
userCredentialMap map[string]string,
|
||||
userConfigMap map[string]*option.OCMUser,
|
||||
providers map[string]credentialProvider,
|
||||
legacyProvider credentialProvider,
|
||||
username string,
|
||||
@@ -995,13 +1112,13 @@ func credentialForUser(
|
||||
if legacyProvider != nil {
|
||||
return legacyProvider, nil
|
||||
}
|
||||
tag, exists := userCredentialMap[username]
|
||||
userConfig, exists := userConfigMap[username]
|
||||
if !exists {
|
||||
return nil, E.New("no credential mapping for user: ", username)
|
||||
}
|
||||
provider, exists := providers[tag]
|
||||
provider, exists := providers[userConfig.Credential]
|
||||
if !exists {
|
||||
return nil, E.New("unknown credential: ", tag)
|
||||
return nil, E.New("unknown credential: ", userConfig.Credential)
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
@@ -74,15 +74,18 @@ const (
|
||||
retryableUsageCode = "credential_usage_exhausted"
|
||||
)
|
||||
|
||||
func hasAlternativeCredential(provider credentialProvider, currentCredential *defaultCredential) bool {
|
||||
func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool {
|
||||
if provider == nil || currentCredential == nil {
|
||||
return false
|
||||
}
|
||||
for _, credential := range provider.allDefaults() {
|
||||
if credential == currentCredential {
|
||||
for _, cred := range provider.allCredentials() {
|
||||
if cred == currentCredential {
|
||||
continue
|
||||
}
|
||||
if credential.isUsable() {
|
||||
if filter != nil && !filter(cred) {
|
||||
continue
|
||||
}
|
||||
if cred.isUsable() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -93,7 +96,7 @@ func unavailableCredentialMessage(provider credentialProvider, fallback string)
|
||||
if provider == nil {
|
||||
return fallback
|
||||
}
|
||||
return allRateLimitedError(provider.allDefaults()).Error()
|
||||
return allRateLimitedError(provider.allCredentials()).Error()
|
||||
}
|
||||
|
||||
func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -108,10 +111,11 @@ func writeCredentialUnavailableError(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
provider credentialProvider,
|
||||
currentCredential *defaultCredential,
|
||||
currentCredential credential,
|
||||
filter func(credential) bool,
|
||||
fallback string,
|
||||
) {
|
||||
if hasAlternativeCredential(provider, currentCredential) {
|
||||
if hasAlternativeCredential(provider, currentCredential, filter) {
|
||||
writeRetryableUsageError(w, r)
|
||||
return
|
||||
}
|
||||
@@ -198,9 +202,9 @@ type Service struct {
|
||||
legacyProvider credentialProvider
|
||||
|
||||
// Multi-credential mode
|
||||
providers map[string]credentialProvider
|
||||
allDefaults []*defaultCredential
|
||||
userCredentialMap map[string]string
|
||||
providers map[string]credentialProvider
|
||||
allCredentials []credential
|
||||
userConfigMap map[string]*option.OCMUser
|
||||
}
|
||||
|
||||
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) {
|
||||
@@ -230,20 +234,20 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
|
||||
}
|
||||
|
||||
if len(options.Credentials) > 0 {
|
||||
providers, allDefaults, err := buildOCMCredentialProviders(ctx, options, logger)
|
||||
providers, allCredentials, err := buildOCMCredentialProviders(ctx, options, logger)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "build credential providers")
|
||||
}
|
||||
service.providers = providers
|
||||
service.allDefaults = allDefaults
|
||||
service.allCredentials = allCredentials
|
||||
|
||||
userCredentialMap := make(map[string]string)
|
||||
for _, user := range options.Users {
|
||||
userCredentialMap[user.Name] = user.Credential
|
||||
userConfigMap := make(map[string]*option.OCMUser)
|
||||
for i := range options.Users {
|
||||
userConfigMap[options.Users[i].Name] = &options.Users[i]
|
||||
}
|
||||
service.userCredentialMap = userCredentialMap
|
||||
service.userConfigMap = userConfigMap
|
||||
} else {
|
||||
credential, err := newDefaultCredential(ctx, "default", option.OCMDefaultCredentialOptions{
|
||||
cred, err := newDefaultCredential(ctx, "default", option.OCMDefaultCredentialOptions{
|
||||
CredentialPath: options.CredentialPath,
|
||||
UsagesPath: options.UsagesPath,
|
||||
Detour: options.Detour,
|
||||
@@ -251,9 +255,9 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
service.legacyCredential = credential
|
||||
service.legacyProvider = &singleCredentialProvider{credential: credential}
|
||||
service.allDefaults = []*defaultCredential{credential}
|
||||
service.legacyCredential = cred
|
||||
service.legacyProvider = &singleCredentialProvider{cred: cred}
|
||||
service.allCredentials = []credential{cred}
|
||||
}
|
||||
|
||||
if options.TLS != nil {
|
||||
@@ -274,15 +278,15 @@ func (s *Service) Start(stage adapter.StartStage) error {
|
||||
|
||||
s.userManager.UpdateUsers(s.options.Users)
|
||||
|
||||
for _, credential := range s.allDefaults {
|
||||
err := credential.start()
|
||||
for _, cred := range s.allCredentials {
|
||||
err := cred.start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tag := credential.tag
|
||||
credential.onBecameUnusable = func() {
|
||||
tag := cred.tagName()
|
||||
cred.setOnBecameUnusable(func() {
|
||||
s.interruptWebSocketSessionsForCredential(tag)
|
||||
}
|
||||
})
|
||||
}
|
||||
if len(s.options.Credentials) > 0 {
|
||||
err := validateOCMCompositeCredentialModes(s.options, s.providers)
|
||||
@@ -327,7 +331,7 @@ func (s *Service) Start(stage adapter.StartStage) error {
|
||||
|
||||
func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) {
|
||||
if len(s.options.Users) > 0 {
|
||||
return credentialForUser(s.userCredentialMap, s.providers, s.legacyProvider, username)
|
||||
return credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
}
|
||||
provider := noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
|
||||
if provider == nil {
|
||||
@@ -337,6 +341,11 @@ func (s *Service) resolveCredentialProvider(username string) (credentialProvider
|
||||
}
|
||||
|
||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/ocm/v1/status" {
|
||||
s.handleStatusEndpoint(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
path := r.URL.Path
|
||||
if !strings.HasPrefix(path, "/v1/") {
|
||||
writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/")
|
||||
@@ -368,49 +377,64 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
sessionID := r.Header.Get("session_id")
|
||||
|
||||
// Resolve credential provider
|
||||
provider, err := s.resolveCredentialProvider(username)
|
||||
if err != nil {
|
||||
s.logger.Error("resolve credential: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
// Resolve credential provider and user config
|
||||
var provider credentialProvider
|
||||
var userConfig *option.OCMUser
|
||||
if len(s.options.Users) > 0 {
|
||||
userConfig = s.userConfigMap[username]
|
||||
var err error
|
||||
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
s.logger.Error("resolve credential: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
|
||||
}
|
||||
if provider == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale(s.ctx)
|
||||
|
||||
credential, isNew, err := provider.selectCredential(sessionID)
|
||||
var credentialFilter func(credential) bool
|
||||
if userConfig != nil && !userConfig.AllowExternalUsage {
|
||||
credentialFilter = func(c credential) bool { return !c.isExternal() }
|
||||
}
|
||||
|
||||
selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter)
|
||||
if err != nil {
|
||||
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error()))
|
||||
return
|
||||
}
|
||||
if isNew {
|
||||
if username != "" {
|
||||
s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID, " by user ", username)
|
||||
s.logger.Debug("assigned credential ", selectedCredential.tagName(), " for session ", sessionID, " by user ", username)
|
||||
} else {
|
||||
s.logger.Debug("assigned credential ", credential.tag, " for session ", sessionID)
|
||||
s.logger.Debug("assigned credential ", selectedCredential.tagName(), " for session ", sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
|
||||
s.handleWebSocket(w, r, path, username, sessionID, provider, credential)
|
||||
s.handleWebSocket(w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter)
|
||||
return
|
||||
}
|
||||
|
||||
var proxyPath string
|
||||
if credential.isAPIKeyMode() {
|
||||
proxyPath = path
|
||||
} else {
|
||||
if !selectedCredential.isExternal() && selectedCredential.ocmIsAPIKeyMode() {
|
||||
// API key mode path handling
|
||||
} else if !selectedCredential.isExternal() {
|
||||
if path == "/v1/chat/completions" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"chat completions endpoint is only available in API key mode")
|
||||
return
|
||||
}
|
||||
proxyPath = strings.TrimPrefix(path, "/v1")
|
||||
}
|
||||
|
||||
shouldTrackUsage := credential.usageTracker != nil &&
|
||||
shouldTrackUsage := selectedCredential.usageTrackerOrNil() != nil &&
|
||||
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses"))
|
||||
canRetryRequest := len(provider.allDefaults()) > 1
|
||||
canRetryRequest := len(provider.allCredentials()) > 1
|
||||
|
||||
// Read body for model extraction and retry buffer when JSON replay is useful.
|
||||
var bodyBytes []byte
|
||||
@@ -435,52 +459,24 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, err := credential.getAccessToken()
|
||||
if err != nil {
|
||||
s.logger.Error("get access token: ", err)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed")
|
||||
return
|
||||
}
|
||||
|
||||
proxyURL := credential.getBaseURL() + proxyPath
|
||||
if r.URL.RawQuery != "" {
|
||||
proxyURL += "?" + r.URL.RawQuery
|
||||
}
|
||||
requestContext := credential.wrapRequestContext(r.Context())
|
||||
requestContext := selectedCredential.wrapRequestContext(r.Context())
|
||||
defer func() {
|
||||
requestContext.cancelRequest()
|
||||
}()
|
||||
proxyRequest, err := http.NewRequestWithContext(requestContext, r.Method, proxyURL, r.Body)
|
||||
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if err != nil {
|
||||
s.logger.Error("create proxy request: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
for key, values := range r.Header {
|
||||
if !isHopByHopHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
for key, values := range s.httpHeaders {
|
||||
proxyRequest.Header.Del(key)
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
if accountID := credential.getAccountID(); accountID != "" {
|
||||
proxyRequest.Header.Set("ChatGPT-Account-Id", accountID)
|
||||
}
|
||||
|
||||
response, err := credential.httpClient.Do(proxyRequest)
|
||||
response, err := selectedCredential.httpTransport().Do(proxyRequest)
|
||||
if err != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, credential, "credential became unavailable while processing the request")
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request")
|
||||
return
|
||||
}
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
|
||||
@@ -491,25 +487,31 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Transparent 429 retry
|
||||
for response.StatusCode == http.StatusTooManyRequests {
|
||||
resetAt := parseOCMRateLimitResetFromHeaders(response.Header)
|
||||
nextCredential := provider.onRateLimited(sessionID, credential, resetAt)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
|
||||
needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete
|
||||
credential.updateStateFromHeaders(response.Header)
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil {
|
||||
response.Body.Close()
|
||||
writeCredentialUnavailableError(w, r, provider, credential, "all credentials rate-limited")
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
|
||||
return
|
||||
}
|
||||
response.Body.Close()
|
||||
s.logger.Info("retrying with credential ", nextCredential.tag, " after 429 from ", credential.tag)
|
||||
s.logger.Info("retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
requestContext.cancelRequest()
|
||||
requestContext = nextCredential.wrapRequestContext(r.Context())
|
||||
retryResponse, retryErr := retryOCMRequestWithBody(requestContext, r, bodyBytes, nextCredential, s.httpHeaders)
|
||||
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if buildErr != nil {
|
||||
s.logger.Error("retry request: ", buildErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
|
||||
return
|
||||
}
|
||||
retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest)
|
||||
if retryErr != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, "credential became unavailable while retrying the request")
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
|
||||
return
|
||||
}
|
||||
s.logger.Error("retry request: ", retryErr)
|
||||
@@ -518,20 +520,25 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
response = retryResponse
|
||||
credential = nextCredential
|
||||
selectedCredential = nextCredential
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
credential.updateStateFromHeaders(response.Header)
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
|
||||
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
s.logger.Error("upstream error from ", credential.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
s.logger.Error("upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
|
||||
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite response headers for external users
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" {
|
||||
s.rewriteResponseHeadersForExternalUser(response.Header, userConfig)
|
||||
}
|
||||
|
||||
for key, values := range response.Header {
|
||||
if !isHopByHopHeader(key) {
|
||||
w.Header()[key] = values
|
||||
@@ -539,10 +546,10 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
w.WriteHeader(response.StatusCode)
|
||||
|
||||
hasUsageTracker := credential.usageTracker != nil
|
||||
if hasUsageTracker && response.StatusCode == http.StatusOK &&
|
||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||
if usageTracker != nil && response.StatusCode == http.StatusOK &&
|
||||
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) {
|
||||
s.handleResponseWithTracking(w, response, credential.usageTracker, path, requestModel, username)
|
||||
s.handleResponseWithTracking(w, response, usageTracker, path, requestModel, username)
|
||||
} else {
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
if err == nil && mediaType != "text/event-stream" {
|
||||
@@ -745,6 +752,93 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
if len(s.options.Users) == 0 {
|
||||
writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication")
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
username, ok := s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
|
||||
userConfig := s.userConfigMap[username]
|
||||
if userConfig == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found")
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale(r.Context())
|
||||
avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]float64{
|
||||
"five_hour_utilization": avgFiveHour,
|
||||
"weekly_utilization": avgWeekly,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64) {
|
||||
var totalFiveHour, totalWeekly float64
|
||||
var count int
|
||||
for _, cred := range provider.allCredentials() {
|
||||
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
|
||||
continue
|
||||
}
|
||||
if !userConfig.AllowExternalUsage && cred.isExternal() {
|
||||
continue
|
||||
}
|
||||
totalFiveHour += cred.fiveHourUtilization()
|
||||
totalWeekly += cred.weeklyUtilization()
|
||||
count++
|
||||
}
|
||||
if count == 0 {
|
||||
return 100, 100
|
||||
}
|
||||
return totalFiveHour / float64(count), totalWeekly / float64(count)
|
||||
}
|
||||
|
||||
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.OCMUser) {
|
||||
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
|
||||
if activeLimitIdentifier == "" {
|
||||
activeLimitIdentifier = "codex"
|
||||
}
|
||||
|
||||
headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64))
|
||||
headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64))
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
webSocketSessions := s.startWebSocketShutdown()
|
||||
|
||||
@@ -758,8 +852,8 @@ func (s *Service) Close() error {
|
||||
}
|
||||
s.webSocketGroup.Wait()
|
||||
|
||||
for _, credential := range s.allDefaults {
|
||||
credential.close()
|
||||
for _, cred := range s.allCredentials {
|
||||
cred.close()
|
||||
}
|
||||
|
||||
return err
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
@@ -85,8 +86,10 @@ func (s *Service) handleWebSocket(
|
||||
path string,
|
||||
username string,
|
||||
sessionID string,
|
||||
userConfig *option.OCMUser,
|
||||
provider credentialProvider,
|
||||
credential *defaultCredential,
|
||||
selectedCredential credential,
|
||||
credentialFilter func(credential) bool,
|
||||
) {
|
||||
var (
|
||||
err error
|
||||
@@ -97,7 +100,7 @@ func (s *Service) handleWebSocket(
|
||||
)
|
||||
|
||||
for {
|
||||
accessToken, accessErr := credential.getAccessToken()
|
||||
accessToken, accessErr := selectedCredential.getAccessToken()
|
||||
if accessErr != nil {
|
||||
s.logger.Error("get access token for websocket: ", accessErr)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed")
|
||||
@@ -105,13 +108,13 @@ func (s *Service) handleWebSocket(
|
||||
}
|
||||
|
||||
var proxyPath string
|
||||
if credential.isAPIKeyMode() {
|
||||
if selectedCredential.ocmIsAPIKeyMode() || selectedCredential.isExternal() {
|
||||
proxyPath = path
|
||||
} else {
|
||||
proxyPath = strings.TrimPrefix(path, "/v1")
|
||||
}
|
||||
|
||||
upstreamURL := buildUpstreamWebSocketURL(credential.getBaseURL(), proxyPath)
|
||||
upstreamURL := buildUpstreamWebSocketURL(selectedCredential.ocmGetBaseURL(), proxyPath)
|
||||
if r.URL.RawQuery != "" {
|
||||
upstreamURL += "?" + r.URL.RawQuery
|
||||
}
|
||||
@@ -127,7 +130,7 @@ func (s *Service) handleWebSocket(
|
||||
upstreamHeaders[key] = values
|
||||
}
|
||||
upstreamHeaders.Set("Authorization", "Bearer "+accessToken)
|
||||
if accountID := credential.getAccountID(); accountID != "" {
|
||||
if accountID := selectedCredential.ocmGetAccountID(); accountID != "" {
|
||||
upstreamHeaders.Set("ChatGPT-Account-Id", accountID)
|
||||
}
|
||||
|
||||
@@ -135,7 +138,7 @@ func (s *Service) handleWebSocket(
|
||||
statusCode = 0
|
||||
upstreamDialer := ws.Dialer{
|
||||
NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return credential.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
return selectedCredential.ocmDialer().DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
TLSConfig: &stdTLS.Config{
|
||||
RootCAs: adapter.RootPoolFromContext(s.ctx),
|
||||
@@ -170,14 +173,14 @@ func (s *Service) handleWebSocket(
|
||||
}
|
||||
if statusCode == http.StatusTooManyRequests {
|
||||
resetAt := parseOCMRateLimitResetFromHeaders(upstreamResponseHeaders)
|
||||
nextCredential := provider.onRateLimited(sessionID, credential, resetAt)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
|
||||
if nextCredential == nil {
|
||||
credential.updateStateFromHeaders(upstreamResponseHeaders)
|
||||
writeCredentialUnavailableError(w, r, provider, credential, "all credentials rate-limited")
|
||||
selectedCredential.updateStateFromHeaders(upstreamResponseHeaders)
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
|
||||
return
|
||||
}
|
||||
s.logger.Info("retrying websocket with credential ", nextCredential.tag, " after 429 from ", credential.tag)
|
||||
credential = nextCredential
|
||||
s.logger.Info("retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
selectedCredential = nextCredential
|
||||
continue
|
||||
}
|
||||
s.logger.Error("dial upstream websocket: ", err)
|
||||
@@ -185,15 +188,18 @@ func (s *Service) handleWebSocket(
|
||||
return
|
||||
}
|
||||
|
||||
credential.updateStateFromHeaders(upstreamResponseHeaders)
|
||||
selectedCredential.updateStateFromHeaders(upstreamResponseHeaders)
|
||||
weeklyCycleHint := extractWeeklyCycleHint(upstreamResponseHeaders)
|
||||
|
||||
clientResponseHeaders := make(http.Header)
|
||||
for key, values := range upstreamResponseHeaders {
|
||||
if isForwardableResponseHeader(key) {
|
||||
clientResponseHeaders[key] = values
|
||||
clientResponseHeaders[key] = append([]string(nil), values...)
|
||||
}
|
||||
}
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" {
|
||||
s.rewriteResponseHeadersForExternalUser(clientResponseHeaders, userConfig)
|
||||
}
|
||||
|
||||
clientUpgrader := ws.HTTPUpgrader{
|
||||
Header: clientResponseHeaders,
|
||||
@@ -212,7 +218,7 @@ func (s *Service) handleWebSocket(
|
||||
session := &webSocketSession{
|
||||
clientConn: clientConn,
|
||||
upstreamConn: upstreamConn,
|
||||
credentialTag: credential.tag,
|
||||
credentialTag: selectedCredential.tagName(),
|
||||
}
|
||||
if !s.registerWebSocketSession(session) {
|
||||
session.Close()
|
||||
@@ -237,17 +243,17 @@ func (s *Service) handleWebSocket(
|
||||
go func() {
|
||||
defer waitGroup.Done()
|
||||
defer session.Close()
|
||||
s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, credential, modelChannel)
|
||||
s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, selectedCredential, modelChannel)
|
||||
}()
|
||||
go func() {
|
||||
defer waitGroup.Done()
|
||||
defer session.Close()
|
||||
s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, credential, modelChannel, username, weeklyCycleHint)
|
||||
s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, selectedCredential, modelChannel, username, weeklyCycleHint)
|
||||
}()
|
||||
waitGroup.Wait()
|
||||
}
|
||||
|
||||
func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, credential *defaultCredential, modelChannel chan<- string) {
|
||||
func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string) {
|
||||
for {
|
||||
data, opCode, err := wsutil.ReadClientData(clientConn)
|
||||
if err != nil {
|
||||
@@ -257,7 +263,7 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo
|
||||
return
|
||||
}
|
||||
|
||||
if opCode == ws.OpText && credential.usageTracker != nil {
|
||||
if opCode == ws.OpText && selectedCredential.usageTrackerOrNil() != nil {
|
||||
var request struct {
|
||||
Type string `json:"type"`
|
||||
Model string `json:"model"`
|
||||
@@ -280,7 +286,8 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, credential *defaultCredential, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
|
||||
func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
|
||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||
var requestModel string
|
||||
for {
|
||||
data, opCode, err := wsutil.ReadServerData(upstreamReadWriter)
|
||||
@@ -291,7 +298,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite
|
||||
return
|
||||
}
|
||||
|
||||
if opCode == ws.OpText && credential.usageTracker != nil {
|
||||
if opCode == ws.OpText && usageTracker != nil {
|
||||
select {
|
||||
case model := <-modelChannel:
|
||||
requestModel = model
|
||||
@@ -317,7 +324,7 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite
|
||||
}
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
|
||||
credential.usageTracker.AddUsageWithCycleHint(
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
inputTokens,
|
||||
|
||||
Reference in New Issue
Block a user