Files
sing-box/protocol/naive/inbound.go
2026-03-01 12:30:43 +08:00

249 lines
7.4 KiB
Go

package naive
import (
"context"
"errors"
"io"
"net"
"net/http"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/common/listener"
"github.com/sagernet/sing-box/common/tls"
"github.com/sagernet/sing-box/common/uot"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/v2rayhttp"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
sHttp "github.com/sagernet/sing/protocol/http"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
var ConfigureHTTP3ListenerFunc func(ctx context.Context, logger logger.Logger, listener *listener.Listener, handler http.Handler, tlsConfig tls.ServerConfig, options option.NaiveInboundOptions) (io.Closer, error)
func RegisterInbound(registry *inbound.Registry) {
inbound.Register[option.NaiveInboundOptions](registry, C.TypeNaive, NewInbound)
}
type Inbound struct {
inbound.Adapter
ctx context.Context
router adapter.ConnectionRouterEx
logger logger.ContextLogger
options option.NaiveInboundOptions
listener *listener.Listener
network []string
networkIsDefault bool
authenticator *auth.Authenticator
tlsConfig tls.ServerConfig
httpServer *http.Server
h3Server io.Closer
}
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NaiveInboundOptions) (adapter.Inbound, error) {
inbound := &Inbound{
Adapter: inbound.NewAdapter(C.TypeNaive, tag),
ctx: ctx,
router: uot.NewRouter(router, logger),
logger: logger,
listener: listener.New(listener.Options{
Context: ctx,
Logger: logger,
Listen: options.ListenOptions,
}),
networkIsDefault: options.Network == "",
network: options.Network.Build(),
authenticator: auth.NewAuthenticator(options.Users),
}
if common.Contains(inbound.network, N.NetworkUDP) {
if options.TLS == nil || !options.TLS.Enabled {
return nil, E.New("TLS is required for QUIC server")
}
}
if len(options.Users) == 0 {
return nil, E.New("missing users")
}
if options.TLS != nil {
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
inbound.tlsConfig = tlsConfig
}
return inbound, nil
}
func (n *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if n.tlsConfig != nil {
err := n.tlsConfig.Start()
if err != nil {
return E.Cause(err, "create TLS config")
}
}
if common.Contains(n.network, N.NetworkTCP) {
tcpListener, err := n.listener.ListenTCP()
if err != nil {
return err
}
n.httpServer = &http.Server{
Handler: h2c.NewHandler(n, &http2.Server{}),
BaseContext: func(listener net.Listener) context.Context {
return n.ctx
},
}
go func() {
listener := net.Listener(tcpListener)
if n.tlsConfig != nil {
if len(n.tlsConfig.NextProtos()) == 0 {
n.tlsConfig.SetNextProtos([]string{http2.NextProtoTLS, "http/1.1"})
} else if !common.Contains(n.tlsConfig.NextProtos(), http2.NextProtoTLS) {
n.tlsConfig.SetNextProtos(append([]string{http2.NextProtoTLS}, n.tlsConfig.NextProtos()...))
}
listener = aTLS.NewListener(tcpListener, n.tlsConfig)
}
sErr := n.httpServer.Serve(listener)
if sErr != nil && !errors.Is(sErr, http.ErrServerClosed) {
n.logger.Error("http server serve error: ", sErr)
}
}()
}
if common.Contains(n.network, N.NetworkUDP) {
http3Server, err := ConfigureHTTP3ListenerFunc(n.ctx, n.logger, n.listener, n, n.tlsConfig, n.options)
if err == nil {
n.h3Server = http3Server
} else if len(n.network) > 1 {
n.logger.Warn(E.Cause(err, "naive http3 disabled"))
} else {
return err
}
}
return nil
}
func (n *Inbound) Close() error {
return common.Close(
&n.listener,
common.PtrOrNil(n.httpServer),
n.h3Server,
n.tlsConfig,
)
}
func (n *Inbound) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
ctx := log.ContextWithNewID(request.Context())
if request.Method != "CONNECT" {
rejectHTTP(writer, http.StatusBadRequest)
n.badRequest(ctx, request, E.New("not CONNECT request"))
return
} else if request.Header.Get("Padding") == "" {
rejectHTTP(writer, http.StatusBadRequest)
n.badRequest(ctx, request, E.New("missing naive padding"))
return
}
userName, password, authOk := sHttp.ParseBasicAuth(request.Header.Get("Proxy-Authorization"))
if authOk {
authOk = n.authenticator.Verify(userName, password)
}
if !authOk {
rejectHTTP(writer, http.StatusProxyAuthRequired)
n.badRequest(ctx, request, E.New("authorization failed"))
return
}
writer.Header().Set("Padding", generatePaddingHeader())
writer.WriteHeader(http.StatusOK)
writer.(http.Flusher).Flush()
hostPort := request.Header.Get("-connect-authority")
if hostPort == "" {
hostPort = request.URL.Host
if hostPort == "" {
hostPort = request.Host
}
}
source := sHttp.SourceAddress(request)
destination := M.ParseSocksaddr(hostPort).Unwrap()
if hijacker, isHijacker := writer.(http.Hijacker); isHijacker {
conn, _, err := hijacker.Hijack()
if err != nil {
n.badRequest(ctx, request, E.New("hijack failed"))
return
}
n.newConnection(ctx, false, &naiveConn{Conn: conn}, userName, source, destination)
} else {
n.newConnection(ctx, true, &naiveH2Conn{
reader: request.Body,
writer: writer,
flusher: writer.(http.Flusher),
remoteAddress: source,
}, userName, source, destination)
}
}
func (n *Inbound) newConnection(ctx context.Context, waitForClose bool, conn net.Conn, userName string, source M.Socksaddr, destination M.Socksaddr) {
if userName != "" {
n.logger.InfoContext(ctx, "[", userName, "] inbound connection from ", source)
n.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", destination)
} else {
n.logger.InfoContext(ctx, "inbound connection from ", source)
n.logger.InfoContext(ctx, "inbound connection to ", destination)
}
var metadata adapter.InboundContext
metadata.Inbound = n.Tag()
metadata.InboundType = n.Type()
//nolint:staticcheck
metadata.InboundDetour = n.listener.ListenOptions().Detour
//nolint:staticcheck
metadata.Source = source
metadata.Destination = destination
metadata.OriginDestination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
metadata.User = userName
if !waitForClose {
n.router.RouteConnectionEx(ctx, conn, metadata, nil)
} else {
done := make(chan struct{})
wrapper := v2rayhttp.NewHTTP2Wrapper(conn)
n.router.RouteConnectionEx(ctx, conn, metadata, N.OnceClose(func(it error) {
close(done)
}))
<-done
wrapper.CloseWrapper()
}
}
func (n *Inbound) badRequest(ctx context.Context, request *http.Request, err error) {
n.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", request.RemoteAddr))
}
func rejectHTTP(writer http.ResponseWriter, statusCode int) {
hijacker, ok := writer.(http.Hijacker)
if !ok {
writer.WriteHeader(statusCode)
return
}
conn, _, err := hijacker.Hijack()
if err != nil {
writer.WriteHeader(statusCode)
return
}
if tcpConn, isTCP := common.Cast[*net.TCPConn](conn); isTCP {
tcpConn.SetLinger(0)
}
conn.Close()
}