Add stream watch endpoint
This commit is contained in:
@@ -6,6 +6,8 @@ import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -115,6 +117,7 @@ type Credential interface {
|
||||
wrapRequestContext(ctx context.Context) *credentialRequestContext
|
||||
interruptConnections()
|
||||
|
||||
setStatusSubscriber(*observable.Subscriber[struct{}])
|
||||
start() error
|
||||
pollUsage(ctx context.Context)
|
||||
lastUpdatedTime() time.Time
|
||||
|
||||
@@ -161,4 +161,3 @@ func credentialForUser(
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
)
|
||||
|
||||
type defaultCredential struct {
|
||||
@@ -43,11 +44,13 @@ type defaultCredential struct {
|
||||
watcher *fswatch.Watcher
|
||||
watcherRetryAt time.Time
|
||||
|
||||
statusSubscriber *observable.Subscriber[struct{}]
|
||||
|
||||
// Connection interruption
|
||||
interrupted bool
|
||||
requestContext context.Context
|
||||
cancelRequests context.CancelFunc
|
||||
requestAccess sync.Mutex
|
||||
requestContext context.Context
|
||||
cancelRequests context.CancelFunc
|
||||
requestAccess sync.Mutex
|
||||
}
|
||||
|
||||
func newDefaultCredential(ctx context.Context, tag string, options option.CCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) {
|
||||
@@ -139,6 +142,23 @@ func (c *defaultCredential) start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) {
|
||||
c.statusSubscriber = subscriber
|
||||
}
|
||||
|
||||
func (c *defaultCredential) emitStatusUpdate() {
|
||||
if c.statusSubscriber != nil {
|
||||
c.statusSubscriber.Emit(struct{}{})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) statusAggregateStateLocked() (bool, float64) {
|
||||
if c.state.unavailable {
|
||||
return false, 0
|
||||
}
|
||||
return true, ccmPlanWeight(c.state.accountType, c.state.rateLimitTier)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) getAccessToken() (string, error) {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
@@ -186,13 +206,19 @@ func (c *defaultCredential) getAccessToken() (string, error) {
|
||||
if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) {
|
||||
c.credentials = latestCredentials
|
||||
c.stateAccess.Lock()
|
||||
wasAvailable, oldWeight := c.statusAggregateStateLocked()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.state.accountType = latestCredentials.SubscriptionType
|
||||
c.state.rateLimitTier = latestCredentials.RateLimitTier
|
||||
c.checkTransitionLocked()
|
||||
isAvailable, newWeight := c.statusAggregateStateLocked()
|
||||
shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight
|
||||
c.stateAccess.Unlock()
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
if !latestCredentials.needsRefresh() {
|
||||
return latestCredentials.AccessToken, nil
|
||||
}
|
||||
@@ -201,13 +227,19 @@ func (c *defaultCredential) getAccessToken() (string, error) {
|
||||
|
||||
c.credentials = newCredentials
|
||||
c.stateAccess.Lock()
|
||||
wasAvailable, oldWeight := c.statusAggregateStateLocked()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.state.accountType = newCredentials.SubscriptionType
|
||||
c.state.rateLimitTier = newCredentials.RateLimitTier
|
||||
c.checkTransitionLocked()
|
||||
isAvailable, newWeight := c.statusAggregateStateLocked()
|
||||
shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight
|
||||
c.stateAccess.Unlock()
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
err = platformWriteCredentials(newCredentials, c.credentialPath)
|
||||
if err != nil {
|
||||
@@ -277,6 +309,9 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
if hadData {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markRateLimited(resetAt time.Time) {
|
||||
@@ -289,6 +324,7 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) {
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isUsable() bool {
|
||||
@@ -584,6 +620,7 @@ func (c *defaultCredential) pollUsage(ctx context.Context) {
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
c.emitStatusUpdate()
|
||||
|
||||
if needsProfileFetch {
|
||||
c.fetchProfile(ctx, httpClient, accessToken)
|
||||
@@ -636,6 +673,7 @@ func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.C
|
||||
rateLimitTier := profileResponse.Organization.RateLimitTier
|
||||
|
||||
c.stateAccess.Lock()
|
||||
wasAvailable, oldWeight := c.statusAggregateStateLocked()
|
||||
if accountType != "" && c.state.accountType == "" {
|
||||
c.state.accountType = accountType
|
||||
}
|
||||
@@ -643,7 +681,12 @@ func (c *defaultCredential) fetchProfile(ctx context.Context, httpClient *http.C
|
||||
c.state.rateLimitTier = rateLimitTier
|
||||
}
|
||||
resolvedAccountType := c.state.accountType
|
||||
isAvailable, newWeight := c.statusAggregateStateLocked()
|
||||
shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight
|
||||
c.stateAccess.Unlock()
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
c.logger.Info("fetched profile for ", c.tag, ": type=", resolvedAccountType, ", tier=", rateLimitTier, ", weight=", ccmPlanWeight(resolvedAccountType, rateLimitTier))
|
||||
}
|
||||
|
||||
|
||||
@@ -22,11 +22,15 @@ import (
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
const reverseProxyBaseURL = "http://reverse-proxy"
|
||||
const (
|
||||
reverseProxyBaseURL = "http://reverse-proxy"
|
||||
statusStreamHeader = "X-CCM-Status-Stream"
|
||||
)
|
||||
|
||||
type externalCredential struct {
|
||||
tag string
|
||||
@@ -40,10 +44,12 @@ type externalCredential struct {
|
||||
usageTracker *AggregatedUsage
|
||||
logger log.ContextLogger
|
||||
|
||||
statusSubscriber *observable.Subscriber[struct{}]
|
||||
|
||||
interrupted bool
|
||||
requestContext context.Context
|
||||
cancelRequests context.CancelFunc
|
||||
requestAccess sync.Mutex
|
||||
requestContext context.Context
|
||||
cancelRequests context.CancelFunc
|
||||
requestAccess sync.Mutex
|
||||
|
||||
// Reverse proxy fields
|
||||
reverse bool
|
||||
@@ -61,6 +67,12 @@ type externalCredential struct {
|
||||
reverseService http.Handler
|
||||
}
|
||||
|
||||
type statusStreamResult struct {
|
||||
duration time.Duration
|
||||
frames int
|
||||
oneShot bool
|
||||
}
|
||||
|
||||
func externalCredentialURLPort(parsedURL *url.URL) uint16 {
|
||||
portString := parsedURL.Port()
|
||||
if portString != "" {
|
||||
@@ -218,6 +230,16 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx
|
||||
return credential, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) {
|
||||
c.statusSubscriber = subscriber
|
||||
}
|
||||
|
||||
func (c *externalCredential) emitStatusUpdate() {
|
||||
if c.statusSubscriber != nil {
|
||||
c.statusSubscriber.Emit(struct{}{})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) start() error {
|
||||
if c.usageTracker != nil {
|
||||
err := c.usageTracker.Load()
|
||||
@@ -227,6 +249,8 @@ func (c *externalCredential) start() error {
|
||||
}
|
||||
if c.reverse && c.connectorURL != nil {
|
||||
go c.connectorLoop()
|
||||
} else {
|
||||
go c.statusStreamLoop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -317,6 +341,7 @@ func (c *externalCredential) markRateLimited(resetAt time.Time) {
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
func (c *externalCredential) earliestReset() time.Time {
|
||||
@@ -458,6 +483,9 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
if hadData {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) checkTransitionLocked() bool {
|
||||
@@ -595,6 +623,142 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
func (c *externalCredential) statusStreamLoop() {
|
||||
var consecutiveFailures int
|
||||
ctx := c.getReverseContext()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
result, err := c.connectStatusStream(ctx)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
var backoff time.Duration
|
||||
var oneShot bool
|
||||
consecutiveFailures, backoff, oneShot = c.nextStatusStreamBackoff(result, consecutiveFailures)
|
||||
if oneShot {
|
||||
c.logger.Debug("status stream for ", c.tag, " returned a single-frame response, retrying in ", backoff)
|
||||
} else {
|
||||
c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff)
|
||||
}
|
||||
timer := time.NewTimer(backoff)
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStreamResult, error) {
|
||||
startTime := time.Now()
|
||||
result := statusStreamResult{}
|
||||
response, err := c.doStreamStatusRequest(ctx)
|
||||
if err != nil {
|
||||
result.duration = time.Since(startTime)
|
||||
return result, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
result.duration = time.Since(startTime)
|
||||
return result, E.New("status ", response.StatusCode, " ", string(body))
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(response.Body)
|
||||
isStatusStream := response.Header.Get(statusStreamHeader) == "true"
|
||||
previousLastUpdated := c.lastUpdatedTime()
|
||||
var firstFrameUpdatedAt time.Time
|
||||
for {
|
||||
var statusResponse struct {
|
||||
FiveHourUtilization float64 `json:"five_hour_utilization"`
|
||||
WeeklyUtilization float64 `json:"weekly_utilization"`
|
||||
PlanWeight float64 `json:"plan_weight"`
|
||||
}
|
||||
err = decoder.Decode(&statusResponse)
|
||||
if err != nil {
|
||||
result.duration = time.Since(startTime)
|
||||
if result.frames == 1 && err == io.EOF && !isStatusStream {
|
||||
result.oneShot = true
|
||||
c.restoreLastUpdatedIfUnchanged(firstFrameUpdatedAt, previousLastUpdated)
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
c.stateAccess.Lock()
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
|
||||
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
|
||||
if statusResponse.PlanWeight > 0 {
|
||||
c.state.remotePlanWeight = statusResponse.PlanWeight
|
||||
}
|
||||
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
result.frames++
|
||||
updatedAt := c.markUsageStreamUpdated()
|
||||
if result.frames == 1 {
|
||||
firstFrameUpdatedAt = updatedAt
|
||||
}
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration, bool) {
|
||||
if result.oneShot {
|
||||
return 0, c.pollInterval, true
|
||||
}
|
||||
if result.duration >= connectorBackoffResetThreshold {
|
||||
consecutiveFailures = 0
|
||||
}
|
||||
consecutiveFailures++
|
||||
return consecutiveFailures, connectorBackoff(consecutiveFailures), false
|
||||
}
|
||||
|
||||
func (c *externalCredential) doStreamStatusRequest(ctx context.Context) (*http.Response, error) {
|
||||
buildRequest := func(baseURL string) (*http.Request, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/ccm/v1/status?watch=true", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer "+c.token)
|
||||
return request, nil
|
||||
}
|
||||
if c.reverseHTTPClient != nil {
|
||||
session := c.getReverseSession()
|
||||
if session != nil && !session.IsClosed() {
|
||||
request, err := buildRequest(reverseProxyBaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response, err := c.reverseHTTPClient.Do(request)
|
||||
if err == nil {
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if c.forwardHTTPClient != nil {
|
||||
request, err := buildRequest(c.baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.forwardHTTPClient.Do(request)
|
||||
}
|
||||
return nil, E.New("no transport available")
|
||||
}
|
||||
|
||||
func (c *externalCredential) lastUpdatedTime() time.Time {
|
||||
@@ -603,6 +767,25 @@ func (c *externalCredential) lastUpdatedTime() time.Time {
|
||||
return c.state.lastUpdated
|
||||
}
|
||||
|
||||
func (c *externalCredential) markUsageStreamUpdated() time.Time {
|
||||
c.stateAccess.Lock()
|
||||
defer c.stateAccess.Unlock()
|
||||
now := time.Now()
|
||||
c.state.lastUpdated = now
|
||||
return now
|
||||
}
|
||||
|
||||
func (c *externalCredential) restoreLastUpdatedIfUnchanged(expectedCurrent time.Time, previous time.Time) {
|
||||
if expectedCurrent.IsZero() {
|
||||
return
|
||||
}
|
||||
c.stateAccess.Lock()
|
||||
defer c.stateAccess.Unlock()
|
||||
if c.state.lastUpdated.Equal(expectedCurrent) {
|
||||
c.state.lastUpdated = previous
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) markUsagePollAttempted() {
|
||||
c.stateAccess.Lock()
|
||||
defer c.stateAccess.Unlock()
|
||||
@@ -665,26 +848,40 @@ func (c *externalCredential) getReverseSession() *yamux.Session {
|
||||
}
|
||||
|
||||
func (c *externalCredential) setReverseSession(session *yamux.Session) bool {
|
||||
var emitStatus bool
|
||||
c.reverseAccess.Lock()
|
||||
if c.closed {
|
||||
c.reverseAccess.Unlock()
|
||||
return false
|
||||
}
|
||||
wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
|
||||
old := c.reverseSession
|
||||
c.reverseSession = session
|
||||
isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
|
||||
emitStatus = wasAvailable != isAvailable
|
||||
c.reverseAccess.Unlock()
|
||||
if old != nil {
|
||||
old.Close()
|
||||
}
|
||||
if emitStatus {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *externalCredential) clearReverseSession(session *yamux.Session) {
|
||||
var emitStatus bool
|
||||
c.reverseAccess.Lock()
|
||||
wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
|
||||
if c.reverseSession == session {
|
||||
c.reverseSession = nil
|
||||
}
|
||||
isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
|
||||
emitStatus = wasAvailable != isAvailable
|
||||
c.reverseAccess.Unlock()
|
||||
if emitStatus {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) getReverseContext() context.Context {
|
||||
|
||||
@@ -111,12 +111,18 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
|
||||
c.access.Unlock()
|
||||
|
||||
c.stateAccess.Lock()
|
||||
wasAvailable, oldWeight := c.statusAggregateStateLocked()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.state.accountType = credentials.SubscriptionType
|
||||
c.state.rateLimitTier = credentials.RateLimitTier
|
||||
c.checkTransitionLocked()
|
||||
isAvailable, newWeight := c.statusAggregateStateLocked()
|
||||
shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight
|
||||
c.stateAccess.Unlock()
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -128,16 +134,22 @@ func (c *defaultCredential) markCredentialsUnavailable(err error) error {
|
||||
c.access.Unlock()
|
||||
|
||||
c.stateAccess.Lock()
|
||||
wasAvailable, oldWeight := c.statusAggregateStateLocked()
|
||||
c.state.unavailable = true
|
||||
c.state.lastCredentialLoadError = err.Error()
|
||||
c.state.accountType = ""
|
||||
c.state.rateLimitTier = ""
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
isAvailable, newWeight := c.statusAggregateStateLocked()
|
||||
shouldEmit := wasAvailable != isAvailable || oldWeight != newWeight
|
||||
c.stateAccess.Unlock()
|
||||
|
||||
if shouldInterrupt && hadCredentials {
|
||||
c.interruptConnections()
|
||||
}
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
280
service/ccm/credential_status_test.go
Normal file
280
service/ccm/credential_status_test.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
return f(request)
|
||||
}
|
||||
|
||||
func drainStatusEvents(subscription observable.Subscription[struct{}]) int {
|
||||
var count int
|
||||
for {
|
||||
select {
|
||||
case <-subscription:
|
||||
count++
|
||||
default:
|
||||
return count
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newTestLogger() log.ContextLogger {
|
||||
return log.NewNOPFactory().Logger()
|
||||
}
|
||||
|
||||
func newTestCCMExternalCredential(t *testing.T, body string, headers http.Header) (*externalCredential, observable.Subscription[struct{}]) {
|
||||
t.Helper()
|
||||
subscriber := observable.NewSubscriber[struct{}](8)
|
||||
subscription, _ := subscriber.Subscription()
|
||||
credential := &externalCredential{
|
||||
tag: "test",
|
||||
baseURL: "http://example.com",
|
||||
token: "token",
|
||||
pollInterval: 25 * time.Millisecond,
|
||||
forwardHTTPClient: &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) {
|
||||
if request.URL.String() != "http://example.com/ccm/v1/status?watch=true" {
|
||||
t.Fatalf("unexpected request URL: %s", request.URL.String())
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: headers.Clone(),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}, nil
|
||||
})},
|
||||
logger: newTestLogger(),
|
||||
statusSubscriber: subscriber,
|
||||
}
|
||||
return credential, subscription
|
||||
}
|
||||
|
||||
func newTestYamuxSessionPair(t *testing.T) (*yamux.Session, *yamux.Session) {
|
||||
t.Helper()
|
||||
clientConn, serverConn := net.Pipe()
|
||||
clientSession, err := yamux.Client(clientConn, defaultYamuxConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("create yamux client: %v", err)
|
||||
}
|
||||
serverSession, err := yamux.Server(serverConn, defaultYamuxConfig)
|
||||
if err != nil {
|
||||
clientSession.Close()
|
||||
t.Fatalf("create yamux server: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
clientSession.Close()
|
||||
serverSession.Close()
|
||||
})
|
||||
return clientSession, serverSession
|
||||
}
|
||||
|
||||
func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *testing.T) {
|
||||
credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil)
|
||||
oldTime := time.Unix(123, 0)
|
||||
credential.stateAccess.Lock()
|
||||
credential.state.lastUpdated = oldTime
|
||||
credential.stateAccess.Unlock()
|
||||
|
||||
result, err := credential.connectStatusStream(context.Background())
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %v", err)
|
||||
}
|
||||
if !result.oneShot {
|
||||
t.Fatal("expected one-shot result")
|
||||
}
|
||||
if result.frames != 1 {
|
||||
t.Fatalf("expected 1 frame, got %d", result.frames)
|
||||
}
|
||||
if !credential.lastUpdatedTime().Equal(oldTime) {
|
||||
t.Fatalf("expected lastUpdated restored to %v, got %v", oldTime, credential.lastUpdatedTime())
|
||||
}
|
||||
if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 {
|
||||
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
|
||||
}
|
||||
if count := drainStatusEvents(subscription); count != 1 {
|
||||
t.Fatalf("expected 1 status event, got %d", count)
|
||||
}
|
||||
|
||||
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
|
||||
if !oneShot {
|
||||
t.Fatal("expected one-shot backoff branch")
|
||||
}
|
||||
if failures != 0 {
|
||||
t.Fatalf("expected failures reset, got %d", failures)
|
||||
}
|
||||
if backoff != credential.pollInterval {
|
||||
t.Fatalf("expected poll interval backoff %v, got %v", credential.pollInterval, backoff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) {
|
||||
headers := make(http.Header)
|
||||
headers.Set(statusStreamHeader, "true")
|
||||
credential, subscription := newTestCCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", headers)
|
||||
oldTime := time.Unix(123, 0)
|
||||
credential.stateAccess.Lock()
|
||||
credential.state.lastUpdated = oldTime
|
||||
credential.stateAccess.Unlock()
|
||||
|
||||
result, err := credential.connectStatusStream(context.Background())
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %v", err)
|
||||
}
|
||||
if result.oneShot {
|
||||
t.Fatal("did not expect one-shot result")
|
||||
}
|
||||
if result.frames != 1 {
|
||||
t.Fatalf("expected 1 frame, got %d", result.frames)
|
||||
}
|
||||
if credential.lastUpdatedTime().Equal(oldTime) {
|
||||
t.Fatal("expected lastUpdated to remain refreshed")
|
||||
}
|
||||
if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 {
|
||||
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
|
||||
}
|
||||
if count := drainStatusEvents(subscription); count != 1 {
|
||||
t.Fatalf("expected 1 status event, got %d", count)
|
||||
}
|
||||
|
||||
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
|
||||
if oneShot {
|
||||
t.Fatal("did not expect one-shot backoff branch")
|
||||
}
|
||||
if failures != 4 {
|
||||
t.Fatalf("expected failures incremented to 4, got %d", failures)
|
||||
}
|
||||
if backoff < 16*time.Second || backoff >= 24*time.Second {
|
||||
t.Fatalf("expected connector backoff in [16s, 24s), got %v", backoff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *testing.T) {
|
||||
credential, subscription := newTestCCMExternalCredential(t, strings.Join([]string{
|
||||
"{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}",
|
||||
"{\"five_hour_utilization\":13,\"weekly_utilization\":35,\"plan_weight\":3}",
|
||||
}, "\n"), nil)
|
||||
oldTime := time.Unix(123, 0)
|
||||
credential.stateAccess.Lock()
|
||||
credential.state.lastUpdated = oldTime
|
||||
credential.stateAccess.Unlock()
|
||||
|
||||
result, err := credential.connectStatusStream(context.Background())
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %v", err)
|
||||
}
|
||||
if result.oneShot {
|
||||
t.Fatal("did not expect one-shot result")
|
||||
}
|
||||
if result.frames != 2 {
|
||||
t.Fatalf("expected 2 frames, got %d", result.frames)
|
||||
}
|
||||
if credential.lastUpdatedTime().Equal(oldTime) {
|
||||
t.Fatal("expected lastUpdated to remain refreshed")
|
||||
}
|
||||
if credential.fiveHourUtilization() != 13 || credential.weeklyUtilization() != 35 {
|
||||
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
|
||||
}
|
||||
if count := drainStatusEvents(subscription); count != 2 {
|
||||
t.Fatalf("expected 2 status events, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCredentialStatusChangesEmitStatus(t *testing.T) {
|
||||
credentialPath := filepath.Join(t.TempDir(), "credentials.json")
|
||||
err := os.WriteFile(credentialPath, []byte("{\"claudeAiOauth\":{\"accessToken\":\"token\",\"refreshToken\":\"\",\"expiresAt\":0,\"subscriptionType\":\"max\"}}\n"), 0o600)
|
||||
if err != nil {
|
||||
t.Fatalf("write credential file: %v", err)
|
||||
}
|
||||
|
||||
subscriber := observable.NewSubscriber[struct{}](8)
|
||||
subscription, _ := subscriber.Subscription()
|
||||
credential := &defaultCredential{
|
||||
tag: "test",
|
||||
credentialPath: credentialPath,
|
||||
logger: newTestLogger(),
|
||||
statusSubscriber: subscriber,
|
||||
}
|
||||
|
||||
err = credential.markCredentialsUnavailable(errors.New("boom"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error from markCredentialsUnavailable")
|
||||
}
|
||||
if count := drainStatusEvents(subscription); count != 1 {
|
||||
t.Fatalf("expected 1 status event after unavailable transition, got %d", count)
|
||||
}
|
||||
|
||||
err = credential.reloadCredentials(true)
|
||||
if err != nil {
|
||||
t.Fatalf("reload credentials: %v", err)
|
||||
}
|
||||
if count := drainStatusEvents(subscription); count != 1 {
|
||||
t.Fatalf("expected 1 status event after recovery, got %d", count)
|
||||
}
|
||||
if weight := credential.planWeight(); weight != 5 {
|
||||
t.Fatalf("expected initial max weight 5, got %v", weight)
|
||||
}
|
||||
|
||||
profileClient := &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
"{\"organization\":{\"organization_type\":\"claude_max\",\"rate_limit_tier\":\"default_claude_max_20x\"}}",
|
||||
)),
|
||||
}, nil
|
||||
})}
|
||||
credential.fetchProfile(context.Background(), profileClient, "token")
|
||||
if count := drainStatusEvents(subscription); count != 1 {
|
||||
t.Fatalf("expected 1 status event after weight change, got %d", count)
|
||||
}
|
||||
if weight := credential.planWeight(); weight != 10 {
|
||||
t.Fatalf("expected upgraded max weight 10, got %v", weight)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalCredentialReverseSessionChangesEmitStatus(t *testing.T) {
|
||||
subscriber := observable.NewSubscriber[struct{}](8)
|
||||
subscription, _ := subscriber.Subscription()
|
||||
credential := &externalCredential{
|
||||
tag: "receiver",
|
||||
baseURL: reverseProxyBaseURL,
|
||||
pollInterval: time.Minute,
|
||||
logger: newTestLogger(),
|
||||
statusSubscriber: subscriber,
|
||||
}
|
||||
|
||||
clientSession, _ := newTestYamuxSessionPair(t)
|
||||
if !credential.setReverseSession(clientSession) {
|
||||
t.Fatal("expected reverse session to be accepted")
|
||||
}
|
||||
if count := drainStatusEvents(subscription); count != 1 {
|
||||
t.Fatalf("expected 1 status event after reverse session up, got %d", count)
|
||||
}
|
||||
if !credential.isAvailable() {
|
||||
t.Fatal("expected receiver credential to become available")
|
||||
}
|
||||
|
||||
credential.clearReverseSession(clientSession)
|
||||
if count := drainStatusEvents(subscription); count != 1 {
|
||||
t.Fatalf("expected 1 status event after reverse session down, got %d", count)
|
||||
}
|
||||
if credential.isAvailable() {
|
||||
t.Fatal("expected receiver credential to become unavailable")
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
boxService "github.com/sagernet/sing-box/adapter/service"
|
||||
"github.com/sagernet/sing-box/common/listener"
|
||||
@@ -17,6 +16,7 @@ import (
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@@ -150,18 +150,21 @@ func isAPIKeyHeader(header string) bool {
|
||||
|
||||
type Service struct {
|
||||
boxService.Adapter
|
||||
ctx context.Context
|
||||
logger log.ContextLogger
|
||||
options option.CCMServiceOptions
|
||||
httpHeaders http.Header
|
||||
listener *listener.Listener
|
||||
tlsConfig tls.ServerConfig
|
||||
httpServer *http.Server
|
||||
ctx context.Context
|
||||
logger log.ContextLogger
|
||||
options option.CCMServiceOptions
|
||||
httpHeaders http.Header
|
||||
listener *listener.Listener
|
||||
tlsConfig tls.ServerConfig
|
||||
httpServer *http.Server
|
||||
userManager *UserManager
|
||||
|
||||
providers map[string]credentialProvider
|
||||
allCredentials []Credential
|
||||
userConfigMap map[string]*option.CCMUser
|
||||
|
||||
statusSubscriber *observable.Subscriber[struct{}]
|
||||
statusObserver *observable.Observer[struct{}]
|
||||
}
|
||||
|
||||
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) {
|
||||
@@ -195,6 +198,7 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
|
||||
tokenMap: make(map[string]string),
|
||||
}
|
||||
|
||||
statusSubscriber := observable.NewSubscriber[struct{}](16)
|
||||
service := &Service{
|
||||
Adapter: boxService.NewAdapter(C.TypeCCM, tag),
|
||||
ctx: ctx,
|
||||
@@ -207,7 +211,9 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
|
||||
Network: []string{N.NetworkTCP},
|
||||
Listen: options.ListenOptions,
|
||||
}),
|
||||
userManager: userManager,
|
||||
userManager: userManager,
|
||||
statusSubscriber: statusSubscriber,
|
||||
statusObserver: observable.NewObserver[struct{}](statusSubscriber, 8),
|
||||
}
|
||||
|
||||
providers, allCredentials, err := buildCredentialProviders(ctx, options, logger)
|
||||
@@ -242,6 +248,7 @@ func (s *Service) Start(stage adapter.StartStage) error {
|
||||
s.userManager.UpdateUsers(s.options.Users)
|
||||
|
||||
for _, credential := range s.allCredentials {
|
||||
credential.setStatusSubscriber(s.statusSubscriber)
|
||||
if external, ok := credential.(*externalCredential); ok && external.reverse && external.connectorURL != nil {
|
||||
external.reverseService = s
|
||||
}
|
||||
@@ -300,6 +307,7 @@ func (s *Service) InterfaceUpdated() {
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
s.statusObserver.Close()
|
||||
err := common.Close(
|
||||
common.PtrOrNil(s.httpServer),
|
||||
common.PtrOrNil(s.listener),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -55,6 +56,11 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Query().Get("watch") == "true" {
|
||||
s.handleStatusStream(w, r, provider, userConfig)
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale(r.Context())
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
@@ -67,6 +73,76 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, provider credentialProvider, userConfig *option.CCMUser) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
subscription, done, err := s.statusObserver.Subscribe()
|
||||
if err != nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "service closing")
|
||||
return
|
||||
}
|
||||
defer s.statusObserver.UnSubscribe(subscription)
|
||||
|
||||
provider.pollIfStale(r.Context())
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set(statusStreamHeader, "true")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
lastFiveHour, lastWeekly, lastWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
buf := &bytes.Buffer{}
|
||||
json.NewEncoder(buf).Encode(map[string]float64{
|
||||
"five_hour_utilization": lastFiveHour,
|
||||
"weekly_utilization": lastWeekly,
|
||||
"plan_weight": lastWeight,
|
||||
})
|
||||
_, writeErr := w.Write(buf.Bytes())
|
||||
if writeErr != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-done:
|
||||
return
|
||||
case <-subscription:
|
||||
for {
|
||||
select {
|
||||
case <-subscription:
|
||||
default:
|
||||
goto drained
|
||||
}
|
||||
}
|
||||
drained:
|
||||
fiveHour, weekly, weight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
if fiveHour == lastFiveHour && weekly == lastWeekly && weight == lastWeight {
|
||||
continue
|
||||
}
|
||||
lastFiveHour = fiveHour
|
||||
lastWeekly = weekly
|
||||
lastWeight = weight
|
||||
buf.Reset()
|
||||
json.NewEncoder(buf).Encode(map[string]float64{
|
||||
"five_hour_utilization": fiveHour,
|
||||
"weekly_utilization": weekly,
|
||||
"plan_weight": weight,
|
||||
})
|
||||
_, writeErr = w.Write(buf.Bytes())
|
||||
if writeErr != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) {
|
||||
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
|
||||
for _, credential := range provider.allCredentials() {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -118,6 +119,7 @@ type Credential interface {
|
||||
interruptConnections()
|
||||
|
||||
setOnBecameUnusable(fn func())
|
||||
setStatusSubscriber(*observable.Subscriber[struct{}])
|
||||
start() error
|
||||
pollUsage(ctx context.Context)
|
||||
lastUpdatedTime() time.Time
|
||||
|
||||
@@ -192,4 +192,3 @@ func credentialForUser(
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
)
|
||||
|
||||
type defaultCredential struct {
|
||||
@@ -45,6 +46,8 @@ type defaultCredential struct {
|
||||
watcher *fswatch.Watcher
|
||||
watcherRetryAt time.Time
|
||||
|
||||
statusSubscriber *observable.Subscriber[struct{}]
|
||||
|
||||
// Connection interruption
|
||||
onBecameUnusable func()
|
||||
interrupted bool
|
||||
@@ -147,6 +150,16 @@ func (c *defaultCredential) setOnBecameUnusable(fn func()) {
|
||||
c.onBecameUnusable = fn
|
||||
}
|
||||
|
||||
func (c *defaultCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) {
|
||||
c.statusSubscriber = subscriber
|
||||
}
|
||||
|
||||
func (c *defaultCredential) emitStatusUpdate() {
|
||||
if c.statusSubscriber != nil {
|
||||
c.statusSubscriber.Emit(struct{}{})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) tagName() string {
|
||||
return c.tag
|
||||
}
|
||||
@@ -202,11 +215,16 @@ func (c *defaultCredential) getAccessToken() (string, error) {
|
||||
if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) {
|
||||
c.credentials = latestCredentials
|
||||
c.stateAccess.Lock()
|
||||
wasAvailable := !c.state.unavailable
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.checkTransitionLocked()
|
||||
shouldEmit := wasAvailable != !c.state.unavailable
|
||||
c.stateAccess.Unlock()
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
if !latestCredentials.needsRefresh() {
|
||||
return latestCredentials.getAccessToken(), nil
|
||||
}
|
||||
@@ -215,11 +233,16 @@ func (c *defaultCredential) getAccessToken() (string, error) {
|
||||
|
||||
c.credentials = newCredentials
|
||||
c.stateAccess.Lock()
|
||||
wasAvailable := !c.state.unavailable
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.checkTransitionLocked()
|
||||
shouldEmit := wasAvailable != !c.state.unavailable
|
||||
c.stateAccess.Unlock()
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
err = platformWriteCredentials(newCredentials, c.credentialPath)
|
||||
if err != nil {
|
||||
@@ -329,6 +352,9 @@ func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
if hadData {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markRateLimited(resetAt time.Time) {
|
||||
@@ -341,6 +367,7 @@ func (c *defaultCredential) markRateLimited(resetAt time.Time) {
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isUsable() bool {
|
||||
@@ -692,6 +719,7 @@ func (c *defaultCredential) pollUsage(ctx context.Context) {
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) {
|
||||
|
||||
@@ -23,11 +23,15 @@ import (
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
const reverseProxyBaseURL = "http://reverse-proxy"
|
||||
const (
|
||||
reverseProxyBaseURL = "http://reverse-proxy"
|
||||
statusStreamHeader = "X-OCM-Status-Stream"
|
||||
)
|
||||
|
||||
type externalCredential struct {
|
||||
tag string
|
||||
@@ -42,6 +46,7 @@ type externalCredential struct {
|
||||
usageTracker *AggregatedUsage
|
||||
logger log.ContextLogger
|
||||
|
||||
statusSubscriber *observable.Subscriber[struct{}]
|
||||
onBecameUnusable func()
|
||||
interrupted bool
|
||||
requestContext context.Context
|
||||
@@ -69,6 +74,12 @@ type reverseSessionDialer struct {
|
||||
credential *externalCredential
|
||||
}
|
||||
|
||||
type statusStreamResult struct {
|
||||
duration time.Duration
|
||||
frames int
|
||||
oneShot bool
|
||||
}
|
||||
|
||||
func (d reverseSessionDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
if N.NetworkName(network) != N.NetworkTCP {
|
||||
return nil, os.ErrInvalid
|
||||
@@ -249,6 +260,8 @@ func (c *externalCredential) start() error {
|
||||
}
|
||||
if c.reverse && c.connectorURL != nil {
|
||||
go c.connectorLoop()
|
||||
} else {
|
||||
go c.statusStreamLoop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -257,6 +270,16 @@ func (c *externalCredential) setOnBecameUnusable(fn func()) {
|
||||
c.onBecameUnusable = fn
|
||||
}
|
||||
|
||||
func (c *externalCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) {
|
||||
c.statusSubscriber = subscriber
|
||||
}
|
||||
|
||||
func (c *externalCredential) emitStatusUpdate() {
|
||||
if c.statusSubscriber != nil {
|
||||
c.statusSubscriber.Emit(struct{}{})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) tagName() string {
|
||||
return c.tag
|
||||
}
|
||||
@@ -342,6 +365,7 @@ func (c *externalCredential) markRateLimited(resetAt time.Time) {
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
func (c *externalCredential) earliestReset() time.Time {
|
||||
@@ -498,6 +522,9 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
if hadData {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) checkTransitionLocked() bool {
|
||||
@@ -638,6 +665,142 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
func (c *externalCredential) statusStreamLoop() {
|
||||
var consecutiveFailures int
|
||||
ctx := c.getReverseContext()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
result, err := c.connectStatusStream(ctx)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
var backoff time.Duration
|
||||
var oneShot bool
|
||||
consecutiveFailures, backoff, oneShot = c.nextStatusStreamBackoff(result, consecutiveFailures)
|
||||
if oneShot {
|
||||
c.logger.Debug("status stream for ", c.tag, " returned a single-frame response, retrying in ", backoff)
|
||||
} else {
|
||||
c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff)
|
||||
}
|
||||
timer := time.NewTimer(backoff)
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStreamResult, error) {
|
||||
startTime := time.Now()
|
||||
result := statusStreamResult{}
|
||||
response, err := c.doStreamStatusRequest(ctx)
|
||||
if err != nil {
|
||||
result.duration = time.Since(startTime)
|
||||
return result, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
result.duration = time.Since(startTime)
|
||||
return result, E.New("status ", response.StatusCode, " ", string(body))
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(response.Body)
|
||||
isStatusStream := response.Header.Get(statusStreamHeader) == "true"
|
||||
previousLastUpdated := c.lastUpdatedTime()
|
||||
var firstFrameUpdatedAt time.Time
|
||||
for {
|
||||
var statusResponse struct {
|
||||
FiveHourUtilization float64 `json:"five_hour_utilization"`
|
||||
WeeklyUtilization float64 `json:"weekly_utilization"`
|
||||
PlanWeight float64 `json:"plan_weight"`
|
||||
}
|
||||
err = decoder.Decode(&statusResponse)
|
||||
if err != nil {
|
||||
result.duration = time.Since(startTime)
|
||||
if result.frames == 1 && err == io.EOF && !isStatusStream {
|
||||
result.oneShot = true
|
||||
c.restoreLastUpdatedIfUnchanged(firstFrameUpdatedAt, previousLastUpdated)
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
c.stateAccess.Lock()
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
|
||||
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
|
||||
if statusResponse.PlanWeight > 0 {
|
||||
c.state.remotePlanWeight = statusResponse.PlanWeight
|
||||
}
|
||||
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
result.frames++
|
||||
updatedAt := c.markUsageStreamUpdated()
|
||||
if result.frames == 1 {
|
||||
firstFrameUpdatedAt = updatedAt
|
||||
}
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration, bool) {
|
||||
if result.oneShot {
|
||||
return 0, c.pollInterval, true
|
||||
}
|
||||
if result.duration >= connectorBackoffResetThreshold {
|
||||
consecutiveFailures = 0
|
||||
}
|
||||
consecutiveFailures++
|
||||
return consecutiveFailures, connectorBackoff(consecutiveFailures), false
|
||||
}
|
||||
|
||||
func (c *externalCredential) doStreamStatusRequest(ctx context.Context) (*http.Response, error) {
|
||||
buildRequest := func(baseURL string) (*http.Request, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/ocm/v1/status?watch=true", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer "+c.token)
|
||||
return request, nil
|
||||
}
|
||||
if c.reverseHTTPClient != nil {
|
||||
session := c.getReverseSession()
|
||||
if session != nil && !session.IsClosed() {
|
||||
request, err := buildRequest(reverseProxyBaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response, err := c.reverseHTTPClient.Do(request)
|
||||
if err == nil {
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if c.forwardHTTPClient != nil {
|
||||
request, err := buildRequest(c.baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.forwardHTTPClient.Do(request)
|
||||
}
|
||||
return nil, E.New("no transport available")
|
||||
}
|
||||
|
||||
func (c *externalCredential) lastUpdatedTime() time.Time {
|
||||
@@ -646,6 +809,25 @@ func (c *externalCredential) lastUpdatedTime() time.Time {
|
||||
return c.state.lastUpdated
|
||||
}
|
||||
|
||||
func (c *externalCredential) markUsageStreamUpdated() time.Time {
|
||||
c.stateAccess.Lock()
|
||||
defer c.stateAccess.Unlock()
|
||||
now := time.Now()
|
||||
c.state.lastUpdated = now
|
||||
return now
|
||||
}
|
||||
|
||||
func (c *externalCredential) restoreLastUpdatedIfUnchanged(expectedCurrent time.Time, previous time.Time) {
|
||||
if expectedCurrent.IsZero() {
|
||||
return
|
||||
}
|
||||
c.stateAccess.Lock()
|
||||
defer c.stateAccess.Unlock()
|
||||
if c.state.lastUpdated.Equal(expectedCurrent) {
|
||||
c.state.lastUpdated = previous
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) markUsagePollAttempted() {
|
||||
c.stateAccess.Lock()
|
||||
defer c.stateAccess.Unlock()
|
||||
@@ -736,26 +918,40 @@ func (c *externalCredential) getReverseSession() *yamux.Session {
|
||||
}
|
||||
|
||||
func (c *externalCredential) setReverseSession(session *yamux.Session) bool {
|
||||
var emitStatus bool
|
||||
c.reverseAccess.Lock()
|
||||
if c.closed {
|
||||
c.reverseAccess.Unlock()
|
||||
return false
|
||||
}
|
||||
wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
|
||||
old := c.reverseSession
|
||||
c.reverseSession = session
|
||||
isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
|
||||
emitStatus = wasAvailable != isAvailable
|
||||
c.reverseAccess.Unlock()
|
||||
if old != nil {
|
||||
old.Close()
|
||||
}
|
||||
if emitStatus {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *externalCredential) clearReverseSession(session *yamux.Session) {
|
||||
var emitStatus bool
|
||||
c.reverseAccess.Lock()
|
||||
wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
|
||||
if c.reverseSession == session {
|
||||
c.reverseSession = nil
|
||||
}
|
||||
isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
|
||||
emitStatus = wasAvailable != isAvailable
|
||||
c.reverseAccess.Unlock()
|
||||
if emitStatus {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) getReverseContext() context.Context {
|
||||
|
||||
@@ -111,10 +111,15 @@ func (c *defaultCredential) reloadCredentials(force bool) error {
|
||||
c.access.Unlock()
|
||||
|
||||
c.stateAccess.Lock()
|
||||
wasAvailable := !c.state.unavailable
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.checkTransitionLocked()
|
||||
shouldEmit := wasAvailable != !c.state.unavailable
|
||||
c.stateAccess.Unlock()
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -126,14 +131,19 @@ func (c *defaultCredential) markCredentialsUnavailable(err error) error {
|
||||
c.access.Unlock()
|
||||
|
||||
c.stateAccess.Lock()
|
||||
wasAvailable := !c.state.unavailable
|
||||
c.state.unavailable = true
|
||||
c.state.lastCredentialLoadError = err.Error()
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
shouldEmit := wasAvailable != !c.state.unavailable
|
||||
c.stateAccess.Unlock()
|
||||
|
||||
if shouldInterrupt && hadCredentials {
|
||||
c.interruptConnections()
|
||||
}
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
263
service/ocm/credential_status_test.go
Normal file
263
service/ocm/credential_status_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
return f(request)
|
||||
}
|
||||
|
||||
func drainStatusEvents(subscription observable.Subscription[struct{}]) int {
|
||||
var count int
|
||||
for {
|
||||
select {
|
||||
case <-subscription:
|
||||
count++
|
||||
default:
|
||||
return count
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newTestLogger() log.ContextLogger {
|
||||
return log.NewNOPFactory().Logger()
|
||||
}
|
||||
|
||||
func newTestOCMExternalCredential(t *testing.T, body string, headers http.Header) (*externalCredential, observable.Subscription[struct{}]) {
|
||||
t.Helper()
|
||||
subscriber := observable.NewSubscriber[struct{}](8)
|
||||
subscription, _ := subscriber.Subscription()
|
||||
credential := &externalCredential{
|
||||
tag: "test",
|
||||
baseURL: "http://example.com",
|
||||
token: "token",
|
||||
pollInterval: 25 * time.Millisecond,
|
||||
forwardHTTPClient: &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) {
|
||||
if request.URL.String() != "http://example.com/ocm/v1/status?watch=true" {
|
||||
t.Fatalf("unexpected request URL: %s", request.URL.String())
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: headers.Clone(),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}, nil
|
||||
})},
|
||||
logger: newTestLogger(),
|
||||
statusSubscriber: subscriber,
|
||||
}
|
||||
return credential, subscription
|
||||
}
|
||||
|
||||
func newTestYamuxSessionPair(t *testing.T) (*yamux.Session, *yamux.Session) {
|
||||
t.Helper()
|
||||
clientConn, serverConn := net.Pipe()
|
||||
clientSession, err := yamux.Client(clientConn, defaultYamuxConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("create yamux client: %v", err)
|
||||
}
|
||||
serverSession, err := yamux.Server(serverConn, defaultYamuxConfig)
|
||||
if err != nil {
|
||||
clientSession.Close()
|
||||
t.Fatalf("create yamux server: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
clientSession.Close()
|
||||
serverSession.Close()
|
||||
})
|
||||
return clientSession, serverSession
|
||||
}
|
||||
|
||||
func TestExternalCredentialConnectStatusStreamOneShotRestoresLastUpdated(t *testing.T) {
|
||||
credential, subscription := newTestOCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", nil)
|
||||
oldTime := time.Unix(123, 0)
|
||||
credential.stateAccess.Lock()
|
||||
credential.state.lastUpdated = oldTime
|
||||
credential.stateAccess.Unlock()
|
||||
|
||||
result, err := credential.connectStatusStream(context.Background())
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %v", err)
|
||||
}
|
||||
if !result.oneShot {
|
||||
t.Fatal("expected one-shot result")
|
||||
}
|
||||
if result.frames != 1 {
|
||||
t.Fatalf("expected 1 frame, got %d", result.frames)
|
||||
}
|
||||
if !credential.lastUpdatedTime().Equal(oldTime) {
|
||||
t.Fatalf("expected lastUpdated restored to %v, got %v", oldTime, credential.lastUpdatedTime())
|
||||
}
|
||||
if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 {
|
||||
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
|
||||
}
|
||||
if count := drainStatusEvents(subscription); count != 1 {
|
||||
t.Fatalf("expected 1 status event, got %d", count)
|
||||
}
|
||||
|
||||
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
|
||||
if !oneShot {
|
||||
t.Fatal("expected one-shot backoff branch")
|
||||
}
|
||||
if failures != 0 {
|
||||
t.Fatalf("expected failures reset, got %d", failures)
|
||||
}
|
||||
if backoff != credential.pollInterval {
|
||||
t.Fatalf("expected poll interval backoff %v, got %v", credential.pollInterval, backoff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalCredentialConnectStatusStreamSingleFrameStreamReconnects(t *testing.T) {
|
||||
headers := make(http.Header)
|
||||
headers.Set(statusStreamHeader, "true")
|
||||
credential, subscription := newTestOCMExternalCredential(t, "{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}\n", headers)
|
||||
oldTime := time.Unix(123, 0)
|
||||
credential.stateAccess.Lock()
|
||||
credential.state.lastUpdated = oldTime
|
||||
credential.stateAccess.Unlock()
|
||||
|
||||
result, err := credential.connectStatusStream(context.Background())
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %v", err)
|
||||
}
|
||||
if result.oneShot {
|
||||
t.Fatal("did not expect one-shot result")
|
||||
}
|
||||
if result.frames != 1 {
|
||||
t.Fatalf("expected 1 frame, got %d", result.frames)
|
||||
}
|
||||
if credential.lastUpdatedTime().Equal(oldTime) {
|
||||
t.Fatal("expected lastUpdated to remain refreshed")
|
||||
}
|
||||
if credential.fiveHourUtilization() != 12 || credential.weeklyUtilization() != 34 {
|
||||
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
|
||||
}
|
||||
if count := drainStatusEvents(subscription); count != 1 {
|
||||
t.Fatalf("expected 1 status event, got %d", count)
|
||||
}
|
||||
|
||||
failures, backoff, oneShot := credential.nextStatusStreamBackoff(result, 3)
|
||||
if oneShot {
|
||||
t.Fatal("did not expect one-shot backoff branch")
|
||||
}
|
||||
if failures != 4 {
|
||||
t.Fatalf("expected failures incremented to 4, got %d", failures)
|
||||
}
|
||||
if backoff < 16*time.Second || backoff >= 24*time.Second {
|
||||
t.Fatalf("expected connector backoff in [16s, 24s), got %v", backoff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalCredentialConnectStatusStreamMultiFrameKeepsLastUpdated(t *testing.T) {
|
||||
credential, subscription := newTestOCMExternalCredential(t, strings.Join([]string{
|
||||
"{\"five_hour_utilization\":12,\"weekly_utilization\":34,\"plan_weight\":2}",
|
||||
"{\"five_hour_utilization\":13,\"weekly_utilization\":35,\"plan_weight\":3}",
|
||||
}, "\n"), nil)
|
||||
oldTime := time.Unix(123, 0)
|
||||
credential.stateAccess.Lock()
|
||||
credential.state.lastUpdated = oldTime
|
||||
credential.stateAccess.Unlock()
|
||||
|
||||
result, err := credential.connectStatusStream(context.Background())
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %v", err)
|
||||
}
|
||||
if result.oneShot {
|
||||
t.Fatal("did not expect one-shot result")
|
||||
}
|
||||
if result.frames != 2 {
|
||||
t.Fatalf("expected 2 frames, got %d", result.frames)
|
||||
}
|
||||
if credential.lastUpdatedTime().Equal(oldTime) {
|
||||
t.Fatal("expected lastUpdated to remain refreshed")
|
||||
}
|
||||
if credential.fiveHourUtilization() != 13 || credential.weeklyUtilization() != 35 {
|
||||
t.Fatalf("unexpected utilizations: 5h=%v weekly=%v", credential.fiveHourUtilization(), credential.weeklyUtilization())
|
||||
}
|
||||
if count := drainStatusEvents(subscription); count != 2 {
|
||||
t.Fatalf("expected 2 status events, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCredentialAvailabilityChangesEmitStatus(t *testing.T) {
|
||||
credentialPath := filepath.Join(t.TempDir(), "auth.json")
|
||||
err := os.WriteFile(credentialPath, []byte("{\"OPENAI_API_KEY\":\"sk-test\"}\n"), 0o600)
|
||||
if err != nil {
|
||||
t.Fatalf("write credential file: %v", err)
|
||||
}
|
||||
|
||||
subscriber := observable.NewSubscriber[struct{}](8)
|
||||
subscription, _ := subscriber.Subscription()
|
||||
credential := &defaultCredential{
|
||||
tag: "test",
|
||||
credentialPath: credentialPath,
|
||||
logger: newTestLogger(),
|
||||
statusSubscriber: subscriber,
|
||||
}
|
||||
|
||||
err = credential.markCredentialsUnavailable(errors.New("boom"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error from markCredentialsUnavailable")
|
||||
}
|
||||
if count := drainStatusEvents(subscription); count != 1 {
|
||||
t.Fatalf("expected 1 status event after unavailable transition, got %d", count)
|
||||
}
|
||||
|
||||
err = credential.reloadCredentials(true)
|
||||
if err != nil {
|
||||
t.Fatalf("reload credentials: %v", err)
|
||||
}
|
||||
if count := drainStatusEvents(subscription); count != 1 {
|
||||
t.Fatalf("expected 1 status event after recovery, got %d", count)
|
||||
}
|
||||
if !credential.isAvailable() {
|
||||
t.Fatal("expected credential to become available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalCredentialReverseSessionChangesEmitStatus(t *testing.T) {
|
||||
subscriber := observable.NewSubscriber[struct{}](8)
|
||||
subscription, _ := subscriber.Subscription()
|
||||
credential := &externalCredential{
|
||||
tag: "receiver",
|
||||
baseURL: reverseProxyBaseURL,
|
||||
pollInterval: time.Minute,
|
||||
logger: newTestLogger(),
|
||||
statusSubscriber: subscriber,
|
||||
}
|
||||
|
||||
clientSession, _ := newTestYamuxSessionPair(t)
|
||||
if !credential.setReverseSession(clientSession) {
|
||||
t.Fatal("expected reverse session to be accepted")
|
||||
}
|
||||
if count := drainStatusEvents(subscription); count != 1 {
|
||||
t.Fatalf("expected 1 status event after reverse session up, got %d", count)
|
||||
}
|
||||
if !credential.isAvailable() {
|
||||
t.Fatal("expected receiver credential to become available")
|
||||
}
|
||||
|
||||
credential.clearReverseSession(clientSession)
|
||||
if count := drainStatusEvents(subscription); count != 1 {
|
||||
t.Fatalf("expected 1 status event after reverse session down, got %d", count)
|
||||
}
|
||||
if credential.isAvailable() {
|
||||
t.Fatal("expected receiver credential to become unavailable")
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@@ -176,9 +177,11 @@ type Service struct {
|
||||
webSocketConns map[*webSocketSession]struct{}
|
||||
shuttingDown bool
|
||||
|
||||
providers map[string]credentialProvider
|
||||
allCredentials []Credential
|
||||
userConfigMap map[string]*option.OCMUser
|
||||
providers map[string]credentialProvider
|
||||
allCredentials []Credential
|
||||
userConfigMap map[string]*option.OCMUser
|
||||
statusSubscriber *observable.Subscriber[struct{}]
|
||||
statusObserver *observable.Observer[struct{}]
|
||||
}
|
||||
|
||||
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) {
|
||||
@@ -210,6 +213,8 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
|
||||
tokenMap: make(map[string]string),
|
||||
}
|
||||
|
||||
statusSubscriber := observable.NewSubscriber[struct{}](16)
|
||||
|
||||
service := &Service{
|
||||
Adapter: boxService.NewAdapter(C.TypeOCM, tag),
|
||||
ctx: ctx,
|
||||
@@ -222,8 +227,10 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
|
||||
Network: []string{N.NetworkTCP},
|
||||
Listen: options.ListenOptions,
|
||||
}),
|
||||
userManager: userManager,
|
||||
webSocketConns: make(map[*webSocketSession]struct{}),
|
||||
userManager: userManager,
|
||||
statusSubscriber: statusSubscriber,
|
||||
statusObserver: observable.NewObserver[struct{}](statusSubscriber, 8),
|
||||
webSocketConns: make(map[*webSocketSession]struct{}),
|
||||
}
|
||||
|
||||
providers, allCredentials, err := buildOCMCredentialProviders(ctx, options, logger)
|
||||
@@ -258,6 +265,7 @@ func (s *Service) Start(stage adapter.StartStage) error {
|
||||
s.userManager.UpdateUsers(s.options.Users)
|
||||
|
||||
for _, credential := range s.allCredentials {
|
||||
credential.setStatusSubscriber(s.statusSubscriber)
|
||||
if external, ok := credential.(*externalCredential); ok && external.reverse && external.connectorURL != nil {
|
||||
external.reverseService = s
|
||||
}
|
||||
@@ -324,6 +332,7 @@ func (s *Service) InterfaceUpdated() {
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
s.statusObserver.Close()
|
||||
webSocketSessions := s.startWebSocketShutdown()
|
||||
|
||||
err := common.Close(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -55,6 +56,11 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Query().Get("watch") == "true" {
|
||||
s.handleStatusStream(w, r, provider, userConfig)
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale(r.Context())
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
@@ -67,6 +73,76 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, provider credentialProvider, userConfig *option.OCMUser) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
subscription, done, err := s.statusObserver.Subscribe()
|
||||
if err != nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "service closing")
|
||||
return
|
||||
}
|
||||
defer s.statusObserver.UnSubscribe(subscription)
|
||||
|
||||
provider.pollIfStale(r.Context())
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set(statusStreamHeader, "true")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
lastFiveHour, lastWeekly, lastWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
buf := &bytes.Buffer{}
|
||||
json.NewEncoder(buf).Encode(map[string]float64{
|
||||
"five_hour_utilization": lastFiveHour,
|
||||
"weekly_utilization": lastWeekly,
|
||||
"plan_weight": lastWeight,
|
||||
})
|
||||
_, writeErr := w.Write(buf.Bytes())
|
||||
if writeErr != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-done:
|
||||
return
|
||||
case <-subscription:
|
||||
for {
|
||||
select {
|
||||
case <-subscription:
|
||||
default:
|
||||
goto drained
|
||||
}
|
||||
}
|
||||
drained:
|
||||
fiveHour, weekly, weight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
if fiveHour == lastFiveHour && weekly == lastWeekly && weight == lastWeight {
|
||||
continue
|
||||
}
|
||||
lastFiveHour = fiveHour
|
||||
lastWeekly = weekly
|
||||
lastWeight = weight
|
||||
buf.Reset()
|
||||
json.NewEncoder(buf).Encode(map[string]float64{
|
||||
"five_hour_utilization": fiveHour,
|
||||
"weekly_utilization": weekly,
|
||||
"plan_weight": weight,
|
||||
})
|
||||
_, writeErr = w.Write(buf.Bytes())
|
||||
if writeErr != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) {
|
||||
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
|
||||
for _, credential := range provider.allCredentials() {
|
||||
|
||||
Reference in New Issue
Block a user