mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
ccm,ocm: fix data race on reverseContext/reverseCancel
InterfaceUpdated() writes reverseContext and reverseCancel without synchronization while connectorLoop/connectorConnect goroutines read them concurrently. close() also accesses reverseCancel without a lock. Fix by extending reverseAccess mutex to protect these fields: - Add getReverseContext()/resetReverseContext() methods - Pass context as parameter to connectorConnect - Merge close() into a single lock acquisition - Use resetReverseContext() in InterfaceUpdated()
This commit is contained in:
@@ -542,10 +542,10 @@ func (c *externalCredential) httpTransport() *http.Client {
|
||||
}
|
||||
|
||||
func (c *externalCredential) close() {
|
||||
c.reverseAccess.Lock()
|
||||
if c.reverseCancel != nil {
|
||||
c.reverseCancel()
|
||||
}
|
||||
c.reverseAccess.Lock()
|
||||
session := c.reverseSession
|
||||
c.reverseSession = nil
|
||||
c.reverseAccess.Unlock()
|
||||
@@ -584,3 +584,16 @@ func (c *externalCredential) clearReverseSession(session *yamux.Session) {
|
||||
}
|
||||
c.reverseAccess.Unlock()
|
||||
}
|
||||
|
||||
func (c *externalCredential) getReverseContext() context.Context {
|
||||
c.reverseAccess.RLock()
|
||||
defer c.reverseAccess.RUnlock()
|
||||
return c.reverseContext
|
||||
}
|
||||
|
||||
func (c *externalCredential) resetReverseContext() {
|
||||
c.reverseAccess.Lock()
|
||||
c.reverseCancel()
|
||||
c.reverseContext, c.reverseCancel = context.WithCancel(context.Background())
|
||||
c.reverseAccess.Unlock()
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package ccm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
@@ -125,15 +126,16 @@ func (s *Service) findReceiverCredential(token string) *externalCredential {
|
||||
|
||||
func (c *externalCredential) connectorLoop() {
|
||||
var consecutiveFailures int
|
||||
ctx := c.getReverseContext()
|
||||
for {
|
||||
select {
|
||||
case <-c.reverseContext.Done():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
sessionLifetime, err := c.connectorConnect()
|
||||
if c.reverseContext.Err() != nil {
|
||||
sessionLifetime, err := c.connectorConnect(ctx)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if sessionLifetime >= connectorBackoffResetThreshold {
|
||||
@@ -144,7 +146,7 @@ func (c *externalCredential) connectorLoop() {
|
||||
c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff)
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-c.reverseContext.Done():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -164,19 +166,19 @@ func connectorBackoff(failures int) time.Duration {
|
||||
return base + jitter
|
||||
}
|
||||
|
||||
func (c *externalCredential) connectorConnect() (time.Duration, error) {
|
||||
func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duration, error) {
|
||||
if c.reverseService == nil {
|
||||
return 0, E.New("reverse service not initialized")
|
||||
}
|
||||
destination := c.connectorResolveDestination()
|
||||
conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination)
|
||||
conn, err := c.connectorDialer.DialContext(ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
return 0, E.Cause(err, "dial")
|
||||
}
|
||||
|
||||
if c.connectorTLS != nil {
|
||||
tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone())
|
||||
err = tlsConn.HandshakeContext(c.reverseContext)
|
||||
err = tlsConn.HandshakeContext(ctx)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "tls handshake")
|
||||
@@ -234,7 +236,7 @@ func (c *externalCredential) connectorConnect() (time.Duration, error) {
|
||||
}
|
||||
err = httpServer.Serve(&yamuxNetListener{session: session})
|
||||
sessionLifetime := time.Since(serveStart)
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) && c.reverseContext.Err() == nil {
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) && ctx.Err() == nil {
|
||||
return sessionLifetime, E.Cause(err, "serve")
|
||||
}
|
||||
return sessionLifetime, E.New("connection closed")
|
||||
|
||||
@@ -802,8 +802,7 @@ func (s *Service) InterfaceUpdated() {
|
||||
}
|
||||
if extCred.reverse && extCred.connectorURL != nil {
|
||||
extCred.reverseService = s
|
||||
extCred.reverseCancel()
|
||||
extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background())
|
||||
extCred.resetReverseContext()
|
||||
go extCred.connectorLoop()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -594,10 +594,10 @@ func (c *externalCredential) ocmGetBaseURL() string {
|
||||
}
|
||||
|
||||
func (c *externalCredential) close() {
|
||||
c.reverseAccess.Lock()
|
||||
if c.reverseCancel != nil {
|
||||
c.reverseCancel()
|
||||
}
|
||||
c.reverseAccess.Lock()
|
||||
session := c.reverseSession
|
||||
c.reverseSession = nil
|
||||
c.reverseAccess.Unlock()
|
||||
@@ -636,3 +636,16 @@ func (c *externalCredential) clearReverseSession(session *yamux.Session) {
|
||||
}
|
||||
c.reverseAccess.Unlock()
|
||||
}
|
||||
|
||||
func (c *externalCredential) getReverseContext() context.Context {
|
||||
c.reverseAccess.RLock()
|
||||
defer c.reverseAccess.RUnlock()
|
||||
return c.reverseContext
|
||||
}
|
||||
|
||||
func (c *externalCredential) resetReverseContext() {
|
||||
c.reverseAccess.Lock()
|
||||
c.reverseCancel()
|
||||
c.reverseContext, c.reverseCancel = context.WithCancel(context.Background())
|
||||
c.reverseAccess.Unlock()
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package ocm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
@@ -125,15 +126,16 @@ func (s *Service) findReceiverCredential(token string) *externalCredential {
|
||||
|
||||
func (c *externalCredential) connectorLoop() {
|
||||
var consecutiveFailures int
|
||||
ctx := c.getReverseContext()
|
||||
for {
|
||||
select {
|
||||
case <-c.reverseContext.Done():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
sessionLifetime, err := c.connectorConnect()
|
||||
if c.reverseContext.Err() != nil {
|
||||
sessionLifetime, err := c.connectorConnect(ctx)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if sessionLifetime >= connectorBackoffResetThreshold {
|
||||
@@ -144,7 +146,7 @@ func (c *externalCredential) connectorLoop() {
|
||||
c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff)
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-c.reverseContext.Done():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -164,19 +166,19 @@ func connectorBackoff(failures int) time.Duration {
|
||||
return base + jitter
|
||||
}
|
||||
|
||||
func (c *externalCredential) connectorConnect() (time.Duration, error) {
|
||||
func (c *externalCredential) connectorConnect(ctx context.Context) (time.Duration, error) {
|
||||
if c.reverseService == nil {
|
||||
return 0, E.New("reverse service not initialized")
|
||||
}
|
||||
destination := c.connectorResolveDestination()
|
||||
conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination)
|
||||
conn, err := c.connectorDialer.DialContext(ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
return 0, E.Cause(err, "dial")
|
||||
}
|
||||
|
||||
if c.connectorTLS != nil {
|
||||
tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone())
|
||||
err = tlsConn.HandshakeContext(c.reverseContext)
|
||||
err = tlsConn.HandshakeContext(ctx)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return 0, E.Cause(err, "tls handshake")
|
||||
@@ -234,7 +236,7 @@ func (c *externalCredential) connectorConnect() (time.Duration, error) {
|
||||
}
|
||||
err = httpServer.Serve(&yamuxNetListener{session: session})
|
||||
sessionLifetime := time.Since(serveStart)
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) && c.reverseContext.Err() == nil {
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) && ctx.Err() == nil {
|
||||
return sessionLifetime, E.Cause(err, "serve")
|
||||
}
|
||||
return sessionLifetime, E.New("connection closed")
|
||||
|
||||
@@ -876,8 +876,7 @@ func (s *Service) InterfaceUpdated() {
|
||||
}
|
||||
if extCred.reverse && extCred.connectorURL != nil {
|
||||
extCred.reverseService = s
|
||||
extCred.reverseCancel()
|
||||
extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background())
|
||||
extCred.resetReverseContext()
|
||||
go extCred.connectorLoop()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user