mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-11 17:47:20 +10:00
Compare commits
40 Commits
fafe3847ec
...
unstable
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b29fd3be4 | ||
|
|
016e5e1b12 | ||
|
|
92b3bde862 | ||
|
|
64f7349fca | ||
|
|
527372ba74 | ||
|
|
c6c07cb52f | ||
|
|
913f033d1a | ||
|
|
688f8cc4ef | ||
|
|
de51879ae9 | ||
|
|
2943e8e5f0 | ||
|
|
e2b2af8322 | ||
|
|
c8d8d0a3e7 | ||
|
|
8cd7713ca9 | ||
|
|
566abb00cd | ||
|
|
ae7550b465 | ||
|
|
63d4cdffef | ||
|
|
5516d7b045 | ||
|
|
c639c27cdb | ||
|
|
f0022f59a2 | ||
|
|
9e7d863ee7 | ||
|
|
d5c6c6aed2 | ||
|
|
4d89d732e2 | ||
|
|
f6821be8a3 | ||
|
|
03b01efe49 | ||
|
|
16aeba8ec0 | ||
|
|
283a5aacee | ||
|
|
8d852bba9b | ||
|
|
29c8794f45 | ||
|
|
c8d593503f | ||
|
|
a8934be7cd | ||
|
|
7aef716ebc | ||
|
|
7df171ff20 | ||
|
|
46eda3e96f | ||
|
|
727a9d18d6 | ||
|
|
20f60b8c7b | ||
|
|
84b0ddff7f | ||
|
|
811ea13b73 | ||
|
|
bdb90f0a01 | ||
|
|
c9ab6458fa | ||
|
|
16a249f672 |
2
.github/CRONET_GO_VERSION
vendored
2
.github/CRONET_GO_VERSION
vendored
@@ -1 +1 @@
|
||||
2fef65f9dba90ddb89a87d00a6eb6165487c10c1
|
||||
ea7cd33752aed62603775af3df946c1b83f4b0b3
|
||||
|
||||
@@ -2,6 +2,7 @@ package adapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
@@ -82,6 +83,8 @@ type InboundContext struct {
|
||||
SourceGeoIPCode string
|
||||
GeoIPCode string
|
||||
ProcessInfo *ConnectionOwner
|
||||
SourceMACAddress net.HardwareAddr
|
||||
SourceHostname string
|
||||
QueryType uint16
|
||||
FakeIP bool
|
||||
|
||||
|
||||
23
adapter/neighbor.go
Normal file
23
adapter/neighbor.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type NeighborEntry struct {
|
||||
Address netip.Addr
|
||||
MACAddress net.HardwareAddr
|
||||
Hostname string
|
||||
}
|
||||
|
||||
type NeighborResolver interface {
|
||||
LookupMAC(address netip.Addr) (net.HardwareAddr, bool)
|
||||
LookupHostname(address netip.Addr) (string, bool)
|
||||
Start() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type NeighborUpdateListener interface {
|
||||
UpdateNeighborTable(entries []NeighborEntry)
|
||||
}
|
||||
@@ -36,6 +36,10 @@ type PlatformInterface interface {
|
||||
|
||||
UsePlatformNotification() bool
|
||||
SendNotification(notification *Notification) error
|
||||
|
||||
UsePlatformNeighborResolver() bool
|
||||
StartNeighborMonitor(listener NeighborUpdateListener) error
|
||||
CloseNeighborMonitor(listener NeighborUpdateListener) error
|
||||
}
|
||||
|
||||
type FindConnectionOwnerRequest struct {
|
||||
|
||||
@@ -26,6 +26,8 @@ type Router interface {
|
||||
RuleSet(tag string) (RuleSet, bool)
|
||||
Rules() []Rule
|
||||
NeedFindProcess() bool
|
||||
NeedFindNeighbor() bool
|
||||
NeighborResolver() NeighborResolver
|
||||
AppendTracker(tracker ConnectionTracker)
|
||||
ResetNetwork()
|
||||
}
|
||||
|
||||
Submodule clients/android updated: 7777469b5d...0d31ac467f
Submodule clients/apple updated: c19945f65b...22dcf646ce
@@ -2,6 +2,79 @@
|
||||
icon: material/alert-decagram
|
||||
---
|
||||
|
||||
#### 1.14.0-alpha.2
|
||||
|
||||
* Add OpenWrt and Alpine APK packages to release **1**
|
||||
* Backport to macOS 10.13 High Sierra **2**
|
||||
* OCM service: Add WebSocket support for Responses API **3**
|
||||
* Fixes and improvements
|
||||
|
||||
**1**:
|
||||
|
||||
Alpine APK files use `linux` in the filename to distinguish from OpenWrt APKs which use the `openwrt` prefix:
|
||||
|
||||
- OpenWrt: `sing-box_{version}_openwrt_{architecture}.apk`
|
||||
- Alpine: `sing-box_{version}_linux_{architecture}.apk`
|
||||
|
||||
**2**:
|
||||
|
||||
Legacy macOS binaries (with `-legacy-macos-10.13` suffix) now support
|
||||
macOS 10.13 High Sierra, built using Go 1.25 with patches
|
||||
from [SagerNet/go](https://github.com/SagerNet/go).
|
||||
|
||||
**3**:
|
||||
|
||||
See [OCM](/configuration/service/ocm).
|
||||
|
||||
#### 1.13.3-beta.1
|
||||
|
||||
* Add OpenWrt and Alpine APK packages to release **1**
|
||||
* Backport to macOS 10.13 High Sierra **2**
|
||||
* OCM service: Add WebSocket support for Responses API **3**
|
||||
* Fixes and improvements
|
||||
|
||||
**1**:
|
||||
|
||||
Alpine APK files use `linux` in the filename to distinguish from OpenWrt APKs which use the `openwrt` prefix:
|
||||
|
||||
- OpenWrt: `sing-box_{version}_openwrt_{architecture}.apk`
|
||||
- Alpine: `sing-box_{version}_linux_{architecture}.apk`
|
||||
|
||||
**2**:
|
||||
|
||||
Legacy macOS binaries (with `-legacy-macos-10.13` suffix) now support
|
||||
macOS 10.13 High Sierra, built using Go 1.25 with patches
|
||||
from [SagerNet/go](https://github.com/SagerNet/go).
|
||||
|
||||
**3**:
|
||||
|
||||
See [OCM](/configuration/service/ocm).
|
||||
|
||||
#### 1.14.0-alpha.1
|
||||
|
||||
* Add `source_mac_address` and `source_hostname` rule items **1**
|
||||
* Add `include_mac_address` and `exclude_mac_address` TUN options **2**
|
||||
* Update NaiveProxy to 145.0.7632.159 **3**
|
||||
* Fixes and improvements
|
||||
|
||||
**1**:
|
||||
|
||||
New rule items for matching LAN devices by MAC address and hostname via neighbor resolution.
|
||||
Supported on Linux, macOS, or in graphical clients on Android and macOS.
|
||||
|
||||
See [Route Rule](/configuration/route/rule/#source_mac_address), [DNS Rule](/configuration/dns/rule/#source_mac_address) and [Neighbor Resolution](/configuration/shared/neighbor/).
|
||||
|
||||
**2**:
|
||||
|
||||
Limit or exclude devices from TUN routing by MAC address.
|
||||
Only supported on Linux with `auto_route` and `auto_redirect` enabled.
|
||||
|
||||
See [TUN](/configuration/inbound/tun/#include_mac_address).
|
||||
|
||||
**3**:
|
||||
|
||||
This is not an official update from NaiveProxy. Instead, it's a Chromium codebase update maintained by Project S.
|
||||
|
||||
#### 1.13.2
|
||||
|
||||
* Fixes and improvements
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
icon: material/alert-decagram
|
||||
---
|
||||
|
||||
!!! quote "Changes in sing-box 1.14.0"
|
||||
|
||||
:material-plus: [source_mac_address](#source_mac_address)
|
||||
:material-plus: [source_hostname](#source_hostname)
|
||||
|
||||
!!! quote "Changes in sing-box 1.13.0"
|
||||
|
||||
:material-plus: [interface_address](#interface_address)
|
||||
@@ -149,6 +154,12 @@ icon: material/alert-decagram
|
||||
"default_interface_address": [
|
||||
"2000::/3"
|
||||
],
|
||||
"source_mac_address": [
|
||||
"00:11:22:33:44:55"
|
||||
],
|
||||
"source_hostname": [
|
||||
"my-device"
|
||||
],
|
||||
"wifi_ssid": [
|
||||
"My WIFI"
|
||||
],
|
||||
@@ -408,6 +419,26 @@ Matches network interface (same values as `network_type`) address.
|
||||
|
||||
Match default interface address.
|
||||
|
||||
#### source_mac_address
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux, macOS, or in graphical clients on Android and macOS. See [Neighbor Resolution](/configuration/shared/neighbor/) for setup.
|
||||
|
||||
Match source device MAC address.
|
||||
|
||||
#### source_hostname
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux, macOS, or in graphical clients on Android and macOS. See [Neighbor Resolution](/configuration/shared/neighbor/) for setup.
|
||||
|
||||
Match source device hostname from DHCP leases.
|
||||
|
||||
#### wifi_ssid
|
||||
|
||||
!!! quote ""
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
icon: material/alert-decagram
|
||||
---
|
||||
|
||||
!!! quote "sing-box 1.14.0 中的更改"
|
||||
|
||||
:material-plus: [source_mac_address](#source_mac_address)
|
||||
:material-plus: [source_hostname](#source_hostname)
|
||||
|
||||
!!! quote "sing-box 1.13.0 中的更改"
|
||||
|
||||
:material-plus: [interface_address](#interface_address)
|
||||
@@ -149,6 +154,12 @@ icon: material/alert-decagram
|
||||
"default_interface_address": [
|
||||
"2000::/3"
|
||||
],
|
||||
"source_mac_address": [
|
||||
"00:11:22:33:44:55"
|
||||
],
|
||||
"source_hostname": [
|
||||
"my-device"
|
||||
],
|
||||
"wifi_ssid": [
|
||||
"My WIFI"
|
||||
],
|
||||
@@ -407,6 +418,26 @@ Available values: `wifi`, `cellular`, `ethernet` and `other`.
|
||||
|
||||
匹配默认接口地址。
|
||||
|
||||
#### source_mac_address
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux、macOS,或在 Android 和 macOS 图形客户端中支持。参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。
|
||||
|
||||
匹配源设备 MAC 地址。
|
||||
|
||||
#### source_hostname
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux、macOS,或在 Android 和 macOS 图形客户端中支持。参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。
|
||||
|
||||
匹配源设备从 DHCP 租约获取的主机名。
|
||||
|
||||
#### wifi_ssid
|
||||
|
||||
!!! quote ""
|
||||
|
||||
@@ -2,6 +2,15 @@
|
||||
icon: material/new-box
|
||||
---
|
||||
|
||||
!!! quote "Changes in sing-box 1.14.0"
|
||||
|
||||
:material-plus: [include_mac_address](#include_mac_address)
|
||||
:material-plus: [exclude_mac_address](#exclude_mac_address)
|
||||
|
||||
!!! quote "Changes in sing-box 1.13.3"
|
||||
|
||||
:material-alert: [strict_route](#strict_route)
|
||||
|
||||
!!! quote "Changes in sing-box 1.13.0"
|
||||
|
||||
:material-plus: [auto_redirect_reset_mark](#auto_redirect_reset_mark)
|
||||
@@ -125,6 +134,12 @@ icon: material/new-box
|
||||
"exclude_package": [
|
||||
"com.android.captiveportallogin"
|
||||
],
|
||||
"include_mac_address": [
|
||||
"00:11:22:33:44:55"
|
||||
],
|
||||
"exclude_mac_address": [
|
||||
"66:77:88:99:aa:bb"
|
||||
],
|
||||
"platform": {
|
||||
"http_proxy": {
|
||||
"enabled": false,
|
||||
@@ -348,6 +363,9 @@ Enforce strict routing rules when `auto_route` is enabled:
|
||||
|
||||
* Let unsupported network unreachable
|
||||
* For legacy reasons, when neither `strict_route` nor `auto_redirect` are enabled, all ICMP traffic will not go through TUN.
|
||||
* When `auto_redirect` is enabled, `strict_route` also affects `SO_BINDTODEVICE` traffic:
|
||||
* Enabled: `SO_BINDTODEVICE` traffic is redirected through sing-box.
|
||||
* Disabled: `SO_BINDTODEVICE` traffic bypasses sing-box.
|
||||
|
||||
*In Windows*:
|
||||
|
||||
@@ -548,6 +566,30 @@ Limit android packages in route.
|
||||
|
||||
Exclude android packages in route.
|
||||
|
||||
#### include_mac_address
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux with `auto_route` and `auto_redirect` enabled.
|
||||
|
||||
Limit MAC addresses in route. Not limited by default.
|
||||
|
||||
Conflict with `exclude_mac_address`.
|
||||
|
||||
#### exclude_mac_address
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux with `auto_route` and `auto_redirect` enabled.
|
||||
|
||||
Exclude MAC addresses in route.
|
||||
|
||||
Conflict with `include_mac_address`.
|
||||
|
||||
#### platform
|
||||
|
||||
Platform-specific settings, provided by client applications.
|
||||
|
||||
@@ -2,6 +2,15 @@
|
||||
icon: material/new-box
|
||||
---
|
||||
|
||||
!!! quote "sing-box 1.14.0 中的更改"
|
||||
|
||||
:material-plus: [include_mac_address](#include_mac_address)
|
||||
:material-plus: [exclude_mac_address](#exclude_mac_address)
|
||||
|
||||
!!! quote "sing-box 1.13.3 中的更改"
|
||||
|
||||
:material-alert: [strict_route](#strict_route)
|
||||
|
||||
!!! quote "sing-box 1.13.0 中的更改"
|
||||
|
||||
:material-plus: [auto_redirect_reset_mark](#auto_redirect_reset_mark)
|
||||
@@ -126,6 +135,12 @@ icon: material/new-box
|
||||
"exclude_package": [
|
||||
"com.android.captiveportallogin"
|
||||
],
|
||||
"include_mac_address": [
|
||||
"00:11:22:33:44:55"
|
||||
],
|
||||
"exclude_mac_address": [
|
||||
"66:77:88:99:aa:bb"
|
||||
],
|
||||
"platform": {
|
||||
"http_proxy": {
|
||||
"enabled": false,
|
||||
@@ -347,6 +362,9 @@ tun 接口的 IPv6 前缀。
|
||||
|
||||
* 使不支持的网络不可达。
|
||||
* 出于历史遗留原因,当未启用 `strict_route` 或 `auto_redirect` 时,所有 ICMP 流量将不会通过 TUN。
|
||||
* 当启用 `auto_redirect` 时,`strict_route` 也影响 `SO_BINDTODEVICE` 流量:
|
||||
* 启用:`SO_BINDTODEVICE` 流量被重定向通过 sing-box。
|
||||
* 禁用:`SO_BINDTODEVICE` 流量绕过 sing-box。
|
||||
|
||||
*在 Windows 中*:
|
||||
|
||||
@@ -536,6 +554,30 @@ TCP/IP 栈。
|
||||
|
||||
排除路由的 Android 应用包名。
|
||||
|
||||
#### include_mac_address
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux,且需要 `auto_route` 和 `auto_redirect` 已启用。
|
||||
|
||||
限制被路由的 MAC 地址。默认不限制。
|
||||
|
||||
与 `exclude_mac_address` 冲突。
|
||||
|
||||
#### exclude_mac_address
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux,且需要 `auto_route` 和 `auto_redirect` 已启用。
|
||||
|
||||
排除路由的 MAC 地址。
|
||||
|
||||
与 `include_mac_address` 冲突。
|
||||
|
||||
#### platform
|
||||
|
||||
平台特定的设置,由客户端应用提供。
|
||||
|
||||
@@ -4,6 +4,11 @@ icon: material/alert-decagram
|
||||
|
||||
# Route
|
||||
|
||||
!!! quote "Changes in sing-box 1.14.0"
|
||||
|
||||
:material-plus: [find_neighbor](#find_neighbor)
|
||||
:material-plus: [dhcp_lease_files](#dhcp_lease_files)
|
||||
|
||||
!!! quote "Changes in sing-box 1.12.0"
|
||||
|
||||
:material-plus: [default_domain_resolver](#default_domain_resolver)
|
||||
@@ -35,6 +40,9 @@ icon: material/alert-decagram
|
||||
"override_android_vpn": false,
|
||||
"default_interface": "",
|
||||
"default_mark": 0,
|
||||
"find_process": false,
|
||||
"find_neighbor": false,
|
||||
"dhcp_lease_files": [],
|
||||
"default_domain_resolver": "", // or {}
|
||||
"default_network_strategy": "",
|
||||
"default_network_type": [],
|
||||
@@ -107,6 +115,38 @@ Set routing mark by default.
|
||||
|
||||
Takes no effect if `outbound.routing_mark` is set.
|
||||
|
||||
#### find_process
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux, Windows, and macOS.
|
||||
|
||||
Enable process search for logging when no `process_name`, `process_path`, `package_name`, `user` or `user_id` rules exist.
|
||||
|
||||
#### find_neighbor
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux and macOS.
|
||||
|
||||
Enable neighbor resolution for logging when no `source_mac_address` or `source_hostname` rules exist.
|
||||
|
||||
See [Neighbor Resolution](/configuration/shared/neighbor/) for setup.
|
||||
|
||||
#### dhcp_lease_files
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux and macOS.
|
||||
|
||||
Custom DHCP lease file paths for hostname and MAC address resolution.
|
||||
|
||||
Automatically detected from common DHCP servers (dnsmasq, odhcpd, ISC dhcpd, Kea) if empty.
|
||||
|
||||
#### default_domain_resolver
|
||||
|
||||
!!! question "Since sing-box 1.12.0"
|
||||
|
||||
@@ -4,6 +4,11 @@ icon: material/alert-decagram
|
||||
|
||||
# 路由
|
||||
|
||||
!!! quote "sing-box 1.14.0 中的更改"
|
||||
|
||||
:material-plus: [find_neighbor](#find_neighbor)
|
||||
:material-plus: [dhcp_lease_files](#dhcp_lease_files)
|
||||
|
||||
!!! quote "sing-box 1.12.0 中的更改"
|
||||
|
||||
:material-plus: [default_domain_resolver](#default_domain_resolver)
|
||||
@@ -37,6 +42,9 @@ icon: material/alert-decagram
|
||||
"override_android_vpn": false,
|
||||
"default_interface": "",
|
||||
"default_mark": 0,
|
||||
"find_process": false,
|
||||
"find_neighbor": false,
|
||||
"dhcp_lease_files": [],
|
||||
"default_network_strategy": "",
|
||||
"default_fallback_delay": ""
|
||||
}
|
||||
@@ -106,6 +114,38 @@ icon: material/alert-decagram
|
||||
|
||||
如果设置了 `outbound.routing_mark` 设置,则不生效。
|
||||
|
||||
#### find_process
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux、Windows 和 macOS。
|
||||
|
||||
在没有 `process_name`、`process_path`、`package_name`、`user` 或 `user_id` 规则时启用进程搜索以输出日志。
|
||||
|
||||
#### find_neighbor
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux 和 macOS。
|
||||
|
||||
在没有 `source_mac_address` 或 `source_hostname` 规则时启用邻居解析以输出日志。
|
||||
|
||||
参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。
|
||||
|
||||
#### dhcp_lease_files
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux 和 macOS。
|
||||
|
||||
用于主机名和 MAC 地址解析的自定义 DHCP 租约文件路径。
|
||||
|
||||
为空时自动从常见 DHCP 服务器(dnsmasq、odhcpd、ISC dhcpd、Kea)检测。
|
||||
|
||||
#### default_domain_resolver
|
||||
|
||||
!!! question "自 sing-box 1.12.0 起"
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
icon: material/new-box
|
||||
---
|
||||
|
||||
!!! quote "Changes in sing-box 1.14.0"
|
||||
|
||||
:material-plus: [source_mac_address](#source_mac_address)
|
||||
:material-plus: [source_hostname](#source_hostname)
|
||||
|
||||
!!! quote "Changes in sing-box 1.13.0"
|
||||
|
||||
:material-plus: [interface_address](#interface_address)
|
||||
@@ -159,6 +164,12 @@ icon: material/new-box
|
||||
"tailscale",
|
||||
"wireguard"
|
||||
],
|
||||
"source_mac_address": [
|
||||
"00:11:22:33:44:55"
|
||||
],
|
||||
"source_hostname": [
|
||||
"my-device"
|
||||
],
|
||||
"rule_set": [
|
||||
"geoip-cn",
|
||||
"geosite-cn"
|
||||
@@ -449,6 +460,26 @@ Match specified outbounds' preferred routes.
|
||||
| `tailscale` | Match MagicDNS domains and peers' allowed IPs |
|
||||
| `wireguard` | Match peers's allowed IPs |
|
||||
|
||||
#### source_mac_address
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux, macOS, or in graphical clients on Android and macOS. See [Neighbor Resolution](/configuration/shared/neighbor/) for setup.
|
||||
|
||||
Match source device MAC address.
|
||||
|
||||
#### source_hostname
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported on Linux, macOS, or in graphical clients on Android and macOS. See [Neighbor Resolution](/configuration/shared/neighbor/) for setup.
|
||||
|
||||
Match source device hostname from DHCP leases.
|
||||
|
||||
#### rule_set
|
||||
|
||||
!!! question "Since sing-box 1.8.0"
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
icon: material/new-box
|
||||
---
|
||||
|
||||
!!! quote "sing-box 1.14.0 中的更改"
|
||||
|
||||
:material-plus: [source_mac_address](#source_mac_address)
|
||||
:material-plus: [source_hostname](#source_hostname)
|
||||
|
||||
!!! quote "sing-box 1.13.0 中的更改"
|
||||
|
||||
:material-plus: [interface_address](#interface_address)
|
||||
@@ -156,6 +161,12 @@ icon: material/new-box
|
||||
"tailscale",
|
||||
"wireguard"
|
||||
],
|
||||
"source_mac_address": [
|
||||
"00:11:22:33:44:55"
|
||||
],
|
||||
"source_hostname": [
|
||||
"my-device"
|
||||
],
|
||||
"rule_set": [
|
||||
"geoip-cn",
|
||||
"geosite-cn"
|
||||
@@ -446,6 +457,26 @@ icon: material/new-box
|
||||
| `tailscale` | 匹配 MagicDNS 域名和对端的 allowed IPs |
|
||||
| `wireguard` | 匹配对端的 allowed IPs |
|
||||
|
||||
#### source_mac_address
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux、macOS,或在 Android 和 macOS 图形客户端中支持。参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。
|
||||
|
||||
匹配源设备 MAC 地址。
|
||||
|
||||
#### source_hostname
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅支持 Linux、macOS,或在 Android 和 macOS 图形客户端中支持。参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。
|
||||
|
||||
匹配源设备从 DHCP 租约获取的主机名。
|
||||
|
||||
#### rule_set
|
||||
|
||||
!!! question "自 sing-box 1.8.0 起"
|
||||
|
||||
@@ -10,6 +10,11 @@ CCM (Claude Code Multiplexer) service is a multiplexing service that allows you
|
||||
|
||||
It handles OAuth authentication with Claude's API on your local machine while allowing remote Claude Code to authenticate using Auth Tokens via the `ANTHROPIC_AUTH_TOKEN` environment variable.
|
||||
|
||||
!!! quote "Changes in sing-box 1.14.0"
|
||||
|
||||
:material-plus: [credentials](#credentials)
|
||||
:material-alert: [users](#users)
|
||||
|
||||
### Structure
|
||||
|
||||
```json
|
||||
@@ -19,6 +24,7 @@ It handles OAuth authentication with Claude's API on your local machine while al
|
||||
... // Listen Fields
|
||||
|
||||
"credential_path": "",
|
||||
"credentials": [],
|
||||
"usages_path": "",
|
||||
"users": [],
|
||||
"headers": {},
|
||||
@@ -45,6 +51,77 @@ On macOS, credentials are read from the system keychain first, then fall back to
|
||||
|
||||
Refreshed tokens are automatically written back to the same location.
|
||||
|
||||
When `credential_path` points to a file, the service can start before the file exists. The credential becomes available automatically after the file is created or updated, and becomes unavailable immediately if the file is later removed or becomes invalid.
|
||||
|
||||
On macOS without an explicit `credential_path`, keychain changes are not watched. Automatic reload only applies to the credential file path.
|
||||
|
||||
Conflict with `credentials`.
|
||||
|
||||
#### credentials
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
List of credential configurations for multi-credential mode.
|
||||
|
||||
When set, top-level `credential_path`, `usages_path`, and `detour` are forbidden. Each user must specify a `credential` tag.
|
||||
|
||||
Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a required `tag` field.
|
||||
|
||||
##### Default Credential
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "a",
|
||||
"credential_path": "/path/to/.credentials.json",
|
||||
"usages_path": "/path/to/usages.json",
|
||||
"detour": "",
|
||||
"reserve_5h": 20,
|
||||
"reserve_weekly": 20
|
||||
}
|
||||
```
|
||||
|
||||
A single OAuth credential file. The `type` field can be omitted (defaults to `default`). The service can start before the file exists, and reloads file updates automatically.
|
||||
|
||||
- `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`.
|
||||
- `usages_path`: Optional usage tracking file for this credential.
|
||||
- `detour`: Outbound tag for connecting to the Claude API with this credential.
|
||||
- `reserve_5h`: Reserve threshold (1-99) for 5-hour window. Credential pauses at (100-N)% utilization.
|
||||
- `reserve_weekly`: Reserve threshold (1-99) for weekly window. Credential pauses at (100-N)% utilization.
|
||||
|
||||
##### Balancer Credential
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "pool",
|
||||
"type": "balancer",
|
||||
"strategy": "",
|
||||
"credentials": ["a", "b"],
|
||||
"poll_interval": "60s"
|
||||
}
|
||||
```
|
||||
|
||||
Assigns sessions to default credentials based on the selected strategy. Sessions are sticky until the assigned credential hits a rate limit.
|
||||
|
||||
- `strategy`: Selection strategy. One of `least_used` `round_robin` `random`. `least_used` will be used by default.
|
||||
- `credentials`: ==Required== List of default credential tags.
|
||||
- `poll_interval`: How often to poll upstream usage API. Default `60s`.
|
||||
|
||||
##### Fallback Credential
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "backup",
|
||||
"type": "fallback",
|
||||
"credentials": ["a", "b"],
|
||||
"poll_interval": "30s"
|
||||
}
|
||||
```
|
||||
|
||||
Uses credentials in order. Falls through to the next when the current one is exhausted.
|
||||
|
||||
- `credentials`: ==Required== Ordered list of default credential tags.
|
||||
- `poll_interval`: How often to poll upstream usage API. Default `60s`.
|
||||
|
||||
#### usages_path
|
||||
|
||||
Path to the file for storing aggregated API usage statistics.
|
||||
@@ -60,6 +137,8 @@ Statistics are organized by model, context window (200k standard vs 1M premium),
|
||||
|
||||
The statistics file is automatically saved every minute and upon service shutdown.
|
||||
|
||||
Conflict with `credentials`. In multi-credential mode, use `usages_path` on individual default credentials.
|
||||
|
||||
#### users
|
||||
|
||||
List of authorized users for token authentication.
|
||||
@@ -71,7 +150,8 @@ Object format:
|
||||
```json
|
||||
{
|
||||
"name": "",
|
||||
"token": ""
|
||||
"token": "",
|
||||
"credential": ""
|
||||
}
|
||||
```
|
||||
|
||||
@@ -79,6 +159,7 @@ Object fields:
|
||||
|
||||
- `name`: Username identifier for tracking purposes.
|
||||
- `token`: Bearer token for authentication. Claude Code authenticates by setting the `ANTHROPIC_AUTH_TOKEN` environment variable to their token value.
|
||||
- `credential`: Credential tag to use for this user. ==Required== when `credentials` is set.
|
||||
|
||||
#### headers
|
||||
|
||||
@@ -90,6 +171,8 @@ These headers will override any existing headers with the same name.
|
||||
|
||||
Outbound tag for connecting to the Claude API.
|
||||
|
||||
Conflict with `credentials`. In multi-credential mode, use `detour` on individual default credentials.
|
||||
|
||||
#### tls
|
||||
|
||||
TLS configuration, see [TLS](/configuration/shared/tls/#inbound).
|
||||
@@ -129,3 +212,52 @@ export ANTHROPIC_AUTH_TOKEN="ak-ccm-hello-world"
|
||||
|
||||
claude
|
||||
```
|
||||
|
||||
### Example with Multiple Credentials
|
||||
|
||||
#### Server
|
||||
|
||||
```json
|
||||
{
|
||||
"services": [
|
||||
{
|
||||
"type": "ccm",
|
||||
"listen": "0.0.0.0",
|
||||
"listen_port": 8080,
|
||||
"credentials": [
|
||||
{
|
||||
"tag": "a",
|
||||
"credential_path": "/home/user/.claude-a/.credentials.json",
|
||||
"usages_path": "/data/usages-a.json",
|
||||
"reserve_5h": 20,
|
||||
"reserve_weekly": 20
|
||||
},
|
||||
{
|
||||
"tag": "b",
|
||||
"credential_path": "/home/user/.claude-b/.credentials.json",
|
||||
"reserve_5h": 10,
|
||||
"reserve_weekly": 10
|
||||
},
|
||||
{
|
||||
"tag": "pool",
|
||||
"type": "balancer",
|
||||
"poll_interval": "60s",
|
||||
"credentials": ["a", "b"]
|
||||
}
|
||||
],
|
||||
"users": [
|
||||
{
|
||||
"name": "alice",
|
||||
"token": "ak-ccm-hello-world",
|
||||
"credential": "pool"
|
||||
},
|
||||
{
|
||||
"name": "bob",
|
||||
"token": "ak-ccm-hello-bob",
|
||||
"credential": "a"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
@@ -10,6 +10,11 @@ CCM(Claude Code 多路复用器)服务是一个多路复用服务,允许
|
||||
|
||||
它在本地机器上处理与 Claude API 的 OAuth 身份验证,同时允许远程 Claude Code 通过 `ANTHROPIC_AUTH_TOKEN` 环境变量使用认证令牌进行身份验证。
|
||||
|
||||
!!! quote "sing-box 1.14.0 中的更改"
|
||||
|
||||
:material-plus: [credentials](#credentials)
|
||||
:material-alert: [users](#users)
|
||||
|
||||
### 结构
|
||||
|
||||
```json
|
||||
@@ -19,6 +24,7 @@ CCM(Claude Code 多路复用器)服务是一个多路复用服务,允许
|
||||
... // 监听字段
|
||||
|
||||
"credential_path": "",
|
||||
"credentials": [],
|
||||
"usages_path": "",
|
||||
"users": [],
|
||||
"headers": {},
|
||||
@@ -45,6 +51,77 @@ Claude Code OAuth 凭据文件的路径。
|
||||
|
||||
刷新的令牌会自动写回相同位置。
|
||||
|
||||
当 `credential_path` 指向文件时,即使文件尚不存在,服务也可以启动。文件被创建或更新后,凭据会自动变为可用;如果文件之后被删除或变为无效,该凭据会立即变为不可用。
|
||||
|
||||
在 macOS 上如果未显式设置 `credential_path`,不会监听钥匙串变化。自动重载只作用于凭据文件路径。
|
||||
|
||||
与 `credentials` 冲突。
|
||||
|
||||
#### credentials
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
多凭据模式的凭据配置列表。
|
||||
|
||||
设置后,顶层 `credential_path`、`usages_path` 和 `detour` 被禁止。每个用户必须指定 `credential` 标签。
|
||||
|
||||
每个凭据有一个 `type` 字段(`default`、`balancer` 或 `fallback`)和一个必填的 `tag` 字段。
|
||||
|
||||
##### 默认凭据
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "a",
|
||||
"credential_path": "/path/to/.credentials.json",
|
||||
"usages_path": "/path/to/usages.json",
|
||||
"detour": "",
|
||||
"reserve_5h": 20,
|
||||
"reserve_weekly": 20
|
||||
}
|
||||
```
|
||||
|
||||
单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。即使文件尚不存在,服务也可以启动,并会自动重载文件更新。
|
||||
|
||||
- `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。
|
||||
- `usages_path`:此凭据的可选使用跟踪文件。
|
||||
- `detour`:此凭据用于连接 Claude API 的出站标签。
|
||||
- `reserve_5h`:5 小时窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。
|
||||
- `reserve_weekly`:每周窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。
|
||||
|
||||
##### 均衡凭据
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "pool",
|
||||
"type": "balancer",
|
||||
"strategy": "",
|
||||
"credentials": ["a", "b"],
|
||||
"poll_interval": "60s"
|
||||
}
|
||||
```
|
||||
|
||||
根据选择的策略将会话分配给默认凭据。会话保持粘性,直到分配的凭据触发速率限制。
|
||||
|
||||
- `strategy`:选择策略。可选值:`least_used` `round_robin` `random`。默认使用 `least_used`。
|
||||
- `credentials`:==必填== 默认凭据标签列表。
|
||||
- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。
|
||||
|
||||
##### 回退凭据
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "backup",
|
||||
"type": "fallback",
|
||||
"credentials": ["a", "b"],
|
||||
"poll_interval": "30s"
|
||||
}
|
||||
```
|
||||
|
||||
按顺序使用凭据。当前凭据耗尽后切换到下一个。
|
||||
|
||||
- `credentials`:==必填== 有序的默认凭据标签列表。
|
||||
- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。
|
||||
|
||||
#### usages_path
|
||||
|
||||
用于存储聚合 API 使用统计信息的文件路径。
|
||||
@@ -60,6 +137,8 @@ Claude Code OAuth 凭据文件的路径。
|
||||
|
||||
统计文件每分钟自动保存一次,并在服务关闭时保存。
|
||||
|
||||
与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `usages_path`。
|
||||
|
||||
#### users
|
||||
|
||||
用于令牌身份验证的授权用户列表。
|
||||
@@ -71,7 +150,8 @@ Claude Code OAuth 凭据文件的路径。
|
||||
```json
|
||||
{
|
||||
"name": "",
|
||||
"token": ""
|
||||
"token": "",
|
||||
"credential": ""
|
||||
}
|
||||
```
|
||||
|
||||
@@ -79,6 +159,7 @@ Claude Code OAuth 凭据文件的路径。
|
||||
|
||||
- `name`:用于跟踪的用户名标识符。
|
||||
- `token`:用于身份验证的 Bearer 令牌。Claude Code 通过设置 `ANTHROPIC_AUTH_TOKEN` 环境变量为其令牌值进行身份验证。
|
||||
- `credential`:此用户使用的凭据标签。设置 `credentials` 时==必填==。
|
||||
|
||||
#### headers
|
||||
|
||||
@@ -90,6 +171,8 @@ Claude Code OAuth 凭据文件的路径。
|
||||
|
||||
用于连接 Claude API 的出站标签。
|
||||
|
||||
与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `detour`。
|
||||
|
||||
#### tls
|
||||
|
||||
TLS 配置,参阅 [TLS](/zh/configuration/shared/tls/#inbound)。
|
||||
@@ -129,3 +212,52 @@ export ANTHROPIC_AUTH_TOKEN="ak-ccm-hello-world"
|
||||
|
||||
claude
|
||||
```
|
||||
|
||||
### 多凭据示例
|
||||
|
||||
#### 服务端
|
||||
|
||||
```json
|
||||
{
|
||||
"services": [
|
||||
{
|
||||
"type": "ccm",
|
||||
"listen": "0.0.0.0",
|
||||
"listen_port": 8080,
|
||||
"credentials": [
|
||||
{
|
||||
"tag": "a",
|
||||
"credential_path": "/home/user/.claude-a/.credentials.json",
|
||||
"usages_path": "/data/usages-a.json",
|
||||
"reserve_5h": 20,
|
||||
"reserve_weekly": 20
|
||||
},
|
||||
{
|
||||
"tag": "b",
|
||||
"credential_path": "/home/user/.claude-b/.credentials.json",
|
||||
"reserve_5h": 10,
|
||||
"reserve_weekly": 10
|
||||
},
|
||||
{
|
||||
"tag": "pool",
|
||||
"type": "balancer",
|
||||
"poll_interval": "60s",
|
||||
"credentials": ["a", "b"]
|
||||
}
|
||||
],
|
||||
"users": [
|
||||
{
|
||||
"name": "alice",
|
||||
"token": "ak-ccm-hello-world",
|
||||
"credential": "pool"
|
||||
},
|
||||
{
|
||||
"name": "bob",
|
||||
"token": "ak-ccm-hello-bob",
|
||||
"credential": "a"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
@@ -10,6 +10,11 @@ OCM (OpenAI Codex Multiplexer) service is a multiplexing service that allows you
|
||||
|
||||
It handles OAuth authentication with OpenAI's API on your local machine while allowing remote clients to authenticate using custom tokens.
|
||||
|
||||
!!! quote "Changes in sing-box 1.14.0"
|
||||
|
||||
:material-plus: [credentials](#credentials)
|
||||
:material-alert: [users](#users)
|
||||
|
||||
### Structure
|
||||
|
||||
```json
|
||||
@@ -19,6 +24,7 @@ It handles OAuth authentication with OpenAI's API on your local machine while al
|
||||
... // Listen Fields
|
||||
|
||||
"credential_path": "",
|
||||
"credentials": [],
|
||||
"usages_path": "",
|
||||
"users": [],
|
||||
"headers": {},
|
||||
@@ -43,6 +49,75 @@ If not specified, defaults to:
|
||||
|
||||
Refreshed tokens are automatically written back to the same location.
|
||||
|
||||
When `credential_path` points to a file, the service can start before the file exists. The credential becomes available automatically after the file is created or updated, and becomes unavailable immediately if the file is later removed or becomes invalid.
|
||||
|
||||
Conflict with `credentials`.
|
||||
|
||||
#### credentials
|
||||
|
||||
!!! question "Since sing-box 1.14.0"
|
||||
|
||||
List of credential configurations for multi-credential mode.
|
||||
|
||||
When set, top-level `credential_path`, `usages_path`, and `detour` are forbidden. Each user must specify a `credential` tag.
|
||||
|
||||
Each credential has a `type` field (`default`, `balancer`, or `fallback`) and a required `tag` field.
|
||||
|
||||
##### Default Credential
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "a",
|
||||
"credential_path": "/path/to/auth.json",
|
||||
"usages_path": "/path/to/usages.json",
|
||||
"detour": "",
|
||||
"reserve_5h": 20,
|
||||
"reserve_weekly": 20
|
||||
}
|
||||
```
|
||||
|
||||
A single OAuth credential file. The `type` field can be omitted (defaults to `default`). The service can start before the file exists, and reloads file updates automatically.
|
||||
|
||||
- `credential_path`: Path to the credentials file. Same defaults as top-level `credential_path`.
|
||||
- `usages_path`: Optional usage tracking file for this credential.
|
||||
- `detour`: Outbound tag for connecting to the OpenAI API with this credential.
|
||||
- `reserve_5h`: Reserve threshold (1-99) for primary rate limit window. Credential pauses at (100-N)% utilization.
|
||||
- `reserve_weekly`: Reserve threshold (1-99) for secondary (weekly) rate limit window. Credential pauses at (100-N)% utilization.
|
||||
|
||||
##### Balancer Credential
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "pool",
|
||||
"type": "balancer",
|
||||
"strategy": "",
|
||||
"credentials": ["a", "b"],
|
||||
"poll_interval": "60s"
|
||||
}
|
||||
```
|
||||
|
||||
Assigns sessions to default credentials based on the selected strategy. Sessions are sticky until the assigned credential hits a rate limit.
|
||||
|
||||
- `strategy`: Selection strategy. One of `least_used` `round_robin` `random`. `least_used` will be used by default.
|
||||
- `credentials`: ==Required== List of default credential tags.
|
||||
- `poll_interval`: How often to poll upstream usage API. Default `60s`.
|
||||
|
||||
##### Fallback Credential
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "backup",
|
||||
"type": "fallback",
|
||||
"credentials": ["a", "b"],
|
||||
"poll_interval": "30s"
|
||||
}
|
||||
```
|
||||
|
||||
Uses credentials in order. Falls through to the next when the current one is exhausted.
|
||||
|
||||
- `credentials`: ==Required== Ordered list of default credential tags.
|
||||
- `poll_interval`: How often to poll upstream usage API. Default `60s`.
|
||||
|
||||
#### usages_path
|
||||
|
||||
Path to the file for storing aggregated API usage statistics.
|
||||
@@ -58,6 +133,8 @@ Statistics are organized by model and optionally by user when authentication is
|
||||
|
||||
The statistics file is automatically saved every minute and upon service shutdown.
|
||||
|
||||
Conflict with `credentials`. In multi-credential mode, use `usages_path` on individual default credentials.
|
||||
|
||||
#### users
|
||||
|
||||
List of authorized users for token authentication.
|
||||
@@ -69,7 +146,8 @@ Object format:
|
||||
```json
|
||||
{
|
||||
"name": "",
|
||||
"token": ""
|
||||
"token": "",
|
||||
"credential": ""
|
||||
}
|
||||
```
|
||||
|
||||
@@ -77,6 +155,7 @@ Object fields:
|
||||
|
||||
- `name`: Username identifier for tracking purposes.
|
||||
- `token`: Bearer token for authentication. Clients authenticate by setting the `Authorization: Bearer <token>` header.
|
||||
- `credential`: Credential tag to use for this user. ==Required== when `credentials` is set.
|
||||
|
||||
#### headers
|
||||
|
||||
@@ -88,6 +167,8 @@ These headers will override any existing headers with the same name.
|
||||
|
||||
Outbound tag for connecting to the OpenAI API.
|
||||
|
||||
Conflict with `credentials`. In multi-credential mode, use `detour` on individual default credentials.
|
||||
|
||||
#### tls
|
||||
|
||||
TLS configuration, see [TLS](/configuration/shared/tls/#inbound).
|
||||
@@ -183,3 +264,52 @@ Then run:
|
||||
```bash
|
||||
codex --profile ocm
|
||||
```
|
||||
|
||||
### Example with Multiple Credentials
|
||||
|
||||
#### Server
|
||||
|
||||
```json
|
||||
{
|
||||
"services": [
|
||||
{
|
||||
"type": "ocm",
|
||||
"listen": "0.0.0.0",
|
||||
"listen_port": 8080,
|
||||
"credentials": [
|
||||
{
|
||||
"tag": "a",
|
||||
"credential_path": "/home/user/.codex-a/auth.json",
|
||||
"usages_path": "/data/usages-a.json",
|
||||
"reserve_5h": 20,
|
||||
"reserve_weekly": 20
|
||||
},
|
||||
{
|
||||
"tag": "b",
|
||||
"credential_path": "/home/user/.codex-b/auth.json",
|
||||
"reserve_5h": 10,
|
||||
"reserve_weekly": 10
|
||||
},
|
||||
{
|
||||
"tag": "pool",
|
||||
"type": "balancer",
|
||||
"poll_interval": "60s",
|
||||
"credentials": ["a", "b"]
|
||||
}
|
||||
],
|
||||
"users": [
|
||||
{
|
||||
"name": "alice",
|
||||
"token": "sk-ocm-hello-world",
|
||||
"credential": "pool"
|
||||
},
|
||||
{
|
||||
"name": "bob",
|
||||
"token": "sk-ocm-hello-bob",
|
||||
"credential": "a"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
@@ -10,6 +10,11 @@ OCM(OpenAI Codex 多路复用器)服务是一个多路复用服务,允许
|
||||
|
||||
它在本地机器上处理与 OpenAI API 的 OAuth 身份验证,同时允许远程客户端使用自定义令牌进行身份验证。
|
||||
|
||||
!!! quote "sing-box 1.14.0 中的更改"
|
||||
|
||||
:material-plus: [credentials](#credentials)
|
||||
:material-alert: [users](#users)
|
||||
|
||||
### 结构
|
||||
|
||||
```json
|
||||
@@ -19,6 +24,7 @@ OCM(OpenAI Codex 多路复用器)服务是一个多路复用服务,允许
|
||||
... // 监听字段
|
||||
|
||||
"credential_path": "",
|
||||
"credentials": [],
|
||||
"usages_path": "",
|
||||
"users": [],
|
||||
"headers": {},
|
||||
@@ -43,6 +49,75 @@ OpenAI OAuth 凭据文件的路径。
|
||||
|
||||
刷新的令牌会自动写回相同位置。
|
||||
|
||||
当 `credential_path` 指向文件时,即使文件尚不存在,服务也可以启动。文件被创建或更新后,凭据会自动变为可用;如果文件之后被删除或变为无效,该凭据会立即变为不可用。
|
||||
|
||||
与 `credentials` 冲突。
|
||||
|
||||
#### credentials
|
||||
|
||||
!!! question "自 sing-box 1.14.0 起"
|
||||
|
||||
多凭据模式的凭据配置列表。
|
||||
|
||||
设置后,顶层 `credential_path`、`usages_path` 和 `detour` 被禁止。每个用户必须指定 `credential` 标签。
|
||||
|
||||
每个凭据有一个 `type` 字段(`default`、`balancer` 或 `fallback`)和一个必填的 `tag` 字段。
|
||||
|
||||
##### 默认凭据
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "a",
|
||||
"credential_path": "/path/to/auth.json",
|
||||
"usages_path": "/path/to/usages.json",
|
||||
"detour": "",
|
||||
"reserve_5h": 20,
|
||||
"reserve_weekly": 20
|
||||
}
|
||||
```
|
||||
|
||||
单个 OAuth 凭据文件。`type` 字段可以省略(默认为 `default`)。即使文件尚不存在,服务也可以启动,并会自动重载文件更新。
|
||||
|
||||
- `credential_path`:凭据文件的路径。默认值与顶层 `credential_path` 相同。
|
||||
- `usages_path`:此凭据的可选使用跟踪文件。
|
||||
- `detour`:此凭据用于连接 OpenAI API 的出站标签。
|
||||
- `reserve_5h`:主要速率限制窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。
|
||||
- `reserve_weekly`:次要(每周)速率限制窗口的保留阈值(1-99)。凭据在利用率达到 (100-N)% 时暂停。
|
||||
|
||||
##### 均衡凭据
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "pool",
|
||||
"type": "balancer",
|
||||
"strategy": "",
|
||||
"credentials": ["a", "b"],
|
||||
"poll_interval": "60s"
|
||||
}
|
||||
```
|
||||
|
||||
根据选择的策略将会话分配给默认凭据。会话保持粘性,直到分配的凭据触发速率限制。
|
||||
|
||||
- `strategy`:选择策略。可选值:`least_used` `round_robin` `random`。默认使用 `least_used`。
|
||||
- `credentials`:==必填== 默认凭据标签列表。
|
||||
- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。
|
||||
|
||||
##### 回退凭据
|
||||
|
||||
```json
|
||||
{
|
||||
"tag": "backup",
|
||||
"type": "fallback",
|
||||
"credentials": ["a", "b"],
|
||||
"poll_interval": "30s"
|
||||
}
|
||||
```
|
||||
|
||||
按顺序使用凭据。当前凭据耗尽后切换到下一个。
|
||||
|
||||
- `credentials`:==必填== 有序的默认凭据标签列表。
|
||||
- `poll_interval`:轮询上游使用 API 的间隔。默认 `60s`。
|
||||
|
||||
#### usages_path
|
||||
|
||||
用于存储聚合 API 使用统计信息的文件路径。
|
||||
@@ -58,6 +133,8 @@ OpenAI OAuth 凭据文件的路径。
|
||||
|
||||
统计文件每分钟自动保存一次,并在服务关闭时保存。
|
||||
|
||||
与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `usages_path`。
|
||||
|
||||
#### users
|
||||
|
||||
用于令牌身份验证的授权用户列表。
|
||||
@@ -69,7 +146,8 @@ OpenAI OAuth 凭据文件的路径。
|
||||
```json
|
||||
{
|
||||
"name": "",
|
||||
"token": ""
|
||||
"token": "",
|
||||
"credential": ""
|
||||
}
|
||||
```
|
||||
|
||||
@@ -77,6 +155,7 @@ OpenAI OAuth 凭据文件的路径。
|
||||
|
||||
- `name`:用于跟踪的用户名标识符。
|
||||
- `token`:用于身份验证的 Bearer 令牌。客户端通过设置 `Authorization: Bearer <token>` 头进行身份验证。
|
||||
- `credential`:此用户使用的凭据标签。设置 `credentials` 时==必填==。
|
||||
|
||||
#### headers
|
||||
|
||||
@@ -88,6 +167,8 @@ OpenAI OAuth 凭据文件的路径。
|
||||
|
||||
用于连接 OpenAI API 的出站标签。
|
||||
|
||||
与 `credentials` 冲突。在多凭据模式下,在各个默认凭据上使用 `detour`。
|
||||
|
||||
#### tls
|
||||
|
||||
TLS 配置,参阅 [TLS](/zh/configuration/shared/tls/#inbound)。
|
||||
@@ -184,3 +265,52 @@ model_provider = "ocm"
|
||||
```bash
|
||||
codex --profile ocm
|
||||
```
|
||||
|
||||
### 多凭据示例
|
||||
|
||||
#### 服务端
|
||||
|
||||
```json
|
||||
{
|
||||
"services": [
|
||||
{
|
||||
"type": "ocm",
|
||||
"listen": "0.0.0.0",
|
||||
"listen_port": 8080,
|
||||
"credentials": [
|
||||
{
|
||||
"tag": "a",
|
||||
"credential_path": "/home/user/.codex-a/auth.json",
|
||||
"usages_path": "/data/usages-a.json",
|
||||
"reserve_5h": 20,
|
||||
"reserve_weekly": 20
|
||||
},
|
||||
{
|
||||
"tag": "b",
|
||||
"credential_path": "/home/user/.codex-b/auth.json",
|
||||
"reserve_5h": 10,
|
||||
"reserve_weekly": 10
|
||||
},
|
||||
{
|
||||
"tag": "pool",
|
||||
"type": "balancer",
|
||||
"poll_interval": "60s",
|
||||
"credentials": ["a", "b"]
|
||||
}
|
||||
],
|
||||
"users": [
|
||||
{
|
||||
"name": "alice",
|
||||
"token": "sk-ocm-hello-world",
|
||||
"credential": "pool"
|
||||
},
|
||||
{
|
||||
"name": "bob",
|
||||
"token": "sk-ocm-hello-bob",
|
||||
"credential": "a"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
49
docs/configuration/shared/neighbor.md
Normal file
49
docs/configuration/shared/neighbor.md
Normal file
@@ -0,0 +1,49 @@
|
||||
---
|
||||
icon: material/lan
|
||||
---
|
||||
|
||||
# Neighbor Resolution
|
||||
|
||||
Match LAN devices by MAC address and hostname using
|
||||
[`source_mac_address`](/configuration/route/rule/#source_mac_address) and
|
||||
[`source_hostname`](/configuration/route/rule/#source_hostname) rule items.
|
||||
|
||||
Neighbor resolution is automatically enabled when these rule items exist.
|
||||
Use [`route.find_neighbor`](/configuration/route/#find_neighbor) to force enable it for logging without rules.
|
||||
|
||||
## Linux
|
||||
|
||||
Works natively. No special setup required.
|
||||
|
||||
Hostname resolution requires DHCP lease files,
|
||||
automatically detected from common DHCP servers (dnsmasq, odhcpd, ISC dhcpd, Kea).
|
||||
Custom paths can be set via [`route.dhcp_lease_files`](/configuration/route/#dhcp_lease_files).
|
||||
|
||||
## Android
|
||||
|
||||
!!! quote ""
|
||||
|
||||
Only supported in graphical clients.
|
||||
|
||||
Requires Android 11 or above and ROOT.
|
||||
|
||||
Must use [VPNHotspot](https://github.com/Mygod/VPNHotspot) to share the VPN connection.
|
||||
ROM built-in features like "Use VPN for connected devices" can share VPN
|
||||
but cannot provide MAC address or hostname information.
|
||||
|
||||
Set **IP Masquerade Mode** to **None** in VPNHotspot settings.
|
||||
|
||||
Only route/DNS rules are supported. TUN include/exclude routes are not supported.
|
||||
|
||||
### Hostname Visibility
|
||||
|
||||
Hostname is only visible in sing-box if it is visible in VPNHotspot.
|
||||
For Apple devices, change **Private Wi-Fi Address** from **Rotating** to **Fixed** in the Wi-Fi settings
|
||||
of the connected network. Non-Apple devices are always visible.
|
||||
|
||||
## macOS
|
||||
|
||||
Requires the standalone version (macOS system extension).
|
||||
The App Store version can share the VPN as a hotspot but does not support MAC address or hostname reading.
|
||||
|
||||
See [VPN Hotspot](/manual/misc/vpn-hotspot/#macos) for Internet Sharing setup.
|
||||
49
docs/configuration/shared/neighbor.zh.md
Normal file
49
docs/configuration/shared/neighbor.zh.md
Normal file
@@ -0,0 +1,49 @@
|
||||
---
|
||||
icon: material/lan
|
||||
---
|
||||
|
||||
# 邻居解析
|
||||
|
||||
通过
|
||||
[`source_mac_address`](/configuration/route/rule/#source_mac_address) 和
|
||||
[`source_hostname`](/configuration/route/rule/#source_hostname) 规则项匹配局域网设备的 MAC 地址和主机名。
|
||||
|
||||
当这些规则项存在时,邻居解析自动启用。
|
||||
使用 [`route.find_neighbor`](/configuration/route/#find_neighbor) 可在没有规则时强制启用以输出日志。
|
||||
|
||||
## Linux
|
||||
|
||||
原生支持,无需特殊设置。
|
||||
|
||||
主机名解析需要 DHCP 租约文件,
|
||||
自动从常见 DHCP 服务器(dnsmasq、odhcpd、ISC dhcpd、Kea)检测。
|
||||
可通过 [`route.dhcp_lease_files`](/configuration/route/#dhcp_lease_files) 设置自定义路径。
|
||||
|
||||
## Android
|
||||
|
||||
!!! quote ""
|
||||
|
||||
仅在图形客户端中支持。
|
||||
|
||||
需要 Android 11 或以上版本和 ROOT。
|
||||
|
||||
必须使用 [VPNHotspot](https://github.com/Mygod/VPNHotspot) 共享 VPN 连接。
|
||||
ROM 自带的「通过 VPN 共享连接」等功能可以共享 VPN,
|
||||
但无法提供 MAC 地址或主机名信息。
|
||||
|
||||
在 VPNHotspot 设置中将 **IP 遮掩模式** 设为 **无**。
|
||||
|
||||
仅支持路由/DNS 规则。不支持 TUN 的 include/exclude 路由。
|
||||
|
||||
### 设备可见性
|
||||
|
||||
MAC 地址和主机名仅在 VPNHotspot 中可见时 sing-box 才能读取。
|
||||
对于 Apple 设备,需要在所连接网络的 Wi-Fi 设置中将**私有无线局域网地址**从**轮替**改为**固定**。
|
||||
非 Apple 设备始终可见。
|
||||
|
||||
## macOS
|
||||
|
||||
需要独立版本(macOS 系统扩展)。
|
||||
App Store 版本可以共享 VPN 热点但不支持 MAC 地址或主机名读取。
|
||||
|
||||
参阅 [VPN 热点](/manual/misc/vpn-hotspot/#macos) 了解互联网共享设置。
|
||||
@@ -144,6 +144,18 @@ func (s *platformInterfaceStub) SendNotification(notification *adapter.Notificat
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *platformInterfaceStub) UsePlatformNeighborResolver() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *platformInterfaceStub) StartNeighborMonitor(listener adapter.NeighborUpdateListener) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (s *platformInterfaceStub) CloseNeighborMonitor(listener adapter.NeighborUpdateListener) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *platformInterfaceStub) UsePlatformLocalDNSTransport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
53
experimental/libbox/neighbor.go
Normal file
53
experimental/libbox/neighbor.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package libbox
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type NeighborEntry struct {
|
||||
Address string
|
||||
MacAddress string
|
||||
Hostname string
|
||||
}
|
||||
|
||||
type NeighborEntryIterator interface {
|
||||
Next() *NeighborEntry
|
||||
HasNext() bool
|
||||
}
|
||||
|
||||
type NeighborSubscription struct {
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func (s *NeighborSubscription) Close() {
|
||||
close(s.done)
|
||||
}
|
||||
|
||||
func tableToIterator(table map[netip.Addr]net.HardwareAddr) NeighborEntryIterator {
|
||||
entries := make([]*NeighborEntry, 0, len(table))
|
||||
for address, mac := range table {
|
||||
entries = append(entries, &NeighborEntry{
|
||||
Address: address.String(),
|
||||
MacAddress: mac.String(),
|
||||
})
|
||||
}
|
||||
return &neighborEntryIterator{entries}
|
||||
}
|
||||
|
||||
type neighborEntryIterator struct {
|
||||
entries []*NeighborEntry
|
||||
}
|
||||
|
||||
func (i *neighborEntryIterator) HasNext() bool {
|
||||
return len(i.entries) > 0
|
||||
}
|
||||
|
||||
func (i *neighborEntryIterator) Next() *NeighborEntry {
|
||||
if len(i.entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
entry := i.entries[0]
|
||||
i.entries = i.entries[1:]
|
||||
return entry
|
||||
}
|
||||
123
experimental/libbox/neighbor_darwin.go
Normal file
123
experimental/libbox/neighbor_darwin.go
Normal file
@@ -0,0 +1,123 @@
|
||||
//go:build darwin
|
||||
|
||||
package libbox
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/route"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
xroute "golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func SubscribeNeighborTable(listener NeighborUpdateListener) (*NeighborSubscription, error) {
|
||||
entries, err := route.ReadNeighborEntries()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "initial neighbor dump")
|
||||
}
|
||||
table := make(map[netip.Addr]net.HardwareAddr)
|
||||
for _, entry := range entries {
|
||||
table[entry.Address] = entry.MACAddress
|
||||
}
|
||||
listener.UpdateNeighborTable(tableToIterator(table))
|
||||
routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, 0)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "open route socket")
|
||||
}
|
||||
err = unix.SetNonblock(routeSocket, true)
|
||||
if err != nil {
|
||||
unix.Close(routeSocket)
|
||||
return nil, E.Cause(err, "set route socket nonblock")
|
||||
}
|
||||
subscription := &NeighborSubscription{
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go subscription.loop(listener, routeSocket, table)
|
||||
return subscription, nil
|
||||
}
|
||||
|
||||
func (s *NeighborSubscription) loop(listener NeighborUpdateListener, routeSocket int, table map[netip.Addr]net.HardwareAddr) {
|
||||
routeSocketFile := os.NewFile(uintptr(routeSocket), "route")
|
||||
defer routeSocketFile.Close()
|
||||
buffer := buf.NewPacket()
|
||||
defer buffer.Release()
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
tv := unix.NsecToTimeval(int64(3 * time.Second))
|
||||
_ = unix.SetsockoptTimeval(routeSocket, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv)
|
||||
n, err := routeSocketFile.Read(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
continue
|
||||
}
|
||||
messages, err := xroute.ParseRIB(xroute.RIBTypeRoute, buffer.FreeBytes()[:n])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
changed := false
|
||||
for _, message := range messages {
|
||||
routeMessage, isRouteMessage := message.(*xroute.RouteMessage)
|
||||
if !isRouteMessage {
|
||||
continue
|
||||
}
|
||||
if routeMessage.Flags&unix.RTF_LLINFO == 0 {
|
||||
continue
|
||||
}
|
||||
address, mac, isDelete, ok := route.ParseRouteNeighborMessage(routeMessage)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if isDelete {
|
||||
if _, exists := table[address]; exists {
|
||||
delete(table, address)
|
||||
changed = true
|
||||
}
|
||||
} else {
|
||||
existing, exists := table[address]
|
||||
if !exists || !slices.Equal(existing, mac) {
|
||||
table[address] = mac
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if changed {
|
||||
listener.UpdateNeighborTable(tableToIterator(table))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ReadBootpdLeases() NeighborEntryIterator {
|
||||
leaseIPToMAC, ipToHostname, macToHostname := route.ReloadLeaseFiles([]string{"/var/db/dhcpd_leases"})
|
||||
entries := make([]*NeighborEntry, 0, len(leaseIPToMAC))
|
||||
for address, mac := range leaseIPToMAC {
|
||||
entry := &NeighborEntry{
|
||||
Address: address.String(),
|
||||
MacAddress: mac.String(),
|
||||
}
|
||||
hostname, found := ipToHostname[address]
|
||||
if !found {
|
||||
hostname = macToHostname[mac.String()]
|
||||
}
|
||||
entry.Hostname = hostname
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
return &neighborEntryIterator{entries}
|
||||
}
|
||||
88
experimental/libbox/neighbor_linux.go
Normal file
88
experimental/libbox/neighbor_linux.go
Normal file
@@ -0,0 +1,88 @@
|
||||
//go:build linux
|
||||
|
||||
package libbox
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/route"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"github.com/mdlayher/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func SubscribeNeighborTable(listener NeighborUpdateListener) (*NeighborSubscription, error) {
|
||||
entries, err := route.ReadNeighborEntries()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "initial neighbor dump")
|
||||
}
|
||||
table := make(map[netip.Addr]net.HardwareAddr)
|
||||
for _, entry := range entries {
|
||||
table[entry.Address] = entry.MACAddress
|
||||
}
|
||||
listener.UpdateNeighborTable(tableToIterator(table))
|
||||
connection, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{
|
||||
Groups: 1 << (unix.RTNLGRP_NEIGH - 1),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "subscribe neighbor updates")
|
||||
}
|
||||
subscription := &NeighborSubscription{
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go subscription.loop(listener, connection, table)
|
||||
return subscription, nil
|
||||
}
|
||||
|
||||
func (s *NeighborSubscription) loop(listener NeighborUpdateListener, connection *netlink.Conn, table map[netip.Addr]net.HardwareAddr) {
|
||||
defer connection.Close()
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
err := connection.SetReadDeadline(time.Now().Add(3 * time.Second))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
messages, err := connection.Receive()
|
||||
if err != nil {
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
continue
|
||||
}
|
||||
changed := false
|
||||
for _, message := range messages {
|
||||
address, mac, isDelete, ok := route.ParseNeighborMessage(message)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if isDelete {
|
||||
if _, exists := table[address]; exists {
|
||||
delete(table, address)
|
||||
changed = true
|
||||
}
|
||||
} else {
|
||||
existing, exists := table[address]
|
||||
if !exists || !slices.Equal(existing, mac) {
|
||||
table[address] = mac
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if changed {
|
||||
listener.UpdateNeighborTable(tableToIterator(table))
|
||||
}
|
||||
}
|
||||
}
|
||||
9
experimental/libbox/neighbor_stub.go
Normal file
9
experimental/libbox/neighbor_stub.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !linux && !darwin
|
||||
|
||||
package libbox
|
||||
|
||||
import "os"
|
||||
|
||||
func SubscribeNeighborTable(_ NeighborUpdateListener) (*NeighborSubscription, error) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
@@ -21,6 +21,13 @@ type PlatformInterface interface {
|
||||
SystemCertificates() StringIterator
|
||||
ClearDNSCache()
|
||||
SendNotification(notification *Notification) error
|
||||
StartNeighborMonitor(listener NeighborUpdateListener) error
|
||||
CloseNeighborMonitor(listener NeighborUpdateListener) error
|
||||
RegisterMyInterface(name string)
|
||||
}
|
||||
|
||||
type NeighborUpdateListener interface {
|
||||
UpdateNeighborTable(entries NeighborEntryIterator)
|
||||
}
|
||||
|
||||
type ConnectionOwner struct {
|
||||
|
||||
@@ -78,6 +78,7 @@ func (w *platformInterfaceWrapper) OpenInterface(options *tun.Options, platformO
|
||||
}
|
||||
options.FileDescriptor = dupFd
|
||||
w.myTunName = options.Name
|
||||
w.iif.RegisterMyInterface(options.Name)
|
||||
return tun.New(*options)
|
||||
}
|
||||
|
||||
@@ -220,6 +221,46 @@ func (w *platformInterfaceWrapper) SendNotification(notification *adapter.Notifi
|
||||
return w.iif.SendNotification((*Notification)(notification))
|
||||
}
|
||||
|
||||
func (w *platformInterfaceWrapper) UsePlatformNeighborResolver() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (w *platformInterfaceWrapper) StartNeighborMonitor(listener adapter.NeighborUpdateListener) error {
|
||||
return w.iif.StartNeighborMonitor(&neighborUpdateListenerWrapper{listener: listener})
|
||||
}
|
||||
|
||||
func (w *platformInterfaceWrapper) CloseNeighborMonitor(listener adapter.NeighborUpdateListener) error {
|
||||
return w.iif.CloseNeighborMonitor(nil)
|
||||
}
|
||||
|
||||
type neighborUpdateListenerWrapper struct {
|
||||
listener adapter.NeighborUpdateListener
|
||||
}
|
||||
|
||||
func (w *neighborUpdateListenerWrapper) UpdateNeighborTable(entries NeighborEntryIterator) {
|
||||
var result []adapter.NeighborEntry
|
||||
for entries.HasNext() {
|
||||
entry := entries.Next()
|
||||
if entry == nil {
|
||||
continue
|
||||
}
|
||||
address, err := netip.ParseAddr(entry.Address)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
macAddress, err := net.ParseMAC(entry.MacAddress)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, adapter.NeighborEntry{
|
||||
Address: address,
|
||||
MACAddress: macAddress,
|
||||
Hostname: entry.Hostname,
|
||||
})
|
||||
}
|
||||
w.listener.UpdateNeighborTable(result)
|
||||
}
|
||||
|
||||
func AvailablePort(startPort int32) (int32, error) {
|
||||
for port := int(startPort); ; port++ {
|
||||
if port > 65535 {
|
||||
|
||||
14
go.mod
14
go.mod
@@ -14,11 +14,13 @@ require (
|
||||
github.com/godbus/dbus/v5 v5.2.2
|
||||
github.com/gofrs/uuid/v5 v5.4.0
|
||||
github.com/insomniacslk/dhcp v0.0.0-20260220084031-5adc3eb26f91
|
||||
github.com/jsimonetti/rtnetlink v1.4.0
|
||||
github.com/keybase/go-keychain v0.0.1
|
||||
github.com/libdns/acmedns v0.5.0
|
||||
github.com/libdns/alidns v1.0.6
|
||||
github.com/libdns/cloudflare v0.2.2
|
||||
github.com/logrusorgru/aurora v2.0.3+incompatible
|
||||
github.com/mdlayher/netlink v1.9.0
|
||||
github.com/metacubex/utls v1.8.4
|
||||
github.com/mholt/acmez/v3 v3.1.6
|
||||
github.com/miekg/dns v1.1.72
|
||||
@@ -27,22 +29,22 @@ require (
|
||||
github.com/sagernet/asc-go v0.0.0-20241217030726-d563060fe4e1
|
||||
github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a
|
||||
github.com/sagernet/cors v1.2.1
|
||||
github.com/sagernet/cronet-go v0.0.0-20260309102448-2fef65f9dba9
|
||||
github.com/sagernet/cronet-go/all v0.0.0-20260309102448-2fef65f9dba9
|
||||
github.com/sagernet/cronet-go v0.0.0-20260309100020-c128886ff3fc
|
||||
github.com/sagernet/cronet-go/all v0.0.0-20260309100020-c128886ff3fc
|
||||
github.com/sagernet/fswatch v0.1.1
|
||||
github.com/sagernet/gomobile v0.1.12
|
||||
github.com/sagernet/gvisor v0.0.0-20250811.0-sing-box-mod.1
|
||||
github.com/sagernet/quic-go v0.59.0-sing-box-mod.4
|
||||
github.com/sagernet/sing v0.8.2
|
||||
github.com/sagernet/sing v0.8.3-0.20260311155444-d39eb42a9f69
|
||||
github.com/sagernet/sing-mux v0.3.4
|
||||
github.com/sagernet/sing-quic v0.6.0
|
||||
github.com/sagernet/sing-shadowsocks v0.2.8
|
||||
github.com/sagernet/sing-shadowsocks2 v0.2.1
|
||||
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11
|
||||
github.com/sagernet/sing-tun v0.8.2
|
||||
github.com/sagernet/sing-tun v0.8.3-0.20260311132553-5485872f601f
|
||||
github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1
|
||||
github.com/sagernet/smux v1.5.50-sing-box-mod.1
|
||||
github.com/sagernet/tailscale v1.92.4-sing-box-1.13-mod.6.0.20260310162543-0c2de366d4de
|
||||
github.com/sagernet/tailscale v1.92.4-sing-box-1.13-mod.6.0.20260311131347-f88b27eeb76e
|
||||
github.com/sagernet/wireguard-go v0.0.2-beta.1.0.20260224074747-506b7631853c
|
||||
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854
|
||||
github.com/spf13/cobra v1.10.2
|
||||
@@ -92,11 +94,9 @@ require (
|
||||
github.com/hashicorp/yamux v0.1.2 // indirect
|
||||
github.com/hdevalence/ed25519consensus v0.2.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jsimonetti/rtnetlink v1.4.0 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/libdns/libdns v1.1.1 // indirect
|
||||
github.com/mdlayher/netlink v1.9.0 // indirect
|
||||
github.com/mdlayher/socket v0.5.1 // indirect
|
||||
github.com/mitchellh/go-ps v1.0.0 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
||||
|
||||
20
go.sum
20
go.sum
@@ -162,10 +162,10 @@ github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a h1:+NkI2670SQpQWvkk
|
||||
github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a/go.mod h1:63s7jpZqcDAIpj8oI/1v4Izok+npJOHACFCU6+huCkM=
|
||||
github.com/sagernet/cors v1.2.1 h1:Cv5Z8y9YSD6Gm+qSpNrL3LO4lD3eQVvbFYJSG7JCMHQ=
|
||||
github.com/sagernet/cors v1.2.1/go.mod h1:O64VyOjjhrkLmQIjF4KGRrJO/5dVXFdpEmCW/eISRAI=
|
||||
github.com/sagernet/cronet-go v0.0.0-20260309102448-2fef65f9dba9 h1:xq5Yr10jXEppD3cnGjE3WENaB6D0YsZu6KptZ8d3054=
|
||||
github.com/sagernet/cronet-go v0.0.0-20260309102448-2fef65f9dba9/go.mod h1:hwFHBEjjthyEquDULbr4c4ucMedp8Drb6Jvm2kt/0Bw=
|
||||
github.com/sagernet/cronet-go/all v0.0.0-20260309102448-2fef65f9dba9 h1:uxQyy6Y/boOuecVA66tf79JgtoRGfeDJcfYZZLKVA5E=
|
||||
github.com/sagernet/cronet-go/all v0.0.0-20260309102448-2fef65f9dba9/go.mod h1:Xm6cCvs0/twozC1JYNq0sVlOVmcSGzV7YON1XGcD97w=
|
||||
github.com/sagernet/cronet-go v0.0.0-20260309100020-c128886ff3fc h1:YK7PwJT0irRAEui9ASdXSxcE2BOVQipWMF/A1Ogt+7c=
|
||||
github.com/sagernet/cronet-go v0.0.0-20260309100020-c128886ff3fc/go.mod h1:hwFHBEjjthyEquDULbr4c4ucMedp8Drb6Jvm2kt/0Bw=
|
||||
github.com/sagernet/cronet-go/all v0.0.0-20260309100020-c128886ff3fc h1:EJPHOqk23IuBsTjXK9OXqkNxPbKOBWKRmviQoCcriAs=
|
||||
github.com/sagernet/cronet-go/all v0.0.0-20260309100020-c128886ff3fc/go.mod h1:8aty0RW96DrJSMWXO6bRPMBJEjuqq5JWiOIi4bCRzFA=
|
||||
github.com/sagernet/cronet-go/lib/android_386 v0.0.0-20260309101654-0cbdcfddded9 h1:Qi0IKBpoPP3qZqIXuOKMsT2dv+l/MLWMyBHDMLRw2EA=
|
||||
github.com/sagernet/cronet-go/lib/android_386 v0.0.0-20260309101654-0cbdcfddded9/go.mod h1:XXDwdjX/T8xftoeJxQmbBoYXZp8MAPFR2CwbFuTpEtw=
|
||||
github.com/sagernet/cronet-go/lib/android_amd64 v0.0.0-20260309101654-0cbdcfddded9 h1:p+wCMjOhj46SpSD/AJeTGgkCcbyA76FyH631XZatyU8=
|
||||
@@ -236,8 +236,8 @@ github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNen
|
||||
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
|
||||
github.com/sagernet/quic-go v0.59.0-sing-box-mod.4 h1:6qvrUW79S+CrPwWz6cMePXohgjHoKxLo3c+MDhNwc3o=
|
||||
github.com/sagernet/quic-go v0.59.0-sing-box-mod.4/go.mod h1:OqILvS182CyOol5zNNo6bguvOGgXzV459+chpRaUC+4=
|
||||
github.com/sagernet/sing v0.8.2 h1:kX1IH9SWJv4S0T9M8O+HNahWgbOuY1VauxbF7NU5lOg=
|
||||
github.com/sagernet/sing v0.8.2/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
|
||||
github.com/sagernet/sing v0.8.3-0.20260311155444-d39eb42a9f69 h1:h6UF2emeydBQMAso99Nr3APV6YustOs+JszVuCkcFy0=
|
||||
github.com/sagernet/sing v0.8.3-0.20260311155444-d39eb42a9f69/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
|
||||
github.com/sagernet/sing-mux v0.3.4 h1:ZQplKl8MNXutjzbMVtWvWG31fohhgOfCuUZR4dVQ8+s=
|
||||
github.com/sagernet/sing-mux v0.3.4/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk=
|
||||
github.com/sagernet/sing-quic v0.6.0 h1:dhrFnP45wgVKEOT1EvtsToxdzRnHIDIAgj6WHV9pLyM=
|
||||
@@ -248,14 +248,14 @@ github.com/sagernet/sing-shadowsocks2 v0.2.1 h1:dWV9OXCeFPuYGHb6IRqlSptVnSzOelnq
|
||||
github.com/sagernet/sing-shadowsocks2 v0.2.1/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ=
|
||||
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 h1:tK+75l64tm9WvEFrYRE1t0YxoFdWQqw/h7Uhzj0vJ+w=
|
||||
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11/go.mod h1:sWqKnGlMipCHaGsw1sTTlimyUpgzP4WP3pjhCsYt9oA=
|
||||
github.com/sagernet/sing-tun v0.8.2 h1:rQr/x3eQCHh3oleIaoJdPdJwqzZp4+QWcJLT0Wz2xKY=
|
||||
github.com/sagernet/sing-tun v0.8.2/go.mod h1:pLCo4o+LacXEzz0bhwhJkKBjLlKOGPBNOAZ97ZVZWzs=
|
||||
github.com/sagernet/sing-tun v0.8.3-0.20260311132553-5485872f601f h1:uj3rzedphq1AiL0PpuVoob5RtKsPBcMRd8aqo+q0rqA=
|
||||
github.com/sagernet/sing-tun v0.8.3-0.20260311132553-5485872f601f/go.mod h1:pLCo4o+LacXEzz0bhwhJkKBjLlKOGPBNOAZ97ZVZWzs=
|
||||
github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1 h1:aSwUNYUkVyVvdmBSufR8/nRFonwJeKSIROxHcm5br9o=
|
||||
github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1/go.mod h1:P11scgTxMxVVQ8dlM27yNm3Cro40mD0+gHbnqrNGDuY=
|
||||
github.com/sagernet/smux v1.5.50-sing-box-mod.1 h1:XkJcivBC9V4wBjiGXIXZ229aZCU1hzcbp6kSkkyQ478=
|
||||
github.com/sagernet/smux v1.5.50-sing-box-mod.1/go.mod h1:NjhsCEWedJm7eFLyhuBgIEzwfhRmytrUoiLluxs5Sk8=
|
||||
github.com/sagernet/tailscale v1.92.4-sing-box-1.13-mod.6.0.20260310162543-0c2de366d4de h1:wsJ0COxUOIvBE+hUho0C/DbMeUe9jtwfh6dECAiTk94=
|
||||
github.com/sagernet/tailscale v1.92.4-sing-box-1.13-mod.6.0.20260310162543-0c2de366d4de/go.mod h1:m87GAn4UcesHQF3leaPFEINZETO5za1LGn1GJdNDgNc=
|
||||
github.com/sagernet/tailscale v1.92.4-sing-box-1.13-mod.6.0.20260311131347-f88b27eeb76e h1:Sv1qUhJIidjSTc24XEknovDZnbmVSlAXj8wNVgIfgGo=
|
||||
github.com/sagernet/tailscale v1.92.4-sing-box-1.13-mod.6.0.20260311131347-f88b27eeb76e/go.mod h1:m87GAn4UcesHQF3leaPFEINZETO5za1LGn1GJdNDgNc=
|
||||
github.com/sagernet/wireguard-go v0.0.2-beta.1.0.20260224074747-506b7631853c h1:f9cXNB+IOOPnR8DOLMTpr42jf7naxh5Un5Y09BBf5Cg=
|
||||
github.com/sagernet/wireguard-go v0.0.2-beta.1.0.20260224074747-506b7631853c/go.mod h1:WUxgxUDZoCF2sxVmW+STSxatP02Qn3FcafTiI2BLtE0=
|
||||
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc=
|
||||
|
||||
@@ -168,7 +168,11 @@ func FormatDuration(duration time.Duration) string {
|
||||
return F.ToString(duration.Milliseconds(), "ms")
|
||||
} else if duration < time.Minute {
|
||||
return F.ToString(int64(duration.Seconds()), ".", int64(duration.Seconds()*100)%100, "s")
|
||||
} else {
|
||||
} else if duration < time.Hour {
|
||||
return F.ToString(int64(duration.Minutes()), "m", int64(duration.Seconds())%60, "s")
|
||||
} else if duration < 24*time.Hour {
|
||||
return F.ToString(int64(duration.Hours()), "h", int64(duration.Minutes())%60, "m")
|
||||
} else {
|
||||
return F.ToString(int64(duration.Hours())/24, "d", int64(duration.Hours())%24, "h")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,6 +129,7 @@ nav:
|
||||
- UDP over TCP: configuration/shared/udp-over-tcp.md
|
||||
- TCP Brutal: configuration/shared/tcp-brutal.md
|
||||
- Wi-Fi State: configuration/shared/wifi-state.md
|
||||
- Neighbor Resolution: configuration/shared/neighbor.md
|
||||
- Endpoint:
|
||||
- configuration/endpoint/index.md
|
||||
- WireGuard: configuration/endpoint/wireguard.md
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package option
|
||||
|
||||
import (
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
"github.com/sagernet/sing/common/json/badjson"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
)
|
||||
|
||||
@@ -8,6 +11,7 @@ type CCMServiceOptions struct {
|
||||
ListenOptions
|
||||
InboundTLSOptionsContainer
|
||||
CredentialPath string `json:"credential_path,omitempty"`
|
||||
Credentials []CCMCredential `json:"credentials,omitempty"`
|
||||
Users []CCMUser `json:"users,omitempty"`
|
||||
Headers badoption.HTTPHeader `json:"headers,omitempty"`
|
||||
Detour string `json:"detour,omitempty"`
|
||||
@@ -15,6 +19,94 @@ type CCMServiceOptions struct {
|
||||
}
|
||||
|
||||
type CCMUser struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Credential string `json:"credential,omitempty"`
|
||||
ExternalCredential string `json:"external_credential,omitempty"`
|
||||
AllowExternalUsage bool `json:"allow_external_usage,omitempty"`
|
||||
}
|
||||
|
||||
type _CCMCredential struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Tag string `json:"tag"`
|
||||
DefaultOptions CCMDefaultCredentialOptions `json:"-"`
|
||||
ExternalOptions CCMExternalCredentialOptions `json:"-"`
|
||||
BalancerOptions CCMBalancerCredentialOptions `json:"-"`
|
||||
FallbackOptions CCMFallbackCredentialOptions `json:"-"`
|
||||
}
|
||||
|
||||
type CCMCredential _CCMCredential
|
||||
|
||||
func (c CCMCredential) MarshalJSON() ([]byte, error) {
|
||||
var v any
|
||||
switch c.Type {
|
||||
case "", "default":
|
||||
c.Type = ""
|
||||
v = c.DefaultOptions
|
||||
case "external":
|
||||
v = c.ExternalOptions
|
||||
case "balancer":
|
||||
v = c.BalancerOptions
|
||||
case "fallback":
|
||||
v = c.FallbackOptions
|
||||
default:
|
||||
return nil, E.New("unknown credential type: ", c.Type)
|
||||
}
|
||||
return badjson.MarshallObjects((_CCMCredential)(c), v)
|
||||
}
|
||||
|
||||
func (c *CCMCredential) UnmarshalJSON(bytes []byte) error {
|
||||
err := json.Unmarshal(bytes, (*_CCMCredential)(c))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.Tag == "" {
|
||||
return E.New("missing credential tag")
|
||||
}
|
||||
var v any
|
||||
switch c.Type {
|
||||
case "", "default":
|
||||
c.Type = "default"
|
||||
v = &c.DefaultOptions
|
||||
case "external":
|
||||
v = &c.ExternalOptions
|
||||
case "balancer":
|
||||
v = &c.BalancerOptions
|
||||
case "fallback":
|
||||
v = &c.FallbackOptions
|
||||
default:
|
||||
return E.New("unknown credential type: ", c.Type)
|
||||
}
|
||||
return badjson.UnmarshallExcluded(bytes, (*_CCMCredential)(c), v)
|
||||
}
|
||||
|
||||
type CCMDefaultCredentialOptions struct {
|
||||
CredentialPath string `json:"credential_path,omitempty"`
|
||||
UsagesPath string `json:"usages_path,omitempty"`
|
||||
Detour string `json:"detour,omitempty"`
|
||||
Reserve5h uint8 `json:"reserve_5h"`
|
||||
ReserveWeekly uint8 `json:"reserve_weekly"`
|
||||
Limit5h uint8 `json:"limit_5h,omitempty"`
|
||||
LimitWeekly uint8 `json:"limit_weekly,omitempty"`
|
||||
}
|
||||
|
||||
type CCMBalancerCredentialOptions struct {
|
||||
Strategy string `json:"strategy,omitempty"`
|
||||
Credentials badoption.Listable[string] `json:"credentials"`
|
||||
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
|
||||
}
|
||||
|
||||
type CCMExternalCredentialOptions struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
ServerOptions
|
||||
Token string `json:"token"`
|
||||
Reverse bool `json:"reverse,omitempty"`
|
||||
Detour string `json:"detour,omitempty"`
|
||||
UsagesPath string `json:"usages_path,omitempty"`
|
||||
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
|
||||
}
|
||||
|
||||
type CCMFallbackCredentialOptions struct {
|
||||
Credentials badoption.Listable[string] `json:"credentials"`
|
||||
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package option
|
||||
|
||||
import (
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
"github.com/sagernet/sing/common/json/badjson"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
)
|
||||
|
||||
@@ -8,6 +11,7 @@ type OCMServiceOptions struct {
|
||||
ListenOptions
|
||||
InboundTLSOptionsContainer
|
||||
CredentialPath string `json:"credential_path,omitempty"`
|
||||
Credentials []OCMCredential `json:"credentials,omitempty"`
|
||||
Users []OCMUser `json:"users,omitempty"`
|
||||
Headers badoption.HTTPHeader `json:"headers,omitempty"`
|
||||
Detour string `json:"detour,omitempty"`
|
||||
@@ -15,6 +19,94 @@ type OCMServiceOptions struct {
|
||||
}
|
||||
|
||||
type OCMUser struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Credential string `json:"credential,omitempty"`
|
||||
ExternalCredential string `json:"external_credential,omitempty"`
|
||||
AllowExternalUsage bool `json:"allow_external_usage,omitempty"`
|
||||
}
|
||||
|
||||
type _OCMCredential struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Tag string `json:"tag"`
|
||||
DefaultOptions OCMDefaultCredentialOptions `json:"-"`
|
||||
ExternalOptions OCMExternalCredentialOptions `json:"-"`
|
||||
BalancerOptions OCMBalancerCredentialOptions `json:"-"`
|
||||
FallbackOptions OCMFallbackCredentialOptions `json:"-"`
|
||||
}
|
||||
|
||||
type OCMCredential _OCMCredential
|
||||
|
||||
func (c OCMCredential) MarshalJSON() ([]byte, error) {
|
||||
var v any
|
||||
switch c.Type {
|
||||
case "", "default":
|
||||
c.Type = ""
|
||||
v = c.DefaultOptions
|
||||
case "external":
|
||||
v = c.ExternalOptions
|
||||
case "balancer":
|
||||
v = c.BalancerOptions
|
||||
case "fallback":
|
||||
v = c.FallbackOptions
|
||||
default:
|
||||
return nil, E.New("unknown credential type: ", c.Type)
|
||||
}
|
||||
return badjson.MarshallObjects((_OCMCredential)(c), v)
|
||||
}
|
||||
|
||||
func (c *OCMCredential) UnmarshalJSON(bytes []byte) error {
|
||||
err := json.Unmarshal(bytes, (*_OCMCredential)(c))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.Tag == "" {
|
||||
return E.New("missing credential tag")
|
||||
}
|
||||
var v any
|
||||
switch c.Type {
|
||||
case "", "default":
|
||||
c.Type = "default"
|
||||
v = &c.DefaultOptions
|
||||
case "external":
|
||||
v = &c.ExternalOptions
|
||||
case "balancer":
|
||||
v = &c.BalancerOptions
|
||||
case "fallback":
|
||||
v = &c.FallbackOptions
|
||||
default:
|
||||
return E.New("unknown credential type: ", c.Type)
|
||||
}
|
||||
return badjson.UnmarshallExcluded(bytes, (*_OCMCredential)(c), v)
|
||||
}
|
||||
|
||||
type OCMDefaultCredentialOptions struct {
|
||||
CredentialPath string `json:"credential_path,omitempty"`
|
||||
UsagesPath string `json:"usages_path,omitempty"`
|
||||
Detour string `json:"detour,omitempty"`
|
||||
Reserve5h uint8 `json:"reserve_5h"`
|
||||
ReserveWeekly uint8 `json:"reserve_weekly"`
|
||||
Limit5h uint8 `json:"limit_5h,omitempty"`
|
||||
LimitWeekly uint8 `json:"limit_weekly,omitempty"`
|
||||
}
|
||||
|
||||
type OCMBalancerCredentialOptions struct {
|
||||
Strategy string `json:"strategy,omitempty"`
|
||||
Credentials badoption.Listable[string] `json:"credentials"`
|
||||
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
|
||||
}
|
||||
|
||||
type OCMExternalCredentialOptions struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
ServerOptions
|
||||
Token string `json:"token"`
|
||||
Reverse bool `json:"reverse,omitempty"`
|
||||
Detour string `json:"detour,omitempty"`
|
||||
UsagesPath string `json:"usages_path,omitempty"`
|
||||
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
|
||||
}
|
||||
|
||||
type OCMFallbackCredentialOptions struct {
|
||||
Credentials badoption.Listable[string] `json:"credentials"`
|
||||
PollInterval badoption.Duration `json:"poll_interval,omitempty"`
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ type RouteOptions struct {
|
||||
RuleSet []RuleSet `json:"rule_set,omitempty"`
|
||||
Final string `json:"final,omitempty"`
|
||||
FindProcess bool `json:"find_process,omitempty"`
|
||||
FindNeighbor bool `json:"find_neighbor,omitempty"`
|
||||
DHCPLeaseFiles badoption.Listable[string] `json:"dhcp_lease_files,omitempty"`
|
||||
AutoDetectInterface bool `json:"auto_detect_interface,omitempty"`
|
||||
OverrideAndroidVPN bool `json:"override_android_vpn,omitempty"`
|
||||
DefaultInterface string `json:"default_interface,omitempty"`
|
||||
|
||||
@@ -103,6 +103,8 @@ type RawDefaultRule struct {
|
||||
InterfaceAddress *badjson.TypedMap[string, badoption.Listable[*badoption.Prefixable]] `json:"interface_address,omitempty"`
|
||||
NetworkInterfaceAddress *badjson.TypedMap[InterfaceType, badoption.Listable[*badoption.Prefixable]] `json:"network_interface_address,omitempty"`
|
||||
DefaultInterfaceAddress badoption.Listable[*badoption.Prefixable] `json:"default_interface_address,omitempty"`
|
||||
SourceMACAddress badoption.Listable[string] `json:"source_mac_address,omitempty"`
|
||||
SourceHostname badoption.Listable[string] `json:"source_hostname,omitempty"`
|
||||
PreferredBy badoption.Listable[string] `json:"preferred_by,omitempty"`
|
||||
RuleSet badoption.Listable[string] `json:"rule_set,omitempty"`
|
||||
RuleSetIPCIDRMatchSource bool `json:"rule_set_ip_cidr_match_source,omitempty"`
|
||||
|
||||
@@ -106,6 +106,8 @@ type RawDefaultDNSRule struct {
|
||||
InterfaceAddress *badjson.TypedMap[string, badoption.Listable[*badoption.Prefixable]] `json:"interface_address,omitempty"`
|
||||
NetworkInterfaceAddress *badjson.TypedMap[InterfaceType, badoption.Listable[*badoption.Prefixable]] `json:"network_interface_address,omitempty"`
|
||||
DefaultInterfaceAddress badoption.Listable[*badoption.Prefixable] `json:"default_interface_address,omitempty"`
|
||||
SourceMACAddress badoption.Listable[string] `json:"source_mac_address,omitempty"`
|
||||
SourceHostname badoption.Listable[string] `json:"source_hostname,omitempty"`
|
||||
RuleSet badoption.Listable[string] `json:"rule_set,omitempty"`
|
||||
RuleSetIPCIDRMatchSource bool `json:"rule_set_ip_cidr_match_source,omitempty"`
|
||||
RuleSetIPCIDRAcceptEmpty bool `json:"rule_set_ip_cidr_accept_empty,omitempty"`
|
||||
|
||||
@@ -39,6 +39,8 @@ type TunInboundOptions struct {
|
||||
IncludeAndroidUser badoption.Listable[int] `json:"include_android_user,omitempty"`
|
||||
IncludePackage badoption.Listable[string] `json:"include_package,omitempty"`
|
||||
ExcludePackage badoption.Listable[string] `json:"exclude_package,omitempty"`
|
||||
IncludeMACAddress badoption.Listable[string] `json:"include_mac_address,omitempty"`
|
||||
ExcludeMACAddress badoption.Listable[string] `json:"exclude_mac_address,omitempty"`
|
||||
UDPTimeout UDPTimeoutCompat `json:"udp_timeout,omitempty"`
|
||||
Stack string `json:"stack,omitempty"`
|
||||
Platform *TunPlatformOptions `json:"platform,omitempty"`
|
||||
|
||||
@@ -333,9 +333,6 @@ func (t *Endpoint) Start(stage adapter.StartStage) error {
|
||||
t.systemTun = systemTun
|
||||
t.systemDialer = systemDialer
|
||||
t.server.TunDevice = wgTunDevice
|
||||
t.server.RouterWrapper = func(inner router.Router) router.Router {
|
||||
return &addressOnlyRouter{Router: inner}
|
||||
}
|
||||
}
|
||||
if mark := t.network.AutoRedirectOutputMark(); mark > 0 {
|
||||
controlFunc := t.network.AutoRedirectOutputMarkFunc()
|
||||
@@ -480,11 +477,12 @@ func (t *Endpoint) Close() error {
|
||||
t.fallbackTCPCloser()
|
||||
t.fallbackTCPCloser = nil
|
||||
}
|
||||
err := common.Close(common.PtrOrNil(t.server))
|
||||
if t.systemTun != nil {
|
||||
_ = t.systemTun.Close()
|
||||
t.systemTun.Close()
|
||||
t.systemTun = nil
|
||||
}
|
||||
return common.Close(common.PtrOrNil(t.server))
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
@@ -849,16 +847,3 @@ func (c *dnsConfigurtor) GetBaseConfig() (tsDNS.OSConfig, error) {
|
||||
func (c *dnsConfigurtor) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type addressOnlyRouter struct {
|
||||
router.Router
|
||||
}
|
||||
|
||||
func (r *addressOnlyRouter) Set(config *router.Config) error {
|
||||
if config != nil {
|
||||
config = &router.Config{
|
||||
LocalAddrs: config.LocalAddrs,
|
||||
}
|
||||
}
|
||||
return r.Router.Set(config)
|
||||
}
|
||||
|
||||
@@ -156,6 +156,22 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||
if nfQueue == 0 {
|
||||
nfQueue = tun.DefaultAutoRedirectNFQueue
|
||||
}
|
||||
var includeMACAddress []net.HardwareAddr
|
||||
for i, macString := range options.IncludeMACAddress {
|
||||
mac, macErr := net.ParseMAC(macString)
|
||||
if macErr != nil {
|
||||
return nil, E.Cause(macErr, "parse include_mac_address[", i, "]")
|
||||
}
|
||||
includeMACAddress = append(includeMACAddress, mac)
|
||||
}
|
||||
var excludeMACAddress []net.HardwareAddr
|
||||
for i, macString := range options.ExcludeMACAddress {
|
||||
mac, macErr := net.ParseMAC(macString)
|
||||
if macErr != nil {
|
||||
return nil, E.Cause(macErr, "parse exclude_mac_address[", i, "]")
|
||||
}
|
||||
excludeMACAddress = append(excludeMACAddress, mac)
|
||||
}
|
||||
networkManager := service.FromContext[adapter.NetworkManager](ctx)
|
||||
multiPendingPackets := C.IsDarwin && ((options.Stack == "gvisor" && tunMTU < 32768) || (options.Stack != "gvisor" && options.MTU <= 9000))
|
||||
inbound := &Inbound{
|
||||
@@ -193,6 +209,8 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||
IncludeAndroidUser: options.IncludeAndroidUser,
|
||||
IncludePackage: options.IncludePackage,
|
||||
ExcludePackage: options.ExcludePackage,
|
||||
IncludeMACAddress: includeMACAddress,
|
||||
ExcludeMACAddress: excludeMACAddress,
|
||||
InterfaceMonitor: networkManager.InterfaceMonitor(),
|
||||
EXP_MultiPendingPackets: multiPendingPackets,
|
||||
},
|
||||
|
||||
239
route/neighbor_resolver_darwin.go
Normal file
239
route/neighbor_resolver_darwin.go
Normal file
@@ -0,0 +1,239 @@
|
||||
//go:build darwin
|
||||
|
||||
package route
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/fswatch"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var defaultLeaseFiles = []string{
|
||||
"/var/db/dhcpd_leases",
|
||||
"/tmp/dhcp.leases",
|
||||
}
|
||||
|
||||
type neighborResolver struct {
|
||||
logger logger.ContextLogger
|
||||
leaseFiles []string
|
||||
access sync.RWMutex
|
||||
neighborIPToMAC map[netip.Addr]net.HardwareAddr
|
||||
leaseIPToMAC map[netip.Addr]net.HardwareAddr
|
||||
ipToHostname map[netip.Addr]string
|
||||
macToHostname map[string]string
|
||||
watcher *fswatch.Watcher
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newNeighborResolver(resolverLogger logger.ContextLogger, leaseFiles []string) (adapter.NeighborResolver, error) {
|
||||
if len(leaseFiles) == 0 {
|
||||
for _, path := range defaultLeaseFiles {
|
||||
info, err := os.Stat(path)
|
||||
if err == nil && info.Size() > 0 {
|
||||
leaseFiles = append(leaseFiles, path)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &neighborResolver{
|
||||
logger: resolverLogger,
|
||||
leaseFiles: leaseFiles,
|
||||
neighborIPToMAC: make(map[netip.Addr]net.HardwareAddr),
|
||||
leaseIPToMAC: make(map[netip.Addr]net.HardwareAddr),
|
||||
ipToHostname: make(map[netip.Addr]string),
|
||||
macToHostname: make(map[string]string),
|
||||
done: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *neighborResolver) Start() error {
|
||||
err := r.loadNeighborTable()
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "load neighbor table"))
|
||||
}
|
||||
r.doReloadLeaseFiles()
|
||||
go r.subscribeNeighborUpdates()
|
||||
if len(r.leaseFiles) > 0 {
|
||||
watcher, err := fswatch.NewWatcher(fswatch.Options{
|
||||
Path: r.leaseFiles,
|
||||
Logger: r.logger,
|
||||
Callback: func(_ string) {
|
||||
r.doReloadLeaseFiles()
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "create lease file watcher"))
|
||||
} else {
|
||||
r.watcher = watcher
|
||||
err = watcher.Start()
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "start lease file watcher"))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *neighborResolver) Close() error {
|
||||
close(r.done)
|
||||
if r.watcher != nil {
|
||||
return r.watcher.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *neighborResolver) LookupMAC(address netip.Addr) (net.HardwareAddr, bool) {
|
||||
r.access.RLock()
|
||||
defer r.access.RUnlock()
|
||||
mac, found := r.neighborIPToMAC[address]
|
||||
if found {
|
||||
return mac, true
|
||||
}
|
||||
mac, found = r.leaseIPToMAC[address]
|
||||
if found {
|
||||
return mac, true
|
||||
}
|
||||
mac, found = extractMACFromEUI64(address)
|
||||
if found {
|
||||
return mac, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (r *neighborResolver) LookupHostname(address netip.Addr) (string, bool) {
|
||||
r.access.RLock()
|
||||
defer r.access.RUnlock()
|
||||
hostname, found := r.ipToHostname[address]
|
||||
if found {
|
||||
return hostname, true
|
||||
}
|
||||
mac, macFound := r.neighborIPToMAC[address]
|
||||
if !macFound {
|
||||
mac, macFound = r.leaseIPToMAC[address]
|
||||
}
|
||||
if !macFound {
|
||||
mac, macFound = extractMACFromEUI64(address)
|
||||
}
|
||||
if macFound {
|
||||
hostname, found = r.macToHostname[mac.String()]
|
||||
if found {
|
||||
return hostname, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (r *neighborResolver) loadNeighborTable() error {
|
||||
entries, err := ReadNeighborEntries()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.access.Lock()
|
||||
defer r.access.Unlock()
|
||||
for _, entry := range entries {
|
||||
r.neighborIPToMAC[entry.Address] = entry.MACAddress
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *neighborResolver) subscribeNeighborUpdates() {
|
||||
routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, 0)
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "subscribe neighbor updates"))
|
||||
return
|
||||
}
|
||||
err = unix.SetNonblock(routeSocket, true)
|
||||
if err != nil {
|
||||
unix.Close(routeSocket)
|
||||
r.logger.Warn(E.Cause(err, "set route socket nonblock"))
|
||||
return
|
||||
}
|
||||
routeSocketFile := os.NewFile(uintptr(routeSocket), "route")
|
||||
defer routeSocketFile.Close()
|
||||
buffer := buf.NewPacket()
|
||||
defer buffer.Release()
|
||||
for {
|
||||
select {
|
||||
case <-r.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
err = setReadDeadline(routeSocketFile, 3*time.Second)
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "set route socket read deadline"))
|
||||
return
|
||||
}
|
||||
n, err := routeSocketFile.Read(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-r.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
r.logger.Warn(E.Cause(err, "receive neighbor update"))
|
||||
continue
|
||||
}
|
||||
messages, err := route.ParseRIB(route.RIBTypeRoute, buffer.FreeBytes()[:n])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, message := range messages {
|
||||
routeMessage, isRouteMessage := message.(*route.RouteMessage)
|
||||
if !isRouteMessage {
|
||||
continue
|
||||
}
|
||||
if routeMessage.Flags&unix.RTF_LLINFO == 0 {
|
||||
continue
|
||||
}
|
||||
address, mac, isDelete, ok := ParseRouteNeighborMessage(routeMessage)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
r.access.Lock()
|
||||
if isDelete {
|
||||
delete(r.neighborIPToMAC, address)
|
||||
} else {
|
||||
r.neighborIPToMAC[address] = mac
|
||||
}
|
||||
r.access.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *neighborResolver) doReloadLeaseFiles() {
|
||||
leaseIPToMAC, ipToHostname, macToHostname := ReloadLeaseFiles(r.leaseFiles)
|
||||
r.access.Lock()
|
||||
r.leaseIPToMAC = leaseIPToMAC
|
||||
r.ipToHostname = ipToHostname
|
||||
r.macToHostname = macToHostname
|
||||
r.access.Unlock()
|
||||
}
|
||||
|
||||
func setReadDeadline(file *os.File, timeout time.Duration) error {
|
||||
rawConn, err := file.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var controlErr error
|
||||
err = rawConn.Control(func(fd uintptr) {
|
||||
tv := unix.NsecToTimeval(int64(timeout))
|
||||
controlErr = unix.SetsockoptTimeval(int(fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return controlErr
|
||||
}
|
||||
386
route/neighbor_resolver_lease.go
Normal file
386
route/neighbor_resolver_lease.go
Normal file
@@ -0,0 +1,386 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func parseLeaseFile(path string, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
if strings.HasSuffix(path, "dhcpd_leases") {
|
||||
parseBootpdLeases(file, ipToMAC, ipToHostname, macToHostname)
|
||||
return
|
||||
}
|
||||
if strings.HasSuffix(path, "kea-leases4.csv") {
|
||||
parseKeaCSV4(file, ipToMAC, ipToHostname, macToHostname)
|
||||
return
|
||||
}
|
||||
if strings.HasSuffix(path, "kea-leases6.csv") {
|
||||
parseKeaCSV6(file, ipToMAC, ipToHostname, macToHostname)
|
||||
return
|
||||
}
|
||||
if strings.HasSuffix(path, "dhcpd.leases") {
|
||||
parseISCDhcpd(file, ipToMAC, ipToHostname, macToHostname)
|
||||
return
|
||||
}
|
||||
parseDnsmasqOdhcpd(file, ipToMAC, ipToHostname, macToHostname)
|
||||
}
|
||||
|
||||
func ReloadLeaseFiles(leaseFiles []string) (leaseIPToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) {
|
||||
leaseIPToMAC = make(map[netip.Addr]net.HardwareAddr)
|
||||
ipToHostname = make(map[netip.Addr]string)
|
||||
macToHostname = make(map[string]string)
|
||||
for _, path := range leaseFiles {
|
||||
parseLeaseFile(path, leaseIPToMAC, ipToHostname, macToHostname)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseDnsmasqOdhcpd(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) {
|
||||
now := time.Now().Unix()
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "duid ") {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "# ") {
|
||||
parseOdhcpdLine(line[2:], ipToMAC, ipToHostname, macToHostname)
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 4 {
|
||||
continue
|
||||
}
|
||||
expiry, err := strconv.ParseInt(fields[0], 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if expiry != 0 && expiry < now {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(fields[1], ":") {
|
||||
mac, macErr := net.ParseMAC(fields[1])
|
||||
if macErr != nil {
|
||||
continue
|
||||
}
|
||||
address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[2]))
|
||||
if !addrOK {
|
||||
continue
|
||||
}
|
||||
address = address.Unmap()
|
||||
ipToMAC[address] = mac
|
||||
hostname := fields[3]
|
||||
if hostname != "*" {
|
||||
ipToHostname[address] = hostname
|
||||
macToHostname[mac.String()] = hostname
|
||||
}
|
||||
} else {
|
||||
var mac net.HardwareAddr
|
||||
if len(fields) >= 5 {
|
||||
duid, duidErr := parseDUID(fields[4])
|
||||
if duidErr == nil {
|
||||
mac, _ = extractMACFromDUID(duid)
|
||||
}
|
||||
}
|
||||
address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[2]))
|
||||
if !addrOK {
|
||||
continue
|
||||
}
|
||||
address = address.Unmap()
|
||||
if mac != nil {
|
||||
ipToMAC[address] = mac
|
||||
}
|
||||
hostname := fields[3]
|
||||
if hostname != "*" {
|
||||
ipToHostname[address] = hostname
|
||||
if mac != nil {
|
||||
macToHostname[mac.String()] = hostname
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseOdhcpdLine(line string, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 5 {
|
||||
return
|
||||
}
|
||||
validTime, err := strconv.ParseInt(fields[4], 10, 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if validTime == 0 {
|
||||
return
|
||||
}
|
||||
if validTime > 0 && validTime < time.Now().Unix() {
|
||||
return
|
||||
}
|
||||
hostname := fields[3]
|
||||
if hostname == "-" || strings.HasPrefix(hostname, `broken\x20`) {
|
||||
hostname = ""
|
||||
}
|
||||
if len(fields) >= 8 && fields[2] == "ipv4" {
|
||||
mac, macErr := net.ParseMAC(fields[1])
|
||||
if macErr != nil {
|
||||
return
|
||||
}
|
||||
addressField := fields[7]
|
||||
slashIndex := strings.IndexByte(addressField, '/')
|
||||
if slashIndex >= 0 {
|
||||
addressField = addressField[:slashIndex]
|
||||
}
|
||||
address, addrOK := netip.AddrFromSlice(net.ParseIP(addressField))
|
||||
if !addrOK {
|
||||
return
|
||||
}
|
||||
address = address.Unmap()
|
||||
ipToMAC[address] = mac
|
||||
if hostname != "" {
|
||||
ipToHostname[address] = hostname
|
||||
macToHostname[mac.String()] = hostname
|
||||
}
|
||||
return
|
||||
}
|
||||
var mac net.HardwareAddr
|
||||
duidHex := fields[1]
|
||||
duidBytes, hexErr := hex.DecodeString(duidHex)
|
||||
if hexErr == nil {
|
||||
mac, _ = extractMACFromDUID(duidBytes)
|
||||
}
|
||||
for i := 7; i < len(fields); i++ {
|
||||
addressField := fields[i]
|
||||
slashIndex := strings.IndexByte(addressField, '/')
|
||||
if slashIndex >= 0 {
|
||||
addressField = addressField[:slashIndex]
|
||||
}
|
||||
address, addrOK := netip.AddrFromSlice(net.ParseIP(addressField))
|
||||
if !addrOK {
|
||||
continue
|
||||
}
|
||||
address = address.Unmap()
|
||||
if mac != nil {
|
||||
ipToMAC[address] = mac
|
||||
}
|
||||
if hostname != "" {
|
||||
ipToHostname[address] = hostname
|
||||
if mac != nil {
|
||||
macToHostname[mac.String()] = hostname
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseISCDhcpd(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) {
|
||||
scanner := bufio.NewScanner(file)
|
||||
var currentIP netip.Addr
|
||||
var currentMAC net.HardwareAddr
|
||||
var currentHostname string
|
||||
var currentActive bool
|
||||
var inLease bool
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "lease ") && strings.HasSuffix(line, "{") {
|
||||
ipString := strings.TrimSuffix(strings.TrimPrefix(line, "lease "), " {")
|
||||
parsed, addrOK := netip.AddrFromSlice(net.ParseIP(ipString))
|
||||
if addrOK {
|
||||
currentIP = parsed.Unmap()
|
||||
inLease = true
|
||||
currentMAC = nil
|
||||
currentHostname = ""
|
||||
currentActive = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
if line == "}" && inLease {
|
||||
if currentActive && currentMAC != nil {
|
||||
ipToMAC[currentIP] = currentMAC
|
||||
if currentHostname != "" {
|
||||
ipToHostname[currentIP] = currentHostname
|
||||
macToHostname[currentMAC.String()] = currentHostname
|
||||
}
|
||||
} else {
|
||||
delete(ipToMAC, currentIP)
|
||||
delete(ipToHostname, currentIP)
|
||||
}
|
||||
inLease = false
|
||||
continue
|
||||
}
|
||||
if !inLease {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "hardware ethernet ") {
|
||||
macString := strings.TrimSuffix(strings.TrimPrefix(line, "hardware ethernet "), ";")
|
||||
parsed, macErr := net.ParseMAC(macString)
|
||||
if macErr == nil {
|
||||
currentMAC = parsed
|
||||
}
|
||||
} else if strings.HasPrefix(line, "client-hostname ") {
|
||||
hostname := strings.TrimSuffix(strings.TrimPrefix(line, "client-hostname "), ";")
|
||||
hostname = strings.Trim(hostname, "\"")
|
||||
if hostname != "" {
|
||||
currentHostname = hostname
|
||||
}
|
||||
} else if strings.HasPrefix(line, "binding state ") {
|
||||
state := strings.TrimSuffix(strings.TrimPrefix(line, "binding state "), ";")
|
||||
currentActive = state == "active"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseKeaCSV4(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) {
|
||||
scanner := bufio.NewScanner(file)
|
||||
firstLine := true
|
||||
for scanner.Scan() {
|
||||
if firstLine {
|
||||
firstLine = false
|
||||
continue
|
||||
}
|
||||
fields := strings.Split(scanner.Text(), ",")
|
||||
if len(fields) < 10 {
|
||||
continue
|
||||
}
|
||||
if fields[9] != "0" {
|
||||
continue
|
||||
}
|
||||
address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[0]))
|
||||
if !addrOK {
|
||||
continue
|
||||
}
|
||||
address = address.Unmap()
|
||||
mac, macErr := net.ParseMAC(fields[1])
|
||||
if macErr != nil {
|
||||
continue
|
||||
}
|
||||
ipToMAC[address] = mac
|
||||
hostname := ""
|
||||
if len(fields) > 8 {
|
||||
hostname = fields[8]
|
||||
}
|
||||
if hostname != "" {
|
||||
ipToHostname[address] = hostname
|
||||
macToHostname[mac.String()] = hostname
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseKeaCSV6(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) {
|
||||
scanner := bufio.NewScanner(file)
|
||||
firstLine := true
|
||||
for scanner.Scan() {
|
||||
if firstLine {
|
||||
firstLine = false
|
||||
continue
|
||||
}
|
||||
fields := strings.Split(scanner.Text(), ",")
|
||||
if len(fields) < 14 {
|
||||
continue
|
||||
}
|
||||
if fields[13] != "0" {
|
||||
continue
|
||||
}
|
||||
address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[0]))
|
||||
if !addrOK {
|
||||
continue
|
||||
}
|
||||
address = address.Unmap()
|
||||
var mac net.HardwareAddr
|
||||
if fields[12] != "" {
|
||||
mac, _ = net.ParseMAC(fields[12])
|
||||
}
|
||||
if mac == nil {
|
||||
duid, duidErr := hex.DecodeString(strings.ReplaceAll(fields[1], ":", ""))
|
||||
if duidErr == nil {
|
||||
mac, _ = extractMACFromDUID(duid)
|
||||
}
|
||||
}
|
||||
hostname := ""
|
||||
if len(fields) > 11 {
|
||||
hostname = fields[11]
|
||||
}
|
||||
if mac != nil {
|
||||
ipToMAC[address] = mac
|
||||
}
|
||||
if hostname != "" {
|
||||
ipToHostname[address] = hostname
|
||||
if mac != nil {
|
||||
macToHostname[mac.String()] = hostname
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseBootpdLeases(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) {
|
||||
now := time.Now().Unix()
|
||||
scanner := bufio.NewScanner(file)
|
||||
var currentName string
|
||||
var currentIP netip.Addr
|
||||
var currentMAC net.HardwareAddr
|
||||
var currentLease int64
|
||||
var inBlock bool
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "{" {
|
||||
inBlock = true
|
||||
currentName = ""
|
||||
currentIP = netip.Addr{}
|
||||
currentMAC = nil
|
||||
currentLease = 0
|
||||
continue
|
||||
}
|
||||
if line == "}" && inBlock {
|
||||
if currentMAC != nil && currentIP.IsValid() {
|
||||
if currentLease == 0 || currentLease >= now {
|
||||
ipToMAC[currentIP] = currentMAC
|
||||
if currentName != "" {
|
||||
ipToHostname[currentIP] = currentName
|
||||
macToHostname[currentMAC.String()] = currentName
|
||||
}
|
||||
}
|
||||
}
|
||||
inBlock = false
|
||||
continue
|
||||
}
|
||||
if !inBlock {
|
||||
continue
|
||||
}
|
||||
key, value, found := strings.Cut(line, "=")
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
switch key {
|
||||
case "name":
|
||||
currentName = value
|
||||
case "ip_address":
|
||||
parsed, addrOK := netip.AddrFromSlice(net.ParseIP(value))
|
||||
if addrOK {
|
||||
currentIP = parsed.Unmap()
|
||||
}
|
||||
case "hw_address":
|
||||
typeAndMAC, hasSep := strings.CutPrefix(value, "1,")
|
||||
if hasSep {
|
||||
mac, macErr := net.ParseMAC(typeAndMAC)
|
||||
if macErr == nil {
|
||||
currentMAC = mac
|
||||
}
|
||||
}
|
||||
case "lease":
|
||||
leaseHex := strings.TrimPrefix(value, "0x")
|
||||
parsed, parseErr := strconv.ParseInt(leaseHex, 16, 64)
|
||||
if parseErr == nil {
|
||||
currentLease = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
224
route/neighbor_resolver_linux.go
Normal file
224
route/neighbor_resolver_linux.go
Normal file
@@ -0,0 +1,224 @@
|
||||
//go:build linux
|
||||
|
||||
package route
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/fswatch"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink"
|
||||
"github.com/mdlayher/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var defaultLeaseFiles = []string{
|
||||
"/tmp/dhcp.leases",
|
||||
"/var/lib/dhcp/dhcpd.leases",
|
||||
"/var/lib/dhcpd/dhcpd.leases",
|
||||
"/var/lib/kea/kea-leases4.csv",
|
||||
"/var/lib/kea/kea-leases6.csv",
|
||||
}
|
||||
|
||||
type neighborResolver struct {
|
||||
logger logger.ContextLogger
|
||||
leaseFiles []string
|
||||
access sync.RWMutex
|
||||
neighborIPToMAC map[netip.Addr]net.HardwareAddr
|
||||
leaseIPToMAC map[netip.Addr]net.HardwareAddr
|
||||
ipToHostname map[netip.Addr]string
|
||||
macToHostname map[string]string
|
||||
watcher *fswatch.Watcher
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newNeighborResolver(resolverLogger logger.ContextLogger, leaseFiles []string) (adapter.NeighborResolver, error) {
|
||||
if len(leaseFiles) == 0 {
|
||||
for _, path := range defaultLeaseFiles {
|
||||
info, err := os.Stat(path)
|
||||
if err == nil && info.Size() > 0 {
|
||||
leaseFiles = append(leaseFiles, path)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &neighborResolver{
|
||||
logger: resolverLogger,
|
||||
leaseFiles: leaseFiles,
|
||||
neighborIPToMAC: make(map[netip.Addr]net.HardwareAddr),
|
||||
leaseIPToMAC: make(map[netip.Addr]net.HardwareAddr),
|
||||
ipToHostname: make(map[netip.Addr]string),
|
||||
macToHostname: make(map[string]string),
|
||||
done: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *neighborResolver) Start() error {
|
||||
err := r.loadNeighborTable()
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "load neighbor table"))
|
||||
}
|
||||
r.doReloadLeaseFiles()
|
||||
go r.subscribeNeighborUpdates()
|
||||
if len(r.leaseFiles) > 0 {
|
||||
watcher, err := fswatch.NewWatcher(fswatch.Options{
|
||||
Path: r.leaseFiles,
|
||||
Logger: r.logger,
|
||||
Callback: func(_ string) {
|
||||
r.doReloadLeaseFiles()
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "create lease file watcher"))
|
||||
} else {
|
||||
r.watcher = watcher
|
||||
err = watcher.Start()
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "start lease file watcher"))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *neighborResolver) Close() error {
|
||||
close(r.done)
|
||||
if r.watcher != nil {
|
||||
return r.watcher.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *neighborResolver) LookupMAC(address netip.Addr) (net.HardwareAddr, bool) {
|
||||
r.access.RLock()
|
||||
defer r.access.RUnlock()
|
||||
mac, found := r.neighborIPToMAC[address]
|
||||
if found {
|
||||
return mac, true
|
||||
}
|
||||
mac, found = r.leaseIPToMAC[address]
|
||||
if found {
|
||||
return mac, true
|
||||
}
|
||||
mac, found = extractMACFromEUI64(address)
|
||||
if found {
|
||||
return mac, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (r *neighborResolver) LookupHostname(address netip.Addr) (string, bool) {
|
||||
r.access.RLock()
|
||||
defer r.access.RUnlock()
|
||||
hostname, found := r.ipToHostname[address]
|
||||
if found {
|
||||
return hostname, true
|
||||
}
|
||||
mac, macFound := r.neighborIPToMAC[address]
|
||||
if !macFound {
|
||||
mac, macFound = r.leaseIPToMAC[address]
|
||||
}
|
||||
if !macFound {
|
||||
mac, macFound = extractMACFromEUI64(address)
|
||||
}
|
||||
if macFound {
|
||||
hostname, found = r.macToHostname[mac.String()]
|
||||
if found {
|
||||
return hostname, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (r *neighborResolver) loadNeighborTable() error {
|
||||
connection, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return E.Cause(err, "dial rtnetlink")
|
||||
}
|
||||
defer connection.Close()
|
||||
neighbors, err := connection.Neigh.List()
|
||||
if err != nil {
|
||||
return E.Cause(err, "list neighbors")
|
||||
}
|
||||
r.access.Lock()
|
||||
defer r.access.Unlock()
|
||||
for _, neigh := range neighbors {
|
||||
if neigh.Attributes == nil {
|
||||
continue
|
||||
}
|
||||
if neigh.Attributes.LLAddress == nil || len(neigh.Attributes.Address) == 0 {
|
||||
continue
|
||||
}
|
||||
address, ok := netip.AddrFromSlice(neigh.Attributes.Address)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
r.neighborIPToMAC[address] = slices.Clone(neigh.Attributes.LLAddress)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *neighborResolver) subscribeNeighborUpdates() {
|
||||
connection, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{
|
||||
Groups: 1 << (unix.RTNLGRP_NEIGH - 1),
|
||||
})
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "subscribe neighbor updates"))
|
||||
return
|
||||
}
|
||||
defer connection.Close()
|
||||
for {
|
||||
select {
|
||||
case <-r.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
err = connection.SetReadDeadline(time.Now().Add(3 * time.Second))
|
||||
if err != nil {
|
||||
r.logger.Warn(E.Cause(err, "set netlink read deadline"))
|
||||
return
|
||||
}
|
||||
messages, err := connection.Receive()
|
||||
if err != nil {
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-r.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
r.logger.Warn(E.Cause(err, "receive neighbor update"))
|
||||
continue
|
||||
}
|
||||
for _, message := range messages {
|
||||
address, mac, isDelete, ok := ParseNeighborMessage(message)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
r.access.Lock()
|
||||
if isDelete {
|
||||
delete(r.neighborIPToMAC, address)
|
||||
} else {
|
||||
r.neighborIPToMAC[address] = mac
|
||||
}
|
||||
r.access.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *neighborResolver) doReloadLeaseFiles() {
|
||||
leaseIPToMAC, ipToHostname, macToHostname := ReloadLeaseFiles(r.leaseFiles)
|
||||
r.access.Lock()
|
||||
r.leaseIPToMAC = leaseIPToMAC
|
||||
r.ipToHostname = ipToHostname
|
||||
r.macToHostname = macToHostname
|
||||
r.access.Unlock()
|
||||
}
|
||||
50
route/neighbor_resolver_parse.go
Normal file
50
route/neighbor_resolver_parse.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func extractMACFromDUID(duid []byte) (net.HardwareAddr, bool) {
|
||||
if len(duid) < 4 {
|
||||
return nil, false
|
||||
}
|
||||
duidType := binary.BigEndian.Uint16(duid[0:2])
|
||||
hwType := binary.BigEndian.Uint16(duid[2:4])
|
||||
if hwType != 1 {
|
||||
return nil, false
|
||||
}
|
||||
switch duidType {
|
||||
case 1:
|
||||
if len(duid) < 14 {
|
||||
return nil, false
|
||||
}
|
||||
return net.HardwareAddr(slices.Clone(duid[8:14])), true
|
||||
case 3:
|
||||
if len(duid) < 10 {
|
||||
return nil, false
|
||||
}
|
||||
return net.HardwareAddr(slices.Clone(duid[4:10])), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func extractMACFromEUI64(address netip.Addr) (net.HardwareAddr, bool) {
|
||||
if !address.Is6() {
|
||||
return nil, false
|
||||
}
|
||||
b := address.As16()
|
||||
if b[11] != 0xff || b[12] != 0xfe {
|
||||
return nil, false
|
||||
}
|
||||
return net.HardwareAddr{b[8] ^ 0x02, b[9], b[10], b[13], b[14], b[15]}, true
|
||||
}
|
||||
|
||||
func parseDUID(s string) ([]byte, error) {
|
||||
cleaned := strings.ReplaceAll(s, ":", "")
|
||||
return hex.DecodeString(cleaned)
|
||||
}
|
||||
84
route/neighbor_resolver_platform.go
Normal file
84
route/neighbor_resolver_platform.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
)
|
||||
|
||||
type platformNeighborResolver struct {
|
||||
logger logger.ContextLogger
|
||||
platform adapter.PlatformInterface
|
||||
access sync.RWMutex
|
||||
ipToMAC map[netip.Addr]net.HardwareAddr
|
||||
ipToHostname map[netip.Addr]string
|
||||
macToHostname map[string]string
|
||||
}
|
||||
|
||||
func newPlatformNeighborResolver(resolverLogger logger.ContextLogger, platform adapter.PlatformInterface) adapter.NeighborResolver {
|
||||
return &platformNeighborResolver{
|
||||
logger: resolverLogger,
|
||||
platform: platform,
|
||||
ipToMAC: make(map[netip.Addr]net.HardwareAddr),
|
||||
ipToHostname: make(map[netip.Addr]string),
|
||||
macToHostname: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *platformNeighborResolver) Start() error {
|
||||
return r.platform.StartNeighborMonitor(r)
|
||||
}
|
||||
|
||||
func (r *platformNeighborResolver) Close() error {
|
||||
return r.platform.CloseNeighborMonitor(r)
|
||||
}
|
||||
|
||||
func (r *platformNeighborResolver) LookupMAC(address netip.Addr) (net.HardwareAddr, bool) {
|
||||
r.access.RLock()
|
||||
defer r.access.RUnlock()
|
||||
mac, found := r.ipToMAC[address]
|
||||
if found {
|
||||
return mac, true
|
||||
}
|
||||
return extractMACFromEUI64(address)
|
||||
}
|
||||
|
||||
func (r *platformNeighborResolver) LookupHostname(address netip.Addr) (string, bool) {
|
||||
r.access.RLock()
|
||||
defer r.access.RUnlock()
|
||||
hostname, found := r.ipToHostname[address]
|
||||
if found {
|
||||
return hostname, true
|
||||
}
|
||||
mac, found := r.ipToMAC[address]
|
||||
if !found {
|
||||
mac, found = extractMACFromEUI64(address)
|
||||
}
|
||||
if !found {
|
||||
return "", false
|
||||
}
|
||||
hostname, found = r.macToHostname[mac.String()]
|
||||
return hostname, found
|
||||
}
|
||||
|
||||
func (r *platformNeighborResolver) UpdateNeighborTable(entries []adapter.NeighborEntry) {
|
||||
ipToMAC := make(map[netip.Addr]net.HardwareAddr)
|
||||
ipToHostname := make(map[netip.Addr]string)
|
||||
macToHostname := make(map[string]string)
|
||||
for _, entry := range entries {
|
||||
ipToMAC[entry.Address] = entry.MACAddress
|
||||
if entry.Hostname != "" {
|
||||
ipToHostname[entry.Address] = entry.Hostname
|
||||
macToHostname[entry.MACAddress.String()] = entry.Hostname
|
||||
}
|
||||
}
|
||||
r.access.Lock()
|
||||
r.ipToMAC = ipToMAC
|
||||
r.ipToHostname = ipToHostname
|
||||
r.macToHostname = macToHostname
|
||||
r.access.Unlock()
|
||||
r.logger.Info("updated neighbor table: ", len(entries), " entries")
|
||||
}
|
||||
14
route/neighbor_resolver_stub.go
Normal file
14
route/neighbor_resolver_stub.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build !linux && !darwin
|
||||
|
||||
package route
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
)
|
||||
|
||||
func newNeighborResolver(_ logger.ContextLogger, _ []string) (adapter.NeighborResolver, error) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
104
route/neighbor_table_darwin.go
Normal file
104
route/neighbor_table_darwin.go
Normal file
@@ -0,0 +1,104 @@
|
||||
//go:build darwin
|
||||
|
||||
package route
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func ReadNeighborEntries() ([]adapter.NeighborEntry, error) {
|
||||
var entries []adapter.NeighborEntry
|
||||
ipv4Entries, err := readNeighborEntriesAF(syscall.AF_INET)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read IPv4 neighbors")
|
||||
}
|
||||
entries = append(entries, ipv4Entries...)
|
||||
ipv6Entries, err := readNeighborEntriesAF(syscall.AF_INET6)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read IPv6 neighbors")
|
||||
}
|
||||
entries = append(entries, ipv6Entries...)
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func readNeighborEntriesAF(addressFamily int) ([]adapter.NeighborEntry, error) {
|
||||
rib, err := route.FetchRIB(addressFamily, route.RIBType(syscall.NET_RT_FLAGS), syscall.RTF_LLINFO)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages, err := route.ParseRIB(route.RIBType(syscall.NET_RT_FLAGS), rib)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var entries []adapter.NeighborEntry
|
||||
for _, message := range messages {
|
||||
routeMessage, isRouteMessage := message.(*route.RouteMessage)
|
||||
if !isRouteMessage {
|
||||
continue
|
||||
}
|
||||
address, macAddress, ok := parseRouteNeighborEntry(routeMessage)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, adapter.NeighborEntry{
|
||||
Address: address,
|
||||
MACAddress: macAddress,
|
||||
})
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func parseRouteNeighborEntry(message *route.RouteMessage) (address netip.Addr, macAddress net.HardwareAddr, ok bool) {
|
||||
if len(message.Addrs) <= unix.RTAX_GATEWAY {
|
||||
return
|
||||
}
|
||||
gateway, isLinkAddr := message.Addrs[unix.RTAX_GATEWAY].(*route.LinkAddr)
|
||||
if !isLinkAddr || len(gateway.Addr) < 6 {
|
||||
return
|
||||
}
|
||||
switch destination := message.Addrs[unix.RTAX_DST].(type) {
|
||||
case *route.Inet4Addr:
|
||||
address = netip.AddrFrom4(destination.IP)
|
||||
case *route.Inet6Addr:
|
||||
address = netip.AddrFrom16(destination.IP)
|
||||
default:
|
||||
return
|
||||
}
|
||||
macAddress = net.HardwareAddr(make([]byte, len(gateway.Addr)))
|
||||
copy(macAddress, gateway.Addr)
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
func ParseRouteNeighborMessage(message *route.RouteMessage) (address netip.Addr, macAddress net.HardwareAddr, isDelete bool, ok bool) {
|
||||
isDelete = message.Type == unix.RTM_DELETE
|
||||
if len(message.Addrs) <= unix.RTAX_GATEWAY {
|
||||
return
|
||||
}
|
||||
switch destination := message.Addrs[unix.RTAX_DST].(type) {
|
||||
case *route.Inet4Addr:
|
||||
address = netip.AddrFrom4(destination.IP)
|
||||
case *route.Inet6Addr:
|
||||
address = netip.AddrFrom16(destination.IP)
|
||||
default:
|
||||
return
|
||||
}
|
||||
if !isDelete {
|
||||
gateway, isLinkAddr := message.Addrs[unix.RTAX_GATEWAY].(*route.LinkAddr)
|
||||
if !isLinkAddr || len(gateway.Addr) < 6 {
|
||||
return
|
||||
}
|
||||
macAddress = net.HardwareAddr(make([]byte, len(gateway.Addr)))
|
||||
copy(macAddress, gateway.Addr)
|
||||
}
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
68
route/neighbor_table_linux.go
Normal file
68
route/neighbor_table_linux.go
Normal file
@@ -0,0 +1,68 @@
|
||||
//go:build linux
|
||||
|
||||
package route
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink"
|
||||
"github.com/mdlayher/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func ReadNeighborEntries() ([]adapter.NeighborEntry, error) {
|
||||
connection, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "dial rtnetlink")
|
||||
}
|
||||
defer connection.Close()
|
||||
neighbors, err := connection.Neigh.List()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "list neighbors")
|
||||
}
|
||||
var entries []adapter.NeighborEntry
|
||||
for _, neighbor := range neighbors {
|
||||
if neighbor.Attributes == nil {
|
||||
continue
|
||||
}
|
||||
if neighbor.Attributes.LLAddress == nil || len(neighbor.Attributes.Address) == 0 {
|
||||
continue
|
||||
}
|
||||
address, ok := netip.AddrFromSlice(neighbor.Attributes.Address)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, adapter.NeighborEntry{
|
||||
Address: address,
|
||||
MACAddress: slices.Clone(neighbor.Attributes.LLAddress),
|
||||
})
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func ParseNeighborMessage(message netlink.Message) (address netip.Addr, macAddress net.HardwareAddr, isDelete bool, ok bool) {
|
||||
var neighMessage rtnetlink.NeighMessage
|
||||
err := neighMessage.UnmarshalBinary(message.Data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if neighMessage.Attributes == nil || len(neighMessage.Attributes.Address) == 0 {
|
||||
return
|
||||
}
|
||||
address, ok = netip.AddrFromSlice(neighMessage.Attributes.Address)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
isDelete = message.Header.Type == unix.RTM_DELNEIGH
|
||||
if !isDelete && neighMessage.Attributes.LLAddress == nil {
|
||||
ok = false
|
||||
return
|
||||
}
|
||||
macAddress = slices.Clone(neighMessage.Attributes.LLAddress)
|
||||
return
|
||||
}
|
||||
@@ -51,6 +51,7 @@ type NetworkManager struct {
|
||||
endpoint adapter.EndpointManager
|
||||
inbound adapter.InboundManager
|
||||
outbound adapter.OutboundManager
|
||||
serviceManager adapter.ServiceManager
|
||||
needWIFIState bool
|
||||
wifiMonitor settings.WIFIMonitor
|
||||
wifiState adapter.WIFIState
|
||||
@@ -94,6 +95,7 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, options
|
||||
endpoint: service.FromContext[adapter.EndpointManager](ctx),
|
||||
inbound: service.FromContext[adapter.InboundManager](ctx),
|
||||
outbound: service.FromContext[adapter.OutboundManager](ctx),
|
||||
serviceManager: service.FromContext[adapter.ServiceManager](ctx),
|
||||
needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule),
|
||||
}
|
||||
if options.DefaultNetworkStrategy != nil {
|
||||
@@ -475,6 +477,15 @@ func (r *NetworkManager) ResetNetwork() {
|
||||
listener.InterfaceUpdated()
|
||||
}
|
||||
}
|
||||
|
||||
if r.serviceManager != nil {
|
||||
for _, svc := range r.serviceManager.Services() {
|
||||
listener, isListener := svc.(adapter.InterfaceUpdateListener)
|
||||
if isListener {
|
||||
listener.InterfaceUpdated()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *NetworkManager) notifyInterfaceUpdate(defaultInterface *control.Interface, flags int) {
|
||||
|
||||
@@ -439,6 +439,23 @@ func (r *Router) matchRule(
|
||||
metadata.ProcessInfo = processInfo
|
||||
}
|
||||
}
|
||||
if r.neighborResolver != nil && metadata.SourceMACAddress == nil && metadata.Source.Addr.IsValid() {
|
||||
mac, macFound := r.neighborResolver.LookupMAC(metadata.Source.Addr)
|
||||
if macFound {
|
||||
metadata.SourceMACAddress = mac
|
||||
}
|
||||
hostname, hostnameFound := r.neighborResolver.LookupHostname(metadata.Source.Addr)
|
||||
if hostnameFound {
|
||||
metadata.SourceHostname = hostname
|
||||
if macFound {
|
||||
r.logger.InfoContext(ctx, "found neighbor: ", mac, ", hostname: ", hostname)
|
||||
} else {
|
||||
r.logger.InfoContext(ctx, "found neighbor hostname: ", hostname)
|
||||
}
|
||||
} else if macFound {
|
||||
r.logger.InfoContext(ctx, "found neighbor: ", mac)
|
||||
}
|
||||
}
|
||||
if metadata.Destination.Addr.IsValid() && r.dnsTransport.FakeIP() != nil && r.dnsTransport.FakeIP().Store().Contains(metadata.Destination.Addr) {
|
||||
domain, loaded := r.dnsTransport.FakeIP().Store().Lookup(metadata.Destination.Addr)
|
||||
if !loaded {
|
||||
|
||||
@@ -31,9 +31,12 @@ type Router struct {
|
||||
network adapter.NetworkManager
|
||||
rules []adapter.Rule
|
||||
needFindProcess bool
|
||||
needFindNeighbor bool
|
||||
leaseFiles []string
|
||||
ruleSets []adapter.RuleSet
|
||||
ruleSetMap map[string]adapter.RuleSet
|
||||
processSearcher process.Searcher
|
||||
neighborResolver adapter.NeighborResolver
|
||||
pauseManager pause.Manager
|
||||
trackers []adapter.ConnectionTracker
|
||||
platformInterface adapter.PlatformInterface
|
||||
@@ -53,6 +56,8 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
|
||||
rules: make([]adapter.Rule, 0, len(options.Rules)),
|
||||
ruleSetMap: make(map[string]adapter.RuleSet),
|
||||
needFindProcess: hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess,
|
||||
needFindNeighbor: hasRule(options.Rules, isNeighborRule) || hasDNSRule(dnsOptions.Rules, isNeighborDNSRule) || options.FindNeighbor,
|
||||
leaseFiles: options.DHCPLeaseFiles,
|
||||
pauseManager: service.FromContext[pause.Manager](ctx),
|
||||
platformInterface: service.FromContext[adapter.PlatformInterface](ctx),
|
||||
}
|
||||
@@ -112,6 +117,7 @@ func (r *Router) Start(stage adapter.StartStage) error {
|
||||
}
|
||||
r.network.Initialize(r.ruleSets)
|
||||
needFindProcess := r.needFindProcess
|
||||
needFindNeighbor := r.needFindNeighbor
|
||||
for _, ruleSet := range r.ruleSets {
|
||||
metadata := ruleSet.Metadata()
|
||||
if metadata.ContainsProcessRule {
|
||||
@@ -141,6 +147,36 @@ func (r *Router) Start(stage adapter.StartStage) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
r.needFindNeighbor = needFindNeighbor
|
||||
if needFindNeighbor {
|
||||
if r.platformInterface != nil && r.platformInterface.UsePlatformNeighborResolver() {
|
||||
monitor.Start("initialize neighbor resolver")
|
||||
resolver := newPlatformNeighborResolver(r.logger, r.platformInterface)
|
||||
err := resolver.Start()
|
||||
monitor.Finish()
|
||||
if err != nil {
|
||||
r.logger.Error(E.Cause(err, "start neighbor resolver"))
|
||||
} else {
|
||||
r.neighborResolver = resolver
|
||||
}
|
||||
} else {
|
||||
monitor.Start("initialize neighbor resolver")
|
||||
resolver, err := newNeighborResolver(r.logger, r.leaseFiles)
|
||||
monitor.Finish()
|
||||
if err != nil {
|
||||
if err != os.ErrInvalid {
|
||||
r.logger.Error(E.Cause(err, "create neighbor resolver"))
|
||||
}
|
||||
} else {
|
||||
err = resolver.Start()
|
||||
if err != nil {
|
||||
r.logger.Error(E.Cause(err, "start neighbor resolver"))
|
||||
} else {
|
||||
r.neighborResolver = resolver
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case adapter.StartStatePostStart:
|
||||
for i, rule := range r.rules {
|
||||
monitor.Start("initialize rule[", i, "]")
|
||||
@@ -172,6 +208,13 @@ func (r *Router) Start(stage adapter.StartStage) error {
|
||||
func (r *Router) Close() error {
|
||||
monitor := taskmonitor.New(r.logger, C.StopTimeout)
|
||||
var err error
|
||||
if r.neighborResolver != nil {
|
||||
monitor.Start("close neighbor resolver")
|
||||
err = E.Append(err, r.neighborResolver.Close(), func(closeErr error) error {
|
||||
return E.Cause(closeErr, "close neighbor resolver")
|
||||
})
|
||||
monitor.Finish()
|
||||
}
|
||||
for i, rule := range r.rules {
|
||||
monitor.Start("close rule[", i, "]")
|
||||
err = E.Append(err, rule.Close(), func(err error) error {
|
||||
@@ -206,6 +249,14 @@ func (r *Router) NeedFindProcess() bool {
|
||||
return r.needFindProcess
|
||||
}
|
||||
|
||||
func (r *Router) NeedFindNeighbor() bool {
|
||||
return r.needFindNeighbor
|
||||
}
|
||||
|
||||
func (r *Router) NeighborResolver() adapter.NeighborResolver {
|
||||
return r.neighborResolver
|
||||
}
|
||||
|
||||
func (r *Router) ResetNetwork() {
|
||||
r.network.ResetNetwork()
|
||||
r.dns.ResetNetwork()
|
||||
|
||||
@@ -260,6 +260,16 @@ func NewDefaultRule(ctx context.Context, logger log.ContextLogger, options optio
|
||||
rule.items = append(rule.items, item)
|
||||
rule.allItems = append(rule.allItems, item)
|
||||
}
|
||||
if len(options.SourceMACAddress) > 0 {
|
||||
item := NewSourceMACAddressItem(options.SourceMACAddress)
|
||||
rule.items = append(rule.items, item)
|
||||
rule.allItems = append(rule.allItems, item)
|
||||
}
|
||||
if len(options.SourceHostname) > 0 {
|
||||
item := NewSourceHostnameItem(options.SourceHostname)
|
||||
rule.items = append(rule.items, item)
|
||||
rule.allItems = append(rule.allItems, item)
|
||||
}
|
||||
if len(options.PreferredBy) > 0 {
|
||||
item := NewPreferredByItem(ctx, options.PreferredBy)
|
||||
rule.items = append(rule.items, item)
|
||||
|
||||
@@ -261,6 +261,16 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op
|
||||
rule.items = append(rule.items, item)
|
||||
rule.allItems = append(rule.allItems, item)
|
||||
}
|
||||
if len(options.SourceMACAddress) > 0 {
|
||||
item := NewSourceMACAddressItem(options.SourceMACAddress)
|
||||
rule.items = append(rule.items, item)
|
||||
rule.allItems = append(rule.allItems, item)
|
||||
}
|
||||
if len(options.SourceHostname) > 0 {
|
||||
item := NewSourceHostnameItem(options.SourceHostname)
|
||||
rule.items = append(rule.items, item)
|
||||
rule.allItems = append(rule.allItems, item)
|
||||
}
|
||||
if len(options.RuleSet) > 0 {
|
||||
//nolint:staticcheck
|
||||
if options.Deprecated_RulesetIPCIDRMatchSource {
|
||||
|
||||
42
route/rule/rule_item_source_hostname.go
Normal file
42
route/rule/rule_item_source_hostname.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
)
|
||||
|
||||
var _ RuleItem = (*SourceHostnameItem)(nil)
|
||||
|
||||
type SourceHostnameItem struct {
|
||||
hostnames []string
|
||||
hostnameMap map[string]bool
|
||||
}
|
||||
|
||||
func NewSourceHostnameItem(hostnameList []string) *SourceHostnameItem {
|
||||
rule := &SourceHostnameItem{
|
||||
hostnames: hostnameList,
|
||||
hostnameMap: make(map[string]bool),
|
||||
}
|
||||
for _, hostname := range hostnameList {
|
||||
rule.hostnameMap[hostname] = true
|
||||
}
|
||||
return rule
|
||||
}
|
||||
|
||||
func (r *SourceHostnameItem) Match(metadata *adapter.InboundContext) bool {
|
||||
if metadata.SourceHostname == "" {
|
||||
return false
|
||||
}
|
||||
return r.hostnameMap[metadata.SourceHostname]
|
||||
}
|
||||
|
||||
func (r *SourceHostnameItem) String() string {
|
||||
var description string
|
||||
if len(r.hostnames) == 1 {
|
||||
description = "source_hostname=" + r.hostnames[0]
|
||||
} else {
|
||||
description = "source_hostname=[" + strings.Join(r.hostnames, " ") + "]"
|
||||
}
|
||||
return description
|
||||
}
|
||||
48
route/rule/rule_item_source_mac_address.go
Normal file
48
route/rule/rule_item_source_mac_address.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
)
|
||||
|
||||
var _ RuleItem = (*SourceMACAddressItem)(nil)
|
||||
|
||||
type SourceMACAddressItem struct {
|
||||
addresses []string
|
||||
addressMap map[string]bool
|
||||
}
|
||||
|
||||
func NewSourceMACAddressItem(addressList []string) *SourceMACAddressItem {
|
||||
rule := &SourceMACAddressItem{
|
||||
addresses: addressList,
|
||||
addressMap: make(map[string]bool),
|
||||
}
|
||||
for _, address := range addressList {
|
||||
parsed, err := net.ParseMAC(address)
|
||||
if err == nil {
|
||||
rule.addressMap[parsed.String()] = true
|
||||
} else {
|
||||
rule.addressMap[address] = true
|
||||
}
|
||||
}
|
||||
return rule
|
||||
}
|
||||
|
||||
func (r *SourceMACAddressItem) Match(metadata *adapter.InboundContext) bool {
|
||||
if metadata.SourceMACAddress == nil {
|
||||
return false
|
||||
}
|
||||
return r.addressMap[metadata.SourceMACAddress.String()]
|
||||
}
|
||||
|
||||
func (r *SourceMACAddressItem) String() string {
|
||||
var description string
|
||||
if len(r.addresses) == 1 {
|
||||
description = "source_mac_address=" + r.addresses[0]
|
||||
} else {
|
||||
description = "source_mac_address=[" + strings.Join(r.addresses, " ") + "]"
|
||||
}
|
||||
return description
|
||||
}
|
||||
@@ -45,6 +45,14 @@ func isProcessDNSRule(rule option.DefaultDNSRule) bool {
|
||||
return len(rule.ProcessName) > 0 || len(rule.ProcessPath) > 0 || len(rule.ProcessPathRegex) > 0 || len(rule.PackageName) > 0 || len(rule.User) > 0 || len(rule.UserID) > 0
|
||||
}
|
||||
|
||||
func isNeighborRule(rule option.DefaultRule) bool {
|
||||
return len(rule.SourceMACAddress) > 0 || len(rule.SourceHostname) > 0
|
||||
}
|
||||
|
||||
func isNeighborDNSRule(rule option.DefaultDNSRule) bool {
|
||||
return len(rule.SourceMACAddress) > 0 || len(rule.SourceHostname) > 0
|
||||
}
|
||||
|
||||
func isWIFIRule(rule option.DefaultRule) bool {
|
||||
return len(rule.WIFISSID) > 0 || len(rule.WIFIBSSID) > 0
|
||||
}
|
||||
|
||||
@@ -2,25 +2,74 @@ package ccm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
const (
|
||||
oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
oauth2TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
oauth2TokenURL = "https://platform.claude.com/v1/oauth/token"
|
||||
claudeAPIBaseURL = "https://api.anthropic.com"
|
||||
tokenRefreshBufferMs = 60000
|
||||
anthropicBetaOAuthValue = "oauth-2025-04-20"
|
||||
)
|
||||
|
||||
const ccmUserAgentFallback = "claude-code/2.1.72"
|
||||
|
||||
var (
|
||||
ccmUserAgentOnce sync.Once
|
||||
ccmUserAgentValue string
|
||||
)
|
||||
|
||||
func initCCMUserAgent(logger log.ContextLogger) {
|
||||
ccmUserAgentOnce.Do(func() {
|
||||
version, err := detectClaudeCodeVersion()
|
||||
if err != nil {
|
||||
logger.Error("detect Claude Code version: ", err)
|
||||
ccmUserAgentValue = ccmUserAgentFallback
|
||||
return
|
||||
}
|
||||
logger.Debug("detected Claude Code version: ", version)
|
||||
ccmUserAgentValue = "claude-code/" + version
|
||||
})
|
||||
}
|
||||
|
||||
func detectClaudeCodeVersion() (string, error) {
|
||||
userInfo, err := getRealUser()
|
||||
if err != nil {
|
||||
return "", E.Cause(err, "get user")
|
||||
}
|
||||
binaryName := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
binaryName = "claude.exe"
|
||||
}
|
||||
linkPath := filepath.Join(userInfo.HomeDir, ".local", "bin", binaryName)
|
||||
target, err := os.Readlink(linkPath)
|
||||
if err != nil {
|
||||
return "", E.Cause(err, "readlink ", linkPath)
|
||||
}
|
||||
if !filepath.IsAbs(target) {
|
||||
target = filepath.Join(filepath.Dir(linkPath), target)
|
||||
}
|
||||
parent := filepath.Base(filepath.Dir(target))
|
||||
if parent != "versions" {
|
||||
return "", E.New("unexpected symlink target: ", target)
|
||||
}
|
||||
return filepath.Base(target), nil
|
||||
}
|
||||
|
||||
func getRealUser() (*user.User, error) {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
sudoUserInfo, err := user.Lookup(sudoUser)
|
||||
@@ -60,6 +109,14 @@ func readCredentialsFromFile(path string) (*oauthCredentials, error) {
|
||||
return credentialsContainer.ClaudeAIAuth, nil
|
||||
}
|
||||
|
||||
func checkCredentialFileWritable(path string) error {
|
||||
file, err := os.OpenFile(path, os.O_WRONLY, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) error {
|
||||
data, err := json.MarshalIndent(map[string]any{
|
||||
"claudeAiOauth": oauthCredentials,
|
||||
@@ -76,6 +133,7 @@ type oauthCredentials struct {
|
||||
ExpiresAt int64 `json:"expiresAt"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
SubscriptionType string `json:"subscriptionType,omitempty"`
|
||||
RateLimitTier string `json:"rateLimitTier,omitempty"`
|
||||
IsMax bool `json:"isMax,omitempty"`
|
||||
}
|
||||
|
||||
@@ -86,7 +144,7 @@ func (c *oauthCredentials) needsRefresh() bool {
|
||||
return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs
|
||||
}
|
||||
|
||||
func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
|
||||
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
|
||||
if credentials.RefreshToken == "" {
|
||||
return nil, E.New("refresh token is empty")
|
||||
}
|
||||
@@ -100,19 +158,24 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut
|
||||
return nil, E.Cause(err, "marshal request")
|
||||
}
|
||||
|
||||
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "application/json")
|
||||
|
||||
response, err := httpClient.Do(request)
|
||||
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("User-Agent", ccmUserAgentValue)
|
||||
return request, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode == http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, E.New("refresh rate limited: ", response.Status, " ", string(body))
|
||||
}
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
|
||||
@@ -137,3 +200,25 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut
|
||||
|
||||
return &newCredentials, nil
|
||||
}
|
||||
|
||||
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
|
||||
if credentials == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *credentials
|
||||
cloned.Scopes = append([]string(nil), credentials.Scopes...)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
|
||||
if left == nil || right == nil {
|
||||
return left == right
|
||||
}
|
||||
return left.AccessToken == right.AccessToken &&
|
||||
left.RefreshToken == right.RefreshToken &&
|
||||
left.ExpiresAt == right.ExpiresAt &&
|
||||
slices.Equal(left.Scopes, right.Scopes) &&
|
||||
left.SubscriptionType == right.SubscriptionType &&
|
||||
left.RateLimitTier == right.RateLimitTier &&
|
||||
left.IsMax == right.IsMax
|
||||
}
|
||||
|
||||
@@ -69,6 +69,13 @@ func platformReadCredentials(customPath string) (*oauthCredentials, error) {
|
||||
return readCredentialsFromFile(defaultPath)
|
||||
}
|
||||
|
||||
func platformCanWriteCredentials(customPath string) error {
|
||||
if customPath == "" {
|
||||
return nil
|
||||
}
|
||||
return checkCredentialFileWritable(customPath)
|
||||
}
|
||||
|
||||
func platformWriteCredentials(oauthCredentials *oauthCredentials, customPath string) error {
|
||||
if customPath != "" {
|
||||
return writeCredentialsToFile(oauthCredentials, customPath)
|
||||
|
||||
676
service/ccm/credential_external.go
Normal file
676
service/ccm/credential_external.go
Normal file
@@ -0,0 +1,676 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
const reverseProxyBaseURL = "http://reverse-proxy"
|
||||
|
||||
type externalCredential struct {
|
||||
tag string
|
||||
baseURL string
|
||||
token string
|
||||
httpClient *http.Client
|
||||
state credentialState
|
||||
stateMutex sync.RWMutex
|
||||
pollAccess sync.Mutex
|
||||
pollInterval time.Duration
|
||||
usageTracker *AggregatedUsage
|
||||
logger log.ContextLogger
|
||||
|
||||
onBecameUnusable func()
|
||||
interrupted bool
|
||||
requestContext context.Context
|
||||
cancelRequests context.CancelFunc
|
||||
requestAccess sync.Mutex
|
||||
|
||||
// Reverse proxy fields
|
||||
reverse bool
|
||||
reverseSession *yamux.Session
|
||||
reverseAccess sync.RWMutex
|
||||
closed bool
|
||||
reverseContext context.Context
|
||||
reverseCancel context.CancelFunc
|
||||
connectorDialer N.Dialer
|
||||
connectorDestination M.Socksaddr
|
||||
connectorRequestPath string
|
||||
connectorURL *url.URL
|
||||
connectorTLS *stdTLS.Config
|
||||
reverseService http.Handler
|
||||
}
|
||||
|
||||
func externalCredentialURLPort(parsedURL *url.URL) uint16 {
|
||||
portStr := parsedURL.Port()
|
||||
if portStr != "" {
|
||||
port, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err == nil {
|
||||
return uint16(port)
|
||||
}
|
||||
}
|
||||
if parsedURL.Scheme == "https" {
|
||||
return 443
|
||||
}
|
||||
return 80
|
||||
}
|
||||
|
||||
func externalCredentialServerPort(parsedURL *url.URL, configuredPort uint16) uint16 {
|
||||
if configuredPort != 0 {
|
||||
return configuredPort
|
||||
}
|
||||
return externalCredentialURLPort(parsedURL)
|
||||
}
|
||||
|
||||
func externalCredentialBaseURL(parsedURL *url.URL) string {
|
||||
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
|
||||
if parsedURL.Path != "" && parsedURL.Path != "/" {
|
||||
baseURL += parsedURL.Path
|
||||
}
|
||||
if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
|
||||
baseURL = baseURL[:len(baseURL)-1]
|
||||
}
|
||||
return baseURL
|
||||
}
|
||||
|
||||
func externalCredentialReversePath(parsedURL *url.URL, endpointPath string) string {
|
||||
pathPrefix := parsedURL.EscapedPath()
|
||||
if pathPrefix == "/" {
|
||||
pathPrefix = ""
|
||||
} else {
|
||||
pathPrefix = strings.TrimSuffix(pathPrefix, "/")
|
||||
}
|
||||
return pathPrefix + endpointPath
|
||||
}
|
||||
|
||||
func newExternalCredential(ctx context.Context, tag string, options option.CCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) {
|
||||
pollInterval := time.Duration(options.PollInterval)
|
||||
if pollInterval <= 0 {
|
||||
pollInterval = 30 * time.Minute
|
||||
}
|
||||
|
||||
requestContext, cancelRequests := context.WithCancel(context.Background())
|
||||
reverseContext, reverseCancel := context.WithCancel(context.Background())
|
||||
|
||||
cred := &externalCredential{
|
||||
tag: tag,
|
||||
token: options.Token,
|
||||
pollInterval: pollInterval,
|
||||
logger: logger,
|
||||
requestContext: requestContext,
|
||||
cancelRequests: cancelRequests,
|
||||
reverse: options.Reverse,
|
||||
reverseContext: reverseContext,
|
||||
reverseCancel: reverseCancel,
|
||||
}
|
||||
|
||||
if options.URL == "" {
|
||||
// Receiver mode: no URL, wait for reverse connection
|
||||
cred.baseURL = reverseProxyBaseURL
|
||||
cred.httpClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
ForceAttemptHTTP2: false,
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return cred.openReverseConnection(ctx)
|
||||
},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// Normal or connector mode: has URL
|
||||
parsedURL, err := url.Parse(options.URL)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "parse url for credential ", tag)
|
||||
}
|
||||
|
||||
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
Options: option.DialerOptions{
|
||||
Detour: options.Detour,
|
||||
},
|
||||
RemoteIsDomain: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create dialer for credential ", tag)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if options.Server != "" {
|
||||
destination := M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort))
|
||||
return credentialDialer.DialContext(ctx, network, destination)
|
||||
}
|
||||
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
}
|
||||
|
||||
if parsedURL.Scheme == "https" {
|
||||
transport.TLSClientConfig = &stdTLS.Config{
|
||||
ServerName: parsedURL.Hostname(),
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
cred.baseURL = externalCredentialBaseURL(parsedURL)
|
||||
|
||||
if options.Reverse {
|
||||
// Connector mode: we dial out to serve, not to proxy
|
||||
cred.connectorDialer = credentialDialer
|
||||
if options.Server != "" {
|
||||
cred.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort))
|
||||
} else {
|
||||
cred.connectorDestination = M.ParseSocksaddrHostPort(parsedURL.Hostname(), externalCredentialURLPort(parsedURL))
|
||||
}
|
||||
cred.connectorRequestPath = externalCredentialReversePath(parsedURL, "/ccm/v1/reverse")
|
||||
cred.connectorURL = parsedURL
|
||||
if parsedURL.Scheme == "https" {
|
||||
cred.connectorTLS = &stdTLS.Config{
|
||||
ServerName: parsedURL.Hostname(),
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Normal mode: standard HTTP client for proxying
|
||||
cred.httpClient = &http.Client{Transport: transport}
|
||||
}
|
||||
}
|
||||
|
||||
if options.UsagesPath != "" {
|
||||
cred.usageTracker = &AggregatedUsage{
|
||||
LastUpdated: time.Now(),
|
||||
Combinations: make([]CostCombination, 0),
|
||||
filePath: options.UsagesPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) start() error {
|
||||
if c.usageTracker != nil {
|
||||
err := c.usageTracker.Load()
|
||||
if err != nil {
|
||||
c.logger.Warn("load usage statistics for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
if c.reverse && c.connectorURL != nil {
|
||||
go c.connectorLoop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) tagName() string {
|
||||
return c.tag
|
||||
}
|
||||
|
||||
func (c *externalCredential) isExternal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *externalCredential) isAvailable() bool {
|
||||
return c.unavailableError() == nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) isUsable() bool {
|
||||
if !c.isAvailable() {
|
||||
return false
|
||||
}
|
||||
c.stateMutex.RLock()
|
||||
if c.state.consecutivePollFailures > 0 {
|
||||
c.stateMutex.RUnlock()
|
||||
return false
|
||||
}
|
||||
if c.state.hardRateLimited {
|
||||
if time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.stateMutex.RUnlock()
|
||||
return false
|
||||
}
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateMutex.Lock()
|
||||
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
// No reserve for external: only 100% is unusable
|
||||
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
|
||||
c.stateMutex.Unlock()
|
||||
return usable
|
||||
}
|
||||
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
|
||||
c.stateMutex.RUnlock()
|
||||
return usable
|
||||
}
|
||||
|
||||
func (c *externalCredential) fiveHourUtilization() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return c.state.fiveHourUtilization
|
||||
}
|
||||
|
||||
func (c *externalCredential) weeklyUtilization() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return c.state.weeklyUtilization
|
||||
}
|
||||
|
||||
func (c *externalCredential) fiveHourCap() float64 {
|
||||
return 100
|
||||
}
|
||||
|
||||
func (c *externalCredential) weeklyCap() float64 {
|
||||
return 100
|
||||
}
|
||||
|
||||
func (c *externalCredential) planWeight() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
if c.state.remotePlanWeight > 0 {
|
||||
return c.state.remotePlanWeight
|
||||
}
|
||||
return 10
|
||||
}
|
||||
|
||||
func (c *externalCredential) weeklyResetTime() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return c.state.weeklyReset
|
||||
}
|
||||
|
||||
func (c *externalCredential) markRateLimited(resetAt time.Time) {
|
||||
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
|
||||
c.stateMutex.Lock()
|
||||
c.state.hardRateLimited = true
|
||||
c.state.rateLimitResetAt = resetAt
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) earliestReset() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
if c.state.hardRateLimited {
|
||||
return c.state.rateLimitResetAt
|
||||
}
|
||||
earliest := c.state.fiveHourReset
|
||||
if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) {
|
||||
earliest = c.state.weeklyReset
|
||||
}
|
||||
return earliest
|
||||
}
|
||||
|
||||
func (c *externalCredential) unavailableError() error {
|
||||
if c.reverse && c.connectorURL != nil {
|
||||
return E.New("credential ", c.tag, " is unavailable: reverse connector credentials cannot serve local requests")
|
||||
}
|
||||
if c.baseURL == reverseProxyBaseURL {
|
||||
session := c.getReverseSession()
|
||||
if session == nil || session.IsClosed() {
|
||||
return E.New("credential ", c.tag, " is unavailable: reverse connection not established")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) getAccessToken() (string, error) {
|
||||
return c.token, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, _ http.Header) (*http.Request, error) {
|
||||
proxyURL := c.baseURL + original.URL.RequestURI()
|
||||
var body io.Reader
|
||||
if bodyBytes != nil {
|
||||
body = bytes.NewReader(bodyBytes)
|
||||
} else {
|
||||
body = original.Body
|
||||
}
|
||||
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, values := range original.Header {
|
||||
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+c.token)
|
||||
|
||||
return proxyRequest, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Conn, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
session := c.getReverseSession()
|
||||
if session == nil || session.IsClosed() {
|
||||
return nil, E.New("reverse connection not established for ", c.tag)
|
||||
}
|
||||
conn, err := session.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
conn.Close()
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
|
||||
c.stateMutex.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
hadData := false
|
||||
|
||||
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-5h-reset"); exists {
|
||||
hadData = true
|
||||
c.state.fiveHourReset = value
|
||||
}
|
||||
if utilization := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilization != "" {
|
||||
value, err := strconv.ParseFloat(utilization, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
c.state.fiveHourUtilization = value * 100
|
||||
}
|
||||
}
|
||||
|
||||
if value, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset"); exists {
|
||||
hadData = true
|
||||
c.state.weeklyReset = value
|
||||
}
|
||||
if utilization := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilization != "" {
|
||||
value, err := strconv.ParseFloat(utilization, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
c.state.weeklyUtilization = value * 100
|
||||
}
|
||||
}
|
||||
if planWeight := headers.Get("X-CCM-Plan-Weight"); planWeight != "" {
|
||||
value, err := strconv.ParseFloat(planWeight, 64)
|
||||
if err == nil && value > 0 {
|
||||
c.state.remotePlanWeight = value
|
||||
}
|
||||
}
|
||||
if hadData {
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.lastUpdated = time.Now()
|
||||
}
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
resetSuffix := ""
|
||||
if !c.state.weeklyReset.IsZero() {
|
||||
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
|
||||
}
|
||||
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) checkTransitionLocked() bool {
|
||||
unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100 || c.state.consecutivePollFailures > 0
|
||||
if unusable && !c.interrupted {
|
||||
c.interrupted = true
|
||||
return true
|
||||
}
|
||||
if !unusable && c.interrupted {
|
||||
c.interrupted = false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *externalCredential) wrapRequestContext(parent context.Context) *credentialRequestContext {
|
||||
c.requestAccess.Lock()
|
||||
credentialContext := c.requestContext
|
||||
c.requestAccess.Unlock()
|
||||
derived, cancel := context.WithCancel(parent)
|
||||
stop := context.AfterFunc(credentialContext, func() {
|
||||
cancel()
|
||||
})
|
||||
return &credentialRequestContext{
|
||||
Context: derived,
|
||||
releaseFunc: stop,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) interruptConnections() {
|
||||
c.logger.Warn("interrupting connections for ", c.tag)
|
||||
c.requestAccess.Lock()
|
||||
c.cancelRequests()
|
||||
c.requestContext, c.cancelRequests = context.WithCancel(context.Background())
|
||||
c.requestAccess.Unlock()
|
||||
if c.onBecameUnusable != nil {
|
||||
c.onBecameUnusable()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) pollUsage(ctx context.Context) {
|
||||
if !c.pollAccess.TryLock() {
|
||||
return
|
||||
}
|
||||
defer c.pollAccess.Unlock()
|
||||
defer c.markUsagePollAttempted()
|
||||
|
||||
statusURL := c.baseURL + "/ccm/v1/status"
|
||||
httpClient := &http.Client{
|
||||
Transport: c.httpClient.Transport,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer "+c.token)
|
||||
return request, nil
|
||||
})
|
||||
if err != nil {
|
||||
c.logger.Error("poll usage for ", c.tag, ": ", err)
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
// 404 means the remote does not have a status endpoint yet;
|
||||
// usage will be updated passively from response headers.
|
||||
if response.StatusCode == http.StatusNotFound {
|
||||
c.stateMutex.Lock()
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
} else {
|
||||
c.incrementPollFailures()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var statusResponse struct {
|
||||
FiveHourUtilization float64 `json:"five_hour_utilization"`
|
||||
WeeklyUtilization float64 `json:"weekly_utilization"`
|
||||
PlanWeight float64 `json:"plan_weight"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&statusResponse)
|
||||
if err != nil {
|
||||
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
|
||||
c.stateMutex.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
|
||||
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
|
||||
if statusResponse.PlanWeight > 0 {
|
||||
c.state.remotePlanWeight = statusResponse.PlanWeight
|
||||
}
|
||||
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
resetSuffix := ""
|
||||
if !c.state.weeklyReset.IsZero() {
|
||||
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
|
||||
}
|
||||
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) lastUpdatedTime() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return c.state.lastUpdated
|
||||
}
|
||||
|
||||
func (c *externalCredential) markUsagePollAttempted() {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
c.state.lastUpdated = time.Now()
|
||||
}
|
||||
|
||||
func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration {
|
||||
c.stateMutex.RLock()
|
||||
failures := c.state.consecutivePollFailures
|
||||
c.stateMutex.RUnlock()
|
||||
if failures <= 0 {
|
||||
return baseInterval
|
||||
}
|
||||
return failedPollRetryInterval
|
||||
}
|
||||
|
||||
func (c *externalCredential) incrementPollFailures() {
|
||||
c.stateMutex.Lock()
|
||||
c.state.consecutivePollFailures++
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage {
|
||||
return c.usageTracker
|
||||
}
|
||||
|
||||
func (c *externalCredential) httpTransport() *http.Client {
|
||||
return c.httpClient
|
||||
}
|
||||
|
||||
func (c *externalCredential) close() {
|
||||
var session *yamux.Session
|
||||
c.reverseAccess.Lock()
|
||||
if !c.closed {
|
||||
c.closed = true
|
||||
if c.reverseCancel != nil {
|
||||
c.reverseCancel()
|
||||
}
|
||||
session = c.reverseSession
|
||||
c.reverseSession = nil
|
||||
}
|
||||
c.reverseAccess.Unlock()
|
||||
if session != nil {
|
||||
session.Close()
|
||||
}
|
||||
if c.usageTracker != nil {
|
||||
c.usageTracker.cancelPendingSave()
|
||||
err := c.usageTracker.Save()
|
||||
if err != nil {
|
||||
c.logger.Error("save usage statistics for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) getReverseSession() *yamux.Session {
|
||||
c.reverseAccess.RLock()
|
||||
defer c.reverseAccess.RUnlock()
|
||||
return c.reverseSession
|
||||
}
|
||||
|
||||
func (c *externalCredential) setReverseSession(session *yamux.Session) bool {
|
||||
c.reverseAccess.Lock()
|
||||
if c.closed {
|
||||
c.reverseAccess.Unlock()
|
||||
return false
|
||||
}
|
||||
old := c.reverseSession
|
||||
c.reverseSession = session
|
||||
c.reverseAccess.Unlock()
|
||||
if old != nil {
|
||||
old.Close()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *externalCredential) clearReverseSession(session *yamux.Session) {
|
||||
c.reverseAccess.Lock()
|
||||
if c.reverseSession == session {
|
||||
c.reverseSession = nil
|
||||
}
|
||||
c.reverseAccess.Unlock()
|
||||
}
|
||||
|
||||
func (c *externalCredential) getReverseContext() context.Context {
|
||||
c.reverseAccess.RLock()
|
||||
defer c.reverseAccess.RUnlock()
|
||||
return c.reverseContext
|
||||
}
|
||||
|
||||
func (c *externalCredential) resetReverseContext() {
|
||||
c.reverseAccess.Lock()
|
||||
if c.closed {
|
||||
c.reverseAccess.Unlock()
|
||||
return
|
||||
}
|
||||
c.reverseCancel()
|
||||
c.reverseContext, c.reverseCancel = context.WithCancel(context.Background())
|
||||
c.reverseAccess.Unlock()
|
||||
}
|
||||
143
service/ccm/credential_file.go
Normal file
143
service/ccm/credential_file.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/fswatch"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
const credentialReloadRetryInterval = 2 * time.Second
|
||||
|
||||
func resolveCredentialFilePath(customPath string) (string, error) {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if filepath.IsAbs(customPath) {
|
||||
return customPath, nil
|
||||
}
|
||||
return filepath.Abs(customPath)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ensureCredentialWatcher() error {
|
||||
c.watcherAccess.Lock()
|
||||
defer c.watcherAccess.Unlock()
|
||||
|
||||
if c.watcher != nil || c.credentialFilePath == "" {
|
||||
return nil
|
||||
}
|
||||
if !c.watcherRetryAt.IsZero() && time.Now().Before(c.watcherRetryAt) {
|
||||
return nil
|
||||
}
|
||||
|
||||
watcher, err := fswatch.NewWatcher(fswatch.Options{
|
||||
Path: []string{c.credentialFilePath},
|
||||
Logger: c.logger,
|
||||
Callback: func(string) {
|
||||
err := c.reloadCredentials(true)
|
||||
if err != nil {
|
||||
c.logger.Warn("reload credentials for ", c.tag, ": ", err)
|
||||
}
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval)
|
||||
return err
|
||||
}
|
||||
|
||||
err = watcher.Start()
|
||||
if err != nil {
|
||||
c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval)
|
||||
return err
|
||||
}
|
||||
|
||||
c.watcher = watcher
|
||||
c.watcherRetryAt = time.Time{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) retryCredentialReloadIfNeeded() {
|
||||
c.stateMutex.RLock()
|
||||
unavailable := c.state.unavailable
|
||||
lastAttempt := c.state.lastCredentialLoadAttempt
|
||||
c.stateMutex.RUnlock()
|
||||
if !unavailable {
|
||||
return
|
||||
}
|
||||
if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval {
|
||||
return
|
||||
}
|
||||
|
||||
err := c.ensureCredentialWatcher()
|
||||
if err != nil {
|
||||
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
_ = c.reloadCredentials(false)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) reloadCredentials(force bool) error {
|
||||
c.reloadAccess.Lock()
|
||||
defer c.reloadAccess.Unlock()
|
||||
|
||||
c.stateMutex.RLock()
|
||||
unavailable := c.state.unavailable
|
||||
lastAttempt := c.state.lastCredentialLoadAttempt
|
||||
c.stateMutex.RUnlock()
|
||||
if !force {
|
||||
if !unavailable {
|
||||
return nil
|
||||
}
|
||||
if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval {
|
||||
return c.unavailableError()
|
||||
}
|
||||
}
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.stateMutex.Unlock()
|
||||
|
||||
credentials, err := platformReadCredentials(c.credentialPath)
|
||||
if err != nil {
|
||||
return c.markCredentialsUnavailable(E.Cause(err, "read credentials"))
|
||||
}
|
||||
|
||||
c.accessMutex.Lock()
|
||||
c.credentials = credentials
|
||||
c.accessMutex.Unlock()
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.state.accountType = credentials.SubscriptionType
|
||||
c.state.rateLimitTier = credentials.RateLimitTier
|
||||
c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markCredentialsUnavailable(err error) error {
|
||||
c.accessMutex.Lock()
|
||||
hadCredentials := c.credentials != nil
|
||||
c.credentials = nil
|
||||
c.accessMutex.Unlock()
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.state.unavailable = true
|
||||
c.state.lastCredentialLoadError = err.Error()
|
||||
c.state.accountType = ""
|
||||
c.state.rateLimitTier = ""
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
|
||||
if shouldInterrupt && hadCredentials {
|
||||
c.interruptConnections()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -13,6 +13,17 @@ func platformReadCredentials(customPath string) (*oauthCredentials, error) {
|
||||
return readCredentialsFromFile(customPath)
|
||||
}
|
||||
|
||||
func platformCanWriteCredentials(customPath string) error {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return checkCredentialFileWritable(customPath)
|
||||
}
|
||||
|
||||
func platformWriteCredentials(oauthCredentials *oauthCredentials, customPath string) error {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
|
||||
1410
service/ccm/credential_state.go
Normal file
1410
service/ccm/credential_state.go
Normal file
File diff suppressed because it is too large
Load Diff
259
service/ccm/reverse.go
Normal file
259
service/ccm/reverse.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
func reverseYamuxConfig() *yamux.Config {
|
||||
config := yamux.DefaultConfig()
|
||||
config.KeepAliveInterval = 15 * time.Second
|
||||
config.ConnectionWriteTimeout = 10 * time.Second
|
||||
config.MaxStreamWindowSize = 512 * 1024
|
||||
config.LogOutput = io.Discard
|
||||
return config
|
||||
}
|
||||
|
||||
type bufferedConn struct {
|
||||
reader *bufio.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
type yamuxNetListener struct {
|
||||
session *yamux.Session
|
||||
}
|
||||
|
||||
func (l *yamuxNetListener) Accept() (net.Conn, error) {
|
||||
return l.session.Accept()
|
||||
}
|
||||
|
||||
func (l *yamuxNetListener) Close() error {
|
||||
return l.session.Close()
|
||||
}
|
||||
|
||||
func (l *yamuxNetListener) Addr() net.Addr {
|
||||
return l.session.Addr()
|
||||
}
|
||||
|
||||
func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Upgrade") != "reverse-proxy" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header")
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
|
||||
receiverCredential := s.findReceiverCredential(clientToken)
|
||||
if receiverCredential == nil {
|
||||
s.logger.WarnContext(ctx, "reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token")
|
||||
return
|
||||
}
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
s.logger.ErrorContext(ctx, "reverse connect: hijack not supported")
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported")
|
||||
return
|
||||
}
|
||||
|
||||
conn, bufferedReadWriter, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "reverse connect: hijack: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
response := "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: reverse-proxy\r\n\r\n"
|
||||
_, err = bufferedReadWriter.WriteString(response)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.ErrorContext(ctx, "reverse connect: write upgrade response: ", err)
|
||||
return
|
||||
}
|
||||
err = bufferedReadWriter.Flush()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.ErrorContext(ctx, "reverse connect: flush upgrade response: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := yamux.Client(conn, reverseYamuxConfig())
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !receiverCredential.setReverseSession(session) {
|
||||
session.Close()
|
||||
return
|
||||
}
|
||||
s.logger.InfoContext(ctx, "reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
|
||||
|
||||
go func() {
|
||||
<-session.CloseChan()
|
||||
receiverCredential.clearReverseSession(session)
|
||||
s.logger.WarnContext(ctx, "reverse connection lost for ", receiverCredential.tagName())
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) findReceiverCredential(token string) *externalCredential {
|
||||
for _, cred := range s.allCredentials {
|
||||
extCred, ok := cred.(*externalCredential)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if extCred.baseURL == reverseProxyBaseURL && extCred.token == token {
|
||||
return extCred
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) connectorLoop() {
|
||||
var consecutiveFailures int
|
||||
ctx := c.getReverseContext()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
sessionLifetime, err := c.connectorConnect(ctx)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if sessionLifetime >= connectorBackoffResetThreshold {
|
||||
consecutiveFailures = 0
|
||||
}
|
||||
consecutiveFailures++
|
||||
backoff := connectorBackoff(consecutiveFailures)
|
||||
c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff)
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const connectorBackoffResetThreshold = time.Minute
|
||||
|
||||
func connectorBackoff(failures int) time.Duration {
|
||||
if failures > 5 {
|
||||
failures = 5
|
||||
}
|
||||
base := time.Second * time.Duration(1<<failures)
|
||||
if base > 30*time.Second {
|
||||
base = 30 * time.Second
|
||||
}
|
||||
jitter := time.Duration(rand.Int64N(int64(base) / 2))
|
||||
return base + jitter
|
||||
}
|
||||
|
||||
func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duration, error) {
|
||||
if c.reverseService == nil {
|
||||
return 0, E.New("reverse service not initialized")
|
||||
}
|
||||
destination := c.connectorResolveDestination()
|
||||
conn, err := c.connectorDialer.DialContext(ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
return 0, E.Cause(err, "dial")
|
||||
}
|
||||
|
||||
if c.connectorTLS != nil {
|
||||
tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone())
|
||||
err = tlsConn.HandshakeContext(ctx)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "tls handshake")
|
||||
}
|
||||
conn = tlsConn
|
||||
}
|
||||
|
||||
upgradeRequest := "GET " + c.connectorRequestPath + " HTTP/1.1\r\n" +
|
||||
"Host: " + c.connectorURL.Host + "\r\n" +
|
||||
"Connection: Upgrade\r\n" +
|
||||
"Upgrade: reverse-proxy\r\n" +
|
||||
"Authorization: Bearer " + c.token + "\r\n" +
|
||||
"\r\n"
|
||||
_, err = io.WriteString(conn, upgradeRequest)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "write upgrade request")
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
statusLine, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "read upgrade response")
|
||||
}
|
||||
if !strings.HasPrefix(statusLine, "HTTP/1.1 101") {
|
||||
conn.Close()
|
||||
return 0, E.New("unexpected upgrade response: ", strings.TrimSpace(statusLine))
|
||||
}
|
||||
for {
|
||||
line, readErr := reader.ReadString('\n')
|
||||
if readErr != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(readErr, "read upgrade headers")
|
||||
}
|
||||
if strings.TrimSpace(line) == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
session, err := yamux.Server(&bufferedConn{reader: reader, Conn: conn}, reverseYamuxConfig())
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "create yamux server")
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
c.logger.Info("reverse connection established for ", c.tag)
|
||||
|
||||
serveStart := time.Now()
|
||||
httpServer := &http.Server{
|
||||
Handler: c.reverseService,
|
||||
ReadTimeout: 0,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
err = httpServer.Serve(&yamuxNetListener{session: session})
|
||||
sessionLifetime := time.Since(serveStart)
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) && ctx.Err() == nil {
|
||||
return sessionLifetime, E.Cause(err, "serve")
|
||||
}
|
||||
return sessionLifetime, E.New("connection closed")
|
||||
}
|
||||
|
||||
func (c *externalCredential) connectorResolveDestination() M.Socksaddr {
|
||||
return c.connectorDestination
|
||||
}
|
||||
@@ -3,12 +3,10 @@ package ccm
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -17,7 +15,6 @@ import (
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
boxService "github.com/sagernet/sing-box/adapter/service"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/common/listener"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
@@ -26,20 +23,20 @@ import (
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
|
||||
const (
|
||||
contextWindowStandard = 200000
|
||||
contextWindowPremium = 1000000
|
||||
premiumContextThreshold = 200000
|
||||
retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential"
|
||||
)
|
||||
|
||||
func RegisterService(registry *boxService.Registry) {
|
||||
@@ -60,7 +57,6 @@ type errorDetails struct {
|
||||
func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
json.NewEncoder(w).Encode(errorResponse{
|
||||
Type: "error",
|
||||
Error: errorDetails{
|
||||
@@ -71,6 +67,58 @@ func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, erro
|
||||
})
|
||||
}
|
||||
|
||||
func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool {
|
||||
if provider == nil || currentCredential == nil {
|
||||
return false
|
||||
}
|
||||
for _, cred := range provider.allCredentials() {
|
||||
if cred == currentCredential {
|
||||
continue
|
||||
}
|
||||
if filter != nil && !filter(cred) {
|
||||
continue
|
||||
}
|
||||
if cred.isUsable() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func unavailableCredentialMessage(provider credentialProvider, fallback string) string {
|
||||
if provider == nil {
|
||||
return fallback
|
||||
}
|
||||
message := allCredentialsUnavailableError(provider.allCredentials()).Error()
|
||||
if message == "all credentials unavailable" && fallback != "" {
|
||||
return fallback
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSONError(w, r, http.StatusTooManyRequests, "rate_limit_error", retryableUsageMessage)
|
||||
}
|
||||
|
||||
func writeNonRetryableCredentialError(w http.ResponseWriter, r *http.Request, message string) {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", message)
|
||||
}
|
||||
|
||||
func writeCredentialUnavailableError(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
provider credentialProvider,
|
||||
currentCredential credential,
|
||||
filter func(credential) bool,
|
||||
fallback string,
|
||||
) {
|
||||
if hasAlternativeCredential(provider, currentCredential, filter) {
|
||||
writeRetryableUsageError(w, r)
|
||||
return
|
||||
}
|
||||
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, fallback))
|
||||
}
|
||||
|
||||
func isHopByHopHeader(header string) bool {
|
||||
switch strings.ToLower(header) {
|
||||
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
|
||||
@@ -80,109 +128,111 @@ func isHopByHopHeader(header string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func isReverseProxyHeader(header string) bool {
|
||||
lowerHeader := strings.ToLower(header)
|
||||
if strings.HasPrefix(lowerHeader, "cf-") {
|
||||
return true
|
||||
}
|
||||
switch lowerHeader {
|
||||
case "cdn-loop", "true-client-ip", "x-forwarded-for", "x-forwarded-proto", "x-real-ip":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
weeklyWindowSeconds = 604800
|
||||
weeklyWindowMinutes = weeklyWindowSeconds / 60
|
||||
)
|
||||
|
||||
func parseInt64Header(headers http.Header, headerName string) (int64, bool) {
|
||||
headerValue := strings.TrimSpace(headers.Get(headerName))
|
||||
if headerValue == "" {
|
||||
return 0, false
|
||||
}
|
||||
parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64)
|
||||
if parseError != nil {
|
||||
return 0, false
|
||||
}
|
||||
return parsedValue, true
|
||||
}
|
||||
|
||||
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
|
||||
resetAtUnix, hasResetAt := parseInt64Header(headers, "anthropic-ratelimit-unified-7d-reset")
|
||||
if !hasResetAt || resetAtUnix <= 0 {
|
||||
resetAt, exists := parseOptionalAnthropicResetHeader(headers, "anthropic-ratelimit-unified-7d-reset")
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &WeeklyCycleHint{
|
||||
WindowMinutes: weeklyWindowMinutes,
|
||||
ResetAt: time.Unix(resetAtUnix, 0).UTC(),
|
||||
ResetAt: resetAt.UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
boxService.Adapter
|
||||
ctx context.Context
|
||||
logger log.ContextLogger
|
||||
credentialPath string
|
||||
credentials *oauthCredentials
|
||||
users []option.CCMUser
|
||||
httpClient *http.Client
|
||||
httpHeaders http.Header
|
||||
listener *listener.Listener
|
||||
tlsConfig tls.ServerConfig
|
||||
httpServer *http.Server
|
||||
userManager *UserManager
|
||||
accessMutex sync.RWMutex
|
||||
usageTracker *AggregatedUsage
|
||||
trackingGroup sync.WaitGroup
|
||||
shuttingDown bool
|
||||
ctx context.Context
|
||||
logger log.ContextLogger
|
||||
options option.CCMServiceOptions
|
||||
httpHeaders http.Header
|
||||
listener *listener.Listener
|
||||
tlsConfig tls.ServerConfig
|
||||
httpServer *http.Server
|
||||
userManager *UserManager
|
||||
trackingGroup sync.WaitGroup
|
||||
shuttingDown bool
|
||||
|
||||
// Legacy mode (single credential)
|
||||
legacyCredential *defaultCredential
|
||||
legacyProvider credentialProvider
|
||||
|
||||
// Multi-credential mode
|
||||
providers map[string]credentialProvider
|
||||
allCredentials []credential
|
||||
userConfigMap map[string]*option.CCMUser
|
||||
}
|
||||
|
||||
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) {
|
||||
serviceDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
Options: option.DialerOptions{
|
||||
Detour: options.Detour,
|
||||
},
|
||||
RemoteIsDomain: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create dialer")
|
||||
}
|
||||
initCCMUserAgent(logger)
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSClientConfig: &stdTLS.Config{
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
},
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
},
|
||||
err := validateCCMOptions(options)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "validate options")
|
||||
}
|
||||
|
||||
userManager := &UserManager{
|
||||
tokenMap: make(map[string]string),
|
||||
}
|
||||
|
||||
var usageTracker *AggregatedUsage
|
||||
if options.UsagesPath != "" {
|
||||
usageTracker = &AggregatedUsage{
|
||||
LastUpdated: time.Now(),
|
||||
Combinations: make([]CostCombination, 0),
|
||||
filePath: options.UsagesPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
service := &Service{
|
||||
Adapter: boxService.NewAdapter(C.TypeCCM, tag),
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
credentialPath: options.CredentialPath,
|
||||
users: options.Users,
|
||||
httpClient: httpClient,
|
||||
httpHeaders: options.Headers.Build(),
|
||||
Adapter: boxService.NewAdapter(C.TypeCCM, tag),
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
options: options,
|
||||
httpHeaders: options.Headers.Build(),
|
||||
listener: listener.New(listener.Options{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
Network: []string{N.NetworkTCP},
|
||||
Listen: options.ListenOptions,
|
||||
}),
|
||||
userManager: userManager,
|
||||
usageTracker: usageTracker,
|
||||
userManager: userManager,
|
||||
}
|
||||
|
||||
if len(options.Credentials) > 0 {
|
||||
providers, allCredentials, err := buildCredentialProviders(ctx, options, logger)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "build credential providers")
|
||||
}
|
||||
service.providers = providers
|
||||
service.allCredentials = allCredentials
|
||||
|
||||
userConfigMap := make(map[string]*option.CCMUser)
|
||||
for i := range options.Users {
|
||||
userConfigMap[options.Users[i].Name] = &options.Users[i]
|
||||
}
|
||||
service.userConfigMap = userConfigMap
|
||||
} else {
|
||||
cred, err := newDefaultCredential(ctx, "default", option.CCMDefaultCredentialOptions{
|
||||
CredentialPath: options.CredentialPath,
|
||||
UsagesPath: options.UsagesPath,
|
||||
Detour: options.Detour,
|
||||
}, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
service.legacyCredential = cred
|
||||
service.legacyProvider = &singleCredentialProvider{cred: cred}
|
||||
service.allCredentials = []credential{cred}
|
||||
}
|
||||
|
||||
if options.TLS != nil {
|
||||
@@ -201,28 +251,25 @@ func (s *Service) Start(stage adapter.StartStage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.userManager.UpdateUsers(s.users)
|
||||
s.userManager.UpdateUsers(s.options.Users)
|
||||
|
||||
credentials, err := platformReadCredentials(s.credentialPath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read credentials")
|
||||
}
|
||||
s.credentials = credentials
|
||||
|
||||
if s.usageTracker != nil {
|
||||
err = s.usageTracker.Load()
|
||||
for _, cred := range s.allCredentials {
|
||||
if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil {
|
||||
extCred.reverseService = s
|
||||
}
|
||||
err := cred.start()
|
||||
if err != nil {
|
||||
s.logger.Warn("load usage statistics: ", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Mount("/", s)
|
||||
|
||||
s.httpServer = &http.Server{Handler: router}
|
||||
s.httpServer = &http.Server{Handler: h2c.NewHandler(router, &http2.Server{})}
|
||||
|
||||
if s.tlsConfig != nil {
|
||||
err = s.tlsConfig.Start()
|
||||
err := s.tlsConfig.Start()
|
||||
if err != nil {
|
||||
return E.Cause(err, "create TLS config")
|
||||
}
|
||||
@@ -250,155 +297,257 @@ func (s *Service) Start(stage adapter.StartStage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) getAccessToken() (string, error) {
|
||||
s.accessMutex.RLock()
|
||||
if !s.credentials.needsRefresh() {
|
||||
token := s.credentials.AccessToken
|
||||
s.accessMutex.RUnlock()
|
||||
return token, nil
|
||||
func isExtendedContextRequest(betaHeader string) bool {
|
||||
for _, feature := range strings.Split(betaHeader, ",") {
|
||||
if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
s.accessMutex.RUnlock()
|
||||
return false
|
||||
}
|
||||
|
||||
s.accessMutex.Lock()
|
||||
defer s.accessMutex.Unlock()
|
||||
|
||||
if !s.credentials.needsRefresh() {
|
||||
return s.credentials.AccessToken, nil
|
||||
func isFastModeRequest(betaHeader string) bool {
|
||||
for _, feature := range strings.Split(betaHeader, ",") {
|
||||
if strings.HasPrefix(strings.TrimSpace(feature), "fast-mode") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
newCredentials, err := refreshToken(s.httpClient, s.credentials)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.credentials = newCredentials
|
||||
|
||||
err = platformWriteCredentials(newCredentials, s.credentialPath)
|
||||
if err != nil {
|
||||
s.logger.Warn("persist refreshed token: ", err)
|
||||
}
|
||||
|
||||
return newCredentials.AccessToken, nil
|
||||
return false
|
||||
}
|
||||
|
||||
func detectContextWindow(betaHeader string, totalInputTokens int64) int {
|
||||
if totalInputTokens > premiumContextThreshold {
|
||||
features := strings.Split(betaHeader, ",")
|
||||
for _, feature := range features {
|
||||
if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") {
|
||||
return contextWindowPremium
|
||||
}
|
||||
if isExtendedContextRequest(betaHeader) {
|
||||
return contextWindowPremium
|
||||
}
|
||||
}
|
||||
return contextWindowStandard
|
||||
}
|
||||
|
||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := log.ContextWithNewID(r.Context())
|
||||
if r.URL.Path == "/ccm/v1/status" {
|
||||
s.handleStatusEndpoint(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/ccm/v1/reverse" {
|
||||
s.handleReverseConnect(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(r.URL.Path, "/v1/") {
|
||||
writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found")
|
||||
return
|
||||
}
|
||||
|
||||
var username string
|
||||
if len(s.users) > 0 {
|
||||
if len(s.options.Users) > 0 {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
var ok bool
|
||||
username, ok = s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Always read body to extract model and session ID
|
||||
var bodyBytes []byte
|
||||
var requestModel string
|
||||
var messagesCount int
|
||||
var sessionID string
|
||||
|
||||
if s.usageTracker != nil && r.Body != nil {
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err == nil {
|
||||
var request struct {
|
||||
Model string `json:"model"`
|
||||
Messages []anthropic.MessageParam `json:"messages"`
|
||||
}
|
||||
err := json.Unmarshal(bodyBytes, &request)
|
||||
if err == nil {
|
||||
requestModel = request.Model
|
||||
messagesCount = len(request.Messages)
|
||||
}
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
if r.Body != nil {
|
||||
var err error
|
||||
bodyBytes, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "read request body: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
var request struct {
|
||||
Model string `json:"model"`
|
||||
Messages []anthropic.MessageParam `json:"messages"`
|
||||
}
|
||||
err = json.Unmarshal(bodyBytes, &request)
|
||||
if err == nil {
|
||||
requestModel = request.Model
|
||||
messagesCount = len(request.Messages)
|
||||
}
|
||||
|
||||
sessionID = extractCCMSessionID(bodyBytes)
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}
|
||||
|
||||
accessToken, err := s.getAccessToken()
|
||||
if err != nil {
|
||||
s.logger.Error("get access token: ", err)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed")
|
||||
// Resolve credential provider and user config
|
||||
var provider credentialProvider
|
||||
var userConfig *option.CCMUser
|
||||
if len(s.options.Users) > 0 {
|
||||
userConfig = s.userConfigMap[username]
|
||||
var err error
|
||||
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "resolve credential: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
|
||||
}
|
||||
if provider == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
|
||||
return
|
||||
}
|
||||
|
||||
proxyURL := claudeAPIBaseURL + r.URL.RequestURI()
|
||||
proxyRequest, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, r.Body)
|
||||
provider.pollIfStale(s.ctx)
|
||||
|
||||
anthropicBetaHeader := r.Header.Get("anthropic-beta")
|
||||
if isFastModeRequest(anthropicBetaHeader) {
|
||||
if _, isSingle := provider.(*singleCredentialProvider); !isSingle {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"fast mode requests will consume Extra usage, please use a default credential directly")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var credentialFilter func(credential) bool
|
||||
if userConfig != nil && !userConfig.AllowExternalUsage {
|
||||
credentialFilter = func(c credential) bool { return !c.isExternal() }
|
||||
}
|
||||
|
||||
selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter)
|
||||
if err != nil {
|
||||
s.logger.Error("create proxy request: ", err)
|
||||
writeNonRetryableCredentialError(w, r, unavailableCredentialMessage(provider, err.Error()))
|
||||
return
|
||||
}
|
||||
if isNew {
|
||||
logParts := []any{"assigned credential ", selectedCredential.tagName()}
|
||||
if sessionID != "" {
|
||||
logParts = append(logParts, " for session ", sessionID)
|
||||
}
|
||||
if username != "" {
|
||||
logParts = append(logParts, " by user ", username)
|
||||
}
|
||||
if requestModel != "" {
|
||||
modelDisplay := requestModel
|
||||
if isExtendedContextRequest(anthropicBetaHeader) {
|
||||
modelDisplay += "[1m]"
|
||||
}
|
||||
logParts = append(logParts, ", model=", modelDisplay)
|
||||
}
|
||||
s.logger.DebugContext(ctx, logParts...)
|
||||
}
|
||||
|
||||
if isFastModeRequest(anthropicBetaHeader) && selectedCredential.isExternal() {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"fast mode requests cannot be proxied through external credentials")
|
||||
return
|
||||
}
|
||||
|
||||
requestContext := selectedCredential.wrapRequestContext(r.Context())
|
||||
defer func() {
|
||||
requestContext.cancelRequest()
|
||||
}()
|
||||
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "create proxy request: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
for key, values := range r.Header {
|
||||
if !isHopByHopHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
serviceOverridesAcceptEncoding := len(s.httpHeaders.Values("Accept-Encoding")) > 0
|
||||
if s.usageTracker != nil && !serviceOverridesAcceptEncoding {
|
||||
// Strip Accept-Encoding so Go Transport adds it automatically
|
||||
// and transparently decompresses the response for correct usage counting.
|
||||
proxyRequest.Header.Del("Accept-Encoding")
|
||||
}
|
||||
|
||||
anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta")
|
||||
if anthropicBetaHeader != "" {
|
||||
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader)
|
||||
} else {
|
||||
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
|
||||
}
|
||||
|
||||
for key, values := range s.httpHeaders {
|
||||
proxyRequest.Header.Del(key)
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
response, err := s.httpClient.Do(proxyRequest)
|
||||
response, err := selectedCredential.httpTransport().Do(proxyRequest)
|
||||
if err != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request")
|
||||
return
|
||||
}
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
|
||||
// Transparent 429 retry
|
||||
for response.StatusCode == http.StatusTooManyRequests {
|
||||
resetAt := parseRateLimitResetFromHeaders(response.Header)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
if bodyBytes == nil || nextCredential == nil {
|
||||
response.Body.Close()
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
|
||||
return
|
||||
}
|
||||
response.Body.Close()
|
||||
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
requestContext.cancelRequest()
|
||||
requestContext = nextCredential.wrapRequestContext(r.Context())
|
||||
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if buildErr != nil {
|
||||
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
|
||||
return
|
||||
}
|
||||
retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest)
|
||||
if retryErr != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
|
||||
return
|
||||
}
|
||||
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
|
||||
return
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
response = retryResponse
|
||||
selectedCredential = nextCredential
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
|
||||
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
||||
go selectedCredential.pollUsage(s.ctx)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
|
||||
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite response headers for external users
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" {
|
||||
s.rewriteResponseHeadersForExternalUser(response.Header, userConfig)
|
||||
}
|
||||
|
||||
for key, values := range response.Header {
|
||||
if !isHopByHopHeader(key) {
|
||||
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
|
||||
w.Header()[key] = values
|
||||
}
|
||||
}
|
||||
w.WriteHeader(response.StatusCode)
|
||||
|
||||
if s.usageTracker != nil && response.StatusCode == http.StatusOK {
|
||||
s.handleResponseWithTracking(w, response, requestModel, anthropicBetaHeader, messagesCount, username)
|
||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||
if usageTracker != nil && response.StatusCode == http.StatusOK {
|
||||
s.handleResponseWithTracking(ctx, w, response, usageTracker, requestModel, anthropicBetaHeader, messagesCount, username)
|
||||
} else {
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
if err == nil && mediaType != "text/event-stream" {
|
||||
@@ -407,7 +556,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error("streaming not supported")
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
@@ -416,7 +565,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if n > 0 {
|
||||
_, writeError := w.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.Error("write streaming response: ", writeError)
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
@@ -428,7 +577,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
|
||||
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
|
||||
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
isStreaming := err == nil && mediaType == "text/event-stream"
|
||||
@@ -436,7 +585,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
if !isStreaming {
|
||||
bodyBytes, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
s.logger.Error("read response body: ", err)
|
||||
s.logger.ErrorContext(ctx, "read response body: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -456,7 +605,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
if responseModel != "" {
|
||||
totalInputTokens := usage.InputTokens + usage.CacheCreationInputTokens + usage.CacheReadInputTokens
|
||||
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
messagesCount,
|
||||
@@ -479,7 +628,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error("streaming not supported")
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -542,7 +691,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
|
||||
_, writeError := writer.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.Error("write streaming response: ", writeError)
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
@@ -557,7 +706,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
if responseModel != "" {
|
||||
totalInputTokens := accumulatedUsage.InputTokens + accumulatedUsage.CacheCreationInputTokens + accumulatedUsage.CacheReadInputTokens
|
||||
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
messagesCount,
|
||||
@@ -578,6 +727,120 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
if len(s.options.Users) == 0 {
|
||||
writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication")
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
username, ok := s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
|
||||
userConfig := s.userConfigMap[username]
|
||||
if userConfig == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found")
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale(r.Context())
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]float64{
|
||||
"five_hour_utilization": avgFiveHour,
|
||||
"weekly_utilization": avgWeekly,
|
||||
"plan_weight": totalWeight,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) {
|
||||
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
|
||||
for _, cred := range provider.allCredentials() {
|
||||
if !cred.isAvailable() {
|
||||
continue
|
||||
}
|
||||
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
|
||||
continue
|
||||
}
|
||||
if !userConfig.AllowExternalUsage && cred.isExternal() {
|
||||
continue
|
||||
}
|
||||
weight := cred.planWeight()
|
||||
remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization()
|
||||
if remaining5h < 0 {
|
||||
remaining5h = 0
|
||||
}
|
||||
remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization()
|
||||
if remainingWeekly < 0 {
|
||||
remainingWeekly = 0
|
||||
}
|
||||
totalWeightedRemaining5h += remaining5h * weight
|
||||
totalWeightedRemainingWeekly += remainingWeekly * weight
|
||||
totalWeight += weight
|
||||
}
|
||||
if totalWeight == 0 {
|
||||
return 100, 100, 0
|
||||
}
|
||||
return 100 - totalWeightedRemaining5h/totalWeight,
|
||||
100 - totalWeightedRemainingWeekly/totalWeight,
|
||||
totalWeight
|
||||
}
|
||||
|
||||
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.CCMUser) {
|
||||
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
// Rewrite utilization headers to aggregated average (convert back to 0.0-1.0 range)
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64))
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64))
|
||||
if totalWeight > 0 {
|
||||
headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) InterfaceUpdated() {
|
||||
for _, cred := range s.allCredentials {
|
||||
extCred, ok := cred.(*externalCredential)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if extCred.reverse && extCred.connectorURL != nil {
|
||||
extCred.reverseService = s
|
||||
extCred.resetReverseContext()
|
||||
go extCred.connectorLoop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
err := common.Close(
|
||||
common.PtrOrNil(s.httpServer),
|
||||
@@ -585,12 +848,8 @@ func (s *Service) Close() error {
|
||||
s.tlsConfig,
|
||||
)
|
||||
|
||||
if s.usageTracker != nil {
|
||||
s.usageTracker.cancelPendingSave()
|
||||
saveErr := s.usageTracker.Save()
|
||||
if saveErr != nil {
|
||||
s.logger.Error("save usage statistics: ", saveErr)
|
||||
}
|
||||
for _, cred := range s.allCredentials {
|
||||
cred.close()
|
||||
}
|
||||
|
||||
return err
|
||||
|
||||
@@ -2,6 +2,7 @@ package ocm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -55,6 +56,14 @@ func readCredentialsFromFile(path string) (*oauthCredentials, error) {
|
||||
return &credentials, nil
|
||||
}
|
||||
|
||||
func checkCredentialFileWritable(path string) error {
|
||||
file, err := os.OpenFile(path, os.O_WRONLY, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
func writeCredentialsToFile(credentials *oauthCredentials, path string) error {
|
||||
data, err := json.MarshalIndent(credentials, "", " ")
|
||||
if err != nil {
|
||||
@@ -110,7 +119,7 @@ func (c *oauthCredentials) needsRefresh() bool {
|
||||
return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour
|
||||
}
|
||||
|
||||
func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
|
||||
func refreshToken(ctx context.Context, httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
|
||||
if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" {
|
||||
return nil, E.New("refresh token is empty")
|
||||
}
|
||||
@@ -125,19 +134,24 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut
|
||||
return nil, E.Cause(err, "marshal request")
|
||||
}
|
||||
|
||||
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "application/json")
|
||||
|
||||
response, err := httpClient.Do(request)
|
||||
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "application/json")
|
||||
return request, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode == http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, E.New("refresh rate limited: ", response.Status, " ", string(body))
|
||||
}
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
|
||||
@@ -171,3 +185,41 @@ func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oaut
|
||||
|
||||
return &newCredentials, nil
|
||||
}
|
||||
|
||||
func cloneCredentials(credentials *oauthCredentials) *oauthCredentials {
|
||||
if credentials == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *credentials
|
||||
if credentials.Tokens != nil {
|
||||
clonedTokens := *credentials.Tokens
|
||||
cloned.Tokens = &clonedTokens
|
||||
}
|
||||
if credentials.LastRefresh != nil {
|
||||
lastRefresh := *credentials.LastRefresh
|
||||
cloned.LastRefresh = &lastRefresh
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func credentialsEqual(left *oauthCredentials, right *oauthCredentials) bool {
|
||||
if left == nil || right == nil {
|
||||
return left == right
|
||||
}
|
||||
if left.APIKey != right.APIKey {
|
||||
return false
|
||||
}
|
||||
if (left.Tokens == nil) != (right.Tokens == nil) {
|
||||
return false
|
||||
}
|
||||
if left.Tokens != nil && *left.Tokens != *right.Tokens {
|
||||
return false
|
||||
}
|
||||
if (left.LastRefresh == nil) != (right.LastRefresh == nil) {
|
||||
return false
|
||||
}
|
||||
if left.LastRefresh != nil && !left.LastRefresh.Equal(*right.LastRefresh) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -13,6 +13,17 @@ func platformReadCredentials(customPath string) (*oauthCredentials, error) {
|
||||
return readCredentialsFromFile(customPath)
|
||||
}
|
||||
|
||||
func platformCanWriteCredentials(customPath string) error {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return checkCredentialFileWritable(customPath)
|
||||
}
|
||||
|
||||
func platformWriteCredentials(credentials *oauthCredentials, customPath string) error {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
|
||||
729
service/ocm/credential_external.go
Normal file
729
service/ocm/credential_external.go
Normal file
@@ -0,0 +1,729 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
const reverseProxyBaseURL = "http://reverse-proxy"
|
||||
|
||||
type externalCredential struct {
|
||||
tag string
|
||||
baseURL string
|
||||
token string
|
||||
credDialer N.Dialer
|
||||
httpClient *http.Client
|
||||
state credentialState
|
||||
stateMutex sync.RWMutex
|
||||
pollAccess sync.Mutex
|
||||
pollInterval time.Duration
|
||||
usageTracker *AggregatedUsage
|
||||
logger log.ContextLogger
|
||||
|
||||
onBecameUnusable func()
|
||||
interrupted bool
|
||||
requestContext context.Context
|
||||
cancelRequests context.CancelFunc
|
||||
requestAccess sync.Mutex
|
||||
|
||||
// Reverse proxy fields
|
||||
reverse bool
|
||||
reverseSession *yamux.Session
|
||||
reverseAccess sync.RWMutex
|
||||
closed bool
|
||||
reverseContext context.Context
|
||||
reverseCancel context.CancelFunc
|
||||
connectorDialer N.Dialer
|
||||
connectorDestination M.Socksaddr
|
||||
connectorRequestPath string
|
||||
connectorURL *url.URL
|
||||
connectorTLS *stdTLS.Config
|
||||
reverseService http.Handler
|
||||
}
|
||||
|
||||
type reverseSessionDialer struct {
|
||||
credential *externalCredential
|
||||
}
|
||||
|
||||
func (d reverseSessionDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
if N.NetworkName(network) != N.NetworkTCP {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
return d.credential.openReverseConnection(ctx)
|
||||
}
|
||||
|
||||
func (d reverseSessionDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
func externalCredentialURLPort(parsedURL *url.URL) uint16 {
|
||||
portStr := parsedURL.Port()
|
||||
if portStr != "" {
|
||||
port, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err == nil {
|
||||
return uint16(port)
|
||||
}
|
||||
}
|
||||
if parsedURL.Scheme == "https" {
|
||||
return 443
|
||||
}
|
||||
return 80
|
||||
}
|
||||
|
||||
func externalCredentialServerPort(parsedURL *url.URL, configuredPort uint16) uint16 {
|
||||
if configuredPort != 0 {
|
||||
return configuredPort
|
||||
}
|
||||
return externalCredentialURLPort(parsedURL)
|
||||
}
|
||||
|
||||
func externalCredentialBaseURL(parsedURL *url.URL) string {
|
||||
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
|
||||
if parsedURL.Path != "" && parsedURL.Path != "/" {
|
||||
baseURL += parsedURL.Path
|
||||
}
|
||||
if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
|
||||
baseURL = baseURL[:len(baseURL)-1]
|
||||
}
|
||||
return baseURL
|
||||
}
|
||||
|
||||
func externalCredentialReversePath(parsedURL *url.URL, endpointPath string) string {
|
||||
pathPrefix := parsedURL.EscapedPath()
|
||||
if pathPrefix == "/" {
|
||||
pathPrefix = ""
|
||||
} else {
|
||||
pathPrefix = strings.TrimSuffix(pathPrefix, "/")
|
||||
}
|
||||
return pathPrefix + endpointPath
|
||||
}
|
||||
|
||||
func newExternalCredential(ctx context.Context, tag string, options option.OCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) {
|
||||
pollInterval := time.Duration(options.PollInterval)
|
||||
if pollInterval <= 0 {
|
||||
pollInterval = 30 * time.Minute
|
||||
}
|
||||
|
||||
requestContext, cancelRequests := context.WithCancel(context.Background())
|
||||
reverseContext, reverseCancel := context.WithCancel(context.Background())
|
||||
|
||||
cred := &externalCredential{
|
||||
tag: tag,
|
||||
token: options.Token,
|
||||
pollInterval: pollInterval,
|
||||
logger: logger,
|
||||
requestContext: requestContext,
|
||||
cancelRequests: cancelRequests,
|
||||
reverse: options.Reverse,
|
||||
reverseContext: reverseContext,
|
||||
reverseCancel: reverseCancel,
|
||||
}
|
||||
|
||||
if options.URL == "" {
|
||||
// Receiver mode: no URL, wait for reverse connection
|
||||
cred.baseURL = reverseProxyBaseURL
|
||||
cred.credDialer = reverseSessionDialer{credential: cred}
|
||||
cred.httpClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
ForceAttemptHTTP2: false,
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return cred.openReverseConnection(ctx)
|
||||
},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// Normal or connector mode: has URL
|
||||
parsedURL, err := url.Parse(options.URL)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "parse url for credential ", tag)
|
||||
}
|
||||
|
||||
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
Options: option.DialerOptions{
|
||||
Detour: options.Detour,
|
||||
},
|
||||
RemoteIsDomain: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create dialer for credential ", tag)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if options.Server != "" {
|
||||
destination := M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort))
|
||||
return credentialDialer.DialContext(ctx, network, destination)
|
||||
}
|
||||
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
}
|
||||
|
||||
if parsedURL.Scheme == "https" {
|
||||
transport.TLSClientConfig = &stdTLS.Config{
|
||||
ServerName: parsedURL.Hostname(),
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
cred.baseURL = externalCredentialBaseURL(parsedURL)
|
||||
|
||||
if options.Reverse {
|
||||
// Connector mode: we dial out to serve, not to proxy
|
||||
cred.connectorDialer = credentialDialer
|
||||
if options.Server != "" {
|
||||
cred.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort))
|
||||
} else {
|
||||
cred.connectorDestination = M.ParseSocksaddrHostPort(parsedURL.Hostname(), externalCredentialURLPort(parsedURL))
|
||||
}
|
||||
cred.connectorRequestPath = externalCredentialReversePath(parsedURL, "/ocm/v1/reverse")
|
||||
cred.connectorURL = parsedURL
|
||||
if parsedURL.Scheme == "https" {
|
||||
cred.connectorTLS = &stdTLS.Config{
|
||||
ServerName: parsedURL.Hostname(),
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Normal mode: standard HTTP client for proxying
|
||||
cred.credDialer = credentialDialer
|
||||
cred.httpClient = &http.Client{Transport: transport}
|
||||
}
|
||||
}
|
||||
|
||||
if options.UsagesPath != "" {
|
||||
cred.usageTracker = &AggregatedUsage{
|
||||
LastUpdated: time.Now(),
|
||||
Combinations: make([]CostCombination, 0),
|
||||
filePath: options.UsagesPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) start() error {
|
||||
if c.usageTracker != nil {
|
||||
err := c.usageTracker.Load()
|
||||
if err != nil {
|
||||
c.logger.Warn("load usage statistics for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
if c.reverse && c.connectorURL != nil {
|
||||
go c.connectorLoop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) setOnBecameUnusable(fn func()) {
|
||||
c.onBecameUnusable = fn
|
||||
}
|
||||
|
||||
func (c *externalCredential) tagName() string {
|
||||
return c.tag
|
||||
}
|
||||
|
||||
func (c *externalCredential) isExternal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *externalCredential) isAvailable() bool {
|
||||
return c.unavailableError() == nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) isUsable() bool {
|
||||
if !c.isAvailable() {
|
||||
return false
|
||||
}
|
||||
c.stateMutex.RLock()
|
||||
if c.state.consecutivePollFailures > 0 {
|
||||
c.stateMutex.RUnlock()
|
||||
return false
|
||||
}
|
||||
if c.state.hardRateLimited {
|
||||
if time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.stateMutex.RUnlock()
|
||||
return false
|
||||
}
|
||||
c.stateMutex.RUnlock()
|
||||
c.stateMutex.Lock()
|
||||
if c.state.hardRateLimited && !time.Now().Before(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
|
||||
c.stateMutex.Unlock()
|
||||
return usable
|
||||
}
|
||||
usable := c.state.fiveHourUtilization < 100 && c.state.weeklyUtilization < 100
|
||||
c.stateMutex.RUnlock()
|
||||
return usable
|
||||
}
|
||||
|
||||
func (c *externalCredential) fiveHourUtilization() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return c.state.fiveHourUtilization
|
||||
}
|
||||
|
||||
func (c *externalCredential) weeklyUtilization() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return c.state.weeklyUtilization
|
||||
}
|
||||
|
||||
func (c *externalCredential) fiveHourCap() float64 {
|
||||
return 100
|
||||
}
|
||||
|
||||
func (c *externalCredential) weeklyCap() float64 {
|
||||
return 100
|
||||
}
|
||||
|
||||
func (c *externalCredential) planWeight() float64 {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
if c.state.remotePlanWeight > 0 {
|
||||
return c.state.remotePlanWeight
|
||||
}
|
||||
return 10
|
||||
}
|
||||
|
||||
func (c *externalCredential) weeklyResetTime() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return c.state.weeklyReset
|
||||
}
|
||||
|
||||
func (c *externalCredential) markRateLimited(resetAt time.Time) {
|
||||
c.logger.Warn("rate limited for ", c.tag, ", reset in ", log.FormatDuration(time.Until(resetAt)))
|
||||
c.stateMutex.Lock()
|
||||
c.state.hardRateLimited = true
|
||||
c.state.rateLimitResetAt = resetAt
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) earliestReset() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
if c.state.hardRateLimited {
|
||||
return c.state.rateLimitResetAt
|
||||
}
|
||||
earliest := c.state.fiveHourReset
|
||||
if !c.state.weeklyReset.IsZero() && (earliest.IsZero() || c.state.weeklyReset.Before(earliest)) {
|
||||
earliest = c.state.weeklyReset
|
||||
}
|
||||
return earliest
|
||||
}
|
||||
|
||||
func (c *externalCredential) unavailableError() error {
|
||||
if c.reverse && c.connectorURL != nil {
|
||||
return E.New("credential ", c.tag, " is unavailable: reverse connector credentials cannot serve local requests")
|
||||
}
|
||||
if c.baseURL == reverseProxyBaseURL {
|
||||
session := c.getReverseSession()
|
||||
if session == nil || session.IsClosed() {
|
||||
return E.New("credential ", c.tag, " is unavailable: reverse connection not established")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) getAccessToken() (string, error) {
|
||||
return c.token, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) buildProxyRequest(ctx context.Context, original *http.Request, bodyBytes []byte, _ http.Header) (*http.Request, error) {
|
||||
proxyURL := c.baseURL + original.URL.RequestURI()
|
||||
var body io.Reader
|
||||
if bodyBytes != nil {
|
||||
body = bytes.NewReader(bodyBytes)
|
||||
} else {
|
||||
body = original.Body
|
||||
}
|
||||
proxyRequest, err := http.NewRequestWithContext(ctx, original.Method, proxyURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, values := range original.Header {
|
||||
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+c.token)
|
||||
|
||||
return proxyRequest, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Conn, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
session := c.getReverseSession()
|
||||
if session == nil || session.IsClosed() {
|
||||
return nil, E.New("reverse connection not established for ", c.tag)
|
||||
}
|
||||
conn, err := session.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
conn.Close()
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
|
||||
c.stateMutex.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
hadData := false
|
||||
|
||||
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
|
||||
if activeLimitIdentifier == "" {
|
||||
activeLimitIdentifier = "codex"
|
||||
}
|
||||
|
||||
fiveHourResetAt := headers.Get("x-" + activeLimitIdentifier + "-primary-reset-at")
|
||||
if fiveHourResetAt != "" {
|
||||
value, err := strconv.ParseInt(fiveHourResetAt, 10, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
c.state.fiveHourReset = time.Unix(value, 0)
|
||||
}
|
||||
}
|
||||
fiveHourPercent := headers.Get("x-" + activeLimitIdentifier + "-primary-used-percent")
|
||||
if fiveHourPercent != "" {
|
||||
value, err := strconv.ParseFloat(fiveHourPercent, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
c.state.fiveHourUtilization = value
|
||||
}
|
||||
}
|
||||
|
||||
weeklyResetAt := headers.Get("x-" + activeLimitIdentifier + "-secondary-reset-at")
|
||||
if weeklyResetAt != "" {
|
||||
value, err := strconv.ParseInt(weeklyResetAt, 10, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
c.state.weeklyReset = time.Unix(value, 0)
|
||||
}
|
||||
}
|
||||
weeklyPercent := headers.Get("x-" + activeLimitIdentifier + "-secondary-used-percent")
|
||||
if weeklyPercent != "" {
|
||||
value, err := strconv.ParseFloat(weeklyPercent, 64)
|
||||
if err == nil {
|
||||
hadData = true
|
||||
c.state.weeklyUtilization = value
|
||||
}
|
||||
}
|
||||
if planWeight := headers.Get("X-OCM-Plan-Weight"); planWeight != "" {
|
||||
value, err := strconv.ParseFloat(planWeight, 64)
|
||||
if err == nil && value > 0 {
|
||||
c.state.remotePlanWeight = value
|
||||
}
|
||||
}
|
||||
if hadData {
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.lastUpdated = time.Now()
|
||||
}
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
resetSuffix := ""
|
||||
if !c.state.weeklyReset.IsZero() {
|
||||
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
|
||||
}
|
||||
c.logger.Debug("usage update for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) checkTransitionLocked() bool {
|
||||
unusable := c.state.hardRateLimited || c.state.fiveHourUtilization >= 100 || c.state.weeklyUtilization >= 100 || c.state.consecutivePollFailures > 0
|
||||
if unusable && !c.interrupted {
|
||||
c.interrupted = true
|
||||
return true
|
||||
}
|
||||
if !unusable && c.interrupted {
|
||||
c.interrupted = false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *externalCredential) wrapRequestContext(parent context.Context) *credentialRequestContext {
|
||||
c.requestAccess.Lock()
|
||||
credentialContext := c.requestContext
|
||||
c.requestAccess.Unlock()
|
||||
derived, cancel := context.WithCancel(parent)
|
||||
stop := context.AfterFunc(credentialContext, func() {
|
||||
cancel()
|
||||
})
|
||||
return &credentialRequestContext{
|
||||
Context: derived,
|
||||
releaseFunc: stop,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) interruptConnections() {
|
||||
c.logger.Warn("interrupting connections for ", c.tag)
|
||||
c.requestAccess.Lock()
|
||||
c.cancelRequests()
|
||||
c.requestContext, c.cancelRequests = context.WithCancel(context.Background())
|
||||
c.requestAccess.Unlock()
|
||||
if c.onBecameUnusable != nil {
|
||||
c.onBecameUnusable()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) pollUsage(ctx context.Context) {
|
||||
if !c.pollAccess.TryLock() {
|
||||
return
|
||||
}
|
||||
defer c.pollAccess.Unlock()
|
||||
defer c.markUsagePollAttempted()
|
||||
|
||||
statusURL := c.baseURL + "/ocm/v1/status"
|
||||
httpClient := &http.Client{
|
||||
Transport: c.httpClient.Transport,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
response, err := doHTTPWithRetry(ctx, httpClient, func() (*http.Request, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, statusURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer "+c.token)
|
||||
return request, nil
|
||||
})
|
||||
if err != nil {
|
||||
c.logger.Error("poll usage for ", c.tag, ": ", err)
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
|
||||
// 404 means the remote does not have a status endpoint yet;
|
||||
// usage will be updated passively from response headers.
|
||||
if response.StatusCode == http.StatusNotFound {
|
||||
c.stateMutex.Lock()
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
} else {
|
||||
c.incrementPollFailures()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var statusResponse struct {
|
||||
FiveHourUtilization float64 `json:"five_hour_utilization"`
|
||||
WeeklyUtilization float64 `json:"weekly_utilization"`
|
||||
PlanWeight float64 `json:"plan_weight"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&statusResponse)
|
||||
if err != nil {
|
||||
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
|
||||
c.incrementPollFailures()
|
||||
return
|
||||
}
|
||||
|
||||
c.stateMutex.Lock()
|
||||
isFirstUpdate := c.state.lastUpdated.IsZero()
|
||||
oldFiveHour := c.state.fiveHourUtilization
|
||||
oldWeekly := c.state.weeklyUtilization
|
||||
c.state.consecutivePollFailures = 0
|
||||
c.state.fiveHourUtilization = statusResponse.FiveHourUtilization
|
||||
c.state.weeklyUtilization = statusResponse.WeeklyUtilization
|
||||
if statusResponse.PlanWeight > 0 {
|
||||
c.state.remotePlanWeight = statusResponse.PlanWeight
|
||||
}
|
||||
if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) {
|
||||
c.state.hardRateLimited = false
|
||||
}
|
||||
if isFirstUpdate || int(c.state.fiveHourUtilization*100) != int(oldFiveHour*100) || int(c.state.weeklyUtilization*100) != int(oldWeekly*100) {
|
||||
resetSuffix := ""
|
||||
if !c.state.weeklyReset.IsZero() {
|
||||
resetSuffix = ", resets=" + log.FormatDuration(time.Until(c.state.weeklyReset))
|
||||
}
|
||||
c.logger.Debug("poll usage for ", c.tag, ": 5h=", c.state.fiveHourUtilization, "%, weekly=", c.state.weeklyUtilization, "%", resetSuffix)
|
||||
}
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) lastUpdatedTime() time.Time {
|
||||
c.stateMutex.RLock()
|
||||
defer c.stateMutex.RUnlock()
|
||||
return c.state.lastUpdated
|
||||
}
|
||||
|
||||
func (c *externalCredential) markUsagePollAttempted() {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
c.state.lastUpdated = time.Now()
|
||||
}
|
||||
|
||||
func (c *externalCredential) pollBackoff(baseInterval time.Duration) time.Duration {
|
||||
c.stateMutex.RLock()
|
||||
failures := c.state.consecutivePollFailures
|
||||
c.stateMutex.RUnlock()
|
||||
if failures <= 0 {
|
||||
return baseInterval
|
||||
}
|
||||
return failedPollRetryInterval
|
||||
}
|
||||
|
||||
func (c *externalCredential) incrementPollFailures() {
|
||||
c.stateMutex.Lock()
|
||||
c.state.consecutivePollFailures++
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
if shouldInterrupt {
|
||||
c.interruptConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) usageTrackerOrNil() *AggregatedUsage {
|
||||
return c.usageTracker
|
||||
}
|
||||
|
||||
func (c *externalCredential) httpTransport() *http.Client {
|
||||
return c.httpClient
|
||||
}
|
||||
|
||||
func (c *externalCredential) ocmDialer() N.Dialer {
|
||||
return c.credDialer
|
||||
}
|
||||
|
||||
func (c *externalCredential) ocmIsAPIKeyMode() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *externalCredential) ocmGetAccountID() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *externalCredential) ocmGetBaseURL() string {
|
||||
return c.baseURL
|
||||
}
|
||||
|
||||
func (c *externalCredential) close() {
|
||||
var session *yamux.Session
|
||||
c.reverseAccess.Lock()
|
||||
if !c.closed {
|
||||
c.closed = true
|
||||
if c.reverseCancel != nil {
|
||||
c.reverseCancel()
|
||||
}
|
||||
session = c.reverseSession
|
||||
c.reverseSession = nil
|
||||
}
|
||||
c.reverseAccess.Unlock()
|
||||
if session != nil {
|
||||
session.Close()
|
||||
}
|
||||
if c.usageTracker != nil {
|
||||
c.usageTracker.cancelPendingSave()
|
||||
err := c.usageTracker.Save()
|
||||
if err != nil {
|
||||
c.logger.Error("save usage statistics for ", c.tag, ": ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalCredential) getReverseSession() *yamux.Session {
|
||||
c.reverseAccess.RLock()
|
||||
defer c.reverseAccess.RUnlock()
|
||||
return c.reverseSession
|
||||
}
|
||||
|
||||
func (c *externalCredential) setReverseSession(session *yamux.Session) bool {
|
||||
c.reverseAccess.Lock()
|
||||
if c.closed {
|
||||
c.reverseAccess.Unlock()
|
||||
return false
|
||||
}
|
||||
old := c.reverseSession
|
||||
c.reverseSession = session
|
||||
c.reverseAccess.Unlock()
|
||||
if old != nil {
|
||||
old.Close()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *externalCredential) clearReverseSession(session *yamux.Session) {
|
||||
c.reverseAccess.Lock()
|
||||
if c.reverseSession == session {
|
||||
c.reverseSession = nil
|
||||
}
|
||||
c.reverseAccess.Unlock()
|
||||
}
|
||||
|
||||
func (c *externalCredential) getReverseContext() context.Context {
|
||||
c.reverseAccess.RLock()
|
||||
defer c.reverseAccess.RUnlock()
|
||||
return c.reverseContext
|
||||
}
|
||||
|
||||
func (c *externalCredential) resetReverseContext() {
|
||||
c.reverseAccess.Lock()
|
||||
if c.closed {
|
||||
c.reverseAccess.Unlock()
|
||||
return
|
||||
}
|
||||
c.reverseCancel()
|
||||
c.reverseContext, c.reverseCancel = context.WithCancel(context.Background())
|
||||
c.reverseAccess.Unlock()
|
||||
}
|
||||
139
service/ocm/credential_file.go
Normal file
139
service/ocm/credential_file.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/fswatch"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
const credentialReloadRetryInterval = 2 * time.Second
|
||||
|
||||
func resolveCredentialFilePath(customPath string) (string, error) {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if filepath.IsAbs(customPath) {
|
||||
return customPath, nil
|
||||
}
|
||||
return filepath.Abs(customPath)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) ensureCredentialWatcher() error {
|
||||
c.watcherAccess.Lock()
|
||||
defer c.watcherAccess.Unlock()
|
||||
|
||||
if c.watcher != nil || c.credentialFilePath == "" {
|
||||
return nil
|
||||
}
|
||||
if !c.watcherRetryAt.IsZero() && time.Now().Before(c.watcherRetryAt) {
|
||||
return nil
|
||||
}
|
||||
|
||||
watcher, err := fswatch.NewWatcher(fswatch.Options{
|
||||
Path: []string{c.credentialFilePath},
|
||||
Logger: c.logger,
|
||||
Callback: func(string) {
|
||||
err := c.reloadCredentials(true)
|
||||
if err != nil {
|
||||
c.logger.Warn("reload credentials for ", c.tag, ": ", err)
|
||||
}
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval)
|
||||
return err
|
||||
}
|
||||
|
||||
err = watcher.Start()
|
||||
if err != nil {
|
||||
c.watcherRetryAt = time.Now().Add(credentialReloadRetryInterval)
|
||||
return err
|
||||
}
|
||||
|
||||
c.watcher = watcher
|
||||
c.watcherRetryAt = time.Time{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) retryCredentialReloadIfNeeded() {
|
||||
c.stateMutex.RLock()
|
||||
unavailable := c.state.unavailable
|
||||
lastAttempt := c.state.lastCredentialLoadAttempt
|
||||
c.stateMutex.RUnlock()
|
||||
if !unavailable {
|
||||
return
|
||||
}
|
||||
if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval {
|
||||
return
|
||||
}
|
||||
|
||||
err := c.ensureCredentialWatcher()
|
||||
if err != nil {
|
||||
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
|
||||
}
|
||||
_ = c.reloadCredentials(false)
|
||||
}
|
||||
|
||||
func (c *defaultCredential) reloadCredentials(force bool) error {
|
||||
c.reloadAccess.Lock()
|
||||
defer c.reloadAccess.Unlock()
|
||||
|
||||
c.stateMutex.RLock()
|
||||
unavailable := c.state.unavailable
|
||||
lastAttempt := c.state.lastCredentialLoadAttempt
|
||||
c.stateMutex.RUnlock()
|
||||
if !force {
|
||||
if !unavailable {
|
||||
return nil
|
||||
}
|
||||
if !lastAttempt.IsZero() && time.Since(lastAttempt) < credentialReloadRetryInterval {
|
||||
return c.unavailableError()
|
||||
}
|
||||
}
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.state.lastCredentialLoadAttempt = time.Now()
|
||||
c.stateMutex.Unlock()
|
||||
|
||||
credentials, err := platformReadCredentials(c.credentialPath)
|
||||
if err != nil {
|
||||
return c.markCredentialsUnavailable(E.Cause(err, "read credentials"))
|
||||
}
|
||||
|
||||
c.accessMutex.Lock()
|
||||
c.credentials = credentials
|
||||
c.accessMutex.Unlock()
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.state.unavailable = false
|
||||
c.state.lastCredentialLoadError = ""
|
||||
c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *defaultCredential) markCredentialsUnavailable(err error) error {
|
||||
c.accessMutex.Lock()
|
||||
hadCredentials := c.credentials != nil
|
||||
c.credentials = nil
|
||||
c.accessMutex.Unlock()
|
||||
|
||||
c.stateMutex.Lock()
|
||||
c.state.unavailable = true
|
||||
c.state.lastCredentialLoadError = err.Error()
|
||||
shouldInterrupt := c.checkTransitionLocked()
|
||||
c.stateMutex.Unlock()
|
||||
|
||||
if shouldInterrupt && hadCredentials {
|
||||
c.interruptConnections()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -13,6 +13,17 @@ func platformReadCredentials(customPath string) (*oauthCredentials, error) {
|
||||
return readCredentialsFromFile(customPath)
|
||||
}
|
||||
|
||||
func platformCanWriteCredentials(customPath string) error {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return checkCredentialFileWritable(customPath)
|
||||
}
|
||||
|
||||
func platformWriteCredentials(credentials *oauthCredentials, customPath string) error {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
|
||||
1430
service/ocm/credential_state.go
Normal file
1430
service/ocm/credential_state.go
Normal file
File diff suppressed because it is too large
Load Diff
259
service/ocm/reverse.go
Normal file
259
service/ocm/reverse.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
func reverseYamuxConfig() *yamux.Config {
|
||||
config := yamux.DefaultConfig()
|
||||
config.KeepAliveInterval = 15 * time.Second
|
||||
config.ConnectionWriteTimeout = 10 * time.Second
|
||||
config.MaxStreamWindowSize = 512 * 1024
|
||||
config.LogOutput = io.Discard
|
||||
return config
|
||||
}
|
||||
|
||||
type bufferedConn struct {
|
||||
reader *bufio.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
type yamuxNetListener struct {
|
||||
session *yamux.Session
|
||||
}
|
||||
|
||||
func (l *yamuxNetListener) Accept() (net.Conn, error) {
|
||||
return l.session.Accept()
|
||||
}
|
||||
|
||||
func (l *yamuxNetListener) Close() error {
|
||||
return l.session.Close()
|
||||
}
|
||||
|
||||
func (l *yamuxNetListener) Addr() net.Addr {
|
||||
return l.session.Addr()
|
||||
}
|
||||
|
||||
func (s *Service) handleReverseConnect(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Upgrade") != "reverse-proxy" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header")
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
|
||||
receiverCredential := s.findReceiverCredential(clientToken)
|
||||
if receiverCredential == nil {
|
||||
s.logger.WarnContext(ctx, "reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token")
|
||||
return
|
||||
}
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
s.logger.ErrorContext(ctx, "reverse connect: hijack not supported")
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported")
|
||||
return
|
||||
}
|
||||
|
||||
conn, bufferedReadWriter, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "reverse connect: hijack: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
response := "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: reverse-proxy\r\n\r\n"
|
||||
_, err = bufferedReadWriter.WriteString(response)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.ErrorContext(ctx, "reverse connect: write upgrade response: ", err)
|
||||
return
|
||||
}
|
||||
err = bufferedReadWriter.Flush()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.ErrorContext(ctx, "reverse connect: flush upgrade response: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := yamux.Client(conn, reverseYamuxConfig())
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.ErrorContext(ctx, "reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !receiverCredential.setReverseSession(session) {
|
||||
session.Close()
|
||||
return
|
||||
}
|
||||
s.logger.InfoContext(ctx, "reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
|
||||
|
||||
go func() {
|
||||
<-session.CloseChan()
|
||||
receiverCredential.clearReverseSession(session)
|
||||
s.logger.WarnContext(ctx, "reverse connection lost for ", receiverCredential.tagName())
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) findReceiverCredential(token string) *externalCredential {
|
||||
for _, cred := range s.allCredentials {
|
||||
extCred, ok := cred.(*externalCredential)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if extCred.baseURL == reverseProxyBaseURL && extCred.token == token {
|
||||
return extCred
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *externalCredential) connectorLoop() {
|
||||
var consecutiveFailures int
|
||||
ctx := c.getReverseContext()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
sessionLifetime, err := c.connectorConnect(ctx)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if sessionLifetime >= connectorBackoffResetThreshold {
|
||||
consecutiveFailures = 0
|
||||
}
|
||||
consecutiveFailures++
|
||||
backoff := connectorBackoff(consecutiveFailures)
|
||||
c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff)
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const connectorBackoffResetThreshold = time.Minute
|
||||
|
||||
func connectorBackoff(failures int) time.Duration {
|
||||
if failures > 5 {
|
||||
failures = 5
|
||||
}
|
||||
base := time.Second * time.Duration(1<<failures)
|
||||
if base > 30*time.Second {
|
||||
base = 30 * time.Second
|
||||
}
|
||||
jitter := time.Duration(rand.Int64N(int64(base) / 2))
|
||||
return base + jitter
|
||||
}
|
||||
|
||||
func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duration, error) {
|
||||
if c.reverseService == nil {
|
||||
return 0, E.New("reverse service not initialized")
|
||||
}
|
||||
destination := c.connectorResolveDestination()
|
||||
conn, err := c.connectorDialer.DialContext(ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
return 0, E.Cause(err, "dial")
|
||||
}
|
||||
|
||||
if c.connectorTLS != nil {
|
||||
tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone())
|
||||
err = tlsConn.HandshakeContext(ctx)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "tls handshake")
|
||||
}
|
||||
conn = tlsConn
|
||||
}
|
||||
|
||||
upgradeRequest := "GET " + c.connectorRequestPath + " HTTP/1.1\r\n" +
|
||||
"Host: " + c.connectorURL.Host + "\r\n" +
|
||||
"Connection: Upgrade\r\n" +
|
||||
"Upgrade: reverse-proxy\r\n" +
|
||||
"Authorization: Bearer " + c.token + "\r\n" +
|
||||
"\r\n"
|
||||
_, err = io.WriteString(conn, upgradeRequest)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "write upgrade request")
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
statusLine, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "read upgrade response")
|
||||
}
|
||||
if !strings.HasPrefix(statusLine, "HTTP/1.1 101") {
|
||||
conn.Close()
|
||||
return 0, E.New("unexpected upgrade response: ", strings.TrimSpace(statusLine))
|
||||
}
|
||||
for {
|
||||
line, readErr := reader.ReadString('\n')
|
||||
if readErr != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(readErr, "read upgrade headers")
|
||||
}
|
||||
if strings.TrimSpace(line) == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
session, err := yamux.Server(&bufferedConn{reader: reader, Conn: conn}, reverseYamuxConfig())
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "create yamux server")
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
c.logger.Info("reverse connection established for ", c.tag)
|
||||
|
||||
serveStart := time.Now()
|
||||
httpServer := &http.Server{
|
||||
Handler: c.reverseService,
|
||||
ReadTimeout: 0,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
err = httpServer.Serve(&yamuxNetListener{session: session})
|
||||
sessionLifetime := time.Since(serveStart)
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) && ctx.Err() == nil {
|
||||
return sessionLifetime, E.Cause(err, "serve")
|
||||
}
|
||||
return sessionLifetime, E.New("connection closed")
|
||||
}
|
||||
|
||||
func (c *externalCredential) connectorResolveDestination() M.Socksaddr {
|
||||
return c.connectorDestination
|
||||
}
|
||||
@@ -3,12 +3,10 @@ package ocm
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -17,7 +15,6 @@ import (
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
boxService "github.com/sagernet/sing-box/adapter/service"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/common/listener"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
@@ -26,15 +23,14 @@ import (
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
|
||||
func RegisterService(registry *boxService.Registry) {
|
||||
@@ -52,17 +48,85 @@ type errorDetails struct {
|
||||
}
|
||||
|
||||
func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) {
|
||||
writeJSONErrorWithCode(w, r, statusCode, errorType, "", message)
|
||||
}
|
||||
|
||||
func writeJSONErrorWithCode(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, errorCode string, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
json.NewEncoder(w).Encode(errorResponse{
|
||||
Error: errorDetails{
|
||||
Type: errorType,
|
||||
Code: errorCode,
|
||||
Message: message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func writePlainTextError(w http.ResponseWriter, statusCode int, message string) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(statusCode)
|
||||
_, _ = io.WriteString(w, message)
|
||||
}
|
||||
|
||||
const (
|
||||
retryableUsageMessage = "current credential reached its usage limit; retry the request to use another credential"
|
||||
retryableUsageCode = "credential_usage_exhausted"
|
||||
)
|
||||
|
||||
func hasAlternativeCredential(provider credentialProvider, currentCredential credential, filter func(credential) bool) bool {
|
||||
if provider == nil || currentCredential == nil {
|
||||
return false
|
||||
}
|
||||
for _, cred := range provider.allCredentials() {
|
||||
if cred == currentCredential {
|
||||
continue
|
||||
}
|
||||
if filter != nil && !filter(cred) {
|
||||
continue
|
||||
}
|
||||
if cred.isUsable() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func unavailableCredentialMessage(provider credentialProvider, fallback string) string {
|
||||
if provider == nil {
|
||||
return fallback
|
||||
}
|
||||
message := allRateLimitedError(provider.allCredentials()).Error()
|
||||
if message == "all credentials unavailable" && fallback != "" {
|
||||
return fallback
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func writeRetryableUsageError(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSONErrorWithCode(w, r, http.StatusServiceUnavailable, "server_error", retryableUsageCode, retryableUsageMessage)
|
||||
}
|
||||
|
||||
func writeNonRetryableCredentialError(w http.ResponseWriter, message string) {
|
||||
writePlainTextError(w, http.StatusBadRequest, message)
|
||||
}
|
||||
|
||||
func writeCredentialUnavailableError(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
provider credentialProvider,
|
||||
currentCredential credential,
|
||||
filter func(credential) bool,
|
||||
fallback string,
|
||||
) {
|
||||
if hasAlternativeCredential(provider, currentCredential, filter) {
|
||||
writeRetryableUsageError(w, r)
|
||||
return
|
||||
}
|
||||
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, fallback))
|
||||
}
|
||||
|
||||
func isHopByHopHeader(header string) bool {
|
||||
switch strings.ToLower(header) {
|
||||
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
|
||||
@@ -72,6 +136,19 @@ func isHopByHopHeader(header string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func isReverseProxyHeader(header string) bool {
|
||||
lowerHeader := strings.ToLower(header)
|
||||
if strings.HasPrefix(lowerHeader, "cf-") {
|
||||
return true
|
||||
}
|
||||
switch lowerHeader {
|
||||
case "cdn-loop", "true-client-ip", "x-forwarded-for", "x-forwarded-proto", "x-real-ip":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeRateLimitIdentifier(limitIdentifier string) string {
|
||||
trimmedIdentifier := strings.TrimSpace(strings.ToLower(limitIdentifier))
|
||||
if trimmedIdentifier == "" {
|
||||
@@ -127,72 +204,43 @@ type Service struct {
|
||||
boxService.Adapter
|
||||
ctx context.Context
|
||||
logger log.ContextLogger
|
||||
credentialPath string
|
||||
credentials *oauthCredentials
|
||||
users []option.OCMUser
|
||||
dialer N.Dialer
|
||||
httpClient *http.Client
|
||||
options option.OCMServiceOptions
|
||||
httpHeaders http.Header
|
||||
listener *listener.Listener
|
||||
tlsConfig tls.ServerConfig
|
||||
httpServer *http.Server
|
||||
userManager *UserManager
|
||||
accessMutex sync.RWMutex
|
||||
usageTracker *AggregatedUsage
|
||||
webSocketMutex sync.Mutex
|
||||
webSocketGroup sync.WaitGroup
|
||||
webSocketConns map[*webSocketSession]struct{}
|
||||
shuttingDown bool
|
||||
|
||||
// Legacy mode
|
||||
legacyCredential *defaultCredential
|
||||
legacyProvider credentialProvider
|
||||
|
||||
// Multi-credential mode
|
||||
providers map[string]credentialProvider
|
||||
allCredentials []credential
|
||||
userConfigMap map[string]*option.OCMUser
|
||||
}
|
||||
|
||||
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) {
|
||||
serviceDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
Options: option.DialerOptions{
|
||||
Detour: options.Detour,
|
||||
},
|
||||
RemoteIsDomain: true,
|
||||
})
|
||||
err := validateOCMOptions(options)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create dialer")
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSClientConfig: &stdTLS.Config{
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
},
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
},
|
||||
return nil, E.Cause(err, "validate options")
|
||||
}
|
||||
|
||||
userManager := &UserManager{
|
||||
tokenMap: make(map[string]string),
|
||||
}
|
||||
|
||||
var usageTracker *AggregatedUsage
|
||||
if options.UsagesPath != "" {
|
||||
usageTracker = &AggregatedUsage{
|
||||
LastUpdated: time.Now(),
|
||||
Combinations: make([]CostCombination, 0),
|
||||
filePath: options.UsagesPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
service := &Service{
|
||||
Adapter: boxService.NewAdapter(C.TypeOCM, tag),
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
credentialPath: options.CredentialPath,
|
||||
users: options.Users,
|
||||
dialer: serviceDialer,
|
||||
httpClient: httpClient,
|
||||
httpHeaders: options.Headers.Build(),
|
||||
Adapter: boxService.NewAdapter(C.TypeOCM, tag),
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
options: options,
|
||||
httpHeaders: options.Headers.Build(),
|
||||
listener: listener.New(listener.Options{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
@@ -200,10 +248,36 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
|
||||
Listen: options.ListenOptions,
|
||||
}),
|
||||
userManager: userManager,
|
||||
usageTracker: usageTracker,
|
||||
webSocketConns: make(map[*webSocketSession]struct{}),
|
||||
}
|
||||
|
||||
if len(options.Credentials) > 0 {
|
||||
providers, allCredentials, err := buildOCMCredentialProviders(ctx, options, logger)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "build credential providers")
|
||||
}
|
||||
service.providers = providers
|
||||
service.allCredentials = allCredentials
|
||||
|
||||
userConfigMap := make(map[string]*option.OCMUser)
|
||||
for i := range options.Users {
|
||||
userConfigMap[options.Users[i].Name] = &options.Users[i]
|
||||
}
|
||||
service.userConfigMap = userConfigMap
|
||||
} else {
|
||||
cred, err := newDefaultCredential(ctx, "default", option.OCMDefaultCredentialOptions{
|
||||
CredentialPath: options.CredentialPath,
|
||||
UsagesPath: options.UsagesPath,
|
||||
Detour: options.Detour,
|
||||
}, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
service.legacyCredential = cred
|
||||
service.legacyProvider = &singleCredentialProvider{cred: cred}
|
||||
service.allCredentials = []credential{cred}
|
||||
}
|
||||
|
||||
if options.TLS != nil {
|
||||
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
|
||||
if err != nil {
|
||||
@@ -220,28 +294,35 @@ func (s *Service) Start(stage adapter.StartStage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.userManager.UpdateUsers(s.users)
|
||||
s.userManager.UpdateUsers(s.options.Users)
|
||||
|
||||
credentials, err := platformReadCredentials(s.credentialPath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read credentials")
|
||||
}
|
||||
s.credentials = credentials
|
||||
|
||||
if s.usageTracker != nil {
|
||||
err = s.usageTracker.Load()
|
||||
for _, cred := range s.allCredentials {
|
||||
if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil {
|
||||
extCred.reverseService = s
|
||||
}
|
||||
err := cred.start()
|
||||
if err != nil {
|
||||
s.logger.Warn("load usage statistics: ", err)
|
||||
return err
|
||||
}
|
||||
tag := cred.tagName()
|
||||
cred.setOnBecameUnusable(func() {
|
||||
s.interruptWebSocketSessionsForCredential(tag)
|
||||
})
|
||||
}
|
||||
if len(s.options.Credentials) > 0 {
|
||||
err := validateOCMCompositeCredentialModes(s.options, s.providers)
|
||||
if err != nil {
|
||||
return E.Cause(err, "validate loaded credentials")
|
||||
}
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Mount("/", s)
|
||||
|
||||
s.httpServer = &http.Server{Handler: router}
|
||||
s.httpServer = &http.Server{Handler: h2c.NewHandler(router, &http2.Server{})}
|
||||
|
||||
if s.tlsConfig != nil {
|
||||
err = s.tlsConfig.Start()
|
||||
err := s.tlsConfig.Start()
|
||||
if err != nil {
|
||||
return E.Cause(err, "create TLS config")
|
||||
}
|
||||
@@ -269,172 +350,247 @@ func (s *Service) Start(stage adapter.StartStage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) getAccessToken() (string, error) {
|
||||
s.accessMutex.RLock()
|
||||
if !s.credentials.needsRefresh() {
|
||||
token := s.credentials.getAccessToken()
|
||||
s.accessMutex.RUnlock()
|
||||
return token, nil
|
||||
func (s *Service) resolveCredentialProvider(username string) (credentialProvider, error) {
|
||||
if len(s.options.Users) > 0 {
|
||||
return credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
}
|
||||
s.accessMutex.RUnlock()
|
||||
|
||||
s.accessMutex.Lock()
|
||||
defer s.accessMutex.Unlock()
|
||||
|
||||
if !s.credentials.needsRefresh() {
|
||||
return s.credentials.getAccessToken(), nil
|
||||
provider := noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
|
||||
if provider == nil {
|
||||
return nil, E.New("no credential available")
|
||||
}
|
||||
|
||||
newCredentials, err := refreshToken(s.httpClient, s.credentials)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.credentials = newCredentials
|
||||
|
||||
err = platformWriteCredentials(newCredentials, s.credentialPath)
|
||||
if err != nil {
|
||||
s.logger.Warn("persist refreshed token: ", err)
|
||||
}
|
||||
|
||||
return newCredentials.getAccessToken(), nil
|
||||
}
|
||||
|
||||
func (s *Service) getAccountID() string {
|
||||
s.accessMutex.RLock()
|
||||
defer s.accessMutex.RUnlock()
|
||||
return s.credentials.getAccountID()
|
||||
}
|
||||
|
||||
func (s *Service) isAPIKeyMode() bool {
|
||||
s.accessMutex.RLock()
|
||||
defer s.accessMutex.RUnlock()
|
||||
return s.credentials.isAPIKeyMode()
|
||||
}
|
||||
|
||||
func (s *Service) getBaseURL() string {
|
||||
if s.isAPIKeyMode() {
|
||||
return openaiAPIBaseURL
|
||||
}
|
||||
return chatGPTBackendURL
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := log.ContextWithNewID(r.Context())
|
||||
if r.URL.Path == "/ocm/v1/status" {
|
||||
s.handleStatusEndpoint(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/ocm/v1/reverse" {
|
||||
s.handleReverseConnect(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
path := r.URL.Path
|
||||
if !strings.HasPrefix(path, "/v1/") {
|
||||
writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/")
|
||||
return
|
||||
}
|
||||
|
||||
var proxyPath string
|
||||
if s.isAPIKeyMode() {
|
||||
proxyPath = path
|
||||
} else {
|
||||
if path == "/v1/chat/completions" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"chat completions endpoint is only available in API key mode")
|
||||
return
|
||||
}
|
||||
proxyPath = strings.TrimPrefix(path, "/v1")
|
||||
}
|
||||
|
||||
var username string
|
||||
if len(s.users) > 0 {
|
||||
if len(s.options.Users) > 0 {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
var ok bool
|
||||
username, ok = s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
s.logger.WarnContext(ctx, "authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
|
||||
s.handleWebSocket(w, r, proxyPath, username)
|
||||
sessionID := r.Header.Get("session_id")
|
||||
|
||||
// Resolve credential provider and user config
|
||||
var provider credentialProvider
|
||||
var userConfig *option.OCMUser
|
||||
if len(s.options.Users) > 0 {
|
||||
userConfig = s.userConfigMap[username]
|
||||
var err error
|
||||
provider, err = credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "resolve credential: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
provider = noUserCredentialProvider(s.providers, s.legacyProvider, s.options)
|
||||
}
|
||||
if provider == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "no credential available")
|
||||
return
|
||||
}
|
||||
|
||||
var requestModel string
|
||||
provider.pollIfStale(s.ctx)
|
||||
|
||||
if s.usageTracker != nil && r.Body != nil {
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err == nil {
|
||||
var request struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
err := json.Unmarshal(bodyBytes, &request)
|
||||
if err == nil {
|
||||
requestModel = request.Model
|
||||
}
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
var credentialFilter func(credential) bool
|
||||
if userConfig != nil && !userConfig.AllowExternalUsage {
|
||||
credentialFilter = func(c credential) bool { return !c.isExternal() }
|
||||
}
|
||||
|
||||
selectedCredential, isNew, err := provider.selectCredential(sessionID, credentialFilter)
|
||||
if err != nil {
|
||||
writeNonRetryableCredentialError(w, unavailableCredentialMessage(provider, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
|
||||
s.handleWebSocket(ctx, w, r, path, username, sessionID, userConfig, provider, selectedCredential, credentialFilter, isNew)
|
||||
return
|
||||
}
|
||||
|
||||
if !selectedCredential.isExternal() && selectedCredential.ocmIsAPIKeyMode() {
|
||||
// API key mode path handling
|
||||
} else if !selectedCredential.isExternal() {
|
||||
if path == "/v1/chat/completions" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"chat completions endpoint is only available in API key mode")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, err := s.getAccessToken()
|
||||
if err != nil {
|
||||
s.logger.Error("get access token: ", err)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed")
|
||||
return
|
||||
shouldTrackUsage := selectedCredential.usageTrackerOrNil() != nil &&
|
||||
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses"))
|
||||
canRetryRequest := len(provider.allCredentials()) > 1
|
||||
|
||||
// Read body for model extraction and retry buffer when JSON replay is useful.
|
||||
var bodyBytes []byte
|
||||
var requestModel string
|
||||
var requestServiceTier string
|
||||
if r.Body != nil && (shouldTrackUsage || canRetryRequest) {
|
||||
mediaType, _, parseErr := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||
isJSONRequest := parseErr == nil && (mediaType == "application/json" || strings.HasSuffix(mediaType, "+json"))
|
||||
if isJSONRequest {
|
||||
bodyBytes, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "read request body: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "failed to read request body")
|
||||
return
|
||||
}
|
||||
var request struct {
|
||||
Model string `json:"model"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
}
|
||||
if json.Unmarshal(bodyBytes, &request) == nil {
|
||||
requestModel = request.Model
|
||||
requestServiceTier = request.ServiceTier
|
||||
}
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
proxyURL := s.getBaseURL() + proxyPath
|
||||
if r.URL.RawQuery != "" {
|
||||
proxyURL += "?" + r.URL.RawQuery
|
||||
if isNew {
|
||||
logParts := []any{"assigned credential ", selectedCredential.tagName()}
|
||||
if sessionID != "" {
|
||||
logParts = append(logParts, " for session ", sessionID)
|
||||
}
|
||||
if username != "" {
|
||||
logParts = append(logParts, " by user ", username)
|
||||
}
|
||||
if requestModel != "" {
|
||||
logParts = append(logParts, ", model=", requestModel)
|
||||
}
|
||||
if requestServiceTier == "priority" {
|
||||
logParts = append(logParts, ", fast")
|
||||
}
|
||||
s.logger.DebugContext(ctx, logParts...)
|
||||
}
|
||||
proxyRequest, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, r.Body)
|
||||
|
||||
requestContext := selectedCredential.wrapRequestContext(r.Context())
|
||||
defer func() {
|
||||
requestContext.cancelRequest()
|
||||
}()
|
||||
proxyRequest, err := selectedCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if err != nil {
|
||||
s.logger.Error("create proxy request: ", err)
|
||||
s.logger.ErrorContext(ctx, "create proxy request: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
for key, values := range r.Header {
|
||||
if !isHopByHopHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
for key, values := range s.httpHeaders {
|
||||
proxyRequest.Header.Del(key)
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
if accountID := s.getAccountID(); accountID != "" {
|
||||
proxyRequest.Header.Set("ChatGPT-Account-Id", accountID)
|
||||
}
|
||||
|
||||
response, err := s.httpClient.Do(proxyRequest)
|
||||
response, err := selectedCredential.httpTransport().Do(proxyRequest)
|
||||
if err != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "credential became unavailable while processing the request")
|
||||
return
|
||||
}
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
|
||||
// Transparent 429 retry
|
||||
for response.StatusCode == http.StatusTooManyRequests {
|
||||
resetAt := parseOCMRateLimitResetFromHeaders(response.Header)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
|
||||
needsBodyReplay := r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodDelete
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
if (needsBodyReplay && bodyBytes == nil) || nextCredential == nil {
|
||||
response.Body.Close()
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
|
||||
return
|
||||
}
|
||||
response.Body.Close()
|
||||
s.logger.InfoContext(ctx, "retrying with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
requestContext.cancelRequest()
|
||||
requestContext = nextCredential.wrapRequestContext(r.Context())
|
||||
retryRequest, buildErr := nextCredential.buildProxyRequest(requestContext, r, bodyBytes, s.httpHeaders)
|
||||
if buildErr != nil {
|
||||
s.logger.ErrorContext(ctx, "retry request: ", buildErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", buildErr.Error())
|
||||
return
|
||||
}
|
||||
retryResponse, retryErr := nextCredential.httpTransport().Do(retryRequest)
|
||||
if retryErr != nil {
|
||||
if r.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
if requestContext.Err() != nil {
|
||||
writeCredentialUnavailableError(w, r, provider, nextCredential, credentialFilter, "credential became unavailable while retrying the request")
|
||||
return
|
||||
}
|
||||
s.logger.ErrorContext(ctx, "retry request: ", retryErr)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", retryErr.Error())
|
||||
return
|
||||
}
|
||||
requestContext.releaseCredentialInterrupt()
|
||||
response = retryResponse
|
||||
selectedCredential = nextCredential
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
selectedCredential.updateStateFromHeaders(response.Header)
|
||||
|
||||
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusTooManyRequests {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
s.logger.ErrorContext(ctx, "upstream error from ", selectedCredential.tagName(), ": status ", response.StatusCode, " ", string(body))
|
||||
go selectedCredential.pollUsage(s.ctx)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error",
|
||||
"proxy request (status "+strconv.Itoa(response.StatusCode)+"): "+string(body))
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite response headers for external users
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" {
|
||||
s.rewriteResponseHeadersForExternalUser(response.Header, userConfig)
|
||||
}
|
||||
|
||||
for key, values := range response.Header {
|
||||
if !isHopByHopHeader(key) {
|
||||
if !isHopByHopHeader(key) && !isReverseProxyHeader(key) {
|
||||
w.Header()[key] = values
|
||||
}
|
||||
}
|
||||
w.WriteHeader(response.StatusCode)
|
||||
|
||||
trackUsage := s.usageTracker != nil && response.StatusCode == http.StatusOK &&
|
||||
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses"))
|
||||
if trackUsage {
|
||||
s.handleResponseWithTracking(w, response, path, requestModel, username)
|
||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||
if usageTracker != nil && response.StatusCode == http.StatusOK &&
|
||||
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses")) {
|
||||
s.handleResponseWithTracking(ctx, w, response, usageTracker, path, requestModel, username)
|
||||
} else {
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
if err == nil && mediaType != "text/event-stream" {
|
||||
@@ -443,7 +599,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error("streaming not supported")
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
@@ -452,7 +608,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if n > 0 {
|
||||
_, writeError := w.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.Error("write streaming response: ", writeError)
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
@@ -464,7 +620,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, path string, requestModel string, username string) {
|
||||
func (s *Service) handleResponseWithTracking(ctx context.Context, writer http.ResponseWriter, response *http.Response, usageTracker *AggregatedUsage, path string, requestModel string, username string) {
|
||||
isChatCompletions := path == "/v1/chat/completions"
|
||||
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
@@ -475,7 +631,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
if !isStreaming {
|
||||
bodyBytes, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
s.logger.Error("read response body: ", err)
|
||||
s.logger.ErrorContext(ctx, "read response body: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -508,7 +664,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
}
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
inputTokens,
|
||||
@@ -528,7 +684,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error("streaming not supported")
|
||||
s.logger.ErrorContext(ctx, "streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -605,7 +761,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
|
||||
_, writeError := writer.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.Error("write streaming response: ", writeError)
|
||||
s.logger.ErrorContext(ctx, "write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
@@ -619,7 +775,7 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
inputTokens,
|
||||
@@ -637,6 +793,124 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeJSONError(w, r, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
if len(s.options.Users) == 0 {
|
||||
writeJSONError(w, r, http.StatusForbidden, "authentication_error", "status endpoint requires user authentication")
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
username, ok := s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
|
||||
userConfig := s.userConfigMap[username]
|
||||
if userConfig == nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "user config not found")
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, username)
|
||||
if err != nil {
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
provider.pollIfStale(r.Context())
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]float64{
|
||||
"five_hour_utilization": avgFiveHour,
|
||||
"weekly_utilization": avgWeekly,
|
||||
"plan_weight": totalWeight,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) {
|
||||
var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64
|
||||
for _, cred := range provider.allCredentials() {
|
||||
if !cred.isAvailable() {
|
||||
continue
|
||||
}
|
||||
if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential {
|
||||
continue
|
||||
}
|
||||
if !userConfig.AllowExternalUsage && cred.isExternal() {
|
||||
continue
|
||||
}
|
||||
weight := cred.planWeight()
|
||||
remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization()
|
||||
if remaining5h < 0 {
|
||||
remaining5h = 0
|
||||
}
|
||||
remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization()
|
||||
if remainingWeekly < 0 {
|
||||
remainingWeekly = 0
|
||||
}
|
||||
totalWeightedRemaining5h += remaining5h * weight
|
||||
totalWeightedRemainingWeekly += remainingWeekly * weight
|
||||
totalWeight += weight
|
||||
}
|
||||
if totalWeight == 0 {
|
||||
return 100, 100, 0
|
||||
}
|
||||
return 100 - totalWeightedRemaining5h/totalWeight,
|
||||
100 - totalWeightedRemainingWeekly/totalWeight,
|
||||
totalWeight
|
||||
}
|
||||
|
||||
func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.OCMUser) {
|
||||
provider, err := credentialForUser(s.userConfigMap, s.providers, s.legacyProvider, userConfig.Name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
|
||||
if activeLimitIdentifier == "" {
|
||||
activeLimitIdentifier = "codex"
|
||||
}
|
||||
|
||||
headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64))
|
||||
headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64))
|
||||
if totalWeight > 0 {
|
||||
headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) InterfaceUpdated() {
|
||||
for _, cred := range s.allCredentials {
|
||||
extCred, ok := cred.(*externalCredential)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if extCred.reverse && extCred.connectorURL != nil {
|
||||
extCred.reverseService = s
|
||||
extCred.resetReverseContext()
|
||||
go extCred.connectorLoop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
webSocketSessions := s.startWebSocketShutdown()
|
||||
|
||||
@@ -650,12 +924,8 @@ func (s *Service) Close() error {
|
||||
}
|
||||
s.webSocketGroup.Wait()
|
||||
|
||||
if s.usageTracker != nil {
|
||||
s.usageTracker.cancelPendingSave()
|
||||
saveErr := s.usageTracker.Save()
|
||||
if saveErr != nil {
|
||||
s.logger.Error("save usage statistics: ", saveErr)
|
||||
}
|
||||
for _, cred := range s.allCredentials {
|
||||
cred.close()
|
||||
}
|
||||
|
||||
return err
|
||||
@@ -693,6 +963,20 @@ func (s *Service) isShuttingDown() bool {
|
||||
return s.shuttingDown
|
||||
}
|
||||
|
||||
func (s *Service) interruptWebSocketSessionsForCredential(tag string) {
|
||||
s.webSocketMutex.Lock()
|
||||
var toClose []*webSocketSession
|
||||
for session := range s.webSocketConns {
|
||||
if session.credentialTag == tag {
|
||||
toClose = append(toClose, session)
|
||||
}
|
||||
}
|
||||
s.webSocketMutex.Unlock()
|
||||
for _, session := range toClose {
|
||||
session.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) startWebSocketShutdown() []*webSocketSession {
|
||||
s.webSocketMutex.Lock()
|
||||
defer s.webSocketMutex.Unlock()
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
@@ -22,9 +26,10 @@ import (
|
||||
)
|
||||
|
||||
type webSocketSession struct {
|
||||
clientConn net.Conn
|
||||
upstreamConn net.Conn
|
||||
closeOnce sync.Once
|
||||
clientConn net.Conn
|
||||
upstreamConn net.Conn
|
||||
credentialTag string
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (s *webSocketSession) Close() {
|
||||
@@ -61,7 +66,7 @@ func isForwardableResponseHeader(key string) bool {
|
||||
}
|
||||
|
||||
func isForwardableWebSocketRequestHeader(key string) bool {
|
||||
if isHopByHopHeader(key) {
|
||||
if isHopByHopHeader(key) || isReverseProxyHeader(key) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -76,65 +81,141 @@ func isForwardableWebSocketRequestHeader(key string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyPath string, username string) {
|
||||
accessToken, err := s.getAccessToken()
|
||||
if err != nil {
|
||||
s.logger.Error("get access token for websocket: ", err)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed")
|
||||
return
|
||||
}
|
||||
func (s *Service) handleWebSocket(
|
||||
ctx context.Context,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
path string,
|
||||
username string,
|
||||
sessionID string,
|
||||
userConfig *option.OCMUser,
|
||||
provider credentialProvider,
|
||||
selectedCredential credential,
|
||||
credentialFilter func(credential) bool,
|
||||
isNew bool,
|
||||
) {
|
||||
var (
|
||||
err error
|
||||
upstreamConn net.Conn
|
||||
upstreamBufferedReader *bufio.Reader
|
||||
upstreamResponseHeaders http.Header
|
||||
statusCode int
|
||||
statusResponseBody string
|
||||
)
|
||||
|
||||
upstreamURL := buildUpstreamWebSocketURL(s.getBaseURL(), proxyPath)
|
||||
if r.URL.RawQuery != "" {
|
||||
upstreamURL += "?" + r.URL.RawQuery
|
||||
}
|
||||
for {
|
||||
accessToken, accessErr := selectedCredential.getAccessToken()
|
||||
if accessErr != nil {
|
||||
s.logger.ErrorContext(ctx, "get access token for websocket: ", accessErr)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed")
|
||||
return
|
||||
}
|
||||
|
||||
upstreamHeaders := make(http.Header)
|
||||
for key, values := range r.Header {
|
||||
if isForwardableWebSocketRequestHeader(key) {
|
||||
var proxyPath string
|
||||
if selectedCredential.ocmIsAPIKeyMode() || selectedCredential.isExternal() {
|
||||
proxyPath = path
|
||||
} else {
|
||||
proxyPath = strings.TrimPrefix(path, "/v1")
|
||||
}
|
||||
|
||||
upstreamURL := buildUpstreamWebSocketURL(selectedCredential.ocmGetBaseURL(), proxyPath)
|
||||
if r.URL.RawQuery != "" {
|
||||
upstreamURL += "?" + r.URL.RawQuery
|
||||
}
|
||||
|
||||
upstreamHeaders := make(http.Header)
|
||||
for key, values := range r.Header {
|
||||
if isForwardableWebSocketRequestHeader(key) {
|
||||
upstreamHeaders[key] = values
|
||||
}
|
||||
}
|
||||
for key, values := range s.httpHeaders {
|
||||
upstreamHeaders.Del(key)
|
||||
upstreamHeaders[key] = values
|
||||
}
|
||||
}
|
||||
for key, values := range s.httpHeaders {
|
||||
upstreamHeaders.Del(key)
|
||||
upstreamHeaders[key] = values
|
||||
}
|
||||
upstreamHeaders.Set("Authorization", "Bearer "+accessToken)
|
||||
if accountID := s.getAccountID(); accountID != "" {
|
||||
upstreamHeaders.Set("ChatGPT-Account-Id", accountID)
|
||||
}
|
||||
upstreamHeaders.Set("Authorization", "Bearer "+accessToken)
|
||||
if accountID := selectedCredential.ocmGetAccountID(); accountID != "" {
|
||||
upstreamHeaders.Set("ChatGPT-Account-Id", accountID)
|
||||
}
|
||||
if upstreamHeaders.Get("OpenAI-Beta") == "" {
|
||||
upstreamHeaders.Set("OpenAI-Beta", "responses_websockets=2026-02-06")
|
||||
}
|
||||
|
||||
upstreamResponseHeaders := make(http.Header)
|
||||
upstreamDialer := ws.Dialer{
|
||||
NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
TLSConfig: &stdTLS.Config{
|
||||
RootCAs: adapter.RootPoolFromContext(s.ctx),
|
||||
Time: ntp.TimeFuncFromContext(s.ctx),
|
||||
},
|
||||
Header: ws.HandshakeHeaderHTTP(upstreamHeaders),
|
||||
OnHeader: func(key, value []byte) error {
|
||||
upstreamResponseHeaders.Add(string(key), string(value))
|
||||
return nil
|
||||
},
|
||||
}
|
||||
upstreamResponseHeaders = make(http.Header)
|
||||
statusCode = 0
|
||||
statusResponseBody = ""
|
||||
upstreamDialer := ws.Dialer{
|
||||
NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return selectedCredential.ocmDialer().DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
TLSConfig: &stdTLS.Config{
|
||||
RootCAs: adapter.RootPoolFromContext(s.ctx),
|
||||
Time: ntp.TimeFuncFromContext(s.ctx),
|
||||
},
|
||||
Header: ws.HandshakeHeaderHTTP(upstreamHeaders),
|
||||
// gobwas/ws@v1.4.0: the response io.Reader is
|
||||
// MultiReader(statusLine_without_CRLF, "\r\n", bufferedConn).
|
||||
// ReadString('\n') consumes the status line, then ReadMIMEHeader
|
||||
// parses the remaining headers.
|
||||
OnStatusError: func(status int, reason []byte, response io.Reader) {
|
||||
statusCode = status
|
||||
bufferedResponse := bufio.NewReader(response)
|
||||
_, readErr := bufferedResponse.ReadString('\n')
|
||||
if readErr != nil {
|
||||
return
|
||||
}
|
||||
mimeHeader, readErr := textproto.NewReader(bufferedResponse).ReadMIMEHeader()
|
||||
if readErr == nil {
|
||||
upstreamResponseHeaders = http.Header(mimeHeader)
|
||||
}
|
||||
body, readErr := io.ReadAll(io.LimitReader(bufferedResponse, 4096))
|
||||
if readErr == nil && len(body) > 0 {
|
||||
statusResponseBody = string(body)
|
||||
}
|
||||
},
|
||||
OnHeader: func(key, value []byte) error {
|
||||
upstreamResponseHeaders.Add(string(key), string(value))
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
upstreamConn, upstreamBufferedReader, _, err := upstreamDialer.Dial(r.Context(), upstreamURL)
|
||||
if err != nil {
|
||||
s.logger.Error("dial upstream websocket: ", err)
|
||||
upstreamConn, upstreamBufferedReader, _, err = upstreamDialer.Dial(s.ctx, upstreamURL)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if statusCode == http.StatusTooManyRequests {
|
||||
resetAt := parseOCMRateLimitResetFromHeaders(upstreamResponseHeaders)
|
||||
nextCredential := provider.onRateLimited(sessionID, selectedCredential, resetAt, credentialFilter)
|
||||
selectedCredential.updateStateFromHeaders(upstreamResponseHeaders)
|
||||
if nextCredential == nil {
|
||||
writeCredentialUnavailableError(w, r, provider, selectedCredential, credentialFilter, "all credentials rate-limited")
|
||||
return
|
||||
}
|
||||
s.logger.InfoContext(ctx, "retrying websocket with credential ", nextCredential.tagName(), " after 429 from ", selectedCredential.tagName())
|
||||
selectedCredential = nextCredential
|
||||
continue
|
||||
}
|
||||
if statusCode > 0 && statusResponseBody != "" {
|
||||
s.logger.ErrorContext(ctx, "dial upstream websocket: status ", statusCode, " body: ", statusResponseBody)
|
||||
} else {
|
||||
s.logger.ErrorContext(ctx, "dial upstream websocket: ", err)
|
||||
}
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", "upstream websocket connection failed")
|
||||
return
|
||||
}
|
||||
|
||||
selectedCredential.updateStateFromHeaders(upstreamResponseHeaders)
|
||||
weeklyCycleHint := extractWeeklyCycleHint(upstreamResponseHeaders)
|
||||
|
||||
clientResponseHeaders := make(http.Header)
|
||||
for key, values := range upstreamResponseHeaders {
|
||||
if isForwardableResponseHeader(key) {
|
||||
clientResponseHeaders[key] = values
|
||||
clientResponseHeaders[key] = append([]string(nil), values...)
|
||||
}
|
||||
}
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" {
|
||||
s.rewriteResponseHeadersForExternalUser(clientResponseHeaders, userConfig)
|
||||
}
|
||||
|
||||
clientUpgrader := ws.HTTPUpgrader{
|
||||
Header: clientResponseHeaders,
|
||||
@@ -146,13 +227,14 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP
|
||||
}
|
||||
clientConn, _, _, err := clientUpgrader.Upgrade(r, w)
|
||||
if err != nil {
|
||||
s.logger.Error("upgrade client websocket: ", err)
|
||||
s.logger.ErrorContext(ctx, "upgrade client websocket: ", err)
|
||||
upstreamConn.Close()
|
||||
return
|
||||
}
|
||||
session := &webSocketSession{
|
||||
clientConn: clientConn,
|
||||
upstreamConn: upstreamConn,
|
||||
clientConn: clientConn,
|
||||
upstreamConn: upstreamConn,
|
||||
credentialTag: selectedCredential.tagName(),
|
||||
}
|
||||
if !s.registerWebSocketSession(session) {
|
||||
session.Close()
|
||||
@@ -177,35 +259,54 @@ func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyP
|
||||
go func() {
|
||||
defer waitGroup.Done()
|
||||
defer session.Close()
|
||||
s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, modelChannel)
|
||||
s.proxyWebSocketClientToUpstream(ctx, clientConn, upstreamConn, selectedCredential, modelChannel, isNew, username, sessionID)
|
||||
}()
|
||||
go func() {
|
||||
defer waitGroup.Done()
|
||||
defer session.Close()
|
||||
s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, modelChannel, username, weeklyCycleHint)
|
||||
s.proxyWebSocketUpstreamToClient(ctx, upstreamReadWriter, clientConn, selectedCredential, userConfig, provider, modelChannel, username, weeklyCycleHint)
|
||||
}()
|
||||
waitGroup.Wait()
|
||||
}
|
||||
|
||||
func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, modelChannel chan<- string) {
|
||||
func (s *Service) proxyWebSocketClientToUpstream(ctx context.Context, clientConn net.Conn, upstreamConn net.Conn, selectedCredential credential, modelChannel chan<- string, isNew bool, username string, sessionID string) {
|
||||
logged := false
|
||||
for {
|
||||
data, opCode, err := wsutil.ReadClientData(clientConn)
|
||||
if err != nil {
|
||||
if !E.IsClosedOrCanceled(err) {
|
||||
s.logger.Debug("read client websocket: ", err)
|
||||
s.logger.DebugContext(ctx, "read client websocket: ", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if opCode == ws.OpText && s.usageTracker != nil {
|
||||
if opCode == ws.OpText {
|
||||
var request struct {
|
||||
Type string `json:"type"`
|
||||
Model string `json:"model"`
|
||||
Type string `json:"type"`
|
||||
Model string `json:"model"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
}
|
||||
if json.Unmarshal(data, &request) == nil && request.Type == "response.create" && request.Model != "" {
|
||||
select {
|
||||
case modelChannel <- request.Model:
|
||||
default:
|
||||
if isNew && !logged {
|
||||
logged = true
|
||||
logParts := []any{"assigned credential ", selectedCredential.tagName()}
|
||||
if sessionID != "" {
|
||||
logParts = append(logParts, " for session ", sessionID)
|
||||
}
|
||||
if username != "" {
|
||||
logParts = append(logParts, " by user ", username)
|
||||
}
|
||||
logParts = append(logParts, ", model=", request.Model)
|
||||
if request.ServiceTier == "priority" {
|
||||
logParts = append(logParts, ", fast")
|
||||
}
|
||||
s.logger.DebugContext(ctx, logParts...)
|
||||
}
|
||||
if selectedCredential.usageTrackerOrNil() != nil {
|
||||
select {
|
||||
case modelChannel <- request.Model:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -213,62 +314,52 @@ func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamCo
|
||||
err = wsutil.WriteClientMessage(upstreamConn, opCode, data)
|
||||
if err != nil {
|
||||
if !E.IsClosedOrCanceled(err) {
|
||||
s.logger.Debug("write upstream websocket: ", err)
|
||||
s.logger.DebugContext(ctx, "write upstream websocket: ", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
|
||||
func (s *Service) proxyWebSocketUpstreamToClient(ctx context.Context, upstreamReadWriter io.ReadWriter, clientConn net.Conn, selectedCredential credential, userConfig *option.OCMUser, provider credentialProvider, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
|
||||
usageTracker := selectedCredential.usageTrackerOrNil()
|
||||
var requestModel string
|
||||
for {
|
||||
data, opCode, err := wsutil.ReadServerData(upstreamReadWriter)
|
||||
if err != nil {
|
||||
if !E.IsClosedOrCanceled(err) {
|
||||
s.logger.Debug("read upstream websocket: ", err)
|
||||
s.logger.DebugContext(ctx, "read upstream websocket: ", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if opCode == ws.OpText && s.usageTracker != nil {
|
||||
select {
|
||||
case model := <-modelChannel:
|
||||
requestModel = model
|
||||
default:
|
||||
}
|
||||
|
||||
if opCode == ws.OpText {
|
||||
var event struct {
|
||||
Type string `json:"type"`
|
||||
Type string `json:"type"`
|
||||
StatusCode int `json:"status_code"`
|
||||
}
|
||||
if json.Unmarshal(data, &event) == nil && event.Type == "response.completed" {
|
||||
var streamEvent responses.ResponseStreamEventUnion
|
||||
if json.Unmarshal(data, &streamEvent) == nil {
|
||||
completedEvent := streamEvent.AsResponseCompleted()
|
||||
responseModel := string(completedEvent.Response.Model)
|
||||
serviceTier := string(completedEvent.Response.ServiceTier)
|
||||
inputTokens := completedEvent.Response.Usage.InputTokens
|
||||
outputTokens := completedEvent.Response.Usage.OutputTokens
|
||||
cachedTokens := completedEvent.Response.Usage.InputTokensDetails.CachedTokens
|
||||
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
if json.Unmarshal(data, &event) == nil {
|
||||
switch event.Type {
|
||||
case "codex.rate_limits":
|
||||
s.handleWebSocketRateLimitsEvent(data, selectedCredential)
|
||||
if userConfig != nil && userConfig.ExternalCredential != "" {
|
||||
rewritten, rewriteErr := s.rewriteWebSocketRateLimitsForExternalUser(data, provider, userConfig)
|
||||
if rewriteErr == nil {
|
||||
data = rewritten
|
||||
}
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cachedTokens,
|
||||
serviceTier,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
case "error":
|
||||
if event.StatusCode == http.StatusTooManyRequests {
|
||||
s.handleWebSocketErrorRateLimited(data, selectedCredential)
|
||||
}
|
||||
case "response.completed":
|
||||
if usageTracker != nil {
|
||||
select {
|
||||
case model := <-modelChannel:
|
||||
requestModel = model
|
||||
default:
|
||||
}
|
||||
s.handleWebSocketResponseCompleted(data, usageTracker, requestModel, username, weeklyCycleHint)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -277,9 +368,175 @@ func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWrite
|
||||
err = wsutil.WriteServerMessage(clientConn, opCode, data)
|
||||
if err != nil {
|
||||
if !E.IsClosedOrCanceled(err) {
|
||||
s.logger.Debug("write client websocket: ", err)
|
||||
s.logger.DebugContext(ctx, "write client websocket: ", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential credential) {
|
||||
var rateLimitsEvent struct {
|
||||
RateLimits struct {
|
||||
Primary *struct {
|
||||
UsedPercent float64 `json:"used_percent"`
|
||||
ResetAt int64 `json:"reset_at"`
|
||||
} `json:"primary"`
|
||||
Secondary *struct {
|
||||
UsedPercent float64 `json:"used_percent"`
|
||||
ResetAt int64 `json:"reset_at"`
|
||||
} `json:"secondary"`
|
||||
} `json:"rate_limits"`
|
||||
LimitName string `json:"limit_name"`
|
||||
MeteredLimitName string `json:"metered_limit_name"`
|
||||
PlanWeight float64 `json:"plan_weight"`
|
||||
}
|
||||
err := json.Unmarshal(data, &rateLimitsEvent)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
identifier := rateLimitsEvent.MeteredLimitName
|
||||
if identifier == "" {
|
||||
identifier = rateLimitsEvent.LimitName
|
||||
}
|
||||
if identifier == "" {
|
||||
identifier = "codex"
|
||||
}
|
||||
identifier = normalizeRateLimitIdentifier(identifier)
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("x-codex-active-limit", identifier)
|
||||
if w := rateLimitsEvent.RateLimits.Primary; w != nil {
|
||||
headers.Set("x-"+identifier+"-primary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64))
|
||||
if w.ResetAt > 0 {
|
||||
headers.Set("x-"+identifier+"-primary-reset-at", strconv.FormatInt(w.ResetAt, 10))
|
||||
}
|
||||
}
|
||||
if w := rateLimitsEvent.RateLimits.Secondary; w != nil {
|
||||
headers.Set("x-"+identifier+"-secondary-used-percent", strconv.FormatFloat(w.UsedPercent, 'f', -1, 64))
|
||||
if w.ResetAt > 0 {
|
||||
headers.Set("x-"+identifier+"-secondary-reset-at", strconv.FormatInt(w.ResetAt, 10))
|
||||
}
|
||||
}
|
||||
if rateLimitsEvent.PlanWeight > 0 {
|
||||
headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(rateLimitsEvent.PlanWeight, 'f', -1, 64))
|
||||
}
|
||||
selectedCredential.updateStateFromHeaders(headers)
|
||||
}
|
||||
|
||||
func (s *Service) handleWebSocketErrorRateLimited(data []byte, selectedCredential credential) {
|
||||
var errorEvent struct {
|
||||
Headers map[string]string `json:"headers"`
|
||||
}
|
||||
err := json.Unmarshal(data, &errorEvent)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
headers := make(http.Header)
|
||||
for key, value := range errorEvent.Headers {
|
||||
headers.Set(key, value)
|
||||
}
|
||||
selectedCredential.updateStateFromHeaders(headers)
|
||||
resetAt := parseOCMRateLimitResetFromHeaders(headers)
|
||||
selectedCredential.markRateLimited(resetAt)
|
||||
}
|
||||
|
||||
func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provider credentialProvider, userConfig *option.OCMUser) ([]byte, error) {
|
||||
var event map[string]json.RawMessage
|
||||
err := json.Unmarshal(data, &event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rateLimitsData, exists := event["rate_limits"]
|
||||
if !exists || len(rateLimitsData) == 0 || string(rateLimitsData) == "null" {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
var rateLimits map[string]json.RawMessage
|
||||
err = json.Unmarshal(rateLimitsData, &rateLimits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
averageFiveHour, averageWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig)
|
||||
|
||||
if totalWeight > 0 {
|
||||
event["plan_weight"], _ = json.Marshal(totalWeight)
|
||||
}
|
||||
|
||||
primaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["primary"], averageFiveHour)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if primaryData != nil {
|
||||
rateLimits["primary"] = primaryData
|
||||
}
|
||||
|
||||
secondaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["secondary"], averageWeekly)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if secondaryData != nil {
|
||||
rateLimits["secondary"] = secondaryData
|
||||
}
|
||||
|
||||
event["rate_limits"], err = json.Marshal(rateLimits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return json.Marshal(event)
|
||||
}
|
||||
|
||||
func rewriteWebSocketRateLimitWindow(data json.RawMessage, usedPercent float64) (json.RawMessage, error) {
|
||||
if len(data) == 0 || string(data) == "null" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var window map[string]json.RawMessage
|
||||
err := json.Unmarshal(data, &window)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
window["used_percent"], err = json.Marshal(usedPercent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return json.Marshal(window)
|
||||
}
|
||||
|
||||
func (s *Service) handleWebSocketResponseCompleted(data []byte, usageTracker *AggregatedUsage, requestModel string, username string, weeklyCycleHint *WeeklyCycleHint) {
|
||||
var streamEvent responses.ResponseStreamEventUnion
|
||||
if json.Unmarshal(data, &streamEvent) != nil {
|
||||
return
|
||||
}
|
||||
completedEvent := streamEvent.AsResponseCompleted()
|
||||
responseModel := string(completedEvent.Response.Model)
|
||||
serviceTier := string(completedEvent.Response.ServiceTier)
|
||||
inputTokens := completedEvent.Response.Usage.InputTokens
|
||||
outputTokens := completedEvent.Response.Usage.OutputTokens
|
||||
cachedTokens := completedEvent.Response.Usage.InputTokensDetails.CachedTokens
|
||||
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
|
||||
usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cachedTokens,
|
||||
serviceTier,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
|
||||
func RegisterService(registry *boxService.Registry) {
|
||||
@@ -59,7 +60,7 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
|
||||
Listen: options.ListenOptions,
|
||||
}),
|
||||
httpServer: &http.Server{
|
||||
Handler: chiRouter,
|
||||
Handler: h2c.NewHandler(chiRouter, &http2.Server{}),
|
||||
},
|
||||
traffics: make(map[string]*TrafficManager),
|
||||
users: make(map[string]*UserManager),
|
||||
|
||||
Reference in New Issue
Block a user