mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-14 04:38:28 +10:00
Validate cloudflare access protected origins
This commit is contained in:
2
go.mod
2
go.mod
@@ -75,6 +75,7 @@ require (
|
||||
github.com/andybalholm/brotli v1.1.0 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 // indirect
|
||||
github.com/coreos/go-oidc/v3 v3.12.0 // indirect
|
||||
github.com/database64128/netx-go v0.1.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa // indirect
|
||||
@@ -84,6 +85,7 @@ require (
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
|
||||
github.com/gaissmai/bart v0.18.0 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20250813024750-ebf49471dced // indirect
|
||||
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||
github.com/gobwas/httphead v0.1.0 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -28,6 +28,8 @@ github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9
|
||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||
github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 h1:8h5+bWd7R6AYUslN6c6iuZWTKsKxUFDlpnmilO6R2n0=
|
||||
github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
|
||||
github.com/coreos/go-oidc/v3 v3.12.0 h1:sJk+8G2qq94rDI6ehZ71Bol3oUHy63qNYmkiSjrc/Jo=
|
||||
github.com/coreos/go-oidc/v3 v3.12.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/cretz/bine v0.2.0 h1:8GiDRGlTgz+o8H9DSnsl+5MeBK4HsExxgl6WgzOCuZo=
|
||||
github.com/cretz/bine v0.2.0/go.mod h1:WU4o9QR9wWp8AVKtTM1XD5vUHkEqnf2vVSo6dBqbetI=
|
||||
|
||||
104
protocol/cloudflare/access.go
Normal file
104
protocol/cloudflare/access.go
Normal file
@@ -0,0 +1,104 @@
|
||||
//go:build with_cloudflare_tunnel
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
const accessJWTAssertionHeader = "Cf-Access-Jwt-Assertion"
|
||||
|
||||
var newAccessValidator = func(access AccessConfig) (accessValidator, error) {
|
||||
issuerURL := accessIssuerURL(access.TeamName, access.Environment)
|
||||
keySet := oidc.NewRemoteKeySet(context.Background(), issuerURL+"/cdn-cgi/access/certs")
|
||||
verifier := oidc.NewVerifier(issuerURL, keySet, &oidc.Config{
|
||||
SkipClientIDCheck: true,
|
||||
})
|
||||
return &oidcAccessValidator{
|
||||
verifier: verifier,
|
||||
audTags: append([]string(nil), access.AudTag...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type accessValidator interface {
|
||||
Validate(ctx context.Context, request *http.Request) error
|
||||
}
|
||||
|
||||
type oidcAccessValidator struct {
|
||||
verifier *oidc.IDTokenVerifier
|
||||
audTags []string
|
||||
}
|
||||
|
||||
func (v *oidcAccessValidator) Validate(ctx context.Context, request *http.Request) error {
|
||||
accessJWT := request.Header.Get(accessJWTAssertionHeader)
|
||||
if accessJWT == "" {
|
||||
return E.New("missing access jwt assertion")
|
||||
}
|
||||
token, err := v.verifier.Verify(ctx, accessJWT)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(v.audTags) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, jwtAudTag := range token.Audience {
|
||||
for _, acceptedAudTag := range v.audTags {
|
||||
if acceptedAudTag == jwtAudTag {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return E.New("access token audience does not match configured aud_tag")
|
||||
}
|
||||
|
||||
func accessIssuerURL(teamName string, environment string) string {
|
||||
if strings.EqualFold(environment, "fed") || strings.EqualFold(environment, "fips") {
|
||||
return fmt.Sprintf("https://%s.fed.cloudflareaccess.com", teamName)
|
||||
}
|
||||
return fmt.Sprintf("https://%s.cloudflareaccess.com", teamName)
|
||||
}
|
||||
|
||||
func validateAccessConfiguration(access AccessConfig) error {
|
||||
if !access.Required {
|
||||
return nil
|
||||
}
|
||||
if access.TeamName == "" && len(access.AudTag) > 0 {
|
||||
return E.New("access.team_name cannot be blank when access.aud_tag is present")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func accessValidatorKey(access AccessConfig) string {
|
||||
return access.TeamName + "|" + access.Environment + "|" + strings.Join(access.AudTag, ",")
|
||||
}
|
||||
|
||||
type accessValidatorCache struct {
|
||||
access sync.RWMutex
|
||||
values map[string]accessValidator
|
||||
}
|
||||
|
||||
func (c *accessValidatorCache) Get(accessConfig AccessConfig) (accessValidator, error) {
|
||||
key := accessValidatorKey(accessConfig)
|
||||
c.access.RLock()
|
||||
validator, loaded := c.values[key]
|
||||
c.access.RUnlock()
|
||||
if loaded {
|
||||
return validator, nil
|
||||
}
|
||||
|
||||
validator, err := newAccessValidator(accessConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.access.Lock()
|
||||
c.values[key] = validator
|
||||
c.access.Unlock()
|
||||
return validator, nil
|
||||
}
|
||||
92
protocol/cloudflare/access_test.go
Normal file
92
protocol/cloudflare/access_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build with_cloudflare_tunnel
|
||||
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type fakeAccessValidator struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (v *fakeAccessValidator) Validate(ctx context.Context, request *http.Request) error {
|
||||
return v.err
|
||||
}
|
||||
|
||||
func newAccessTestInbound(t *testing.T) *Inbound {
|
||||
t.Helper()
|
||||
logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"),
|
||||
logger: logFactory.NewLogger("test"),
|
||||
accessCache: &accessValidatorCache{values: make(map[string]accessValidator)},
|
||||
router: &testRouter{},
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAccessConfiguration(t *testing.T) {
|
||||
err := validateAccessConfiguration(AccessConfig{
|
||||
Required: true,
|
||||
AudTag: []string{"aud"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected access config validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundTripHTTPAccessDenied(t *testing.T) {
|
||||
originalFactory := newAccessValidator
|
||||
defer func() {
|
||||
newAccessValidator = originalFactory
|
||||
}()
|
||||
newAccessValidator = func(access AccessConfig) (accessValidator, error) {
|
||||
return &fakeAccessValidator{err: E.New("forbidden")}, nil
|
||||
}
|
||||
|
||||
inboundInstance := newAccessTestInbound(t)
|
||||
service := ResolvedService{
|
||||
Kind: ResolvedServiceHTTP,
|
||||
OriginRequest: OriginRequestConfig{
|
||||
Access: AccessConfig{
|
||||
Required: true,
|
||||
TeamName: "team",
|
||||
},
|
||||
},
|
||||
}
|
||||
serverSide, clientSide := net.Pipe()
|
||||
defer serverSide.Close()
|
||||
defer clientSide.Close()
|
||||
|
||||
respWriter := &fakeConnectResponseWriter{}
|
||||
request := &ConnectRequest{
|
||||
Type: ConnectionTypeHTTP,
|
||||
Dest: "http://127.0.0.1:8083",
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPMethod, Val: http.MethodGet},
|
||||
{Key: metadataHTTPHost, Val: "example.com"},
|
||||
},
|
||||
}
|
||||
go func() {
|
||||
defer clientSide.Close()
|
||||
_, _ = io.Copy(io.Discard, clientSide)
|
||||
}()
|
||||
|
||||
inboundInstance.roundTripHTTP(context.Background(), serverSide, respWriter, request, service, &http.Transport{})
|
||||
if respWriter.status != http.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d", respWriter.status)
|
||||
}
|
||||
}
|
||||
@@ -321,6 +321,18 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser,
|
||||
defer cancel()
|
||||
httpRequest = httpRequest.WithContext(requestCtx)
|
||||
}
|
||||
if service.OriginRequest.Access.Required {
|
||||
validator, err := i.accessCache.Get(service.OriginRequest.Access)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "create access validator: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
if err := validator.Validate(requestCtx, httpRequest); err != nil {
|
||||
respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusForbidden, http.Header{}))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: transport,
|
||||
|
||||
@@ -45,6 +45,7 @@ type Inbound struct {
|
||||
gracePeriod time.Duration
|
||||
configManager *ConfigManager
|
||||
flowLimiter *FlowLimiter
|
||||
accessCache *accessValidatorCache
|
||||
|
||||
connectionAccess sync.Mutex
|
||||
connections []io.Closer
|
||||
@@ -121,6 +122,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||
gracePeriod: gracePeriod,
|
||||
configManager: configManager,
|
||||
flowLimiter: &FlowLimiter{},
|
||||
accessCache: &accessValidatorCache{values: make(map[string]accessValidator)},
|
||||
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
|
||||
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
|
||||
}, nil
|
||||
|
||||
@@ -363,6 +363,9 @@ func compileIngressRules(defaultOriginRequest OriginRequestConfig, rawRules []lo
|
||||
if err := validateHostname(rule.Hostname, index == len(rawRules)-1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateAccessConfiguration(rule.OriginRequest.Access); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
service, err := parseResolvedService(rule.Service, rule.OriginRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
Reference in New Issue
Block a user