mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
ccm,ocm: reorganize files and improve naming conventions
Split credential_state.go (1500+ lines) into credential.go, credential_default.go, credential_provider.go, credential_builder.go. Split service.go (900+ lines) into service.go, service_handler.go, service_status.go. Rename credential.go to credential_oauth.go to avoid name conflict with the credential interface. Apply naming fixes: accessMutex→access, stateMutex→stateAccess, sessionMutex→sessionAccess, webSocketMutex→webSocketAccess, httpTransport()→httpClient(), httpClient field→forwardHTTPClient, weeklyWindowDuration→weeklyWindowHours.
This commit is contained in:
@@ -1,224 +1,187 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
const (
|
||||
oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
oauth2TokenURL = "https://platform.claude.com/v1/oauth/token"
|
||||
claudeAPIBaseURL = "https://api.anthropic.com"
|
||||
tokenRefreshBufferMs = 60000
|
||||
anthropicBetaOAuthValue = "oauth-2025-04-20"
|
||||
defaultPollInterval = 60 * time.Minute
|
||||
failedPollRetryInterval = time.Minute
|
||||
httpRetryMaxBackoff = 5 * time.Minute
|
||||
)
|
||||
|
||||
const ccmUserAgentFallback = "claude-code/2.1.72"
|
||||
|
||||
var (
|
||||
ccmUserAgentOnce sync.Once
|
||||
ccmUserAgentValue string
|
||||
const (
|
||||
httpRetryMaxAttempts = 3
|
||||
httpRetryInitialDelay = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
func initCCMUserAgent(logger log.ContextLogger) {
|
||||
ccmUserAgentOnce.Do(func() {
|
||||
version, err := detectClaudeCodeVersion()
|
||||
if err != nil {
|
||||
logger.Error("detect Claude Code version: ", err)
|
||||
ccmUserAgentValue = ccmUserAgentFallback
|
||||
return
|
||||
const sessionExpiry = 24 * time.Hour
|
||||
|
||||
func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) {
|
||||
var lastError error
|
||||
for attempt := range httpRetryMaxAttempts {
|
||||
if attempt > 0 {
|
||||
delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, lastError
|
||||
case <-time.After(delay):
|
||||
}
|
||||
}
|
||||
logger.Debug("detected Claude Code version: ", version)
|
||||
ccmUserAgentValue = "claude-code/" + version
|
||||
})
|
||||
}
|
||||
|
||||
func detectClaudeCodeVersion() (string, error) {
|
||||
userInfo, err := getRealUser()
|
||||
if err != nil {
|
||||
return "", E.Cause(err, "get user")
|
||||
}
|
||||
binaryName := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
binaryName = "claude.exe"
|
||||
}
|
||||
linkPath := filepath.Join(userInfo.HomeDir, ".local", "bin", binaryName)
|
||||
target, err := os.Readlink(linkPath)
|
||||
if err != nil {
|
||||
return "", E.Cause(err, "readlink ", linkPath)
|
||||
}
|
||||
if !filepath.IsAbs(target) {
|
||||
target = filepath.Join(filepath.Dir(linkPath), target)
|
||||
}
|
||||
parent := filepath.Base(filepath.Dir(target))
|
||||
if parent != "versions" {
|
||||
return "", E.New("unexpected symlink target: ", target)
|
||||
}
|
||||
return filepath.Base(target), nil
|
||||
}
|
||||
|
||||
func getRealUser() (*user.User, error) {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
sudoUserInfo, err := user.Lookup(sudoUser)
|
||||
if err == nil {
|
||||
return sudoUserInfo, nil
|
||||
}
|
||||
}
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
func getDefaultCredentialsPath() (string, error) {
|
||||
if configDir := os.Getenv("CLAUDE_CONFIG_DIR"); configDir != "" {
|
||||
return filepath.Join(configDir, ".credentials.json"), nil
|
||||
}
|
||||
userInfo, err := getRealUser()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(userInfo.HomeDir, ".claude", ".credentials.json"), nil
|
||||
}
|
||||
|
||||
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var credentialsContainer struct {
|
||||
ClaudeAIAuth *oauthCredentials `json:"claudeAiOauth,omitempty"`
|
||||
}
|
||||
err = json.Unmarshal(data, &credentialsContainer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if credentialsContainer.ClaudeAIAuth == nil {
|
||||
return nil, E.New("claudeAiOauth field not found in credentials")
|
||||
}
|
||||
return credentialsContainer.ClaudeAIAuth, nil
|
||||
}
|
||||
|
||||
func checkCredentialFileWritable(path string) error {
|
||||
file, err := os.OpenFile(path, os.O_WRONLY, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) error {
|
||||
data, err := json.MarshalIndent(map[string]any{
|
||||
"claudeAiOauth": oauthCredentials,
|
||||
}, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0o600)
|
||||
}
|
||||
|
||||
type oauthCredentials struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpiresAt int64 `json:"expiresAt"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
SubscriptionType string `json:"subscriptionType,omitempty"`
|
||||
RateLimitTier string `json:"rateLimitTier,omitempty"`
|
||||
IsMax bool `json:"isMax,omitempty"`
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) needsRefresh() bool {
|
||||
if c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs
|
||||
}
|
||||
|
||||
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
|
||||
if credentials.RefreshToken == "" {
|
||||
return nil, E.New("refresh token is empty")
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.RefreshToken,
|
||||
"client_id": oauth2ClientID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "marshal request")
|
||||
}
|
||||
|
||||
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
|
||||
request, err := buildRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("User-Agent", ccmUserAgentValue)
|
||||
return request, nil
|
||||
response, err := client.Do(request)
|
||||
if err == nil {
|
||||
return response, nil
|
||||
}
|
||||
lastError = err
|
||||
if ctx.Err() != nil {
|
||||
return nil, lastError
|
||||
}
|
||||
}
|
||||
return nil, lastError
|
||||
}
|
||||
|
||||
type credentialState struct {
|
||||
fiveHourUtilization float64
|
||||
fiveHourReset time.Time
|
||||
weeklyUtilization float64
|
||||
weeklyReset time.Time
|
||||
hardRateLimited bool
|
||||
rateLimitResetAt time.Time
|
||||
accountType string
|
||||
rateLimitTier string
|
||||
remotePlanWeight float64
|
||||
lastUpdated time.Time
|
||||
consecutivePollFailures int
|
||||
unavailable bool
|
||||
lastCredentialLoadAttempt time.Time
|
||||
lastCredentialLoadError string
|
||||
}
|
||||
|
||||
type credentialRequestContext struct {
|
||||
context.Context
|
||||
releaseOnce sync.Once
|
||||
cancelOnce sync.Once
|
||||
releaseFuncs []func() bool
|
||||
cancelFunc context.CancelFunc
|
||||
}
|
||||
|
||||
func (c *credentialRequestContext) addInterruptLink(stop func() bool) {
|
||||
c.releaseFuncs = append(c.releaseFuncs, stop)
|
||||
}
|
||||
|
||||
func (c *credentialRequestContext) releaseCredentialInterrupt() {
|
||||
c.releaseOnce.Do(func() {
|
||||
for _, f := range c.releaseFuncs {
|
||||
f()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (c *credentialRequestContext) cancelRequest() {
|
||||
c.releaseCredentialInterrupt()
|
||||
c.cancelOnce.Do(c.cancelFunc)
|
||||
}
|
||||
|
||||
type credential interface {
|
||||
tagName() string
|
||||
isAvailable() bool
|
||||
isUsable() bool
|
||||
isExternal() bool
|
||||
fiveHourUtilization() float64
|
||||
weeklyUtilization() float64
|
||||
fiveHourCap() float64
|
||||
weeklyCap() float64
|
||||
planWeight() float64
|
||||
weeklyResetTime() time.Time
|
||||
markRateLimited(resetAt time.Time)
|
||||
earliestReset() time.Time
|
||||
unavailableError() error
|
||||
|
||||
getAccessToken() (string, error)
|
||||
buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error)
|
||||
updateStateFromHeaders(header http.Header)
|
||||
|
||||
wrapRequestContext(ctx context.Context) *credentialRequestContext
|
||||
interruptConnections()
|
||||
|
||||
start() error
|
||||
pollUsage(ctx context.Context)
|
||||
lastUpdatedTime() time.Time
|
||||
pollBackoff(base time.Duration) time.Duration
|
||||
usageTrackerOrNil() *AggregatedUsage
|
||||
httpClient() *http.Client
|
||||
close()
|
||||
}
|
||||
|
||||
type credentialSelectionScope string
|
||||
|
||||
const (
|
||||
credentialSelectionScopeAll credentialSelectionScope = "all"
|
||||
credentialSelectionScopeNonExternal credentialSelectionScope = "non_external"
|
||||
)
|
||||
|
||||
type credentialSelection struct {
|
||||
scope credentialSelectionScope
|
||||
filter func(credential) bool
|
||||
}
|
||||
|
||||
func (s credentialSelection) allows(cred credential) bool {
|
||||
return s.filter == nil || s.filter(cred)
|
||||
}
|
||||
|
||||
func (s credentialSelection) scopeOrDefault() credentialSelectionScope {
|
||||
if s.scope == "" {
|
||||
return credentialSelectionScopeAll
|
||||
}
|
||||
return s.scope
|
||||
}
|
||||
|
||||
// Claude Code's unified rate-limit handling parses these reset headers with
|
||||
// Number(...), compares them against Date.now()/1000, and renders them via
|
||||
// new Date(seconds*1000), so keep the wire format pinned to Unix epoch seconds.
|
||||
func parseAnthropicResetHeaderValue(headerName string, headerValue string) time.Time {
|
||||
unixEpoch, err := strconv.ParseInt(headerValue, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
panic("invalid " + headerName + " header: expected Unix epoch seconds, got " + strconv.Quote(headerValue))
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode == http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, E.New("refresh rate limited: ", response.Status, " ", string(body))
|
||||
if unixEpoch <= 0 {
|
||||
panic("invalid " + headerName + " header: expected positive Unix epoch seconds, got " + strconv.Quote(headerValue))
|
||||
}
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
|
||||
}
|
||||
|
||||
var tokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode response")
|
||||
}
|
||||
|
||||
newCredentials := *credentials
|
||||
newCredentials.AccessToken = tokenResponse.AccessToken
|
||||
if tokenResponse.RefreshToken != "" {
|
||||
newCredentials.RefreshToken = tokenResponse.RefreshToken
|
||||
}
|
||||
newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000
|
||||
|
||||
return &newCredentials, nil
|
||||
return time.Unix(unixEpoch, 0)
|
||||
}
|
||||
|
||||
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
|
||||
if credentials == nil {
|
||||
return nil
|
||||
func parseOptionalAnthropicResetHeader(headers http.Header, headerName string) (time.Time, bool) {
|
||||
headerValue := headers.Get(headerName)
|
||||
if headerValue == "" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
cloned := *credentials
|
||||
cloned.Scopes = append([]string(nil), credentials.Scopes...)
|
||||
return &cloned
|
||||
return parseAnthropicResetHeaderValue(headerName, headerValue), true
|
||||
}
|
||||
|
||||
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
|
||||
if left == nil || right == nil {
|
||||
return left == right
|
||||
func parseRequiredAnthropicResetHeader(headers http.Header, headerName string) time.Time {
|
||||
headerValue := headers.Get(headerName)
|
||||
if headerValue == "" {
|
||||
panic("missing required " + headerName + " header")
|
||||
}
|
||||
return parseAnthropicResetHeaderValue(headerName, headerValue)
|
||||
}
|
||||
|
||||
func parseRateLimitResetFromHeaders(headers http.Header) time.Time {
|
||||
claim := headers.Get("anthropic-ratelimit-unified-representative-claim")
|
||||
switch claim {
|
||||
case "5h":
|
||||
return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset")
|
||||
case "7d":
|
||||
return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset")
|
||||
default:
|
||||
panic("invalid anthropic-ratelimit-unified-representative-claim header: " + strconv.Quote(claim))
|
||||
}
|
||||
return left.AccessToken == right.AccessToken &&
|
||||
left.RefreshToken == right.RefreshToken &&
|
||||
left.ExpiresAt == right.ExpiresAt &&
|
||||
slices.Equal(left.Scopes, right.Scopes) &&
|
||||
left.SubscriptionType == right.SubscriptionType &&
|
||||
left.RateLimitTier == right.RateLimitTier &&
|
||||
left.IsMax == right.IsMax
|
||||
}
|
||||
|
||||
192
service/ccm/credential_builder.go
Normal file
192
service/ccm/credential_builder.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func buildCredentialProviders(
|
||||
ctx context.Context,
|
||||
options option.CCMServiceOptions,
|
||||
logger log.ContextLogger,
|
||||
) (map[string]credentialProvider, []credential, error) {
|
||||
allCredentialMap := make(map[string]credential)
|
||||
var allCreds []credential
|
||||
providers := make(map[string]credentialProvider)
|
||||
|
||||
// Pass 1: create default and external credentials
|
||||
for _, credOpt := range options.Credentials {
|
||||
switch credOpt.Type {
|
||||
case "default":
|
||||
cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
allCredentialMap[credOpt.Tag] = cred
|
||||
allCreds = append(allCreds, cred)
|
||||
providers[credOpt.Tag] = &singleCredentialProvider{cred: cred}
|
||||
case "external":
|
||||
cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
allCredentialMap[credOpt.Tag] = cred
|
||||
allCreds = append(allCreds, cred)
|
||||
providers[credOpt.Tag] = &singleCredentialProvider{cred: cred}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2: create balancer providers
|
||||
for _, credOpt := range options.Credentials {
|
||||
if credOpt.Type == "balancer" {
|
||||
subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger)
|
||||
}
|
||||
}
|
||||
|
||||
return providers, allCreds, nil
|
||||
}
|
||||
|
||||
func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) {
|
||||
credentials := make([]credential, 0, len(tags))
|
||||
for _, tag := range tags {
|
||||
cred, exists := allCredentials[tag]
|
||||
if !exists {
|
||||
return nil, E.New("credential ", parentTag, " references unknown credential: ", tag)
|
||||
}
|
||||
credentials = append(credentials, cred)
|
||||
}
|
||||
if len(credentials) == 0 {
|
||||
return nil, E.New("credential ", parentTag, " has no sub-credentials")
|
||||
}
|
||||
return credentials, nil
|
||||
}
|
||||
|
||||
func validateCCMOptions(options option.CCMServiceOptions) error {
|
||||
hasCredentials := len(options.Credentials) > 0
|
||||
hasLegacyPath := options.CredentialPath != ""
|
||||
hasLegacyUsages := options.UsagesPath != ""
|
||||
hasLegacyDetour := options.Detour != ""
|
||||
|
||||
if hasCredentials && hasLegacyPath {
|
||||
return E.New("credential_path and credentials are mutually exclusive")
|
||||
}
|
||||
if hasCredentials && hasLegacyUsages {
|
||||
return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials")
|
||||
}
|
||||
if hasCredentials && hasLegacyDetour {
|
||||
return E.New("detour and credentials are mutually exclusive; use detour on individual credentials")
|
||||
}
|
||||
|
||||
if hasCredentials {
|
||||
tags := make(map[string]bool)
|
||||
credentialTypes := make(map[string]string)
|
||||
for _, cred := range options.Credentials {
|
||||
if tags[cred.Tag] {
|
||||
return E.New("duplicate credential tag: ", cred.Tag)
|
||||
}
|
||||
tags[cred.Tag] = true
|
||||
credentialTypes[cred.Tag] = cred.Type
|
||||
if cred.Type == "default" || cred.Type == "" {
|
||||
if cred.DefaultOptions.Reserve5h > 99 {
|
||||
return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99")
|
||||
}
|
||||
if cred.DefaultOptions.ReserveWeekly > 99 {
|
||||
return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99")
|
||||
}
|
||||
if cred.DefaultOptions.Limit5h > 100 {
|
||||
return E.New("credential ", cred.Tag, ": limit_5h must be at most 100")
|
||||
}
|
||||
if cred.DefaultOptions.LimitWeekly > 100 {
|
||||
return E.New("credential ", cred.Tag, ": limit_weekly must be at most 100")
|
||||
}
|
||||
if cred.DefaultOptions.Reserve5h > 0 && cred.DefaultOptions.Limit5h > 0 {
|
||||
return E.New("credential ", cred.Tag, ": reserve_5h and limit_5h are mutually exclusive")
|
||||
}
|
||||
if cred.DefaultOptions.ReserveWeekly > 0 && cred.DefaultOptions.LimitWeekly > 0 {
|
||||
return E.New("credential ", cred.Tag, ": reserve_weekly and limit_weekly are mutually exclusive")
|
||||
}
|
||||
}
|
||||
if cred.Type == "external" {
|
||||
if cred.ExternalOptions.Token == "" {
|
||||
return E.New("credential ", cred.Tag, ": external credential requires token")
|
||||
}
|
||||
if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" {
|
||||
return E.New("credential ", cred.Tag, ": reverse external credential requires url")
|
||||
}
|
||||
}
|
||||
if cred.Type == "balancer" {
|
||||
switch cred.BalancerOptions.Strategy {
|
||||
case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback:
|
||||
default:
|
||||
return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy)
|
||||
}
|
||||
if cred.BalancerOptions.RebalanceThreshold < 0 {
|
||||
return E.New("credential ", cred.Tag, ": rebalance_threshold must not be negative")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, user := range options.Users {
|
||||
if user.Credential == "" {
|
||||
return E.New("user ", user.Name, " must specify credential in multi-credential mode")
|
||||
}
|
||||
if !tags[user.Credential] {
|
||||
return E.New("user ", user.Name, " references unknown credential: ", user.Credential)
|
||||
}
|
||||
if user.ExternalCredential != "" {
|
||||
if !tags[user.ExternalCredential] {
|
||||
return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential)
|
||||
}
|
||||
if credentialTypes[user.ExternalCredential] != "external" {
|
||||
return E.New("user ", user.Name, ": external_credential must reference an external type credential")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func credentialForUser(
|
||||
userConfigMap map[string]*option.CCMUser,
|
||||
providers map[string]credentialProvider,
|
||||
legacyProvider credentialProvider,
|
||||
username string,
|
||||
) (credentialProvider, error) {
|
||||
if legacyProvider != nil {
|
||||
return legacyProvider, nil
|
||||
}
|
||||
userConfig, exists := userConfigMap[username]
|
||||
if !exists {
|
||||
return nil, E.New("no credential mapping for user: ", username)
|
||||
}
|
||||
provider, exists := providers[userConfig.Credential]
|
||||
if !exists {
|
||||
return nil, E.New("unknown credential: ", userConfig.Credential)
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func noUserCredentialProvider(
|
||||
providers map[string]credentialProvider,
|
||||
legacyProvider credentialProvider,
|
||||
options option.CCMServiceOptions,
|
||||
) credentialProvider {
|
||||
if legacyProvider != nil {
|
||||
return legacyProvider
|
||||
}
|
||||
if len(options.Credentials) > 0 {
|
||||
tag := options.Credentials[0].Tag
|
||||
return providers[tag]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
726
service/ccm/credential_default.go
Normal file
726
service/ccm/credential_default.go
Normal file
@@ -0,0 +1,726 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/fswatch"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
)
|
||||
|
||||
type defaultCredential struct {
|
||||
tag string
|
||||
serviceContext context.Context
|
||||
credentialPath string
|
||||
credentialFilePath string
|
||||
credentials *oauthCredentials
|
||||
access sync.RWMutex
|
||||
state credentialState
|
||||
stateAccess sync.RWMutex
|
||||
pollAccess sync.Mutex
|
||||
reloadAccess sync.Mutex
|
||||
watcherAccess sync.Mutex
|
||||
cap5h float64
|
||||
capWeekly float64
|
||||
usageTracker *AggregatedUsage
|
||||
forwardHTTPClient *http.Client
|
||||
logger log.ContextLogger
|
||||
watcher *fswatch.Watcher
|
||||
watcherRetryAt time.Time
|
||||
|
||||
// Connection interruption
|
||||
onBecameUnusable func()
|
||||
interrupted bool
|
||||
requestContext context.Context
|
||||
cancelRequests context.CancelFunc
|
||||
requestAccess sync.Mutex
|
||||
}
|
||||
|
||||
func newDefaultCredential(ctx context.Context, tag string, options option.CCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) {
|
||||
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
Options: option.DialerOptions{
|
||||
Detour: options.Detour,
|
||||
},
|
||||
RemoteIsDomain: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create dialer for credential ", tag)
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSClientConfig: &stdTLS.Config{
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
},
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
},
|
||||
}
|
||||
reserve5h := options.Reserve5h
|
||||
if reserve5h == 0 {
|
||||
reserve5h = 1
|
||||
}
|
||||
reserveWeekly := options.ReserveWeekly
|
||||
if reserveWeekly == 0 {
|
||||
reserveWeekly = 1
|
||||
}
|
||||
var cap5h float64
|
||||
if options.Limit5h > 0 {
|
||||
cap5h = float64(options.Limit5h)
|
||||
} else {
|
||||
cap5h = float64(100 - reserve5h)
|
||||
}
|
||||
var capWeekly float64
|
||||
if options.LimitWeekly > 0 {
|
||||
capWeekly = float64(options.LimitWeekly)
|
||||
} else {
|
||||
capWeekly = float64(100 - reserveWeekly)
|
||||
}
|
||||
requestContext, cancelRequests := context.WithCancel(context.Background())
|
||||
credential := &defaultCredential{
|
||||
tag: tag,
|
||||
serviceContext: ctx,
|
||||
credentialPath: options.CredentialPath,
|
||||
cap5h: cap5h,
|
||||
capWeekly: capWeekly,
|
||||
forwardHTTPClient: httpClient,
|
||||
logger: logger,
|
||||
requestContext: requestContext,
|
||||
cancelRequests: cancelRequests,
|
||||
}
|
||||
if options.UsagesPath != "" {
|
||||
credential.usageTracker = &AggregatedUsage{
|
||||
LastUpdated: time.Now(),
|
||||
Combinations: make([]CostCombination, 0),
|
||||
filePath: options.UsagesPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
return credential, nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) start() error {
|
||||
credentialFilePath, err := resolveCredentialFilePath(c.credentialPath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "resolve credential path for ", c.tag)
|
||||
}
|
||||
c.credentialFilePath = credentialFilePath
|
||||
err = c.ensureCredentialWatcher()
|
||||
if err != nil {
|
||||
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
err = c.reloadCredentials(true)
|
||||
if err != nil {
|
||||
c.logger.Warn("initial credential load for ", c.tag, ": ", err)
|
||||
}
|
||||
if c.usageTracker != nil {
|
||||
err = c.usageTracker.Load()
|
||||
if err != nil {
|
||||
c.logger.Warn("load usage statistics for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) getAccessToken() (string, error) {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
c.access.RLock()
|
||||
if c.credentials != nil && !c.credentials.needsRefresh() {
|
||||
token := c.credentials.AccessToken
|
||||
c.access.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
c.access.RUnlock()
|
||||
|
||||
err := c.reloadCredentials(true)
|
||||
if err == nil {
|
||||
c.access.RLock()
|
||||
if c.credentials != nil && !c.credentials.needsRefresh() {
|
||||
token := c.credentials.AccessToken
|
||||
c.access.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
c.access.RUnlock()
|
||||
}
|
||||
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
|
||||
if c.credentials == nil {
|
||||
return "", c.unavailableError()
|
||||
}
|
||||
if !c.credentials.needsRefresh() {
|
||||
return c.credentials.AccessToken, nil
|
||||
}
|
||||
|
||||
err = platformCanWriteCredentials(c.credentialPath)
|
||||
if err != nil {
|
||||
return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation")
|
||||
}
|
||||
|
||||
baseCredentials := cloneCredentials(c.credentials)
|
||||
newCredentials, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
latestCredentials, latestErr := platformReadCredentials(c.credentialPath)
|
||||
if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) {
|
||||
c.credentials = latestCredentials
|
||||
c.stateAccess.Lock()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.state.accountType = latestCredentials.SubscriptionType
|
||||
c.state.rateLimitTier = latestCredentials.RateLimitTier
|
||||
c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if !latestCredentials.needsRefresh() {
|
||||
return latestCredentials.AccessToken, nil
|
||||
}
|
||||
return "", E.New("credential ", c.tag, " changed while refreshing")
|
||||
}
|
||||
|
||||
c.credentials = newCredentials
|
||||
c.stateAccess.Lock()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.state.accountType = newCredentials.SubscriptionType
|
||||
c.state.rateLimitTier = newCredentials.RateLimitTier
|
||||
c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
|
||||
err = platformWriteCredentials(newCredentials, c.credentialPath)
|
||||
if err != nil {
|
||||
c.logger.Error("persist refreshed token for ", c.tag, ": ", err)
|
||||
}
|
||||
|
||||
return newCredentials.AccessToken, nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
|
||||
c.stateAccess.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
hadData := false
|
||||
|
||||
fiveHourResetChanged := false
|
||||
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists {
|
||||
hadData = true
|
||||
if value.After(c.state.fiveHourReset) {
|
||||
fiveHourResetChanged = true
|
||||
c.state.fiveHourReset = value
|
||||
}
|
||||
}
|
||||
if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" {
|
||||
value, err := strconv.ParseFloat(utilization, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
newValue := math.Ceil(value * 100)
|
||||
if newValue >= c.state.fiveHourUtilization || fiveHourResetChanged {
|
||||
c.state.fiveHourUtilization = newValue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
weeklyResetChanged := false
|
||||
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists {
|
||||
hadData = true
|
||||
if value.After(c.state.weeklyReset) {
|
||||
weeklyResetChanged = true
|
||||
c.state.weeklyReset = value
|
||||
}
|
||||
}
|
||||
if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" {
|
||||
value, err := strconv.ParseFloat(utilization, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
newValue := math.Ceil(value * 100)
|
||||
if newValue >= c.state.weeklyUtilization || weeklyResetChanged {
|
||||
c.state.weeklyUtilization = newValue
|
||||
}
|
||||
}
|
||||
}
|
||||
if hadData {
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.lastUpdated = time.Now()
|
||||
}
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
resetSuffix := ""
|
||||
if !c.state.weeklyReset.IsZero() {
|
||||
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
|
||||
}
|
||||
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markRateLimited(resetAt time.Time) {
|
||||
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
|
||||
c.stateAccess.Lock()
|
||||
c.state.hardRateLimited = true
|
||||
c.state.rateLimitResetAt = resetAt
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isUsable() bool {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
c.stateAccess.RLock()
|
||||
if c.state.unavailable {
|
||||
c.stateAccess.RUnlock()
|
||||
return false
|
||||
}
|
||||
if c.state.consecutivePollFailures > 0 {
|
||||
c.stateAccess.RUnlock()
|
||||
return false
|
||||
}
|
||||
if c.state.hardRateLimited {
|
||||
if time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.stateAccess.RUnlock()
|
||||
return false
|
||||
}
|
||||
c.stateAccess.RUnlock()
|
||||
c.stateAccess.Lock()
|
||||
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
usable := c.checkReservesLocked()
|
||||
c.stateAccess.Unlock()
|
||||
return usable
|
||||
}
|
||||
usable := c.checkReservesLocked()
|
||||
c.stateAccess.RUnlock()
|
||||
return usable
|
||||
}
|
||||
|
||||
func (c *defaultCredential) checkReservesLocked() bool {
|
||||
if c.state.fiveHourUtilization >= c.cap5h {
|
||||
return false
|
||||
}
|
||||
if c.state.weeklyUtilization >= c.capWeekly {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// checkTransitionLocked detects usable→unusable transition.
|
||||
// Must be called with stateAccess write lock held.
|
||||
func (c *defaultCredential) checkTransitionLocked() bool {
|
||||
unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() || c.state.consecutivePollFailures > 0
|
||||
if unusable && !c.interrupted {
|
||||
c.interrupted = true
|
||||
return true
|
||||
}
|
||||
if !unusable && c.interrupted {
|
||||
c.interrupted = false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *defaultCredential) interruptConnections() {
|
||||
c.logger.Warn("interrupting connections for ", c.tag)
|
||||
c.requestAccess.Lock()
|
||||
c.cancelRequests()
|
||||
c.requestContext, c.cancelRequests = context.WithCancel(context.Background())
|
||||
c.requestAccess.Unlock()
|
||||
if c.onBecameUnusable != nil {
|
||||
c.onBecameUnusable()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext {
|
||||
c.requestAccess.Lock()
|
||||
credentialContext := c.requestContext
|
||||
c.requestAccess.Unlock()
|
||||
derived, cancel := context.WithCancel(parent)
|
||||
stop := context.AfterFunc(credentialContext, func() {
|
||||
cancel()
|
||||
})
|
||||
return &credentialRequestContext{
|
||||
Context: derived,
|
||||
releaseFuncs: []func() bool{stop},
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) weeklyUtilization() float64 {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.weeklyUtilization
|
||||
}
|
||||
|
||||
func (c *defaultCredential) planWeight() float64 {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return ccmPlanWeight(c.state.accountType, c.state.rateLimitTier)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) weeklyResetTime() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.weeklyReset
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isAvailable() bool {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return !c.state.unavailable
|
||||
}
|
||||
|
||||
func (c *defaultCredential) unavailableError() error {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
if !c.state.unavailable {
|
||||
return nil
|
||||
}
|
||||
if c.state.lastCredentialLoadError == "" {
|
||||
return E.New("credential ", c.tag, " is unavailable")
|
||||
}
|
||||
return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) lastUpdatedTime() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.lastUpdated
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markUsagePollAttempted() {
|
||||
c.stateAccess.Lock()
|
||||
defer c.stateAccess.Unlock()
|
||||
c.state.lastUpdated = time.Now()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) incrementPollFailures() {
|
||||
c.stateAccess.Lock()
|
||||
c.state.consecutivePollFailures++
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration {
|
||||
c.stateAccess.RLock()
|
||||
failures := c.state.consecutivePollFailures
|
||||
c.stateAccess.RUnlock()
|
||||
if failures <= 0 {
|
||||
return baseInterval
|
||||
}
|
||||
backoff := failedPollRetryInterval * time.Duration(1<<(failures-1))
|
||||
if backoff > httpRetryMaxBackoff {
|
||||
return httpRetryMaxBackoff
|
||||
}
|
||||
return backoff
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isPollBackoffAtCap() bool {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
failures := c.state.consecutivePollFailures
|
||||
return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff
|
||||
}
|
||||
|
||||
func (c *defaultCredential) earliestReset() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
if c.state.unavailable {
|
||||
return time.Time{}
|
||||
}
|
||||
if c.state.hardRateLimited {
|
||||
return c.state.rateLimitResetAt
|
||||
}
|
||||
earliest := c.state.fiveHourReset
|
||||
if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) {
|
||||
earliest = c.state.weeklyReset
|
||||
}
|
||||
return earliest
|
||||
}
|
||||
|
||||
func (c *defaultCredential) pollUsage(ctx context.Context) {
|
||||
if !c.pollAccess.TryLock() {
|
||||
return
|
||||
}
|
||||
defer c.pollAccess.Unlock()
|
||||
defer c.markUsagePollAttempted()
|
||||
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
if !c.isAvailable() {
|
||||
return
|
||||
}
|
||||
|
||||
accessToken, err := c.getAccessToken()
|
||||
if err != nil {
|
||||
if !c.isPollBackoffAtCap() {
|
||||
c.logger.Error("poll usage for ", c.tag, ": get token: ", err)
|
||||
}
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: c.forwardHTTPClient.Transport,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("User-Agent", ccmUserAgentValue)
|
||||
request.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
|
||||
return request, nil
|
||||
})
|
||||
if err != nil {
|
||||
if !c.isPollBackoffAtCap() {
|
||||
c.logger.Error("poll usage for ", c.tag, ": ", err)
|
||||
}
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
if response.StatusCode == http.StatusTooManyRequests {
|
||||
c.logger.Warn("poll usage for ", c.tag, ": rate limited")
|
||||
}
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
|
||||
var usageResponse struct {
|
||||
FiveHour struct {
|
||||
Utilization float64 `json:"utilization"`
|
||||
ResetsAt time.Time `json:"resets_at"`
|
||||
} `json:"five_hour"`
|
||||
SevenDay struct {
|
||||
Utilization float64 `json:"utilization"`
|
||||
ResetsAt time.Time `json:"resets_at"`
|
||||
} `json:"seven_day"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&usageResponse)
|
||||
if err != nil {
|
||||
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
|
||||
c.stateAccess.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.fiveHourUtilization = usageResponse.FiveHour.Utilization
|
||||
if !usageResponse.FiveHour.ResetsAt.IsZero() {
|
||||
c.state.fiveHourReset = usageResponse.FiveHour.ResetsAt
|
||||
}
|
||||
c.state.weeklyUtilization = usageResponse.SevenDay.Utilization
|
||||
if !usageResponse.SevenDay.ResetsAt.IsZero() {
|
||||
c.state.weeklyReset = usageResponse.SevenDay.ResetsAt
|
||||
}
|
||||
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
resetSuffix := ""
|
||||
if !c.state.weeklyReset.IsZero() {
|
||||
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
|
||||
}
|
||||
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
needsProfileFetch := c.state.rateLimitTier == ""
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
|
||||
if needsProfileFetch {
|
||||
c.fetchProfile(ctx, httpClient, accessToken)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.Client, accessToken string) {
|
||||
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/profile", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("User-Agent", ccmUserAgentValue)
|
||||
return request, nil
|
||||
})
|
||||
if err != nil {
|
||||
c.logger.Debug("fetch profile for ", c.tag, ": ", err)
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
var profileResponse struct {
|
||||
Organization *struct {
|
||||
OrganizationType string `json:"organization_type"`
|
||||
RateLimitTier string `json:"rate_limit_tier"`
|
||||
} `json:"organization"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&profileResponse)
|
||||
if err != nil || profileResponse.Organization == nil {
|
||||
return
|
||||
}
|
||||
|
||||
accountType := ""
|
||||
switch profileResponse.Organization.OrganizationType {
|
||||
case "claude_pro":
|
||||
accountType = "pro"
|
||||
case "claude_max":
|
||||
accountType = "max"
|
||||
case "claude_team":
|
||||
accountType = "team"
|
||||
case "claude_enterprise":
|
||||
accountType = "enterprise"
|
||||
}
|
||||
rateLimitTier := profileResponse.Organization.RateLimitTier
|
||||
|
||||
c.stateAccess.Lock()
|
||||
if accountType != "" && c.state.accountType == "" {
|
||||
c.state.accountType = accountType
|
||||
}
|
||||
if rateLimitTier != "" {
|
||||
c.state.rateLimitTier = rateLimitTier
|
||||
}
|
||||
c.stateAccess.Unlock()
|
||||
c.logger.Info("fetched profile for ", c.tag, ": type=", c.state.accountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(c.state.accountType, rateLimitTier))
|
||||
}
|
||||
|
||||
func (c *defaultCredential) close() {
|
||||
if c.watcher != nil {
|
||||
err := c.watcher.Close()
|
||||
if err != nil {
|
||||
c.logger.Error("close credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
if c.usageTracker != nil {
|
||||
c.usageTracker.cancelPendingSave()
|
||||
err := c.usageTracker.Save()
|
||||
if err != nil {
|
||||
c.logger.Error("save usage statistics for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) tagName() string {
|
||||
return c.tag
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isExternal() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *defaultCredential) fiveHourUtilization() float64 {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.fiveHourUtilization
|
||||
}
|
||||
|
||||
func (c *defaultCredential) fiveHourCap() float64 {
|
||||
return c.cap5h
|
||||
}
|
||||
|
||||
func (c *defaultCredential) weeklyCap() float64 {
|
||||
return c.capWeekly
|
||||
}
|
||||
|
||||
func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage {
|
||||
return c.usageTracker
|
||||
}
|
||||
|
||||
func (c *defaultCredential) httpClient() *http.Client {
|
||||
return c.forwardHTTPClient
|
||||
}
|
||||
|
||||
func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) {
|
||||
accessToken, err := c.getAccessToken()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "get access token for ", c.tag)
|
||||
}
|
||||
|
||||
proxyURL := claudeAPIBaseURL + original.URL.RequestURI()
|
||||
var body io.Reader
|
||||
if bodyBytes != nil {
|
||||
body = bytes.NewReader(bodyBytes)
|
||||
} else {
|
||||
body = original.Body
|
||||
}
|
||||
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, values := range original.Header {
|
||||
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
serviceOverridesAcceptEncoding := len(serviceHeaders.Values("Accept-Encoding")) > 0
|
||||
if c.usageTracker != nil && !serviceOverridesAcceptEncoding {
|
||||
proxyRequest.Header.Del("Accept-Encoding")
|
||||
}
|
||||
|
||||
anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta")
|
||||
if anthropicBetaHeader != "" {
|
||||
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader)
|
||||
} else {
|
||||
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
|
||||
}
|
||||
|
||||
for key, values := range serviceHeaders {
|
||||
proxyRequest.Header.Del(key)
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
return proxyRequest, nil
|
||||
}
|
||||
@@ -29,16 +29,16 @@ import (
|
||||
const reverseProxyBaseURL = "http://reverse-proxy"
|
||||
|
||||
type externalCredential struct {
|
||||
tag string
|
||||
baseURL string
|
||||
token string
|
||||
httpClient *http.Client
|
||||
state credentialState
|
||||
stateMutex sync.RWMutex
|
||||
pollAccess sync.Mutex
|
||||
pollInterval time.Duration
|
||||
usageTracker *AggregatedUsage
|
||||
logger log.ContextLogger
|
||||
tag string
|
||||
baseURL string
|
||||
token string
|
||||
forwardHTTPClient *http.Client
|
||||
state credentialState
|
||||
stateAccess sync.RWMutex
|
||||
pollAccess sync.Mutex
|
||||
pollInterval time.Duration
|
||||
usageTracker *AggregatedUsage
|
||||
logger log.ContextLogger
|
||||
|
||||
onBecameUnusable func()
|
||||
interrupted bool
|
||||
@@ -128,7 +128,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx
|
||||
if options.URL == "" {
|
||||
// Receiver mode: no URL, wait for reverse connection
|
||||
cred.baseURL = reverseProxyBaseURL
|
||||
cred.httpClient = &http.Client{
|
||||
cred.forwardHTTPClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
ForceAttemptHTTP2: false,
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
@@ -192,10 +192,10 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
}
|
||||
}
|
||||
cred.httpClient = &http.Client{Transport: transport}
|
||||
cred.forwardHTTPClient = &http.Client{Transport: transport}
|
||||
} else {
|
||||
// Normal mode: standard HTTP client for proxying
|
||||
cred.httpClient = &http.Client{Transport: transport}
|
||||
cred.forwardHTTPClient = &http.Client{Transport: transport}
|
||||
cred.reverseHttpClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
ForceAttemptHTTP2: false,
|
||||
@@ -248,40 +248,40 @@ func (c *externalCredential) isUsable() bool {
|
||||
if !c.isAvailable() {
|
||||
return false
|
||||
}
|
||||
c.stateMutex.RLock()
|
||||
c.stateAccess.RLock()
|
||||
if c.state.consecutivePollFailures > 0 {
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateAccess.RUnlock()
|
||||
return false
|
||||
}
|
||||
if c.state.hardRateLimited {
|
||||
if time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateAccess.RUnlock()
|
||||
return false
|
||||
}
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.RUnlock()
|
||||
c.stateAccess.Lock()
|
||||
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
// No reserve for external: only 100% is unusable
|
||||
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
return usable
|
||||
}
|
||||
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateAccess.RUnlock()
|
||||
return usable
|
||||
}
|
||||
|
||||
func (c *externalCredential) fiveHourUtilization() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.fiveHourUtilization
|
||||
}
|
||||
|
||||
func (c *externalCredential) weeklyUtilization() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.weeklyUtilization
|
||||
}
|
||||
|
||||
@@ -294,8 +294,8 @@ func (c *externalCredential) weeklyCap() float64 {
|
||||
}
|
||||
|
||||
func (c *externalCredential) planWeight() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
if c.state.remotePlanWeight > 0 {
|
||||
return c.state.remotePlanWeight
|
||||
}
|
||||
@@ -303,26 +303,26 @@ func (c *externalCredential) planWeight() float64 {
|
||||
}
|
||||
|
||||
func (c *externalCredential) weeklyResetTime() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.weeklyReset
|
||||
}
|
||||
|
||||
func (c *externalCredential) markRateLimited(resetAt time.Time) {
|
||||
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.Lock()
|
||||
c.state.hardRateLimited = true
|
||||
c.state.rateLimitResetAt = resetAt
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) earliestReset() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
if c.state.hardRateLimited {
|
||||
return c.state.rateLimitResetAt
|
||||
}
|
||||
@@ -408,7 +408,7 @@ func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Con
|
||||
}
|
||||
|
||||
func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
@@ -455,7 +455,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
|
||||
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
@@ -530,9 +530,9 @@ func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Resp
|
||||
}
|
||||
}
|
||||
// Forward transport with retries
|
||||
if c.httpClient != nil {
|
||||
if c.forwardHTTPClient != nil {
|
||||
forwardClient := &http.Client{
|
||||
Transport: c.httpClient.Transport,
|
||||
Transport: c.forwardHTTPClient.Transport,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
return doHTTPWithRetry(ctx, forwardClient, buildRequest(c.baseURL))
|
||||
@@ -563,10 +563,10 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
|
||||
// 404 means the remote does not have a status endpoint yet;
|
||||
// usage will be updated passively from response headers.
|
||||
if response.StatusCode == http.StatusNotFound {
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.Lock()
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
} else {
|
||||
c.incrementPollFailures()
|
||||
}
|
||||
@@ -585,7 +585,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
@@ -606,28 +606,28 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
|
||||
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) lastUpdatedTime() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.lastUpdated
|
||||
}
|
||||
|
||||
func (c *externalCredential) markUsagePollAttempted() {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
c.stateAccess.Lock()
|
||||
defer c.stateAccess.Unlock()
|
||||
c.state.lastUpdated = time.Now()
|
||||
}
|
||||
|
||||
func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration {
|
||||
c.stateMutex.RLock()
|
||||
c.stateAccess.RLock()
|
||||
failures := c.state.consecutivePollFailures
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateAccess.RUnlock()
|
||||
if failures <= 0 {
|
||||
return baseInterval
|
||||
}
|
||||
@@ -639,17 +639,17 @@ func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Durati
|
||||
}
|
||||
|
||||
func (c *externalCredential) isPollBackoffAtCap() bool {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
failures := c.state.consecutivePollFailures
|
||||
return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff
|
||||
}
|
||||
|
||||
func (c *externalCredential) incrementPollFailures() {
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.Lock()
|
||||
c.state.consecutivePollFailures++
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
@@ -659,14 +659,14 @@ func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage {
|
||||
return c.usageTracker
|
||||
}
|
||||
|
||||
func (c *externalCredential) httpTransport() *http.Client {
|
||||
func (c *externalCredential) httpClient() *http.Client {
|
||||
if c.reverseHttpClient != nil {
|
||||
session := c.getReverseSession()
|
||||
if session != nil && !session.IsClosed() {
|
||||
return c.reverseHttpClient
|
||||
}
|
||||
}
|
||||
return c.httpClient
|
||||
return c.forwardHTTPClient
|
||||
}
|
||||
|
||||
func (c *externalCredential) close() {
|
||||
|
||||
@@ -62,10 +62,10 @@ func (c *defaultCredential) ensureCredentialWatcher() error {
|
||||
}
|
||||
|
||||
func (c *defaultCredential) retryCredentialReloadIfNeeded() {
|
||||
c.stateMutex.RLock()
|
||||
c.stateAccess.RLock()
|
||||
unavailable := c.state.unavailable
|
||||
lastAttempt := c.state.lastCredentialLoadAttempt
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateAccess.RUnlock()
|
||||
if !unavailable {
|
||||
return
|
||||
}
|
||||
@@ -84,10 +84,10 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
|
||||
c.reloadAccess.Lock()
|
||||
defer c.reloadAccess.Unlock()
|
||||
|
||||
c.stateMutex.RLock()
|
||||
c.stateAccess.RLock()
|
||||
unavailable := c.state.unavailable
|
||||
lastAttempt := c.state.lastCredentialLoadAttempt
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateAccess.RUnlock()
|
||||
if !force {
|
||||
if !unavailable {
|
||||
return nil
|
||||
@@ -97,43 +97,43 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
|
||||
}
|
||||
}
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.Lock()
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
|
||||
credentials, err := platformReadCredentials(c.credentialPath)
|
||||
if err != nil {
|
||||
return c.markCredentialsUnavailable(E.Cause(err, "read credentials"))
|
||||
}
|
||||
|
||||
c.accessMutex.Lock()
|
||||
c.access.Lock()
|
||||
c.credentials = credentials
|
||||
c.accessMutex.Unlock()
|
||||
c.access.Unlock()
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.Lock()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.state.accountType = credentials.SubscriptionType
|
||||
c.state.rateLimitTier = credentials.RateLimitTier
|
||||
c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markCredentialsUnavailable(err error) error {
|
||||
c.accessMutex.Lock()
|
||||
c.access.Lock()
|
||||
hadCredentials := c.credentials != nil
|
||||
c.credentials = nil
|
||||
c.accessMutex.Unlock()
|
||||
c.access.Unlock()
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.Lock()
|
||||
c.state.unavailable = true
|
||||
c.state.lastCredentialLoadError = err.Error()
|
||||
c.state.accountType = ""
|
||||
c.state.rateLimitTier = ""
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
|
||||
if shouldInterrupt && hadCredentials {
|
||||
c.interruptConnections()
|
||||
|
||||
224
service/ccm/credential_oauth.go
Normal file
224
service/ccm/credential_oauth.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
const (
|
||||
oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
oauth2TokenURL = "https://platform.claude.com/v1/oauth/token"
|
||||
claudeAPIBaseURL = "https://api.anthropic.com"
|
||||
tokenRefreshBufferMs = 60000
|
||||
anthropicBetaOAuthValue = "oauth-2025-04-20"
|
||||
)
|
||||
|
||||
const ccmUserAgentFallback = "claude-code/2.1.72"
|
||||
|
||||
var (
|
||||
ccmUserAgentOnce sync.Once
|
||||
ccmUserAgentValue string
|
||||
)
|
||||
|
||||
func initCCMUserAgent(logger log.ContextLogger) {
|
||||
ccmUserAgentOnce.Do(func() {
|
||||
version, err := detectClaudeCodeVersion()
|
||||
if err != nil {
|
||||
logger.Error("detect Claude Code version: ", err)
|
||||
ccmUserAgentValue = ccmUserAgentFallback
|
||||
return
|
||||
}
|
||||
logger.Debug("detected Claude Code version: ", version)
|
||||
ccmUserAgentValue = "claude-code/" + version
|
||||
})
|
||||
}
|
||||
|
||||
func detectClaudeCodeVersion() (string, error) {
|
||||
userInfo, err := getRealUser()
|
||||
if err != nil {
|
||||
return "", E.Cause(err, "get user")
|
||||
}
|
||||
binaryName := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
binaryName = "claude.exe"
|
||||
}
|
||||
linkPath := filepath.Join(userInfo.HomeDir, ".local", "bin", binaryName)
|
||||
target, err := os.Readlink(linkPath)
|
||||
if err != nil {
|
||||
return "", E.Cause(err, "readlink ", linkPath)
|
||||
}
|
||||
if !filepath.IsAbs(target) {
|
||||
target = filepath.Join(filepath.Dir(linkPath), target)
|
||||
}
|
||||
parent := filepath.Base(filepath.Dir(target))
|
||||
if parent != "versions" {
|
||||
return "", E.New("unexpected symlink target: ", target)
|
||||
}
|
||||
return filepath.Base(target), nil
|
||||
}
|
||||
|
||||
func getRealUser() (*user.User, error) {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
sudoUserInfo, err := user.Lookup(sudoUser)
|
||||
if err == nil {
|
||||
return sudoUserInfo, nil
|
||||
}
|
||||
}
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
func getDefaultCredentialsPath() (string, error) {
|
||||
if configDir := os.Getenv("CLAUDE_CONFIG_DIR"); configDir != "" {
|
||||
return filepath.Join(configDir, ".credentials.json"), nil
|
||||
}
|
||||
userInfo, err := getRealUser()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(userInfo.HomeDir, ".claude", ".credentials.json"), nil
|
||||
}
|
||||
|
||||
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var credentialsContainer struct {
|
||||
ClaudeAIAuth *oauthCredentials `json:"claudeAiOauth,omitempty"`
|
||||
}
|
||||
err = json.Unmarshal(data, &credentialsContainer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if credentialsContainer.ClaudeAIAuth == nil {
|
||||
return nil, E.New("claudeAiOauth field not found in credentials")
|
||||
}
|
||||
return credentialsContainer.ClaudeAIAuth, nil
|
||||
}
|
||||
|
||||
func checkCredentialFileWritable(path string) error {
|
||||
file, err := os.OpenFile(path, os.O_WRONLY, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) error {
|
||||
data, err := json.MarshalIndent(map[string]any{
|
||||
"claudeAiOauth": oauthCredentials,
|
||||
}, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0o600)
|
||||
}
|
||||
|
||||
type oauthCredentials struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpiresAt int64 `json:"expiresAt"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
SubscriptionType string `json:"subscriptionType,omitempty"`
|
||||
RateLimitTier string `json:"rateLimitTier,omitempty"`
|
||||
IsMax bool `json:"isMax,omitempty"`
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) needsRefresh() bool {
|
||||
if c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs
|
||||
}
|
||||
|
||||
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
|
||||
if credentials.RefreshToken == "" {
|
||||
return nil, E.New("refresh token is empty")
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.RefreshToken,
|
||||
"client_id": oauth2ClientID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "marshal request")
|
||||
}
|
||||
|
||||
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("User-Agent", ccmUserAgentValue)
|
||||
return request, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode == http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, E.New("refresh rate limited: ", response.Status, " ", string(body))
|
||||
}
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
|
||||
}
|
||||
|
||||
var tokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode response")
|
||||
}
|
||||
|
||||
newCredentials := *credentials
|
||||
newCredentials.AccessToken = tokenResponse.AccessToken
|
||||
if tokenResponse.RefreshToken != "" {
|
||||
newCredentials.RefreshToken = tokenResponse.RefreshToken
|
||||
}
|
||||
newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000
|
||||
|
||||
return &newCredentials, nil
|
||||
}
|
||||
|
||||
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
|
||||
if credentials == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *credentials
|
||||
cloned.Scopes = append([]string(nil), credentials.Scopes...)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
|
||||
if left == nil || right == nil {
|
||||
return left == right
|
||||
}
|
||||
return left.AccessToken == right.AccessToken &&
|
||||
left.RefreshToken == right.RefreshToken &&
|
||||
left.ExpiresAt == right.ExpiresAt &&
|
||||
slices.Equal(left.Scopes, right.Scopes) &&
|
||||
left.SubscriptionType == right.SubscriptionType &&
|
||||
left.RateLimitTier == right.RateLimitTier &&
|
||||
left.IsMax == right.IsMax
|
||||
}
|
||||
405
service/ccm/credential_provider.go
Normal file
405
service/ccm/credential_provider.go
Normal file
@@ -0,0 +1,405 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type credentialProvider interface {
|
||||
selectCredential(sessionID string, selection credentialSelection) (credential, bool, error)
|
||||
onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential
|
||||
linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool
|
||||
pollIfStale(ctx context.Context)
|
||||
allCredentials() []credential
|
||||
close()
|
||||
}
|
||||
|
||||
type singleCredentialProvider struct {
|
||||
cred credential
|
||||
sessionAccess sync.RWMutex
|
||||
sessions map[string]time.Time
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) {
|
||||
if !selection.allows(p.cred) {
|
||||
return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out")
|
||||
}
|
||||
if !p.cred.isAvailable() {
|
||||
return nil, false, p.cred.unavailableError()
|
||||
}
|
||||
if !p.cred.isUsable() {
|
||||
return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited")
|
||||
}
|
||||
var isNew bool
|
||||
if sessionID != "" {
|
||||
p.sessionAccess.Lock()
|
||||
if p.sessions == nil {
|
||||
p.sessions = make(map[string]time.Time)
|
||||
}
|
||||
_, exists := p.sessions[sessionID]
|
||||
if !exists {
|
||||
p.sessions[sessionID] = time.Now()
|
||||
isNew = true
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
return p.cred, isNew, nil
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential {
|
||||
cred.markRateLimited(resetAt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) pollIfStale(ctx context.Context) {
|
||||
now := time.Now()
|
||||
p.sessionAccess.Lock()
|
||||
for id, createdAt := range p.sessions {
|
||||
if now.Sub(createdAt) > sessionExpiry {
|
||||
delete(p.sessions, id)
|
||||
}
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
|
||||
if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) {
|
||||
p.cred.pollUsage(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) allCredentials() []credential {
|
||||
return []credential{p.cred}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool {
|
||||
return func() bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) close() {}
|
||||
|
||||
type sessionEntry struct {
|
||||
tag string
|
||||
selectionScope credentialSelectionScope
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
type credentialInterruptKey struct {
|
||||
tag string
|
||||
selectionScope credentialSelectionScope
|
||||
}
|
||||
|
||||
type credentialInterruptEntry struct {
|
||||
context context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type balancerProvider struct {
|
||||
credentials []credential
|
||||
strategy string
|
||||
roundRobinIndex atomic.Uint64
|
||||
pollInterval time.Duration
|
||||
rebalanceThreshold float64
|
||||
sessionAccess sync.RWMutex
|
||||
sessions map[string]sessionEntry
|
||||
interruptAccess sync.Mutex
|
||||
credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry
|
||||
logger log.ContextLogger
|
||||
}
|
||||
|
||||
func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider {
|
||||
if pollInterval <= 0 {
|
||||
pollInterval = defaultPollInterval
|
||||
}
|
||||
return &balancerProvider{
|
||||
credentials: credentials,
|
||||
strategy: strategy,
|
||||
pollInterval: pollInterval,
|
||||
rebalanceThreshold: rebalanceThreshold,
|
||||
sessions: make(map[string]sessionEntry),
|
||||
credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) {
|
||||
if p.strategy == C.BalancerStrategyFallback {
|
||||
best := p.pickCredential(selection.filter)
|
||||
if best == nil {
|
||||
return nil, false, allCredentialsUnavailableError(p.credentials)
|
||||
}
|
||||
return best, false, nil
|
||||
}
|
||||
|
||||
selectionScope := selection.scopeOrDefault()
|
||||
if sessionID != "" {
|
||||
p.sessionAccess.RLock()
|
||||
entry, exists := p.sessions[sessionID]
|
||||
p.sessionAccess.RUnlock()
|
||||
if exists {
|
||||
if entry.selectionScope == selectionScope {
|
||||
for _, cred := range p.credentials {
|
||||
if cred.tagName() == entry.tag && selection.allows(cred) && cred.isUsable() {
|
||||
if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) {
|
||||
better := p.pickLeastUsed(selection.filter)
|
||||
if better != nil && better.tagName() != cred.tagName() {
|
||||
effectiveThreshold := p.rebalanceThreshold / cred.planWeight()
|
||||
delta := cred.weeklyUtilization() - better.weeklyUtilization()
|
||||
if delta > effectiveThreshold {
|
||||
p.logger.Info("rebalancing away from ", cred.tagName(),
|
||||
": utilization delta ", delta, "% exceeds effective threshold ",
|
||||
effectiveThreshold, "% (weight ", cred.planWeight(), ")")
|
||||
p.rebalanceCredential(cred.tagName(), selectionScope)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return cred, false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
p.sessionAccess.Lock()
|
||||
delete(p.sessions, sessionID)
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
best := p.pickCredential(selection.filter)
|
||||
if best == nil {
|
||||
return nil, false, allCredentialsUnavailableError(p.credentials)
|
||||
}
|
||||
|
||||
isNew := sessionID != ""
|
||||
if isNew {
|
||||
p.sessionAccess.Lock()
|
||||
p.sessions[sessionID] = sessionEntry{
|
||||
tag: best.tagName(),
|
||||
selectionScope: selectionScope,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
return best, isNew, nil
|
||||
}
|
||||
|
||||
func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) {
|
||||
key := credentialInterruptKey{tag: tag, selectionScope: selectionScope}
|
||||
p.interruptAccess.Lock()
|
||||
if entry, loaded := p.credentialInterrupts[key]; loaded {
|
||||
entry.cancel()
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel}
|
||||
p.interruptAccess.Unlock()
|
||||
|
||||
p.sessionAccess.Lock()
|
||||
for id, entry := range p.sessions {
|
||||
if entry.tag == tag && entry.selectionScope == selectionScope {
|
||||
delete(p.sessions, id)
|
||||
}
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
|
||||
func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool {
|
||||
if p.strategy == C.BalancerStrategyFallback {
|
||||
return func() bool { return false }
|
||||
}
|
||||
key := credentialInterruptKey{
|
||||
tag: cred.tagName(),
|
||||
selectionScope: selection.scopeOrDefault(),
|
||||
}
|
||||
p.interruptAccess.Lock()
|
||||
entry, loaded := p.credentialInterrupts[key]
|
||||
if !loaded {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
entry = credentialInterruptEntry{context: ctx, cancel: cancel}
|
||||
p.credentialInterrupts[key] = entry
|
||||
}
|
||||
p.interruptAccess.Unlock()
|
||||
return context.AfterFunc(entry.context, onInterrupt)
|
||||
}
|
||||
|
||||
func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential {
|
||||
cred.markRateLimited(resetAt)
|
||||
if p.strategy == C.BalancerStrategyFallback {
|
||||
return p.pickCredential(selection.filter)
|
||||
}
|
||||
if sessionID != "" {
|
||||
p.sessionAccess.Lock()
|
||||
delete(p.sessions, sessionID)
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
|
||||
best := p.pickCredential(selection.filter)
|
||||
if best != nil && sessionID != "" {
|
||||
p.sessionAccess.Lock()
|
||||
p.sessions[sessionID] = sessionEntry{
|
||||
tag: best.tagName(),
|
||||
selectionScope: selection.scopeOrDefault(),
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickCredential(filter func(credential) bool) credential {
|
||||
switch p.strategy {
|
||||
case C.BalancerStrategyRoundRobin:
|
||||
return p.pickRoundRobin(filter)
|
||||
case C.BalancerStrategyRandom:
|
||||
return p.pickRandom(filter)
|
||||
case C.BalancerStrategyFallback:
|
||||
return p.pickFallback(filter)
|
||||
default:
|
||||
return p.pickLeastUsed(filter)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickFallback(filter func(credential) bool) credential {
|
||||
for _, cred := range p.credentials {
|
||||
if filter != nil && !filter(cred) {
|
||||
continue
|
||||
}
|
||||
if cred.isUsable() {
|
||||
return cred
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const weeklyWindowHours = 7 * 24
|
||||
|
||||
func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential {
|
||||
var best credential
|
||||
bestScore := float64(-1)
|
||||
now := time.Now()
|
||||
for _, cred := range p.credentials {
|
||||
if filter != nil && !filter(cred) {
|
||||
continue
|
||||
}
|
||||
if !cred.isUsable() {
|
||||
continue
|
||||
}
|
||||
remaining := cred.weeklyCap() - cred.weeklyUtilization()
|
||||
score := remaining * cred.planWeight()
|
||||
resetTime := cred.weeklyResetTime()
|
||||
if !resetTime.IsZero() {
|
||||
timeUntilReset := resetTime.Sub(now)
|
||||
if timeUntilReset < time.Hour {
|
||||
timeUntilReset = time.Hour
|
||||
}
|
||||
score *= weeklyWindowHours / timeUntilReset.Hours()
|
||||
}
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
best = cred
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential {
|
||||
start := int(p.roundRobinIndex.Add(1) - 1)
|
||||
count := len(p.credentials)
|
||||
for offset := range count {
|
||||
candidate := p.credentials[(start+offset)%count]
|
||||
if filter != nil && !filter(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickRandom(filter func(credential) bool) credential {
|
||||
var usable []credential
|
||||
for _, candidate := range p.credentials {
|
||||
if filter != nil && !filter(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
usable = append(usable, candidate)
|
||||
}
|
||||
}
|
||||
if len(usable) == 0 {
|
||||
return nil
|
||||
}
|
||||
return usable[rand.IntN(len(usable))]
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pollIfStale(ctx context.Context) {
|
||||
now := time.Now()
|
||||
p.sessionAccess.Lock()
|
||||
for id, entry := range p.sessions {
|
||||
if now.Sub(entry.createdAt) > sessionExpiry {
|
||||
delete(p.sessions, id)
|
||||
}
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
|
||||
for _, cred := range p.credentials {
|
||||
if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) {
|
||||
cred.pollUsage(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) allCredentials() []credential {
|
||||
return p.credentials
|
||||
}
|
||||
|
||||
func (p *balancerProvider) close() {}
|
||||
|
||||
func ccmPlanWeight(accountType string, rateLimitTier string) float64 {
|
||||
switch accountType {
|
||||
case "max":
|
||||
switch rateLimitTier {
|
||||
case "default_claude_max_20x":
|
||||
return 10
|
||||
case "default_claude_max_5x":
|
||||
return 5
|
||||
default:
|
||||
return 5
|
||||
}
|
||||
case "team":
|
||||
if rateLimitTier == "default_claude_max_5x" {
|
||||
return 5
|
||||
}
|
||||
return 1
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func allCredentialsUnavailableError(credentials []credential) error {
|
||||
var hasUnavailable bool
|
||||
var earliest time.Time
|
||||
for _, cred := range credentials {
|
||||
if cred.unavailableError() != nil {
|
||||
hasUnavailable = true
|
||||
continue
|
||||
}
|
||||
resetAt := cred.earliestReset()
|
||||
if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) {
|
||||
earliest = resetAt
|
||||
}
|
||||
}
|
||||
if hasUnavailable {
|
||||
return E.New("all credentials unavailable")
|
||||
}
|
||||
if earliest.IsZero() {
|
||||
return E.New("all credentials rate-limited")
|
||||
}
|
||||
return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest)))
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,17 +1,12 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
boxService "github.com/sagernet/sing-box/adapter/service"
|
||||
@@ -21,23 +16,16 @@ import (
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
|
||||
const (
|
||||
contextWindowStandard = 200000
|
||||
contextWindowPremium = 1000000
|
||||
premiumContextThreshold = 200000
|
||||
retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential"
|
||||
)
|
||||
const retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential"
|
||||
|
||||
func RegisterService(registry *boxService.Registry) {
|
||||
boxService.Register[option.CCMServiceOptions](registry, C.TypeCCM, NewService)
|
||||
@@ -152,23 +140,6 @@ func isReverseProxyHeader(header string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
weeklyWindowSeconds = 604800
|
||||
weeklyWindowMinutes = weeklyWindowSeconds / 60
|
||||
)
|
||||
|
||||
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
|
||||
resetAt, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset")
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &WeeklyCycleHint{
|
||||
WindowMinutes: weeklyWindowMinutes,
|
||||
ResetAt: resetAt.UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
boxService.Adapter
|
||||
ctx context.Context
|
||||
@@ -308,545 +279,6 @@ func (s *Service) Start(stage adapter.StartStage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func isExtendedContextRequest(betaHeader string) bool {
|
||||
for _, feature := range strings.Split(betaHeader, ",") {
|
||||
if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isFastModeRequest(betaHeader string) bool {
|
||||
for _, feature := range strings.Split(betaHeader, ",") {
|
||||
if strings.HasPrefix(strings.TrimSpace(feature), "fast-mode") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func detectContextWindow(betaHeader string, totalInputTokens int64) int {
|
||||
if totalInputTokens > premiumContextThreshold {
|
||||
if isExtendedContextRequest(betaHeader) {
|
||||
return contextWindowPremium
|
||||
}
|
||||
}
|
||||
return contextWindowStandard
|
||||
}
|
||||
|
||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := log.ContextWithNewID(r.Context())
|
||||
if r.URL.Path == "/ccm/v1/status" {
|
||||
s.handleStatusEndpoint(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/ccm/v1/reverse" {
|
||||
s.handleReverseConnect(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(r.URL.Path, "/v1/") {
|
||||
writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found")
|
||||
return
|
||||
}
|
||||
|
||||
var username string
|
||||
if len(s.options.Users) > 0 {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
var ok bool
|
||||
username, ok = s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Always read body to extract model and session ID
|
||||
var bodyBytes []byte
|
||||
var requestModel string
|
||||
var messagesCount int
|
||||
var sessionID string
|
||||
|
||||
if r.Body != nil {
|
||||
var err error
|
||||
bodyBytes, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "read request body: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
var request struct {
|
||||
Model string `json:"model"`
|
||||
Messages []anthropic.MessageParam `json:"messages"`
|
||||
}
|
||||
err = json.Unmarshal(bodyBytes, &request)
|
||||
if err == nil {
|
||||
requestModel = request.Model
|
||||
messagesCount = len(request.Messages)
|
||||
}
|
||||
|
||||
sessionID = extractCCMSessionID(bodyBytes)
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}
|
||||
|
||||
// Resolve credential provider and user config
|
||||
var provider credentialProvider
|
||||
var userConfig *option.CCMUser
|
||||
if len(s.options.Users) > 0 {
|
||||
userConfig = s.userConfigMap[username]
|
||||
var err error
|
||||
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "resolve credential: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
|
||||
}
|
||||
if provider == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale(s.ctx)
|
||||
|
||||
anthropicBetaHeader := r.Header.Get("anthropic-beta")
|
||||
if isFastModeRequest(anthropicBetaHeader) {
|
||||
if _, isSingle := provider.(*singleCredentialProvider); !isSingle {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"fast mode requests will consume Extra usage, please use a default credential directly")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
selection := credentialSelectionForUser(userConfig)
|
||||
|
||||
selectedCredential, isNew, err := provider.selectCredential(sessionID, selection)
|
||||
if err != nil {
|
||||
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error()))
|
||||
return
|
||||
}
|
||||
if isNew {
|
||||
logParts := []any{"assigned credential ", selectedCredential.tagName()}
|
||||
if sessionID != "" {
|
||||
logParts = append(logParts, " for session ", sessionID)
|
||||
}
|
||||
if username != "" {
|
||||
logParts = append(logParts, " by user ", username)
|
||||
}
|
||||
if requestModel != "" {
|
||||
modelDisplay := requestModel
|
||||
if isExtendedContextRequest(anthropicBetaHeader) {
|
||||
modelDisplay += "[1m]"
|
||||
}
|
||||
logParts = append(logParts, ", model=", modelDisplay)
|
||||
}
|
||||
s.logger.DebugContext(ctx, logParts...)
|
||||
}
|
||||
|
||||
if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"fast mode requests cannot be proxied through external credentials")
|
||||
return
|
||||
}
|
||||
|
||||
requestContext := selectedCredential.wrapRequestContext(ctx)
|
||||
{
|
||||
currentRequestContext := requestContext
|
||||
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
|
||||
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
||||
}))
|
||||
}
|
||||
defer func() {
|
||||
requestContext.cancelRequest()
|
||||
}()
|
||||
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "create proxy request: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
response, err := selectedCredential.httpTransport().Do(proxyRequest)
|
||||
if err != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request")
|
||||
return
|
||||
}
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
|
||||
// Transparent 429 retry
|
||||
for response.StatusCode == http.StatusTooManyRequests {
|
||||
resetAt := parseRateLimitResetFromHeaders(response.Header)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
if bodyBytes == nil || nextCredential == nil {
|
||||
response.Body.Close()
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
|
||||
return
|
||||
}
|
||||
response.Body.Close()
|
||||
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
requestContext.cancelRequest()
|
||||
requestContext = nextCredential.wrapRequestContext(ctx)
|
||||
{
|
||||
currentRequestContext := requestContext
|
||||
requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() {
|
||||
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
||||
}))
|
||||
}
|
||||
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if buildErr != nil {
|
||||
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
|
||||
return
|
||||
}
|
||||
retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest)
|
||||
if retryErr != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request")
|
||||
return
|
||||
}
|
||||
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
|
||||
return
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
response = retryResponse
|
||||
selectedCredential = nextCredential
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
|
||||
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
||||
go selectedCredential.pollUsage(s.ctx)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
|
||||
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite response headers for external users
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" {
|
||||
s.rewriteResponseHeadersForExternalUser(response.Header, userConfig)
|
||||
}
|
||||
|
||||
for key, values := range response.Header {
|
||||
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
|
||||
w.Header()[key] = values
|
||||
}
|
||||
}
|
||||
w.WriteHeader(response.StatusCode)
|
||||
|
||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||
if usageTracker != nil && response.StatusCode == http.StatusOK {
|
||||
s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
|
||||
} else {
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
if err == nil && mediaType != "text/event-stream" {
|
||||
_, _ = io.Copy(w, response.Body)
|
||||
return
|
||||
}
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
_, writeError := w.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
|
||||
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
isStreaming := err == nil && mediaType == "text/event-stream"
|
||||
|
||||
if !isStreaming {
|
||||
bodyBytes, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "read response body: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
var message anthropic.Message
|
||||
var usage anthropic.Usage
|
||||
var responseModel string
|
||||
err = json.Unmarshal(bodyBytes, &message)
|
||||
if err == nil {
|
||||
responseModel = string(message.Model)
|
||||
usage = message.Usage
|
||||
}
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
|
||||
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
totalInputTokens := usage.InputTokens + usage.CacheCreationInputTokens + usage.CacheReadInputTokens
|
||||
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
messagesCount,
|
||||
usage.InputTokens,
|
||||
usage.OutputTokens,
|
||||
usage.CacheReadInputTokens,
|
||||
usage.CacheCreationInputTokens,
|
||||
usage.CacheCreation.Ephemeral5mInputTokens,
|
||||
usage.CacheCreation.Ephemeral1hInputTokens,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = writer.Write(bodyBytes)
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
var accumulatedUsage anthropic.Usage
|
||||
var responseModel string
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
var leftover []byte
|
||||
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
data := append(leftover, buffer[:n]...)
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
|
||||
if err == nil {
|
||||
leftover = lines[len(lines)-1]
|
||||
lines = lines[:len(lines)-1]
|
||||
} else {
|
||||
leftover = nil
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
eventData := bytes.TrimPrefix(line, []byte("data: "))
|
||||
if bytes.Equal(eventData, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
|
||||
var event anthropic.MessageStreamEventUnion
|
||||
err := json.Unmarshal(eventData, &event)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
switch event.Type {
|
||||
case "message_start":
|
||||
messageStart := event.AsMessageStart()
|
||||
if messageStart.Message.Model != "" {
|
||||
responseModel = string(messageStart.Message.Model)
|
||||
}
|
||||
if messageStart.Message.Usage.InputTokens > 0 {
|
||||
accumulatedUsage.InputTokens = messageStart.Message.Usage.InputTokens
|
||||
accumulatedUsage.CacheReadInputTokens = messageStart.Message.Usage.CacheReadInputTokens
|
||||
accumulatedUsage.CacheCreationInputTokens = messageStart.Message.Usage.CacheCreationInputTokens
|
||||
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral5mInputTokens
|
||||
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral1hInputTokens
|
||||
}
|
||||
case "message_delta":
|
||||
messageDelta := event.AsMessageDelta()
|
||||
if messageDelta.Usage.OutputTokens > 0 {
|
||||
accumulatedUsage.OutputTokens = messageDelta.Usage.OutputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, writeError := writer.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
|
||||
if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
totalInputTokens := accumulatedUsage.InputTokens + accumulatedUsage.CacheCreationInputTokens + accumulatedUsage.CacheReadInputTokens
|
||||
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
messagesCount,
|
||||
accumulatedUsage.InputTokens,
|
||||
accumulatedUsage.OutputTokens,
|
||||
accumulatedUsage.CacheReadInputTokens,
|
||||
accumulatedUsage.CacheCreationInputTokens,
|
||||
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens,
|
||||
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
if len(s.options.Users) == 0 {
|
||||
writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication")
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
username, ok := s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
|
||||
userConfig := s.userConfigMap[username]
|
||||
if userConfig == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found")
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale(r.Context())
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]float64{
|
||||
"five_hour_utilization": avgFiveHour,
|
||||
"weekly_utilization": avgWeekly,
|
||||
"plan_weight": totalWeight,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) {
|
||||
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
|
||||
for _, cred := range provider.allCredentials() {
|
||||
if !cred.isAvailable() {
|
||||
continue
|
||||
}
|
||||
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
|
||||
continue
|
||||
}
|
||||
if !userConfig.AllowExternalUsage && cred.isExternal() {
|
||||
continue
|
||||
}
|
||||
weight := cred.planWeight()
|
||||
remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization()
|
||||
if remaining5h < 0 {
|
||||
remaining5h = 0
|
||||
}
|
||||
remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization()
|
||||
if remainingWeekly < 0 {
|
||||
remainingWeekly = 0
|
||||
}
|
||||
totalWeightedRemaining5h += remaining5h * weight
|
||||
totalWeightedRemainingWeekly += remainingWeekly * weight
|
||||
totalWeight += weight
|
||||
}
|
||||
if totalWeight == 0 {
|
||||
return 100, 100, 0
|
||||
}
|
||||
return 100 - totalWeightedRemaining5h/totalWeight,
|
||||
100 - totalWeightedRemainingWeekly/totalWeight,
|
||||
totalWeight
|
||||
}
|
||||
|
||||
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.CCMUser) {
|
||||
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
// Rewrite utilization headers to aggregated average (convert back to 0.0-1.0 range)
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64))
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64))
|
||||
if totalWeight > 0 {
|
||||
headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) InterfaceUpdated() {
|
||||
for _, cred := range s.allCredentials {
|
||||
extCred, ok := cred.(*externalCredential)
|
||||
|
||||
499
service/ccm/service_handler.go
Normal file
499
service/ccm/service_handler.go
Normal file
@@ -0,0 +1,499 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
)
|
||||
|
||||
const (
|
||||
contextWindowStandard = 200000
|
||||
contextWindowPremium = 1000000
|
||||
premiumContextThreshold = 200000
|
||||
)
|
||||
|
||||
const (
|
||||
weeklyWindowSeconds = 604800
|
||||
weeklyWindowMinutes = weeklyWindowSeconds / 60
|
||||
)
|
||||
|
||||
func isExtendedContextRequest(betaHeader string) bool {
|
||||
for _, feature := range strings.Split(betaHeader, ",") {
|
||||
if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isFastModeRequest(betaHeader string) bool {
|
||||
for _, feature := range strings.Split(betaHeader, ",") {
|
||||
if strings.HasPrefix(strings.TrimSpace(feature), "fast-mode") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func detectContextWindow(betaHeader string, totalInputTokens int64) int {
|
||||
if totalInputTokens > premiumContextThreshold {
|
||||
if isExtendedContextRequest(betaHeader) {
|
||||
return contextWindowPremium
|
||||
}
|
||||
}
|
||||
return contextWindowStandard
|
||||
}
|
||||
|
||||
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
|
||||
resetAt, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset")
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &WeeklyCycleHint{
|
||||
WindowMinutes: weeklyWindowMinutes,
|
||||
ResetAt: resetAt.UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func extractCCMSessionID(bodyBytes []byte) string {
|
||||
var body struct {
|
||||
Metadata struct {
|
||||
UserID string `json:"user_id"`
|
||||
} `json:"metadata"`
|
||||
}
|
||||
err := json.Unmarshal(bodyBytes, &body)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
userID := body.Metadata.UserID
|
||||
sessionIndex := strings.LastIndex(userID, "_session_")
|
||||
if sessionIndex < 0 {
|
||||
return ""
|
||||
}
|
||||
return userID[sessionIndex+len("_session_"):]
|
||||
}
|
||||
|
||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := log.ContextWithNewID(r.Context())
|
||||
if r.URL.Path == "/ccm/v1/status" {
|
||||
s.handleStatusEndpoint(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/ccm/v1/reverse" {
|
||||
s.handleReverseConnect(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(r.URL.Path, "/v1/") {
|
||||
writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found")
|
||||
return
|
||||
}
|
||||
|
||||
var username string
|
||||
if len(s.options.Users) > 0 {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
var ok bool
|
||||
username, ok = s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Always read body to extract model and session ID
|
||||
var bodyBytes []byte
|
||||
var requestModel string
|
||||
var messagesCount int
|
||||
var sessionID string
|
||||
|
||||
if r.Body != nil {
|
||||
var err error
|
||||
bodyBytes, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "read request body: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
var request struct {
|
||||
Model string `json:"model"`
|
||||
Messages []anthropic.MessageParam `json:"messages"`
|
||||
}
|
||||
err = json.Unmarshal(bodyBytes, &request)
|
||||
if err == nil {
|
||||
requestModel = request.Model
|
||||
messagesCount = len(request.Messages)
|
||||
}
|
||||
|
||||
sessionID = extractCCMSessionID(bodyBytes)
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}
|
||||
|
||||
// Resolve credential provider and user config
|
||||
var provider credentialProvider
|
||||
var userConfig *option.CCMUser
|
||||
if len(s.options.Users) > 0 {
|
||||
userConfig = s.userConfigMap[username]
|
||||
var err error
|
||||
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "resolve credential: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
|
||||
}
|
||||
if provider == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale(s.ctx)
|
||||
|
||||
anthropicBetaHeader := r.Header.Get("anthropic-beta")
|
||||
if isFastModeRequest(anthropicBetaHeader) {
|
||||
if _, isSingle := provider.(*singleCredentialProvider); !isSingle {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"fast mode requests will consume Extra usage, please use a default credential directly")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
selection := credentialSelectionForUser(userConfig)
|
||||
|
||||
selectedCredential, isNew, err := provider.selectCredential(sessionID, selection)
|
||||
if err != nil {
|
||||
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error()))
|
||||
return
|
||||
}
|
||||
if isNew {
|
||||
logParts := []any{"assigned credential ", selectedCredential.tagName()}
|
||||
if sessionID != "" {
|
||||
logParts = append(logParts, " for session ", sessionID)
|
||||
}
|
||||
if username != "" {
|
||||
logParts = append(logParts, " by user ", username)
|
||||
}
|
||||
if requestModel != "" {
|
||||
modelDisplay := requestModel
|
||||
if isExtendedContextRequest(anthropicBetaHeader) {
|
||||
modelDisplay += "[1m]"
|
||||
}
|
||||
logParts = append(logParts, ", model=", modelDisplay)
|
||||
}
|
||||
s.logger.DebugContext(ctx, logParts...)
|
||||
}
|
||||
|
||||
if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"fast mode requests cannot be proxied through external credentials")
|
||||
return
|
||||
}
|
||||
|
||||
requestContext := selectedCredential.wrapRequestContext(ctx)
|
||||
{
|
||||
currentRequestContext := requestContext
|
||||
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
|
||||
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
||||
}))
|
||||
}
|
||||
defer func() {
|
||||
requestContext.cancelRequest()
|
||||
}()
|
||||
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "create proxy request: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
response, err := selectedCredential.httpClient().Do(proxyRequest)
|
||||
if err != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request")
|
||||
return
|
||||
}
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
|
||||
// Transparent 429 retry
|
||||
for response.StatusCode == http.StatusTooManyRequests {
|
||||
resetAt := parseRateLimitResetFromHeaders(response.Header)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
if bodyBytes == nil || nextCredential == nil {
|
||||
response.Body.Close()
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
|
||||
return
|
||||
}
|
||||
response.Body.Close()
|
||||
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
requestContext.cancelRequest()
|
||||
requestContext = nextCredential.wrapRequestContext(ctx)
|
||||
{
|
||||
currentRequestContext := requestContext
|
||||
requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() {
|
||||
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
||||
}))
|
||||
}
|
||||
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if buildErr != nil {
|
||||
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
|
||||
return
|
||||
}
|
||||
retryResponse, retryErr := nextCredential.httpClient().Do(retryRequest)
|
||||
if retryErr != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request")
|
||||
return
|
||||
}
|
||||
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
|
||||
return
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
response = retryResponse
|
||||
selectedCredential = nextCredential
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
|
||||
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
||||
go selectedCredential.pollUsage(s.ctx)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
|
||||
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite response headers for external users
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" {
|
||||
s.rewriteResponseHeadersForExternalUser(response.Header, userConfig)
|
||||
}
|
||||
|
||||
for key, values := range response.Header {
|
||||
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
|
||||
w.Header()[key] = values
|
||||
}
|
||||
}
|
||||
w.WriteHeader(response.StatusCode)
|
||||
|
||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||
if usageTracker != nil && response.StatusCode == http.StatusOK {
|
||||
s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
|
||||
} else {
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
if err == nil && mediaType != "text/event-stream" {
|
||||
_, _ = io.Copy(w, response.Body)
|
||||
return
|
||||
}
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
_, writeError := w.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
|
||||
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
isStreaming := err == nil && mediaType == "text/event-stream"
|
||||
|
||||
if !isStreaming {
|
||||
bodyBytes, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "read response body: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
var message anthropic.Message
|
||||
var usage anthropic.Usage
|
||||
var responseModel string
|
||||
err = json.Unmarshal(bodyBytes, &message)
|
||||
if err == nil {
|
||||
responseModel = string(message.Model)
|
||||
usage = message.Usage
|
||||
}
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
|
||||
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
totalInputTokens := usage.InputTokens + usage.CacheCreationInputTokens + usage.CacheReadInputTokens
|
||||
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
messagesCount,
|
||||
usage.InputTokens,
|
||||
usage.OutputTokens,
|
||||
usage.CacheReadInputTokens,
|
||||
usage.CacheCreationInputTokens,
|
||||
usage.CacheCreation.Ephemeral5mInputTokens,
|
||||
usage.CacheCreation.Ephemeral1hInputTokens,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = writer.Write(bodyBytes)
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
var accumulatedUsage anthropic.Usage
|
||||
var responseModel string
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
var leftover []byte
|
||||
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
data := append(leftover, buffer[:n]...)
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
|
||||
if err == nil {
|
||||
leftover = lines[len(lines)-1]
|
||||
lines = lines[:len(lines)-1]
|
||||
} else {
|
||||
leftover = nil
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
eventData := bytes.TrimPrefix(line, []byte("data: "))
|
||||
if bytes.Equal(eventData, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
|
||||
var event anthropic.MessageStreamEventUnion
|
||||
err := json.Unmarshal(eventData, &event)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
switch event.Type {
|
||||
case "message_start":
|
||||
messageStart := event.AsMessageStart()
|
||||
if messageStart.Message.Model != "" {
|
||||
responseModel = string(messageStart.Message.Model)
|
||||
}
|
||||
if messageStart.Message.Usage.InputTokens > 0 {
|
||||
accumulatedUsage.InputTokens = messageStart.Message.Usage.InputTokens
|
||||
accumulatedUsage.CacheReadInputTokens = messageStart.Message.Usage.CacheReadInputTokens
|
||||
accumulatedUsage.CacheCreationInputTokens = messageStart.Message.Usage.CacheCreationInputTokens
|
||||
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral5mInputTokens
|
||||
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral1hInputTokens
|
||||
}
|
||||
case "message_delta":
|
||||
messageDelta := event.AsMessageDelta()
|
||||
if messageDelta.Usage.OutputTokens > 0 {
|
||||
accumulatedUsage.OutputTokens = messageDelta.Usage.OutputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, writeError := writer.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
|
||||
if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
totalInputTokens := accumulatedUsage.InputTokens + accumulatedUsage.CacheCreationInputTokens + accumulatedUsage.CacheReadInputTokens
|
||||
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
messagesCount,
|
||||
accumulatedUsage.InputTokens,
|
||||
accumulatedUsage.OutputTokens,
|
||||
accumulatedUsage.CacheReadInputTokens,
|
||||
accumulatedUsage.CacheCreationInputTokens,
|
||||
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens,
|
||||
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
109
service/ccm/service_status.go
Normal file
109
service/ccm/service_status.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
if len(s.options.Users) == 0 {
|
||||
writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication")
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
username, ok := s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
|
||||
userConfig := s.userConfigMap[username]
|
||||
if userConfig == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found")
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale(r.Context())
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]float64{
|
||||
"five_hour_utilization": avgFiveHour,
|
||||
"weekly_utilization": avgWeekly,
|
||||
"plan_weight": totalWeight,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) {
|
||||
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
|
||||
for _, cred := range provider.allCredentials() {
|
||||
if !cred.isAvailable() {
|
||||
continue
|
||||
}
|
||||
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
|
||||
continue
|
||||
}
|
||||
if !userConfig.AllowExternalUsage && cred.isExternal() {
|
||||
continue
|
||||
}
|
||||
weight := cred.planWeight()
|
||||
remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization()
|
||||
if remaining5h < 0 {
|
||||
remaining5h = 0
|
||||
}
|
||||
remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization()
|
||||
if remainingWeekly < 0 {
|
||||
remainingWeekly = 0
|
||||
}
|
||||
totalWeightedRemaining5h += remaining5h * weight
|
||||
totalWeightedRemainingWeekly += remainingWeekly * weight
|
||||
totalWeight += weight
|
||||
}
|
||||
if totalWeight == 0 {
|
||||
return 100, 100, 0
|
||||
}
|
||||
return 100 - totalWeightedRemaining5h/totalWeight,
|
||||
100 - totalWeightedRemainingWeekly/totalWeight,
|
||||
totalWeight
|
||||
}
|
||||
|
||||
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.CCMUser) {
|
||||
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64))
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64))
|
||||
if totalWeight > 0 {
|
||||
headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64))
|
||||
}
|
||||
}
|
||||
@@ -7,13 +7,13 @@ import (
|
||||
)
|
||||
|
||||
type UserManager struct {
|
||||
accessMutex sync.RWMutex
|
||||
access sync.RWMutex
|
||||
tokenMap map[string]string
|
||||
}
|
||||
|
||||
func (m *UserManager) UpdateUsers(users []option.CCMUser) {
|
||||
m.accessMutex.Lock()
|
||||
defer m.accessMutex.Unlock()
|
||||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
tokenMap := make(map[string]string, len(users))
|
||||
for _, user := range users {
|
||||
tokenMap[user.Token] = user.Name
|
||||
@@ -22,8 +22,8 @@ func (m *UserManager) UpdateUsers(users []option.CCMUser) {
|
||||
}
|
||||
|
||||
func (m *UserManager) Authenticate(token string) (string, bool) {
|
||||
m.accessMutex.RLock()
|
||||
m.access.RLock()
|
||||
username, found := m.tokenMap[token]
|
||||
m.accessMutex.RUnlock()
|
||||
m.access.RUnlock()
|
||||
return username, found
|
||||
}
|
||||
|
||||
@@ -1,225 +1,194 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
const (
|
||||
oauth2ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
oauth2TokenURL = "https://auth.openai.com/oauth/token"
|
||||
openaiAPIBaseURL = "https://api.openai.com"
|
||||
chatGPTBackendURL = "https://chatgpt.com/backend-api/codex"
|
||||
tokenRefreshIntervalDays = 8
|
||||
defaultPollInterval = 60 * time.Minute
|
||||
failedPollRetryInterval = time.Minute
|
||||
httpRetryMaxBackoff = 5 * time.Minute
|
||||
)
|
||||
|
||||
func getRealUser() (*user.User, error) {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
sudoUserInfo, err := user.Lookup(sudoUser)
|
||||
if err == nil {
|
||||
return sudoUserInfo, nil
|
||||
const (
|
||||
httpRetryMaxAttempts = 3
|
||||
httpRetryInitialDelay = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
const sessionExpiry = 24 * time.Hour
|
||||
|
||||
func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) {
|
||||
var lastError error
|
||||
for attempt := range httpRetryMaxAttempts {
|
||||
if attempt > 0 {
|
||||
delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, lastError
|
||||
case <-time.After(delay):
|
||||
}
|
||||
}
|
||||
}
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
func getDefaultCredentialsPath() (string, error) {
|
||||
if codexHome := os.Getenv("CODEX_HOME"); codexHome != "" {
|
||||
return filepath.Join(codexHome, "auth.json"), nil
|
||||
}
|
||||
userInfo, err := getRealUser()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(userInfo.HomeDir, ".codex", "auth.json"), nil
|
||||
}
|
||||
|
||||
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var credentials oauthCredentials
|
||||
err = json.Unmarshal(data, &credentials)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &credentials, nil
|
||||
}
|
||||
|
||||
func checkCredentialFileWritable(path string) error {
|
||||
file, err := os.OpenFile(path, os.O_WRONLY, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
func writeCredentialsToFile(credentials *oauthCredentials, path string) error {
|
||||
data, err := json.MarshalIndent(credentials, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0o600)
|
||||
}
|
||||
|
||||
type oauthCredentials struct {
|
||||
APIKey string `json:"OPENAI_API_KEY,omitempty"`
|
||||
Tokens *tokenData `json:"tokens,omitempty"`
|
||||
LastRefresh *time.Time `json:"last_refresh,omitempty"`
|
||||
}
|
||||
|
||||
type tokenData struct {
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) isAPIKeyMode() bool {
|
||||
return c.APIKey != ""
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) getAccessToken() string {
|
||||
if c.APIKey != "" {
|
||||
return c.APIKey
|
||||
}
|
||||
if c.Tokens != nil {
|
||||
return c.Tokens.AccessToken
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) getAccountID() string {
|
||||
if c.Tokens != nil {
|
||||
return c.Tokens.AccountID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) needsRefresh() bool {
|
||||
if c.APIKey != "" {
|
||||
return false
|
||||
}
|
||||
if c.Tokens == nil || c.Tokens.RefreshToken == "" {
|
||||
return false
|
||||
}
|
||||
if c.LastRefresh == nil {
|
||||
return true
|
||||
}
|
||||
return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour
|
||||
}
|
||||
|
||||
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
|
||||
if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" {
|
||||
return nil, E.New("refresh token is empty")
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.Tokens.RefreshToken,
|
||||
"client_id": oauth2ClientID,
|
||||
"scope": "openid profile email",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "marshal request")
|
||||
}
|
||||
|
||||
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
|
||||
request, err := buildRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "application/json")
|
||||
return request, nil
|
||||
response, err := client.Do(request)
|
||||
if err == nil {
|
||||
return response, nil
|
||||
}
|
||||
lastError = err
|
||||
if ctx.Err() != nil {
|
||||
return nil, lastError
|
||||
}
|
||||
}
|
||||
return nil, lastError
|
||||
}
|
||||
|
||||
type credentialState struct {
|
||||
fiveHourUtilization float64
|
||||
fiveHourReset time.Time
|
||||
weeklyUtilization float64
|
||||
weeklyReset time.Time
|
||||
hardRateLimited bool
|
||||
rateLimitResetAt time.Time
|
||||
accountType string
|
||||
remotePlanWeight float64
|
||||
lastUpdated time.Time
|
||||
consecutivePollFailures int
|
||||
unavailable bool
|
||||
lastCredentialLoadAttempt time.Time
|
||||
lastCredentialLoadError string
|
||||
}
|
||||
|
||||
type credentialRequestContext struct {
|
||||
context.Context
|
||||
releaseOnce sync.Once
|
||||
cancelOnce sync.Once
|
||||
releaseFuncs []func() bool
|
||||
cancelFunc context.CancelFunc
|
||||
}
|
||||
|
||||
func (c *credentialRequestContext) addInterruptLink(stop func() bool) {
|
||||
c.releaseFuncs = append(c.releaseFuncs, stop)
|
||||
}
|
||||
|
||||
func (c *credentialRequestContext) releaseCredentialInterrupt() {
|
||||
c.releaseOnce.Do(func() {
|
||||
for _, f := range c.releaseFuncs {
|
||||
f()
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode == http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, E.New("refresh rate limited: ", response.Status, " ", string(body))
|
||||
}
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
|
||||
}
|
||||
|
||||
var tokenResponse struct {
|
||||
IDToken string `json:"id_token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode response")
|
||||
}
|
||||
|
||||
newCredentials := *credentials
|
||||
if newCredentials.Tokens == nil {
|
||||
newCredentials.Tokens = &tokenData{}
|
||||
}
|
||||
if tokenResponse.IDToken != "" {
|
||||
newCredentials.Tokens.IDToken = tokenResponse.IDToken
|
||||
}
|
||||
if tokenResponse.AccessToken != "" {
|
||||
newCredentials.Tokens.AccessToken = tokenResponse.AccessToken
|
||||
}
|
||||
if tokenResponse.RefreshToken != "" {
|
||||
newCredentials.Tokens.RefreshToken = tokenResponse.RefreshToken
|
||||
}
|
||||
now := time.Now()
|
||||
newCredentials.LastRefresh = &now
|
||||
|
||||
return &newCredentials, nil
|
||||
}
|
||||
|
||||
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
|
||||
if credentials == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *credentials
|
||||
if credentials.Tokens != nil {
|
||||
clonedTokens := *credentials.Tokens
|
||||
cloned.Tokens = &clonedTokens
|
||||
}
|
||||
if credentials.LastRefresh != nil {
|
||||
lastRefresh := *credentials.LastRefresh
|
||||
cloned.LastRefresh = &lastRefresh
|
||||
}
|
||||
return &cloned
|
||||
func (c *credentialRequestContext) cancelRequest() {
|
||||
c.releaseCredentialInterrupt()
|
||||
c.cancelOnce.Do(c.cancelFunc)
|
||||
}
|
||||
|
||||
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
|
||||
if left == nil || right == nil {
|
||||
return left == right
|
||||
}
|
||||
if left.APIKey != right.APIKey {
|
||||
return false
|
||||
}
|
||||
if (left.Tokens == nil) != (right.Tokens == nil) {
|
||||
return false
|
||||
}
|
||||
if left.Tokens != nil && *left.Tokens != *right.Tokens {
|
||||
return false
|
||||
}
|
||||
if (left.LastRefresh == nil) != (right.LastRefresh == nil) {
|
||||
return false
|
||||
}
|
||||
if left.LastRefresh != nil && !left.LastRefresh.Equal(*right.LastRefresh) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
type credential interface {
|
||||
tagName() string
|
||||
isAvailable() bool
|
||||
isUsable() bool
|
||||
isExternal() bool
|
||||
fiveHourUtilization() float64
|
||||
weeklyUtilization() float64
|
||||
fiveHourCap() float64
|
||||
weeklyCap() float64
|
||||
planWeight() float64
|
||||
weeklyResetTime() time.Time
|
||||
markRateLimited(resetAt time.Time)
|
||||
earliestReset() time.Time
|
||||
unavailableError() error
|
||||
|
||||
getAccessToken() (string, error)
|
||||
buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error)
|
||||
updateStateFromHeaders(header http.Header)
|
||||
|
||||
wrapRequestContext(ctx context.Context) *credentialRequestContext
|
||||
interruptConnections()
|
||||
|
||||
setOnBecameUnusable(fn func())
|
||||
start() error
|
||||
pollUsage(ctx context.Context)
|
||||
lastUpdatedTime() time.Time
|
||||
pollBackoff(base time.Duration) time.Duration
|
||||
usageTrackerOrNil() *AggregatedUsage
|
||||
httpClient() *http.Client
|
||||
close()
|
||||
|
||||
// OCM-specific
|
||||
ocmDialer() N.Dialer
|
||||
ocmIsAPIKeyMode() bool
|
||||
ocmGetAccountID() string
|
||||
ocmGetBaseURL() string
|
||||
}
|
||||
|
||||
type credentialSelectionScope string
|
||||
|
||||
const (
|
||||
credentialSelectionScopeAll credentialSelectionScope = "all"
|
||||
credentialSelectionScopeNonExternal credentialSelectionScope = "non_external"
|
||||
)
|
||||
|
||||
type credentialSelection struct {
|
||||
scope credentialSelectionScope
|
||||
filter func(credential) bool
|
||||
}
|
||||
|
||||
func (s credentialSelection) allows(cred credential) bool {
|
||||
return s.filter == nil || s.filter(cred)
|
||||
}
|
||||
|
||||
func (s credentialSelection) scopeOrDefault() credentialSelectionScope {
|
||||
if s.scope == "" {
|
||||
return credentialSelectionScopeAll
|
||||
}
|
||||
return s.scope
|
||||
}
|
||||
|
||||
func normalizeRateLimitIdentifier(limitIdentifier string) string {
|
||||
trimmedIdentifier := strings.TrimSpace(strings.ToLower(limitIdentifier))
|
||||
if trimmedIdentifier == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.ReplaceAll(trimmedIdentifier, "_", "-")
|
||||
}
|
||||
|
||||
func parseInt64Header(headers http.Header, headerName string) (int64, bool) {
|
||||
headerValue := strings.TrimSpace(headers.Get(headerName))
|
||||
if headerValue == "" {
|
||||
return 0, false
|
||||
}
|
||||
parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64)
|
||||
if parseError != nil {
|
||||
return 0, false
|
||||
}
|
||||
return parsedValue, true
|
||||
}
|
||||
|
||||
func parseOCMRateLimitResetFromHeaders(headers http.Header) time.Time {
|
||||
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
|
||||
if activeLimitIdentifier != "" {
|
||||
resetHeader := "x-" + activeLimitIdentifier + "-primary-reset-at"
|
||||
if resetStr := headers.Get(resetHeader); resetStr != "" {
|
||||
value, err := strconv.ParseInt(resetStr, 10, 64)
|
||||
if err == nil {
|
||||
return time.Unix(value, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
if retryAfter := headers.Get("Retry-After"); retryAfter != "" {
|
||||
seconds, err := strconv.ParseInt(retryAfter, 10, 64)
|
||||
if err == nil {
|
||||
return time.Now().Add(time.Duration(seconds) * time.Second)
|
||||
}
|
||||
}
|
||||
return time.Now().Add(5 * time.Minute)
|
||||
}
|
||||
|
||||
223
service/ocm/credential_builder.go
Normal file
223
service/ocm/credential_builder.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func buildOCMCredentialProviders(
|
||||
ctx context.Context,
|
||||
options option.OCMServiceOptions,
|
||||
logger log.ContextLogger,
|
||||
) (map[string]credentialProvider, []credential, error) {
|
||||
allCredentialMap := make(map[string]credential)
|
||||
var allCreds []credential
|
||||
providers := make(map[string]credentialProvider)
|
||||
|
||||
// Pass 1: create default and external credentials
|
||||
for _, credOpt := range options.Credentials {
|
||||
switch credOpt.Type {
|
||||
case "default":
|
||||
cred, err := newDefaultCredential(ctx, credOpt.Tag, credOpt.DefaultOptions, logger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
allCredentialMap[credOpt.Tag] = cred
|
||||
allCreds = append(allCreds, cred)
|
||||
providers[credOpt.Tag] = &singleCredentialProvider{cred: cred}
|
||||
case "external":
|
||||
cred, err := newExternalCredential(ctx, credOpt.Tag, credOpt.ExternalOptions, logger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
allCredentialMap[credOpt.Tag] = cred
|
||||
allCreds = append(allCreds, cred)
|
||||
providers[credOpt.Tag] = &singleCredentialProvider{cred: cred}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2: create balancer providers
|
||||
for _, credOpt := range options.Credentials {
|
||||
if credOpt.Type == "balancer" {
|
||||
subCredentials, err := resolveCredentialTags(credOpt.BalancerOptions.Credentials, allCredentialMap, credOpt.Tag)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
providers[credOpt.Tag] = newBalancerProvider(subCredentials, credOpt.BalancerOptions.Strategy, time.Duration(credOpt.BalancerOptions.PollInterval), credOpt.BalancerOptions.RebalanceThreshold, logger)
|
||||
}
|
||||
}
|
||||
|
||||
return providers, allCreds, nil
|
||||
}
|
||||
|
||||
func resolveCredentialTags(tags []string, allCredentials map[string]credential, parentTag string) ([]credential, error) {
|
||||
credentials := make([]credential, 0, len(tags))
|
||||
for _, tag := range tags {
|
||||
cred, exists := allCredentials[tag]
|
||||
if !exists {
|
||||
return nil, E.New("credential ", parentTag, " references unknown credential: ", tag)
|
||||
}
|
||||
credentials = append(credentials, cred)
|
||||
}
|
||||
if len(credentials) == 0 {
|
||||
return nil, E.New("credential ", parentTag, " has no sub-credentials")
|
||||
}
|
||||
return credentials, nil
|
||||
}
|
||||
|
||||
func validateOCMOptions(options option.OCMServiceOptions) error {
|
||||
hasCredentials := len(options.Credentials) > 0
|
||||
hasLegacyPath := options.CredentialPath != ""
|
||||
hasLegacyUsages := options.UsagesPath != ""
|
||||
hasLegacyDetour := options.Detour != ""
|
||||
|
||||
if hasCredentials && hasLegacyPath {
|
||||
return E.New("credential_path and credentials are mutually exclusive")
|
||||
}
|
||||
if hasCredentials && hasLegacyUsages {
|
||||
return E.New("usages_path and credentials are mutually exclusive; use usages_path on individual credentials")
|
||||
}
|
||||
if hasCredentials && hasLegacyDetour {
|
||||
return E.New("detour and credentials are mutually exclusive; use detour on individual credentials")
|
||||
}
|
||||
|
||||
if hasCredentials {
|
||||
tags := make(map[string]bool)
|
||||
credentialTypes := make(map[string]string)
|
||||
for _, cred := range options.Credentials {
|
||||
if tags[cred.Tag] {
|
||||
return E.New("duplicate credential tag: ", cred.Tag)
|
||||
}
|
||||
tags[cred.Tag] = true
|
||||
credentialTypes[cred.Tag] = cred.Type
|
||||
if cred.Type == "default" || cred.Type == "" {
|
||||
if cred.DefaultOptions.Reserve5h > 99 {
|
||||
return E.New("credential ", cred.Tag, ": reserve_5h must be at most 99")
|
||||
}
|
||||
if cred.DefaultOptions.ReserveWeekly > 99 {
|
||||
return E.New("credential ", cred.Tag, ": reserve_weekly must be at most 99")
|
||||
}
|
||||
if cred.DefaultOptions.Limit5h > 100 {
|
||||
return E.New("credential ", cred.Tag, ": limit_5h must be at most 100")
|
||||
}
|
||||
if cred.DefaultOptions.LimitWeekly > 100 {
|
||||
return E.New("credential ", cred.Tag, ": limit_weekly must be at most 100")
|
||||
}
|
||||
if cred.DefaultOptions.Reserve5h > 0 && cred.DefaultOptions.Limit5h > 0 {
|
||||
return E.New("credential ", cred.Tag, ": reserve_5h and limit_5h are mutually exclusive")
|
||||
}
|
||||
if cred.DefaultOptions.ReserveWeekly > 0 && cred.DefaultOptions.LimitWeekly > 0 {
|
||||
return E.New("credential ", cred.Tag, ": reserve_weekly and limit_weekly are mutually exclusive")
|
||||
}
|
||||
}
|
||||
if cred.Type == "external" {
|
||||
if cred.ExternalOptions.Token == "" {
|
||||
return E.New("credential ", cred.Tag, ": external credential requires token")
|
||||
}
|
||||
if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" {
|
||||
return E.New("credential ", cred.Tag, ": reverse external credential requires url")
|
||||
}
|
||||
}
|
||||
if cred.Type == "balancer" {
|
||||
switch cred.BalancerOptions.Strategy {
|
||||
case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback:
|
||||
default:
|
||||
return E.New("credential ", cred.Tag, ": unknown balancer strategy: ", cred.BalancerOptions.Strategy)
|
||||
}
|
||||
if cred.BalancerOptions.RebalanceThreshold < 0 {
|
||||
return E.New("credential ", cred.Tag, ": rebalance_threshold must not be negative")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, user := range options.Users {
|
||||
if user.Credential == "" {
|
||||
return E.New("user ", user.Name, " must specify credential in multi-credential mode")
|
||||
}
|
||||
if !tags[user.Credential] {
|
||||
return E.New("user ", user.Name, " references unknown credential: ", user.Credential)
|
||||
}
|
||||
if user.ExternalCredential != "" {
|
||||
if !tags[user.ExternalCredential] {
|
||||
return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential)
|
||||
}
|
||||
if credentialTypes[user.ExternalCredential] != "external" {
|
||||
return E.New("user ", user.Name, ": external_credential must reference an external type credential")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateOCMCompositeCredentialModes(
|
||||
options option.OCMServiceOptions,
|
||||
providers map[string]credentialProvider,
|
||||
) error {
|
||||
for _, credOpt := range options.Credentials {
|
||||
if credOpt.Type != "balancer" {
|
||||
continue
|
||||
}
|
||||
|
||||
provider, exists := providers[credOpt.Tag]
|
||||
if !exists {
|
||||
return E.New("unknown credential: ", credOpt.Tag)
|
||||
}
|
||||
|
||||
for _, subCred := range provider.allCredentials() {
|
||||
if !subCred.isAvailable() {
|
||||
continue
|
||||
}
|
||||
if subCred.ocmIsAPIKeyMode() {
|
||||
return E.New(
|
||||
"credential ", credOpt.Tag,
|
||||
" references API key default credential ", subCred.tagName(),
|
||||
"; balancer and fallback only support OAuth default credentials",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func credentialForUser(
|
||||
userConfigMap map[string]*option.OCMUser,
|
||||
providers map[string]credentialProvider,
|
||||
legacyProvider credentialProvider,
|
||||
username string,
|
||||
) (credentialProvider, error) {
|
||||
if legacyProvider != nil {
|
||||
return legacyProvider, nil
|
||||
}
|
||||
userConfig, exists := userConfigMap[username]
|
||||
if !exists {
|
||||
return nil, E.New("no credential mapping for user: ", username)
|
||||
}
|
||||
provider, exists := providers[userConfig.Credential]
|
||||
if !exists {
|
||||
return nil, E.New("unknown credential: ", userConfig.Credential)
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func noUserCredentialProvider(
|
||||
providers map[string]credentialProvider,
|
||||
legacyProvider credentialProvider,
|
||||
options option.OCMServiceOptions,
|
||||
) credentialProvider {
|
||||
if legacyProvider != nil {
|
||||
return legacyProvider
|
||||
}
|
||||
if len(options.Credentials) > 0 {
|
||||
tag := options.Credentials[0].Tag
|
||||
return providers[tag]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
749
service/ocm/credential_default.go
Normal file
749
service/ocm/credential_default.go
Normal file
@@ -0,0 +1,749 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/fswatch"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
)
|
||||
|
||||
type defaultCredential struct {
|
||||
tag string
|
||||
serviceContext context.Context
|
||||
credentialPath string
|
||||
credentialFilePath string
|
||||
credentials *oauthCredentials
|
||||
access sync.RWMutex
|
||||
state credentialState
|
||||
stateAccess sync.RWMutex
|
||||
pollAccess sync.Mutex
|
||||
reloadAccess sync.Mutex
|
||||
watcherAccess sync.Mutex
|
||||
cap5h float64
|
||||
capWeekly float64
|
||||
usageTracker *AggregatedUsage
|
||||
dialer N.Dialer
|
||||
forwardHTTPClient *http.Client
|
||||
logger log.ContextLogger
|
||||
watcher *fswatch.Watcher
|
||||
watcherRetryAt time.Time
|
||||
|
||||
// Connection interruption
|
||||
onBecameUnusable func()
|
||||
interrupted bool
|
||||
requestContext context.Context
|
||||
cancelRequests context.CancelFunc
|
||||
requestAccess sync.Mutex
|
||||
}
|
||||
|
||||
func newDefaultCredential(ctx context.Context, tag string, options option.OCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) {
|
||||
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
Options: option.DialerOptions{
|
||||
Detour: options.Detour,
|
||||
},
|
||||
RemoteIsDomain: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create dialer for credential ", tag)
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSClientConfig: &stdTLS.Config{
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
},
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
},
|
||||
}
|
||||
reserve5h := options.Reserve5h
|
||||
if reserve5h == 0 {
|
||||
reserve5h = 1
|
||||
}
|
||||
reserveWeekly := options.ReserveWeekly
|
||||
if reserveWeekly == 0 {
|
||||
reserveWeekly = 1
|
||||
}
|
||||
var cap5h float64
|
||||
if options.Limit5h > 0 {
|
||||
cap5h = float64(options.Limit5h)
|
||||
} else {
|
||||
cap5h = float64(100 - reserve5h)
|
||||
}
|
||||
var capWeekly float64
|
||||
if options.LimitWeekly > 0 {
|
||||
capWeekly = float64(options.LimitWeekly)
|
||||
} else {
|
||||
capWeekly = float64(100 - reserveWeekly)
|
||||
}
|
||||
requestContext, cancelRequests := context.WithCancel(context.Background())
|
||||
credential := &defaultCredential{
|
||||
tag: tag,
|
||||
serviceContext: ctx,
|
||||
credentialPath: options.CredentialPath,
|
||||
cap5h: cap5h,
|
||||
capWeekly: capWeekly,
|
||||
dialer: credentialDialer,
|
||||
forwardHTTPClient: httpClient,
|
||||
logger: logger,
|
||||
requestContext: requestContext,
|
||||
cancelRequests: cancelRequests,
|
||||
}
|
||||
if options.UsagesPath != "" {
|
||||
credential.usageTracker = &AggregatedUsage{
|
||||
LastUpdated: time.Now(),
|
||||
Combinations: make([]CostCombination, 0),
|
||||
filePath: options.UsagesPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
return credential, nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) start() error {
|
||||
credentialFilePath, err := resolveCredentialFilePath(c.credentialPath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "resolve credential path for ", c.tag)
|
||||
}
|
||||
c.credentialFilePath = credentialFilePath
|
||||
err = c.ensureCredentialWatcher()
|
||||
if err != nil {
|
||||
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
err = c.reloadCredentials(true)
|
||||
if err != nil {
|
||||
c.logger.Warn("initial credential load for ", c.tag, ": ", err)
|
||||
}
|
||||
if c.usageTracker != nil {
|
||||
err = c.usageTracker.Load()
|
||||
if err != nil {
|
||||
c.logger.Warn("load usage statistics for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) setOnBecameUnusable(fn func()) {
|
||||
c.onBecameUnusable = fn
|
||||
}
|
||||
|
||||
func (c *defaultCredential) tagName() string {
|
||||
return c.tag
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isExternal() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *defaultCredential) getAccessToken() (string, error) {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
c.access.RLock()
|
||||
if c.credentials != nil && !c.credentials.needsRefresh() {
|
||||
token := c.credentials.getAccessToken()
|
||||
c.access.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
c.access.RUnlock()
|
||||
|
||||
err := c.reloadCredentials(true)
|
||||
if err == nil {
|
||||
c.access.RLock()
|
||||
if c.credentials != nil && !c.credentials.needsRefresh() {
|
||||
token := c.credentials.getAccessToken()
|
||||
c.access.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
c.access.RUnlock()
|
||||
}
|
||||
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
|
||||
if c.credentials == nil {
|
||||
return "", c.unavailableError()
|
||||
}
|
||||
if !c.credentials.needsRefresh() {
|
||||
return c.credentials.getAccessToken(), nil
|
||||
}
|
||||
|
||||
err = platformCanWriteCredentials(c.credentialPath)
|
||||
if err != nil {
|
||||
return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation")
|
||||
}
|
||||
|
||||
baseCredentials := cloneCredentials(c.credentials)
|
||||
newCredentials, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
latestCredentials, latestErr := platformReadCredentials(c.credentialPath)
|
||||
if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) {
|
||||
c.credentials = latestCredentials
|
||||
c.stateAccess.Lock()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if !latestCredentials.needsRefresh() {
|
||||
return latestCredentials.getAccessToken(), nil
|
||||
}
|
||||
return "", E.New("credential ", c.tag, " changed while refreshing")
|
||||
}
|
||||
|
||||
c.credentials = newCredentials
|
||||
c.stateAccess.Lock()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
|
||||
err = platformWriteCredentials(newCredentials, c.credentialPath)
|
||||
if err != nil {
|
||||
c.logger.Error("persist refreshed token for ", c.tag, ": ", err)
|
||||
}
|
||||
|
||||
return newCredentials.getAccessToken(), nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) getAccountID() string {
|
||||
c.access.RLock()
|
||||
defer c.access.RUnlock()
|
||||
if c.credentials == nil {
|
||||
return ""
|
||||
}
|
||||
return c.credentials.getAccountID()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isAPIKeyMode() bool {
|
||||
c.access.RLock()
|
||||
defer c.access.RUnlock()
|
||||
if c.credentials == nil {
|
||||
return false
|
||||
}
|
||||
return c.credentials.isAPIKeyMode()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) getBaseURL() string {
|
||||
if c.isAPIKeyMode() {
|
||||
return openaiAPIBaseURL
|
||||
}
|
||||
return chatGPTBackendURL
|
||||
}
|
||||
|
||||
func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
|
||||
c.stateAccess.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
hadData := false
|
||||
|
||||
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
|
||||
if activeLimitIdentifier == "" {
|
||||
activeLimitIdentifier = "codex"
|
||||
}
|
||||
|
||||
fiveHourResetChanged := false
|
||||
fiveHourResetAt := headers.Get("x-" + activeLimitIdentifier + "-primary-reset-at")
|
||||
if fiveHourResetAt != "" {
|
||||
value, err := strconv.ParseInt(fiveHourResetAt, 10, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
newReset := time.Unix(value, 0)
|
||||
if newReset.After(c.state.fiveHourReset) {
|
||||
fiveHourResetChanged = true
|
||||
c.state.fiveHourReset = newReset
|
||||
}
|
||||
}
|
||||
}
|
||||
fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent")
|
||||
if fiveHourPercent != "" {
|
||||
value, err := strconv.ParseFloat(fiveHourPercent, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
if value >= c.state.fiveHourUtilization || fiveHourResetChanged {
|
||||
c.state.fiveHourUtilization = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
weeklyResetChanged := false
|
||||
weeklyResetAt := headers.Get("x-" + activeLimitIdentifier + "-secondary-reset-at")
|
||||
if weeklyResetAt != "" {
|
||||
value, err := strconv.ParseInt(weeklyResetAt, 10, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
newReset := time.Unix(value, 0)
|
||||
if newReset.After(c.state.weeklyReset) {
|
||||
weeklyResetChanged = true
|
||||
c.state.weeklyReset = newReset
|
||||
}
|
||||
}
|
||||
}
|
||||
weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent")
|
||||
if weeklyPercent != "" {
|
||||
value, err := strconv.ParseFloat(weeklyPercent, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
if value >= c.state.weeklyUtilization || weeklyResetChanged {
|
||||
c.state.weeklyUtilization = value
|
||||
}
|
||||
}
|
||||
}
|
||||
if hadData {
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.lastUpdated = time.Now()
|
||||
}
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
resetSuffix := ""
|
||||
if !c.state.weeklyReset.IsZero() {
|
||||
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
|
||||
}
|
||||
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markRateLimited(resetAt time.Time) {
|
||||
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
|
||||
c.stateAccess.Lock()
|
||||
c.state.hardRateLimited = true
|
||||
c.state.rateLimitResetAt = resetAt
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isUsable() bool {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
c.stateAccess.RLock()
|
||||
if c.state.unavailable {
|
||||
c.stateAccess.RUnlock()
|
||||
return false
|
||||
}
|
||||
if c.state.consecutivePollFailures > 0 {
|
||||
c.stateAccess.RUnlock()
|
||||
return false
|
||||
}
|
||||
if c.state.hardRateLimited {
|
||||
if time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.stateAccess.RUnlock()
|
||||
return false
|
||||
}
|
||||
c.stateAccess.RUnlock()
|
||||
c.stateAccess.Lock()
|
||||
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
usable := c.checkReservesLocked()
|
||||
c.stateAccess.Unlock()
|
||||
return usable
|
||||
}
|
||||
usable := c.checkReservesLocked()
|
||||
c.stateAccess.RUnlock()
|
||||
return usable
|
||||
}
|
||||
|
||||
func (c *defaultCredential) checkReservesLocked() bool {
|
||||
if c.state.fiveHourUtilization >= c.cap5h {
|
||||
return false
|
||||
}
|
||||
if c.state.weeklyUtilization >= c.capWeekly {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// checkTransitionLocked detects usable->unusable transition.
|
||||
// Must be called with stateAccess write lock held.
|
||||
func (c *defaultCredential) checkTransitionLocked() bool {
|
||||
unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() || c.state.consecutivePollFailures > 0
|
||||
if unusable && !c.interrupted {
|
||||
c.interrupted = true
|
||||
return true
|
||||
}
|
||||
if !unusable && c.interrupted {
|
||||
c.interrupted = false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *defaultCredential) interruptConnections() {
|
||||
c.logger.Warn("interrupting connections for ", c.tag)
|
||||
c.requestAccess.Lock()
|
||||
c.cancelRequests()
|
||||
c.requestContext, c.cancelRequests = context.WithCancel(context.Background())
|
||||
c.requestAccess.Unlock()
|
||||
if c.onBecameUnusable != nil {
|
||||
c.onBecameUnusable()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext {
|
||||
c.requestAccess.Lock()
|
||||
credentialContext := c.requestContext
|
||||
c.requestAccess.Unlock()
|
||||
derived, cancel := context.WithCancel(parent)
|
||||
stop := context.AfterFunc(credentialContext, func() {
|
||||
cancel()
|
||||
})
|
||||
return &credentialRequestContext{
|
||||
Context: derived,
|
||||
releaseFuncs: []func() bool{stop},
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) fiveHourUtilization() float64 {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.fiveHourUtilization
|
||||
}
|
||||
|
||||
func (c *defaultCredential) weeklyUtilization() float64 {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.weeklyUtilization
|
||||
}
|
||||
|
||||
func (c *defaultCredential) planWeight() float64 {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return ocmPlanWeight(c.state.accountType)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) weeklyResetTime() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.weeklyReset
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isAvailable() bool {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return !c.state.unavailable
|
||||
}
|
||||
|
||||
func (c *defaultCredential) unavailableError() error {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
if !c.state.unavailable {
|
||||
return nil
|
||||
}
|
||||
if c.state.lastCredentialLoadError == "" {
|
||||
return E.New("credential ", c.tag, " is unavailable")
|
||||
}
|
||||
return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) lastUpdatedTime() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.lastUpdated
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markUsagePollAttempted() {
|
||||
c.stateAccess.Lock()
|
||||
defer c.stateAccess.Unlock()
|
||||
c.state.lastUpdated = time.Now()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) incrementPollFailures() {
|
||||
c.stateAccess.Lock()
|
||||
c.state.consecutivePollFailures++
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration {
|
||||
c.stateAccess.RLock()
|
||||
failures := c.state.consecutivePollFailures
|
||||
c.stateAccess.RUnlock()
|
||||
if failures <= 0 {
|
||||
return baseInterval
|
||||
}
|
||||
backoff := failedPollRetryInterval * time.Duration(1<<(failures-1))
|
||||
if backoff > httpRetryMaxBackoff {
|
||||
return httpRetryMaxBackoff
|
||||
}
|
||||
return backoff
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isPollBackoffAtCap() bool {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
failures := c.state.consecutivePollFailures
|
||||
return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff
|
||||
}
|
||||
|
||||
func (c *defaultCredential) earliestReset() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
if c.state.unavailable {
|
||||
return time.Time{}
|
||||
}
|
||||
if c.state.hardRateLimited {
|
||||
return c.state.rateLimitResetAt
|
||||
}
|
||||
earliest := c.state.fiveHourReset
|
||||
if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) {
|
||||
earliest = c.state.weeklyReset
|
||||
}
|
||||
return earliest
|
||||
}
|
||||
|
||||
func (c *defaultCredential) fiveHourCap() float64 {
|
||||
return c.cap5h
|
||||
}
|
||||
|
||||
func (c *defaultCredential) weeklyCap() float64 {
|
||||
return c.capWeekly
|
||||
}
|
||||
|
||||
func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage {
|
||||
return c.usageTracker
|
||||
}
|
||||
|
||||
func (c *defaultCredential) httpClient() *http.Client {
|
||||
return c.forwardHTTPClient
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ocmDialer() N.Dialer {
|
||||
return c.dialer
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ocmIsAPIKeyMode() bool {
|
||||
return c.isAPIKeyMode()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ocmGetAccountID() string {
|
||||
return c.getAccountID()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ocmGetBaseURL() string {
|
||||
return c.getBaseURL()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) pollUsage(ctx context.Context) {
|
||||
if !c.pollAccess.TryLock() {
|
||||
return
|
||||
}
|
||||
defer c.pollAccess.Unlock()
|
||||
defer c.markUsagePollAttempted()
|
||||
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
if !c.isAvailable() {
|
||||
return
|
||||
}
|
||||
if c.isAPIKeyMode() {
|
||||
return
|
||||
}
|
||||
|
||||
accessToken, err := c.getAccessToken()
|
||||
if err != nil {
|
||||
if !c.isPollBackoffAtCap() {
|
||||
c.logger.Error("poll usage for ", c.tag, ": get token: ", err)
|
||||
}
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
|
||||
var usageURL string
|
||||
if c.isAPIKeyMode() {
|
||||
usageURL = openaiAPIBaseURL + "/api/codex/usage"
|
||||
} else {
|
||||
usageURL = strings.TrimSuffix(chatGPTBackendURL, "/codex") + "/wham/usage"
|
||||
}
|
||||
|
||||
accountID := c.getAccountID()
|
||||
pollClient := &http.Client{
|
||||
Transport: c.forwardHTTPClient.Transport,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
response, err := doHTTPWithRetry(ctx, pollClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, usageURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
if accountID != "" {
|
||||
request.Header.Set("ChatGPT-Account-Id", accountID)
|
||||
}
|
||||
return request, nil
|
||||
})
|
||||
if err != nil {
|
||||
if !c.isPollBackoffAtCap() {
|
||||
c.logger.Error("poll usage for ", c.tag, ": ", err)
|
||||
}
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
if response.StatusCode == http.StatusTooManyRequests {
|
||||
c.logger.Warn("poll usage for ", c.tag, ": rate limited")
|
||||
}
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
|
||||
type usageWindow struct {
|
||||
UsedPercent float64 `json:"used_percent"`
|
||||
ResetAt int64 `json:"reset_at"`
|
||||
}
|
||||
var usageResponse struct {
|
||||
PlanType string `json:"plan_type"`
|
||||
RateLimit *struct {
|
||||
PrimaryWindow *usageWindow `json:"primary_window"`
|
||||
SecondaryWindow *usageWindow `json:"secondary_window"`
|
||||
} `json:"rate_limit"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&usageResponse)
|
||||
if err != nil {
|
||||
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
|
||||
c.stateAccess.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
c.state.consecutivePollFailures = 0
|
||||
if usageResponse.RateLimit != nil {
|
||||
if w := usageResponse.RateLimit.PrimaryWindow; w != nil {
|
||||
c.state.fiveHourUtilization = w.UsedPercent
|
||||
if w.ResetAt > 0 {
|
||||
c.state.fiveHourReset = time.Unix(w.ResetAt, 0)
|
||||
}
|
||||
}
|
||||
if w := usageResponse.RateLimit.SecondaryWindow; w != nil {
|
||||
c.state.weeklyUtilization = w.UsedPercent
|
||||
if w.ResetAt > 0 {
|
||||
c.state.weeklyReset = time.Unix(w.ResetAt, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
if usageResponse.PlanType != "" {
|
||||
c.state.accountType = usageResponse.PlanType
|
||||
}
|
||||
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
resetSuffix := ""
|
||||
if !c.state.weeklyReset.IsZero() {
|
||||
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
|
||||
}
|
||||
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) {
|
||||
accessToken, err := c.getAccessToken()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "get access token for ", c.tag)
|
||||
}
|
||||
|
||||
path := original.URL.Path
|
||||
var proxyPath string
|
||||
if c.isAPIKeyMode() {
|
||||
proxyPath = path
|
||||
} else {
|
||||
proxyPath = strings.TrimPrefix(path, "/v1")
|
||||
}
|
||||
|
||||
proxyURL := c.getBaseURL() + proxyPath
|
||||
if original.URL.RawQuery != "" {
|
||||
proxyURL += "?" + original.URL.RawQuery
|
||||
}
|
||||
|
||||
var body io.Reader
|
||||
if bodyBytes != nil {
|
||||
body = bytes.NewReader(bodyBytes)
|
||||
} else {
|
||||
body = original.Body
|
||||
}
|
||||
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, values := range original.Header {
|
||||
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
for key, values := range serviceHeaders {
|
||||
proxyRequest.Header.Del(key)
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
if accountID := c.getAccountID(); accountID != "" {
|
||||
proxyRequest.Header.Set("ChatGPT-Account-Id", accountID)
|
||||
}
|
||||
|
||||
return proxyRequest, nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) close() {
|
||||
if c.watcher != nil {
|
||||
err := c.watcher.Close()
|
||||
if err != nil {
|
||||
c.logger.Error("close credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
if c.usageTracker != nil {
|
||||
c.usageTracker.cancelPendingSave()
|
||||
err := c.usageTracker.Save()
|
||||
if err != nil {
|
||||
c.logger.Error("save usage statistics for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -30,17 +30,17 @@ import (
|
||||
const reverseProxyBaseURL = "http://reverse-proxy"
|
||||
|
||||
type externalCredential struct {
|
||||
tag string
|
||||
baseURL string
|
||||
token string
|
||||
credDialer N.Dialer
|
||||
httpClient *http.Client
|
||||
state credentialState
|
||||
stateMutex sync.RWMutex
|
||||
pollAccess sync.Mutex
|
||||
pollInterval time.Duration
|
||||
usageTracker *AggregatedUsage
|
||||
logger log.ContextLogger
|
||||
tag string
|
||||
baseURL string
|
||||
token string
|
||||
credDialer N.Dialer
|
||||
forwardHTTPClient *http.Client
|
||||
state credentialState
|
||||
stateAccess sync.RWMutex
|
||||
pollAccess sync.Mutex
|
||||
pollInterval time.Duration
|
||||
usageTracker *AggregatedUsage
|
||||
logger log.ContextLogger
|
||||
|
||||
onBecameUnusable func()
|
||||
interrupted bool
|
||||
@@ -147,7 +147,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx
|
||||
// Receiver mode: no URL, wait for reverse connection
|
||||
cred.baseURL = reverseProxyBaseURL
|
||||
cred.credDialer = reverseSessionDialer{credential: cred}
|
||||
cred.httpClient = &http.Client{
|
||||
cred.forwardHTTPClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
ForceAttemptHTTP2: false,
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
@@ -211,11 +211,11 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
}
|
||||
}
|
||||
cred.httpClient = &http.Client{Transport: transport}
|
||||
cred.forwardHTTPClient = &http.Client{Transport: transport}
|
||||
} else {
|
||||
// Normal mode: standard HTTP client for proxying
|
||||
cred.credDialer = credentialDialer
|
||||
cred.httpClient = &http.Client{Transport: transport}
|
||||
cred.forwardHTTPClient = &http.Client{Transport: transport}
|
||||
cred.reverseCredDialer = reverseSessionDialer{credential: cred}
|
||||
cred.reverseHttpClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
@@ -273,39 +273,39 @@ func (c *externalCredential) isUsable() bool {
|
||||
if !c.isAvailable() {
|
||||
return false
|
||||
}
|
||||
c.stateMutex.RLock()
|
||||
c.stateAccess.RLock()
|
||||
if c.state.consecutivePollFailures > 0 {
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateAccess.RUnlock()
|
||||
return false
|
||||
}
|
||||
if c.state.hardRateLimited {
|
||||
if time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateAccess.RUnlock()
|
||||
return false
|
||||
}
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.RUnlock()
|
||||
c.stateAccess.Lock()
|
||||
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
return usable
|
||||
}
|
||||
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateAccess.RUnlock()
|
||||
return usable
|
||||
}
|
||||
|
||||
func (c *externalCredential) fiveHourUtilization() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.fiveHourUtilization
|
||||
}
|
||||
|
||||
func (c *externalCredential) weeklyUtilization() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.weeklyUtilization
|
||||
}
|
||||
|
||||
@@ -318,8 +318,8 @@ func (c *externalCredential) weeklyCap() float64 {
|
||||
}
|
||||
|
||||
func (c *externalCredential) planWeight() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
if c.state.remotePlanWeight > 0 {
|
||||
return c.state.remotePlanWeight
|
||||
}
|
||||
@@ -327,26 +327,26 @@ func (c *externalCredential) planWeight() float64 {
|
||||
}
|
||||
|
||||
func (c *externalCredential) weeklyResetTime() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.weeklyReset
|
||||
}
|
||||
|
||||
func (c *externalCredential) markRateLimited(resetAt time.Time) {
|
||||
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.Lock()
|
||||
c.state.hardRateLimited = true
|
||||
c.state.rateLimitResetAt = resetAt
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) earliestReset() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
if c.state.hardRateLimited {
|
||||
return c.state.rateLimitResetAt
|
||||
}
|
||||
@@ -432,7 +432,7 @@ func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Con
|
||||
}
|
||||
|
||||
func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
@@ -494,7 +494,7 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
|
||||
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
@@ -569,9 +569,9 @@ func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Resp
|
||||
}
|
||||
}
|
||||
// Forward transport with retries
|
||||
if c.httpClient != nil {
|
||||
if c.forwardHTTPClient != nil {
|
||||
forwardClient := &http.Client{
|
||||
Transport: c.httpClient.Transport,
|
||||
Transport: c.forwardHTTPClient.Transport,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
return doHTTPWithRetry(ctx, forwardClient, buildRequest(c.baseURL))
|
||||
@@ -602,10 +602,10 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
|
||||
// 404 means the remote does not have a status endpoint yet;
|
||||
// usage will be updated passively from response headers.
|
||||
if response.StatusCode == http.StatusNotFound {
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.Lock()
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
} else {
|
||||
c.incrementPollFailures()
|
||||
}
|
||||
@@ -624,7 +624,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
@@ -645,28 +645,28 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
|
||||
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) lastUpdatedTime() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.lastUpdated
|
||||
}
|
||||
|
||||
func (c *externalCredential) markUsagePollAttempted() {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
c.stateAccess.Lock()
|
||||
defer c.stateAccess.Unlock()
|
||||
c.state.lastUpdated = time.Now()
|
||||
}
|
||||
|
||||
func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration {
|
||||
c.stateMutex.RLock()
|
||||
c.stateAccess.RLock()
|
||||
failures := c.state.consecutivePollFailures
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateAccess.RUnlock()
|
||||
if failures <= 0 {
|
||||
return baseInterval
|
||||
}
|
||||
@@ -678,17 +678,17 @@ func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Durati
|
||||
}
|
||||
|
||||
func (c *externalCredential) isPollBackoffAtCap() bool {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
failures := c.state.consecutivePollFailures
|
||||
return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff
|
||||
}
|
||||
|
||||
func (c *externalCredential) incrementPollFailures() {
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.Lock()
|
||||
c.state.consecutivePollFailures++
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
@@ -698,14 +698,14 @@ func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage {
|
||||
return c.usageTracker
|
||||
}
|
||||
|
||||
func (c *externalCredential) httpTransport() *http.Client {
|
||||
func (c *externalCredential) httpClient() *http.Client {
|
||||
if c.reverseHttpClient != nil {
|
||||
session := c.getReverseSession()
|
||||
if session != nil && !session.IsClosed() {
|
||||
return c.reverseHttpClient
|
||||
}
|
||||
}
|
||||
return c.httpClient
|
||||
return c.forwardHTTPClient
|
||||
}
|
||||
|
||||
func (c *externalCredential) ocmDialer() N.Dialer {
|
||||
|
||||
@@ -62,10 +62,10 @@ func (c *defaultCredential) ensureCredentialWatcher() error {
|
||||
}
|
||||
|
||||
func (c *defaultCredential) retryCredentialReloadIfNeeded() {
|
||||
c.stateMutex.RLock()
|
||||
c.stateAccess.RLock()
|
||||
unavailable := c.state.unavailable
|
||||
lastAttempt := c.state.lastCredentialLoadAttempt
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateAccess.RUnlock()
|
||||
if !unavailable {
|
||||
return
|
||||
}
|
||||
@@ -84,10 +84,10 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
|
||||
c.reloadAccess.Lock()
|
||||
defer c.reloadAccess.Unlock()
|
||||
|
||||
c.stateMutex.RLock()
|
||||
c.stateAccess.RLock()
|
||||
unavailable := c.state.unavailable
|
||||
lastAttempt := c.state.lastCredentialLoadAttempt
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateAccess.RUnlock()
|
||||
if !force {
|
||||
if !unavailable {
|
||||
return nil
|
||||
@@ -97,39 +97,39 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
|
||||
}
|
||||
}
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.Lock()
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
|
||||
credentials, err := platformReadCredentials(c.credentialPath)
|
||||
if err != nil {
|
||||
return c.markCredentialsUnavailable(E.Cause(err, "read credentials"))
|
||||
}
|
||||
|
||||
c.accessMutex.Lock()
|
||||
c.access.Lock()
|
||||
c.credentials = credentials
|
||||
c.accessMutex.Unlock()
|
||||
c.access.Unlock()
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.Lock()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markCredentialsUnavailable(err error) error {
|
||||
c.accessMutex.Lock()
|
||||
c.access.Lock()
|
||||
hadCredentials := c.credentials != nil
|
||||
c.credentials = nil
|
||||
c.accessMutex.Unlock()
|
||||
c.access.Unlock()
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.stateAccess.Lock()
|
||||
c.state.unavailable = true
|
||||
c.state.lastCredentialLoadError = err.Error()
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
c.stateAccess.Unlock()
|
||||
|
||||
if shouldInterrupt && hadCredentials {
|
||||
c.interruptConnections()
|
||||
|
||||
225
service/ocm/credential_oauth.go
Normal file
225
service/ocm/credential_oauth.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
const (
|
||||
oauth2ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
oauth2TokenURL = "https://auth.openai.com/oauth/token"
|
||||
openaiAPIBaseURL = "https://api.openai.com"
|
||||
chatGPTBackendURL = "https://chatgpt.com/backend-api/codex"
|
||||
tokenRefreshIntervalDays = 8
|
||||
)
|
||||
|
||||
func getRealUser() (*user.User, error) {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
sudoUserInfo, err := user.Lookup(sudoUser)
|
||||
if err == nil {
|
||||
return sudoUserInfo, nil
|
||||
}
|
||||
}
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
func getDefaultCredentialsPath() (string, error) {
|
||||
if codexHome := os.Getenv("CODEX_HOME"); codexHome != "" {
|
||||
return filepath.Join(codexHome, "auth.json"), nil
|
||||
}
|
||||
userInfo, err := getRealUser()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(userInfo.HomeDir, ".codex", "auth.json"), nil
|
||||
}
|
||||
|
||||
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var credentials oauthCredentials
|
||||
err = json.Unmarshal(data, &credentials)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &credentials, nil
|
||||
}
|
||||
|
||||
func checkCredentialFileWritable(path string) error {
|
||||
file, err := os.OpenFile(path, os.O_WRONLY, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
func writeCredentialsToFile(credentials *oauthCredentials, path string) error {
|
||||
data, err := json.MarshalIndent(credentials, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0o600)
|
||||
}
|
||||
|
||||
type oauthCredentials struct {
|
||||
APIKey string `json:"OPENAI_API_KEY,omitempty"`
|
||||
Tokens *tokenData `json:"tokens,omitempty"`
|
||||
LastRefresh *time.Time `json:"last_refresh,omitempty"`
|
||||
}
|
||||
|
||||
type tokenData struct {
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) isAPIKeyMode() bool {
|
||||
return c.APIKey != ""
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) getAccessToken() string {
|
||||
if c.APIKey != "" {
|
||||
return c.APIKey
|
||||
}
|
||||
if c.Tokens != nil {
|
||||
return c.Tokens.AccessToken
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) getAccountID() string {
|
||||
if c.Tokens != nil {
|
||||
return c.Tokens.AccountID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) needsRefresh() bool {
|
||||
if c.APIKey != "" {
|
||||
return false
|
||||
}
|
||||
if c.Tokens == nil || c.Tokens.RefreshToken == "" {
|
||||
return false
|
||||
}
|
||||
if c.LastRefresh == nil {
|
||||
return true
|
||||
}
|
||||
return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour
|
||||
}
|
||||
|
||||
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
|
||||
if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" {
|
||||
return nil, E.New("refresh token is empty")
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.Tokens.RefreshToken,
|
||||
"client_id": oauth2ClientID,
|
||||
"scope": "openid profile email",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "marshal request")
|
||||
}
|
||||
|
||||
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "application/json")
|
||||
return request, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode == http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, E.New("refresh rate limited: ", response.Status, " ", string(body))
|
||||
}
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
|
||||
}
|
||||
|
||||
var tokenResponse struct {
|
||||
IDToken string `json:"id_token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode response")
|
||||
}
|
||||
|
||||
newCredentials := *credentials
|
||||
if newCredentials.Tokens == nil {
|
||||
newCredentials.Tokens = &tokenData{}
|
||||
}
|
||||
if tokenResponse.IDToken != "" {
|
||||
newCredentials.Tokens.IDToken = tokenResponse.IDToken
|
||||
}
|
||||
if tokenResponse.AccessToken != "" {
|
||||
newCredentials.Tokens.AccessToken = tokenResponse.AccessToken
|
||||
}
|
||||
if tokenResponse.RefreshToken != "" {
|
||||
newCredentials.Tokens.RefreshToken = tokenResponse.RefreshToken
|
||||
}
|
||||
now := time.Now()
|
||||
newCredentials.LastRefresh = &now
|
||||
|
||||
return &newCredentials, nil
|
||||
}
|
||||
|
||||
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
|
||||
if credentials == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *credentials
|
||||
if credentials.Tokens != nil {
|
||||
clonedTokens := *credentials.Tokens
|
||||
cloned.Tokens = &clonedTokens
|
||||
}
|
||||
if credentials.LastRefresh != nil {
|
||||
lastRefresh := *credentials.LastRefresh
|
||||
cloned.LastRefresh = &lastRefresh
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
|
||||
if left == nil || right == nil {
|
||||
return left == right
|
||||
}
|
||||
if left.APIKey != right.APIKey {
|
||||
return false
|
||||
}
|
||||
if (left.Tokens == nil) != (right.Tokens == nil) {
|
||||
return false
|
||||
}
|
||||
if left.Tokens != nil && *left.Tokens != *right.Tokens {
|
||||
return false
|
||||
}
|
||||
if (left.LastRefresh == nil) != (right.LastRefresh == nil) {
|
||||
return false
|
||||
}
|
||||
if left.LastRefresh != nil && !left.LastRefresh.Equal(*right.LastRefresh) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
411
service/ocm/credential_provider.go
Normal file
411
service/ocm/credential_provider.go
Normal file
@@ -0,0 +1,411 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type credentialProvider interface {
|
||||
selectCredential(sessionID string, selection credentialSelection) (credential, bool, error)
|
||||
onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential
|
||||
linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool
|
||||
pollIfStale(ctx context.Context)
|
||||
allCredentials() []credential
|
||||
close()
|
||||
}
|
||||
|
||||
type singleCredentialProvider struct {
|
||||
cred credential
|
||||
sessionAccess sync.RWMutex
|
||||
sessions map[string]time.Time
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) {
|
||||
if !selection.allows(p.cred) {
|
||||
return nil, false, E.New("credential ", p.cred.tagName(), " is filtered out")
|
||||
}
|
||||
if !p.cred.isAvailable() {
|
||||
return nil, false, p.cred.unavailableError()
|
||||
}
|
||||
if !p.cred.isUsable() {
|
||||
return nil, false, E.New("credential ", p.cred.tagName(), " is rate-limited")
|
||||
}
|
||||
var isNew bool
|
||||
if sessionID != "" {
|
||||
p.sessionAccess.Lock()
|
||||
if p.sessions == nil {
|
||||
p.sessions = make(map[string]time.Time)
|
||||
}
|
||||
_, exists := p.sessions[sessionID]
|
||||
if !exists {
|
||||
p.sessions[sessionID] = time.Now()
|
||||
isNew = true
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
return p.cred, isNew, nil
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) onRateLimited(_ string, cred credential, resetAt time.Time, _ credentialSelection) credential {
|
||||
cred.markRateLimited(resetAt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) pollIfStale(ctx context.Context) {
|
||||
now := time.Now()
|
||||
p.sessionAccess.Lock()
|
||||
for id, createdAt := range p.sessions {
|
||||
if now.Sub(createdAt) > sessionExpiry {
|
||||
delete(p.sessions, id)
|
||||
}
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
|
||||
if time.Since(p.cred.lastUpdatedTime()) > p.cred.pollBackoff(defaultPollInterval) {
|
||||
p.cred.pollUsage(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) allCredentials() []credential {
|
||||
return []credential{p.cred}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) linkProviderInterrupt(_ credential, _ credentialSelection, _ func()) func() bool {
|
||||
return func() bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) close() {}
|
||||
|
||||
type sessionEntry struct {
|
||||
tag string
|
||||
selectionScope credentialSelectionScope
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
type credentialInterruptKey struct {
|
||||
tag string
|
||||
selectionScope credentialSelectionScope
|
||||
}
|
||||
|
||||
type credentialInterruptEntry struct {
|
||||
context context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type balancerProvider struct {
|
||||
credentials []credential
|
||||
strategy string
|
||||
roundRobinIndex atomic.Uint64
|
||||
pollInterval time.Duration
|
||||
rebalanceThreshold float64
|
||||
sessionAccess sync.RWMutex
|
||||
sessions map[string]sessionEntry
|
||||
interruptAccess sync.Mutex
|
||||
credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry
|
||||
logger log.ContextLogger
|
||||
}
|
||||
|
||||
func compositeCredentialSelectable(cred credential) bool {
|
||||
return !cred.ocmIsAPIKeyMode()
|
||||
}
|
||||
|
||||
func newBalancerProvider(credentials []credential, strategy string, pollInterval time.Duration, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider {
|
||||
if pollInterval <= 0 {
|
||||
pollInterval = defaultPollInterval
|
||||
}
|
||||
return &balancerProvider{
|
||||
credentials: credentials,
|
||||
strategy: strategy,
|
||||
pollInterval: pollInterval,
|
||||
rebalanceThreshold: rebalanceThreshold,
|
||||
sessions: make(map[string]sessionEntry),
|
||||
credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (credential, bool, error) {
|
||||
if p.strategy == C.BalancerStrategyFallback {
|
||||
best := p.pickCredential(selection.filter)
|
||||
if best == nil {
|
||||
return nil, false, allRateLimitedError(p.credentials)
|
||||
}
|
||||
return best, false, nil
|
||||
}
|
||||
|
||||
selectionScope := selection.scopeOrDefault()
|
||||
if sessionID != "" {
|
||||
p.sessionAccess.RLock()
|
||||
entry, exists := p.sessions[sessionID]
|
||||
p.sessionAccess.RUnlock()
|
||||
if exists {
|
||||
if entry.selectionScope == selectionScope {
|
||||
for _, cred := range p.credentials {
|
||||
if cred.tagName() == entry.tag && compositeCredentialSelectable(cred) && selection.allows(cred) && cred.isUsable() {
|
||||
if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) {
|
||||
better := p.pickLeastUsed(selection.filter)
|
||||
if better != nil && better.tagName() != cred.tagName() {
|
||||
effectiveThreshold := p.rebalanceThreshold / cred.planWeight()
|
||||
delta := cred.weeklyUtilization() - better.weeklyUtilization()
|
||||
if delta > effectiveThreshold {
|
||||
p.logger.Info("rebalancing away from ", cred.tagName(),
|
||||
": utilization delta ", delta, "% exceeds effective threshold ",
|
||||
effectiveThreshold, "% (weight ", cred.planWeight(), ")")
|
||||
p.rebalanceCredential(cred.tagName(), selectionScope)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return cred, false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
p.sessionAccess.Lock()
|
||||
delete(p.sessions, sessionID)
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
best := p.pickCredential(selection.filter)
|
||||
if best == nil {
|
||||
return nil, false, allRateLimitedError(p.credentials)
|
||||
}
|
||||
|
||||
isNew := sessionID != ""
|
||||
if isNew {
|
||||
p.sessionAccess.Lock()
|
||||
p.sessions[sessionID] = sessionEntry{
|
||||
tag: best.tagName(),
|
||||
selectionScope: selectionScope,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
return best, isNew, nil
|
||||
}
|
||||
|
||||
func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) {
|
||||
key := credentialInterruptKey{tag: tag, selectionScope: selectionScope}
|
||||
p.interruptAccess.Lock()
|
||||
if entry, loaded := p.credentialInterrupts[key]; loaded {
|
||||
entry.cancel()
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel}
|
||||
p.interruptAccess.Unlock()
|
||||
|
||||
p.sessionAccess.Lock()
|
||||
for id, entry := range p.sessions {
|
||||
if entry.tag == tag && entry.selectionScope == selectionScope {
|
||||
delete(p.sessions, id)
|
||||
}
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
|
||||
func (p *balancerProvider) linkProviderInterrupt(cred credential, selection credentialSelection, onInterrupt func()) func() bool {
|
||||
if p.strategy == C.BalancerStrategyFallback {
|
||||
return func() bool { return false }
|
||||
}
|
||||
key := credentialInterruptKey{
|
||||
tag: cred.tagName(),
|
||||
selectionScope: selection.scopeOrDefault(),
|
||||
}
|
||||
p.interruptAccess.Lock()
|
||||
entry, loaded := p.credentialInterrupts[key]
|
||||
if !loaded {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
entry = credentialInterruptEntry{context: ctx, cancel: cancel}
|
||||
p.credentialInterrupts[key] = entry
|
||||
}
|
||||
p.interruptAccess.Unlock()
|
||||
return context.AfterFunc(entry.context, onInterrupt)
|
||||
}
|
||||
|
||||
func (p *balancerProvider) onRateLimited(sessionID string, cred credential, resetAt time.Time, selection credentialSelection) credential {
|
||||
cred.markRateLimited(resetAt)
|
||||
if p.strategy == C.BalancerStrategyFallback {
|
||||
return p.pickCredential(selection.filter)
|
||||
}
|
||||
if sessionID != "" {
|
||||
p.sessionAccess.Lock()
|
||||
delete(p.sessions, sessionID)
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
|
||||
best := p.pickCredential(selection.filter)
|
||||
if best != nil && sessionID != "" {
|
||||
p.sessionAccess.Lock()
|
||||
p.sessions[sessionID] = sessionEntry{
|
||||
tag: best.tagName(),
|
||||
selectionScope: selection.scopeOrDefault(),
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickCredential(filter func(credential) bool) credential {
|
||||
switch p.strategy {
|
||||
case C.BalancerStrategyRoundRobin:
|
||||
return p.pickRoundRobin(filter)
|
||||
case C.BalancerStrategyRandom:
|
||||
return p.pickRandom(filter)
|
||||
case C.BalancerStrategyFallback:
|
||||
return p.pickFallback(filter)
|
||||
default:
|
||||
return p.pickLeastUsed(filter)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickFallback(filter func(credential) bool) credential {
|
||||
for _, cred := range p.credentials {
|
||||
if filter != nil && !filter(cred) {
|
||||
continue
|
||||
}
|
||||
if !compositeCredentialSelectable(cred) {
|
||||
continue
|
||||
}
|
||||
if cred.isUsable() {
|
||||
return cred
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const weeklyWindowHours = 7 * 24
|
||||
|
||||
func (p *balancerProvider) pickLeastUsed(filter func(credential) bool) credential {
|
||||
var best credential
|
||||
bestScore := float64(-1)
|
||||
now := time.Now()
|
||||
for _, cred := range p.credentials {
|
||||
if filter != nil && !filter(cred) {
|
||||
continue
|
||||
}
|
||||
if !compositeCredentialSelectable(cred) {
|
||||
continue
|
||||
}
|
||||
if !cred.isUsable() {
|
||||
continue
|
||||
}
|
||||
remaining := cred.weeklyCap() - cred.weeklyUtilization()
|
||||
score := remaining * cred.planWeight()
|
||||
resetTime := cred.weeklyResetTime()
|
||||
if !resetTime.IsZero() {
|
||||
timeUntilReset := resetTime.Sub(now)
|
||||
if timeUntilReset < time.Hour {
|
||||
timeUntilReset = time.Hour
|
||||
}
|
||||
score *= weeklyWindowHours / timeUntilReset.Hours()
|
||||
}
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
best = cred
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func ocmPlanWeight(accountType string) float64 {
|
||||
switch accountType {
|
||||
case "pro":
|
||||
return 10
|
||||
case "plus":
|
||||
return 1
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickRoundRobin(filter func(credential) bool) credential {
|
||||
start := int(p.roundRobinIndex.Add(1) - 1)
|
||||
count := len(p.credentials)
|
||||
for offset := range count {
|
||||
candidate := p.credentials[(start+offset)%count]
|
||||
if filter != nil && !filter(candidate) {
|
||||
continue
|
||||
}
|
||||
if !compositeCredentialSelectable(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickRandom(filter func(credential) bool) credential {
|
||||
var usable []credential
|
||||
for _, candidate := range p.credentials {
|
||||
if filter != nil && !filter(candidate) {
|
||||
continue
|
||||
}
|
||||
if !compositeCredentialSelectable(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
usable = append(usable, candidate)
|
||||
}
|
||||
}
|
||||
if len(usable) == 0 {
|
||||
return nil
|
||||
}
|
||||
return usable[rand.IntN(len(usable))]
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pollIfStale(ctx context.Context) {
|
||||
now := time.Now()
|
||||
p.sessionAccess.Lock()
|
||||
for id, entry := range p.sessions {
|
||||
if now.Sub(entry.createdAt) > sessionExpiry {
|
||||
delete(p.sessions, id)
|
||||
}
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
|
||||
for _, cred := range p.credentials {
|
||||
if time.Since(cred.lastUpdatedTime()) > cred.pollBackoff(p.pollInterval) {
|
||||
cred.pollUsage(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) allCredentials() []credential {
|
||||
return p.credentials
|
||||
}
|
||||
|
||||
func (p *balancerProvider) close() {}
|
||||
|
||||
func allRateLimitedError(credentials []credential) error {
|
||||
var hasUnavailable bool
|
||||
var earliest time.Time
|
||||
for _, cred := range credentials {
|
||||
if cred.unavailableError() != nil {
|
||||
hasUnavailable = true
|
||||
continue
|
||||
}
|
||||
resetAt := cred.earliestReset()
|
||||
if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) {
|
||||
earliest = resetAt
|
||||
}
|
||||
}
|
||||
if hasUnavailable {
|
||||
return E.New("all credentials unavailable")
|
||||
}
|
||||
if earliest.IsZero() {
|
||||
return E.New("all credentials rate-limited")
|
||||
}
|
||||
return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest)))
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,17 +1,13 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
boxService "github.com/sagernet/sing-box/adapter/service"
|
||||
@@ -21,14 +17,11 @@ import (
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
@@ -160,71 +153,20 @@ func isReverseProxyHeader(header string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeRateLimitIdentifier(limitIdentifier string) string {
|
||||
trimmedIdentifier := strings.TrimSpace(strings.ToLower(limitIdentifier))
|
||||
if trimmedIdentifier == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.ReplaceAll(trimmedIdentifier, "_", "-")
|
||||
}
|
||||
|
||||
func parseInt64Header(headers http.Header, headerName string) (int64, bool) {
|
||||
headerValue := strings.TrimSpace(headers.Get(headerName))
|
||||
if headerValue == "" {
|
||||
return 0, false
|
||||
}
|
||||
parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64)
|
||||
if parseError != nil {
|
||||
return 0, false
|
||||
}
|
||||
return parsedValue, true
|
||||
}
|
||||
|
||||
func weeklyCycleHintForLimit(headers http.Header, limitIdentifier string) *WeeklyCycleHint {
|
||||
normalizedLimitIdentifier := normalizeRateLimitIdentifier(limitIdentifier)
|
||||
if normalizedLimitIdentifier == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
windowHeader := "x-" + normalizedLimitIdentifier + "-secondary-window-minutes"
|
||||
resetHeader := "x-" + normalizedLimitIdentifier + "-secondary-reset-at"
|
||||
|
||||
windowMinutes, hasWindowMinutes := parseInt64Header(headers, windowHeader)
|
||||
resetAtUnix, hasResetAt := parseInt64Header(headers, resetHeader)
|
||||
if !hasWindowMinutes || !hasResetAt || windowMinutes <= 0 || resetAtUnix <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &WeeklyCycleHint{
|
||||
WindowMinutes: windowMinutes,
|
||||
ResetAt: time.Unix(resetAtUnix, 0).UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
|
||||
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
|
||||
if activeLimitIdentifier != "" {
|
||||
if activeHint := weeklyCycleHintForLimit(headers, activeLimitIdentifier); activeHint != nil {
|
||||
return activeHint
|
||||
}
|
||||
}
|
||||
return weeklyCycleHintForLimit(headers, "codex")
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
boxService.Adapter
|
||||
ctx context.Context
|
||||
logger log.ContextLogger
|
||||
options option.OCMServiceOptions
|
||||
httpHeaders http.Header
|
||||
listener *listener.Listener
|
||||
tlsConfig tls.ServerConfig
|
||||
httpServer *http.Server
|
||||
userManager *UserManager
|
||||
webSocketMutex sync.Mutex
|
||||
webSocketGroup sync.WaitGroup
|
||||
webSocketConns map[*webSocketSession]struct{}
|
||||
shuttingDown bool
|
||||
ctx context.Context
|
||||
logger log.ContextLogger
|
||||
options option.OCMServiceOptions
|
||||
httpHeaders http.Header
|
||||
listener *listener.Listener
|
||||
tlsConfig tls.ServerConfig
|
||||
httpServer *http.Server
|
||||
userManager *UserManager
|
||||
webSocketAccess sync.Mutex
|
||||
webSocketGroup sync.WaitGroup
|
||||
webSocketConns map[*webSocketSession]struct{}
|
||||
shuttingDown bool
|
||||
|
||||
// Legacy mode
|
||||
legacyCredential *defaultCredential
|
||||
@@ -361,562 +303,6 @@ func (s *Service) Start(stage adapter.StartStage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) {
|
||||
if len(s.options.Users) > 0 {
|
||||
return credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
}
|
||||
provider := noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
|
||||
if provider == nil {
|
||||
return nil, E.New("no credential available")
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := log.ContextWithNewID(r.Context())
|
||||
if r.URL.Path == "/ocm/v1/status" {
|
||||
s.handleStatusEndpoint(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/ocm/v1/reverse" {
|
||||
s.handleReverseConnect(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
path := r.URL.Path
|
||||
if !strings.HasPrefix(path, "/v1/") {
|
||||
writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/")
|
||||
return
|
||||
}
|
||||
|
||||
var username string
|
||||
if len(s.options.Users) > 0 {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
var ok bool
|
||||
username, ok = s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sessionID := r.Header.Get("session_id")
|
||||
|
||||
// Resolve credential provider and user config
|
||||
var provider credentialProvider
|
||||
var userConfig *option.OCMUser
|
||||
if len(s.options.Users) > 0 {
|
||||
userConfig = s.userConfigMap[username]
|
||||
var err error
|
||||
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "resolve credential: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
|
||||
}
|
||||
if provider == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale(s.ctx)
|
||||
|
||||
selection := credentialSelectionForUser(userConfig)
|
||||
|
||||
selectedCredential, isNew, err := provider.selectCredential(sessionID, selection)
|
||||
if err != nil {
|
||||
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
|
||||
s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, selection, isNew)
|
||||
return
|
||||
}
|
||||
|
||||
if !selectedCredential.isExternal() && selectedCredential.ocmIsAPIKeyMode() {
|
||||
// API key mode path handling
|
||||
} else if !selectedCredential.isExternal() {
|
||||
if path == "/v1/chat/completions" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"chat completions endpoint is only available in API key mode")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
shouldTrackUsage := selectedCredential.usageTrackerOrNil() != nil &&
|
||||
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses"))
|
||||
canRetryRequest := len(provider.allCredentials()) > 1
|
||||
|
||||
// Read body for model extraction and retry buffer when JSON replay is useful.
|
||||
var bodyBytes []byte
|
||||
var requestModel string
|
||||
var requestServiceTier string
|
||||
if r.Body != nil && (shouldTrackUsage || canRetryRequest) {
|
||||
mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||
isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json"))
|
||||
if isJSONRequest {
|
||||
bodyBytes, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "read request body: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
|
||||
return
|
||||
}
|
||||
var request struct {
|
||||
Model string `json:"model"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
}
|
||||
if json.Unmarshal(bodyBytes, &request) == nil {
|
||||
requestModel = request.Model
|
||||
requestServiceTier = request.ServiceTier
|
||||
}
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
if isNew {
|
||||
logParts := []any{"assigned credential ", selectedCredential.tagName()}
|
||||
if sessionID != "" {
|
||||
logParts = append(logParts, " for session ", sessionID)
|
||||
}
|
||||
if username != "" {
|
||||
logParts = append(logParts, " by user ", username)
|
||||
}
|
||||
if requestModel != "" {
|
||||
logParts = append(logParts, ", model=", requestModel)
|
||||
}
|
||||
if requestServiceTier == "priority" {
|
||||
logParts = append(logParts, ", fast")
|
||||
}
|
||||
s.logger.DebugContext(ctx, logParts...)
|
||||
}
|
||||
|
||||
requestContext := selectedCredential.wrapRequestContext(ctx)
|
||||
{
|
||||
currentRequestContext := requestContext
|
||||
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
|
||||
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
||||
}))
|
||||
}
|
||||
defer func() {
|
||||
requestContext.cancelRequest()
|
||||
}()
|
||||
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "create proxy request: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
response, err := selectedCredential.httpTransport().Do(proxyRequest)
|
||||
if err != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request")
|
||||
return
|
||||
}
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
|
||||
// Transparent 429 retry
|
||||
for response.StatusCode == http.StatusTooManyRequests {
|
||||
resetAt := parseOCMRateLimitResetFromHeaders(response.Header)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
|
||||
needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil {
|
||||
response.Body.Close()
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
|
||||
return
|
||||
}
|
||||
response.Body.Close()
|
||||
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
requestContext.cancelRequest()
|
||||
requestContext = nextCredential.wrapRequestContext(ctx)
|
||||
{
|
||||
currentRequestContext := requestContext
|
||||
requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() {
|
||||
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
||||
}))
|
||||
}
|
||||
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if buildErr != nil {
|
||||
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
|
||||
return
|
||||
}
|
||||
retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest)
|
||||
if retryErr != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request")
|
||||
return
|
||||
}
|
||||
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
|
||||
return
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
response = retryResponse
|
||||
selectedCredential = nextCredential
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
|
||||
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
||||
go selectedCredential.pollUsage(s.ctx)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
|
||||
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite response headers for external users
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" {
|
||||
s.rewriteResponseHeadersForExternalUser(response.Header, userConfig)
|
||||
}
|
||||
|
||||
for key, values := range response.Header {
|
||||
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
|
||||
w.Header()[key] = values
|
||||
}
|
||||
}
|
||||
w.WriteHeader(response.StatusCode)
|
||||
|
||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||
if usageTracker != nil && response.StatusCode == http.StatusOK &&
|
||||
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) {
|
||||
s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username)
|
||||
} else {
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
if err == nil && mediaType != "text/event-stream" {
|
||||
_, _ = io.Copy(w, response.Body)
|
||||
return
|
||||
}
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
_, writeError := w.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) {
|
||||
isChatCompletions := path == "/v1/chat/completions"
|
||||
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
isStreaming := err == nil && mediaType == "text/event-stream"
|
||||
if !isStreaming && !isChatCompletions && response.Header.Get("Content-Type") == "" {
|
||||
isStreaming = true
|
||||
}
|
||||
if !isStreaming {
|
||||
bodyBytes, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "read response body: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
var responseModel, serviceTier string
|
||||
var inputTokens, outputTokens, cachedTokens int64
|
||||
|
||||
if isChatCompletions {
|
||||
var chatCompletion openai.ChatCompletion
|
||||
if json.Unmarshal(bodyBytes, &chatCompletion) == nil {
|
||||
responseModel = chatCompletion.Model
|
||||
serviceTier = string(chatCompletion.ServiceTier)
|
||||
inputTokens = chatCompletion.Usage.PromptTokens
|
||||
outputTokens = chatCompletion.Usage.CompletionTokens
|
||||
cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
} else {
|
||||
var responsesResponse responses.Response
|
||||
if json.Unmarshal(bodyBytes, &responsesResponse) == nil {
|
||||
responseModel = string(responsesResponse.Model)
|
||||
serviceTier = string(responsesResponse.ServiceTier)
|
||||
inputTokens = responsesResponse.Usage.InputTokens
|
||||
outputTokens = responsesResponse.Usage.OutputTokens
|
||||
cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cachedTokens,
|
||||
serviceTier,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = writer.Write(bodyBytes)
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
var inputTokens, outputTokens, cachedTokens int64
|
||||
var responseModel, serviceTier string
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
var leftover []byte
|
||||
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
data := append(leftover, buffer[:n]...)
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
|
||||
if err == nil {
|
||||
leftover = lines[len(lines)-1]
|
||||
lines = lines[:len(lines)-1]
|
||||
} else {
|
||||
leftover = nil
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
eventData := bytes.TrimPrefix(line, []byte("data: "))
|
||||
if bytes.Equal(eventData, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
|
||||
if isChatCompletions {
|
||||
var chatChunk openai.ChatCompletionChunk
|
||||
if json.Unmarshal(eventData, &chatChunk) == nil {
|
||||
if chatChunk.Model != "" {
|
||||
responseModel = chatChunk.Model
|
||||
}
|
||||
if chatChunk.ServiceTier != "" {
|
||||
serviceTier = string(chatChunk.ServiceTier)
|
||||
}
|
||||
if chatChunk.Usage.PromptTokens > 0 {
|
||||
inputTokens = chatChunk.Usage.PromptTokens
|
||||
cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
if chatChunk.Usage.CompletionTokens > 0 {
|
||||
outputTokens = chatChunk.Usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var streamEvent responses.ResponseStreamEventUnion
|
||||
if json.Unmarshal(eventData, &streamEvent) == nil {
|
||||
if streamEvent.Type == "response.completed" {
|
||||
completedEvent := streamEvent.AsResponseCompleted()
|
||||
if string(completedEvent.Response.Model) != "" {
|
||||
responseModel = string(completedEvent.Response.Model)
|
||||
}
|
||||
if completedEvent.Response.ServiceTier != "" {
|
||||
serviceTier = string(completedEvent.Response.ServiceTier)
|
||||
}
|
||||
if completedEvent.Response.Usage.InputTokens > 0 {
|
||||
inputTokens = completedEvent.Response.Usage.InputTokens
|
||||
cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
if completedEvent.Response.Usage.OutputTokens > 0 {
|
||||
outputTokens = completedEvent.Response.Usage.OutputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, writeError := writer.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cachedTokens,
|
||||
serviceTier,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
if len(s.options.Users) == 0 {
|
||||
writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication")
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
username, ok := s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
|
||||
userConfig := s.userConfigMap[username]
|
||||
if userConfig == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found")
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale(r.Context())
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]float64{
|
||||
"five_hour_utilization": avgFiveHour,
|
||||
"weekly_utilization": avgWeekly,
|
||||
"plan_weight": totalWeight,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) {
|
||||
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
|
||||
for _, cred := range provider.allCredentials() {
|
||||
if !cred.isAvailable() {
|
||||
continue
|
||||
}
|
||||
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
|
||||
continue
|
||||
}
|
||||
if !userConfig.AllowExternalUsage && cred.isExternal() {
|
||||
continue
|
||||
}
|
||||
weight := cred.planWeight()
|
||||
remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization()
|
||||
if remaining5h < 0 {
|
||||
remaining5h = 0
|
||||
}
|
||||
remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization()
|
||||
if remainingWeekly < 0 {
|
||||
remainingWeekly = 0
|
||||
}
|
||||
totalWeightedRemaining5h += remaining5h * weight
|
||||
totalWeightedRemainingWeekly += remainingWeekly * weight
|
||||
totalWeight += weight
|
||||
}
|
||||
if totalWeight == 0 {
|
||||
return 100, 100, 0
|
||||
}
|
||||
return 100 - totalWeightedRemaining5h/totalWeight,
|
||||
100 - totalWeightedRemainingWeekly/totalWeight,
|
||||
totalWeight
|
||||
}
|
||||
|
||||
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.OCMUser) {
|
||||
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
|
||||
if activeLimitIdentifier == "" {
|
||||
activeLimitIdentifier = "codex"
|
||||
}
|
||||
|
||||
headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64))
|
||||
headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64))
|
||||
if totalWeight > 0 {
|
||||
headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) InterfaceUpdated() {
|
||||
for _, cred := range s.allCredentials {
|
||||
extCred, ok := cred.(*externalCredential)
|
||||
@@ -952,8 +338,8 @@ func (s *Service) Close() error {
|
||||
}
|
||||
|
||||
func (s *Service) registerWebSocketSession(session *webSocketSession) bool {
|
||||
s.webSocketMutex.Lock()
|
||||
defer s.webSocketMutex.Unlock()
|
||||
s.webSocketAccess.Lock()
|
||||
defer s.webSocketAccess.Unlock()
|
||||
|
||||
if s.shuttingDown {
|
||||
return false
|
||||
@@ -965,12 +351,12 @@ func (s *Service) registerWebSocketSession(session *webSocketSession) bool {
|
||||
}
|
||||
|
||||
func (s *Service) unregisterWebSocketSession(session *webSocketSession) {
|
||||
s.webSocketMutex.Lock()
|
||||
s.webSocketAccess.Lock()
|
||||
_, loaded := s.webSocketConns[session]
|
||||
if loaded {
|
||||
delete(s.webSocketConns, session)
|
||||
}
|
||||
s.webSocketMutex.Unlock()
|
||||
s.webSocketAccess.Unlock()
|
||||
|
||||
if loaded {
|
||||
s.webSocketGroup.Done()
|
||||
@@ -978,28 +364,28 @@ func (s *Service) unregisterWebSocketSession(session *webSocketSession) {
|
||||
}
|
||||
|
||||
func (s *Service) isShuttingDown() bool {
|
||||
s.webSocketMutex.Lock()
|
||||
defer s.webSocketMutex.Unlock()
|
||||
s.webSocketAccess.Lock()
|
||||
defer s.webSocketAccess.Unlock()
|
||||
return s.shuttingDown
|
||||
}
|
||||
|
||||
func (s *Service) interruptWebSocketSessionsForCredential(tag string) {
|
||||
s.webSocketMutex.Lock()
|
||||
s.webSocketAccess.Lock()
|
||||
var toClose []*webSocketSession
|
||||
for session := range s.webSocketConns {
|
||||
if session.credentialTag == tag {
|
||||
toClose = append(toClose, session)
|
||||
}
|
||||
}
|
||||
s.webSocketMutex.Unlock()
|
||||
s.webSocketAccess.Unlock()
|
||||
for _, session := range toClose {
|
||||
session.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) startWebSocketShutdown() []*webSocketSession {
|
||||
s.webSocketMutex.Lock()
|
||||
defer s.webSocketMutex.Unlock()
|
||||
s.webSocketAccess.Lock()
|
||||
defer s.webSocketAccess.Unlock()
|
||||
|
||||
s.shuttingDown = true
|
||||
|
||||
|
||||
504
service/ocm/service_handler.go
Normal file
504
service/ocm/service_handler.go
Normal file
@@ -0,0 +1,504 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
)
|
||||
|
||||
func weeklyCycleHintForLimit(headers http.Header, limitIdentifier string) *WeeklyCycleHint {
|
||||
normalizedLimitIdentifier := normalizeRateLimitIdentifier(limitIdentifier)
|
||||
if normalizedLimitIdentifier == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
windowHeader := "x-" + normalizedLimitIdentifier + "-secondary-window-minutes"
|
||||
resetHeader := "x-" + normalizedLimitIdentifier + "-secondary-reset-at"
|
||||
|
||||
windowMinutes, hasWindowMinutes := parseInt64Header(headers, windowHeader)
|
||||
resetAtUnix, hasResetAt := parseInt64Header(headers, resetHeader)
|
||||
if !hasWindowMinutes || !hasResetAt || windowMinutes <= 0 || resetAtUnix <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &WeeklyCycleHint{
|
||||
WindowMinutes: windowMinutes,
|
||||
ResetAt: time.Unix(resetAtUnix, 0).UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
|
||||
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
|
||||
if activeLimitIdentifier != "" {
|
||||
if activeHint := weeklyCycleHintForLimit(headers, activeLimitIdentifier); activeHint != nil {
|
||||
return activeHint
|
||||
}
|
||||
}
|
||||
return weeklyCycleHintForLimit(headers, "codex")
|
||||
}
|
||||
|
||||
func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) {
|
||||
if len(s.options.Users) > 0 {
|
||||
return credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
}
|
||||
provider := noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
|
||||
if provider == nil {
|
||||
return nil, E.New("no credential available")
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := log.ContextWithNewID(r.Context())
|
||||
if r.URL.Path == "/ocm/v1/status" {
|
||||
s.handleStatusEndpoint(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/ocm/v1/reverse" {
|
||||
s.handleReverseConnect(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
path := r.URL.Path
|
||||
if !strings.HasPrefix(path, "/v1/") {
|
||||
writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/")
|
||||
return
|
||||
}
|
||||
|
||||
var username string
|
||||
if len(s.options.Users) > 0 {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
var ok bool
|
||||
username, ok = s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sessionID := r.Header.Get("session_id")
|
||||
|
||||
// Resolve credential provider and user config
|
||||
var provider credentialProvider
|
||||
var userConfig *option.OCMUser
|
||||
if len(s.options.Users) > 0 {
|
||||
userConfig = s.userConfigMap[username]
|
||||
var err error
|
||||
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "resolve credential: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
|
||||
}
|
||||
if provider == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale(s.ctx)
|
||||
|
||||
selection := credentialSelectionForUser(userConfig)
|
||||
|
||||
selectedCredential, isNew, err := provider.selectCredential(sessionID, selection)
|
||||
if err != nil {
|
||||
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
|
||||
s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, selection, isNew)
|
||||
return
|
||||
}
|
||||
|
||||
if !selectedCredential.isExternal() && selectedCredential.ocmIsAPIKeyMode() {
|
||||
// API key mode path handling
|
||||
} else if !selectedCredential.isExternal() {
|
||||
if path == "/v1/chat/completions" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"chat completions endpoint is only available in API key mode")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
shouldTrackUsage := selectedCredential.usageTrackerOrNil() != nil &&
|
||||
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses"))
|
||||
canRetryRequest := len(provider.allCredentials()) > 1
|
||||
|
||||
// Read body for model extraction and retry buffer when JSON replay is useful.
|
||||
var bodyBytes []byte
|
||||
var requestModel string
|
||||
var requestServiceTier string
|
||||
if r.Body != nil && (shouldTrackUsage || canRetryRequest) {
|
||||
mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||
isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json"))
|
||||
if isJSONRequest {
|
||||
bodyBytes, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "read request body: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
|
||||
return
|
||||
}
|
||||
var request struct {
|
||||
Model string `json:"model"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
}
|
||||
if json.Unmarshal(bodyBytes, &request) == nil {
|
||||
requestModel = request.Model
|
||||
requestServiceTier = request.ServiceTier
|
||||
}
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
if isNew {
|
||||
logParts := []any{"assigned credential ", selectedCredential.tagName()}
|
||||
if sessionID != "" {
|
||||
logParts = append(logParts, " for session ", sessionID)
|
||||
}
|
||||
if username != "" {
|
||||
logParts = append(logParts, " by user ", username)
|
||||
}
|
||||
if requestModel != "" {
|
||||
logParts = append(logParts, ", model=", requestModel)
|
||||
}
|
||||
if requestServiceTier == "priority" {
|
||||
logParts = append(logParts, ", fast")
|
||||
}
|
||||
s.logger.DebugContext(ctx, logParts...)
|
||||
}
|
||||
|
||||
requestContext := selectedCredential.wrapRequestContext(ctx)
|
||||
{
|
||||
currentRequestContext := requestContext
|
||||
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
|
||||
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
||||
}))
|
||||
}
|
||||
defer func() {
|
||||
requestContext.cancelRequest()
|
||||
}()
|
||||
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "create proxy request: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
response, err := selectedCredential.httpClient().Do(proxyRequest)
|
||||
if err != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request")
|
||||
return
|
||||
}
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
|
||||
// Transparent 429 retry
|
||||
for response.StatusCode == http.StatusTooManyRequests {
|
||||
resetAt := parseOCMRateLimitResetFromHeaders(response.Header)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
|
||||
needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil {
|
||||
response.Body.Close()
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
|
||||
return
|
||||
}
|
||||
response.Body.Close()
|
||||
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
requestContext.cancelRequest()
|
||||
requestContext = nextCredential.wrapRequestContext(ctx)
|
||||
{
|
||||
currentRequestContext := requestContext
|
||||
requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() {
|
||||
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
||||
}))
|
||||
}
|
||||
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if buildErr != nil {
|
||||
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
|
||||
return
|
||||
}
|
||||
retryResponse, retryErr := nextCredential.httpClient().Do(retryRequest)
|
||||
if retryErr != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request")
|
||||
return
|
||||
}
|
||||
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
|
||||
return
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
response = retryResponse
|
||||
selectedCredential = nextCredential
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
|
||||
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
||||
go selectedCredential.pollUsage(s.ctx)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
|
||||
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite response headers for external users
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" {
|
||||
s.rewriteResponseHeadersForExternalUser(response.Header, userConfig)
|
||||
}
|
||||
|
||||
for key, values := range response.Header {
|
||||
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
|
||||
w.Header()[key] = values
|
||||
}
|
||||
}
|
||||
w.WriteHeader(response.StatusCode)
|
||||
|
||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||
if usageTracker != nil && response.StatusCode == http.StatusOK &&
|
||||
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) {
|
||||
s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username)
|
||||
} else {
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
if err == nil && mediaType != "text/event-stream" {
|
||||
_, _ = io.Copy(w, response.Body)
|
||||
return
|
||||
}
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
_, writeError := w.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) {
|
||||
isChatCompletions := path == "/v1/chat/completions"
|
||||
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
isStreaming := err == nil && mediaType == "text/event-stream"
|
||||
if !isStreaming && !isChatCompletions && response.Header.Get("Content-Type") == "" {
|
||||
isStreaming = true
|
||||
}
|
||||
if !isStreaming {
|
||||
bodyBytes, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "read response body: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
var responseModel, serviceTier string
|
||||
var inputTokens, outputTokens, cachedTokens int64
|
||||
|
||||
if isChatCompletions {
|
||||
var chatCompletion openai.ChatCompletion
|
||||
if json.Unmarshal(bodyBytes, &chatCompletion) == nil {
|
||||
responseModel = chatCompletion.Model
|
||||
serviceTier = string(chatCompletion.ServiceTier)
|
||||
inputTokens = chatCompletion.Usage.PromptTokens
|
||||
outputTokens = chatCompletion.Usage.CompletionTokens
|
||||
cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
} else {
|
||||
var responsesResponse responses.Response
|
||||
if json.Unmarshal(bodyBytes, &responsesResponse) == nil {
|
||||
responseModel = string(responsesResponse.Model)
|
||||
serviceTier = string(responsesResponse.ServiceTier)
|
||||
inputTokens = responsesResponse.Usage.InputTokens
|
||||
outputTokens = responsesResponse.Usage.OutputTokens
|
||||
cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cachedTokens,
|
||||
serviceTier,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = writer.Write(bodyBytes)
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
var inputTokens, outputTokens, cachedTokens int64
|
||||
var responseModel, serviceTier string
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
var leftover []byte
|
||||
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
data := append(leftover, buffer[:n]...)
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
|
||||
if err == nil {
|
||||
leftover = lines[len(lines)-1]
|
||||
lines = lines[:len(lines)-1]
|
||||
} else {
|
||||
leftover = nil
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
eventData := bytes.TrimPrefix(line, []byte("data: "))
|
||||
if bytes.Equal(eventData, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
|
||||
if isChatCompletions {
|
||||
var chatChunk openai.ChatCompletionChunk
|
||||
if json.Unmarshal(eventData, &chatChunk) == nil {
|
||||
if chatChunk.Model != "" {
|
||||
responseModel = chatChunk.Model
|
||||
}
|
||||
if chatChunk.ServiceTier != "" {
|
||||
serviceTier = string(chatChunk.ServiceTier)
|
||||
}
|
||||
if chatChunk.Usage.PromptTokens > 0 {
|
||||
inputTokens = chatChunk.Usage.PromptTokens
|
||||
cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
if chatChunk.Usage.CompletionTokens > 0 {
|
||||
outputTokens = chatChunk.Usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var streamEvent responses.ResponseStreamEventUnion
|
||||
if json.Unmarshal(eventData, &streamEvent) == nil {
|
||||
if streamEvent.Type == "response.completed" {
|
||||
completedEvent := streamEvent.AsResponseCompleted()
|
||||
if string(completedEvent.Response.Model) != "" {
|
||||
responseModel = string(completedEvent.Response.Model)
|
||||
}
|
||||
if completedEvent.Response.ServiceTier != "" {
|
||||
serviceTier = string(completedEvent.Response.ServiceTier)
|
||||
}
|
||||
if completedEvent.Response.Usage.InputTokens > 0 {
|
||||
inputTokens = completedEvent.Response.Usage.InputTokens
|
||||
cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
if completedEvent.Response.Usage.OutputTokens > 0 {
|
||||
outputTokens = completedEvent.Response.Usage.OutputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, writeError := writer.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cachedTokens,
|
||||
serviceTier,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
114
service/ocm/service_status.go
Normal file
114
service/ocm/service_status.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
if len(s.options.Users) == 0 {
|
||||
writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication")
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
username, ok := s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
|
||||
userConfig := s.userConfigMap[username]
|
||||
if userConfig == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found")
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale(r.Context())
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]float64{
|
||||
"five_hour_utilization": avgFiveHour,
|
||||
"weekly_utilization": avgWeekly,
|
||||
"plan_weight": totalWeight,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) {
|
||||
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
|
||||
for _, cred := range provider.allCredentials() {
|
||||
if !cred.isAvailable() {
|
||||
continue
|
||||
}
|
||||
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
|
||||
continue
|
||||
}
|
||||
if !userConfig.AllowExternalUsage && cred.isExternal() {
|
||||
continue
|
||||
}
|
||||
weight := cred.planWeight()
|
||||
remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization()
|
||||
if remaining5h < 0 {
|
||||
remaining5h = 0
|
||||
}
|
||||
remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization()
|
||||
if remainingWeekly < 0 {
|
||||
remainingWeekly = 0
|
||||
}
|
||||
totalWeightedRemaining5h += remaining5h * weight
|
||||
totalWeightedRemainingWeekly += remainingWeekly * weight
|
||||
totalWeight += weight
|
||||
}
|
||||
if totalWeight == 0 {
|
||||
return 100, 100, 0
|
||||
}
|
||||
return 100 - totalWeightedRemaining5h/totalWeight,
|
||||
100 - totalWeightedRemainingWeekly/totalWeight,
|
||||
totalWeight
|
||||
}
|
||||
|
||||
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.OCMUser) {
|
||||
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
|
||||
if activeLimitIdentifier == "" {
|
||||
activeLimitIdentifier = "codex"
|
||||
}
|
||||
|
||||
headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64))
|
||||
headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64))
|
||||
if totalWeight > 0 {
|
||||
headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64))
|
||||
}
|
||||
}
|
||||
@@ -7,13 +7,13 @@ import (
|
||||
)
|
||||
|
||||
type UserManager struct {
|
||||
accessMutex sync.RWMutex
|
||||
access sync.RWMutex
|
||||
tokenMap map[string]string
|
||||
}
|
||||
|
||||
func (m *UserManager) UpdateUsers(users []option.OCMUser) {
|
||||
m.accessMutex.Lock()
|
||||
defer m.accessMutex.Unlock()
|
||||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
tokenMap := make(map[string]string, len(users))
|
||||
for _, user := range users {
|
||||
tokenMap[user.Token] = user.Name
|
||||
@@ -22,8 +22,8 @@ func (m *UserManager) UpdateUsers(users []option.OCMUser) {
|
||||
}
|
||||
|
||||
func (m *UserManager) Authenticate(token string) (string, bool) {
|
||||
m.accessMutex.RLock()
|
||||
m.access.RLock()
|
||||
username, found := m.tokenMap[token]
|
||||
m.accessMutex.RUnlock()
|
||||
m.access.RUnlock()
|
||||
return username, found
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user