mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
225 lines
6.0 KiB
Go
225 lines
6.0 KiB
Go
package ccm
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"os/user"
|
|
"path/filepath"
|
|
"runtime"
|
|
"slices"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sagernet/sing-box/log"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
)
|
|
|
|
const (
|
|
oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
|
oauth2TokenURL = "https://platform.claude.com/v1/oauth/token"
|
|
claudeAPIBaseURL = "https://api.anthropic.com"
|
|
tokenRefreshBufferMs = 60000
|
|
anthropicBetaOAuthValue = "oauth-2025-04-20"
|
|
)
|
|
|
|
const ccmUserAgentFallback = "claude-code/2.1.72"
|
|
|
|
var (
|
|
ccmUserAgentOnce sync.Once
|
|
ccmUserAgentValue string
|
|
)
|
|
|
|
func initCCMUserAgent(logger log.ContextLogger) {
|
|
ccmUserAgentOnce.Do(func() {
|
|
version, err := detectClaudeCodeVersion()
|
|
if err != nil {
|
|
logger.Error("detect Claude Code version: ", err)
|
|
ccmUserAgentValue = ccmUserAgentFallback
|
|
return
|
|
}
|
|
logger.Debug("detected Claude Code version: ", version)
|
|
ccmUserAgentValue = "claude-code/" + version
|
|
})
|
|
}
|
|
|
|
func detectClaudeCodeVersion() (string, error) {
|
|
userInfo, err := getRealUser()
|
|
if err != nil {
|
|
return "", E.Cause(err, "get user")
|
|
}
|
|
binaryName := "claude"
|
|
if runtime.GOOS == "windows" {
|
|
binaryName = "claude.exe"
|
|
}
|
|
linkPath := filepath.Join(userInfo.HomeDir, ".local", "bin", binaryName)
|
|
target, err := os.Readlink(linkPath)
|
|
if err != nil {
|
|
return "", E.Cause(err, "readlink ", linkPath)
|
|
}
|
|
if !filepath.IsAbs(target) {
|
|
target = filepath.Join(filepath.Dir(linkPath), target)
|
|
}
|
|
parent := filepath.Base(filepath.Dir(target))
|
|
if parent != "versions" {
|
|
return "", E.New("unexpected symlink target: ", target)
|
|
}
|
|
return filepath.Base(target), nil
|
|
}
|
|
|
|
func getRealUser() (*user.User, error) {
|
|
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
|
sudoUserInfo, err := user.Lookup(sudoUser)
|
|
if err == nil {
|
|
return sudoUserInfo, nil
|
|
}
|
|
}
|
|
return user.Current()
|
|
}
|
|
|
|
func getDefaultCredentialsPath() (string, error) {
|
|
if configDir := os.Getenv("CLAUDE_CONFIG_DIR"); configDir != "" {
|
|
return filepath.Join(configDir, ".credentials.json"), nil
|
|
}
|
|
userInfo, err := getRealUser()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return filepath.Join(userInfo.HomeDir, ".claude", ".credentials.json"), nil
|
|
}
|
|
|
|
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var credentialsContainer struct {
|
|
ClaudeAIAuth *oauthCredentials `json:"claudeAiOauth,omitempty"`
|
|
}
|
|
err = json.Unmarshal(data, &credentialsContainer)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if credentialsContainer.ClaudeAIAuth == nil {
|
|
return nil, E.New("claudeAiOauth field not found in credentials")
|
|
}
|
|
return credentialsContainer.ClaudeAIAuth, nil
|
|
}
|
|
|
|
func checkCredentialFileWritable(path string) error {
|
|
file, err := os.OpenFile(path, os.O_WRONLY, 0)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return file.Close()
|
|
}
|
|
|
|
func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) error {
|
|
data, err := json.MarshalIndent(map[string]any{
|
|
"claudeAiOauth": oauthCredentials,
|
|
}, "", " ")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return os.WriteFile(path, data, 0o600)
|
|
}
|
|
|
|
type oauthCredentials struct {
|
|
AccessToken string `json:"accessToken"`
|
|
RefreshToken string `json:"refreshToken"`
|
|
ExpiresAt int64 `json:"expiresAt"`
|
|
Scopes []string `json:"scopes,omitempty"`
|
|
SubscriptionType string `json:"subscriptionType,omitempty"`
|
|
RateLimitTier string `json:"rateLimitTier,omitempty"`
|
|
IsMax bool `json:"isMax,omitempty"`
|
|
}
|
|
|
|
func (c *oauthCredentials) needsRefresh() bool {
|
|
if c.ExpiresAt == 0 {
|
|
return false
|
|
}
|
|
return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs
|
|
}
|
|
|
|
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
|
|
if credentials.RefreshToken == "" {
|
|
return nil, E.New("refresh token is empty")
|
|
}
|
|
|
|
requestBody, err := json.Marshal(map[string]string{
|
|
"grant_type": "refresh_token",
|
|
"refresh_token": credentials.RefreshToken,
|
|
"client_id": oauth2ClientID,
|
|
})
|
|
if err != nil {
|
|
return nil, E.Cause(err, "marshal request")
|
|
}
|
|
|
|
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
|
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
request.Header.Set("Content-Type", "application/json")
|
|
request.Header.Set("User-Agent", ccmUserAgentValue)
|
|
return request, nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer response.Body.Close()
|
|
|
|
if response.StatusCode == http.StatusTooManyRequests {
|
|
body, _ := io.ReadAll(response.Body)
|
|
return nil, E.New("refresh rate limited: ", response.Status, " ", string(body))
|
|
}
|
|
if response.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(response.Body)
|
|
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
|
|
}
|
|
|
|
var tokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
}
|
|
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
|
|
if err != nil {
|
|
return nil, E.Cause(err, "decode response")
|
|
}
|
|
|
|
newCredentials := *credentials
|
|
newCredentials.AccessToken = tokenResponse.AccessToken
|
|
if tokenResponse.RefreshToken != "" {
|
|
newCredentials.RefreshToken = tokenResponse.RefreshToken
|
|
}
|
|
newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000
|
|
|
|
return &newCredentials, nil
|
|
}
|
|
|
|
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
|
|
if credentials == nil {
|
|
return nil
|
|
}
|
|
cloned := *credentials
|
|
cloned.Scopes = append([]string(nil), credentials.Scopes...)
|
|
return &cloned
|
|
}
|
|
|
|
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
|
|
if left == nil || right == nil {
|
|
return left == right
|
|
}
|
|
return left.AccessToken == right.AccessToken &&
|
|
left.RefreshToken == right.RefreshToken &&
|
|
left.ExpiresAt == right.ExpiresAt &&
|
|
slices.Equal(left.Scopes, right.Scopes) &&
|
|
left.SubscriptionType == right.SubscriptionType &&
|
|
left.RateLimitTier == right.RateLimitTier &&
|
|
left.IsMax == right.IsMax
|
|
}
|