Migrate to gobwas/ws
This commit is contained in:
@@ -2,12 +2,14 @@ package clashapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/common/json"
|
||||
"github.com/sagernet/sing-box/experimental/clashapi/trafficontrol"
|
||||
"github.com/sagernet/websocket"
|
||||
"github.com/sagernet/ws"
|
||||
"github.com/sagernet/ws/wsutil"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/render"
|
||||
@@ -27,16 +29,16 @@ type Memory struct {
|
||||
|
||||
func memory(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var wsConn *websocket.Conn
|
||||
if websocket.IsWebSocketUpgrade(r) {
|
||||
var conn net.Conn
|
||||
if r.Header.Get("Upgrade") == "websocket" {
|
||||
var err error
|
||||
wsConn, err = upgrader.Upgrade(w, r, nil)
|
||||
conn, _, _, err = ws.UpgradeHTTP(r, w)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if wsConn == nil {
|
||||
if conn == nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
render.Status(r, http.StatusOK)
|
||||
}
|
||||
@@ -63,13 +65,12 @@ func memory(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r
|
||||
}); err != nil {
|
||||
break
|
||||
}
|
||||
if wsConn == nil {
|
||||
if conn == nil {
|
||||
_, err = w.Write(buf.Bytes())
|
||||
w.(http.Flusher).Flush()
|
||||
} else {
|
||||
err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes())
|
||||
err = wsutil.WriteServerText(conn, buf.Bytes())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -9,7 +9,8 @@ import (
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/json"
|
||||
"github.com/sagernet/sing-box/experimental/clashapi/trafficontrol"
|
||||
"github.com/sagernet/websocket"
|
||||
"github.com/sagernet/ws"
|
||||
"github.com/sagernet/ws/wsutil"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/render"
|
||||
@@ -25,13 +26,13 @@ func connectionRouter(router adapter.Router, trafficManager *trafficontrol.Manag
|
||||
|
||||
func getConnections(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !websocket.IsWebSocketUpgrade(r) {
|
||||
if r.Header.Get("Upgrade") != "websocket" {
|
||||
snapshot := trafficManager.Snapshot()
|
||||
render.JSON(w, r, snapshot)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
conn, _, _, err := ws.UpgradeHTTP(r, w)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -56,7 +57,7 @@ func getConnections(trafficManager *trafficontrol.Manager) func(w http.ResponseW
|
||||
if err := json.NewEncoder(buf).Encode(snapshot); err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.WriteMessage(websocket.TextMessage, buf.Bytes())
|
||||
return wsutil.WriteServerText(conn, buf.Bytes())
|
||||
}
|
||||
|
||||
if err = sendSnapshot(); err != nil {
|
||||
|
||||
@@ -25,7 +25,8 @@ import (
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/service"
|
||||
"github.com/sagernet/sing/service/filemanager"
|
||||
"github.com/sagernet/websocket"
|
||||
"github.com/sagernet/ws"
|
||||
"github.com/sagernet/ws/wsutil"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/cors"
|
||||
@@ -314,7 +315,7 @@ func authentication(serverSecret string) func(next http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// Browser websocket not support custom header
|
||||
if websocket.IsWebSocketUpgrade(r) && r.URL.Query().Get("token") != "" {
|
||||
if r.Header.Get("Upgrade") == "websocket" && r.URL.Query().Get("token") != "" {
|
||||
token := r.URL.Query().Get("token")
|
||||
if token != serverSecret {
|
||||
render.Status(r, http.StatusUnauthorized)
|
||||
@@ -351,12 +352,6 @@ func hello(redirect bool) func(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
type Traffic struct {
|
||||
Up int64 `json:"up"`
|
||||
Down int64 `json:"down"`
|
||||
@@ -364,16 +359,17 @@ type Traffic struct {
|
||||
|
||||
func traffic(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var wsConn *websocket.Conn
|
||||
if websocket.IsWebSocketUpgrade(r) {
|
||||
var conn net.Conn
|
||||
if r.Header.Get("Upgrade") == "websocket" {
|
||||
var err error
|
||||
wsConn, err = upgrader.Upgrade(w, r, nil)
|
||||
conn, _, _, err = ws.UpgradeHTTP(r, w)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
if wsConn == nil {
|
||||
if conn == nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
render.Status(r, http.StatusOK)
|
||||
}
|
||||
@@ -392,11 +388,11 @@ func traffic(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter,
|
||||
break
|
||||
}
|
||||
|
||||
if wsConn == nil {
|
||||
if conn == nil {
|
||||
_, err = w.Write(buf.Bytes())
|
||||
w.(http.Flusher).Flush()
|
||||
} else {
|
||||
err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes())
|
||||
err = wsutil.WriteServerText(conn, buf.Bytes())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -432,16 +428,16 @@ func getLogs(logFactory log.ObservableFactory) func(w http.ResponseWriter, r *ht
|
||||
}
|
||||
defer logFactory.UnSubscribe(subscription)
|
||||
|
||||
var wsConn *websocket.Conn
|
||||
if websocket.IsWebSocketUpgrade(r) {
|
||||
var err error
|
||||
wsConn, err = upgrader.Upgrade(w, r, nil)
|
||||
var conn net.Conn
|
||||
if r.Header.Get("Upgrade") == "websocket" {
|
||||
conn, _, _, err = ws.UpgradeHTTP(r, w)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
if wsConn == nil {
|
||||
if conn == nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
render.Status(r, http.StatusOK)
|
||||
}
|
||||
@@ -465,11 +461,11 @@ func getLogs(logFactory log.ObservableFactory) func(w http.ResponseWriter, r *ht
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if wsConn == nil {
|
||||
if conn == nil {
|
||||
_, err = w.Write(buf.Bytes())
|
||||
w.(http.Flusher).Flush()
|
||||
} else {
|
||||
err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes())
|
||||
err = wsutil.WriteServerText(conn, buf.Bytes())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user