From baf1da892b904b69c6b7a422a8ee5f75f446af5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 29 Mar 2026 01:34:27 +0800 Subject: [PATCH] option: reject nested rule actions --- dns/router.go | 6 + dns/router_test.go | 2 + option/dns_test.go | 1 + option/rule.go | 77 +++++++- option/rule_dns.go | 42 +++- option/rule_nested.go | 133 +++++++++++++ option/rule_nested_test.go | 271 ++++++++++++++++++++++++++ route/router.go | 4 + route/rule/nested_action.go | 76 ++++++++ route/rule/nested_action_test.go | 137 +++++++++++++ route/rule/rule_default.go | 4 + route/rule/rule_dns.go | 4 + route/rule/rule_item_rule_set_test.go | 5 + 13 files changed, 747 insertions(+), 15 deletions(-) create mode 100644 option/rule_nested.go create mode 100644 option/rule_nested_test.go create mode 100644 route/rule/nested_action.go create mode 100644 route/rule/nested_action_test.go diff --git a/dns/router.go b/dns/router.go index 13ffd0a0a..ae8aed987 100644 --- a/dns/router.go +++ b/dns/router.go @@ -202,6 +202,12 @@ func (r *Router) isClosing() bool { } func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) { + for i, ruleOptions := range r.rawRules { + err := R.ValidateNoNestedDNSRuleActions(ruleOptions) + if err != nil { + return nil, false, E.Cause(err, "parse dns rule[", i, "]") + } + } router := service.FromContext[adapter.Router](r.ctx) legacyDNSMode, err := resolveLegacyDNSMode(router, r.rawRules) if err != nil { diff --git a/dns/router_test.go b/dns/router_test.go index 71cbfc960..7c3c4b5fb 100644 --- a/dns/router_test.go +++ b/dns/router_test.go @@ -135,6 +135,7 @@ func (s *fakeRuleSet) IncRef() { defer s.access.Unlock() s.refs++ } + func (s *fakeRuleSet) DecRef() { s.access.Lock() defer s.access.Unlock() @@ -149,6 +150,7 @@ func (s *fakeRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback) * defer s.access.Unlock() return s.callbacks.PushBack(callback) } + func (s *fakeRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSetUpdateCallback]) { s.access.Lock() defer s.access.Unlock() diff --git a/option/dns_test.go b/option/dns_test.go index 7448a9e2b..12ee0bca3 100644 --- a/option/dns_test.go +++ b/option/dns_test.go @@ -7,6 +7,7 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing/common/json" "github.com/sagernet/sing/service" + "github.com/stretchr/testify/require" ) diff --git a/option/rule.go b/option/rule.go index b792ccf4b..9fd943797 100644 --- a/option/rule.go +++ b/option/rule.go @@ -1,6 +1,7 @@ package option import ( + "context" "reflect" C "github.com/sagernet/sing-box/constant" @@ -33,26 +34,24 @@ func (r Rule) MarshalJSON() ([]byte, error) { return badjson.MarshallObjects((_Rule)(r), v) } -func (r *Rule) UnmarshalJSON(bytes []byte) error { - err := json.Unmarshal(bytes, (*_Rule)(r)) +func (r *Rule) UnmarshalJSONContext(ctx context.Context, bytes []byte) error { + err := json.UnmarshalContext(ctx, bytes, (*_Rule)(r)) + if err != nil { + return err + } + payload, err := rulePayloadWithoutType(ctx, bytes) if err != nil { return err } - var v any switch r.Type { case "", C.RuleTypeDefault: r.Type = C.RuleTypeDefault - v = &r.DefaultOptions + return unmarshalDefaultRuleContext(ctx, payload, &r.DefaultOptions) case C.RuleTypeLogical: - v = &r.LogicalOptions + return unmarshalLogicalRuleContext(ctx, payload, &r.LogicalOptions) default: return E.New("unknown rule type: " + r.Type) } - err = badjson.UnmarshallExcluded(bytes, (*_Rule)(r), v) - if err != nil { - return err - } - return nil } func (r Rule) IsValid() bool { @@ -160,6 +159,64 @@ func (r *LogicalRule) UnmarshalJSON(data []byte) error { return badjson.UnmarshallExcluded(data, &r.RawLogicalRule, &r.RuleAction) } +func rulePayloadWithoutType(ctx context.Context, data []byte) ([]byte, error) { + var content badjson.JSONObject + err := content.UnmarshalJSONContext(ctx, data) + if err != nil { + return nil, err + } + content.Remove("type") + return content.MarshalJSONContext(ctx) +} + +func unmarshalDefaultRuleContext(ctx context.Context, data []byte, rule *DefaultRule) error { + rawAction, routeOptions, err := inspectRouteRuleAction(ctx, data) + if err != nil { + return err + } + err = rejectNestedRouteRuleAction(ctx, data) + if err != nil { + return err + } + depth := nestedRuleDepth(ctx) + err = json.UnmarshalContext(ctx, data, &rule.RawDefaultRule) + if err != nil { + return err + } + err = badjson.UnmarshallExcludedContext(ctx, data, &rule.RawDefaultRule, &rule.RuleAction) + if err != nil { + return err + } + if depth > 0 && rawAction == "" && routeOptions == (RouteActionOptions{}) { + rule.RuleAction = RuleAction{} + } + return nil +} + +func unmarshalLogicalRuleContext(ctx context.Context, data []byte, rule *LogicalRule) error { + rawAction, routeOptions, err := inspectRouteRuleAction(ctx, data) + if err != nil { + return err + } + err = rejectNestedRouteRuleAction(ctx, data) + if err != nil { + return err + } + depth := nestedRuleDepth(ctx) + err = json.UnmarshalContext(nestedRuleChildContext(ctx), data, &rule.RawLogicalRule) + if err != nil { + return err + } + err = badjson.UnmarshallExcludedContext(ctx, data, &rule.RawLogicalRule, &rule.RuleAction) + if err != nil { + return err + } + if depth > 0 && rawAction == "" && routeOptions == (RouteActionOptions{}) { + rule.RuleAction = RuleAction{} + } + return nil +} + func (r *LogicalRule) IsValid() bool { return len(r.Rules) > 0 && common.All(r.Rules, Rule.IsValid) } diff --git a/option/rule_dns.go b/option/rule_dns.go index 5a73d69fd..d1298635b 100644 --- a/option/rule_dns.go +++ b/option/rule_dns.go @@ -35,7 +35,7 @@ func (r DNSRule) MarshalJSON() ([]byte, error) { } func (r *DNSRule) UnmarshalJSONContext(ctx context.Context, bytes []byte) error { - err := json.Unmarshal(bytes, (*_DNSRule)(r)) + err := json.UnmarshalContext(ctx, bytes, (*_DNSRule)(r)) if err != nil { return err } @@ -135,11 +135,27 @@ func (r DefaultDNSRule) MarshalJSON() ([]byte, error) { } func (r *DefaultDNSRule) UnmarshalJSONContext(ctx context.Context, data []byte) error { - err := json.UnmarshalContext(ctx, data, &r.RawDefaultDNSRule) + rawAction, routeOptions, err := inspectDNSRuleAction(ctx, data) if err != nil { return err } - return badjson.UnmarshallExcludedContext(ctx, data, &r.RawDefaultDNSRule, &r.DNSRuleAction) + err = rejectNestedDNSRuleAction(ctx, data) + if err != nil { + return err + } + depth := nestedRuleDepth(ctx) + err = json.UnmarshalContext(ctx, data, &r.RawDefaultDNSRule) + if err != nil { + return err + } + err = badjson.UnmarshallExcludedContext(ctx, data, &r.RawDefaultDNSRule, &r.DNSRuleAction) + if err != nil { + return err + } + if depth > 0 && rawAction == "" && routeOptions == (DNSRouteActionOptions{}) { + r.DNSRuleAction = DNSRuleAction{} + } + return nil } func (r DefaultDNSRule) IsValid() bool { @@ -164,11 +180,27 @@ func (r LogicalDNSRule) MarshalJSON() ([]byte, error) { } func (r *LogicalDNSRule) UnmarshalJSONContext(ctx context.Context, data []byte) error { - err := json.Unmarshal(data, &r.RawLogicalDNSRule) + rawAction, routeOptions, err := inspectDNSRuleAction(ctx, data) if err != nil { return err } - return badjson.UnmarshallExcludedContext(ctx, data, &r.RawLogicalDNSRule, &r.DNSRuleAction) + err = rejectNestedDNSRuleAction(ctx, data) + if err != nil { + return err + } + depth := nestedRuleDepth(ctx) + err = json.UnmarshalContext(nestedRuleChildContext(ctx), data, &r.RawLogicalDNSRule) + if err != nil { + return err + } + err = badjson.UnmarshallExcludedContext(ctx, data, &r.RawLogicalDNSRule, &r.DNSRuleAction) + if err != nil { + return err + } + if depth > 0 && rawAction == "" && routeOptions == (DNSRouteActionOptions{}) { + r.DNSRuleAction = DNSRuleAction{} + } + return nil } func (r *LogicalDNSRule) IsValid() bool { diff --git a/option/rule_nested.go b/option/rule_nested.go new file mode 100644 index 000000000..c6038aad2 --- /dev/null +++ b/option/rule_nested.go @@ -0,0 +1,133 @@ +package option + +import ( + "context" + "reflect" + "strings" + + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json/badjson" +) + +type nestedRuleDepthContextKey struct{} + +const ( + routeRuleActionNestedUnsupportedMessage = "rule action is not supported in nested rules" + dnsRuleActionNestedUnsupportedMessage = "DNS rule action is not supported in nested rules" +) + +var ( + routeRuleActionKeys = jsonFieldNames(reflect.TypeFor[_RuleAction](), reflect.TypeFor[RouteActionOptions]()) + dnsRuleActionKeys = jsonFieldNames(reflect.TypeFor[_DNSRuleAction](), reflect.TypeFor[DNSRouteActionOptions]()) +) + +func nestedRuleChildContext(ctx context.Context) context.Context { + return context.WithValue(ctx, nestedRuleDepthContextKey{}, nestedRuleDepth(ctx)+1) +} + +func rejectNestedRouteRuleAction(ctx context.Context, content []byte) error { + return rejectNestedRuleAction(ctx, content, routeRuleActionKeys, routeRuleActionNestedUnsupportedMessage) +} + +func rejectNestedDNSRuleAction(ctx context.Context, content []byte) error { + return rejectNestedRuleAction(ctx, content, dnsRuleActionKeys, dnsRuleActionNestedUnsupportedMessage) +} + +func nestedRuleDepth(ctx context.Context) int { + depth, _ := ctx.Value(nestedRuleDepthContextKey{}).(int) + return depth +} + +func rejectNestedRuleAction(ctx context.Context, content []byte, keys []string, message string) error { + if nestedRuleDepth(ctx) == 0 { + return nil + } + hasActionKey, err := hasAnyJSONKey(ctx, content, keys...) + if err != nil { + return err + } + if hasActionKey { + return E.New(message) + } + return nil +} + +func hasAnyJSONKey(ctx context.Context, content []byte, keys ...string) (bool, error) { + var object badjson.JSONObject + err := object.UnmarshalJSONContext(ctx, content) + if err != nil { + return false, err + } + for _, key := range keys { + if object.ContainsKey(key) { + return true, nil + } + } + return false, nil +} + +func inspectRouteRuleAction(ctx context.Context, content []byte) (string, RouteActionOptions, error) { + var rawAction _RuleAction + err := json.UnmarshalContext(ctx, content, &rawAction) + if err != nil { + return "", RouteActionOptions{}, err + } + var routeOptions RouteActionOptions + err = json.UnmarshalContext(ctx, content, &routeOptions) + if err != nil { + return "", RouteActionOptions{}, err + } + return rawAction.Action, routeOptions, nil +} + +func inspectDNSRuleAction(ctx context.Context, content []byte) (string, DNSRouteActionOptions, error) { + var rawAction _DNSRuleAction + err := json.UnmarshalContext(ctx, content, &rawAction) + if err != nil { + return "", DNSRouteActionOptions{}, err + } + var routeOptions DNSRouteActionOptions + err = json.UnmarshalContext(ctx, content, &routeOptions) + if err != nil { + return "", DNSRouteActionOptions{}, err + } + return rawAction.Action, routeOptions, nil +} + +func jsonFieldNames(types ...reflect.Type) []string { + fieldMap := make(map[string]struct{}) + for _, fieldType := range types { + appendJSONFieldNames(fieldMap, fieldType) + } + fieldNames := make([]string, 0, len(fieldMap)) + for fieldName := range fieldMap { + fieldNames = append(fieldNames, fieldName) + } + return fieldNames +} + +func appendJSONFieldNames(fieldMap map[string]struct{}, fieldType reflect.Type) { + for fieldType.Kind() == reflect.Pointer { + fieldType = fieldType.Elem() + } + if fieldType.Kind() != reflect.Struct { + return + } + for i := range fieldType.NumField() { + field := fieldType.Field(i) + tagValue := field.Tag.Get("json") + tagName, _, _ := strings.Cut(tagValue, ",") + if tagName == "-" { + continue + } + if field.Anonymous && tagName == "" { + appendJSONFieldNames(fieldMap, field.Type) + continue + } + if tagName == "" { + tagName = field.Name + } + fieldMap[tagName] = struct{}{} + } +} diff --git a/option/rule_nested_test.go b/option/rule_nested_test.go new file mode 100644 index 000000000..ca0ff32df --- /dev/null +++ b/option/rule_nested_test.go @@ -0,0 +1,271 @@ +package option + +import ( + "context" + "testing" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common/json" + + "github.com/stretchr/testify/require" +) + +func TestRuleRejectsNestedDefaultRuleAction(t *testing.T) { + t.Parallel() + + var rule Rule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "rules": [ + {"domain": "example.com", "outbound": "direct"} + ] + }`), &rule) + require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage) +} + +func TestRuleRejectsNestedLogicalRuleAction(t *testing.T) { + t.Parallel() + + var rule Rule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "rules": [ + { + "type": "logical", + "mode": "or", + "action": "route", + "outbound": "direct", + "rules": [{"domain": "example.com"}] + } + ] + }`), &rule) + require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage) +} + +func TestRuleRejectsNestedDefaultRuleZeroValueOutbound(t *testing.T) { + t.Parallel() + + var rule Rule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "rules": [ + {"domain": "example.com", "outbound": ""} + ] + }`), &rule) + require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage) +} + +func TestRuleRejectsNestedDefaultRuleZeroValueRouteOption(t *testing.T) { + t.Parallel() + + var rule Rule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "rules": [ + {"domain": "example.com", "udp_connect": false} + ] + }`), &rule) + require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage) +} + +func TestRuleRejectsNestedLogicalRuleZeroValueAction(t *testing.T) { + t.Parallel() + + var rule Rule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "rules": [ + { + "type": "logical", + "mode": "or", + "action": "", + "rules": [{"domain": "example.com"}] + } + ] + }`), &rule) + require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage) +} + +func TestRuleRejectsNestedLogicalRuleZeroValueRouteOption(t *testing.T) { + t.Parallel() + + var rule Rule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "rules": [ + { + "type": "logical", + "mode": "or", + "override_port": 0, + "rules": [{"domain": "example.com"}] + } + ] + }`), &rule) + require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage) +} + +func TestRuleAllowsTopLevelLogicalAction(t *testing.T) { + t.Parallel() + + var rule Rule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "outbound": "direct", + "rules": [{"domain": "example.com"}] + }`), &rule) + require.NoError(t, err) + require.Equal(t, C.RuleActionTypeRoute, rule.LogicalOptions.Action) + require.Equal(t, "direct", rule.LogicalOptions.RouteOptions.Outbound) +} + +func TestRuleLeavesUnknownNestedKeysToNormalValidation(t *testing.T) { + t.Parallel() + + var rule Rule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "rules": [ + {"domain": "example.com", "foo": "bar"} + ] + }`), &rule) + require.ErrorContains(t, err, "unknown field") + require.NotContains(t, err.Error(), routeRuleActionNestedUnsupportedMessage) +} + +func TestDNSRuleRejectsNestedDefaultRuleAction(t *testing.T) { + t.Parallel() + + var rule DNSRule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "rules": [ + {"domain": "example.com", "server": "default"} + ] + }`), &rule) + require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage) +} + +func TestDNSRuleRejectsNestedLogicalRuleAction(t *testing.T) { + t.Parallel() + + var rule DNSRule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "rules": [ + { + "type": "logical", + "mode": "or", + "action": "route", + "server": "default", + "rules": [{"domain": "example.com"}] + } + ] + }`), &rule) + require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage) +} + +func TestDNSRuleRejectsNestedDefaultRuleZeroValueServer(t *testing.T) { + t.Parallel() + + var rule DNSRule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "rules": [ + {"domain": "example.com", "server": ""} + ] + }`), &rule) + require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage) +} + +func TestDNSRuleRejectsNestedDefaultRuleZeroValueRouteOption(t *testing.T) { + t.Parallel() + + var rule DNSRule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "rules": [ + {"domain": "example.com", "disable_cache": false} + ] + }`), &rule) + require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage) +} + +func TestDNSRuleRejectsNestedLogicalRuleZeroValueAction(t *testing.T) { + t.Parallel() + + var rule DNSRule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "rules": [ + { + "type": "logical", + "mode": "or", + "action": "", + "rules": [{"domain": "example.com"}] + } + ] + }`), &rule) + require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage) +} + +func TestDNSRuleRejectsNestedLogicalRuleZeroValueRouteOption(t *testing.T) { + t.Parallel() + + var rule DNSRule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "rules": [ + { + "type": "logical", + "mode": "or", + "disable_cache": false, + "rules": [{"domain": "example.com"}] + } + ] + }`), &rule) + require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage) +} + +func TestDNSRuleAllowsTopLevelLogicalAction(t *testing.T) { + t.Parallel() + + var rule DNSRule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "server": "default", + "rules": [{"domain": "example.com"}] + }`), &rule) + require.NoError(t, err) + require.Equal(t, C.RuleActionTypeRoute, rule.LogicalOptions.Action) + require.Equal(t, "default", rule.LogicalOptions.RouteOptions.Server) +} + +func TestDNSRuleLeavesUnknownNestedKeysToNormalValidation(t *testing.T) { + t.Parallel() + + var rule DNSRule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "rules": [ + {"domain": "example.com", "foo": "bar"} + ] + }`), &rule) + require.ErrorContains(t, err, "unknown field") + require.NotContains(t, err.Error(), dnsRuleActionNestedUnsupportedMessage) +} diff --git a/route/router.go b/route/router.go index 2815d5095..03546b2a7 100644 --- a/route/router.go +++ b/route/router.go @@ -70,6 +70,10 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route func (r *Router) Initialize(rules []option.Rule, ruleSets []option.RuleSet) error { for i, options := range rules { + err := R.ValidateNoNestedRuleActions(options) + if err != nil { + return E.Cause(err, "parse rule[", i, "]") + } rule, err := R.NewRule(r.ctx, r.logger, options, false) if err != nil { return E.Cause(err, "parse rule[", i, "]") diff --git a/route/rule/nested_action.go b/route/rule/nested_action.go new file mode 100644 index 000000000..95bb57215 --- /dev/null +++ b/route/rule/nested_action.go @@ -0,0 +1,76 @@ +package rule + +import ( + "reflect" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +const ( + routeRuleActionNestedUnsupportedMessage = "rule action is not supported in nested rules" + dnsRuleActionNestedUnsupportedMessage = "DNS rule action is not supported in nested rules" +) + +func ValidateNoNestedRuleActions(rule option.Rule) error { + return validateNoNestedRuleActions(rule, false) +} + +func ValidateNoNestedDNSRuleActions(rule option.DNSRule) error { + return validateNoNestedDNSRuleActions(rule, false) +} + +func validateNoNestedRuleActions(rule option.Rule, nested bool) error { + if nested && ruleHasConfiguredAction(rule) { + return E.New(routeRuleActionNestedUnsupportedMessage) + } + if rule.Type != C.RuleTypeLogical { + return nil + } + for i, subRule := range rule.LogicalOptions.Rules { + err := validateNoNestedRuleActions(subRule, true) + if err != nil { + return E.Cause(err, "sub rule[", i, "]") + } + } + return nil +} + +func validateNoNestedDNSRuleActions(rule option.DNSRule, nested bool) error { + if nested && dnsRuleHasConfiguredAction(rule) { + return E.New(dnsRuleActionNestedUnsupportedMessage) + } + if rule.Type != C.RuleTypeLogical { + return nil + } + for i, subRule := range rule.LogicalOptions.Rules { + err := validateNoNestedDNSRuleActions(subRule, true) + if err != nil { + return E.Cause(err, "sub rule[", i, "]") + } + } + return nil +} + +func ruleHasConfiguredAction(rule option.Rule) bool { + switch rule.Type { + case "", C.RuleTypeDefault: + return !reflect.DeepEqual(rule.DefaultOptions.RuleAction, option.RuleAction{}) + case C.RuleTypeLogical: + return !reflect.DeepEqual(rule.LogicalOptions.RuleAction, option.RuleAction{}) + default: + return false + } +} + +func dnsRuleHasConfiguredAction(rule option.DNSRule) bool { + switch rule.Type { + case "", C.RuleTypeDefault: + return !reflect.DeepEqual(rule.DefaultOptions.DNSRuleAction, option.DNSRuleAction{}) + case C.RuleTypeLogical: + return !reflect.DeepEqual(rule.LogicalOptions.DNSRuleAction, option.DNSRuleAction{}) + default: + return false + } +} diff --git a/route/rule/nested_action_test.go b/route/rule/nested_action_test.go new file mode 100644 index 000000000..b21828c04 --- /dev/null +++ b/route/rule/nested_action_test.go @@ -0,0 +1,137 @@ +package rule + +import ( + "context" + "testing" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/json" + + "github.com/stretchr/testify/require" +) + +func TestNewRulePreservesImplicitTopLevelDefaultAction(t *testing.T) { + t.Parallel() + + var options option.Rule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "domain": "example.com" + }`), &options) + require.NoError(t, err) + + rule, err := NewRule(context.Background(), log.NewNOPFactory().NewLogger("router"), options, false) + require.NoError(t, err) + require.NotNil(t, rule.Action()) + require.Equal(t, C.RuleActionTypeRoute, rule.Action().Type()) +} + +func TestNewRuleAllowsNestedRuleWithoutAction(t *testing.T) { + t.Parallel() + + var options option.Rule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "rules": [ + {"domain": "example.com"} + ] + }`), &options) + require.NoError(t, err) + + rule, err := NewRule(context.Background(), log.NewNOPFactory().NewLogger("router"), options, false) + require.NoError(t, err) + require.NotNil(t, rule.Action()) + require.Equal(t, C.RuleActionTypeRoute, rule.Action().Type()) +} + +func TestNewRuleRejectsNestedRuleAction(t *testing.T) { + t.Parallel() + + _, err := NewRule(context.Background(), log.NewNOPFactory().NewLogger("router"), option.Rule{ + Type: C.RuleTypeLogical, + LogicalOptions: option.LogicalRule{ + RawLogicalRule: option.RawLogicalRule{ + Mode: C.LogicalTypeAnd, + Rules: []option.Rule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultRule{ + RuleAction: option.RuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.RouteActionOptions{ + Outbound: "direct", + }, + }, + }, + }}, + }, + }, + }, false) + require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage) +} + +func TestNewDNSRulePreservesImplicitTopLevelDefaultAction(t *testing.T) { + t.Parallel() + + var options option.DNSRule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "domain": "example.com" + }`), &options) + require.NoError(t, err) + + rule, err := NewDNSRule(context.Background(), log.NewNOPFactory().NewLogger("dns"), options, false, false) + require.NoError(t, err) + require.NotNil(t, rule.Action()) + require.Equal(t, C.RuleActionTypeRoute, rule.Action().Type()) +} + +func TestNewDNSRuleAllowsNestedRuleWithoutAction(t *testing.T) { + t.Parallel() + + var options option.DNSRule + err := json.UnmarshalContext(context.Background(), []byte(`{ + "type": "logical", + "mode": "and", + "rules": [ + {"domain": "example.com"} + ] + }`), &options) + require.NoError(t, err) + + rule, err := NewDNSRule(context.Background(), log.NewNOPFactory().NewLogger("dns"), options, false, false) + require.NoError(t, err) + require.NotNil(t, rule.Action()) + require.Equal(t, C.RuleActionTypeRoute, rule.Action().Type()) +} + +func TestNewDNSRuleRejectsNestedRuleAction(t *testing.T) { + t.Parallel() + + _, err := NewDNSRule(context.Background(), log.NewNOPFactory().NewLogger("dns"), option.DNSRule{ + Type: C.RuleTypeLogical, + LogicalOptions: option.LogicalDNSRule{ + RawLogicalDNSRule: option.RawLogicalDNSRule{ + Mode: C.LogicalTypeAnd, + Rules: []option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{ + Server: "default", + }, + }, + }, + }}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{ + Server: "default", + }, + }, + }, + }, true, false) + require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage) +} diff --git a/route/rule/rule_default.go b/route/rule/rule_default.go index 5ce1f87d4..d4de6ff7a 100644 --- a/route/rule/rule_default.go +++ b/route/rule/rule_default.go @@ -326,6 +326,10 @@ func NewLogicalRule(ctx context.Context, logger log.ContextLogger, options optio return nil, E.New("unknown logical mode: ", options.Mode) } for i, subOptions := range options.Rules { + err = validateNoNestedRuleActions(subOptions, true) + if err != nil { + return nil, E.Cause(err, "sub rule[", i, "]") + } subRule, err := NewRule(ctx, logger, subOptions, false) if err != nil { return nil, E.Cause(err, "sub rule[", i, "]") diff --git a/route/rule/rule_dns.go b/route/rule/rule_dns.go index 0cff49221..ebc556c0e 100644 --- a/route/rule/rule_dns.go +++ b/route/rule/rule_dns.go @@ -467,6 +467,10 @@ func NewLogicalDNSRule(ctx context.Context, logger log.ContextLogger, options op return nil, E.New("unknown logical mode: ", options.Mode) } for i, subRule := range options.Rules { + err := validateNoNestedDNSRuleActions(subRule, true) + if err != nil { + return nil, E.Cause(err, "sub rule[", i, "]") + } rule, err := NewDNSRule(ctx, logger, subRule, false, legacyDNSMode) if err != nil { return nil, E.Cause(err, "sub rule[", i, "]") diff --git a/route/rule/rule_item_rule_set_test.go b/route/rule/rule_item_rule_set_test.go index 3d5959a39..a73ed91f5 100644 --- a/route/rule/rule_item_rule_set_test.go +++ b/route/rule/rule_item_rule_set_test.go @@ -25,16 +25,21 @@ func (r *ruleSetItemTestRouter) Close() error { return nil } func (r *ruleSetItemTestRouter) PreMatch(adapter.InboundContext, tun.DirectRouteContext, time.Duration, bool) (tun.DirectRouteDestination, error) { return nil, nil } + func (r *ruleSetItemTestRouter) RouteConnection(context.Context, net.Conn, adapter.InboundContext) error { return nil } + func (r *ruleSetItemTestRouter) RoutePacketConnection(context.Context, N.PacketConn, adapter.InboundContext) error { return nil } + func (r *ruleSetItemTestRouter) RouteConnectionEx(context.Context, net.Conn, adapter.InboundContext, N.CloseHandlerFunc) { } + func (r *ruleSetItemTestRouter) RoutePacketConnectionEx(context.Context, N.PacketConn, adapter.InboundContext, N.CloseHandlerFunc) { } + func (r *ruleSetItemTestRouter) RuleSet(tag string) (adapter.RuleSet, bool) { ruleSet, loaded := r.ruleSets[tag] return ruleSet, loaded