ccm,ocm: fix data race on reverseContext/reverseCancel

InterfaceUpdated() writes reverseContext and reverseCancel without
synchronization while connectorLoop/connectorConnect goroutines
read them concurrently. close() also accesses reverseCancel without
a lock.

Fix by extending reverseAccess mutex to protect these fields:
- Add getReverseContext()/resetReverseContext() methods
- Pass context as parameter to connectorConnect
- Merge close() into a single lock acquisition
- Use resetReverseContext() in InterfaceUpdated()
This commit is contained in:
世界
2026-03-13 20:10:31 +08:00
parent 1824881719
commit 3b177df05e
6 changed files with 50 additions and 22 deletions

View File

@@ -542,10 +542,10 @@ func (c *externalCredential) httpTransport() *http.Client {
}
func (c *externalCredential) close() {
c.reverseAccess.Lock()
if c.reverseCancel != nil {
c.reverseCancel()
}
c.reverseAccess.Lock()
session := c.reverseSession
c.reverseSession = nil
c.reverseAccess.Unlock()
@@ -584,3 +584,16 @@ func (c *externalCredential) clearReverseSession(session *yamux.Session) {
}
c.reverseAccess.Unlock()
}
func (c *externalCredential) getReverseContext() context.Context {
c.reverseAccess.RLock()
defer c.reverseAccess.RUnlock()
return c.reverseContext
}
func (c *externalCredential) resetReverseContext() {
c.reverseAccess.Lock()
c.reverseCancel()
c.reverseContext, c.reverseCancel = context.WithCancel(context.Background())
c.reverseAccess.Unlock()
}

View File

@@ -2,6 +2,7 @@ package ccm
import (
"bufio"
"context"
stdTLS "crypto/tls"
"errors"
"io"
@@ -125,15 +126,16 @@ func (s *Service) findReceiverCredential(token string) *externalCredential {
func (c *externalCredential) connectorLoop() {
var consecutiveFailures int
ctx := c.getReverseContext()
for {
select {
case <-c.reverseContext.Done():
case <-ctx.Done():
return
default:
}
sessionLifetime, err := c.connectorConnect()
if c.reverseContext.Err() != nil {
sessionLifetime, err := c.connectorConnect(ctx)
if ctx.Err() != nil {
return
}
if sessionLifetime >= connectorBackoffResetThreshold {
@@ -144,7 +146,7 @@ func (c *externalCredential) connectorLoop() {
c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff)
select {
case <-time.After(backoff):
case <-c.reverseContext.Done():
case <-ctx.Done():
return
}
}
@@ -164,19 +166,19 @@ func connectorBackoff(failures int) time.Duration {
return base + jitter
}
func (c *externalCredential) connectorConnect() (time.Duration, error) {
func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duration, error) {
if c.reverseService == nil {
return 0, E.New("reverse service not initialized")
}
destination := c.connectorResolveDestination()
conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination)
conn, err := c.connectorDialer.DialContext(ctx, "tcp", destination)
if err != nil {
return 0, E.Cause(err, "dial")
}
if c.connectorTLS != nil {
tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone())
err = tlsConn.HandshakeContext(c.reverseContext)
err = tlsConn.HandshakeContext(ctx)
if err != nil {
conn.Close()
return 0, E.Cause(err, "tls handshake")
@@ -234,7 +236,7 @@ func (c *externalCredential) connectorConnect() (time.Duration, error) {
}
err = httpServer.Serve(&yamuxNetListener{session: session})
sessionLifetime := time.Since(serveStart)
if err != nil && !errors.Is(err, http.ErrServerClosed) && c.reverseContext.Err() == nil {
if err != nil && !errors.Is(err, http.ErrServerClosed) && ctx.Err() == nil {
return sessionLifetime, E.Cause(err, "serve")
}
return sessionLifetime, E.New("connection closed")

View File

@@ -802,8 +802,7 @@ func (s *Service) InterfaceUpdated() {
}
if extCred.reverse && extCred.connectorURL != nil {
extCred.reverseService = s
extCred.reverseCancel()
extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background())
extCred.resetReverseContext()
go extCred.connectorLoop()
}
}

View File

@@ -594,10 +594,10 @@ func (c *externalCredential) ocmGetBaseURL() string {
}
func (c *externalCredential) close() {
c.reverseAccess.Lock()
if c.reverseCancel != nil {
c.reverseCancel()
}
c.reverseAccess.Lock()
session := c.reverseSession
c.reverseSession = nil
c.reverseAccess.Unlock()
@@ -636,3 +636,16 @@ func (c *externalCredential) clearReverseSession(session *yamux.Session) {
}
c.reverseAccess.Unlock()
}
func (c *externalCredential) getReverseContext() context.Context {
c.reverseAccess.RLock()
defer c.reverseAccess.RUnlock()
return c.reverseContext
}
func (c *externalCredential) resetReverseContext() {
c.reverseAccess.Lock()
c.reverseCancel()
c.reverseContext, c.reverseCancel = context.WithCancel(context.Background())
c.reverseAccess.Unlock()
}

View File

@@ -2,6 +2,7 @@ package ocm
import (
"bufio"
"context"
stdTLS "crypto/tls"
"errors"
"io"
@@ -125,15 +126,16 @@ func (s *Service) findReceiverCredential(token string) *externalCredential {
func (c *externalCredential) connectorLoop() {
var consecutiveFailures int
ctx := c.getReverseContext()
for {
select {
case <-c.reverseContext.Done():
case <-ctx.Done():
return
default:
}
sessionLifetime, err := c.connectorConnect()
if c.reverseContext.Err() != nil {
sessionLifetime, err := c.connectorConnect(ctx)
if ctx.Err() != nil {
return
}
if sessionLifetime >= connectorBackoffResetThreshold {
@@ -144,7 +146,7 @@ func (c *externalCredential) connectorLoop() {
c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff)
select {
case <-time.After(backoff):
case <-c.reverseContext.Done():
case <-ctx.Done():
return
}
}
@@ -164,19 +166,19 @@ func connectorBackoff(failures int) time.Duration {
return base + jitter
}
func (c *externalCredential) connectorConnect() (time.Duration, error) {
func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duration, error) {
if c.reverseService == nil {
return 0, E.New("reverse service not initialized")
}
destination := c.connectorResolveDestination()
conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination)
conn, err := c.connectorDialer.DialContext(ctx, "tcp", destination)
if err != nil {
return 0, E.Cause(err, "dial")
}
if c.connectorTLS != nil {
tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone())
err = tlsConn.HandshakeContext(c.reverseContext)
err = tlsConn.HandshakeContext(ctx)
if err != nil {
conn.Close()
return 0, E.Cause(err, "tls handshake")
@@ -234,7 +236,7 @@ func (c *externalCredential) connectorConnect() (time.Duration, error) {
}
err = httpServer.Serve(&yamuxNetListener{session: session})
sessionLifetime := time.Since(serveStart)
if err != nil && !errors.Is(err, http.ErrServerClosed) && c.reverseContext.Err() == nil {
if err != nil && !errors.Is(err, http.ErrServerClosed) && ctx.Err() == nil {
return sessionLifetime, E.Cause(err, "serve")
}
return sessionLifetime, E.New("connection closed")

View File

@@ -876,8 +876,7 @@ func (s *Service) InterfaceUpdated() {
}
if extCred.reverse && extCred.connectorURL != nil {
extCred.reverseService = s
extCred.reverseCancel()
extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background())
extCred.resetReverseContext()
go extCred.connectorLoop()
}
}