From ed6be9b0785ec9f8554b9e0582b369c0dcbb23c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 12:50:45 +0800 Subject: [PATCH] Validate cloudflare access protected origins --- go.mod | 2 + go.sum | 2 + protocol/cloudflare/access.go | 104 ++++++++++++++++++++++++++ protocol/cloudflare/access_test.go | 92 +++++++++++++++++++++++ protocol/cloudflare/dispatch.go | 12 +++ protocol/cloudflare/inbound.go | 2 + protocol/cloudflare/runtime_config.go | 3 + 7 files changed, 217 insertions(+) create mode 100644 protocol/cloudflare/access.go create mode 100644 protocol/cloudflare/access_test.go diff --git a/go.mod b/go.mod index 9709176e4..fa945dfce 100644 --- a/go.mod +++ b/go.mod @@ -75,6 +75,7 @@ require ( github.com/andybalholm/brotli v1.1.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 // indirect + github.com/coreos/go-oidc/v3 v3.12.0 // indirect github.com/database64128/netx-go v0.1.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa // indirect @@ -84,6 +85,7 @@ require ( github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/gaissmai/bart v0.18.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/go-json-experiment/json v0.0.0-20250813024750-ebf49471dced // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/gobwas/httphead v0.1.0 // indirect diff --git a/go.sum b/go.sum index c5d7315e8..fe44e4b2b 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9 github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 h1:8h5+bWd7R6AYUslN6c6iuZWTKsKxUFDlpnmilO6R2n0= github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= +github.com/coreos/go-oidc/v3 v3.12.0 h1:sJk+8G2qq94rDI6ehZ71Bol3oUHy63qNYmkiSjrc/Jo= +github.com/coreos/go-oidc/v3 v3.12.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/cretz/bine v0.2.0 h1:8GiDRGlTgz+o8H9DSnsl+5MeBK4HsExxgl6WgzOCuZo= github.com/cretz/bine v0.2.0/go.mod h1:WU4o9QR9wWp8AVKtTM1XD5vUHkEqnf2vVSo6dBqbetI= diff --git a/protocol/cloudflare/access.go b/protocol/cloudflare/access.go new file mode 100644 index 000000000..9407d0312 --- /dev/null +++ b/protocol/cloudflare/access.go @@ -0,0 +1,104 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "fmt" + "net/http" + "strings" + "sync" + + "github.com/coreos/go-oidc/v3/oidc" + E "github.com/sagernet/sing/common/exceptions" +) + +const accessJWTAssertionHeader = "Cf-Access-Jwt-Assertion" + +var newAccessValidator = func(access AccessConfig) (accessValidator, error) { + issuerURL := accessIssuerURL(access.TeamName, access.Environment) + keySet := oidc.NewRemoteKeySet(context.Background(), issuerURL+"/cdn-cgi/access/certs") + verifier := oidc.NewVerifier(issuerURL, keySet, &oidc.Config{ + SkipClientIDCheck: true, + }) + return &oidcAccessValidator{ + verifier: verifier, + audTags: append([]string(nil), access.AudTag...), + }, nil +} + +type accessValidator interface { + Validate(ctx context.Context, request *http.Request) error +} + +type oidcAccessValidator struct { + verifier *oidc.IDTokenVerifier + audTags []string +} + +func (v *oidcAccessValidator) Validate(ctx context.Context, request *http.Request) error { + accessJWT := request.Header.Get(accessJWTAssertionHeader) + if accessJWT == "" { + return E.New("missing access jwt assertion") + } + token, err := v.verifier.Verify(ctx, accessJWT) + if err != nil { + return err + } + if len(v.audTags) == 0 { + return nil + } + for _, jwtAudTag := range token.Audience { + for _, acceptedAudTag := range v.audTags { + if acceptedAudTag == jwtAudTag { + return nil + } + } + } + return E.New("access token audience does not match configured aud_tag") +} + +func accessIssuerURL(teamName string, environment string) string { + if strings.EqualFold(environment, "fed") || strings.EqualFold(environment, "fips") { + return fmt.Sprintf("https://%s.fed.cloudflareaccess.com", teamName) + } + return fmt.Sprintf("https://%s.cloudflareaccess.com", teamName) +} + +func validateAccessConfiguration(access AccessConfig) error { + if !access.Required { + return nil + } + if access.TeamName == "" && len(access.AudTag) > 0 { + return E.New("access.team_name cannot be blank when access.aud_tag is present") + } + return nil +} + +func accessValidatorKey(access AccessConfig) string { + return access.TeamName + "|" + access.Environment + "|" + strings.Join(access.AudTag, ",") +} + +type accessValidatorCache struct { + access sync.RWMutex + values map[string]accessValidator +} + +func (c *accessValidatorCache) Get(accessConfig AccessConfig) (accessValidator, error) { + key := accessValidatorKey(accessConfig) + c.access.RLock() + validator, loaded := c.values[key] + c.access.RUnlock() + if loaded { + return validator, nil + } + + validator, err := newAccessValidator(accessConfig) + if err != nil { + return nil, err + } + c.access.Lock() + c.values[key] = validator + c.access.Unlock() + return validator, nil +} diff --git a/protocol/cloudflare/access_test.go b/protocol/cloudflare/access_test.go new file mode 100644 index 000000000..594f94d77 --- /dev/null +++ b/protocol/cloudflare/access_test.go @@ -0,0 +1,92 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "io" + "net" + "net/http" + "testing" + + "github.com/sagernet/sing-box/adapter/inbound" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +type fakeAccessValidator struct { + err error +} + +func (v *fakeAccessValidator) Validate(ctx context.Context, request *http.Request) error { + return v.err +} + +func newAccessTestInbound(t *testing.T) *Inbound { + t.Helper() + logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}}) + if err != nil { + t.Fatal(err) + } + return &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + logger: logFactory.NewLogger("test"), + accessCache: &accessValidatorCache{values: make(map[string]accessValidator)}, + router: &testRouter{}, + } +} + +func TestValidateAccessConfiguration(t *testing.T) { + err := validateAccessConfiguration(AccessConfig{ + Required: true, + AudTag: []string{"aud"}, + }) + if err == nil { + t.Fatal("expected access config validation error") + } +} + +func TestRoundTripHTTPAccessDenied(t *testing.T) { + originalFactory := newAccessValidator + defer func() { + newAccessValidator = originalFactory + }() + newAccessValidator = func(access AccessConfig) (accessValidator, error) { + return &fakeAccessValidator{err: E.New("forbidden")}, nil + } + + inboundInstance := newAccessTestInbound(t) + service := ResolvedService{ + Kind: ResolvedServiceHTTP, + OriginRequest: OriginRequestConfig{ + Access: AccessConfig{ + Required: true, + TeamName: "team", + }, + }, + } + serverSide, clientSide := net.Pipe() + defer serverSide.Close() + defer clientSide.Close() + + respWriter := &fakeConnectResponseWriter{} + request := &ConnectRequest{ + Type: ConnectionTypeHTTP, + Dest: "http://127.0.0.1:8083", + Metadata: []Metadata{ + {Key: metadataHTTPMethod, Val: http.MethodGet}, + {Key: metadataHTTPHost, Val: "example.com"}, + }, + } + go func() { + defer clientSide.Close() + _, _ = io.Copy(io.Discard, clientSide) + }() + + inboundInstance.roundTripHTTP(context.Background(), serverSide, respWriter, request, service, &http.Transport{}) + if respWriter.status != http.StatusForbidden { + t.Fatalf("expected 403, got %d", respWriter.status) + } +} diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 2dc2c73e7..7cf664aca 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -321,6 +321,18 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, defer cancel() httpRequest = httpRequest.WithContext(requestCtx) } + if service.OriginRequest.Access.Required { + validator, err := i.accessCache.Get(service.OriginRequest.Access) + if err != nil { + i.logger.ErrorContext(ctx, "create access validator: ", err) + respWriter.WriteResponse(err, nil) + return + } + if err := validator.Validate(requestCtx, httpRequest); err != nil { + respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusForbidden, http.Header{})) + return + } + } httpClient := &http.Client{ Transport: transport, diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index 6d31eaa05..e7cbdb4b6 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -45,6 +45,7 @@ type Inbound struct { gracePeriod time.Duration configManager *ConfigManager flowLimiter *FlowLimiter + accessCache *accessValidatorCache connectionAccess sync.Mutex connections []io.Closer @@ -121,6 +122,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo gracePeriod: gracePeriod, configManager: configManager, flowLimiter: &FlowLimiter{}, + accessCache: &accessValidatorCache{values: make(map[string]accessValidator)}, datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), }, nil diff --git a/protocol/cloudflare/runtime_config.go b/protocol/cloudflare/runtime_config.go index 276e99d41..c35c5505c 100644 --- a/protocol/cloudflare/runtime_config.go +++ b/protocol/cloudflare/runtime_config.go @@ -363,6 +363,9 @@ func compileIngressRules(defaultOriginRequest OriginRequestConfig, rawRules []lo if err := validateHostname(rule.Hostname, index == len(rawRules)-1); err != nil { return nil, err } + if err := validateAccessConfiguration(rule.OriginRequest.Access); err != nil { + return nil, err + } service, err := parseResolvedService(rule.Service, rule.OriginRequest) if err != nil { return nil, err