diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index e8e53c181..0bcf15a77 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -10,6 +10,7 @@ import ( "net/http" "net/url" "strconv" + "strings" "sync" "time" @@ -46,15 +47,59 @@ type externalCredential struct { requestAccess sync.Mutex // Reverse proxy fields - reverse bool - reverseSession *yamux.Session - reverseAccess sync.RWMutex - reverseContext context.Context - reverseCancel context.CancelFunc - connectorDialer N.Dialer - connectorURL *url.URL - connectorTLS *stdTLS.Config - reverseService http.Handler + reverse bool + reverseSession *yamux.Session + reverseAccess sync.RWMutex + reverseContext context.Context + reverseCancel context.CancelFunc + connectorDialer N.Dialer + connectorDestination M.Socksaddr + connectorRequestPath string + connectorURL *url.URL + connectorTLS *stdTLS.Config + reverseService http.Handler +} + +func externalCredentialURLPort(parsedURL *url.URL) uint16 { + portStr := parsedURL.Port() + if portStr != "" { + port, err := strconv.ParseUint(portStr, 10, 16) + if err == nil { + return uint16(port) + } + } + if parsedURL.Scheme == "https" { + return 443 + } + return 80 +} + +func externalCredentialServerPort(parsedURL *url.URL, configuredPort uint16) uint16 { + if configuredPort != 0 { + return configuredPort + } + return externalCredentialURLPort(parsedURL) +} + +func externalCredentialBaseURL(parsedURL *url.URL) string { + baseURL := parsedURL.Scheme + "://" + parsedURL.Host + if parsedURL.Path != "" && parsedURL.Path != "/" { + baseURL += parsedURL.Path + } + if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { + baseURL = baseURL[:len(baseURL)-1] + } + return baseURL +} + +func externalCredentialReversePath(parsedURL *url.URL, endpointPath string) string { + pathPrefix := parsedURL.EscapedPath() + if pathPrefix == "/" { + pathPrefix = "" + } else { + pathPrefix = strings.TrimSuffix(pathPrefix, "/") + } + return pathPrefix + endpointPath } func newExternalCredential(ctx context.Context, tag string, options option.CCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) { @@ -85,11 +130,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx Transport: &http.Transport{ ForceAttemptHTTP2: false, DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - session := cred.getReverseSession() - if session == nil || session.IsClosed() { - return nil, E.New("reverse connection not established for ", cred.tag) - } - return session.Open() + return cred.openReverseConnection(ctx) }, }, } @@ -115,24 +156,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx ForceAttemptHTTP2: true, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { if options.Server != "" { - serverPort := options.ServerPort - if serverPort == 0 { - portStr := parsedURL.Port() - if portStr != "" { - port, parseErr := strconv.ParseUint(portStr, 10, 16) - if parseErr == nil { - serverPort = uint16(port) - } - } - if serverPort == 0 { - if parsedURL.Scheme == "https" { - serverPort = 443 - } else { - serverPort = 80 - } - } - } - destination := M.ParseSocksaddrHostPort(options.Server, serverPort) + destination := M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort)) return credentialDialer.DialContext(ctx, network, destination) } return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) @@ -147,19 +171,17 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx } } - baseURL := parsedURL.Scheme + "://" + parsedURL.Host - if parsedURL.Path != "" && parsedURL.Path != "/" { - baseURL += parsedURL.Path - } - if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { - baseURL = baseURL[:len(baseURL)-1] - } - - cred.baseURL = baseURL + cred.baseURL = externalCredentialBaseURL(parsedURL) if options.Reverse { // Connector mode: we dial out to serve, not to proxy cred.connectorDialer = credentialDialer + if options.Server != "" { + cred.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort)) + } else { + cred.connectorDestination = M.ParseSocksaddrHostPort(parsedURL.Hostname(), externalCredentialURLPort(parsedURL)) + } + cred.connectorRequestPath = externalCredentialReversePath(parsedURL, "/ccm/v1/reverse") cred.connectorURL = parsedURL if parsedURL.Scheme == "https" { cred.connectorTLS = &stdTLS.Config{ @@ -208,18 +230,13 @@ func (c *externalCredential) isExternal() bool { } func (c *externalCredential) isAvailable() bool { - if c.reverse && c.connectorURL != nil { - return false // connector mode: not for local proxying - } - if c.baseURL == reverseProxyBaseURL { - // receiver mode: only available when reverse connection active - session := c.getReverseSession() - return session != nil && !session.IsClosed() - } - return true + return c.unavailableError() == nil } func (c *externalCredential) isUsable() bool { + if !c.isAvailable() { + return false + } c.stateMutex.RLock() if c.state.hardRateLimited { if time.Now().Before(c.state.rateLimitResetAt) { @@ -279,6 +296,15 @@ func (c *externalCredential) earliestReset() time.Time { } func (c *externalCredential) unavailableError() error { + if c.reverse && c.connectorURL != nil { + return E.New("credential ", c.tag, " is unavailable: reverse connector credentials cannot serve local requests") + } + if c.baseURL == reverseProxyBaseURL { + session := c.getReverseSession() + if session == nil || session.IsClosed() { + return E.New("credential ", c.tag, " is unavailable: reverse connection not established") + } + } return nil } @@ -310,6 +336,32 @@ func (c *externalCredential) buildProxyRequest(ctx context.Context, original *ht return proxyRequest, nil } +func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Conn, error) { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + session := c.getReverseSession() + if session == nil || session.IsClosed() { + return nil, E.New("reverse connection not established for ", c.tag) + } + conn, err := session.Open() + if err != nil { + return nil, err + } + select { + case <-ctx.Done(): + conn.Close() + return nil, ctx.Err() + default: + } + return conn, nil +} + func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.stateMutex.Lock() isFirstUpdate := c.state.lastUpdated.IsZero() diff --git a/service/ccm/reverse.go b/service/ccm/reverse.go index 571c8c55a..7e38c9ced 100644 --- a/service/ccm/reverse.go +++ b/service/ccm/reverse.go @@ -160,6 +160,9 @@ func connectorBackoff(failures int) time.Duration { } func (c *externalCredential) connectorConnect() error { + if c.reverseService == nil { + return E.New("reverse service not initialized") + } destination := c.connectorResolveDestination() conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination) if err != nil { @@ -176,7 +179,7 @@ func (c *externalCredential) connectorConnect() error { conn = tlsConn } - upgradeRequest := "GET /ccm/v1/reverse HTTP/1.1\r\n" + + upgradeRequest := "GET " + c.connectorRequestPath + " HTTP/1.1\r\n" + "Host: " + c.connectorURL.Host + "\r\n" + "Connection: Upgrade\r\n" + "Upgrade: reverse-proxy\r\n" + @@ -231,13 +234,5 @@ func (c *externalCredential) connectorConnect() error { } func (c *externalCredential) connectorResolveDestination() M.Socksaddr { - port := c.connectorURL.Port() - if port == "" { - if c.connectorURL.Scheme == "https" { - port = "443" - } else { - port = "80" - } - } - return M.ParseSocksaddr(net.JoinHostPort(c.connectorURL.Hostname(), port)) + return c.connectorDestination } diff --git a/service/ccm/service.go b/service/ccm/service.go index 69697b5c0..5d3415ea2 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -254,13 +254,13 @@ func (s *Service) Start(stage adapter.StartStage) error { s.userManager.UpdateUsers(s.options.Users) for _, cred := range s.allCredentials { + if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil { + extCred.reverseService = s + } err := cred.start() if err != nil { return err } - if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil { - extCred.reverseService = s - } } router := chi.NewRouter() @@ -801,6 +801,7 @@ func (s *Service) InterfaceUpdated() { continue } if extCred.reverse && extCred.connectorURL != nil { + extCred.reverseService = s extCred.reverseCancel() extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background()) go extCred.connectorLoop() diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 8226d6366..83d37f385 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -9,7 +9,9 @@ import ( "net" "net/http" "net/url" + "os" "strconv" + "strings" "sync" "time" @@ -47,15 +49,74 @@ type externalCredential struct { requestAccess sync.Mutex // Reverse proxy fields - reverse bool - reverseSession *yamux.Session - reverseAccess sync.RWMutex - reverseContext context.Context - reverseCancel context.CancelFunc - connectorDialer N.Dialer - connectorURL *url.URL - connectorTLS *stdTLS.Config - reverseService http.Handler + reverse bool + reverseSession *yamux.Session + reverseAccess sync.RWMutex + reverseContext context.Context + reverseCancel context.CancelFunc + connectorDialer N.Dialer + connectorDestination M.Socksaddr + connectorRequestPath string + connectorURL *url.URL + connectorTLS *stdTLS.Config + reverseService http.Handler +} + +type reverseSessionDialer struct { + credential *externalCredential +} + +func (d reverseSessionDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + if N.NetworkName(network) != N.NetworkTCP { + return nil, os.ErrInvalid + } + return d.credential.openReverseConnection(ctx) +} + +func (d reverseSessionDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + return nil, os.ErrInvalid +} + +func externalCredentialURLPort(parsedURL *url.URL) uint16 { + portStr := parsedURL.Port() + if portStr != "" { + port, err := strconv.ParseUint(portStr, 10, 16) + if err == nil { + return uint16(port) + } + } + if parsedURL.Scheme == "https" { + return 443 + } + return 80 +} + +func externalCredentialServerPort(parsedURL *url.URL, configuredPort uint16) uint16 { + if configuredPort != 0 { + return configuredPort + } + return externalCredentialURLPort(parsedURL) +} + +func externalCredentialBaseURL(parsedURL *url.URL) string { + baseURL := parsedURL.Scheme + "://" + parsedURL.Host + if parsedURL.Path != "" && parsedURL.Path != "/" { + baseURL += parsedURL.Path + } + if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { + baseURL = baseURL[:len(baseURL)-1] + } + return baseURL +} + +func externalCredentialReversePath(parsedURL *url.URL, endpointPath string) string { + pathPrefix := parsedURL.EscapedPath() + if pathPrefix == "/" { + pathPrefix = "" + } else { + pathPrefix = strings.TrimSuffix(pathPrefix, "/") + } + return pathPrefix + endpointPath } func newExternalCredential(ctx context.Context, tag string, options option.OCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) { @@ -82,15 +143,12 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx if options.URL == "" { // Receiver mode: no URL, wait for reverse connection cred.baseURL = reverseProxyBaseURL + cred.credDialer = reverseSessionDialer{credential: cred} cred.httpClient = &http.Client{ Transport: &http.Transport{ ForceAttemptHTTP2: false, DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - session := cred.getReverseSession() - if session == nil || session.IsClosed() { - return nil, E.New("reverse connection not established for ", cred.tag) - } - return session.Open() + return cred.openReverseConnection(ctx) }, }, } @@ -116,24 +174,7 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx ForceAttemptHTTP2: true, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { if options.Server != "" { - serverPort := options.ServerPort - if serverPort == 0 { - portStr := parsedURL.Port() - if portStr != "" { - port, parseErr := strconv.ParseUint(portStr, 10, 16) - if parseErr == nil { - serverPort = uint16(port) - } - } - if serverPort == 0 { - if parsedURL.Scheme == "https" { - serverPort = 443 - } else { - serverPort = 80 - } - } - } - destination := M.ParseSocksaddrHostPort(options.Server, serverPort) + destination := M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort)) return credentialDialer.DialContext(ctx, network, destination) } return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) @@ -148,19 +189,17 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx } } - baseURL := parsedURL.Scheme + "://" + parsedURL.Host - if parsedURL.Path != "" && parsedURL.Path != "/" { - baseURL += parsedURL.Path - } - if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { - baseURL = baseURL[:len(baseURL)-1] - } - - cred.baseURL = baseURL + cred.baseURL = externalCredentialBaseURL(parsedURL) if options.Reverse { // Connector mode: we dial out to serve, not to proxy cred.connectorDialer = credentialDialer + if options.Server != "" { + cred.connectorDestination = M.ParseSocksaddrHostPort(options.Server, externalCredentialServerPort(parsedURL, options.ServerPort)) + } else { + cred.connectorDestination = M.ParseSocksaddrHostPort(parsedURL.Hostname(), externalCredentialURLPort(parsedURL)) + } + cred.connectorRequestPath = externalCredentialReversePath(parsedURL, "/ocm/v1/reverse") cred.connectorURL = parsedURL if parsedURL.Scheme == "https" { cred.connectorTLS = &stdTLS.Config{ @@ -214,18 +253,13 @@ func (c *externalCredential) isExternal() bool { } func (c *externalCredential) isAvailable() bool { - if c.reverse && c.connectorURL != nil { - return false // connector mode: not for local proxying - } - if c.baseURL == reverseProxyBaseURL { - // receiver mode: only available when reverse connection active - session := c.getReverseSession() - return session != nil && !session.IsClosed() - } - return true + return c.unavailableError() == nil } func (c *externalCredential) isUsable() bool { + if !c.isAvailable() { + return false + } c.stateMutex.RLock() if c.state.hardRateLimited { if time.Now().Before(c.state.rateLimitResetAt) { @@ -284,6 +318,15 @@ func (c *externalCredential) earliestReset() time.Time { } func (c *externalCredential) unavailableError() error { + if c.reverse && c.connectorURL != nil { + return E.New("credential ", c.tag, " is unavailable: reverse connector credentials cannot serve local requests") + } + if c.baseURL == reverseProxyBaseURL { + session := c.getReverseSession() + if session == nil || session.IsClosed() { + return E.New("credential ", c.tag, " is unavailable: reverse connection not established") + } + } return nil } @@ -315,6 +358,32 @@ func (c *externalCredential) buildProxyRequest(ctx context.Context, original *ht return proxyRequest, nil } +func (c *externalCredential) openReverseConnection(ctx context.Context) (net.Conn, error) { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + session := c.getReverseSession() + if session == nil || session.IsClosed() { + return nil, E.New("reverse connection not established for ", c.tag) + } + conn, err := session.Open() + if err != nil { + return nil, err + } + select { + case <-ctx.Done(): + conn.Close() + return nil, ctx.Err() + default: + } + return conn, nil +} + func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.stateMutex.Lock() isFirstUpdate := c.state.lastUpdated.IsZero() diff --git a/service/ocm/reverse.go b/service/ocm/reverse.go index b02a20222..23ca1cc47 100644 --- a/service/ocm/reverse.go +++ b/service/ocm/reverse.go @@ -160,6 +160,9 @@ func connectorBackoff(failures int) time.Duration { } func (c *externalCredential) connectorConnect() error { + if c.reverseService == nil { + return E.New("reverse service not initialized") + } destination := c.connectorResolveDestination() conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination) if err != nil { @@ -176,7 +179,7 @@ func (c *externalCredential) connectorConnect() error { conn = tlsConn } - upgradeRequest := "GET /ocm/v1/reverse HTTP/1.1\r\n" + + upgradeRequest := "GET " + c.connectorRequestPath + " HTTP/1.1\r\n" + "Host: " + c.connectorURL.Host + "\r\n" + "Connection: Upgrade\r\n" + "Upgrade: reverse-proxy\r\n" + @@ -231,13 +234,5 @@ func (c *externalCredential) connectorConnect() error { } func (c *externalCredential) connectorResolveDestination() M.Socksaddr { - port := c.connectorURL.Port() - if port == "" { - if c.connectorURL.Scheme == "https" { - port = "443" - } else { - port = "80" - } - } - return M.ParseSocksaddr(net.JoinHostPort(c.connectorURL.Hostname(), port)) + return c.connectorDestination } diff --git a/service/ocm/service.go b/service/ocm/service.go index 50a44db89..245f2a444 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -297,6 +297,9 @@ func (s *Service) Start(stage adapter.StartStage) error { s.userManager.UpdateUsers(s.options.Users) for _, cred := range s.allCredentials { + if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil { + extCred.reverseService = s + } err := cred.start() if err != nil { return err @@ -305,9 +308,6 @@ func (s *Service) Start(stage adapter.StartStage) error { cred.setOnBecameUnusable(func() { s.interruptWebSocketSessionsForCredential(tag) }) - if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil { - extCred.reverseService = s - } } if len(s.options.Credentials) > 0 { err := validateOCMCompositeCredentialModes(s.options, s.providers) @@ -875,6 +875,7 @@ func (s *Service) InterfaceUpdated() { continue } if extCred.reverse && extCred.connectorURL != nil { + extCred.reverseService = s extCred.reverseCancel() extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background()) go extCred.connectorLoop()