Enforce cloudflare access on all ingress services

This commit is contained in:
世界
2026-03-24 12:56:05 +08:00
parent 1ea083cd6f
commit 289101fc56
2 changed files with 103 additions and 27 deletions

View File

@@ -4,16 +4,16 @@ package cloudflare
import (
"context"
"io"
"net"
"net/http"
"testing"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
type fakeAccessValidator struct {
@@ -58,34 +58,94 @@ func TestRoundTripHTTPAccessDenied(t *testing.T) {
}
inboundInstance := newAccessTestInbound(t)
service := ResolvedService{
Kind: ResolvedServiceHTTP,
respWriter := &fakeConnectResponseWriter{}
request := &ConnectRequest{
Type: ConnectionTypeHTTP,
Dest: "http://127.0.0.1:8083/test",
Metadata: []Metadata{
{Key: metadataHTTPMethod, Val: http.MethodGet},
{Key: metadataHTTPHost, Val: "example.com"},
},
}
inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{
Kind: ResolvedServiceHTTP,
Destination: M.ParseSocksaddr("127.0.0.1:8083"),
OriginRequest: OriginRequestConfig{
Access: AccessConfig{
Required: true,
TeamName: "team",
},
},
})
if respWriter.status != http.StatusForbidden {
t.Fatalf("expected 403, got %d", respWriter.status)
}
serverSide, clientSide := net.Pipe()
defer serverSide.Close()
defer clientSide.Close()
}
func TestHandleHTTPServiceStatusAccessDenied(t *testing.T) {
originalFactory := newAccessValidator
defer func() {
newAccessValidator = originalFactory
}()
newAccessValidator = func(access AccessConfig) (accessValidator, error) {
return &fakeAccessValidator{err: E.New("forbidden")}, nil
}
inboundInstance := newAccessTestInbound(t)
respWriter := &fakeConnectResponseWriter{}
request := &ConnectRequest{
Type: ConnectionTypeHTTP,
Dest: "http://127.0.0.1:8083",
Dest: "https://example.com/status",
Metadata: []Metadata{
{Key: metadataHTTPMethod, Val: http.MethodGet},
{Key: metadataHTTPHost, Val: "example.com"},
},
}
go func() {
defer clientSide.Close()
_, _ = io.Copy(io.Discard, clientSide)
}()
inboundInstance.roundTripHTTP(context.Background(), serverSide, respWriter, request, service, &http.Transport{})
inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{
Kind: ResolvedServiceStatus,
OriginRequest: OriginRequestConfig{
Access: AccessConfig{
Required: true,
TeamName: "team",
},
},
StatusCode: 404,
})
if respWriter.status != http.StatusForbidden {
t.Fatalf("expected 403, got %d", respWriter.status)
}
}
func TestHandleHTTPServiceStreamAccessDenied(t *testing.T) {
originalFactory := newAccessValidator
defer func() {
newAccessValidator = originalFactory
}()
newAccessValidator = func(access AccessConfig) (accessValidator, error) {
return &fakeAccessValidator{err: E.New("forbidden")}, nil
}
inboundInstance := newAccessTestInbound(t)
respWriter := &fakeConnectResponseWriter{}
request := &ConnectRequest{
Type: ConnectionTypeWebsocket,
Dest: "https://example.com/ws",
Metadata: []Metadata{
{Key: metadataHTTPMethod, Val: http.MethodGet},
{Key: metadataHTTPHost, Val: "example.com"},
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
},
}
inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{
Kind: ResolvedServiceStream,
Destination: M.ParseSocksaddr("127.0.0.1:8080"),
OriginRequest: OriginRequestConfig{
Access: AccessConfig{
Required: true,
TeamName: "team",
},
},
})
if respWriter.status != http.StatusForbidden {
t.Fatalf("expected 403, got %d", respWriter.status)
}

View File

@@ -208,9 +208,29 @@ func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser
}
func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
validationRequest, err := buildMetadataOnlyHTTPRequest(ctx, request)
if err != nil {
i.logger.ErrorContext(ctx, "build request for access validation: ", err)
respWriter.WriteResponse(err, nil)
return
}
validationRequest = applyOriginRequest(validationRequest, service.OriginRequest)
if service.OriginRequest.Access.Required {
validator, err := i.accessCache.Get(service.OriginRequest.Access)
if err != nil {
i.logger.ErrorContext(ctx, "create access validator: ", err)
respWriter.WriteResponse(err, nil)
return
}
if err := validator.Validate(validationRequest.Context(), validationRequest); err != nil {
respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusForbidden, http.Header{}))
return
}
}
switch service.Kind {
case ResolvedServiceStatus:
err := respWriter.WriteResponse(nil, encodeResponseHeaders(service.StatusCode, http.Header{}))
err = respWriter.WriteResponse(nil, encodeResponseHeaders(service.StatusCode, http.Header{}))
if err != nil {
i.logger.ErrorContext(ctx, "write status service response: ", err)
}
@@ -321,18 +341,6 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser,
defer cancel()
httpRequest = httpRequest.WithContext(requestCtx)
}
if service.OriginRequest.Access.Required {
validator, err := i.accessCache.Get(service.OriginRequest.Access)
if err != nil {
i.logger.ErrorContext(ctx, "create access validator: ", err)
respWriter.WriteResponse(err, nil)
return
}
if err := validator.Validate(requestCtx, httpRequest); err != nil {
respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusForbidden, http.Header{}))
return
}
}
httpClient := &http.Client{
Transport: transport,
@@ -498,6 +506,14 @@ func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig
return request
}
func buildMetadataOnlyHTTPRequest(ctx context.Context, connectRequest *ConnectRequest) (*http.Request, error) {
return buildHTTPRequestFromMetadata(ctx, &ConnectRequest{
Dest: connectRequest.Dest,
Type: connectRequest.Type,
Metadata: append([]Metadata(nil), connectRequest.Metadata...),
}, http.NoBody)
}
func bidirectionalCopy(left, right io.ReadWriteCloser) {
var closeOnce sync.Once
closeBoth := func() {