mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-14 04:38:28 +10:00
dns: validate rule-set updates before commit
This commit is contained in:
@@ -66,6 +66,10 @@ type RuleSet interface {
|
||||
|
||||
type RuleSetUpdateCallback func(it RuleSet)
|
||||
|
||||
type DNSRuleSetUpdateValidator interface {
|
||||
ValidateRuleSetMetadataUpdate(tag string, metadata RuleSetMetadata) error
|
||||
}
|
||||
|
||||
// ip_version is not a headless-rule item, so ContainsIPVersionRule is intentionally absent.
|
||||
type RuleSetMetadata struct {
|
||||
ContainsProcessRule bool
|
||||
|
||||
1
box.go
1
box.go
@@ -199,6 +199,7 @@ func New(options Options) (*Box, error) {
|
||||
service.MustRegister[adapter.CertificateProviderManager](ctx, certificateProviderManager)
|
||||
dnsRouter := dns.NewRouter(ctx, logFactory, dnsOptions)
|
||||
service.MustRegister[adapter.DNSRouter](ctx, dnsRouter)
|
||||
service.MustRegister[adapter.DNSRuleSetUpdateValidator](ctx, dnsRouter)
|
||||
networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions, dnsOptions)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "initialize network manager")
|
||||
|
||||
309
dns/router.go
309
dns/router.go
@@ -6,7 +6,6 @@ import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
@@ -23,7 +22,6 @@ import (
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
"github.com/sagernet/sing/contrab/freelru"
|
||||
"github.com/sagernet/sing/contrab/maphash"
|
||||
"github.com/sagernet/sing/service"
|
||||
@@ -32,54 +30,7 @@ import (
|
||||
)
|
||||
|
||||
var _ adapter.DNSRouter = (*Router)(nil)
|
||||
|
||||
type dnsRuleSetCallback struct {
|
||||
ruleSet adapter.RuleSet
|
||||
element *list.Element[adapter.RuleSetUpdateCallback]
|
||||
}
|
||||
|
||||
type rulesSnapshot struct {
|
||||
rules []adapter.DNSRule
|
||||
legacyDNSMode bool
|
||||
references atomic.Int64
|
||||
}
|
||||
|
||||
func newRulesSnapshot(rules []adapter.DNSRule, legacyDNSMode bool) *rulesSnapshot {
|
||||
snapshot := &rulesSnapshot{
|
||||
rules: rules,
|
||||
legacyDNSMode: legacyDNSMode,
|
||||
}
|
||||
snapshot.references.Store(1)
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (s *rulesSnapshot) retain() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.references.Add(1)
|
||||
}
|
||||
|
||||
func (s *rulesSnapshot) rulesAndMode() ([]adapter.DNSRule, bool) {
|
||||
if s == nil {
|
||||
return nil, false
|
||||
}
|
||||
return s.rules, s.legacyDNSMode
|
||||
}
|
||||
|
||||
func (s *rulesSnapshot) release() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
references := s.references.Add(-1)
|
||||
switch {
|
||||
case references > 0:
|
||||
case references == 0:
|
||||
closeRules(s.rules)
|
||||
default:
|
||||
panic("dns: negative rules snapshot references")
|
||||
}
|
||||
}
|
||||
var _ adapter.DNSRuleSetUpdateValidator = (*Router)(nil)
|
||||
|
||||
type Router struct {
|
||||
ctx context.Context
|
||||
@@ -88,14 +39,14 @@ type Router struct {
|
||||
outbound adapter.OutboundManager
|
||||
client adapter.DNSClient
|
||||
rawRules []option.DNSRule
|
||||
currentRules atomic.Pointer[rulesSnapshot]
|
||||
rules []adapter.DNSRule
|
||||
defaultDomainStrategy C.DomainStrategy
|
||||
dnsReverseMapping freelru.Cache[netip.Addr, string]
|
||||
platformInterface adapter.PlatformInterface
|
||||
rebuildAccess sync.Mutex
|
||||
stateAccess sync.Mutex
|
||||
legacyDNSMode bool
|
||||
rulesAccess sync.RWMutex
|
||||
started bool
|
||||
closing bool
|
||||
ruleSetCallbacks []dnsRuleSetCallback
|
||||
addressFilterDeprecatedReported bool
|
||||
ruleStrategyDeprecatedReported bool
|
||||
}
|
||||
@@ -107,9 +58,9 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp
|
||||
transport: service.FromContext[adapter.DNSTransportManager](ctx),
|
||||
outbound: service.FromContext[adapter.OutboundManager](ctx),
|
||||
rawRules: make([]option.DNSRule, 0, len(options.Rules)),
|
||||
rules: make([]adapter.DNSRule, 0, len(options.Rules)),
|
||||
defaultDomainStrategy: C.DomainStrategy(options.Strategy),
|
||||
}
|
||||
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, len(options.Rules)), false))
|
||||
router.client = NewClient(ClientOptions{
|
||||
DisableCache: options.DNSClientOptions.DisableCache,
|
||||
DisableExpire: options.DNSClientOptions.DisableExpire,
|
||||
@@ -153,107 +104,57 @@ func (r *Router) Start(stage adapter.StartStage) error {
|
||||
monitor.Finish()
|
||||
|
||||
monitor.Start("initialize DNS rules")
|
||||
err := r.rebuildRules(true)
|
||||
newRules, legacyDNSMode, modeFlags, err := r.buildRules(true)
|
||||
monitor.Finish()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
monitor.Start("register DNS rule-set callbacks")
|
||||
needsRulesRefresh, err := r.registerRuleSetCallbacks()
|
||||
monitor.Finish()
|
||||
if err != nil {
|
||||
return err
|
||||
shouldReportAddressFilterDeprecated := legacyDNSMode &&
|
||||
!r.addressFilterDeprecatedReported &&
|
||||
common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() })
|
||||
shouldReportRuleStrategyDeprecated := legacyDNSMode &&
|
||||
!r.ruleStrategyDeprecatedReported &&
|
||||
modeFlags.neededFromStrategy
|
||||
r.rulesAccess.Lock()
|
||||
if r.closing {
|
||||
r.rulesAccess.Unlock()
|
||||
closeRules(newRules)
|
||||
return nil
|
||||
}
|
||||
if needsRulesRefresh {
|
||||
monitor.Start("refresh DNS rules after callback registration")
|
||||
err = r.rebuildRules(true)
|
||||
monitor.Finish()
|
||||
if err != nil {
|
||||
r.logger.Error(E.Cause(err, "refresh DNS rules after callback registration"))
|
||||
}
|
||||
r.rules = newRules
|
||||
r.legacyDNSMode = legacyDNSMode
|
||||
r.started = true
|
||||
if shouldReportAddressFilterDeprecated {
|
||||
r.addressFilterDeprecatedReported = true
|
||||
}
|
||||
if shouldReportRuleStrategyDeprecated {
|
||||
r.ruleStrategyDeprecatedReported = true
|
||||
}
|
||||
r.rulesAccess.Unlock()
|
||||
if shouldReportAddressFilterDeprecated {
|
||||
deprecated.Report(r.ctx, deprecated.OptionLegacyDNSAddressFilter)
|
||||
}
|
||||
if shouldReportRuleStrategyDeprecated {
|
||||
deprecated.Report(r.ctx, deprecated.OptionLegacyDNSRuleStrategy)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) Close() error {
|
||||
r.stateAccess.Lock()
|
||||
r.rulesAccess.Lock()
|
||||
if r.closing {
|
||||
r.stateAccess.Unlock()
|
||||
r.rulesAccess.Unlock()
|
||||
return nil
|
||||
}
|
||||
r.closing = true
|
||||
callbacks := r.ruleSetCallbacks
|
||||
r.ruleSetCallbacks = nil
|
||||
oldSnapshot := r.currentRules.Swap(nil)
|
||||
for _, callback := range callbacks {
|
||||
callback.ruleSet.UnregisterCallback(callback.element)
|
||||
}
|
||||
r.stateAccess.Unlock()
|
||||
oldSnapshot.release()
|
||||
runtimeRules := r.rules
|
||||
r.rules = nil
|
||||
r.rulesAccess.Unlock()
|
||||
closeRules(runtimeRules)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) rebuildRules(startRules bool) error {
|
||||
r.rebuildAccess.Lock()
|
||||
defer r.rebuildAccess.Unlock()
|
||||
if r.isClosing() {
|
||||
return nil
|
||||
}
|
||||
newRules, legacyDNSMode, modeFlags, err := r.buildRules(startRules)
|
||||
if err != nil {
|
||||
if r.isClosing() {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
shouldReportAddressFilterDeprecated := startRules &&
|
||||
legacyDNSMode &&
|
||||
!r.addressFilterDeprecatedReported &&
|
||||
common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() })
|
||||
shouldReportRuleStrategyDeprecated := startRules &&
|
||||
legacyDNSMode &&
|
||||
!r.ruleStrategyDeprecatedReported &&
|
||||
modeFlags.neededFromStrategy
|
||||
newSnapshot := newRulesSnapshot(newRules, legacyDNSMode)
|
||||
r.stateAccess.Lock()
|
||||
if r.closing {
|
||||
r.stateAccess.Unlock()
|
||||
newSnapshot.release()
|
||||
return nil
|
||||
}
|
||||
if shouldReportAddressFilterDeprecated {
|
||||
r.addressFilterDeprecatedReported = true
|
||||
}
|
||||
if shouldReportRuleStrategyDeprecated {
|
||||
r.ruleStrategyDeprecatedReported = true
|
||||
}
|
||||
oldSnapshot := r.currentRules.Swap(newSnapshot)
|
||||
r.stateAccess.Unlock()
|
||||
oldSnapshot.release()
|
||||
if shouldReportAddressFilterDeprecated {
|
||||
deprecated.Report(r.ctx, deprecated.OptionLegacyDNSAddressFilter)
|
||||
}
|
||||
if shouldReportRuleStrategyDeprecated {
|
||||
deprecated.Report(r.ctx, deprecated.OptionLegacyDNSRuleStrategy)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) isClosing() bool {
|
||||
r.stateAccess.Lock()
|
||||
defer r.stateAccess.Unlock()
|
||||
return r.closing
|
||||
}
|
||||
|
||||
func (r *Router) acquireRulesSnapshot() *rulesSnapshot {
|
||||
r.stateAccess.Lock()
|
||||
defer r.stateAccess.Unlock()
|
||||
snapshot := r.currentRules.Load()
|
||||
snapshot.retain()
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, dnsRuleModeFlags, error) {
|
||||
for i, ruleOptions := range r.rawRules {
|
||||
err := R.ValidateNoNestedDNSRuleActions(ruleOptions)
|
||||
@@ -262,7 +163,7 @@ func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, dnsRuleMo
|
||||
}
|
||||
}
|
||||
router := service.FromContext[adapter.Router](r.ctx)
|
||||
legacyDNSMode, modeFlags, err := resolveLegacyDNSMode(router, r.rawRules)
|
||||
legacyDNSMode, modeFlags, err := resolveLegacyDNSMode(router, r.rawRules, nil)
|
||||
if err != nil {
|
||||
return nil, false, dnsRuleModeFlags{}, err
|
||||
}
|
||||
@@ -304,51 +205,53 @@ func closeRules(rules []adapter.DNSRule) {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Router) registerRuleSetCallbacks() (bool, error) {
|
||||
tags := referencedDNSRuleSetTags(r.rawRules)
|
||||
if len(tags) == 0 {
|
||||
return false, nil
|
||||
func (r *Router) ValidateRuleSetMetadataUpdate(tag string, metadata adapter.RuleSetMetadata) error {
|
||||
if len(r.rawRules) == 0 {
|
||||
return nil
|
||||
}
|
||||
r.stateAccess.Lock()
|
||||
if len(r.ruleSetCallbacks) > 0 {
|
||||
r.stateAccess.Unlock()
|
||||
return true, nil
|
||||
}
|
||||
r.stateAccess.Unlock()
|
||||
router := service.FromContext[adapter.Router](r.ctx)
|
||||
if router == nil {
|
||||
return false, E.New("router service not found")
|
||||
return E.New("router service not found")
|
||||
}
|
||||
callbacks := make([]dnsRuleSetCallback, 0, len(tags))
|
||||
for _, tag := range tags {
|
||||
ruleSet, loaded := router.RuleSet(tag)
|
||||
if !loaded {
|
||||
for _, callback := range callbacks {
|
||||
callback.ruleSet.UnregisterCallback(callback.element)
|
||||
}
|
||||
return false, E.New("rule-set not found: ", tag)
|
||||
overrides := map[string]adapter.RuleSetMetadata{
|
||||
tag: metadata,
|
||||
}
|
||||
r.rulesAccess.RLock()
|
||||
started := r.started
|
||||
legacyDNSMode := r.legacyDNSMode
|
||||
closing := r.closing
|
||||
r.rulesAccess.RUnlock()
|
||||
if closing {
|
||||
return nil
|
||||
}
|
||||
if !started {
|
||||
candidateLegacyDNSMode, _, err := resolveLegacyDNSMode(router, r.rawRules, overrides)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
element := ruleSet.RegisterCallback(func(adapter.RuleSet) {
|
||||
err := r.rebuildRules(true)
|
||||
if !candidateLegacyDNSMode {
|
||||
return validateLegacyDNSModeDisabledRules(r.rawRules)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
_, flags, err := resolveLegacyDNSMode(router, r.rawRules, overrides)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if legacyDNSMode {
|
||||
if flags.disabled {
|
||||
err := validateLegacyDNSModeDisabledRules(r.rawRules)
|
||||
if err != nil {
|
||||
r.logger.Error(E.Cause(err, "rebuild DNS rules after rule-set update"))
|
||||
return err
|
||||
}
|
||||
})
|
||||
callbacks = append(callbacks, dnsRuleSetCallback{
|
||||
ruleSet: ruleSet,
|
||||
element: element,
|
||||
})
|
||||
return E.New(deprecated.OptionLegacyDNSAddressFilter.MessageWithLink())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
r.stateAccess.Lock()
|
||||
if !r.closing && len(r.ruleSetCallbacks) == 0 {
|
||||
r.ruleSetCallbacks = callbacks
|
||||
callbacks = nil
|
||||
if flags.needed {
|
||||
return E.New(deprecated.OptionLegacyDNSAddressFilter.MessageWithLink())
|
||||
}
|
||||
r.stateAccess.Unlock()
|
||||
for _, callback := range callbacks {
|
||||
callback.ruleSet.UnregisterCallback(callback.element)
|
||||
}
|
||||
return true, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) matchDNS(ctx context.Context, rules []adapter.DNSRule, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) {
|
||||
@@ -702,9 +605,13 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
|
||||
}
|
||||
return &responseMessage, nil
|
||||
}
|
||||
snapshot := r.acquireRulesSnapshot()
|
||||
defer snapshot.release()
|
||||
rules, legacyDNSMode := snapshot.rulesAndMode()
|
||||
r.rulesAccess.RLock()
|
||||
defer r.rulesAccess.RUnlock()
|
||||
if r.closing {
|
||||
return nil, E.New("dns router closed")
|
||||
}
|
||||
rules := r.rules
|
||||
legacyDNSMode := r.legacyDNSMode
|
||||
r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String()))
|
||||
var (
|
||||
response *mDNS.Msg
|
||||
@@ -810,9 +717,13 @@ done:
|
||||
}
|
||||
|
||||
func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
|
||||
snapshot := r.acquireRulesSnapshot()
|
||||
defer snapshot.release()
|
||||
rules, legacyDNSMode := snapshot.rulesAndMode()
|
||||
r.rulesAccess.RLock()
|
||||
defer r.rulesAccess.RUnlock()
|
||||
if r.closing {
|
||||
return nil, E.New("dns router closed")
|
||||
}
|
||||
rules := r.rules
|
||||
legacyDNSMode := r.legacyDNSMode
|
||||
var (
|
||||
responseAddrs []netip.Addr
|
||||
err error
|
||||
@@ -979,8 +890,8 @@ func (f *dnsRuleModeFlags) merge(other dnsRuleModeFlags) {
|
||||
f.neededFromStrategy = f.neededFromStrategy || other.neededFromStrategy
|
||||
}
|
||||
|
||||
func resolveLegacyDNSMode(router adapter.Router, rules []option.DNSRule) (bool, dnsRuleModeFlags, error) {
|
||||
flags, err := dnsRuleModeRequirements(router, rules)
|
||||
func resolveLegacyDNSMode(router adapter.Router, rules []option.DNSRule, metadataOverrides map[string]adapter.RuleSetMetadata) (bool, dnsRuleModeFlags, error) {
|
||||
flags, err := dnsRuleModeRequirements(router, rules, metadataOverrides)
|
||||
if err != nil {
|
||||
return false, flags, err
|
||||
}
|
||||
@@ -993,10 +904,10 @@ func resolveLegacyDNSMode(router adapter.Router, rules []option.DNSRule) (bool,
|
||||
return flags.needed, flags, nil
|
||||
}
|
||||
|
||||
func dnsRuleModeRequirements(router adapter.Router, rules []option.DNSRule) (dnsRuleModeFlags, error) {
|
||||
func dnsRuleModeRequirements(router adapter.Router, rules []option.DNSRule, metadataOverrides map[string]adapter.RuleSetMetadata) (dnsRuleModeFlags, error) {
|
||||
var flags dnsRuleModeFlags
|
||||
for i, rule := range rules {
|
||||
ruleFlags, err := dnsRuleModeRequirementsInRule(router, rule)
|
||||
ruleFlags, err := dnsRuleModeRequirementsInRule(router, rule, metadataOverrides)
|
||||
if err != nil {
|
||||
return dnsRuleModeFlags{}, E.Cause(err, "dns rule[", i, "]")
|
||||
}
|
||||
@@ -1005,10 +916,10 @@ func dnsRuleModeRequirements(router adapter.Router, rules []option.DNSRule) (dns
|
||||
return flags, nil
|
||||
}
|
||||
|
||||
func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule) (dnsRuleModeFlags, error) {
|
||||
func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule, metadataOverrides map[string]adapter.RuleSetMetadata) (dnsRuleModeFlags, error) {
|
||||
switch rule.Type {
|
||||
case "", C.RuleTypeDefault:
|
||||
return dnsRuleModeRequirementsInDefaultRule(router, rule.DefaultOptions)
|
||||
return dnsRuleModeRequirementsInDefaultRule(router, rule.DefaultOptions, metadataOverrides)
|
||||
case C.RuleTypeLogical:
|
||||
flags := dnsRuleModeFlags{
|
||||
disabled: dnsRuleActionType(rule) == C.RuleActionTypeEvaluate || dnsRuleActionType(rule) == C.RuleActionTypeRespond,
|
||||
@@ -1016,7 +927,7 @@ func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule) (
|
||||
}
|
||||
flags.needed = flags.neededFromStrategy
|
||||
for i, subRule := range rule.LogicalOptions.Rules {
|
||||
subFlags, err := dnsRuleModeRequirementsInRule(router, subRule)
|
||||
subFlags, err := dnsRuleModeRequirementsInRule(router, subRule, metadataOverrides)
|
||||
if err != nil {
|
||||
return dnsRuleModeFlags{}, E.Cause(err, "sub rule[", i, "]")
|
||||
}
|
||||
@@ -1028,7 +939,7 @@ func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule) (
|
||||
}
|
||||
}
|
||||
|
||||
func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.DefaultDNSRule) (dnsRuleModeFlags, error) {
|
||||
func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.DefaultDNSRule, metadataOverrides map[string]adapter.RuleSetMetadata) (dnsRuleModeFlags, error) {
|
||||
flags := dnsRuleModeFlags{
|
||||
disabled: defaultRuleDisablesLegacyDNSMode(rule),
|
||||
neededFromStrategy: dnsRuleActionHasStrategy(rule.DNSRuleAction),
|
||||
@@ -1041,11 +952,10 @@ func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.Def
|
||||
return dnsRuleModeFlags{}, E.New("router service not found")
|
||||
}
|
||||
for _, tag := range rule.RuleSet {
|
||||
ruleSet, loaded := router.RuleSet(tag)
|
||||
if !loaded {
|
||||
return dnsRuleModeFlags{}, E.New("rule-set not found: ", tag)
|
||||
metadata, err := lookupDNSRuleSetMetadata(router, tag, metadataOverrides)
|
||||
if err != nil {
|
||||
return dnsRuleModeFlags{}, err
|
||||
}
|
||||
metadata := ruleSet.Metadata()
|
||||
// ip_version is not a headless-rule item, so ContainsIPVersionRule is intentionally absent.
|
||||
flags.disabled = flags.disabled || metadata.ContainsDNSQueryTypeRule
|
||||
if !rule.RuleSetIPCIDRMatchSource && metadata.ContainsIPCIDRRule {
|
||||
@@ -1055,6 +965,19 @@ func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.Def
|
||||
return flags, nil
|
||||
}
|
||||
|
||||
func lookupDNSRuleSetMetadata(router adapter.Router, tag string, metadataOverrides map[string]adapter.RuleSetMetadata) (adapter.RuleSetMetadata, error) {
|
||||
if metadataOverrides != nil {
|
||||
if metadata, loaded := metadataOverrides[tag]; loaded {
|
||||
return metadata, nil
|
||||
}
|
||||
}
|
||||
ruleSet, loaded := router.RuleSet(tag)
|
||||
if !loaded {
|
||||
return adapter.RuleSetMetadata{}, E.New("rule-set not found: ", tag)
|
||||
}
|
||||
return ruleSet.Metadata(), nil
|
||||
}
|
||||
|
||||
func referencedDNSRuleSetTags(rules []option.DNSRule) []string {
|
||||
tagMap := make(map[string]bool)
|
||||
var walkRule func(rule option.DNSRule)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
"github.com/sagernet/sing/service"
|
||||
|
||||
"go4.org/netipx"
|
||||
)
|
||||
@@ -73,3 +74,20 @@ func isIPCIDRHeadlessRule(rule option.DefaultHeadlessRule) bool {
|
||||
func isDNSQueryTypeHeadlessRule(rule option.DefaultHeadlessRule) bool {
|
||||
return len(rule.QueryType) > 0
|
||||
}
|
||||
|
||||
func buildRuleSetMetadata(headlessRules []option.HeadlessRule) adapter.RuleSetMetadata {
|
||||
return adapter.RuleSetMetadata{
|
||||
ContainsProcessRule: HasHeadlessRule(headlessRules, isProcessHeadlessRule),
|
||||
ContainsWIFIRule: HasHeadlessRule(headlessRules, isWIFIHeadlessRule),
|
||||
ContainsIPCIDRRule: HasHeadlessRule(headlessRules, isIPCIDRHeadlessRule),
|
||||
ContainsDNSQueryTypeRule: HasHeadlessRule(headlessRules, isDNSQueryTypeHeadlessRule),
|
||||
}
|
||||
}
|
||||
|
||||
func validateRuleSetMetadataUpdate(ctx context.Context, tag string, metadata adapter.RuleSetMetadata) error {
|
||||
validator := service.FromContext[adapter.DNSRuleSetUpdateValidator](ctx)
|
||||
if validator == nil {
|
||||
return nil
|
||||
}
|
||||
return validator.ValidateRuleSetMetadataUpdate(tag, metadata)
|
||||
}
|
||||
|
||||
@@ -137,11 +137,11 @@ func (s *LocalRuleSet) reloadRules(headlessRules []option.HeadlessRule) error {
|
||||
return E.Cause(err, "parse rule_set.rules.[", i, "]")
|
||||
}
|
||||
}
|
||||
var metadata adapter.RuleSetMetadata
|
||||
metadata.ContainsProcessRule = HasHeadlessRule(headlessRules, isProcessHeadlessRule)
|
||||
metadata.ContainsWIFIRule = HasHeadlessRule(headlessRules, isWIFIHeadlessRule)
|
||||
metadata.ContainsIPCIDRRule = HasHeadlessRule(headlessRules, isIPCIDRHeadlessRule)
|
||||
metadata.ContainsDNSQueryTypeRule = HasHeadlessRule(headlessRules, isDNSQueryTypeHeadlessRule)
|
||||
metadata := buildRuleSetMetadata(headlessRules)
|
||||
err = validateRuleSetMetadataUpdate(s.ctx, s.tag, metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.access.Lock()
|
||||
s.rules = rules
|
||||
s.metadata = metadata
|
||||
|
||||
@@ -189,11 +189,13 @@ func (s *RemoteRuleSet) loadBytes(content []byte) error {
|
||||
return E.Cause(err, "parse rule_set.rules.[", i, "]")
|
||||
}
|
||||
}
|
||||
metadata := buildRuleSetMetadata(plainRuleSet.Rules)
|
||||
err = validateRuleSetMetadataUpdate(s.ctx, s.options.Tag, metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.access.Lock()
|
||||
s.metadata.ContainsProcessRule = HasHeadlessRule(plainRuleSet.Rules, isProcessHeadlessRule)
|
||||
s.metadata.ContainsWIFIRule = HasHeadlessRule(plainRuleSet.Rules, isWIFIHeadlessRule)
|
||||
s.metadata.ContainsIPCIDRRule = HasHeadlessRule(plainRuleSet.Rules, isIPCIDRHeadlessRule)
|
||||
s.metadata.ContainsDNSQueryTypeRule = HasHeadlessRule(plainRuleSet.Rules, isDNSQueryTypeHeadlessRule)
|
||||
s.metadata = metadata
|
||||
s.rules = rules
|
||||
callbacks := s.callbacks.Array()
|
||||
s.access.Unlock()
|
||||
|
||||
110
route/rule/rule_set_update_validation_test.go
Normal file
110
route/rule/rule_set_update_validation_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
"github.com/sagernet/sing/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeDNSRuleSetUpdateValidator struct {
|
||||
validate func(tag string, metadata adapter.RuleSetMetadata) error
|
||||
}
|
||||
|
||||
func (v *fakeDNSRuleSetUpdateValidator) ValidateRuleSetMetadataUpdate(tag string, metadata adapter.RuleSetMetadata) error {
|
||||
if v.validate == nil {
|
||||
return nil
|
||||
}
|
||||
return v.validate(tag, metadata)
|
||||
}
|
||||
|
||||
func TestLocalRuleSetReloadRulesRejectsInvalidUpdateBeforeCommit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callbackCount atomic.Int32
|
||||
ctx := service.ContextWith[adapter.DNSRuleSetUpdateValidator](context.Background(), &fakeDNSRuleSetUpdateValidator{
|
||||
validate: func(tag string, metadata adapter.RuleSetMetadata) error {
|
||||
require.Equal(t, "dynamic-set", tag)
|
||||
if metadata.ContainsDNSQueryTypeRule {
|
||||
return E.New("dns conflict")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
ruleSet := &LocalRuleSet{
|
||||
ctx: ctx,
|
||||
tag: "dynamic-set",
|
||||
fileFormat: C.RuleSetFormatSource,
|
||||
}
|
||||
_ = ruleSet.callbacks.PushBack(func(adapter.RuleSet) {
|
||||
callbackCount.Add(1)
|
||||
})
|
||||
|
||||
err := ruleSet.reloadRules([]option.HeadlessRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultHeadlessRule{
|
||||
Domain: badoption.Listable[string]{"example.com"},
|
||||
},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(1), callbackCount.Load())
|
||||
require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule)
|
||||
require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"}))
|
||||
|
||||
err = ruleSet.reloadRules([]option.HeadlessRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultHeadlessRule{
|
||||
QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(1)},
|
||||
},
|
||||
}})
|
||||
require.ErrorContains(t, err, "dns conflict")
|
||||
require.Equal(t, int32(1), callbackCount.Load())
|
||||
require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule)
|
||||
require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"}))
|
||||
}
|
||||
|
||||
func TestRemoteRuleSetLoadBytesRejectsInvalidUpdateBeforeCommit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callbackCount atomic.Int32
|
||||
ctx := service.ContextWith[adapter.DNSRuleSetUpdateValidator](context.Background(), &fakeDNSRuleSetUpdateValidator{
|
||||
validate: func(tag string, metadata adapter.RuleSetMetadata) error {
|
||||
require.Equal(t, "dynamic-set", tag)
|
||||
if metadata.ContainsDNSQueryTypeRule {
|
||||
return E.New("dns conflict")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
ruleSet := &RemoteRuleSet{
|
||||
ctx: ctx,
|
||||
options: option.RuleSet{
|
||||
Tag: "dynamic-set",
|
||||
Format: C.RuleSetFormatSource,
|
||||
},
|
||||
callbacks: list.List[adapter.RuleSetUpdateCallback]{},
|
||||
}
|
||||
_ = ruleSet.callbacks.PushBack(func(adapter.RuleSet) {
|
||||
callbackCount.Add(1)
|
||||
})
|
||||
|
||||
err := ruleSet.loadBytes([]byte(`{"version":4,"rules":[{"domain":["example.com"]}]}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(1), callbackCount.Load())
|
||||
require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule)
|
||||
require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"}))
|
||||
|
||||
err = ruleSet.loadBytes([]byte(`{"version":4,"rules":[{"query_type":["A"]}]}`))
|
||||
require.ErrorContains(t, err, "dns conflict")
|
||||
require.Equal(t, int32(1), callbackCount.Load())
|
||||
require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule)
|
||||
require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"}))
|
||||
}
|
||||
Reference in New Issue
Block a user