Files
sing-box/protocol/cloudflare/access_test.go

155 lines
4.5 KiB
Go

//go:build with_cloudflare_tunnel
package cloudflare
import (
"context"
"net/http"
"testing"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type fakeAccessValidator struct {
err error
}
func (v *fakeAccessValidator) Validate(ctx context.Context, request *http.Request) error {
return v.err
}
func newAccessTestInbound(t *testing.T) *Inbound {
t.Helper()
logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}})
if err != nil {
t.Fatal(err)
}
return &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"),
logger: logFactory.NewLogger("test"),
accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer},
router: &testRouter{},
controlDialer: N.SystemDialer,
}
}
func TestValidateAccessConfiguration(t *testing.T) {
err := validateAccessConfiguration(AccessConfig{
Required: true,
AudTag: []string{"aud"},
})
if err == nil {
t.Fatal("expected access config validation error")
}
}
func TestRoundTripHTTPAccessDenied(t *testing.T) {
originalFactory := newAccessValidator
defer func() {
newAccessValidator = originalFactory
}()
newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
return &fakeAccessValidator{err: E.New("forbidden")}, nil
}
inboundInstance := newAccessTestInbound(t)
respWriter := &fakeConnectResponseWriter{}
request := &ConnectRequest{
Type: ConnectionTypeHTTP,
Dest: "http://127.0.0.1:8083/test",
Metadata: []Metadata{
{Key: metadataHTTPMethod, Val: http.MethodGet},
{Key: metadataHTTPHost, Val: "example.com"},
},
}
inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{
Kind: ResolvedServiceHTTP,
Destination: M.ParseSocksaddr("127.0.0.1:8083"),
OriginRequest: OriginRequestConfig{
Access: AccessConfig{
Required: true,
TeamName: "team",
},
},
})
if respWriter.status != http.StatusForbidden {
t.Fatalf("expected 403, got %d", respWriter.status)
}
}
func TestHandleHTTPServiceStatusAccessDenied(t *testing.T) {
originalFactory := newAccessValidator
defer func() {
newAccessValidator = originalFactory
}()
newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
return &fakeAccessValidator{err: E.New("forbidden")}, nil
}
inboundInstance := newAccessTestInbound(t)
respWriter := &fakeConnectResponseWriter{}
request := &ConnectRequest{
Type: ConnectionTypeHTTP,
Dest: "https://example.com/status",
Metadata: []Metadata{
{Key: metadataHTTPMethod, Val: http.MethodGet},
{Key: metadataHTTPHost, Val: "example.com"},
},
}
inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{
Kind: ResolvedServiceStatus,
OriginRequest: OriginRequestConfig{
Access: AccessConfig{
Required: true,
TeamName: "team",
},
},
StatusCode: 404,
})
if respWriter.status != http.StatusForbidden {
t.Fatalf("expected 403, got %d", respWriter.status)
}
}
func TestHandleHTTPServiceStreamAccessDenied(t *testing.T) {
originalFactory := newAccessValidator
defer func() {
newAccessValidator = originalFactory
}()
newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
return &fakeAccessValidator{err: E.New("forbidden")}, nil
}
inboundInstance := newAccessTestInbound(t)
respWriter := &fakeConnectResponseWriter{}
request := &ConnectRequest{
Type: ConnectionTypeWebsocket,
Dest: "https://example.com/ws",
Metadata: []Metadata{
{Key: metadataHTTPMethod, Val: http.MethodGet},
{Key: metadataHTTPHost, Val: "example.com"},
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
},
}
inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{
Kind: ResolvedServiceStream,
Destination: M.ParseSocksaddr("127.0.0.1:8080"),
OriginRequest: OriginRequestConfig{
Access: AccessConfig{
Required: true,
TeamName: "team",
},
},
})
if respWriter.status != http.StatusForbidden {
t.Fatalf("expected 403, got %d", respWriter.status)
}
}