Add stream watch endpoint

This commit is contained in:
世界
2026-03-17 16:03:35 +08:00
parent f3c3022094
commit f84832a369
16 changed files with 1225 additions and 24 deletions

View File

@@ -6,6 +6,8 @@ import (
"strconv"
"sync"
"time"
"github.com/sagernet/sing/common/observable"
)
const (
@@ -115,6 +117,7 @@ type Credential interface {
wrapRequestContext(ctx context.Context) *credentialRequestContext
interruptConnections()
setStatusSubscriber(*observable.Subscriber[struct{}])
start() error
pollUsage(ctx context.Context)
lastUpdatedTime() time.Time

View File

@@ -161,4 +161,3 @@ func credentialForUser(
}
return provider, nil
}

View File

@@ -21,6 +21,7 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/common/observable"
)
type defaultCredential struct {
@@ -43,11 +44,13 @@ type defaultCredential struct {
watcher *fswatch.Watcher
watcherRetryAt time.Time
statusSubscriber *observable.Subscriber[struct{}]
// Connection interruption
interrupted bool
requestContext context.Context
cancelRequests context.CancelFunc
requestAccess sync.Mutex
requestContext context.Context
cancelRequests context.CancelFunc
requestAccess sync.Mutex
}
func newDefaultCredential(ctx context.Context, tag string, options option.CCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) {
@@ -139,6 +142,23 @@ func (c *defaultCredential) start() error {
return nil
}
func (c *defaultCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) {
c.statusSubscriber = subscriber
}
func (c *defaultCredential) emitStatusUpdate() {
if c.statusSubscriber != nil {
c.statusSubscriber.Emit(struct{}{})
}
}
func (c *defaultCredential) statusAggregateStateLocked() (bool, float64) {
if c.state.unavailable {
return false, 0
}
return true, ccmPlanWeight(c.state.accountType, c.state.rateLimitTier)
}
func (c *defaultCredential) getAccessToken() (string, error) {
c.retryCredentialReloadIfNeeded()
@@ -186,13 +206,19 @@ func (c *defaultCredential) getAccessToken() (string, error) {
if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) {
c.credentials = latestCredentials
c.stateAccess.Lock()
wasAvailable, oldWeight := c.statusAggregateStateLocked()
c.state.unavailable = false
c.state.lastCredentialLoadAttempt = time.Now()
c.state.lastCredentialLoadError = ""
c.state.accountType = latestCredentials.SubscriptionType
c.state.rateLimitTier = latestCredentials.RateLimitTier
c.checkTransitionLocked()
isAvailable, newWeight := c.statusAggregateStateLocked()
shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight
c.stateAccess.Unlock()
if shouldEmit {
c.emitStatusUpdate()
}
if !latestCredentials.needsRefresh() {
return latestCredentials.AccessToken, nil
}
@@ -201,13 +227,19 @@ func (c *defaultCredential) getAccessToken() (string, error) {
c.credentials = newCredentials
c.stateAccess.Lock()
wasAvailable, oldWeight := c.statusAggregateStateLocked()
c.state.unavailable = false
c.state.lastCredentialLoadAttempt = time.Now()
c.state.lastCredentialLoadError = ""
c.state.accountType = newCredentials.SubscriptionType
c.state.rateLimitTier = newCredentials.RateLimitTier
c.checkTransitionLocked()
isAvailable, newWeight := c.statusAggregateStateLocked()
shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight
c.stateAccess.Unlock()
if shouldEmit {
c.emitStatusUpdate()
}
err = platformWriteCredentials(newCredentials, c.credentialPath)
if err != nil {
@@ -277,6 +309,9 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
if shouldInterrupt {
c.interruptConnections()
}
if hadData {
c.emitStatusUpdate()
}
}
func (c *defaultCredential) markRateLimited(resetAt time.Time) {
@@ -289,6 +324,7 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) {
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *defaultCredential) isUsable() bool {
@@ -584,6 +620,7 @@ func (c *defaultCredential) pollUsage(ctx context.Context) {
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
if needsProfileFetch {
c.fetchProfile(ctx, httpClient, accessToken)
@@ -636,6 +673,7 @@ func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.C
rateLimitTier := profileResponse.Organization.RateLimitTier
c.stateAccess.Lock()
wasAvailable, oldWeight := c.statusAggregateStateLocked()
if accountType != "" && c.state.accountType == "" {
c.state.accountType = accountType
}
@@ -643,7 +681,12 @@ func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.C
c.state.rateLimitTier = rateLimitTier
}
resolvedAccountType := c.state.accountType
isAvailable, newWeight := c.statusAggregateStateLocked()
shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight
c.stateAccess.Unlock()
if shouldEmit {
c.emitStatusUpdate()
}
c.logger.Info("fetched profile for ", c.tag, ": type=", resolvedAccountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(resolvedAccountType, rateLimitTier))
}

View File

@@ -22,11 +22,15 @@ import (
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/common/observable"
"github.com/hashicorp/yamux"
)
const reverseProxyBaseURL = "http://reverse-proxy"
const (
reverseProxyBaseURL = "http://reverse-proxy"
statusStreamHeader = "X-CCM-Status-Stream"
)
type externalCredential struct {
tag string
@@ -40,10 +44,12 @@ type externalCredential struct {
usageTracker *AggregatedUsage
logger log.ContextLogger
statusSubscriber *observable.Subscriber[struct{}]
interrupted bool
requestContext context.Context
cancelRequests context.CancelFunc
requestAccess sync.Mutex
requestContext context.Context
cancelRequests context.CancelFunc
requestAccess sync.Mutex
// Reverse proxy fields
reverse bool
@@ -61,6 +67,12 @@ type externalCredential struct {
reverseService http.Handler
}
type statusStreamResult struct {
duration time.Duration
frames int
oneShot bool
}
func externalCredentialURLPort(parsedURL *url.URL) uint16 {
portString := parsedURL.Port()
if portString != "" {
@@ -218,6 +230,16 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx
return credential, nil
}
func (c *externalCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) {
c.statusSubscriber = subscriber
}
func (c *externalCredential) emitStatusUpdate() {
if c.statusSubscriber != nil {
c.statusSubscriber.Emit(struct{}{})
}
}
func (c *externalCredential) start() error {
if c.usageTracker != nil {
err := c.usageTracker.Load()
@@ -227,6 +249,8 @@ func (c *externalCredential) start() error {
}
if c.reverse && c.connectorURL != nil {
go c.connectorLoop()
} else {
go c.statusStreamLoop()
}
return nil
}
@@ -317,6 +341,7 @@ func (c *externalCredential) markRateLimited(resetAt time.Time) {
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *externalCredential) earliestReset() time.Time {
@@ -458,6 +483,9 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
if shouldInterrupt {
c.interruptConnections()
}
if hadData {
c.emitStatusUpdate()
}
}
func (c *externalCredential) checkTransitionLocked() bool {
@@ -595,6 +623,142 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *externalCredential) statusStreamLoop() {
var consecutiveFailures int
ctx := c.getReverseContext()
for {
select {
case <-ctx.Done():
return
default:
}
result, err := c.connectStatusStream(ctx)
if ctx.Err() != nil {
return
}
var backoff time.Duration
var oneShot bool
consecutiveFailures, backoff, oneShot = c.nextStatusStreamBackoff(result, consecutiveFailures)
if oneShot {
c.logger.Debug("status stream for ", c.tag, " returned a single-frame response, retrying in ", backoff)
} else {
c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff)
}
timer := time.NewTimer(backoff)
select {
case <-timer.C:
case <-ctx.Done():
timer.Stop()
return
}
}
}
func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStreamResult, error) {
startTime := time.Now()
result := statusStreamResult{}
response, err := c.doStreamStatusRequest(ctx)
if err != nil {
result.duration = time.Since(startTime)
return result, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
result.duration = time.Since(startTime)
return result, E.New("status ", response.StatusCode, " ", string(body))
}
decoder := json.NewDecoder(response.Body)
isStatusStream := response.Header.Get(statusStreamHeader) == "true"
previousLastUpdated := c.lastUpdatedTime()
var firstFrameUpdatedAt time.Time
for {
var statusResponse struct {
FiveHourUtilization float64 `json:"five_hour_utilization"`
WeeklyUtilization float64 `json:"weekly_utilization"`
PlanWeight float64 `json:"plan_weight"`
}
err = decoder.Decode(&statusResponse)
if err != nil {
result.duration = time.Since(startTime)
if result.frames == 1 && err == io.EOF && !isStatusStream {
result.oneShot = true
c.restoreLastUpdatedIfUnchanged(firstFrameUpdatedAt, previousLastUpdated)
}
return result, err
}
c.stateAccess.Lock()
c.state.consecutivePollFailures = 0
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
if statusResponse.PlanWeight > 0 {
c.state.remotePlanWeight = statusResponse.PlanWeight
}
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
result.frames++
updatedAt := c.markUsageStreamUpdated()
if result.frames == 1 {
firstFrameUpdatedAt = updatedAt
}
c.emitStatusUpdate()
}
}
func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration, bool) {
if result.oneShot {
return 0, c.pollInterval, true
}
if result.duration >= connectorBackoffResetThreshold {
consecutiveFailures = 0
}
consecutiveFailures++
return consecutiveFailures, connectorBackoff(consecutiveFailures), false
}
func (c *externalCredential) doStreamStatusRequest(ctx context.Context) (*http.Response, error) {
buildRequest := func(baseURL string) (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/ccm/v1/status?watch=true", nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+c.token)
return request, nil
}
if c.reverseHTTPClient != nil {
session := c.getReverseSession()
if session != nil && !session.IsClosed() {
request, err := buildRequest(reverseProxyBaseURL)
if err != nil {
return nil, err
}
response, err := c.reverseHTTPClient.Do(request)
if err == nil {
return response, nil
}
}
}
if c.forwardHTTPClient != nil {
request, err := buildRequest(c.baseURL)
if err != nil {
return nil, err
}
return c.forwardHTTPClient.Do(request)
}
return nil, E.New("no transport available")
}
func (c *externalCredential) lastUpdatedTime() time.Time {
@@ -603,6 +767,25 @@ func (c *externalCredential) lastUpdatedTime() time.Time {
return c.state.lastUpdated
}
func (c *externalCredential) markUsageStreamUpdated() time.Time {
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
now := time.Now()
c.state.lastUpdated = now
return now
}
func (c *externalCredential) restoreLastUpdatedIfUnchanged(expectedCurrent time.Time, previous time.Time) {
if expectedCurrent.IsZero() {
return
}
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
if c.state.lastUpdated.Equal(expectedCurrent) {
c.state.lastUpdated = previous
}
}
func (c *externalCredential) markUsagePollAttempted() {
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
@@ -665,26 +848,40 @@ func (c *externalCredential) getReverseSession() *yamux.Session {
}
func (c *externalCredential) setReverseSession(session *yamux.Session) bool {
var emitStatus bool
c.reverseAccess.Lock()
if c.closed {
c.reverseAccess.Unlock()
return false
}
wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
old := c.reverseSession
c.reverseSession = session
isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
emitStatus = wasAvailable != isAvailable
c.reverseAccess.Unlock()
if old != nil {
old.Close()
}
if emitStatus {
c.emitStatusUpdate()
}
return true
}
func (c *externalCredential) clearReverseSession(session *yamux.Session) {
var emitStatus bool
c.reverseAccess.Lock()
wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
if c.reverseSession == session {
c.reverseSession = nil
}
isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
emitStatus = wasAvailable != isAvailable
c.reverseAccess.Unlock()
if emitStatus {
c.emitStatusUpdate()
}
}
func (c *externalCredential) getReverseContext() context.Context {

View File

@@ -111,12 +111,18 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
c.access.Unlock()
c.stateAccess.Lock()
wasAvailable, oldWeight := c.statusAggregateStateLocked()
c.state.unavailable = false
c.state.lastCredentialLoadError = ""
c.state.accountType = credentials.SubscriptionType
c.state.rateLimitTier = credentials.RateLimitTier
c.checkTransitionLocked()
isAvailable, newWeight := c.statusAggregateStateLocked()
shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight
c.stateAccess.Unlock()
if shouldEmit {
c.emitStatusUpdate()
}
return nil
}
@@ -128,16 +134,22 @@ func (c *defaultCredential) markCredentialsUnavailable(err error) error {
c.access.Unlock()
c.stateAccess.Lock()
wasAvailable, oldWeight := c.statusAggregateStateLocked()
c.state.unavailable = true
c.state.lastCredentialLoadError = err.Error()
c.state.accountType = ""
c.state.rateLimitTier = ""
shouldInterrupt := c.checkTransitionLocked()
isAvailable, newWeight := c.statusAggregateStateLocked()
shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight
c.stateAccess.Unlock()
if shouldInterrupt && hadCredentials {
c.interruptConnections()
}
if shouldEmit {
c.emitStatusUpdate()
}
return err
}

View File

@@ -0,0 +1,280 @@
package ccm
import (
"context"
"errors"
"io"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common/observable"
"github.com/hashicorp/yamux"
)
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) {
return f(request)
}
func drainStatusEvents(subscription observable.Subscription[struct{}]) int {
var count int
for {
select {
case <-subscription:
count++
default:
return count
}
}
}
func newTestLogger() log.ContextLogger {
return log.NewNOPFactory().Logger()
}
func newTestCCMExternalCredential(t *testing.T, body string, headers http.Header) (*externalCredential, observable.Subscription[struct{}]) {
t.Helper()
subscriber := observable.NewSubscriber[struct{}](8)
subscription, _ := subscriber.Subscription()
credential := &externalCredential{
tag: "test",
baseURL: "http://example.com",
token: "token",
pollInterval: 25 * time.Millisecond,
forwardHTTPClient: &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) {
if request.URL.String() != "http://example.com/ccm/v1/status?watch=true" {
t.Fatalf("unexpected request URL: %s", request.URL.String())
}
return &http.Response{
StatusCode: http.StatusOK,
Header: headers.Clone(),
Body: io.NopCloser(strings.NewReader(body)),
}, nil
})},
logger: newTestLogger(),
statusSubscriber: subscriber,
}
return credential, subscription
}
func newTestYamuxSessionPair(t *testing.T) (*yamux.Session, *yamux.Session) {
t.Helper()
clientConn, serverConn := net.Pipe()
clientSession, err := yamux.Client(clientConn, defaultYamuxConfig)
if err != nil {
t.Fatalf("create yamux client: %v", err)
}
serverSession, err := yamux.Server(serverConn, defaultYamuxConfig)
if err != nil {
clientSession.Close()
t.Fatalf("create yamux server: %v", err)
}
t.Cleanup(func() {
clientSession.Close()
serverSession.Close()
})
return clientSession, serverSession
}
func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *testing.T) {
credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil)
oldTime := time.Unix(123, 0)
credential.stateAccess.Lock()
credential.state.lastUpdated = oldTime
credential.stateAccess.Unlock()
result, err := credential.connectStatusStream(context.Background())
if err != io.EOF {
t.Fatalf("expected EOF, got %v", err)
}
if !result.oneShot {
t.Fatal("expected one-shot result")
}
if result.frames != 1 {
t.Fatalf("expected 1 frame, got %d", result.frames)
}
if !credential.lastUpdatedTime().Equal(oldTime) {
t.Fatalf("expected lastUpdated restored to %v, got %v", oldTime, credential.lastUpdatedTime())
}
if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 {
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
}
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event, got %d", count)
}
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
if !oneShot {
t.Fatal("expected one-shot backoff branch")
}
if failures != 0 {
t.Fatalf("expected failures reset, got %d", failures)
}
if backoff != credential.pollInterval {
t.Fatalf("expected poll interval backoff %v, got %v", credential.pollInterval, backoff)
}
}
func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) {
headers := make(http.Header)
headers.Set(statusStreamHeader, "true")
credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", headers)
oldTime := time.Unix(123, 0)
credential.stateAccess.Lock()
credential.state.lastUpdated = oldTime
credential.stateAccess.Unlock()
result, err := credential.connectStatusStream(context.Background())
if err != io.EOF {
t.Fatalf("expected EOF, got %v", err)
}
if result.oneShot {
t.Fatal("did not expect one-shot result")
}
if result.frames != 1 {
t.Fatalf("expected 1 frame, got %d", result.frames)
}
if credential.lastUpdatedTime().Equal(oldTime) {
t.Fatal("expected lastUpdated to remain refreshed")
}
if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 {
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
}
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event, got %d", count)
}
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
if oneShot {
t.Fatal("did not expect one-shot backoff branch")
}
if failures != 4 {
t.Fatalf("expected failures incremented to 4, got %d", failures)
}
if backoff < 16*time.Second || backoff >= 24*time.Second {
t.Fatalf("expected connector backoff in [16s, 24s), got %v", backoff)
}
}
func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *testing.T) {
credential, subscription := newTestCCMExternalCredential(t, strings.Join([]string{
"{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}",
"{\"five_hour_utilization\":13,\"weekly_utilization\":35,\"plan_weight\":3}",
}, "\n"), nil)
oldTime := time.Unix(123, 0)
credential.stateAccess.Lock()
credential.state.lastUpdated = oldTime
credential.stateAccess.Unlock()
result, err := credential.connectStatusStream(context.Background())
if err != io.EOF {
t.Fatalf("expected EOF, got %v", err)
}
if result.oneShot {
t.Fatal("did not expect one-shot result")
}
if result.frames != 2 {
t.Fatalf("expected 2 frames, got %d", result.frames)
}
if credential.lastUpdatedTime().Equal(oldTime) {
t.Fatal("expected lastUpdated to remain refreshed")
}
if credential.fiveHourUtilization() != 13 || credential.weeklyUtilization() != 35 {
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
}
if count := drainStatusEvents(subscription); count != 2 {
t.Fatalf("expected 2 status events, got %d", count)
}
}
func TestDefaultCredentialStatusChangesEmitStatus(t *testing.T) {
credentialPath := filepath.Join(t.TempDir(), "credentials.json")
err := os.WriteFile(credentialPath, []byte("{\"claudeAiOauth\":{\"accessToken\":\"token\",\"refreshToken\":\"\",\"expiresAt\":0,\"subscriptionType\":\"max\"}}\n"), 0o600)
if err != nil {
t.Fatalf("write credential file: %v", err)
}
subscriber := observable.NewSubscriber[struct{}](8)
subscription, _ := subscriber.Subscription()
credential := &defaultCredential{
tag: "test",
credentialPath: credentialPath,
logger: newTestLogger(),
statusSubscriber: subscriber,
}
err = credential.markCredentialsUnavailable(errors.New("boom"))
if err == nil {
t.Fatal("expected error from markCredentialsUnavailable")
}
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event after unavailable transition, got %d", count)
}
err = credential.reloadCredentials(true)
if err != nil {
t.Fatalf("reload credentials: %v", err)
}
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event after recovery, got %d", count)
}
if weight := credential.planWeight(); weight != 5 {
t.Fatalf("expected initial max weight 5, got %v", weight)
}
profileClient := &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(
"{\"organization\":{\"organization_type\":\"claude_max\",\"rate_limit_tier\":\"default_claude_max_20x\"}}",
)),
}, nil
})}
credential.fetchProfile(context.Background(), profileClient, "token")
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event after weight change, got %d", count)
}
if weight := credential.planWeight(); weight != 10 {
t.Fatalf("expected upgraded max weight 10, got %v", weight)
}
}
func TestExternalCredentialReverseSessionChangesEmitStatus(t *testing.T) {
subscriber := observable.NewSubscriber[struct{}](8)
subscription, _ := subscriber.Subscription()
credential := &externalCredential{
tag: "receiver",
baseURL: reverseProxyBaseURL,
pollInterval: time.Minute,
logger: newTestLogger(),
statusSubscriber: subscriber,
}
clientSession, _ := newTestYamuxSessionPair(t)
if !credential.setReverseSession(clientSession) {
t.Fatal("expected reverse session to be accepted")
}
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event after reverse session up, got %d", count)
}
if !credential.isAvailable() {
t.Fatal("expected receiver credential to become available")
}
credential.clearReverseSession(clientSession)
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event after reverse session down, got %d", count)
}
if credential.isAvailable() {
t.Fatal("expected receiver credential to become unavailable")
}
}

View File

@@ -6,7 +6,6 @@ import (
"net/http"
"strings"
"github.com/sagernet/sing-box/adapter"
boxService "github.com/sagernet/sing-box/adapter/service"
"github.com/sagernet/sing-box/common/listener"
@@ -17,6 +16,7 @@ import (
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/observable"
aTLS "github.com/sagernet/sing/common/tls"
"github.com/go-chi/chi/v5"
@@ -150,18 +150,21 @@ func isAPIKeyHeader(header string) bool {
type Service struct {
boxService.Adapter
ctx context.Context
logger log.ContextLogger
options option.CCMServiceOptions
httpHeaders http.Header
listener *listener.Listener
tlsConfig tls.ServerConfig
httpServer *http.Server
ctx context.Context
logger log.ContextLogger
options option.CCMServiceOptions
httpHeaders http.Header
listener *listener.Listener
tlsConfig tls.ServerConfig
httpServer *http.Server
userManager *UserManager
providers map[string]credentialProvider
allCredentials []Credential
userConfigMap map[string]*option.CCMUser
statusSubscriber *observable.Subscriber[struct{}]
statusObserver *observable.Observer[struct{}]
}
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) {
@@ -195,6 +198,7 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
tokenMap: make(map[string]string),
}
statusSubscriber := observable.NewSubscriber[struct{}](16)
service := &Service{
Adapter: boxService.NewAdapter(C.TypeCCM, tag),
ctx: ctx,
@@ -207,7 +211,9 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
Network: []string{N.NetworkTCP},
Listen: options.ListenOptions,
}),
userManager: userManager,
userManager: userManager,
statusSubscriber: statusSubscriber,
statusObserver: observable.NewObserver[struct{}](statusSubscriber, 8),
}
providers, allCredentials, err := buildCredentialProviders(ctx, options, logger)
@@ -242,6 +248,7 @@ func (s *Service) Start(stage adapter.StartStage) error {
s.userManager.UpdateUsers(s.options.Users)
for _, credential := range s.allCredentials {
credential.setStatusSubscriber(s.statusSubscriber)
if external, ok := credential.(*externalCredential); ok && external.reverse && external.connectorURL != nil {
external.reverseService = s
}
@@ -300,6 +307,7 @@ func (s *Service) InterfaceUpdated() {
}
func (s *Service) Close() error {
s.statusObserver.Close()
err := common.Close(
common.PtrOrNil(s.httpServer),
common.PtrOrNil(s.listener),

View File

@@ -1,6 +1,7 @@
package ccm
import (
"bytes"
"encoding/json"
"net/http"
"strconv"
@@ -55,6 +56,11 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
return
}
if r.URL.Query().Get("watch") == "true" {
s.handleStatusStream(w, r, provider, userConfig)
return
}
provider.pollIfStale(r.Context())
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
@@ -67,6 +73,76 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
})
}
func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, provider credentialProvider, userConfig *option.CCMUser) {
flusher, ok := w.(http.Flusher)
if !ok {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "streaming not supported")
return
}
subscription, done, err := s.statusObserver.Subscribe()
if err != nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "service closing")
return
}
defer s.statusObserver.UnSubscribe(subscription)
provider.pollIfStale(r.Context())
w.Header().Set("Content-Type", "application/json")
w.Header().Set(statusStreamHeader, "true")
w.WriteHeader(http.StatusOK)
lastFiveHour, lastWeekly, lastWeight := s.computeAggregatedUtilization(provider, userConfig)
buf := &bytes.Buffer{}
json.NewEncoder(buf).Encode(map[string]float64{
"five_hour_utilization": lastFiveHour,
"weekly_utilization": lastWeekly,
"plan_weight": lastWeight,
})
_, writeErr := w.Write(buf.Bytes())
if writeErr != nil {
return
}
flusher.Flush()
for {
select {
case <-r.Context().Done():
return
case <-done:
return
case <-subscription:
for {
select {
case <-subscription:
default:
goto drained
}
}
drained:
fiveHour, weekly, weight := s.computeAggregatedUtilization(provider, userConfig)
if fiveHour == lastFiveHour && weekly == lastWeekly && weight == lastWeight {
continue
}
lastFiveHour = fiveHour
lastWeekly = weekly
lastWeight = weight
buf.Reset()
json.NewEncoder(buf).Encode(map[string]float64{
"five_hour_utilization": fiveHour,
"weekly_utilization": weekly,
"plan_weight": weight,
})
_, writeErr = w.Write(buf.Bytes())
if writeErr != nil {
return
}
flusher.Flush()
}
}
}
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) {
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
for _, credential := range provider.allCredentials() {

View File

@@ -9,6 +9,7 @@ import (
"time"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/observable"
)
const (
@@ -118,6 +119,7 @@ type Credential interface {
interruptConnections()
setOnBecameUnusable(fn func())
setStatusSubscriber(*observable.Subscriber[struct{}])
start() error
pollUsage(ctx context.Context)
lastUpdatedTime() time.Time

View File

@@ -192,4 +192,3 @@ func credentialForUser(
}
return provider, nil
}

View File

@@ -22,6 +22,7 @@ import (
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/common/observable"
)
type defaultCredential struct {
@@ -45,6 +46,8 @@ type defaultCredential struct {
watcher *fswatch.Watcher
watcherRetryAt time.Time
statusSubscriber *observable.Subscriber[struct{}]
// Connection interruption
onBecameUnusable func()
interrupted bool
@@ -147,6 +150,16 @@ func (c *defaultCredential) setOnBecameUnusable(fn func()) {
c.onBecameUnusable = fn
}
func (c *defaultCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) {
c.statusSubscriber = subscriber
}
func (c *defaultCredential) emitStatusUpdate() {
if c.statusSubscriber != nil {
c.statusSubscriber.Emit(struct{}{})
}
}
func (c *defaultCredential) tagName() string {
return c.tag
}
@@ -202,11 +215,16 @@ func (c *defaultCredential) getAccessToken() (string, error) {
if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) {
c.credentials = latestCredentials
c.stateAccess.Lock()
wasAvailable := !c.state.unavailable
c.state.unavailable = false
c.state.lastCredentialLoadAttempt = time.Now()
c.state.lastCredentialLoadError = ""
c.checkTransitionLocked()
shouldEmit := wasAvailable != !c.state.unavailable
c.stateAccess.Unlock()
if shouldEmit {
c.emitStatusUpdate()
}
if !latestCredentials.needsRefresh() {
return latestCredentials.getAccessToken(), nil
}
@@ -215,11 +233,16 @@ func (c *defaultCredential) getAccessToken() (string, error) {
c.credentials = newCredentials
c.stateAccess.Lock()
wasAvailable := !c.state.unavailable
c.state.unavailable = false
c.state.lastCredentialLoadAttempt = time.Now()
c.state.lastCredentialLoadError = ""
c.checkTransitionLocked()
shouldEmit := wasAvailable != !c.state.unavailable
c.stateAccess.Unlock()
if shouldEmit {
c.emitStatusUpdate()
}
err = platformWriteCredentials(newCredentials, c.credentialPath)
if err != nil {
@@ -329,6 +352,9 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
if shouldInterrupt {
c.interruptConnections()
}
if hadData {
c.emitStatusUpdate()
}
}
func (c *defaultCredential) markRateLimited(resetAt time.Time) {
@@ -341,6 +367,7 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) {
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *defaultCredential) isUsable() bool {
@@ -692,6 +719,7 @@ func (c *defaultCredential) pollUsage(ctx context.Context) {
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) {

View File

@@ -23,11 +23,15 @@ import (
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/common/observable"
"github.com/hashicorp/yamux"
)
const reverseProxyBaseURL = "http://reverse-proxy"
const (
reverseProxyBaseURL = "http://reverse-proxy"
statusStreamHeader = "X-OCM-Status-Stream"
)
type externalCredential struct {
tag string
@@ -42,6 +46,7 @@ type externalCredential struct {
usageTracker *AggregatedUsage
logger log.ContextLogger
statusSubscriber *observable.Subscriber[struct{}]
onBecameUnusable func()
interrupted bool
requestContext context.Context
@@ -69,6 +74,12 @@ type reverseSessionDialer struct {
credential *externalCredential
}
type statusStreamResult struct {
duration time.Duration
frames int
oneShot bool
}
func (d reverseSessionDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if N.NetworkName(network) != N.NetworkTCP {
return nil, os.ErrInvalid
@@ -249,6 +260,8 @@ func (c *externalCredential) start() error {
}
if c.reverse && c.connectorURL != nil {
go c.connectorLoop()
} else {
go c.statusStreamLoop()
}
return nil
}
@@ -257,6 +270,16 @@ func (c *externalCredential) setOnBecameUnusable(fn func()) {
c.onBecameUnusable = fn
}
func (c *externalCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) {
c.statusSubscriber = subscriber
}
func (c *externalCredential) emitStatusUpdate() {
if c.statusSubscriber != nil {
c.statusSubscriber.Emit(struct{}{})
}
}
func (c *externalCredential) tagName() string {
return c.tag
}
@@ -342,6 +365,7 @@ func (c *externalCredential) markRateLimited(resetAt time.Time) {
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *externalCredential) earliestReset() time.Time {
@@ -498,6 +522,9 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
if shouldInterrupt {
c.interruptConnections()
}
if hadData {
c.emitStatusUpdate()
}
}
func (c *externalCredential) checkTransitionLocked() bool {
@@ -638,6 +665,142 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
if shouldInterrupt {
c.interruptConnections()
}
c.emitStatusUpdate()
}
func (c *externalCredential) statusStreamLoop() {
var consecutiveFailures int
ctx := c.getReverseContext()
for {
select {
case <-ctx.Done():
return
default:
}
result, err := c.connectStatusStream(ctx)
if ctx.Err() != nil {
return
}
var backoff time.Duration
var oneShot bool
consecutiveFailures, backoff, oneShot = c.nextStatusStreamBackoff(result, consecutiveFailures)
if oneShot {
c.logger.Debug("status stream for ", c.tag, " returned a single-frame response, retrying in ", backoff)
} else {
c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff)
}
timer := time.NewTimer(backoff)
select {
case <-timer.C:
case <-ctx.Done():
timer.Stop()
return
}
}
}
func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStreamResult, error) {
startTime := time.Now()
result := statusStreamResult{}
response, err := c.doStreamStatusRequest(ctx)
if err != nil {
result.duration = time.Since(startTime)
return result, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
result.duration = time.Since(startTime)
return result, E.New("status ", response.StatusCode, " ", string(body))
}
decoder := json.NewDecoder(response.Body)
isStatusStream := response.Header.Get(statusStreamHeader) == "true"
previousLastUpdated := c.lastUpdatedTime()
var firstFrameUpdatedAt time.Time
for {
var statusResponse struct {
FiveHourUtilization float64 `json:"five_hour_utilization"`
WeeklyUtilization float64 `json:"weekly_utilization"`
PlanWeight float64 `json:"plan_weight"`
}
err = decoder.Decode(&statusResponse)
if err != nil {
result.duration = time.Since(startTime)
if result.frames == 1 && err == io.EOF && !isStatusStream {
result.oneShot = true
c.restoreLastUpdatedIfUnchanged(firstFrameUpdatedAt, previousLastUpdated)
}
return result, err
}
c.stateAccess.Lock()
c.state.consecutivePollFailures = 0
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
if statusResponse.PlanWeight > 0 {
c.state.remotePlanWeight = statusResponse.PlanWeight
}
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
c.state.hardRateLimited = false
}
shouldInterrupt := c.checkTransitionLocked()
c.stateAccess.Unlock()
if shouldInterrupt {
c.interruptConnections()
}
result.frames++
updatedAt := c.markUsageStreamUpdated()
if result.frames == 1 {
firstFrameUpdatedAt = updatedAt
}
c.emitStatusUpdate()
}
}
func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration, bool) {
if result.oneShot {
return 0, c.pollInterval, true
}
if result.duration >= connectorBackoffResetThreshold {
consecutiveFailures = 0
}
consecutiveFailures++
return consecutiveFailures, connectorBackoff(consecutiveFailures), false
}
func (c *externalCredential) doStreamStatusRequest(ctx context.Context) (*http.Response, error) {
buildRequest := func(baseURL string) (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/ocm/v1/status?watch=true", nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+c.token)
return request, nil
}
if c.reverseHTTPClient != nil {
session := c.getReverseSession()
if session != nil && !session.IsClosed() {
request, err := buildRequest(reverseProxyBaseURL)
if err != nil {
return nil, err
}
response, err := c.reverseHTTPClient.Do(request)
if err == nil {
return response, nil
}
}
}
if c.forwardHTTPClient != nil {
request, err := buildRequest(c.baseURL)
if err != nil {
return nil, err
}
return c.forwardHTTPClient.Do(request)
}
return nil, E.New("no transport available")
}
func (c *externalCredential) lastUpdatedTime() time.Time {
@@ -646,6 +809,25 @@ func (c *externalCredential) lastUpdatedTime() time.Time {
return c.state.lastUpdated
}
func (c *externalCredential) markUsageStreamUpdated() time.Time {
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
now := time.Now()
c.state.lastUpdated = now
return now
}
func (c *externalCredential) restoreLastUpdatedIfUnchanged(expectedCurrent time.Time, previous time.Time) {
if expectedCurrent.IsZero() {
return
}
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
if c.state.lastUpdated.Equal(expectedCurrent) {
c.state.lastUpdated = previous
}
}
func (c *externalCredential) markUsagePollAttempted() {
c.stateAccess.Lock()
defer c.stateAccess.Unlock()
@@ -736,26 +918,40 @@ func (c *externalCredential) getReverseSession() *yamux.Session {
}
func (c *externalCredential) setReverseSession(session *yamux.Session) bool {
var emitStatus bool
c.reverseAccess.Lock()
if c.closed {
c.reverseAccess.Unlock()
return false
}
wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
old := c.reverseSession
c.reverseSession = session
isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
emitStatus = wasAvailable != isAvailable
c.reverseAccess.Unlock()
if old != nil {
old.Close()
}
if emitStatus {
c.emitStatusUpdate()
}
return true
}
func (c *externalCredential) clearReverseSession(session *yamux.Session) {
var emitStatus bool
c.reverseAccess.Lock()
wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
if c.reverseSession == session {
c.reverseSession = nil
}
isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
emitStatus = wasAvailable != isAvailable
c.reverseAccess.Unlock()
if emitStatus {
c.emitStatusUpdate()
}
}
func (c *externalCredential) getReverseContext() context.Context {

View File

@@ -111,10 +111,15 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
c.access.Unlock()
c.stateAccess.Lock()
wasAvailable := !c.state.unavailable
c.state.unavailable = false
c.state.lastCredentialLoadError = ""
c.checkTransitionLocked()
shouldEmit := wasAvailable != !c.state.unavailable
c.stateAccess.Unlock()
if shouldEmit {
c.emitStatusUpdate()
}
return nil
}
@@ -126,14 +131,19 @@ func (c *defaultCredential) markCredentialsUnavailable(err error) error {
c.access.Unlock()
c.stateAccess.Lock()
wasAvailable := !c.state.unavailable
c.state.unavailable = true
c.state.lastCredentialLoadError = err.Error()
shouldInterrupt := c.checkTransitionLocked()
shouldEmit := wasAvailable != !c.state.unavailable
c.stateAccess.Unlock()
if shouldInterrupt && hadCredentials {
c.interruptConnections()
}
if shouldEmit {
c.emitStatusUpdate()
}
return err
}

View File

@@ -0,0 +1,263 @@
package ocm
import (
"context"
"errors"
"io"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common/observable"
"github.com/hashicorp/yamux"
)
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) {
return f(request)
}
func drainStatusEvents(subscription observable.Subscription[struct{}]) int {
var count int
for {
select {
case <-subscription:
count++
default:
return count
}
}
}
func newTestLogger() log.ContextLogger {
return log.NewNOPFactory().Logger()
}
func newTestOCMExternalCredential(t *testing.T, body string, headers http.Header) (*externalCredential, observable.Subscription[struct{}]) {
t.Helper()
subscriber := observable.NewSubscriber[struct{}](8)
subscription, _ := subscriber.Subscription()
credential := &externalCredential{
tag: "test",
baseURL: "http://example.com",
token: "token",
pollInterval: 25 * time.Millisecond,
forwardHTTPClient: &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) {
if request.URL.String() != "http://example.com/ocm/v1/status?watch=true" {
t.Fatalf("unexpected request URL: %s", request.URL.String())
}
return &http.Response{
StatusCode: http.StatusOK,
Header: headers.Clone(),
Body: io.NopCloser(strings.NewReader(body)),
}, nil
})},
logger: newTestLogger(),
statusSubscriber: subscriber,
}
return credential, subscription
}
func newTestYamuxSessionPair(t *testing.T) (*yamux.Session, *yamux.Session) {
t.Helper()
clientConn, serverConn := net.Pipe()
clientSession, err := yamux.Client(clientConn, defaultYamuxConfig)
if err != nil {
t.Fatalf("create yamux client: %v", err)
}
serverSession, err := yamux.Server(serverConn, defaultYamuxConfig)
if err != nil {
clientSession.Close()
t.Fatalf("create yamux server: %v", err)
}
t.Cleanup(func() {
clientSession.Close()
serverSession.Close()
})
return clientSession, serverSession
}
func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *testing.T) {
credential, subscription := newTestOCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil)
oldTime := time.Unix(123, 0)
credential.stateAccess.Lock()
credential.state.lastUpdated = oldTime
credential.stateAccess.Unlock()
result, err := credential.connectStatusStream(context.Background())
if err != io.EOF {
t.Fatalf("expected EOF, got %v", err)
}
if !result.oneShot {
t.Fatal("expected one-shot result")
}
if result.frames != 1 {
t.Fatalf("expected 1 frame, got %d", result.frames)
}
if !credential.lastUpdatedTime().Equal(oldTime) {
t.Fatalf("expected lastUpdated restored to %v, got %v", oldTime, credential.lastUpdatedTime())
}
if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 {
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
}
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event, got %d", count)
}
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
if !oneShot {
t.Fatal("expected one-shot backoff branch")
}
if failures != 0 {
t.Fatalf("expected failures reset, got %d", failures)
}
if backoff != credential.pollInterval {
t.Fatalf("expected poll interval backoff %v, got %v", credential.pollInterval, backoff)
}
}
func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) {
headers := make(http.Header)
headers.Set(statusStreamHeader, "true")
credential, subscription := newTestOCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", headers)
oldTime := time.Unix(123, 0)
credential.stateAccess.Lock()
credential.state.lastUpdated = oldTime
credential.stateAccess.Unlock()
result, err := credential.connectStatusStream(context.Background())
if err != io.EOF {
t.Fatalf("expected EOF, got %v", err)
}
if result.oneShot {
t.Fatal("did not expect one-shot result")
}
if result.frames != 1 {
t.Fatalf("expected 1 frame, got %d", result.frames)
}
if credential.lastUpdatedTime().Equal(oldTime) {
t.Fatal("expected lastUpdated to remain refreshed")
}
if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 {
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
}
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event, got %d", count)
}
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
if oneShot {
t.Fatal("did not expect one-shot backoff branch")
}
if failures != 4 {
t.Fatalf("expected failures incremented to 4, got %d", failures)
}
if backoff < 16*time.Second || backoff >= 24*time.Second {
t.Fatalf("expected connector backoff in [16s, 24s), got %v", backoff)
}
}
func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *testing.T) {
credential, subscription := newTestOCMExternalCredential(t, strings.Join([]string{
"{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}",
"{\"five_hour_utilization\":13,\"weekly_utilization\":35,\"plan_weight\":3}",
}, "\n"), nil)
oldTime := time.Unix(123, 0)
credential.stateAccess.Lock()
credential.state.lastUpdated = oldTime
credential.stateAccess.Unlock()
result, err := credential.connectStatusStream(context.Background())
if err != io.EOF {
t.Fatalf("expected EOF, got %v", err)
}
if result.oneShot {
t.Fatal("did not expect one-shot result")
}
if result.frames != 2 {
t.Fatalf("expected 2 frames, got %d", result.frames)
}
if credential.lastUpdatedTime().Equal(oldTime) {
t.Fatal("expected lastUpdated to remain refreshed")
}
if credential.fiveHourUtilization() != 13 || credential.weeklyUtilization() != 35 {
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
}
if count := drainStatusEvents(subscription); count != 2 {
t.Fatalf("expected 2 status events, got %d", count)
}
}
func TestDefaultCredentialAvailabilityChangesEmitStatus(t *testing.T) {
credentialPath := filepath.Join(t.TempDir(), "auth.json")
err := os.WriteFile(credentialPath, []byte("{\"OPENAI_API_KEY\":\"sk-test\"}\n"), 0o600)
if err != nil {
t.Fatalf("write credential file: %v", err)
}
subscriber := observable.NewSubscriber[struct{}](8)
subscription, _ := subscriber.Subscription()
credential := &defaultCredential{
tag: "test",
credentialPath: credentialPath,
logger: newTestLogger(),
statusSubscriber: subscriber,
}
err = credential.markCredentialsUnavailable(errors.New("boom"))
if err == nil {
t.Fatal("expected error from markCredentialsUnavailable")
}
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event after unavailable transition, got %d", count)
}
err = credential.reloadCredentials(true)
if err != nil {
t.Fatalf("reload credentials: %v", err)
}
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event after recovery, got %d", count)
}
if !credential.isAvailable() {
t.Fatal("expected credential to become available")
}
}
func TestExternalCredentialReverseSessionChangesEmitStatus(t *testing.T) {
subscriber := observable.NewSubscriber[struct{}](8)
subscription, _ := subscriber.Subscription()
credential := &externalCredential{
tag: "receiver",
baseURL: reverseProxyBaseURL,
pollInterval: time.Minute,
logger: newTestLogger(),
statusSubscriber: subscriber,
}
clientSession, _ := newTestYamuxSessionPair(t)
if !credential.setReverseSession(clientSession) {
t.Fatal("expected reverse session to be accepted")
}
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event after reverse session up, got %d", count)
}
if !credential.isAvailable() {
t.Fatal("expected receiver credential to become available")
}
credential.clearReverseSession(clientSession)
if count := drainStatusEvents(subscription); count != 1 {
t.Fatalf("expected 1 status event after reverse session down, got %d", count)
}
if credential.isAvailable() {
t.Fatal("expected receiver credential to become unavailable")
}
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/observable"
aTLS "github.com/sagernet/sing/common/tls"
"github.com/go-chi/chi/v5"
@@ -176,9 +177,11 @@ type Service struct {
webSocketConns map[*webSocketSession]struct{}
shuttingDown bool
providers map[string]credentialProvider
allCredentials []Credential
userConfigMap map[string]*option.OCMUser
providers map[string]credentialProvider
allCredentials []Credential
userConfigMap map[string]*option.OCMUser
statusSubscriber *observable.Subscriber[struct{}]
statusObserver *observable.Observer[struct{}]
}
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) {
@@ -210,6 +213,8 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
tokenMap: make(map[string]string),
}
statusSubscriber := observable.NewSubscriber[struct{}](16)
service := &Service{
Adapter: boxService.NewAdapter(C.TypeOCM, tag),
ctx: ctx,
@@ -222,8 +227,10 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
Network: []string{N.NetworkTCP},
Listen: options.ListenOptions,
}),
userManager: userManager,
webSocketConns: make(map[*webSocketSession]struct{}),
userManager: userManager,
statusSubscriber: statusSubscriber,
statusObserver: observable.NewObserver[struct{}](statusSubscriber, 8),
webSocketConns: make(map[*webSocketSession]struct{}),
}
providers, allCredentials, err := buildOCMCredentialProviders(ctx, options, logger)
@@ -258,6 +265,7 @@ func (s *Service) Start(stage adapter.StartStage) error {
s.userManager.UpdateUsers(s.options.Users)
for _, credential := range s.allCredentials {
credential.setStatusSubscriber(s.statusSubscriber)
if external, ok := credential.(*externalCredential); ok && external.reverse && external.connectorURL != nil {
external.reverseService = s
}
@@ -324,6 +332,7 @@ func (s *Service) InterfaceUpdated() {
}
func (s *Service) Close() error {
s.statusObserver.Close()
webSocketSessions := s.startWebSocketShutdown()
err := common.Close(

View File

@@ -1,6 +1,7 @@
package ocm
import (
"bytes"
"encoding/json"
"net/http"
"strconv"
@@ -55,6 +56,11 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
return
}
if r.URL.Query().Get("watch") == "true" {
s.handleStatusStream(w, r, provider, userConfig)
return
}
provider.pollIfStale(r.Context())
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
@@ -67,6 +73,76 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
})
}
func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, provider credentialProvider, userConfig *option.OCMUser) {
flusher, ok := w.(http.Flusher)
if !ok {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "streaming not supported")
return
}
subscription, done, err := s.statusObserver.Subscribe()
if err != nil {
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "service closing")
return
}
defer s.statusObserver.UnSubscribe(subscription)
provider.pollIfStale(r.Context())
w.Header().Set("Content-Type", "application/json")
w.Header().Set(statusStreamHeader, "true")
w.WriteHeader(http.StatusOK)
lastFiveHour, lastWeekly, lastWeight := s.computeAggregatedUtilization(provider, userConfig)
buf := &bytes.Buffer{}
json.NewEncoder(buf).Encode(map[string]float64{
"five_hour_utilization": lastFiveHour,
"weekly_utilization": lastWeekly,
"plan_weight": lastWeight,
})
_, writeErr := w.Write(buf.Bytes())
if writeErr != nil {
return
}
flusher.Flush()
for {
select {
case <-r.Context().Done():
return
case <-done:
return
case <-subscription:
for {
select {
case <-subscription:
default:
goto drained
}
}
drained:
fiveHour, weekly, weight := s.computeAggregatedUtilization(provider, userConfig)
if fiveHour == lastFiveHour && weekly == lastWeekly && weight == lastWeight {
continue
}
lastFiveHour = fiveHour
lastWeekly = weekly
lastWeight = weight
buf.Reset()
json.NewEncoder(buf).Encode(map[string]float64{
"five_hour_utilization": fiveHour,
"weekly_utilization": weekly,
"plan_weight": weight,
})
_, writeErr = w.Write(buf.Bytes())
if writeErr != nil {
return
}
flusher.Flush()
}
}
}
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) {
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
for _, credential := range provider.allCredentials() {