diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index be1bb7f5b..11ddc8dad 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -614,7 +614,7 @@ func (c *externalCredential) pollUsage() { response, err := c.doPollUsageRequest(ctx) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": ", err) - c.clearPollFailures() + c.incrementPollFailures() return } defer response.Body.Close() @@ -622,35 +622,35 @@ func (c *externalCredential) pollUsage() { if response.StatusCode != http.StatusOK { body, _ := io.ReadAll(response.Body) c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) - c.clearPollFailures() + c.incrementPollFailures() return } body, err := io.ReadAll(response.Body) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": read body: ", err) - c.clearPollFailures() + c.incrementPollFailures() return } var rawFields map[string]json.RawMessage err = json.Unmarshal(body, &rawFields) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) - c.clearPollFailures() + c.incrementPollFailures() return } if rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil || rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil || rawFields["plan_weight"] == nil { c.logger.Error("poll usage for ", c.tag, ": invalid response") - c.clearPollFailures() + c.incrementPollFailures() return } var statusResponse statusPayload err = json.Unmarshal(body, &statusResponse) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) - c.clearPollFailures() + c.incrementPollFailures() return } @@ -943,11 +943,16 @@ func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Durati return baseInterval } -func (c *externalCredential) clearPollFailures() { +func (c *externalCredential) incrementPollFailures() { c.stateAccess.Lock() - c.state.consecutivePollFailures = 0 - c.checkTransitionLocked() + c.state.consecutivePollFailures++ + c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonPollFailed, time.Time{}) + shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } + c.emitStatusUpdate() } func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage { diff --git a/service/ccm/service_handler.go b/service/ccm/service_handler.go index ad5dfcaee..dd8c71d2d 100644 --- a/service/ccm/service_handler.go +++ b/service/ccm/service_handler.go @@ -218,6 +218,14 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } provider.pollIfStale() + if userConfig != nil && userConfig.ExternalCredential != "" { + for _, credential := range s.allCredentials { + if credential.tagName() == userConfig.ExternalCredential && !credential.isUsable() { + credential.pollUsage() + break + } + } + } s.cleanSessionModels() anthropicBetaHeader := r.Header.Get("anthropic-beta") diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 765c3a27b..67b9b2a1b 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -656,7 +656,7 @@ func (c *externalCredential) pollUsage() { response, err := c.doPollUsageRequest(ctx) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": ", err) - c.clearPollFailures() + c.incrementPollFailures() return } defer response.Body.Close() @@ -664,35 +664,35 @@ func (c *externalCredential) pollUsage() { if response.StatusCode != http.StatusOK { body, _ := io.ReadAll(response.Body) c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body)) - c.clearPollFailures() + c.incrementPollFailures() return } body, err := io.ReadAll(response.Body) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": read body: ", err) - c.clearPollFailures() + c.incrementPollFailures() return } var rawFields map[string]json.RawMessage err = json.Unmarshal(body, &rawFields) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) - c.clearPollFailures() + c.incrementPollFailures() return } if rawFields["limits"] == nil && (rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil || rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil || rawFields["plan_weight"] == nil) { c.logger.Error("poll usage for ", c.tag, ": invalid response") - c.clearPollFailures() + c.incrementPollFailures() return } var statusResponse statusPayload err = json.Unmarshal(body, &statusResponse) if err != nil { c.logger.Debug("poll usage for ", c.tag, ": decode: ", err) - c.clearPollFailures() + c.incrementPollFailures() return } @@ -985,11 +985,16 @@ func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Durati return baseInterval } -func (c *externalCredential) clearPollFailures() { +func (c *externalCredential) incrementPollFailures() { c.stateAccess.Lock() - c.state.consecutivePollFailures = 0 - c.checkTransitionLocked() + c.state.consecutivePollFailures++ + c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonPollFailed, time.Time{}) + shouldInterrupt := c.checkTransitionLocked() c.stateAccess.Unlock() + if shouldInterrupt { + c.interruptConnections() + } + c.emitStatusUpdate() } func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage { diff --git a/service/ocm/service_handler.go b/service/ocm/service_handler.go index c2e90a582..cfb34e15b 100644 --- a/service/ocm/service_handler.go +++ b/service/ocm/service_handler.go @@ -132,6 +132,14 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { } provider.pollIfStale() + if userConfig != nil && userConfig.ExternalCredential != "" { + for _, credential := range s.allCredentials { + if credential.tagName() == userConfig.ExternalCredential && !credential.isUsable() { + credential.pollUsage() + break + } + } + } selection := credentialSelectionForUser(userConfig)