Files
sing-box/service/ccm/credential_status_test.go
2026-03-17 16:03:35 +08:00

281 lines
9.2 KiB
Go

package ccm
import (
"context"
"errors"
"io"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common/observable"
"github.com/hashicorp/yamux"
)
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) {
return f(request)
}
func drainStatusEvents(subscription observable.Subscription[struct{}]) int {
var count int
for {
select {
case <-subscription:
count++
default:
return count
}
}
}
func newTestLogger() log.ContextLogger {
return log.NewNOPFactory().Logger()
}
func newTestCCMExternalCredential(t *testing.T, body string, headers http.Header) (*externalCredential, observable.Subscription[struct{}]) {
t.Helper()
subscriber := observable.NewSubscriber[struct{}](8)
subscription, _ := subscriber.Subscription()
credential := &externalCredential{
tag: "test",
baseURL: "http://example.com",
token: "token",
pollInterval: 25 * time.Millisecond,
forwardHTTPClient: &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) {
if request.URL.String() != "http://example.com/ccm/v1/status?watch=true" {
t.Fatalf("unexpected request URL: %s", request.URL.String())
}
return &http.Response{
StatusCode: http.StatusOK,
Header: headers.Clone(),
Body: io.NopCloser(strings.NewReader(body)),
}, nil
})},
logger: newTestLogger(),
statusSubscriber: subscriber,
}
return credential, subscription
}
func newTestYamuxSessionPair(t *testing.T) (*yamux.Session, *yamux.Session) {
t.Helper()
clientConn, serverConn := net.Pipe()
clientSession, err := yamux.Client(clientConn, defaultYamuxConfig)
if err != nil {
t.Fatalf("create yamux client: %v", err)
}
serverSession, err := yamux.Server(serverConn, defaultYamuxConfig)
if err != nil {
clientSession.Close()
t.Fatalf("create yamux server: %v", err)
}
t.Cleanup(func() {
clientSession.Close()
serverSession.Close()
})
return clientSession, serverSession
}
func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *testing.T) {
credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil)
oldTime := time.Unix(123, 0)
credential.stateAccess.Lock()
credential.state.lastUpdated = oldTime
credential.stateAccess.Unlock()
result, err := credential.connectStatusStream(context.Background())
if err != io.EOF {
t.Fatalf("expected EOF, got %v", err)
}
if !result.oneShot {
t.Fatal("expected one-shot result")
}
if result.frames != 1 {
t.Fatalf("expected 1 frame, got %d", result.frames)
}
if !credential.lastUpdatedTime().Equal(oldTime) {
t.Fatalf("expected lastUpdated restored to %v, got %v", oldTime, credential.lastUpdatedTime())
}
if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 {
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
}
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event, got %d", count)
}
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
if !oneShot {
t.Fatal("expected one-shot backoff branch")
}
if failures != 0 {
t.Fatalf("expected failures reset, got %d", failures)
}
if backoff != credential.pollInterval {
t.Fatalf("expected poll interval backoff %v, got %v", credential.pollInterval, backoff)
}
}
func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) {
headers := make(http.Header)
headers.Set(statusStreamHeader, "true")
credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", headers)
oldTime := time.Unix(123, 0)
credential.stateAccess.Lock()
credential.state.lastUpdated = oldTime
credential.stateAccess.Unlock()
result, err := credential.connectStatusStream(context.Background())
if err != io.EOF {
t.Fatalf("expected EOF, got %v", err)
}
if result.oneShot {
t.Fatal("did not expect one-shot result")
}
if result.frames != 1 {
t.Fatalf("expected 1 frame, got %d", result.frames)
}
if credential.lastUpdatedTime().Equal(oldTime) {
t.Fatal("expected lastUpdated to remain refreshed")
}
if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 {
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
}
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event, got %d", count)
}
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
if oneShot {
t.Fatal("did not expect one-shot backoff branch")
}
if failures != 4 {
t.Fatalf("expected failures incremented to 4, got %d", failures)
}
if backoff < 16*time.Second || backoff >= 24*time.Second {
t.Fatalf("expected connector backoff in [16s, 24s), got %v", backoff)
}
}
func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *testing.T) {
credential, subscription := newTestCCMExternalCredential(t, strings.Join([]string{
"{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}",
"{\"five_hour_utilization\":13,\"weekly_utilization\":35,\"plan_weight\":3}",
}, "\n"), nil)
oldTime := time.Unix(123, 0)
credential.stateAccess.Lock()
credential.state.lastUpdated = oldTime
credential.stateAccess.Unlock()
result, err := credential.connectStatusStream(context.Background())
if err != io.EOF {
t.Fatalf("expected EOF, got %v", err)
}
if result.oneShot {
t.Fatal("did not expect one-shot result")
}
if result.frames != 2 {
t.Fatalf("expected 2 frames, got %d", result.frames)
}
if credential.lastUpdatedTime().Equal(oldTime) {
t.Fatal("expected lastUpdated to remain refreshed")
}
if credential.fiveHourUtilization() != 13 || credential.weeklyUtilization() != 35 {
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
}
if count := drainStatusEvents(subscription); count != 2 {
t.Fatalf("expected 2 status events, got %d", count)
}
}
func TestDefaultCredentialStatusChangesEmitStatus(t *testing.T) {
credentialPath := filepath.Join(t.TempDir(), "credentials.json")
err := os.WriteFile(credentialPath, []byte("{\"claudeAiOauth\":{\"accessToken\":\"token\",\"refreshToken\":\"\",\"expiresAt\":0,\"subscriptionType\":\"max\"}}\n"), 0o600)
if err != nil {
t.Fatalf("write credential file: %v", err)
}
subscriber := observable.NewSubscriber[struct{}](8)
subscription, _ := subscriber.Subscription()
credential := &defaultCredential{
tag: "test",
credentialPath: credentialPath,
logger: newTestLogger(),
statusSubscriber: subscriber,
}
err = credential.markCredentialsUnavailable(errors.New("boom"))
if err == nil {
t.Fatal("expected error from markCredentialsUnavailable")
}
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event after unavailable transition, got %d", count)
}
err = credential.reloadCredentials(true)
if err != nil {
t.Fatalf("reload credentials: %v", err)
}
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event after recovery, got %d", count)
}
if weight := credential.planWeight(); weight != 5 {
t.Fatalf("expected initial max weight 5, got %v", weight)
}
profileClient := &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(
"{\"organization\":{\"organization_type\":\"claude_max\",\"rate_limit_tier\":\"default_claude_max_20x\"}}",
)),
}, nil
})}
credential.fetchProfile(context.Background(), profileClient, "token")
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event after weight change, got %d", count)
}
if weight := credential.planWeight(); weight != 10 {
t.Fatalf("expected upgraded max weight 10, got %v", weight)
}
}
func TestExternalCredentialReverseSessionChangesEmitStatus(t *testing.T) {
subscriber := observable.NewSubscriber[struct{}](8)
subscription, _ := subscriber.Subscription()
credential := &externalCredential{
tag: "receiver",
baseURL: reverseProxyBaseURL,
pollInterval: time.Minute,
logger: newTestLogger(),
statusSubscriber: subscriber,
}
clientSession, _ := newTestYamuxSessionPair(t)
if !credential.setReverseSession(clientSession) {
t.Fatal("expected reverse session to be accepted")
}
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event after reverse session up, got %d", count)
}
if !credential.isAvailable() {
t.Fatal("expected receiver credential to become available")
}
credential.clearReverseSession(clientSession)
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event after reverse session down, got %d", count)
}
if credential.isAvailable() {
t.Fatal("expected receiver credential to become unavailable")
}
}