mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-11 17:47:20 +10:00
Refactor ACME support to certificate provider
This commit is contained in:
411
service/acme/service.go
Normal file
411
service/acme/service.go
Normal file
@@ -0,0 +1,411 @@
|
||||
//go:build with_acme
|
||||
|
||||
package acme
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/certificate"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
boxtls "github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
|
||||
"github.com/caddyserver/certmagic"
|
||||
"github.com/caddyserver/zerossl"
|
||||
"github.com/libdns/alidns"
|
||||
"github.com/libdns/cloudflare"
|
||||
"github.com/libdns/libdns"
|
||||
"github.com/mholt/acmez/v3/acme"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
func RegisterCertificateProvider(registry *certificate.Registry) {
|
||||
certificate.Register[option.ACMECertificateProviderOptions](registry, C.TypeACME, NewCertificateProvider)
|
||||
}
|
||||
|
||||
var (
|
||||
_ adapter.CertificateProviderService = (*Service)(nil)
|
||||
_ adapter.ACMECertificateProvider = (*Service)(nil)
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
certificate.Adapter
|
||||
ctx context.Context
|
||||
config *certmagic.Config
|
||||
cache *certmagic.Cache
|
||||
domain []string
|
||||
nextProtos []string
|
||||
}
|
||||
|
||||
func NewCertificateProvider(ctx context.Context, logger log.ContextLogger, tag string, options option.ACMECertificateProviderOptions) (adapter.CertificateProviderService, error) {
|
||||
if len(options.Domain) == 0 {
|
||||
return nil, E.New("missing domain")
|
||||
}
|
||||
var acmeServer string
|
||||
switch options.Provider {
|
||||
case "", "letsencrypt":
|
||||
acmeServer = certmagic.LetsEncryptProductionCA
|
||||
case "zerossl":
|
||||
acmeServer = certmagic.ZeroSSLProductionCA
|
||||
default:
|
||||
if !strings.HasPrefix(options.Provider, "https://") {
|
||||
return nil, E.New("unsupported ACME provider: ", options.Provider)
|
||||
}
|
||||
acmeServer = options.Provider
|
||||
}
|
||||
if acmeServer == certmagic.ZeroSSLProductionCA &&
|
||||
(options.ExternalAccount == nil || options.ExternalAccount.KeyID == "") &&
|
||||
strings.TrimSpace(options.Email) == "" &&
|
||||
strings.TrimSpace(options.AccountKey) == "" {
|
||||
return nil, E.New("email is required to use the ZeroSSL ACME endpoint without external_account or account_key")
|
||||
}
|
||||
|
||||
var storage certmagic.Storage
|
||||
if options.DataDirectory != "" {
|
||||
storage = &certmagic.FileStorage{Path: options.DataDirectory}
|
||||
} else {
|
||||
storage = certmagic.Default.Storage
|
||||
}
|
||||
|
||||
zapLogger := zap.New(zapcore.NewCore(
|
||||
zapcore.NewConsoleEncoder(boxtls.ACMEEncoderConfig()),
|
||||
&boxtls.ACMELogWriter{Logger: logger},
|
||||
zap.DebugLevel,
|
||||
))
|
||||
|
||||
config := &certmagic.Config{
|
||||
DefaultServerName: options.DefaultServerName,
|
||||
Storage: storage,
|
||||
Logger: zapLogger,
|
||||
}
|
||||
if options.KeyType != "" {
|
||||
var keyType certmagic.KeyType
|
||||
switch options.KeyType {
|
||||
case option.ACMEKeyTypeED25519:
|
||||
keyType = certmagic.ED25519
|
||||
case option.ACMEKeyTypeP256:
|
||||
keyType = certmagic.P256
|
||||
case option.ACMEKeyTypeP384:
|
||||
keyType = certmagic.P384
|
||||
case option.ACMEKeyTypeRSA2048:
|
||||
keyType = certmagic.RSA2048
|
||||
case option.ACMEKeyTypeRSA4096:
|
||||
keyType = certmagic.RSA4096
|
||||
default:
|
||||
return nil, E.New("unsupported ACME key type: ", options.KeyType)
|
||||
}
|
||||
config.KeySource = certmagic.StandardKeyGenerator{KeyType: keyType}
|
||||
}
|
||||
|
||||
acmeIssuer := certmagic.ACMEIssuer{
|
||||
CA: acmeServer,
|
||||
Email: options.Email,
|
||||
AccountKeyPEM: options.AccountKey,
|
||||
Agreed: true,
|
||||
DisableHTTPChallenge: options.DisableHTTPChallenge,
|
||||
DisableTLSALPNChallenge: options.DisableTLSALPNChallenge,
|
||||
AltHTTPPort: int(options.AlternativeHTTPPort),
|
||||
AltTLSALPNPort: int(options.AlternativeTLSPort),
|
||||
Logger: zapLogger,
|
||||
}
|
||||
acmeHTTPClient, err := newACMEHTTPClient(ctx, options.Detour)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dnsSolver, err := newDNSSolver(options.DNS01Challenge, zapLogger, acmeHTTPClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dnsSolver != nil {
|
||||
acmeIssuer.DNS01Solver = dnsSolver
|
||||
}
|
||||
if options.ExternalAccount != nil && options.ExternalAccount.KeyID != "" {
|
||||
acmeIssuer.ExternalAccount = (*acme.EAB)(options.ExternalAccount)
|
||||
}
|
||||
if acmeServer == certmagic.ZeroSSLProductionCA {
|
||||
acmeIssuer.NewAccountFunc = func(ctx context.Context, acmeIssuer *certmagic.ACMEIssuer, account acme.Account) (acme.Account, error) {
|
||||
if acmeIssuer.ExternalAccount != nil {
|
||||
return account, nil
|
||||
}
|
||||
var err error
|
||||
acmeIssuer.ExternalAccount, account, err = createZeroSSLExternalAccountBinding(ctx, acmeIssuer, account, acmeHTTPClient)
|
||||
return account, err
|
||||
}
|
||||
}
|
||||
|
||||
certmagicIssuer := certmagic.NewACMEIssuer(config, acmeIssuer)
|
||||
httpClientField := reflect.ValueOf(certmagicIssuer).Elem().FieldByName("httpClient")
|
||||
if !httpClientField.IsValid() || !httpClientField.CanAddr() {
|
||||
return nil, E.New("certmagic ACME issuer HTTP client field is unavailable")
|
||||
}
|
||||
reflect.NewAt(httpClientField.Type(), unsafe.Pointer(httpClientField.UnsafeAddr())).Elem().Set(reflect.ValueOf(acmeHTTPClient))
|
||||
config.Issuers = []certmagic.Issuer{certmagicIssuer}
|
||||
cache := certmagic.NewCache(certmagic.CacheOptions{
|
||||
GetConfigForCert: func(certificate certmagic.Certificate) (*certmagic.Config, error) {
|
||||
return config, nil
|
||||
},
|
||||
Logger: zapLogger,
|
||||
})
|
||||
config = certmagic.New(cache, *config)
|
||||
|
||||
var nextProtos []string
|
||||
if !acmeIssuer.DisableTLSALPNChallenge && acmeIssuer.DNS01Solver == nil {
|
||||
nextProtos = []string{C.ACMETLS1Protocol}
|
||||
}
|
||||
return &Service{
|
||||
Adapter: certificate.NewAdapter(C.TypeACME, tag),
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
cache: cache,
|
||||
domain: options.Domain,
|
||||
nextProtos: nextProtos,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
return s.config.ManageAsync(s.ctx, s.domain)
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
if s.cache != nil {
|
||||
s.cache.Stop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return s.config.GetCertificate(hello)
|
||||
}
|
||||
|
||||
func (s *Service) GetACMENextProtos() []string {
|
||||
return s.nextProtos
|
||||
}
|
||||
|
||||
func newDNSSolver(dnsOptions *option.ACMEProviderDNS01ChallengeOptions, logger *zap.Logger, httpClient *http.Client) (*certmagic.DNS01Solver, error) {
|
||||
if dnsOptions == nil || dnsOptions.Provider == "" {
|
||||
return nil, nil
|
||||
}
|
||||
if dnsOptions.TTL < 0 {
|
||||
return nil, E.New("invalid ACME DNS01 ttl: ", dnsOptions.TTL)
|
||||
}
|
||||
if dnsOptions.PropagationDelay < 0 {
|
||||
return nil, E.New("invalid ACME DNS01 propagation_delay: ", dnsOptions.PropagationDelay)
|
||||
}
|
||||
if dnsOptions.PropagationTimeout < -1 {
|
||||
return nil, E.New("invalid ACME DNS01 propagation_timeout: ", dnsOptions.PropagationTimeout)
|
||||
}
|
||||
solver := &certmagic.DNS01Solver{
|
||||
DNSManager: certmagic.DNSManager{
|
||||
TTL: time.Duration(dnsOptions.TTL),
|
||||
PropagationDelay: time.Duration(dnsOptions.PropagationDelay),
|
||||
PropagationTimeout: time.Duration(dnsOptions.PropagationTimeout),
|
||||
Resolvers: dnsOptions.Resolvers,
|
||||
OverrideDomain: dnsOptions.OverrideDomain,
|
||||
Logger: logger.Named("dns_manager"),
|
||||
},
|
||||
}
|
||||
switch dnsOptions.Provider {
|
||||
case C.DNSProviderAliDNS:
|
||||
solver.DNSProvider = &alidns.Provider{
|
||||
CredentialInfo: alidns.CredentialInfo{
|
||||
AccessKeyID: dnsOptions.AliDNSOptions.AccessKeyID,
|
||||
AccessKeySecret: dnsOptions.AliDNSOptions.AccessKeySecret,
|
||||
RegionID: dnsOptions.AliDNSOptions.RegionID,
|
||||
SecurityToken: dnsOptions.AliDNSOptions.SecurityToken,
|
||||
},
|
||||
}
|
||||
case C.DNSProviderCloudflare:
|
||||
solver.DNSProvider = &cloudflare.Provider{
|
||||
APIToken: dnsOptions.CloudflareOptions.APIToken,
|
||||
ZoneToken: dnsOptions.CloudflareOptions.ZoneToken,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
case C.DNSProviderACMEDNS:
|
||||
solver.DNSProvider = &acmeDNSProvider{
|
||||
username: dnsOptions.ACMEDNSOptions.Username,
|
||||
password: dnsOptions.ACMEDNSOptions.Password,
|
||||
subdomain: dnsOptions.ACMEDNSOptions.Subdomain,
|
||||
serverURL: dnsOptions.ACMEDNSOptions.ServerURL,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
default:
|
||||
return nil, E.New("unsupported ACME DNS01 provider type: ", dnsOptions.Provider)
|
||||
}
|
||||
return solver, nil
|
||||
}
|
||||
|
||||
func createZeroSSLExternalAccountBinding(ctx context.Context, acmeIssuer *certmagic.ACMEIssuer, account acme.Account, httpClient *http.Client) (*acme.EAB, acme.Account, error) {
|
||||
email := strings.TrimSpace(acmeIssuer.Email)
|
||||
if email == "" {
|
||||
return nil, acme.Account{}, E.New("email is required to use the ZeroSSL ACME endpoint without external_account")
|
||||
}
|
||||
if len(account.Contact) == 0 {
|
||||
account.Contact = []string{"mailto:" + email}
|
||||
}
|
||||
if acmeIssuer.CertObtainTimeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, acmeIssuer.CertObtainTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
form := url.Values{"email": []string{email}}
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, zerossl.BaseURL+"/acme/eab-credentials-email", strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, account, E.Cause(err, "create ZeroSSL EAB request")
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
request.Header.Set("User-Agent", certmagic.UserAgent)
|
||||
|
||||
response, err := httpClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, account, E.Cause(err, "request ZeroSSL EAB")
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
var result struct {
|
||||
Success bool `json:"success"`
|
||||
Error struct {
|
||||
Code int `json:"code"`
|
||||
Type string `json:"type"`
|
||||
} `json:"error"`
|
||||
EABKID string `json:"eab_kid"`
|
||||
EABHMACKey string `json:"eab_hmac_key"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&result)
|
||||
if err != nil {
|
||||
return nil, account, E.Cause(err, "decode ZeroSSL EAB response")
|
||||
}
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return nil, account, E.New("failed getting ZeroSSL EAB credentials: HTTP ", response.StatusCode)
|
||||
}
|
||||
if result.Error.Code != 0 {
|
||||
return nil, account, E.New("failed getting ZeroSSL EAB credentials: ", result.Error.Type, " (code ", result.Error.Code, ")")
|
||||
}
|
||||
|
||||
acmeIssuer.Logger.Info("generated ZeroSSL EAB credentials", zap.String("key_id", result.EABKID))
|
||||
|
||||
return &acme.EAB{
|
||||
KeyID: result.EABKID,
|
||||
MACKey: result.EABHMACKey,
|
||||
}, account, nil
|
||||
}
|
||||
|
||||
func newACMEHTTPClient(ctx context.Context, detour string) (*http.Client, error) {
|
||||
outboundDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
Options: option.DialerOptions{
|
||||
Detour: detour,
|
||||
},
|
||||
RemoteIsDomain: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create ACME provider dialer")
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return outboundDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
},
|
||||
// from certmagic defaults (acmeissuer.go)
|
||||
TLSHandshakeTimeout: 30 * time.Second,
|
||||
ResponseHeaderTimeout: 30 * time.Second,
|
||||
ExpectContinueTimeout: 2 * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
},
|
||||
Timeout: certmagic.HTTPTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type acmeDNSProvider struct {
|
||||
username string
|
||||
password string
|
||||
subdomain string
|
||||
serverURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
type acmeDNSRecord struct {
|
||||
resourceRecord libdns.RR
|
||||
}
|
||||
|
||||
func (r acmeDNSRecord) RR() libdns.RR {
|
||||
return r.resourceRecord
|
||||
}
|
||||
|
||||
func (p *acmeDNSProvider) AppendRecords(ctx context.Context, _ string, records []libdns.Record) ([]libdns.Record, error) {
|
||||
if p.username == "" {
|
||||
return nil, E.New("ACME-DNS username cannot be empty")
|
||||
}
|
||||
if p.password == "" {
|
||||
return nil, E.New("ACME-DNS password cannot be empty")
|
||||
}
|
||||
if p.subdomain == "" {
|
||||
return nil, E.New("ACME-DNS subdomain cannot be empty")
|
||||
}
|
||||
if p.serverURL == "" {
|
||||
return nil, E.New("ACME-DNS server_url cannot be empty")
|
||||
}
|
||||
appendedRecords := make([]libdns.Record, 0, len(records))
|
||||
for _, record := range records {
|
||||
resourceRecord := record.RR()
|
||||
if resourceRecord.Type != "TXT" {
|
||||
return appendedRecords, E.New("ACME-DNS only supports adding TXT records")
|
||||
}
|
||||
requestBody, err := json.Marshal(map[string]string{
|
||||
"subdomain": p.subdomain,
|
||||
"txt": resourceRecord.Data,
|
||||
})
|
||||
if err != nil {
|
||||
return appendedRecords, E.Cause(err, "marshal ACME-DNS update request")
|
||||
}
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, p.serverURL+"/update", bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return appendedRecords, E.Cause(err, "create ACME-DNS update request")
|
||||
}
|
||||
request.Header.Set("X-Api-User", p.username)
|
||||
request.Header.Set("X-Api-Key", p.password)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
response, err := p.httpClient.Do(request)
|
||||
if err != nil {
|
||||
return appendedRecords, E.Cause(err, "update ACME-DNS record")
|
||||
}
|
||||
_ = response.Body.Close()
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return appendedRecords, E.New("update ACME-DNS record: HTTP ", response.StatusCode)
|
||||
}
|
||||
appendedRecords = append(appendedRecords, acmeDNSRecord{resourceRecord: libdns.RR{
|
||||
Type: "TXT",
|
||||
Name: resourceRecord.Name,
|
||||
Data: resourceRecord.Data,
|
||||
}})
|
||||
}
|
||||
return appendedRecords, nil
|
||||
}
|
||||
|
||||
func (p *acmeDNSProvider) DeleteRecords(context.Context, string, []libdns.Record) ([]libdns.Record, error) {
|
||||
return nil, nil
|
||||
}
|
||||
3
service/acme/stub.go
Normal file
3
service/acme/stub.go
Normal file
@@ -0,0 +1,3 @@
|
||||
//go:build !with_acme
|
||||
|
||||
package acme
|
||||
618
service/origin_ca/service.go
Normal file
618
service/origin_ca/service.go
Normal file
@@ -0,0 +1,618 @@
|
||||
package originca
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/certificate"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
|
||||
"github.com/caddyserver/certmagic"
|
||||
)
|
||||
|
||||
const (
|
||||
cloudflareOriginCAEndpoint = "https://api.cloudflare.com/client/v4/certificates"
|
||||
defaultRequestedValidity = option.CloudflareOriginCARequestValidity5475
|
||||
// min of 30 days and certmagic's 1/3 lifetime ratio (maintain.go)
|
||||
defaultRenewBefore = 30 * 24 * time.Hour
|
||||
// from certmagic retry backoff range (async.go)
|
||||
minimumRenewRetryDelay = time.Minute
|
||||
maximumRenewRetryDelay = time.Hour
|
||||
storageLockPrefix = "cloudflare-origin-ca"
|
||||
)
|
||||
|
||||
func RegisterCertificateProvider(registry *certificate.Registry) {
|
||||
certificate.Register[option.CloudflareOriginCACertificateProviderOptions](registry, C.TypeCloudflareOriginCA, NewCertificateProvider)
|
||||
}
|
||||
|
||||
var _ adapter.CertificateProviderService = (*Service)(nil)
|
||||
|
||||
type Service struct {
|
||||
certificate.Adapter
|
||||
logger log.ContextLogger
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
timeFunc func() time.Time
|
||||
httpClient *http.Client
|
||||
storage certmagic.Storage
|
||||
storageIssuerKey string
|
||||
storageNamesKey string
|
||||
storageLockKey string
|
||||
apiToken string
|
||||
originCAKey string
|
||||
domain []string
|
||||
requestType option.CloudflareOriginCARequestType
|
||||
requestedValidity option.CloudflareOriginCARequestValidity
|
||||
|
||||
access sync.RWMutex
|
||||
currentCertificate *tls.Certificate
|
||||
currentLeaf *x509.Certificate
|
||||
}
|
||||
|
||||
func NewCertificateProvider(ctx context.Context, logger log.ContextLogger, tag string, options option.CloudflareOriginCACertificateProviderOptions) (adapter.CertificateProviderService, error) {
|
||||
domain, err := normalizeHostnames(options.Domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(domain) == 0 {
|
||||
return nil, E.New("missing domain")
|
||||
}
|
||||
apiToken := strings.TrimSpace(options.APIToken)
|
||||
originCAKey := strings.TrimSpace(options.OriginCAKey)
|
||||
switch {
|
||||
case apiToken == "" && originCAKey == "":
|
||||
return nil, E.New("api_token or origin_ca_key is required")
|
||||
case apiToken != "" && originCAKey != "":
|
||||
return nil, E.New("api_token and origin_ca_key are mutually exclusive")
|
||||
}
|
||||
requestType := options.RequestType
|
||||
if requestType == "" {
|
||||
requestType = option.CloudflareOriginCARequestTypeOriginRSA
|
||||
}
|
||||
requestedValidity := options.RequestedValidity
|
||||
if requestedValidity == 0 {
|
||||
requestedValidity = defaultRequestedValidity
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
serviceDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
Options: option.DialerOptions{
|
||||
Detour: options.Detour,
|
||||
},
|
||||
RemoteIsDomain: true,
|
||||
})
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, E.Cause(err, "create Cloudflare Origin CA dialer")
|
||||
}
|
||||
var storage certmagic.Storage
|
||||
if options.DataDirectory != "" {
|
||||
storage = &certmagic.FileStorage{Path: options.DataDirectory}
|
||||
} else {
|
||||
storage = certmagic.Default.Storage
|
||||
}
|
||||
timeFunc := ntp.TimeFuncFromContext(ctx)
|
||||
if timeFunc == nil {
|
||||
timeFunc = time.Now
|
||||
}
|
||||
storageIssuerKey := C.TypeCloudflareOriginCA + "-" + string(requestType)
|
||||
storageNamesKey := (&certmagic.CertificateResource{SANs: slices.Clone(domain)}).NamesKey()
|
||||
storageLockKey := strings.Join([]string{
|
||||
storageLockPrefix,
|
||||
certmagic.StorageKeys.Safe(storageIssuerKey),
|
||||
certmagic.StorageKeys.Safe(storageNamesKey),
|
||||
}, "/")
|
||||
return &Service{
|
||||
Adapter: certificate.NewAdapter(C.TypeCloudflareOriginCA, tag),
|
||||
logger: logger,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
timeFunc: timeFunc,
|
||||
httpClient: &http.Client{Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: timeFunc,
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
}},
|
||||
storage: storage,
|
||||
storageIssuerKey: storageIssuerKey,
|
||||
storageNamesKey: storageNamesKey,
|
||||
storageLockKey: storageLockKey,
|
||||
apiToken: apiToken,
|
||||
originCAKey: originCAKey,
|
||||
domain: domain,
|
||||
requestType: requestType,
|
||||
requestedValidity: requestedValidity,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
cachedCertificate, cachedLeaf, err := s.loadCachedCertificate()
|
||||
if err != nil {
|
||||
s.logger.Warn(E.Cause(err, "load cached Cloudflare Origin CA certificate"))
|
||||
} else if cachedCertificate != nil {
|
||||
s.setCurrentCertificate(cachedCertificate, cachedLeaf)
|
||||
}
|
||||
if cachedCertificate == nil {
|
||||
err = s.issueAndStoreCertificate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if s.shouldRenew(cachedLeaf, s.timeFunc()) {
|
||||
err = s.issueAndStoreCertificate()
|
||||
if err != nil {
|
||||
s.logger.Warn(E.Cause(err, "renew cached Cloudflare Origin CA certificate"))
|
||||
}
|
||||
}
|
||||
s.done = make(chan struct{})
|
||||
go s.refreshLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
s.cancel()
|
||||
if done := s.done; done != nil {
|
||||
<-done
|
||||
}
|
||||
if transport, loaded := s.httpClient.Transport.(*http.Transport); loaded {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
s.access.RLock()
|
||||
certificate := s.currentCertificate
|
||||
s.access.RUnlock()
|
||||
if certificate == nil {
|
||||
return nil, E.New("Cloudflare Origin CA certificate is unavailable")
|
||||
}
|
||||
return certificate, nil
|
||||
}
|
||||
|
||||
func (s *Service) refreshLoop() {
|
||||
defer close(s.done)
|
||||
var retryDelay time.Duration
|
||||
for {
|
||||
waitDuration := retryDelay
|
||||
if waitDuration == 0 {
|
||||
s.access.RLock()
|
||||
leaf := s.currentLeaf
|
||||
s.access.RUnlock()
|
||||
if leaf == nil {
|
||||
waitDuration = minimumRenewRetryDelay
|
||||
} else {
|
||||
refreshAt := leaf.NotAfter.Add(-s.effectiveRenewBefore(leaf))
|
||||
waitDuration = refreshAt.Sub(s.timeFunc())
|
||||
if waitDuration < minimumRenewRetryDelay {
|
||||
waitDuration = minimumRenewRetryDelay
|
||||
}
|
||||
}
|
||||
}
|
||||
timer := time.NewTimer(waitDuration)
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
err := s.issueAndStoreCertificate()
|
||||
if err != nil {
|
||||
s.logger.Error(E.Cause(err, "renew Cloudflare Origin CA certificate"))
|
||||
s.access.RLock()
|
||||
leaf := s.currentLeaf
|
||||
s.access.RUnlock()
|
||||
if leaf == nil {
|
||||
retryDelay = minimumRenewRetryDelay
|
||||
} else {
|
||||
remaining := leaf.NotAfter.Sub(s.timeFunc())
|
||||
switch {
|
||||
case remaining <= minimumRenewRetryDelay:
|
||||
retryDelay = minimumRenewRetryDelay
|
||||
case remaining < maximumRenewRetryDelay:
|
||||
retryDelay = max(remaining/2, minimumRenewRetryDelay)
|
||||
default:
|
||||
retryDelay = maximumRenewRetryDelay
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
retryDelay = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) shouldRenew(leaf *x509.Certificate, now time.Time) bool {
|
||||
return !now.Before(leaf.NotAfter.Add(-s.effectiveRenewBefore(leaf)))
|
||||
}
|
||||
|
||||
func (s *Service) effectiveRenewBefore(leaf *x509.Certificate) time.Duration {
|
||||
lifetime := leaf.NotAfter.Sub(leaf.NotBefore)
|
||||
if lifetime <= 0 {
|
||||
return 0
|
||||
}
|
||||
return min(lifetime/3, defaultRenewBefore)
|
||||
}
|
||||
|
||||
func (s *Service) issueAndStoreCertificate() error {
|
||||
err := s.storage.Lock(s.ctx, s.storageLockKey)
|
||||
if err != nil {
|
||||
return E.Cause(err, "lock Cloudflare Origin CA certificate storage")
|
||||
}
|
||||
defer func() {
|
||||
err = s.storage.Unlock(context.WithoutCancel(s.ctx), s.storageLockKey)
|
||||
if err != nil {
|
||||
s.logger.Warn(E.Cause(err, "unlock Cloudflare Origin CA certificate storage"))
|
||||
}
|
||||
}()
|
||||
cachedCertificate, cachedLeaf, err := s.loadCachedCertificate()
|
||||
if err != nil {
|
||||
s.logger.Warn(E.Cause(err, "load cached Cloudflare Origin CA certificate"))
|
||||
} else if cachedCertificate != nil && !s.shouldRenew(cachedLeaf, s.timeFunc()) {
|
||||
s.setCurrentCertificate(cachedCertificate, cachedLeaf)
|
||||
return nil
|
||||
}
|
||||
certificatePEM, privateKeyPEM, tlsCertificate, leaf, err := s.requestCertificate(s.ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
issuerData, err := json.Marshal(originCAIssuerData{
|
||||
RequestType: s.requestType,
|
||||
RequestedValidity: s.requestedValidity,
|
||||
})
|
||||
if err != nil {
|
||||
return E.Cause(err, "encode Cloudflare Origin CA certificate metadata")
|
||||
}
|
||||
err = storeCertificateResource(s.ctx, s.storage, s.storageIssuerKey, certmagic.CertificateResource{
|
||||
SANs: slices.Clone(s.domain),
|
||||
CertificatePEM: certificatePEM,
|
||||
PrivateKeyPEM: privateKeyPEM,
|
||||
IssuerData: issuerData,
|
||||
})
|
||||
if err != nil {
|
||||
return E.Cause(err, "store Cloudflare Origin CA certificate")
|
||||
}
|
||||
s.setCurrentCertificate(tlsCertificate, leaf)
|
||||
s.logger.Info("updated Cloudflare Origin CA certificate, expires at ", leaf.NotAfter.Format(time.RFC3339))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) requestCertificate(ctx context.Context) ([]byte, []byte, *tls.Certificate, *x509.Certificate, error) {
|
||||
var privateKey crypto.Signer
|
||||
switch s.requestType {
|
||||
case option.CloudflareOriginCARequestTypeOriginRSA:
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
privateKey = rsaKey
|
||||
case option.CloudflareOriginCARequestTypeOriginECC:
|
||||
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
privateKey = ecKey
|
||||
default:
|
||||
return nil, nil, nil, nil, E.New("unsupported Cloudflare Origin CA request type: ", s.requestType)
|
||||
}
|
||||
privateKeyDER, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, E.Cause(err, "encode private key")
|
||||
}
|
||||
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: privateKeyDER,
|
||||
})
|
||||
certificateRequestDER, err := x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{
|
||||
Subject: pkix.Name{CommonName: s.domain[0]},
|
||||
DNSNames: s.domain,
|
||||
}, privateKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, E.Cause(err, "create certificate request")
|
||||
}
|
||||
certificateRequestPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE REQUEST",
|
||||
Bytes: certificateRequestDER,
|
||||
})
|
||||
requestBody, err := json.Marshal(originCARequest{
|
||||
CSR: string(certificateRequestPEM),
|
||||
Hostnames: s.domain,
|
||||
RequestType: string(s.requestType),
|
||||
RequestedValidity: uint16(s.requestedValidity),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, E.Cause(err, "marshal request")
|
||||
}
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, cloudflareOriginCAEndpoint, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, E.Cause(err, "create request")
|
||||
}
|
||||
request.Header.Set("Accept", "application/json")
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("User-Agent", "sing-box/"+C.Version)
|
||||
if s.apiToken != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+s.apiToken)
|
||||
} else {
|
||||
request.Header.Set("X-Auth-User-Service-Key", s.originCAKey)
|
||||
}
|
||||
response, err := s.httpClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, E.Cause(err, "request certificate from Cloudflare")
|
||||
}
|
||||
defer response.Body.Close()
|
||||
responseBody, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, E.Cause(err, "read Cloudflare response")
|
||||
}
|
||||
var responseEnvelope originCAResponse
|
||||
err = json.Unmarshal(responseBody, &responseEnvelope)
|
||||
if err != nil && response.StatusCode >= http.StatusOK && response.StatusCode < http.StatusMultipleChoices {
|
||||
return nil, nil, nil, nil, E.Cause(err, "decode Cloudflare response")
|
||||
}
|
||||
if response.StatusCode < http.StatusOK || response.StatusCode >= http.StatusMultipleChoices {
|
||||
return nil, nil, nil, nil, buildOriginCAError(response.StatusCode, responseEnvelope.Errors, responseBody)
|
||||
}
|
||||
if !responseEnvelope.Success {
|
||||
return nil, nil, nil, nil, buildOriginCAError(response.StatusCode, responseEnvelope.Errors, responseBody)
|
||||
}
|
||||
if responseEnvelope.Result.Certificate == "" {
|
||||
return nil, nil, nil, nil, E.New("Cloudflare Origin CA response is missing certificate data")
|
||||
}
|
||||
certificatePEM := []byte(responseEnvelope.Result.Certificate)
|
||||
tlsCertificate, leaf, err := parseKeyPair(certificatePEM, privateKeyPEM)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, E.Cause(err, "parse issued certificate")
|
||||
}
|
||||
if !s.matchesCertificate(leaf) {
|
||||
return nil, nil, nil, nil, E.New("issued Cloudflare Origin CA certificate does not match requested hostnames or key type")
|
||||
}
|
||||
return certificatePEM, privateKeyPEM, tlsCertificate, leaf, nil
|
||||
}
|
||||
|
||||
func (s *Service) loadCachedCertificate() (*tls.Certificate, *x509.Certificate, error) {
|
||||
certificateResource, err := loadCertificateResource(s.ctx, s.storage, s.storageIssuerKey, s.storageNamesKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
tlsCertificate, leaf, err := parseKeyPair(certificateResource.CertificatePEM, certificateResource.PrivateKeyPEM)
|
||||
if err != nil {
|
||||
return nil, nil, E.Cause(err, "parse cached key pair")
|
||||
}
|
||||
if s.timeFunc().After(leaf.NotAfter) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
if !s.matchesCertificate(leaf) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
return tlsCertificate, leaf, nil
|
||||
}
|
||||
|
||||
func (s *Service) matchesCertificate(leaf *x509.Certificate) bool {
|
||||
if leaf == nil {
|
||||
return false
|
||||
}
|
||||
leafHostnames := leaf.DNSNames
|
||||
if len(leafHostnames) == 0 && leaf.Subject.CommonName != "" {
|
||||
leafHostnames = []string{leaf.Subject.CommonName}
|
||||
}
|
||||
normalizedLeafHostnames, err := normalizeHostnames(leafHostnames)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if !slices.Equal(normalizedLeafHostnames, s.domain) {
|
||||
return false
|
||||
}
|
||||
switch s.requestType {
|
||||
case option.CloudflareOriginCARequestTypeOriginRSA:
|
||||
return leaf.PublicKeyAlgorithm == x509.RSA
|
||||
case option.CloudflareOriginCARequestTypeOriginECC:
|
||||
return leaf.PublicKeyAlgorithm == x509.ECDSA
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) setCurrentCertificate(certificate *tls.Certificate, leaf *x509.Certificate) {
|
||||
s.access.Lock()
|
||||
s.currentCertificate = certificate
|
||||
s.currentLeaf = leaf
|
||||
s.access.Unlock()
|
||||
}
|
||||
|
||||
func normalizeHostnames(hostnames []string) ([]string, error) {
|
||||
normalizedHostnames := make([]string, 0, len(hostnames))
|
||||
seen := make(map[string]struct{}, len(hostnames))
|
||||
for _, hostname := range hostnames {
|
||||
normalizedHostname := strings.ToLower(strings.TrimSpace(strings.TrimSuffix(hostname, ".")))
|
||||
if normalizedHostname == "" {
|
||||
return nil, E.New("hostname is empty")
|
||||
}
|
||||
if net.ParseIP(normalizedHostname) != nil {
|
||||
return nil, E.New("hostname cannot be an IP address: ", normalizedHostname)
|
||||
}
|
||||
if strings.Contains(normalizedHostname, "*") {
|
||||
if !strings.HasPrefix(normalizedHostname, "*.") || strings.Count(normalizedHostname, "*") != 1 {
|
||||
return nil, E.New("invalid wildcard hostname: ", normalizedHostname)
|
||||
}
|
||||
suffix := strings.TrimPrefix(normalizedHostname, "*.")
|
||||
if strings.Count(suffix, ".") == 0 {
|
||||
return nil, E.New("wildcard hostname must cover a multi-label domain: ", normalizedHostname)
|
||||
}
|
||||
normalizedHostname = "*." + suffix
|
||||
}
|
||||
if _, loaded := seen[normalizedHostname]; loaded {
|
||||
continue
|
||||
}
|
||||
seen[normalizedHostname] = struct{}{}
|
||||
normalizedHostnames = append(normalizedHostnames, normalizedHostname)
|
||||
}
|
||||
slices.Sort(normalizedHostnames)
|
||||
return normalizedHostnames, nil
|
||||
}
|
||||
|
||||
func parseKeyPair(certificatePEM []byte, privateKeyPEM []byte) (*tls.Certificate, *x509.Certificate, error) {
|
||||
keyPair, err := tls.X509KeyPair(certificatePEM, privateKeyPEM)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if len(keyPair.Certificate) == 0 {
|
||||
return nil, nil, E.New("certificate chain is empty")
|
||||
}
|
||||
leaf, err := x509.ParseCertificate(keyPair.Certificate[0])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
keyPair.Leaf = leaf
|
||||
return &keyPair, leaf, nil
|
||||
}
|
||||
|
||||
func storeCertificateResource(ctx context.Context, storage certmagic.Storage, issuerKey string, certificateResource certmagic.CertificateResource) error {
|
||||
metaBytes, err := json.MarshalIndent(certificateResource, "", "\t")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
namesKey := certificateResource.NamesKey()
|
||||
keyValueList := []struct {
|
||||
key string
|
||||
value []byte
|
||||
}{
|
||||
{
|
||||
key: certmagic.StorageKeys.SitePrivateKey(issuerKey, namesKey),
|
||||
value: certificateResource.PrivateKeyPEM,
|
||||
},
|
||||
{
|
||||
key: certmagic.StorageKeys.SiteCert(issuerKey, namesKey),
|
||||
value: certificateResource.CertificatePEM,
|
||||
},
|
||||
{
|
||||
key: certmagic.StorageKeys.SiteMeta(issuerKey, namesKey),
|
||||
value: metaBytes,
|
||||
},
|
||||
}
|
||||
for i, item := range keyValueList {
|
||||
err = storage.Store(ctx, item.key, item.value)
|
||||
if err != nil {
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
storage.Delete(ctx, keyValueList[j].key)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadCertificateResource(ctx context.Context, storage certmagic.Storage, issuerKey string, namesKey string) (certmagic.CertificateResource, error) {
|
||||
privateKeyPEM, err := storage.Load(ctx, certmagic.StorageKeys.SitePrivateKey(issuerKey, namesKey))
|
||||
if err != nil {
|
||||
return certmagic.CertificateResource{}, err
|
||||
}
|
||||
certificatePEM, err := storage.Load(ctx, certmagic.StorageKeys.SiteCert(issuerKey, namesKey))
|
||||
if err != nil {
|
||||
return certmagic.CertificateResource{}, err
|
||||
}
|
||||
metaBytes, err := storage.Load(ctx, certmagic.StorageKeys.SiteMeta(issuerKey, namesKey))
|
||||
if err != nil {
|
||||
return certmagic.CertificateResource{}, err
|
||||
}
|
||||
var certificateResource certmagic.CertificateResource
|
||||
err = json.Unmarshal(metaBytes, &certificateResource)
|
||||
if err != nil {
|
||||
return certmagic.CertificateResource{}, E.Cause(err, "decode Cloudflare Origin CA certificate metadata")
|
||||
}
|
||||
certificateResource.PrivateKeyPEM = privateKeyPEM
|
||||
certificateResource.CertificatePEM = certificatePEM
|
||||
return certificateResource, nil
|
||||
}
|
||||
|
||||
func buildOriginCAError(statusCode int, responseErrors []originCAResponseError, responseBody []byte) error {
|
||||
if len(responseErrors) > 0 {
|
||||
messageList := make([]string, 0, len(responseErrors))
|
||||
for _, responseError := range responseErrors {
|
||||
if responseError.Message == "" {
|
||||
continue
|
||||
}
|
||||
if responseError.Code != 0 {
|
||||
messageList = append(messageList, responseError.Message+" (code "+strconv.Itoa(responseError.Code)+")")
|
||||
} else {
|
||||
messageList = append(messageList, responseError.Message)
|
||||
}
|
||||
}
|
||||
if len(messageList) > 0 {
|
||||
return E.New("Cloudflare Origin CA request failed: HTTP ", statusCode, " ", strings.Join(messageList, ", "))
|
||||
}
|
||||
}
|
||||
responseText := strings.TrimSpace(string(responseBody))
|
||||
if responseText == "" {
|
||||
return E.New("Cloudflare Origin CA request failed: HTTP ", statusCode)
|
||||
}
|
||||
return E.New("Cloudflare Origin CA request failed: HTTP ", statusCode, " ", responseText)
|
||||
}
|
||||
|
||||
type originCARequest struct {
|
||||
CSR string `json:"csr"`
|
||||
Hostnames []string `json:"hostnames"`
|
||||
RequestType string `json:"request_type"`
|
||||
RequestedValidity uint16 `json:"requested_validity"`
|
||||
}
|
||||
|
||||
type originCAResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Errors []originCAResponseError `json:"errors"`
|
||||
Result originCAResponseResult `json:"result"`
|
||||
}
|
||||
|
||||
type originCAResponseError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type originCAResponseResult struct {
|
||||
Certificate string `json:"certificate"`
|
||||
}
|
||||
|
||||
type originCAIssuerData struct {
|
||||
RequestType option.CloudflareOriginCARequestType `json:"request_type,omitempty"`
|
||||
RequestedValidity option.CloudflareOriginCARequestValidity `json:"requested_validity,omitempty"`
|
||||
}
|
||||
Reference in New Issue
Block a user