From 670f32baee1168ca35c854f8ecbb1d56925b35bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 12 Dec 2025 21:18:59 +0800 Subject: [PATCH] Fix naive inbound --- protocol/naive/inbound.go | 63 ++--- protocol/naive/inbound_conn.go | 469 +++++++++++---------------------- 2 files changed, 176 insertions(+), 356 deletions(-) diff --git a/protocol/naive/inbound.go b/protocol/naive/inbound.go index f6e456f52..6354f011c 100644 --- a/protocol/naive/inbound.go +++ b/protocol/naive/inbound.go @@ -2,8 +2,8 @@ package naive import ( "context" + "errors" "io" - "math/rand" "net" "net/http" @@ -22,7 +22,11 @@ import ( "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + aTLS "github.com/sagernet/sing/common/tls" sHttp "github.com/sagernet/sing/protocol/http" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) var ConfigureHTTP3ListenerFunc func(listener *listener.Listener, handler http.Handler, tlsConfig tls.ServerConfig, logger logger.Logger) (io.Closer, error) @@ -82,16 +86,11 @@ func (n *Inbound) Start(stage adapter.StartStage) error { if stage != adapter.StartStateStart { return nil } - var tlsConfig *tls.STDConfig if n.tlsConfig != nil { err := n.tlsConfig.Start() if err != nil { return E.Cause(err, "create TLS config") } - tlsConfig, err = n.tlsConfig.Config() - if err != nil { - return err - } } if common.Contains(n.network, N.NetworkTCP) { tcpListener, err := n.listener.ListenTCP() @@ -99,20 +98,23 @@ func (n *Inbound) Start(stage adapter.StartStage) error { return err } n.httpServer = &http.Server{ - Handler: n, - TLSConfig: tlsConfig, + Handler: h2c.NewHandler(n, &http2.Server{}), BaseContext: func(listener net.Listener) context.Context { return n.ctx }, } go func() { - var sErr error - if tlsConfig != nil { - sErr = n.httpServer.ServeTLS(tcpListener, "", "") - } else { - sErr = n.httpServer.Serve(tcpListener) + listener := net.Listener(tcpListener) + if n.tlsConfig != nil { + if len(n.tlsConfig.NextProtos()) == 0 { + n.tlsConfig.SetNextProtos([]string{http2.NextProtoTLS, "http/1.1"}) + } else if !common.Contains(n.tlsConfig.NextProtos(), http2.NextProtoTLS) { + n.tlsConfig.SetNextProtos(append([]string{http2.NextProtoTLS}, n.tlsConfig.NextProtos()...)) + } + listener = aTLS.NewListener(tcpListener, n.tlsConfig) } - if sErr != nil && !E.IsClosedOrCanceled(sErr) { + sErr := n.httpServer.Serve(listener) + if sErr != nil && !errors.Is(sErr, http.ErrServerClosed) { n.logger.Error("http server serve error: ", sErr) } }() @@ -161,13 +163,16 @@ func (n *Inbound) ServeHTTP(writer http.ResponseWriter, request *http.Request) { n.badRequest(ctx, request, E.New("authorization failed")) return } - writer.Header().Set("Padding", generateNaivePaddingHeader()) + writer.Header().Set("Padding", generatePaddingHeader()) writer.WriteHeader(http.StatusOK) writer.(http.Flusher).Flush() - hostPort := request.URL.Host + hostPort := request.Header.Get("-connect-authority") if hostPort == "" { - hostPort = request.Host + hostPort = request.URL.Host + if hostPort == "" { + hostPort = request.Host + } } source := sHttp.SourceAddress(request) destination := M.ParseSocksaddr(hostPort).Unwrap() @@ -178,9 +183,14 @@ func (n *Inbound) ServeHTTP(writer http.ResponseWriter, request *http.Request) { n.badRequest(ctx, request, E.New("hijack failed")) return } - n.newConnection(ctx, false, &naiveH1Conn{Conn: conn}, userName, source, destination) + n.newConnection(ctx, false, &naiveConn{Conn: conn}, userName, source, destination) } else { - n.newConnection(ctx, true, &naiveH2Conn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, userName, source, destination) + n.newConnection(ctx, true, &naiveH2Conn{ + reader: request.Body, + writer: writer, + flusher: writer.(http.Flusher), + remoteAddress: source, + }, userName, source, destination) } } @@ -236,18 +246,3 @@ func rejectHTTP(writer http.ResponseWriter, statusCode int) { } conn.Close() } - -func generateNaivePaddingHeader() string { - paddingLen := rand.Intn(32) + 30 - padding := make([]byte, paddingLen) - bits := rand.Uint64() - for i := 0; i < 16; i++ { - // Codes that won't be Huffman coded. - padding[i] = "!#$()+<>?@[]^`{}"[bits&15] - bits >>= 4 - } - for i := 16; i < paddingLen; i++ { - padding[i] = '~' - } - return string(padding) -} diff --git a/protocol/naive/inbound_conn.go b/protocol/naive/inbound_conn.go index 16944cbaf..1dbb2bd82 100644 --- a/protocol/naive/inbound_conn.go +++ b/protocol/naive/inbound_conn.go @@ -7,417 +7,242 @@ import ( "net" "net/http" "os" - "strings" "time" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/baderror" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/rw" ) -const kFirstPaddings = 8 +const paddingCount = 8 -type naiveH1Conn struct { - net.Conn +func generatePaddingHeader() string { + paddingLen := rand.Intn(32) + 30 + padding := make([]byte, paddingLen) + bits := rand.Uint64() + for i := 0; i < 16; i++ { + padding[i] = "!#$()+<>?@[]^`{}"[bits&15] + bits >>= 4 + } + for i := 16; i < paddingLen; i++ { + padding[i] = '~' + } + return string(padding) +} + +type paddingConn struct { readPadding int writePadding int readRemaining int paddingRemaining int } -func (c *naiveH1Conn) Read(p []byte) (n int, err error) { - n, err = c.read(p) - return n, wrapHttpError(err) -} - -func (c *naiveH1Conn) read(p []byte) (n int, err error) { - if c.readRemaining > 0 { - if len(p) > c.readRemaining { - p = p[:c.readRemaining] +func (p *paddingConn) readWithPadding(reader io.Reader, buffer []byte) (n int, err error) { + if p.readRemaining > 0 { + if len(buffer) > p.readRemaining { + buffer = buffer[:p.readRemaining] } - n, err = c.Conn.Read(p) + n, err = reader.Read(buffer) if err != nil { return } - c.readRemaining -= n + p.readRemaining -= n return } - if c.paddingRemaining > 0 { - err = rw.SkipN(c.Conn, c.paddingRemaining) + if p.paddingRemaining > 0 { + err = rw.SkipN(reader, p.paddingRemaining) if err != nil { return } - c.paddingRemaining = 0 + p.paddingRemaining = 0 } - if c.readPadding < kFirstPaddings { - var paddingHdr []byte - if len(p) >= 3 { - paddingHdr = p[:3] + if p.readPadding < paddingCount { + var paddingHeader []byte + if len(buffer) >= 3 { + paddingHeader = buffer[:3] } else { - paddingHdr = make([]byte, 3) + paddingHeader = make([]byte, 3) } - _, err = io.ReadFull(c.Conn, paddingHdr) + _, err = io.ReadFull(reader, paddingHeader) if err != nil { return } - originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2])) - paddingSize := int(paddingHdr[2]) - if len(p) > originalDataSize { - p = p[:originalDataSize] + originalDataSize := int(binary.BigEndian.Uint16(paddingHeader[:2])) + paddingSize := int(paddingHeader[2]) + if len(buffer) > originalDataSize { + buffer = buffer[:originalDataSize] } - n, err = c.Conn.Read(p) + n, err = reader.Read(buffer) if err != nil { return } - c.readPadding++ - c.readRemaining = originalDataSize - n - c.paddingRemaining = paddingSize + p.readPadding++ + p.readRemaining = originalDataSize - n + p.paddingRemaining = paddingSize return } - return c.Conn.Read(p) + return reader.Read(buffer) } -func (c *naiveH1Conn) Write(p []byte) (n int, err error) { - for pLen := len(p); pLen > 0; { - var data []byte - if pLen > 65535 { - data = p[:65535] - p = p[65535:] - pLen -= 65535 - } else { - data = p - pLen = 0 - } - var writeN int - writeN, err = c.write(data) - n += writeN - if err != nil { - break - } - } - return n, wrapHttpError(err) -} - -func (c *naiveH1Conn) write(p []byte) (n int, err error) { - if c.writePadding < kFirstPaddings { +func (p *paddingConn) writeWithPadding(writer io.Writer, data []byte) (n int, err error) { + if p.writePadding < paddingCount { paddingSize := rand.Intn(256) - - buffer := buf.NewSize(3 + len(p) + paddingSize) + buffer := buf.NewSize(3 + len(data) + paddingSize) defer buffer.Release() header := buffer.Extend(3) - binary.BigEndian.PutUint16(header, uint16(len(p))) + binary.BigEndian.PutUint16(header, uint16(len(data))) header[2] = byte(paddingSize) - - common.Must1(buffer.Write(p)) - _, err = c.Conn.Write(buffer.Bytes()) + common.Must1(buffer.Write(data)) + _, err = writer.Write(buffer.Bytes()) if err == nil { - n = len(p) + n = len(data) } - c.writePadding++ + p.writePadding++ return } - return c.Conn.Write(p) + return writer.Write(data) } -func (c *naiveH1Conn) FrontHeadroom() int { - if c.writePadding < kFirstPaddings { - return 3 - } - return 0 -} - -func (c *naiveH1Conn) RearHeadroom() int { - if c.writePadding < kFirstPaddings { - return 255 - } - return 0 -} - -func (c *naiveH1Conn) WriterMTU() int { - if c.writePadding < kFirstPaddings { - return 65535 - } - return 0 -} - -func (c *naiveH1Conn) WriteBuffer(buffer *buf.Buffer) error { - defer buffer.Release() - if c.writePadding < kFirstPaddings { +func (p *paddingConn) writeBufferWithPadding(writer io.Writer, buffer *buf.Buffer) error { + if p.writePadding < paddingCount { bufferLen := buffer.Len() if bufferLen > 65535 { - return common.Error(c.Write(buffer.Bytes())) + _, err := p.writeChunked(writer, buffer.Bytes()) + return err } paddingSize := rand.Intn(256) header := buffer.ExtendHeader(3) binary.BigEndian.PutUint16(header, uint16(bufferLen)) header[2] = byte(paddingSize) buffer.Extend(paddingSize) - c.writePadding++ + p.writePadding++ } - return wrapHttpError(common.Error(c.Conn.Write(buffer.Bytes()))) + return common.Error(writer.Write(buffer.Bytes())) } -// FIXME -/*func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) { - if c.readPadding < kFirstPaddings { - n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding) - } else { - n, err = bufio.Copy(w, c.Conn) - } - return n, wrapHttpError(err) -} - -func (c *naiveH1Conn) ReadFrom(r io.Reader) (n int64, err error) { - if c.writePadding < kFirstPaddings { - n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding) - } else { - n, err = bufio.Copy(c.Conn, r) - } - return n, wrapHttpError(err) -} -*/ - -func (c *naiveH1Conn) Upstream() any { - return c.Conn -} - -func (c *naiveH1Conn) ReaderReplaceable() bool { - return c.readPadding == kFirstPaddings -} - -func (c *naiveH1Conn) WriterReplaceable() bool { - return c.writePadding == kFirstPaddings -} - -type naiveH2Conn struct { - reader io.Reader - writer io.Writer - flusher http.Flusher - rAddr net.Addr - readPadding int - writePadding int - readRemaining int - paddingRemaining int -} - -func (c *naiveH2Conn) Read(p []byte) (n int, err error) { - n, err = c.read(p) - return n, wrapHttpError(err) -} - -func (c *naiveH2Conn) read(p []byte) (n int, err error) { - if c.readRemaining > 0 { - if len(p) > c.readRemaining { - p = p[:c.readRemaining] - } - n, err = c.reader.Read(p) - if err != nil { - return - } - c.readRemaining -= n - return - } - if c.paddingRemaining > 0 { - err = rw.SkipN(c.reader, c.paddingRemaining) - if err != nil { - return - } - c.paddingRemaining = 0 - } - if c.readPadding < kFirstPaddings { - var paddingHdr []byte - if len(p) >= 3 { - paddingHdr = p[:3] +func (p *paddingConn) writeChunked(writer io.Writer, data []byte) (n int, err error) { + for len(data) > 0 { + var chunk []byte + if len(data) > 65535 { + chunk = data[:65535] + data = data[65535:] } else { - paddingHdr = make([]byte, 3) + chunk = data + data = nil } - _, err = io.ReadFull(c.reader, paddingHdr) + var written int + written, err = p.writeWithPadding(writer, chunk) + n += written if err != nil { return } - originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2])) - paddingSize := int(paddingHdr[2]) - if len(p) > originalDataSize { - p = p[:originalDataSize] - } - n, err = c.reader.Read(p) - if err != nil { - return - } - c.readPadding++ - c.readRemaining = originalDataSize - n - c.paddingRemaining = paddingSize - return } - return c.reader.Read(p) + return } -func (c *naiveH2Conn) Write(p []byte) (n int, err error) { - for pLen := len(p); pLen > 0; { - var data []byte - if pLen > 65535 { - data = p[:65535] - p = p[65535:] - pLen -= 65535 - } else { - data = p - pLen = 0 - } - var writeN int - writeN, err = c.write(data) - n += writeN - if err != nil { - break - } - } - if err == nil { - c.flusher.Flush() - } - return n, wrapHttpError(err) -} - -func (c *naiveH2Conn) write(p []byte) (n int, err error) { - if c.writePadding < kFirstPaddings { - paddingSize := rand.Intn(256) - - buffer := buf.NewSize(3 + len(p) + paddingSize) - defer buffer.Release() - header := buffer.Extend(3) - binary.BigEndian.PutUint16(header, uint16(len(p))) - header[2] = byte(paddingSize) - - common.Must1(buffer.Write(p)) - _, err = c.writer.Write(buffer.Bytes()) - if err == nil { - n = len(p) - } - c.writePadding++ - return - } - return c.writer.Write(p) -} - -func (c *naiveH2Conn) FrontHeadroom() int { - if c.writePadding < kFirstPaddings { +func (p *paddingConn) frontHeadroom() int { + if p.writePadding < paddingCount { return 3 } return 0 } -func (c *naiveH2Conn) RearHeadroom() int { - if c.writePadding < kFirstPaddings { +func (p *paddingConn) rearHeadroom() int { + if p.writePadding < paddingCount { return 255 } return 0 } -func (c *naiveH2Conn) WriterMTU() int { - if c.writePadding < kFirstPaddings { +func (p *paddingConn) writerMTU() int { + if p.writePadding < paddingCount { return 65535 } return 0 } +func (p *paddingConn) readerReplaceable() bool { + return p.readPadding == paddingCount +} + +func (p *paddingConn) writerReplaceable() bool { + return p.writePadding == paddingCount +} + +type naiveConn struct { + net.Conn + paddingConn +} + +func (c *naiveConn) Read(p []byte) (n int, err error) { + n, err = c.readWithPadding(c.Conn, p) + return n, baderror.WrapH2(err) +} + +func (c *naiveConn) Write(p []byte) (n int, err error) { + n, err = c.writeChunked(c.Conn, p) + return n, baderror.WrapH2(err) +} + +func (c *naiveConn) WriteBuffer(buffer *buf.Buffer) error { + defer buffer.Release() + err := c.writeBufferWithPadding(c.Conn, buffer) + return baderror.WrapH2(err) +} + +func (c *naiveConn) FrontHeadroom() int { return c.frontHeadroom() } +func (c *naiveConn) RearHeadroom() int { return c.rearHeadroom() } +func (c *naiveConn) WriterMTU() int { return c.writerMTU() } +func (c *naiveConn) Upstream() any { return c.Conn } +func (c *naiveConn) ReaderReplaceable() bool { return c.readerReplaceable() } +func (c *naiveConn) WriterReplaceable() bool { return c.writerReplaceable() } + +type naiveH2Conn struct { + reader io.Reader + writer io.Writer + flusher http.Flusher + remoteAddress net.Addr + paddingConn +} + +func (c *naiveH2Conn) Read(p []byte) (n int, err error) { + n, err = c.readWithPadding(c.reader, p) + return n, baderror.WrapH2(err) +} + +func (c *naiveH2Conn) Write(p []byte) (n int, err error) { + n, err = c.writeChunked(c.writer, p) + if err == nil { + c.flusher.Flush() + } + return n, baderror.WrapH2(err) +} + func (c *naiveH2Conn) WriteBuffer(buffer *buf.Buffer) error { defer buffer.Release() - if c.writePadding < kFirstPaddings { - bufferLen := buffer.Len() - if bufferLen > 65535 { - return common.Error(c.Write(buffer.Bytes())) - } - paddingSize := rand.Intn(256) - header := buffer.ExtendHeader(3) - binary.BigEndian.PutUint16(header, uint16(bufferLen)) - header[2] = byte(paddingSize) - buffer.Extend(paddingSize) - c.writePadding++ - } - err := common.Error(c.writer.Write(buffer.Bytes())) + err := c.writeBufferWithPadding(c.writer, buffer) if err == nil { c.flusher.Flush() } - return wrapHttpError(err) + return baderror.WrapH2(err) } -// FIXME -/*func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) { - if c.readPadding < kFirstPaddings { - n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding) - } else { - n, err = bufio.Copy(w, c.reader) - } - return n, wrapHttpError(err) -} - -func (c *naiveH2Conn) ReadFrom(r io.Reader) (n int64, err error) { - if c.writePadding < kFirstPaddings { - n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding) - } else { - n, err = bufio.Copy(c.writer, r) - } - return n, wrapHttpError(err) -}*/ - func (c *naiveH2Conn) Close() error { - return common.Close( - c.reader, - c.writer, - ) + return common.Close(c.reader, c.writer) } -func (c *naiveH2Conn) LocalAddr() net.Addr { - return M.Socksaddr{} -} - -func (c *naiveH2Conn) RemoteAddr() net.Addr { - return c.rAddr -} - -func (c *naiveH2Conn) SetDeadline(t time.Time) error { - return os.ErrInvalid -} - -func (c *naiveH2Conn) SetReadDeadline(t time.Time) error { - return os.ErrInvalid -} - -func (c *naiveH2Conn) SetWriteDeadline(t time.Time) error { - return os.ErrInvalid -} - -func (c *naiveH2Conn) NeedAdditionalReadDeadline() bool { - return true -} - -func (c *naiveH2Conn) UpstreamReader() any { - return c.reader -} - -func (c *naiveH2Conn) UpstreamWriter() any { - return c.writer -} - -func (c *naiveH2Conn) ReaderReplaceable() bool { - return c.readPadding == kFirstPaddings -} - -func (c *naiveH2Conn) WriterReplaceable() bool { - return c.writePadding == kFirstPaddings -} - -func wrapHttpError(err error) error { - if err == nil { - return err - } - if strings.Contains(err.Error(), "client disconnected") { - return net.ErrClosed - } - if strings.Contains(err.Error(), "body closed by handler") { - return net.ErrClosed - } - if strings.Contains(err.Error(), "canceled with error code 268") { - return io.EOF - } - return err -} +func (c *naiveH2Conn) LocalAddr() net.Addr { return M.Socksaddr{} } +func (c *naiveH2Conn) RemoteAddr() net.Addr { return c.remoteAddress } +func (c *naiveH2Conn) SetDeadline(t time.Time) error { return os.ErrInvalid } +func (c *naiveH2Conn) SetReadDeadline(t time.Time) error { return os.ErrInvalid } +func (c *naiveH2Conn) SetWriteDeadline(t time.Time) error { return os.ErrInvalid } +func (c *naiveH2Conn) NeedAdditionalReadDeadline() bool { return true } +func (c *naiveH2Conn) UpstreamReader() any { return c.reader } +func (c *naiveH2Conn) UpstreamWriter() any { return c.writer } +func (c *naiveH2Conn) FrontHeadroom() int { return c.frontHeadroom() } +func (c *naiveH2Conn) RearHeadroom() int { return c.rearHeadroom() } +func (c *naiveH2Conn) WriterMTU() int { return c.writerMTU() } +func (c *naiveH2Conn) ReaderReplaceable() bool { return c.readerReplaceable() } +func (c *naiveH2Conn) WriterReplaceable() bool { return c.writerReplaceable() }