From 3b177df05e3c5b44b6547463b80286d4f665d4c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 20:10:31 +0800 Subject: [PATCH] ccm,ocm: fix data race on reverseContext/reverseCancel InterfaceUpdated() writes reverseContext and reverseCancel without synchronization while connectorLoop/connectorConnect goroutines read them concurrently. close() also accesses reverseCancel without a lock. Fix by extending reverseAccess mutex to protect these fields: - Add getReverseContext()/resetReverseContext() methods - Pass context as parameter to connectorConnect - Merge close() into a single lock acquisition - Use resetReverseContext() in InterfaceUpdated() --- service/ccm/credential_external.go | 15 ++++++++++++++- service/ccm/reverse.go | 18 ++++++++++-------- service/ccm/service.go | 3 +-- service/ocm/credential_external.go | 15 ++++++++++++++- service/ocm/reverse.go | 18 ++++++++++-------- service/ocm/service.go | 3 +-- 6 files changed, 50 insertions(+), 22 deletions(-) diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 8a0ffda86..7459a8891 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -542,10 +542,10 @@ func (c *externalCredential) httpTransport() *http.Client { } func (c *externalCredential) close() { + c.reverseAccess.Lock() if c.reverseCancel != nil { c.reverseCancel() } - c.reverseAccess.Lock() session := c.reverseSession c.reverseSession = nil c.reverseAccess.Unlock() @@ -584,3 +584,16 @@ func (c *externalCredential) clearReverseSession(session *yamux.Session) { } c.reverseAccess.Unlock() } + +func (c *externalCredential) getReverseContext() context.Context { + c.reverseAccess.RLock() + defer c.reverseAccess.RUnlock() + return c.reverseContext +} + +func (c *externalCredential) resetReverseContext() { + c.reverseAccess.Lock() + c.reverseCancel() + c.reverseContext, c.reverseCancel = context.WithCancel(context.Background()) + c.reverseAccess.Unlock() +} diff --git a/service/ccm/reverse.go b/service/ccm/reverse.go index 97d71f0ad..e07480a0b 100644 --- a/service/ccm/reverse.go +++ b/service/ccm/reverse.go @@ -2,6 +2,7 @@ package ccm import ( "bufio" + "context" stdTLS "crypto/tls" "errors" "io" @@ -125,15 +126,16 @@ func (s *Service) findReceiverCredential(token string) *externalCredential { func (c *externalCredential) connectorLoop() { var consecutiveFailures int + ctx := c.getReverseContext() for { select { - case <-c.reverseContext.Done(): + case <-ctx.Done(): return default: } - sessionLifetime, err := c.connectorConnect() - if c.reverseContext.Err() != nil { + sessionLifetime, err := c.connectorConnect(ctx) + if ctx.Err() != nil { return } if sessionLifetime >= connectorBackoffResetThreshold { @@ -144,7 +146,7 @@ func (c *externalCredential) connectorLoop() { c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff) select { case <-time.After(backoff): - case <-c.reverseContext.Done(): + case <-ctx.Done(): return } } @@ -164,19 +166,19 @@ func connectorBackoff(failures int) time.Duration { return base + jitter } -func (c *externalCredential) connectorConnect() (time.Duration, error) { +func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duration, error) { if c.reverseService == nil { return 0, E.New("reverse service not initialized") } destination := c.connectorResolveDestination() - conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination) + conn, err := c.connectorDialer.DialContext(ctx, "tcp", destination) if err != nil { return 0, E.Cause(err, "dial") } if c.connectorTLS != nil { tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone()) - err = tlsConn.HandshakeContext(c.reverseContext) + err = tlsConn.HandshakeContext(ctx) if err != nil { conn.Close() return 0, E.Cause(err, "tls handshake") @@ -234,7 +236,7 @@ func (c *externalCredential) connectorConnect() (time.Duration, error) { } err = httpServer.Serve(&yamuxNetListener{session: session}) sessionLifetime := time.Since(serveStart) - if err != nil && !errors.Is(err, http.ErrServerClosed) && c.reverseContext.Err() == nil { + if err != nil && !errors.Is(err, http.ErrServerClosed) && ctx.Err() == nil { return sessionLifetime, E.Cause(err, "serve") } return sessionLifetime, E.New("connection closed") diff --git a/service/ccm/service.go b/service/ccm/service.go index 5d3415ea2..2e7685e71 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -802,8 +802,7 @@ func (s *Service) InterfaceUpdated() { } if extCred.reverse && extCred.connectorURL != nil { extCred.reverseService = s - extCred.reverseCancel() - extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background()) + extCred.resetReverseContext() go extCred.connectorLoop() } } diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 0d19ea557..d396705f2 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -594,10 +594,10 @@ func (c *externalCredential) ocmGetBaseURL() string { } func (c *externalCredential) close() { + c.reverseAccess.Lock() if c.reverseCancel != nil { c.reverseCancel() } - c.reverseAccess.Lock() session := c.reverseSession c.reverseSession = nil c.reverseAccess.Unlock() @@ -636,3 +636,16 @@ func (c *externalCredential) clearReverseSession(session *yamux.Session) { } c.reverseAccess.Unlock() } + +func (c *externalCredential) getReverseContext() context.Context { + c.reverseAccess.RLock() + defer c.reverseAccess.RUnlock() + return c.reverseContext +} + +func (c *externalCredential) resetReverseContext() { + c.reverseAccess.Lock() + c.reverseCancel() + c.reverseContext, c.reverseCancel = context.WithCancel(context.Background()) + c.reverseAccess.Unlock() +} diff --git a/service/ocm/reverse.go b/service/ocm/reverse.go index b3c17f45f..e88ccea0a 100644 --- a/service/ocm/reverse.go +++ b/service/ocm/reverse.go @@ -2,6 +2,7 @@ package ocm import ( "bufio" + "context" stdTLS "crypto/tls" "errors" "io" @@ -125,15 +126,16 @@ func (s *Service) findReceiverCredential(token string) *externalCredential { func (c *externalCredential) connectorLoop() { var consecutiveFailures int + ctx := c.getReverseContext() for { select { - case <-c.reverseContext.Done(): + case <-ctx.Done(): return default: } - sessionLifetime, err := c.connectorConnect() - if c.reverseContext.Err() != nil { + sessionLifetime, err := c.connectorConnect(ctx) + if ctx.Err() != nil { return } if sessionLifetime >= connectorBackoffResetThreshold { @@ -144,7 +146,7 @@ func (c *externalCredential) connectorLoop() { c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff) select { case <-time.After(backoff): - case <-c.reverseContext.Done(): + case <-ctx.Done(): return } } @@ -164,19 +166,19 @@ func connectorBackoff(failures int) time.Duration { return base + jitter } -func (c *externalCredential) connectorConnect() (time.Duration, error) { +func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duration, error) { if c.reverseService == nil { return 0, E.New("reverse service not initialized") } destination := c.connectorResolveDestination() - conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination) + conn, err := c.connectorDialer.DialContext(ctx, "tcp", destination) if err != nil { return 0, E.Cause(err, "dial") } if c.connectorTLS != nil { tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone()) - err = tlsConn.HandshakeContext(c.reverseContext) + err = tlsConn.HandshakeContext(ctx) if err != nil { conn.Close() return 0, E.Cause(err, "tls handshake") @@ -234,7 +236,7 @@ func (c *externalCredential) connectorConnect() (time.Duration, error) { } err = httpServer.Serve(&yamuxNetListener{session: session}) sessionLifetime := time.Since(serveStart) - if err != nil && !errors.Is(err, http.ErrServerClosed) && c.reverseContext.Err() == nil { + if err != nil && !errors.Is(err, http.ErrServerClosed) && ctx.Err() == nil { return sessionLifetime, E.Cause(err, "serve") } return sessionLifetime, E.New("connection closed") diff --git a/service/ocm/service.go b/service/ocm/service.go index 245f2a444..3868725ff 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -876,8 +876,7 @@ func (s *Service) InterfaceUpdated() { } if extCred.reverse && extCred.connectorURL != nil { extCred.reverseService = s - extCred.reverseCancel() - extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background()) + extCred.resetReverseContext() go extCred.connectorLoop() } }