Migrate to gobwas/ws

This commit is contained in:
世界
2023-10-30 12:00:00 +08:00
parent 56b7f554fd
commit 69645b45bf
11 changed files with 176 additions and 157 deletions

View File

@@ -2,12 +2,14 @@ package clashapi
import (
"bytes"
"net"
"net/http"
"time"
"github.com/sagernet/sing-box/common/json"
"github.com/sagernet/sing-box/experimental/clashapi/trafficontrol"
"github.com/sagernet/websocket"
"github.com/sagernet/ws"
"github.com/sagernet/ws/wsutil"
"github.com/go-chi/chi/v5"
"github.com/go-chi/render"
@@ -27,16 +29,16 @@ type Memory struct {
func memory(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var wsConn *websocket.Conn
if websocket.IsWebSocketUpgrade(r) {
var conn net.Conn
if r.Header.Get("Upgrade") == "websocket" {
var err error
wsConn, err = upgrader.Upgrade(w, r, nil)
conn, _, _, err = ws.UpgradeHTTP(r, w)
if err != nil {
return
}
}
if wsConn == nil {
if conn == nil {
w.Header().Set("Content-Type", "application/json")
render.Status(r, http.StatusOK)
}
@@ -63,13 +65,12 @@ func memory(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r
}); err != nil {
break
}
if wsConn == nil {
if conn == nil {
_, err = w.Write(buf.Bytes())
w.(http.Flusher).Flush()
} else {
err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes())
err = wsutil.WriteServerText(conn, buf.Bytes())
}
if err != nil {
break
}

View File

@@ -9,7 +9,8 @@ import (
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/json"
"github.com/sagernet/sing-box/experimental/clashapi/trafficontrol"
"github.com/sagernet/websocket"
"github.com/sagernet/ws"
"github.com/sagernet/ws/wsutil"
"github.com/go-chi/chi/v5"
"github.com/go-chi/render"
@@ -25,13 +26,13 @@ func connectionRouter(router adapter.Router, trafficManager *trafficontrol.Manag
func getConnections(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if !websocket.IsWebSocketUpgrade(r) {
if r.Header.Get("Upgrade") != "websocket" {
snapshot := trafficManager.Snapshot()
render.JSON(w, r, snapshot)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
conn, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
return
}
@@ -56,7 +57,7 @@ func getConnections(trafficManager *trafficontrol.Manager) func(w http.ResponseW
if err := json.NewEncoder(buf).Encode(snapshot); err != nil {
return err
}
return conn.WriteMessage(websocket.TextMessage, buf.Bytes())
return wsutil.WriteServerText(conn, buf.Bytes())
}
if err = sendSnapshot(); err != nil {

View File

@@ -25,7 +25,8 @@ import (
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/filemanager"
"github.com/sagernet/websocket"
"github.com/sagernet/ws"
"github.com/sagernet/ws/wsutil"
"github.com/go-chi/chi/v5"
"github.com/go-chi/cors"
@@ -314,7 +315,7 @@ func authentication(serverSecret string) func(next http.Handler) http.Handler {
}
// Browser websocket not support custom header
if websocket.IsWebSocketUpgrade(r) && r.URL.Query().Get("token") != "" {
if r.Header.Get("Upgrade") == "websocket" && r.URL.Query().Get("token") != "" {
token := r.URL.Query().Get("token")
if token != serverSecret {
render.Status(r, http.StatusUnauthorized)
@@ -351,12 +352,6 @@ func hello(redirect bool) func(w http.ResponseWriter, r *http.Request) {
}
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
type Traffic struct {
Up int64 `json:"up"`
Down int64 `json:"down"`
@@ -364,16 +359,17 @@ type Traffic struct {
func traffic(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var wsConn *websocket.Conn
if websocket.IsWebSocketUpgrade(r) {
var conn net.Conn
if r.Header.Get("Upgrade") == "websocket" {
var err error
wsConn, err = upgrader.Upgrade(w, r, nil)
conn, _, _, err = ws.UpgradeHTTP(r, w)
if err != nil {
return
}
defer conn.Close()
}
if wsConn == nil {
if conn == nil {
w.Header().Set("Content-Type", "application/json")
render.Status(r, http.StatusOK)
}
@@ -392,11 +388,11 @@ func traffic(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter,
break
}
if wsConn == nil {
if conn == nil {
_, err = w.Write(buf.Bytes())
w.(http.Flusher).Flush()
} else {
err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes())
err = wsutil.WriteServerText(conn, buf.Bytes())
}
if err != nil {
@@ -432,16 +428,16 @@ func getLogs(logFactory log.ObservableFactory) func(w http.ResponseWriter, r *ht
}
defer logFactory.UnSubscribe(subscription)
var wsConn *websocket.Conn
if websocket.IsWebSocketUpgrade(r) {
var err error
wsConn, err = upgrader.Upgrade(w, r, nil)
var conn net.Conn
if r.Header.Get("Upgrade") == "websocket" {
conn, _, _, err = ws.UpgradeHTTP(r, w)
if err != nil {
return
}
defer conn.Close()
}
if wsConn == nil {
if conn == nil {
w.Header().Set("Content-Type", "application/json")
render.Status(r, http.StatusOK)
}
@@ -465,11 +461,11 @@ func getLogs(logFactory log.ObservableFactory) func(w http.ResponseWriter, r *ht
if err != nil {
break
}
if wsConn == nil {
if conn == nil {
_, err = w.Write(buf.Bytes())
w.(http.Flusher).Flush()
} else {
err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes())
err = wsutil.WriteServerText(conn, buf.Bytes())
}
if err != nil {