Fix websocket connection and goroutine leaks in Clash API

Co-authored-by: traitman <112139837+traitman@users.noreply.github.com>
This commit is contained in:
世界
2026-03-09 15:38:25 +08:00
parent e1477bd065
commit e21a72fcd1
3 changed files with 15 additions and 5 deletions

View File

@@ -2,6 +2,7 @@ package clashapi
import (
"bytes"
"context"
"net"
"net/http"
"runtime/debug"
@@ -27,7 +28,7 @@ func (s *Server) setupMetaAPI(r chi.Router) {
})
r.Mount("/", middleware.Profiler())
}
r.Get("/memory", memory(s.trafficManager))
r.Get("/memory", memory(s.ctx, s.trafficManager))
r.Mount("/group", groupRouter(s))
r.Mount("/upgrade", upgradeRouter(s))
}
@@ -37,7 +38,7 @@ type Memory struct {
OSLimit uint64 `json:"oslimit"` // maybe we need it in the future
}
func memory(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
func memory(ctx context.Context, trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var conn net.Conn
if r.Header.Get("Upgrade") == "websocket" {
@@ -46,6 +47,7 @@ func memory(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r
if err != nil {
return
}
defer conn.Close()
}
if conn == nil {
@@ -58,7 +60,12 @@ func memory(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r
buf := &bytes.Buffer{}
var err error
first := true
for range tick.C {
for {
select {
case <-ctx.Done():
return
case <-tick.C:
}
buf.Reset()
inuse := trafficManager.Snapshot().Memory

View File

@@ -38,6 +38,7 @@ func getConnections(ctx context.Context, trafficManager *trafficontrol.Manager)
if err != nil {
return
}
defer conn.Close()
intervalStr := r.URL.Query().Get("interval")
interval := 1000

View File

@@ -115,7 +115,7 @@ func NewServer(ctx context.Context, logFactory log.ObservableFactory, options op
chiRouter.Group(func(r chi.Router) {
r.Use(authentication(options.Secret))
r.Get("/", hello(options.ExternalUI != ""))
r.Get("/logs", getLogs(logFactory))
r.Get("/logs", getLogs(s.ctx, logFactory))
r.Get("/traffic", traffic(s.ctx, trafficManager))
r.Get("/version", version)
r.Mount("/configs", configRouter(s, logFactory))
@@ -360,7 +360,7 @@ type Log struct {
Payload string `json:"payload"`
}
func getLogs(logFactory log.ObservableFactory) func(w http.ResponseWriter, r *http.Request) {
func getLogs(ctx context.Context, logFactory log.ObservableFactory) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
levelText := r.URL.Query().Get("level")
if levelText == "" {
@@ -399,6 +399,8 @@ func getLogs(logFactory log.ObservableFactory) func(w http.ResponseWriter, r *ht
var logEntry log.Entry
for {
select {
case <-ctx.Done():
return
case <-done:
return
case logEntry = <-subscription: