diff --git a/protocol/cloudflare/access_test.go b/protocol/cloudflare/access_test.go index 594f94d77..357cd9f43 100644 --- a/protocol/cloudflare/access_test.go +++ b/protocol/cloudflare/access_test.go @@ -4,16 +4,16 @@ package cloudflare import ( "context" - "io" - "net" "net/http" "testing" + "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/inbound" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" ) type fakeAccessValidator struct { @@ -58,34 +58,94 @@ func TestRoundTripHTTPAccessDenied(t *testing.T) { } inboundInstance := newAccessTestInbound(t) - service := ResolvedService{ - Kind: ResolvedServiceHTTP, + respWriter := &fakeConnectResponseWriter{} + request := &ConnectRequest{ + Type: ConnectionTypeHTTP, + Dest: "http://127.0.0.1:8083/test", + Metadata: []Metadata{ + {Key: metadataHTTPMethod, Val: http.MethodGet}, + {Key: metadataHTTPHost, Val: "example.com"}, + }, + } + inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{ + Kind: ResolvedServiceHTTP, + Destination: M.ParseSocksaddr("127.0.0.1:8083"), OriginRequest: OriginRequestConfig{ Access: AccessConfig{ Required: true, TeamName: "team", }, }, + }) + if respWriter.status != http.StatusForbidden { + t.Fatalf("expected 403, got %d", respWriter.status) } - serverSide, clientSide := net.Pipe() - defer serverSide.Close() - defer clientSide.Close() +} +func TestHandleHTTPServiceStatusAccessDenied(t *testing.T) { + originalFactory := newAccessValidator + defer func() { + newAccessValidator = originalFactory + }() + newAccessValidator = func(access AccessConfig) (accessValidator, error) { + return &fakeAccessValidator{err: E.New("forbidden")}, nil + } + + inboundInstance := newAccessTestInbound(t) respWriter := &fakeConnectResponseWriter{} request := &ConnectRequest{ Type: ConnectionTypeHTTP, - Dest: "http://127.0.0.1:8083", + Dest: "https://example.com/status", Metadata: []Metadata{ {Key: metadataHTTPMethod, Val: http.MethodGet}, {Key: metadataHTTPHost, Val: "example.com"}, }, } - go func() { - defer clientSide.Close() - _, _ = io.Copy(io.Discard, clientSide) - }() - - inboundInstance.roundTripHTTP(context.Background(), serverSide, respWriter, request, service, &http.Transport{}) + inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{ + Kind: ResolvedServiceStatus, + OriginRequest: OriginRequestConfig{ + Access: AccessConfig{ + Required: true, + TeamName: "team", + }, + }, + StatusCode: 404, + }) + if respWriter.status != http.StatusForbidden { + t.Fatalf("expected 403, got %d", respWriter.status) + } +} + +func TestHandleHTTPServiceStreamAccessDenied(t *testing.T) { + originalFactory := newAccessValidator + defer func() { + newAccessValidator = originalFactory + }() + newAccessValidator = func(access AccessConfig) (accessValidator, error) { + return &fakeAccessValidator{err: E.New("forbidden")}, nil + } + + inboundInstance := newAccessTestInbound(t) + respWriter := &fakeConnectResponseWriter{} + request := &ConnectRequest{ + Type: ConnectionTypeWebsocket, + Dest: "https://example.com/ws", + Metadata: []Metadata{ + {Key: metadataHTTPMethod, Val: http.MethodGet}, + {Key: metadataHTTPHost, Val: "example.com"}, + {Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="}, + }, + } + inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{ + Kind: ResolvedServiceStream, + Destination: M.ParseSocksaddr("127.0.0.1:8080"), + OriginRequest: OriginRequestConfig{ + Access: AccessConfig{ + Required: true, + TeamName: "team", + }, + }, + }) if respWriter.status != http.StatusForbidden { t.Fatalf("expected 403, got %d", respWriter.status) } diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 6d049d66c..b33dea97c 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -208,9 +208,29 @@ func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser } func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { + validationRequest, err := buildMetadataOnlyHTTPRequest(ctx, request) + if err != nil { + i.logger.ErrorContext(ctx, "build request for access validation: ", err) + respWriter.WriteResponse(err, nil) + return + } + validationRequest = applyOriginRequest(validationRequest, service.OriginRequest) + if service.OriginRequest.Access.Required { + validator, err := i.accessCache.Get(service.OriginRequest.Access) + if err != nil { + i.logger.ErrorContext(ctx, "create access validator: ", err) + respWriter.WriteResponse(err, nil) + return + } + if err := validator.Validate(validationRequest.Context(), validationRequest); err != nil { + respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusForbidden, http.Header{})) + return + } + } + switch service.Kind { case ResolvedServiceStatus: - err := respWriter.WriteResponse(nil, encodeResponseHeaders(service.StatusCode, http.Header{})) + err = respWriter.WriteResponse(nil, encodeResponseHeaders(service.StatusCode, http.Header{})) if err != nil { i.logger.ErrorContext(ctx, "write status service response: ", err) } @@ -321,18 +341,6 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, defer cancel() httpRequest = httpRequest.WithContext(requestCtx) } - if service.OriginRequest.Access.Required { - validator, err := i.accessCache.Get(service.OriginRequest.Access) - if err != nil { - i.logger.ErrorContext(ctx, "create access validator: ", err) - respWriter.WriteResponse(err, nil) - return - } - if err := validator.Validate(requestCtx, httpRequest); err != nil { - respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusForbidden, http.Header{})) - return - } - } httpClient := &http.Client{ Transport: transport, @@ -498,6 +506,14 @@ func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig return request } +func buildMetadataOnlyHTTPRequest(ctx context.Context, connectRequest *ConnectRequest) (*http.Request, error) { + return buildHTTPRequestFromMetadata(ctx, &ConnectRequest{ + Dest: connectRequest.Dest, + Type: connectRequest.Type, + Metadata: append([]Metadata(nil), connectRequest.Metadata...), + }, http.NoBody) +} + func bidirectionalCopy(left, right io.ReadWriteCloser) { var closeOnce sync.Once closeBoth := func() {