Add custom tls client support for std grpc
This commit is contained in:
@@ -24,11 +24,11 @@ func NewDialerFromOptions(router adapter.Router, dialer N.Dialer, serverAddress
|
||||
|
||||
func NewClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
if options.ECH != nil && options.ECH.Enabled {
|
||||
return newECHClient(router, serverAddress, options)
|
||||
return NewECHClient(router, serverAddress, options)
|
||||
} else if options.UTLS != nil && options.UTLS.Enabled {
|
||||
return newUTLSClient(router, serverAddress, options)
|
||||
return NewUTLSClient(router, serverAddress, options)
|
||||
} else {
|
||||
return newStdClient(serverAddress, options)
|
||||
return NewSTDClient(serverAddress, options)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,10 +15,13 @@ type (
|
||||
)
|
||||
|
||||
type Config interface {
|
||||
ServerName() string
|
||||
SetServerName(serverName string)
|
||||
NextProtos() []string
|
||||
SetNextProtos(nextProto []string)
|
||||
Config() (*STDConfig, error)
|
||||
Client(conn net.Conn) Conn
|
||||
Clone() Config
|
||||
}
|
||||
|
||||
type ServerConfig interface {
|
||||
|
||||
@@ -20,26 +20,40 @@ import (
|
||||
mDNS "github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type echClientConfig struct {
|
||||
type ECHClientConfig struct {
|
||||
config *cftls.Config
|
||||
}
|
||||
|
||||
func (e *echClientConfig) NextProtos() []string {
|
||||
func (e *ECHClientConfig) ServerName() string {
|
||||
return e.config.ServerName
|
||||
}
|
||||
|
||||
func (e *ECHClientConfig) SetServerName(serverName string) {
|
||||
e.config.ServerName = serverName
|
||||
}
|
||||
|
||||
func (e *ECHClientConfig) NextProtos() []string {
|
||||
return e.config.NextProtos
|
||||
}
|
||||
|
||||
func (e *echClientConfig) SetNextProtos(nextProto []string) {
|
||||
func (e *ECHClientConfig) SetNextProtos(nextProto []string) {
|
||||
e.config.NextProtos = nextProto
|
||||
}
|
||||
|
||||
func (e *echClientConfig) Config() (*STDConfig, error) {
|
||||
func (e *ECHClientConfig) Config() (*STDConfig, error) {
|
||||
return nil, E.New("unsupported usage for ECH")
|
||||
}
|
||||
|
||||
func (e *echClientConfig) Client(conn net.Conn) Conn {
|
||||
func (e *ECHClientConfig) Client(conn net.Conn) Conn {
|
||||
return &echConnWrapper{cftls.Client(conn, e.config)}
|
||||
}
|
||||
|
||||
func (e *ECHClientConfig) Clone() Config {
|
||||
return &ECHClientConfig{
|
||||
config: e.config.Clone(),
|
||||
}
|
||||
}
|
||||
|
||||
type echConnWrapper struct {
|
||||
*cftls.Conn
|
||||
}
|
||||
@@ -62,7 +76,7 @@ func (c *echConnWrapper) ConnectionState() tls.ConnectionState {
|
||||
}
|
||||
}
|
||||
|
||||
func newECHClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
func NewECHClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
var serverName string
|
||||
if options.ServerName != "" {
|
||||
serverName = options.ServerName
|
||||
@@ -162,7 +176,7 @@ func newECHClient(router adapter.Router, serverAddress string, options option.Ou
|
||||
} else {
|
||||
tlsConfig.GetClientECHConfigs = fetchECHClientConfig(router)
|
||||
}
|
||||
return &echClientConfig{&tlsConfig}, nil
|
||||
return &ECHClientConfig{&tlsConfig}, nil
|
||||
}
|
||||
|
||||
func fetchECHClientConfig(router adapter.Router) func(ctx context.Context, serverName string) ([]cftls.ECHConfig, error) {
|
||||
|
||||
@@ -8,6 +8,6 @@ import (
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func newECHClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
func NewECHClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
return nil, E.New(`ECH is not included in this build, rebuild with -tags with_ech`)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func NewServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
|
||||
return newSTDServer(ctx, logger, options)
|
||||
return NewSTDServer(ctx, logger, options)
|
||||
}
|
||||
|
||||
func ServerHandshake(ctx context.Context, conn net.Conn, config ServerConfig) (Conn, error) {
|
||||
|
||||
@@ -11,11 +11,39 @@ import (
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type stdClientConfig struct {
|
||||
type STDClientConfig struct {
|
||||
config *tls.Config
|
||||
}
|
||||
|
||||
func newStdClient(serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
func (s *STDClientConfig) ServerName() string {
|
||||
return s.config.ServerName
|
||||
}
|
||||
|
||||
func (s *STDClientConfig) SetServerName(serverName string) {
|
||||
s.config.ServerName = serverName
|
||||
}
|
||||
|
||||
func (s *STDClientConfig) NextProtos() []string {
|
||||
return s.config.NextProtos
|
||||
}
|
||||
|
||||
func (s *STDClientConfig) SetNextProtos(nextProto []string) {
|
||||
s.config.NextProtos = nextProto
|
||||
}
|
||||
|
||||
func (s *STDClientConfig) Config() (*STDConfig, error) {
|
||||
return s.config, nil
|
||||
}
|
||||
|
||||
func (s *STDClientConfig) Client(conn net.Conn) Conn {
|
||||
return tls.Client(conn, s.config)
|
||||
}
|
||||
|
||||
func (s *STDClientConfig) Clone() Config {
|
||||
return &STDClientConfig{s.config.Clone()}
|
||||
}
|
||||
|
||||
func NewSTDClient(serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
var serverName string
|
||||
if options.ServerName != "" {
|
||||
serverName = options.ServerName
|
||||
@@ -96,21 +124,5 @@ func newStdClient(serverAddress string, options option.OutboundTLSOptions) (Conf
|
||||
}
|
||||
tlsConfig.RootCAs = certPool
|
||||
}
|
||||
return &stdClientConfig{&tlsConfig}, nil
|
||||
}
|
||||
|
||||
func (s *stdClientConfig) NextProtos() []string {
|
||||
return s.config.NextProtos
|
||||
}
|
||||
|
||||
func (s *stdClientConfig) SetNextProtos(nextProto []string) {
|
||||
s.config.NextProtos = nextProto
|
||||
}
|
||||
|
||||
func (s *stdClientConfig) Config() (*STDConfig, error) {
|
||||
return s.config, nil
|
||||
}
|
||||
|
||||
func (s *stdClientConfig) Client(conn net.Conn) Conn {
|
||||
return tls.Client(conn, s.config)
|
||||
return &STDClientConfig{&tlsConfig}, nil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
var errInsecureUnused = E.New("tls: insecure unused")
|
||||
|
||||
type STDServerConfig struct {
|
||||
config *tls.Config
|
||||
logger log.Logger
|
||||
@@ -26,6 +28,14 @@ type STDServerConfig struct {
|
||||
watcher *fsnotify.Watcher
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) ServerName() string {
|
||||
return c.config.ServerName
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) SetServerName(serverName string) {
|
||||
c.config.ServerName = serverName
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) NextProtos() []string {
|
||||
return c.config.NextProtos
|
||||
}
|
||||
@@ -34,9 +44,119 @@ func (c *STDServerConfig) SetNextProtos(nextProto []string) {
|
||||
c.config.NextProtos = nextProto
|
||||
}
|
||||
|
||||
var errInsecureUnused = E.New("tls: insecure unused")
|
||||
func (c *STDServerConfig) Config() (*STDConfig, error) {
|
||||
return c.config, nil
|
||||
}
|
||||
|
||||
func newSTDServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
|
||||
func (c *STDServerConfig) Client(conn net.Conn) Conn {
|
||||
return tls.Client(conn, c.config)
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) Server(conn net.Conn) Conn {
|
||||
return tls.Server(conn, c.config)
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) Clone() Config {
|
||||
return &STDServerConfig{
|
||||
config: c.config.Clone(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) Start() error {
|
||||
if c.acmeService != nil {
|
||||
return c.acmeService.Start()
|
||||
} else {
|
||||
if c.certificatePath == "" && c.keyPath == "" {
|
||||
return nil
|
||||
}
|
||||
err := c.startWatcher()
|
||||
if err != nil {
|
||||
c.logger.Warn("create fsnotify watcher: ", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) startWatcher() error {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.certificatePath != "" {
|
||||
err = watcher.Add(c.certificatePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.keyPath != "" {
|
||||
err = watcher.Add(c.keyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.watcher = watcher
|
||||
go c.loopUpdate()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) loopUpdate() {
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-c.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Op&fsnotify.Write != fsnotify.Write {
|
||||
continue
|
||||
}
|
||||
err := c.reloadKeyPair()
|
||||
if err != nil {
|
||||
c.logger.Error(E.Cause(err, "reload TLS key pair"))
|
||||
}
|
||||
case err, ok := <-c.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
c.logger.Error(E.Cause(err, "fsnotify error"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) reloadKeyPair() error {
|
||||
if c.certificatePath != "" {
|
||||
certificate, err := os.ReadFile(c.certificatePath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "reload certificate from ", c.certificatePath)
|
||||
}
|
||||
c.certificate = certificate
|
||||
}
|
||||
if c.keyPath != "" {
|
||||
key, err := os.ReadFile(c.keyPath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "reload key from ", c.keyPath)
|
||||
}
|
||||
c.key = key
|
||||
}
|
||||
keyPair, err := tls.X509KeyPair(c.certificate, c.key)
|
||||
if err != nil {
|
||||
return E.Cause(err, "reload key pair")
|
||||
}
|
||||
c.config.Certificates = []tls.Certificate{keyPair}
|
||||
c.logger.Info("reloaded TLS certificate")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) Close() error {
|
||||
if c.acmeService != nil {
|
||||
return c.acmeService.Close()
|
||||
}
|
||||
if c.watcher != nil {
|
||||
return c.watcher.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewSTDServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
|
||||
if !options.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -136,109 +256,3 @@ func newSTDServer(ctx context.Context, logger log.Logger, options option.Inbound
|
||||
keyPath: options.KeyPath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) Config() (*STDConfig, error) {
|
||||
return c.config, nil
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) Client(conn net.Conn) Conn {
|
||||
return tls.Client(conn, c.config)
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) Server(conn net.Conn) Conn {
|
||||
return tls.Server(conn, c.config)
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) Start() error {
|
||||
if c.acmeService != nil {
|
||||
return c.acmeService.Start()
|
||||
} else {
|
||||
if c.certificatePath == "" && c.keyPath == "" {
|
||||
return nil
|
||||
}
|
||||
err := c.startWatcher()
|
||||
if err != nil {
|
||||
c.logger.Warn("create fsnotify watcher: ", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) startWatcher() error {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.certificatePath != "" {
|
||||
err = watcher.Add(c.certificatePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.keyPath != "" {
|
||||
err = watcher.Add(c.keyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.watcher = watcher
|
||||
go c.loopUpdate()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) loopUpdate() {
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-c.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Op&fsnotify.Write != fsnotify.Write {
|
||||
continue
|
||||
}
|
||||
err := c.reloadKeyPair()
|
||||
if err != nil {
|
||||
c.logger.Error(E.Cause(err, "reload TLS key pair"))
|
||||
}
|
||||
case err, ok := <-c.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
c.logger.Error(E.Cause(err, "fsnotify error"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) reloadKeyPair() error {
|
||||
if c.certificatePath != "" {
|
||||
certificate, err := os.ReadFile(c.certificatePath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "reload certificate from ", c.certificatePath)
|
||||
}
|
||||
c.certificate = certificate
|
||||
}
|
||||
if c.keyPath != "" {
|
||||
key, err := os.ReadFile(c.keyPath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "reload key from ", c.keyPath)
|
||||
}
|
||||
c.key = key
|
||||
}
|
||||
keyPair, err := tls.X509KeyPair(c.certificate, c.key)
|
||||
if err != nil {
|
||||
return E.Cause(err, "reload key pair")
|
||||
}
|
||||
c.config.Certificates = []tls.Certificate{keyPair}
|
||||
c.logger.Info("reloaded TLS certificate")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) Close() error {
|
||||
if c.acmeService != nil {
|
||||
return c.acmeService.Close()
|
||||
}
|
||||
if c.watcher != nil {
|
||||
return c.watcher.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -16,24 +16,32 @@ import (
|
||||
utls "github.com/refraction-networking/utls"
|
||||
)
|
||||
|
||||
type utlsClientConfig struct {
|
||||
type UTLSClientConfig struct {
|
||||
config *utls.Config
|
||||
id utls.ClientHelloID
|
||||
}
|
||||
|
||||
func (e *utlsClientConfig) NextProtos() []string {
|
||||
func (e *UTLSClientConfig) ServerName() string {
|
||||
return e.config.ServerName
|
||||
}
|
||||
|
||||
func (e *UTLSClientConfig) SetServerName(serverName string) {
|
||||
e.config.ServerName = serverName
|
||||
}
|
||||
|
||||
func (e *UTLSClientConfig) NextProtos() []string {
|
||||
return e.config.NextProtos
|
||||
}
|
||||
|
||||
func (e *utlsClientConfig) SetNextProtos(nextProto []string) {
|
||||
func (e *UTLSClientConfig) SetNextProtos(nextProto []string) {
|
||||
e.config.NextProtos = nextProto
|
||||
}
|
||||
|
||||
func (e *utlsClientConfig) Config() (*STDConfig, error) {
|
||||
func (e *UTLSClientConfig) Config() (*STDConfig, error) {
|
||||
return nil, E.New("unsupported usage for uTLS")
|
||||
}
|
||||
|
||||
func (e *utlsClientConfig) Client(conn net.Conn) Conn {
|
||||
func (e *UTLSClientConfig) Client(conn net.Conn) Conn {
|
||||
return &utlsConnWrapper{utls.UClient(conn, e.config.Clone(), e.id)}
|
||||
}
|
||||
|
||||
@@ -59,7 +67,14 @@ func (c *utlsConnWrapper) ConnectionState() tls.ConnectionState {
|
||||
}
|
||||
}
|
||||
|
||||
func newUTLSClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
func (e *UTLSClientConfig) Clone() Config {
|
||||
return &UTLSClientConfig{
|
||||
config: e.config.Clone(),
|
||||
id: e.id,
|
||||
}
|
||||
}
|
||||
|
||||
func NewUTLSClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
var serverName string
|
||||
if options.ServerName != "" {
|
||||
serverName = options.ServerName
|
||||
@@ -152,5 +167,5 @@ func newUTLSClient(router adapter.Router, serverAddress string, options option.O
|
||||
default:
|
||||
return nil, E.New("unknown uTLS fingerprint: ", options.UTLS.Fingerprint)
|
||||
}
|
||||
return &utlsClientConfig{&tlsConfig, id}, nil
|
||||
return &UTLSClientConfig{&tlsConfig, id}, nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,6 @@ import (
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func newUTLSClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
func NewUTLSClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
return nil, E.New(`uTLS is not included in this build, rebuild with -tags with_utls`)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user