mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-11 17:47:20 +10:00
Compare commits
95 Commits
e6427e8244
...
ccm-ocm-im
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bf9e390cf4 | ||
|
|
471c9c3b47 | ||
|
|
e7478ce947 | ||
|
|
cf11e0e74a | ||
|
|
a87a2b0e2b | ||
|
|
d9c298af1e | ||
|
|
cd5007ffbb | ||
|
|
e49d0685ad | ||
|
|
e1c9667319 | ||
|
|
6f45ea9c27 | ||
|
|
ca60f93184 | ||
|
|
1774d98793 | ||
|
|
6721dff48a | ||
|
|
4592164a7a | ||
|
|
92c8f4c5c8 | ||
|
|
441c98890d | ||
|
|
bb2169bc17 | ||
|
|
d996b60f44 | ||
|
|
084a6f1302 | ||
|
|
0950783479 | ||
|
|
f172a575b7 | ||
|
|
29b901a8b3 | ||
|
|
53f832330d | ||
|
|
99d9e06dd0 | ||
|
|
608b7e7fa2 | ||
|
|
7acba74755 | ||
|
|
2fe1e37b17 | ||
|
|
3bcfdd5455 | ||
|
|
b119d08764 | ||
|
|
6b8838d323 | ||
|
|
b3429ef1f3 | ||
|
|
a2d6cf9715 | ||
|
|
99e19e7033 | ||
|
|
969defeef0 | ||
|
|
f57eff33bb | ||
|
|
0a054b9aa4 | ||
|
|
7d15d9d282 | ||
|
|
cf2d677043 | ||
|
|
4a6a211775 | ||
|
|
f84832a369 | ||
|
|
f3c3022094 | ||
|
|
2dd093a32e | ||
|
|
14ade76956 | ||
|
|
9e3ec30d72 | ||
|
|
763e0af010 | ||
|
|
656b09d1be | ||
|
|
8e9c61e624 | ||
|
|
bc6e72408d | ||
|
|
56af7313b2 | ||
|
|
6878ad0d35 | ||
|
|
04bd63b455 | ||
|
|
51d564c9ff | ||
|
|
4d8baf7175 | ||
|
|
d1e5426bc8 | ||
|
|
4d907bc49d | ||
|
|
2c907bef2c | ||
|
|
d2300353fd | ||
|
|
f871113832 | ||
|
|
b97b9d9cfd | ||
|
|
badeeb91fe | ||
|
|
f4aaf33bf2 | ||
|
|
8fe8e238b3 | ||
|
|
6f433937ba | ||
|
|
80d5432654 | ||
|
|
8984b45ded | ||
|
|
25a9e4ce59 | ||
|
|
615a7e05b4 | ||
|
|
1628272507 | ||
|
|
ee65b375cb | ||
|
|
a09174a9a2 | ||
|
|
ce543a935f | ||
|
|
7f93c76b1a | ||
|
|
df6e47f5f1 | ||
|
|
1993da3735 | ||
|
|
22376472d0 | ||
|
|
74bf20d349 | ||
|
|
ff8585f7c6 | ||
|
|
4d5108fe7f | ||
|
|
3b177df05e | ||
|
|
1824881719 | ||
|
|
02a1409e9a | ||
|
|
af94ea9089 | ||
|
|
970951f369 | ||
|
|
15f3619995 | ||
|
|
b96ab4fef9 | ||
|
|
6829f91a06 | ||
|
|
8e5811a8c7 | ||
|
|
da8ff6f578 | ||
|
|
2801bce815 | ||
|
|
a11cd1e0c6 | ||
|
|
bd0fb83d2d | ||
|
|
9462b1deeb | ||
|
|
44d1c86b1b | ||
|
|
f802668915 | ||
|
|
4d217b7481 |
2
.github/CRONET_GO_VERSION
vendored
2
.github/CRONET_GO_VERSION
vendored
@@ -1 +1 @@
|
||||
2fef65f9dba90ddb89a87d00a6eb6165487c10c1
|
||||
ea7cd33752aed62603775af3df946c1b83f4b0b3
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -18,6 +18,6 @@
|
||||
.DS_Store
|
||||
/config.d/
|
||||
/venv/
|
||||
CLAUDE.md
|
||||
AGENTS.md
|
||||
/CLAUDE.md
|
||||
/AGENTS.md
|
||||
/.claude/
|
||||
|
||||
@@ -2,6 +2,7 @@ package adapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
@@ -82,6 +83,8 @@ type InboundContext struct {
|
||||
SourceGeoIPCode string
|
||||
GeoIPCode string
|
||||
ProcessInfo *ConnectionOwner
|
||||
SourceMACAddress net.HardwareAddr
|
||||
SourceHostname string
|
||||
QueryType uint16
|
||||
FakeIP bool
|
||||
|
||||
|
||||
23
adapter/neighbor.go
Normal file
23
adapter/neighbor.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type NeighborEntry struct {
|
||||
Address netip.Addr
|
||||
MACAddress net.HardwareAddr
|
||||
Hostname string
|
||||
}
|
||||
|
||||
type NeighborResolver interface {
|
||||
LookupMAC(address netip.Addr) (net.HardwareAddr, bool)
|
||||
LookupHostname(address netip.Addr) (string, bool)
|
||||
Start() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type NeighborUpdateListener interface {
|
||||
UpdateNeighborTable(entries []NeighborEntry)
|
||||
}
|
||||
@@ -36,6 +36,10 @@ type PlatformInterface interface {
|
||||
|
||||
UsePlatformNotification() bool
|
||||
SendNotification(notification *Notification) error
|
||||
|
||||
UsePlatformNeighborResolver() bool
|
||||
StartNeighborMonitor(listener NeighborUpdateListener) error
|
||||
CloseNeighborMonitor(listener NeighborUpdateListener) error
|
||||
}
|
||||
|
||||
type FindConnectionOwnerRequest struct {
|
||||
|
||||
@@ -26,6 +26,8 @@ type Router interface {
|
||||
RuleSet(tag string) (RuleSet, bool)
|
||||
Rules() []Rule
|
||||
NeedFindProcess() bool
|
||||
NeedFindNeighbor() bool
|
||||
NeighborResolver() NeighborResolver
|
||||
AppendTracker(tracker ConnectionTracker)
|
||||
ResetNetwork()
|
||||
}
|
||||
|
||||
@@ -38,6 +38,13 @@ const (
|
||||
TypeURLTest = "urltest"
|
||||
)
|
||||
|
||||
const (
|
||||
BalancerStrategyLeastUsed = "least-used"
|
||||
BalancerStrategyRoundRobin = "round-robin"
|
||||
BalancerStrategyRandom = "random"
|
||||
BalancerStrategyFallback = "fallback"
|
||||
)
|
||||
|
||||
func ProxyDisplayName(proxyType string) string {
|
||||
switch proxyType {
|
||||
case TypeTun:
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
icon: material/alert-decagram
|
||||
---
|
||||
|
||||
#### 1.14.0-alpha.3
|
||||
|
||||
* Fixes and improvements
|
||||
|
||||
#### 1.13.3
|
||||
|
||||
* Add OpenWrt and Alpine APK packages to release **1**
|
||||
@@ -26,6 +30,59 @@ from [SagerNet/go](https://github.com/SagerNet/go).
|
||||
|
||||
See [OCM](/configuration/service/ocm).
|
||||
|
||||
#### 1.12.24
|
||||
|
||||
* Fixes and improvements
|
||||
|
||||
#### 1.14.0-alpha.2
|
||||
|
||||
* Add OpenWrt and Alpine APK packages to release **1**
|
||||
* Backport to macOS 10.13 High Sierra **2**
|
||||
* OCM service: Add WebSocket support for Responses API **3**
|
||||
* Fixes and improvements
|
||||
|
||||
**1**:
|
||||
|
||||
Alpine APK files use `linux` in the filename to distinguish from OpenWrt APKs which use the `openwrt` prefix:
|
||||
|
||||
- OpenWrt: `sing-box_{version}_openwrt_{architecture}.apk`
|
||||
- Alpine: `sing-box_{version}_linux_{architecture}.apk`
|
||||
|
||||
**2**:
|
||||
|
||||
Legacy macOS binaries (with `-legacy-macos-10.13` suffix) now support
|
||||
macOS 10.13 High Sierra, built using Go 1.25 with patches
|
||||
from [SagerNet/go](https://github.com/SagerNet/go).
|
||||
|
||||
**3**:
|
||||
|
||||
See [OCM](/configuration/service/ocm).
|
||||
|
||||
#### 1.14.0-alpha.1
|
||||
|
||||
* Add `source_mac_address` and `source_hostname` rule items **1**
|
||||
* Add `include_mac_address` and `exclude_mac_address` TUN options **2**
|
||||
* Update NaiveProxy to 145.0.7632.159 **3**
|
||||
* Fixes and improvements
|
||||
|
||||
**1**:
|
||||
|
||||
New rule items for matching LAN devices by MAC address and hostname via neighbor resolution.
|
||||
Supported on Linux, macOS, or in graphical clients on Android and macOS.
|
||||
|
||||
See [Route Rule](/configuration/route/rule/#source_mac_address), [DNS Rule](/configuration/dns/rule/#source_mac_address) and [Neighbor Resolution](/configuration/shared/neighbor/).
|
||||
|
||||
**2**:
|
||||
|
||||
Limit or exclude devices from TUN routing by MAC address.
|
||||
Only supported on Linux with `auto_route` and `auto_redirect` enabled.
|
||||
|
||||
See [TUN](/configuration/inbound/tun/#include_mac_address).
|
||||
|
||||
**3**:
|
||||
|
||||
This is not an official update from NaiveProxy. Instead, it's a Chromium codebase update maintained by Project S.
|
||||
|
||||
#### 1.13.2
|
||||
|
||||
* Fixes and improvements
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
icon: material/alert-decagram
|
||||
---
|
||||
|
||||
!!! quote "Changes in sing-box 1.14.0"
|
||||
|
||||
:material-plus: [source_mac_address](#source_mac_address)
|
||||
:material-plus: [source_hostname](#source_hostname)
|
||||
|
||||
!!! quote "Changes in sing-box 1.13.0"
|
||||
|
||||
:material-plus: [interface_address](#interface_address)
|
||||
@@ -149,6 +154,12 @@ icon: material/alert-decagram
|
||||
"default_interface_address": [
|
||||
"2000::/3"
|
||||
],
|
||||
"source_mac_address": [
|
||||
"00:11:22:33:44:55"
|
||||
],
|
||||
"source_hostname": [
|
||||
"my-device"
|
||||
],
|
||||
"wifi_ssid": [
|
||||
"My WIFI"
|
||||
],
|
||||
@@ -408,6 +419,26 @@ Matches network interface (same values as `network_type`) address.
|
||||
|
||||
Match default interface address.
|
||||
|
||||
#### source_mac_address
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux, macOS, or in graphical clients on Android and macOS. See [Neighbor Resolution](/configuration/shared/neighbor/) for setup.
|
||||
|
||||
Match source device MAC address.
|
||||
|
||||
#### source_hostname
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux, macOS, or in graphical clients on Android and macOS. See [Neighbor Resolution](/configuration/shared/neighbor/) for setup.
|
||||
|
||||
Match source device hostname from DHCP leases.
|
||||
|
||||
#### wifi_ssid
|
||||
|
||||
!!! quote ""
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
icon: material/alert-decagram
|
||||
---
|
||||
|
||||
!!! quote "sing-box 1.14.0 中的更改"
|
||||
|
||||
:material-plus: [source_mac_address](#source_mac_address)
|
||||
:material-plus: [source_hostname](#source_hostname)
|
||||
|
||||
!!! quote "sing-box 1.13.0 中的更改"
|
||||
|
||||
:material-plus: [interface_address](#interface_address)
|
||||
@@ -149,6 +154,12 @@ icon: material/alert-decagram
|
||||
"default_interface_address": [
|
||||
"2000::/3"
|
||||
],
|
||||
"source_mac_address": [
|
||||
"00:11:22:33:44:55"
|
||||
],
|
||||
"source_hostname": [
|
||||
"my-device"
|
||||
],
|
||||
"wifi_ssid": [
|
||||
"My WIFI"
|
||||
],
|
||||
@@ -407,6 +418,26 @@ Available values: `wifi`, `cellular`, `ethernet` and `other`.
|
||||
|
||||
匹配默认接口地址。
|
||||
|
||||
#### source_mac_address
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux、macOS,或在 Android 和 macOS 图形客户端中支持。参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。
|
||||
|
||||
匹配源设备 MAC 地址。
|
||||
|
||||
#### source_hostname
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux、macOS,或在 Android 和 macOS 图形客户端中支持。参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。
|
||||
|
||||
匹配源设备从 DHCP 租约获取的主机名。
|
||||
|
||||
#### wifi_ssid
|
||||
|
||||
!!! quote ""
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
icon: material/new-box
|
||||
---
|
||||
|
||||
!!! quote "Changes in sing-box 1.14.0"
|
||||
|
||||
:material-plus: [include_mac_address](#include_mac_address)
|
||||
:material-plus: [exclude_mac_address](#exclude_mac_address)
|
||||
|
||||
!!! quote "Changes in sing-box 1.13.3"
|
||||
|
||||
:material-alert: [strict_route](#strict_route)
|
||||
@@ -129,6 +134,12 @@ icon: material/new-box
|
||||
"exclude_package": [
|
||||
"com.android.captiveportallogin"
|
||||
],
|
||||
"include_mac_address": [
|
||||
"00:11:22:33:44:55"
|
||||
],
|
||||
"exclude_mac_address": [
|
||||
"66:77:88:99:aa:bb"
|
||||
],
|
||||
"platform": {
|
||||
"http_proxy": {
|
||||
"enabled": false,
|
||||
@@ -555,6 +566,30 @@ Limit android packages in route.
|
||||
|
||||
Exclude android packages in route.
|
||||
|
||||
#### include_mac_address
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux with `auto_route` and `auto_redirect` enabled.
|
||||
|
||||
Limit MAC addresses in route. Not limited by default.
|
||||
|
||||
Conflict with `exclude_mac_address`.
|
||||
|
||||
#### exclude_mac_address
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux with `auto_route` and `auto_redirect` enabled.
|
||||
|
||||
Exclude MAC addresses in route.
|
||||
|
||||
Conflict with `include_mac_address`.
|
||||
|
||||
#### platform
|
||||
|
||||
Platform-specific settings, provided by client applications.
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
icon: material/new-box
|
||||
---
|
||||
|
||||
!!! quote "sing-box 1.14.0 中的更改"
|
||||
|
||||
:material-plus: [include_mac_address](#include_mac_address)
|
||||
:material-plus: [exclude_mac_address](#exclude_mac_address)
|
||||
|
||||
!!! quote "sing-box 1.13.3 中的更改"
|
||||
|
||||
:material-alert: [strict_route](#strict_route)
|
||||
@@ -130,6 +135,12 @@ icon: material/new-box
|
||||
"exclude_package": [
|
||||
"com.android.captiveportallogin"
|
||||
],
|
||||
"include_mac_address": [
|
||||
"00:11:22:33:44:55"
|
||||
],
|
||||
"exclude_mac_address": [
|
||||
"66:77:88:99:aa:bb"
|
||||
],
|
||||
"platform": {
|
||||
"http_proxy": {
|
||||
"enabled": false,
|
||||
@@ -543,6 +554,30 @@ TCP/IP 栈。
|
||||
|
||||
排除路由的 Android 应用包名。
|
||||
|
||||
#### include_mac_address
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux,且需要 `auto_route` 和 `auto_redirect` 已启用。
|
||||
|
||||
限制被路由的 MAC 地址。默认不限制。
|
||||
|
||||
与 `exclude_mac_address` 冲突。
|
||||
|
||||
#### exclude_mac_address
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux,且需要 `auto_route` 和 `auto_redirect` 已启用。
|
||||
|
||||
排除路由的 MAC 地址。
|
||||
|
||||
与 `include_mac_address` 冲突。
|
||||
|
||||
#### platform
|
||||
|
||||
平台特定的设置,由客户端应用提供。
|
||||
|
||||
@@ -4,6 +4,11 @@ icon: material/alert-decagram
|
||||
|
||||
# Route
|
||||
|
||||
!!! quote "Changes in sing-box 1.14.0"
|
||||
|
||||
:material-plus: [find_neighbor](#find_neighbor)
|
||||
:material-plus: [dhcp_lease_files](#dhcp_lease_files)
|
||||
|
||||
!!! quote "Changes in sing-box 1.12.0"
|
||||
|
||||
:material-plus: [default_domain_resolver](#default_domain_resolver)
|
||||
@@ -35,6 +40,9 @@ icon: material/alert-decagram
|
||||
"override_android_vpn": false,
|
||||
"default_interface": "",
|
||||
"default_mark": 0,
|
||||
"find_process": false,
|
||||
"find_neighbor": false,
|
||||
"dhcp_lease_files": [],
|
||||
"default_domain_resolver": "", // or {}
|
||||
"default_network_strategy": "",
|
||||
"default_network_type": [],
|
||||
@@ -107,6 +115,38 @@ Set routing mark by default.
|
||||
|
||||
Takes no effect if `outbound.routing_mark` is set.
|
||||
|
||||
#### find_process
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux, Windows, and macOS.
|
||||
|
||||
Enable process search for logging when no `process_name`, `process_path`, `package_name`, `user` or `user_id` rules exist.
|
||||
|
||||
#### find_neighbor
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux and macOS.
|
||||
|
||||
Enable neighbor resolution for logging when no `source_mac_address` or `source_hostname` rules exist.
|
||||
|
||||
See [Neighbor Resolution](/configuration/shared/neighbor/) for setup.
|
||||
|
||||
#### dhcp_lease_files
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux and macOS.
|
||||
|
||||
Custom DHCP lease file paths for hostname and MAC address resolution.
|
||||
|
||||
Automatically detected from common DHCP servers (dnsmasq, odhcpd, ISC dhcpd, Kea) if empty.
|
||||
|
||||
#### default_domain_resolver
|
||||
|
||||
!!! question "Since sing-box 1.12.0"
|
||||
|
||||
@@ -4,6 +4,11 @@ icon: material/alert-decagram
|
||||
|
||||
# 路由
|
||||
|
||||
!!! quote "sing-box 1.14.0 中的更改"
|
||||
|
||||
:material-plus: [find_neighbor](#find_neighbor)
|
||||
:material-plus: [dhcp_lease_files](#dhcp_lease_files)
|
||||
|
||||
!!! quote "sing-box 1.12.0 中的更改"
|
||||
|
||||
:material-plus: [default_domain_resolver](#default_domain_resolver)
|
||||
@@ -37,6 +42,9 @@ icon: material/alert-decagram
|
||||
"override_android_vpn": false,
|
||||
"default_interface": "",
|
||||
"default_mark": 0,
|
||||
"find_process": false,
|
||||
"find_neighbor": false,
|
||||
"dhcp_lease_files": [],
|
||||
"default_network_strategy": "",
|
||||
"default_fallback_delay": ""
|
||||
}
|
||||
@@ -106,6 +114,38 @@ icon: material/alert-decagram
|
||||
|
||||
如果设置了 `outbound.routing_mark` 设置,则不生效。
|
||||
|
||||
#### find_process
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux、Windows 和 macOS。
|
||||
|
||||
在没有 `process_name`、`process_path`、`package_name`、`user` 或 `user_id` 规则时启用进程搜索以输出日志。
|
||||
|
||||
#### find_neighbor
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux 和 macOS。
|
||||
|
||||
在没有 `source_mac_address` 或 `source_hostname` 规则时启用邻居解析以输出日志。
|
||||
|
||||
参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。
|
||||
|
||||
#### dhcp_lease_files
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux 和 macOS。
|
||||
|
||||
用于主机名和 MAC 地址解析的自定义 DHCP 租约文件路径。
|
||||
|
||||
为空时自动从常见 DHCP 服务器(dnsmasq、odhcpd、ISC dhcpd、Kea)检测。
|
||||
|
||||
#### default_domain_resolver
|
||||
|
||||
!!! question "自 sing-box 1.12.0 起"
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
icon: material/new-box
|
||||
---
|
||||
|
||||
!!! quote "Changes in sing-box 1.14.0"
|
||||
|
||||
:material-plus: [source_mac_address](#source_mac_address)
|
||||
:material-plus: [source_hostname](#source_hostname)
|
||||
|
||||
!!! quote "Changes in sing-box 1.13.0"
|
||||
|
||||
:material-plus: [interface_address](#interface_address)
|
||||
@@ -159,6 +164,12 @@ icon: material/new-box
|
||||
"tailscale",
|
||||
"wireguard"
|
||||
],
|
||||
"source_mac_address": [
|
||||
"00:11:22:33:44:55"
|
||||
],
|
||||
"source_hostname": [
|
||||
"my-device"
|
||||
],
|
||||
"rule_set": [
|
||||
"geoip-cn",
|
||||
"geosite-cn"
|
||||
@@ -449,6 +460,26 @@ Match specified outbounds' preferred routes.
|
||||
| `tailscale` | Match MagicDNS domains and peers' allowed IPs |
|
||||
| `wireguard` | Match peers's allowed IPs |
|
||||
|
||||
#### source_mac_address
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux, macOS, or in graphical clients on Android and macOS. See [Neighbor Resolution](/configuration/shared/neighbor/) for setup.
|
||||
|
||||
Match source device MAC address.
|
||||
|
||||
#### source_hostname
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux, macOS, or in graphical clients on Android and macOS. See [Neighbor Resolution](/configuration/shared/neighbor/) for setup.
|
||||
|
||||
Match source device hostname from DHCP leases.
|
||||
|
||||
#### rule_set
|
||||
|
||||
!!! question "Since sing-box 1.8.0"
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
icon: material/new-box
|
||||
---
|
||||
|
||||
!!! quote "sing-box 1.14.0 中的更改"
|
||||
|
||||
:material-plus: [source_mac_address](#source_mac_address)
|
||||
:material-plus: [source_hostname](#source_hostname)
|
||||
|
||||
!!! quote "sing-box 1.13.0 中的更改"
|
||||
|
||||
:material-plus: [interface_address](#interface_address)
|
||||
@@ -156,6 +161,12 @@ icon: material/new-box
|
||||
"tailscale",
|
||||
"wireguard"
|
||||
],
|
||||
"source_mac_address": [
|
||||
"00:11:22:33:44:55"
|
||||
],
|
||||
"source_hostname": [
|
||||
"my-device"
|
||||
],
|
||||
"rule_set": [
|
||||
"geoip-cn",
|
||||
"geosite-cn"
|
||||
@@ -446,6 +457,26 @@ icon: material/new-box
|
||||
| `tailscale` | 匹配 MagicDNS 域名和对端的 allowed IPs |
|
||||
| `wireguard` | 匹配对端的 allowed IPs |
|
||||
|
||||
#### source_mac_address
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux、macOS,或在 Android 和 macOS 图形客户端中支持。参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。
|
||||
|
||||
匹配源设备 MAC 地址。
|
||||
|
||||
#### source_hostname
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux、macOS,或在 Android 和 macOS 图形客户端中支持。参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。
|
||||
|
||||
匹配源设备从 DHCP 租约获取的主机名。
|
||||
|
||||
#### rule_set
|
||||
|
||||
!!! question "自 sing-box 1.8.0 起"
|
||||
|
||||
@@ -10,6 +10,14 @@ CCM (Claude Code Multiplexer) service is a multiplexing service that allows you
|
||||
|
||||
It handles OAuth authentication with Claude's API on your local machine while allowing remote Claude Code to authenticate using Auth Tokens via the `ANTHROPIC_AUTH_TOKEN` environment variable.
|
||||
|
||||
!!! quote "Changes in sing-box 1.14.0"
|
||||
|
||||
:material-plus: [credentials](#credentials)
|
||||
:material-alert: [credential_path](#credential_path)
|
||||
:material-alert: [usages_path](#usages_path)
|
||||
:material-alert: [users](#users)
|
||||
:material-alert: [detour](#detour)
|
||||
|
||||
### Structure
|
||||
|
||||
```json
|
||||
@@ -19,6 +27,7 @@ It handles OAuth authentication with Claude's API on your local machine while al
|
||||
... // Listen Fields
|
||||
|
||||
"credential_path": "",
|
||||
"credentials": [],
|
||||
"usages_path": "",
|
||||
"users": [],
|
||||
"headers": {},
|
||||
@@ -45,6 +54,106 @@ On macOS, credentials are read from the system keychain first, then fall back to
|
||||
|
||||
Refreshed tokens are automatically written back to the same location.
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
When `credential_path` points to a file, the service can start before the file exists. The credential becomes available automatically after the file is created or updated, and becomes unavailable immediately if the file is later removed or becomes invalid.
|
||||
|
||||
On macOS without an explicit `credential_path`, keychain changes are not watched. Automatic reload only applies to the credential file path.
|
||||
|
||||
Conflict with `credentials`.
|
||||
|
||||
#### credentials
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
List of credential configurations for multi-credential mode.
|
||||
|
||||
When set, top-level `credential_path`, `usages_path`, and `detour` are forbidden. Each user must specify a `credential` tag.
|
||||
|
||||
Each credential has a `type` field (`default`, `external`, or `balancer`) and a required `tag` field.
|
||||
|
||||
##### Default Credential
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "a",
|
||||
"credential_path": "/path/to/.credentials.json",
|
||||
"usages_path": "/path/to/usages.json",
|
||||
"detour": "",
|
||||
"reserve_5h": 20,
|
||||
"reserve_weekly": 20,
|
||||
"limit_5h": 0,
|
||||
"limit_weekly": 0
|
||||
}
|
||||
```
|
||||
|
||||
A single OAuth credential file. The `type` field can be omitted (defaults to `default`). The service can start before the file exists, and reloads file updates automatically.
|
||||
|
||||
- `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`.
|
||||
- `usages_path`: Optional usage tracking file for this credential.
|
||||
- `detour`: Outbound tag for connecting to the Claude API with this credential.
|
||||
- `reserve_5h`: Reserve threshold (1-99) for 5-hour window. Credential pauses at (100-N)% utilization. Conflict with `limit_5h`.
|
||||
- `reserve_weekly`: Reserve threshold (1-99) for weekly window. Credential pauses at (100-N)% utilization. Conflict with `limit_weekly`.
|
||||
- `limit_5h`: Explicit utilization cap (0-100) for 5-hour window. `0` means unset. Credential pauses when utilization reaches this value. Conflict with `reserve_5h`.
|
||||
- `limit_weekly`: Explicit utilization cap (0-100) for weekly window. `0` means unset. Credential pauses when utilization reaches this value. Conflict with `reserve_weekly`.
|
||||
|
||||
##### Balancer Credential
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "pool",
|
||||
"type": "balancer",
|
||||
"strategy": "",
|
||||
"credentials": ["a", "b"]
|
||||
}
|
||||
```
|
||||
|
||||
Assigns sessions to default credentials based on the selected strategy. Sessions are sticky until the assigned credential hits a rate limit.
|
||||
|
||||
- `strategy`: Selection strategy. One of `least_used` `round_robin` `random` `fallback`. `least_used` will be used by default.
|
||||
- `credentials`: ==Required== List of default credential tags.
|
||||
|
||||
##### Fallback Strategy
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "backup",
|
||||
"type": "balancer",
|
||||
"strategy": "fallback",
|
||||
"credentials": ["a", "b"]
|
||||
}
|
||||
```
|
||||
|
||||
A balancer with `strategy: "fallback"` uses credentials in order. It falls through to the next when the current one is exhausted.
|
||||
|
||||
- `credentials`: ==Required== Ordered list of default credential tags.
|
||||
|
||||
##### External Credential
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "remote",
|
||||
"type": "external",
|
||||
"url": "",
|
||||
"server": "",
|
||||
"server_port": 0,
|
||||
"token": "",
|
||||
"reverse": false,
|
||||
"detour": "",
|
||||
"usages_path": ""
|
||||
}
|
||||
```
|
||||
|
||||
Proxies requests through a remote CCM instance instead of using a local OAuth credential.
|
||||
|
||||
- `url`: URL of the remote CCM instance. Omit to create a receiver that only waits for inbound reverse connections.
|
||||
- `server`: Override server address for dialing, separate from URL hostname.
|
||||
- `server_port`: Override server port for dialing.
|
||||
- `token`: ==Required== Authentication token for the remote instance.
|
||||
- `reverse`: Enable connector mode. Requires `url`. A connector dials out to `/ccm/v1/reverse` on the remote instance and cannot serve local requests directly. When `url` is set without `reverse`, the credential proxies requests through the remote instance normally and prefers an established reverse connection when one is available.
|
||||
- `detour`: Outbound tag for connecting to the remote instance.
|
||||
- `usages_path`: Optional usage tracking file.
|
||||
|
||||
#### usages_path
|
||||
|
||||
Path to the file for storing aggregated API usage statistics.
|
||||
@@ -60,6 +169,10 @@ Statistics are organized by model, context window (200k standard vs 1M premium),
|
||||
|
||||
The statistics file is automatically saved every minute and upon service shutdown.
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
Conflict with `credentials`. In multi-credential mode, use `usages_path` on individual default credentials.
|
||||
|
||||
#### users
|
||||
|
||||
List of authorized users for token authentication.
|
||||
@@ -71,7 +184,10 @@ Object format:
|
||||
```json
|
||||
{
|
||||
"name": "",
|
||||
"token": ""
|
||||
"token": "",
|
||||
"credential": "",
|
||||
"external_credential": "",
|
||||
"allow_external_usage": false
|
||||
}
|
||||
```
|
||||
|
||||
@@ -80,6 +196,12 @@ Object fields:
|
||||
- `name`: Username identifier for tracking purposes.
|
||||
- `token`: Bearer token for authentication. Claude Code authenticates by setting the `ANTHROPIC_AUTH_TOKEN` environment variable to their token value.
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
- `credential`: Credential tag to use for this user. ==Required== when `credentials` is set.
|
||||
- `external_credential`: Tag of an external credential used only to rewrite response rate-limit headers with aggregated utilization from this user's other available credentials. It does not control request routing; request selection still comes from `credential` and `allow_external_usage`.
|
||||
- `allow_external_usage`: Allow this user to use external credentials. `false` by default.
|
||||
|
||||
#### headers
|
||||
|
||||
Custom HTTP headers to send to the Claude API.
|
||||
@@ -90,6 +212,10 @@ These headers will override any existing headers with the same name.
|
||||
|
||||
Outbound tag for connecting to the Claude API.
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
Conflict with `credentials`. In multi-credential mode, use `detour` on individual default credentials.
|
||||
|
||||
#### tls
|
||||
|
||||
TLS configuration, see [TLS](/configuration/shared/tls/#inbound).
|
||||
@@ -129,3 +255,51 @@ export ANTHROPIC_AUTH_TOKEN="ak-ccm-hello-world"
|
||||
|
||||
claude
|
||||
```
|
||||
|
||||
### Example with Multiple Credentials
|
||||
|
||||
#### Server
|
||||
|
||||
```json
|
||||
{
|
||||
"services": [
|
||||
{
|
||||
"type": "ccm",
|
||||
"listen": "0.0.0.0",
|
||||
"listen_port": 8080,
|
||||
"credentials": [
|
||||
{
|
||||
"tag": "a",
|
||||
"credential_path": "/home/user/.claude-a/.credentials.json",
|
||||
"usages_path": "/data/usages-a.json",
|
||||
"reserve_5h": 20,
|
||||
"reserve_weekly": 20
|
||||
},
|
||||
{
|
||||
"tag": "b",
|
||||
"credential_path": "/home/user/.claude-b/.credentials.json",
|
||||
"reserve_5h": 10,
|
||||
"reserve_weekly": 10
|
||||
},
|
||||
{
|
||||
"tag": "pool",
|
||||
"type": "balancer",
|
||||
"credentials": ["a", "b"]
|
||||
}
|
||||
],
|
||||
"users": [
|
||||
{
|
||||
"name": "alice",
|
||||
"token": "ak-ccm-hello-world",
|
||||
"credential": "pool"
|
||||
},
|
||||
{
|
||||
"name": "bob",
|
||||
"token": "ak-ccm-hello-bob",
|
||||
"credential": "a"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
@@ -10,6 +10,14 @@ CCM(Claude Code 多路复用器)服务是一个多路复用服务,允许
|
||||
|
||||
它在本地机器上处理与 Claude API 的 OAuth 身份验证,同时允许远程 Claude Code 通过 `ANTHROPIC_AUTH_TOKEN` 环境变量使用认证令牌进行身份验证。
|
||||
|
||||
!!! quote "sing-box 1.14.0 中的更改"
|
||||
|
||||
:material-plus: [credentials](#credentials)
|
||||
:material-alert: [credential_path](#credential_path)
|
||||
:material-alert: [usages_path](#usages_path)
|
||||
:material-alert: [users](#users)
|
||||
:material-alert: [detour](#detour)
|
||||
|
||||
### 结构
|
||||
|
||||
```json
|
||||
@@ -19,6 +27,7 @@ CCM(Claude Code 多路复用器)服务是一个多路复用服务,允许
|
||||
... // 监听字段
|
||||
|
||||
"credential_path": "",
|
||||
"credentials": [],
|
||||
"usages_path": "",
|
||||
"users": [],
|
||||
"headers": {},
|
||||
@@ -45,6 +54,106 @@ Claude Code OAuth 凭据文件的路径。
|
||||
|
||||
刷新的令牌会自动写回相同位置。
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
当 `credential_path` 指向文件时,即使文件尚不存在,服务也可以启动。文件被创建或更新后,凭据会自动变为可用;如果文件之后被删除或变为无效,该凭据会立即变为不可用。
|
||||
|
||||
在 macOS 上如果未显式设置 `credential_path`,不会监听钥匙串变化。自动重载只作用于凭据文件路径。
|
||||
|
||||
与 `credentials` 冲突。
|
||||
|
||||
#### credentials
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
多凭据模式的凭据配置列表。
|
||||
|
||||
设置后,顶层 `credential_path`、`usages_path` 和 `detour` 被禁止。每个用户必须指定 `credential` 标签。
|
||||
|
||||
每个凭据有一个 `type` 字段(`default`、`external` 或 `balancer`)和一个必填的 `tag` 字段。
|
||||
|
||||
##### 默认凭据
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "a",
|
||||
"credential_path": "/path/to/.credentials.json",
|
||||
"usages_path": "/path/to/usages.json",
|
||||
"detour": "",
|
||||
"reserve_5h": 20,
|
||||
"reserve_weekly": 20,
|
||||
"limit_5h": 0,
|
||||
"limit_weekly": 0
|
||||
}
|
||||
```
|
||||
|
||||
单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。即使文件尚不存在,服务也可以启动,并会自动重载文件更新。
|
||||
|
||||
- `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。
|
||||
- `usages_path`:此凭据的可选使用跟踪文件。
|
||||
- `detour`:此凭据用于连接 Claude API 的出站标签。
|
||||
- `reserve_5h`:5 小时窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。与 `limit_5h` 冲突。
|
||||
- `reserve_weekly`:每周窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。与 `limit_weekly` 冲突。
|
||||
- `limit_5h`:5 小时窗口的显式利用率上限(0-100)。`0` 表示未设置显式上限。凭据在利用率达到此值时暂停。与 `reserve_5h` 冲突。
|
||||
- `limit_weekly`:每周窗口的显式利用率上限(0-100)。`0` 表示未设置显式上限。凭据在利用率达到此值时暂停。与 `reserve_weekly` 冲突。
|
||||
|
||||
##### 均衡凭据
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "pool",
|
||||
"type": "balancer",
|
||||
"strategy": "",
|
||||
"credentials": ["a", "b"]
|
||||
}
|
||||
```
|
||||
|
||||
根据选择的策略将会话分配给默认凭据。会话保持粘性,直到分配的凭据触发速率限制。
|
||||
|
||||
- `strategy`:选择策略。可选值:`least_used` `round_robin` `random` `fallback`。默认使用 `least_used`。
|
||||
- `credentials`:==必填== 默认凭据标签列表。
|
||||
|
||||
##### 回退策略
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "backup",
|
||||
"type": "balancer",
|
||||
"strategy": "fallback",
|
||||
"credentials": ["a", "b"]
|
||||
}
|
||||
```
|
||||
|
||||
将 `strategy` 设为 `fallback` 的均衡凭据会按顺序使用凭据。当前凭据耗尽后切换到下一个。
|
||||
|
||||
- `credentials`:==必填== 有序的默认凭据标签列表。
|
||||
|
||||
##### 外部凭据
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "remote",
|
||||
"type": "external",
|
||||
"url": "",
|
||||
"server": "",
|
||||
"server_port": 0,
|
||||
"token": "",
|
||||
"reverse": false,
|
||||
"detour": "",
|
||||
"usages_path": ""
|
||||
}
|
||||
```
|
||||
|
||||
通过远程 CCM 实例代理请求,而非使用本地 OAuth 凭据。
|
||||
|
||||
- `url`:远程 CCM 实例的 URL。省略时,此凭据作为仅等待入站反向连接的接收器。
|
||||
- `server`:覆盖拨号的服务器地址,与 URL 主机名分开。
|
||||
- `server_port`:覆盖拨号的服务器端口。
|
||||
- `token`:==必填== 远程实例的身份验证令牌。
|
||||
- `reverse`:启用连接器模式。要求设置 `url`。启用后,此凭据会主动拨出到远程实例的 `/ccm/v1/reverse`,且不能直接为本地请求提供服务。当设置了 `url` 但未启用 `reverse` 时,此凭据会正常通过远程实例转发请求,并在反向连接建立后优先使用该反向连接。
|
||||
- `detour`:用于连接远程实例的出站标签。
|
||||
- `usages_path`:可选的使用跟踪文件。
|
||||
|
||||
#### usages_path
|
||||
|
||||
用于存储聚合 API 使用统计信息的文件路径。
|
||||
@@ -60,6 +169,10 @@ Claude Code OAuth 凭据文件的路径。
|
||||
|
||||
统计文件每分钟自动保存一次,并在服务关闭时保存。
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `usages_path`。
|
||||
|
||||
#### users
|
||||
|
||||
用于令牌身份验证的授权用户列表。
|
||||
@@ -71,7 +184,10 @@ Claude Code OAuth 凭据文件的路径。
|
||||
```json
|
||||
{
|
||||
"name": "",
|
||||
"token": ""
|
||||
"token": "",
|
||||
"credential": "",
|
||||
"external_credential": "",
|
||||
"allow_external_usage": false
|
||||
}
|
||||
```
|
||||
|
||||
@@ -80,6 +196,12 @@ Claude Code OAuth 凭据文件的路径。
|
||||
- `name`:用于跟踪的用户名标识符。
|
||||
- `token`:用于身份验证的 Bearer 令牌。Claude Code 通过设置 `ANTHROPIC_AUTH_TOKEN` 环境变量为其令牌值进行身份验证。
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
- `credential`:此用户使用的凭据标签。设置 `credentials` 时==必填==。
|
||||
- `external_credential`:仅用于用此用户其他可用凭据的聚合利用率重写响应速率限制头的外部凭据标签。它不参与请求路由;请求选择仍由 `credential` 和 `allow_external_usage` 决定。
|
||||
- `allow_external_usage`:允许此用户使用外部凭据。默认为 `false`。
|
||||
|
||||
#### headers
|
||||
|
||||
发送到 Claude API 的自定义 HTTP 头。
|
||||
@@ -90,6 +212,10 @@ Claude Code OAuth 凭据文件的路径。
|
||||
|
||||
用于连接 Claude API 的出站标签。
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `detour`。
|
||||
|
||||
#### tls
|
||||
|
||||
TLS 配置,参阅 [TLS](/zh/configuration/shared/tls/#inbound)。
|
||||
@@ -129,3 +255,51 @@ export ANTHROPIC_AUTH_TOKEN="ak-ccm-hello-world"
|
||||
|
||||
claude
|
||||
```
|
||||
|
||||
### 多凭据示例
|
||||
|
||||
#### 服务端
|
||||
|
||||
```json
|
||||
{
|
||||
"services": [
|
||||
{
|
||||
"type": "ccm",
|
||||
"listen": "0.0.0.0",
|
||||
"listen_port": 8080,
|
||||
"credentials": [
|
||||
{
|
||||
"tag": "a",
|
||||
"credential_path": "/home/user/.claude-a/.credentials.json",
|
||||
"usages_path": "/data/usages-a.json",
|
||||
"reserve_5h": 20,
|
||||
"reserve_weekly": 20
|
||||
},
|
||||
{
|
||||
"tag": "b",
|
||||
"credential_path": "/home/user/.claude-b/.credentials.json",
|
||||
"reserve_5h": 10,
|
||||
"reserve_weekly": 10
|
||||
},
|
||||
{
|
||||
"tag": "pool",
|
||||
"type": "balancer",
|
||||
"credentials": ["a", "b"]
|
||||
}
|
||||
],
|
||||
"users": [
|
||||
{
|
||||
"name": "alice",
|
||||
"token": "ak-ccm-hello-world",
|
||||
"credential": "pool"
|
||||
},
|
||||
{
|
||||
"name": "bob",
|
||||
"token": "ak-ccm-hello-bob",
|
||||
"credential": "a"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
@@ -10,6 +10,14 @@ OCM (OpenAI Codex Multiplexer) service is a multiplexing service that allows you
|
||||
|
||||
It handles OAuth authentication with OpenAI's API on your local machine while allowing remote clients to authenticate using custom tokens.
|
||||
|
||||
!!! quote "Changes in sing-box 1.14.0"
|
||||
|
||||
:material-plus: [credentials](#credentials)
|
||||
:material-alert: [credential_path](#credential_path)
|
||||
:material-alert: [usages_path](#usages_path)
|
||||
:material-alert: [users](#users)
|
||||
:material-alert: [detour](#detour)
|
||||
|
||||
### Structure
|
||||
|
||||
```json
|
||||
@@ -19,6 +27,7 @@ It handles OAuth authentication with OpenAI's API on your local machine while al
|
||||
... // Listen Fields
|
||||
|
||||
"credential_path": "",
|
||||
"credentials": [],
|
||||
"usages_path": "",
|
||||
"users": [],
|
||||
"headers": {},
|
||||
@@ -43,6 +52,104 @@ If not specified, defaults to:
|
||||
|
||||
Refreshed tokens are automatically written back to the same location.
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
When `credential_path` points to a file, the service can start before the file exists. The credential becomes available automatically after the file is created or updated, and becomes unavailable immediately if the file is later removed or becomes invalid.
|
||||
|
||||
Conflict with `credentials`.
|
||||
|
||||
#### credentials
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
List of credential configurations for multi-credential mode.
|
||||
|
||||
When set, top-level `credential_path`, `usages_path`, and `detour` are forbidden. Each user must specify a `credential` tag.
|
||||
|
||||
Each credential has a `type` field (`default`, `external`, or `balancer`) and a required `tag` field.
|
||||
|
||||
##### Default Credential
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "a",
|
||||
"credential_path": "/path/to/auth.json",
|
||||
"usages_path": "/path/to/usages.json",
|
||||
"detour": "",
|
||||
"reserve_5h": 20,
|
||||
"reserve_weekly": 20,
|
||||
"limit_5h": 0,
|
||||
"limit_weekly": 0
|
||||
}
|
||||
```
|
||||
|
||||
A single OAuth credential file. The `type` field can be omitted (defaults to `default`). The service can start before the file exists, and reloads file updates automatically.
|
||||
|
||||
- `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`.
|
||||
- `usages_path`: Optional usage tracking file for this credential.
|
||||
- `detour`: Outbound tag for connecting to the OpenAI API with this credential.
|
||||
- `reserve_5h`: Reserve threshold (1-99) for primary rate limit window. Credential pauses at (100-N)% utilization. Conflict with `limit_5h`.
|
||||
- `reserve_weekly`: Reserve threshold (1-99) for secondary (weekly) rate limit window. Credential pauses at (100-N)% utilization. Conflict with `limit_weekly`.
|
||||
- `limit_5h`: Explicit utilization cap (0-100) for primary rate limit window. `0` means unset. Credential pauses when utilization reaches this value. Conflict with `reserve_5h`.
|
||||
- `limit_weekly`: Explicit utilization cap (0-100) for secondary (weekly) rate limit window. `0` means unset. Credential pauses when utilization reaches this value. Conflict with `reserve_weekly`.
|
||||
|
||||
##### Balancer Credential
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "pool",
|
||||
"type": "balancer",
|
||||
"strategy": "",
|
||||
"credentials": ["a", "b"]
|
||||
}
|
||||
```
|
||||
|
||||
Assigns sessions to default credentials based on the selected strategy. Sessions are sticky until the assigned credential hits a rate limit.
|
||||
|
||||
- `strategy`: Selection strategy. One of `least_used` `round_robin` `random` `fallback`. `least_used` will be used by default.
|
||||
- `credentials`: ==Required== List of default credential tags.
|
||||
|
||||
##### Fallback Strategy
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "backup",
|
||||
"type": "balancer",
|
||||
"strategy": "fallback",
|
||||
"credentials": ["a", "b"]
|
||||
}
|
||||
```
|
||||
|
||||
A balancer with `strategy: "fallback"` uses credentials in order. It falls through to the next when the current one is exhausted.
|
||||
|
||||
- `credentials`: ==Required== Ordered list of default credential tags.
|
||||
|
||||
##### External Credential
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "remote",
|
||||
"type": "external",
|
||||
"url": "",
|
||||
"server": "",
|
||||
"server_port": 0,
|
||||
"token": "",
|
||||
"reverse": false,
|
||||
"detour": "",
|
||||
"usages_path": ""
|
||||
}
|
||||
```
|
||||
|
||||
Proxies requests through a remote OCM instance instead of using a local OAuth credential.
|
||||
|
||||
- `url`: URL of the remote OCM instance. Omit to create a receiver that only waits for inbound reverse connections.
|
||||
- `server`: Override server address for dialing, separate from URL hostname.
|
||||
- `server_port`: Override server port for dialing.
|
||||
- `token`: ==Required== Authentication token for the remote instance.
|
||||
- `reverse`: Enable connector mode. Requires `url`. A connector dials out to `/ocm/v1/reverse` on the remote instance and cannot serve local requests directly. When `url` is set without `reverse`, the credential proxies requests through the remote instance normally and prefers an established reverse connection when one is available.
|
||||
- `detour`: Outbound tag for connecting to the remote instance.
|
||||
- `usages_path`: Optional usage tracking file.
|
||||
|
||||
#### usages_path
|
||||
|
||||
Path to the file for storing aggregated API usage statistics.
|
||||
@@ -58,6 +165,10 @@ Statistics are organized by model and optionally by user when authentication is
|
||||
|
||||
The statistics file is automatically saved every minute and upon service shutdown.
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
Conflict with `credentials`. In multi-credential mode, use `usages_path` on individual default credentials.
|
||||
|
||||
#### users
|
||||
|
||||
List of authorized users for token authentication.
|
||||
@@ -69,7 +180,10 @@ Object format:
|
||||
```json
|
||||
{
|
||||
"name": "",
|
||||
"token": ""
|
||||
"token": "",
|
||||
"credential": "",
|
||||
"external_credential": "",
|
||||
"allow_external_usage": false
|
||||
}
|
||||
```
|
||||
|
||||
@@ -78,6 +192,12 @@ Object fields:
|
||||
- `name`: Username identifier for tracking purposes.
|
||||
- `token`: Bearer token for authentication. Clients authenticate by setting the `Authorization: Bearer <token>` header.
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
- `credential`: Credential tag to use for this user. ==Required== when `credentials` is set.
|
||||
- `external_credential`: Tag of an external credential used only to rewrite response rate-limit headers with aggregated utilization from this user's other available credentials. It does not control request routing; request selection still comes from `credential` and `allow_external_usage`.
|
||||
- `allow_external_usage`: Allow this user to use external credentials. `false` by default.
|
||||
|
||||
#### headers
|
||||
|
||||
Custom HTTP headers to send to the OpenAI API.
|
||||
@@ -88,6 +208,10 @@ These headers will override any existing headers with the same name.
|
||||
|
||||
Outbound tag for connecting to the OpenAI API.
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
Conflict with `credentials`. In multi-credential mode, use `detour` on individual default credentials.
|
||||
|
||||
#### tls
|
||||
|
||||
TLS configuration, see [TLS](/configuration/shared/tls/#inbound).
|
||||
@@ -183,3 +307,51 @@ Then run:
|
||||
```bash
|
||||
codex --profile ocm
|
||||
```
|
||||
|
||||
### Example with Multiple Credentials
|
||||
|
||||
#### Server
|
||||
|
||||
```json
|
||||
{
|
||||
"services": [
|
||||
{
|
||||
"type": "ocm",
|
||||
"listen": "0.0.0.0",
|
||||
"listen_port": 8080,
|
||||
"credentials": [
|
||||
{
|
||||
"tag": "a",
|
||||
"credential_path": "/home/user/.codex-a/auth.json",
|
||||
"usages_path": "/data/usages-a.json",
|
||||
"reserve_5h": 20,
|
||||
"reserve_weekly": 20
|
||||
},
|
||||
{
|
||||
"tag": "b",
|
||||
"credential_path": "/home/user/.codex-b/auth.json",
|
||||
"reserve_5h": 10,
|
||||
"reserve_weekly": 10
|
||||
},
|
||||
{
|
||||
"tag": "pool",
|
||||
"type": "balancer",
|
||||
"credentials": ["a", "b"]
|
||||
}
|
||||
],
|
||||
"users": [
|
||||
{
|
||||
"name": "alice",
|
||||
"token": "sk-ocm-hello-world",
|
||||
"credential": "pool"
|
||||
},
|
||||
{
|
||||
"name": "bob",
|
||||
"token": "sk-ocm-hello-bob",
|
||||
"credential": "a"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
@@ -10,6 +10,14 @@ OCM(OpenAI Codex 多路复用器)服务是一个多路复用服务,允许
|
||||
|
||||
它在本地机器上处理与 OpenAI API 的 OAuth 身份验证,同时允许远程客户端使用自定义令牌进行身份验证。
|
||||
|
||||
!!! quote "sing-box 1.14.0 中的更改"
|
||||
|
||||
:material-plus: [credentials](#credentials)
|
||||
:material-alert: [credential_path](#credential_path)
|
||||
:material-alert: [usages_path](#usages_path)
|
||||
:material-alert: [users](#users)
|
||||
:material-alert: [detour](#detour)
|
||||
|
||||
### 结构
|
||||
|
||||
```json
|
||||
@@ -19,6 +27,7 @@ OCM(OpenAI Codex 多路复用器)服务是一个多路复用服务,允许
|
||||
... // 监听字段
|
||||
|
||||
"credential_path": "",
|
||||
"credentials": [],
|
||||
"usages_path": "",
|
||||
"users": [],
|
||||
"headers": {},
|
||||
@@ -43,6 +52,104 @@ OpenAI OAuth 凭据文件的路径。
|
||||
|
||||
刷新的令牌会自动写回相同位置。
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
当 `credential_path` 指向文件时,即使文件尚不存在,服务也可以启动。文件被创建或更新后,凭据会自动变为可用;如果文件之后被删除或变为无效,该凭据会立即变为不可用。
|
||||
|
||||
与 `credentials` 冲突。
|
||||
|
||||
#### credentials
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
多凭据模式的凭据配置列表。
|
||||
|
||||
设置后,顶层 `credential_path`、`usages_path` 和 `detour` 被禁止。每个用户必须指定 `credential` 标签。
|
||||
|
||||
每个凭据有一个 `type` 字段(`default`、`external` 或 `balancer`)和一个必填的 `tag` 字段。
|
||||
|
||||
##### 默认凭据
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "a",
|
||||
"credential_path": "/path/to/auth.json",
|
||||
"usages_path": "/path/to/usages.json",
|
||||
"detour": "",
|
||||
"reserve_5h": 20,
|
||||
"reserve_weekly": 20,
|
||||
"limit_5h": 0,
|
||||
"limit_weekly": 0
|
||||
}
|
||||
```
|
||||
|
||||
单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。即使文件尚不存在,服务也可以启动,并会自动重载文件更新。
|
||||
|
||||
- `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。
|
||||
- `usages_path`:此凭据的可选使用跟踪文件。
|
||||
- `detour`:此凭据用于连接 OpenAI API 的出站标签。
|
||||
- `reserve_5h`:主要速率限制窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。与 `limit_5h` 冲突。
|
||||
- `reserve_weekly`:次要(每周)速率限制窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。与 `limit_weekly` 冲突。
|
||||
- `limit_5h`:主要速率限制窗口的显式利用率上限(0-100)。`0` 表示未设置显式上限。凭据在利用率达到此值时暂停。与 `reserve_5h` 冲突。
|
||||
- `limit_weekly`:次要(每周)速率限制窗口的显式利用率上限(0-100)。`0` 表示未设置显式上限。凭据在利用率达到此值时暂停。与 `reserve_weekly` 冲突。
|
||||
|
||||
##### 均衡凭据
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "pool",
|
||||
"type": "balancer",
|
||||
"strategy": "",
|
||||
"credentials": ["a", "b"]
|
||||
}
|
||||
```
|
||||
|
||||
根据选择的策略将会话分配给默认凭据。会话保持粘性,直到分配的凭据触发速率限制。
|
||||
|
||||
- `strategy`:选择策略。可选值:`least_used` `round_robin` `random` `fallback`。默认使用 `least_used`。
|
||||
- `credentials`:==必填== 默认凭据标签列表。
|
||||
|
||||
##### 回退策略
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "backup",
|
||||
"type": "balancer",
|
||||
"strategy": "fallback",
|
||||
"credentials": ["a", "b"]
|
||||
}
|
||||
```
|
||||
|
||||
将 `strategy` 设为 `fallback` 的均衡凭据会按顺序使用凭据。当前凭据耗尽后切换到下一个。
|
||||
|
||||
- `credentials`:==必填== 有序的默认凭据标签列表。
|
||||
|
||||
##### 外部凭据
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "remote",
|
||||
"type": "external",
|
||||
"url": "",
|
||||
"server": "",
|
||||
"server_port": 0,
|
||||
"token": "",
|
||||
"reverse": false,
|
||||
"detour": "",
|
||||
"usages_path": ""
|
||||
}
|
||||
```
|
||||
|
||||
通过远程 OCM 实例代理请求,而非使用本地 OAuth 凭据。
|
||||
|
||||
- `url`:远程 OCM 实例的 URL。省略时,此凭据作为仅等待入站反向连接的接收器。
|
||||
- `server`:覆盖拨号的服务器地址,与 URL 主机名分开。
|
||||
- `server_port`:覆盖拨号的服务器端口。
|
||||
- `token`:==必填== 远程实例的身份验证令牌。
|
||||
- `reverse`:启用连接器模式。要求设置 `url`。启用后,此凭据会主动拨出到远程实例的 `/ocm/v1/reverse`,且不能直接为本地请求提供服务。当设置了 `url` 但未启用 `reverse` 时,此凭据会正常通过远程实例转发请求,并在反向连接建立后优先使用该反向连接。
|
||||
- `detour`:用于连接远程实例的出站标签。
|
||||
- `usages_path`:可选的使用跟踪文件。
|
||||
|
||||
#### usages_path
|
||||
|
||||
用于存储聚合 API 使用统计信息的文件路径。
|
||||
@@ -58,6 +165,10 @@ OpenAI OAuth 凭据文件的路径。
|
||||
|
||||
统计文件每分钟自动保存一次,并在服务关闭时保存。
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `usages_path`。
|
||||
|
||||
#### users
|
||||
|
||||
用于令牌身份验证的授权用户列表。
|
||||
@@ -69,7 +180,10 @@ OpenAI OAuth 凭据文件的路径。
|
||||
```json
|
||||
{
|
||||
"name": "",
|
||||
"token": ""
|
||||
"token": "",
|
||||
"credential": "",
|
||||
"external_credential": "",
|
||||
"allow_external_usage": false
|
||||
}
|
||||
```
|
||||
|
||||
@@ -78,6 +192,12 @@ OpenAI OAuth 凭据文件的路径。
|
||||
- `name`:用于跟踪的用户名标识符。
|
||||
- `token`:用于身份验证的 Bearer 令牌。客户端通过设置 `Authorization: Bearer <token>` 头进行身份验证。
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
- `credential`:此用户使用的凭据标签。设置 `credentials` 时==必填==。
|
||||
- `external_credential`:仅用于用此用户其他可用凭据的聚合利用率重写响应速率限制头的外部凭据标签。它不参与请求路由;请求选择仍由 `credential` 和 `allow_external_usage` 决定。
|
||||
- `allow_external_usage`:允许此用户使用外部凭据。默认为 `false`。
|
||||
|
||||
#### headers
|
||||
|
||||
发送到 OpenAI API 的自定义 HTTP 头。
|
||||
@@ -88,6 +208,10 @@ OpenAI OAuth 凭据文件的路径。
|
||||
|
||||
用于连接 OpenAI API 的出站标签。
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `detour`。
|
||||
|
||||
#### tls
|
||||
|
||||
TLS 配置,参阅 [TLS](/zh/configuration/shared/tls/#inbound)。
|
||||
@@ -184,3 +308,51 @@ model_provider = "ocm"
|
||||
```bash
|
||||
codex --profile ocm
|
||||
```
|
||||
|
||||
### 多凭据示例
|
||||
|
||||
#### 服务端
|
||||
|
||||
```json
|
||||
{
|
||||
"services": [
|
||||
{
|
||||
"type": "ocm",
|
||||
"listen": "0.0.0.0",
|
||||
"listen_port": 8080,
|
||||
"credentials": [
|
||||
{
|
||||
"tag": "a",
|
||||
"credential_path": "/home/user/.codex-a/auth.json",
|
||||
"usages_path": "/data/usages-a.json",
|
||||
"reserve_5h": 20,
|
||||
"reserve_weekly": 20
|
||||
},
|
||||
{
|
||||
"tag": "b",
|
||||
"credential_path": "/home/user/.codex-b/auth.json",
|
||||
"reserve_5h": 10,
|
||||
"reserve_weekly": 10
|
||||
},
|
||||
{
|
||||
"tag": "pool",
|
||||
"type": "balancer",
|
||||
"credentials": ["a", "b"]
|
||||
}
|
||||
],
|
||||
"users": [
|
||||
{
|
||||
"name": "alice",
|
||||
"token": "sk-ocm-hello-world",
|
||||
"credential": "pool"
|
||||
},
|
||||
{
|
||||
"name": "bob",
|
||||
"token": "sk-ocm-hello-bob",
|
||||
"credential": "a"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
49
docs/configuration/shared/neighbor.md
Normal file
49
docs/configuration/shared/neighbor.md
Normal file
@@ -0,0 +1,49 @@
|
||||
---
|
||||
icon: material/lan
|
||||
---
|
||||
|
||||
# Neighbor Resolution
|
||||
|
||||
Match LAN devices by MAC address and hostname using
|
||||
[`source_mac_address`](/configuration/route/rule/#source_mac_address) and
|
||||
[`source_hostname`](/configuration/route/rule/#source_hostname) rule items.
|
||||
|
||||
Neighbor resolution is automatically enabled when these rule items exist.
|
||||
Use [`route.find_neighbor`](/configuration/route/#find_neighbor) to force enable it for logging without rules.
|
||||
|
||||
## Linux
|
||||
|
||||
Works natively. No special setup required.
|
||||
|
||||
Hostname resolution requires DHCP lease files,
|
||||
automatically detected from common DHCP servers (dnsmasq, odhcpd, ISC dhcpd, Kea).
|
||||
Custom paths can be set via [`route.dhcp_lease_files`](/configuration/route/#dhcp_lease_files).
|
||||
|
||||
## Android
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported in graphical clients.
|
||||
|
||||
Requires Android 11 or above and ROOT.
|
||||
|
||||
Must use [VPNHotspot](https://github.com/Mygod/VPNHotspot) to share the VPN connection.
|
||||
ROM built-in features like "Use VPN for connected devices" can share VPN
|
||||
but cannot provide MAC address or hostname information.
|
||||
|
||||
Set **IP Masquerade Mode** to **None** in VPNHotspot settings.
|
||||
|
||||
Only route/DNS rules are supported. TUN include/exclude routes are not supported.
|
||||
|
||||
### Hostname Visibility
|
||||
|
||||
Hostname is only visible in sing-box if it is visible in VPNHotspot.
|
||||
For Apple devices, change **Private Wi-Fi Address** from **Rotating** to **Fixed** in the Wi-Fi settings
|
||||
of the connected network. Non-Apple devices are always visible.
|
||||
|
||||
## macOS
|
||||
|
||||
Requires the standalone version (macOS system extension).
|
||||
The App Store version can share the VPN as a hotspot but does not support MAC address or hostname reading.
|
||||
|
||||
See [VPN Hotspot](/manual/misc/vpn-hotspot/#macos) for Internet Sharing setup.
|
||||
49
docs/configuration/shared/neighbor.zh.md
Normal file
49
docs/configuration/shared/neighbor.zh.md
Normal file
@@ -0,0 +1,49 @@
|
||||
---
|
||||
icon: material/lan
|
||||
---
|
||||
|
||||
# 邻居解析
|
||||
|
||||
通过
|
||||
[`source_mac_address`](/configuration/route/rule/#source_mac_address) 和
|
||||
[`source_hostname`](/configuration/route/rule/#source_hostname) 规则项匹配局域网设备的 MAC 地址和主机名。
|
||||
|
||||
当这些规则项存在时,邻居解析自动启用。
|
||||
使用 [`route.find_neighbor`](/configuration/route/#find_neighbor) 可在没有规则时强制启用以输出日志。
|
||||
|
||||
## Linux
|
||||
|
||||
原生支持,无需特殊设置。
|
||||
|
||||
主机名解析需要 DHCP 租约文件,
|
||||
自动从常见 DHCP 服务器(dnsmasq、odhcpd、ISC dhcpd、Kea)检测。
|
||||
可通过 [`route.dhcp_lease_files`](/configuration/route/#dhcp_lease_files) 设置自定义路径。
|
||||
|
||||
## Android
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅在图形客户端中支持。
|
||||
|
||||
需要 Android 11 或以上版本和 ROOT。
|
||||
|
||||
必须使用 [VPNHotspot](https://github.com/Mygod/VPNHotspot) 共享 VPN 连接。
|
||||
ROM 自带的「通过 VPN 共享连接」等功能可以共享 VPN,
|
||||
但无法提供 MAC 地址或主机名信息。
|
||||
|
||||
在 VPNHotspot 设置中将 **IP 遮掩模式** 设为 **无**。
|
||||
|
||||
仅支持路由/DNS 规则。不支持 TUN 的 include/exclude 路由。
|
||||
|
||||
### 设备可见性
|
||||
|
||||
MAC 地址和主机名仅在 VPNHotspot 中可见时 sing-box 才能读取。
|
||||
对于 Apple 设备,需要在所连接网络的 Wi-Fi 设置中将**私有无线局域网地址**从**轮替**改为**固定**。
|
||||
非 Apple 设备始终可见。
|
||||
|
||||
## macOS
|
||||
|
||||
需要独立版本(macOS 系统扩展)。
|
||||
App Store 版本可以共享 VPN 热点但不支持 MAC 地址或主机名读取。
|
||||
|
||||
参阅 [VPN 热点](/manual/misc/vpn-hotspot/#macos) 了解互联网共享设置。
|
||||
@@ -144,6 +144,18 @@ func (s *platformInterfaceStub) SendNotification(notification *adapter.Notificat
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *platformInterfaceStub) UsePlatformNeighborResolver() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *platformInterfaceStub) StartNeighborMonitor(listener adapter.NeighborUpdateListener) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (s *platformInterfaceStub) CloseNeighborMonitor(listener adapter.NeighborUpdateListener) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *platformInterfaceStub) UsePlatformLocalDNSTransport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
53
experimental/libbox/neighbor.go
Normal file
53
experimental/libbox/neighbor.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package libbox
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type NeighborEntry struct {
|
||||
Address string
|
||||
MacAddress string
|
||||
Hostname string
|
||||
}
|
||||
|
||||
type NeighborEntryIterator interface {
|
||||
Next() *NeighborEntry
|
||||
HasNext() bool
|
||||
}
|
||||
|
||||
type NeighborSubscription struct {
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func (s *NeighborSubscription) Close() {
|
||||
close(s.done)
|
||||
}
|
||||
|
||||
func tableToIterator(table map[netip.Addr]net.HardwareAddr) NeighborEntryIterator {
|
||||
entries := make([]*NeighborEntry, 0, len(table))
|
||||
for address, mac := range table {
|
||||
entries = append(entries, &NeighborEntry{
|
||||
Address: address.String(),
|
||||
MacAddress: mac.String(),
|
||||
})
|
||||
}
|
||||
return &neighborEntryIterator{entries}
|
||||
}
|
||||
|
||||
type neighborEntryIterator struct {
|
||||
entries []*NeighborEntry
|
||||
}
|
||||
|
||||
func (i *neighborEntryIterator) HasNext() bool {
|
||||
return len(i.entries) > 0
|
||||
}
|
||||
|
||||
func (i *neighborEntryIterator) Next() *NeighborEntry {
|
||||
if len(i.entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
entry := i.entries[0]
|
||||
i.entries = i.entries[1:]
|
||||
return entry
|
||||
}
|
||||
123
experimental/libbox/neighbor_darwin.go
Normal file
123
experimental/libbox/neighbor_darwin.go
Normal file
@@ -0,0 +1,123 @@
|
||||
//go:build darwin
|
||||
|
||||
package libbox
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/route"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
xroute "golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func SubscribeNeighborTable(listener NeighborUpdateListener) (*NeighborSubscription, error) {
|
||||
entries, err := route.ReadNeighborEntries()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "initial neighbor dump")
|
||||
}
|
||||
table := make(map[netip.Addr]net.HardwareAddr)
|
||||
for _, entry := range entries {
|
||||
table[entry.Address] = entry.MACAddress
|
||||
}
|
||||
listener.UpdateNeighborTable(tableToIterator(table))
|
||||
routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, 0)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "open route socket")
|
||||
}
|
||||
err = unix.SetNonblock(routeSocket, true)
|
||||
if err != nil {
|
||||
unix.Close(routeSocket)
|
||||
return nil, E.Cause(err, "set route socket nonblock")
|
||||
}
|
||||
subscription := &NeighborSubscription{
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go subscription.loop(listener, routeSocket, table)
|
||||
return subscription, nil
|
||||
}
|
||||
|
||||
func (s *NeighborSubscription) loop(listener NeighborUpdateListener, routeSocket int, table map[netip.Addr]net.HardwareAddr) {
|
||||
routeSocketFile := os.NewFile(uintptr(routeSocket), "route")
|
||||
defer routeSocketFile.Close()
|
||||
buffer := buf.NewPacket()
|
||||
defer buffer.Release()
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
tv := unix.NsecToTimeval(int64(3 * time.Second))
|
||||
_ = unix.SetsockoptTimeval(routeSocket, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv)
|
||||
n, err := routeSocketFile.Read(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
continue
|
||||
}
|
||||
messages, err := xroute.ParseRIB(xroute.RIBTypeRoute, buffer.FreeBytes()[:n])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
changed := false
|
||||
for _, message := range messages {
|
||||
routeMessage, isRouteMessage := message.(*xroute.RouteMessage)
|
||||
if !isRouteMessage {
|
||||
continue
|
||||
}
|
||||
if routeMessage.Flags&unix.RTF_LLINFO == 0 {
|
||||
continue
|
||||
}
|
||||
address, mac, isDelete, ok := route.ParseRouteNeighborMessage(routeMessage)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if isDelete {
|
||||
if _, exists := table[address]; exists {
|
||||
delete(table, address)
|
||||
changed = true
|
||||
}
|
||||
} else {
|
||||
existing, exists := table[address]
|
||||
if !exists || !slices.Equal(existing, mac) {
|
||||
table[address] = mac
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if changed {
|
||||
listener.UpdateNeighborTable(tableToIterator(table))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ReadBootpdLeases() NeighborEntryIterator {
|
||||
leaseIPToMAC, ipToHostname, macToHostname := route.ReloadLeaseFiles([]string{"/var/db/dhcpd_leases"})
|
||||
entries := make([]*NeighborEntry, 0, len(leaseIPToMAC))
|
||||
for address, mac := range leaseIPToMAC {
|
||||
entry := &NeighborEntry{
|
||||
Address: address.String(),
|
||||
MacAddress: mac.String(),
|
||||
}
|
||||
hostname, found := ipToHostname[address]
|
||||
if !found {
|
||||
hostname = macToHostname[mac.String()]
|
||||
}
|
||||
entry.Hostname = hostname
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
return &neighborEntryIterator{entries}
|
||||
}
|
||||
88
experimental/libbox/neighbor_linux.go
Normal file
88
experimental/libbox/neighbor_linux.go
Normal file
@@ -0,0 +1,88 @@
|
||||
//go:build linux
|
||||
|
||||
package libbox
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/route"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"github.com/mdlayher/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func SubscribeNeighborTable(listener NeighborUpdateListener) (*NeighborSubscription, error) {
|
||||
entries, err := route.ReadNeighborEntries()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "initial neighbor dump")
|
||||
}
|
||||
table := make(map[netip.Addr]net.HardwareAddr)
|
||||
for _, entry := range entries {
|
||||
table[entry.Address] = entry.MACAddress
|
||||
}
|
||||
listener.UpdateNeighborTable(tableToIterator(table))
|
||||
connection, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{
|
||||
Groups: 1 << (unix.RTNLGRP_NEIGH - 1),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "subscribe neighbor updates")
|
||||
}
|
||||
subscription := &NeighborSubscription{
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go subscription.loop(listener, connection, table)
|
||||
return subscription, nil
|
||||
}
|
||||
|
||||
func (s *NeighborSubscription) loop(listener NeighborUpdateListener, connection *netlink.Conn, table map[netip.Addr]net.HardwareAddr) {
|
||||
defer connection.Close()
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
err := connection.SetReadDeadline(time.Now().Add(3 * time.Second))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
messages, err := connection.Receive()
|
||||
if err != nil {
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
continue
|
||||
}
|
||||
changed := false
|
||||
for _, message := range messages {
|
||||
address, mac, isDelete, ok := route.ParseNeighborMessage(message)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if isDelete {
|
||||
if _, exists := table[address]; exists {
|
||||
delete(table, address)
|
||||
changed = true
|
||||
}
|
||||
} else {
|
||||
existing, exists := table[address]
|
||||
if !exists || !slices.Equal(existing, mac) {
|
||||
table[address] = mac
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if changed {
|
||||
listener.UpdateNeighborTable(tableToIterator(table))
|
||||
}
|
||||
}
|
||||
}
|
||||
9
experimental/libbox/neighbor_stub.go
Normal file
9
experimental/libbox/neighbor_stub.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !linux && !darwin
|
||||
|
||||
package libbox
|
||||
|
||||
import "os"
|
||||
|
||||
func SubscribeNeighborTable(_ NeighborUpdateListener) (*NeighborSubscription, error) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
@@ -21,6 +21,13 @@ type PlatformInterface interface {
|
||||
SystemCertificates() StringIterator
|
||||
ClearDNSCache()
|
||||
SendNotification(notification *Notification) error
|
||||
StartNeighborMonitor(listener NeighborUpdateListener) error
|
||||
CloseNeighborMonitor(listener NeighborUpdateListener) error
|
||||
RegisterMyInterface(name string)
|
||||
}
|
||||
|
||||
type NeighborUpdateListener interface {
|
||||
UpdateNeighborTable(entries NeighborEntryIterator)
|
||||
}
|
||||
|
||||
type ConnectionOwner struct {
|
||||
|
||||
@@ -78,6 +78,7 @@ func (w *platformInterfaceWrapper) OpenInterface(options *tun.Options, platformO
|
||||
}
|
||||
options.FileDescriptor = dupFd
|
||||
w.myTunName = options.Name
|
||||
w.iif.RegisterMyInterface(options.Name)
|
||||
return tun.New(*options)
|
||||
}
|
||||
|
||||
@@ -220,6 +221,46 @@ func (w *platformInterfaceWrapper) SendNotification(notification *adapter.Notifi
|
||||
return w.iif.SendNotification((*Notification)(notification))
|
||||
}
|
||||
|
||||
func (w *platformInterfaceWrapper) UsePlatformNeighborResolver() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (w *platformInterfaceWrapper) StartNeighborMonitor(listener adapter.NeighborUpdateListener) error {
|
||||
return w.iif.StartNeighborMonitor(&neighborUpdateListenerWrapper{listener: listener})
|
||||
}
|
||||
|
||||
func (w *platformInterfaceWrapper) CloseNeighborMonitor(listener adapter.NeighborUpdateListener) error {
|
||||
return w.iif.CloseNeighborMonitor(nil)
|
||||
}
|
||||
|
||||
type neighborUpdateListenerWrapper struct {
|
||||
listener adapter.NeighborUpdateListener
|
||||
}
|
||||
|
||||
func (w *neighborUpdateListenerWrapper) UpdateNeighborTable(entries NeighborEntryIterator) {
|
||||
var result []adapter.NeighborEntry
|
||||
for entries.HasNext() {
|
||||
entry := entries.Next()
|
||||
if entry == nil {
|
||||
continue
|
||||
}
|
||||
address, err := netip.ParseAddr(entry.Address)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
macAddress, err := net.ParseMAC(entry.MacAddress)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, adapter.NeighborEntry{
|
||||
Address: address,
|
||||
MACAddress: macAddress,
|
||||
Hostname: entry.Hostname,
|
||||
})
|
||||
}
|
||||
w.listener.UpdateNeighborTable(result)
|
||||
}
|
||||
|
||||
func AvailablePort(startPort int32) (int32, error) {
|
||||
for port := int(startPort); ; port++ {
|
||||
if port > 65535 {
|
||||
|
||||
12
go.mod
12
go.mod
@@ -14,11 +14,13 @@ require (
|
||||
github.com/godbus/dbus/v5 v5.2.2
|
||||
github.com/gofrs/uuid/v5 v5.4.0
|
||||
github.com/insomniacslk/dhcp v0.0.0-20260220084031-5adc3eb26f91
|
||||
github.com/jsimonetti/rtnetlink v1.4.0
|
||||
github.com/keybase/go-keychain v0.0.1
|
||||
github.com/libdns/acmedns v0.5.0
|
||||
github.com/libdns/alidns v1.0.6
|
||||
github.com/libdns/cloudflare v0.2.2
|
||||
github.com/logrusorgru/aurora v2.0.3+incompatible
|
||||
github.com/mdlayher/netlink v1.9.0
|
||||
github.com/metacubex/utls v1.8.4
|
||||
github.com/mholt/acmez/v3 v3.1.6
|
||||
github.com/miekg/dns v1.1.72
|
||||
@@ -27,19 +29,19 @@ require (
|
||||
github.com/sagernet/asc-go v0.0.0-20241217030726-d563060fe4e1
|
||||
github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a
|
||||
github.com/sagernet/cors v1.2.1
|
||||
github.com/sagernet/cronet-go v0.0.0-20260309102448-2fef65f9dba9
|
||||
github.com/sagernet/cronet-go/all v0.0.0-20260309102448-2fef65f9dba9
|
||||
github.com/sagernet/cronet-go v0.0.0-20260309100020-c128886ff3fc
|
||||
github.com/sagernet/cronet-go/all v0.0.0-20260309100020-c128886ff3fc
|
||||
github.com/sagernet/fswatch v0.1.1
|
||||
github.com/sagernet/gomobile v0.1.12
|
||||
github.com/sagernet/gvisor v0.0.0-20250811.0-sing-box-mod.1
|
||||
github.com/sagernet/quic-go v0.59.0-sing-box-mod.4
|
||||
github.com/sagernet/sing v0.8.2
|
||||
github.com/sagernet/sing v0.8.3-0.20260311155444-d39eb42a9f69
|
||||
github.com/sagernet/sing-mux v0.3.4
|
||||
github.com/sagernet/sing-quic v0.6.0
|
||||
github.com/sagernet/sing-shadowsocks v0.2.8
|
||||
github.com/sagernet/sing-shadowsocks2 v0.2.1
|
||||
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11
|
||||
github.com/sagernet/sing-tun v0.8.3
|
||||
github.com/sagernet/sing-tun v0.8.4-0.20260315091454-bbe21100c226
|
||||
github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1
|
||||
github.com/sagernet/smux v1.5.50-sing-box-mod.1
|
||||
github.com/sagernet/tailscale v1.92.4-sing-box-1.13-mod.6.0.20260311131347-f88b27eeb76e
|
||||
@@ -92,11 +94,9 @@ require (
|
||||
github.com/hashicorp/yamux v0.1.2 // indirect
|
||||
github.com/hdevalence/ed25519consensus v0.2.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jsimonetti/rtnetlink v1.4.0 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/libdns/libdns v1.1.1 // indirect
|
||||
github.com/mdlayher/netlink v1.9.0 // indirect
|
||||
github.com/mdlayher/socket v0.5.1 // indirect
|
||||
github.com/mitchellh/go-ps v1.0.0 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
||||
|
||||
16
go.sum
16
go.sum
@@ -162,10 +162,10 @@ github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a h1:+NkI2670SQpQWvkk
|
||||
github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a/go.mod h1:63s7jpZqcDAIpj8oI/1v4Izok+npJOHACFCU6+huCkM=
|
||||
github.com/sagernet/cors v1.2.1 h1:Cv5Z8y9YSD6Gm+qSpNrL3LO4lD3eQVvbFYJSG7JCMHQ=
|
||||
github.com/sagernet/cors v1.2.1/go.mod h1:O64VyOjjhrkLmQIjF4KGRrJO/5dVXFdpEmCW/eISRAI=
|
||||
github.com/sagernet/cronet-go v0.0.0-20260309102448-2fef65f9dba9 h1:xq5Yr10jXEppD3cnGjE3WENaB6D0YsZu6KptZ8d3054=
|
||||
github.com/sagernet/cronet-go v0.0.0-20260309102448-2fef65f9dba9/go.mod h1:hwFHBEjjthyEquDULbr4c4ucMedp8Drb6Jvm2kt/0Bw=
|
||||
github.com/sagernet/cronet-go/all v0.0.0-20260309102448-2fef65f9dba9 h1:uxQyy6Y/boOuecVA66tf79JgtoRGfeDJcfYZZLKVA5E=
|
||||
github.com/sagernet/cronet-go/all v0.0.0-20260309102448-2fef65f9dba9/go.mod h1:Xm6cCvs0/twozC1JYNq0sVlOVmcSGzV7YON1XGcD97w=
|
||||
github.com/sagernet/cronet-go v0.0.0-20260309100020-c128886ff3fc h1:YK7PwJT0irRAEui9ASdXSxcE2BOVQipWMF/A1Ogt+7c=
|
||||
github.com/sagernet/cronet-go v0.0.0-20260309100020-c128886ff3fc/go.mod h1:hwFHBEjjthyEquDULbr4c4ucMedp8Drb6Jvm2kt/0Bw=
|
||||
github.com/sagernet/cronet-go/all v0.0.0-20260309100020-c128886ff3fc h1:EJPHOqk23IuBsTjXK9OXqkNxPbKOBWKRmviQoCcriAs=
|
||||
github.com/sagernet/cronet-go/all v0.0.0-20260309100020-c128886ff3fc/go.mod h1:8aty0RW96DrJSMWXO6bRPMBJEjuqq5JWiOIi4bCRzFA=
|
||||
github.com/sagernet/cronet-go/lib/android_386 v0.0.0-20260309101654-0cbdcfddded9 h1:Qi0IKBpoPP3qZqIXuOKMsT2dv+l/MLWMyBHDMLRw2EA=
|
||||
github.com/sagernet/cronet-go/lib/android_386 v0.0.0-20260309101654-0cbdcfddded9/go.mod h1:XXDwdjX/T8xftoeJxQmbBoYXZp8MAPFR2CwbFuTpEtw=
|
||||
github.com/sagernet/cronet-go/lib/android_amd64 v0.0.0-20260309101654-0cbdcfddded9 h1:p+wCMjOhj46SpSD/AJeTGgkCcbyA76FyH631XZatyU8=
|
||||
@@ -236,8 +236,8 @@ github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNen
|
||||
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
|
||||
github.com/sagernet/quic-go v0.59.0-sing-box-mod.4 h1:6qvrUW79S+CrPwWz6cMePXohgjHoKxLo3c+MDhNwc3o=
|
||||
github.com/sagernet/quic-go v0.59.0-sing-box-mod.4/go.mod h1:OqILvS182CyOol5zNNo6bguvOGgXzV459+chpRaUC+4=
|
||||
github.com/sagernet/sing v0.8.2 h1:kX1IH9SWJv4S0T9M8O+HNahWgbOuY1VauxbF7NU5lOg=
|
||||
github.com/sagernet/sing v0.8.2/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
|
||||
github.com/sagernet/sing v0.8.3-0.20260311155444-d39eb42a9f69 h1:h6UF2emeydBQMAso99Nr3APV6YustOs+JszVuCkcFy0=
|
||||
github.com/sagernet/sing v0.8.3-0.20260311155444-d39eb42a9f69/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
|
||||
github.com/sagernet/sing-mux v0.3.4 h1:ZQplKl8MNXutjzbMVtWvWG31fohhgOfCuUZR4dVQ8+s=
|
||||
github.com/sagernet/sing-mux v0.3.4/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk=
|
||||
github.com/sagernet/sing-quic v0.6.0 h1:dhrFnP45wgVKEOT1EvtsToxdzRnHIDIAgj6WHV9pLyM=
|
||||
@@ -248,8 +248,8 @@ github.com/sagernet/sing-shadowsocks2 v0.2.1 h1:dWV9OXCeFPuYGHb6IRqlSptVnSzOelnq
|
||||
github.com/sagernet/sing-shadowsocks2 v0.2.1/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ=
|
||||
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 h1:tK+75l64tm9WvEFrYRE1t0YxoFdWQqw/h7Uhzj0vJ+w=
|
||||
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11/go.mod h1:sWqKnGlMipCHaGsw1sTTlimyUpgzP4WP3pjhCsYt9oA=
|
||||
github.com/sagernet/sing-tun v0.8.3 h1:mozxmuIoRhFdVHnheenLpBaammVj7bZPcnkApaYKDPY=
|
||||
github.com/sagernet/sing-tun v0.8.3/go.mod h1:pLCo4o+LacXEzz0bhwhJkKBjLlKOGPBNOAZ97ZVZWzs=
|
||||
github.com/sagernet/sing-tun v0.8.4-0.20260315091454-bbe21100c226 h1:Shy/fsm+pqVq6OkBAWPaOmOiPT/AwoRxQLiV1357Y0Y=
|
||||
github.com/sagernet/sing-tun v0.8.4-0.20260315091454-bbe21100c226/go.mod h1:pLCo4o+LacXEzz0bhwhJkKBjLlKOGPBNOAZ97ZVZWzs=
|
||||
github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1 h1:aSwUNYUkVyVvdmBSufR8/nRFonwJeKSIROxHcm5br9o=
|
||||
github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1/go.mod h1:P11scgTxMxVVQ8dlM27yNm3Cro40mD0+gHbnqrNGDuY=
|
||||
github.com/sagernet/smux v1.5.50-sing-box-mod.1 h1:XkJcivBC9V4wBjiGXIXZ229aZCU1hzcbp6kSkkyQ478=
|
||||
|
||||
@@ -168,7 +168,11 @@ func FormatDuration(duration time.Duration) string {
|
||||
return F.ToString(duration.Milliseconds(), "ms")
|
||||
} else if duration < time.Minute {
|
||||
return F.ToString(int64(duration.Seconds()), ".", int64(duration.Seconds()*100)%100, "s")
|
||||
} else {
|
||||
} else if duration < time.Hour {
|
||||
return F.ToString(int64(duration.Minutes()), "m", int64(duration.Seconds())%60, "s")
|
||||
} else if duration < 24*time.Hour {
|
||||
return F.ToString(int64(duration.Hours()), "h", int64(duration.Minutes())%60, "m")
|
||||
} else {
|
||||
return F.ToString(int64(duration.Hours())/24, "d", int64(duration.Hours())%24, "h")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,6 +129,7 @@ nav:
|
||||
- UDP over TCP: configuration/shared/udp-over-tcp.md
|
||||
- TCP Brutal: configuration/shared/tcp-brutal.md
|
||||
- Wi-Fi State: configuration/shared/wifi-state.md
|
||||
- Neighbor Resolution: configuration/shared/neighbor.md
|
||||
- Endpoint:
|
||||
- configuration/endpoint/index.md
|
||||
- WireGuard: configuration/endpoint/wireguard.md
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package option
|
||||
|
||||
import (
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
"github.com/sagernet/sing/common/json/badjson"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
)
|
||||
|
||||
@@ -8,6 +11,7 @@ type CCMServiceOptions struct {
|
||||
ListenOptions
|
||||
InboundTLSOptionsContainer
|
||||
CredentialPath string `json:"credential_path,omitempty"`
|
||||
Credentials []CCMCredential `json:"credentials,omitempty"`
|
||||
Users []CCMUser `json:"users,omitempty"`
|
||||
Headers badoption.HTTPHeader `json:"headers,omitempty"`
|
||||
Detour string `json:"detour,omitempty"`
|
||||
@@ -15,6 +19,84 @@ type CCMServiceOptions struct {
|
||||
}
|
||||
|
||||
type CCMUser struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Credential string `json:"credential,omitempty"`
|
||||
ExternalCredential string `json:"external_credential,omitempty"`
|
||||
AllowExternalUsage bool `json:"allow_external_usage,omitempty"`
|
||||
}
|
||||
|
||||
type _CCMCredential struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Tag string `json:"tag"`
|
||||
DefaultOptions CCMDefaultCredentialOptions `json:"-"`
|
||||
ExternalOptions CCMExternalCredentialOptions `json:"-"`
|
||||
BalancerOptions CCMBalancerCredentialOptions `json:"-"`
|
||||
}
|
||||
|
||||
type CCMCredential _CCMCredential
|
||||
|
||||
func (c CCMCredential) MarshalJSON() ([]byte, error) {
|
||||
var v any
|
||||
switch c.Type {
|
||||
case "", "default":
|
||||
c.Type = ""
|
||||
v = c.DefaultOptions
|
||||
case "external":
|
||||
v = c.ExternalOptions
|
||||
case "balancer":
|
||||
v = c.BalancerOptions
|
||||
default:
|
||||
return nil, E.New("unknown credential type: ", c.Type)
|
||||
}
|
||||
return badjson.MarshallObjects((_CCMCredential)(c), v)
|
||||
}
|
||||
|
||||
func (c *CCMCredential) UnmarshalJSON(bytes []byte) error {
|
||||
err := json.Unmarshal(bytes, (*_CCMCredential)(c))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.Tag == "" {
|
||||
return E.New("missing credential tag")
|
||||
}
|
||||
var v any
|
||||
switch c.Type {
|
||||
case "", "default":
|
||||
c.Type = "default"
|
||||
v = &c.DefaultOptions
|
||||
case "external":
|
||||
v = &c.ExternalOptions
|
||||
case "balancer":
|
||||
v = &c.BalancerOptions
|
||||
default:
|
||||
return E.New("unknown credential type: ", c.Type)
|
||||
}
|
||||
return badjson.UnmarshallExcluded(bytes, (*_CCMCredential)(c), v)
|
||||
}
|
||||
|
||||
type CCMDefaultCredentialOptions struct {
|
||||
CredentialPath string `json:"credential_path,omitempty"`
|
||||
ClaudeDirectory string `json:"claude_directory,omitempty"`
|
||||
UsagesPath string `json:"usages_path,omitempty"`
|
||||
Detour string `json:"detour,omitempty"`
|
||||
Reserve5h uint8 `json:"reserve_5h"`
|
||||
ReserveWeekly uint8 `json:"reserve_weekly"`
|
||||
Limit5h uint8 `json:"limit_5h,omitempty"`
|
||||
LimitWeekly uint8 `json:"limit_weekly,omitempty"`
|
||||
}
|
||||
|
||||
type CCMBalancerCredentialOptions struct {
|
||||
Strategy string `json:"strategy,omitempty"`
|
||||
Credentials badoption.Listable[string] `json:"credentials"`
|
||||
RebalanceThreshold float64 `json:"rebalance_threshold,omitempty"`
|
||||
}
|
||||
|
||||
type CCMExternalCredentialOptions struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
ServerOptions
|
||||
Token string `json:"token"`
|
||||
Reverse bool `json:"reverse,omitempty"`
|
||||
Detour string `json:"detour,omitempty"`
|
||||
UsagesPath string `json:"usages_path,omitempty"`
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package option
|
||||
|
||||
import (
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
"github.com/sagernet/sing/common/json/badjson"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
)
|
||||
|
||||
@@ -8,6 +11,7 @@ type OCMServiceOptions struct {
|
||||
ListenOptions
|
||||
InboundTLSOptionsContainer
|
||||
CredentialPath string `json:"credential_path,omitempty"`
|
||||
Credentials []OCMCredential `json:"credentials,omitempty"`
|
||||
Users []OCMUser `json:"users,omitempty"`
|
||||
Headers badoption.HTTPHeader `json:"headers,omitempty"`
|
||||
Detour string `json:"detour,omitempty"`
|
||||
@@ -15,6 +19,83 @@ type OCMServiceOptions struct {
|
||||
}
|
||||
|
||||
type OCMUser struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Credential string `json:"credential,omitempty"`
|
||||
ExternalCredential string `json:"external_credential,omitempty"`
|
||||
AllowExternalUsage bool `json:"allow_external_usage,omitempty"`
|
||||
}
|
||||
|
||||
type _OCMCredential struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Tag string `json:"tag"`
|
||||
DefaultOptions OCMDefaultCredentialOptions `json:"-"`
|
||||
ExternalOptions OCMExternalCredentialOptions `json:"-"`
|
||||
BalancerOptions OCMBalancerCredentialOptions `json:"-"`
|
||||
}
|
||||
|
||||
type OCMCredential _OCMCredential
|
||||
|
||||
func (c OCMCredential) MarshalJSON() ([]byte, error) {
|
||||
var v any
|
||||
switch c.Type {
|
||||
case "", "default":
|
||||
c.Type = ""
|
||||
v = c.DefaultOptions
|
||||
case "external":
|
||||
v = c.ExternalOptions
|
||||
case "balancer":
|
||||
v = c.BalancerOptions
|
||||
default:
|
||||
return nil, E.New("unknown credential type: ", c.Type)
|
||||
}
|
||||
return badjson.MarshallObjects((_OCMCredential)(c), v)
|
||||
}
|
||||
|
||||
func (c *OCMCredential) UnmarshalJSON(bytes []byte) error {
|
||||
err := json.Unmarshal(bytes, (*_OCMCredential)(c))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.Tag == "" {
|
||||
return E.New("missing credential tag")
|
||||
}
|
||||
var v any
|
||||
switch c.Type {
|
||||
case "", "default":
|
||||
c.Type = "default"
|
||||
v = &c.DefaultOptions
|
||||
case "external":
|
||||
v = &c.ExternalOptions
|
||||
case "balancer":
|
||||
v = &c.BalancerOptions
|
||||
default:
|
||||
return E.New("unknown credential type: ", c.Type)
|
||||
}
|
||||
return badjson.UnmarshallExcluded(bytes, (*_OCMCredential)(c), v)
|
||||
}
|
||||
|
||||
type OCMDefaultCredentialOptions struct {
|
||||
CredentialPath string `json:"credential_path,omitempty"`
|
||||
UsagesPath string `json:"usages_path,omitempty"`
|
||||
Detour string `json:"detour,omitempty"`
|
||||
Reserve5h uint8 `json:"reserve_5h"`
|
||||
ReserveWeekly uint8 `json:"reserve_weekly"`
|
||||
Limit5h uint8 `json:"limit_5h,omitempty"`
|
||||
LimitWeekly uint8 `json:"limit_weekly,omitempty"`
|
||||
}
|
||||
|
||||
type OCMBalancerCredentialOptions struct {
|
||||
Strategy string `json:"strategy,omitempty"`
|
||||
Credentials badoption.Listable[string] `json:"credentials"`
|
||||
RebalanceThreshold float64 `json:"rebalance_threshold,omitempty"`
|
||||
}
|
||||
|
||||
type OCMExternalCredentialOptions struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
ServerOptions
|
||||
Token string `json:"token"`
|
||||
Reverse bool `json:"reverse,omitempty"`
|
||||
Detour string `json:"detour,omitempty"`
|
||||
UsagesPath string `json:"usages_path,omitempty"`
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ type RouteOptions struct {
|
||||
RuleSet []RuleSet `json:"rule_set,omitempty"`
|
||||
Final string `json:"final,omitempty"`
|
||||
FindProcess bool `json:"find_process,omitempty"`
|
||||
FindNeighbor bool `json:"find_neighbor,omitempty"`
|
||||
DHCPLeaseFiles badoption.Listable[string] `json:"dhcp_lease_files,omitempty"`
|
||||
AutoDetectInterface bool `json:"auto_detect_interface,omitempty"`
|
||||
OverrideAndroidVPN bool `json:"override_android_vpn,omitempty"`
|
||||
DefaultInterface string `json:"default_interface,omitempty"`
|
||||
|
||||
@@ -103,6 +103,8 @@ type RawDefaultRule struct {
|
||||
InterfaceAddress *badjson.TypedMap[string, badoption.Listable[*badoption.Prefixable]] `json:"interface_address,omitempty"`
|
||||
NetworkInterfaceAddress *badjson.TypedMap[InterfaceType, badoption.Listable[*badoption.Prefixable]] `json:"network_interface_address,omitempty"`
|
||||
DefaultInterfaceAddress badoption.Listable[*badoption.Prefixable] `json:"default_interface_address,omitempty"`
|
||||
SourceMACAddress badoption.Listable[string] `json:"source_mac_address,omitempty"`
|
||||
SourceHostname badoption.Listable[string] `json:"source_hostname,omitempty"`
|
||||
PreferredBy badoption.Listable[string] `json:"preferred_by,omitempty"`
|
||||
RuleSet badoption.Listable[string] `json:"rule_set,omitempty"`
|
||||
RuleSetIPCIDRMatchSource bool `json:"rule_set_ip_cidr_match_source,omitempty"`
|
||||
|
||||
@@ -106,6 +106,8 @@ type RawDefaultDNSRule struct {
|
||||
InterfaceAddress *badjson.TypedMap[string, badoption.Listable[*badoption.Prefixable]] `json:"interface_address,omitempty"`
|
||||
NetworkInterfaceAddress *badjson.TypedMap[InterfaceType, badoption.Listable[*badoption.Prefixable]] `json:"network_interface_address,omitempty"`
|
||||
DefaultInterfaceAddress badoption.Listable[*badoption.Prefixable] `json:"default_interface_address,omitempty"`
|
||||
SourceMACAddress badoption.Listable[string] `json:"source_mac_address,omitempty"`
|
||||
SourceHostname badoption.Listable[string] `json:"source_hostname,omitempty"`
|
||||
RuleSet badoption.Listable[string] `json:"rule_set,omitempty"`
|
||||
RuleSetIPCIDRMatchSource bool `json:"rule_set_ip_cidr_match_source,omitempty"`
|
||||
RuleSetIPCIDRAcceptEmpty bool `json:"rule_set_ip_cidr_accept_empty,omitempty"`
|
||||
|
||||
@@ -39,6 +39,8 @@ type TunInboundOptions struct {
|
||||
IncludeAndroidUser badoption.Listable[int] `json:"include_android_user,omitempty"`
|
||||
IncludePackage badoption.Listable[string] `json:"include_package,omitempty"`
|
||||
ExcludePackage badoption.Listable[string] `json:"exclude_package,omitempty"`
|
||||
IncludeMACAddress badoption.Listable[string] `json:"include_mac_address,omitempty"`
|
||||
ExcludeMACAddress badoption.Listable[string] `json:"exclude_mac_address,omitempty"`
|
||||
UDPTimeout UDPTimeoutCompat `json:"udp_timeout,omitempty"`
|
||||
Stack string `json:"stack,omitempty"`
|
||||
Platform *TunPlatformOptions `json:"platform,omitempty"`
|
||||
|
||||
@@ -156,6 +156,22 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||
if nfQueue == 0 {
|
||||
nfQueue = tun.DefaultAutoRedirectNFQueue
|
||||
}
|
||||
var includeMACAddress []net.HardwareAddr
|
||||
for i, macString := range options.IncludeMACAddress {
|
||||
mac, macErr := net.ParseMAC(macString)
|
||||
if macErr != nil {
|
||||
return nil, E.Cause(macErr, "parse include_mac_address[", i, "]")
|
||||
}
|
||||
includeMACAddress = append(includeMACAddress, mac)
|
||||
}
|
||||
var excludeMACAddress []net.HardwareAddr
|
||||
for i, macString := range options.ExcludeMACAddress {
|
||||
mac, macErr := net.ParseMAC(macString)
|
||||
if macErr != nil {
|
||||
return nil, E.Cause(macErr, "parse exclude_mac_address[", i, "]")
|
||||
}
|
||||
excludeMACAddress = append(excludeMACAddress, mac)
|
||||
}
|
||||
networkManager := service.FromContext[adapter.NetworkManager](ctx)
|
||||
multiPendingPackets := C.IsDarwin && ((options.Stack == "gvisor" && tunMTU < 32768) || (options.Stack != "gvisor" && options.MTU <= 9000))
|
||||
inbound := &Inbound{
|
||||
@@ -193,6 +209,8 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||
IncludeAndroidUser: options.IncludeAndroidUser,
|
||||
IncludePackage: options.IncludePackage,
|
||||
ExcludePackage: options.ExcludePackage,
|
||||
IncludeMACAddress: includeMACAddress,
|
||||
ExcludeMACAddress: excludeMACAddress,
|
||||
InterfaceMonitor: networkManager.InterfaceMonitor(),
|
||||
EXP_MultiPendingPackets: multiPendingPackets,
|
||||
},
|
||||
|
||||
@@ -2,8 +2,16 @@
|
||||
|
||||
set -e -o pipefail
|
||||
|
||||
go_version=$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
|
||||
curl -Lo go.tar.gz "https://go.dev/dl/go$go_version.linux-amd64.tar.gz"
|
||||
manifest=$(curl -fS 'https://go.dev/VERSION?m=text')
|
||||
go_version=$(echo "$manifest" | head -1 | sed 's/^go//')
|
||||
os=$(uname -s | tr '[:upper:]' '[:lower:]')
|
||||
arch=$(uname -m)
|
||||
case "$arch" in
|
||||
x86_64) arch="amd64" ;;
|
||||
aarch64|arm64) arch="arm64" ;;
|
||||
esac
|
||||
curl -Lo go.tar.gz "https://go.dev/dl/go$go_version.$os-$arch.tar.gz"
|
||||
sudo rm -rf /usr/local/go
|
||||
sudo tar -C /usr/local -xzf go.tar.gz
|
||||
rm go.tar.gz
|
||||
echo "Installed Go $go_version"
|
||||
|
||||
239
route/neighbor_resolver_darwin.go
Normal file
239
route/neighbor_resolver_darwin.go
Normal file
@@ -0,0 +1,239 @@
|
||||
//go:build darwin
|
||||
|
||||
package route
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/fswatch"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var defaultLeaseFiles = []string{
|
||||
"/var/db/dhcpd_leases",
|
||||
"/tmp/dhcp.leases",
|
||||
}
|
||||
|
||||
type neighborResolver struct {
|
||||
logger logger.ContextLogger
|
||||
leaseFiles []string
|
||||
access sync.RWMutex
|
||||
neighborIPToMAC map[netip.Addr]net.HardwareAddr
|
||||
leaseIPToMAC map[netip.Addr]net.HardwareAddr
|
||||
ipToHostname map[netip.Addr]string
|
||||
macToHostname map[string]string
|
||||
watcher *fswatch.Watcher
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newNeighborResolver(resolverLogger logger.ContextLogger, leaseFiles []string) (adapter.NeighborResolver, error) {
|
||||
if len(leaseFiles) == 0 {
|
||||
for _, path := range defaultLeaseFiles {
|
||||
info, err := os.Stat(path)
|
||||
if err == nil && info.Size() > 0 {
|
||||
leaseFiles = append(leaseFiles, path)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &neighborResolver{
|
||||
logger: resolverLogger,
|
||||
leaseFiles: leaseFiles,
|
||||
neighborIPToMAC: make(map[netip.Addr]net.HardwareAddr),
|
||||
leaseIPToMAC: make(map[netip.Addr]net.HardwareAddr),
|
||||
ipToHostname: make(map[netip.Addr]string),
|
||||
macToHostname: make(map[string]string),
|
||||
done: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *neighborResolver) Start() error {
|
||||
err := r.loadNeighborTable()
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "load neighbor table"))
|
||||
}
|
||||
r.doReloadLeaseFiles()
|
||||
go r.subscribeNeighborUpdates()
|
||||
if len(r.leaseFiles) > 0 {
|
||||
watcher, err := fswatch.NewWatcher(fswatch.Options{
|
||||
Path: r.leaseFiles,
|
||||
Logger: r.logger,
|
||||
Callback: func(_ string) {
|
||||
r.doReloadLeaseFiles()
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "create lease file watcher"))
|
||||
} else {
|
||||
r.watcher = watcher
|
||||
err = watcher.Start()
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "start lease file watcher"))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *neighborResolver) Close() error {
|
||||
close(r.done)
|
||||
if r.watcher != nil {
|
||||
return r.watcher.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *neighborResolver) LookupMAC(address netip.Addr) (net.HardwareAddr, bool) {
|
||||
r.access.RLock()
|
||||
defer r.access.RUnlock()
|
||||
mac, found := r.neighborIPToMAC[address]
|
||||
if found {
|
||||
return mac, true
|
||||
}
|
||||
mac, found = r.leaseIPToMAC[address]
|
||||
if found {
|
||||
return mac, true
|
||||
}
|
||||
mac, found = extractMACFromEUI64(address)
|
||||
if found {
|
||||
return mac, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (r *neighborResolver) LookupHostname(address netip.Addr) (string, bool) {
|
||||
r.access.RLock()
|
||||
defer r.access.RUnlock()
|
||||
hostname, found := r.ipToHostname[address]
|
||||
if found {
|
||||
return hostname, true
|
||||
}
|
||||
mac, macFound := r.neighborIPToMAC[address]
|
||||
if !macFound {
|
||||
mac, macFound = r.leaseIPToMAC[address]
|
||||
}
|
||||
if !macFound {
|
||||
mac, macFound = extractMACFromEUI64(address)
|
||||
}
|
||||
if macFound {
|
||||
hostname, found = r.macToHostname[mac.String()]
|
||||
if found {
|
||||
return hostname, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (r *neighborResolver) loadNeighborTable() error {
|
||||
entries, err := ReadNeighborEntries()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.access.Lock()
|
||||
defer r.access.Unlock()
|
||||
for _, entry := range entries {
|
||||
r.neighborIPToMAC[entry.Address] = entry.MACAddress
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *neighborResolver) subscribeNeighborUpdates() {
|
||||
routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, 0)
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "subscribe neighbor updates"))
|
||||
return
|
||||
}
|
||||
err = unix.SetNonblock(routeSocket, true)
|
||||
if err != nil {
|
||||
unix.Close(routeSocket)
|
||||
r.logger.Warn(E.Cause(err, "set route socket nonblock"))
|
||||
return
|
||||
}
|
||||
routeSocketFile := os.NewFile(uintptr(routeSocket), "route")
|
||||
defer routeSocketFile.Close()
|
||||
buffer := buf.NewPacket()
|
||||
defer buffer.Release()
|
||||
for {
|
||||
select {
|
||||
case <-r.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
err = setReadDeadline(routeSocketFile, 3*time.Second)
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "set route socket read deadline"))
|
||||
return
|
||||
}
|
||||
n, err := routeSocketFile.Read(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-r.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
r.logger.Warn(E.Cause(err, "receive neighbor update"))
|
||||
continue
|
||||
}
|
||||
messages, err := route.ParseRIB(route.RIBTypeRoute, buffer.FreeBytes()[:n])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, message := range messages {
|
||||
routeMessage, isRouteMessage := message.(*route.RouteMessage)
|
||||
if !isRouteMessage {
|
||||
continue
|
||||
}
|
||||
if routeMessage.Flags&unix.RTF_LLINFO == 0 {
|
||||
continue
|
||||
}
|
||||
address, mac, isDelete, ok := ParseRouteNeighborMessage(routeMessage)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
r.access.Lock()
|
||||
if isDelete {
|
||||
delete(r.neighborIPToMAC, address)
|
||||
} else {
|
||||
r.neighborIPToMAC[address] = mac
|
||||
}
|
||||
r.access.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *neighborResolver) doReloadLeaseFiles() {
|
||||
leaseIPToMAC, ipToHostname, macToHostname := ReloadLeaseFiles(r.leaseFiles)
|
||||
r.access.Lock()
|
||||
r.leaseIPToMAC = leaseIPToMAC
|
||||
r.ipToHostname = ipToHostname
|
||||
r.macToHostname = macToHostname
|
||||
r.access.Unlock()
|
||||
}
|
||||
|
||||
func setReadDeadline(file *os.File, timeout time.Duration) error {
|
||||
rawConn, err := file.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var controlErr error
|
||||
err = rawConn.Control(func(fd uintptr) {
|
||||
tv := unix.NsecToTimeval(int64(timeout))
|
||||
controlErr = unix.SetsockoptTimeval(int(fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return controlErr
|
||||
}
|
||||
386
route/neighbor_resolver_lease.go
Normal file
386
route/neighbor_resolver_lease.go
Normal file
@@ -0,0 +1,386 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func parseLeaseFile(path string, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
if strings.HasSuffix(path, "dhcpd_leases") {
|
||||
parseBootpdLeases(file, ipToMAC, ipToHostname, macToHostname)
|
||||
return
|
||||
}
|
||||
if strings.HasSuffix(path, "kea-leases4.csv") {
|
||||
parseKeaCSV4(file, ipToMAC, ipToHostname, macToHostname)
|
||||
return
|
||||
}
|
||||
if strings.HasSuffix(path, "kea-leases6.csv") {
|
||||
parseKeaCSV6(file, ipToMAC, ipToHostname, macToHostname)
|
||||
return
|
||||
}
|
||||
if strings.HasSuffix(path, "dhcpd.leases") {
|
||||
parseISCDhcpd(file, ipToMAC, ipToHostname, macToHostname)
|
||||
return
|
||||
}
|
||||
parseDnsmasqOdhcpd(file, ipToMAC, ipToHostname, macToHostname)
|
||||
}
|
||||
|
||||
func ReloadLeaseFiles(leaseFiles []string) (leaseIPToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) {
|
||||
leaseIPToMAC = make(map[netip.Addr]net.HardwareAddr)
|
||||
ipToHostname = make(map[netip.Addr]string)
|
||||
macToHostname = make(map[string]string)
|
||||
for _, path := range leaseFiles {
|
||||
parseLeaseFile(path, leaseIPToMAC, ipToHostname, macToHostname)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseDnsmasqOdhcpd(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) {
|
||||
now := time.Now().Unix()
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "duid ") {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "# ") {
|
||||
parseOdhcpdLine(line[2:], ipToMAC, ipToHostname, macToHostname)
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 4 {
|
||||
continue
|
||||
}
|
||||
expiry, err := strconv.ParseInt(fields[0], 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if expiry != 0 && expiry < now {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(fields[1], ":") {
|
||||
mac, macErr := net.ParseMAC(fields[1])
|
||||
if macErr != nil {
|
||||
continue
|
||||
}
|
||||
address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[2]))
|
||||
if !addrOK {
|
||||
continue
|
||||
}
|
||||
address = address.Unmap()
|
||||
ipToMAC[address] = mac
|
||||
hostname := fields[3]
|
||||
if hostname != "*" {
|
||||
ipToHostname[address] = hostname
|
||||
macToHostname[mac.String()] = hostname
|
||||
}
|
||||
} else {
|
||||
var mac net.HardwareAddr
|
||||
if len(fields) >= 5 {
|
||||
duid, duidErr := parseDUID(fields[4])
|
||||
if duidErr == nil {
|
||||
mac, _ = extractMACFromDUID(duid)
|
||||
}
|
||||
}
|
||||
address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[2]))
|
||||
if !addrOK {
|
||||
continue
|
||||
}
|
||||
address = address.Unmap()
|
||||
if mac != nil {
|
||||
ipToMAC[address] = mac
|
||||
}
|
||||
hostname := fields[3]
|
||||
if hostname != "*" {
|
||||
ipToHostname[address] = hostname
|
||||
if mac != nil {
|
||||
macToHostname[mac.String()] = hostname
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseOdhcpdLine(line string, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 5 {
|
||||
return
|
||||
}
|
||||
validTime, err := strconv.ParseInt(fields[4], 10, 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if validTime == 0 {
|
||||
return
|
||||
}
|
||||
if validTime > 0 && validTime < time.Now().Unix() {
|
||||
return
|
||||
}
|
||||
hostname := fields[3]
|
||||
if hostname == "-" || strings.HasPrefix(hostname, `broken\x20`) {
|
||||
hostname = ""
|
||||
}
|
||||
if len(fields) >= 8 && fields[2] == "ipv4" {
|
||||
mac, macErr := net.ParseMAC(fields[1])
|
||||
if macErr != nil {
|
||||
return
|
||||
}
|
||||
addressField := fields[7]
|
||||
slashIndex := strings.IndexByte(addressField, '/')
|
||||
if slashIndex >= 0 {
|
||||
addressField = addressField[:slashIndex]
|
||||
}
|
||||
address, addrOK := netip.AddrFromSlice(net.ParseIP(addressField))
|
||||
if !addrOK {
|
||||
return
|
||||
}
|
||||
address = address.Unmap()
|
||||
ipToMAC[address] = mac
|
||||
if hostname != "" {
|
||||
ipToHostname[address] = hostname
|
||||
macToHostname[mac.String()] = hostname
|
||||
}
|
||||
return
|
||||
}
|
||||
var mac net.HardwareAddr
|
||||
duidHex := fields[1]
|
||||
duidBytes, hexErr := hex.DecodeString(duidHex)
|
||||
if hexErr == nil {
|
||||
mac, _ = extractMACFromDUID(duidBytes)
|
||||
}
|
||||
for i := 7; i < len(fields); i++ {
|
||||
addressField := fields[i]
|
||||
slashIndex := strings.IndexByte(addressField, '/')
|
||||
if slashIndex >= 0 {
|
||||
addressField = addressField[:slashIndex]
|
||||
}
|
||||
address, addrOK := netip.AddrFromSlice(net.ParseIP(addressField))
|
||||
if !addrOK {
|
||||
continue
|
||||
}
|
||||
address = address.Unmap()
|
||||
if mac != nil {
|
||||
ipToMAC[address] = mac
|
||||
}
|
||||
if hostname != "" {
|
||||
ipToHostname[address] = hostname
|
||||
if mac != nil {
|
||||
macToHostname[mac.String()] = hostname
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseISCDhcpd(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) {
|
||||
scanner := bufio.NewScanner(file)
|
||||
var currentIP netip.Addr
|
||||
var currentMAC net.HardwareAddr
|
||||
var currentHostname string
|
||||
var currentActive bool
|
||||
var inLease bool
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "lease ") && strings.HasSuffix(line, "{") {
|
||||
ipString := strings.TrimSuffix(strings.TrimPrefix(line, "lease "), " {")
|
||||
parsed, addrOK := netip.AddrFromSlice(net.ParseIP(ipString))
|
||||
if addrOK {
|
||||
currentIP = parsed.Unmap()
|
||||
inLease = true
|
||||
currentMAC = nil
|
||||
currentHostname = ""
|
||||
currentActive = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
if line == "}" && inLease {
|
||||
if currentActive && currentMAC != nil {
|
||||
ipToMAC[currentIP] = currentMAC
|
||||
if currentHostname != "" {
|
||||
ipToHostname[currentIP] = currentHostname
|
||||
macToHostname[currentMAC.String()] = currentHostname
|
||||
}
|
||||
} else {
|
||||
delete(ipToMAC, currentIP)
|
||||
delete(ipToHostname, currentIP)
|
||||
}
|
||||
inLease = false
|
||||
continue
|
||||
}
|
||||
if !inLease {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "hardware ethernet ") {
|
||||
macString := strings.TrimSuffix(strings.TrimPrefix(line, "hardware ethernet "), ";")
|
||||
parsed, macErr := net.ParseMAC(macString)
|
||||
if macErr == nil {
|
||||
currentMAC = parsed
|
||||
}
|
||||
} else if strings.HasPrefix(line, "client-hostname ") {
|
||||
hostname := strings.TrimSuffix(strings.TrimPrefix(line, "client-hostname "), ";")
|
||||
hostname = strings.Trim(hostname, "\"")
|
||||
if hostname != "" {
|
||||
currentHostname = hostname
|
||||
}
|
||||
} else if strings.HasPrefix(line, "binding state ") {
|
||||
state := strings.TrimSuffix(strings.TrimPrefix(line, "binding state "), ";")
|
||||
currentActive = state == "active"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseKeaCSV4(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) {
|
||||
scanner := bufio.NewScanner(file)
|
||||
firstLine := true
|
||||
for scanner.Scan() {
|
||||
if firstLine {
|
||||
firstLine = false
|
||||
continue
|
||||
}
|
||||
fields := strings.Split(scanner.Text(), ",")
|
||||
if len(fields) < 10 {
|
||||
continue
|
||||
}
|
||||
if fields[9] != "0" {
|
||||
continue
|
||||
}
|
||||
address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[0]))
|
||||
if !addrOK {
|
||||
continue
|
||||
}
|
||||
address = address.Unmap()
|
||||
mac, macErr := net.ParseMAC(fields[1])
|
||||
if macErr != nil {
|
||||
continue
|
||||
}
|
||||
ipToMAC[address] = mac
|
||||
hostname := ""
|
||||
if len(fields) > 8 {
|
||||
hostname = fields[8]
|
||||
}
|
||||
if hostname != "" {
|
||||
ipToHostname[address] = hostname
|
||||
macToHostname[mac.String()] = hostname
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseKeaCSV6(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) {
|
||||
scanner := bufio.NewScanner(file)
|
||||
firstLine := true
|
||||
for scanner.Scan() {
|
||||
if firstLine {
|
||||
firstLine = false
|
||||
continue
|
||||
}
|
||||
fields := strings.Split(scanner.Text(), ",")
|
||||
if len(fields) < 14 {
|
||||
continue
|
||||
}
|
||||
if fields[13] != "0" {
|
||||
continue
|
||||
}
|
||||
address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[0]))
|
||||
if !addrOK {
|
||||
continue
|
||||
}
|
||||
address = address.Unmap()
|
||||
var mac net.HardwareAddr
|
||||
if fields[12] != "" {
|
||||
mac, _ = net.ParseMAC(fields[12])
|
||||
}
|
||||
if mac == nil {
|
||||
duid, duidErr := hex.DecodeString(strings.ReplaceAll(fields[1], ":", ""))
|
||||
if duidErr == nil {
|
||||
mac, _ = extractMACFromDUID(duid)
|
||||
}
|
||||
}
|
||||
hostname := ""
|
||||
if len(fields) > 11 {
|
||||
hostname = fields[11]
|
||||
}
|
||||
if mac != nil {
|
||||
ipToMAC[address] = mac
|
||||
}
|
||||
if hostname != "" {
|
||||
ipToHostname[address] = hostname
|
||||
if mac != nil {
|
||||
macToHostname[mac.String()] = hostname
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseBootpdLeases(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) {
|
||||
now := time.Now().Unix()
|
||||
scanner := bufio.NewScanner(file)
|
||||
var currentName string
|
||||
var currentIP netip.Addr
|
||||
var currentMAC net.HardwareAddr
|
||||
var currentLease int64
|
||||
var inBlock bool
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "{" {
|
||||
inBlock = true
|
||||
currentName = ""
|
||||
currentIP = netip.Addr{}
|
||||
currentMAC = nil
|
||||
currentLease = 0
|
||||
continue
|
||||
}
|
||||
if line == "}" && inBlock {
|
||||
if currentMAC != nil && currentIP.IsValid() {
|
||||
if currentLease == 0 || currentLease >= now {
|
||||
ipToMAC[currentIP] = currentMAC
|
||||
if currentName != "" {
|
||||
ipToHostname[currentIP] = currentName
|
||||
macToHostname[currentMAC.String()] = currentName
|
||||
}
|
||||
}
|
||||
}
|
||||
inBlock = false
|
||||
continue
|
||||
}
|
||||
if !inBlock {
|
||||
continue
|
||||
}
|
||||
key, value, found := strings.Cut(line, "=")
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
switch key {
|
||||
case "name":
|
||||
currentName = value
|
||||
case "ip_address":
|
||||
parsed, addrOK := netip.AddrFromSlice(net.ParseIP(value))
|
||||
if addrOK {
|
||||
currentIP = parsed.Unmap()
|
||||
}
|
||||
case "hw_address":
|
||||
typeAndMAC, hasSep := strings.CutPrefix(value, "1,")
|
||||
if hasSep {
|
||||
mac, macErr := net.ParseMAC(typeAndMAC)
|
||||
if macErr == nil {
|
||||
currentMAC = mac
|
||||
}
|
||||
}
|
||||
case "lease":
|
||||
leaseHex := strings.TrimPrefix(value, "0x")
|
||||
parsed, parseErr := strconv.ParseInt(leaseHex, 16, 64)
|
||||
if parseErr == nil {
|
||||
currentLease = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
224
route/neighbor_resolver_linux.go
Normal file
224
route/neighbor_resolver_linux.go
Normal file
@@ -0,0 +1,224 @@
|
||||
//go:build linux
|
||||
|
||||
package route
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/fswatch"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink"
|
||||
"github.com/mdlayher/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var defaultLeaseFiles = []string{
|
||||
"/tmp/dhcp.leases",
|
||||
"/var/lib/dhcp/dhcpd.leases",
|
||||
"/var/lib/dhcpd/dhcpd.leases",
|
||||
"/var/lib/kea/kea-leases4.csv",
|
||||
"/var/lib/kea/kea-leases6.csv",
|
||||
}
|
||||
|
||||
type neighborResolver struct {
|
||||
logger logger.ContextLogger
|
||||
leaseFiles []string
|
||||
access sync.RWMutex
|
||||
neighborIPToMAC map[netip.Addr]net.HardwareAddr
|
||||
leaseIPToMAC map[netip.Addr]net.HardwareAddr
|
||||
ipToHostname map[netip.Addr]string
|
||||
macToHostname map[string]string
|
||||
watcher *fswatch.Watcher
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newNeighborResolver(resolverLogger logger.ContextLogger, leaseFiles []string) (adapter.NeighborResolver, error) {
|
||||
if len(leaseFiles) == 0 {
|
||||
for _, path := range defaultLeaseFiles {
|
||||
info, err := os.Stat(path)
|
||||
if err == nil && info.Size() > 0 {
|
||||
leaseFiles = append(leaseFiles, path)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &neighborResolver{
|
||||
logger: resolverLogger,
|
||||
leaseFiles: leaseFiles,
|
||||
neighborIPToMAC: make(map[netip.Addr]net.HardwareAddr),
|
||||
leaseIPToMAC: make(map[netip.Addr]net.HardwareAddr),
|
||||
ipToHostname: make(map[netip.Addr]string),
|
||||
macToHostname: make(map[string]string),
|
||||
done: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *neighborResolver) Start() error {
|
||||
err := r.loadNeighborTable()
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "load neighbor table"))
|
||||
}
|
||||
r.doReloadLeaseFiles()
|
||||
go r.subscribeNeighborUpdates()
|
||||
if len(r.leaseFiles) > 0 {
|
||||
watcher, err := fswatch.NewWatcher(fswatch.Options{
|
||||
Path: r.leaseFiles,
|
||||
Logger: r.logger,
|
||||
Callback: func(_ string) {
|
||||
r.doReloadLeaseFiles()
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "create lease file watcher"))
|
||||
} else {
|
||||
r.watcher = watcher
|
||||
err = watcher.Start()
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "start lease file watcher"))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *neighborResolver) Close() error {
|
||||
close(r.done)
|
||||
if r.watcher != nil {
|
||||
return r.watcher.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *neighborResolver) LookupMAC(address netip.Addr) (net.HardwareAddr, bool) {
|
||||
r.access.RLock()
|
||||
defer r.access.RUnlock()
|
||||
mac, found := r.neighborIPToMAC[address]
|
||||
if found {
|
||||
return mac, true
|
||||
}
|
||||
mac, found = r.leaseIPToMAC[address]
|
||||
if found {
|
||||
return mac, true
|
||||
}
|
||||
mac, found = extractMACFromEUI64(address)
|
||||
if found {
|
||||
return mac, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (r *neighborResolver) LookupHostname(address netip.Addr) (string, bool) {
|
||||
r.access.RLock()
|
||||
defer r.access.RUnlock()
|
||||
hostname, found := r.ipToHostname[address]
|
||||
if found {
|
||||
return hostname, true
|
||||
}
|
||||
mac, macFound := r.neighborIPToMAC[address]
|
||||
if !macFound {
|
||||
mac, macFound = r.leaseIPToMAC[address]
|
||||
}
|
||||
if !macFound {
|
||||
mac, macFound = extractMACFromEUI64(address)
|
||||
}
|
||||
if macFound {
|
||||
hostname, found = r.macToHostname[mac.String()]
|
||||
if found {
|
||||
return hostname, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (r *neighborResolver) loadNeighborTable() error {
|
||||
connection, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return E.Cause(err, "dial rtnetlink")
|
||||
}
|
||||
defer connection.Close()
|
||||
neighbors, err := connection.Neigh.List()
|
||||
if err != nil {
|
||||
return E.Cause(err, "list neighbors")
|
||||
}
|
||||
r.access.Lock()
|
||||
defer r.access.Unlock()
|
||||
for _, neigh := range neighbors {
|
||||
if neigh.Attributes == nil {
|
||||
continue
|
||||
}
|
||||
if neigh.Attributes.LLAddress == nil || len(neigh.Attributes.Address) == 0 {
|
||||
continue
|
||||
}
|
||||
address, ok := netip.AddrFromSlice(neigh.Attributes.Address)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
r.neighborIPToMAC[address] = slices.Clone(neigh.Attributes.LLAddress)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *neighborResolver) subscribeNeighborUpdates() {
|
||||
connection, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{
|
||||
Groups: 1 << (unix.RTNLGRP_NEIGH - 1),
|
||||
})
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "subscribe neighbor updates"))
|
||||
return
|
||||
}
|
||||
defer connection.Close()
|
||||
for {
|
||||
select {
|
||||
case <-r.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
err = connection.SetReadDeadline(time.Now().Add(3 * time.Second))
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "set netlink read deadline"))
|
||||
return
|
||||
}
|
||||
messages, err := connection.Receive()
|
||||
if err != nil {
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-r.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
r.logger.Warn(E.Cause(err, "receive neighbor update"))
|
||||
continue
|
||||
}
|
||||
for _, message := range messages {
|
||||
address, mac, isDelete, ok := ParseNeighborMessage(message)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
r.access.Lock()
|
||||
if isDelete {
|
||||
delete(r.neighborIPToMAC, address)
|
||||
} else {
|
||||
r.neighborIPToMAC[address] = mac
|
||||
}
|
||||
r.access.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *neighborResolver) doReloadLeaseFiles() {
|
||||
leaseIPToMAC, ipToHostname, macToHostname := ReloadLeaseFiles(r.leaseFiles)
|
||||
r.access.Lock()
|
||||
r.leaseIPToMAC = leaseIPToMAC
|
||||
r.ipToHostname = ipToHostname
|
||||
r.macToHostname = macToHostname
|
||||
r.access.Unlock()
|
||||
}
|
||||
50
route/neighbor_resolver_parse.go
Normal file
50
route/neighbor_resolver_parse.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func extractMACFromDUID(duid []byte) (net.HardwareAddr, bool) {
|
||||
if len(duid) < 4 {
|
||||
return nil, false
|
||||
}
|
||||
duidType := binary.BigEndian.Uint16(duid[0:2])
|
||||
hwType := binary.BigEndian.Uint16(duid[2:4])
|
||||
if hwType != 1 {
|
||||
return nil, false
|
||||
}
|
||||
switch duidType {
|
||||
case 1:
|
||||
if len(duid) < 14 {
|
||||
return nil, false
|
||||
}
|
||||
return net.HardwareAddr(slices.Clone(duid[8:14])), true
|
||||
case 3:
|
||||
if len(duid) < 10 {
|
||||
return nil, false
|
||||
}
|
||||
return net.HardwareAddr(slices.Clone(duid[4:10])), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func extractMACFromEUI64(address netip.Addr) (net.HardwareAddr, bool) {
|
||||
if !address.Is6() {
|
||||
return nil, false
|
||||
}
|
||||
b := address.As16()
|
||||
if b[11] != 0xff || b[12] != 0xfe {
|
||||
return nil, false
|
||||
}
|
||||
return net.HardwareAddr{b[8] ^ 0x02, b[9], b[10], b[13], b[14], b[15]}, true
|
||||
}
|
||||
|
||||
func parseDUID(s string) ([]byte, error) {
|
||||
cleaned := strings.ReplaceAll(s, ":", "")
|
||||
return hex.DecodeString(cleaned)
|
||||
}
|
||||
84
route/neighbor_resolver_platform.go
Normal file
84
route/neighbor_resolver_platform.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
)
|
||||
|
||||
type platformNeighborResolver struct {
|
||||
logger logger.ContextLogger
|
||||
platform adapter.PlatformInterface
|
||||
access sync.RWMutex
|
||||
ipToMAC map[netip.Addr]net.HardwareAddr
|
||||
ipToHostname map[netip.Addr]string
|
||||
macToHostname map[string]string
|
||||
}
|
||||
|
||||
func newPlatformNeighborResolver(resolverLogger logger.ContextLogger, platform adapter.PlatformInterface) adapter.NeighborResolver {
|
||||
return &platformNeighborResolver{
|
||||
logger: resolverLogger,
|
||||
platform: platform,
|
||||
ipToMAC: make(map[netip.Addr]net.HardwareAddr),
|
||||
ipToHostname: make(map[netip.Addr]string),
|
||||
macToHostname: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *platformNeighborResolver) Start() error {
|
||||
return r.platform.StartNeighborMonitor(r)
|
||||
}
|
||||
|
||||
func (r *platformNeighborResolver) Close() error {
|
||||
return r.platform.CloseNeighborMonitor(r)
|
||||
}
|
||||
|
||||
func (r *platformNeighborResolver) LookupMAC(address netip.Addr) (net.HardwareAddr, bool) {
|
||||
r.access.RLock()
|
||||
defer r.access.RUnlock()
|
||||
mac, found := r.ipToMAC[address]
|
||||
if found {
|
||||
return mac, true
|
||||
}
|
||||
return extractMACFromEUI64(address)
|
||||
}
|
||||
|
||||
func (r *platformNeighborResolver) LookupHostname(address netip.Addr) (string, bool) {
|
||||
r.access.RLock()
|
||||
defer r.access.RUnlock()
|
||||
hostname, found := r.ipToHostname[address]
|
||||
if found {
|
||||
return hostname, true
|
||||
}
|
||||
mac, found := r.ipToMAC[address]
|
||||
if !found {
|
||||
mac, found = extractMACFromEUI64(address)
|
||||
}
|
||||
if !found {
|
||||
return "", false
|
||||
}
|
||||
hostname, found = r.macToHostname[mac.String()]
|
||||
return hostname, found
|
||||
}
|
||||
|
||||
func (r *platformNeighborResolver) UpdateNeighborTable(entries []adapter.NeighborEntry) {
|
||||
ipToMAC := make(map[netip.Addr]net.HardwareAddr)
|
||||
ipToHostname := make(map[netip.Addr]string)
|
||||
macToHostname := make(map[string]string)
|
||||
for _, entry := range entries {
|
||||
ipToMAC[entry.Address] = entry.MACAddress
|
||||
if entry.Hostname != "" {
|
||||
ipToHostname[entry.Address] = entry.Hostname
|
||||
macToHostname[entry.MACAddress.String()] = entry.Hostname
|
||||
}
|
||||
}
|
||||
r.access.Lock()
|
||||
r.ipToMAC = ipToMAC
|
||||
r.ipToHostname = ipToHostname
|
||||
r.macToHostname = macToHostname
|
||||
r.access.Unlock()
|
||||
r.logger.Info("updated neighbor table: ", len(entries), " entries")
|
||||
}
|
||||
14
route/neighbor_resolver_stub.go
Normal file
14
route/neighbor_resolver_stub.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build !linux && !darwin
|
||||
|
||||
package route
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
)
|
||||
|
||||
func newNeighborResolver(_ logger.ContextLogger, _ []string) (adapter.NeighborResolver, error) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
104
route/neighbor_table_darwin.go
Normal file
104
route/neighbor_table_darwin.go
Normal file
@@ -0,0 +1,104 @@
|
||||
//go:build darwin
|
||||
|
||||
package route
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func ReadNeighborEntries() ([]adapter.NeighborEntry, error) {
|
||||
var entries []adapter.NeighborEntry
|
||||
ipv4Entries, err := readNeighborEntriesAF(syscall.AF_INET)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read IPv4 neighbors")
|
||||
}
|
||||
entries = append(entries, ipv4Entries...)
|
||||
ipv6Entries, err := readNeighborEntriesAF(syscall.AF_INET6)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read IPv6 neighbors")
|
||||
}
|
||||
entries = append(entries, ipv6Entries...)
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func readNeighborEntriesAF(addressFamily int) ([]adapter.NeighborEntry, error) {
|
||||
rib, err := route.FetchRIB(addressFamily, route.RIBType(syscall.NET_RT_FLAGS), syscall.RTF_LLINFO)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages, err := route.ParseRIB(route.RIBType(syscall.NET_RT_FLAGS), rib)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var entries []adapter.NeighborEntry
|
||||
for _, message := range messages {
|
||||
routeMessage, isRouteMessage := message.(*route.RouteMessage)
|
||||
if !isRouteMessage {
|
||||
continue
|
||||
}
|
||||
address, macAddress, ok := parseRouteNeighborEntry(routeMessage)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, adapter.NeighborEntry{
|
||||
Address: address,
|
||||
MACAddress: macAddress,
|
||||
})
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func parseRouteNeighborEntry(message *route.RouteMessage) (address netip.Addr, macAddress net.HardwareAddr, ok bool) {
|
||||
if len(message.Addrs) <= unix.RTAX_GATEWAY {
|
||||
return
|
||||
}
|
||||
gateway, isLinkAddr := message.Addrs[unix.RTAX_GATEWAY].(*route.LinkAddr)
|
||||
if !isLinkAddr || len(gateway.Addr) < 6 {
|
||||
return
|
||||
}
|
||||
switch destination := message.Addrs[unix.RTAX_DST].(type) {
|
||||
case *route.Inet4Addr:
|
||||
address = netip.AddrFrom4(destination.IP)
|
||||
case *route.Inet6Addr:
|
||||
address = netip.AddrFrom16(destination.IP)
|
||||
default:
|
||||
return
|
||||
}
|
||||
macAddress = net.HardwareAddr(make([]byte, len(gateway.Addr)))
|
||||
copy(macAddress, gateway.Addr)
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
func ParseRouteNeighborMessage(message *route.RouteMessage) (address netip.Addr, macAddress net.HardwareAddr, isDelete bool, ok bool) {
|
||||
isDelete = message.Type == unix.RTM_DELETE
|
||||
if len(message.Addrs) <= unix.RTAX_GATEWAY {
|
||||
return
|
||||
}
|
||||
switch destination := message.Addrs[unix.RTAX_DST].(type) {
|
||||
case *route.Inet4Addr:
|
||||
address = netip.AddrFrom4(destination.IP)
|
||||
case *route.Inet6Addr:
|
||||
address = netip.AddrFrom16(destination.IP)
|
||||
default:
|
||||
return
|
||||
}
|
||||
if !isDelete {
|
||||
gateway, isLinkAddr := message.Addrs[unix.RTAX_GATEWAY].(*route.LinkAddr)
|
||||
if !isLinkAddr || len(gateway.Addr) < 6 {
|
||||
return
|
||||
}
|
||||
macAddress = net.HardwareAddr(make([]byte, len(gateway.Addr)))
|
||||
copy(macAddress, gateway.Addr)
|
||||
}
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
68
route/neighbor_table_linux.go
Normal file
68
route/neighbor_table_linux.go
Normal file
@@ -0,0 +1,68 @@
|
||||
//go:build linux
|
||||
|
||||
package route
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink"
|
||||
"github.com/mdlayher/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func ReadNeighborEntries() ([]adapter.NeighborEntry, error) {
|
||||
connection, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "dial rtnetlink")
|
||||
}
|
||||
defer connection.Close()
|
||||
neighbors, err := connection.Neigh.List()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "list neighbors")
|
||||
}
|
||||
var entries []adapter.NeighborEntry
|
||||
for _, neighbor := range neighbors {
|
||||
if neighbor.Attributes == nil {
|
||||
continue
|
||||
}
|
||||
if neighbor.Attributes.LLAddress == nil || len(neighbor.Attributes.Address) == 0 {
|
||||
continue
|
||||
}
|
||||
address, ok := netip.AddrFromSlice(neighbor.Attributes.Address)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, adapter.NeighborEntry{
|
||||
Address: address,
|
||||
MACAddress: slices.Clone(neighbor.Attributes.LLAddress),
|
||||
})
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func ParseNeighborMessage(message netlink.Message) (address netip.Addr, macAddress net.HardwareAddr, isDelete bool, ok bool) {
|
||||
var neighMessage rtnetlink.NeighMessage
|
||||
err := neighMessage.UnmarshalBinary(message.Data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if neighMessage.Attributes == nil || len(neighMessage.Attributes.Address) == 0 {
|
||||
return
|
||||
}
|
||||
address, ok = netip.AddrFromSlice(neighMessage.Attributes.Address)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
isDelete = message.Header.Type == unix.RTM_DELNEIGH
|
||||
if !isDelete && neighMessage.Attributes.LLAddress == nil {
|
||||
ok = false
|
||||
return
|
||||
}
|
||||
macAddress = slices.Clone(neighMessage.Attributes.LLAddress)
|
||||
return
|
||||
}
|
||||
@@ -51,6 +51,7 @@ type NetworkManager struct {
|
||||
endpoint adapter.EndpointManager
|
||||
inbound adapter.InboundManager
|
||||
outbound adapter.OutboundManager
|
||||
serviceManager adapter.ServiceManager
|
||||
needWIFIState bool
|
||||
wifiMonitor settings.WIFIMonitor
|
||||
wifiState adapter.WIFIState
|
||||
@@ -94,6 +95,7 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, options
|
||||
endpoint: service.FromContext[adapter.EndpointManager](ctx),
|
||||
inbound: service.FromContext[adapter.InboundManager](ctx),
|
||||
outbound: service.FromContext[adapter.OutboundManager](ctx),
|
||||
serviceManager: service.FromContext[adapter.ServiceManager](ctx),
|
||||
needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule),
|
||||
}
|
||||
if options.DefaultNetworkStrategy != nil {
|
||||
@@ -475,6 +477,15 @@ func (r *NetworkManager) ResetNetwork() {
|
||||
listener.InterfaceUpdated()
|
||||
}
|
||||
}
|
||||
|
||||
if r.serviceManager != nil {
|
||||
for _, svc := range r.serviceManager.Services() {
|
||||
listener, isListener := svc.(adapter.InterfaceUpdateListener)
|
||||
if isListener {
|
||||
listener.InterfaceUpdated()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *NetworkManager) notifyInterfaceUpdate(defaultInterface *control.Interface, flags int) {
|
||||
|
||||
@@ -439,6 +439,23 @@ func (r *Router) matchRule(
|
||||
metadata.ProcessInfo = processInfo
|
||||
}
|
||||
}
|
||||
if r.neighborResolver != nil && metadata.SourceMACAddress == nil && metadata.Source.Addr.IsValid() {
|
||||
mac, macFound := r.neighborResolver.LookupMAC(metadata.Source.Addr)
|
||||
if macFound {
|
||||
metadata.SourceMACAddress = mac
|
||||
}
|
||||
hostname, hostnameFound := r.neighborResolver.LookupHostname(metadata.Source.Addr)
|
||||
if hostnameFound {
|
||||
metadata.SourceHostname = hostname
|
||||
if macFound {
|
||||
r.logger.InfoContext(ctx, "found neighbor: ", mac, ", hostname: ", hostname)
|
||||
} else {
|
||||
r.logger.InfoContext(ctx, "found neighbor hostname: ", hostname)
|
||||
}
|
||||
} else if macFound {
|
||||
r.logger.InfoContext(ctx, "found neighbor: ", mac)
|
||||
}
|
||||
}
|
||||
if metadata.Destination.Addr.IsValid() && r.dnsTransport.FakeIP() != nil && r.dnsTransport.FakeIP().Store().Contains(metadata.Destination.Addr) {
|
||||
domain, loaded := r.dnsTransport.FakeIP().Store().Lookup(metadata.Destination.Addr)
|
||||
if !loaded {
|
||||
|
||||
@@ -31,9 +31,12 @@ type Router struct {
|
||||
network adapter.NetworkManager
|
||||
rules []adapter.Rule
|
||||
needFindProcess bool
|
||||
needFindNeighbor bool
|
||||
leaseFiles []string
|
||||
ruleSets []adapter.RuleSet
|
||||
ruleSetMap map[string]adapter.RuleSet
|
||||
processSearcher process.Searcher
|
||||
neighborResolver adapter.NeighborResolver
|
||||
pauseManager pause.Manager
|
||||
trackers []adapter.ConnectionTracker
|
||||
platformInterface adapter.PlatformInterface
|
||||
@@ -53,6 +56,8 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
|
||||
rules: make([]adapter.Rule, 0, len(options.Rules)),
|
||||
ruleSetMap: make(map[string]adapter.RuleSet),
|
||||
needFindProcess: hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess,
|
||||
needFindNeighbor: hasRule(options.Rules, isNeighborRule) || hasDNSRule(dnsOptions.Rules, isNeighborDNSRule) || options.FindNeighbor,
|
||||
leaseFiles: options.DHCPLeaseFiles,
|
||||
pauseManager: service.FromContext[pause.Manager](ctx),
|
||||
platformInterface: service.FromContext[adapter.PlatformInterface](ctx),
|
||||
}
|
||||
@@ -112,6 +117,7 @@ func (r *Router) Start(stage adapter.StartStage) error {
|
||||
}
|
||||
r.network.Initialize(r.ruleSets)
|
||||
needFindProcess := r.needFindProcess
|
||||
needFindNeighbor := r.needFindNeighbor
|
||||
for _, ruleSet := range r.ruleSets {
|
||||
metadata := ruleSet.Metadata()
|
||||
if metadata.ContainsProcessRule {
|
||||
@@ -141,6 +147,36 @@ func (r *Router) Start(stage adapter.StartStage) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
r.needFindNeighbor = needFindNeighbor
|
||||
if needFindNeighbor {
|
||||
if r.platformInterface != nil && r.platformInterface.UsePlatformNeighborResolver() {
|
||||
monitor.Start("initialize neighbor resolver")
|
||||
resolver := newPlatformNeighborResolver(r.logger, r.platformInterface)
|
||||
err := resolver.Start()
|
||||
monitor.Finish()
|
||||
if err != nil {
|
||||
r.logger.Error(E.Cause(err, "start neighbor resolver"))
|
||||
} else {
|
||||
r.neighborResolver = resolver
|
||||
}
|
||||
} else {
|
||||
monitor.Start("initialize neighbor resolver")
|
||||
resolver, err := newNeighborResolver(r.logger, r.leaseFiles)
|
||||
monitor.Finish()
|
||||
if err != nil {
|
||||
if err != os.ErrInvalid {
|
||||
r.logger.Error(E.Cause(err, "create neighbor resolver"))
|
||||
}
|
||||
} else {
|
||||
err = resolver.Start()
|
||||
if err != nil {
|
||||
r.logger.Error(E.Cause(err, "start neighbor resolver"))
|
||||
} else {
|
||||
r.neighborResolver = resolver
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case adapter.StartStatePostStart:
|
||||
for i, rule := range r.rules {
|
||||
monitor.Start("initialize rule[", i, "]")
|
||||
@@ -172,6 +208,13 @@ func (r *Router) Start(stage adapter.StartStage) error {
|
||||
func (r *Router) Close() error {
|
||||
monitor := taskmonitor.New(r.logger, C.StopTimeout)
|
||||
var err error
|
||||
if r.neighborResolver != nil {
|
||||
monitor.Start("close neighbor resolver")
|
||||
err = E.Append(err, r.neighborResolver.Close(), func(closeErr error) error {
|
||||
return E.Cause(closeErr, "close neighbor resolver")
|
||||
})
|
||||
monitor.Finish()
|
||||
}
|
||||
for i, rule := range r.rules {
|
||||
monitor.Start("close rule[", i, "]")
|
||||
err = E.Append(err, rule.Close(), func(err error) error {
|
||||
@@ -206,6 +249,14 @@ func (r *Router) NeedFindProcess() bool {
|
||||
return r.needFindProcess
|
||||
}
|
||||
|
||||
func (r *Router) NeedFindNeighbor() bool {
|
||||
return r.needFindNeighbor
|
||||
}
|
||||
|
||||
func (r *Router) NeighborResolver() adapter.NeighborResolver {
|
||||
return r.neighborResolver
|
||||
}
|
||||
|
||||
func (r *Router) ResetNetwork() {
|
||||
r.network.ResetNetwork()
|
||||
r.dns.ResetNetwork()
|
||||
|
||||
@@ -260,6 +260,16 @@ func NewDefaultRule(ctx context.Context, logger log.ContextLogger, options optio
|
||||
rule.items = append(rule.items, item)
|
||||
rule.allItems = append(rule.allItems, item)
|
||||
}
|
||||
if len(options.SourceMACAddress) > 0 {
|
||||
item := NewSourceMACAddressItem(options.SourceMACAddress)
|
||||
rule.items = append(rule.items, item)
|
||||
rule.allItems = append(rule.allItems, item)
|
||||
}
|
||||
if len(options.SourceHostname) > 0 {
|
||||
item := NewSourceHostnameItem(options.SourceHostname)
|
||||
rule.items = append(rule.items, item)
|
||||
rule.allItems = append(rule.allItems, item)
|
||||
}
|
||||
if len(options.PreferredBy) > 0 {
|
||||
item := NewPreferredByItem(ctx, options.PreferredBy)
|
||||
rule.items = append(rule.items, item)
|
||||
|
||||
@@ -261,6 +261,16 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op
|
||||
rule.items = append(rule.items, item)
|
||||
rule.allItems = append(rule.allItems, item)
|
||||
}
|
||||
if len(options.SourceMACAddress) > 0 {
|
||||
item := NewSourceMACAddressItem(options.SourceMACAddress)
|
||||
rule.items = append(rule.items, item)
|
||||
rule.allItems = append(rule.allItems, item)
|
||||
}
|
||||
if len(options.SourceHostname) > 0 {
|
||||
item := NewSourceHostnameItem(options.SourceHostname)
|
||||
rule.items = append(rule.items, item)
|
||||
rule.allItems = append(rule.allItems, item)
|
||||
}
|
||||
if len(options.RuleSet) > 0 {
|
||||
//nolint:staticcheck
|
||||
if options.Deprecated_RulesetIPCIDRMatchSource {
|
||||
|
||||
42
route/rule/rule_item_source_hostname.go
Normal file
42
route/rule/rule_item_source_hostname.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
)
|
||||
|
||||
var _ RuleItem = (*SourceHostnameItem)(nil)
|
||||
|
||||
type SourceHostnameItem struct {
|
||||
hostnames []string
|
||||
hostnameMap map[string]bool
|
||||
}
|
||||
|
||||
func NewSourceHostnameItem(hostnameList []string) *SourceHostnameItem {
|
||||
rule := &SourceHostnameItem{
|
||||
hostnames: hostnameList,
|
||||
hostnameMap: make(map[string]bool),
|
||||
}
|
||||
for _, hostname := range hostnameList {
|
||||
rule.hostnameMap[hostname] = true
|
||||
}
|
||||
return rule
|
||||
}
|
||||
|
||||
func (r *SourceHostnameItem) Match(metadata *adapter.InboundContext) bool {
|
||||
if metadata.SourceHostname == "" {
|
||||
return false
|
||||
}
|
||||
return r.hostnameMap[metadata.SourceHostname]
|
||||
}
|
||||
|
||||
func (r *SourceHostnameItem) String() string {
|
||||
var description string
|
||||
if len(r.hostnames) == 1 {
|
||||
description = "source_hostname=" + r.hostnames[0]
|
||||
} else {
|
||||
description = "source_hostname=[" + strings.Join(r.hostnames, " ") + "]"
|
||||
}
|
||||
return description
|
||||
}
|
||||
48
route/rule/rule_item_source_mac_address.go
Normal file
48
route/rule/rule_item_source_mac_address.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
)
|
||||
|
||||
var _ RuleItem = (*SourceMACAddressItem)(nil)
|
||||
|
||||
type SourceMACAddressItem struct {
|
||||
addresses []string
|
||||
addressMap map[string]bool
|
||||
}
|
||||
|
||||
func NewSourceMACAddressItem(addressList []string) *SourceMACAddressItem {
|
||||
rule := &SourceMACAddressItem{
|
||||
addresses: addressList,
|
||||
addressMap: make(map[string]bool),
|
||||
}
|
||||
for _, address := range addressList {
|
||||
parsed, err := net.ParseMAC(address)
|
||||
if err == nil {
|
||||
rule.addressMap[parsed.String()] = true
|
||||
} else {
|
||||
rule.addressMap[address] = true
|
||||
}
|
||||
}
|
||||
return rule
|
||||
}
|
||||
|
||||
func (r *SourceMACAddressItem) Match(metadata *adapter.InboundContext) bool {
|
||||
if metadata.SourceMACAddress == nil {
|
||||
return false
|
||||
}
|
||||
return r.addressMap[metadata.SourceMACAddress.String()]
|
||||
}
|
||||
|
||||
func (r *SourceMACAddressItem) String() string {
|
||||
var description string
|
||||
if len(r.addresses) == 1 {
|
||||
description = "source_mac_address=" + r.addresses[0]
|
||||
} else {
|
||||
description = "source_mac_address=[" + strings.Join(r.addresses, " ") + "]"
|
||||
}
|
||||
return description
|
||||
}
|
||||
@@ -45,6 +45,14 @@ func isProcessDNSRule(rule option.DefaultDNSRule) bool {
|
||||
return len(rule.ProcessName) > 0 || len(rule.ProcessPath) > 0 || len(rule.ProcessPathRegex) > 0 || len(rule.PackageName) > 0 || len(rule.User) > 0 || len(rule.UserID) > 0
|
||||
}
|
||||
|
||||
func isNeighborRule(rule option.DefaultRule) bool {
|
||||
return len(rule.SourceMACAddress) > 0 || len(rule.SourceHostname) > 0
|
||||
}
|
||||
|
||||
func isNeighborDNSRule(rule option.DefaultDNSRule) bool {
|
||||
return len(rule.SourceMACAddress) > 0 || len(rule.SourceHostname) > 0
|
||||
}
|
||||
|
||||
func isWIFIRule(rule option.DefaultRule) bool {
|
||||
return len(rule.WIFISSID) > 0 || len(rule.WIFIBSSID) > 0
|
||||
}
|
||||
|
||||
13
service/ccm/CLAUDE.md
Normal file
13
service/ccm/CLAUDE.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# Claude Code Multiplexer
|
||||
|
||||
### Reverse Claude Code
|
||||
|
||||
Claude distributes a huge binary by default in a Bun, which is difficult to reverse engineer (and is very likely the one the user have installed now).
|
||||
|
||||
You must obtain the npm version of the Claude Code js source code:
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
cd /tmp && npm pack @anthropic-ai/claude-code && tar xzf anthropic-ai-claude-code-*.tgz && npx prettier --write package/cli.js
|
||||
```
|
||||
@@ -1,139 +1,258 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
)
|
||||
|
||||
const (
|
||||
oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
oauth2TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
claudeAPIBaseURL = "https://api.anthropic.com"
|
||||
tokenRefreshBufferMs = 60000
|
||||
anthropicBetaOAuthValue = "oauth-2025-04-20"
|
||||
defaultPollInterval = 60 * time.Minute
|
||||
failedPollRetryInterval = time.Minute
|
||||
httpRetryMaxBackoff = 5 * time.Minute
|
||||
)
|
||||
|
||||
func getRealUser() (*user.User, error) {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
sudoUserInfo, err := user.Lookup(sudoUser)
|
||||
const (
|
||||
httpRetryMaxAttempts = 3
|
||||
httpRetryInitialDelay = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
const sessionExpiry = 24 * time.Hour
|
||||
|
||||
func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) {
|
||||
var lastError error
|
||||
for attempt := range httpRetryMaxAttempts {
|
||||
if attempt > 0 {
|
||||
delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1))
|
||||
timer := time.NewTimer(delay)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return nil, lastError
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
request, err := buildRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response, err := client.Do(request)
|
||||
if err == nil {
|
||||
return sudoUserInfo, nil
|
||||
return response, nil
|
||||
}
|
||||
lastError = err
|
||||
if ctx.Err() != nil {
|
||||
return nil, lastError
|
||||
}
|
||||
}
|
||||
return user.Current()
|
||||
return nil, lastError
|
||||
}
|
||||
|
||||
func getDefaultCredentialsPath() (string, error) {
|
||||
if configDir := os.Getenv("CLAUDE_CONFIG_DIR"); configDir != "" {
|
||||
return filepath.Join(configDir, ".credentials.json"), nil
|
||||
}
|
||||
userInfo, err := getRealUser()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(userInfo.HomeDir, ".claude", ".credentials.json"), nil
|
||||
type credentialState struct {
|
||||
fiveHourUtilization float64
|
||||
fiveHourReset time.Time
|
||||
weeklyUtilization float64
|
||||
weeklyReset time.Time
|
||||
hardRateLimited bool
|
||||
rateLimitResetAt time.Time
|
||||
availabilityState availabilityState
|
||||
availabilityReason availabilityReason
|
||||
availabilityResetAt time.Time
|
||||
lastKnownDataAt time.Time
|
||||
accountUUID string
|
||||
accountType string
|
||||
rateLimitTier string
|
||||
oauthAccount *claudeOAuthAccount
|
||||
remotePlanWeight float64
|
||||
lastUpdated time.Time
|
||||
consecutivePollFailures int
|
||||
usageAPIRetryDelay time.Duration
|
||||
unavailable bool
|
||||
upstreamRejectedUntil time.Time
|
||||
lastCredentialLoadAttempt time.Time
|
||||
lastCredentialLoadError string
|
||||
}
|
||||
|
||||
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var credentialsContainer struct {
|
||||
ClaudeAIAuth *oauthCredentials `json:"claudeAiOauth,omitempty"`
|
||||
}
|
||||
err = json.Unmarshal(data, &credentialsContainer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if credentialsContainer.ClaudeAIAuth == nil {
|
||||
return nil, E.New("claudeAiOauth field not found in credentials")
|
||||
}
|
||||
return credentialsContainer.ClaudeAIAuth, nil
|
||||
type credentialRequestContext struct {
|
||||
context.Context
|
||||
releaseOnce sync.Once
|
||||
cancelOnce sync.Once
|
||||
releaseFuncs []func() bool
|
||||
cancelFunc context.CancelFunc
|
||||
}
|
||||
|
||||
func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) error {
|
||||
data, err := json.MarshalIndent(map[string]any{
|
||||
"claudeAiOauth": oauthCredentials,
|
||||
}, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0o600)
|
||||
func (c *credentialRequestContext) addInterruptLink(stop func() bool) {
|
||||
c.releaseFuncs = append(c.releaseFuncs, stop)
|
||||
}
|
||||
|
||||
type oauthCredentials struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpiresAt int64 `json:"expiresAt"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
SubscriptionType string `json:"subscriptionType,omitempty"`
|
||||
IsMax bool `json:"isMax,omitempty"`
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) needsRefresh() bool {
|
||||
if c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs
|
||||
}
|
||||
|
||||
func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
|
||||
if credentials.RefreshToken == "" {
|
||||
return nil, E.New("refresh token is empty")
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.RefreshToken,
|
||||
"client_id": oauth2ClientID,
|
||||
func (c *credentialRequestContext) releaseCredentialInterrupt() {
|
||||
c.releaseOnce.Do(func() {
|
||||
for _, f := range c.releaseFuncs {
|
||||
f()
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "marshal request")
|
||||
}
|
||||
|
||||
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "application/json")
|
||||
|
||||
response, err := httpClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
|
||||
}
|
||||
|
||||
var tokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode response")
|
||||
}
|
||||
|
||||
newCredentials := *credentials
|
||||
newCredentials.AccessToken = tokenResponse.AccessToken
|
||||
if tokenResponse.RefreshToken != "" {
|
||||
newCredentials.RefreshToken = tokenResponse.RefreshToken
|
||||
}
|
||||
newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000
|
||||
|
||||
return &newCredentials, nil
|
||||
}
|
||||
|
||||
func (c *credentialRequestContext) cancelRequest() {
|
||||
c.releaseCredentialInterrupt()
|
||||
c.cancelOnce.Do(c.cancelFunc)
|
||||
}
|
||||
|
||||
type Credential interface {
|
||||
tagName() string
|
||||
isAvailable() bool
|
||||
isUsable() bool
|
||||
isExternal() bool
|
||||
hasSnapshotData() bool
|
||||
fiveHourUtilization() float64
|
||||
weeklyUtilization() float64
|
||||
fiveHourCap() float64
|
||||
weeklyCap() float64
|
||||
planWeight() float64
|
||||
fiveHourResetTime() time.Time
|
||||
weeklyResetTime() time.Time
|
||||
markRateLimited(resetAt time.Time)
|
||||
markUpstreamRejected()
|
||||
availabilityStatus() availabilityStatus
|
||||
earliestReset() time.Time
|
||||
unavailableError() error
|
||||
|
||||
getAccessToken() (string, error)
|
||||
buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error)
|
||||
updateStateFromHeaders(header http.Header)
|
||||
|
||||
wrapRequestContext(ctx context.Context) *credentialRequestContext
|
||||
interruptConnections()
|
||||
|
||||
setStatusSubscriber(*observable.Subscriber[struct{}])
|
||||
start() error
|
||||
pollUsage()
|
||||
lastUpdatedTime() time.Time
|
||||
pollBackoff(base time.Duration) time.Duration
|
||||
usageTrackerOrNil() *AggregatedUsage
|
||||
httpClient() *http.Client
|
||||
close()
|
||||
}
|
||||
|
||||
type credentialSelectionScope string
|
||||
|
||||
const (
|
||||
credentialSelectionScopeAll credentialSelectionScope = "all"
|
||||
credentialSelectionScopeNonExternal credentialSelectionScope = "non_external"
|
||||
)
|
||||
|
||||
type credentialSelection struct {
|
||||
scope credentialSelectionScope
|
||||
filter func(Credential) bool
|
||||
}
|
||||
|
||||
func (s credentialSelection) allows(credential Credential) bool {
|
||||
return s.filter == nil || s.filter(credential)
|
||||
}
|
||||
|
||||
func (s credentialSelection) scopeOrDefault() credentialSelectionScope {
|
||||
if s.scope == "" {
|
||||
return credentialSelectionScopeAll
|
||||
}
|
||||
return s.scope
|
||||
}
|
||||
|
||||
// Claude Code's unified rate-limit handling parses these reset headers with
|
||||
// Number(...), compares them against Date.now()/1000, and renders them via
|
||||
// new Date(seconds*1000), so keep the wire format pinned to Unix epoch seconds.
|
||||
func parseAnthropicResetHeaderValue(headerName string, headerValue string) time.Time {
|
||||
unixEpoch, err := strconv.ParseInt(headerValue, 10, 64)
|
||||
if err != nil {
|
||||
panic("invalid " + headerName + " header: expected Unix epoch seconds, got " + strconv.Quote(headerValue))
|
||||
}
|
||||
if unixEpoch <= 0 {
|
||||
panic("invalid " + headerName + " header: expected positive Unix epoch seconds, got " + strconv.Quote(headerValue))
|
||||
}
|
||||
return time.Unix(unixEpoch, 0)
|
||||
}
|
||||
|
||||
func parseOptionalAnthropicResetHeader(headers http.Header, headerName string) (time.Time, bool) {
|
||||
headerValue := headers.Get(headerName)
|
||||
if headerValue == "" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return parseAnthropicResetHeaderValue(headerName, headerValue), true
|
||||
}
|
||||
|
||||
func parseRequiredAnthropicResetHeader(headers http.Header, headerName string) time.Time {
|
||||
headerValue := headers.Get(headerName)
|
||||
if headerValue == "" {
|
||||
panic("missing required " + headerName + " header")
|
||||
}
|
||||
return parseAnthropicResetHeaderValue(headerName, headerValue)
|
||||
}
|
||||
|
||||
func (s *credentialState) noteSnapshotData() {
|
||||
s.lastKnownDataAt = time.Now()
|
||||
}
|
||||
|
||||
func (s credentialState) hasSnapshotData() bool {
|
||||
return !s.lastKnownDataAt.IsZero() ||
|
||||
s.fiveHourUtilization > 0 ||
|
||||
s.weeklyUtilization > 0 ||
|
||||
!s.fiveHourReset.IsZero() ||
|
||||
!s.weeklyReset.IsZero()
|
||||
}
|
||||
|
||||
func (s *credentialState) setAvailability(state availabilityState, reason availabilityReason, resetAt time.Time) {
|
||||
s.availabilityState = state
|
||||
s.availabilityReason = reason
|
||||
s.availabilityResetAt = resetAt
|
||||
}
|
||||
|
||||
func (s credentialState) currentAvailability() availabilityStatus {
|
||||
now := time.Now()
|
||||
switch {
|
||||
case s.unavailable:
|
||||
return availabilityStatus{
|
||||
State: availabilityStateUnavailable,
|
||||
Reason: availabilityReasonUnknown,
|
||||
ResetAt: s.availabilityResetAt,
|
||||
}
|
||||
case s.hardRateLimited && (s.rateLimitResetAt.IsZero() || now.Before(s.rateLimitResetAt)):
|
||||
reason := s.availabilityReason
|
||||
if reason == "" {
|
||||
reason = availabilityReasonHardRateLimit
|
||||
}
|
||||
return availabilityStatus{
|
||||
State: availabilityStateRateLimited,
|
||||
Reason: reason,
|
||||
ResetAt: s.rateLimitResetAt,
|
||||
}
|
||||
case !s.upstreamRejectedUntil.IsZero() && now.Before(s.upstreamRejectedUntil):
|
||||
return availabilityStatus{
|
||||
State: availabilityStateTemporarilyBlocked,
|
||||
Reason: availabilityReasonUpstreamRejected,
|
||||
ResetAt: s.upstreamRejectedUntil,
|
||||
}
|
||||
case s.consecutivePollFailures > 0:
|
||||
return availabilityStatus{
|
||||
State: availabilityStateTemporarilyBlocked,
|
||||
Reason: availabilityReasonPollFailed,
|
||||
}
|
||||
default:
|
||||
return availabilityStatus{State: availabilityStateUsable}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func parseRateLimitResetFromHeaders(headers http.Header) time.Time {
|
||||
claim := headers.Get("anthropic-ratelimit-unified-representative-claim")
|
||||
switch claim {
|
||||
case "5h":
|
||||
return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset")
|
||||
case "7d":
|
||||
return parseRequiredAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset")
|
||||
default:
|
||||
panic("invalid anthropic-ratelimit-unified-representative-claim header: " + strconv.Quote(claim))
|
||||
}
|
||||
}
|
||||
|
||||
162
service/ccm/credential_builder.go
Normal file
162
service/ccm/credential_builder.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func buildCredentialProviders(
|
||||
ctx context.Context,
|
||||
options option.CCMServiceOptions,
|
||||
logger log.ContextLogger,
|
||||
) (map[string]credentialProvider, []Credential, error) {
|
||||
allCredentialMap := make(map[string]Credential)
|
||||
var allCredentials []Credential
|
||||
providers := make(map[string]credentialProvider)
|
||||
|
||||
// Pass 1: create default and external credentials
|
||||
for _, credentialOption := range options.Credentials {
|
||||
switch credentialOption.Type {
|
||||
case "default":
|
||||
credential, err := newDefaultCredential(ctx, credentialOption.Tag, credentialOption.DefaultOptions, logger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
allCredentialMap[credentialOption.Tag] = credential
|
||||
allCredentials = append(allCredentials, credential)
|
||||
providers[credentialOption.Tag] = &singleCredentialProvider{credential: credential}
|
||||
case "external":
|
||||
credential, err := newExternalCredential(ctx, credentialOption.Tag, credentialOption.ExternalOptions, logger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
allCredentialMap[credentialOption.Tag] = credential
|
||||
allCredentials = append(allCredentials, credential)
|
||||
providers[credentialOption.Tag] = &singleCredentialProvider{credential: credential}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2: create balancer providers
|
||||
for _, credentialOption := range options.Credentials {
|
||||
if credentialOption.Type == "balancer" {
|
||||
subCredentials, err := resolveCredentialTags(credentialOption.BalancerOptions.Credentials, allCredentialMap, credentialOption.Tag)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
providers[credentialOption.Tag] = newBalancerProvider(subCredentials, credentialOption.BalancerOptions.Strategy, credentialOption.BalancerOptions.RebalanceThreshold, logger)
|
||||
}
|
||||
}
|
||||
|
||||
return providers, allCredentials, nil
|
||||
}
|
||||
|
||||
func resolveCredentialTags(tags []string, allCredentials map[string]Credential, parentTag string) ([]Credential, error) {
|
||||
credentials := make([]Credential, 0, len(tags))
|
||||
for _, tag := range tags {
|
||||
credential, exists := allCredentials[tag]
|
||||
if !exists {
|
||||
return nil, E.New("credential ", parentTag, " references unknown credential: ", tag)
|
||||
}
|
||||
credentials = append(credentials, credential)
|
||||
}
|
||||
if len(credentials) == 0 {
|
||||
return nil, E.New("credential ", parentTag, " has no sub-credentials")
|
||||
}
|
||||
return credentials, nil
|
||||
}
|
||||
|
||||
func validateCCMOptions(options option.CCMServiceOptions) error {
|
||||
tags := make(map[string]bool)
|
||||
credentialTypes := make(map[string]string)
|
||||
for _, credential := range options.Credentials {
|
||||
if tags[credential.Tag] {
|
||||
return E.New("duplicate credential tag: ", credential.Tag)
|
||||
}
|
||||
tags[credential.Tag] = true
|
||||
credentialTypes[credential.Tag] = credential.Type
|
||||
if credential.Type == "default" || credential.Type == "" {
|
||||
if credential.DefaultOptions.Reserve5h > 99 {
|
||||
return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99")
|
||||
}
|
||||
if credential.DefaultOptions.ReserveWeekly > 99 {
|
||||
return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99")
|
||||
}
|
||||
if credential.DefaultOptions.Limit5h > 100 {
|
||||
return E.New("credential ", credential.Tag, ": limit_5h must be at most 100")
|
||||
}
|
||||
if credential.DefaultOptions.LimitWeekly > 100 {
|
||||
return E.New("credential ", credential.Tag, ": limit_weekly must be at most 100")
|
||||
}
|
||||
if credential.DefaultOptions.Reserve5h > 0 && credential.DefaultOptions.Limit5h > 0 {
|
||||
return E.New("credential ", credential.Tag, ": reserve_5h and limit_5h are mutually exclusive")
|
||||
}
|
||||
if credential.DefaultOptions.ReserveWeekly > 0 && credential.DefaultOptions.LimitWeekly > 0 {
|
||||
return E.New("credential ", credential.Tag, ": reserve_weekly and limit_weekly are mutually exclusive")
|
||||
}
|
||||
}
|
||||
if credential.Type == "external" {
|
||||
if credential.ExternalOptions.Token == "" {
|
||||
return E.New("credential ", credential.Tag, ": external credential requires token")
|
||||
}
|
||||
if credential.ExternalOptions.Reverse && credential.ExternalOptions.URL == "" {
|
||||
return E.New("credential ", credential.Tag, ": reverse external credential requires url")
|
||||
}
|
||||
}
|
||||
if credential.Type == "balancer" {
|
||||
switch credential.BalancerOptions.Strategy {
|
||||
case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback:
|
||||
default:
|
||||
return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy)
|
||||
}
|
||||
if credential.BalancerOptions.RebalanceThreshold < 0 {
|
||||
return E.New("credential ", credential.Tag, ": rebalance_threshold must not be negative")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
singleCredential := len(options.Credentials) == 1
|
||||
for _, user := range options.Users {
|
||||
if user.Credential == "" && !singleCredential {
|
||||
return E.New("user ", user.Name, " must specify credential in multi-credential mode")
|
||||
}
|
||||
if user.Credential != "" && !tags[user.Credential] {
|
||||
return E.New("user ", user.Name, " references unknown credential: ", user.Credential)
|
||||
}
|
||||
if user.ExternalCredential != "" {
|
||||
if !tags[user.ExternalCredential] {
|
||||
return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential)
|
||||
}
|
||||
if credentialTypes[user.ExternalCredential] != "external" {
|
||||
return E.New("user ", user.Name, ": external_credential must reference an external type credential")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func credentialForUser(
|
||||
userConfigMap map[string]*option.CCMUser,
|
||||
providers map[string]credentialProvider,
|
||||
username string,
|
||||
) (credentialProvider, error) {
|
||||
userConfig, exists := userConfigMap[username]
|
||||
if !exists {
|
||||
return nil, E.New("no credential mapping for user: ", username)
|
||||
}
|
||||
if userConfig.Credential == "" {
|
||||
for _, provider := range providers {
|
||||
return provider, nil
|
||||
}
|
||||
return nil, E.New("no credential available")
|
||||
}
|
||||
provider, exists := providers[userConfig.Credential]
|
||||
if !exists {
|
||||
return nil, E.New("unknown credential: ", userConfig.Credential)
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
147
service/ccm/credential_config_file.go
Normal file
147
service/ccm/credential_config_file.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// claudeCodeConfig represents the persisted config written by Claude Code.
|
||||
//
|
||||
// ref (@anthropic-ai/claude-code @2.1.81):
|
||||
//
|
||||
// ref: cli.js P8() (line 174997) — reads config
|
||||
// ref: cli.js c8() (line 174919) — writes config
|
||||
// ref: cli.js _D() (line 39158-39163) — config file path resolution
|
||||
type claudeCodeConfig struct {
|
||||
UserID string `json:"userID"` // ref: cli.js XL() (line 175325) — random 32-byte hex, generated once
|
||||
OAuthAccount *claudeOAuthAccount `json:"oauthAccount"` // ref: cli.js fP6() / storeOAuthAccountInfo — from /api/oauth/profile
|
||||
}
|
||||
|
||||
type claudeOAuthAccount struct {
|
||||
AccountUUID string `json:"accountUuid,omitempty"`
|
||||
EmailAddress string `json:"emailAddress,omitempty"`
|
||||
OrganizationUUID string `json:"organizationUuid,omitempty"`
|
||||
DisplayName *string `json:"displayName,omitempty"`
|
||||
HasExtraUsageEnabled *bool `json:"hasExtraUsageEnabled,omitempty"`
|
||||
BillingType *string `json:"billingType,omitempty"`
|
||||
AccountCreatedAt *string `json:"accountCreatedAt,omitempty"`
|
||||
SubscriptionCreatedAt *string `json:"subscriptionCreatedAt,omitempty"`
|
||||
}
|
||||
|
||||
// resolveClaudeConfigFile finds the Claude Code config file within the given directory.
|
||||
//
|
||||
// Config file path resolution mirrors cli.js _D() (line 39158-39163):
|
||||
// 1. claudeDirectory/.config.json — newer format, checked first
|
||||
// 2. claudeDirectory/.claude.json — used when CLAUDE_CONFIG_DIR is set
|
||||
// 3. filepath.Dir(claudeDirectory)/.claude.json — default ~/.claude case → ~/.claude.json
|
||||
//
|
||||
// Returns the first path that exists, or "" if none found.
|
||||
func resolveClaudeConfigFile(claudeDirectory string) string {
|
||||
candidates := []string{
|
||||
filepath.Join(claudeDirectory, ".config.json"),
|
||||
filepath.Join(claudeDirectory, claudeCodeLegacyConfigFileName()),
|
||||
filepath.Join(filepath.Dir(claudeDirectory), claudeCodeLegacyConfigFileName()),
|
||||
}
|
||||
for _, candidate := range candidates {
|
||||
_, err := os.Stat(candidate)
|
||||
if err == nil {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func readClaudeCodeConfig(path string) (*claudeCodeConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var config claudeCodeConfig
|
||||
err = json.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func resolveClaudeConfigWritePath(claudeDirectory string) string {
|
||||
if claudeDirectory == "" {
|
||||
return ""
|
||||
}
|
||||
existingPath := resolveClaudeConfigFile(claudeDirectory)
|
||||
if existingPath != "" {
|
||||
return existingPath
|
||||
}
|
||||
if os.Getenv("CLAUDE_CONFIG_DIR") != "" {
|
||||
return filepath.Join(claudeDirectory, claudeCodeLegacyConfigFileName())
|
||||
}
|
||||
defaultClaudeDirectory := filepath.Join(filepath.Dir(claudeDirectory), ".claude")
|
||||
if claudeDirectory != defaultClaudeDirectory {
|
||||
return filepath.Join(claudeDirectory, claudeCodeLegacyConfigFileName())
|
||||
}
|
||||
return filepath.Join(filepath.Dir(claudeDirectory), claudeCodeLegacyConfigFileName())
|
||||
}
|
||||
|
||||
func writeClaudeCodeOAuthAccount(path string, account *claudeOAuthAccount) error {
|
||||
if path == "" || account == nil {
|
||||
return nil
|
||||
}
|
||||
storage := jsonFileStorage{path: path}
|
||||
return writeStorageValue(storage, "oauthAccount", account)
|
||||
}
|
||||
|
||||
func claudeCodeLegacyConfigFileName() string {
|
||||
if os.Getenv("CLAUDE_CODE_CUSTOM_OAUTH_URL") != "" {
|
||||
return ".claude-custom-oauth.json"
|
||||
}
|
||||
return ".claude.json"
|
||||
}
|
||||
|
||||
func cloneClaudeOAuthAccount(account *claudeOAuthAccount) *claudeOAuthAccount {
|
||||
if account == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *account
|
||||
cloned.DisplayName = cloneStringPointer(account.DisplayName)
|
||||
cloned.HasExtraUsageEnabled = cloneBoolPointer(account.HasExtraUsageEnabled)
|
||||
cloned.BillingType = cloneStringPointer(account.BillingType)
|
||||
cloned.AccountCreatedAt = cloneStringPointer(account.AccountCreatedAt)
|
||||
cloned.SubscriptionCreatedAt = cloneStringPointer(account.SubscriptionCreatedAt)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func mergeClaudeOAuthAccount(base *claudeOAuthAccount, update *claudeOAuthAccount) *claudeOAuthAccount {
|
||||
if update == nil {
|
||||
return cloneClaudeOAuthAccount(base)
|
||||
}
|
||||
if base == nil {
|
||||
return cloneClaudeOAuthAccount(update)
|
||||
}
|
||||
merged := cloneClaudeOAuthAccount(base)
|
||||
if update.AccountUUID != "" {
|
||||
merged.AccountUUID = update.AccountUUID
|
||||
}
|
||||
if update.EmailAddress != "" {
|
||||
merged.EmailAddress = update.EmailAddress
|
||||
}
|
||||
if update.OrganizationUUID != "" {
|
||||
merged.OrganizationUUID = update.OrganizationUUID
|
||||
}
|
||||
if update.DisplayName != nil {
|
||||
merged.DisplayName = cloneStringPointer(update.DisplayName)
|
||||
}
|
||||
if update.HasExtraUsageEnabled != nil {
|
||||
merged.HasExtraUsageEnabled = cloneBoolPointer(update.HasExtraUsageEnabled)
|
||||
}
|
||||
if update.BillingType != nil {
|
||||
merged.BillingType = cloneStringPointer(update.BillingType)
|
||||
}
|
||||
if update.AccountCreatedAt != nil {
|
||||
merged.AccountCreatedAt = cloneStringPointer(update.AccountCreatedAt)
|
||||
}
|
||||
if update.SubscriptionCreatedAt != nil {
|
||||
merged.SubscriptionCreatedAt = cloneStringPointer(update.SubscriptionCreatedAt)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
@@ -14,6 +14,11 @@ import (
|
||||
"github.com/keybase/go-keychain"
|
||||
)
|
||||
|
||||
type keychainStorage struct {
|
||||
service string
|
||||
account string
|
||||
}
|
||||
|
||||
func getKeychainServiceName() string {
|
||||
configDirectory := os.Getenv("CLAUDE_CONFIG_DIR")
|
||||
if configDirectory == "" {
|
||||
@@ -69,48 +74,97 @@ func platformReadCredentials(customPath string) (*oauthCredentials, error) {
|
||||
return readCredentialsFromFile(defaultPath)
|
||||
}
|
||||
|
||||
func platformWriteCredentials(oauthCredentials *oauthCredentials, customPath string) error {
|
||||
if customPath != "" {
|
||||
return writeCredentialsToFile(oauthCredentials, customPath)
|
||||
func platformCanWriteCredentials(customPath string) error {
|
||||
if customPath == "" {
|
||||
return nil
|
||||
}
|
||||
return checkCredentialFileWritable(customPath)
|
||||
}
|
||||
|
||||
userInfo, err := getRealUser()
|
||||
if err == nil {
|
||||
data, err := json.Marshal(map[string]any{"claudeAiOauth": oauthCredentials})
|
||||
if err == nil {
|
||||
serviceName := getKeychainServiceName()
|
||||
item := keychain.NewItem()
|
||||
item.SetSecClass(keychain.SecClassGenericPassword)
|
||||
item.SetService(serviceName)
|
||||
item.SetAccount(userInfo.Username)
|
||||
item.SetData(data)
|
||||
item.SetAccessible(keychain.AccessibleWhenUnlocked)
|
||||
|
||||
err = keychain.AddItem(item)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err == keychain.ErrorDuplicateItem {
|
||||
query := keychain.NewItem()
|
||||
query.SetSecClass(keychain.SecClassGenericPassword)
|
||||
query.SetService(serviceName)
|
||||
query.SetAccount(userInfo.Username)
|
||||
|
||||
updateItem := keychain.NewItem()
|
||||
updateItem.SetData(data)
|
||||
|
||||
updateErr := keychain.UpdateItem(query, updateItem)
|
||||
if updateErr == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
func platformWriteCredentials(credentials *oauthCredentials, customPath string) error {
|
||||
if customPath != "" {
|
||||
return writeCredentialsToFile(credentials, customPath)
|
||||
}
|
||||
|
||||
defaultPath, err := getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeCredentialsToFile(oauthCredentials, defaultPath)
|
||||
fileStorage := jsonFileStorage{path: defaultPath}
|
||||
|
||||
userInfo, err := getRealUser()
|
||||
if err != nil {
|
||||
return writeCredentialsToFile(credentials, defaultPath)
|
||||
}
|
||||
return persistStorageValue(keychainStorage{
|
||||
service: getKeychainServiceName(),
|
||||
account: userInfo.Username,
|
||||
}, fileStorage, "claudeAiOauth", credentials)
|
||||
}
|
||||
|
||||
func (s keychainStorage) readContainer() (map[string]json.RawMessage, bool, error) {
|
||||
query := keychain.NewItem()
|
||||
query.SetSecClass(keychain.SecClassGenericPassword)
|
||||
query.SetService(s.service)
|
||||
query.SetAccount(s.account)
|
||||
query.SetMatchLimit(keychain.MatchLimitOne)
|
||||
query.SetReturnData(true)
|
||||
|
||||
results, err := keychain.QueryItem(query)
|
||||
if err != nil {
|
||||
if err == keychain.ErrorItemNotFound {
|
||||
return make(map[string]json.RawMessage), false, nil
|
||||
}
|
||||
return nil, false, E.Cause(err, "query keychain")
|
||||
}
|
||||
if len(results) != 1 {
|
||||
return make(map[string]json.RawMessage), false, nil
|
||||
}
|
||||
|
||||
container := make(map[string]json.RawMessage)
|
||||
if len(results[0].Data) == 0 {
|
||||
return container, true, nil
|
||||
}
|
||||
if err := json.Unmarshal(results[0].Data, &container); err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
return container, true, nil
|
||||
}
|
||||
|
||||
func (s keychainStorage) writeContainer(container map[string]json.RawMessage) error {
|
||||
data, err := json.Marshal(container)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
item := keychain.NewItem()
|
||||
item.SetSecClass(keychain.SecClassGenericPassword)
|
||||
item.SetService(s.service)
|
||||
item.SetAccount(s.account)
|
||||
item.SetData(data)
|
||||
item.SetAccessible(keychain.AccessibleWhenUnlocked)
|
||||
err = keychain.AddItem(item)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if err != keychain.ErrorDuplicateItem {
|
||||
return err
|
||||
}
|
||||
|
||||
updateQuery := keychain.NewItem()
|
||||
updateQuery.SetSecClass(keychain.SecClassGenericPassword)
|
||||
updateQuery.SetService(s.service)
|
||||
updateQuery.SetAccount(s.account)
|
||||
|
||||
updateItem := keychain.NewItem()
|
||||
updateItem.SetData(data)
|
||||
return keychain.UpdateItem(updateQuery, updateItem)
|
||||
}
|
||||
|
||||
func (s keychainStorage) delete() error {
|
||||
err := keychain.DeleteGenericPasswordItem(s.service, s.account)
|
||||
if err != nil && err != keychain.ErrorItemNotFound {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
1307
service/ccm/credential_default.go
Normal file
1307
service/ccm/credential_default.go
Normal file
File diff suppressed because it is too large
Load Diff
245
service/ccm/credential_default_test.go
Normal file
245
service/ccm/credential_default_test.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetAccessTokenMarksUnavailableWhenLockFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := t.TempDir()
|
||||
credentialPath := filepath.Join(directory, ".credentials.json")
|
||||
credentials := &oauthCredentials{
|
||||
AccessToken: "old-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(time.Hour).UnixMilli(),
|
||||
Scopes: []string{"user:profile", "user:inference"},
|
||||
SubscriptionType: optionalStringPointer("max"),
|
||||
RateLimitTier: optionalStringPointer("default_claude_max_20x"),
|
||||
}
|
||||
writeTestCredentials(t, credentialPath, credentials)
|
||||
|
||||
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
t.Fatal("refresh should not be attempted when lock acquisition fails")
|
||||
return nil, nil
|
||||
}))
|
||||
if err := credential.reloadCredentials(true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expiredCredentials := cloneCredentials(credentials)
|
||||
expiredCredentials.ExpiresAt = time.Now().Add(-time.Minute).UnixMilli()
|
||||
writeTestCredentials(t, credentialPath, expiredCredentials)
|
||||
credential.absorbCredentials(expiredCredentials)
|
||||
|
||||
credential.acquireLock = func(string) (func(), error) {
|
||||
return nil, errors.New("permission denied")
|
||||
}
|
||||
|
||||
_, err := credential.getAccessToken()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when lock acquisition fails, got nil")
|
||||
}
|
||||
if credential.isUsable() {
|
||||
t.Fatal("credential should be marked unavailable after lock failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccessTokenMarksUnavailableOnUnwritableFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := t.TempDir()
|
||||
credentialPath := filepath.Join(directory, ".credentials.json")
|
||||
credentials := &oauthCredentials{
|
||||
AccessToken: "old-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(time.Hour).UnixMilli(),
|
||||
Scopes: []string{"user:profile", "user:inference"},
|
||||
}
|
||||
writeTestCredentials(t, credentialPath, credentials)
|
||||
|
||||
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
t.Fatal("refresh should not be attempted when file is not writable")
|
||||
return nil, nil
|
||||
}))
|
||||
if err := credential.reloadCredentials(true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expiredCredentials := cloneCredentials(credentials)
|
||||
expiredCredentials.ExpiresAt = time.Now().Add(-time.Minute).UnixMilli()
|
||||
writeTestCredentials(t, credentialPath, expiredCredentials)
|
||||
credential.absorbCredentials(expiredCredentials)
|
||||
|
||||
os.Chmod(credentialPath, 0o444)
|
||||
t.Cleanup(func() { os.Chmod(credentialPath, 0o644) })
|
||||
|
||||
_, err := credential.getAccessToken()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when credential file is not writable, got nil")
|
||||
}
|
||||
if credential.isUsable() {
|
||||
t.Fatal("credential should be marked unavailable after write permission failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccessTokenAbsorbsRefreshDoneByAnotherProcess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := t.TempDir()
|
||||
credentialPath := filepath.Join(directory, ".credentials.json")
|
||||
oldCredentials := &oauthCredentials{
|
||||
AccessToken: "old-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
|
||||
Scopes: []string{"user:profile", "user:inference"},
|
||||
SubscriptionType: optionalStringPointer("max"),
|
||||
RateLimitTier: optionalStringPointer("default_claude_max_20x"),
|
||||
}
|
||||
writeTestCredentials(t, credentialPath, oldCredentials)
|
||||
|
||||
newCredentials := cloneCredentials(oldCredentials)
|
||||
newCredentials.AccessToken = "new-token"
|
||||
newCredentials.ExpiresAt = time.Now().Add(time.Hour).UnixMilli()
|
||||
transport := roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
if request.URL.Path == "/v1/oauth/token" {
|
||||
writeTestCredentials(t, credentialPath, newCredentials)
|
||||
return newJSONResponse(http.StatusInternalServerError, `{"error":"boom"}`), nil
|
||||
}
|
||||
t.Fatalf("unexpected path %s", request.URL.Path)
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
credential := newTestDefaultCredential(t, credentialPath, transport)
|
||||
if err := credential.reloadCredentials(true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
token, err := credential.getAccessToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if token != "new-token" {
|
||||
t.Fatalf("expected refreshed token from disk, got %q", token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomCredentialPathDoesNotEnableClaudeConfigSync(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := t.TempDir()
|
||||
credentialPath := filepath.Join(directory, ".credentials.json")
|
||||
writeTestCredentials(t, credentialPath, &oauthCredentials{
|
||||
AccessToken: "token",
|
||||
ExpiresAt: time.Now().Add(time.Hour).UnixMilli(),
|
||||
Scopes: []string{"user:profile"},
|
||||
})
|
||||
|
||||
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
t.Fatalf("unexpected request to %s", request.URL.Path)
|
||||
return nil, nil
|
||||
}))
|
||||
if err := credential.reloadCredentials(true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
token, err := credential.getAccessToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if token != "token" {
|
||||
t.Fatalf("expected token, got %q", token)
|
||||
}
|
||||
if credential.shouldUseClaudeConfig() {
|
||||
t.Fatal("custom credential path should not enable Claude config sync")
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(directory, ".claude.json")); !os.IsNotExist(err) {
|
||||
t.Fatalf("did not expect config file to be created, stat err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCredentialHydratesProfileAndWritesConfig(t *testing.T) {
|
||||
configDir := t.TempDir()
|
||||
credentialPath := filepath.Join(configDir, ".credentials.json")
|
||||
|
||||
writeTestCredentials(t, credentialPath, &oauthCredentials{
|
||||
AccessToken: "old-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
|
||||
Scopes: []string{"user:profile", "user:inference"},
|
||||
})
|
||||
|
||||
transport := roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
switch request.URL.Path {
|
||||
case "/v1/oauth/token":
|
||||
return newJSONResponse(http.StatusOK, `{
|
||||
"access_token":"new-token",
|
||||
"refresh_token":"new-refresh",
|
||||
"expires_in":3600,
|
||||
"account":{"uuid":"account","email_address":"user@example.com"},
|
||||
"organization":{"uuid":"org"}
|
||||
}`), nil
|
||||
case "/api/oauth/profile":
|
||||
return newJSONResponse(http.StatusOK, `{
|
||||
"account":{
|
||||
"uuid":"account",
|
||||
"email":"user@example.com",
|
||||
"display_name":"User",
|
||||
"created_at":"2024-01-01T00:00:00Z"
|
||||
},
|
||||
"organization":{
|
||||
"uuid":"org",
|
||||
"organization_type":"claude_max",
|
||||
"rate_limit_tier":"default_claude_max_20x",
|
||||
"has_extra_usage_enabled":true,
|
||||
"billing_type":"individual",
|
||||
"subscription_created_at":"2024-01-02T00:00:00Z"
|
||||
}
|
||||
}`), nil
|
||||
default:
|
||||
t.Fatalf("unexpected path %s", request.URL.Path)
|
||||
return nil, nil
|
||||
}
|
||||
})
|
||||
|
||||
credential := newTestDefaultCredential(t, credentialPath, transport)
|
||||
credential.syncClaudeConfig = true
|
||||
credential.claudeDirectory = configDir
|
||||
credential.claudeConfigPath = resolveClaudeConfigWritePath(configDir)
|
||||
if err := credential.reloadCredentials(true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
token, err := credential.getAccessToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if token != "new-token" {
|
||||
t.Fatalf("expected refreshed token, got %q", token)
|
||||
}
|
||||
|
||||
updatedCredentials := readTestCredentials(t, credentialPath)
|
||||
if updatedCredentials.SubscriptionType == nil || *updatedCredentials.SubscriptionType != "max" {
|
||||
t.Fatalf("expected subscription type to be persisted, got %#v", updatedCredentials.SubscriptionType)
|
||||
}
|
||||
if updatedCredentials.RateLimitTier == nil || *updatedCredentials.RateLimitTier != "default_claude_max_20x" {
|
||||
t.Fatalf("expected rate limit tier to be persisted, got %#v", updatedCredentials.RateLimitTier)
|
||||
}
|
||||
|
||||
configPath := tempConfigPath(t, configDir)
|
||||
config, err := readClaudeCodeConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.OAuthAccount == nil || config.OAuthAccount.AccountUUID != "account" || config.OAuthAccount.EmailAddress != "user@example.com" {
|
||||
t.Fatalf("unexpected oauth account: %#v", config.OAuthAccount)
|
||||
}
|
||||
if config.OAuthAccount.BillingType == nil || *config.OAuthAccount.BillingType != "individual" {
|
||||
t.Fatalf("expected billing type to be hydrated, got %#v", config.OAuthAccount.BillingType)
|
||||
}
|
||||
}
|
||||
988
service/ccm/credential_external.go
Normal file
988
service/ccm/credential_external.go
Normal file
@@ -0,0 +1,988 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
const reverseProxyBaseURL = "http://reverse-proxy"
|
||||
|
||||
type externalCredential struct {
|
||||
tag string
|
||||
baseURL string
|
||||
token string
|
||||
forwardHTTPClient *http.Client
|
||||
state credentialState
|
||||
stateAccess sync.RWMutex
|
||||
pollAccess sync.Mutex
|
||||
usageTracker *AggregatedUsage
|
||||
logger log.ContextLogger
|
||||
|
||||
statusSubscriber *observable.Subscriber[struct{}]
|
||||
|
||||
interrupted bool
|
||||
requestContext context.Context
|
||||
cancelRequests context.CancelFunc
|
||||
requestAccess sync.Mutex
|
||||
|
||||
// Reverse proxy fields
|
||||
reverse bool
|
||||
reverseHTTPClient *http.Client
|
||||
reverseSession *yamux.Session
|
||||
reverseAccess sync.RWMutex
|
||||
closed bool
|
||||
reverseContext context.Context
|
||||
reverseCancel context.CancelFunc
|
||||
connectorDialer N.Dialer
|
||||
connectorDestination M.Socksaddr
|
||||
connectorRequestPath string
|
||||
connectorURL *url.URL
|
||||
connectorTLS *stdTLS.Config
|
||||
reverseService http.Handler
|
||||
}
|
||||
|
||||
type statusStreamResult struct {
|
||||
duration time.Duration
|
||||
frames int
|
||||
}
|
||||
|
||||
func externalCredentialURLPort(parsedURL *url.URL) uint16 {
|
||||
portString := parsedURL.Port()
|
||||
if portString != "" {
|
||||
port, err := strconv.ParseUint(portString, 10, 16)
|
||||
if err == nil {
|
||||
return uint16(port)
|
||||
}
|
||||
}
|
||||
if parsedURL.Scheme == "https" {
|
||||
return 443
|
||||
}
|
||||
return 80
|
||||
}
|
||||
|
||||
func externalCredentialServerPort(parsedURL *url.URL, configuredPort uint16) uint16 {
|
||||
if configuredPort != 0 {
|
||||
return configuredPort
|
||||
}
|
||||
return externalCredentialURLPort(parsedURL)
|
||||
}
|
||||
|
||||
func externalCredentialBaseURL(parsedURL *url.URL) string {
|
||||
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
|
||||
if parsedURL.Path != "" && parsedURL.Path != "/" {
|
||||
baseURL += parsedURL.Path
|
||||
}
|
||||
if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
|
||||
baseURL = baseURL[:len(baseURL)-1]
|
||||
}
|
||||
return baseURL
|
||||
}
|
||||
|
||||
func externalCredentialReversePath(parsedURL *url.URL, endpointPath string) string {
|
||||
pathPrefix := parsedURL.EscapedPath()
|
||||
if pathPrefix == "/" {
|
||||
pathPrefix = ""
|
||||
} else {
|
||||
pathPrefix = strings.TrimSuffix(pathPrefix, "/")
|
||||
}
|
||||
return pathPrefix + endpointPath
|
||||
}
|
||||
|
||||
func newExternalCredential(ctx context.Context, tag string, options option.CCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) {
|
||||
requestContext, cancelRequests := context.WithCancel(context.Background())
|
||||
reverseContext, reverseCancel := context.WithCancel(context.Background())
|
||||
|
||||
credential := &externalCredential{
|
||||
tag: tag,
|
||||
token: options.Token,
|
||||
logger: logger,
|
||||
requestContext: requestContext,
|
||||
cancelRequests: cancelRequests,
|
||||
reverse: options.Reverse,
|
||||
reverseContext: reverseContext,
|
||||
reverseCancel: reverseCancel,
|
||||
}
|
||||
|
||||
if options.URL == "" {
|
||||
// Receiver mode: no URL, wait for reverse connection
|
||||
credential.baseURL = reverseProxyBaseURL
|
||||
credential.forwardHTTPClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
ForceAttemptHTTP2: false,
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return credential.openReverseConnection(ctx)
|
||||
},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// Normal or connector mode: has URL
|
||||
parsedURL, err := url.Parse(options.URL)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "parse url for credential ", tag)
|
||||
}
|
||||
|
||||
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
Options: option.DialerOptions{
|
||||
Detour: options.Detour,
|
||||
},
|
||||
RemoteIsDomain: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create dialer for credential ", tag)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if options.Server != "" {
|
||||
destination := M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort))
|
||||
return credentialDialer.DialContext(ctx, network, destination)
|
||||
}
|
||||
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
}
|
||||
|
||||
if parsedURL.Scheme == "https" {
|
||||
transport.TLSClientConfig = &stdTLS.Config{
|
||||
ServerName: parsedURL.Hostname(),
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
credential.baseURL = externalCredentialBaseURL(parsedURL)
|
||||
|
||||
if options.Reverse {
|
||||
// Connector mode: we dial out to serve, not to proxy
|
||||
credential.connectorDialer = credentialDialer
|
||||
if options.Server != "" {
|
||||
credential.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort))
|
||||
} else {
|
||||
credential.connectorDestination = M.ParseSocksaddrHostPort(parsedURL.Hostname(), externalCredentialURLPort(parsedURL))
|
||||
}
|
||||
credential.connectorRequestPath = externalCredentialReversePath(parsedURL, "/ccm/v1/reverse")
|
||||
credential.connectorURL = parsedURL
|
||||
if parsedURL.Scheme == "https" {
|
||||
credential.connectorTLS = &stdTLS.Config{
|
||||
ServerName: parsedURL.Hostname(),
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
}
|
||||
}
|
||||
credential.forwardHTTPClient = &http.Client{Transport: transport}
|
||||
} else {
|
||||
// Normal mode: standard HTTP client for proxying
|
||||
credential.forwardHTTPClient = &http.Client{Transport: transport}
|
||||
credential.reverseHTTPClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
ForceAttemptHTTP2: false,
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return credential.openReverseConnection(ctx)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if options.UsagesPath != "" {
|
||||
credential.usageTracker = &AggregatedUsage{
|
||||
LastUpdated: time.Now(),
|
||||
Combinations: make([]CostCombination, 0),
|
||||
filePath: options.UsagesPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
return credential, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) {
|
||||
c.statusSubscriber = subscriber
|
||||
}
|
||||
|
||||
func (c *externalCredential) emitStatusUpdate() {
|
||||
if c.statusSubscriber != nil {
|
||||
c.statusSubscriber.Emit(struct{}{})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) start() error {
|
||||
if c.usageTracker != nil {
|
||||
err := c.usageTracker.Load()
|
||||
if err != nil {
|
||||
c.logger.Warn("load usage statistics for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
if c.reverse && c.connectorURL != nil {
|
||||
go c.connectorLoop()
|
||||
} else {
|
||||
go c.statusStreamLoop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) tagName() string {
|
||||
return c.tag
|
||||
}
|
||||
|
||||
func (c *externalCredential) isExternal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *externalCredential) isAvailable() bool {
|
||||
return c.unavailableError() == nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) isUsable() bool {
|
||||
if !c.isAvailable() {
|
||||
return false
|
||||
}
|
||||
c.stateAccess.RLock()
|
||||
if c.state.consecutivePollFailures > 0 {
|
||||
c.stateAccess.RUnlock()
|
||||
return false
|
||||
}
|
||||
if !c.state.upstreamRejectedUntil.IsZero() && time.Now().Before(c.state.upstreamRejectedUntil) {
|
||||
c.stateAccess.RUnlock()
|
||||
return false
|
||||
}
|
||||
if c.state.hardRateLimited {
|
||||
if time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.stateAccess.RUnlock()
|
||||
return false
|
||||
}
|
||||
c.stateAccess.RUnlock()
|
||||
c.stateAccess.Lock()
|
||||
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
// No reserve for external: only 100% is unusable
|
||||
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
|
||||
c.stateAccess.Unlock()
|
||||
return usable
|
||||
}
|
||||
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
|
||||
c.stateAccess.RUnlock()
|
||||
return usable
|
||||
}
|
||||
|
||||
func (c *externalCredential) fiveHourUtilization() float64 {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.fiveHourUtilization
|
||||
}
|
||||
|
||||
func (c *externalCredential) weeklyUtilization() float64 {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.weeklyUtilization
|
||||
}
|
||||
|
||||
func (c *externalCredential) fiveHourCap() float64 {
|
||||
return 100
|
||||
}
|
||||
|
||||
func (c *externalCredential) weeklyCap() float64 {
|
||||
return 100
|
||||
}
|
||||
|
||||
func (c *externalCredential) planWeight() float64 {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
if c.state.remotePlanWeight > 0 {
|
||||
return c.state.remotePlanWeight
|
||||
}
|
||||
return 10
|
||||
}
|
||||
|
||||
func (c *externalCredential) fiveHourResetTime() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.fiveHourReset
|
||||
}
|
||||
|
||||
func (c *externalCredential) weeklyResetTime() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.weeklyReset
|
||||
}
|
||||
|
||||
func (c *externalCredential) markRateLimited(resetAt time.Time) {
|
||||
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
|
||||
c.stateAccess.Lock()
|
||||
c.state.hardRateLimited = true
|
||||
c.state.rateLimitResetAt = resetAt
|
||||
c.state.setAvailability(availabilityStateRateLimited, availabilityReasonHardRateLimit, resetAt)
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
func (c *externalCredential) markUpstreamRejected() {
|
||||
c.logger.Warn("upstream rejected credential ", c.tag, ", marking unavailable for ", log.FormatDuration(defaultPollInterval))
|
||||
c.stateAccess.Lock()
|
||||
c.state.upstreamRejectedUntil = time.Now().Add(defaultPollInterval)
|
||||
c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonUpstreamRejected, c.state.upstreamRejectedUntil)
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
func (c *externalCredential) earliestReset() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
if c.state.hardRateLimited {
|
||||
return c.state.rateLimitResetAt
|
||||
}
|
||||
earliest := c.state.fiveHourReset
|
||||
if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) {
|
||||
earliest = c.state.weeklyReset
|
||||
}
|
||||
return earliest
|
||||
}
|
||||
|
||||
func (c *externalCredential) unavailableError() error {
|
||||
if c.baseURL == reverseProxyBaseURL {
|
||||
session := c.getReverseSession()
|
||||
if session == nil || session.IsClosed() {
|
||||
return E.New("credential ", c.tag, " is unavailable: reverse connection not established")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) getAccessToken() (string, error) {
|
||||
return c.token, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, _ http.Header) (*http.Request, error) {
|
||||
baseURL := c.baseURL
|
||||
if c.reverseHTTPClient != nil {
|
||||
session := c.getReverseSession()
|
||||
if session != nil && !session.IsClosed() {
|
||||
baseURL = reverseProxyBaseURL
|
||||
}
|
||||
}
|
||||
proxyURL := baseURL + original.URL.RequestURI()
|
||||
var body io.Reader
|
||||
if bodyBytes != nil {
|
||||
body = bytes.NewReader(bodyBytes)
|
||||
} else {
|
||||
body = original.Body
|
||||
}
|
||||
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, values := range original.Header {
|
||||
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && !isAPIKeyHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+c.token)
|
||||
|
||||
return proxyRequest, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Conn, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
session := c.getReverseSession()
|
||||
if session == nil || session.IsClosed() {
|
||||
return nil, E.New("reverse connection not established for ", c.tag)
|
||||
}
|
||||
conn, err := session.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
conn.Close()
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
|
||||
c.stateAccess.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
oldPlanWeight := c.state.remotePlanWeight
|
||||
oldFiveHourReset := c.state.fiveHourReset
|
||||
oldWeeklyReset := c.state.weeklyReset
|
||||
hadData := false
|
||||
|
||||
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists {
|
||||
hadData = true
|
||||
c.state.fiveHourReset = value
|
||||
}
|
||||
if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" {
|
||||
value, err := strconv.ParseFloat(utilization, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
c.state.fiveHourUtilization = value * 100
|
||||
}
|
||||
}
|
||||
|
||||
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists {
|
||||
hadData = true
|
||||
c.state.weeklyReset = value
|
||||
}
|
||||
if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" {
|
||||
value, err := strconv.ParseFloat(utilization, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
c.state.weeklyUtilization = value * 100
|
||||
}
|
||||
}
|
||||
if planWeight := headers.Get("X-CCM-Plan-Weight"); planWeight != "" {
|
||||
value, err := strconv.ParseFloat(planWeight, 64)
|
||||
if err == nil && value > 0 {
|
||||
c.state.remotePlanWeight = value
|
||||
}
|
||||
}
|
||||
if hadData {
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.upstreamRejectedUntil = time.Time{}
|
||||
c.state.lastUpdated = time.Now()
|
||||
c.state.noteSnapshotData()
|
||||
}
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
resetSuffix := ""
|
||||
if !c.state.weeklyReset.IsZero() {
|
||||
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
|
||||
}
|
||||
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
utilizationChanged := c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly
|
||||
planWeightChanged := c.state.remotePlanWeight != oldPlanWeight
|
||||
resetChanged := c.state.fiveHourReset != oldFiveHourReset || c.state.weeklyReset != oldWeeklyReset
|
||||
shouldEmit := (hadData && (utilizationChanged || resetChanged)) || planWeightChanged
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) checkTransitionLocked() bool {
|
||||
upstreamRejected := !c.state.upstreamRejectedUntil.IsZero() && time.Now().Before(c.state.upstreamRejectedUntil)
|
||||
unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100 || c.state.consecutivePollFailures > 0 || upstreamRejected
|
||||
if unusable && !c.interrupted {
|
||||
c.interrupted = true
|
||||
return true
|
||||
}
|
||||
if !unusable && c.interrupted {
|
||||
c.interrupted = false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *externalCredential) wrapRequestContext(parent context.Context) *credentialRequestContext {
|
||||
c.requestAccess.Lock()
|
||||
credentialContext := c.requestContext
|
||||
c.requestAccess.Unlock()
|
||||
derived, cancel := context.WithCancel(parent)
|
||||
stop := context.AfterFunc(credentialContext, func() {
|
||||
cancel()
|
||||
})
|
||||
return &credentialRequestContext{
|
||||
Context: derived,
|
||||
releaseFuncs: []func() bool{stop},
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) interruptConnections() {
|
||||
c.logger.Warn("interrupting connections for ", c.tag)
|
||||
c.requestAccess.Lock()
|
||||
c.cancelRequests()
|
||||
c.requestContext, c.cancelRequests = context.WithCancel(context.Background())
|
||||
c.requestAccess.Unlock()
|
||||
}
|
||||
|
||||
func (c *externalCredential) doPollUsageRequest(ctx context.Context) (*http.Response, error) {
|
||||
buildRequest := func(baseURL string) func() (*http.Request, error) {
|
||||
return func() (*http.Request, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/ccm/v1/status", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer "+c.token)
|
||||
return request, nil
|
||||
}
|
||||
}
|
||||
// Try reverse transport first (single attempt, no retry)
|
||||
if c.reverseHTTPClient != nil {
|
||||
session := c.getReverseSession()
|
||||
if session != nil && !session.IsClosed() {
|
||||
request, err := buildRequest(reverseProxyBaseURL)()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reverseClient := &http.Client{
|
||||
Transport: c.reverseHTTPClient.Transport,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
response, err := reverseClient.Do(request)
|
||||
if err == nil {
|
||||
return response, nil
|
||||
}
|
||||
// Reverse failed, fall through to forward if available
|
||||
}
|
||||
}
|
||||
// Forward transport with retries
|
||||
if c.forwardHTTPClient != nil {
|
||||
forwardClient := &http.Client{
|
||||
Transport: c.forwardHTTPClient.Transport,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
return doHTTPWithRetry(ctx, forwardClient, buildRequest(c.baseURL))
|
||||
}
|
||||
return nil, E.New("no transport available")
|
||||
}
|
||||
|
||||
func (c *externalCredential) pollUsage() {
|
||||
if !c.pollAccess.TryLock() {
|
||||
return
|
||||
}
|
||||
defer c.pollAccess.Unlock()
|
||||
defer c.markUsagePollAttempted()
|
||||
|
||||
ctx := c.getReverseContext()
|
||||
response, err := c.doPollUsageRequest(ctx)
|
||||
if err != nil {
|
||||
c.logger.Error("poll usage for ", c.tag, ": ", err)
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
c.logger.Error("poll usage for ", c.tag, ": read body: ", err)
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
var rawFields map[string]json.RawMessage
|
||||
err = json.Unmarshal(body, &rawFields)
|
||||
if err != nil {
|
||||
c.logger.Error("poll usage for ", c.tag, ": decode: ", err)
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
if rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil ||
|
||||
rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil ||
|
||||
rawFields["plan_weight"] == nil {
|
||||
c.logger.Error("poll usage for ", c.tag, ": invalid response")
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
var statusResponse statusPayload
|
||||
err = json.Unmarshal(body, &statusResponse)
|
||||
if err != nil {
|
||||
c.logger.Error("poll usage for ", c.tag, ": decode: ", err)
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
|
||||
c.stateAccess.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.upstreamRejectedUntil = time.Time{}
|
||||
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
|
||||
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
|
||||
if statusResponse.PlanWeight > 0 {
|
||||
c.state.remotePlanWeight = statusResponse.PlanWeight
|
||||
}
|
||||
if statusResponse.FiveHourReset > 0 {
|
||||
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
|
||||
}
|
||||
if statusResponse.WeeklyReset > 0 {
|
||||
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
|
||||
}
|
||||
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
resetSuffix := ""
|
||||
if !c.state.weeklyReset.IsZero() {
|
||||
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
|
||||
}
|
||||
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
func (c *externalCredential) statusStreamLoop() {
|
||||
var consecutiveFailures int
|
||||
ctx := c.getReverseContext()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
result, err := c.connectStatusStream(ctx)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if !shouldRetryStatusStreamError(err) {
|
||||
c.logger.Warn("status stream for ", c.tag, " disconnected: ", err, ", not retrying")
|
||||
return
|
||||
}
|
||||
var backoff time.Duration
|
||||
consecutiveFailures, backoff = c.nextStatusStreamBackoff(result, consecutiveFailures)
|
||||
c.logger.Debug("status stream for ", c.tag, " disconnected: ", err, ", reconnecting in ", backoff)
|
||||
timer := time.NewTimer(backoff)
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) connectStatusStream(ctx context.Context) (statusStreamResult, error) {
|
||||
startTime := time.Now()
|
||||
result := statusStreamResult{}
|
||||
response, err := c.doStreamStatusRequest(ctx)
|
||||
if err != nil {
|
||||
result.duration = time.Since(startTime)
|
||||
return result, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
result.duration = time.Since(startTime)
|
||||
return result, E.New("status ", response.StatusCode, " ", string(body))
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(response.Body)
|
||||
for {
|
||||
var rawMessage json.RawMessage
|
||||
err = decoder.Decode(&rawMessage)
|
||||
if err != nil {
|
||||
result.duration = time.Since(startTime)
|
||||
return result, err
|
||||
}
|
||||
var rawFields map[string]json.RawMessage
|
||||
err = json.Unmarshal(rawMessage, &rawFields)
|
||||
if err != nil {
|
||||
result.duration = time.Since(startTime)
|
||||
return result, E.Cause(err, "decode status frame")
|
||||
}
|
||||
if rawFields["five_hour_utilization"] == nil || rawFields["five_hour_reset"] == nil ||
|
||||
rawFields["weekly_utilization"] == nil || rawFields["weekly_reset"] == nil ||
|
||||
rawFields["plan_weight"] == nil {
|
||||
result.duration = time.Since(startTime)
|
||||
return result, E.New("invalid response")
|
||||
}
|
||||
var statusResponse statusPayload
|
||||
err = json.Unmarshal(rawMessage, &statusResponse)
|
||||
if err != nil {
|
||||
result.duration = time.Since(startTime)
|
||||
return result, E.Cause(err, "decode status frame")
|
||||
}
|
||||
|
||||
c.stateAccess.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.upstreamRejectedUntil = time.Time{}
|
||||
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
|
||||
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
|
||||
if statusResponse.PlanWeight > 0 {
|
||||
c.state.remotePlanWeight = statusResponse.PlanWeight
|
||||
}
|
||||
if statusResponse.FiveHourReset > 0 {
|
||||
c.state.fiveHourReset = time.Unix(statusResponse.FiveHourReset, 0)
|
||||
}
|
||||
if statusResponse.WeeklyReset > 0 {
|
||||
c.state.weeklyReset = time.Unix(statusResponse.WeeklyReset, 0)
|
||||
}
|
||||
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
resetSuffix := ""
|
||||
if !c.state.weeklyReset.IsZero() {
|
||||
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
|
||||
}
|
||||
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
result.frames++
|
||||
c.markUsageStreamUpdated()
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func shouldRetryStatusStreamError(err error) bool {
|
||||
return errors.Is(err, io.ErrUnexpectedEOF) || E.IsClosedOrCanceled(err)
|
||||
}
|
||||
|
||||
func (c *externalCredential) nextStatusStreamBackoff(result statusStreamResult, consecutiveFailures int) (int, time.Duration) {
|
||||
if result.duration >= connectorBackoffResetThreshold {
|
||||
consecutiveFailures = 0
|
||||
}
|
||||
consecutiveFailures++
|
||||
return consecutiveFailures, connectorBackoff(consecutiveFailures)
|
||||
}
|
||||
|
||||
func (c *externalCredential) doStreamStatusRequest(ctx context.Context) (*http.Response, error) {
|
||||
buildRequest := func(baseURL string) (*http.Request, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/ccm/v1/status?watch=true", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer "+c.token)
|
||||
return request, nil
|
||||
}
|
||||
if c.reverseHTTPClient != nil {
|
||||
session := c.getReverseSession()
|
||||
if session != nil && !session.IsClosed() {
|
||||
request, err := buildRequest(reverseProxyBaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response, err := c.reverseHTTPClient.Do(request)
|
||||
if err == nil {
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if c.forwardHTTPClient != nil {
|
||||
request, err := buildRequest(c.baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.forwardHTTPClient.Do(request)
|
||||
}
|
||||
return nil, E.New("no transport available")
|
||||
}
|
||||
|
||||
func (c *externalCredential) lastUpdatedTime() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.lastUpdated
|
||||
}
|
||||
|
||||
func (c *externalCredential) hasSnapshotData() bool {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.hasSnapshotData()
|
||||
}
|
||||
|
||||
func (c *externalCredential) availabilityStatus() availabilityStatus {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.currentAvailability()
|
||||
}
|
||||
|
||||
|
||||
func (c *externalCredential) markUsageStreamUpdated() {
|
||||
c.stateAccess.Lock()
|
||||
defer c.stateAccess.Unlock()
|
||||
c.state.lastUpdated = time.Now()
|
||||
}
|
||||
|
||||
func (c *externalCredential) markUsagePollAttempted() {
|
||||
c.stateAccess.Lock()
|
||||
defer c.stateAccess.Unlock()
|
||||
c.state.lastUpdated = time.Now()
|
||||
}
|
||||
|
||||
func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration {
|
||||
return baseInterval
|
||||
}
|
||||
|
||||
func (c *externalCredential) incrementPollFailures() {
|
||||
c.stateAccess.Lock()
|
||||
c.state.consecutivePollFailures++
|
||||
c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonPollFailed, time.Time{})
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage {
|
||||
return c.usageTracker
|
||||
}
|
||||
|
||||
func (c *externalCredential) httpClient() *http.Client {
|
||||
if c.reverseHTTPClient != nil {
|
||||
session := c.getReverseSession()
|
||||
if session != nil && !session.IsClosed() {
|
||||
return c.reverseHTTPClient
|
||||
}
|
||||
}
|
||||
return c.forwardHTTPClient
|
||||
}
|
||||
|
||||
func (c *externalCredential) close() {
|
||||
var session *yamux.Session
|
||||
c.reverseAccess.Lock()
|
||||
if !c.closed {
|
||||
c.closed = true
|
||||
if c.reverseCancel != nil {
|
||||
c.reverseCancel()
|
||||
}
|
||||
session = c.reverseSession
|
||||
c.reverseSession = nil
|
||||
}
|
||||
c.reverseAccess.Unlock()
|
||||
if session != nil {
|
||||
session.Close()
|
||||
}
|
||||
if c.usageTracker != nil {
|
||||
c.usageTracker.cancelPendingSave()
|
||||
err := c.usageTracker.Save()
|
||||
if err != nil {
|
||||
c.logger.Error("save usage statistics for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) getReverseSession() *yamux.Session {
|
||||
c.reverseAccess.RLock()
|
||||
defer c.reverseAccess.RUnlock()
|
||||
return c.reverseSession
|
||||
}
|
||||
|
||||
func (c *externalCredential) setReverseSession(session *yamux.Session) bool {
|
||||
var emitStatus bool
|
||||
var restartStatusStream bool
|
||||
var triggerUsageRefresh bool
|
||||
c.reverseAccess.Lock()
|
||||
if c.closed {
|
||||
c.reverseAccess.Unlock()
|
||||
return false
|
||||
}
|
||||
wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
|
||||
old := c.reverseSession
|
||||
c.reverseSession = session
|
||||
isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
|
||||
emitStatus = wasAvailable != isAvailable
|
||||
if isAvailable && !wasAvailable {
|
||||
c.reverseCancel()
|
||||
c.reverseContext, c.reverseCancel = context.WithCancel(context.Background())
|
||||
restartStatusStream = true
|
||||
triggerUsageRefresh = true
|
||||
}
|
||||
c.reverseAccess.Unlock()
|
||||
if old != nil {
|
||||
old.Close()
|
||||
}
|
||||
if restartStatusStream {
|
||||
c.logger.Debug("poll usage for ", c.tag, ": reverse session ready, restarting status stream")
|
||||
go c.statusStreamLoop()
|
||||
}
|
||||
if triggerUsageRefresh {
|
||||
go c.pollUsage()
|
||||
}
|
||||
if emitStatus {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *externalCredential) clearReverseSession(session *yamux.Session) {
|
||||
var emitStatus bool
|
||||
c.reverseAccess.Lock()
|
||||
wasAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
|
||||
if c.reverseSession == session {
|
||||
c.reverseSession = nil
|
||||
}
|
||||
isAvailable := c.baseURL == reverseProxyBaseURL && c.reverseSession != nil && !c.reverseSession.IsClosed()
|
||||
emitStatus = wasAvailable != isAvailable
|
||||
c.reverseAccess.Unlock()
|
||||
if emitStatus {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) getReverseContext() context.Context {
|
||||
c.reverseAccess.RLock()
|
||||
defer c.reverseAccess.RUnlock()
|
||||
return c.reverseContext
|
||||
}
|
||||
|
||||
func (c *externalCredential) resetReverseContext() {
|
||||
c.reverseAccess.Lock()
|
||||
if c.closed {
|
||||
c.reverseAccess.Unlock()
|
||||
return
|
||||
}
|
||||
c.reverseCancel()
|
||||
c.reverseContext, c.reverseCancel = context.WithCancel(context.Background())
|
||||
c.reverseAccess.Unlock()
|
||||
}
|
||||
137
service/ccm/credential_file.go
Normal file
137
service/ccm/credential_file.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/fswatch"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
const credentialReloadRetryInterval = 2 * time.Second
|
||||
|
||||
func resolveCredentialFilePath(customPath string) (string, error) {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if filepath.IsAbs(customPath) {
|
||||
return customPath, nil
|
||||
}
|
||||
return filepath.Abs(customPath)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ensureCredentialWatcher() error {
|
||||
c.watcherAccess.Lock()
|
||||
defer c.watcherAccess.Unlock()
|
||||
|
||||
if c.watcher != nil || c.credentialFilePath == "" {
|
||||
return nil
|
||||
}
|
||||
if !c.watcherRetryAt.IsZero() && time.Now().Before(c.watcherRetryAt) {
|
||||
return nil
|
||||
}
|
||||
|
||||
watcher, err := fswatch.NewWatcher(fswatch.Options{
|
||||
Path: []string{c.credentialFilePath},
|
||||
Logger: c.logger,
|
||||
Callback: func(string) {
|
||||
err := c.reloadCredentials(true)
|
||||
if err != nil {
|
||||
c.logger.Warn("reload credentials for ", c.tag, ": ", err)
|
||||
}
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval)
|
||||
return err
|
||||
}
|
||||
|
||||
err = watcher.Start()
|
||||
if err != nil {
|
||||
c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval)
|
||||
return err
|
||||
}
|
||||
|
||||
c.watcher = watcher
|
||||
c.watcherRetryAt = time.Time{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) retryCredentialReloadIfNeeded() {
|
||||
c.stateAccess.RLock()
|
||||
unavailable := c.state.unavailable
|
||||
lastAttempt := c.state.lastCredentialLoadAttempt
|
||||
c.stateAccess.RUnlock()
|
||||
if !unavailable {
|
||||
return
|
||||
}
|
||||
if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval {
|
||||
return
|
||||
}
|
||||
|
||||
err := c.ensureCredentialWatcher()
|
||||
if err != nil {
|
||||
c.logger.Error("start credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
_ = c.reloadCredentials(false)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) reloadCredentials(force bool) error {
|
||||
c.reloadAccess.Lock()
|
||||
defer c.reloadAccess.Unlock()
|
||||
|
||||
c.stateAccess.RLock()
|
||||
unavailable := c.state.unavailable
|
||||
lastAttempt := c.state.lastCredentialLoadAttempt
|
||||
c.stateAccess.RUnlock()
|
||||
if !force {
|
||||
if !unavailable {
|
||||
return nil
|
||||
}
|
||||
if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval {
|
||||
return c.unavailableError()
|
||||
}
|
||||
}
|
||||
|
||||
c.stateAccess.Lock()
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.stateAccess.Unlock()
|
||||
|
||||
credentials, err := platformReadCredentials(c.credentialPath)
|
||||
if err != nil {
|
||||
return c.markCredentialsUnavailable(E.Cause(err, "read credentials"))
|
||||
}
|
||||
|
||||
c.absorbCredentials(credentials)
|
||||
return c.refreshCredentialsIfNeeded(false)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markCredentialsUnavailable(err error) error {
|
||||
c.access.Lock()
|
||||
hadCredentials := c.credentials != nil
|
||||
c.credentials = nil
|
||||
c.access.Unlock()
|
||||
|
||||
c.stateAccess.Lock()
|
||||
before := c.statusSnapshotLocked()
|
||||
c.state.unavailable = true
|
||||
c.state.lastCredentialLoadError = err.Error()
|
||||
c.state.accountType = ""
|
||||
c.state.rateLimitTier = ""
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
shouldEmit := before != c.statusSnapshotLocked()
|
||||
c.stateAccess.Unlock()
|
||||
|
||||
if shouldInterrupt && hadCredentials {
|
||||
c.interruptConnections()
|
||||
}
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
84
service/ccm/credential_lock.go
Normal file
84
service/ccm/credential_lock.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
// acquireCredentialLock acquires a cross-process lock compatible with Claude Code's
|
||||
// proper-lockfile protocol. The lock is a directory created via mkdir (atomic on
|
||||
// POSIX filesystems).
|
||||
//
|
||||
// ref (@anthropic-ai/claude-code @2.1.81): cli.js _P1 (line 179530-179577)
|
||||
// ref: proper-lockfile mkdir protocol (cli.js:43570)
|
||||
// ref: proper-lockfile default options — stale=10s, update=stale/2=5s, realpath=true (cli.js:43661-43664)
|
||||
//
|
||||
// Claude Code locks d1() (= ~/.claude config dir). The lock directory is
|
||||
// <realpath(configDir)>.lock (proper-lockfile default: <path>.lock).
|
||||
// Manual retry: initial + 5 retries = 6 total, delay 1+rand(1s) per retry.
|
||||
func acquireCredentialLock(configDir string) (func(), error) {
|
||||
// ref: cli.js _P1 line 179531 — mkdir -p configDir before locking
|
||||
os.MkdirAll(configDir, 0o700)
|
||||
// ref: proper-lockfile realpath:true (cli.js:43664) — resolve symlinks before appending .lock
|
||||
resolved, err := filepath.EvalSymlinks(configDir)
|
||||
if err != nil {
|
||||
resolved = filepath.Clean(configDir)
|
||||
}
|
||||
lockPath := resolved + ".lock"
|
||||
// ref: cli.js _P1 line 179539-179543 — initial + 5 retries = 6 total attempts
|
||||
for attempt := 0; attempt < 6; attempt++ {
|
||||
if attempt > 0 {
|
||||
// ref: cli.js _P1 line 179542 — 1000 + Math.random() * 1000
|
||||
delay := time.Second + time.Duration(rand.IntN(1000))*time.Millisecond
|
||||
time.Sleep(delay)
|
||||
}
|
||||
err = os.Mkdir(lockPath, 0o755)
|
||||
if err == nil {
|
||||
return startLockHeartbeat(lockPath), nil
|
||||
}
|
||||
if !os.IsExist(err) {
|
||||
return nil, E.Cause(err, "create lock directory")
|
||||
}
|
||||
// ref: proper-lockfile stale check (cli.js:43603-43604)
|
||||
// stale threshold = 10s (cli.js:43662)
|
||||
info, statErr := os.Stat(lockPath)
|
||||
if statErr != nil {
|
||||
continue
|
||||
}
|
||||
if time.Since(info.ModTime()) > 10*time.Second {
|
||||
os.Remove(lockPath)
|
||||
}
|
||||
}
|
||||
return nil, E.New("credential lock timeout")
|
||||
}
|
||||
|
||||
// startLockHeartbeat spawns a goroutine that touches the lock directory's mtime
|
||||
// every 5 seconds to prevent stale detection by other processes.
|
||||
//
|
||||
// ref: proper-lockfile update interval = stale/2 = 5s (cli.js:43662-43663)
|
||||
//
|
||||
// Returns a release function that stops the heartbeat and removes the lock directory.
|
||||
func startLockHeartbeat(lockPath string) func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
os.Chtimes(lockPath, now, now)
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return func() {
|
||||
close(done)
|
||||
os.Remove(lockPath)
|
||||
}
|
||||
}
|
||||
327
service/ccm/credential_oauth.go
Normal file
327
service/ccm/credential_oauth.go
Normal file
@@ -0,0 +1,327 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
const (
|
||||
oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
oauth2TokenURL = "https://platform.claude.com/v1/oauth/token"
|
||||
claudeAPIBaseURL = "https://api.anthropic.com"
|
||||
anthropicBetaOAuthValue = "oauth-2025-04-20"
|
||||
|
||||
// ref (@anthropic-ai/claude-code @2.1.81): cli.js vB (line 172879)
|
||||
tokenRefreshBufferMs = 300000
|
||||
)
|
||||
|
||||
// ref (@anthropic-ai/claude-code @2.1.81): cli.js q78 (line 33167)
|
||||
// These scopes may change across Claude Code versions.
|
||||
var defaultOAuthScopes = []string{
|
||||
"user:profile", "user:inference", "user:sessions:claude_code",
|
||||
"user:mcp_servers", "user:file_upload",
|
||||
}
|
||||
|
||||
// resolveRefreshScopes determines which scopes to send in the token refresh request.
|
||||
//
|
||||
// ref (@anthropic-ai/claude-code @2.1.81): cli.js NR() (line 172693) + mB6 scope logic (line 172761)
|
||||
//
|
||||
// Claude Code behavior: if stored scopes include "user:inference", send default
|
||||
// scopes; otherwise send the stored scopes verbatim.
|
||||
func resolveRefreshScopes(stored []string) string {
|
||||
if len(stored) == 0 || slices.Contains(stored, "user:inference") {
|
||||
return strings.Join(defaultOAuthScopes, " ")
|
||||
}
|
||||
return strings.Join(stored, " ")
|
||||
}
|
||||
|
||||
const (
|
||||
ccmRefreshUserAgent = "axios/1.13.6"
|
||||
ccmUserAgentFallback = "claude-code/2.1.85"
|
||||
)
|
||||
|
||||
var (
|
||||
ccmUserAgentOnce sync.Once
|
||||
ccmUserAgentValue string
|
||||
)
|
||||
|
||||
func initCCMUserAgent(logger log.ContextLogger) {
|
||||
ccmUserAgentOnce.Do(func() {
|
||||
version, err := detectClaudeCodeVersion()
|
||||
if err != nil {
|
||||
logger.Error("detect Claude Code version: ", err)
|
||||
ccmUserAgentValue = ccmUserAgentFallback
|
||||
return
|
||||
}
|
||||
logger.Debug("detected Claude Code version: ", version)
|
||||
ccmUserAgentValue = "claude-code/" + version
|
||||
})
|
||||
}
|
||||
|
||||
func detectClaudeCodeVersion() (string, error) {
|
||||
userInfo, err := getRealUser()
|
||||
if err != nil {
|
||||
return "", E.Cause(err, "get user")
|
||||
}
|
||||
binaryName := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
binaryName = "claude.exe"
|
||||
}
|
||||
linkPath := filepath.Join(userInfo.HomeDir, ".local", "bin", binaryName)
|
||||
target, err := os.Readlink(linkPath)
|
||||
if err != nil {
|
||||
return "", E.Cause(err, "readlink ", linkPath)
|
||||
}
|
||||
if !filepath.IsAbs(target) {
|
||||
target = filepath.Join(filepath.Dir(linkPath), target)
|
||||
}
|
||||
parent := filepath.Base(filepath.Dir(target))
|
||||
if parent != "versions" {
|
||||
return "", E.New("unexpected symlink target: ", target)
|
||||
}
|
||||
return filepath.Base(target), nil
|
||||
}
|
||||
|
||||
// resolveConfigDir returns the Claude config directory for lock coordination.
|
||||
//
|
||||
// ref (@anthropic-ai/claude-code @2.1.81): cli.js d1() (line 2983) — config dir used for locking
|
||||
func resolveConfigDir(credentialPath string, credentialFilePath string) string {
|
||||
if credentialPath == "" {
|
||||
if configDir := os.Getenv("CLAUDE_CONFIG_DIR"); configDir != "" {
|
||||
return configDir
|
||||
}
|
||||
userInfo, err := getRealUser()
|
||||
if err == nil {
|
||||
return filepath.Join(userInfo.HomeDir, ".claude")
|
||||
}
|
||||
}
|
||||
return filepath.Dir(credentialFilePath)
|
||||
}
|
||||
|
||||
func getRealUser() (*user.User, error) {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
sudoUserInfo, err := user.Lookup(sudoUser)
|
||||
if err == nil {
|
||||
return sudoUserInfo, nil
|
||||
}
|
||||
}
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
func getDefaultCredentialsPath() (string, error) {
|
||||
if configDir := os.Getenv("CLAUDE_CONFIG_DIR"); configDir != "" {
|
||||
return filepath.Join(configDir, ".credentials.json"), nil
|
||||
}
|
||||
userInfo, err := getRealUser()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(userInfo.HomeDir, ".claude", ".credentials.json"), nil
|
||||
}
|
||||
|
||||
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var credentialsContainer struct {
|
||||
ClaudeAIAuth *oauthCredentials `json:"claudeAiOauth,omitempty"`
|
||||
}
|
||||
err = json.Unmarshal(data, &credentialsContainer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if credentialsContainer.ClaudeAIAuth == nil {
|
||||
return nil, E.New("claudeAiOauth field not found in credentials")
|
||||
}
|
||||
return credentialsContainer.ClaudeAIAuth, nil
|
||||
}
|
||||
|
||||
func checkCredentialFileWritable(path string) error {
|
||||
file, err := os.OpenFile(path, os.O_WRONLY, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
// writeCredentialsToFile performs a read-modify-write: reads the existing JSON,
|
||||
// replaces only the claudeAiOauth key, and writes back. This preserves any
|
||||
// other top-level keys in the credential file.
|
||||
//
|
||||
// ref (@anthropic-ai/claude-code @2.1.81): cli.js BP6 (line 179444-179454) — read-modify-write
|
||||
// ref: cli.js qD1.update (line 176156) — writeFileSync + chmod 0o600
|
||||
func writeCredentialsToFile(credentials *oauthCredentials, path string) error {
|
||||
return writeStorageValue(jsonFileStorage{path: path}, "claudeAiOauth", credentials)
|
||||
}
|
||||
|
||||
// oauthCredentials mirrors the claudeAiOauth object in Claude Code's
|
||||
// credential file ($CLAUDE_CONFIG_DIR/.credentials.json).
|
||||
//
|
||||
// ref (@anthropic-ai/claude-code @2.1.81): cli.js BP6 (line 179446-179452)
|
||||
type oauthCredentials struct {
|
||||
AccessToken string `json:"accessToken"` // ref: cli.js line 179447
|
||||
RefreshToken string `json:"refreshToken"` // ref: cli.js line 179448
|
||||
ExpiresAt int64 `json:"expiresAt"` // ref: cli.js line 179449 (epoch ms)
|
||||
Scopes []string `json:"scopes"` // ref: cli.js line 179450
|
||||
SubscriptionType *string `json:"subscriptionType"` // ref: cli.js line 179451 (?? null)
|
||||
RateLimitTier *string `json:"rateLimitTier"` // ref: cli.js line 179452 (?? null)
|
||||
}
|
||||
|
||||
type oauthRefreshResult struct {
|
||||
Credentials *oauthCredentials
|
||||
TokenAccount *claudeOAuthAccount
|
||||
Profile *claudeProfileSnapshot
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) needsRefresh() bool {
|
||||
if c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs
|
||||
}
|
||||
|
||||
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthRefreshResult, time.Duration, error) {
|
||||
if credentials.RefreshToken == "" {
|
||||
return nil, 0, E.New("refresh token is empty")
|
||||
}
|
||||
|
||||
// ref (@anthropic-ai/claude-code @2.1.81): cli.js mB6 (line 172757-172761)
|
||||
requestBody, err := json.Marshal(map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.RefreshToken,
|
||||
"client_id": oauth2ClientID,
|
||||
"scope": resolveRefreshScopes(credentials.Scopes),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, 0, E.Cause(err, "marshal request")
|
||||
}
|
||||
|
||||
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("User-Agent", ccmRefreshUserAgent)
|
||||
return request, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode == http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
retryDelay := time.Duration(-1)
|
||||
if retryAfter := response.Header.Get("Retry-After"); retryAfter != "" {
|
||||
seconds, parseErr := strconv.ParseInt(retryAfter, 10, 64)
|
||||
if parseErr == nil && seconds > 0 {
|
||||
retryDelay = time.Duration(seconds) * time.Second
|
||||
}
|
||||
}
|
||||
return nil, retryDelay, E.New("refresh rate limited: ", response.Status, " ", string(body))
|
||||
}
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, 0, E.New("refresh failed: ", response.Status, " ", string(body))
|
||||
}
|
||||
|
||||
// ref (@anthropic-ai/claude-code @2.1.81): cli.js mB6 response (line 172769-172772)
|
||||
var tokenResponse struct {
|
||||
AccessToken string `json:"access_token"` // ref: cli.js line 172770 z
|
||||
RefreshToken string `json:"refresh_token"` // ref: cli.js line 172770 w (defaults to input)
|
||||
ExpiresIn int `json:"expires_in"` // ref: cli.js line 172770 O
|
||||
Scope *string `json:"scope"` // ref: cli.js line 172772 uB6(Y.scope)
|
||||
Account *struct {
|
||||
UUID string `json:"uuid"`
|
||||
EmailAddress string `json:"email_address"`
|
||||
} `json:"account"`
|
||||
Organization *struct {
|
||||
UUID string `json:"uuid"`
|
||||
} `json:"organization"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
|
||||
if err != nil {
|
||||
return nil, 0, E.Cause(err, "decode response")
|
||||
}
|
||||
|
||||
newCredentials := *credentials
|
||||
newCredentials.AccessToken = tokenResponse.AccessToken
|
||||
if tokenResponse.RefreshToken != "" {
|
||||
newCredentials.RefreshToken = tokenResponse.RefreshToken
|
||||
}
|
||||
newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000
|
||||
// ref: cli.js uB6 (line 172696-172697): A?.split(" ").filter(Boolean)
|
||||
// strings.Fields matches .filter(Boolean): splits on whitespace runs, removes empty strings
|
||||
if tokenResponse.Scope != nil {
|
||||
newCredentials.Scopes = strings.Fields(*tokenResponse.Scope)
|
||||
}
|
||||
|
||||
return &oauthRefreshResult{
|
||||
Credentials: &newCredentials,
|
||||
TokenAccount: extractTokenAccount(tokenResponse.Account, tokenResponse.Organization),
|
||||
}, 0, nil
|
||||
}
|
||||
|
||||
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
|
||||
if credentials == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *credentials
|
||||
cloned.Scopes = append([]string(nil), credentials.Scopes...)
|
||||
cloned.SubscriptionType = cloneStringPointer(credentials.SubscriptionType)
|
||||
cloned.RateLimitTier = cloneStringPointer(credentials.RateLimitTier)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
|
||||
if left == nil || right == nil {
|
||||
return left == right
|
||||
}
|
||||
return left.AccessToken == right.AccessToken &&
|
||||
left.RefreshToken == right.RefreshToken &&
|
||||
left.ExpiresAt == right.ExpiresAt &&
|
||||
slices.Equal(left.Scopes, right.Scopes) &&
|
||||
equalStringPointer(left.SubscriptionType, right.SubscriptionType) &&
|
||||
equalStringPointer(left.RateLimitTier, right.RateLimitTier)
|
||||
}
|
||||
|
||||
func extractTokenAccount(account *struct {
|
||||
UUID string `json:"uuid"`
|
||||
EmailAddress string `json:"email_address"`
|
||||
}, organization *struct {
|
||||
UUID string `json:"uuid"`
|
||||
},
|
||||
) *claudeOAuthAccount {
|
||||
if account == nil && organization == nil {
|
||||
return nil
|
||||
}
|
||||
tokenAccount := &claudeOAuthAccount{}
|
||||
if account != nil {
|
||||
tokenAccount.AccountUUID = account.UUID
|
||||
tokenAccount.EmailAddress = account.EmailAddress
|
||||
}
|
||||
if organization != nil {
|
||||
tokenAccount.OrganizationUUID = organization.UUID
|
||||
}
|
||||
if tokenAccount.AccountUUID == "" && tokenAccount.EmailAddress == "" && tokenAccount.OrganizationUUID == "" {
|
||||
return nil
|
||||
}
|
||||
return tokenAccount
|
||||
}
|
||||
141
service/ccm/credential_oauth_test.go
Normal file
141
service/ccm/credential_oauth_test.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRefreshTokenScopeParsing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
storedScopes []string
|
||||
responseBody string
|
||||
expectedScope string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "missing scope preserves stored scopes",
|
||||
storedScopes: []string{"user:profile", "user:inference"},
|
||||
responseBody: `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`,
|
||||
expectedScope: strings.Join(defaultOAuthScopes, " "),
|
||||
expected: []string{"user:profile", "user:inference"},
|
||||
},
|
||||
{
|
||||
name: "empty scope clears stored scopes",
|
||||
storedScopes: []string{"user:profile", "user:inference"},
|
||||
responseBody: `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600,"scope":""}`,
|
||||
expectedScope: strings.Join(defaultOAuthScopes, " "),
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "stored non inference scopes are sent verbatim",
|
||||
storedScopes: []string{"user:profile"},
|
||||
responseBody: `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600,"scope":"user:profile user:file_upload"}`,
|
||||
expectedScope: "user:profile",
|
||||
expected: []string{"user:profile", "user:file_upload"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var seenScope string
|
||||
client := &http.Client{Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
body, err := io.ReadAll(request.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var payload map[string]string
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
seenScope = payload["scope"]
|
||||
return newJSONResponse(http.StatusOK, testCase.responseBody), nil
|
||||
})}
|
||||
|
||||
result, _, err := refreshToken(context.Background(), client, &oauthCredentials{
|
||||
AccessToken: "old-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
|
||||
Scopes: testCase.storedScopes,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if seenScope != testCase.expectedScope {
|
||||
t.Fatalf("expected request scope %q, got %q", testCase.expectedScope, seenScope)
|
||||
}
|
||||
if result == nil || result.Credentials == nil {
|
||||
t.Fatal("expected refresh result credentials")
|
||||
}
|
||||
if !slices.Equal(result.Credentials.Scopes, testCase.expected) {
|
||||
t.Fatalf("expected scopes %v, got %v", testCase.expected, result.Credentials.Scopes)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshTokenExtractsTokenAccount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := &http.Client{Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
return newJSONResponse(http.StatusOK, `{
|
||||
"access_token":"new-token",
|
||||
"refresh_token":"new-refresh",
|
||||
"expires_in":3600,
|
||||
"account":{"uuid":"account","email_address":"user@example.com"},
|
||||
"organization":{"uuid":"org"}
|
||||
}`), nil
|
||||
})}
|
||||
|
||||
result, _, err := refreshToken(context.Background(), client, &oauthCredentials{
|
||||
AccessToken: "old-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
|
||||
Scopes: []string{"user:profile", "user:inference"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result == nil || result.TokenAccount == nil {
|
||||
t.Fatal("expected token account")
|
||||
}
|
||||
if result.TokenAccount.AccountUUID != "account" || result.TokenAccount.EmailAddress != "user@example.com" || result.TokenAccount.OrganizationUUID != "org" {
|
||||
t.Fatalf("unexpected token account: %#v", result.TokenAccount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCredentialsEqualIncludesProfileFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
subscriptionType := "max"
|
||||
rateLimitTier := "default_claude_max_20x"
|
||||
left := &oauthCredentials{
|
||||
AccessToken: "token",
|
||||
RefreshToken: "refresh",
|
||||
ExpiresAt: 123,
|
||||
Scopes: []string{"user:inference"},
|
||||
SubscriptionType: &subscriptionType,
|
||||
RateLimitTier: &rateLimitTier,
|
||||
}
|
||||
right := cloneCredentials(left)
|
||||
if !credentialsEqual(left, right) {
|
||||
t.Fatal("expected cloned credentials to be equal")
|
||||
}
|
||||
|
||||
otherTier := "default_claude_max_5x"
|
||||
right.RateLimitTier = &otherTier
|
||||
if credentialsEqual(left, right) {
|
||||
t.Fatal("expected different rate limit tier to break equality")
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,17 @@ func platformReadCredentials(customPath string) (*oauthCredentials, error) {
|
||||
return readCredentialsFromFile(customPath)
|
||||
}
|
||||
|
||||
func platformCanWriteCredentials(customPath string) error {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return checkCredentialFileWritable(customPath)
|
||||
}
|
||||
|
||||
func platformWriteCredentials(oauthCredentials *oauthCredentials, customPath string) error {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
|
||||
440
service/ccm/credential_provider.go
Normal file
440
service/ccm/credential_provider.go
Normal file
@@ -0,0 +1,440 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type credentialProvider interface {
|
||||
selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error)
|
||||
onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential
|
||||
linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool
|
||||
pollIfStale()
|
||||
pollCredentialIfStale(credential Credential)
|
||||
allCredentials() []Credential
|
||||
close()
|
||||
}
|
||||
|
||||
type singleCredentialProvider struct {
|
||||
credential Credential
|
||||
sessionAccess sync.RWMutex
|
||||
sessions map[string]time.Time
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) {
|
||||
if !selection.allows(p.credential) {
|
||||
return nil, false, E.New("credential ", p.credential.tagName(), " is filtered out")
|
||||
}
|
||||
if !p.credential.isAvailable() {
|
||||
return nil, false, p.credential.unavailableError()
|
||||
}
|
||||
if !p.credential.isUsable() {
|
||||
return nil, false, E.New("credential ", p.credential.tagName(), " is rate-limited")
|
||||
}
|
||||
var isNew bool
|
||||
if sessionID != "" {
|
||||
p.sessionAccess.Lock()
|
||||
if p.sessions == nil {
|
||||
p.sessions = make(map[string]time.Time)
|
||||
}
|
||||
_, exists := p.sessions[sessionID]
|
||||
if !exists {
|
||||
p.sessions[sessionID] = time.Now()
|
||||
isNew = true
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
return p.credential, isNew, nil
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) onRateLimited(_ string, credential Credential, resetAt time.Time, _ credentialSelection) Credential {
|
||||
credential.markRateLimited(resetAt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) pollIfStale() {
|
||||
now := time.Now()
|
||||
p.sessionAccess.Lock()
|
||||
for id, createdAt := range p.sessions {
|
||||
if now.Sub(createdAt) > sessionExpiry {
|
||||
delete(p.sessions, id)
|
||||
}
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
|
||||
if time.Since(p.credential.lastUpdatedTime()) > p.credential.pollBackoff(defaultPollInterval) {
|
||||
p.credential.pollUsage()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) allCredentials() []Credential {
|
||||
return []Credential{p.credential}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) linkProviderInterrupt(_ Credential, _ credentialSelection, _ func()) func() bool {
|
||||
return func() bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) pollCredentialIfStale(credential Credential) {
|
||||
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) {
|
||||
credential.pollUsage()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) close() {}
|
||||
|
||||
type sessionEntry struct {
|
||||
tag string
|
||||
selectionScope credentialSelectionScope
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
type credentialInterruptKey struct {
|
||||
tag string
|
||||
selectionScope credentialSelectionScope
|
||||
}
|
||||
|
||||
type credentialInterruptEntry struct {
|
||||
context context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type balancerProvider struct {
|
||||
credentials []Credential
|
||||
strategy string
|
||||
roundRobinIndex atomic.Uint64
|
||||
rebalanceThreshold float64
|
||||
sessionAccess sync.RWMutex
|
||||
sessions map[string]sessionEntry
|
||||
interruptAccess sync.Mutex
|
||||
credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry
|
||||
logger log.ContextLogger
|
||||
}
|
||||
|
||||
func newBalancerProvider(credentials []Credential, strategy string, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider {
|
||||
return &balancerProvider{
|
||||
credentials: credentials,
|
||||
strategy: strategy,
|
||||
rebalanceThreshold: rebalanceThreshold,
|
||||
sessions: make(map[string]sessionEntry),
|
||||
credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) {
|
||||
selectionScope := selection.scopeOrDefault()
|
||||
for {
|
||||
if p.strategy == C.BalancerStrategyFallback {
|
||||
best := p.pickCredential(selection.filter)
|
||||
if best == nil {
|
||||
return nil, false, allCredentialsUnavailableError(p.credentials)
|
||||
}
|
||||
return best, p.storeSessionIfAbsent(sessionID, sessionEntry{createdAt: time.Now()}), nil
|
||||
}
|
||||
|
||||
if sessionID != "" {
|
||||
p.sessionAccess.RLock()
|
||||
entry, exists := p.sessions[sessionID]
|
||||
p.sessionAccess.RUnlock()
|
||||
if exists {
|
||||
if entry.selectionScope == selectionScope {
|
||||
for _, credential := range p.credentials {
|
||||
if credential.tagName() == entry.tag && selection.allows(credential) && credential.isUsable() {
|
||||
if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) {
|
||||
better := p.pickLeastUsed(selection.filter)
|
||||
if better != nil && better.tagName() != credential.tagName() {
|
||||
effectiveThreshold := p.rebalanceThreshold / credential.planWeight()
|
||||
delta := credential.weeklyUtilization() - better.weeklyUtilization()
|
||||
if delta > effectiveThreshold {
|
||||
p.logger.Info("rebalancing away from ", credential.tagName(),
|
||||
": utilization delta ", delta, "% exceeds effective threshold ",
|
||||
effectiveThreshold, "% (weight ", credential.planWeight(), ")")
|
||||
p.rebalanceCredential(credential.tagName(), selectionScope)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return credential, false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
p.sessionAccess.Lock()
|
||||
currentEntry, stillExists := p.sessions[sessionID]
|
||||
if stillExists && currentEntry == entry {
|
||||
delete(p.sessions, sessionID)
|
||||
p.sessionAccess.Unlock()
|
||||
} else {
|
||||
p.sessionAccess.Unlock()
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
best := p.pickCredential(selection.filter)
|
||||
if best == nil {
|
||||
return nil, false, allCredentialsUnavailableError(p.credentials)
|
||||
}
|
||||
if p.storeSessionIfAbsent(sessionID, sessionEntry{
|
||||
tag: best.tagName(),
|
||||
selectionScope: selectionScope,
|
||||
createdAt: time.Now(),
|
||||
}) {
|
||||
return best, true, nil
|
||||
}
|
||||
if sessionID == "" {
|
||||
return best, false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) storeSessionIfAbsent(sessionID string, entry sessionEntry) bool {
|
||||
if sessionID == "" {
|
||||
return false
|
||||
}
|
||||
p.sessionAccess.Lock()
|
||||
defer p.sessionAccess.Unlock()
|
||||
if _, exists := p.sessions[sessionID]; exists {
|
||||
return false
|
||||
}
|
||||
p.sessions[sessionID] = entry
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) {
|
||||
key := credentialInterruptKey{tag: tag, selectionScope: selectionScope}
|
||||
p.interruptAccess.Lock()
|
||||
if entry, loaded := p.credentialInterrupts[key]; loaded {
|
||||
entry.cancel()
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel}
|
||||
p.interruptAccess.Unlock()
|
||||
|
||||
p.sessionAccess.Lock()
|
||||
for id, entry := range p.sessions {
|
||||
if entry.tag == tag && entry.selectionScope == selectionScope {
|
||||
delete(p.sessions, id)
|
||||
}
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
|
||||
func (p *balancerProvider) linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool {
|
||||
if p.strategy == C.BalancerStrategyFallback {
|
||||
return func() bool { return false }
|
||||
}
|
||||
key := credentialInterruptKey{
|
||||
tag: credential.tagName(),
|
||||
selectionScope: selection.scopeOrDefault(),
|
||||
}
|
||||
p.interruptAccess.Lock()
|
||||
entry, loaded := p.credentialInterrupts[key]
|
||||
if !loaded {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
entry = credentialInterruptEntry{context: ctx, cancel: cancel}
|
||||
p.credentialInterrupts[key] = entry
|
||||
}
|
||||
p.interruptAccess.Unlock()
|
||||
return context.AfterFunc(entry.context, onInterrupt)
|
||||
}
|
||||
|
||||
func (p *balancerProvider) onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential {
|
||||
credential.markRateLimited(resetAt)
|
||||
if p.strategy == C.BalancerStrategyFallback {
|
||||
return p.pickCredential(selection.filter)
|
||||
}
|
||||
if sessionID != "" {
|
||||
p.sessionAccess.Lock()
|
||||
delete(p.sessions, sessionID)
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
|
||||
best := p.pickCredential(selection.filter)
|
||||
if best != nil && sessionID != "" {
|
||||
p.sessionAccess.Lock()
|
||||
p.sessions[sessionID] = sessionEntry{
|
||||
tag: best.tagName(),
|
||||
selectionScope: selection.scopeOrDefault(),
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickCredential(filter func(Credential) bool) Credential {
|
||||
switch p.strategy {
|
||||
case C.BalancerStrategyRoundRobin:
|
||||
return p.pickRoundRobin(filter)
|
||||
case C.BalancerStrategyRandom:
|
||||
return p.pickRandom(filter)
|
||||
case C.BalancerStrategyFallback:
|
||||
return p.pickFallback(filter)
|
||||
default:
|
||||
return p.pickLeastUsed(filter)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickFallback(filter func(Credential) bool) Credential {
|
||||
for _, credential := range p.credentials {
|
||||
if filter != nil && !filter(credential) {
|
||||
continue
|
||||
}
|
||||
if credential.isUsable() {
|
||||
return credential
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const weeklyWindowHours = 7 * 24
|
||||
|
||||
func (p *balancerProvider) pickLeastUsed(filter func(Credential) bool) Credential {
|
||||
var best Credential
|
||||
bestScore := float64(-1)
|
||||
now := time.Now()
|
||||
for _, credential := range p.credentials {
|
||||
if filter != nil && !filter(credential) {
|
||||
continue
|
||||
}
|
||||
if !credential.isUsable() {
|
||||
continue
|
||||
}
|
||||
remaining := credential.weeklyCap() - credential.weeklyUtilization()
|
||||
score := remaining * credential.planWeight()
|
||||
resetTime := credential.weeklyResetTime()
|
||||
if !resetTime.IsZero() {
|
||||
timeUntilReset := resetTime.Sub(now)
|
||||
if timeUntilReset < time.Hour {
|
||||
timeUntilReset = time.Hour
|
||||
}
|
||||
score *= weeklyWindowHours / timeUntilReset.Hours()
|
||||
}
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
best = credential
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickRoundRobin(filter func(Credential) bool) Credential {
|
||||
start := int(p.roundRobinIndex.Add(1) - 1)
|
||||
count := len(p.credentials)
|
||||
for offset := range count {
|
||||
candidate := p.credentials[(start+offset)%count]
|
||||
if filter != nil && !filter(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickRandom(filter func(Credential) bool) Credential {
|
||||
var usable []Credential
|
||||
for _, candidate := range p.credentials {
|
||||
if filter != nil && !filter(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
usable = append(usable, candidate)
|
||||
}
|
||||
}
|
||||
if len(usable) == 0 {
|
||||
return nil
|
||||
}
|
||||
return usable[rand.IntN(len(usable))]
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pollIfStale() {
|
||||
now := time.Now()
|
||||
p.sessionAccess.Lock()
|
||||
for id, entry := range p.sessions {
|
||||
if now.Sub(entry.createdAt) > sessionExpiry {
|
||||
delete(p.sessions, id)
|
||||
}
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
|
||||
p.interruptAccess.Lock()
|
||||
for key, entry := range p.credentialInterrupts {
|
||||
if entry.context.Err() != nil {
|
||||
delete(p.credentialInterrupts, key)
|
||||
}
|
||||
}
|
||||
p.interruptAccess.Unlock()
|
||||
|
||||
for _, credential := range p.credentials {
|
||||
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) {
|
||||
credential.pollUsage()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pollCredentialIfStale(credential Credential) {
|
||||
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) {
|
||||
credential.pollUsage()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) allCredentials() []Credential {
|
||||
return p.credentials
|
||||
}
|
||||
|
||||
func (p *balancerProvider) close() {}
|
||||
|
||||
func ccmPlanWeight(accountType string, rateLimitTier string) float64 {
|
||||
switch accountType {
|
||||
case "max":
|
||||
switch rateLimitTier {
|
||||
case "default_claude_max_20x":
|
||||
return 10
|
||||
case "default_claude_max_5x":
|
||||
return 5
|
||||
default:
|
||||
return 5
|
||||
}
|
||||
case "team":
|
||||
if rateLimitTier == "default_claude_max_5x" {
|
||||
return 5
|
||||
}
|
||||
return 1
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func allCredentialsUnavailableError(credentials []Credential) error {
|
||||
var hasUnavailable bool
|
||||
var earliest time.Time
|
||||
for _, credential := range credentials {
|
||||
if credential.unavailableError() != nil {
|
||||
hasUnavailable = true
|
||||
continue
|
||||
}
|
||||
resetAt := credential.earliestReset()
|
||||
if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) {
|
||||
earliest = resetAt
|
||||
}
|
||||
}
|
||||
if hasUnavailable {
|
||||
return E.New("all credentials unavailable")
|
||||
}
|
||||
if earliest.IsZero() {
|
||||
return E.New("all credentials rate-limited")
|
||||
}
|
||||
return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest)))
|
||||
}
|
||||
124
service/ccm/credential_storage.go
Normal file
124
service/ccm/credential_storage.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type jsonContainerStorage interface {
|
||||
readContainer() (map[string]json.RawMessage, bool, error)
|
||||
writeContainer(map[string]json.RawMessage) error
|
||||
delete() error
|
||||
}
|
||||
|
||||
type jsonFileStorage struct {
|
||||
path string
|
||||
}
|
||||
|
||||
func (s jsonFileStorage) readContainer() (map[string]json.RawMessage, bool, error) {
|
||||
data, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return make(map[string]json.RawMessage), false, nil
|
||||
}
|
||||
return nil, false, err
|
||||
}
|
||||
container := make(map[string]json.RawMessage)
|
||||
if len(data) == 0 {
|
||||
return container, true, nil
|
||||
}
|
||||
if err := json.Unmarshal(data, &container); err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
return container, true, nil
|
||||
}
|
||||
|
||||
func (s jsonFileStorage) writeContainer(container map[string]json.RawMessage) error {
|
||||
if err := os.MkdirAll(filepath.Dir(s.path), 0o700); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.MarshalIndent(container, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(s.path, data, 0o600)
|
||||
}
|
||||
|
||||
func (s jsonFileStorage) delete() error {
|
||||
err := os.Remove(s.path)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeStorageValue(storage jsonContainerStorage, key string, value any) error {
|
||||
container, _, err := storage.readContainer()
|
||||
if err != nil {
|
||||
var syntaxError *json.SyntaxError
|
||||
var typeError *json.UnmarshalTypeError
|
||||
if !errors.As(err, &syntaxError) && !errors.As(err, &typeError) {
|
||||
return err
|
||||
}
|
||||
container = make(map[string]json.RawMessage)
|
||||
}
|
||||
if container == nil {
|
||||
container = make(map[string]json.RawMessage)
|
||||
}
|
||||
encodedValue, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
container[key] = encodedValue
|
||||
return storage.writeContainer(container)
|
||||
}
|
||||
|
||||
func persistStorageValue(primary jsonContainerStorage, fallback jsonContainerStorage, key string, value any) error {
|
||||
primaryErr := writeStorageValue(primary, key, value)
|
||||
if primaryErr == nil {
|
||||
if fallback != nil {
|
||||
_ = fallback.delete()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if fallback == nil {
|
||||
return primaryErr
|
||||
}
|
||||
if err := writeStorageValue(fallback, key, value); err != nil {
|
||||
return err
|
||||
}
|
||||
_ = primary.delete()
|
||||
return nil
|
||||
}
|
||||
|
||||
func cloneStringPointer(value *string) *string {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *value
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func cloneBoolPointer(value *bool) *bool {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *value
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func equalStringPointer(left *string, right *string) bool {
|
||||
if left == nil || right == nil {
|
||||
return left == right
|
||||
}
|
||||
return *left == *right
|
||||
}
|
||||
|
||||
func equalBoolPointer(left *bool, right *bool) bool {
|
||||
if left == nil || right == nil {
|
||||
return left == right
|
||||
}
|
||||
return *left == *right
|
||||
}
|
||||
125
service/ccm/credential_storage_test.go
Normal file
125
service/ccm/credential_storage_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fakeJSONStorage struct {
|
||||
container map[string]json.RawMessage
|
||||
writeErr error
|
||||
deleted bool
|
||||
}
|
||||
|
||||
func (s *fakeJSONStorage) readContainer() (map[string]json.RawMessage, bool, error) {
|
||||
if s.container == nil {
|
||||
return make(map[string]json.RawMessage), false, nil
|
||||
}
|
||||
cloned := make(map[string]json.RawMessage, len(s.container))
|
||||
for key, value := range s.container {
|
||||
cloned[key] = value
|
||||
}
|
||||
return cloned, true, nil
|
||||
}
|
||||
|
||||
func (s *fakeJSONStorage) writeContainer(container map[string]json.RawMessage) error {
|
||||
if s.writeErr != nil {
|
||||
return s.writeErr
|
||||
}
|
||||
s.container = make(map[string]json.RawMessage, len(container))
|
||||
for key, value := range container {
|
||||
s.container[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeJSONStorage) delete() error {
|
||||
s.deleted = true
|
||||
s.container = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestPersistStorageValueDeletesFallbackOnPrimarySuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
primary := &fakeJSONStorage{}
|
||||
fallback := &fakeJSONStorage{container: map[string]json.RawMessage{"stale": json.RawMessage(`true`)}}
|
||||
if err := persistStorageValue(primary, fallback, "claudeAiOauth", &oauthCredentials{AccessToken: "token"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !fallback.deleted {
|
||||
t.Fatal("expected fallback storage to be deleted after primary write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistStorageValueDeletesPrimaryAfterFallbackSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
primary := &fakeJSONStorage{
|
||||
container: map[string]json.RawMessage{"claudeAiOauth": json.RawMessage(`{"accessToken":"old"}`)},
|
||||
writeErr: os.ErrPermission,
|
||||
}
|
||||
fallback := &fakeJSONStorage{}
|
||||
if err := persistStorageValue(primary, fallback, "claudeAiOauth", &oauthCredentials{AccessToken: "new"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !primary.deleted {
|
||||
t.Fatal("expected primary storage to be deleted after fallback write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteCredentialsToFilePreservesTopLevelKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := t.TempDir()
|
||||
path := filepath.Join(directory, ".credentials.json")
|
||||
initial := []byte(`{"keep":{"nested":true},"claudeAiOauth":{"accessToken":"old"}}`)
|
||||
if err := os.WriteFile(path, initial, 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := writeCredentialsToFile(&oauthCredentials{AccessToken: "new"}, path); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var container map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &container); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, exists := container["keep"]; !exists {
|
||||
t.Fatal("expected unknown top-level key to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteClaudeCodeOAuthAccountPreservesTopLevelKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := t.TempDir()
|
||||
path := filepath.Join(directory, ".claude.json")
|
||||
initial := []byte(`{"keep":{"nested":true},"oauthAccount":{"accountUuid":"old"}}`)
|
||||
if err := os.WriteFile(path, initial, 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := writeClaudeCodeOAuthAccount(path, &claudeOAuthAccount{AccountUUID: "new"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var container map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &container); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, exists := container["keep"]; !exists {
|
||||
t.Fatal("expected unknown config key to be preserved")
|
||||
}
|
||||
}
|
||||
76
service/ccm/rate_limit_state.go
Normal file
76
service/ccm/rate_limit_state.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package ccm
|
||||
|
||||
import "time"
|
||||
|
||||
type availabilityState string
|
||||
|
||||
const (
|
||||
availabilityStateUsable availabilityState = "usable"
|
||||
availabilityStateRateLimited availabilityState = "rate_limited"
|
||||
availabilityStateTemporarilyBlocked availabilityState = "temporarily_blocked"
|
||||
availabilityStateUnavailable availabilityState = "unavailable"
|
||||
availabilityStateUnknown availabilityState = "unknown"
|
||||
)
|
||||
|
||||
type availabilityReason string
|
||||
|
||||
const (
|
||||
availabilityReasonHardRateLimit availabilityReason = "hard_rate_limit"
|
||||
availabilityReasonConnectionLimit availabilityReason = "connection_limit"
|
||||
availabilityReasonPollFailed availabilityReason = "poll_failed"
|
||||
availabilityReasonUpstreamRejected availabilityReason = "upstream_rejected"
|
||||
availabilityReasonNoCredentials availabilityReason = "no_credentials"
|
||||
availabilityReasonUnknown availabilityReason = "unknown"
|
||||
)
|
||||
|
||||
type availabilityStatus struct {
|
||||
State availabilityState
|
||||
Reason availabilityReason
|
||||
ResetAt time.Time
|
||||
}
|
||||
|
||||
func (s availabilityStatus) normalized() availabilityStatus {
|
||||
if s.State == "" {
|
||||
s.State = availabilityStateUnknown
|
||||
}
|
||||
if s.Reason == "" && s.State != availabilityStateUsable {
|
||||
s.Reason = availabilityReasonUnknown
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func claudeWindowProgress(resetAt time.Time, windowSeconds float64, now time.Time) float64 {
|
||||
if resetAt.IsZero() || windowSeconds <= 0 {
|
||||
return 0
|
||||
}
|
||||
windowStart := resetAt.Add(-time.Duration(windowSeconds * float64(time.Second)))
|
||||
if now.Before(windowStart) {
|
||||
return 0
|
||||
}
|
||||
progress := now.Sub(windowStart).Seconds() / windowSeconds
|
||||
if progress < 0 {
|
||||
return 0
|
||||
}
|
||||
if progress > 1 {
|
||||
return 1
|
||||
}
|
||||
return progress
|
||||
}
|
||||
|
||||
func claudeFiveHourWarning(utilizationPercent float64, resetAt time.Time, now time.Time) bool {
|
||||
return utilizationPercent >= 90 && claudeWindowProgress(resetAt, 5*60*60, now) >= 0.72
|
||||
}
|
||||
|
||||
func claudeWeeklyWarning(utilizationPercent float64, resetAt time.Time, now time.Time) bool {
|
||||
progress := claudeWindowProgress(resetAt, 7*24*60*60, now)
|
||||
switch {
|
||||
case utilizationPercent >= 75:
|
||||
return progress >= 0.60
|
||||
case utilizationPercent >= 50:
|
||||
return progress >= 0.35
|
||||
case utilizationPercent >= 25:
|
||||
return progress >= 0.15
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
266
service/ccm/reverse.go
Normal file
266
service/ccm/reverse.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"io"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
var defaultYamuxConfig = func() *yamux.Config {
|
||||
config := yamux.DefaultConfig()
|
||||
config.KeepAliveInterval = 15 * time.Second
|
||||
config.ConnectionWriteTimeout = 10 * time.Second
|
||||
config.MaxStreamWindowSize = 512 * 1024
|
||||
config.LogOutput = io.Discard
|
||||
return config
|
||||
}()
|
||||
|
||||
type bufferedConn struct {
|
||||
reader *bufio.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
type yamuxNetListener struct {
|
||||
session *yamux.Session
|
||||
}
|
||||
|
||||
func (l *yamuxNetListener) Accept() (net.Conn, error) {
|
||||
return l.session.Accept()
|
||||
}
|
||||
|
||||
func (l *yamuxNetListener) Close() error {
|
||||
return l.session.Close()
|
||||
}
|
||||
|
||||
func (l *yamuxNetListener) Addr() net.Addr {
|
||||
return l.session.Addr()
|
||||
}
|
||||
|
||||
func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Upgrade") != "reverse-proxy" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header")
|
||||
return
|
||||
}
|
||||
|
||||
if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"API key authentication is not supported; use Authorization: Bearer with a CCM user token")
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
|
||||
receiverCredential := s.findReceiverCredential(clientToken)
|
||||
if receiverCredential == nil {
|
||||
s.logger.WarnContext(ctx, "reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token")
|
||||
return
|
||||
}
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
s.logger.ErrorContext(ctx, "reverse connect: hijack not supported")
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported")
|
||||
return
|
||||
}
|
||||
|
||||
conn, bufferedReadWriter, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "reverse connect: hijack: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
response := "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: reverse-proxy\r\n\r\n"
|
||||
_, err = bufferedReadWriter.WriteString(response)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.ErrorContext(ctx, "reverse connect: write upgrade response: ", err)
|
||||
return
|
||||
}
|
||||
err = bufferedReadWriter.Flush()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.ErrorContext(ctx, "reverse connect: flush upgrade response: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := yamux.Client(conn, defaultYamuxConfig)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !receiverCredential.setReverseSession(session) {
|
||||
session.Close()
|
||||
return
|
||||
}
|
||||
s.logger.InfoContext(ctx, "reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
|
||||
|
||||
go func() {
|
||||
<-session.CloseChan()
|
||||
receiverCredential.clearReverseSession(session)
|
||||
s.logger.WarnContext(ctx, "reverse connection lost for ", receiverCredential.tagName())
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) findReceiverCredential(token string) *externalCredential {
|
||||
for _, credential := range s.allCredentials {
|
||||
external, ok := credential.(*externalCredential)
|
||||
if !ok || external.connectorURL != nil {
|
||||
continue
|
||||
}
|
||||
if external.token == token {
|
||||
return external
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) connectorLoop() {
|
||||
var consecutiveFailures int
|
||||
ctx := c.getReverseContext()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
sessionLifetime, err := c.connectorConnect(ctx)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if sessionLifetime >= connectorBackoffResetThreshold {
|
||||
consecutiveFailures = 0
|
||||
}
|
||||
consecutiveFailures++
|
||||
backoff := connectorBackoff(consecutiveFailures)
|
||||
c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff)
|
||||
timer := time.NewTimer(backoff)
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const connectorBackoffResetThreshold = time.Minute
|
||||
|
||||
func connectorBackoff(failures int) time.Duration {
|
||||
if failures > 5 {
|
||||
failures = 5
|
||||
}
|
||||
base := time.Second * time.Duration(1<<failures)
|
||||
if base > 30*time.Second {
|
||||
base = 30 * time.Second
|
||||
}
|
||||
jitter := time.Duration(rand.Int64N(int64(base) / 2))
|
||||
return base + jitter
|
||||
}
|
||||
|
||||
func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duration, error) {
|
||||
if c.reverseService == nil {
|
||||
return 0, E.New("reverse service not initialized")
|
||||
}
|
||||
destination := c.connectorResolveDestination()
|
||||
conn, err := c.connectorDialer.DialContext(ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
return 0, E.Cause(err, "dial")
|
||||
}
|
||||
|
||||
if c.connectorTLS != nil {
|
||||
tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone())
|
||||
err = tlsConn.HandshakeContext(ctx)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "tls handshake")
|
||||
}
|
||||
conn = tlsConn
|
||||
}
|
||||
|
||||
upgradeRequest := "GET " + c.connectorRequestPath + " HTTP/1.1\r\n" +
|
||||
"Host: " + c.connectorURL.Host + "\r\n" +
|
||||
"Connection: Upgrade\r\n" +
|
||||
"Upgrade: reverse-proxy\r\n" +
|
||||
"Authorization: Bearer " + c.token + "\r\n" +
|
||||
"\r\n"
|
||||
_, err = io.WriteString(conn, upgradeRequest)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "write upgrade request")
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
statusLine, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "read upgrade response")
|
||||
}
|
||||
if !strings.HasPrefix(statusLine, "HTTP/1.1 101") {
|
||||
conn.Close()
|
||||
return 0, E.New("unexpected upgrade response: ", strings.TrimSpace(statusLine))
|
||||
}
|
||||
for {
|
||||
line, readErr := reader.ReadString('\n')
|
||||
if readErr != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(readErr, "read upgrade headers")
|
||||
}
|
||||
if strings.TrimSpace(line) == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
session, err := yamux.Server(&bufferedConn{reader: reader, Conn: conn}, defaultYamuxConfig)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "create yamux server")
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
c.logger.Info("reverse connection established for ", c.tag)
|
||||
|
||||
serveStart := time.Now()
|
||||
httpServer := &http.Server{
|
||||
Handler: c.reverseService,
|
||||
ReadTimeout: 0,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
err = httpServer.Serve(&yamuxNetListener{session: session})
|
||||
sessionLifetime := time.Since(serveStart)
|
||||
if err != nil && !E.IsClosed(err) && ctx.Err() == nil {
|
||||
return sessionLifetime, E.Cause(err, "serve")
|
||||
}
|
||||
return sessionLifetime, E.New("connection closed")
|
||||
}
|
||||
|
||||
func (c *externalCredential) connectorResolveDestination() M.Socksaddr {
|
||||
return c.connectorDestination
|
||||
}
|
||||
@@ -1,69 +1,45 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
boxService "github.com/sagernet/sing-box/adapter/service"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/common/listener"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
anthropicconstant "github.com/anthropics/anthropic-sdk-go/shared/constant"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
|
||||
const (
|
||||
contextWindowStandard = 200000
|
||||
contextWindowPremium = 1000000
|
||||
premiumContextThreshold = 200000
|
||||
)
|
||||
const retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential"
|
||||
|
||||
func RegisterService(registry *boxService.Registry) {
|
||||
boxService.Register[option.CCMServiceOptions](registry, C.TypeCCM, NewService)
|
||||
}
|
||||
|
||||
type errorResponse struct {
|
||||
Type string `json:"type"`
|
||||
Error errorDetails `json:"error"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
}
|
||||
|
||||
type errorDetails struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
json.NewEncoder(w).Encode(errorResponse{
|
||||
Type: "error",
|
||||
Error: errorDetails{
|
||||
json.NewEncoder(w).Encode(anthropic.ErrorResponse{
|
||||
Type: anthropicconstant.Error("").Default(),
|
||||
Error: anthropic.ErrorObjectUnion{
|
||||
Type: errorType,
|
||||
Message: message,
|
||||
},
|
||||
@@ -71,6 +47,73 @@ func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, erro
|
||||
})
|
||||
}
|
||||
|
||||
func hasAlternativeCredential(provider credentialProvider, currentCredential Credential, selection credentialSelection) bool {
|
||||
if provider == nil || currentCredential == nil {
|
||||
return false
|
||||
}
|
||||
for _, credential := range provider.allCredentials() {
|
||||
if credential == currentCredential {
|
||||
continue
|
||||
}
|
||||
if !selection.allows(credential) {
|
||||
continue
|
||||
}
|
||||
if credential.isUsable() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func unavailableCredentialMessage(provider credentialProvider, fallback string) string {
|
||||
if provider == nil {
|
||||
return fallback
|
||||
}
|
||||
message := allCredentialsUnavailableError(provider.allCredentials()).Error()
|
||||
if message == "all credentials unavailable" && fallback != "" {
|
||||
return fallback
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSONError(w, r, http.StatusTooManyRequests, "rate_limit_error", retryableUsageMessage)
|
||||
}
|
||||
|
||||
func writeNonRetryableCredentialError(w http.ResponseWriter, r *http.Request, message string) {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", message)
|
||||
}
|
||||
|
||||
func writeCredentialUnavailableError(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
provider credentialProvider,
|
||||
currentCredential Credential,
|
||||
selection credentialSelection,
|
||||
fallback string,
|
||||
) {
|
||||
if hasAlternativeCredential(provider, currentCredential, selection) {
|
||||
writeRetryableUsageError(w, r)
|
||||
return
|
||||
}
|
||||
if provider != nil && strings.HasPrefix(allCredentialsUnavailableError(provider.allCredentials()).Error(), "all credentials rate-limited") {
|
||||
writeRetryableUsageError(w, r)
|
||||
return
|
||||
}
|
||||
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, fallback))
|
||||
}
|
||||
|
||||
func credentialSelectionForUser(userConfig *option.CCMUser) credentialSelection {
|
||||
selection := credentialSelection{scope: credentialSelectionScopeAll}
|
||||
if userConfig != nil && !userConfig.AllowExternalUsage {
|
||||
selection.scope = credentialSelectionScopeNonExternal
|
||||
selection.filter = func(credential Credential) bool {
|
||||
return !credential.isExternal()
|
||||
}
|
||||
}
|
||||
return selection
|
||||
}
|
||||
|
||||
func isHopByHopHeader(header string) bool {
|
||||
switch strings.ToLower(header) {
|
||||
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
|
||||
@@ -80,111 +123,129 @@ func isHopByHopHeader(header string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
weeklyWindowSeconds = 604800
|
||||
weeklyWindowMinutes = weeklyWindowSeconds / 60
|
||||
)
|
||||
|
||||
func parseInt64Header(headers http.Header, headerName string) (int64, bool) {
|
||||
headerValue := strings.TrimSpace(headers.Get(headerName))
|
||||
if headerValue == "" {
|
||||
return 0, false
|
||||
func isReverseProxyHeader(header string) bool {
|
||||
lowerHeader := strings.ToLower(header)
|
||||
if strings.HasPrefix(lowerHeader, "cf-") {
|
||||
return true
|
||||
}
|
||||
parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64)
|
||||
if parseError != nil {
|
||||
return 0, false
|
||||
switch lowerHeader {
|
||||
case "cdn-loop", "true-client-ip", "x-forwarded-for", "x-forwarded-proto", "x-real-ip":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return parsedValue, true
|
||||
}
|
||||
|
||||
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
|
||||
resetAtUnix, hasResetAt := parseInt64Header(headers, "anthropic-ratelimit-unified-7d-reset")
|
||||
if !hasResetAt || resetAtUnix <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &WeeklyCycleHint{
|
||||
WindowMinutes: weeklyWindowMinutes,
|
||||
ResetAt: time.Unix(resetAtUnix, 0).UTC(),
|
||||
func isAPIKeyHeader(header string) bool {
|
||||
switch strings.ToLower(header) {
|
||||
case "x-api-key", "api-key":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
boxService.Adapter
|
||||
ctx context.Context
|
||||
logger log.ContextLogger
|
||||
credentialPath string
|
||||
credentials *oauthCredentials
|
||||
users []option.CCMUser
|
||||
httpClient *http.Client
|
||||
httpHeaders http.Header
|
||||
listener *listener.Listener
|
||||
tlsConfig tls.ServerConfig
|
||||
httpServer *http.Server
|
||||
userManager *UserManager
|
||||
accessMutex sync.RWMutex
|
||||
usageTracker *AggregatedUsage
|
||||
trackingGroup sync.WaitGroup
|
||||
shuttingDown bool
|
||||
ctx context.Context
|
||||
logger log.ContextLogger
|
||||
options option.CCMServiceOptions
|
||||
httpHeaders http.Header
|
||||
listener *listener.Listener
|
||||
tlsConfig tls.ServerConfig
|
||||
httpServer *http.Server
|
||||
userManager *UserManager
|
||||
|
||||
providers map[string]credentialProvider
|
||||
allCredentials []Credential
|
||||
userConfigMap map[string]*option.CCMUser
|
||||
|
||||
sessionModelAccess sync.Mutex
|
||||
sessionModels map[sessionModelKey]time.Time
|
||||
|
||||
statusSubscriber *observable.Subscriber[struct{}]
|
||||
statusObserver *observable.Observer[struct{}]
|
||||
}
|
||||
|
||||
type sessionModelKey struct {
|
||||
sessionID string
|
||||
model string
|
||||
}
|
||||
|
||||
func (s *Service) cleanSessionModels() {
|
||||
now := time.Now()
|
||||
s.sessionModelAccess.Lock()
|
||||
for key, createdAt := range s.sessionModels {
|
||||
if now.Sub(createdAt) > sessionExpiry {
|
||||
delete(s.sessionModels, key)
|
||||
}
|
||||
}
|
||||
s.sessionModelAccess.Unlock()
|
||||
}
|
||||
|
||||
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) {
|
||||
serviceDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
Options: option.DialerOptions{
|
||||
Detour: options.Detour,
|
||||
},
|
||||
RemoteIsDomain: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create dialer")
|
||||
initCCMUserAgent(logger)
|
||||
|
||||
hasLegacy := options.CredentialPath != "" || options.UsagesPath != "" || options.Detour != ""
|
||||
if hasLegacy && len(options.Credentials) > 0 {
|
||||
return nil, E.New("credential_path/usages_path/detour and credentials are mutually exclusive")
|
||||
}
|
||||
if len(options.Credentials) == 0 {
|
||||
options.Credentials = []option.CCMCredential{{
|
||||
Type: "default",
|
||||
Tag: "default",
|
||||
DefaultOptions: option.CCMDefaultCredentialOptions{
|
||||
CredentialPath: options.CredentialPath,
|
||||
UsagesPath: options.UsagesPath,
|
||||
Detour: options.Detour,
|
||||
},
|
||||
}}
|
||||
options.CredentialPath = ""
|
||||
options.UsagesPath = ""
|
||||
options.Detour = ""
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSClientConfig: &stdTLS.Config{
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
},
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
},
|
||||
err := validateCCMOptions(options)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "validate options")
|
||||
}
|
||||
|
||||
userManager := &UserManager{
|
||||
tokenMap: make(map[string]string),
|
||||
}
|
||||
|
||||
var usageTracker *AggregatedUsage
|
||||
if options.UsagesPath != "" {
|
||||
usageTracker = &AggregatedUsage{
|
||||
LastUpdated: time.Now(),
|
||||
Combinations: make([]CostCombination, 0),
|
||||
filePath: options.UsagesPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
statusSubscriber := observable.NewSubscriber[struct{}](16)
|
||||
service := &Service{
|
||||
Adapter: boxService.NewAdapter(C.TypeCCM, tag),
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
credentialPath: options.CredentialPath,
|
||||
users: options.Users,
|
||||
httpClient: httpClient,
|
||||
httpHeaders: options.Headers.Build(),
|
||||
Adapter: boxService.NewAdapter(C.TypeCCM, tag),
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
options: options,
|
||||
httpHeaders: options.Headers.Build(),
|
||||
listener: listener.New(listener.Options{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
Network: []string{N.NetworkTCP},
|
||||
Listen: options.ListenOptions,
|
||||
}),
|
||||
userManager: userManager,
|
||||
usageTracker: usageTracker,
|
||||
userManager: userManager,
|
||||
sessionModels: make(map[sessionModelKey]time.Time),
|
||||
statusSubscriber: statusSubscriber,
|
||||
statusObserver: observable.NewObserver[struct{}](statusSubscriber, 8),
|
||||
}
|
||||
|
||||
providers, allCredentials, err := buildCredentialProviders(ctx, options, logger)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "build credential providers")
|
||||
}
|
||||
service.providers = providers
|
||||
service.allCredentials = allCredentials
|
||||
|
||||
userConfigMap := make(map[string]*option.CCMUser)
|
||||
for i := range options.Users {
|
||||
userConfigMap[options.Users[i].Name] = &options.Users[i]
|
||||
}
|
||||
service.userConfigMap = userConfigMap
|
||||
|
||||
if options.TLS != nil {
|
||||
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
|
||||
if err != nil {
|
||||
@@ -201,28 +262,26 @@ func (s *Service) Start(stage adapter.StartStage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.userManager.UpdateUsers(s.users)
|
||||
s.userManager.UpdateUsers(s.options.Users)
|
||||
|
||||
credentials, err := platformReadCredentials(s.credentialPath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read credentials")
|
||||
}
|
||||
s.credentials = credentials
|
||||
|
||||
if s.usageTracker != nil {
|
||||
err = s.usageTracker.Load()
|
||||
for _, credential := range s.allCredentials {
|
||||
credential.setStatusSubscriber(s.statusSubscriber)
|
||||
if external, ok := credential.(*externalCredential); ok && external.reverse && external.connectorURL != nil {
|
||||
external.reverseService = s
|
||||
}
|
||||
err := credential.start()
|
||||
if err != nil {
|
||||
s.logger.Warn("load usage statistics: ", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Mount("/", s)
|
||||
|
||||
s.httpServer = &http.Server{Handler: router}
|
||||
s.httpServer = &http.Server{Handler: h2c.NewHandler(router, &http2.Server{})}
|
||||
|
||||
if s.tlsConfig != nil {
|
||||
err = s.tlsConfig.Start()
|
||||
err := s.tlsConfig.Start()
|
||||
if err != nil {
|
||||
return E.Cause(err, "create TLS config")
|
||||
}
|
||||
@@ -242,7 +301,7 @@ func (s *Service) Start(stage adapter.StartStage) error {
|
||||
|
||||
go func() {
|
||||
serveErr := s.httpServer.Serve(tcpListener)
|
||||
if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
|
||||
if serveErr != nil && !E.IsClosed(serveErr) {
|
||||
s.logger.Error("serve error: ", serveErr)
|
||||
}
|
||||
}()
|
||||
@@ -250,347 +309,30 @@ func (s *Service) Start(stage adapter.StartStage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) getAccessToken() (string, error) {
|
||||
s.accessMutex.RLock()
|
||||
if !s.credentials.needsRefresh() {
|
||||
token := s.credentials.AccessToken
|
||||
s.accessMutex.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
s.accessMutex.RUnlock()
|
||||
|
||||
s.accessMutex.Lock()
|
||||
defer s.accessMutex.Unlock()
|
||||
|
||||
if !s.credentials.needsRefresh() {
|
||||
return s.credentials.AccessToken, nil
|
||||
}
|
||||
|
||||
newCredentials, err := refreshToken(s.httpClient, s.credentials)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.credentials = newCredentials
|
||||
|
||||
err = platformWriteCredentials(newCredentials, s.credentialPath)
|
||||
if err != nil {
|
||||
s.logger.Warn("persist refreshed token: ", err)
|
||||
}
|
||||
|
||||
return newCredentials.AccessToken, nil
|
||||
}
|
||||
|
||||
func detectContextWindow(betaHeader string, totalInputTokens int64) int {
|
||||
if totalInputTokens > premiumContextThreshold {
|
||||
features := strings.Split(betaHeader, ",")
|
||||
for _, feature := range features {
|
||||
if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") {
|
||||
return contextWindowPremium
|
||||
}
|
||||
}
|
||||
}
|
||||
return contextWindowStandard
|
||||
}
|
||||
|
||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.HasPrefix(r.URL.Path, "/v1/") {
|
||||
writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found")
|
||||
return
|
||||
}
|
||||
|
||||
var username string
|
||||
if len(s.users) > 0 {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
var ok bool
|
||||
username, ok = s.userManager.Authenticate(clientToken)
|
||||
func (s *Service) InterfaceUpdated() {
|
||||
for _, credential := range s.allCredentials {
|
||||
external, ok := credential.(*externalCredential)
|
||||
if !ok {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
var requestModel string
|
||||
var messagesCount int
|
||||
|
||||
if s.usageTracker != nil && r.Body != nil {
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err == nil {
|
||||
var request struct {
|
||||
Model string `json:"model"`
|
||||
Messages []anthropic.MessageParam `json:"messages"`
|
||||
}
|
||||
err := json.Unmarshal(bodyBytes, &request)
|
||||
if err == nil {
|
||||
requestModel = request.Model
|
||||
messagesCount = len(request.Messages)
|
||||
}
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, err := s.getAccessToken()
|
||||
if err != nil {
|
||||
s.logger.Error("get access token: ", err)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed")
|
||||
return
|
||||
}
|
||||
|
||||
proxyURL := claudeAPIBaseURL + r.URL.RequestURI()
|
||||
proxyRequest, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, r.Body)
|
||||
if err != nil {
|
||||
s.logger.Error("create proxy request: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
for key, values := range r.Header {
|
||||
if !isHopByHopHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
serviceOverridesAcceptEncoding := len(s.httpHeaders.Values("Accept-Encoding")) > 0
|
||||
if s.usageTracker != nil && !serviceOverridesAcceptEncoding {
|
||||
// Strip Accept-Encoding so Go Transport adds it automatically
|
||||
// and transparently decompresses the response for correct usage counting.
|
||||
proxyRequest.Header.Del("Accept-Encoding")
|
||||
}
|
||||
|
||||
anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta")
|
||||
if anthropicBetaHeader != "" {
|
||||
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader)
|
||||
} else {
|
||||
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
|
||||
}
|
||||
|
||||
for key, values := range s.httpHeaders {
|
||||
proxyRequest.Header.Del(key)
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
response, err := s.httpClient.Do(proxyRequest)
|
||||
if err != nil {
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
for key, values := range response.Header {
|
||||
if !isHopByHopHeader(key) {
|
||||
w.Header()[key] = values
|
||||
}
|
||||
}
|
||||
w.WriteHeader(response.StatusCode)
|
||||
|
||||
if s.usageTracker != nil && response.StatusCode == http.StatusOK {
|
||||
s.handleResponseWithTracking(w, response, requestModel, anthropicBetaHeader, messagesCount, username)
|
||||
} else {
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
if err == nil && mediaType != "text/event-stream" {
|
||||
_, _ = io.Copy(w, response.Body)
|
||||
return
|
||||
}
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error("streaming not supported")
|
||||
return
|
||||
}
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
_, writeError := w.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.Error("write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
|
||||
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
isStreaming := err == nil && mediaType == "text/event-stream"
|
||||
|
||||
if !isStreaming {
|
||||
bodyBytes, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
s.logger.Error("read response body: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
var message anthropic.Message
|
||||
var usage anthropic.Usage
|
||||
var responseModel string
|
||||
err = json.Unmarshal(bodyBytes, &message)
|
||||
if err == nil {
|
||||
responseModel = string(message.Model)
|
||||
usage = message.Usage
|
||||
}
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
|
||||
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
totalInputTokens := usage.InputTokens + usage.CacheCreationInputTokens + usage.CacheReadInputTokens
|
||||
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
messagesCount,
|
||||
usage.InputTokens,
|
||||
usage.OutputTokens,
|
||||
usage.CacheReadInputTokens,
|
||||
usage.CacheCreationInputTokens,
|
||||
usage.CacheCreation.Ephemeral5mInputTokens,
|
||||
usage.CacheCreation.Ephemeral1hInputTokens,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = writer.Write(bodyBytes)
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error("streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
var accumulatedUsage anthropic.Usage
|
||||
var responseModel string
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
var leftover []byte
|
||||
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
data := append(leftover, buffer[:n]...)
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
|
||||
if err == nil {
|
||||
leftover = lines[len(lines)-1]
|
||||
lines = lines[:len(lines)-1]
|
||||
} else {
|
||||
leftover = nil
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
eventData := bytes.TrimPrefix(line, []byte("data: "))
|
||||
if bytes.Equal(eventData, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
|
||||
var event anthropic.MessageStreamEventUnion
|
||||
err := json.Unmarshal(eventData, &event)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
switch event.Type {
|
||||
case "message_start":
|
||||
messageStart := event.AsMessageStart()
|
||||
if messageStart.Message.Model != "" {
|
||||
responseModel = string(messageStart.Message.Model)
|
||||
}
|
||||
if messageStart.Message.Usage.InputTokens > 0 {
|
||||
accumulatedUsage.InputTokens = messageStart.Message.Usage.InputTokens
|
||||
accumulatedUsage.CacheReadInputTokens = messageStart.Message.Usage.CacheReadInputTokens
|
||||
accumulatedUsage.CacheCreationInputTokens = messageStart.Message.Usage.CacheCreationInputTokens
|
||||
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral5mInputTokens
|
||||
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral1hInputTokens
|
||||
}
|
||||
case "message_delta":
|
||||
messageDelta := event.AsMessageDelta()
|
||||
if messageDelta.Usage.OutputTokens > 0 {
|
||||
accumulatedUsage.OutputTokens = messageDelta.Usage.OutputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, writeError := writer.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.Error("write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
|
||||
if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
totalInputTokens := accumulatedUsage.InputTokens + accumulatedUsage.CacheCreationInputTokens + accumulatedUsage.CacheReadInputTokens
|
||||
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
messagesCount,
|
||||
accumulatedUsage.InputTokens,
|
||||
accumulatedUsage.OutputTokens,
|
||||
accumulatedUsage.CacheReadInputTokens,
|
||||
accumulatedUsage.CacheCreationInputTokens,
|
||||
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens,
|
||||
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
return
|
||||
if external.reverse && external.connectorURL != nil {
|
||||
external.reverseService = s
|
||||
external.resetReverseContext()
|
||||
go external.connectorLoop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
s.statusObserver.Close()
|
||||
err := common.Close(
|
||||
common.PtrOrNil(s.httpServer),
|
||||
common.PtrOrNil(s.listener),
|
||||
s.tlsConfig,
|
||||
)
|
||||
|
||||
if s.usageTracker != nil {
|
||||
s.usageTracker.cancelPendingSave()
|
||||
saveErr := s.usageTracker.Save()
|
||||
if saveErr != nil {
|
||||
s.logger.Error("save usage statistics: ", saveErr)
|
||||
}
|
||||
for _, credential := range s.allCredentials {
|
||||
credential.close()
|
||||
}
|
||||
|
||||
return err
|
||||
|
||||
667
service/ccm/service_handler.go
Normal file
667
service/ccm/service_handler.go
Normal file
@@ -0,0 +1,667 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
)
|
||||
|
||||
const (
|
||||
contextWindowStandard = 200000
|
||||
contextWindowPremium = 1000000
|
||||
premiumContextThreshold = 200000
|
||||
)
|
||||
|
||||
const (
|
||||
weeklyWindowSeconds = 604800
|
||||
weeklyWindowMinutes = weeklyWindowSeconds / 60
|
||||
)
|
||||
|
||||
type ccmRequestMetadata struct {
|
||||
Model string
|
||||
MessagesCount int
|
||||
SessionID string
|
||||
}
|
||||
|
||||
func isExtendedContextRequest(betaHeader string) bool {
|
||||
for _, feature := range strings.Split(betaHeader, ",") {
|
||||
if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isFastModeRequest(betaHeader string) bool {
|
||||
for _, feature := range strings.Split(betaHeader, ",") {
|
||||
if strings.HasPrefix(strings.TrimSpace(feature), "fast-mode") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func detectContextWindow(betaHeader string, totalInputTokens int64) int {
|
||||
if totalInputTokens > premiumContextThreshold {
|
||||
if isExtendedContextRequest(betaHeader) {
|
||||
return contextWindowPremium
|
||||
}
|
||||
}
|
||||
return contextWindowStandard
|
||||
}
|
||||
|
||||
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
|
||||
resetAt, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset")
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &WeeklyCycleHint{
|
||||
WindowMinutes: weeklyWindowMinutes,
|
||||
ResetAt: resetAt.UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
// extractCCMSessionID extracts the session ID from the metadata.user_id field.
|
||||
//
|
||||
// Claude Code >= 2.1.78 (@anthropic-ai/claude-code) encodes user_id as:
|
||||
//
|
||||
// JSON.stringify({device_id, account_uuid, session_id, ...extras})
|
||||
//
|
||||
// ref: cli.js L66() — metadata constructor
|
||||
//
|
||||
// Claude Code < 2.1.78 used a template literal:
|
||||
//
|
||||
// `user_${deviceId}_account_${accountUuid}_session_${sessionId}`
|
||||
//
|
||||
// ref: cli.js qs() — old metadata constructor
|
||||
//
|
||||
// Returns ("", nil) when userID is empty.
|
||||
// Returns error when user_id is present but in an unrecognized format.
|
||||
func extractCCMSessionID(userID string) (string, error) {
|
||||
if userID == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// v2.1.78+ JSON object format
|
||||
var userIDObject struct {
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
if json.Unmarshal([]byte(userID), &userIDObject) == nil && userIDObject.SessionID != "" {
|
||||
return userIDObject.SessionID, nil
|
||||
}
|
||||
|
||||
// legacy template literal format
|
||||
sessionIndex := strings.LastIndex(userID, "_session_")
|
||||
if sessionIndex >= 0 {
|
||||
return userID[sessionIndex+len("_session_"):], nil
|
||||
}
|
||||
|
||||
return "", E.New("unrecognized metadata.user_id format: ", userID)
|
||||
}
|
||||
|
||||
func extractCCMRequestMetadata(path string, bodyBytes []byte) (ccmRequestMetadata, error) {
|
||||
switch path {
|
||||
case "/v1/messages":
|
||||
var request anthropic.MessageNewParams
|
||||
if json.Unmarshal(bodyBytes, &request) != nil {
|
||||
return ccmRequestMetadata{}, nil
|
||||
}
|
||||
|
||||
metadata := ccmRequestMetadata{
|
||||
Model: string(request.Model),
|
||||
MessagesCount: len(request.Messages),
|
||||
}
|
||||
if request.Metadata.UserID.Valid() {
|
||||
sessionID, err := extractCCMSessionID(request.Metadata.UserID.Value)
|
||||
if err != nil {
|
||||
return ccmRequestMetadata{}, err
|
||||
}
|
||||
metadata.SessionID = sessionID
|
||||
}
|
||||
return metadata, nil
|
||||
case "/v1/messages/count_tokens":
|
||||
var request anthropic.MessageCountTokensParams
|
||||
if json.Unmarshal(bodyBytes, &request) != nil {
|
||||
return ccmRequestMetadata{}, nil
|
||||
}
|
||||
return ccmRequestMetadata{
|
||||
Model: string(request.Model),
|
||||
MessagesCount: len(request.Messages),
|
||||
}, nil
|
||||
default:
|
||||
return ccmRequestMetadata{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := log.ContextWithNewID(r.Context())
|
||||
if r.URL.Path == "/ccm/v1/status" {
|
||||
s.handleStatusEndpoint(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/ccm/v1/reverse" {
|
||||
s.handleReverseConnect(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(r.URL.Path, "/v1/") {
|
||||
writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found")
|
||||
return
|
||||
}
|
||||
|
||||
if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"API key authentication is not supported; use Authorization: Bearer with a CCM user token")
|
||||
return
|
||||
}
|
||||
|
||||
var username string
|
||||
if len(s.options.Users) > 0 {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
var ok bool
|
||||
username, ok = s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Always read body to extract model and session ID
|
||||
var bodyBytes []byte
|
||||
var requestModel string
|
||||
var messagesCount int
|
||||
var sessionID string
|
||||
|
||||
if r.Body != nil {
|
||||
var err error
|
||||
bodyBytes, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "read request body: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
requestMetadata, err := extractCCMRequestMetadata(r.URL.Path, bodyBytes)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "invalid metadata format: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "invalid metadata format")
|
||||
return
|
||||
}
|
||||
requestModel = requestMetadata.Model
|
||||
messagesCount = requestMetadata.MessagesCount
|
||||
sessionID = requestMetadata.SessionID
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}
|
||||
|
||||
// Resolve credential provider and user config
|
||||
var provider credentialProvider
|
||||
var userConfig *option.CCMUser
|
||||
if len(s.options.Users) > 0 {
|
||||
userConfig = s.userConfigMap[username]
|
||||
var err error
|
||||
provider, err = credentialForUser(s.userConfigMap, s.providers, username)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "resolve credential: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
provider = s.providers[s.options.Credentials[0].Tag]
|
||||
}
|
||||
if provider == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale()
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" {
|
||||
for _, credential := range s.allCredentials {
|
||||
if credential.tagName() == userConfig.ExternalCredential && !credential.isUsable() {
|
||||
credential.pollUsage()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
s.cleanSessionModels()
|
||||
|
||||
anthropicBetaHeader := r.Header.Get("anthropic-beta")
|
||||
if isFastModeRequest(anthropicBetaHeader) {
|
||||
if _, isSingle := provider.(*singleCredentialProvider); !isSingle {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"fast mode requests will consume Extra usage, please use a default credential directly")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
selection := credentialSelectionForUser(userConfig)
|
||||
|
||||
selectedCredential, isNew, err := provider.selectCredential(sessionID, selection)
|
||||
if err != nil {
|
||||
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error()))
|
||||
return
|
||||
}
|
||||
modelDisplay := requestModel
|
||||
if requestModel != "" && isExtendedContextRequest(anthropicBetaHeader) {
|
||||
modelDisplay += "[1m]"
|
||||
}
|
||||
isNewModel := false
|
||||
if sessionID != "" && modelDisplay != "" {
|
||||
key := sessionModelKey{sessionID, modelDisplay}
|
||||
s.sessionModelAccess.Lock()
|
||||
_, exists := s.sessionModels[key]
|
||||
if !exists {
|
||||
s.sessionModels[key] = time.Now()
|
||||
isNewModel = true
|
||||
}
|
||||
s.sessionModelAccess.Unlock()
|
||||
}
|
||||
if isNew || isNewModel {
|
||||
logParts := []any{"assigned credential ", selectedCredential.tagName()}
|
||||
if sessionID != "" {
|
||||
logParts = append(logParts, " for session ", sessionID)
|
||||
}
|
||||
if username != "" {
|
||||
logParts = append(logParts, " by user ", username)
|
||||
}
|
||||
if modelDisplay != "" {
|
||||
logParts = append(logParts, ", model=", modelDisplay)
|
||||
}
|
||||
s.logger.DebugContext(ctx, logParts...)
|
||||
}
|
||||
|
||||
if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"fast mode requests cannot be proxied through external credentials")
|
||||
return
|
||||
}
|
||||
|
||||
requestContext := selectedCredential.wrapRequestContext(ctx)
|
||||
{
|
||||
currentRequestContext := requestContext
|
||||
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
|
||||
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
||||
}))
|
||||
}
|
||||
defer func() {
|
||||
requestContext.cancelRequest()
|
||||
}()
|
||||
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "create proxy request: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
response, err := selectedCredential.httpClient().Do(proxyRequest)
|
||||
if err != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request")
|
||||
return
|
||||
}
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
|
||||
// Transparent 429 retry
|
||||
for response.StatusCode == http.StatusTooManyRequests {
|
||||
resetAt := parseRateLimitResetFromHeaders(response.Header)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
if bodyBytes == nil || nextCredential == nil {
|
||||
response.Body.Close()
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
|
||||
return
|
||||
}
|
||||
response.Body.Close()
|
||||
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
requestContext.cancelRequest()
|
||||
requestContext = nextCredential.wrapRequestContext(ctx)
|
||||
{
|
||||
currentRequestContext := requestContext
|
||||
requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() {
|
||||
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
||||
}))
|
||||
}
|
||||
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if buildErr != nil {
|
||||
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
|
||||
return
|
||||
}
|
||||
retryResponse, retryErr := nextCredential.httpClient().Do(retryRequest)
|
||||
if retryErr != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request")
|
||||
return
|
||||
}
|
||||
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
|
||||
return
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
response = retryResponse
|
||||
selectedCredential = nextCredential
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
|
||||
if response.StatusCode == 529 {
|
||||
s.logger.WarnContext(ctx, "upstream overloaded from ", selectedCredential.tagName())
|
||||
for key, values := range response.Header {
|
||||
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
|
||||
w.Header()[key] = values
|
||||
}
|
||||
}
|
||||
w.WriteHeader(response.StatusCode)
|
||||
io.Copy(w, response.Body)
|
||||
return
|
||||
}
|
||||
|
||||
if response.StatusCode == http.StatusBadRequest {
|
||||
if selectedCredential.isExternal() {
|
||||
selectedCredential.markUpstreamRejected()
|
||||
} else {
|
||||
provider.pollCredentialIfStale(selectedCredential)
|
||||
}
|
||||
s.logger.ErrorContext(ctx, "upstream rejected from ", selectedCredential.tagName(), ": status ", response.StatusCode)
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "upstream rejected credential")
|
||||
return
|
||||
}
|
||||
|
||||
// ref (@anthropic-ai/claude-code @2.1.81): cli.js NA9 (line 179488-179494) — 401 recovery
|
||||
// ref: cli.js CR1 (line 314268-314273) — 403 "OAuth token has been revoked" recovery
|
||||
if !selectedCredential.isExternal() && bodyBytes != nil &&
|
||||
(response.StatusCode == http.StatusUnauthorized || response.StatusCode == http.StatusForbidden) {
|
||||
shouldRetry := response.StatusCode == http.StatusUnauthorized
|
||||
var peekBody []byte
|
||||
if response.StatusCode == http.StatusForbidden {
|
||||
peekBody, _ = io.ReadAll(response.Body)
|
||||
shouldRetry = strings.Contains(string(peekBody), "OAuth token has been revoked")
|
||||
if !shouldRetry {
|
||||
response.Body.Close()
|
||||
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(peekBody))
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
|
||||
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(peekBody))
|
||||
return
|
||||
}
|
||||
}
|
||||
if shouldRetry {
|
||||
recovered := false
|
||||
var recoverErr error
|
||||
if defaultCred, ok := selectedCredential.(*defaultCredential); ok {
|
||||
failedAccessToken := ""
|
||||
currentCredentials := defaultCred.currentCredentials()
|
||||
if currentCredentials != nil {
|
||||
failedAccessToken = currentCredentials.AccessToken
|
||||
}
|
||||
s.logger.WarnContext(ctx, "upstream auth failure from ", selectedCredential.tagName(), ", reloading credentials and retrying")
|
||||
recovered, recoverErr = defaultCred.recoverAuthFailure(failedAccessToken)
|
||||
}
|
||||
if recoverErr != nil {
|
||||
response.Body.Close()
|
||||
if isHardRefreshFailure(recoverErr) || selectedCredential.unavailableError() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable during auth recovery")
|
||||
return
|
||||
}
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(recoverErr, "auth recovery").Error())
|
||||
return
|
||||
}
|
||||
if recovered {
|
||||
response.Body.Close()
|
||||
retryRequest, buildErr := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if buildErr != nil {
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(buildErr, "rebuild request after auth recovery").Error())
|
||||
return
|
||||
}
|
||||
retryResponse, retryErr := selectedCredential.httpClient().Do(retryRequest)
|
||||
if retryErr != nil {
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", E.Cause(retryErr, "retry request after auth recovery").Error())
|
||||
return
|
||||
}
|
||||
response = retryResponse
|
||||
defer retryResponse.Body.Close()
|
||||
} else if response.StatusCode == http.StatusForbidden {
|
||||
response.Body = io.NopCloser(bytes.NewReader(peekBody))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
|
||||
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
|
||||
return
|
||||
}
|
||||
|
||||
s.rewriteResponseHeaders(response.Header, provider, userConfig)
|
||||
|
||||
for key, values := range response.Header {
|
||||
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
|
||||
w.Header()[key] = values
|
||||
}
|
||||
}
|
||||
w.WriteHeader(response.StatusCode)
|
||||
|
||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||
if usageTracker != nil && response.StatusCode == http.StatusOK {
|
||||
s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
|
||||
} else {
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
if err == nil && mediaType != "text/event-stream" {
|
||||
_, _ = io.Copy(w, response.Body)
|
||||
return
|
||||
}
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
_, writeError := w.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
if E.IsClosedOrCanceled(writeError) {
|
||||
return
|
||||
}
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
|
||||
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
isStreaming := err == nil && mediaType == "text/event-stream"
|
||||
|
||||
if !isStreaming {
|
||||
bodyBytes, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "read response body: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
var message anthropic.Message
|
||||
var usage anthropic.Usage
|
||||
var responseModel string
|
||||
err = json.Unmarshal(bodyBytes, &message)
|
||||
if err == nil {
|
||||
responseModel = string(message.Model)
|
||||
usage = message.Usage
|
||||
}
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
|
||||
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
totalInputTokens := usage.InputTokens + usage.CacheCreationInputTokens + usage.CacheReadInputTokens
|
||||
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
messagesCount,
|
||||
usage.InputTokens,
|
||||
usage.OutputTokens,
|
||||
usage.CacheReadInputTokens,
|
||||
usage.CacheCreationInputTokens,
|
||||
usage.CacheCreation.Ephemeral5mInputTokens,
|
||||
usage.CacheCreation.Ephemeral1hInputTokens,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = writer.Write(bodyBytes)
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
var accumulatedUsage anthropic.Usage
|
||||
var responseModel string
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
var leftover []byte
|
||||
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
data := append(leftover, buffer[:n]...)
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
|
||||
if err == nil {
|
||||
leftover = lines[len(lines)-1]
|
||||
lines = lines[:len(lines)-1]
|
||||
} else {
|
||||
leftover = nil
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
eventData := bytes.TrimPrefix(line, []byte("data: "))
|
||||
if bytes.Equal(eventData, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
|
||||
var event anthropic.MessageStreamEventUnion
|
||||
err := json.Unmarshal(eventData, &event)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
switch event.Type {
|
||||
case "message_start":
|
||||
messageStart := event.AsMessageStart()
|
||||
if messageStart.Message.Model != "" {
|
||||
responseModel = string(messageStart.Message.Model)
|
||||
}
|
||||
if messageStart.Message.Usage.InputTokens > 0 {
|
||||
accumulatedUsage.InputTokens = messageStart.Message.Usage.InputTokens
|
||||
accumulatedUsage.CacheReadInputTokens = messageStart.Message.Usage.CacheReadInputTokens
|
||||
accumulatedUsage.CacheCreationInputTokens = messageStart.Message.Usage.CacheCreationInputTokens
|
||||
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral5mInputTokens
|
||||
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral1hInputTokens
|
||||
}
|
||||
case "message_delta":
|
||||
messageDelta := event.AsMessageDelta()
|
||||
if messageDelta.Usage.OutputTokens > 0 {
|
||||
accumulatedUsage.OutputTokens = messageDelta.Usage.OutputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, writeError := writer.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
if E.IsClosedOrCanceled(writeError) {
|
||||
return
|
||||
}
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
|
||||
if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
totalInputTokens := accumulatedUsage.InputTokens + accumulatedUsage.CacheCreationInputTokens + accumulatedUsage.CacheReadInputTokens
|
||||
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
messagesCount,
|
||||
accumulatedUsage.InputTokens,
|
||||
accumulatedUsage.OutputTokens,
|
||||
accumulatedUsage.CacheReadInputTokens,
|
||||
accumulatedUsage.CacheCreationInputTokens,
|
||||
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens,
|
||||
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
221
service/ccm/service_handler_test.go
Normal file
221
service/ccm/service_handler_test.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newHandlerCredential(t *testing.T, transport http.RoundTripper) (*defaultCredential, string) {
|
||||
t.Helper()
|
||||
directory := t.TempDir()
|
||||
credentialPath := filepath.Join(directory, ".credentials.json")
|
||||
writeTestCredentials(t, credentialPath, &oauthCredentials{
|
||||
AccessToken: "old-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(time.Hour).UnixMilli(),
|
||||
Scopes: []string{"user:profile", "user:inference"},
|
||||
SubscriptionType: optionalStringPointer("max"),
|
||||
RateLimitTier: optionalStringPointer("default_claude_max_20x"),
|
||||
})
|
||||
credential := newTestDefaultCredential(t, credentialPath, transport)
|
||||
if err := credential.reloadCredentials(true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
seedTestCredentialState(credential)
|
||||
return credential, credentialPath
|
||||
}
|
||||
|
||||
func TestServiceHandlerRecoversFrom401(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var messageRequests atomic.Int32
|
||||
var refreshRequests atomic.Int32
|
||||
credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
switch request.URL.Path {
|
||||
case "/v1/messages":
|
||||
call := messageRequests.Add(1)
|
||||
switch request.Header.Get("Authorization") {
|
||||
case "Bearer old-token":
|
||||
if call != 1 {
|
||||
t.Fatalf("unexpected old-token call count %d", call)
|
||||
}
|
||||
return newTextResponse(http.StatusUnauthorized, "unauthorized"), nil
|
||||
case "Bearer new-token":
|
||||
return newJSONResponse(http.StatusOK, `{}`), nil
|
||||
default:
|
||||
t.Fatalf("unexpected authorization header %q", request.Header.Get("Authorization"))
|
||||
}
|
||||
case "/v1/oauth/token":
|
||||
refreshRequests.Add(1)
|
||||
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
|
||||
default:
|
||||
t.Fatalf("unexpected path %s", request.URL.Path)
|
||||
}
|
||||
return nil, nil
|
||||
}))
|
||||
|
||||
service := newTestService(credential)
|
||||
recorder := httptest.NewRecorder()
|
||||
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
if messageRequests.Load() != 2 {
|
||||
t.Fatalf("expected two upstream message requests, got %d", messageRequests.Load())
|
||||
}
|
||||
if refreshRequests.Load() != 1 {
|
||||
t.Fatalf("expected one refresh request, got %d", refreshRequests.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceHandlerRecoversFromRevoked403(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var messageRequests atomic.Int32
|
||||
var refreshRequests atomic.Int32
|
||||
credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
switch request.URL.Path {
|
||||
case "/v1/messages":
|
||||
messageRequests.Add(1)
|
||||
if request.Header.Get("Authorization") == "Bearer old-token" {
|
||||
return newTextResponse(http.StatusForbidden, "OAuth token has been revoked"), nil
|
||||
}
|
||||
return newJSONResponse(http.StatusOK, `{}`), nil
|
||||
case "/v1/oauth/token":
|
||||
refreshRequests.Add(1)
|
||||
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
|
||||
default:
|
||||
t.Fatalf("unexpected path %s", request.URL.Path)
|
||||
}
|
||||
return nil, nil
|
||||
}))
|
||||
|
||||
service := newTestService(credential)
|
||||
recorder := httptest.NewRecorder()
|
||||
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
if refreshRequests.Load() != 1 {
|
||||
t.Fatalf("expected one refresh request, got %d", refreshRequests.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceHandlerDoesNotRecoverFromOrdinary403(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var refreshRequests atomic.Int32
|
||||
credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
switch request.URL.Path {
|
||||
case "/v1/messages":
|
||||
return newTextResponse(http.StatusForbidden, "forbidden"), nil
|
||||
case "/v1/oauth/token":
|
||||
refreshRequests.Add(1)
|
||||
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
|
||||
default:
|
||||
t.Fatalf("unexpected path %s", request.URL.Path)
|
||||
}
|
||||
return nil, nil
|
||||
}))
|
||||
|
||||
service := newTestService(credential)
|
||||
recorder := httptest.NewRecorder()
|
||||
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
|
||||
|
||||
if recorder.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d", recorder.Code)
|
||||
}
|
||||
if refreshRequests.Load() != 0 {
|
||||
t.Fatalf("expected no refresh request, got %d", refreshRequests.Load())
|
||||
}
|
||||
if !strings.Contains(recorder.Body.String(), "forbidden") {
|
||||
t.Fatalf("expected forbidden body, got %s", recorder.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceHandlerUsesReloadedTokenBeforeRefreshing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var messageRequests atomic.Int32
|
||||
var refreshRequests atomic.Int32
|
||||
var credentialPath string
|
||||
var credential *defaultCredential
|
||||
credential, credentialPath = newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
switch request.URL.Path {
|
||||
case "/v1/messages":
|
||||
call := messageRequests.Add(1)
|
||||
if request.Header.Get("Authorization") == "Bearer old-token" {
|
||||
updatedCredentials := readTestCredentials(t, credentialPath)
|
||||
updatedCredentials.AccessToken = "disk-token"
|
||||
updatedCredentials.ExpiresAt = time.Now().Add(time.Hour).UnixMilli()
|
||||
writeTestCredentials(t, credentialPath, updatedCredentials)
|
||||
if call != 1 {
|
||||
t.Fatalf("unexpected old-token call count %d", call)
|
||||
}
|
||||
return newTextResponse(http.StatusUnauthorized, "unauthorized"), nil
|
||||
}
|
||||
if request.Header.Get("Authorization") != "Bearer disk-token" {
|
||||
t.Fatalf("expected disk token retry, got %q", request.Header.Get("Authorization"))
|
||||
}
|
||||
return newJSONResponse(http.StatusOK, `{}`), nil
|
||||
case "/v1/oauth/token":
|
||||
refreshRequests.Add(1)
|
||||
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
|
||||
default:
|
||||
t.Fatalf("unexpected path %s", request.URL.Path)
|
||||
}
|
||||
return nil, nil
|
||||
}))
|
||||
|
||||
service := newTestService(credential)
|
||||
recorder := httptest.NewRecorder()
|
||||
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
if refreshRequests.Load() != 0 {
|
||||
t.Fatalf("expected zero refresh requests, got %d", refreshRequests.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceHandlerRetriesAuthRecoveryOnlyOnce(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var messageRequests atomic.Int32
|
||||
var refreshRequests atomic.Int32
|
||||
credential, _ := newHandlerCredential(t, roundTripFunc(func(request *http.Request) (*http.Response, error) {
|
||||
switch request.URL.Path {
|
||||
case "/v1/messages":
|
||||
messageRequests.Add(1)
|
||||
return newTextResponse(http.StatusUnauthorized, "still unauthorized"), nil
|
||||
case "/v1/oauth/token":
|
||||
refreshRequests.Add(1)
|
||||
return newJSONResponse(http.StatusOK, `{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`), nil
|
||||
default:
|
||||
t.Fatalf("unexpected path %s", request.URL.Path)
|
||||
}
|
||||
return nil, nil
|
||||
}))
|
||||
|
||||
service := newTestService(credential)
|
||||
recorder := httptest.NewRecorder()
|
||||
service.ServeHTTP(recorder, newMessageRequest(`{"model":"claude","messages":[],"metadata":{"user_id":"{\"session_id\":\"session\"}"}}`))
|
||||
|
||||
if recorder.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d", recorder.Code)
|
||||
}
|
||||
if messageRequests.Load() != 2 {
|
||||
t.Fatalf("expected exactly two upstream attempts, got %d", messageRequests.Load())
|
||||
}
|
||||
if refreshRequests.Load() != 1 {
|
||||
t.Fatalf("expected exactly one refresh request, got %d", refreshRequests.Load())
|
||||
}
|
||||
}
|
||||
115
service/ccm/service_json_test.go
Normal file
115
service/ccm/service_json_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
)
|
||||
|
||||
func TestWriteJSONErrorUsesAnthropicShape(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, "/v1/messages", nil)
|
||||
request.Header.Set("Request-Id", "req_123")
|
||||
|
||||
writeJSONError(recorder, request, http.StatusBadRequest, "invalid_request_error", "broken")
|
||||
|
||||
if recorder.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||
}
|
||||
|
||||
var body anthropic.ErrorResponse
|
||||
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(body.Type) != "error" {
|
||||
t.Fatalf("expected error type, got %q", body.Type)
|
||||
}
|
||||
if body.RequestID != "req_123" {
|
||||
t.Fatalf("expected req_123 request ID, got %q", body.RequestID)
|
||||
}
|
||||
if body.Error.Type != "invalid_request_error" {
|
||||
t.Fatalf("expected invalid_request_error, got %q", body.Error.Type)
|
||||
}
|
||||
if body.Error.Message != "broken" {
|
||||
t.Fatalf("expected broken message, got %q", body.Error.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCCMRequestMetadataFromMessagesJSONSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
metadata, err := extractCCMRequestMetadata("/v1/messages", []byte(`{
|
||||
"model":"claude-sonnet-4-5",
|
||||
"max_tokens":1,
|
||||
"messages":[{"role":"user","content":"hello"}],
|
||||
"metadata":{"user_id":"{\"session_id\":\"session-1\"}"}
|
||||
}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if metadata.Model != "claude-sonnet-4-5" {
|
||||
t.Fatalf("expected model, got %#v", metadata)
|
||||
}
|
||||
if metadata.MessagesCount != 1 {
|
||||
t.Fatalf("expected one message, got %#v", metadata)
|
||||
}
|
||||
if metadata.SessionID != "session-1" {
|
||||
t.Fatalf("expected session-1, got %#v", metadata)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCCMRequestMetadataFromMessagesLegacySession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
metadata, err := extractCCMRequestMetadata("/v1/messages", []byte(`{
|
||||
"model":"claude-sonnet-4-5",
|
||||
"max_tokens":1,
|
||||
"messages":[{"role":"user","content":"hello"}],
|
||||
"metadata":{"user_id":"user_device_account_account_session_session-legacy"}
|
||||
}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if metadata.SessionID != "session-legacy" {
|
||||
t.Fatalf("expected session-legacy, got %#v", metadata)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCCMRequestMetadataFromCountTokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
metadata, err := extractCCMRequestMetadata("/v1/messages/count_tokens", []byte(`{
|
||||
"model":"claude-sonnet-4-5",
|
||||
"messages":[{"role":"user","content":"hello"}]
|
||||
}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if metadata.Model != "claude-sonnet-4-5" {
|
||||
t.Fatalf("expected model, got %#v", metadata)
|
||||
}
|
||||
if metadata.MessagesCount != 1 {
|
||||
t.Fatalf("expected one message, got %#v", metadata)
|
||||
}
|
||||
if metadata.SessionID != "" {
|
||||
t.Fatalf("expected empty session ID, got %#v", metadata)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCCMRequestMetadataIgnoresUnsupportedPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
metadata, err := extractCCMRequestMetadata("/v1/models", []byte(`{"model":"claude"}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if metadata != (ccmRequestMetadata{}) {
|
||||
t.Fatalf("expected zero metadata, got %#v", metadata)
|
||||
}
|
||||
}
|
||||
398
service/ccm/service_status.go
Normal file
398
service/ccm/service_status.go
Normal file
@@ -0,0 +1,398 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
type statusPayload struct {
|
||||
FiveHourUtilization float64 `json:"five_hour_utilization"`
|
||||
FiveHourReset int64 `json:"five_hour_reset"`
|
||||
WeeklyUtilization float64 `json:"weekly_utilization"`
|
||||
WeeklyReset int64 `json:"weekly_reset"`
|
||||
PlanWeight float64 `json:"plan_weight"`
|
||||
}
|
||||
|
||||
type aggregatedStatus struct {
|
||||
fiveHourUtilization float64
|
||||
weeklyUtilization float64
|
||||
totalWeight float64
|
||||
fiveHourReset time.Time
|
||||
weeklyReset time.Time
|
||||
availability availabilityStatus
|
||||
}
|
||||
|
||||
func resetToEpoch(t time.Time) int64 {
|
||||
if t.IsZero() {
|
||||
return 0
|
||||
}
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
func (s aggregatedStatus) equal(other aggregatedStatus) bool {
|
||||
return reflect.DeepEqual(s.toPayload(), other.toPayload())
|
||||
}
|
||||
|
||||
func (s aggregatedStatus) toPayload() statusPayload {
|
||||
return statusPayload{
|
||||
FiveHourUtilization: s.fiveHourUtilization,
|
||||
FiveHourReset: resetToEpoch(s.fiveHourReset),
|
||||
WeeklyUtilization: s.weeklyUtilization,
|
||||
WeeklyReset: resetToEpoch(s.weeklyReset),
|
||||
PlanWeight: s.totalWeight,
|
||||
}
|
||||
}
|
||||
|
||||
type aggregateInput struct {
|
||||
availability availabilityStatus
|
||||
}
|
||||
|
||||
func aggregateAvailability(inputs []aggregateInput) availabilityStatus {
|
||||
if len(inputs) == 0 {
|
||||
return availabilityStatus{
|
||||
State: availabilityStateUnavailable,
|
||||
Reason: availabilityReasonNoCredentials,
|
||||
}
|
||||
}
|
||||
var earliestRateLimit time.Time
|
||||
var hasRateLimited bool
|
||||
var blocked availabilityStatus
|
||||
var hasBlocked bool
|
||||
var hasUnavailable bool
|
||||
for _, input := range inputs {
|
||||
availability := input.availability.normalized()
|
||||
switch availability.State {
|
||||
case availabilityStateUsable:
|
||||
return availabilityStatus{State: availabilityStateUsable}
|
||||
case availabilityStateRateLimited:
|
||||
hasRateLimited = true
|
||||
if !availability.ResetAt.IsZero() && (earliestRateLimit.IsZero() || availability.ResetAt.Before(earliestRateLimit)) {
|
||||
earliestRateLimit = availability.ResetAt
|
||||
}
|
||||
if blocked.State == "" {
|
||||
blocked = availabilityStatus{
|
||||
State: availabilityStateRateLimited,
|
||||
Reason: availabilityReasonHardRateLimit,
|
||||
ResetAt: earliestRateLimit,
|
||||
}
|
||||
}
|
||||
case availabilityStateTemporarilyBlocked:
|
||||
if !hasBlocked {
|
||||
blocked = availability
|
||||
hasBlocked = true
|
||||
}
|
||||
if !availability.ResetAt.IsZero() && (blocked.ResetAt.IsZero() || availability.ResetAt.Before(blocked.ResetAt)) {
|
||||
blocked.ResetAt = availability.ResetAt
|
||||
}
|
||||
case availabilityStateUnavailable:
|
||||
hasUnavailable = true
|
||||
}
|
||||
}
|
||||
if hasRateLimited {
|
||||
blocked.ResetAt = earliestRateLimit
|
||||
return blocked
|
||||
}
|
||||
if hasBlocked {
|
||||
return blocked
|
||||
}
|
||||
if hasUnavailable {
|
||||
return availabilityStatus{
|
||||
State: availabilityStateUnavailable,
|
||||
Reason: availabilityReasonUnknown,
|
||||
}
|
||||
}
|
||||
return availabilityStatus{
|
||||
State: availabilityStateUnknown,
|
||||
Reason: availabilityReasonUnknown,
|
||||
}
|
||||
}
|
||||
|
||||
func chooseRepresentativeClaim(fiveHourUtilization float64, fiveHourReset time.Time, weeklyUtilization float64, weeklyReset time.Time, now time.Time) string {
|
||||
fiveHourWarning := claudeFiveHourWarning(fiveHourUtilization, fiveHourReset, now)
|
||||
weeklyWarning := claudeWeeklyWarning(weeklyUtilization, weeklyReset, now)
|
||||
type claimCandidate struct {
|
||||
name string
|
||||
priority int
|
||||
utilization float64
|
||||
}
|
||||
candidateFor := func(name string, utilization float64, warning bool) claimCandidate {
|
||||
priority := 0
|
||||
switch {
|
||||
case utilization >= 100:
|
||||
priority = 2
|
||||
case warning:
|
||||
priority = 1
|
||||
}
|
||||
return claimCandidate{name: name, priority: priority, utilization: utilization}
|
||||
}
|
||||
five := candidateFor("5h", fiveHourUtilization, fiveHourWarning)
|
||||
weekly := candidateFor("7d", weeklyUtilization, weeklyWarning)
|
||||
switch {
|
||||
case five.priority > weekly.priority:
|
||||
return five.name
|
||||
case weekly.priority > five.priority:
|
||||
return weekly.name
|
||||
case five.utilization > weekly.utilization:
|
||||
return five.name
|
||||
case weekly.utilization > five.utilization:
|
||||
return weekly.name
|
||||
case !fiveHourReset.IsZero():
|
||||
return five.name
|
||||
case !weeklyReset.IsZero():
|
||||
return weekly.name
|
||||
default:
|
||||
return "5h"
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
var provider credentialProvider
|
||||
var userConfig *option.CCMUser
|
||||
if len(s.options.Users) > 0 {
|
||||
if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"API key authentication is not supported; use Authorization: Bearer with a CCM user token")
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
username, ok := s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
|
||||
userConfig = s.userConfigMap[username]
|
||||
var err error
|
||||
provider, err = credentialForUser(s.userConfigMap, s.providers, username)
|
||||
if err != nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
provider = s.providers[s.options.Credentials[0].Tag]
|
||||
}
|
||||
if provider == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Query().Get("watch") == "true" {
|
||||
s.handleStatusStream(w, r, provider, userConfig)
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale()
|
||||
status := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(status.toPayload())
|
||||
}
|
||||
|
||||
func (s *Service) handleStatusStream(w http.ResponseWriter, r *http.Request, provider credentialProvider, userConfig *option.CCMUser) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
subscription, done, err := s.statusObserver.Subscribe()
|
||||
if err != nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "service closing")
|
||||
return
|
||||
}
|
||||
defer s.statusObserver.UnSubscribe(subscription)
|
||||
|
||||
provider.pollIfStale()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
last := s.computeAggregatedUtilization(provider, userConfig)
|
||||
buf := &bytes.Buffer{}
|
||||
json.NewEncoder(buf).Encode(last.toPayload())
|
||||
_, writeErr := w.Write(buf.Bytes())
|
||||
if writeErr != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-done:
|
||||
return
|
||||
case <-subscription:
|
||||
for {
|
||||
select {
|
||||
case <-subscription:
|
||||
default:
|
||||
goto drained
|
||||
}
|
||||
}
|
||||
drained:
|
||||
current := s.computeAggregatedUtilization(provider, userConfig)
|
||||
if current.equal(last) {
|
||||
continue
|
||||
}
|
||||
last = current
|
||||
buf.Reset()
|
||||
json.NewEncoder(buf).Encode(current.toPayload())
|
||||
_, writeErr = w.Write(buf.Bytes())
|
||||
if writeErr != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) aggregatedStatus {
|
||||
visibleInputs := make([]aggregateInput, 0, len(provider.allCredentials()))
|
||||
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
|
||||
now := time.Now()
|
||||
var totalWeightedHoursUntil5hReset, total5hResetWeight float64
|
||||
var totalWeightedHoursUntilWeeklyReset, totalWeeklyResetWeight float64
|
||||
var hasSnapshotData bool
|
||||
for _, credential := range provider.allCredentials() {
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" && credential.tagName() == userConfig.ExternalCredential {
|
||||
continue
|
||||
}
|
||||
if userConfig != nil && !userConfig.AllowExternalUsage && credential.isExternal() {
|
||||
continue
|
||||
}
|
||||
visibleInputs = append(visibleInputs, aggregateInput{
|
||||
availability: credential.availabilityStatus(),
|
||||
})
|
||||
if !credential.hasSnapshotData() {
|
||||
continue
|
||||
}
|
||||
hasSnapshotData = true
|
||||
weight := credential.planWeight()
|
||||
remaining5h := credential.fiveHourCap() - credential.fiveHourUtilization()
|
||||
if remaining5h < 0 {
|
||||
remaining5h = 0
|
||||
}
|
||||
remainingWeekly := credential.weeklyCap() - credential.weeklyUtilization()
|
||||
if remainingWeekly < 0 {
|
||||
remainingWeekly = 0
|
||||
}
|
||||
totalWeightedRemaining5h += remaining5h * weight
|
||||
totalWeightedRemainingWeekly += remainingWeekly * weight
|
||||
totalWeight += weight
|
||||
|
||||
fiveHourReset := credential.fiveHourResetTime()
|
||||
if !fiveHourReset.IsZero() {
|
||||
hours := fiveHourReset.Sub(now).Hours()
|
||||
if hours > 0 {
|
||||
totalWeightedHoursUntil5hReset += hours * weight
|
||||
total5hResetWeight += weight
|
||||
}
|
||||
}
|
||||
weeklyReset := credential.weeklyResetTime()
|
||||
if !weeklyReset.IsZero() {
|
||||
hours := weeklyReset.Sub(now).Hours()
|
||||
if hours > 0 {
|
||||
totalWeightedHoursUntilWeeklyReset += hours * weight
|
||||
totalWeeklyResetWeight += weight
|
||||
}
|
||||
}
|
||||
}
|
||||
availability := aggregateAvailability(visibleInputs)
|
||||
if totalWeight == 0 {
|
||||
result := aggregatedStatus{availability: availability}
|
||||
if !hasSnapshotData {
|
||||
result.fiveHourUtilization = 100
|
||||
result.weeklyUtilization = 100
|
||||
}
|
||||
return result
|
||||
}
|
||||
result := aggregatedStatus{
|
||||
fiveHourUtilization: 100 - totalWeightedRemaining5h/totalWeight,
|
||||
weeklyUtilization: 100 - totalWeightedRemainingWeekly/totalWeight,
|
||||
totalWeight: totalWeight,
|
||||
availability: availability,
|
||||
}
|
||||
if total5hResetWeight > 0 {
|
||||
avgHours := totalWeightedHoursUntil5hReset / total5hResetWeight
|
||||
result.fiveHourReset = now.Add(time.Duration(avgHours * float64(time.Hour)))
|
||||
}
|
||||
if totalWeeklyResetWeight > 0 {
|
||||
avgHours := totalWeightedHoursUntilWeeklyReset / totalWeeklyResetWeight
|
||||
result.weeklyReset = now.Add(time.Duration(avgHours * float64(time.Hour)))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *Service) rewriteResponseHeaders(headers http.Header, provider credentialProvider, userConfig *option.CCMUser) {
|
||||
for key := range headers {
|
||||
if strings.HasPrefix(strings.ToLower(key), "anthropic-ratelimit-unified-") {
|
||||
headers.Del(key)
|
||||
}
|
||||
}
|
||||
status := s.computeAggregatedUtilization(provider, userConfig)
|
||||
now := time.Now()
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(status.fiveHourUtilization/100, 'f', 6, 64))
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(status.weeklyUtilization/100, 'f', 6, 64))
|
||||
if !status.fiveHourReset.IsZero() {
|
||||
headers.Set("anthropic-ratelimit-unified-5h-reset", strconv.FormatInt(status.fiveHourReset.Unix(), 10))
|
||||
}
|
||||
if !status.weeklyReset.IsZero() {
|
||||
headers.Set("anthropic-ratelimit-unified-7d-reset", strconv.FormatInt(status.weeklyReset.Unix(), 10))
|
||||
}
|
||||
if status.totalWeight > 0 {
|
||||
headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(status.totalWeight, 'f', -1, 64))
|
||||
}
|
||||
fiveHourWarning := claudeFiveHourWarning(status.fiveHourUtilization, status.fiveHourReset, now)
|
||||
weeklyWarning := claudeWeeklyWarning(status.weeklyUtilization, status.weeklyReset, now)
|
||||
switch {
|
||||
case status.fiveHourUtilization >= 100 || status.weeklyUtilization >= 100 ||
|
||||
status.availability.State == availabilityStateRateLimited:
|
||||
headers.Set("anthropic-ratelimit-unified-status", "rejected")
|
||||
case fiveHourWarning || weeklyWarning:
|
||||
headers.Set("anthropic-ratelimit-unified-status", "allowed_warning")
|
||||
default:
|
||||
headers.Set("anthropic-ratelimit-unified-status", "allowed")
|
||||
}
|
||||
claim := chooseRepresentativeClaim(status.fiveHourUtilization, status.fiveHourReset, status.weeklyUtilization, status.weeklyReset, now)
|
||||
headers.Set("anthropic-ratelimit-unified-representative-claim", claim)
|
||||
switch claim {
|
||||
case "7d":
|
||||
if !status.weeklyReset.IsZero() {
|
||||
headers.Set("anthropic-ratelimit-unified-reset", strconv.FormatInt(status.weeklyReset.Unix(), 10))
|
||||
}
|
||||
default:
|
||||
if !status.fiveHourReset.IsZero() {
|
||||
headers.Set("anthropic-ratelimit-unified-reset", strconv.FormatInt(status.fiveHourReset.Unix(), 10))
|
||||
}
|
||||
}
|
||||
if fiveHourWarning || status.fiveHourUtilization >= 100 {
|
||||
headers.Set("anthropic-ratelimit-unified-5h-surpassed-threshold", "true")
|
||||
}
|
||||
if weeklyWarning || status.weeklyUtilization >= 100 {
|
||||
headers.Set("anthropic-ratelimit-unified-7d-surpassed-threshold", "true")
|
||||
}
|
||||
}
|
||||
234
service/ccm/service_status_test.go
Normal file
234
service/ccm/service_status_test.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
)
|
||||
|
||||
type testCredential struct {
|
||||
tag string
|
||||
external bool
|
||||
available bool
|
||||
usable bool
|
||||
hasData bool
|
||||
fiveHour float64
|
||||
weekly float64
|
||||
fiveHourCapV float64
|
||||
weeklyCapV float64
|
||||
weight float64
|
||||
fiveReset time.Time
|
||||
weeklyReset time.Time
|
||||
availability availabilityStatus
|
||||
}
|
||||
|
||||
func (c *testCredential) tagName() string { return c.tag }
|
||||
func (c *testCredential) isAvailable() bool { return c.available }
|
||||
func (c *testCredential) isUsable() bool { return c.usable }
|
||||
func (c *testCredential) isExternal() bool { return c.external }
|
||||
func (c *testCredential) hasSnapshotData() bool { return c.hasData }
|
||||
func (c *testCredential) fiveHourUtilization() float64 { return c.fiveHour }
|
||||
func (c *testCredential) weeklyUtilization() float64 { return c.weekly }
|
||||
func (c *testCredential) fiveHourCap() float64 { return c.fiveHourCapV }
|
||||
func (c *testCredential) weeklyCap() float64 { return c.weeklyCapV }
|
||||
func (c *testCredential) planWeight() float64 { return c.weight }
|
||||
func (c *testCredential) fiveHourResetTime() time.Time { return c.fiveReset }
|
||||
func (c *testCredential) weeklyResetTime() time.Time { return c.weeklyReset }
|
||||
func (c *testCredential) markRateLimited(time.Time) {}
|
||||
func (c *testCredential) markUpstreamRejected() {}
|
||||
func (c *testCredential) availabilityStatus() availabilityStatus { return c.availability }
|
||||
func (c *testCredential) earliestReset() time.Time { return c.fiveReset }
|
||||
func (c *testCredential) unavailableError() error { return nil }
|
||||
func (c *testCredential) getAccessToken() (string, error) { return "", nil }
|
||||
func (c *testCredential) buildProxyRequest(context.Context, *http.Request, []byte, http.Header) (*http.Request, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *testCredential) updateStateFromHeaders(http.Header) {}
|
||||
func (c *testCredential) wrapRequestContext(context.Context) *credentialRequestContext { return nil }
|
||||
func (c *testCredential) interruptConnections() {}
|
||||
func (c *testCredential) setStatusSubscriber(*observable.Subscriber[struct{}]) {}
|
||||
func (c *testCredential) start() error { return nil }
|
||||
func (c *testCredential) pollUsage() {}
|
||||
func (c *testCredential) lastUpdatedTime() time.Time { return time.Now() }
|
||||
func (c *testCredential) pollBackoff(time.Duration) time.Duration { return 0 }
|
||||
func (c *testCredential) usageTrackerOrNil() *AggregatedUsage { return nil }
|
||||
func (c *testCredential) httpClient() *http.Client { return nil }
|
||||
func (c *testCredential) close() {}
|
||||
|
||||
type testProvider struct {
|
||||
credentials []Credential
|
||||
}
|
||||
|
||||
func (p *testProvider) selectCredential(string, credentialSelection) (Credential, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
func (p *testProvider) onRateLimited(string, Credential, time.Time, credentialSelection) Credential {
|
||||
return nil
|
||||
}
|
||||
func (p *testProvider) linkProviderInterrupt(Credential, credentialSelection, func()) func() bool {
|
||||
return func() bool { return true }
|
||||
}
|
||||
func (p *testProvider) pollIfStale() {}
|
||||
func (p *testProvider) pollCredentialIfStale(Credential) {}
|
||||
func (p *testProvider) allCredentials() []Credential { return p.credentials }
|
||||
func (p *testProvider) close() {}
|
||||
|
||||
func TestComputeAggregatedUtilizationPreservesSnapshotForRateLimitedCredential(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reset := time.Now().Add(15 * time.Minute)
|
||||
service := &Service{}
|
||||
status := service.computeAggregatedUtilization(&testProvider{credentials: []Credential{
|
||||
&testCredential{
|
||||
tag: "a",
|
||||
available: true,
|
||||
usable: false,
|
||||
hasData: true,
|
||||
fiveHour: 42,
|
||||
weekly: 18,
|
||||
fiveHourCapV: 100,
|
||||
weeklyCapV: 100,
|
||||
weight: 1,
|
||||
fiveReset: reset,
|
||||
weeklyReset: reset.Add(2 * time.Hour),
|
||||
availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: reset},
|
||||
},
|
||||
}}, nil)
|
||||
|
||||
if status.fiveHourUtilization != 42 || status.weeklyUtilization != 18 {
|
||||
t.Fatalf("expected preserved utilization, got 5h=%v weekly=%v", status.fiveHourUtilization, status.weeklyUtilization)
|
||||
}
|
||||
if status.availability.State != availabilityStateRateLimited {
|
||||
t.Fatalf("expected rate-limited availability, got %#v", status.availability)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteResponseHeadersComputesUnifiedStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reset := time.Now().Add(80 * time.Minute)
|
||||
service := &Service{}
|
||||
headers := make(http.Header)
|
||||
service.rewriteResponseHeaders(headers, &testProvider{credentials: []Credential{
|
||||
&testCredential{
|
||||
tag: "a",
|
||||
available: true,
|
||||
usable: true,
|
||||
hasData: true,
|
||||
fiveHour: 92,
|
||||
weekly: 30,
|
||||
fiveHourCapV: 100,
|
||||
weeklyCapV: 100,
|
||||
weight: 1,
|
||||
fiveReset: reset,
|
||||
weeklyReset: time.Now().Add(4 * 24 * time.Hour),
|
||||
availability: availabilityStatus{State: availabilityStateUsable},
|
||||
},
|
||||
}}, nil)
|
||||
|
||||
if headers.Get("anthropic-ratelimit-unified-status") != "allowed_warning" {
|
||||
t.Fatalf("expected allowed_warning, got %q", headers.Get("anthropic-ratelimit-unified-status"))
|
||||
}
|
||||
if headers.Get("anthropic-ratelimit-unified-representative-claim") != "5h" {
|
||||
t.Fatalf("expected 5h representative claim, got %q", headers.Get("anthropic-ratelimit-unified-representative-claim"))
|
||||
}
|
||||
if headers.Get("anthropic-ratelimit-unified-5h-surpassed-threshold") != "true" {
|
||||
t.Fatalf("expected 5h threshold header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteResponseHeadersStripsUpstreamHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service := &Service{}
|
||||
headers := make(http.Header)
|
||||
headers.Set("anthropic-ratelimit-unified-overage-status", "rejected")
|
||||
headers.Set("anthropic-ratelimit-unified-overage-disabled-reason", "org_level_disabled")
|
||||
headers.Set("anthropic-ratelimit-unified-fallback", "available")
|
||||
service.rewriteResponseHeaders(headers, &testProvider{credentials: []Credential{
|
||||
&testCredential{
|
||||
tag: "a",
|
||||
available: true,
|
||||
usable: true,
|
||||
hasData: true,
|
||||
fiveHour: 10,
|
||||
weekly: 5,
|
||||
fiveHourCapV: 100,
|
||||
weeklyCapV: 100,
|
||||
weight: 1,
|
||||
fiveReset: time.Now().Add(3 * time.Hour),
|
||||
weeklyReset: time.Now().Add(5 * 24 * time.Hour),
|
||||
availability: availabilityStatus{State: availabilityStateUsable},
|
||||
},
|
||||
}}, nil)
|
||||
|
||||
if headers.Get("anthropic-ratelimit-unified-overage-status") != "" {
|
||||
t.Fatalf("expected overage-status stripped, got %q", headers.Get("anthropic-ratelimit-unified-overage-status"))
|
||||
}
|
||||
if headers.Get("anthropic-ratelimit-unified-overage-disabled-reason") != "" {
|
||||
t.Fatalf("expected overage-disabled-reason stripped, got %q", headers.Get("anthropic-ratelimit-unified-overage-disabled-reason"))
|
||||
}
|
||||
if headers.Get("anthropic-ratelimit-unified-fallback") != "" {
|
||||
t.Fatalf("expected fallback stripped, got %q", headers.Get("anthropic-ratelimit-unified-fallback"))
|
||||
}
|
||||
if headers.Get("anthropic-ratelimit-unified-status") != "allowed" {
|
||||
t.Fatalf("expected allowed status, got %q", headers.Get("anthropic-ratelimit-unified-status"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteResponseHeadersRejectedOnHardRateLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reset := time.Now().Add(10 * time.Minute)
|
||||
service := &Service{}
|
||||
headers := make(http.Header)
|
||||
service.rewriteResponseHeaders(headers, &testProvider{credentials: []Credential{
|
||||
&testCredential{
|
||||
tag: "a",
|
||||
available: true,
|
||||
usable: false,
|
||||
hasData: true,
|
||||
fiveHour: 50,
|
||||
weekly: 20,
|
||||
fiveHourCapV: 100,
|
||||
weeklyCapV: 100,
|
||||
weight: 1,
|
||||
fiveReset: reset,
|
||||
weeklyReset: time.Now().Add(5 * 24 * time.Hour),
|
||||
availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: reset},
|
||||
},
|
||||
}}, nil)
|
||||
|
||||
if headers.Get("anthropic-ratelimit-unified-status") != "rejected" {
|
||||
t.Fatalf("expected rejected (hard rate limited), got %q", headers.Get("anthropic-ratelimit-unified-status"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteCredentialUnavailableErrorReturns429ForRateLimitedCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, "/v1/messages", nil)
|
||||
provider := &testProvider{credentials: []Credential{
|
||||
&testCredential{
|
||||
tag: "a",
|
||||
available: true,
|
||||
usable: false,
|
||||
hasData: true,
|
||||
fiveHourCapV: 100,
|
||||
weeklyCapV: 100,
|
||||
weight: 1,
|
||||
availability: availabilityStatus{State: availabilityStateRateLimited, Reason: availabilityReasonHardRateLimit, ResetAt: time.Now().Add(time.Minute)},
|
||||
},
|
||||
}}
|
||||
|
||||
writeCredentialUnavailableError(recorder, request, provider, provider.credentials[0], credentialSelection{}, "all credentials rate-limited")
|
||||
|
||||
if recorder.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected 429, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
@@ -35,13 +35,13 @@ type CostCombination struct {
|
||||
type AggregatedUsage struct {
|
||||
LastUpdated time.Time `json:"last_updated"`
|
||||
Combinations []CostCombination `json:"combinations"`
|
||||
mutex sync.Mutex
|
||||
access sync.Mutex
|
||||
filePath string
|
||||
logger log.ContextLogger
|
||||
lastSaveTime time.Time
|
||||
pendingSave bool
|
||||
saveTimer *time.Timer
|
||||
saveMutex sync.Mutex
|
||||
saveAccess sync.Mutex
|
||||
}
|
||||
|
||||
type UsageStatsJSON struct {
|
||||
@@ -527,8 +527,8 @@ func deriveWeekStartUnix(cycleHint *WeeklyCycleHint) int64 {
|
||||
}
|
||||
|
||||
func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON {
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
u.access.Lock()
|
||||
defer u.access.Unlock()
|
||||
|
||||
result := &AggregatedUsageJSON{
|
||||
LastUpdated: u.LastUpdated,
|
||||
@@ -561,8 +561,8 @@ func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON {
|
||||
}
|
||||
|
||||
func (u *AggregatedUsage) Load() error {
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
u.access.Lock()
|
||||
defer u.access.Unlock()
|
||||
|
||||
u.LastUpdated = time.Time{}
|
||||
u.Combinations = nil
|
||||
@@ -608,9 +608,9 @@ func (u *AggregatedUsage) Save() error {
|
||||
defer os.Remove(tmpFile)
|
||||
err = os.Rename(tmpFile, u.filePath)
|
||||
if err == nil {
|
||||
u.saveMutex.Lock()
|
||||
u.saveAccess.Lock()
|
||||
u.lastSaveTime = time.Now()
|
||||
u.saveMutex.Unlock()
|
||||
u.saveAccess.Unlock()
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -644,15 +644,15 @@ func (u *AggregatedUsage) AddUsageWithCycleHint(
|
||||
observedAt = time.Now()
|
||||
}
|
||||
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
u.access.Lock()
|
||||
defer u.access.Unlock()
|
||||
|
||||
u.LastUpdated = observedAt
|
||||
weekStartUnix := deriveWeekStartUnix(cycleHint)
|
||||
|
||||
addUsageToCombinations(&u.Combinations, model, contextWindow, weekStartUnix, messagesCount, inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens, cacheCreation5MinuteTokens, cacheCreation1HourTokens, user)
|
||||
|
||||
go u.scheduleSave()
|
||||
u.scheduleSave()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -660,8 +660,8 @@ func (u *AggregatedUsage) AddUsageWithCycleHint(
|
||||
func (u *AggregatedUsage) scheduleSave() {
|
||||
const saveInterval = time.Minute
|
||||
|
||||
u.saveMutex.Lock()
|
||||
defer u.saveMutex.Unlock()
|
||||
u.saveAccess.Lock()
|
||||
defer u.saveAccess.Unlock()
|
||||
|
||||
timeSinceLastSave := time.Since(u.lastSaveTime)
|
||||
|
||||
@@ -678,9 +678,9 @@ func (u *AggregatedUsage) scheduleSave() {
|
||||
remainingTime := saveInterval - timeSinceLastSave
|
||||
|
||||
u.saveTimer = time.AfterFunc(remainingTime, func() {
|
||||
u.saveMutex.Lock()
|
||||
u.saveAccess.Lock()
|
||||
u.pendingSave = false
|
||||
u.saveMutex.Unlock()
|
||||
u.saveAccess.Unlock()
|
||||
u.saveAsync()
|
||||
})
|
||||
}
|
||||
@@ -695,8 +695,8 @@ func (u *AggregatedUsage) saveAsync() {
|
||||
}
|
||||
|
||||
func (u *AggregatedUsage) cancelPendingSave() {
|
||||
u.saveMutex.Lock()
|
||||
defer u.saveMutex.Unlock()
|
||||
u.saveAccess.Lock()
|
||||
defer u.saveAccess.Unlock()
|
||||
|
||||
if u.saveTimer != nil {
|
||||
u.saveTimer.Stop()
|
||||
|
||||
@@ -7,13 +7,13 @@ import (
|
||||
)
|
||||
|
||||
type UserManager struct {
|
||||
accessMutex sync.RWMutex
|
||||
tokenMap map[string]string
|
||||
access sync.RWMutex
|
||||
tokenMap map[string]string
|
||||
}
|
||||
|
||||
func (m *UserManager) UpdateUsers(users []option.CCMUser) {
|
||||
m.accessMutex.Lock()
|
||||
defer m.accessMutex.Unlock()
|
||||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
tokenMap := make(map[string]string, len(users))
|
||||
for _, user := range users {
|
||||
tokenMap[user.Token] = user.Name
|
||||
@@ -22,8 +22,8 @@ func (m *UserManager) UpdateUsers(users []option.CCMUser) {
|
||||
}
|
||||
|
||||
func (m *UserManager) Authenticate(token string) (string, bool) {
|
||||
m.accessMutex.RLock()
|
||||
m.access.RLock()
|
||||
username, found := m.tokenMap[token]
|
||||
m.accessMutex.RUnlock()
|
||||
m.access.RUnlock()
|
||||
return username, found
|
||||
}
|
||||
|
||||
139
service/ccm/test_helpers_test.go
Normal file
139
service/ccm/test_helpers_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
return f(request)
|
||||
}
|
||||
|
||||
func newJSONResponse(statusCode int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Status: http.StatusText(statusCode),
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func newTextResponse(statusCode int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Status: http.StatusText(statusCode),
|
||||
Header: http.Header{"Content-Type": []string{"text/plain"}},
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func writeTestCredentials(t *testing.T, path string, credentials *oauthCredentials) {
|
||||
t.Helper()
|
||||
if path == "" {
|
||||
var err error
|
||||
path, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if err := writeCredentialsToFile(credentials, path); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func readTestCredentials(t *testing.T, path string) *oauthCredentials {
|
||||
t.Helper()
|
||||
if path == "" {
|
||||
var err error
|
||||
path, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
credentials, err := readCredentialsFromFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return credentials
|
||||
}
|
||||
|
||||
func newTestDefaultCredential(t *testing.T, credentialPath string, transport http.RoundTripper) *defaultCredential {
|
||||
t.Helper()
|
||||
credentialFilePath, err := resolveCredentialFilePath(credentialPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
requestContext, cancelRequests := context.WithCancel(context.Background())
|
||||
credential := &defaultCredential{
|
||||
tag: "test",
|
||||
serviceContext: context.Background(),
|
||||
credentialPath: credentialPath,
|
||||
credentialFilePath: credentialFilePath,
|
||||
configDir: resolveConfigDir(credentialPath, credentialFilePath),
|
||||
syncClaudeConfig: credentialPath == "",
|
||||
cap5h: 99,
|
||||
capWeekly: 99,
|
||||
forwardHTTPClient: &http.Client{Transport: transport},
|
||||
acquireLock: acquireCredentialLock,
|
||||
logger: log.NewNOPFactory().Logger(),
|
||||
requestContext: requestContext,
|
||||
cancelRequests: cancelRequests,
|
||||
}
|
||||
if credential.syncClaudeConfig {
|
||||
credential.claudeDirectory = credential.configDir
|
||||
credential.claudeConfigPath = resolveClaudeConfigWritePath(credential.claudeDirectory)
|
||||
}
|
||||
credential.state.lastUpdated = time.Now()
|
||||
return credential
|
||||
}
|
||||
|
||||
func seedTestCredentialState(credential *defaultCredential) {
|
||||
billingType := "individual"
|
||||
accountCreatedAt := "2024-01-01T00:00:00Z"
|
||||
subscriptionCreatedAt := "2024-01-02T00:00:00Z"
|
||||
credential.stateAccess.Lock()
|
||||
credential.state.accountUUID = "account"
|
||||
credential.state.accountType = "max"
|
||||
credential.state.rateLimitTier = "default_claude_max_20x"
|
||||
credential.state.oauthAccount = &claudeOAuthAccount{
|
||||
AccountUUID: "account",
|
||||
EmailAddress: "user@example.com",
|
||||
OrganizationUUID: "org",
|
||||
BillingType: &billingType,
|
||||
AccountCreatedAt: &accountCreatedAt,
|
||||
SubscriptionCreatedAt: &subscriptionCreatedAt,
|
||||
}
|
||||
credential.stateAccess.Unlock()
|
||||
}
|
||||
|
||||
func newTestService(credential *defaultCredential) *Service {
|
||||
return &Service{
|
||||
logger: log.NewNOPFactory().Logger(),
|
||||
options: option.CCMServiceOptions{Credentials: []option.CCMCredential{{Tag: "default"}}},
|
||||
httpHeaders: make(http.Header),
|
||||
providers: map[string]credentialProvider{"default": &singleCredentialProvider{credential: credential}},
|
||||
sessionModels: make(map[sessionModelKey]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
func newMessageRequest(body string) *http.Request {
|
||||
request := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
return request
|
||||
}
|
||||
|
||||
func tempConfigPath(t *testing.T, dir string) string {
|
||||
t.Helper()
|
||||
return filepath.Join(dir, claudeCodeLegacyConfigFileName())
|
||||
}
|
||||
7
service/ocm/CLAUDE.md
Normal file
7
service/ocm/CLAUDE.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# OpenAI Codex Multiplexer
|
||||
|
||||
### Reverse Codex
|
||||
|
||||
Oh, Codex is just open source.
|
||||
|
||||
Clone it and study its code: https://github.com/openai/codex
|
||||
@@ -1,173 +1,275 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
)
|
||||
|
||||
const (
|
||||
oauth2ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
oauth2TokenURL = "https://auth.openai.com/oauth/token"
|
||||
openaiAPIBaseURL = "https://api.openai.com"
|
||||
chatGPTBackendURL = "https://chatgpt.com/backend-api/codex"
|
||||
tokenRefreshIntervalDays = 8
|
||||
defaultPollInterval = 60 * time.Minute
|
||||
failedPollRetryInterval = time.Minute
|
||||
httpRetryMaxBackoff = 5 * time.Minute
|
||||
)
|
||||
|
||||
func getRealUser() (*user.User, error) {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
sudoUserInfo, err := user.Lookup(sudoUser)
|
||||
const (
|
||||
httpRetryMaxAttempts = 3
|
||||
httpRetryInitialDelay = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
const sessionExpiry = 24 * time.Hour
|
||||
|
||||
func doHTTPWithRetry(ctx context.Context, client *http.Client, buildRequest func() (*http.Request, error)) (*http.Response, error) {
|
||||
var lastError error
|
||||
for attempt := range httpRetryMaxAttempts {
|
||||
if attempt > 0 {
|
||||
delay := httpRetryInitialDelay * time.Duration(1<<(attempt-1))
|
||||
timer := time.NewTimer(delay)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return nil, lastError
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
request, err := buildRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response, err := client.Do(request)
|
||||
if err == nil {
|
||||
return sudoUserInfo, nil
|
||||
return response, nil
|
||||
}
|
||||
lastError = err
|
||||
if ctx.Err() != nil {
|
||||
return nil, lastError
|
||||
}
|
||||
}
|
||||
return user.Current()
|
||||
return nil, lastError
|
||||
}
|
||||
|
||||
func getDefaultCredentialsPath() (string, error) {
|
||||
if codexHome := os.Getenv("CODEX_HOME"); codexHome != "" {
|
||||
return filepath.Join(codexHome, "auth.json"), nil
|
||||
}
|
||||
userInfo, err := getRealUser()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(userInfo.HomeDir, ".codex", "auth.json"), nil
|
||||
type credentialState struct {
|
||||
fiveHourUtilization float64
|
||||
fiveHourReset time.Time
|
||||
weeklyUtilization float64
|
||||
weeklyReset time.Time
|
||||
hardRateLimited bool
|
||||
rateLimitResetAt time.Time
|
||||
availabilityState availabilityState
|
||||
availabilityReason availabilityReason
|
||||
availabilityResetAt time.Time
|
||||
lastKnownDataAt time.Time
|
||||
accountType string
|
||||
remotePlanWeight float64
|
||||
activeLimitID string
|
||||
rateLimitSnapshots map[string]rateLimitSnapshot
|
||||
lastUpdated time.Time
|
||||
consecutivePollFailures int
|
||||
usageAPIRetryDelay time.Duration
|
||||
unavailable bool
|
||||
upstreamRejectedUntil time.Time
|
||||
lastCredentialLoadAttempt time.Time
|
||||
lastCredentialLoadError string
|
||||
}
|
||||
|
||||
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var credentials oauthCredentials
|
||||
err = json.Unmarshal(data, &credentials)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &credentials, nil
|
||||
type credentialRequestContext struct {
|
||||
context.Context
|
||||
releaseOnce sync.Once
|
||||
cancelOnce sync.Once
|
||||
releaseFuncs []func() bool
|
||||
cancelFunc context.CancelFunc
|
||||
}
|
||||
|
||||
func writeCredentialsToFile(credentials *oauthCredentials, path string) error {
|
||||
data, err := json.MarshalIndent(credentials, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0o600)
|
||||
func (c *credentialRequestContext) addInterruptLink(stop func() bool) {
|
||||
c.releaseFuncs = append(c.releaseFuncs, stop)
|
||||
}
|
||||
|
||||
type oauthCredentials struct {
|
||||
APIKey string `json:"OPENAI_API_KEY,omitempty"`
|
||||
Tokens *tokenData `json:"tokens,omitempty"`
|
||||
LastRefresh *time.Time `json:"last_refresh,omitempty"`
|
||||
}
|
||||
|
||||
type tokenData struct {
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) isAPIKeyMode() bool {
|
||||
return c.APIKey != ""
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) getAccessToken() string {
|
||||
if c.APIKey != "" {
|
||||
return c.APIKey
|
||||
}
|
||||
if c.Tokens != nil {
|
||||
return c.Tokens.AccessToken
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) getAccountID() string {
|
||||
if c.Tokens != nil {
|
||||
return c.Tokens.AccountID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) needsRefresh() bool {
|
||||
if c.APIKey != "" {
|
||||
return false
|
||||
}
|
||||
if c.Tokens == nil || c.Tokens.RefreshToken == "" {
|
||||
return false
|
||||
}
|
||||
if c.LastRefresh == nil {
|
||||
return true
|
||||
}
|
||||
return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour
|
||||
}
|
||||
|
||||
func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
|
||||
if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" {
|
||||
return nil, E.New("refresh token is empty")
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.Tokens.RefreshToken,
|
||||
"client_id": oauth2ClientID,
|
||||
"scope": "openid profile email",
|
||||
func (c *credentialRequestContext) releaseCredentialInterrupt() {
|
||||
c.releaseOnce.Do(func() {
|
||||
for _, f := range c.releaseFuncs {
|
||||
f()
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "marshal request")
|
||||
}
|
||||
|
||||
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "application/json")
|
||||
|
||||
response, err := httpClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
|
||||
}
|
||||
|
||||
var tokenResponse struct {
|
||||
IDToken string `json:"id_token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode response")
|
||||
}
|
||||
|
||||
newCredentials := *credentials
|
||||
if newCredentials.Tokens == nil {
|
||||
newCredentials.Tokens = &tokenData{}
|
||||
}
|
||||
if tokenResponse.IDToken != "" {
|
||||
newCredentials.Tokens.IDToken = tokenResponse.IDToken
|
||||
}
|
||||
if tokenResponse.AccessToken != "" {
|
||||
newCredentials.Tokens.AccessToken = tokenResponse.AccessToken
|
||||
}
|
||||
if tokenResponse.RefreshToken != "" {
|
||||
newCredentials.Tokens.RefreshToken = tokenResponse.RefreshToken
|
||||
}
|
||||
now := time.Now()
|
||||
newCredentials.LastRefresh = &now
|
||||
|
||||
return &newCredentials, nil
|
||||
}
|
||||
|
||||
func (c *credentialRequestContext) cancelRequest() {
|
||||
c.releaseCredentialInterrupt()
|
||||
c.cancelOnce.Do(c.cancelFunc)
|
||||
}
|
||||
|
||||
type Credential interface {
|
||||
tagName() string
|
||||
isAvailable() bool
|
||||
isUsable() bool
|
||||
isExternal() bool
|
||||
hasSnapshotData() bool
|
||||
fiveHourUtilization() float64
|
||||
weeklyUtilization() float64
|
||||
fiveHourCap() float64
|
||||
weeklyCap() float64
|
||||
planWeight() float64
|
||||
weeklyResetTime() time.Time
|
||||
fiveHourResetTime() time.Time
|
||||
markRateLimited(resetAt time.Time)
|
||||
markUpstreamRejected()
|
||||
markTemporarilyBlocked(reason availabilityReason, resetAt time.Time)
|
||||
availabilityStatus() availabilityStatus
|
||||
earliestReset() time.Time
|
||||
unavailableError() error
|
||||
|
||||
getAccessToken() (string, error)
|
||||
buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error)
|
||||
updateStateFromHeaders(header http.Header)
|
||||
|
||||
wrapRequestContext(ctx context.Context) *credentialRequestContext
|
||||
interruptConnections()
|
||||
|
||||
setOnBecameUnusable(fn func())
|
||||
setStatusSubscriber(*observable.Subscriber[struct{}])
|
||||
start() error
|
||||
pollUsage()
|
||||
lastUpdatedTime() time.Time
|
||||
pollBackoff(base time.Duration) time.Duration
|
||||
usageTrackerOrNil() *AggregatedUsage
|
||||
httpClient() *http.Client
|
||||
close()
|
||||
|
||||
// OCM-specific
|
||||
ocmDialer() N.Dialer
|
||||
ocmIsAPIKeyMode() bool
|
||||
ocmGetAccountID() string
|
||||
ocmGetBaseURL() string
|
||||
}
|
||||
|
||||
type credentialSelectionScope string
|
||||
|
||||
const (
|
||||
credentialSelectionScopeAll credentialSelectionScope = "all"
|
||||
credentialSelectionScopeNonExternal credentialSelectionScope = "non_external"
|
||||
)
|
||||
|
||||
type credentialSelection struct {
|
||||
scope credentialSelectionScope
|
||||
filter func(Credential) bool
|
||||
}
|
||||
|
||||
func (s credentialSelection) allows(credential Credential) bool {
|
||||
return s.filter == nil || s.filter(credential)
|
||||
}
|
||||
|
||||
func (s credentialSelection) scopeOrDefault() credentialSelectionScope {
|
||||
if s.scope == "" {
|
||||
return credentialSelectionScopeAll
|
||||
}
|
||||
return s.scope
|
||||
}
|
||||
|
||||
func normalizeRateLimitIdentifier(limitIdentifier string) string {
|
||||
trimmedIdentifier := strings.TrimSpace(strings.ToLower(limitIdentifier))
|
||||
if trimmedIdentifier == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.ReplaceAll(trimmedIdentifier, "_", "-")
|
||||
}
|
||||
|
||||
func parseInt64Header(headers http.Header, headerName string) (int64, bool) {
|
||||
headerValue := strings.TrimSpace(headers.Get(headerName))
|
||||
if headerValue == "" {
|
||||
return 0, false
|
||||
}
|
||||
parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64)
|
||||
if parseError != nil {
|
||||
return 0, false
|
||||
}
|
||||
return parsedValue, true
|
||||
}
|
||||
|
||||
func parseOCMRateLimitResetFromHeaders(headers http.Header) time.Time {
|
||||
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
|
||||
if activeLimitIdentifier != "" {
|
||||
resetHeader := "x-" + activeLimitIdentifier + "-primary-reset-at"
|
||||
if resetStr := headers.Get(resetHeader); resetStr != "" {
|
||||
value, err := strconv.ParseInt(resetStr, 10, 64)
|
||||
if err == nil {
|
||||
return time.Unix(value, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
if retryAfter := headers.Get("Retry-After"); retryAfter != "" {
|
||||
seconds, err := strconv.ParseInt(retryAfter, 10, 64)
|
||||
if err == nil {
|
||||
return time.Now().Add(time.Duration(seconds) * time.Second)
|
||||
}
|
||||
}
|
||||
return time.Now().Add(5 * time.Minute)
|
||||
}
|
||||
|
||||
func (s *credentialState) noteSnapshotData() {
|
||||
s.lastKnownDataAt = time.Now()
|
||||
}
|
||||
|
||||
func (s credentialState) hasSnapshotData() bool {
|
||||
return !s.lastKnownDataAt.IsZero() ||
|
||||
s.fiveHourUtilization > 0 ||
|
||||
s.weeklyUtilization > 0 ||
|
||||
!s.fiveHourReset.IsZero() ||
|
||||
!s.weeklyReset.IsZero() ||
|
||||
len(s.rateLimitSnapshots) > 0
|
||||
}
|
||||
|
||||
func (s *credentialState) setAvailability(state availabilityState, reason availabilityReason, resetAt time.Time) {
|
||||
s.availabilityState = state
|
||||
s.availabilityReason = reason
|
||||
s.availabilityResetAt = resetAt
|
||||
}
|
||||
|
||||
func (s credentialState) currentAvailability() availabilityStatus {
|
||||
now := time.Now()
|
||||
switch {
|
||||
case s.unavailable:
|
||||
return availabilityStatus{
|
||||
State: availabilityStateUnavailable,
|
||||
Reason: availabilityReasonUnknown,
|
||||
}
|
||||
case s.availabilityState == availabilityStateTemporarilyBlocked &&
|
||||
(s.availabilityResetAt.IsZero() || now.Before(s.availabilityResetAt)):
|
||||
reason := s.availabilityReason
|
||||
if reason == "" {
|
||||
reason = availabilityReasonUnknown
|
||||
}
|
||||
return availabilityStatus{
|
||||
State: availabilityStateTemporarilyBlocked,
|
||||
Reason: reason,
|
||||
ResetAt: s.availabilityResetAt,
|
||||
}
|
||||
case s.hardRateLimited && (s.rateLimitResetAt.IsZero() || now.Before(s.rateLimitResetAt)):
|
||||
reason := s.availabilityReason
|
||||
if reason == "" {
|
||||
reason = availabilityReasonHardRateLimit
|
||||
}
|
||||
return availabilityStatus{
|
||||
State: availabilityStateRateLimited,
|
||||
Reason: reason,
|
||||
ResetAt: s.rateLimitResetAt,
|
||||
}
|
||||
case !s.upstreamRejectedUntil.IsZero() && now.Before(s.upstreamRejectedUntil):
|
||||
return availabilityStatus{
|
||||
State: availabilityStateTemporarilyBlocked,
|
||||
Reason: availabilityReasonUpstreamRejected,
|
||||
ResetAt: s.upstreamRejectedUntil,
|
||||
}
|
||||
case s.consecutivePollFailures > 0:
|
||||
return availabilityStatus{
|
||||
State: availabilityStateTemporarilyBlocked,
|
||||
Reason: availabilityReasonPollFailed,
|
||||
}
|
||||
default:
|
||||
return availabilityStatus{State: availabilityStateUsable}
|
||||
}
|
||||
}
|
||||
|
||||
193
service/ocm/credential_builder.go
Normal file
193
service/ocm/credential_builder.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func buildOCMCredentialProviders(
|
||||
ctx context.Context,
|
||||
options option.OCMServiceOptions,
|
||||
logger log.ContextLogger,
|
||||
) (map[string]credentialProvider, []Credential, error) {
|
||||
allCredentialMap := make(map[string]Credential)
|
||||
var allCredentials []Credential
|
||||
providers := make(map[string]credentialProvider)
|
||||
|
||||
// Pass 1: create default and external credentials
|
||||
for _, credentialOption := range options.Credentials {
|
||||
switch credentialOption.Type {
|
||||
case "default":
|
||||
credential, err := newDefaultCredential(ctx, credentialOption.Tag, credentialOption.DefaultOptions, logger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
allCredentialMap[credentialOption.Tag] = credential
|
||||
allCredentials = append(allCredentials, credential)
|
||||
providers[credentialOption.Tag] = &singleCredentialProvider{credential: credential}
|
||||
case "external":
|
||||
credential, err := newExternalCredential(ctx, credentialOption.Tag, credentialOption.ExternalOptions, logger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
allCredentialMap[credentialOption.Tag] = credential
|
||||
allCredentials = append(allCredentials, credential)
|
||||
providers[credentialOption.Tag] = &singleCredentialProvider{credential: credential}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2: create balancer providers
|
||||
for _, credentialOption := range options.Credentials {
|
||||
if credentialOption.Type == "balancer" {
|
||||
subCredentials, err := resolveCredentialTags(credentialOption.BalancerOptions.Credentials, allCredentialMap, credentialOption.Tag)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
providers[credentialOption.Tag] = newBalancerProvider(subCredentials, credentialOption.BalancerOptions.Strategy, credentialOption.BalancerOptions.RebalanceThreshold, logger)
|
||||
}
|
||||
}
|
||||
|
||||
return providers, allCredentials, nil
|
||||
}
|
||||
|
||||
func resolveCredentialTags(tags []string, allCredentials map[string]Credential, parentTag string) ([]Credential, error) {
|
||||
credentials := make([]Credential, 0, len(tags))
|
||||
for _, tag := range tags {
|
||||
credential, exists := allCredentials[tag]
|
||||
if !exists {
|
||||
return nil, E.New("credential ", parentTag, " references unknown credential: ", tag)
|
||||
}
|
||||
credentials = append(credentials, credential)
|
||||
}
|
||||
if len(credentials) == 0 {
|
||||
return nil, E.New("credential ", parentTag, " has no sub-credentials")
|
||||
}
|
||||
return credentials, nil
|
||||
}
|
||||
|
||||
func validateOCMOptions(options option.OCMServiceOptions) error {
|
||||
tags := make(map[string]bool)
|
||||
credentialTypes := make(map[string]string)
|
||||
for _, credential := range options.Credentials {
|
||||
if tags[credential.Tag] {
|
||||
return E.New("duplicate credential tag: ", credential.Tag)
|
||||
}
|
||||
tags[credential.Tag] = true
|
||||
credentialTypes[credential.Tag] = credential.Type
|
||||
if credential.Type == "default" || credential.Type == "" {
|
||||
if credential.DefaultOptions.Reserve5h > 99 {
|
||||
return E.New("credential ", credential.Tag, ": reserve_5h must be at most 99")
|
||||
}
|
||||
if credential.DefaultOptions.ReserveWeekly > 99 {
|
||||
return E.New("credential ", credential.Tag, ": reserve_weekly must be at most 99")
|
||||
}
|
||||
if credential.DefaultOptions.Limit5h > 100 {
|
||||
return E.New("credential ", credential.Tag, ": limit_5h must be at most 100")
|
||||
}
|
||||
if credential.DefaultOptions.LimitWeekly > 100 {
|
||||
return E.New("credential ", credential.Tag, ": limit_weekly must be at most 100")
|
||||
}
|
||||
if credential.DefaultOptions.Reserve5h > 0 && credential.DefaultOptions.Limit5h > 0 {
|
||||
return E.New("credential ", credential.Tag, ": reserve_5h and limit_5h are mutually exclusive")
|
||||
}
|
||||
if credential.DefaultOptions.ReserveWeekly > 0 && credential.DefaultOptions.LimitWeekly > 0 {
|
||||
return E.New("credential ", credential.Tag, ": reserve_weekly and limit_weekly are mutually exclusive")
|
||||
}
|
||||
}
|
||||
if credential.Type == "external" {
|
||||
if credential.ExternalOptions.Token == "" {
|
||||
return E.New("credential ", credential.Tag, ": external credential requires token")
|
||||
}
|
||||
if credential.ExternalOptions.Reverse && credential.ExternalOptions.URL == "" {
|
||||
return E.New("credential ", credential.Tag, ": reverse external credential requires url")
|
||||
}
|
||||
}
|
||||
if credential.Type == "balancer" {
|
||||
switch credential.BalancerOptions.Strategy {
|
||||
case "", C.BalancerStrategyLeastUsed, C.BalancerStrategyRoundRobin, C.BalancerStrategyRandom, C.BalancerStrategyFallback:
|
||||
default:
|
||||
return E.New("credential ", credential.Tag, ": unknown balancer strategy: ", credential.BalancerOptions.Strategy)
|
||||
}
|
||||
if credential.BalancerOptions.RebalanceThreshold < 0 {
|
||||
return E.New("credential ", credential.Tag, ": rebalance_threshold must not be negative")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
singleCredential := len(options.Credentials) == 1
|
||||
for _, user := range options.Users {
|
||||
if user.Credential == "" && !singleCredential {
|
||||
return E.New("user ", user.Name, " must specify credential in multi-credential mode")
|
||||
}
|
||||
if user.Credential != "" && !tags[user.Credential] {
|
||||
return E.New("user ", user.Name, " references unknown credential: ", user.Credential)
|
||||
}
|
||||
if user.ExternalCredential != "" {
|
||||
if !tags[user.ExternalCredential] {
|
||||
return E.New("user ", user.Name, " references unknown external_credential: ", user.ExternalCredential)
|
||||
}
|
||||
if credentialTypes[user.ExternalCredential] != "external" {
|
||||
return E.New("user ", user.Name, ": external_credential must reference an external type credential")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateOCMCompositeCredentialModes(
|
||||
options option.OCMServiceOptions,
|
||||
providers map[string]credentialProvider,
|
||||
) error {
|
||||
for _, credentialOption := range options.Credentials {
|
||||
if credentialOption.Type != "balancer" {
|
||||
continue
|
||||
}
|
||||
|
||||
provider, exists := providers[credentialOption.Tag]
|
||||
if !exists {
|
||||
return E.New("unknown credential: ", credentialOption.Tag)
|
||||
}
|
||||
|
||||
for _, subCred := range provider.allCredentials() {
|
||||
if !subCred.isAvailable() {
|
||||
continue
|
||||
}
|
||||
if subCred.ocmIsAPIKeyMode() {
|
||||
return E.New(
|
||||
"credential ", credentialOption.Tag,
|
||||
" references API key default credential ", subCred.tagName(),
|
||||
"; balancer and fallback only support OAuth default credentials",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func credentialForUser(
|
||||
userConfigMap map[string]*option.OCMUser,
|
||||
providers map[string]credentialProvider,
|
||||
username string,
|
||||
) (credentialProvider, error) {
|
||||
userConfig, exists := userConfigMap[username]
|
||||
if !exists {
|
||||
return nil, E.New("no credential mapping for user: ", username)
|
||||
}
|
||||
if userConfig.Credential == "" {
|
||||
for _, provider := range providers {
|
||||
return provider, nil
|
||||
}
|
||||
return nil, E.New("no credential available")
|
||||
}
|
||||
provider, exists := providers[userConfig.Credential]
|
||||
if !exists {
|
||||
return nil, E.New("unknown credential: ", userConfig.Credential)
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
@@ -13,6 +13,17 @@ func platformReadCredentials(customPath string) (*oauthCredentials, error) {
|
||||
return readCredentialsFromFile(customPath)
|
||||
}
|
||||
|
||||
func platformCanWriteCredentials(customPath string) error {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return checkCredentialFileWritable(customPath)
|
||||
}
|
||||
|
||||
func platformWriteCredentials(credentials *oauthCredentials, customPath string) error {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
|
||||
826
service/ocm/credential_default.go
Normal file
826
service/ocm/credential_default.go
Normal file
@@ -0,0 +1,826 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/fswatch"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
)
|
||||
|
||||
type defaultCredential struct {
|
||||
tag string
|
||||
serviceContext context.Context
|
||||
credentialPath string
|
||||
credentialFilePath string
|
||||
credentials *oauthCredentials
|
||||
access sync.RWMutex
|
||||
state credentialState
|
||||
stateAccess sync.RWMutex
|
||||
pollAccess sync.Mutex
|
||||
reloadAccess sync.Mutex
|
||||
watcherAccess sync.Mutex
|
||||
cap5h float64
|
||||
capWeekly float64
|
||||
usageTracker *AggregatedUsage
|
||||
dialer N.Dialer
|
||||
forwardHTTPClient *http.Client
|
||||
logger log.ContextLogger
|
||||
watcher *fswatch.Watcher
|
||||
watcherRetryAt time.Time
|
||||
|
||||
statusSubscriber *observable.Subscriber[struct{}]
|
||||
|
||||
// Refresh rate-limit cooldown (protected by access mutex)
|
||||
refreshRetryAt time.Time
|
||||
refreshRetryError error
|
||||
refreshBlocked bool
|
||||
|
||||
// Connection interruption
|
||||
onBecameUnusable func()
|
||||
interrupted bool
|
||||
requestContext context.Context
|
||||
cancelRequests context.CancelFunc
|
||||
requestAccess sync.Mutex
|
||||
}
|
||||
|
||||
func newDefaultCredential(ctx context.Context, tag string, options option.OCMDefaultCredentialOptions, logger log.ContextLogger) (*defaultCredential, error) {
|
||||
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
Options: option.DialerOptions{
|
||||
Detour: options.Detour,
|
||||
},
|
||||
RemoteIsDomain: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create dialer for credential ", tag)
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSClientConfig: &stdTLS.Config{
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
},
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
},
|
||||
}
|
||||
reserve5h := options.Reserve5h
|
||||
if reserve5h == 0 {
|
||||
reserve5h = 1
|
||||
}
|
||||
reserveWeekly := options.ReserveWeekly
|
||||
if reserveWeekly == 0 {
|
||||
reserveWeekly = 1
|
||||
}
|
||||
var cap5h float64
|
||||
if options.Limit5h > 0 {
|
||||
cap5h = float64(options.Limit5h)
|
||||
} else {
|
||||
cap5h = float64(100 - reserve5h)
|
||||
}
|
||||
var capWeekly float64
|
||||
if options.LimitWeekly > 0 {
|
||||
capWeekly = float64(options.LimitWeekly)
|
||||
} else {
|
||||
capWeekly = float64(100 - reserveWeekly)
|
||||
}
|
||||
requestContext, cancelRequests := context.WithCancel(context.Background())
|
||||
credential := &defaultCredential{
|
||||
tag: tag,
|
||||
serviceContext: ctx,
|
||||
credentialPath: options.CredentialPath,
|
||||
cap5h: cap5h,
|
||||
capWeekly: capWeekly,
|
||||
dialer: credentialDialer,
|
||||
forwardHTTPClient: httpClient,
|
||||
logger: logger,
|
||||
requestContext: requestContext,
|
||||
cancelRequests: cancelRequests,
|
||||
}
|
||||
if options.UsagesPath != "" {
|
||||
credential.usageTracker = &AggregatedUsage{
|
||||
LastUpdated: time.Now(),
|
||||
Combinations: make([]CostCombination, 0),
|
||||
filePath: options.UsagesPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
return credential, nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) start() error {
|
||||
credentialFilePath, err := resolveCredentialFilePath(c.credentialPath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "resolve credential path for ", c.tag)
|
||||
}
|
||||
c.credentialFilePath = credentialFilePath
|
||||
err = c.ensureCredentialWatcher()
|
||||
if err != nil {
|
||||
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
err = c.reloadCredentials(true)
|
||||
if err != nil {
|
||||
c.logger.Warn("initial credential load for ", c.tag, ": ", err)
|
||||
}
|
||||
if c.usageTracker != nil {
|
||||
err = c.usageTracker.Load()
|
||||
if err != nil {
|
||||
c.logger.Warn("load usage statistics for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
go c.pollUsage()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) setOnBecameUnusable(fn func()) {
|
||||
c.onBecameUnusable = fn
|
||||
}
|
||||
|
||||
func (c *defaultCredential) setStatusSubscriber(subscriber *observable.Subscriber[struct{}]) {
|
||||
c.statusSubscriber = subscriber
|
||||
}
|
||||
|
||||
func (c *defaultCredential) emitStatusUpdate() {
|
||||
if c.statusSubscriber != nil {
|
||||
c.statusSubscriber.Emit(struct{}{})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) tagName() string {
|
||||
return c.tag
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isExternal() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *defaultCredential) getAccessToken() (string, error) {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
c.access.RLock()
|
||||
if c.credentials != nil && !c.credentials.needsRefresh() {
|
||||
token := c.credentials.getAccessToken()
|
||||
c.access.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
c.access.RUnlock()
|
||||
|
||||
err := c.reloadCredentials(true)
|
||||
if err == nil {
|
||||
c.access.RLock()
|
||||
if c.credentials != nil && !c.credentials.needsRefresh() {
|
||||
token := c.credentials.getAccessToken()
|
||||
c.access.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
c.access.RUnlock()
|
||||
}
|
||||
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
|
||||
if c.credentials == nil {
|
||||
return "", c.unavailableError()
|
||||
}
|
||||
if !c.credentials.needsRefresh() {
|
||||
return c.credentials.getAccessToken(), nil
|
||||
}
|
||||
|
||||
if c.refreshBlocked {
|
||||
return "", c.refreshRetryError
|
||||
}
|
||||
if !c.refreshRetryAt.IsZero() && time.Now().Before(c.refreshRetryAt) {
|
||||
return "", c.refreshRetryError
|
||||
}
|
||||
|
||||
err = platformCanWriteCredentials(c.credentialPath)
|
||||
if err != nil {
|
||||
return "", E.Cause(err, "credential file not writable, refusing refresh to avoid invalidation")
|
||||
}
|
||||
|
||||
baseCredentials := cloneCredentials(c.credentials)
|
||||
newCredentials, retryDelay, err := refreshToken(c.serviceContext, c.forwardHTTPClient, c.credentials)
|
||||
if err != nil {
|
||||
if retryDelay < 0 {
|
||||
c.refreshBlocked = true
|
||||
c.refreshRetryError = err
|
||||
} else if retryDelay > 0 {
|
||||
c.refreshRetryAt = time.Now().Add(retryDelay)
|
||||
c.refreshRetryError = err
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
c.refreshRetryAt = time.Time{}
|
||||
c.refreshRetryError = nil
|
||||
c.refreshBlocked = false
|
||||
|
||||
latestCredentials, latestErr := platformReadCredentials(c.credentialPath)
|
||||
if latestErr == nil && !credentialsEqual(latestCredentials, baseCredentials) {
|
||||
c.credentials = latestCredentials
|
||||
c.stateAccess.Lock()
|
||||
wasAvailable := !c.state.unavailable
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.checkTransitionLocked()
|
||||
shouldEmit := wasAvailable != !c.state.unavailable
|
||||
c.stateAccess.Unlock()
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
if !latestCredentials.needsRefresh() {
|
||||
return latestCredentials.getAccessToken(), nil
|
||||
}
|
||||
return "", E.New("credential ", c.tag, " changed while refreshing")
|
||||
}
|
||||
|
||||
c.credentials = newCredentials
|
||||
c.stateAccess.Lock()
|
||||
wasAvailable := !c.state.unavailable
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.checkTransitionLocked()
|
||||
shouldEmit := wasAvailable != !c.state.unavailable
|
||||
c.stateAccess.Unlock()
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
err = platformWriteCredentials(newCredentials, c.credentialPath)
|
||||
if err != nil {
|
||||
c.logger.Error("persist refreshed token for ", c.tag, ": ", err)
|
||||
}
|
||||
|
||||
return newCredentials.getAccessToken(), nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) getAccountID() string {
|
||||
c.access.RLock()
|
||||
defer c.access.RUnlock()
|
||||
if c.credentials == nil {
|
||||
return ""
|
||||
}
|
||||
return c.credentials.getAccountID()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isAPIKeyMode() bool {
|
||||
c.access.RLock()
|
||||
defer c.access.RUnlock()
|
||||
if c.credentials == nil {
|
||||
return false
|
||||
}
|
||||
return c.credentials.isAPIKeyMode()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) getBaseURL() string {
|
||||
if c.isAPIKeyMode() {
|
||||
return openaiAPIBaseURL
|
||||
}
|
||||
return chatGPTBackendURL
|
||||
}
|
||||
|
||||
func (c *defaultCredential) updateStateFromHeaders(headers http.Header) {
|
||||
c.stateAccess.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
hadData := false
|
||||
|
||||
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
|
||||
if activeLimitIdentifier == "" {
|
||||
activeLimitIdentifier = "codex"
|
||||
}
|
||||
|
||||
fiveHourResetChanged := false
|
||||
fiveHourResetAt := headers.Get("x-" + activeLimitIdentifier + "-primary-reset-at")
|
||||
if fiveHourResetAt != "" {
|
||||
value, err := strconv.ParseInt(fiveHourResetAt, 10, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
newReset := time.Unix(value, 0)
|
||||
if newReset.After(c.state.fiveHourReset) {
|
||||
fiveHourResetChanged = true
|
||||
c.state.fiveHourReset = newReset
|
||||
}
|
||||
}
|
||||
}
|
||||
fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent")
|
||||
if fiveHourPercent != "" {
|
||||
value, err := strconv.ParseFloat(fiveHourPercent, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
if value >= c.state.fiveHourUtilization || fiveHourResetChanged {
|
||||
c.state.fiveHourUtilization = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
weeklyResetChanged := false
|
||||
weeklyResetAt := headers.Get("x-" + activeLimitIdentifier + "-secondary-reset-at")
|
||||
if weeklyResetAt != "" {
|
||||
value, err := strconv.ParseInt(weeklyResetAt, 10, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
newReset := time.Unix(value, 0)
|
||||
if newReset.After(c.state.weeklyReset) {
|
||||
weeklyResetChanged = true
|
||||
c.state.weeklyReset = newReset
|
||||
}
|
||||
}
|
||||
}
|
||||
weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent")
|
||||
if weeklyPercent != "" {
|
||||
value, err := strconv.ParseFloat(weeklyPercent, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
if value >= c.state.weeklyUtilization || weeklyResetChanged {
|
||||
c.state.weeklyUtilization = value
|
||||
}
|
||||
}
|
||||
}
|
||||
if snapshots := parseRateLimitSnapshotsFromHeaders(headers); len(snapshots) > 0 {
|
||||
hadData = true
|
||||
applyRateLimitSnapshotsLocked(&c.state, snapshots, headers.Get("x-codex-active-limit"), c.state.remotePlanWeight, c.state.accountType)
|
||||
}
|
||||
if hadData {
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.lastUpdated = time.Now()
|
||||
c.state.noteSnapshotData()
|
||||
}
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
resetSuffix := ""
|
||||
if !c.state.weeklyReset.IsZero() {
|
||||
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
|
||||
}
|
||||
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
shouldEmit := hadData && (c.state.fiveHourUtilization != oldFiveHour || c.state.weeklyUtilization != oldWeekly || fiveHourResetChanged || weeklyResetChanged)
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markRateLimited(resetAt time.Time) {
|
||||
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
|
||||
c.stateAccess.Lock()
|
||||
c.state.hardRateLimited = true
|
||||
c.state.rateLimitResetAt = resetAt
|
||||
c.state.setAvailability(availabilityStateRateLimited, availabilityReasonHardRateLimit, resetAt)
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markUpstreamRejected() {}
|
||||
|
||||
func (c *defaultCredential) markTemporarilyBlocked(reason availabilityReason, resetAt time.Time) {
|
||||
c.stateAccess.Lock()
|
||||
c.state.setAvailability(availabilityStateTemporarilyBlocked, reason, resetAt)
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isUsable() bool {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
c.stateAccess.RLock()
|
||||
if c.state.unavailable {
|
||||
c.stateAccess.RUnlock()
|
||||
return false
|
||||
}
|
||||
if c.state.consecutivePollFailures > 0 {
|
||||
c.stateAccess.RUnlock()
|
||||
return false
|
||||
}
|
||||
if c.state.hardRateLimited {
|
||||
if time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.stateAccess.RUnlock()
|
||||
return false
|
||||
}
|
||||
c.stateAccess.RUnlock()
|
||||
c.stateAccess.Lock()
|
||||
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
usable := c.checkReservesLocked()
|
||||
c.stateAccess.Unlock()
|
||||
return usable
|
||||
}
|
||||
usable := c.checkReservesLocked()
|
||||
c.stateAccess.RUnlock()
|
||||
return usable
|
||||
}
|
||||
|
||||
func (c *defaultCredential) checkReservesLocked() bool {
|
||||
if c.state.fiveHourUtilization >= c.cap5h {
|
||||
return false
|
||||
}
|
||||
if c.state.weeklyUtilization >= c.capWeekly {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// checkTransitionLocked detects usable->unusable transition.
|
||||
// Must be called with stateAccess write lock held.
|
||||
func (c *defaultCredential) checkTransitionLocked() bool {
|
||||
unusable := c.state.unavailable || c.state.hardRateLimited || !c.checkReservesLocked() || c.state.consecutivePollFailures > 0
|
||||
if unusable && !c.interrupted {
|
||||
c.interrupted = true
|
||||
return true
|
||||
}
|
||||
if !unusable && c.interrupted {
|
||||
c.interrupted = false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *defaultCredential) interruptConnections() {
|
||||
c.logger.Warn("interrupting connections for ", c.tag)
|
||||
c.requestAccess.Lock()
|
||||
c.cancelRequests()
|
||||
c.requestContext, c.cancelRequests = context.WithCancel(context.Background())
|
||||
c.requestAccess.Unlock()
|
||||
if c.onBecameUnusable != nil {
|
||||
c.onBecameUnusable()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) wrapRequestContext(parent context.Context) *credentialRequestContext {
|
||||
c.requestAccess.Lock()
|
||||
credentialContext := c.requestContext
|
||||
c.requestAccess.Unlock()
|
||||
derived, cancel := context.WithCancel(parent)
|
||||
stop := context.AfterFunc(credentialContext, func() {
|
||||
cancel()
|
||||
})
|
||||
return &credentialRequestContext{
|
||||
Context: derived,
|
||||
releaseFuncs: []func() bool{stop},
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) fiveHourUtilization() float64 {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.fiveHourUtilization
|
||||
}
|
||||
|
||||
func (c *defaultCredential) hasSnapshotData() bool {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.hasSnapshotData()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) weeklyUtilization() float64 {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.weeklyUtilization
|
||||
}
|
||||
|
||||
func (c *defaultCredential) planWeight() float64 {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return ocmPlanWeight(c.state.accountType)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) weeklyResetTime() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.weeklyReset
|
||||
}
|
||||
|
||||
func (c *defaultCredential) fiveHourResetTime() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.fiveHourReset
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isAvailable() bool {
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return !c.state.unavailable
|
||||
}
|
||||
|
||||
func (c *defaultCredential) availabilityStatus() availabilityStatus {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.currentAvailability()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) unavailableError() error {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
if !c.state.unavailable {
|
||||
return nil
|
||||
}
|
||||
if c.state.lastCredentialLoadError == "" {
|
||||
return E.New("credential ", c.tag, " is unavailable")
|
||||
}
|
||||
return E.New("credential ", c.tag, " is unavailable: ", c.state.lastCredentialLoadError)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) lastUpdatedTime() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
return c.state.lastUpdated
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markUsagePollAttempted() {
|
||||
c.stateAccess.Lock()
|
||||
defer c.stateAccess.Unlock()
|
||||
c.state.lastUpdated = time.Now()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) incrementPollFailures() {
|
||||
c.stateAccess.Lock()
|
||||
c.state.consecutivePollFailures++
|
||||
c.state.setAvailability(availabilityStateTemporarilyBlocked, availabilityReasonPollFailed, time.Time{})
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *defaultCredential) pollBackoff(baseInterval time.Duration) time.Duration {
|
||||
c.stateAccess.RLock()
|
||||
failures := c.state.consecutivePollFailures
|
||||
retryDelay := c.state.usageAPIRetryDelay
|
||||
c.stateAccess.RUnlock()
|
||||
if failures <= 0 {
|
||||
if retryDelay > 0 {
|
||||
return retryDelay
|
||||
}
|
||||
return baseInterval
|
||||
}
|
||||
backoff := failedPollRetryInterval * time.Duration(1<<(failures-1))
|
||||
if backoff > httpRetryMaxBackoff {
|
||||
return httpRetryMaxBackoff
|
||||
}
|
||||
return backoff
|
||||
}
|
||||
|
||||
func (c *defaultCredential) isPollBackoffAtCap() bool {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
failures := c.state.consecutivePollFailures
|
||||
return failures > 0 && failedPollRetryInterval*time.Duration(1<<(failures-1)) >= httpRetryMaxBackoff
|
||||
}
|
||||
|
||||
func (c *defaultCredential) earliestReset() time.Time {
|
||||
c.stateAccess.RLock()
|
||||
defer c.stateAccess.RUnlock()
|
||||
if c.state.unavailable {
|
||||
return time.Time{}
|
||||
}
|
||||
if c.state.hardRateLimited {
|
||||
return c.state.rateLimitResetAt
|
||||
}
|
||||
earliest := c.state.fiveHourReset
|
||||
if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) {
|
||||
earliest = c.state.weeklyReset
|
||||
}
|
||||
return earliest
|
||||
}
|
||||
|
||||
func (c *defaultCredential) fiveHourCap() float64 {
|
||||
return c.cap5h
|
||||
}
|
||||
|
||||
func (c *defaultCredential) weeklyCap() float64 {
|
||||
return c.capWeekly
|
||||
}
|
||||
|
||||
func (c *defaultCredential) usageTrackerOrNil() *AggregatedUsage {
|
||||
return c.usageTracker
|
||||
}
|
||||
|
||||
func (c *defaultCredential) httpClient() *http.Client {
|
||||
return c.forwardHTTPClient
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ocmDialer() N.Dialer {
|
||||
return c.dialer
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ocmIsAPIKeyMode() bool {
|
||||
return c.isAPIKeyMode()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ocmGetAccountID() string {
|
||||
return c.getAccountID()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ocmGetBaseURL() string {
|
||||
return c.getBaseURL()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) pollUsage() {
|
||||
if !c.pollAccess.TryLock() {
|
||||
return
|
||||
}
|
||||
defer c.pollAccess.Unlock()
|
||||
defer c.markUsagePollAttempted()
|
||||
|
||||
c.retryCredentialReloadIfNeeded()
|
||||
if !c.isAvailable() {
|
||||
return
|
||||
}
|
||||
if c.isAPIKeyMode() {
|
||||
return
|
||||
}
|
||||
|
||||
accessToken, err := c.getAccessToken()
|
||||
if err != nil {
|
||||
if !c.isPollBackoffAtCap() {
|
||||
c.logger.Error("poll usage for ", c.tag, ": get token: ", err)
|
||||
}
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.serviceContext
|
||||
usageURL := strings.TrimSuffix(chatGPTBackendURL, "/codex") + "/wham/usage"
|
||||
|
||||
accountID := c.getAccountID()
|
||||
pollClient := &http.Client{
|
||||
Transport: c.forwardHTTPClient.Transport,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
response, err := doHTTPWithRetry(ctx, pollClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, usageURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
if accountID != "" {
|
||||
request.Header.Set("ChatGPT-Account-Id", accountID)
|
||||
}
|
||||
return request, nil
|
||||
})
|
||||
if err != nil {
|
||||
if !c.isPollBackoffAtCap() {
|
||||
c.logger.Error("poll usage for ", c.tag, ": ", err)
|
||||
}
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
if response.StatusCode == http.StatusTooManyRequests {
|
||||
retryDelay := time.Minute
|
||||
if retryAfter := response.Header.Get("Retry-After"); retryAfter != "" {
|
||||
seconds, err := strconv.ParseInt(retryAfter, 10, 64)
|
||||
if err == nil && seconds > 0 {
|
||||
retryDelay = time.Duration(seconds) * time.Second
|
||||
}
|
||||
}
|
||||
c.logger.Warn("poll usage for ", c.tag, ": usage API rate limited, retry in ", log.FormatDuration(retryDelay))
|
||||
c.stateAccess.Lock()
|
||||
c.state.usageAPIRetryDelay = retryDelay
|
||||
c.stateAccess.Unlock()
|
||||
return
|
||||
}
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
|
||||
var usageResponse usageRateLimitStatusPayload
|
||||
err = json.NewDecoder(response.Body).Decode(&usageResponse)
|
||||
if err != nil {
|
||||
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
|
||||
c.stateAccess.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.usageAPIRetryDelay = 0
|
||||
applyRateLimitSnapshotsLocked(&c.state, snapshotsFromUsagePayload(usageResponse), c.state.activeLimitID, c.state.remotePlanWeight, usageResponse.PlanType)
|
||||
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
c.state.noteSnapshotData()
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
resetSuffix := ""
|
||||
if !c.state.weeklyReset.IsZero() {
|
||||
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
|
||||
}
|
||||
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateAccess.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
func (c *defaultCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, serviceHeaders http.Header) (*http.Request, error) {
|
||||
accessToken, err := c.getAccessToken()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "get access token for ", c.tag)
|
||||
}
|
||||
|
||||
path := original.URL.Path
|
||||
var proxyPath string
|
||||
if c.isAPIKeyMode() {
|
||||
proxyPath = path
|
||||
} else {
|
||||
proxyPath = strings.TrimPrefix(path, "/v1")
|
||||
}
|
||||
|
||||
proxyURL := c.getBaseURL() + proxyPath
|
||||
if original.URL.RawQuery != "" {
|
||||
proxyURL += "?" + original.URL.RawQuery
|
||||
}
|
||||
|
||||
var body io.Reader
|
||||
if bodyBytes != nil {
|
||||
body = bytes.NewReader(bodyBytes)
|
||||
} else {
|
||||
body = original.Body
|
||||
}
|
||||
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, values := range original.Header {
|
||||
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && !isAPIKeyHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
for key, values := range serviceHeaders {
|
||||
proxyRequest.Header.Del(key)
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
if accountID := c.getAccountID(); accountID != "" {
|
||||
proxyRequest.Header.Set("ChatGPT-Account-Id", accountID)
|
||||
}
|
||||
|
||||
return proxyRequest, nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) close() {
|
||||
if c.watcher != nil {
|
||||
err := c.watcher.Close()
|
||||
if err != nil {
|
||||
c.logger.Error("close credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
if c.usageTracker != nil {
|
||||
c.usageTracker.cancelPendingSave()
|
||||
err := c.usageTracker.Save()
|
||||
if err != nil {
|
||||
c.logger.Error("save usage statistics for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
1074
service/ocm/credential_external.go
Normal file
1074
service/ocm/credential_external.go
Normal file
File diff suppressed because it is too large
Load Diff
152
service/ocm/credential_file.go
Normal file
152
service/ocm/credential_file.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/fswatch"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
const credentialReloadRetryInterval = 2 * time.Second
|
||||
|
||||
func resolveCredentialFilePath(customPath string) (string, error) {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if filepath.IsAbs(customPath) {
|
||||
return customPath, nil
|
||||
}
|
||||
return filepath.Abs(customPath)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ensureCredentialWatcher() error {
|
||||
c.watcherAccess.Lock()
|
||||
defer c.watcherAccess.Unlock()
|
||||
|
||||
if c.watcher != nil || c.credentialFilePath == "" {
|
||||
return nil
|
||||
}
|
||||
if !c.watcherRetryAt.IsZero() && time.Now().Before(c.watcherRetryAt) {
|
||||
return nil
|
||||
}
|
||||
|
||||
watcher, err := fswatch.NewWatcher(fswatch.Options{
|
||||
Path: []string{c.credentialFilePath},
|
||||
Logger: c.logger,
|
||||
Callback: func(string) {
|
||||
err := c.reloadCredentials(true)
|
||||
if err != nil {
|
||||
c.logger.Warn("reload credentials for ", c.tag, ": ", err)
|
||||
}
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval)
|
||||
return err
|
||||
}
|
||||
|
||||
err = watcher.Start()
|
||||
if err != nil {
|
||||
c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval)
|
||||
return err
|
||||
}
|
||||
|
||||
c.watcher = watcher
|
||||
c.watcherRetryAt = time.Time{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) retryCredentialReloadIfNeeded() {
|
||||
c.stateAccess.RLock()
|
||||
unavailable := c.state.unavailable
|
||||
lastAttempt := c.state.lastCredentialLoadAttempt
|
||||
c.stateAccess.RUnlock()
|
||||
if !unavailable {
|
||||
return
|
||||
}
|
||||
if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval {
|
||||
return
|
||||
}
|
||||
|
||||
err := c.ensureCredentialWatcher()
|
||||
if err != nil {
|
||||
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
_ = c.reloadCredentials(false)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) reloadCredentials(force bool) error {
|
||||
c.reloadAccess.Lock()
|
||||
defer c.reloadAccess.Unlock()
|
||||
|
||||
c.stateAccess.RLock()
|
||||
unavailable := c.state.unavailable
|
||||
lastAttempt := c.state.lastCredentialLoadAttempt
|
||||
c.stateAccess.RUnlock()
|
||||
if !force {
|
||||
if !unavailable {
|
||||
return nil
|
||||
}
|
||||
if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval {
|
||||
return c.unavailableError()
|
||||
}
|
||||
}
|
||||
|
||||
c.stateAccess.Lock()
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.stateAccess.Unlock()
|
||||
|
||||
credentials, err := platformReadCredentials(c.credentialPath)
|
||||
if err != nil {
|
||||
return c.markCredentialsUnavailable(E.Cause(err, "read credentials"))
|
||||
}
|
||||
|
||||
c.access.Lock()
|
||||
c.credentials = credentials
|
||||
c.refreshRetryAt = time.Time{}
|
||||
c.refreshRetryError = nil
|
||||
c.refreshBlocked = false
|
||||
c.access.Unlock()
|
||||
|
||||
c.stateAccess.Lock()
|
||||
wasAvailable := !c.state.unavailable
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.checkTransitionLocked()
|
||||
shouldEmit := wasAvailable != !c.state.unavailable
|
||||
c.stateAccess.Unlock()
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markCredentialsUnavailable(err error) error {
|
||||
c.access.Lock()
|
||||
hadCredentials := c.credentials != nil
|
||||
c.credentials = nil
|
||||
c.access.Unlock()
|
||||
|
||||
c.stateAccess.Lock()
|
||||
wasAvailable := !c.state.unavailable
|
||||
c.state.unavailable = true
|
||||
c.state.lastCredentialLoadError = err.Error()
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
shouldEmit := wasAvailable != !c.state.unavailable
|
||||
c.stateAccess.Unlock()
|
||||
|
||||
if shouldInterrupt && hadCredentials {
|
||||
c.interruptConnections()
|
||||
}
|
||||
if shouldEmit {
|
||||
c.emitStatusUpdate()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
233
service/ocm/credential_oauth.go
Normal file
233
service/ocm/credential_oauth.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
const (
|
||||
oauth2ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
oauth2TokenURL = "https://auth.openai.com/oauth/token"
|
||||
openaiAPIBaseURL = "https://api.openai.com"
|
||||
chatGPTBackendURL = "https://chatgpt.com/backend-api/codex"
|
||||
tokenRefreshIntervalDays = 8
|
||||
)
|
||||
|
||||
func getRealUser() (*user.User, error) {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
sudoUserInfo, err := user.Lookup(sudoUser)
|
||||
if err == nil {
|
||||
return sudoUserInfo, nil
|
||||
}
|
||||
}
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
func getDefaultCredentialsPath() (string, error) {
|
||||
if codexHome := os.Getenv("CODEX_HOME"); codexHome != "" {
|
||||
return filepath.Join(codexHome, "auth.json"), nil
|
||||
}
|
||||
userInfo, err := getRealUser()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(userInfo.HomeDir, ".codex", "auth.json"), nil
|
||||
}
|
||||
|
||||
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var credentials oauthCredentials
|
||||
err = json.Unmarshal(data, &credentials)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &credentials, nil
|
||||
}
|
||||
|
||||
func checkCredentialFileWritable(path string) error {
|
||||
file, err := os.OpenFile(path, os.O_WRONLY, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
func writeCredentialsToFile(credentials *oauthCredentials, path string) error {
|
||||
data, err := json.MarshalIndent(credentials, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0o600)
|
||||
}
|
||||
|
||||
type oauthCredentials struct {
|
||||
APIKey string `json:"OPENAI_API_KEY,omitempty"`
|
||||
Tokens *tokenData `json:"tokens,omitempty"`
|
||||
LastRefresh *time.Time `json:"last_refresh,omitempty"`
|
||||
}
|
||||
|
||||
type tokenData struct {
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) isAPIKeyMode() bool {
|
||||
return c.APIKey != ""
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) getAccessToken() string {
|
||||
if c.APIKey != "" {
|
||||
return c.APIKey
|
||||
}
|
||||
if c.Tokens != nil {
|
||||
return c.Tokens.AccessToken
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) getAccountID() string {
|
||||
if c.Tokens != nil {
|
||||
return c.Tokens.AccountID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) needsRefresh() bool {
|
||||
if c.APIKey != "" {
|
||||
return false
|
||||
}
|
||||
if c.Tokens == nil || c.Tokens.RefreshToken == "" {
|
||||
return false
|
||||
}
|
||||
if c.LastRefresh == nil {
|
||||
return true
|
||||
}
|
||||
return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour
|
||||
}
|
||||
|
||||
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, time.Duration, error) {
|
||||
if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" {
|
||||
return nil, 0, E.New("refresh token is empty")
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.Tokens.RefreshToken,
|
||||
"client_id": oauth2ClientID,
|
||||
"scope": "openid profile email",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, 0, E.Cause(err, "marshal request")
|
||||
}
|
||||
|
||||
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "application/json")
|
||||
return request, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode == http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
retryDelay := time.Duration(-1)
|
||||
if retryAfter := response.Header.Get("Retry-After"); retryAfter != "" {
|
||||
seconds, parseErr := strconv.ParseInt(retryAfter, 10, 64)
|
||||
if parseErr == nil && seconds > 0 {
|
||||
retryDelay = time.Duration(seconds) * time.Second
|
||||
}
|
||||
}
|
||||
return nil, retryDelay, E.New("refresh rate limited: ", response.Status, " ", string(body))
|
||||
}
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, 0, E.New("refresh failed: ", response.Status, " ", string(body))
|
||||
}
|
||||
|
||||
var tokenResponse struct {
|
||||
IDToken string `json:"id_token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
|
||||
if err != nil {
|
||||
return nil, 0, E.Cause(err, "decode response")
|
||||
}
|
||||
|
||||
newCredentials := *credentials
|
||||
if newCredentials.Tokens == nil {
|
||||
newCredentials.Tokens = &tokenData{}
|
||||
}
|
||||
if tokenResponse.IDToken != "" {
|
||||
newCredentials.Tokens.IDToken = tokenResponse.IDToken
|
||||
}
|
||||
if tokenResponse.AccessToken != "" {
|
||||
newCredentials.Tokens.AccessToken = tokenResponse.AccessToken
|
||||
}
|
||||
if tokenResponse.RefreshToken != "" {
|
||||
newCredentials.Tokens.RefreshToken = tokenResponse.RefreshToken
|
||||
}
|
||||
now := time.Now()
|
||||
newCredentials.LastRefresh = &now
|
||||
|
||||
return &newCredentials, 0, nil
|
||||
}
|
||||
|
||||
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
|
||||
if credentials == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *credentials
|
||||
if credentials.Tokens != nil {
|
||||
clonedTokens := *credentials.Tokens
|
||||
cloned.Tokens = &clonedTokens
|
||||
}
|
||||
if credentials.LastRefresh != nil {
|
||||
lastRefresh := *credentials.LastRefresh
|
||||
cloned.LastRefresh = &lastRefresh
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
|
||||
if left == nil || right == nil {
|
||||
return left == right
|
||||
}
|
||||
if left.APIKey != right.APIKey {
|
||||
return false
|
||||
}
|
||||
if (left.Tokens == nil) != (right.Tokens == nil) {
|
||||
return false
|
||||
}
|
||||
if left.Tokens != nil && *left.Tokens != *right.Tokens {
|
||||
return false
|
||||
}
|
||||
if (left.LastRefresh == nil) != (right.LastRefresh == nil) {
|
||||
return false
|
||||
}
|
||||
if left.LastRefresh != nil && !left.LastRefresh.Equal(*right.LastRefresh) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -13,6 +13,17 @@ func platformReadCredentials(customPath string) (*oauthCredentials, error) {
|
||||
return readCredentialsFromFile(customPath)
|
||||
}
|
||||
|
||||
func platformCanWriteCredentials(customPath string) error {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return checkCredentialFileWritable(customPath)
|
||||
}
|
||||
|
||||
func platformWriteCredentials(credentials *oauthCredentials, customPath string) error {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
|
||||
446
service/ocm/credential_provider.go
Normal file
446
service/ocm/credential_provider.go
Normal file
@@ -0,0 +1,446 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type credentialProvider interface {
|
||||
selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error)
|
||||
onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential
|
||||
linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool
|
||||
pollIfStale()
|
||||
pollCredentialIfStale(credential Credential)
|
||||
allCredentials() []Credential
|
||||
close()
|
||||
}
|
||||
|
||||
type singleCredentialProvider struct {
|
||||
credential Credential
|
||||
sessionAccess sync.RWMutex
|
||||
sessions map[string]time.Time
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) {
|
||||
if !selection.allows(p.credential) {
|
||||
return nil, false, E.New("credential ", p.credential.tagName(), " is filtered out")
|
||||
}
|
||||
if !p.credential.isAvailable() {
|
||||
return nil, false, p.credential.unavailableError()
|
||||
}
|
||||
if !p.credential.isUsable() {
|
||||
return nil, false, E.New("credential ", p.credential.tagName(), " is rate-limited")
|
||||
}
|
||||
var isNew bool
|
||||
if sessionID != "" {
|
||||
p.sessionAccess.Lock()
|
||||
if p.sessions == nil {
|
||||
p.sessions = make(map[string]time.Time)
|
||||
}
|
||||
_, exists := p.sessions[sessionID]
|
||||
if !exists {
|
||||
p.sessions[sessionID] = time.Now()
|
||||
isNew = true
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
return p.credential, isNew, nil
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) onRateLimited(_ string, credential Credential, resetAt time.Time, _ credentialSelection) Credential {
|
||||
credential.markRateLimited(resetAt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) pollIfStale() {
|
||||
now := time.Now()
|
||||
p.sessionAccess.Lock()
|
||||
for id, createdAt := range p.sessions {
|
||||
if now.Sub(createdAt) > sessionExpiry {
|
||||
delete(p.sessions, id)
|
||||
}
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
|
||||
if time.Since(p.credential.lastUpdatedTime()) > p.credential.pollBackoff(defaultPollInterval) {
|
||||
p.credential.pollUsage()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) allCredentials() []Credential {
|
||||
return []Credential{p.credential}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) linkProviderInterrupt(_ Credential, _ credentialSelection, _ func()) func() bool {
|
||||
return func() bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) pollCredentialIfStale(credential Credential) {
|
||||
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) {
|
||||
credential.pollUsage()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *singleCredentialProvider) close() {}
|
||||
|
||||
type sessionEntry struct {
|
||||
tag string
|
||||
selectionScope credentialSelectionScope
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
type credentialInterruptKey struct {
|
||||
tag string
|
||||
selectionScope credentialSelectionScope
|
||||
}
|
||||
|
||||
type credentialInterruptEntry struct {
|
||||
context context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type balancerProvider struct {
|
||||
credentials []Credential
|
||||
strategy string
|
||||
roundRobinIndex atomic.Uint64
|
||||
rebalanceThreshold float64
|
||||
sessionAccess sync.RWMutex
|
||||
sessions map[string]sessionEntry
|
||||
interruptAccess sync.Mutex
|
||||
credentialInterrupts map[credentialInterruptKey]credentialInterruptEntry
|
||||
logger log.ContextLogger
|
||||
}
|
||||
|
||||
func compositeCredentialSelectable(credential Credential) bool {
|
||||
return !credential.ocmIsAPIKeyMode()
|
||||
}
|
||||
|
||||
func newBalancerProvider(credentials []Credential, strategy string, rebalanceThreshold float64, logger log.ContextLogger) *balancerProvider {
|
||||
return &balancerProvider{
|
||||
credentials: credentials,
|
||||
strategy: strategy,
|
||||
rebalanceThreshold: rebalanceThreshold,
|
||||
sessions: make(map[string]sessionEntry),
|
||||
credentialInterrupts: make(map[credentialInterruptKey]credentialInterruptEntry),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) selectCredential(sessionID string, selection credentialSelection) (Credential, bool, error) {
|
||||
selectionScope := selection.scopeOrDefault()
|
||||
for {
|
||||
if p.strategy == C.BalancerStrategyFallback {
|
||||
best := p.pickCredential(selection.filter)
|
||||
if best == nil {
|
||||
return nil, false, allRateLimitedError(p.credentials)
|
||||
}
|
||||
return best, p.storeSessionIfAbsent(sessionID, sessionEntry{createdAt: time.Now()}), nil
|
||||
}
|
||||
|
||||
if sessionID != "" {
|
||||
p.sessionAccess.RLock()
|
||||
entry, exists := p.sessions[sessionID]
|
||||
p.sessionAccess.RUnlock()
|
||||
if exists {
|
||||
if entry.selectionScope == selectionScope {
|
||||
for _, credential := range p.credentials {
|
||||
if credential.tagName() == entry.tag && compositeCredentialSelectable(credential) && selection.allows(credential) && credential.isUsable() {
|
||||
if p.rebalanceThreshold > 0 && (p.strategy == "" || p.strategy == C.BalancerStrategyLeastUsed) {
|
||||
better := p.pickLeastUsed(selection.filter)
|
||||
if better != nil && better.tagName() != credential.tagName() {
|
||||
effectiveThreshold := p.rebalanceThreshold / credential.planWeight()
|
||||
delta := credential.weeklyUtilization() - better.weeklyUtilization()
|
||||
if delta > effectiveThreshold {
|
||||
p.logger.Info("rebalancing away from ", credential.tagName(),
|
||||
": utilization delta ", delta, "% exceeds effective threshold ",
|
||||
effectiveThreshold, "% (weight ", credential.planWeight(), ")")
|
||||
p.rebalanceCredential(credential.tagName(), selectionScope)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return credential, false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
p.sessionAccess.Lock()
|
||||
currentEntry, stillExists := p.sessions[sessionID]
|
||||
if stillExists && currentEntry == entry {
|
||||
delete(p.sessions, sessionID)
|
||||
p.sessionAccess.Unlock()
|
||||
} else {
|
||||
p.sessionAccess.Unlock()
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
best := p.pickCredential(selection.filter)
|
||||
if best == nil {
|
||||
return nil, false, allRateLimitedError(p.credentials)
|
||||
}
|
||||
if p.storeSessionIfAbsent(sessionID, sessionEntry{
|
||||
tag: best.tagName(),
|
||||
selectionScope: selectionScope,
|
||||
createdAt: time.Now(),
|
||||
}) {
|
||||
return best, true, nil
|
||||
}
|
||||
if sessionID == "" {
|
||||
return best, false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) storeSessionIfAbsent(sessionID string, entry sessionEntry) bool {
|
||||
if sessionID == "" {
|
||||
return false
|
||||
}
|
||||
p.sessionAccess.Lock()
|
||||
defer p.sessionAccess.Unlock()
|
||||
if _, exists := p.sessions[sessionID]; exists {
|
||||
return false
|
||||
}
|
||||
p.sessions[sessionID] = entry
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *balancerProvider) rebalanceCredential(tag string, selectionScope credentialSelectionScope) {
|
||||
key := credentialInterruptKey{tag: tag, selectionScope: selectionScope}
|
||||
p.interruptAccess.Lock()
|
||||
if entry, loaded := p.credentialInterrupts[key]; loaded {
|
||||
entry.cancel()
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
p.credentialInterrupts[key] = credentialInterruptEntry{context: ctx, cancel: cancel}
|
||||
p.interruptAccess.Unlock()
|
||||
|
||||
p.sessionAccess.Lock()
|
||||
for id, entry := range p.sessions {
|
||||
if entry.tag == tag && entry.selectionScope == selectionScope {
|
||||
delete(p.sessions, id)
|
||||
}
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
|
||||
func (p *balancerProvider) linkProviderInterrupt(credential Credential, selection credentialSelection, onInterrupt func()) func() bool {
|
||||
if p.strategy == C.BalancerStrategyFallback {
|
||||
return func() bool { return false }
|
||||
}
|
||||
key := credentialInterruptKey{
|
||||
tag: credential.tagName(),
|
||||
selectionScope: selection.scopeOrDefault(),
|
||||
}
|
||||
p.interruptAccess.Lock()
|
||||
entry, loaded := p.credentialInterrupts[key]
|
||||
if !loaded {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
entry = credentialInterruptEntry{context: ctx, cancel: cancel}
|
||||
p.credentialInterrupts[key] = entry
|
||||
}
|
||||
p.interruptAccess.Unlock()
|
||||
return context.AfterFunc(entry.context, onInterrupt)
|
||||
}
|
||||
|
||||
func (p *balancerProvider) onRateLimited(sessionID string, credential Credential, resetAt time.Time, selection credentialSelection) Credential {
|
||||
credential.markRateLimited(resetAt)
|
||||
if p.strategy == C.BalancerStrategyFallback {
|
||||
return p.pickCredential(selection.filter)
|
||||
}
|
||||
if sessionID != "" {
|
||||
p.sessionAccess.Lock()
|
||||
delete(p.sessions, sessionID)
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
|
||||
best := p.pickCredential(selection.filter)
|
||||
if best != nil && sessionID != "" {
|
||||
p.sessionAccess.Lock()
|
||||
p.sessions[sessionID] = sessionEntry{
|
||||
tag: best.tagName(),
|
||||
selectionScope: selection.scopeOrDefault(),
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickCredential(filter func(Credential) bool) Credential {
|
||||
switch p.strategy {
|
||||
case C.BalancerStrategyRoundRobin:
|
||||
return p.pickRoundRobin(filter)
|
||||
case C.BalancerStrategyRandom:
|
||||
return p.pickRandom(filter)
|
||||
case C.BalancerStrategyFallback:
|
||||
return p.pickFallback(filter)
|
||||
default:
|
||||
return p.pickLeastUsed(filter)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickFallback(filter func(Credential) bool) Credential {
|
||||
for _, credential := range p.credentials {
|
||||
if filter != nil && !filter(credential) {
|
||||
continue
|
||||
}
|
||||
if !compositeCredentialSelectable(credential) {
|
||||
continue
|
||||
}
|
||||
if credential.isUsable() {
|
||||
return credential
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const weeklyWindowHours = 7 * 24
|
||||
|
||||
func (p *balancerProvider) pickLeastUsed(filter func(Credential) bool) Credential {
|
||||
var best Credential
|
||||
bestScore := float64(-1)
|
||||
now := time.Now()
|
||||
for _, credential := range p.credentials {
|
||||
if filter != nil && !filter(credential) {
|
||||
continue
|
||||
}
|
||||
if !compositeCredentialSelectable(credential) {
|
||||
continue
|
||||
}
|
||||
if !credential.isUsable() {
|
||||
continue
|
||||
}
|
||||
remaining := credential.weeklyCap() - credential.weeklyUtilization()
|
||||
score := remaining * credential.planWeight()
|
||||
resetTime := credential.weeklyResetTime()
|
||||
if !resetTime.IsZero() {
|
||||
timeUntilReset := resetTime.Sub(now)
|
||||
if timeUntilReset < time.Hour {
|
||||
timeUntilReset = time.Hour
|
||||
}
|
||||
score *= weeklyWindowHours / timeUntilReset.Hours()
|
||||
}
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
best = credential
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func ocmPlanWeight(accountType string) float64 {
|
||||
switch accountType {
|
||||
case "pro":
|
||||
return 10
|
||||
case "plus":
|
||||
return 1
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickRoundRobin(filter func(Credential) bool) Credential {
|
||||
start := int(p.roundRobinIndex.Add(1) - 1)
|
||||
count := len(p.credentials)
|
||||
for offset := range count {
|
||||
candidate := p.credentials[(start+offset)%count]
|
||||
if filter != nil && !filter(candidate) {
|
||||
continue
|
||||
}
|
||||
if !compositeCredentialSelectable(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pickRandom(filter func(Credential) bool) Credential {
|
||||
var usable []Credential
|
||||
for _, candidate := range p.credentials {
|
||||
if filter != nil && !filter(candidate) {
|
||||
continue
|
||||
}
|
||||
if !compositeCredentialSelectable(candidate) {
|
||||
continue
|
||||
}
|
||||
if candidate.isUsable() {
|
||||
usable = append(usable, candidate)
|
||||
}
|
||||
}
|
||||
if len(usable) == 0 {
|
||||
return nil
|
||||
}
|
||||
return usable[rand.IntN(len(usable))]
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pollIfStale() {
|
||||
now := time.Now()
|
||||
p.sessionAccess.Lock()
|
||||
for id, entry := range p.sessions {
|
||||
if now.Sub(entry.createdAt) > sessionExpiry {
|
||||
delete(p.sessions, id)
|
||||
}
|
||||
}
|
||||
p.sessionAccess.Unlock()
|
||||
|
||||
p.interruptAccess.Lock()
|
||||
for key, entry := range p.credentialInterrupts {
|
||||
if entry.context.Err() != nil {
|
||||
delete(p.credentialInterrupts, key)
|
||||
}
|
||||
}
|
||||
p.interruptAccess.Unlock()
|
||||
|
||||
for _, credential := range p.credentials {
|
||||
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) {
|
||||
credential.pollUsage()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) pollCredentialIfStale(credential Credential) {
|
||||
if time.Since(credential.lastUpdatedTime()) > credential.pollBackoff(defaultPollInterval) {
|
||||
credential.pollUsage()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *balancerProvider) allCredentials() []Credential {
|
||||
return p.credentials
|
||||
}
|
||||
|
||||
func (p *balancerProvider) close() {}
|
||||
|
||||
func allRateLimitedError(credentials []Credential) error {
|
||||
var hasUnavailable bool
|
||||
var earliest time.Time
|
||||
for _, credential := range credentials {
|
||||
if credential.unavailableError() != nil {
|
||||
hasUnavailable = true
|
||||
continue
|
||||
}
|
||||
resetAt := credential.earliestReset()
|
||||
if !resetAt.IsZero() && (earliest.IsZero() || resetAt.Before(earliest)) {
|
||||
earliest = resetAt
|
||||
}
|
||||
}
|
||||
if hasUnavailable {
|
||||
return E.New("all credentials unavailable")
|
||||
}
|
||||
if earliest.IsZero() {
|
||||
return E.New("all credentials rate-limited")
|
||||
}
|
||||
return E.New("all credentials rate-limited, earliest reset in ", log.FormatDuration(time.Until(earliest)))
|
||||
}
|
||||
364
service/ocm/rate_limit_state.go
Normal file
364
service/ocm/rate_limit_state.go
Normal file
@@ -0,0 +1,364 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type availabilityState string
|
||||
|
||||
const (
|
||||
availabilityStateUsable availabilityState = "usable"
|
||||
availabilityStateRateLimited availabilityState = "rate_limited"
|
||||
availabilityStateTemporarilyBlocked availabilityState = "temporarily_blocked"
|
||||
availabilityStateUnavailable availabilityState = "unavailable"
|
||||
availabilityStateUnknown availabilityState = "unknown"
|
||||
)
|
||||
|
||||
type availabilityReason string
|
||||
|
||||
const (
|
||||
availabilityReasonHardRateLimit availabilityReason = "hard_rate_limit"
|
||||
availabilityReasonConnectionLimit availabilityReason = "connection_limit"
|
||||
availabilityReasonPollFailed availabilityReason = "poll_failed"
|
||||
availabilityReasonUpstreamRejected availabilityReason = "upstream_rejected"
|
||||
availabilityReasonNoCredentials availabilityReason = "no_credentials"
|
||||
availabilityReasonUnknown availabilityReason = "unknown"
|
||||
)
|
||||
|
||||
type availabilityStatus struct {
|
||||
State availabilityState
|
||||
Reason availabilityReason
|
||||
ResetAt time.Time
|
||||
}
|
||||
|
||||
func (s availabilityStatus) normalized() availabilityStatus {
|
||||
if s.State == "" {
|
||||
s.State = availabilityStateUnknown
|
||||
}
|
||||
if s.Reason == "" && s.State != availabilityStateUsable {
|
||||
s.Reason = availabilityReasonUnknown
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
type creditsSnapshot struct {
|
||||
HasCredits bool `json:"has_credits"`
|
||||
Unlimited bool `json:"unlimited"`
|
||||
Balance string `json:"balance,omitempty"`
|
||||
}
|
||||
|
||||
type rateLimitWindow struct {
|
||||
UsedPercent float64 `json:"used_percent"`
|
||||
WindowMinutes int64 `json:"window_minutes,omitempty"`
|
||||
ResetAt int64 `json:"reset_at,omitempty"`
|
||||
}
|
||||
|
||||
type rateLimitSnapshot struct {
|
||||
LimitID string `json:"limit_id,omitempty"`
|
||||
LimitName string `json:"limit_name,omitempty"`
|
||||
Primary *rateLimitWindow `json:"primary,omitempty"`
|
||||
Secondary *rateLimitWindow `json:"secondary,omitempty"`
|
||||
Credits *creditsSnapshot `json:"credits,omitempty"`
|
||||
PlanType string `json:"plan_type,omitempty"`
|
||||
}
|
||||
|
||||
func normalizeStoredLimitID(limitID string) string {
|
||||
normalized := normalizeRateLimitIdentifier(limitID)
|
||||
if normalized == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.ReplaceAll(normalized, "-", "_")
|
||||
}
|
||||
|
||||
func headerLimitID(limitID string) string {
|
||||
if limitID == "" {
|
||||
return "codex"
|
||||
}
|
||||
return strings.ReplaceAll(normalizeStoredLimitID(limitID), "_", "-")
|
||||
}
|
||||
|
||||
func defaultRateLimitSnapshot(limitID string) rateLimitSnapshot {
|
||||
if limitID == "" {
|
||||
limitID = "codex"
|
||||
}
|
||||
return rateLimitSnapshot{LimitID: normalizeStoredLimitID(limitID)}
|
||||
}
|
||||
|
||||
func cloneCreditsSnapshot(snapshot *creditsSnapshot) *creditsSnapshot {
|
||||
if snapshot == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *snapshot
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func cloneRateLimitWindow(window *rateLimitWindow) *rateLimitWindow {
|
||||
if window == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *window
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func cloneRateLimitSnapshot(snapshot rateLimitSnapshot) rateLimitSnapshot {
|
||||
snapshot.Primary = cloneRateLimitWindow(snapshot.Primary)
|
||||
snapshot.Secondary = cloneRateLimitWindow(snapshot.Secondary)
|
||||
snapshot.Credits = cloneCreditsSnapshot(snapshot.Credits)
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func sortRateLimitSnapshots(snapshots []rateLimitSnapshot) {
|
||||
slices.SortFunc(snapshots, func(a, b rateLimitSnapshot) int {
|
||||
return strings.Compare(a.LimitID, b.LimitID)
|
||||
})
|
||||
}
|
||||
|
||||
func parseHeaderFloat(headers http.Header, name string) (float64, bool) {
|
||||
value := strings.TrimSpace(headers.Get(name))
|
||||
if value == "" {
|
||||
return 0, false
|
||||
}
|
||||
parsed, err := strconv.ParseFloat(value, 64)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
if !isFinite(parsed) {
|
||||
return 0, false
|
||||
}
|
||||
return parsed, true
|
||||
}
|
||||
|
||||
func isFinite(value float64) bool {
|
||||
return !((value != value) || value > 1e308 || value < -1e308)
|
||||
}
|
||||
|
||||
func parseCreditsSnapshotFromHeaders(headers http.Header) *creditsSnapshot {
|
||||
hasCreditsValue := strings.TrimSpace(headers.Get("x-codex-credits-has-credits"))
|
||||
unlimitedValue := strings.TrimSpace(headers.Get("x-codex-credits-unlimited"))
|
||||
if hasCreditsValue == "" || unlimitedValue == "" {
|
||||
return nil
|
||||
}
|
||||
hasCredits := strings.EqualFold(hasCreditsValue, "true") || hasCreditsValue == "1"
|
||||
unlimited := strings.EqualFold(unlimitedValue, "true") || unlimitedValue == "1"
|
||||
return &creditsSnapshot{
|
||||
HasCredits: hasCredits,
|
||||
Unlimited: unlimited,
|
||||
Balance: strings.TrimSpace(headers.Get("x-codex-credits-balance")),
|
||||
}
|
||||
}
|
||||
|
||||
func parseRateLimitWindowFromHeaders(headers http.Header, prefix string, windowName string) *rateLimitWindow {
|
||||
usedPercent, hasPercent := parseHeaderFloat(headers, prefix+"-"+windowName+"-used-percent")
|
||||
windowMinutes, hasWindow := parseInt64Header(headers, prefix+"-"+windowName+"-window-minutes")
|
||||
resetAt, hasReset := parseInt64Header(headers, prefix+"-"+windowName+"-reset-at")
|
||||
if !hasPercent && !hasWindow && !hasReset {
|
||||
return nil
|
||||
}
|
||||
window := &rateLimitWindow{}
|
||||
if hasPercent {
|
||||
window.UsedPercent = usedPercent
|
||||
}
|
||||
if hasWindow {
|
||||
window.WindowMinutes = windowMinutes
|
||||
}
|
||||
if hasReset {
|
||||
window.ResetAt = resetAt
|
||||
}
|
||||
return window
|
||||
}
|
||||
|
||||
func parseRateLimitSnapshotsFromHeaders(headers http.Header) []rateLimitSnapshot {
|
||||
limitIDs := map[string]struct{}{}
|
||||
for key := range headers {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if strings.HasPrefix(lowerKey, "x-") && strings.Contains(lowerKey, "-primary-") {
|
||||
limitID := strings.TrimPrefix(lowerKey, "x-")
|
||||
if suffix := strings.Index(limitID, "-primary-"); suffix > 0 {
|
||||
limitIDs[normalizeStoredLimitID(limitID[:suffix])] = struct{}{}
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(lowerKey, "x-") && strings.Contains(lowerKey, "-secondary-") {
|
||||
limitID := strings.TrimPrefix(lowerKey, "x-")
|
||||
if suffix := strings.Index(limitID, "-secondary-"); suffix > 0 {
|
||||
limitIDs[normalizeStoredLimitID(limitID[:suffix])] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
if activeLimit := normalizeStoredLimitID(headers.Get("x-codex-active-limit")); activeLimit != "" {
|
||||
limitIDs[activeLimit] = struct{}{}
|
||||
}
|
||||
if credits := parseCreditsSnapshotFromHeaders(headers); credits != nil {
|
||||
_ = credits
|
||||
limitIDs["codex"] = struct{}{}
|
||||
}
|
||||
if len(limitIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
snapshots := make([]rateLimitSnapshot, 0, len(limitIDs))
|
||||
for limitID := range limitIDs {
|
||||
prefix := "x-" + headerLimitID(limitID)
|
||||
snapshot := defaultRateLimitSnapshot(limitID)
|
||||
snapshot.LimitName = strings.TrimSpace(headers.Get(prefix + "-limit-name"))
|
||||
snapshot.Primary = parseRateLimitWindowFromHeaders(headers, prefix, "primary")
|
||||
snapshot.Secondary = parseRateLimitWindowFromHeaders(headers, prefix, "secondary")
|
||||
if limitID == "codex" {
|
||||
snapshot.Credits = parseCreditsSnapshotFromHeaders(headers)
|
||||
}
|
||||
if snapshot.Primary == nil && snapshot.Secondary == nil && snapshot.Credits == nil {
|
||||
continue
|
||||
}
|
||||
snapshots = append(snapshots, snapshot)
|
||||
}
|
||||
sortRateLimitSnapshots(snapshots)
|
||||
return snapshots
|
||||
}
|
||||
|
||||
type usageRateLimitWindowPayload struct {
|
||||
UsedPercent float64 `json:"used_percent"`
|
||||
LimitWindowSeconds int64 `json:"limit_window_seconds"`
|
||||
ResetAt int64 `json:"reset_at"`
|
||||
}
|
||||
|
||||
type usageRateLimitDetailsPayload struct {
|
||||
PrimaryWindow *usageRateLimitWindowPayload `json:"primary_window"`
|
||||
SecondaryWindow *usageRateLimitWindowPayload `json:"secondary_window"`
|
||||
}
|
||||
|
||||
type usageCreditsPayload struct {
|
||||
HasCredits bool `json:"has_credits"`
|
||||
Unlimited bool `json:"unlimited"`
|
||||
Balance *string `json:"balance"`
|
||||
}
|
||||
|
||||
type additionalRateLimitPayload struct {
|
||||
LimitName string `json:"limit_name"`
|
||||
MeteredFeature string `json:"metered_feature"`
|
||||
RateLimit *usageRateLimitDetailsPayload `json:"rate_limit"`
|
||||
}
|
||||
|
||||
type usageRateLimitStatusPayload struct {
|
||||
PlanType string `json:"plan_type"`
|
||||
RateLimit *usageRateLimitDetailsPayload `json:"rate_limit"`
|
||||
Credits *usageCreditsPayload `json:"credits"`
|
||||
AdditionalRateLimits []additionalRateLimitPayload `json:"additional_rate_limits"`
|
||||
}
|
||||
|
||||
func windowFromUsagePayload(window *usageRateLimitWindowPayload) *rateLimitWindow {
|
||||
if window == nil {
|
||||
return nil
|
||||
}
|
||||
result := &rateLimitWindow{
|
||||
UsedPercent: window.UsedPercent,
|
||||
}
|
||||
if window.LimitWindowSeconds > 0 {
|
||||
result.WindowMinutes = (window.LimitWindowSeconds + 59) / 60
|
||||
}
|
||||
if window.ResetAt > 0 {
|
||||
result.ResetAt = window.ResetAt
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func snapshotsFromUsagePayload(payload usageRateLimitStatusPayload) []rateLimitSnapshot {
|
||||
snapshots := make([]rateLimitSnapshot, 0, 1+len(payload.AdditionalRateLimits))
|
||||
codex := defaultRateLimitSnapshot("codex")
|
||||
codex.PlanType = payload.PlanType
|
||||
if payload.RateLimit != nil {
|
||||
codex.Primary = windowFromUsagePayload(payload.RateLimit.PrimaryWindow)
|
||||
codex.Secondary = windowFromUsagePayload(payload.RateLimit.SecondaryWindow)
|
||||
}
|
||||
if payload.Credits != nil {
|
||||
codex.Credits = &creditsSnapshot{
|
||||
HasCredits: payload.Credits.HasCredits,
|
||||
Unlimited: payload.Credits.Unlimited,
|
||||
}
|
||||
if payload.Credits.Balance != nil {
|
||||
codex.Credits.Balance = *payload.Credits.Balance
|
||||
}
|
||||
}
|
||||
if codex.Primary != nil || codex.Secondary != nil || codex.Credits != nil || codex.PlanType != "" {
|
||||
snapshots = append(snapshots, codex)
|
||||
}
|
||||
for _, additional := range payload.AdditionalRateLimits {
|
||||
snapshot := defaultRateLimitSnapshot(additional.MeteredFeature)
|
||||
snapshot.LimitName = additional.LimitName
|
||||
snapshot.PlanType = payload.PlanType
|
||||
if additional.RateLimit != nil {
|
||||
snapshot.Primary = windowFromUsagePayload(additional.RateLimit.PrimaryWindow)
|
||||
snapshot.Secondary = windowFromUsagePayload(additional.RateLimit.SecondaryWindow)
|
||||
}
|
||||
if snapshot.Primary == nil && snapshot.Secondary == nil {
|
||||
continue
|
||||
}
|
||||
snapshots = append(snapshots, snapshot)
|
||||
}
|
||||
sortRateLimitSnapshots(snapshots)
|
||||
return snapshots
|
||||
}
|
||||
|
||||
func applyRateLimitSnapshotsLocked(state *credentialState, snapshots []rateLimitSnapshot, activeLimitID string, planWeight float64, planType string) {
|
||||
if len(snapshots) == 0 {
|
||||
return
|
||||
}
|
||||
if state.rateLimitSnapshots == nil {
|
||||
state.rateLimitSnapshots = make(map[string]rateLimitSnapshot, len(snapshots))
|
||||
} else {
|
||||
clear(state.rateLimitSnapshots)
|
||||
}
|
||||
for _, snapshot := range snapshots {
|
||||
snapshot = cloneRateLimitSnapshot(snapshot)
|
||||
if snapshot.LimitID == "" {
|
||||
snapshot.LimitID = "codex"
|
||||
}
|
||||
if snapshot.LimitName == "" && snapshot.LimitID != "codex" {
|
||||
snapshot.LimitName = strings.ReplaceAll(snapshot.LimitID, "_", "-")
|
||||
}
|
||||
if snapshot.PlanType == "" {
|
||||
snapshot.PlanType = planType
|
||||
}
|
||||
state.rateLimitSnapshots[snapshot.LimitID] = snapshot
|
||||
}
|
||||
if planWeight > 0 {
|
||||
state.remotePlanWeight = planWeight
|
||||
}
|
||||
if planType != "" {
|
||||
state.accountType = planType
|
||||
}
|
||||
if normalizedActive := normalizeStoredLimitID(activeLimitID); normalizedActive != "" {
|
||||
state.activeLimitID = normalizedActive
|
||||
} else if state.activeLimitID == "" {
|
||||
if _, exists := state.rateLimitSnapshots["codex"]; exists {
|
||||
state.activeLimitID = "codex"
|
||||
} else {
|
||||
for limitID := range state.rateLimitSnapshots {
|
||||
state.activeLimitID = limitID
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
legacy := state.rateLimitSnapshots["codex"]
|
||||
if legacy.LimitID == "" && state.activeLimitID != "" {
|
||||
legacy = state.rateLimitSnapshots[state.activeLimitID]
|
||||
}
|
||||
state.fiveHourUtilization = 0
|
||||
state.fiveHourReset = time.Time{}
|
||||
state.weeklyUtilization = 0
|
||||
state.weeklyReset = time.Time{}
|
||||
if legacy.Primary != nil {
|
||||
state.fiveHourUtilization = legacy.Primary.UsedPercent
|
||||
if legacy.Primary.ResetAt > 0 {
|
||||
state.fiveHourReset = time.Unix(legacy.Primary.ResetAt, 0)
|
||||
}
|
||||
}
|
||||
if legacy.Secondary != nil {
|
||||
state.weeklyUtilization = legacy.Secondary.UsedPercent
|
||||
if legacy.Secondary.ResetAt > 0 {
|
||||
state.weeklyReset = time.Unix(legacy.Secondary.ResetAt, 0)
|
||||
}
|
||||
}
|
||||
state.noteSnapshotData()
|
||||
}
|
||||
88
service/ocm/request_log.go
Normal file
88
service/ocm/request_log.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
)
|
||||
|
||||
type requestLogMetadata struct {
|
||||
Model string
|
||||
ServiceTier string
|
||||
ReasoningEffort string
|
||||
}
|
||||
|
||||
type legacyReasoningEffortPayload struct {
|
||||
ReasoningEffort string `json:"reasoning_effort"`
|
||||
}
|
||||
|
||||
func requestLogMetadataFromChatCompletionRequest(request openai.ChatCompletionNewParams) requestLogMetadata {
|
||||
return requestLogMetadata{
|
||||
Model: string(request.Model),
|
||||
ServiceTier: string(request.ServiceTier),
|
||||
ReasoningEffort: string(request.ReasoningEffort),
|
||||
}
|
||||
}
|
||||
|
||||
func requestLogMetadataFromResponsesRequest(request responses.ResponseNewParams, legacyReasoningEffort string) requestLogMetadata {
|
||||
metadata := requestLogMetadata{
|
||||
Model: string(request.Model),
|
||||
ServiceTier: string(request.ServiceTier),
|
||||
}
|
||||
if request.Reasoning.Effort != "" {
|
||||
metadata.ReasoningEffort = string(request.Reasoning.Effort)
|
||||
}
|
||||
if metadata.ReasoningEffort == "" {
|
||||
metadata.ReasoningEffort = legacyReasoningEffort
|
||||
}
|
||||
return metadata
|
||||
}
|
||||
|
||||
func parseLegacyReasoningEffort(data []byte) string {
|
||||
var legacy legacyReasoningEffortPayload
|
||||
if json.Unmarshal(data, &legacy) != nil {
|
||||
return ""
|
||||
}
|
||||
return legacy.ReasoningEffort
|
||||
}
|
||||
|
||||
func parseRequestLogMetadata(path string, data []byte) requestLogMetadata {
|
||||
switch {
|
||||
case path == "/v1/chat/completions":
|
||||
var request openai.ChatCompletionNewParams
|
||||
if json.Unmarshal(data, &request) != nil {
|
||||
return requestLogMetadata{}
|
||||
}
|
||||
return requestLogMetadataFromChatCompletionRequest(request)
|
||||
case strings.HasPrefix(path, "/v1/responses"):
|
||||
var request responses.ResponseNewParams
|
||||
if json.Unmarshal(data, &request) != nil {
|
||||
return requestLogMetadata{}
|
||||
}
|
||||
return requestLogMetadataFromResponsesRequest(request, parseLegacyReasoningEffort(data))
|
||||
default:
|
||||
return requestLogMetadata{}
|
||||
}
|
||||
}
|
||||
|
||||
func buildAssignedCredentialLogParts(credentialTag string, sessionID string, username string, metadata requestLogMetadata) []any {
|
||||
logParts := []any{"assigned credential ", credentialTag}
|
||||
if sessionID != "" {
|
||||
logParts = append(logParts, " for session ", sessionID)
|
||||
}
|
||||
if username != "" {
|
||||
logParts = append(logParts, " by user ", username)
|
||||
}
|
||||
if metadata.Model != "" {
|
||||
logParts = append(logParts, ", model=", metadata.Model)
|
||||
}
|
||||
if metadata.ReasoningEffort != "" {
|
||||
logParts = append(logParts, ", think=", metadata.ReasoningEffort)
|
||||
}
|
||||
if metadata.ServiceTier == "priority" {
|
||||
logParts = append(logParts, ", fast")
|
||||
}
|
||||
return logParts
|
||||
}
|
||||
126
service/ocm/request_log_test.go
Normal file
126
service/ocm/request_log_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
)
|
||||
|
||||
func TestParseRequestLogMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
metadata := parseRequestLogMetadata("/v1/responses", []byte(`{
|
||||
"model":"gpt-5.4",
|
||||
"service_tier":"priority",
|
||||
"reasoning":{"effort":"xhigh"}
|
||||
}`))
|
||||
|
||||
if metadata.Model != "gpt-5.4" {
|
||||
t.Fatalf("expected model gpt-5.4, got %q", metadata.Model)
|
||||
}
|
||||
if metadata.ServiceTier != "priority" {
|
||||
t.Fatalf("expected priority service tier, got %q", metadata.ServiceTier)
|
||||
}
|
||||
if metadata.ReasoningEffort != "xhigh" {
|
||||
t.Fatalf("expected xhigh reasoning effort, got %q", metadata.ReasoningEffort)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRequestLogMetadataFallsBackToTopLevelReasoningEffort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
metadata := parseRequestLogMetadata("/v1/responses", []byte(`{
|
||||
"model":"gpt-5.4",
|
||||
"reasoning_effort":"high"
|
||||
}`))
|
||||
|
||||
if metadata.ReasoningEffort != "high" {
|
||||
t.Fatalf("expected high reasoning effort, got %q", metadata.ReasoningEffort)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRequestLogMetadataFromChatCompletions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
metadata := parseRequestLogMetadata("/v1/chat/completions", []byte(`{
|
||||
"model":"gpt-5.4",
|
||||
"service_tier":"priority",
|
||||
"reasoning_effort":"xhigh",
|
||||
"messages":[{"role":"user","content":"hi"}]
|
||||
}`))
|
||||
|
||||
if metadata.Model != "gpt-5.4" {
|
||||
t.Fatalf("expected model gpt-5.4, got %q", metadata.Model)
|
||||
}
|
||||
if metadata.ServiceTier != "priority" {
|
||||
t.Fatalf("expected priority service tier, got %q", metadata.ServiceTier)
|
||||
}
|
||||
if metadata.ReasoningEffort != "xhigh" {
|
||||
t.Fatalf("expected xhigh reasoning effort, got %q", metadata.ReasoningEffort)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRequestLogMetadataIgnoresUnsupportedPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
metadata := parseRequestLogMetadata("/v1/files", []byte(`{"model":"gpt-5.4"}`))
|
||||
if metadata != (requestLogMetadata{}) {
|
||||
t.Fatalf("expected zero metadata, got %#v", metadata)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAssignedCredentialLogPartsIncludesThinkLevel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
message := F.ToString(buildAssignedCredentialLogParts("a", "session-1", "alice", requestLogMetadata{
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: "priority",
|
||||
ReasoningEffort: "xhigh",
|
||||
})...)
|
||||
|
||||
for _, fragment := range []string{
|
||||
"assigned credential a",
|
||||
"for session session-1",
|
||||
"by user alice",
|
||||
"model=gpt-5.4",
|
||||
"think=xhigh",
|
||||
"fast",
|
||||
} {
|
||||
if !strings.Contains(message, fragment) {
|
||||
t.Fatalf("expected %q in %q", fragment, message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseWebSocketResponseCreateRequestIncludesThinkLevel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
request, ok := parseWebSocketResponseCreateRequest([]byte(`{
|
||||
"type":"response.create",
|
||||
"model":"gpt-5.4",
|
||||
"reasoning":{"effort":"xhigh"}
|
||||
}`))
|
||||
if !ok {
|
||||
t.Fatal("expected websocket response.create request to parse")
|
||||
}
|
||||
if request.metadata().ReasoningEffort != "xhigh" {
|
||||
t.Fatalf("expected xhigh reasoning effort, got %q", request.metadata().ReasoningEffort)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseWebSocketResponseCreateRequestFallsBackToLegacyReasoningEffort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
request, ok := parseWebSocketResponseCreateRequest([]byte(`{
|
||||
"type":"response.create",
|
||||
"model":"gpt-5.4",
|
||||
"reasoning_effort":"high"
|
||||
}`))
|
||||
if !ok {
|
||||
t.Fatal("expected websocket response.create request to parse")
|
||||
}
|
||||
if request.metadata().ReasoningEffort != "high" {
|
||||
t.Fatalf("expected high reasoning effort, got %q", request.metadata().ReasoningEffort)
|
||||
}
|
||||
}
|
||||
266
service/ocm/reverse.go
Normal file
266
service/ocm/reverse.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"io"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
var defaultYamuxConfig = func() *yamux.Config {
|
||||
config := yamux.DefaultConfig()
|
||||
config.KeepAliveInterval = 15 * time.Second
|
||||
config.ConnectionWriteTimeout = 10 * time.Second
|
||||
config.MaxStreamWindowSize = 512 * 1024
|
||||
config.LogOutput = io.Discard
|
||||
return config
|
||||
}()
|
||||
|
||||
type bufferedConn struct {
|
||||
reader *bufio.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
type yamuxNetListener struct {
|
||||
session *yamux.Session
|
||||
}
|
||||
|
||||
func (l *yamuxNetListener) Accept() (net.Conn, error) {
|
||||
return l.session.Accept()
|
||||
}
|
||||
|
||||
func (l *yamuxNetListener) Close() error {
|
||||
return l.session.Close()
|
||||
}
|
||||
|
||||
func (l *yamuxNetListener) Addr() net.Addr {
|
||||
return l.session.Addr()
|
||||
}
|
||||
|
||||
func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Upgrade") != "reverse-proxy" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header")
|
||||
return
|
||||
}
|
||||
|
||||
if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"API key authentication is not supported; use Authorization: Bearer with an OCM user token")
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
|
||||
receiverCredential := s.findReceiverCredential(clientToken)
|
||||
if receiverCredential == nil {
|
||||
s.logger.WarnContext(ctx, "reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token")
|
||||
return
|
||||
}
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
s.logger.ErrorContext(ctx, "reverse connect: hijack not supported")
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported")
|
||||
return
|
||||
}
|
||||
|
||||
conn, bufferedReadWriter, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "reverse connect: hijack: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
response := "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: reverse-proxy\r\n\r\n"
|
||||
_, err = bufferedReadWriter.WriteString(response)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.ErrorContext(ctx, "reverse connect: write upgrade response: ", err)
|
||||
return
|
||||
}
|
||||
err = bufferedReadWriter.Flush()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.ErrorContext(ctx, "reverse connect: flush upgrade response: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := yamux.Client(conn, defaultYamuxConfig)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !receiverCredential.setReverseSession(session) {
|
||||
session.Close()
|
||||
return
|
||||
}
|
||||
s.logger.InfoContext(ctx, "reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
|
||||
|
||||
go func() {
|
||||
<-session.CloseChan()
|
||||
receiverCredential.clearReverseSession(session)
|
||||
s.logger.WarnContext(ctx, "reverse connection lost for ", receiverCredential.tagName())
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) findReceiverCredential(token string) *externalCredential {
|
||||
for _, credential := range s.allCredentials {
|
||||
external, ok := credential.(*externalCredential)
|
||||
if !ok || external.connectorURL != nil {
|
||||
continue
|
||||
}
|
||||
if external.token == token {
|
||||
return external
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) connectorLoop() {
|
||||
var consecutiveFailures int
|
||||
ctx := c.getReverseContext()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
sessionLifetime, err := c.connectorConnect(ctx)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if sessionLifetime >= connectorBackoffResetThreshold {
|
||||
consecutiveFailures = 0
|
||||
}
|
||||
consecutiveFailures++
|
||||
backoff := connectorBackoff(consecutiveFailures)
|
||||
c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff)
|
||||
timer := time.NewTimer(backoff)
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const connectorBackoffResetThreshold = time.Minute
|
||||
|
||||
func connectorBackoff(failures int) time.Duration {
|
||||
if failures > 5 {
|
||||
failures = 5
|
||||
}
|
||||
base := time.Second * time.Duration(1<<failures)
|
||||
if base > 30*time.Second {
|
||||
base = 30 * time.Second
|
||||
}
|
||||
jitter := time.Duration(rand.Int64N(int64(base) / 2))
|
||||
return base + jitter
|
||||
}
|
||||
|
||||
func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duration, error) {
|
||||
if c.reverseService == nil {
|
||||
return 0, E.New("reverse service not initialized")
|
||||
}
|
||||
destination := c.connectorResolveDestination()
|
||||
conn, err := c.connectorDialer.DialContext(ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
return 0, E.Cause(err, "dial")
|
||||
}
|
||||
|
||||
if c.connectorTLS != nil {
|
||||
tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone())
|
||||
err = tlsConn.HandshakeContext(ctx)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "tls handshake")
|
||||
}
|
||||
conn = tlsConn
|
||||
}
|
||||
|
||||
upgradeRequest := "GET " + c.connectorRequestPath + " HTTP/1.1\r\n" +
|
||||
"Host: " + c.connectorURL.Host + "\r\n" +
|
||||
"Connection: Upgrade\r\n" +
|
||||
"Upgrade: reverse-proxy\r\n" +
|
||||
"Authorization: Bearer " + c.token + "\r\n" +
|
||||
"\r\n"
|
||||
_, err = io.WriteString(conn, upgradeRequest)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "write upgrade request")
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
statusLine, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "read upgrade response")
|
||||
}
|
||||
if !strings.HasPrefix(statusLine, "HTTP/1.1 101") {
|
||||
conn.Close()
|
||||
return 0, E.New("unexpected upgrade response: ", strings.TrimSpace(statusLine))
|
||||
}
|
||||
for {
|
||||
line, readErr := reader.ReadString('\n')
|
||||
if readErr != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(readErr, "read upgrade headers")
|
||||
}
|
||||
if strings.TrimSpace(line) == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
session, err := yamux.Server(&bufferedConn{reader: reader, Conn: conn}, defaultYamuxConfig)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "create yamux server")
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
c.logger.Info("reverse connection established for ", c.tag)
|
||||
|
||||
serveStart := time.Now()
|
||||
httpServer := &http.Server{
|
||||
Handler: c.reverseService,
|
||||
ReadTimeout: 0,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
err = httpServer.Serve(&yamuxNetListener{session: session})
|
||||
sessionLifetime := time.Since(serveStart)
|
||||
if err != nil && !E.IsClosed(err) && ctx.Err() == nil {
|
||||
return sessionLifetime, E.Cause(err, "serve")
|
||||
}
|
||||
return sessionLifetime, E.New("connection closed")
|
||||
}
|
||||
|
||||
func (c *externalCredential) connectorResolveDestination() M.Socksaddr {
|
||||
return c.connectorDestination
|
||||
}
|
||||
@@ -1,40 +1,30 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
boxService "github.com/sagernet/sing-box/adapter/service"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/common/listener"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
openaishared "github.com/openai/openai-go/v3/shared"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
|
||||
func RegisterService(registry *boxService.Registry) {
|
||||
@@ -42,27 +32,104 @@ func RegisterService(registry *boxService.Registry) {
|
||||
}
|
||||
|
||||
type errorResponse struct {
|
||||
Error errorDetails `json:"error"`
|
||||
}
|
||||
|
||||
type errorDetails struct {
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message"`
|
||||
Error openaishared.ErrorObject `json:"error"`
|
||||
}
|
||||
|
||||
func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) {
|
||||
writeJSONErrorWithCode(w, r, statusCode, errorType, "", message)
|
||||
}
|
||||
|
||||
func writeJSONErrorWithCode(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, errorCode string, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
json.NewEncoder(w).Encode(errorResponse{
|
||||
Error: errorDetails{
|
||||
Error: openaishared.ErrorObject{
|
||||
Type: errorType,
|
||||
Code: errorCode,
|
||||
Message: message,
|
||||
Param: "",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func writePlainTextError(w http.ResponseWriter, statusCode int, message string) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(statusCode)
|
||||
_, _ = io.WriteString(w, message)
|
||||
}
|
||||
|
||||
const (
|
||||
retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential"
|
||||
)
|
||||
|
||||
func hasAlternativeCredential(provider credentialProvider, currentCredential Credential, selection credentialSelection) bool {
|
||||
if provider == nil || currentCredential == nil {
|
||||
return false
|
||||
}
|
||||
for _, credential := range provider.allCredentials() {
|
||||
if credential == currentCredential {
|
||||
continue
|
||||
}
|
||||
if !selection.allows(credential) {
|
||||
continue
|
||||
}
|
||||
if credential.isUsable() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func unavailableCredentialMessage(provider credentialProvider, fallback string) string {
|
||||
if provider == nil {
|
||||
return fallback
|
||||
}
|
||||
message := allRateLimitedError(provider.allCredentials()).Error()
|
||||
if message == "all credentials unavailable" && fallback != "" {
|
||||
return fallback
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSONErrorWithCode(w, r, http.StatusTooManyRequests, "usage_limit_reached", "", retryableUsageMessage)
|
||||
}
|
||||
|
||||
func writeNonRetryableCredentialError(w http.ResponseWriter, message string) {
|
||||
writePlainTextError(w, http.StatusBadRequest, message)
|
||||
}
|
||||
|
||||
func writeCredentialUnavailableError(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
provider credentialProvider,
|
||||
currentCredential Credential,
|
||||
selection credentialSelection,
|
||||
fallback string,
|
||||
) {
|
||||
if hasAlternativeCredential(provider, currentCredential, selection) {
|
||||
writeRetryableUsageError(w, r)
|
||||
return
|
||||
}
|
||||
if provider != nil && strings.HasPrefix(allRateLimitedError(provider.allCredentials()).Error(), "all credentials rate-limited") {
|
||||
writeRetryableUsageError(w, r)
|
||||
return
|
||||
}
|
||||
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, fallback))
|
||||
}
|
||||
|
||||
func credentialSelectionForUser(userConfig *option.OCMUser) credentialSelection {
|
||||
selection := credentialSelection{scope: credentialSelectionScopeAll}
|
||||
if userConfig != nil && !userConfig.AllowExternalUsage {
|
||||
selection.scope = credentialSelectionScopeNonExternal
|
||||
selection.filter = func(credential Credential) bool {
|
||||
return !credential.isExternal()
|
||||
}
|
||||
}
|
||||
return selection
|
||||
}
|
||||
|
||||
func isHopByHopHeader(header string) bool {
|
||||
switch strings.ToLower(header) {
|
||||
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
|
||||
@@ -72,138 +139,112 @@ func isHopByHopHeader(header string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeRateLimitIdentifier(limitIdentifier string) string {
|
||||
trimmedIdentifier := strings.TrimSpace(strings.ToLower(limitIdentifier))
|
||||
if trimmedIdentifier == "" {
|
||||
return ""
|
||||
func isReverseProxyHeader(header string) bool {
|
||||
lowerHeader := strings.ToLower(header)
|
||||
if strings.HasPrefix(lowerHeader, "cf-") {
|
||||
return true
|
||||
}
|
||||
return strings.ReplaceAll(trimmedIdentifier, "_", "-")
|
||||
}
|
||||
|
||||
func parseInt64Header(headers http.Header, headerName string) (int64, bool) {
|
||||
headerValue := strings.TrimSpace(headers.Get(headerName))
|
||||
if headerValue == "" {
|
||||
return 0, false
|
||||
}
|
||||
parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64)
|
||||
if parseError != nil {
|
||||
return 0, false
|
||||
}
|
||||
return parsedValue, true
|
||||
}
|
||||
|
||||
func weeklyCycleHintForLimit(headers http.Header, limitIdentifier string) *WeeklyCycleHint {
|
||||
normalizedLimitIdentifier := normalizeRateLimitIdentifier(limitIdentifier)
|
||||
if normalizedLimitIdentifier == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
windowHeader := "x-" + normalizedLimitIdentifier + "-secondary-window-minutes"
|
||||
resetHeader := "x-" + normalizedLimitIdentifier + "-secondary-reset-at"
|
||||
|
||||
windowMinutes, hasWindowMinutes := parseInt64Header(headers, windowHeader)
|
||||
resetAtUnix, hasResetAt := parseInt64Header(headers, resetHeader)
|
||||
if !hasWindowMinutes || !hasResetAt || windowMinutes <= 0 || resetAtUnix <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &WeeklyCycleHint{
|
||||
WindowMinutes: windowMinutes,
|
||||
ResetAt: time.Unix(resetAtUnix, 0).UTC(),
|
||||
switch lowerHeader {
|
||||
case "cdn-loop", "true-client-ip", "x-forwarded-for", "x-forwarded-proto", "x-real-ip":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
|
||||
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
|
||||
if activeLimitIdentifier != "" {
|
||||
if activeHint := weeklyCycleHintForLimit(headers, activeLimitIdentifier); activeHint != nil {
|
||||
return activeHint
|
||||
}
|
||||
func isAPIKeyHeader(header string) bool {
|
||||
switch strings.ToLower(header) {
|
||||
case "x-api-key", "api-key":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return weeklyCycleHintForLimit(headers, "codex")
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
boxService.Adapter
|
||||
ctx context.Context
|
||||
logger log.ContextLogger
|
||||
credentialPath string
|
||||
credentials *oauthCredentials
|
||||
users []option.OCMUser
|
||||
dialer N.Dialer
|
||||
httpClient *http.Client
|
||||
httpHeaders http.Header
|
||||
listener *listener.Listener
|
||||
tlsConfig tls.ServerConfig
|
||||
httpServer *http.Server
|
||||
userManager *UserManager
|
||||
accessMutex sync.RWMutex
|
||||
usageTracker *AggregatedUsage
|
||||
webSocketMutex sync.Mutex
|
||||
webSocketGroup sync.WaitGroup
|
||||
webSocketConns map[*webSocketSession]struct{}
|
||||
shuttingDown bool
|
||||
ctx context.Context
|
||||
logger log.ContextLogger
|
||||
options option.OCMServiceOptions
|
||||
httpHeaders http.Header
|
||||
listener *listener.Listener
|
||||
tlsConfig tls.ServerConfig
|
||||
httpServer *http.Server
|
||||
userManager *UserManager
|
||||
webSocketAccess sync.Mutex
|
||||
webSocketGroup sync.WaitGroup
|
||||
webSocketConns map[*webSocketSession]struct{}
|
||||
shuttingDown bool
|
||||
|
||||
providers map[string]credentialProvider
|
||||
allCredentials []Credential
|
||||
userConfigMap map[string]*option.OCMUser
|
||||
statusSubscriber *observable.Subscriber[struct{}]
|
||||
statusObserver *observable.Observer[struct{}]
|
||||
}
|
||||
|
||||
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) {
|
||||
serviceDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
Options: option.DialerOptions{
|
||||
Detour: options.Detour,
|
||||
},
|
||||
RemoteIsDomain: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create dialer")
|
||||
hasLegacy := options.CredentialPath != "" || options.UsagesPath != "" || options.Detour != ""
|
||||
if hasLegacy && len(options.Credentials) > 0 {
|
||||
return nil, E.New("credential_path/usages_path/detour and credentials are mutually exclusive")
|
||||
}
|
||||
if len(options.Credentials) == 0 {
|
||||
options.Credentials = []option.OCMCredential{{
|
||||
Type: "default",
|
||||
Tag: "default",
|
||||
DefaultOptions: option.OCMDefaultCredentialOptions{
|
||||
CredentialPath: options.CredentialPath,
|
||||
UsagesPath: options.UsagesPath,
|
||||
Detour: options.Detour,
|
||||
},
|
||||
}}
|
||||
options.CredentialPath = ""
|
||||
options.UsagesPath = ""
|
||||
options.Detour = ""
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSClientConfig: &stdTLS.Config{
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
},
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
},
|
||||
err := validateOCMOptions(options)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "validate options")
|
||||
}
|
||||
|
||||
userManager := &UserManager{
|
||||
tokenMap: make(map[string]string),
|
||||
}
|
||||
|
||||
var usageTracker *AggregatedUsage
|
||||
if options.UsagesPath != "" {
|
||||
usageTracker = &AggregatedUsage{
|
||||
LastUpdated: time.Now(),
|
||||
Combinations: make([]CostCombination, 0),
|
||||
filePath: options.UsagesPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
statusSubscriber := observable.NewSubscriber[struct{}](16)
|
||||
|
||||
service := &Service{
|
||||
Adapter: boxService.NewAdapter(C.TypeOCM, tag),
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
credentialPath: options.CredentialPath,
|
||||
users: options.Users,
|
||||
dialer: serviceDialer,
|
||||
httpClient: httpClient,
|
||||
httpHeaders: options.Headers.Build(),
|
||||
Adapter: boxService.NewAdapter(C.TypeOCM, tag),
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
options: options,
|
||||
httpHeaders: options.Headers.Build(),
|
||||
listener: listener.New(listener.Options{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
Network: []string{N.NetworkTCP},
|
||||
Listen: options.ListenOptions,
|
||||
}),
|
||||
userManager: userManager,
|
||||
usageTracker: usageTracker,
|
||||
webSocketConns: make(map[*webSocketSession]struct{}),
|
||||
userManager: userManager,
|
||||
statusSubscriber: statusSubscriber,
|
||||
statusObserver: observable.NewObserver[struct{}](statusSubscriber, 8),
|
||||
webSocketConns: make(map[*webSocketSession]struct{}),
|
||||
}
|
||||
|
||||
providers, allCredentials, err := buildOCMCredentialProviders(ctx, options, logger)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "build credential providers")
|
||||
}
|
||||
service.providers = providers
|
||||
service.allCredentials = allCredentials
|
||||
|
||||
userConfigMap := make(map[string]*option.OCMUser)
|
||||
for i := range options.Users {
|
||||
userConfigMap[options.Users[i].Name] = &options.Users[i]
|
||||
}
|
||||
service.userConfigMap = userConfigMap
|
||||
|
||||
if options.TLS != nil {
|
||||
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
|
||||
if err != nil {
|
||||
@@ -220,28 +261,34 @@ func (s *Service) Start(stage adapter.StartStage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.userManager.UpdateUsers(s.users)
|
||||
s.userManager.UpdateUsers(s.options.Users)
|
||||
|
||||
credentials, err := platformReadCredentials(s.credentialPath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read credentials")
|
||||
}
|
||||
s.credentials = credentials
|
||||
|
||||
if s.usageTracker != nil {
|
||||
err = s.usageTracker.Load()
|
||||
if err != nil {
|
||||
s.logger.Warn("load usage statistics: ", err)
|
||||
for _, credential := range s.allCredentials {
|
||||
credential.setStatusSubscriber(s.statusSubscriber)
|
||||
if external, ok := credential.(*externalCredential); ok && external.reverse && external.connectorURL != nil {
|
||||
external.reverseService = s
|
||||
}
|
||||
err := credential.start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tag := credential.tagName()
|
||||
credential.setOnBecameUnusable(func() {
|
||||
s.interruptWebSocketSessionsForCredential(tag)
|
||||
})
|
||||
}
|
||||
err := validateOCMCompositeCredentialModes(s.options, s.providers)
|
||||
if err != nil {
|
||||
return E.Cause(err, "validate loaded credentials")
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Mount("/", s)
|
||||
|
||||
s.httpServer = &http.Server{Handler: router}
|
||||
s.httpServer = &http.Server{Handler: h2c.NewHandler(router, &http2.Server{})}
|
||||
|
||||
if s.tlsConfig != nil {
|
||||
err = s.tlsConfig.Start()
|
||||
err := s.tlsConfig.Start()
|
||||
if err != nil {
|
||||
return E.Cause(err, "create TLS config")
|
||||
}
|
||||
@@ -261,7 +308,7 @@ func (s *Service) Start(stage adapter.StartStage) error {
|
||||
|
||||
go func() {
|
||||
serveErr := s.httpServer.Serve(tcpListener)
|
||||
if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
|
||||
if serveErr != nil && !E.IsClosed(serveErr) {
|
||||
s.logger.Error("serve error: ", serveErr)
|
||||
}
|
||||
}()
|
||||
@@ -269,375 +316,22 @@ func (s *Service) Start(stage adapter.StartStage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) getAccessToken() (string, error) {
|
||||
s.accessMutex.RLock()
|
||||
if !s.credentials.needsRefresh() {
|
||||
token := s.credentials.getAccessToken()
|
||||
s.accessMutex.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
s.accessMutex.RUnlock()
|
||||
|
||||
s.accessMutex.Lock()
|
||||
defer s.accessMutex.Unlock()
|
||||
|
||||
if !s.credentials.needsRefresh() {
|
||||
return s.credentials.getAccessToken(), nil
|
||||
}
|
||||
|
||||
newCredentials, err := refreshToken(s.httpClient, s.credentials)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.credentials = newCredentials
|
||||
|
||||
err = platformWriteCredentials(newCredentials, s.credentialPath)
|
||||
if err != nil {
|
||||
s.logger.Warn("persist refreshed token: ", err)
|
||||
}
|
||||
|
||||
return newCredentials.getAccessToken(), nil
|
||||
}
|
||||
|
||||
func (s *Service) getAccountID() string {
|
||||
s.accessMutex.RLock()
|
||||
defer s.accessMutex.RUnlock()
|
||||
return s.credentials.getAccountID()
|
||||
}
|
||||
|
||||
func (s *Service) isAPIKeyMode() bool {
|
||||
s.accessMutex.RLock()
|
||||
defer s.accessMutex.RUnlock()
|
||||
return s.credentials.isAPIKeyMode()
|
||||
}
|
||||
|
||||
func (s *Service) getBaseURL() string {
|
||||
if s.isAPIKeyMode() {
|
||||
return openaiAPIBaseURL
|
||||
}
|
||||
return chatGPTBackendURL
|
||||
}
|
||||
|
||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
path := r.URL.Path
|
||||
if !strings.HasPrefix(path, "/v1/") {
|
||||
writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/")
|
||||
return
|
||||
}
|
||||
|
||||
var proxyPath string
|
||||
if s.isAPIKeyMode() {
|
||||
proxyPath = path
|
||||
} else {
|
||||
if path == "/v1/chat/completions" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"chat completions endpoint is only available in API key mode")
|
||||
return
|
||||
}
|
||||
proxyPath = strings.TrimPrefix(path, "/v1")
|
||||
}
|
||||
|
||||
var username string
|
||||
if len(s.users) > 0 {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
var ok bool
|
||||
username, ok = s.userManager.Authenticate(clientToken)
|
||||
func (s *Service) InterfaceUpdated() {
|
||||
for _, credential := range s.allCredentials {
|
||||
external, ok := credential.(*externalCredential)
|
||||
if !ok {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
|
||||
s.handleWebSocket(w, r, proxyPath, username)
|
||||
return
|
||||
}
|
||||
|
||||
var requestModel string
|
||||
|
||||
if s.usageTracker != nil && r.Body != nil {
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err == nil {
|
||||
var request struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
err := json.Unmarshal(bodyBytes, &request)
|
||||
if err == nil {
|
||||
requestModel = request.Model
|
||||
}
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, err := s.getAccessToken()
|
||||
if err != nil {
|
||||
s.logger.Error("get access token: ", err)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed")
|
||||
return
|
||||
}
|
||||
|
||||
proxyURL := s.getBaseURL() + proxyPath
|
||||
if r.URL.RawQuery != "" {
|
||||
proxyURL += "?" + r.URL.RawQuery
|
||||
}
|
||||
proxyRequest, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, r.Body)
|
||||
if err != nil {
|
||||
s.logger.Error("create proxy request: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
for key, values := range r.Header {
|
||||
if !isHopByHopHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
for key, values := range s.httpHeaders {
|
||||
proxyRequest.Header.Del(key)
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
if accountID := s.getAccountID(); accountID != "" {
|
||||
proxyRequest.Header.Set("ChatGPT-Account-Id", accountID)
|
||||
}
|
||||
|
||||
response, err := s.httpClient.Do(proxyRequest)
|
||||
if err != nil {
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
for key, values := range response.Header {
|
||||
if !isHopByHopHeader(key) {
|
||||
w.Header()[key] = values
|
||||
}
|
||||
}
|
||||
w.WriteHeader(response.StatusCode)
|
||||
|
||||
trackUsage := s.usageTracker != nil && response.StatusCode == http.StatusOK &&
|
||||
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses"))
|
||||
if trackUsage {
|
||||
s.handleResponseWithTracking(w, response, path, requestModel, username)
|
||||
} else {
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
if err == nil && mediaType != "text/event-stream" {
|
||||
_, _ = io.Copy(w, response.Body)
|
||||
return
|
||||
}
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error("streaming not supported")
|
||||
return
|
||||
}
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
_, writeError := w.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.Error("write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, path string, requestModel string, username string) {
|
||||
isChatCompletions := path == "/v1/chat/completions"
|
||||
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
isStreaming := err == nil && mediaType == "text/event-stream"
|
||||
if !isStreaming && !isChatCompletions && response.Header.Get("Content-Type") == "" {
|
||||
isStreaming = true
|
||||
}
|
||||
if !isStreaming {
|
||||
bodyBytes, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
s.logger.Error("read response body: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
var responseModel, serviceTier string
|
||||
var inputTokens, outputTokens, cachedTokens int64
|
||||
|
||||
if isChatCompletions {
|
||||
var chatCompletion openai.ChatCompletion
|
||||
if json.Unmarshal(bodyBytes, &chatCompletion) == nil {
|
||||
responseModel = chatCompletion.Model
|
||||
serviceTier = string(chatCompletion.ServiceTier)
|
||||
inputTokens = chatCompletion.Usage.PromptTokens
|
||||
outputTokens = chatCompletion.Usage.CompletionTokens
|
||||
cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
} else {
|
||||
var responsesResponse responses.Response
|
||||
if json.Unmarshal(bodyBytes, &responsesResponse) == nil {
|
||||
responseModel = string(responsesResponse.Model)
|
||||
serviceTier = string(responsesResponse.ServiceTier)
|
||||
inputTokens = responsesResponse.Usage.InputTokens
|
||||
outputTokens = responsesResponse.Usage.OutputTokens
|
||||
cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cachedTokens,
|
||||
serviceTier,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = writer.Write(bodyBytes)
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error("streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
var inputTokens, outputTokens, cachedTokens int64
|
||||
var responseModel, serviceTier string
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
var leftover []byte
|
||||
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
data := append(leftover, buffer[:n]...)
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
|
||||
if err == nil {
|
||||
leftover = lines[len(lines)-1]
|
||||
lines = lines[:len(lines)-1]
|
||||
} else {
|
||||
leftover = nil
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
eventData := bytes.TrimPrefix(line, []byte("data: "))
|
||||
if bytes.Equal(eventData, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
|
||||
if isChatCompletions {
|
||||
var chatChunk openai.ChatCompletionChunk
|
||||
if json.Unmarshal(eventData, &chatChunk) == nil {
|
||||
if chatChunk.Model != "" {
|
||||
responseModel = chatChunk.Model
|
||||
}
|
||||
if chatChunk.ServiceTier != "" {
|
||||
serviceTier = string(chatChunk.ServiceTier)
|
||||
}
|
||||
if chatChunk.Usage.PromptTokens > 0 {
|
||||
inputTokens = chatChunk.Usage.PromptTokens
|
||||
cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
if chatChunk.Usage.CompletionTokens > 0 {
|
||||
outputTokens = chatChunk.Usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var streamEvent responses.ResponseStreamEventUnion
|
||||
if json.Unmarshal(eventData, &streamEvent) == nil {
|
||||
if streamEvent.Type == "response.completed" {
|
||||
completedEvent := streamEvent.AsResponseCompleted()
|
||||
if string(completedEvent.Response.Model) != "" {
|
||||
responseModel = string(completedEvent.Response.Model)
|
||||
}
|
||||
if completedEvent.Response.ServiceTier != "" {
|
||||
serviceTier = string(completedEvent.Response.ServiceTier)
|
||||
}
|
||||
if completedEvent.Response.Usage.InputTokens > 0 {
|
||||
inputTokens = completedEvent.Response.Usage.InputTokens
|
||||
cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
if completedEvent.Response.Usage.OutputTokens > 0 {
|
||||
outputTokens = completedEvent.Response.Usage.OutputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, writeError := writer.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.Error("write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cachedTokens,
|
||||
serviceTier,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
return
|
||||
if external.reverse && external.connectorURL != nil {
|
||||
external.reverseService = s
|
||||
external.resetReverseContext()
|
||||
go external.connectorLoop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
s.statusObserver.Close()
|
||||
webSocketSessions := s.startWebSocketShutdown()
|
||||
|
||||
err := common.Close(
|
||||
@@ -650,20 +344,16 @@ func (s *Service) Close() error {
|
||||
}
|
||||
s.webSocketGroup.Wait()
|
||||
|
||||
if s.usageTracker != nil {
|
||||
s.usageTracker.cancelPendingSave()
|
||||
saveErr := s.usageTracker.Save()
|
||||
if saveErr != nil {
|
||||
s.logger.Error("save usage statistics: ", saveErr)
|
||||
}
|
||||
for _, credential := range s.allCredentials {
|
||||
credential.close()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Service) registerWebSocketSession(session *webSocketSession) bool {
|
||||
s.webSocketMutex.Lock()
|
||||
defer s.webSocketMutex.Unlock()
|
||||
s.webSocketAccess.Lock()
|
||||
defer s.webSocketAccess.Unlock()
|
||||
|
||||
if s.shuttingDown {
|
||||
return false
|
||||
@@ -675,12 +365,12 @@ func (s *Service) registerWebSocketSession(session *webSocketSession) bool {
|
||||
}
|
||||
|
||||
func (s *Service) unregisterWebSocketSession(session *webSocketSession) {
|
||||
s.webSocketMutex.Lock()
|
||||
s.webSocketAccess.Lock()
|
||||
_, loaded := s.webSocketConns[session]
|
||||
if loaded {
|
||||
delete(s.webSocketConns, session)
|
||||
}
|
||||
s.webSocketMutex.Unlock()
|
||||
s.webSocketAccess.Unlock()
|
||||
|
||||
if loaded {
|
||||
s.webSocketGroup.Done()
|
||||
@@ -688,14 +378,28 @@ func (s *Service) unregisterWebSocketSession(session *webSocketSession) {
|
||||
}
|
||||
|
||||
func (s *Service) isShuttingDown() bool {
|
||||
s.webSocketMutex.Lock()
|
||||
defer s.webSocketMutex.Unlock()
|
||||
s.webSocketAccess.Lock()
|
||||
defer s.webSocketAccess.Unlock()
|
||||
return s.shuttingDown
|
||||
}
|
||||
|
||||
func (s *Service) interruptWebSocketSessionsForCredential(tag string) {
|
||||
s.webSocketAccess.Lock()
|
||||
var toClose []*webSocketSession
|
||||
for session := range s.webSocketConns {
|
||||
if session.credentialTag == tag {
|
||||
toClose = append(toClose, session)
|
||||
}
|
||||
}
|
||||
s.webSocketAccess.Unlock()
|
||||
for _, session := range toClose {
|
||||
session.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) startWebSocketShutdown() []*webSocketSession {
|
||||
s.webSocketMutex.Lock()
|
||||
defer s.webSocketMutex.Unlock()
|
||||
s.webSocketAccess.Lock()
|
||||
defer s.webSocketAccess.Unlock()
|
||||
|
||||
s.shuttingDown = true
|
||||
|
||||
|
||||
512
service/ocm/service_handler.go
Normal file
512
service/ocm/service_handler.go
Normal file
@@ -0,0 +1,512 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
)
|
||||
|
||||
func weeklyCycleHintForLimit(headers http.Header, limitIdentifier string) *WeeklyCycleHint {
|
||||
normalizedLimitIdentifier := normalizeRateLimitIdentifier(limitIdentifier)
|
||||
if normalizedLimitIdentifier == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
windowHeader := "x-" + normalizedLimitIdentifier + "-secondary-window-minutes"
|
||||
resetHeader := "x-" + normalizedLimitIdentifier + "-secondary-reset-at"
|
||||
|
||||
windowMinutes, hasWindowMinutes := parseInt64Header(headers, windowHeader)
|
||||
resetAtUnix, hasResetAt := parseInt64Header(headers, resetHeader)
|
||||
if !hasWindowMinutes || !hasResetAt || windowMinutes <= 0 || resetAtUnix <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &WeeklyCycleHint{
|
||||
WindowMinutes: windowMinutes,
|
||||
ResetAt: time.Unix(resetAtUnix, 0).UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
|
||||
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
|
||||
if activeLimitIdentifier != "" {
|
||||
if activeHint := weeklyCycleHintForLimit(headers, activeLimitIdentifier); activeHint != nil {
|
||||
return activeHint
|
||||
}
|
||||
}
|
||||
return weeklyCycleHintForLimit(headers, "codex")
|
||||
}
|
||||
|
||||
func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) {
|
||||
if len(s.options.Users) > 0 {
|
||||
return credentialForUser(s.userConfigMap, s.providers, username)
|
||||
}
|
||||
provider := s.providers[s.options.Credentials[0].Tag]
|
||||
if provider == nil {
|
||||
return nil, E.New("no credential available")
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := log.ContextWithNewID(r.Context())
|
||||
if r.URL.Path == "/ocm/v1/status" {
|
||||
s.handleStatusEndpoint(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/ocm/v1/reverse" {
|
||||
s.handleReverseConnect(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
path := r.URL.Path
|
||||
if !strings.HasPrefix(path, "/v1/") {
|
||||
writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/")
|
||||
return
|
||||
}
|
||||
|
||||
if r.Header.Get("X-Api-Key") != "" || r.Header.Get("Api-Key") != "" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"API key authentication is not supported; use Authorization: Bearer with an OCM user token")
|
||||
return
|
||||
}
|
||||
|
||||
var username string
|
||||
if len(s.options.Users) > 0 {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
var ok bool
|
||||
username, ok = s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sessionID := r.Header.Get("session_id")
|
||||
|
||||
// Resolve credential provider and user config
|
||||
var provider credentialProvider
|
||||
var userConfig *option.OCMUser
|
||||
if len(s.options.Users) > 0 {
|
||||
userConfig = s.userConfigMap[username]
|
||||
var err error
|
||||
provider, err = credentialForUser(s.userConfigMap, s.providers, username)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "resolve credential: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
provider = s.providers[s.options.Credentials[0].Tag]
|
||||
}
|
||||
if provider == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale()
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" {
|
||||
for _, credential := range s.allCredentials {
|
||||
if credential.tagName() == userConfig.ExternalCredential && !credential.isUsable() {
|
||||
credential.pollUsage()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
selection := credentialSelectionForUser(userConfig)
|
||||
|
||||
selectedCredential, isNew, err := provider.selectCredential(sessionID, selection)
|
||||
if err != nil {
|
||||
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
|
||||
s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, selection, isNew)
|
||||
return
|
||||
}
|
||||
|
||||
if !selectedCredential.isExternal() && selectedCredential.ocmIsAPIKeyMode() {
|
||||
// API key mode path handling
|
||||
} else if !selectedCredential.isExternal() {
|
||||
if path == "/v1/chat/completions" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"chat completions endpoint is only available in API key mode")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
shouldTrackUsage := selectedCredential.usageTrackerOrNil() != nil &&
|
||||
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses"))
|
||||
canRetryRequest := len(provider.allCredentials()) > 1
|
||||
|
||||
// Read body for model extraction and retry buffer when JSON replay is useful.
|
||||
var bodyBytes []byte
|
||||
var requestMetadata requestLogMetadata
|
||||
var requestModel string
|
||||
if r.Body != nil && (isNew || shouldTrackUsage || canRetryRequest) {
|
||||
mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||
isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json"))
|
||||
if isJSONRequest {
|
||||
bodyBytes, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "read request body: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
|
||||
return
|
||||
}
|
||||
requestMetadata = parseRequestLogMetadata(path, bodyBytes)
|
||||
requestModel = requestMetadata.Model
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
if isNew {
|
||||
s.logger.DebugContext(ctx, buildAssignedCredentialLogParts(selectedCredential.tagName(), sessionID, username, requestMetadata)...)
|
||||
}
|
||||
|
||||
requestContext := selectedCredential.wrapRequestContext(ctx)
|
||||
{
|
||||
currentRequestContext := requestContext
|
||||
requestContext.addInterruptLink(provider.linkProviderInterrupt(selectedCredential, selection, func() {
|
||||
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
||||
}))
|
||||
}
|
||||
defer func() {
|
||||
requestContext.cancelRequest()
|
||||
}()
|
||||
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "create proxy request: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
response, err := selectedCredential.httpClient().Do(proxyRequest)
|
||||
if err != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "credential became unavailable while processing the request")
|
||||
return
|
||||
}
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
|
||||
// Transparent 429 retry
|
||||
for response.StatusCode == http.StatusTooManyRequests {
|
||||
resetAt := parseOCMRateLimitResetFromHeaders(response.Header)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, selection)
|
||||
needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil {
|
||||
response.Body.Close()
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "all credentials rate-limited")
|
||||
return
|
||||
}
|
||||
response.Body.Close()
|
||||
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
requestContext.cancelRequest()
|
||||
requestContext = nextCredential.wrapRequestContext(ctx)
|
||||
{
|
||||
currentRequestContext := requestContext
|
||||
requestContext.addInterruptLink(provider.linkProviderInterrupt(nextCredential, selection, func() {
|
||||
currentRequestContext.cancelOnce.Do(currentRequestContext.cancelFunc)
|
||||
}))
|
||||
}
|
||||
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if buildErr != nil {
|
||||
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
|
||||
return
|
||||
}
|
||||
retryResponse, retryErr := nextCredential.httpClient().Do(retryRequest)
|
||||
if retryErr != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, selection, "credential became unavailable while retrying the request")
|
||||
return
|
||||
}
|
||||
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
|
||||
return
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
response = retryResponse
|
||||
selectedCredential = nextCredential
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
|
||||
if response.StatusCode == http.StatusBadRequest {
|
||||
if selectedCredential.isExternal() {
|
||||
selectedCredential.markUpstreamRejected()
|
||||
} else {
|
||||
provider.pollCredentialIfStale(selectedCredential)
|
||||
}
|
||||
s.logger.ErrorContext(ctx, "upstream rejected from ", selectedCredential.tagName(), ": status ", response.StatusCode)
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, selection, "upstream rejected credential")
|
||||
return
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
|
||||
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
|
||||
return
|
||||
}
|
||||
|
||||
s.rewriteResponseHeaders(response.Header, provider, userConfig)
|
||||
|
||||
for key, values := range response.Header {
|
||||
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
|
||||
w.Header()[key] = values
|
||||
}
|
||||
}
|
||||
w.WriteHeader(response.StatusCode)
|
||||
|
||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||
if usageTracker != nil && response.StatusCode == http.StatusOK &&
|
||||
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) {
|
||||
s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username)
|
||||
} else {
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
if err == nil && mediaType != "text/event-stream" {
|
||||
_, _ = io.Copy(w, response.Body)
|
||||
return
|
||||
}
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
_, writeError := w.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
if E.IsClosedOrCanceled(writeError) {
|
||||
return
|
||||
}
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) {
|
||||
isChatCompletions := path == "/v1/chat/completions"
|
||||
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
isStreaming := err == nil && mediaType == "text/event-stream"
|
||||
if !isStreaming && !isChatCompletions && response.Header.Get("Content-Type") == "" {
|
||||
isStreaming = true
|
||||
}
|
||||
if !isStreaming {
|
||||
bodyBytes, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "read response body: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
var responseModel, serviceTier string
|
||||
var inputTokens, outputTokens, cachedTokens int64
|
||||
|
||||
if isChatCompletions {
|
||||
var chatCompletion openai.ChatCompletion
|
||||
if json.Unmarshal(bodyBytes, &chatCompletion) == nil {
|
||||
responseModel = chatCompletion.Model
|
||||
serviceTier = string(chatCompletion.ServiceTier)
|
||||
inputTokens = chatCompletion.Usage.PromptTokens
|
||||
outputTokens = chatCompletion.Usage.CompletionTokens
|
||||
cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
} else {
|
||||
var responsesResponse responses.Response
|
||||
if json.Unmarshal(bodyBytes, &responsesResponse) == nil {
|
||||
responseModel = string(responsesResponse.Model)
|
||||
serviceTier = string(responsesResponse.ServiceTier)
|
||||
inputTokens = responsesResponse.Usage.InputTokens
|
||||
outputTokens = responsesResponse.Usage.OutputTokens
|
||||
cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cachedTokens,
|
||||
serviceTier,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = writer.Write(bodyBytes)
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
var inputTokens, outputTokens, cachedTokens int64
|
||||
var responseModel, serviceTier string
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
var leftover []byte
|
||||
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
data := append(leftover, buffer[:n]...)
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
|
||||
if err == nil {
|
||||
leftover = lines[len(lines)-1]
|
||||
lines = lines[:len(lines)-1]
|
||||
} else {
|
||||
leftover = nil
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
eventData := bytes.TrimPrefix(line, []byte("data: "))
|
||||
if bytes.Equal(eventData, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
|
||||
if isChatCompletions {
|
||||
var chatChunk openai.ChatCompletionChunk
|
||||
if json.Unmarshal(eventData, &chatChunk) == nil {
|
||||
if chatChunk.Model != "" {
|
||||
responseModel = chatChunk.Model
|
||||
}
|
||||
if chatChunk.ServiceTier != "" {
|
||||
serviceTier = string(chatChunk.ServiceTier)
|
||||
}
|
||||
if chatChunk.Usage.PromptTokens > 0 {
|
||||
inputTokens = chatChunk.Usage.PromptTokens
|
||||
cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
if chatChunk.Usage.CompletionTokens > 0 {
|
||||
outputTokens = chatChunk.Usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var streamEvent responses.ResponseStreamEventUnion
|
||||
if json.Unmarshal(eventData, &streamEvent) == nil {
|
||||
if streamEvent.Type == "response.completed" {
|
||||
completedEvent := streamEvent.AsResponseCompleted()
|
||||
if string(completedEvent.Response.Model) != "" {
|
||||
responseModel = string(completedEvent.Response.Model)
|
||||
}
|
||||
if completedEvent.Response.ServiceTier != "" {
|
||||
serviceTier = string(completedEvent.Response.ServiceTier)
|
||||
}
|
||||
if completedEvent.Response.Usage.InputTokens > 0 {
|
||||
inputTokens = completedEvent.Response.Usage.InputTokens
|
||||
cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
if completedEvent.Response.Usage.OutputTokens > 0 {
|
||||
outputTokens = completedEvent.Response.Usage.OutputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, writeError := writer.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
if E.IsClosedOrCanceled(writeError) {
|
||||
return
|
||||
}
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cachedTokens,
|
||||
serviceTier,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user