option: reject nested rule actions

This commit is contained in:
世界
2026-03-29 01:34:27 +08:00
parent 385cd703d4
commit 4cfc1c6fbf
13 changed files with 747 additions and 15 deletions

View File

@@ -202,6 +202,12 @@ func (r *Router) isClosing() bool {
}
func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) {
for i, ruleOptions := range r.rawRules {
err := R.ValidateNoNestedDNSRuleActions(ruleOptions)
if err != nil {
return nil, false, E.Cause(err, "parse dns rule[", i, "]")
}
}
router := service.FromContext[adapter.Router](r.ctx)
legacyDNSMode, err := resolveLegacyDNSMode(router, r.rawRules)
if err != nil {

View File

@@ -135,6 +135,7 @@ func (s *fakeRuleSet) IncRef() {
defer s.access.Unlock()
s.refs++
}
func (s *fakeRuleSet) DecRef() {
s.access.Lock()
defer s.access.Unlock()
@@ -149,6 +150,7 @@ func (s *fakeRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback) *
defer s.access.Unlock()
return s.callbacks.PushBack(callback)
}
func (s *fakeRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSetUpdateCallback]) {
s.access.Lock()
defer s.access.Unlock()

View File

@@ -7,6 +7,7 @@ import (
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/service"
"github.com/stretchr/testify/require"
)

View File

@@ -1,6 +1,7 @@
package option
import (
"context"
"reflect"
C "github.com/sagernet/sing-box/constant"
@@ -33,26 +34,24 @@ func (r Rule) MarshalJSON() ([]byte, error) {
return badjson.MarshallObjects((_Rule)(r), v)
}
func (r *Rule) UnmarshalJSON(bytes []byte) error {
err := json.Unmarshal(bytes, (*_Rule)(r))
func (r *Rule) UnmarshalJSONContext(ctx context.Context, bytes []byte) error {
err := json.UnmarshalContext(ctx, bytes, (*_Rule)(r))
if err != nil {
return err
}
payload, err := rulePayloadWithoutType(ctx, bytes)
if err != nil {
return err
}
var v any
switch r.Type {
case "", C.RuleTypeDefault:
r.Type = C.RuleTypeDefault
v = &r.DefaultOptions
return unmarshalDefaultRuleContext(ctx, payload, &r.DefaultOptions)
case C.RuleTypeLogical:
v = &r.LogicalOptions
return unmarshalLogicalRuleContext(ctx, payload, &r.LogicalOptions)
default:
return E.New("unknown rule type: " + r.Type)
}
err = badjson.UnmarshallExcluded(bytes, (*_Rule)(r), v)
if err != nil {
return err
}
return nil
}
func (r Rule) IsValid() bool {
@@ -160,6 +159,64 @@ func (r *LogicalRule) UnmarshalJSON(data []byte) error {
return badjson.UnmarshallExcluded(data, &r.RawLogicalRule, &r.RuleAction)
}
func rulePayloadWithoutType(ctx context.Context, data []byte) ([]byte, error) {
var content badjson.JSONObject
err := content.UnmarshalJSONContext(ctx, data)
if err != nil {
return nil, err
}
content.Remove("type")
return content.MarshalJSONContext(ctx)
}
func unmarshalDefaultRuleContext(ctx context.Context, data []byte, rule *DefaultRule) error {
rawAction, routeOptions, err := inspectRouteRuleAction(ctx, data)
if err != nil {
return err
}
err = rejectNestedRouteRuleAction(ctx, data)
if err != nil {
return err
}
depth := nestedRuleDepth(ctx)
err = json.UnmarshalContext(ctx, data, &rule.RawDefaultRule)
if err != nil {
return err
}
err = badjson.UnmarshallExcludedContext(ctx, data, &rule.RawDefaultRule, &rule.RuleAction)
if err != nil {
return err
}
if depth > 0 && rawAction == "" && routeOptions == (RouteActionOptions{}) {
rule.RuleAction = RuleAction{}
}
return nil
}
func unmarshalLogicalRuleContext(ctx context.Context, data []byte, rule *LogicalRule) error {
rawAction, routeOptions, err := inspectRouteRuleAction(ctx, data)
if err != nil {
return err
}
err = rejectNestedRouteRuleAction(ctx, data)
if err != nil {
return err
}
depth := nestedRuleDepth(ctx)
err = json.UnmarshalContext(nestedRuleChildContext(ctx), data, &rule.RawLogicalRule)
if err != nil {
return err
}
err = badjson.UnmarshallExcludedContext(ctx, data, &rule.RawLogicalRule, &rule.RuleAction)
if err != nil {
return err
}
if depth > 0 && rawAction == "" && routeOptions == (RouteActionOptions{}) {
rule.RuleAction = RuleAction{}
}
return nil
}
func (r *LogicalRule) IsValid() bool {
return len(r.Rules) > 0 && common.All(r.Rules, Rule.IsValid)
}

View File

@@ -35,7 +35,7 @@ func (r DNSRule) MarshalJSON() ([]byte, error) {
}
func (r *DNSRule) UnmarshalJSONContext(ctx context.Context, bytes []byte) error {
err := json.Unmarshal(bytes, (*_DNSRule)(r))
err := json.UnmarshalContext(ctx, bytes, (*_DNSRule)(r))
if err != nil {
return err
}
@@ -135,11 +135,27 @@ func (r DefaultDNSRule) MarshalJSON() ([]byte, error) {
}
func (r *DefaultDNSRule) UnmarshalJSONContext(ctx context.Context, data []byte) error {
err := json.UnmarshalContext(ctx, data, &r.RawDefaultDNSRule)
rawAction, routeOptions, err := inspectDNSRuleAction(ctx, data)
if err != nil {
return err
}
return badjson.UnmarshallExcludedContext(ctx, data, &r.RawDefaultDNSRule, &r.DNSRuleAction)
err = rejectNestedDNSRuleAction(ctx, data)
if err != nil {
return err
}
depth := nestedRuleDepth(ctx)
err = json.UnmarshalContext(ctx, data, &r.RawDefaultDNSRule)
if err != nil {
return err
}
err = badjson.UnmarshallExcludedContext(ctx, data, &r.RawDefaultDNSRule, &r.DNSRuleAction)
if err != nil {
return err
}
if depth > 0 && rawAction == "" && routeOptions == (DNSRouteActionOptions{}) {
r.DNSRuleAction = DNSRuleAction{}
}
return nil
}
func (r DefaultDNSRule) IsValid() bool {
@@ -164,11 +180,27 @@ func (r LogicalDNSRule) MarshalJSON() ([]byte, error) {
}
func (r *LogicalDNSRule) UnmarshalJSONContext(ctx context.Context, data []byte) error {
err := json.Unmarshal(data, &r.RawLogicalDNSRule)
rawAction, routeOptions, err := inspectDNSRuleAction(ctx, data)
if err != nil {
return err
}
return badjson.UnmarshallExcludedContext(ctx, data, &r.RawLogicalDNSRule, &r.DNSRuleAction)
err = rejectNestedDNSRuleAction(ctx, data)
if err != nil {
return err
}
depth := nestedRuleDepth(ctx)
err = json.UnmarshalContext(nestedRuleChildContext(ctx), data, &r.RawLogicalDNSRule)
if err != nil {
return err
}
err = badjson.UnmarshallExcludedContext(ctx, data, &r.RawLogicalDNSRule, &r.DNSRuleAction)
if err != nil {
return err
}
if depth > 0 && rawAction == "" && routeOptions == (DNSRouteActionOptions{}) {
r.DNSRuleAction = DNSRuleAction{}
}
return nil
}
func (r *LogicalDNSRule) IsValid() bool {

133
option/rule_nested.go Normal file
View File

@@ -0,0 +1,133 @@
package option
import (
"context"
"reflect"
"strings"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/json/badjson"
)
type nestedRuleDepthContextKey struct{}
const (
routeRuleActionNestedUnsupportedMessage = "rule action is not supported in nested rules"
dnsRuleActionNestedUnsupportedMessage = "DNS rule action is not supported in nested rules"
)
var (
routeRuleActionKeys = jsonFieldNames(reflect.TypeFor[_RuleAction](), reflect.TypeFor[RouteActionOptions]())
dnsRuleActionKeys = jsonFieldNames(reflect.TypeFor[_DNSRuleAction](), reflect.TypeFor[DNSRouteActionOptions]())
)
func nestedRuleChildContext(ctx context.Context) context.Context {
return context.WithValue(ctx, nestedRuleDepthContextKey{}, nestedRuleDepth(ctx)+1)
}
func rejectNestedRouteRuleAction(ctx context.Context, content []byte) error {
return rejectNestedRuleAction(ctx, content, routeRuleActionKeys, routeRuleActionNestedUnsupportedMessage)
}
func rejectNestedDNSRuleAction(ctx context.Context, content []byte) error {
return rejectNestedRuleAction(ctx, content, dnsRuleActionKeys, dnsRuleActionNestedUnsupportedMessage)
}
func nestedRuleDepth(ctx context.Context) int {
depth, _ := ctx.Value(nestedRuleDepthContextKey{}).(int)
return depth
}
func rejectNestedRuleAction(ctx context.Context, content []byte, keys []string, message string) error {
if nestedRuleDepth(ctx) == 0 {
return nil
}
hasActionKey, err := hasAnyJSONKey(ctx, content, keys...)
if err != nil {
return err
}
if hasActionKey {
return E.New(message)
}
return nil
}
func hasAnyJSONKey(ctx context.Context, content []byte, keys ...string) (bool, error) {
var object badjson.JSONObject
err := object.UnmarshalJSONContext(ctx, content)
if err != nil {
return false, err
}
for _, key := range keys {
if object.ContainsKey(key) {
return true, nil
}
}
return false, nil
}
func inspectRouteRuleAction(ctx context.Context, content []byte) (string, RouteActionOptions, error) {
var rawAction _RuleAction
err := json.UnmarshalContext(ctx, content, &rawAction)
if err != nil {
return "", RouteActionOptions{}, err
}
var routeOptions RouteActionOptions
err = json.UnmarshalContext(ctx, content, &routeOptions)
if err != nil {
return "", RouteActionOptions{}, err
}
return rawAction.Action, routeOptions, nil
}
func inspectDNSRuleAction(ctx context.Context, content []byte) (string, DNSRouteActionOptions, error) {
var rawAction _DNSRuleAction
err := json.UnmarshalContext(ctx, content, &rawAction)
if err != nil {
return "", DNSRouteActionOptions{}, err
}
var routeOptions DNSRouteActionOptions
err = json.UnmarshalContext(ctx, content, &routeOptions)
if err != nil {
return "", DNSRouteActionOptions{}, err
}
return rawAction.Action, routeOptions, nil
}
func jsonFieldNames(types ...reflect.Type) []string {
fieldMap := make(map[string]struct{})
for _, fieldType := range types {
appendJSONFieldNames(fieldMap, fieldType)
}
fieldNames := make([]string, 0, len(fieldMap))
for fieldName := range fieldMap {
fieldNames = append(fieldNames, fieldName)
}
return fieldNames
}
func appendJSONFieldNames(fieldMap map[string]struct{}, fieldType reflect.Type) {
for fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem()
}
if fieldType.Kind() != reflect.Struct {
return
}
for i := range fieldType.NumField() {
field := fieldType.Field(i)
tagValue := field.Tag.Get("json")
tagName, _, _ := strings.Cut(tagValue, ",")
if tagName == "-" {
continue
}
if field.Anonymous && tagName == "" {
appendJSONFieldNames(fieldMap, field.Type)
continue
}
if tagName == "" {
tagName = field.Name
}
fieldMap[tagName] = struct{}{}
}
}

271
option/rule_nested_test.go Normal file
View File

@@ -0,0 +1,271 @@
package option
import (
"context"
"testing"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common/json"
"github.com/stretchr/testify/require"
)
func TestRuleRejectsNestedDefaultRuleAction(t *testing.T) {
t.Parallel()
var rule Rule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"rules": [
{"domain": "example.com", "outbound": "direct"}
]
}`), &rule)
require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage)
}
func TestRuleRejectsNestedLogicalRuleAction(t *testing.T) {
t.Parallel()
var rule Rule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"rules": [
{
"type": "logical",
"mode": "or",
"action": "route",
"outbound": "direct",
"rules": [{"domain": "example.com"}]
}
]
}`), &rule)
require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage)
}
func TestRuleRejectsNestedDefaultRuleZeroValueOutbound(t *testing.T) {
t.Parallel()
var rule Rule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"rules": [
{"domain": "example.com", "outbound": ""}
]
}`), &rule)
require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage)
}
func TestRuleRejectsNestedDefaultRuleZeroValueRouteOption(t *testing.T) {
t.Parallel()
var rule Rule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"rules": [
{"domain": "example.com", "udp_connect": false}
]
}`), &rule)
require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage)
}
func TestRuleRejectsNestedLogicalRuleZeroValueAction(t *testing.T) {
t.Parallel()
var rule Rule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"rules": [
{
"type": "logical",
"mode": "or",
"action": "",
"rules": [{"domain": "example.com"}]
}
]
}`), &rule)
require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage)
}
func TestRuleRejectsNestedLogicalRuleZeroValueRouteOption(t *testing.T) {
t.Parallel()
var rule Rule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"rules": [
{
"type": "logical",
"mode": "or",
"override_port": 0,
"rules": [{"domain": "example.com"}]
}
]
}`), &rule)
require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage)
}
func TestRuleAllowsTopLevelLogicalAction(t *testing.T) {
t.Parallel()
var rule Rule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"outbound": "direct",
"rules": [{"domain": "example.com"}]
}`), &rule)
require.NoError(t, err)
require.Equal(t, C.RuleActionTypeRoute, rule.LogicalOptions.Action)
require.Equal(t, "direct", rule.LogicalOptions.RouteOptions.Outbound)
}
func TestRuleLeavesUnknownNestedKeysToNormalValidation(t *testing.T) {
t.Parallel()
var rule Rule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"rules": [
{"domain": "example.com", "foo": "bar"}
]
}`), &rule)
require.ErrorContains(t, err, "unknown field")
require.NotContains(t, err.Error(), routeRuleActionNestedUnsupportedMessage)
}
func TestDNSRuleRejectsNestedDefaultRuleAction(t *testing.T) {
t.Parallel()
var rule DNSRule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"rules": [
{"domain": "example.com", "server": "default"}
]
}`), &rule)
require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage)
}
func TestDNSRuleRejectsNestedLogicalRuleAction(t *testing.T) {
t.Parallel()
var rule DNSRule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"rules": [
{
"type": "logical",
"mode": "or",
"action": "route",
"server": "default",
"rules": [{"domain": "example.com"}]
}
]
}`), &rule)
require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage)
}
func TestDNSRuleRejectsNestedDefaultRuleZeroValueServer(t *testing.T) {
t.Parallel()
var rule DNSRule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"rules": [
{"domain": "example.com", "server": ""}
]
}`), &rule)
require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage)
}
func TestDNSRuleRejectsNestedDefaultRuleZeroValueRouteOption(t *testing.T) {
t.Parallel()
var rule DNSRule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"rules": [
{"domain": "example.com", "disable_cache": false}
]
}`), &rule)
require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage)
}
func TestDNSRuleRejectsNestedLogicalRuleZeroValueAction(t *testing.T) {
t.Parallel()
var rule DNSRule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"rules": [
{
"type": "logical",
"mode": "or",
"action": "",
"rules": [{"domain": "example.com"}]
}
]
}`), &rule)
require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage)
}
func TestDNSRuleRejectsNestedLogicalRuleZeroValueRouteOption(t *testing.T) {
t.Parallel()
var rule DNSRule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"rules": [
{
"type": "logical",
"mode": "or",
"disable_cache": false,
"rules": [{"domain": "example.com"}]
}
]
}`), &rule)
require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage)
}
func TestDNSRuleAllowsTopLevelLogicalAction(t *testing.T) {
t.Parallel()
var rule DNSRule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"server": "default",
"rules": [{"domain": "example.com"}]
}`), &rule)
require.NoError(t, err)
require.Equal(t, C.RuleActionTypeRoute, rule.LogicalOptions.Action)
require.Equal(t, "default", rule.LogicalOptions.RouteOptions.Server)
}
func TestDNSRuleLeavesUnknownNestedKeysToNormalValidation(t *testing.T) {
t.Parallel()
var rule DNSRule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"rules": [
{"domain": "example.com", "foo": "bar"}
]
}`), &rule)
require.ErrorContains(t, err, "unknown field")
require.NotContains(t, err.Error(), dnsRuleActionNestedUnsupportedMessage)
}

View File

@@ -70,6 +70,10 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
func (r *Router) Initialize(rules []option.Rule, ruleSets []option.RuleSet) error {
for i, options := range rules {
err := R.ValidateNoNestedRuleActions(options)
if err != nil {
return E.Cause(err, "parse rule[", i, "]")
}
rule, err := R.NewRule(r.ctx, r.logger, options, false)
if err != nil {
return E.Cause(err, "parse rule[", i, "]")

View File

@@ -0,0 +1,76 @@
package rule
import (
"reflect"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
const (
routeRuleActionNestedUnsupportedMessage = "rule action is not supported in nested rules"
dnsRuleActionNestedUnsupportedMessage = "DNS rule action is not supported in nested rules"
)
func ValidateNoNestedRuleActions(rule option.Rule) error {
return validateNoNestedRuleActions(rule, false)
}
func ValidateNoNestedDNSRuleActions(rule option.DNSRule) error {
return validateNoNestedDNSRuleActions(rule, false)
}
func validateNoNestedRuleActions(rule option.Rule, nested bool) error {
if nested && ruleHasConfiguredAction(rule) {
return E.New(routeRuleActionNestedUnsupportedMessage)
}
if rule.Type != C.RuleTypeLogical {
return nil
}
for i, subRule := range rule.LogicalOptions.Rules {
err := validateNoNestedRuleActions(subRule, true)
if err != nil {
return E.Cause(err, "sub rule[", i, "]")
}
}
return nil
}
func validateNoNestedDNSRuleActions(rule option.DNSRule, nested bool) error {
if nested && dnsRuleHasConfiguredAction(rule) {
return E.New(dnsRuleActionNestedUnsupportedMessage)
}
if rule.Type != C.RuleTypeLogical {
return nil
}
for i, subRule := range rule.LogicalOptions.Rules {
err := validateNoNestedDNSRuleActions(subRule, true)
if err != nil {
return E.Cause(err, "sub rule[", i, "]")
}
}
return nil
}
func ruleHasConfiguredAction(rule option.Rule) bool {
switch rule.Type {
case "", C.RuleTypeDefault:
return !reflect.DeepEqual(rule.DefaultOptions.RuleAction, option.RuleAction{})
case C.RuleTypeLogical:
return !reflect.DeepEqual(rule.LogicalOptions.RuleAction, option.RuleAction{})
default:
return false
}
}
func dnsRuleHasConfiguredAction(rule option.DNSRule) bool {
switch rule.Type {
case "", C.RuleTypeDefault:
return !reflect.DeepEqual(rule.DefaultOptions.DNSRuleAction, option.DNSRuleAction{})
case C.RuleTypeLogical:
return !reflect.DeepEqual(rule.LogicalOptions.DNSRuleAction, option.DNSRuleAction{})
default:
return false
}
}

View File

@@ -0,0 +1,137 @@
package rule
import (
"context"
"testing"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/json"
"github.com/stretchr/testify/require"
)
func TestNewRulePreservesImplicitTopLevelDefaultAction(t *testing.T) {
t.Parallel()
var options option.Rule
err := json.UnmarshalContext(context.Background(), []byte(`{
"domain": "example.com"
}`), &options)
require.NoError(t, err)
rule, err := NewRule(context.Background(), log.NewNOPFactory().NewLogger("router"), options, false)
require.NoError(t, err)
require.NotNil(t, rule.Action())
require.Equal(t, C.RuleActionTypeRoute, rule.Action().Type())
}
func TestNewRuleAllowsNestedRuleWithoutAction(t *testing.T) {
t.Parallel()
var options option.Rule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"rules": [
{"domain": "example.com"}
]
}`), &options)
require.NoError(t, err)
rule, err := NewRule(context.Background(), log.NewNOPFactory().NewLogger("router"), options, false)
require.NoError(t, err)
require.NotNil(t, rule.Action())
require.Equal(t, C.RuleActionTypeRoute, rule.Action().Type())
}
func TestNewRuleRejectsNestedRuleAction(t *testing.T) {
t.Parallel()
_, err := NewRule(context.Background(), log.NewNOPFactory().NewLogger("router"), option.Rule{
Type: C.RuleTypeLogical,
LogicalOptions: option.LogicalRule{
RawLogicalRule: option.RawLogicalRule{
Mode: C.LogicalTypeAnd,
Rules: []option.Rule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultRule{
RuleAction: option.RuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.RouteActionOptions{
Outbound: "direct",
},
},
},
}},
},
},
}, false)
require.ErrorContains(t, err, routeRuleActionNestedUnsupportedMessage)
}
func TestNewDNSRulePreservesImplicitTopLevelDefaultAction(t *testing.T) {
t.Parallel()
var options option.DNSRule
err := json.UnmarshalContext(context.Background(), []byte(`{
"domain": "example.com"
}`), &options)
require.NoError(t, err)
rule, err := NewDNSRule(context.Background(), log.NewNOPFactory().NewLogger("dns"), options, false, false)
require.NoError(t, err)
require.NotNil(t, rule.Action())
require.Equal(t, C.RuleActionTypeRoute, rule.Action().Type())
}
func TestNewDNSRuleAllowsNestedRuleWithoutAction(t *testing.T) {
t.Parallel()
var options option.DNSRule
err := json.UnmarshalContext(context.Background(), []byte(`{
"type": "logical",
"mode": "and",
"rules": [
{"domain": "example.com"}
]
}`), &options)
require.NoError(t, err)
rule, err := NewDNSRule(context.Background(), log.NewNOPFactory().NewLogger("dns"), options, false, false)
require.NoError(t, err)
require.NotNil(t, rule.Action())
require.Equal(t, C.RuleActionTypeRoute, rule.Action().Type())
}
func TestNewDNSRuleRejectsNestedRuleAction(t *testing.T) {
t.Parallel()
_, err := NewDNSRule(context.Background(), log.NewNOPFactory().NewLogger("dns"), option.DNSRule{
Type: C.RuleTypeLogical,
LogicalOptions: option.LogicalDNSRule{
RawLogicalDNSRule: option.RawLogicalDNSRule{
Mode: C.LogicalTypeAnd,
Rules: []option.DNSRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{
Server: "default",
},
},
},
}},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{
Server: "default",
},
},
},
}, true, false)
require.ErrorContains(t, err, dnsRuleActionNestedUnsupportedMessage)
}

View File

@@ -326,6 +326,10 @@ func NewLogicalRule(ctx context.Context, logger log.ContextLogger, options optio
return nil, E.New("unknown logical mode: ", options.Mode)
}
for i, subOptions := range options.Rules {
err = validateNoNestedRuleActions(subOptions, true)
if err != nil {
return nil, E.Cause(err, "sub rule[", i, "]")
}
subRule, err := NewRule(ctx, logger, subOptions, false)
if err != nil {
return nil, E.Cause(err, "sub rule[", i, "]")

View File

@@ -467,6 +467,10 @@ func NewLogicalDNSRule(ctx context.Context, logger log.ContextLogger, options op
return nil, E.New("unknown logical mode: ", options.Mode)
}
for i, subRule := range options.Rules {
err := validateNoNestedDNSRuleActions(subRule, true)
if err != nil {
return nil, E.Cause(err, "sub rule[", i, "]")
}
rule, err := NewDNSRule(ctx, logger, subRule, false, legacyDNSMode)
if err != nil {
return nil, E.Cause(err, "sub rule[", i, "]")

View File

@@ -25,16 +25,21 @@ func (r *ruleSetItemTestRouter) Close() error { return nil }
func (r *ruleSetItemTestRouter) PreMatch(adapter.InboundContext, tun.DirectRouteContext, time.Duration, bool) (tun.DirectRouteDestination, error) {
return nil, nil
}
func (r *ruleSetItemTestRouter) RouteConnection(context.Context, net.Conn, adapter.InboundContext) error {
return nil
}
func (r *ruleSetItemTestRouter) RoutePacketConnection(context.Context, N.PacketConn, adapter.InboundContext) error {
return nil
}
func (r *ruleSetItemTestRouter) RouteConnectionEx(context.Context, net.Conn, adapter.InboundContext, N.CloseHandlerFunc) {
}
func (r *ruleSetItemTestRouter) RoutePacketConnectionEx(context.Context, N.PacketConn, adapter.InboundContext, N.CloseHandlerFunc) {
}
func (r *ruleSetItemTestRouter) RuleSet(tag string) (adapter.RuleSet, bool) {
ruleSet, loaded := r.ruleSets[tag]
return ruleSet, loaded