diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 761e4a978..29671f0db 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -220,6 +220,22 @@ func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteClos } else { i.handleDirectWebSocketStream(ctx, stream, respWriter, request, metadata, service) } + case ResolvedServiceBastion: + if request.Type != ConnectionTypeWebsocket { + err := E.New("bastion service requires websocket request type") + i.logger.ErrorContext(ctx, err) + respWriter.WriteResponse(err, nil) + return + } + i.handleBastionStream(ctx, stream, respWriter, request, metadata) + case ResolvedServiceSocksProxy: + if request.Type != ConnectionTypeWebsocket { + err := E.New("socks-proxy service requires websocket request type") + i.logger.ErrorContext(ctx, err) + respWriter.WriteResponse(err, nil) + return + } + i.handleSocksProxyStream(ctx, stream, respWriter, request, metadata) default: err := E.New("unsupported service kind for HTTP/WebSocket request") i.logger.ErrorContext(ctx, err) diff --git a/protocol/cloudflare/runtime_config.go b/protocol/cloudflare/runtime_config.go index 99fe73c81..276e99d41 100644 --- a/protocol/cloudflare/runtime_config.go +++ b/protocol/cloudflare/runtime_config.go @@ -39,6 +39,8 @@ const ( ResolvedServiceHelloWorld ResolvedServiceUnix ResolvedServiceUnixTLS + ResolvedServiceBastion + ResolvedServiceSocksProxy ) type ResolvedService struct { @@ -392,6 +394,13 @@ func compileIngressRules(defaultOriginRequest OriginRequestConfig, rawRules []lo func parseResolvedService(rawService string, originRequest OriginRequestConfig) (ResolvedService, error) { switch { case rawService == "": + if originRequest.BastionMode { + return ResolvedService{ + Kind: ResolvedServiceBastion, + Service: "bastion", + OriginRequest: originRequest, + }, nil + } return ResolvedService{}, E.New("missing ingress service") case strings.HasPrefix(rawService, "http_status:"): statusCode, err := strconv.Atoi(strings.TrimPrefix(rawService, "http_status:")) @@ -413,6 +422,18 @@ func parseResolvedService(rawService string, originRequest OriginRequestConfig) Service: rawService, OriginRequest: originRequest, }, nil + case rawService == "bastion": + return ResolvedService{ + Kind: ResolvedServiceBastion, + Service: rawService, + OriginRequest: originRequest, + }, nil + case rawService == "socks-proxy": + return ResolvedService{ + Kind: ResolvedServiceSocksProxy, + Service: rawService, + OriginRequest: originRequest, + }, nil case strings.HasPrefix(rawService, "unix:"): return ResolvedService{ Kind: ResolvedServiceUnix, diff --git a/protocol/cloudflare/special_service.go b/protocol/cloudflare/special_service.go new file mode 100644 index 000000000..61f6cb574 --- /dev/null +++ b/protocol/cloudflare/special_service.go @@ -0,0 +1,233 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "crypto/sha1" + "encoding/base64" + "io" + "net" + "net/http" + "net/netip" + "net/url" + "strconv" + "strings" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/transport/v2raywebsocket" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/pipe" + "github.com/sagernet/ws" +) + +var wsAcceptGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func (i *Inbound) handleBastionStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) { + destination, err := resolveBastionDestination(request) + if err != nil { + respWriter.WriteResponse(err, nil) + return + } + + targetConn, cleanup, err := i.dialRouterTCP(ctx, M.ParseSocksaddr(destination)) + if err != nil { + respWriter.WriteResponse(err, nil) + return + } + defer cleanup() + + err = respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusSwitchingProtocols, websocketResponseHeaders(request))) + if err != nil { + i.logger.ErrorContext(ctx, "write bastion websocket response: ", err) + return + } + + wsConn := v2raywebsocket.NewConn(newStreamConn(stream), nil, ws.StateServerSide) + defer wsConn.Close() + _ = bufio.CopyConn(ctx, wsConn, targetConn) +} + +func (i *Inbound) handleSocksProxyStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) { + err := respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusSwitchingProtocols, websocketResponseHeaders(request))) + if err != nil { + i.logger.ErrorContext(ctx, "write socks-proxy websocket response: ", err) + return + } + + wsConn := v2raywebsocket.NewConn(newStreamConn(stream), nil, ws.StateServerSide) + defer wsConn.Close() + if err := i.serveSocksProxy(ctx, wsConn); err != nil && !E.IsClosedOrCanceled(err) { + i.logger.DebugContext(ctx, "socks-proxy stream closed: ", err) + } +} + +func resolveBastionDestination(request *ConnectRequest) (string, error) { + headerValue := requestHeaderValue(request, "Cf-Access-Jump-Destination") + if headerValue == "" { + return "", E.New("missing Cf-Access-Jump-Destination header") + } + if parsed, err := url.Parse(headerValue); err == nil && parsed.Host != "" { + headerValue = parsed.Host + } + return strings.SplitN(headerValue, "/", 2)[0], nil +} + +func websocketResponseHeaders(request *ConnectRequest) http.Header { + header := http.Header{} + header.Set("Connection", "Upgrade") + header.Set("Upgrade", "websocket") + secKey := requestHeaderValue(request, "Sec-WebSocket-Key") + if secKey != "" { + sum := sha1.Sum(append([]byte(secKey), wsAcceptGUID...)) + header.Set("Sec-WebSocket-Accept", base64.StdEncoding.EncodeToString(sum[:])) + } + return header +} + +func requestHeaderValue(request *ConnectRequest, headerName string) string { + for _, entry := range request.Metadata { + if !strings.HasPrefix(entry.Key, metadataHTTPHeader+":") { + continue + } + name := strings.TrimPrefix(entry.Key, metadataHTTPHeader+":") + if strings.EqualFold(name, headerName) { + return entry.Val + } + } + return "" +} + +func (i *Inbound) dialRouterTCP(ctx context.Context, destination M.Socksaddr) (net.Conn, func(), error) { + input, output := pipe.Pipe() + done := make(chan struct{}) + metadata := adapter.InboundContext{ + Inbound: i.Tag(), + InboundType: i.Type(), + Network: N.NetworkTCP, + Destination: destination, + } + go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) { + common.Close(input, output) + close(done) + })) + return input, func() { + common.Close(input, output) + select { + case <-done: + case <-time.After(time.Second): + } + }, nil +} + +func (i *Inbound) serveSocksProxy(ctx context.Context, conn net.Conn) error { + version := make([]byte, 1) + if _, err := io.ReadFull(conn, version); err != nil { + return err + } + if version[0] != 5 { + return E.New("unsupported SOCKS version: ", version[0]) + } + + methodCount := make([]byte, 1) + if _, err := io.ReadFull(conn, methodCount); err != nil { + return err + } + methods := make([]byte, int(methodCount[0])) + if _, err := io.ReadFull(conn, methods); err != nil { + return err + } + if _, err := conn.Write([]byte{5, 0}); err != nil { + return err + } + + requestHeader := make([]byte, 4) + if _, err := io.ReadFull(conn, requestHeader); err != nil { + return err + } + if requestHeader[0] != 5 { + return E.New("unsupported SOCKS request version: ", requestHeader[0]) + } + if requestHeader[1] != 1 { + _, _ = conn.Write([]byte{5, 7, 0, 1, 0, 0, 0, 0, 0, 0}) + return E.New("unsupported SOCKS command: ", requestHeader[1]) + } + + destination, err := readSocksDestination(conn, requestHeader[3]) + if err != nil { + return err + } + targetConn, cleanup, err := i.dialRouterTCP(ctx, destination) + if err != nil { + _, _ = conn.Write([]byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}) + return err + } + defer cleanup() + + if _, err := conn.Write([]byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}); err != nil { + return err + } + return bufio.CopyConn(ctx, conn, targetConn) +} + +func readSocksDestination(conn net.Conn, addressType byte) (M.Socksaddr, error) { + switch addressType { + case 1: + addr := make([]byte, 4) + if _, err := io.ReadFull(conn, addr); err != nil { + return M.Socksaddr{}, err + } + port, err := readSocksPort(conn) + if err != nil { + return M.Socksaddr{}, err + } + ipAddr, ok := netip.AddrFromSlice(addr) + if !ok { + return M.Socksaddr{}, E.New("invalid IPv4 SOCKS destination") + } + return M.SocksaddrFrom(ipAddr, port), nil + case 3: + length := make([]byte, 1) + if _, err := io.ReadFull(conn, length); err != nil { + return M.Socksaddr{}, err + } + host := make([]byte, int(length[0])) + if _, err := io.ReadFull(conn, host); err != nil { + return M.Socksaddr{}, err + } + port, err := readSocksPort(conn) + if err != nil { + return M.Socksaddr{}, err + } + return M.ParseSocksaddr(net.JoinHostPort(string(host), strconv.Itoa(int(port)))), nil + case 4: + addr := make([]byte, 16) + if _, err := io.ReadFull(conn, addr); err != nil { + return M.Socksaddr{}, err + } + port, err := readSocksPort(conn) + if err != nil { + return M.Socksaddr{}, err + } + ipAddr, ok := netip.AddrFromSlice(addr) + if !ok { + return M.Socksaddr{}, E.New("invalid IPv6 SOCKS destination") + } + return M.SocksaddrFrom(ipAddr, port), nil + default: + return M.Socksaddr{}, E.New("unsupported SOCKS address type: ", addressType) + } +} + +func readSocksPort(conn net.Conn) (uint16, error) { + port := make([]byte, 2) + if _, err := io.ReadFull(conn, port); err != nil { + return 0, err + } + return uint16(port[0])<<8 | uint16(port[1]), nil +} diff --git a/protocol/cloudflare/special_service_test.go b/protocol/cloudflare/special_service_test.go new file mode 100644 index 000000000..5c29d40fa --- /dev/null +++ b/protocol/cloudflare/special_service_test.go @@ -0,0 +1,235 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "io" + "net" + "net/http" + "strconv" + "testing" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/inbound" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/ws" + "github.com/sagernet/ws/wsutil" +) + +type fakeConnectResponseWriter struct { + status int + headers http.Header + err error + done chan struct{} +} + +func (w *fakeConnectResponseWriter) WriteResponse(responseError error, metadata []Metadata) error { + w.err = responseError + w.headers = make(http.Header) + for _, entry := range metadata { + switch { + case entry.Key == metadataHTTPStatus: + status, _ := strconv.Atoi(entry.Val) + w.status = status + case len(entry.Key) > len(metadataHTTPHeader)+1 && entry.Key[:len(metadataHTTPHeader)+1] == metadataHTTPHeader+":": + w.headers.Add(entry.Key[len(metadataHTTPHeader)+1:], entry.Val) + } + } + if w.done != nil { + close(w.done) + w.done = nil + } + return nil +} + +func newSpecialServiceInbound(t *testing.T) *Inbound { + t.Helper() + logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}}) + if err != nil { + t.Fatal(err) + } + configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{}) + if err != nil { + t.Fatal(err) + } + return &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + router: &testRouter{}, + logger: logFactory.NewLogger("test"), + configManager: configManager, + } +} + +func TestHandleBastionStream(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + go func(conn net.Conn) { + defer conn.Close() + _, _ = io.Copy(conn, conn) + }(conn) + } + }() + + serverSide, clientSide := net.Pipe() + defer clientSide.Close() + + inboundInstance := newSpecialServiceInbound(t) + request := &ConnectRequest{ + Type: ConnectionTypeWebsocket, + Metadata: []Metadata{ + {Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="}, + {Key: metadataHTTPHeader + ":Cf-Access-Jump-Destination", Val: listener.Addr().String()}, + }, + } + respWriter := &fakeConnectResponseWriter{done: make(chan struct{})} + + done := make(chan struct{}) + go func() { + defer close(done) + inboundInstance.handleBastionStream(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}) + }() + + select { + case <-respWriter.done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for bastion connect response") + } + if respWriter.err != nil { + t.Fatal(respWriter.err) + } + if respWriter.status != http.StatusSwitchingProtocols { + t.Fatalf("expected 101 response, got %d", respWriter.status) + } + if respWriter.headers.Get("Sec-WebSocket-Accept") == "" { + t.Fatal("expected websocket accept header") + } + + if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil { + t.Fatal(err) + } + data, opCode, err := wsutil.ReadServerData(clientSide) + if err != nil { + t.Fatal(err) + } + if opCode != ws.OpBinary { + t.Fatalf("expected binary frame, got %v", opCode) + } + if string(data) != "hello" { + t.Fatalf("expected echoed payload, got %q", string(data)) + } + _ = clientSide.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("bastion stream did not exit") + } +} + +func TestHandleSocksProxyStream(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + go func(conn net.Conn) { + defer conn.Close() + _, _ = io.Copy(conn, conn) + }(conn) + } + }() + + serverSide, clientSide := net.Pipe() + defer clientSide.Close() + + inboundInstance := newSpecialServiceInbound(t) + request := &ConnectRequest{ + Type: ConnectionTypeWebsocket, + Metadata: []Metadata{ + {Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="}, + }, + } + respWriter := &fakeConnectResponseWriter{done: make(chan struct{})} + + done := make(chan struct{}) + go func() { + defer close(done) + inboundInstance.handleSocksProxyStream(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}) + }() + + select { + case <-respWriter.done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for socks-proxy connect response") + } + if respWriter.err != nil { + t.Fatal(respWriter.err) + } + if respWriter.status != http.StatusSwitchingProtocols { + t.Fatalf("expected 101 response, got %d", respWriter.status) + } + + if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte{5, 1, 0}); err != nil { + t.Fatal(err) + } + data, _, err := wsutil.ReadServerData(clientSide) + if err != nil { + t.Fatal(err) + } + if string(data) != string([]byte{5, 0}) { + t.Fatalf("unexpected auth response: %v", data) + } + + host, portText, _ := net.SplitHostPort(listener.Addr().String()) + port, _ := strconv.Atoi(portText) + requestBytes := []byte{5, 1, 0, 1} + requestBytes = append(requestBytes, net.ParseIP(host).To4()...) + requestBytes = append(requestBytes, byte(port>>8), byte(port)) + if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, requestBytes); err != nil { + t.Fatal(err) + } + data, _, err = wsutil.ReadServerData(clientSide) + if err != nil { + t.Fatal(err) + } + if len(data) != 10 || data[1] != 0 { + t.Fatalf("unexpected connect response: %v", data) + } + + if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil { + t.Fatal(err) + } + data, _, err = wsutil.ReadServerData(clientSide) + if err != nil { + t.Fatal(err) + } + if string(data) != "hello" { + t.Fatalf("expected echoed payload, got %q", string(data)) + } + _ = clientSide.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("socks-proxy stream did not exit") + } +}