dns: validate rule-set updates before commit
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
"github.com/sagernet/sing/service"
|
||||
|
||||
"go4.org/netipx"
|
||||
)
|
||||
@@ -73,3 +74,20 @@ func isIPCIDRHeadlessRule(rule option.DefaultHeadlessRule) bool {
|
||||
func isDNSQueryTypeHeadlessRule(rule option.DefaultHeadlessRule) bool {
|
||||
return len(rule.QueryType) > 0
|
||||
}
|
||||
|
||||
func buildRuleSetMetadata(headlessRules []option.HeadlessRule) adapter.RuleSetMetadata {
|
||||
return adapter.RuleSetMetadata{
|
||||
ContainsProcessRule: HasHeadlessRule(headlessRules, isProcessHeadlessRule),
|
||||
ContainsWIFIRule: HasHeadlessRule(headlessRules, isWIFIHeadlessRule),
|
||||
ContainsIPCIDRRule: HasHeadlessRule(headlessRules, isIPCIDRHeadlessRule),
|
||||
ContainsDNSQueryTypeRule: HasHeadlessRule(headlessRules, isDNSQueryTypeHeadlessRule),
|
||||
}
|
||||
}
|
||||
|
||||
func validateRuleSetMetadataUpdate(ctx context.Context, tag string, metadata adapter.RuleSetMetadata) error {
|
||||
validator := service.FromContext[adapter.DNSRuleSetUpdateValidator](ctx)
|
||||
if validator == nil {
|
||||
return nil
|
||||
}
|
||||
return validator.ValidateRuleSetMetadataUpdate(tag, metadata)
|
||||
}
|
||||
|
||||
@@ -137,11 +137,11 @@ func (s *LocalRuleSet) reloadRules(headlessRules []option.HeadlessRule) error {
|
||||
return E.Cause(err, "parse rule_set.rules.[", i, "]")
|
||||
}
|
||||
}
|
||||
var metadata adapter.RuleSetMetadata
|
||||
metadata.ContainsProcessRule = HasHeadlessRule(headlessRules, isProcessHeadlessRule)
|
||||
metadata.ContainsWIFIRule = HasHeadlessRule(headlessRules, isWIFIHeadlessRule)
|
||||
metadata.ContainsIPCIDRRule = HasHeadlessRule(headlessRules, isIPCIDRHeadlessRule)
|
||||
metadata.ContainsDNSQueryTypeRule = HasHeadlessRule(headlessRules, isDNSQueryTypeHeadlessRule)
|
||||
metadata := buildRuleSetMetadata(headlessRules)
|
||||
err = validateRuleSetMetadataUpdate(s.ctx, s.tag, metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.access.Lock()
|
||||
s.rules = rules
|
||||
s.metadata = metadata
|
||||
|
||||
@@ -189,11 +189,13 @@ func (s *RemoteRuleSet) loadBytes(content []byte) error {
|
||||
return E.Cause(err, "parse rule_set.rules.[", i, "]")
|
||||
}
|
||||
}
|
||||
metadata := buildRuleSetMetadata(plainRuleSet.Rules)
|
||||
err = validateRuleSetMetadataUpdate(s.ctx, s.options.Tag, metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.access.Lock()
|
||||
s.metadata.ContainsProcessRule = HasHeadlessRule(plainRuleSet.Rules, isProcessHeadlessRule)
|
||||
s.metadata.ContainsWIFIRule = HasHeadlessRule(plainRuleSet.Rules, isWIFIHeadlessRule)
|
||||
s.metadata.ContainsIPCIDRRule = HasHeadlessRule(plainRuleSet.Rules, isIPCIDRHeadlessRule)
|
||||
s.metadata.ContainsDNSQueryTypeRule = HasHeadlessRule(plainRuleSet.Rules, isDNSQueryTypeHeadlessRule)
|
||||
s.metadata = metadata
|
||||
s.rules = rules
|
||||
callbacks := s.callbacks.Array()
|
||||
s.access.Unlock()
|
||||
|
||||
110
route/rule/rule_set_update_validation_test.go
Normal file
110
route/rule/rule_set_update_validation_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
"github.com/sagernet/sing/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeDNSRuleSetUpdateValidator struct {
|
||||
validate func(tag string, metadata adapter.RuleSetMetadata) error
|
||||
}
|
||||
|
||||
func (v *fakeDNSRuleSetUpdateValidator) ValidateRuleSetMetadataUpdate(tag string, metadata adapter.RuleSetMetadata) error {
|
||||
if v.validate == nil {
|
||||
return nil
|
||||
}
|
||||
return v.validate(tag, metadata)
|
||||
}
|
||||
|
||||
func TestLocalRuleSetReloadRulesRejectsInvalidUpdateBeforeCommit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callbackCount atomic.Int32
|
||||
ctx := service.ContextWith[adapter.DNSRuleSetUpdateValidator](context.Background(), &fakeDNSRuleSetUpdateValidator{
|
||||
validate: func(tag string, metadata adapter.RuleSetMetadata) error {
|
||||
require.Equal(t, "dynamic-set", tag)
|
||||
if metadata.ContainsDNSQueryTypeRule {
|
||||
return E.New("dns conflict")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
ruleSet := &LocalRuleSet{
|
||||
ctx: ctx,
|
||||
tag: "dynamic-set",
|
||||
fileFormat: C.RuleSetFormatSource,
|
||||
}
|
||||
_ = ruleSet.callbacks.PushBack(func(adapter.RuleSet) {
|
||||
callbackCount.Add(1)
|
||||
})
|
||||
|
||||
err := ruleSet.reloadRules([]option.HeadlessRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultHeadlessRule{
|
||||
Domain: badoption.Listable[string]{"example.com"},
|
||||
},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(1), callbackCount.Load())
|
||||
require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule)
|
||||
require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"}))
|
||||
|
||||
err = ruleSet.reloadRules([]option.HeadlessRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultHeadlessRule{
|
||||
QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(1)},
|
||||
},
|
||||
}})
|
||||
require.ErrorContains(t, err, "dns conflict")
|
||||
require.Equal(t, int32(1), callbackCount.Load())
|
||||
require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule)
|
||||
require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"}))
|
||||
}
|
||||
|
||||
func TestRemoteRuleSetLoadBytesRejectsInvalidUpdateBeforeCommit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callbackCount atomic.Int32
|
||||
ctx := service.ContextWith[adapter.DNSRuleSetUpdateValidator](context.Background(), &fakeDNSRuleSetUpdateValidator{
|
||||
validate: func(tag string, metadata adapter.RuleSetMetadata) error {
|
||||
require.Equal(t, "dynamic-set", tag)
|
||||
if metadata.ContainsDNSQueryTypeRule {
|
||||
return E.New("dns conflict")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
ruleSet := &RemoteRuleSet{
|
||||
ctx: ctx,
|
||||
options: option.RuleSet{
|
||||
Tag: "dynamic-set",
|
||||
Format: C.RuleSetFormatSource,
|
||||
},
|
||||
callbacks: list.List[adapter.RuleSetUpdateCallback]{},
|
||||
}
|
||||
_ = ruleSet.callbacks.PushBack(func(adapter.RuleSet) {
|
||||
callbackCount.Add(1)
|
||||
})
|
||||
|
||||
err := ruleSet.loadBytes([]byte(`{"version":4,"rules":[{"domain":["example.com"]}]}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(1), callbackCount.Load())
|
||||
require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule)
|
||||
require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"}))
|
||||
|
||||
err = ruleSet.loadBytes([]byte(`{"version":4,"rules":[{"query_type":["A"]}]}`))
|
||||
require.ErrorContains(t, err, "dns conflict")
|
||||
require.Equal(t, int32(1), callbackCount.Load())
|
||||
require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule)
|
||||
require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"}))
|
||||
}
|
||||
Reference in New Issue
Block a user