dns: validate rule-set updates before commit

This commit is contained in:
世界
2026-04-02 00:24:16 +08:00
parent ca43d71152
commit bdfb344955
8 changed files with 470 additions and 762 deletions

View File

@@ -66,6 +66,10 @@ type RuleSet interface {
type RuleSetUpdateCallback func(it RuleSet)
type DNSRuleSetUpdateValidator interface {
ValidateRuleSetMetadataUpdate(tag string, metadata RuleSetMetadata) error
}
// ip_version is not a headless-rule item, so ContainsIPVersionRule is intentionally absent.
type RuleSetMetadata struct {
ContainsProcessRule bool

1
box.go
View File

@@ -199,6 +199,7 @@ func New(options Options) (*Box, error) {
service.MustRegister[adapter.CertificateProviderManager](ctx, certificateProviderManager)
dnsRouter := dns.NewRouter(ctx, logFactory, dnsOptions)
service.MustRegister[adapter.DNSRouter](ctx, dnsRouter)
service.MustRegister[adapter.DNSRuleSetUpdateValidator](ctx, dnsRouter)
networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions, dnsOptions)
if err != nil {
return nil, E.Cause(err, "initialize network manager")

View File

@@ -6,7 +6,6 @@ import (
"net/netip"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/sagernet/sing-box/adapter"
@@ -23,7 +22,6 @@ import (
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
"github.com/sagernet/sing/service"
@@ -32,54 +30,7 @@ import (
)
var _ adapter.DNSRouter = (*Router)(nil)
type dnsRuleSetCallback struct {
ruleSet adapter.RuleSet
element *list.Element[adapter.RuleSetUpdateCallback]
}
type rulesSnapshot struct {
rules []adapter.DNSRule
legacyDNSMode bool
references atomic.Int64
}
func newRulesSnapshot(rules []adapter.DNSRule, legacyDNSMode bool) *rulesSnapshot {
snapshot := &rulesSnapshot{
rules: rules,
legacyDNSMode: legacyDNSMode,
}
snapshot.references.Store(1)
return snapshot
}
func (s *rulesSnapshot) retain() {
if s == nil {
return
}
s.references.Add(1)
}
func (s *rulesSnapshot) rulesAndMode() ([]adapter.DNSRule, bool) {
if s == nil {
return nil, false
}
return s.rules, s.legacyDNSMode
}
func (s *rulesSnapshot) release() {
if s == nil {
return
}
references := s.references.Add(-1)
switch {
case references > 0:
case references == 0:
closeRules(s.rules)
default:
panic("dns: negative rules snapshot references")
}
}
var _ adapter.DNSRuleSetUpdateValidator = (*Router)(nil)
type Router struct {
ctx context.Context
@@ -88,14 +39,14 @@ type Router struct {
outbound adapter.OutboundManager
client adapter.DNSClient
rawRules []option.DNSRule
currentRules atomic.Pointer[rulesSnapshot]
rules []adapter.DNSRule
defaultDomainStrategy C.DomainStrategy
dnsReverseMapping freelru.Cache[netip.Addr, string]
platformInterface adapter.PlatformInterface
rebuildAccess sync.Mutex
stateAccess sync.Mutex
legacyDNSMode bool
rulesAccess sync.RWMutex
started bool
closing bool
ruleSetCallbacks []dnsRuleSetCallback
addressFilterDeprecatedReported bool
ruleStrategyDeprecatedReported bool
}
@@ -107,9 +58,9 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp
transport: service.FromContext[adapter.DNSTransportManager](ctx),
outbound: service.FromContext[adapter.OutboundManager](ctx),
rawRules: make([]option.DNSRule, 0, len(options.Rules)),
rules: make([]adapter.DNSRule, 0, len(options.Rules)),
defaultDomainStrategy: C.DomainStrategy(options.Strategy),
}
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, len(options.Rules)), false))
router.client = NewClient(ClientOptions{
DisableCache: options.DNSClientOptions.DisableCache,
DisableExpire: options.DNSClientOptions.DisableExpire,
@@ -153,107 +104,57 @@ func (r *Router) Start(stage adapter.StartStage) error {
monitor.Finish()
monitor.Start("initialize DNS rules")
err := r.rebuildRules(true)
newRules, legacyDNSMode, modeFlags, err := r.buildRules(true)
monitor.Finish()
if err != nil {
return err
}
monitor.Start("register DNS rule-set callbacks")
needsRulesRefresh, err := r.registerRuleSetCallbacks()
monitor.Finish()
if err != nil {
return err
shouldReportAddressFilterDeprecated := legacyDNSMode &&
!r.addressFilterDeprecatedReported &&
common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() })
shouldReportRuleStrategyDeprecated := legacyDNSMode &&
!r.ruleStrategyDeprecatedReported &&
modeFlags.neededFromStrategy
r.rulesAccess.Lock()
if r.closing {
r.rulesAccess.Unlock()
closeRules(newRules)
return nil
}
if needsRulesRefresh {
monitor.Start("refresh DNS rules after callback registration")
err = r.rebuildRules(true)
monitor.Finish()
if err != nil {
r.logger.Error(E.Cause(err, "refresh DNS rules after callback registration"))
}
r.rules = newRules
r.legacyDNSMode = legacyDNSMode
r.started = true
if shouldReportAddressFilterDeprecated {
r.addressFilterDeprecatedReported = true
}
if shouldReportRuleStrategyDeprecated {
r.ruleStrategyDeprecatedReported = true
}
r.rulesAccess.Unlock()
if shouldReportAddressFilterDeprecated {
deprecated.Report(r.ctx, deprecated.OptionLegacyDNSAddressFilter)
}
if shouldReportRuleStrategyDeprecated {
deprecated.Report(r.ctx, deprecated.OptionLegacyDNSRuleStrategy)
}
}
return nil
}
func (r *Router) Close() error {
r.stateAccess.Lock()
r.rulesAccess.Lock()
if r.closing {
r.stateAccess.Unlock()
r.rulesAccess.Unlock()
return nil
}
r.closing = true
callbacks := r.ruleSetCallbacks
r.ruleSetCallbacks = nil
oldSnapshot := r.currentRules.Swap(nil)
for _, callback := range callbacks {
callback.ruleSet.UnregisterCallback(callback.element)
}
r.stateAccess.Unlock()
oldSnapshot.release()
runtimeRules := r.rules
r.rules = nil
r.rulesAccess.Unlock()
closeRules(runtimeRules)
return nil
}
func (r *Router) rebuildRules(startRules bool) error {
r.rebuildAccess.Lock()
defer r.rebuildAccess.Unlock()
if r.isClosing() {
return nil
}
newRules, legacyDNSMode, modeFlags, err := r.buildRules(startRules)
if err != nil {
if r.isClosing() {
return nil
}
return err
}
shouldReportAddressFilterDeprecated := startRules &&
legacyDNSMode &&
!r.addressFilterDeprecatedReported &&
common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() })
shouldReportRuleStrategyDeprecated := startRules &&
legacyDNSMode &&
!r.ruleStrategyDeprecatedReported &&
modeFlags.neededFromStrategy
newSnapshot := newRulesSnapshot(newRules, legacyDNSMode)
r.stateAccess.Lock()
if r.closing {
r.stateAccess.Unlock()
newSnapshot.release()
return nil
}
if shouldReportAddressFilterDeprecated {
r.addressFilterDeprecatedReported = true
}
if shouldReportRuleStrategyDeprecated {
r.ruleStrategyDeprecatedReported = true
}
oldSnapshot := r.currentRules.Swap(newSnapshot)
r.stateAccess.Unlock()
oldSnapshot.release()
if shouldReportAddressFilterDeprecated {
deprecated.Report(r.ctx, deprecated.OptionLegacyDNSAddressFilter)
}
if shouldReportRuleStrategyDeprecated {
deprecated.Report(r.ctx, deprecated.OptionLegacyDNSRuleStrategy)
}
return nil
}
func (r *Router) isClosing() bool {
r.stateAccess.Lock()
defer r.stateAccess.Unlock()
return r.closing
}
func (r *Router) acquireRulesSnapshot() *rulesSnapshot {
r.stateAccess.Lock()
defer r.stateAccess.Unlock()
snapshot := r.currentRules.Load()
snapshot.retain()
return snapshot
}
func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, dnsRuleModeFlags, error) {
for i, ruleOptions := range r.rawRules {
err := R.ValidateNoNestedDNSRuleActions(ruleOptions)
@@ -262,7 +163,7 @@ func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, dnsRuleMo
}
}
router := service.FromContext[adapter.Router](r.ctx)
legacyDNSMode, modeFlags, err := resolveLegacyDNSMode(router, r.rawRules)
legacyDNSMode, modeFlags, err := resolveLegacyDNSMode(router, r.rawRules, nil)
if err != nil {
return nil, false, dnsRuleModeFlags{}, err
}
@@ -304,51 +205,53 @@ func closeRules(rules []adapter.DNSRule) {
}
}
func (r *Router) registerRuleSetCallbacks() (bool, error) {
tags := referencedDNSRuleSetTags(r.rawRules)
if len(tags) == 0 {
return false, nil
func (r *Router) ValidateRuleSetMetadataUpdate(tag string, metadata adapter.RuleSetMetadata) error {
if len(r.rawRules) == 0 {
return nil
}
r.stateAccess.Lock()
if len(r.ruleSetCallbacks) > 0 {
r.stateAccess.Unlock()
return true, nil
}
r.stateAccess.Unlock()
router := service.FromContext[adapter.Router](r.ctx)
if router == nil {
return false, E.New("router service not found")
return E.New("router service not found")
}
callbacks := make([]dnsRuleSetCallback, 0, len(tags))
for _, tag := range tags {
ruleSet, loaded := router.RuleSet(tag)
if !loaded {
for _, callback := range callbacks {
callback.ruleSet.UnregisterCallback(callback.element)
}
return false, E.New("rule-set not found: ", tag)
overrides := map[string]adapter.RuleSetMetadata{
tag: metadata,
}
r.rulesAccess.RLock()
started := r.started
legacyDNSMode := r.legacyDNSMode
closing := r.closing
r.rulesAccess.RUnlock()
if closing {
return nil
}
if !started {
candidateLegacyDNSMode, _, err := resolveLegacyDNSMode(router, r.rawRules, overrides)
if err != nil {
return err
}
element := ruleSet.RegisterCallback(func(adapter.RuleSet) {
err := r.rebuildRules(true)
if !candidateLegacyDNSMode {
return validateLegacyDNSModeDisabledRules(r.rawRules)
}
return nil
}
_, flags, err := resolveLegacyDNSMode(router, r.rawRules, overrides)
if err != nil {
return err
}
if legacyDNSMode {
if flags.disabled {
err := validateLegacyDNSModeDisabledRules(r.rawRules)
if err != nil {
r.logger.Error(E.Cause(err, "rebuild DNS rules after rule-set update"))
return err
}
})
callbacks = append(callbacks, dnsRuleSetCallback{
ruleSet: ruleSet,
element: element,
})
return E.New(deprecated.OptionLegacyDNSAddressFilter.MessageWithLink())
}
return nil
}
r.stateAccess.Lock()
if !r.closing && len(r.ruleSetCallbacks) == 0 {
r.ruleSetCallbacks = callbacks
callbacks = nil
if flags.needed {
return E.New(deprecated.OptionLegacyDNSAddressFilter.MessageWithLink())
}
r.stateAccess.Unlock()
for _, callback := range callbacks {
callback.ruleSet.UnregisterCallback(callback.element)
}
return true, nil
return nil
}
func (r *Router) matchDNS(ctx context.Context, rules []adapter.DNSRule, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) {
@@ -702,9 +605,13 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
}
return &responseMessage, nil
}
snapshot := r.acquireRulesSnapshot()
defer snapshot.release()
rules, legacyDNSMode := snapshot.rulesAndMode()
r.rulesAccess.RLock()
defer r.rulesAccess.RUnlock()
if r.closing {
return nil, E.New("dns router closed")
}
rules := r.rules
legacyDNSMode := r.legacyDNSMode
r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String()))
var (
response *mDNS.Msg
@@ -810,9 +717,13 @@ done:
}
func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
snapshot := r.acquireRulesSnapshot()
defer snapshot.release()
rules, legacyDNSMode := snapshot.rulesAndMode()
r.rulesAccess.RLock()
defer r.rulesAccess.RUnlock()
if r.closing {
return nil, E.New("dns router closed")
}
rules := r.rules
legacyDNSMode := r.legacyDNSMode
var (
responseAddrs []netip.Addr
err error
@@ -979,8 +890,8 @@ func (f *dnsRuleModeFlags) merge(other dnsRuleModeFlags) {
f.neededFromStrategy = f.neededFromStrategy || other.neededFromStrategy
}
func resolveLegacyDNSMode(router adapter.Router, rules []option.DNSRule) (bool, dnsRuleModeFlags, error) {
flags, err := dnsRuleModeRequirements(router, rules)
func resolveLegacyDNSMode(router adapter.Router, rules []option.DNSRule, metadataOverrides map[string]adapter.RuleSetMetadata) (bool, dnsRuleModeFlags, error) {
flags, err := dnsRuleModeRequirements(router, rules, metadataOverrides)
if err != nil {
return false, flags, err
}
@@ -993,10 +904,10 @@ func resolveLegacyDNSMode(router adapter.Router, rules []option.DNSRule) (bool,
return flags.needed, flags, nil
}
func dnsRuleModeRequirements(router adapter.Router, rules []option.DNSRule) (dnsRuleModeFlags, error) {
func dnsRuleModeRequirements(router adapter.Router, rules []option.DNSRule, metadataOverrides map[string]adapter.RuleSetMetadata) (dnsRuleModeFlags, error) {
var flags dnsRuleModeFlags
for i, rule := range rules {
ruleFlags, err := dnsRuleModeRequirementsInRule(router, rule)
ruleFlags, err := dnsRuleModeRequirementsInRule(router, rule, metadataOverrides)
if err != nil {
return dnsRuleModeFlags{}, E.Cause(err, "dns rule[", i, "]")
}
@@ -1005,10 +916,10 @@ func dnsRuleModeRequirements(router adapter.Router, rules []option.DNSRule) (dns
return flags, nil
}
func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule) (dnsRuleModeFlags, error) {
func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule, metadataOverrides map[string]adapter.RuleSetMetadata) (dnsRuleModeFlags, error) {
switch rule.Type {
case "", C.RuleTypeDefault:
return dnsRuleModeRequirementsInDefaultRule(router, rule.DefaultOptions)
return dnsRuleModeRequirementsInDefaultRule(router, rule.DefaultOptions, metadataOverrides)
case C.RuleTypeLogical:
flags := dnsRuleModeFlags{
disabled: dnsRuleActionType(rule) == C.RuleActionTypeEvaluate || dnsRuleActionType(rule) == C.RuleActionTypeRespond,
@@ -1016,7 +927,7 @@ func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule) (
}
flags.needed = flags.neededFromStrategy
for i, subRule := range rule.LogicalOptions.Rules {
subFlags, err := dnsRuleModeRequirementsInRule(router, subRule)
subFlags, err := dnsRuleModeRequirementsInRule(router, subRule, metadataOverrides)
if err != nil {
return dnsRuleModeFlags{}, E.Cause(err, "sub rule[", i, "]")
}
@@ -1028,7 +939,7 @@ func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule) (
}
}
func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.DefaultDNSRule) (dnsRuleModeFlags, error) {
func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.DefaultDNSRule, metadataOverrides map[string]adapter.RuleSetMetadata) (dnsRuleModeFlags, error) {
flags := dnsRuleModeFlags{
disabled: defaultRuleDisablesLegacyDNSMode(rule),
neededFromStrategy: dnsRuleActionHasStrategy(rule.DNSRuleAction),
@@ -1041,11 +952,10 @@ func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.Def
return dnsRuleModeFlags{}, E.New("router service not found")
}
for _, tag := range rule.RuleSet {
ruleSet, loaded := router.RuleSet(tag)
if !loaded {
return dnsRuleModeFlags{}, E.New("rule-set not found: ", tag)
metadata, err := lookupDNSRuleSetMetadata(router, tag, metadataOverrides)
if err != nil {
return dnsRuleModeFlags{}, err
}
metadata := ruleSet.Metadata()
// ip_version is not a headless-rule item, so ContainsIPVersionRule is intentionally absent.
flags.disabled = flags.disabled || metadata.ContainsDNSQueryTypeRule
if !rule.RuleSetIPCIDRMatchSource && metadata.ContainsIPCIDRRule {
@@ -1055,6 +965,19 @@ func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.Def
return flags, nil
}
func lookupDNSRuleSetMetadata(router adapter.Router, tag string, metadataOverrides map[string]adapter.RuleSetMetadata) (adapter.RuleSetMetadata, error) {
if metadataOverrides != nil {
if metadata, loaded := metadataOverrides[tag]; loaded {
return metadata, nil
}
}
ruleSet, loaded := router.RuleSet(tag)
if !loaded {
return adapter.RuleSetMetadata{}, E.New("rule-set not found: ", tag)
}
return ruleSet.Metadata(), nil
}
func referencedDNSRuleSetTags(rules []option.DNSRule) []string {
tagMap := make(map[string]bool)
var walkRule func(rule option.DNSRule)

File diff suppressed because it is too large Load Diff

View File

@@ -9,6 +9,7 @@ import (
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/service"
"go4.org/netipx"
)
@@ -73,3 +74,20 @@ func isIPCIDRHeadlessRule(rule option.DefaultHeadlessRule) bool {
func isDNSQueryTypeHeadlessRule(rule option.DefaultHeadlessRule) bool {
return len(rule.QueryType) > 0
}
func buildRuleSetMetadata(headlessRules []option.HeadlessRule) adapter.RuleSetMetadata {
return adapter.RuleSetMetadata{
ContainsProcessRule: HasHeadlessRule(headlessRules, isProcessHeadlessRule),
ContainsWIFIRule: HasHeadlessRule(headlessRules, isWIFIHeadlessRule),
ContainsIPCIDRRule: HasHeadlessRule(headlessRules, isIPCIDRHeadlessRule),
ContainsDNSQueryTypeRule: HasHeadlessRule(headlessRules, isDNSQueryTypeHeadlessRule),
}
}
func validateRuleSetMetadataUpdate(ctx context.Context, tag string, metadata adapter.RuleSetMetadata) error {
validator := service.FromContext[adapter.DNSRuleSetUpdateValidator](ctx)
if validator == nil {
return nil
}
return validator.ValidateRuleSetMetadataUpdate(tag, metadata)
}

View File

@@ -137,11 +137,11 @@ func (s *LocalRuleSet) reloadRules(headlessRules []option.HeadlessRule) error {
return E.Cause(err, "parse rule_set.rules.[", i, "]")
}
}
var metadata adapter.RuleSetMetadata
metadata.ContainsProcessRule = HasHeadlessRule(headlessRules, isProcessHeadlessRule)
metadata.ContainsWIFIRule = HasHeadlessRule(headlessRules, isWIFIHeadlessRule)
metadata.ContainsIPCIDRRule = HasHeadlessRule(headlessRules, isIPCIDRHeadlessRule)
metadata.ContainsDNSQueryTypeRule = HasHeadlessRule(headlessRules, isDNSQueryTypeHeadlessRule)
metadata := buildRuleSetMetadata(headlessRules)
err = validateRuleSetMetadataUpdate(s.ctx, s.tag, metadata)
if err != nil {
return err
}
s.access.Lock()
s.rules = rules
s.metadata = metadata

View File

@@ -189,11 +189,13 @@ func (s *RemoteRuleSet) loadBytes(content []byte) error {
return E.Cause(err, "parse rule_set.rules.[", i, "]")
}
}
metadata := buildRuleSetMetadata(plainRuleSet.Rules)
err = validateRuleSetMetadataUpdate(s.ctx, s.options.Tag, metadata)
if err != nil {
return err
}
s.access.Lock()
s.metadata.ContainsProcessRule = HasHeadlessRule(plainRuleSet.Rules, isProcessHeadlessRule)
s.metadata.ContainsWIFIRule = HasHeadlessRule(plainRuleSet.Rules, isWIFIHeadlessRule)
s.metadata.ContainsIPCIDRRule = HasHeadlessRule(plainRuleSet.Rules, isIPCIDRHeadlessRule)
s.metadata.ContainsDNSQueryTypeRule = HasHeadlessRule(plainRuleSet.Rules, isDNSQueryTypeHeadlessRule)
s.metadata = metadata
s.rules = rules
callbacks := s.callbacks.Array()
s.access.Unlock()

View File

@@ -0,0 +1,110 @@
package rule
import (
"context"
"sync/atomic"
"testing"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json/badoption"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
"github.com/stretchr/testify/require"
)
type fakeDNSRuleSetUpdateValidator struct {
validate func(tag string, metadata adapter.RuleSetMetadata) error
}
func (v *fakeDNSRuleSetUpdateValidator) ValidateRuleSetMetadataUpdate(tag string, metadata adapter.RuleSetMetadata) error {
if v.validate == nil {
return nil
}
return v.validate(tag, metadata)
}
func TestLocalRuleSetReloadRulesRejectsInvalidUpdateBeforeCommit(t *testing.T) {
t.Parallel()
var callbackCount atomic.Int32
ctx := service.ContextWith[adapter.DNSRuleSetUpdateValidator](context.Background(), &fakeDNSRuleSetUpdateValidator{
validate: func(tag string, metadata adapter.RuleSetMetadata) error {
require.Equal(t, "dynamic-set", tag)
if metadata.ContainsDNSQueryTypeRule {
return E.New("dns conflict")
}
return nil
},
})
ruleSet := &LocalRuleSet{
ctx: ctx,
tag: "dynamic-set",
fileFormat: C.RuleSetFormatSource,
}
_ = ruleSet.callbacks.PushBack(func(adapter.RuleSet) {
callbackCount.Add(1)
})
err := ruleSet.reloadRules([]option.HeadlessRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultHeadlessRule{
Domain: badoption.Listable[string]{"example.com"},
},
}})
require.NoError(t, err)
require.Equal(t, int32(1), callbackCount.Load())
require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule)
require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"}))
err = ruleSet.reloadRules([]option.HeadlessRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultHeadlessRule{
QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(1)},
},
}})
require.ErrorContains(t, err, "dns conflict")
require.Equal(t, int32(1), callbackCount.Load())
require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule)
require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"}))
}
func TestRemoteRuleSetLoadBytesRejectsInvalidUpdateBeforeCommit(t *testing.T) {
t.Parallel()
var callbackCount atomic.Int32
ctx := service.ContextWith[adapter.DNSRuleSetUpdateValidator](context.Background(), &fakeDNSRuleSetUpdateValidator{
validate: func(tag string, metadata adapter.RuleSetMetadata) error {
require.Equal(t, "dynamic-set", tag)
if metadata.ContainsDNSQueryTypeRule {
return E.New("dns conflict")
}
return nil
},
})
ruleSet := &RemoteRuleSet{
ctx: ctx,
options: option.RuleSet{
Tag: "dynamic-set",
Format: C.RuleSetFormatSource,
},
callbacks: list.List[adapter.RuleSetUpdateCallback]{},
}
_ = ruleSet.callbacks.PushBack(func(adapter.RuleSet) {
callbackCount.Add(1)
})
err := ruleSet.loadBytes([]byte(`{"version":4,"rules":[{"domain":["example.com"]}]}`))
require.NoError(t, err)
require.Equal(t, int32(1), callbackCount.Load())
require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule)
require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"}))
err = ruleSet.loadBytes([]byte(`{"version":4,"rules":[{"query_type":["A"]}]}`))
require.ErrorContains(t, err, "dns conflict")
require.Equal(t, int32(1), callbackCount.Load())
require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule)
require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"}))
}