Files
sing-box/protocol/cloudflare/access_test.go
2026-03-31 15:32:56 +08:00

93 lines
2.3 KiB
Go

//go:build with_cloudflare_tunnel
package cloudflare
import (
"context"
"io"
"net"
"net/http"
"testing"
"github.com/sagernet/sing-box/adapter/inbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
type fakeAccessValidator struct {
err error
}
func (v *fakeAccessValidator) Validate(ctx context.Context, request *http.Request) error {
return v.err
}
func newAccessTestInbound(t *testing.T) *Inbound {
t.Helper()
logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}})
if err != nil {
t.Fatal(err)
}
return &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"),
logger: logFactory.NewLogger("test"),
accessCache: &accessValidatorCache{values: make(map[string]accessValidator)},
router: &testRouter{},
}
}
func TestValidateAccessConfiguration(t *testing.T) {
err := validateAccessConfiguration(AccessConfig{
Required: true,
AudTag: []string{"aud"},
})
if err == nil {
t.Fatal("expected access config validation error")
}
}
func TestRoundTripHTTPAccessDenied(t *testing.T) {
originalFactory := newAccessValidator
defer func() {
newAccessValidator = originalFactory
}()
newAccessValidator = func(access AccessConfig) (accessValidator, error) {
return &fakeAccessValidator{err: E.New("forbidden")}, nil
}
inboundInstance := newAccessTestInbound(t)
service := ResolvedService{
Kind: ResolvedServiceHTTP,
OriginRequest: OriginRequestConfig{
Access: AccessConfig{
Required: true,
TeamName: "team",
},
},
}
serverSide, clientSide := net.Pipe()
defer serverSide.Close()
defer clientSide.Close()
respWriter := &fakeConnectResponseWriter{}
request := &ConnectRequest{
Type: ConnectionTypeHTTP,
Dest: "http://127.0.0.1:8083",
Metadata: []Metadata{
{Key: metadataHTTPMethod, Val: http.MethodGet},
{Key: metadataHTTPHost, Val: "example.com"},
},
}
go func() {
defer clientSide.Close()
_, _ = io.Copy(io.Discard, clientSide)
}()
inboundInstance.roundTripHTTP(context.Background(), serverSide, respWriter, request, service, &http.Transport{})
if respWriter.status != http.StatusForbidden {
t.Fatalf("expected 403, got %d", respWriter.status)
}
}