mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-14 04:38:28 +10:00
Enforce cloudflare access on all ingress services
This commit is contained in:
@@ -4,16 +4,16 @@ package cloudflare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type fakeAccessValidator struct {
|
||||
@@ -58,34 +58,94 @@ func TestRoundTripHTTPAccessDenied(t *testing.T) {
|
||||
}
|
||||
|
||||
inboundInstance := newAccessTestInbound(t)
|
||||
service := ResolvedService{
|
||||
Kind: ResolvedServiceHTTP,
|
||||
respWriter := &fakeConnectResponseWriter{}
|
||||
request := &ConnectRequest{
|
||||
Type: ConnectionTypeHTTP,
|
||||
Dest: "http://127.0.0.1:8083/test",
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPMethod, Val: http.MethodGet},
|
||||
{Key: metadataHTTPHost, Val: "example.com"},
|
||||
},
|
||||
}
|
||||
inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{
|
||||
Kind: ResolvedServiceHTTP,
|
||||
Destination: M.ParseSocksaddr("127.0.0.1:8083"),
|
||||
OriginRequest: OriginRequestConfig{
|
||||
Access: AccessConfig{
|
||||
Required: true,
|
||||
TeamName: "team",
|
||||
},
|
||||
},
|
||||
})
|
||||
if respWriter.status != http.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d", respWriter.status)
|
||||
}
|
||||
serverSide, clientSide := net.Pipe()
|
||||
defer serverSide.Close()
|
||||
defer clientSide.Close()
|
||||
}
|
||||
|
||||
func TestHandleHTTPServiceStatusAccessDenied(t *testing.T) {
|
||||
originalFactory := newAccessValidator
|
||||
defer func() {
|
||||
newAccessValidator = originalFactory
|
||||
}()
|
||||
newAccessValidator = func(access AccessConfig) (accessValidator, error) {
|
||||
return &fakeAccessValidator{err: E.New("forbidden")}, nil
|
||||
}
|
||||
|
||||
inboundInstance := newAccessTestInbound(t)
|
||||
respWriter := &fakeConnectResponseWriter{}
|
||||
request := &ConnectRequest{
|
||||
Type: ConnectionTypeHTTP,
|
||||
Dest: "http://127.0.0.1:8083",
|
||||
Dest: "https://example.com/status",
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPMethod, Val: http.MethodGet},
|
||||
{Key: metadataHTTPHost, Val: "example.com"},
|
||||
},
|
||||
}
|
||||
go func() {
|
||||
defer clientSide.Close()
|
||||
_, _ = io.Copy(io.Discard, clientSide)
|
||||
}()
|
||||
|
||||
inboundInstance.roundTripHTTP(context.Background(), serverSide, respWriter, request, service, &http.Transport{})
|
||||
inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{
|
||||
Kind: ResolvedServiceStatus,
|
||||
OriginRequest: OriginRequestConfig{
|
||||
Access: AccessConfig{
|
||||
Required: true,
|
||||
TeamName: "team",
|
||||
},
|
||||
},
|
||||
StatusCode: 404,
|
||||
})
|
||||
if respWriter.status != http.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d", respWriter.status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleHTTPServiceStreamAccessDenied(t *testing.T) {
|
||||
originalFactory := newAccessValidator
|
||||
defer func() {
|
||||
newAccessValidator = originalFactory
|
||||
}()
|
||||
newAccessValidator = func(access AccessConfig) (accessValidator, error) {
|
||||
return &fakeAccessValidator{err: E.New("forbidden")}, nil
|
||||
}
|
||||
|
||||
inboundInstance := newAccessTestInbound(t)
|
||||
respWriter := &fakeConnectResponseWriter{}
|
||||
request := &ConnectRequest{
|
||||
Type: ConnectionTypeWebsocket,
|
||||
Dest: "https://example.com/ws",
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPMethod, Val: http.MethodGet},
|
||||
{Key: metadataHTTPHost, Val: "example.com"},
|
||||
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
},
|
||||
}
|
||||
inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{
|
||||
Kind: ResolvedServiceStream,
|
||||
Destination: M.ParseSocksaddr("127.0.0.1:8080"),
|
||||
OriginRequest: OriginRequestConfig{
|
||||
Access: AccessConfig{
|
||||
Required: true,
|
||||
TeamName: "team",
|
||||
},
|
||||
},
|
||||
})
|
||||
if respWriter.status != http.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d", respWriter.status)
|
||||
}
|
||||
|
||||
@@ -208,9 +208,29 @@ func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
validationRequest, err := buildMetadataOnlyHTTPRequest(ctx, request)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "build request for access validation: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
validationRequest = applyOriginRequest(validationRequest, service.OriginRequest)
|
||||
if service.OriginRequest.Access.Required {
|
||||
validator, err := i.accessCache.Get(service.OriginRequest.Access)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "create access validator: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
if err := validator.Validate(validationRequest.Context(), validationRequest); err != nil {
|
||||
respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusForbidden, http.Header{}))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
switch service.Kind {
|
||||
case ResolvedServiceStatus:
|
||||
err := respWriter.WriteResponse(nil, encodeResponseHeaders(service.StatusCode, http.Header{}))
|
||||
err = respWriter.WriteResponse(nil, encodeResponseHeaders(service.StatusCode, http.Header{}))
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "write status service response: ", err)
|
||||
}
|
||||
@@ -321,18 +341,6 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser,
|
||||
defer cancel()
|
||||
httpRequest = httpRequest.WithContext(requestCtx)
|
||||
}
|
||||
if service.OriginRequest.Access.Required {
|
||||
validator, err := i.accessCache.Get(service.OriginRequest.Access)
|
||||
if err != nil {
|
||||
i.logger.ErrorContext(ctx, "create access validator: ", err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
if err := validator.Validate(requestCtx, httpRequest); err != nil {
|
||||
respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusForbidden, http.Header{}))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: transport,
|
||||
@@ -498,6 +506,14 @@ func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig
|
||||
return request
|
||||
}
|
||||
|
||||
func buildMetadataOnlyHTTPRequest(ctx context.Context, connectRequest *ConnectRequest) (*http.Request, error) {
|
||||
return buildHTTPRequestFromMetadata(ctx, &ConnectRequest{
|
||||
Dest: connectRequest.Dest,
|
||||
Type: connectRequest.Type,
|
||||
Metadata: append([]Metadata(nil), connectRequest.Metadata...),
|
||||
}, http.NoBody)
|
||||
}
|
||||
|
||||
func bidirectionalCopy(left, right io.ReadWriteCloser) {
|
||||
var closeOnce sync.Once
|
||||
closeBoth := func() {
|
||||
|
||||
Reference in New Issue
Block a user