mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
option: reject nested rule actions
This commit is contained in:
@@ -202,6 +202,12 @@ func (r *Router) isClosing() bool {
|
||||
}
|
||||
|
||||
func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) {
|
||||
for i, ruleOptions := range r.rawRules {
|
||||
err := R.ValidateNoNestedDNSRuleActions(ruleOptions)
|
||||
if err != nil {
|
||||
return nil, false, E.Cause(err, "parse dns rule[", i, "]")
|
||||
}
|
||||
}
|
||||
router := service.FromContext[adapter.Router](r.ctx)
|
||||
legacyDNSMode, err := resolveLegacyDNSMode(router, r.rawRules)
|
||||
if err != nil {
|
||||
|
||||
@@ -135,6 +135,7 @@ func (s *fakeRuleSet) IncRef() {
|
||||
defer s.access.Unlock()
|
||||
s.refs++
|
||||
}
|
||||
|
||||
func (s *fakeRuleSet) DecRef() {
|
||||
s.access.Lock()
|
||||
defer s.access.Unlock()
|
||||
@@ -149,6 +150,7 @@ func (s *fakeRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback) *
|
||||
defer s.access.Unlock()
|
||||
return s.callbacks.PushBack(callback)
|
||||
}
|
||||
|
||||
func (s *fakeRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSetUpdateCallback]) {
|
||||
s.access.Lock()
|
||||
defer s.access.Unlock()
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
"github.com/sagernet/sing/service"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package option
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
@@ -33,26 +34,24 @@ func (r Rule) MarshalJSON() ([]byte, error) {
|
||||
return badjson.MarshallObjects((_Rule)(r), v)
|
||||
}
|
||||
|
||||
func (r *Rule) UnmarshalJSON(bytes []byte) error {
|
||||
err := json.Unmarshal(bytes, (*_Rule)(r))
|
||||
func (r *Rule) UnmarshalJSONContext(ctx context.Context, bytes []byte) error {
|
||||
err := json.UnmarshalContext(ctx, bytes, (*_Rule)(r))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload, err := rulePayloadWithoutType(ctx, bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var v any
|
||||
switch r.Type {
|
||||
case "", C.RuleTypeDefault:
|
||||
r.Type = C.RuleTypeDefault
|
||||
v = &r.DefaultOptions
|
||||
return unmarshalDefaultRuleContext(ctx, payload, &r.DefaultOptions)
|
||||
case C.RuleTypeLogical:
|
||||
v = &r.LogicalOptions
|
||||
return unmarshalLogicalRuleContext(ctx, payload, &r.LogicalOptions)
|
||||
default:
|
||||
return E.New("unknown rule type: " + r.Type)
|
||||
}
|
||||
err = badjson.UnmarshallExcluded(bytes, (*_Rule)(r), v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r Rule) IsValid() bool {
|
||||
@@ -160,6 +159,64 @@ func (r *LogicalRule) UnmarshalJSON(data []byte) error {
|
||||
return badjson.UnmarshallExcluded(data, &r.RawLogicalRule, &r.RuleAction)
|
||||
}
|
||||
|
||||
func rulePayloadWithoutType(ctx context.Context, data []byte) ([]byte, error) {
|
||||
var content badjson.JSONObject
|
||||
err := content.UnmarshalJSONContext(ctx, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content.Remove("type")
|
||||
return content.MarshalJSONContext(ctx)
|
||||
}
|
||||
|
||||
func unmarshalDefaultRuleContext(ctx context.Context, data []byte, rule *DefaultRule) error {
|
||||
rawAction, routeOptions, err := inspectRouteRuleAction(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = rejectNestedRouteRuleAction(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
depth := nestedRuleDepth(ctx)
|
||||
err = json.UnmarshalContext(ctx, data, &rule.RawDefaultRule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = badjson.UnmarshallExcludedContext(ctx, data, &rule.RawDefaultRule, &rule.RuleAction)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if depth > 0 && rawAction == "" && routeOptions == (RouteActionOptions{}) {
|
||||
rule.RuleAction = RuleAction{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func unmarshalLogicalRuleContext(ctx context.Context, data []byte, rule *LogicalRule) error {
|
||||
rawAction, routeOptions, err := inspectRouteRuleAction(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = rejectNestedRouteRuleAction(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
depth := nestedRuleDepth(ctx)
|
||||
err = json.UnmarshalContext(nestedRuleChildContext(ctx), data, &rule.RawLogicalRule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = badjson.UnmarshallExcludedContext(ctx, data, &rule.RawLogicalRule, &rule.RuleAction)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if depth > 0 && rawAction == "" && routeOptions == (RouteActionOptions{}) {
|
||||
rule.RuleAction = RuleAction{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *LogicalRule) IsValid() bool {
|
||||
return len(r.Rules) > 0 && common.All(r.Rules, Rule.IsValid)
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func (r DNSRule) MarshalJSON() ([]byte, error) {
|
||||
}
|
||||
|
||||
func (r *DNSRule) UnmarshalJSONContext(ctx context.Context, bytes []byte) error {
|
||||
err := json.Unmarshal(bytes, (*_DNSRule)(r))
|
||||
err := json.UnmarshalContext(ctx, bytes, (*_DNSRule)(r))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -135,11 +135,27 @@ func (r DefaultDNSRule) MarshalJSON() ([]byte, error) {
|
||||
}
|
||||
|
||||
func (r *DefaultDNSRule) UnmarshalJSONContext(ctx context.Context, data []byte) error {
|
||||
err := json.UnmarshalContext(ctx, data, &r.RawDefaultDNSRule)
|
||||
rawAction, routeOptions, err := inspectDNSRuleAction(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return badjson.UnmarshallExcludedContext(ctx, data, &r.RawDefaultDNSRule, &r.DNSRuleAction)
|
||||
err = rejectNestedDNSRuleAction(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
depth := nestedRuleDepth(ctx)
|
||||
err = json.UnmarshalContext(ctx, data, &r.RawDefaultDNSRule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = badjson.UnmarshallExcludedContext(ctx, data, &r.RawDefaultDNSRule, &r.DNSRuleAction)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if depth > 0 && rawAction == "" && routeOptions == (DNSRouteActionOptions{}) {
|
||||
r.DNSRuleAction = DNSRuleAction{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r DefaultDNSRule) IsValid() bool {
|
||||
@@ -164,11 +180,27 @@ func (r LogicalDNSRule) MarshalJSON() ([]byte, error) {
|
||||
}
|
||||
|
||||
func (r *LogicalDNSRule) UnmarshalJSONContext(ctx context.Context, data []byte) error {
|
||||
err := json.Unmarshal(data, &r.RawLogicalDNSRule)
|
||||
rawAction, routeOptions, err := inspectDNSRuleAction(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return badjson.UnmarshallExcludedContext(ctx, data, &r.RawLogicalDNSRule, &r.DNSRuleAction)
|
||||
err = rejectNestedDNSRuleAction(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
depth := nestedRuleDepth(ctx)
|
||||
err = json.UnmarshalContext(nestedRuleChildContext(ctx), data, &r.RawLogicalDNSRule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = badjson.UnmarshallExcludedContext(ctx, data, &r.RawLogicalDNSRule, &r.DNSRuleAction)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if depth > 0 && rawAction == "" && routeOptions == (DNSRouteActionOptions{}) {
|
||||
r.DNSRuleAction = DNSRuleAction{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *LogicalDNSRule) IsValid() bool {
|
||||
|
||||
133
option/rule_nested.go
Normal file
133
option/rule_nested.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package option
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
"github.com/sagernet/sing/common/json/badjson"
|
||||
)
|
||||
|
||||
type nestedRuleDepthContextKey struct{}
|
||||
|
||||
const (
|
||||
routeRuleActionNestedUnsupportedMessage = "rule action is not supported in nested rules"
|
||||
dnsRuleActionNestedUnsupportedMessage = "DNS rule action is not supported in nested rules"
|
||||
)
|
||||
|
||||
var (
|
||||
routeRuleActionKeys = jsonFieldNames(reflect.TypeFor[_RuleAction](), reflect.TypeFor[RouteActionOptions]())
|
||||
dnsRuleActionKeys = jsonFieldNames(reflect.TypeFor[_DNSRuleAction](), reflect.TypeFor[DNSRouteActionOptions]())
|
||||
)
|
||||
|
||||
func nestedRuleChildContext(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, nestedRuleDepthContextKey{}, nestedRuleDepth(ctx)+1)
|
||||
}
|
||||
|
||||
func rejectNestedRouteRuleAction(ctx context.Context, content []byte) error {
|
||||
return rejectNestedRuleAction(ctx, content, routeRuleActionKeys, routeRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func rejectNestedDNSRuleAction(ctx context.Context, content []byte) error {
|
||||
return rejectNestedRuleAction(ctx, content, dnsRuleActionKeys, dnsRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func nestedRuleDepth(ctx context.Context) int {
|
||||
depth, _ := ctx.Value(nestedRuleDepthContextKey{}).(int)
|
||||
return depth
|
||||
}
|
||||
|
||||
func rejectNestedRuleAction(ctx context.Context, content []byte, keys []string, message string) error {
|
||||
if nestedRuleDepth(ctx) == 0 {
|
||||
return nil
|
||||
}
|
||||
hasActionKey, err := hasAnyJSONKey(ctx, content, keys...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if hasActionKey {
|
||||
return E.New(message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasAnyJSONKey(ctx context.Context, content []byte, keys ...string) (bool, error) {
|
||||
var object badjson.JSONObject
|
||||
err := object.UnmarshalJSONContext(ctx, content)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, key := range keys {
|
||||
if object.ContainsKey(key) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func inspectRouteRuleAction(ctx context.Context, content []byte) (string, RouteActionOptions, error) {
|
||||
var rawAction _RuleAction
|
||||
err := json.UnmarshalContext(ctx, content, &rawAction)
|
||||
if err != nil {
|
||||
return "", RouteActionOptions{}, err
|
||||
}
|
||||
var routeOptions RouteActionOptions
|
||||
err = json.UnmarshalContext(ctx, content, &routeOptions)
|
||||
if err != nil {
|
||||
return "", RouteActionOptions{}, err
|
||||
}
|
||||
return rawAction.Action, routeOptions, nil
|
||||
}
|
||||
|
||||
func inspectDNSRuleAction(ctx context.Context, content []byte) (string, DNSRouteActionOptions, error) {
|
||||
var rawAction _DNSRuleAction
|
||||
err := json.UnmarshalContext(ctx, content, &rawAction)
|
||||
if err != nil {
|
||||
return "", DNSRouteActionOptions{}, err
|
||||
}
|
||||
var routeOptions DNSRouteActionOptions
|
||||
err = json.UnmarshalContext(ctx, content, &routeOptions)
|
||||
if err != nil {
|
||||
return "", DNSRouteActionOptions{}, err
|
||||
}
|
||||
return rawAction.Action, routeOptions, nil
|
||||
}
|
||||
|
||||
func jsonFieldNames(types ...reflect.Type) []string {
|
||||
fieldMap := make(map[string]struct{})
|
||||
for _, fieldType := range types {
|
||||
appendJSONFieldNames(fieldMap, fieldType)
|
||||
}
|
||||
fieldNames := make([]string, 0, len(fieldMap))
|
||||
for fieldName := range fieldMap {
|
||||
fieldNames = append(fieldNames, fieldName)
|
||||
}
|
||||
return fieldNames
|
||||
}
|
||||
|
||||
func appendJSONFieldNames(fieldMap map[string]struct{}, fieldType reflect.Type) {
|
||||
for fieldType.Kind() == reflect.Pointer {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
if fieldType.Kind() != reflect.Struct {
|
||||
return
|
||||
}
|
||||
for i := range fieldType.NumField() {
|
||||
field := fieldType.Field(i)
|
||||
tagValue := field.Tag.Get("json")
|
||||
tagName, _, _ := strings.Cut(tagValue, ",")
|
||||
if tagName == "-" {
|
||||
continue
|
||||
}
|
||||
if field.Anonymous && tagName == "" {
|
||||
appendJSONFieldNames(fieldMap, field.Type)
|
||||
continue
|
||||
}
|
||||
if tagName == "" {
|
||||
tagName = field.Name
|
||||
}
|
||||
fieldMap[tagName] = struct{}{}
|
||||
}
|
||||
}
|
||||
271
option/rule_nested_test.go
Normal file
271
option/rule_nested_test.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package option
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRuleRejectsNestedDefaultRuleAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule Rule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{"domain": "example.com", "outbound": "direct"}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func TestRuleRejectsNestedLogicalRuleAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule Rule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{
|
||||
"type": "logical",
|
||||
"mode": "or",
|
||||
"action": "route",
|
||||
"outbound": "direct",
|
||||
"rules": [{"domain": "example.com"}]
|
||||
}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func TestRuleRejectsNestedDefaultRuleZeroValueOutbound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule Rule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{"domain": "example.com", "outbound": ""}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func TestRuleRejectsNestedDefaultRuleZeroValueRouteOption(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule Rule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{"domain": "example.com", "udp_connect": false}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func TestRuleRejectsNestedLogicalRuleZeroValueAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule Rule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{
|
||||
"type": "logical",
|
||||
"mode": "or",
|
||||
"action": "",
|
||||
"rules": [{"domain": "example.com"}]
|
||||
}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func TestRuleRejectsNestedLogicalRuleZeroValueRouteOption(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule Rule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{
|
||||
"type": "logical",
|
||||
"mode": "or",
|
||||
"override_port": 0,
|
||||
"rules": [{"domain": "example.com"}]
|
||||
}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func TestRuleAllowsTopLevelLogicalAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule Rule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"outbound": "direct",
|
||||
"rules": [{"domain": "example.com"}]
|
||||
}`), &rule)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, C.RuleActionTypeRoute, rule.LogicalOptions.Action)
|
||||
require.Equal(t, "direct", rule.LogicalOptions.RouteOptions.Outbound)
|
||||
}
|
||||
|
||||
func TestRuleLeavesUnknownNestedKeysToNormalValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule Rule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{"domain": "example.com", "foo": "bar"}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, "unknown field")
|
||||
require.NotContains(t, err.Error(), routeRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func TestDNSRuleRejectsNestedDefaultRuleAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule DNSRule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{"domain": "example.com", "server": "default"}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func TestDNSRuleRejectsNestedLogicalRuleAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule DNSRule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{
|
||||
"type": "logical",
|
||||
"mode": "or",
|
||||
"action": "route",
|
||||
"server": "default",
|
||||
"rules": [{"domain": "example.com"}]
|
||||
}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func TestDNSRuleRejectsNestedDefaultRuleZeroValueServer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule DNSRule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{"domain": "example.com", "server": ""}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func TestDNSRuleRejectsNestedDefaultRuleZeroValueRouteOption(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule DNSRule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{"domain": "example.com", "disable_cache": false}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func TestDNSRuleRejectsNestedLogicalRuleZeroValueAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule DNSRule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{
|
||||
"type": "logical",
|
||||
"mode": "or",
|
||||
"action": "",
|
||||
"rules": [{"domain": "example.com"}]
|
||||
}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func TestDNSRuleRejectsNestedLogicalRuleZeroValueRouteOption(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule DNSRule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{
|
||||
"type": "logical",
|
||||
"mode": "or",
|
||||
"disable_cache": false,
|
||||
"rules": [{"domain": "example.com"}]
|
||||
}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func TestDNSRuleAllowsTopLevelLogicalAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule DNSRule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"server": "default",
|
||||
"rules": [{"domain": "example.com"}]
|
||||
}`), &rule)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, C.RuleActionTypeRoute, rule.LogicalOptions.Action)
|
||||
require.Equal(t, "default", rule.LogicalOptions.RouteOptions.Server)
|
||||
}
|
||||
|
||||
func TestDNSRuleLeavesUnknownNestedKeysToNormalValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule DNSRule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{"domain": "example.com", "foo": "bar"}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, "unknown field")
|
||||
require.NotContains(t, err.Error(), dnsRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
@@ -70,6 +70,10 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
|
||||
|
||||
func (r *Router) Initialize(rules []option.Rule, ruleSets []option.RuleSet) error {
|
||||
for i, options := range rules {
|
||||
err := R.ValidateNoNestedRuleActions(options)
|
||||
if err != nil {
|
||||
return E.Cause(err, "parse rule[", i, "]")
|
||||
}
|
||||
rule, err := R.NewRule(r.ctx, r.logger, options, false)
|
||||
if err != nil {
|
||||
return E.Cause(err, "parse rule[", i, "]")
|
||||
|
||||
76
route/rule/nested_action.go
Normal file
76
route/rule/nested_action.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
const (
|
||||
routeRuleActionNestedUnsupportedMessage = "rule action is not supported in nested rules"
|
||||
dnsRuleActionNestedUnsupportedMessage = "DNS rule action is not supported in nested rules"
|
||||
)
|
||||
|
||||
func ValidateNoNestedRuleActions(rule option.Rule) error {
|
||||
return validateNoNestedRuleActions(rule, false)
|
||||
}
|
||||
|
||||
func ValidateNoNestedDNSRuleActions(rule option.DNSRule) error {
|
||||
return validateNoNestedDNSRuleActions(rule, false)
|
||||
}
|
||||
|
||||
func validateNoNestedRuleActions(rule option.Rule, nested bool) error {
|
||||
if nested && ruleHasConfiguredAction(rule) {
|
||||
return E.New(routeRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
if rule.Type != C.RuleTypeLogical {
|
||||
return nil
|
||||
}
|
||||
for i, subRule := range rule.LogicalOptions.Rules {
|
||||
err := validateNoNestedRuleActions(subRule, true)
|
||||
if err != nil {
|
||||
return E.Cause(err, "sub rule[", i, "]")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateNoNestedDNSRuleActions(rule option.DNSRule, nested bool) error {
|
||||
if nested && dnsRuleHasConfiguredAction(rule) {
|
||||
return E.New(dnsRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
if rule.Type != C.RuleTypeLogical {
|
||||
return nil
|
||||
}
|
||||
for i, subRule := range rule.LogicalOptions.Rules {
|
||||
err := validateNoNestedDNSRuleActions(subRule, true)
|
||||
if err != nil {
|
||||
return E.Cause(err, "sub rule[", i, "]")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ruleHasConfiguredAction(rule option.Rule) bool {
|
||||
switch rule.Type {
|
||||
case "", C.RuleTypeDefault:
|
||||
return !reflect.DeepEqual(rule.DefaultOptions.RuleAction, option.RuleAction{})
|
||||
case C.RuleTypeLogical:
|
||||
return !reflect.DeepEqual(rule.LogicalOptions.RuleAction, option.RuleAction{})
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func dnsRuleHasConfiguredAction(rule option.DNSRule) bool {
|
||||
switch rule.Type {
|
||||
case "", C.RuleTypeDefault:
|
||||
return !reflect.DeepEqual(rule.DefaultOptions.DNSRuleAction, option.DNSRuleAction{})
|
||||
case C.RuleTypeLogical:
|
||||
return !reflect.DeepEqual(rule.LogicalOptions.DNSRuleAction, option.DNSRuleAction{})
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
137
route/rule/nested_action_test.go
Normal file
137
route/rule/nested_action_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewRulePreservesImplicitTopLevelDefaultAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var options option.Rule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"domain": "example.com"
|
||||
}`), &options)
|
||||
require.NoError(t, err)
|
||||
|
||||
rule, err := NewRule(context.Background(), log.NewNOPFactory().NewLogger("router"), options, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule.Action())
|
||||
require.Equal(t, C.RuleActionTypeRoute, rule.Action().Type())
|
||||
}
|
||||
|
||||
func TestNewRuleAllowsNestedRuleWithoutAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var options option.Rule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{"domain": "example.com"}
|
||||
]
|
||||
}`), &options)
|
||||
require.NoError(t, err)
|
||||
|
||||
rule, err := NewRule(context.Background(), log.NewNOPFactory().NewLogger("router"), options, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule.Action())
|
||||
require.Equal(t, C.RuleActionTypeRoute, rule.Action().Type())
|
||||
}
|
||||
|
||||
func TestNewRuleRejectsNestedRuleAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := NewRule(context.Background(), log.NewNOPFactory().NewLogger("router"), option.Rule{
|
||||
Type: C.RuleTypeLogical,
|
||||
LogicalOptions: option.LogicalRule{
|
||||
RawLogicalRule: option.RawLogicalRule{
|
||||
Mode: C.LogicalTypeAnd,
|
||||
Rules: []option.Rule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultRule{
|
||||
RuleAction: option.RuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.RouteActionOptions{
|
||||
Outbound: "direct",
|
||||
},
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func TestNewDNSRulePreservesImplicitTopLevelDefaultAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var options option.DNSRule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"domain": "example.com"
|
||||
}`), &options)
|
||||
require.NoError(t, err)
|
||||
|
||||
rule, err := NewDNSRule(context.Background(), log.NewNOPFactory().NewLogger("dns"), options, false, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule.Action())
|
||||
require.Equal(t, C.RuleActionTypeRoute, rule.Action().Type())
|
||||
}
|
||||
|
||||
func TestNewDNSRuleAllowsNestedRuleWithoutAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var options option.DNSRule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{"domain": "example.com"}
|
||||
]
|
||||
}`), &options)
|
||||
require.NoError(t, err)
|
||||
|
||||
rule, err := NewDNSRule(context.Background(), log.NewNOPFactory().NewLogger("dns"), options, false, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule.Action())
|
||||
require.Equal(t, C.RuleActionTypeRoute, rule.Action().Type())
|
||||
}
|
||||
|
||||
func TestNewDNSRuleRejectsNestedRuleAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := NewDNSRule(context.Background(), log.NewNOPFactory().NewLogger("dns"), option.DNSRule{
|
||||
Type: C.RuleTypeLogical,
|
||||
LogicalOptions: option.LogicalDNSRule{
|
||||
RawLogicalDNSRule: option.RawLogicalDNSRule{
|
||||
Mode: C.LogicalTypeAnd,
|
||||
Rules: []option.DNSRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{
|
||||
Server: "default",
|
||||
},
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{
|
||||
Server: "default",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, true, false)
|
||||
require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
@@ -326,6 +326,10 @@ func NewLogicalRule(ctx context.Context, logger log.ContextLogger, options optio
|
||||
return nil, E.New("unknown logical mode: ", options.Mode)
|
||||
}
|
||||
for i, subOptions := range options.Rules {
|
||||
err = validateNoNestedRuleActions(subOptions, true)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "sub rule[", i, "]")
|
||||
}
|
||||
subRule, err := NewRule(ctx, logger, subOptions, false)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "sub rule[", i, "]")
|
||||
|
||||
@@ -467,6 +467,10 @@ func NewLogicalDNSRule(ctx context.Context, logger log.ContextLogger, options op
|
||||
return nil, E.New("unknown logical mode: ", options.Mode)
|
||||
}
|
||||
for i, subRule := range options.Rules {
|
||||
err := validateNoNestedDNSRuleActions(subRule, true)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "sub rule[", i, "]")
|
||||
}
|
||||
rule, err := NewDNSRule(ctx, logger, subRule, false, legacyDNSMode)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "sub rule[", i, "]")
|
||||
|
||||
@@ -25,16 +25,21 @@ func (r *ruleSetItemTestRouter) Close() error { return nil }
|
||||
func (r *ruleSetItemTestRouter) PreMatch(adapter.InboundContext, tun.DirectRouteContext, time.Duration, bool) (tun.DirectRouteDestination, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *ruleSetItemTestRouter) RouteConnection(context.Context, net.Conn, adapter.InboundContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ruleSetItemTestRouter) RoutePacketConnection(context.Context, N.PacketConn, adapter.InboundContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ruleSetItemTestRouter) RouteConnectionEx(context.Context, net.Conn, adapter.InboundContext, N.CloseHandlerFunc) {
|
||||
}
|
||||
|
||||
func (r *ruleSetItemTestRouter) RoutePacketConnectionEx(context.Context, N.PacketConn, adapter.InboundContext, N.CloseHandlerFunc) {
|
||||
}
|
||||
|
||||
func (r *ruleSetItemTestRouter) RuleSet(tag string) (adapter.RuleSet, bool) {
|
||||
ruleSet, loaded := r.ruleSets[tag]
|
||||
return ruleSet, loaded
|
||||
|
||||
Reference in New Issue
Block a user