Fix reverse external credential handling

This commit is contained in:
世界
2026-03-13 19:36:51 +08:00
parent 970951f369
commit af94ea9089
6 changed files with 239 additions and 126 deletions

View File

@@ -10,6 +10,7 @@ import (
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
@@ -46,15 +47,59 @@ type externalCredential struct {
requestAccess sync.Mutex
// Reverse proxy fields
reverse bool
reverseSession *yamux.Session
reverseAccess sync.RWMutex
reverseContext context.Context
reverseCancel context.CancelFunc
connectorDialer N.Dialer
connectorURL *url.URL
connectorTLS *stdTLS.Config
reverseService http.Handler
reverse bool
reverseSession *yamux.Session
reverseAccess sync.RWMutex
reverseContext context.Context
reverseCancel context.CancelFunc
connectorDialer N.Dialer
connectorDestination M.Socksaddr
connectorRequestPath string
connectorURL *url.URL
connectorTLS *stdTLS.Config
reverseService http.Handler
}
func externalCredentialURLPort(parsedURL *url.URL) uint16 {
portStr := parsedURL.Port()
if portStr != "" {
port, err := strconv.ParseUint(portStr, 10, 16)
if err == nil {
return uint16(port)
}
}
if parsedURL.Scheme == "https" {
return 443
}
return 80
}
func externalCredentialServerPort(parsedURL *url.URL, configuredPort uint16) uint16 {
if configuredPort != 0 {
return configuredPort
}
return externalCredentialURLPort(parsedURL)
}
func externalCredentialBaseURL(parsedURL *url.URL) string {
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
if parsedURL.Path != "" && parsedURL.Path != "/" {
baseURL += parsedURL.Path
}
if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
baseURL = baseURL[:len(baseURL)-1]
}
return baseURL
}
func externalCredentialReversePath(parsedURL *url.URL, endpointPath string) string {
pathPrefix := parsedURL.EscapedPath()
if pathPrefix == "/" {
pathPrefix = ""
} else {
pathPrefix = strings.TrimSuffix(pathPrefix, "/")
}
return pathPrefix + endpointPath
}
func newExternalCredential(ctx context.Context, tag string, options option.CCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) {
@@ -85,11 +130,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx
Transport: &http.Transport{
ForceAttemptHTTP2: false,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
session := cred.getReverseSession()
if session == nil || session.IsClosed() {
return nil, E.New("reverse connection not established for ", cred.tag)
}
return session.Open()
return cred.openReverseConnection(ctx)
},
},
}
@@ -115,24 +156,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx
ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if options.Server != "" {
serverPort := options.ServerPort
if serverPort == 0 {
portStr := parsedURL.Port()
if portStr != "" {
port, parseErr := strconv.ParseUint(portStr, 10, 16)
if parseErr == nil {
serverPort = uint16(port)
}
}
if serverPort == 0 {
if parsedURL.Scheme == "https" {
serverPort = 443
} else {
serverPort = 80
}
}
}
destination := M.ParseSocksaddrHostPort(options.Server, serverPort)
destination := M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort))
return credentialDialer.DialContext(ctx, network, destination)
}
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
@@ -147,19 +171,17 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx
}
}
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
if parsedURL.Path != "" && parsedURL.Path != "/" {
baseURL += parsedURL.Path
}
if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
baseURL = baseURL[:len(baseURL)-1]
}
cred.baseURL = baseURL
cred.baseURL = externalCredentialBaseURL(parsedURL)
if options.Reverse {
// Connector mode: we dial out to serve, not to proxy
cred.connectorDialer = credentialDialer
if options.Server != "" {
cred.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort))
} else {
cred.connectorDestination = M.ParseSocksaddrHostPort(parsedURL.Hostname(), externalCredentialURLPort(parsedURL))
}
cred.connectorRequestPath = externalCredentialReversePath(parsedURL, "/ccm/v1/reverse")
cred.connectorURL = parsedURL
if parsedURL.Scheme == "https" {
cred.connectorTLS = &stdTLS.Config{
@@ -208,18 +230,13 @@ func (c *externalCredential) isExternal() bool {
}
func (c *externalCredential) isAvailable() bool {
if c.reverse && c.connectorURL != nil {
return false // connector mode: not for local proxying
}
if c.baseURL == reverseProxyBaseURL {
// receiver mode: only available when reverse connection active
session := c.getReverseSession()
return session != nil && !session.IsClosed()
}
return true
return c.unavailableError() == nil
}
func (c *externalCredential) isUsable() bool {
if !c.isAvailable() {
return false
}
c.stateMutex.RLock()
if c.state.hardRateLimited {
if time.Now().Before(c.state.rateLimitResetAt) {
@@ -279,6 +296,15 @@ func (c *externalCredential) earliestReset() time.Time {
}
func (c *externalCredential) unavailableError() error {
if c.reverse && c.connectorURL != nil {
return E.New("credential ", c.tag, " is unavailable: reverse connector credentials cannot serve local requests")
}
if c.baseURL == reverseProxyBaseURL {
session := c.getReverseSession()
if session == nil || session.IsClosed() {
return E.New("credential ", c.tag, " is unavailable: reverse connection not established")
}
}
return nil
}
@@ -310,6 +336,32 @@ func (c *externalCredential) buildProxyRequest(ctx context.Context, original *ht
return proxyRequest, nil
}
func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Conn, error) {
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
session := c.getReverseSession()
if session == nil || session.IsClosed() {
return nil, E.New("reverse connection not established for ", c.tag)
}
conn, err := session.Open()
if err != nil {
return nil, err
}
select {
case <-ctx.Done():
conn.Close()
return nil, ctx.Err()
default:
}
return conn, nil
}
func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
c.stateMutex.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()

View File

@@ -160,6 +160,9 @@ func connectorBackoff(failures int) time.Duration {
}
func (c *externalCredential) connectorConnect() error {
if c.reverseService == nil {
return E.New("reverse service not initialized")
}
destination := c.connectorResolveDestination()
conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination)
if err != nil {
@@ -176,7 +179,7 @@ func (c *externalCredential) connectorConnect() error {
conn = tlsConn
}
upgradeRequest := "GET /ccm/v1/reverse HTTP/1.1\r\n" +
upgradeRequest := "GET " + c.connectorRequestPath + " HTTP/1.1\r\n" +
"Host: " + c.connectorURL.Host + "\r\n" +
"Connection: Upgrade\r\n" +
"Upgrade: reverse-proxy\r\n" +
@@ -231,13 +234,5 @@ func (c *externalCredential) connectorConnect() error {
}
func (c *externalCredential) connectorResolveDestination() M.Socksaddr {
port := c.connectorURL.Port()
if port == "" {
if c.connectorURL.Scheme == "https" {
port = "443"
} else {
port = "80"
}
}
return M.ParseSocksaddr(net.JoinHostPort(c.connectorURL.Hostname(), port))
return c.connectorDestination
}

View File

@@ -254,13 +254,13 @@ func (s *Service) Start(stage adapter.StartStage) error {
s.userManager.UpdateUsers(s.options.Users)
for _, cred := range s.allCredentials {
if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil {
extCred.reverseService = s
}
err := cred.start()
if err != nil {
return err
}
if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil {
extCred.reverseService = s
}
}
router := chi.NewRouter()
@@ -801,6 +801,7 @@ func (s *Service) InterfaceUpdated() {
continue
}
if extCred.reverse && extCred.connectorURL != nil {
extCred.reverseService = s
extCred.reverseCancel()
extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background())
go extCred.connectorLoop()

View File

@@ -9,7 +9,9 @@ import (
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
@@ -47,15 +49,74 @@ type externalCredential struct {
requestAccess sync.Mutex
// Reverse proxy fields
reverse bool
reverseSession *yamux.Session
reverseAccess sync.RWMutex
reverseContext context.Context
reverseCancel context.CancelFunc
connectorDialer N.Dialer
connectorURL *url.URL
connectorTLS *stdTLS.Config
reverseService http.Handler
reverse bool
reverseSession *yamux.Session
reverseAccess sync.RWMutex
reverseContext context.Context
reverseCancel context.CancelFunc
connectorDialer N.Dialer
connectorDestination M.Socksaddr
connectorRequestPath string
connectorURL *url.URL
connectorTLS *stdTLS.Config
reverseService http.Handler
}
type reverseSessionDialer struct {
credential *externalCredential
}
func (d reverseSessionDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if N.NetworkName(network) != N.NetworkTCP {
return nil, os.ErrInvalid
}
return d.credential.openReverseConnection(ctx)
}
func (d reverseSessionDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return nil, os.ErrInvalid
}
func externalCredentialURLPort(parsedURL *url.URL) uint16 {
portStr := parsedURL.Port()
if portStr != "" {
port, err := strconv.ParseUint(portStr, 10, 16)
if err == nil {
return uint16(port)
}
}
if parsedURL.Scheme == "https" {
return 443
}
return 80
}
func externalCredentialServerPort(parsedURL *url.URL, configuredPort uint16) uint16 {
if configuredPort != 0 {
return configuredPort
}
return externalCredentialURLPort(parsedURL)
}
func externalCredentialBaseURL(parsedURL *url.URL) string {
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
if parsedURL.Path != "" && parsedURL.Path != "/" {
baseURL += parsedURL.Path
}
if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
baseURL = baseURL[:len(baseURL)-1]
}
return baseURL
}
func externalCredentialReversePath(parsedURL *url.URL, endpointPath string) string {
pathPrefix := parsedURL.EscapedPath()
if pathPrefix == "/" {
pathPrefix = ""
} else {
pathPrefix = strings.TrimSuffix(pathPrefix, "/")
}
return pathPrefix + endpointPath
}
func newExternalCredential(ctx context.Context, tag string, options option.OCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) {
@@ -82,15 +143,12 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx
if options.URL == "" {
// Receiver mode: no URL, wait for reverse connection
cred.baseURL = reverseProxyBaseURL
cred.credDialer = reverseSessionDialer{credential: cred}
cred.httpClient = &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: false,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
session := cred.getReverseSession()
if session == nil || session.IsClosed() {
return nil, E.New("reverse connection not established for ", cred.tag)
}
return session.Open()
return cred.openReverseConnection(ctx)
},
},
}
@@ -116,24 +174,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx
ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if options.Server != "" {
serverPort := options.ServerPort
if serverPort == 0 {
portStr := parsedURL.Port()
if portStr != "" {
port, parseErr := strconv.ParseUint(portStr, 10, 16)
if parseErr == nil {
serverPort = uint16(port)
}
}
if serverPort == 0 {
if parsedURL.Scheme == "https" {
serverPort = 443
} else {
serverPort = 80
}
}
}
destination := M.ParseSocksaddrHostPort(options.Server, serverPort)
destination := M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort))
return credentialDialer.DialContext(ctx, network, destination)
}
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
@@ -148,19 +189,17 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx
}
}
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
if parsedURL.Path != "" && parsedURL.Path != "/" {
baseURL += parsedURL.Path
}
if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
baseURL = baseURL[:len(baseURL)-1]
}
cred.baseURL = baseURL
cred.baseURL = externalCredentialBaseURL(parsedURL)
if options.Reverse {
// Connector mode: we dial out to serve, not to proxy
cred.connectorDialer = credentialDialer
if options.Server != "" {
cred.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort))
} else {
cred.connectorDestination = M.ParseSocksaddrHostPort(parsedURL.Hostname(), externalCredentialURLPort(parsedURL))
}
cred.connectorRequestPath = externalCredentialReversePath(parsedURL, "/ocm/v1/reverse")
cred.connectorURL = parsedURL
if parsedURL.Scheme == "https" {
cred.connectorTLS = &stdTLS.Config{
@@ -214,18 +253,13 @@ func (c *externalCredential) isExternal() bool {
}
func (c *externalCredential) isAvailable() bool {
if c.reverse && c.connectorURL != nil {
return false // connector mode: not for local proxying
}
if c.baseURL == reverseProxyBaseURL {
// receiver mode: only available when reverse connection active
session := c.getReverseSession()
return session != nil && !session.IsClosed()
}
return true
return c.unavailableError() == nil
}
func (c *externalCredential) isUsable() bool {
if !c.isAvailable() {
return false
}
c.stateMutex.RLock()
if c.state.hardRateLimited {
if time.Now().Before(c.state.rateLimitResetAt) {
@@ -284,6 +318,15 @@ func (c *externalCredential) earliestReset() time.Time {
}
func (c *externalCredential) unavailableError() error {
if c.reverse && c.connectorURL != nil {
return E.New("credential ", c.tag, " is unavailable: reverse connector credentials cannot serve local requests")
}
if c.baseURL == reverseProxyBaseURL {
session := c.getReverseSession()
if session == nil || session.IsClosed() {
return E.New("credential ", c.tag, " is unavailable: reverse connection not established")
}
}
return nil
}
@@ -315,6 +358,32 @@ func (c *externalCredential) buildProxyRequest(ctx context.Context, original *ht
return proxyRequest, nil
}
func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Conn, error) {
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
session := c.getReverseSession()
if session == nil || session.IsClosed() {
return nil, E.New("reverse connection not established for ", c.tag)
}
conn, err := session.Open()
if err != nil {
return nil, err
}
select {
case <-ctx.Done():
conn.Close()
return nil, ctx.Err()
default:
}
return conn, nil
}
func (c *externalCredential) updateStateFromHeaders(headers http.Header) {
c.stateMutex.Lock()
isFirstUpdate := c.state.lastUpdated.IsZero()

View File

@@ -160,6 +160,9 @@ func connectorBackoff(failures int) time.Duration {
}
func (c *externalCredential) connectorConnect() error {
if c.reverseService == nil {
return E.New("reverse service not initialized")
}
destination := c.connectorResolveDestination()
conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination)
if err != nil {
@@ -176,7 +179,7 @@ func (c *externalCredential) connectorConnect() error {
conn = tlsConn
}
upgradeRequest := "GET /ocm/v1/reverse HTTP/1.1\r\n" +
upgradeRequest := "GET " + c.connectorRequestPath + " HTTP/1.1\r\n" +
"Host: " + c.connectorURL.Host + "\r\n" +
"Connection: Upgrade\r\n" +
"Upgrade: reverse-proxy\r\n" +
@@ -231,13 +234,5 @@ func (c *externalCredential) connectorConnect() error {
}
func (c *externalCredential) connectorResolveDestination() M.Socksaddr {
port := c.connectorURL.Port()
if port == "" {
if c.connectorURL.Scheme == "https" {
port = "443"
} else {
port = "80"
}
}
return M.ParseSocksaddr(net.JoinHostPort(c.connectorURL.Hostname(), port))
return c.connectorDestination
}

View File

@@ -297,6 +297,9 @@ func (s *Service) Start(stage adapter.StartStage) error {
s.userManager.UpdateUsers(s.options.Users)
for _, cred := range s.allCredentials {
if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil {
extCred.reverseService = s
}
err := cred.start()
if err != nil {
return err
@@ -305,9 +308,6 @@ func (s *Service) Start(stage adapter.StartStage) error {
cred.setOnBecameUnusable(func() {
s.interruptWebSocketSessionsForCredential(tag)
})
if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil {
extCred.reverseService = s
}
}
if len(s.options.Credentials) > 0 {
err := validateOCMCompositeCredentialModes(s.options, s.providers)
@@ -875,6 +875,7 @@ func (s *Service) InterfaceUpdated() {
continue
}
if extCred.reverse && extCred.connectorURL != nil {
extCred.reverseService = s
extCred.reverseCancel()
extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background())
go extCred.connectorLoop()