mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-14 20:58:33 +10:00
206 lines
6.5 KiB
Go
206 lines
6.5 KiB
Go
package ccm
|
|
|
|
import (
|
|
"errors"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestGetAccessTokenReturnsExistingTokenWhenLockFails(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
directory := t.TempDir()
|
|
credentialPath := filepath.Join(directory, ".credentials.json")
|
|
writeTestCredentials(t, credentialPath, &oauthCredentials{
|
|
AccessToken: "old-token",
|
|
RefreshToken: "refresh-token",
|
|
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
|
|
Scopes: []string{"user:profile", "user:inference"},
|
|
SubscriptionType: optionalStringPointer("max"),
|
|
RateLimitTier: optionalStringPointer("default_claude_max_20x"),
|
|
})
|
|
|
|
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
|
t.Fatal("refresh should not be attempted when lock acquisition fails")
|
|
return nil, nil
|
|
}))
|
|
if err := credential.reloadCredentials(true); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
originalLockFunc := acquireCredentialLockFunc
|
|
acquireCredentialLockFunc = func(string) (func(), error) {
|
|
return nil, errors.New("locked")
|
|
}
|
|
t.Cleanup(func() {
|
|
acquireCredentialLockFunc = originalLockFunc
|
|
})
|
|
|
|
token, err := credential.getAccessToken()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if token != "old-token" {
|
|
t.Fatalf("expected old token, got %q", token)
|
|
}
|
|
}
|
|
|
|
func TestGetAccessTokenAbsorbsRefreshDoneByAnotherProcess(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
directory := t.TempDir()
|
|
credentialPath := filepath.Join(directory, ".credentials.json")
|
|
oldCredentials := &oauthCredentials{
|
|
AccessToken: "old-token",
|
|
RefreshToken: "refresh-token",
|
|
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
|
|
Scopes: []string{"user:profile", "user:inference"},
|
|
SubscriptionType: optionalStringPointer("max"),
|
|
RateLimitTier: optionalStringPointer("default_claude_max_20x"),
|
|
}
|
|
writeTestCredentials(t, credentialPath, oldCredentials)
|
|
|
|
newCredentials := cloneCredentials(oldCredentials)
|
|
newCredentials.AccessToken = "new-token"
|
|
newCredentials.ExpiresAt = time.Now().Add(time.Hour).UnixMilli()
|
|
transport := roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
|
if request.URL.Path == "/v1/oauth/token" {
|
|
writeTestCredentials(t, credentialPath, newCredentials)
|
|
return newJSONResponse(http.StatusInternalServerError, `{"error":"boom"}`), nil
|
|
}
|
|
t.Fatalf("unexpected path %s", request.URL.Path)
|
|
return nil, nil
|
|
})
|
|
|
|
credential := newTestDefaultCredential(t, credentialPath, transport)
|
|
if err := credential.reloadCredentials(true); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
token, err := credential.getAccessToken()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if token != "new-token" {
|
|
t.Fatalf("expected refreshed token from disk, got %q", token)
|
|
}
|
|
}
|
|
|
|
func TestCustomCredentialPathDoesNotEnableClaudeConfigSync(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
directory := t.TempDir()
|
|
credentialPath := filepath.Join(directory, ".credentials.json")
|
|
writeTestCredentials(t, credentialPath, &oauthCredentials{
|
|
AccessToken: "token",
|
|
ExpiresAt: time.Now().Add(time.Hour).UnixMilli(),
|
|
Scopes: []string{"user:profile"},
|
|
})
|
|
|
|
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
|
t.Fatalf("unexpected request to %s", request.URL.Path)
|
|
return nil, nil
|
|
}))
|
|
if err := credential.reloadCredentials(true); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
token, err := credential.getAccessToken()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if token != "token" {
|
|
t.Fatalf("expected token, got %q", token)
|
|
}
|
|
if credential.shouldUseClaudeConfig() {
|
|
t.Fatal("custom credential path should not enable Claude config sync")
|
|
}
|
|
if _, err := os.Stat(filepath.Join(directory, ".claude.json")); !os.IsNotExist(err) {
|
|
t.Fatalf("did not expect config file to be created, stat err=%v", err)
|
|
}
|
|
}
|
|
|
|
func TestDefaultCredentialHydratesProfileAndWritesConfig(t *testing.T) {
|
|
configDir := t.TempDir()
|
|
credentialPath := filepath.Join(configDir, ".credentials.json")
|
|
|
|
writeTestCredentials(t, credentialPath, &oauthCredentials{
|
|
AccessToken: "old-token",
|
|
RefreshToken: "refresh-token",
|
|
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
|
|
Scopes: []string{"user:profile", "user:inference"},
|
|
})
|
|
|
|
transport := roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
|
switch request.URL.Path {
|
|
case "/v1/oauth/token":
|
|
return newJSONResponse(http.StatusOK, `{
|
|
"access_token":"new-token",
|
|
"refresh_token":"new-refresh",
|
|
"expires_in":3600,
|
|
"account":{"uuid":"account","email_address":"user@example.com"},
|
|
"organization":{"uuid":"org"}
|
|
}`), nil
|
|
case "/api/oauth/profile":
|
|
return newJSONResponse(http.StatusOK, `{
|
|
"account":{
|
|
"uuid":"account",
|
|
"email":"user@example.com",
|
|
"display_name":"User",
|
|
"created_at":"2024-01-01T00:00:00Z"
|
|
},
|
|
"organization":{
|
|
"uuid":"org",
|
|
"organization_type":"claude_max",
|
|
"rate_limit_tier":"default_claude_max_20x",
|
|
"has_extra_usage_enabled":true,
|
|
"billing_type":"individual",
|
|
"subscription_created_at":"2024-01-02T00:00:00Z"
|
|
}
|
|
}`), nil
|
|
default:
|
|
t.Fatalf("unexpected path %s", request.URL.Path)
|
|
return nil, nil
|
|
}
|
|
})
|
|
|
|
credential := newTestDefaultCredential(t, credentialPath, transport)
|
|
credential.syncClaudeConfig = true
|
|
credential.claudeDirectory = configDir
|
|
credential.claudeConfigPath = resolveClaudeConfigWritePath(configDir)
|
|
if err := credential.reloadCredentials(true); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
token, err := credential.getAccessToken()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if token != "new-token" {
|
|
t.Fatalf("expected refreshed token, got %q", token)
|
|
}
|
|
|
|
updatedCredentials := readTestCredentials(t, credentialPath)
|
|
if updatedCredentials.SubscriptionType == nil || *updatedCredentials.SubscriptionType != "max" {
|
|
t.Fatalf("expected subscription type to be persisted, got %#v", updatedCredentials.SubscriptionType)
|
|
}
|
|
if updatedCredentials.RateLimitTier == nil || *updatedCredentials.RateLimitTier != "default_claude_max_20x" {
|
|
t.Fatalf("expected rate limit tier to be persisted, got %#v", updatedCredentials.RateLimitTier)
|
|
}
|
|
|
|
configPath := tempConfigPath(t, configDir)
|
|
config, err := readClaudeCodeConfig(configPath)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if config.OAuthAccount == nil || config.OAuthAccount.AccountUUID != "account" || config.OAuthAccount.EmailAddress != "user@example.com" {
|
|
t.Fatalf("unexpected oauth account: %#v", config.OAuthAccount)
|
|
}
|
|
if config.OAuthAccount.BillingType == nil || *config.OAuthAccount.BillingType != "individual" {
|
|
t.Fatalf("expected billing type to be hydrated, got %#v", config.OAuthAccount.BillingType)
|
|
}
|
|
}
|