ccm,ocm: unify HTTP request retry with fast retry and exponential backoff

This commit is contained in:
世界
2026-03-13 20:05:54 +08:00
parent af94ea9089
commit 02a1409e9a
6 changed files with 119 additions and 93 deletions

View File

@@ -2,6 +2,7 @@ package ccm
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@@ -142,7 +143,7 @@ func (c *oauthCredentials) needsRefresh() bool {
return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs
}
func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
if credentials.RefreshToken == "" {
return nil, E.New("refresh token is empty")
}
@@ -156,15 +157,16 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut
return nil, E.Cause(err, "marshal request")
}
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
request.Header.Set("User-Agent", ccmUserAgentValue)
response, err := httpClient.Do(request)
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
request.Header.Set("User-Agent", ccmUserAgentValue)
return request, nil
})
if err != nil {
return nil, err
}

View File

@@ -449,14 +449,14 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
Timeout: 5 * time.Second,
}
request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil)
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": create request: ", err)
return
}
request.Header.Set("Authorization", "Bearer "+c.token)
response, err := httpClient.Do(request)
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+c.token)
return request, nil
})
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": ", err)
c.stateMutex.Lock()

View File

@@ -5,7 +5,6 @@ import (
"context"
stdTLS "crypto/tls"
"encoding/json"
"errors"
"io"
"math"
"math/rand/v2"
@@ -29,6 +28,38 @@ import (
const defaultPollInterval = 60 * time.Minute
const (
httpRetryMaxAttempts = 3
httpRetryInitialDelay = 200 * time.Millisecond
)
func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) {
var lastError error
for attempt := range httpRetryMaxAttempts {
if attempt > 0 {
delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1))
select {
case <-ctx.Done():
return nil, lastError
case <-time.After(delay):
}
}
request, err := buildRequest()
if err != nil {
return nil, err
}
response, err := client.Do(request)
if err == nil {
return response, nil
}
lastError = err
if ctx.Err() != nil {
return nil, lastError
}
}
return nil, lastError
}
type credentialState struct {
fiveHourUtilization float64
fiveHourReset time.Time
@@ -46,6 +77,7 @@ type credentialState struct {
type defaultCredential struct {
tag string
serviceContext context.Context
credentialPath string
credentialFilePath string
credentials *oauthCredentials
@@ -151,6 +183,7 @@ func newDefaultCredential(ctx context.Context, tag string, options option.CCMDef
requestContext, cancelRequests := context.WithCancel(context.Background())
credential := &defaultCredential{
tag: tag,
serviceContext: ctx,
credentialPath: options.CredentialPath,
reserve5h: reserve5h,
reserveWeekly: reserveWeekly,
@@ -231,7 +264,7 @@ func (c *defaultCredential) getAccessToken() (string, error) {
}
baseCredentials := cloneCredentials(c.credentials)
newCredentials, err := refreshToken(c.httpClient, c.credentials)
newCredentials, err := refreshToken(c.serviceContext, c.httpClient, c.credentials)
if err != nil {
return "", err
}
@@ -498,16 +531,6 @@ func (c *defaultCredential) earliestReset() time.Time {
return earliest
}
const pollUsageMaxRetries = 3
func isTimeoutError(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}
func (c *defaultCredential) pollUsage(ctx context.Context) {
if !c.pollAccess.TryLock() {
return
@@ -531,30 +554,18 @@ func (c *defaultCredential) pollUsage(ctx context.Context) {
Timeout: 5 * time.Second,
}
var response *http.Response
for attempt := range pollUsageMaxRetries {
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeAPIBaseURL+"/api/oauth/usage", nil)
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": create request: ", err)
return
return nil, err
}
request.Header.Set("Authorization", "Bearer "+accessToken)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("User-Agent", ccmUserAgentValue)
request.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
response, err = httpClient.Do(request)
if err == nil {
break
}
if !isTimeoutError(err) {
c.logger.Error("poll usage for ", c.tag, ": ", err)
return
}
if attempt < pollUsageMaxRetries-1 {
c.logger.Warn("poll usage for ", c.tag, ": timeout, retrying (", attempt+1, "/", pollUsageMaxRetries, ")")
continue
}
return request, nil
})
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": ", err)
return
}

View File

@@ -2,6 +2,7 @@ package ocm
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@@ -118,7 +119,7 @@ func (c *oauthCredentials) needsRefresh() bool {
return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour
}
func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" {
return nil, E.New("refresh token is empty")
}
@@ -133,14 +134,15 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut
return nil, E.Cause(err, "marshal request")
}
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
response, err := httpClient.Do(request)
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
return request, nil
})
if err != nil {
return nil, err
}

View File

@@ -485,14 +485,14 @@ func (c *externalCredential) pollUsage(ctx context.Context) {
Timeout: 5 * time.Second,
}
request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil)
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": create request: ", err)
return
}
request.Header.Set("Authorization", "Bearer "+c.token)
response, err := httpClient.Do(request)
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+c.token)
return request, nil
})
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": ", err)
c.stateMutex.Lock()

View File

@@ -5,7 +5,6 @@ import (
"context"
stdTLS "crypto/tls"
"encoding/json"
"errors"
"io"
"math/rand/v2"
"net"
@@ -29,6 +28,38 @@ import (
const defaultPollInterval = 60 * time.Minute
const (
httpRetryMaxAttempts = 3
httpRetryInitialDelay = 200 * time.Millisecond
)
func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) {
var lastError error
for attempt := range httpRetryMaxAttempts {
if attempt > 0 {
delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1))
select {
case <-ctx.Done():
return nil, lastError
case <-time.After(delay):
}
}
request, err := buildRequest()
if err != nil {
return nil, err
}
response, err := client.Do(request)
if err == nil {
return response, nil
}
lastError = err
if ctx.Err() != nil {
return nil, lastError
}
}
return nil, lastError
}
type credentialState struct {
fiveHourUtilization float64
fiveHourReset time.Time
@@ -46,6 +77,7 @@ type credentialState struct {
type defaultCredential struct {
tag string
serviceContext context.Context
credentialPath string
credentialFilePath string
credentials *oauthCredentials
@@ -159,6 +191,7 @@ func newDefaultCredential(ctx context.Context, tag string, options option.OCMDef
requestContext, cancelRequests := context.WithCancel(context.Background())
credential := &defaultCredential{
tag: tag,
serviceContext: ctx,
credentialPath: options.CredentialPath,
reserve5h: reserve5h,
reserveWeekly: reserveWeekly,
@@ -240,7 +273,7 @@ func (c *defaultCredential) getAccessToken() (string, error) {
}
baseCredentials := cloneCredentials(c.credentials)
newCredentials, err := refreshToken(c.httpClient, c.credentials)
newCredentials, err := refreshToken(c.serviceContext, c.httpClient, c.credentials)
if err != nil {
return "", err
}
@@ -507,16 +540,6 @@ func (c *defaultCredential) earliestReset() time.Time {
return earliest
}
const pollUsageMaxRetries = 3
func isTimeoutError(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}
func (c *defaultCredential) pollUsage(ctx context.Context) {
if !c.pollAccess.TryLock() {
return
@@ -551,30 +574,18 @@ func (c *defaultCredential) pollUsage(ctx context.Context) {
Timeout: 5 * time.Second,
}
var response *http.Response
for attempt := range pollUsageMaxRetries {
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, usageURL, nil)
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": create request: ", err)
return
return nil, err
}
request.Header.Set("Authorization", "Bearer "+accessToken)
if accountID != "" {
request.Header.Set("ChatGPT-Account-Id", accountID)
}
response, err = httpClient.Do(request)
if err == nil {
break
}
if !isTimeoutError(err) {
c.logger.Error("poll usage for ", c.tag, ": ", err)
return
}
if attempt < pollUsageMaxRetries-1 {
c.logger.Warn("poll usage for ", c.tag, ": timeout, retrying (", attempt+1, "/", pollUsageMaxRetries, ")")
continue
}
return request, nil
})
if err != nil {
c.logger.Error("poll usage for ", c.tag, ": ", err)
return
}