Validate cloudflare access protected origins

This commit is contained in:
世界
2026-03-24 12:50:45 +08:00
parent 854718992f
commit ed6be9b078
7 changed files with 217 additions and 0 deletions

2
go.mod
View File

@@ -75,6 +75,7 @@ require (
github.com/andybalholm/brotli v1.1.0 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 // indirect
github.com/coreos/go-oidc/v3 v3.12.0 // indirect
github.com/database64128/netx-go v0.1.1 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa // indirect
@@ -84,6 +85,7 @@ require (
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
github.com/gaissmai/bart v0.18.0 // indirect
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
github.com/go-json-experiment/json v0.0.0-20250813024750-ebf49471dced // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/gobwas/httphead v0.1.0 // indirect

2
go.sum
View File

@@ -28,6 +28,8 @@ github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 h1:8h5+bWd7R6AYUslN6c6iuZWTKsKxUFDlpnmilO6R2n0=
github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
github.com/coreos/go-oidc/v3 v3.12.0 h1:sJk+8G2qq94rDI6ehZ71Bol3oUHy63qNYmkiSjrc/Jo=
github.com/coreos/go-oidc/v3 v3.12.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/cretz/bine v0.2.0 h1:8GiDRGlTgz+o8H9DSnsl+5MeBK4HsExxgl6WgzOCuZo=
github.com/cretz/bine v0.2.0/go.mod h1:WU4o9QR9wWp8AVKtTM1XD5vUHkEqnf2vVSo6dBqbetI=

View File

@@ -0,0 +1,104 @@
//go:build with_cloudflare_tunnel
package cloudflare
import (
"context"
"fmt"
"net/http"
"strings"
"sync"
"github.com/coreos/go-oidc/v3/oidc"
E "github.com/sagernet/sing/common/exceptions"
)
const accessJWTAssertionHeader = "Cf-Access-Jwt-Assertion"
var newAccessValidator = func(access AccessConfig) (accessValidator, error) {
issuerURL := accessIssuerURL(access.TeamName, access.Environment)
keySet := oidc.NewRemoteKeySet(context.Background(), issuerURL+"/cdn-cgi/access/certs")
verifier := oidc.NewVerifier(issuerURL, keySet, &oidc.Config{
SkipClientIDCheck: true,
})
return &oidcAccessValidator{
verifier: verifier,
audTags: append([]string(nil), access.AudTag...),
}, nil
}
type accessValidator interface {
Validate(ctx context.Context, request *http.Request) error
}
type oidcAccessValidator struct {
verifier *oidc.IDTokenVerifier
audTags []string
}
func (v *oidcAccessValidator) Validate(ctx context.Context, request *http.Request) error {
accessJWT := request.Header.Get(accessJWTAssertionHeader)
if accessJWT == "" {
return E.New("missing access jwt assertion")
}
token, err := v.verifier.Verify(ctx, accessJWT)
if err != nil {
return err
}
if len(v.audTags) == 0 {
return nil
}
for _, jwtAudTag := range token.Audience {
for _, acceptedAudTag := range v.audTags {
if acceptedAudTag == jwtAudTag {
return nil
}
}
}
return E.New("access token audience does not match configured aud_tag")
}
func accessIssuerURL(teamName string, environment string) string {
if strings.EqualFold(environment, "fed") || strings.EqualFold(environment, "fips") {
return fmt.Sprintf("https://%s.fed.cloudflareaccess.com", teamName)
}
return fmt.Sprintf("https://%s.cloudflareaccess.com", teamName)
}
func validateAccessConfiguration(access AccessConfig) error {
if !access.Required {
return nil
}
if access.TeamName == "" && len(access.AudTag) > 0 {
return E.New("access.team_name cannot be blank when access.aud_tag is present")
}
return nil
}
func accessValidatorKey(access AccessConfig) string {
return access.TeamName + "|" + access.Environment + "|" + strings.Join(access.AudTag, ",")
}
type accessValidatorCache struct {
access sync.RWMutex
values map[string]accessValidator
}
func (c *accessValidatorCache) Get(accessConfig AccessConfig) (accessValidator, error) {
key := accessValidatorKey(accessConfig)
c.access.RLock()
validator, loaded := c.values[key]
c.access.RUnlock()
if loaded {
return validator, nil
}
validator, err := newAccessValidator(accessConfig)
if err != nil {
return nil, err
}
c.access.Lock()
c.values[key] = validator
c.access.Unlock()
return validator, nil
}

View File

@@ -0,0 +1,92 @@
//go:build with_cloudflare_tunnel
package cloudflare
import (
"context"
"io"
"net"
"net/http"
"testing"
"github.com/sagernet/sing-box/adapter/inbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
type fakeAccessValidator struct {
err error
}
func (v *fakeAccessValidator) Validate(ctx context.Context, request *http.Request) error {
return v.err
}
func newAccessTestInbound(t *testing.T) *Inbound {
t.Helper()
logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}})
if err != nil {
t.Fatal(err)
}
return &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"),
logger: logFactory.NewLogger("test"),
accessCache: &accessValidatorCache{values: make(map[string]accessValidator)},
router: &testRouter{},
}
}
func TestValidateAccessConfiguration(t *testing.T) {
err := validateAccessConfiguration(AccessConfig{
Required: true,
AudTag: []string{"aud"},
})
if err == nil {
t.Fatal("expected access config validation error")
}
}
func TestRoundTripHTTPAccessDenied(t *testing.T) {
originalFactory := newAccessValidator
defer func() {
newAccessValidator = originalFactory
}()
newAccessValidator = func(access AccessConfig) (accessValidator, error) {
return &fakeAccessValidator{err: E.New("forbidden")}, nil
}
inboundInstance := newAccessTestInbound(t)
service := ResolvedService{
Kind: ResolvedServiceHTTP,
OriginRequest: OriginRequestConfig{
Access: AccessConfig{
Required: true,
TeamName: "team",
},
},
}
serverSide, clientSide := net.Pipe()
defer serverSide.Close()
defer clientSide.Close()
respWriter := &fakeConnectResponseWriter{}
request := &ConnectRequest{
Type: ConnectionTypeHTTP,
Dest: "http://127.0.0.1:8083",
Metadata: []Metadata{
{Key: metadataHTTPMethod, Val: http.MethodGet},
{Key: metadataHTTPHost, Val: "example.com"},
},
}
go func() {
defer clientSide.Close()
_, _ = io.Copy(io.Discard, clientSide)
}()
inboundInstance.roundTripHTTP(context.Background(), serverSide, respWriter, request, service, &http.Transport{})
if respWriter.status != http.StatusForbidden {
t.Fatalf("expected 403, got %d", respWriter.status)
}
}

View File

@@ -321,6 +321,18 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser,
defer cancel()
httpRequest = httpRequest.WithContext(requestCtx)
}
if service.OriginRequest.Access.Required {
validator, err := i.accessCache.Get(service.OriginRequest.Access)
if err != nil {
i.logger.ErrorContext(ctx, "create access validator: ", err)
respWriter.WriteResponse(err, nil)
return
}
if err := validator.Validate(requestCtx, httpRequest); err != nil {
respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusForbidden, http.Header{}))
return
}
}
httpClient := &http.Client{
Transport: transport,

View File

@@ -45,6 +45,7 @@ type Inbound struct {
gracePeriod time.Duration
configManager *ConfigManager
flowLimiter *FlowLimiter
accessCache *accessValidatorCache
connectionAccess sync.Mutex
connections []io.Closer
@@ -121,6 +122,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
gracePeriod: gracePeriod,
configManager: configManager,
flowLimiter: &FlowLimiter{},
accessCache: &accessValidatorCache{values: make(map[string]accessValidator)},
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
}, nil

View File

@@ -363,6 +363,9 @@ func compileIngressRules(defaultOriginRequest OriginRequestConfig, rawRules []lo
if err := validateHostname(rule.Hostname, index == len(rawRules)-1); err != nil {
return nil, err
}
if err := validateAccessConfiguration(rule.OriginRequest.Access); err != nil {
return nil, err
}
service, err := parseResolvedService(rule.Service, rule.OriginRequest)
if err != nil {
return nil, err