From 02a1409e9addf4c29be152683325fd3e75bd78db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 20:05:54 +0800 Subject: [PATCH] ccm,ocm: unify HTTP request retry with fast retry and exponential backoff --- service/ccm/credential.go | 22 +++++----- service/ccm/credential_external.go | 16 +++---- service/ccm/credential_state.go | 69 +++++++++++++++++------------- service/ocm/credential.go | 20 +++++---- service/ocm/credential_external.go | 16 +++---- service/ocm/credential_state.go | 69 +++++++++++++++++------------- 6 files changed, 119 insertions(+), 93 deletions(-) diff --git a/service/ccm/credential.go b/service/ccm/credential.go index 6b3000861..75ae62f97 100644 --- a/service/ccm/credential.go +++ b/service/ccm/credential.go @@ -2,6 +2,7 @@ package ccm import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -142,7 +143,7 @@ func (c *oauthCredentials) needsRefresh() bool { return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs } -func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { +func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { if credentials.RefreshToken == "" { return nil, E.New("refresh token is empty") } @@ -156,15 +157,16 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut return nil, E.Cause(err, "marshal request") } - request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody)) - if err != nil { - return nil, err - } - request.Header.Set("Content-Type", "application/json") - request.Header.Set("Accept", "application/json") - request.Header.Set("User-Agent", ccmUserAgentValue) - - response, err := httpClient.Do(request) + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody)) + if err != nil { + return nil, err + } + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Accept", "application/json") + request.Header.Set("User-Agent", ccmUserAgentValue) + return request, nil + }) if err != nil { return nil, err } diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 0bcf15a77..8a0ffda86 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -449,14 +449,14 @@ func (c *externalCredential) pollUsage(ctx context.Context) { Timeout: 5 * time.Second, } - request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil) - if err != nil { - c.logger.Error("poll usage for ", c.tag, ": create request: ", err) - return - } - request.Header.Set("Authorization", "Bearer "+c.token) - - response, err := httpClient.Do(request) + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+c.token) + return request, nil + }) if err != nil { c.logger.Error("poll usage for ", c.tag, ": ", err) c.stateMutex.Lock() diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index 673af5c2e..6ecdd50a8 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -5,7 +5,6 @@ import ( "context" stdTLS "crypto/tls" "encoding/json" - "errors" "io" "math" "math/rand/v2" @@ -29,6 +28,38 @@ import ( const defaultPollInterval = 60 * time.Minute +const ( + httpRetryMaxAttempts = 3 + httpRetryInitialDelay = 200 * time.Millisecond +) + +func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) { + var lastError error + for attempt := range httpRetryMaxAttempts { + if attempt > 0 { + delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1)) + select { + case <-ctx.Done(): + return nil, lastError + case <-time.After(delay): + } + } + request, err := buildRequest() + if err != nil { + return nil, err + } + response, err := client.Do(request) + if err == nil { + return response, nil + } + lastError = err + if ctx.Err() != nil { + return nil, lastError + } + } + return nil, lastError +} + type credentialState struct { fiveHourUtilization float64 fiveHourReset time.Time @@ -46,6 +77,7 @@ type credentialState struct { type defaultCredential struct { tag string + serviceContext context.Context credentialPath string credentialFilePath string credentials *oauthCredentials @@ -151,6 +183,7 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef requestContext, cancelRequests := context.WithCancel(context.Background()) credential := &defaultCredential{ tag: tag, + serviceContext: ctx, credentialPath: options.CredentialPath, reserve5h: reserve5h, reserveWeekly: reserveWeekly, @@ -231,7 +264,7 @@ func (c *defaultCredential) getAccessToken() (string, error) { } baseCredentials := cloneCredentials(c.credentials) - newCredentials, err := refreshToken(c.httpClient, c.credentials) + newCredentials, err := refreshToken(c.serviceContext, c.httpClient, c.credentials) if err != nil { return "", err } @@ -498,16 +531,6 @@ func (c *defaultCredential) earliestReset() time.Time { return earliest } -const pollUsageMaxRetries = 3 - -func isTimeoutError(err error) bool { - var netErr net.Error - if errors.As(err, &netErr) { - return netErr.Timeout() - } - return false -} - func (c *defaultCredential) pollUsage(ctx context.Context) { if !c.pollAccess.TryLock() { return @@ -531,30 +554,18 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { Timeout: 5 * time.Second, } - var response *http.Response - for attempt := range pollUsageMaxRetries { + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil) if err != nil { - c.logger.Error("poll usage for ", c.tag, ": create request: ", err) - return + return nil, err } request.Header.Set("Authorization", "Bearer "+accessToken) request.Header.Set("Content-Type", "application/json") request.Header.Set("User-Agent", ccmUserAgentValue) request.Header.Set("anthropic-beta", anthropicBetaOAuthValue) - - response, err = httpClient.Do(request) - if err == nil { - break - } - if !isTimeoutError(err) { - c.logger.Error("poll usage for ", c.tag, ": ", err) - return - } - if attempt < pollUsageMaxRetries-1 { - c.logger.Warn("poll usage for ", c.tag, ": timeout, retrying (", attempt+1, "/", pollUsageMaxRetries, ")") - continue - } + return request, nil + }) + if err != nil { c.logger.Error("poll usage for ", c.tag, ": ", err) return } diff --git a/service/ocm/credential.go b/service/ocm/credential.go index c143f868a..bb240b5ab 100644 --- a/service/ocm/credential.go +++ b/service/ocm/credential.go @@ -2,6 +2,7 @@ package ocm import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -118,7 +119,7 @@ func (c *oauthCredentials) needsRefresh() bool { return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour } -func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { +func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) { if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" { return nil, E.New("refresh token is empty") } @@ -133,14 +134,15 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut return nil, E.Cause(err, "marshal request") } - request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody)) - if err != nil { - return nil, err - } - request.Header.Set("Content-Type", "application/json") - request.Header.Set("Accept", "application/json") - - response, err := httpClient.Do(request) + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody)) + if err != nil { + return nil, err + } + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Accept", "application/json") + return request, nil + }) if err != nil { return nil, err } diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 83d37f385..0d19ea557 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -485,14 +485,14 @@ func (c *externalCredential) pollUsage(ctx context.Context) { Timeout: 5 * time.Second, } - request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil) - if err != nil { - c.logger.Error("poll usage for ", c.tag, ": create request: ", err) - return - } - request.Header.Set("Authorization", "Bearer "+c.token) - - response, err := httpClient.Do(request) + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", "Bearer "+c.token) + return request, nil + }) if err != nil { c.logger.Error("poll usage for ", c.tag, ": ", err) c.stateMutex.Lock() diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index b663632af..821183da2 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -5,7 +5,6 @@ import ( "context" stdTLS "crypto/tls" "encoding/json" - "errors" "io" "math/rand/v2" "net" @@ -29,6 +28,38 @@ import ( const defaultPollInterval = 60 * time.Minute +const ( + httpRetryMaxAttempts = 3 + httpRetryInitialDelay = 200 * time.Millisecond +) + +func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) { + var lastError error + for attempt := range httpRetryMaxAttempts { + if attempt > 0 { + delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1)) + select { + case <-ctx.Done(): + return nil, lastError + case <-time.After(delay): + } + } + request, err := buildRequest() + if err != nil { + return nil, err + } + response, err := client.Do(request) + if err == nil { + return response, nil + } + lastError = err + if ctx.Err() != nil { + return nil, lastError + } + } + return nil, lastError +} + type credentialState struct { fiveHourUtilization float64 fiveHourReset time.Time @@ -46,6 +77,7 @@ type credentialState struct { type defaultCredential struct { tag string + serviceContext context.Context credentialPath string credentialFilePath string credentials *oauthCredentials @@ -159,6 +191,7 @@ func newDefaultCredential(ctx context.Context, tag string, options option.OCMDef requestContext, cancelRequests := context.WithCancel(context.Background()) credential := &defaultCredential{ tag: tag, + serviceContext: ctx, credentialPath: options.CredentialPath, reserve5h: reserve5h, reserveWeekly: reserveWeekly, @@ -240,7 +273,7 @@ func (c *defaultCredential) getAccessToken() (string, error) { } baseCredentials := cloneCredentials(c.credentials) - newCredentials, err := refreshToken(c.httpClient, c.credentials) + newCredentials, err := refreshToken(c.serviceContext, c.httpClient, c.credentials) if err != nil { return "", err } @@ -507,16 +540,6 @@ func (c *defaultCredential) earliestReset() time.Time { return earliest } -const pollUsageMaxRetries = 3 - -func isTimeoutError(err error) bool { - var netErr net.Error - if errors.As(err, &netErr) { - return netErr.Timeout() - } - return false -} - func (c *defaultCredential) pollUsage(ctx context.Context) { if !c.pollAccess.TryLock() { return @@ -551,30 +574,18 @@ func (c *defaultCredential) pollUsage(ctx context.Context) { Timeout: 5 * time.Second, } - var response *http.Response - for attempt := range pollUsageMaxRetries { + response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) { request, err := http.NewRequestWithContext(ctx, http.MethodGet, usageURL, nil) if err != nil { - c.logger.Error("poll usage for ", c.tag, ": create request: ", err) - return + return nil, err } request.Header.Set("Authorization", "Bearer "+accessToken) if accountID != "" { request.Header.Set("ChatGPT-Account-Id", accountID) } - - response, err = httpClient.Do(request) - if err == nil { - break - } - if !isTimeoutError(err) { - c.logger.Error("poll usage for ", c.tag, ": ", err) - return - } - if attempt < pollUsageMaxRetries-1 { - c.logger.Warn("poll usage for ", c.tag, ": timeout, retrying (", attempt+1, "/", pollUsageMaxRetries, ")") - continue - } + return request, nil + }) + if err != nil { c.logger.Error("poll usage for ", c.tag, ": ", err) return }