From bdfb3449554f7e015fd7de1b5bf2ffd73b0903d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 2 Apr 2026 00:24:16 +0800 Subject: [PATCH] dns: validate rule-set updates before commit --- adapter/router.go | 4 + box.go | 1 + dns/router.go | 309 +++---- dns/router_test.go | 770 +++++------------- route/rule/rule_set.go | 18 + route/rule/rule_set_local.go | 10 +- route/rule/rule_set_remote.go | 10 +- route/rule/rule_set_update_validation_test.go | 110 +++ 8 files changed, 470 insertions(+), 762 deletions(-) create mode 100644 route/rule/rule_set_update_validation_test.go diff --git a/adapter/router.go b/adapter/router.go index 550aa6629..f1e3da9a0 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -66,6 +66,10 @@ type RuleSet interface { type RuleSetUpdateCallback func(it RuleSet) +type DNSRuleSetUpdateValidator interface { + ValidateRuleSetMetadataUpdate(tag string, metadata RuleSetMetadata) error +} + // ip_version is not a headless-rule item, so ContainsIPVersionRule is intentionally absent. type RuleSetMetadata struct { ContainsProcessRule bool diff --git a/box.go b/box.go index 82403a29c..04faabbb2 100644 --- a/box.go +++ b/box.go @@ -199,6 +199,7 @@ func New(options Options) (*Box, error) { service.MustRegister[adapter.CertificateProviderManager](ctx, certificateProviderManager) dnsRouter := dns.NewRouter(ctx, logFactory, dnsOptions) service.MustRegister[adapter.DNSRouter](ctx, dnsRouter) + service.MustRegister[adapter.DNSRuleSetUpdateValidator](ctx, dnsRouter) networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions, dnsOptions) if err != nil { return nil, E.Cause(err, "initialize network manager") diff --git a/dns/router.go b/dns/router.go index 0ad1b0c30..a485e599c 100644 --- a/dns/router.go +++ b/dns/router.go @@ -6,7 +6,6 @@ import ( "net/netip" "strings" "sync" - "sync/atomic" "time" "github.com/sagernet/sing-box/adapter" @@ -23,7 +22,6 @@ import ( "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/task" - "github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/contrab/freelru" "github.com/sagernet/sing/contrab/maphash" "github.com/sagernet/sing/service" @@ -32,54 +30,7 @@ import ( ) var _ adapter.DNSRouter = (*Router)(nil) - -type dnsRuleSetCallback struct { - ruleSet adapter.RuleSet - element *list.Element[adapter.RuleSetUpdateCallback] -} - -type rulesSnapshot struct { - rules []adapter.DNSRule - legacyDNSMode bool - references atomic.Int64 -} - -func newRulesSnapshot(rules []adapter.DNSRule, legacyDNSMode bool) *rulesSnapshot { - snapshot := &rulesSnapshot{ - rules: rules, - legacyDNSMode: legacyDNSMode, - } - snapshot.references.Store(1) - return snapshot -} - -func (s *rulesSnapshot) retain() { - if s == nil { - return - } - s.references.Add(1) -} - -func (s *rulesSnapshot) rulesAndMode() ([]adapter.DNSRule, bool) { - if s == nil { - return nil, false - } - return s.rules, s.legacyDNSMode -} - -func (s *rulesSnapshot) release() { - if s == nil { - return - } - references := s.references.Add(-1) - switch { - case references > 0: - case references == 0: - closeRules(s.rules) - default: - panic("dns: negative rules snapshot references") - } -} +var _ adapter.DNSRuleSetUpdateValidator = (*Router)(nil) type Router struct { ctx context.Context @@ -88,14 +39,14 @@ type Router struct { outbound adapter.OutboundManager client adapter.DNSClient rawRules []option.DNSRule - currentRules atomic.Pointer[rulesSnapshot] + rules []adapter.DNSRule defaultDomainStrategy C.DomainStrategy dnsReverseMapping freelru.Cache[netip.Addr, string] platformInterface adapter.PlatformInterface - rebuildAccess sync.Mutex - stateAccess sync.Mutex + legacyDNSMode bool + rulesAccess sync.RWMutex + started bool closing bool - ruleSetCallbacks []dnsRuleSetCallback addressFilterDeprecatedReported bool ruleStrategyDeprecatedReported bool } @@ -107,9 +58,9 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp transport: service.FromContext[adapter.DNSTransportManager](ctx), outbound: service.FromContext[adapter.OutboundManager](ctx), rawRules: make([]option.DNSRule, 0, len(options.Rules)), + rules: make([]adapter.DNSRule, 0, len(options.Rules)), defaultDomainStrategy: C.DomainStrategy(options.Strategy), } - router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, len(options.Rules)), false)) router.client = NewClient(ClientOptions{ DisableCache: options.DNSClientOptions.DisableCache, DisableExpire: options.DNSClientOptions.DisableExpire, @@ -153,107 +104,57 @@ func (r *Router) Start(stage adapter.StartStage) error { monitor.Finish() monitor.Start("initialize DNS rules") - err := r.rebuildRules(true) + newRules, legacyDNSMode, modeFlags, err := r.buildRules(true) monitor.Finish() if err != nil { return err } - monitor.Start("register DNS rule-set callbacks") - needsRulesRefresh, err := r.registerRuleSetCallbacks() - monitor.Finish() - if err != nil { - return err + shouldReportAddressFilterDeprecated := legacyDNSMode && + !r.addressFilterDeprecatedReported && + common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() }) + shouldReportRuleStrategyDeprecated := legacyDNSMode && + !r.ruleStrategyDeprecatedReported && + modeFlags.neededFromStrategy + r.rulesAccess.Lock() + if r.closing { + r.rulesAccess.Unlock() + closeRules(newRules) + return nil } - if needsRulesRefresh { - monitor.Start("refresh DNS rules after callback registration") - err = r.rebuildRules(true) - monitor.Finish() - if err != nil { - r.logger.Error(E.Cause(err, "refresh DNS rules after callback registration")) - } + r.rules = newRules + r.legacyDNSMode = legacyDNSMode + r.started = true + if shouldReportAddressFilterDeprecated { + r.addressFilterDeprecatedReported = true + } + if shouldReportRuleStrategyDeprecated { + r.ruleStrategyDeprecatedReported = true + } + r.rulesAccess.Unlock() + if shouldReportAddressFilterDeprecated { + deprecated.Report(r.ctx, deprecated.OptionLegacyDNSAddressFilter) + } + if shouldReportRuleStrategyDeprecated { + deprecated.Report(r.ctx, deprecated.OptionLegacyDNSRuleStrategy) } } return nil } func (r *Router) Close() error { - r.stateAccess.Lock() + r.rulesAccess.Lock() if r.closing { - r.stateAccess.Unlock() + r.rulesAccess.Unlock() return nil } r.closing = true - callbacks := r.ruleSetCallbacks - r.ruleSetCallbacks = nil - oldSnapshot := r.currentRules.Swap(nil) - for _, callback := range callbacks { - callback.ruleSet.UnregisterCallback(callback.element) - } - r.stateAccess.Unlock() - oldSnapshot.release() + runtimeRules := r.rules + r.rules = nil + r.rulesAccess.Unlock() + closeRules(runtimeRules) return nil } -func (r *Router) rebuildRules(startRules bool) error { - r.rebuildAccess.Lock() - defer r.rebuildAccess.Unlock() - if r.isClosing() { - return nil - } - newRules, legacyDNSMode, modeFlags, err := r.buildRules(startRules) - if err != nil { - if r.isClosing() { - return nil - } - return err - } - shouldReportAddressFilterDeprecated := startRules && - legacyDNSMode && - !r.addressFilterDeprecatedReported && - common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() }) - shouldReportRuleStrategyDeprecated := startRules && - legacyDNSMode && - !r.ruleStrategyDeprecatedReported && - modeFlags.neededFromStrategy - newSnapshot := newRulesSnapshot(newRules, legacyDNSMode) - r.stateAccess.Lock() - if r.closing { - r.stateAccess.Unlock() - newSnapshot.release() - return nil - } - if shouldReportAddressFilterDeprecated { - r.addressFilterDeprecatedReported = true - } - if shouldReportRuleStrategyDeprecated { - r.ruleStrategyDeprecatedReported = true - } - oldSnapshot := r.currentRules.Swap(newSnapshot) - r.stateAccess.Unlock() - oldSnapshot.release() - if shouldReportAddressFilterDeprecated { - deprecated.Report(r.ctx, deprecated.OptionLegacyDNSAddressFilter) - } - if shouldReportRuleStrategyDeprecated { - deprecated.Report(r.ctx, deprecated.OptionLegacyDNSRuleStrategy) - } - return nil -} - -func (r *Router) isClosing() bool { - r.stateAccess.Lock() - defer r.stateAccess.Unlock() - return r.closing -} - -func (r *Router) acquireRulesSnapshot() *rulesSnapshot { - r.stateAccess.Lock() - defer r.stateAccess.Unlock() - snapshot := r.currentRules.Load() - snapshot.retain() - return snapshot -} - func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, dnsRuleModeFlags, error) { for i, ruleOptions := range r.rawRules { err := R.ValidateNoNestedDNSRuleActions(ruleOptions) @@ -262,7 +163,7 @@ func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, dnsRuleMo } } router := service.FromContext[adapter.Router](r.ctx) - legacyDNSMode, modeFlags, err := resolveLegacyDNSMode(router, r.rawRules) + legacyDNSMode, modeFlags, err := resolveLegacyDNSMode(router, r.rawRules, nil) if err != nil { return nil, false, dnsRuleModeFlags{}, err } @@ -304,51 +205,53 @@ func closeRules(rules []adapter.DNSRule) { } } -func (r *Router) registerRuleSetCallbacks() (bool, error) { - tags := referencedDNSRuleSetTags(r.rawRules) - if len(tags) == 0 { - return false, nil +func (r *Router) ValidateRuleSetMetadataUpdate(tag string, metadata adapter.RuleSetMetadata) error { + if len(r.rawRules) == 0 { + return nil } - r.stateAccess.Lock() - if len(r.ruleSetCallbacks) > 0 { - r.stateAccess.Unlock() - return true, nil - } - r.stateAccess.Unlock() router := service.FromContext[adapter.Router](r.ctx) if router == nil { - return false, E.New("router service not found") + return E.New("router service not found") } - callbacks := make([]dnsRuleSetCallback, 0, len(tags)) - for _, tag := range tags { - ruleSet, loaded := router.RuleSet(tag) - if !loaded { - for _, callback := range callbacks { - callback.ruleSet.UnregisterCallback(callback.element) - } - return false, E.New("rule-set not found: ", tag) + overrides := map[string]adapter.RuleSetMetadata{ + tag: metadata, + } + r.rulesAccess.RLock() + started := r.started + legacyDNSMode := r.legacyDNSMode + closing := r.closing + r.rulesAccess.RUnlock() + if closing { + return nil + } + if !started { + candidateLegacyDNSMode, _, err := resolveLegacyDNSMode(router, r.rawRules, overrides) + if err != nil { + return err } - element := ruleSet.RegisterCallback(func(adapter.RuleSet) { - err := r.rebuildRules(true) + if !candidateLegacyDNSMode { + return validateLegacyDNSModeDisabledRules(r.rawRules) + } + return nil + } + _, flags, err := resolveLegacyDNSMode(router, r.rawRules, overrides) + if err != nil { + return err + } + if legacyDNSMode { + if flags.disabled { + err := validateLegacyDNSModeDisabledRules(r.rawRules) if err != nil { - r.logger.Error(E.Cause(err, "rebuild DNS rules after rule-set update")) + return err } - }) - callbacks = append(callbacks, dnsRuleSetCallback{ - ruleSet: ruleSet, - element: element, - }) + return E.New(deprecated.OptionLegacyDNSAddressFilter.MessageWithLink()) + } + return nil } - r.stateAccess.Lock() - if !r.closing && len(r.ruleSetCallbacks) == 0 { - r.ruleSetCallbacks = callbacks - callbacks = nil + if flags.needed { + return E.New(deprecated.OptionLegacyDNSAddressFilter.MessageWithLink()) } - r.stateAccess.Unlock() - for _, callback := range callbacks { - callback.ruleSet.UnregisterCallback(callback.element) - } - return true, nil + return nil } func (r *Router) matchDNS(ctx context.Context, rules []adapter.DNSRule, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) { @@ -702,9 +605,13 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte } return &responseMessage, nil } - snapshot := r.acquireRulesSnapshot() - defer snapshot.release() - rules, legacyDNSMode := snapshot.rulesAndMode() + r.rulesAccess.RLock() + defer r.rulesAccess.RUnlock() + if r.closing { + return nil, E.New("dns router closed") + } + rules := r.rules + legacyDNSMode := r.legacyDNSMode r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String())) var ( response *mDNS.Msg @@ -810,9 +717,13 @@ done: } func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) { - snapshot := r.acquireRulesSnapshot() - defer snapshot.release() - rules, legacyDNSMode := snapshot.rulesAndMode() + r.rulesAccess.RLock() + defer r.rulesAccess.RUnlock() + if r.closing { + return nil, E.New("dns router closed") + } + rules := r.rules + legacyDNSMode := r.legacyDNSMode var ( responseAddrs []netip.Addr err error @@ -979,8 +890,8 @@ func (f *dnsRuleModeFlags) merge(other dnsRuleModeFlags) { f.neededFromStrategy = f.neededFromStrategy || other.neededFromStrategy } -func resolveLegacyDNSMode(router adapter.Router, rules []option.DNSRule) (bool, dnsRuleModeFlags, error) { - flags, err := dnsRuleModeRequirements(router, rules) +func resolveLegacyDNSMode(router adapter.Router, rules []option.DNSRule, metadataOverrides map[string]adapter.RuleSetMetadata) (bool, dnsRuleModeFlags, error) { + flags, err := dnsRuleModeRequirements(router, rules, metadataOverrides) if err != nil { return false, flags, err } @@ -993,10 +904,10 @@ func resolveLegacyDNSMode(router adapter.Router, rules []option.DNSRule) (bool, return flags.needed, flags, nil } -func dnsRuleModeRequirements(router adapter.Router, rules []option.DNSRule) (dnsRuleModeFlags, error) { +func dnsRuleModeRequirements(router adapter.Router, rules []option.DNSRule, metadataOverrides map[string]adapter.RuleSetMetadata) (dnsRuleModeFlags, error) { var flags dnsRuleModeFlags for i, rule := range rules { - ruleFlags, err := dnsRuleModeRequirementsInRule(router, rule) + ruleFlags, err := dnsRuleModeRequirementsInRule(router, rule, metadataOverrides) if err != nil { return dnsRuleModeFlags{}, E.Cause(err, "dns rule[", i, "]") } @@ -1005,10 +916,10 @@ func dnsRuleModeRequirements(router adapter.Router, rules []option.DNSRule) (dns return flags, nil } -func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule) (dnsRuleModeFlags, error) { +func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule, metadataOverrides map[string]adapter.RuleSetMetadata) (dnsRuleModeFlags, error) { switch rule.Type { case "", C.RuleTypeDefault: - return dnsRuleModeRequirementsInDefaultRule(router, rule.DefaultOptions) + return dnsRuleModeRequirementsInDefaultRule(router, rule.DefaultOptions, metadataOverrides) case C.RuleTypeLogical: flags := dnsRuleModeFlags{ disabled: dnsRuleActionType(rule) == C.RuleActionTypeEvaluate || dnsRuleActionType(rule) == C.RuleActionTypeRespond, @@ -1016,7 +927,7 @@ func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule) ( } flags.needed = flags.neededFromStrategy for i, subRule := range rule.LogicalOptions.Rules { - subFlags, err := dnsRuleModeRequirementsInRule(router, subRule) + subFlags, err := dnsRuleModeRequirementsInRule(router, subRule, metadataOverrides) if err != nil { return dnsRuleModeFlags{}, E.Cause(err, "sub rule[", i, "]") } @@ -1028,7 +939,7 @@ func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule) ( } } -func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.DefaultDNSRule) (dnsRuleModeFlags, error) { +func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.DefaultDNSRule, metadataOverrides map[string]adapter.RuleSetMetadata) (dnsRuleModeFlags, error) { flags := dnsRuleModeFlags{ disabled: defaultRuleDisablesLegacyDNSMode(rule), neededFromStrategy: dnsRuleActionHasStrategy(rule.DNSRuleAction), @@ -1041,11 +952,10 @@ func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.Def return dnsRuleModeFlags{}, E.New("router service not found") } for _, tag := range rule.RuleSet { - ruleSet, loaded := router.RuleSet(tag) - if !loaded { - return dnsRuleModeFlags{}, E.New("rule-set not found: ", tag) + metadata, err := lookupDNSRuleSetMetadata(router, tag, metadataOverrides) + if err != nil { + return dnsRuleModeFlags{}, err } - metadata := ruleSet.Metadata() // ip_version is not a headless-rule item, so ContainsIPVersionRule is intentionally absent. flags.disabled = flags.disabled || metadata.ContainsDNSQueryTypeRule if !rule.RuleSetIPCIDRMatchSource && metadata.ContainsIPCIDRRule { @@ -1055,6 +965,19 @@ func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.Def return flags, nil } +func lookupDNSRuleSetMetadata(router adapter.Router, tag string, metadataOverrides map[string]adapter.RuleSetMetadata) (adapter.RuleSetMetadata, error) { + if metadataOverrides != nil { + if metadata, loaded := metadataOverrides[tag]; loaded { + return metadata, nil + } + } + ruleSet, loaded := router.RuleSet(tag) + if !loaded { + return adapter.RuleSetMetadata{}, E.New("rule-set not found: ", tag) + } + return ruleSet.Metadata(), nil +} + func referencedDNSRuleSetTags(rules []option.DNSRule) []string { tagMap := make(map[string]bool) var walkRule func(rule option.DNSRule) diff --git a/dns/router_test.go b/dns/router_test.go index aff6a318d..a6f71d877 100644 --- a/dns/router_test.go +++ b/dns/router_test.go @@ -2,7 +2,6 @@ package dns import ( "context" - "io" "net" "net/netip" "strings" @@ -17,7 +16,6 @@ import ( "github.com/sagernet/sing-box/option" rulepkg "github.com/sagernet/sing-box/route/rule" "github.com/sagernet/sing-tun" - "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json/badoption" N "github.com/sagernet/sing/common/network" @@ -75,6 +73,7 @@ func (m *fakeDNSTransportManager) Create(context.Context, log.ContextLogger, str type fakeDNSClient struct { beforeExchange func(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg) exchange func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) + lookupWithCtx func(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) lookup func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) } @@ -233,14 +232,47 @@ func (c *fakeDNSClient) Exchange(ctx context.Context, transport adapter.DNSTrans if c.beforeExchange != nil { c.beforeExchange(ctx, transport, message) } + if c.exchange == nil { + if len(message.Question) != 1 { + return nil, E.New("unused client exchange") + } + var ( + addresses []netip.Addr + response *mDNS.Msg + err error + ) + if c.lookupWithCtx != nil { + addresses, response, err = c.lookupWithCtx(ctx, transport, FqdnToDomain(message.Question[0].Name), adapter.DNSQueryOptions{}) + } else if c.lookup != nil { + addresses, response, err = c.lookup(transport, FqdnToDomain(message.Question[0].Name), adapter.DNSQueryOptions{}) + } else { + return nil, E.New("unused client exchange") + } + if err != nil { + return nil, err + } + if response != nil { + return response, nil + } + return FixedResponse(0, message.Question[0], addresses, 60), nil + } return c.exchange(transport, message) } -func (c *fakeDNSClient) Lookup(_ context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(*mDNS.Msg) bool) ([]netip.Addr, error) { - if c.lookup == nil { +func (c *fakeDNSClient) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(*mDNS.Msg) bool) ([]netip.Addr, error) { + if c.lookup == nil && c.lookupWithCtx == nil { return nil, E.New("unused client lookup") } - addresses, response, err := c.lookup(transport, domain, options) + var ( + addresses []netip.Addr + response *mDNS.Msg + err error + ) + if c.lookupWithCtx != nil { + addresses, response, err = c.lookupWithCtx(ctx, transport, domain, options) + } else { + addresses, response, err = c.lookup(transport, domain, options) + } if err != nil { return nil, err } @@ -278,9 +310,9 @@ func newTestRouterWithContextAndLogger(t *testing.T, ctx context.Context, rules transport: transportManager, client: client, rawRules: make([]option.DNSRule, 0, len(rules)), + rules: make([]adapter.DNSRule, 0, len(rules)), defaultDomainStrategy: C.DomainStrategyAsIS, } - router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, len(rules)), false)) if rules != nil { err := router.Initialize(rules) require.NoError(t, err) @@ -356,7 +388,6 @@ func TestInitializeRejectsDirectLegacyRuleWhenRuleSetForcesNew(t *testing.T) { rawRules: make([]option.DNSRule, 0, 2), defaultDomainStrategy: C.DomainStrategyAsIS, } - router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 2), false)) err = router.Initialize([]option.DNSRule{ { Type: C.RuleTypeDefault, @@ -438,7 +469,7 @@ func TestLookupLegacyDNSModeDefersRuleSetDestinationIPMatch(t *testing.T) { }, }) - require.True(t, router.currentRules.Load().legacyDNSMode) + require.True(t, router.legacyDNSMode) addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{ LookupStrategy: C.DomainStrategyIPv4Only, @@ -487,45 +518,21 @@ func TestRuleSetUpdateReleasesOldRuleSetRefs(t *testing.T) { require.Zero(t, fakeSet.refCount()) } -func TestRuleSetUpdateKeepsLastSuccessfullyCompiledRuleGraphWhenRebuildFails(t *testing.T) { +func TestValidateRuleSetMetadataUpdateRejectsRuleSetThatWouldDisableLegacyDNSMode(t *testing.T) { t.Parallel() - callbackRuleSet := &fakeRuleSet{ - match: func(*adapter.InboundContext) bool { - return false + fakeSet := &fakeRuleSet{ + metadata: adapter.RuleSetMetadata{ + ContainsIPCIDRRule: true, }, } routerService := &fakeRouter{ ruleSets: map[string]adapter.RuleSet{ - "dynamic-set": callbackRuleSet, + "dynamic-set": fakeSet, }, } ctx := service.ContextWith[adapter.Router](context.Background(), routerService) - defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} - preservedTransport := &fakeDNSTransport{tag: "preserved", transportType: C.DNSTypeUDP} - wouldBeNewTransport := &fakeDNSTransport{tag: "would-be-new", transportType: C.DNSTypeUDP} - loggerFactory := log.NewDefaultFactory( - context.Background(), - log.Formatter{ - BaseTime: time.Now(), - DisableColors: true, - DisableTimestamp: true, - }, - io.Discard, - "", - nil, - true, - ) - loggerFactory.SetLevel(log.LevelError) - logEntries, logDone, err := loggerFactory.Subscribe() - require.NoError(t, err) - t.Cleanup(func() { - loggerFactory.UnSubscribe(logEntries) - closeErr := loggerFactory.Close() - require.NoError(t, closeErr) - }) - var lastUsedTransport common.TypedValue[string] - router := newTestRouterWithContextAndLogger(t, ctx, []option.DNSRule{ + router := newTestRouterWithContext(t, ctx, []option.DNSRule{ { Type: C.RuleTypeDefault, DefaultOptions: option.DefaultDNSRule{ @@ -534,19 +541,7 @@ func TestRuleSetUpdateKeepsLastSuccessfullyCompiledRuleGraphWhenRebuildFails(t * }, DNSRuleAction: option.DNSRuleAction{ Action: C.RuleActionTypeRoute, - RouteOptions: option.DNSRouteActionOptions{Server: "would-be-new"}, - }, - }, - }, - { - Type: C.RuleTypeDefault, - DefaultOptions: option.DefaultDNSRule{ - RawDefaultDNSRule: option.RawDefaultDNSRule{ - Domain: badoption.Listable[string]{"example.com"}, - }, - DNSRuleAction: option.DNSRuleAction{ - Action: C.RuleActionTypeRoute, - RouteOptions: option.DNSRouteActionOptions{Server: "preserved"}, + RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, }, }, }, @@ -558,275 +553,30 @@ func TestRuleSetUpdateKeepsLastSuccessfullyCompiledRuleGraphWhenRebuildFails(t * }, DNSRuleAction: option.DNSRuleAction{ Action: C.RuleActionTypeRoute, - RouteOptions: option.DNSRouteActionOptions{Server: "preserved"}, + RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, }, }, }, }, &fakeDNSTransportManager{ - defaultTransport: defaultTransport, + defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, transports: map[string]adapter.DNSTransport{ - "default": defaultTransport, - "preserved": preservedTransport, - "would-be-new": wouldBeNewTransport, + "default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + "selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP}, }, }, &fakeDNSClient{ - lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { - lastUsedTransport.Store(transport.Tag()) - response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60) - return MessageToAddresses(response), response, nil + lookup: func(adapter.DNSTransport, string, adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { + return []netip.Addr{netip.MustParseAddr("10.0.0.1")}, nil, nil }, - }, loggerFactory.NewLogger("dns")) - t.Cleanup(func() { - closeErr := router.Close() - require.NoError(t, closeErr) }) + require.True(t, router.legacyDNSMode) - require.True(t, router.currentRules.Load().legacyDNSMode) - require.Equal(t, 1, callbackRuleSet.refCount()) - - addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) - require.NoError(t, err) - require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses) - require.Equal(t, "preserved", lastUsedTransport.Load()) - - rebuildTargetRuleSet := &fakeRuleSet{ - metadata: adapter.RuleSetMetadata{ - ContainsDNSQueryTypeRule: true, - }, - match: func(*adapter.InboundContext) bool { - return true - }, - } - routerService.setRuleSet("dynamic-set", rebuildTargetRuleSet) - - callbackRuleSet.updateMetadata(adapter.RuleSetMetadata{ + err := router.ValidateRuleSetMetadataUpdate("dynamic-set", adapter.RuleSetMetadata{ ContainsDNSQueryTypeRule: true, }) - rebuildErrorEntry := waitForLogMessageContaining(t, logEntries, logDone, "rebuild DNS rules after rule-set update") - require.Contains(t, rebuildErrorEntry.Message, "Address Filter Fields") - require.True(t, router.currentRules.Load().legacyDNSMode) - require.Equal(t, 1, callbackRuleSet.refCount()) - require.Zero(t, rebuildTargetRuleSet.refCount()) - - lastUsedTransport.Store("") - addresses, err = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) - require.NoError(t, err) - require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses) - require.Equal(t, "preserved", lastUsedTransport.Load()) - require.NotEqual(t, "would-be-new", lastUsedTransport.Load()) + require.ErrorContains(t, err, "Address Filter Fields") } -func TestRuleSetUpdateSerializesConcurrentRebuilds(t *testing.T) { - t.Parallel() - - callbackRuleSet := &fakeRuleSet{ - match: func(*adapter.InboundContext) bool { - return false - }, - } - routerService := &fakeRouter{ - ruleSets: map[string]adapter.RuleSet{ - "dynamic-set": callbackRuleSet, - }, - } - ctx := service.ContextWith[adapter.Router](context.Background(), routerService) - defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} - firstTransport := &fakeDNSTransport{tag: "first", transportType: C.DNSTypeUDP} - secondTransport := &fakeDNSTransport{tag: "second", transportType: C.DNSTypeUDP} - var lastUsedTransport common.TypedValue[string] - router := newTestRouterWithContext(t, ctx, []option.DNSRule{ - { - Type: C.RuleTypeDefault, - DefaultOptions: option.DefaultDNSRule{ - RawDefaultDNSRule: option.RawDefaultDNSRule{ - RuleSet: badoption.Listable[string]{"dynamic-set"}, - }, - DNSRuleAction: option.DNSRuleAction{ - Action: C.RuleActionTypeRoute, - RouteOptions: option.DNSRouteActionOptions{Server: "first"}, - }, - }, - }, - { - Type: C.RuleTypeDefault, - DefaultOptions: option.DefaultDNSRule{ - RawDefaultDNSRule: option.RawDefaultDNSRule{ - Domain: badoption.Listable[string]{"example.com"}, - }, - DNSRuleAction: option.DNSRuleAction{ - Action: C.RuleActionTypeRoute, - RouteOptions: option.DNSRouteActionOptions{Server: "second"}, - }, - }, - }, - }, &fakeDNSTransportManager{ - defaultTransport: defaultTransport, - transports: map[string]adapter.DNSTransport{ - "default": defaultTransport, - "first": firstTransport, - "second": secondTransport, - }, - }, &fakeDNSClient{ - exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { - lastUsedTransport.Store(transport.Tag()) - return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60), nil - }, - }) - - addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) - require.NoError(t, err) - require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses) - require.Equal(t, "second", lastUsedTransport.Load()) - - callbacks := callbackRuleSet.snapshotCallbacks() - require.Len(t, callbacks, 1) - - firstMetadataEntered := make(chan struct{}) - releaseFirstMetadata := make(chan struct{}) - firstRuleSetStarted := make(chan struct{}) - releaseFirstRuleSetStart := make(chan struct{}) - secondMetadataEntered := make(chan struct{}) - releaseSecondMetadata := make(chan struct{}) - - var metadataAccess sync.Mutex - var metadataCallCount int - var concurrentMetadataCalls int - var maximumConcurrentMetadataCalls int - - recordMetadataEntry := func() func() { - metadataAccess.Lock() - metadataCallCount++ - concurrentMetadataCalls++ - if concurrentMetadataCalls > maximumConcurrentMetadataCalls { - maximumConcurrentMetadataCalls = concurrentMetadataCalls - } - metadataAccess.Unlock() - return func() { - metadataAccess.Lock() - concurrentMetadataCalls-- - metadataAccess.Unlock() - } - } - - firstBuildRuleSet := &fakeRuleSet{ - match: func(*adapter.InboundContext) bool { - return true - }, - metadataRead: func(metadata adapter.RuleSetMetadata) adapter.RuleSetMetadata { - metadataDone := recordMetadataEntry() - close(firstMetadataEntered) - <-releaseFirstMetadata - metadataDone() - return metadata - }, - afterIncrementReference: func() { - close(firstRuleSetStarted) - <-releaseFirstRuleSetStart - }, - } - secondBuildRuleSet := &fakeRuleSet{ - match: func(*adapter.InboundContext) bool { - return false - }, - metadataRead: func(metadata adapter.RuleSetMetadata) adapter.RuleSetMetadata { - metadataDone := recordMetadataEntry() - close(secondMetadataEntered) - <-releaseSecondMetadata - metadataDone() - return metadata - }, - } - - routerService.setRuleSet("dynamic-set", firstBuildRuleSet) - - firstCallbackFinished := make(chan struct{}) - go func() { - callbacks[0](callbackRuleSet) - close(firstCallbackFinished) - }() - - select { - case <-firstMetadataEntered: - case <-time.After(time.Second): - t.Fatal("first rebuild did not reach rule-set metadata") - } - - close(releaseFirstMetadata) - - select { - case <-firstRuleSetStarted: - case <-time.After(time.Second): - t.Fatal("first rebuild did not reach rule-set start") - } - - routerService.setRuleSet("dynamic-set", secondBuildRuleSet) - - secondCallbackStarted := make(chan struct{}) - secondCallbackFinished := make(chan struct{}) - go func() { - close(secondCallbackStarted) - callbacks[0](callbackRuleSet) - close(secondCallbackFinished) - }() - - select { - case <-secondCallbackStarted: - case <-time.After(time.Second): - t.Fatal("second rebuild did not start") - } - - select { - case <-secondMetadataEntered: - t.Fatal("second rebuild entered rule-set metadata before the first rebuild completed") - default: - } - - close(releaseFirstRuleSetStart) - - select { - case <-firstCallbackFinished: - case <-time.After(time.Second): - t.Fatal("first rebuild callback did not finish") - } - - select { - case <-secondMetadataEntered: - case <-time.After(time.Second): - t.Fatal("second rebuild did not enter rule-set metadata after the first rebuild finished") - } - - addresses, err = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) - require.NoError(t, err) - require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses) - require.Equal(t, "first", lastUsedTransport.Load()) - - close(releaseSecondMetadata) - - select { - case <-secondCallbackFinished: - case <-time.After(time.Second): - t.Fatal("second rebuild callback did not finish") - } - - metadataAccess.Lock() - require.Equal(t, 2, metadataCallCount) - require.Equal(t, 1, maximumConcurrentMetadataCalls) - metadataAccess.Unlock() - - lastUsedTransport.Store("") - addresses, err = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) - require.NoError(t, err) - require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses) - require.Equal(t, "second", lastUsedTransport.Load()) - - err = router.Close() - require.NoError(t, err) - require.Zero(t, callbackRuleSet.refCount()) - require.Zero(t, firstBuildRuleSet.refCount()) - require.Zero(t, secondBuildRuleSet.refCount()) -} - -func TestCloseDuringRebuildDiscardsResult(t *testing.T) { +func TestValidateRuleSetMetadataUpdateRejectsRuleSetOnlyLegacyModeSwitchToNew(t *testing.T) { t.Parallel() fakeSet := &fakeRuleSet{ @@ -834,96 +584,63 @@ func TestCloseDuringRebuildDiscardsResult(t *testing.T) { ContainsIPCIDRRule: true, }, } - ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{ + routerService := &fakeRouter{ ruleSets: map[string]adapter.RuleSet{ "dynamic-set": fakeSet, }, - }) - defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} - router := newTestRouterWithContext(t, ctx, []option.DNSRule{ - { - Type: C.RuleTypeDefault, - DefaultOptions: option.DefaultDNSRule{ - RawDefaultDNSRule: option.RawDefaultDNSRule{ - RuleSet: badoption.Listable[string]{"dynamic-set"}, - }, - DNSRuleAction: option.DNSRuleAction{ - Action: C.RuleActionTypeRoute, - RouteOptions: option.DNSRouteActionOptions{Server: "installed"}, - }, + } + ctx := service.ContextWith[adapter.Router](context.Background(), routerService) + router := newTestRouterWithContext(t, ctx, []option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + RuleSet: badoption.Listable[string]{"dynamic-set"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, }, }, - }, &fakeDNSTransportManager{ - defaultTransport: defaultTransport, + }}, &fakeDNSTransportManager{ + defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, transports: map[string]adapter.DNSTransport{ - "default": defaultTransport, - "discarded": &fakeDNSTransport{tag: "discarded", transportType: C.DNSTypeUDP}, - "installed": &fakeDNSTransport{tag: "installed", transportType: C.DNSTypeUDP}, + "default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + "selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP}, }, }, &fakeDNSClient{ - exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { - switch transport.Tag() { - case "discarded", "installed", "default": - return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60), nil - default: - return nil, E.New("unexpected transport: ", transport.Tag()) - } + lookup: func(adapter.DNSTransport, string, adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { + return []netip.Addr{netip.MustParseAddr("10.0.0.1")}, nil, nil }, }) - require.True(t, router.currentRules.Load().legacyDNSMode) - require.Equal(t, 1, fakeSet.refCount()) + require.True(t, router.legacyDNSMode) - callbacks := fakeSet.snapshotCallbacks() - require.Len(t, callbacks, 1) - - firstMetadataEntered := make(chan struct{}) - releaseFirstMetadata := make(chan struct{}) - callbackFinished := make(chan struct{}) - fakeSet.metadataRead = func(metadata adapter.RuleSetMetadata) adapter.RuleSetMetadata { - router.rawRules[0].DefaultOptions.RouteOptions.Server = "discarded" - close(firstMetadataEntered) - <-releaseFirstMetadata - return adapter.RuleSetMetadata{} - } - - go func() { - callbacks[0](fakeSet) - close(callbackFinished) - }() - - select { - case <-firstMetadataEntered: - case <-time.After(time.Second): - t.Fatal("rebuild did not reach rule-set metadata") - } - - err := router.Close() - require.NoError(t, err) - close(releaseFirstMetadata) - - select { - case <-callbackFinished: - case <-time.After(time.Second): - t.Fatal("rebuild callback did not finish after close") - } - - fakeSet.metadataRead = nil - - require.Nil(t, router.currentRules.Load()) - require.Zero(t, fakeSet.refCount()) + err := router.ValidateRuleSetMetadataUpdate("dynamic-set", adapter.RuleSetMetadata{ + ContainsIPCIDRRule: true, + ContainsDNSQueryTypeRule: true, + }) + require.ErrorContains(t, err, "Address Filter Fields") } -func TestCloseIgnoresSnapshottedRuleSetCallback(t *testing.T) { +func TestValidateRuleSetMetadataUpdateBeforeStartUsesStartupValidation(t *testing.T) { t.Parallel() fakeSet := &fakeRuleSet{} - ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{ + routerService := &fakeRouter{ ruleSets: map[string]adapter.RuleSet{ "dynamic-set": fakeSet, }, - }) - defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} - router := newTestRouterWithContext(t, ctx, []option.DNSRule{ + } + ctx := service.ContextWith[adapter.Router](context.Background(), routerService) + router := &Router{ + ctx: ctx, + logger: log.NewNOPFactory().NewLogger("dns"), + transport: &fakeDNSTransportManager{}, + client: &fakeDNSClient{}, + rawRules: make([]option.DNSRule, 0, 2), + rules: make([]adapter.DNSRule, 0, 2), + defaultDomainStrategy: C.DomainStrategyAsIS, + } + err := router.Initialize([]option.DNSRule{ { Type: C.RuleTypeDefault, DefaultOptions: option.DefaultDNSRule{ @@ -932,7 +649,7 @@ func TestCloseIgnoresSnapshottedRuleSetCallback(t *testing.T) { }, DNSRuleAction: option.DNSRuleAction{ Action: C.RuleActionTypeRoute, - RouteOptions: option.DNSRouteActionOptions{Server: "default"}, + RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, }, }, }, @@ -944,143 +661,61 @@ func TestCloseIgnoresSnapshottedRuleSetCallback(t *testing.T) { }, DNSRuleAction: option.DNSRuleAction{ Action: C.RuleActionTypeRoute, - RouteOptions: option.DNSRouteActionOptions{Server: "default"}, + RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, }, }, }, - }, &fakeDNSTransportManager{ - defaultTransport: defaultTransport, - transports: map[string]adapter.DNSTransport{ - "default": defaultTransport, - }, - }, &fakeDNSClient{ - lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { - response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60) - return MessageToAddresses(response), response, nil - }, }) - - callbacks := fakeSet.snapshotCallbacks() - require.Len(t, callbacks, 1) - - require.NoError(t, router.Close()) - require.Empty(t, fakeSet.snapshotCallbacks()) - - fakeSet.metadata = adapter.RuleSetMetadata{ - ContainsDNSQueryTypeRule: true, - } - callbacks[0](fakeSet) -} - -func TestRuleSetUpdateDoesNotBlockOnInFlightLookup(t *testing.T) { - t.Parallel() - - fakeSet := &fakeRuleSet{ - metadata: adapter.RuleSetMetadata{ - ContainsIPCIDRRule: true, - }, - } - ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{ - ruleSets: map[string]adapter.RuleSet{ - "dynamic-set": fakeSet, - }, - }) - defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} - selectedTransport := &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP} - lookupStarted := make(chan struct{}) - releaseLookup := make(chan struct{}) - router := newTestRouterWithContext(t, ctx, []option.DNSRule{{ - Type: C.RuleTypeDefault, - DefaultOptions: option.DefaultDNSRule{ - RawDefaultDNSRule: option.RawDefaultDNSRule{ - RuleSet: badoption.Listable[string]{"dynamic-set"}, - }, - DNSRuleAction: option.DNSRuleAction{ - Action: C.RuleActionTypeRoute, - RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, - }, - }, - }}, &fakeDNSTransportManager{ - defaultTransport: defaultTransport, - transports: map[string]adapter.DNSTransport{ - "default": defaultTransport, - "selected": selectedTransport, - }, - }, &fakeDNSClient{ - lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { - require.Equal(t, "selected", transport.Tag()) - require.Equal(t, "example.com", domain) - require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy) - close(lookupStarted) - <-releaseLookup - response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60) - return MessageToAddresses(response), response, nil - }, - }) - t.Cleanup(func() { - closeErr := router.Close() - require.NoError(t, closeErr) - }) - - require.True(t, router.currentRules.Load().legacyDNSMode) - require.Equal(t, 1, fakeSet.refCount()) - - var ( - addresses []netip.Addr - err error - ) - lookupDone := make(chan struct{}) - go func() { - addresses, err = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{ - LookupStrategy: C.DomainStrategyIPv4Only, - }) - close(lookupDone) - }() - - select { - case <-lookupStarted: - case <-time.After(time.Second): - t.Fatal("lookup did not reach DNS client") - } - - rebuildDone := make(chan struct{}) - go func() { - fakeSet.updateMetadata(adapter.RuleSetMetadata{ - ContainsIPCIDRRule: true, - }) - close(rebuildDone) - }() - - select { - case <-rebuildDone: - case <-time.After(time.Second): - t.Fatal("rebuild blocked on in-flight lookup") - } - - require.Equal(t, 2, fakeSet.refCount()) - - select { - case <-lookupDone: - t.Fatal("lookup finished before release") - default: - } - - close(releaseLookup) - - select { - case <-lookupDone: - case <-time.After(time.Second): - t.Fatal("lookup did not finish after release") - } - require.NoError(t, err) - require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses) - require.Eventually(t, func() bool { - return fakeSet.refCount() == 1 - }, time.Second, 10*time.Millisecond) + require.False(t, router.started) + + err = router.ValidateRuleSetMetadataUpdate("dynamic-set", adapter.RuleSetMetadata{ + ContainsDNSQueryTypeRule: true, + }) + require.ErrorContains(t, err, "Address Filter Fields") } -func TestCloseReleasesSnapshottedRulesAfterInFlightLookup(t *testing.T) { +func TestValidateRuleSetMetadataUpdateRejectsRuleSetThatWouldRequireLegacyDNSMode(t *testing.T) { + t.Parallel() + + fakeSet := &fakeRuleSet{} + routerService := &fakeRouter{ + ruleSets: map[string]adapter.RuleSet{ + "dynamic-set": fakeSet, + }, + } + ctx := service.ContextWith[adapter.Router](context.Background(), routerService) + router := newTestRouterWithContext(t, ctx, []option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + RuleSet: badoption.Listable[string]{"dynamic-set"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, + }, + }, + }}, &fakeDNSTransportManager{ + defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + transports: map[string]adapter.DNSTransport{ + "default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + "selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP}, + }, + }, &fakeDNSClient{ + lookup: func(adapter.DNSTransport, string, adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { + return []netip.Addr{netip.MustParseAddr("1.1.1.1")}, nil, nil + }, + }) + require.False(t, router.legacyDNSMode) + + err := router.ValidateRuleSetMetadataUpdate("dynamic-set", adapter.RuleSetMetadata{ + ContainsIPCIDRRule: true, + }) + require.ErrorContains(t, err, "Address Filter Fields") +} + +func TestValidateRuleSetMetadataUpdateAllowsRelaxingLegacyRequirement(t *testing.T) { t.Parallel() fakeSet := &fakeRuleSet{ @@ -1088,15 +723,12 @@ func TestCloseReleasesSnapshottedRulesAfterInFlightLookup(t *testing.T) { ContainsIPCIDRRule: true, }, } - ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{ + routerService := &fakeRouter{ ruleSets: map[string]adapter.RuleSet{ "dynamic-set": fakeSet, }, - }) - defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} - selectedTransport := &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP} - lookupStarted := make(chan struct{}) - releaseLookup := make(chan struct{}) + } + ctx := service.ContextWith[adapter.Router](context.Background(), routerService) router := newTestRouterWithContext(t, ctx, []option.DNSRule{{ Type: C.RuleTypeDefault, DefaultOptions: option.DefaultDNSRule{ @@ -1108,6 +740,41 @@ func TestCloseReleasesSnapshottedRulesAfterInFlightLookup(t *testing.T) { RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, }, }, + }}, &fakeDNSTransportManager{ + defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + transports: map[string]adapter.DNSTransport{ + "default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + "selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP}, + }, + }, &fakeDNSClient{ + lookup: func(adapter.DNSTransport, string, adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { + return []netip.Addr{netip.MustParseAddr("10.0.0.1")}, nil, nil + }, + }) + require.True(t, router.legacyDNSMode) + + err := router.ValidateRuleSetMetadataUpdate("dynamic-set", adapter.RuleSetMetadata{}) + require.NoError(t, err) +} + +func TestCloseWaitsForInFlightLookupUntilContextCancellation(t *testing.T) { + t.Parallel() + + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + selectedTransport := &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP} + lookupStarted := make(chan struct{}) + var lookupStartedOnce sync.Once + router := newTestRouter(t, []option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, + }, + }, }}, &fakeDNSTransportManager{ defaultTransport: defaultTransport, transports: map[string]adapter.DNSTransport{ @@ -1115,30 +782,26 @@ func TestCloseReleasesSnapshottedRulesAfterInFlightLookup(t *testing.T) { "selected": selectedTransport, }, }, &fakeDNSClient{ - lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { + lookupWithCtx: func(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { require.Equal(t, "selected", transport.Tag()) require.Equal(t, "example.com", domain) - require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy) - close(lookupStarted) - <-releaseLookup - response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60) - return MessageToAddresses(response), response, nil + lookupStartedOnce.Do(func() { + close(lookupStarted) + }) + <-ctx.Done() + return nil, nil, ctx.Err() }, }) - require.True(t, router.currentRules.Load().legacyDNSMode) - require.Equal(t, 1, fakeSet.refCount()) - + lookupCtx, cancelLookup := context.WithCancel(context.Background()) + defer cancelLookup() var ( - addresses []netip.Addr lookupErr error closeErr error ) lookupDone := make(chan struct{}) go func() { - addresses, lookupErr = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{ - LookupStrategy: C.DomainStrategyIPv4Only, - }) + _, lookupErr = router.Lookup(lookupCtx, "example.com", adapter.DNSQueryOptions{}) close(lookupDone) }() @@ -1154,29 +817,27 @@ func TestCloseReleasesSnapshottedRulesAfterInFlightLookup(t *testing.T) { close(closeDone) }() - require.Eventually(t, func() bool { - return router.currentRules.Load() == nil && fakeSet.refCount() == 1 - }, time.Second, 10*time.Millisecond) + select { + case <-closeDone: + t.Fatal("close finished before lookup context cancellation") + default: + } - close(releaseLookup) + cancelLookup() select { case <-lookupDone: case <-time.After(time.Second): - t.Fatal("lookup did not finish after release") + t.Fatal("lookup did not finish after cancellation") } select { case <-closeDone: case <-time.After(time.Second): - t.Fatal("close did not finish") + t.Fatal("close did not finish after lookup cancellation") } - require.NoError(t, lookupErr) + require.ErrorIs(t, lookupErr, context.Canceled) require.NoError(t, closeErr) - require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses) - require.Eventually(t, func() bool { - return fakeSet.refCount() == 0 - }, time.Second, 10*time.Millisecond) } func TestLookupLegacyDNSModeDefersDirectDestinationIPMatch(t *testing.T) { @@ -1217,7 +878,7 @@ func TestLookupLegacyDNSModeDefersDirectDestinationIPMatch(t *testing.T) { }, }, client) - require.True(t, router.currentRules.Load().legacyDNSMode) + require.True(t, router.legacyDNSMode) addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{ LookupStrategy: C.DomainStrategyIPv4Only, @@ -1369,7 +1030,7 @@ func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *tes }, }) - require.True(t, router.currentRules.Load().legacyDNSMode) + require.True(t, router.legacyDNSMode) addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{ LookupStrategy: C.DomainStrategyIPv4Only, @@ -1998,7 +1659,7 @@ func TestExchangeLegacyDNSModeDisabledRespondReturnsEvaluatedResponse(t *testing return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil }, }) - require.False(t, router.currentRules.Load().legacyDNSMode) + require.False(t, router.legacyDNSMode) response, err := router.Exchange(context.Background(), &mDNS.Msg{ Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)}, @@ -2055,7 +1716,7 @@ func TestLookupLegacyDNSModeDisabledRespondReturnsEvaluatedResponse(t *testing.T } }, }) - require.False(t, router.currentRules.Load().legacyDNSMode) + require.False(t, router.legacyDNSMode) addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) require.NoError(t, err) @@ -2105,7 +1766,7 @@ func TestExchangeLegacyDNSModeDisabledRespondWithoutEvaluatedResponseReturnsErro return nil, E.New("upstream exchange failed") }, }) - require.False(t, router.currentRules.Load().legacyDNSMode) + require.False(t, router.legacyDNSMode) response, err := router.Exchange(context.Background(), &mDNS.Msg{ Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)}, @@ -2136,7 +1797,7 @@ func TestLookupLegacyDNSModeDisabledAllowsPartialSuccess(t *testing.T) { } }, }) - router.currentRules.Load().legacyDNSMode = false + router.legacyDNSMode = false addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) require.NoError(t, err) @@ -2173,7 +1834,7 @@ func TestLookupLegacyDNSModeDisabledSkipsFakeIPRule(t *testing.T) { return FixedResponse(0, message.Question[0], nil, 60), nil }, }) - router.currentRules.Load().legacyDNSMode = false + router.legacyDNSMode = false addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) require.NoError(t, err) @@ -2226,7 +1887,6 @@ func TestInitializeRejectsDNSRuleStrategyWhenLegacyDNSModeIsDisabledByEvaluate(t rawRules: make([]option.DNSRule, 0, 1), defaultDomainStrategy: C.DomainStrategyAsIS, } - router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false)) err := router.Initialize([]option.DNSRule{{ Type: C.RuleTypeDefault, DefaultOptions: option.DefaultDNSRule{ @@ -2257,7 +1917,6 @@ func TestInitializeRejectsEvaluateFakeIPServerInDefaultRule(t *testing.T) { rawRules: make([]option.DNSRule, 0, 1), defaultDomainStrategy: C.DomainStrategyAsIS, } - router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false)) err := router.Initialize([]option.DNSRule{{ Type: C.RuleTypeDefault, DefaultOptions: option.DefaultDNSRule{ @@ -2285,7 +1944,6 @@ func TestInitializeRejectsEvaluateFakeIPServerInLogicalRule(t *testing.T) { rawRules: make([]option.DNSRule, 0, 1), defaultDomainStrategy: C.DomainStrategyAsIS, } - router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false)) err := router.Initialize([]option.DNSRule{{ Type: C.RuleTypeLogical, LogicalOptions: option.LogicalDNSRule{ @@ -2321,7 +1979,6 @@ func TestInitializeRejectsDNSRuleStrategyWhenLegacyDNSModeIsDisabledByMatchRespo rawRules: make([]option.DNSRule, 0, 1), defaultDomainStrategy: C.DomainStrategyAsIS, } - router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false)) err := router.Initialize([]option.DNSRule{{ Type: C.RuleTypeDefault, DefaultOptions: option.DefaultDNSRule{ @@ -2351,7 +2008,6 @@ func TestInitializeRejectsDNSMatchResponseWithoutPrecedingEvaluate(t *testing.T) rawRules: make([]option.DNSRule, 0, 1), defaultDomainStrategy: C.DomainStrategyAsIS, } - router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false)) err := router.Initialize([]option.DNSRule{{ Type: C.RuleTypeDefault, DefaultOptions: option.DefaultDNSRule{ @@ -2379,7 +2035,6 @@ func TestInitializeRejectsDNSRespondWithoutPrecedingEvaluate(t *testing.T) { rawRules: make([]option.DNSRule, 0, 1), defaultDomainStrategy: C.DomainStrategyAsIS, } - router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false)) err := router.Initialize([]option.DNSRule{{ Type: C.RuleTypeDefault, DefaultOptions: option.DefaultDNSRule{ @@ -2405,7 +2060,6 @@ func TestInitializeRejectsLogicalDNSRespondWithoutPrecedingEvaluate(t *testing.T rawRules: make([]option.DNSRule, 0, 1), defaultDomainStrategy: C.DomainStrategyAsIS, } - router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false)) err := router.Initialize([]option.DNSRule{{ Type: C.RuleTypeLogical, LogicalOptions: option.LogicalDNSRule{ @@ -2439,7 +2093,6 @@ func TestInitializeRejectsEvaluateRuleWithResponseMatchWithoutPrecedingEvaluate( rawRules: make([]option.DNSRule, 0, 1), defaultDomainStrategy: C.DomainStrategyAsIS, } - router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false)) err := router.Initialize([]option.DNSRule{{ Type: C.RuleTypeLogical, LogicalOptions: option.LogicalDNSRule{ @@ -2485,7 +2138,6 @@ func TestInitializeAllowsEvaluateRuleWithResponseMatchAfterPrecedingEvaluate(t * rawRules: make([]option.DNSRule, 0, 2), defaultDomainStrategy: C.DomainStrategyAsIS, } - router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 2), false)) err := router.Initialize([]option.DNSRule{ { Type: C.RuleTypeDefault, @@ -2559,7 +2211,7 @@ func TestLookupLegacyDNSModeDisabledReturnsRejectedErrorForRejectAction(t *testi "default": defaultTransport, }, }, &fakeDNSClient{}) - require.False(t, router.currentRules.Load().legacyDNSMode) + require.False(t, router.legacyDNSMode) addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) require.Nil(t, addresses) @@ -2592,7 +2244,7 @@ func TestExchangeLegacyDNSModeDisabledReturnsRefusedResponseForRejectAction(t *t "default": defaultTransport, }, }, &fakeDNSClient{}) - require.False(t, router.currentRules.Load().legacyDNSMode) + require.False(t, router.legacyDNSMode) response, err := router.Exchange(context.Background(), &mDNS.Msg{ Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)}, @@ -2627,7 +2279,7 @@ func TestExchangeLegacyDNSModeDisabledReturnsDropErrorForRejectDropAction(t *tes "default": defaultTransport, }, }, &fakeDNSClient{}) - require.False(t, router.currentRules.Load().legacyDNSMode) + require.False(t, router.legacyDNSMode) response, err := router.Exchange(context.Background(), &mDNS.Msg{ Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)}, @@ -2664,7 +2316,7 @@ func TestLookupLegacyDNSModeDisabledFiltersPerQueryTypeAddressesBeforeMerging(t "default": defaultTransport, }, }, &fakeDNSClient{}) - require.False(t, router.currentRules.Load().legacyDNSMode) + require.False(t, router.legacyDNSMode) addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) require.NoError(t, err) @@ -2754,7 +2406,6 @@ func TestLegacyDNSModeReportsLegacyAddressFilterDeprecation(t *testing.T) { client: &fakeDNSClient{}, defaultDomainStrategy: C.DomainStrategyAsIS, } - router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false)) err := router.Initialize([]option.DNSRule{{ Type: C.RuleTypeDefault, DefaultOptions: option.DefaultDNSRule{ @@ -2786,7 +2437,6 @@ func TestLegacyDNSModeReportsDNSRuleStrategyDeprecation(t *testing.T) { client: &fakeDNSClient{}, defaultDomainStrategy: C.DomainStrategyAsIS, } - router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false)) err := router.Initialize([]option.DNSRule{{ Type: C.RuleTypeDefault, DefaultOptions: option.DefaultDNSRule{ diff --git a/route/rule/rule_set.go b/route/rule/rule_set.go index 9bffa8fcb..d286a7941 100644 --- a/route/rule/rule_set.go +++ b/route/rule/rule_set.go @@ -9,6 +9,7 @@ import ( "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" + "github.com/sagernet/sing/service" "go4.org/netipx" ) @@ -73,3 +74,20 @@ func isIPCIDRHeadlessRule(rule option.DefaultHeadlessRule) bool { func isDNSQueryTypeHeadlessRule(rule option.DefaultHeadlessRule) bool { return len(rule.QueryType) > 0 } + +func buildRuleSetMetadata(headlessRules []option.HeadlessRule) adapter.RuleSetMetadata { + return adapter.RuleSetMetadata{ + ContainsProcessRule: HasHeadlessRule(headlessRules, isProcessHeadlessRule), + ContainsWIFIRule: HasHeadlessRule(headlessRules, isWIFIHeadlessRule), + ContainsIPCIDRRule: HasHeadlessRule(headlessRules, isIPCIDRHeadlessRule), + ContainsDNSQueryTypeRule: HasHeadlessRule(headlessRules, isDNSQueryTypeHeadlessRule), + } +} + +func validateRuleSetMetadataUpdate(ctx context.Context, tag string, metadata adapter.RuleSetMetadata) error { + validator := service.FromContext[adapter.DNSRuleSetUpdateValidator](ctx) + if validator == nil { + return nil + } + return validator.ValidateRuleSetMetadataUpdate(tag, metadata) +} diff --git a/route/rule/rule_set_local.go b/route/rule/rule_set_local.go index 51e8f2723..5408615fc 100644 --- a/route/rule/rule_set_local.go +++ b/route/rule/rule_set_local.go @@ -137,11 +137,11 @@ func (s *LocalRuleSet) reloadRules(headlessRules []option.HeadlessRule) error { return E.Cause(err, "parse rule_set.rules.[", i, "]") } } - var metadata adapter.RuleSetMetadata - metadata.ContainsProcessRule = HasHeadlessRule(headlessRules, isProcessHeadlessRule) - metadata.ContainsWIFIRule = HasHeadlessRule(headlessRules, isWIFIHeadlessRule) - metadata.ContainsIPCIDRRule = HasHeadlessRule(headlessRules, isIPCIDRHeadlessRule) - metadata.ContainsDNSQueryTypeRule = HasHeadlessRule(headlessRules, isDNSQueryTypeHeadlessRule) + metadata := buildRuleSetMetadata(headlessRules) + err = validateRuleSetMetadataUpdate(s.ctx, s.tag, metadata) + if err != nil { + return err + } s.access.Lock() s.rules = rules s.metadata = metadata diff --git a/route/rule/rule_set_remote.go b/route/rule/rule_set_remote.go index 4d2691450..53d353b3c 100644 --- a/route/rule/rule_set_remote.go +++ b/route/rule/rule_set_remote.go @@ -189,11 +189,13 @@ func (s *RemoteRuleSet) loadBytes(content []byte) error { return E.Cause(err, "parse rule_set.rules.[", i, "]") } } + metadata := buildRuleSetMetadata(plainRuleSet.Rules) + err = validateRuleSetMetadataUpdate(s.ctx, s.options.Tag, metadata) + if err != nil { + return err + } s.access.Lock() - s.metadata.ContainsProcessRule = HasHeadlessRule(plainRuleSet.Rules, isProcessHeadlessRule) - s.metadata.ContainsWIFIRule = HasHeadlessRule(plainRuleSet.Rules, isWIFIHeadlessRule) - s.metadata.ContainsIPCIDRRule = HasHeadlessRule(plainRuleSet.Rules, isIPCIDRHeadlessRule) - s.metadata.ContainsDNSQueryTypeRule = HasHeadlessRule(plainRuleSet.Rules, isDNSQueryTypeHeadlessRule) + s.metadata = metadata s.rules = rules callbacks := s.callbacks.Array() s.access.Unlock() diff --git a/route/rule/rule_set_update_validation_test.go b/route/rule/rule_set_update_validation_test.go new file mode 100644 index 000000000..2f29e551a --- /dev/null +++ b/route/rule/rule_set_update_validation_test.go @@ -0,0 +1,110 @@ +package rule + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json/badoption" + "github.com/sagernet/sing/common/x/list" + "github.com/sagernet/sing/service" + "github.com/stretchr/testify/require" +) + +type fakeDNSRuleSetUpdateValidator struct { + validate func(tag string, metadata adapter.RuleSetMetadata) error +} + +func (v *fakeDNSRuleSetUpdateValidator) ValidateRuleSetMetadataUpdate(tag string, metadata adapter.RuleSetMetadata) error { + if v.validate == nil { + return nil + } + return v.validate(tag, metadata) +} + +func TestLocalRuleSetReloadRulesRejectsInvalidUpdateBeforeCommit(t *testing.T) { + t.Parallel() + + var callbackCount atomic.Int32 + ctx := service.ContextWith[adapter.DNSRuleSetUpdateValidator](context.Background(), &fakeDNSRuleSetUpdateValidator{ + validate: func(tag string, metadata adapter.RuleSetMetadata) error { + require.Equal(t, "dynamic-set", tag) + if metadata.ContainsDNSQueryTypeRule { + return E.New("dns conflict") + } + return nil + }, + }) + ruleSet := &LocalRuleSet{ + ctx: ctx, + tag: "dynamic-set", + fileFormat: C.RuleSetFormatSource, + } + _ = ruleSet.callbacks.PushBack(func(adapter.RuleSet) { + callbackCount.Add(1) + }) + + err := ruleSet.reloadRules([]option.HeadlessRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultHeadlessRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + }}) + require.NoError(t, err) + require.Equal(t, int32(1), callbackCount.Load()) + require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule) + require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"})) + + err = ruleSet.reloadRules([]option.HeadlessRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultHeadlessRule{ + QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(1)}, + }, + }}) + require.ErrorContains(t, err, "dns conflict") + require.Equal(t, int32(1), callbackCount.Load()) + require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule) + require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"})) +} + +func TestRemoteRuleSetLoadBytesRejectsInvalidUpdateBeforeCommit(t *testing.T) { + t.Parallel() + + var callbackCount atomic.Int32 + ctx := service.ContextWith[adapter.DNSRuleSetUpdateValidator](context.Background(), &fakeDNSRuleSetUpdateValidator{ + validate: func(tag string, metadata adapter.RuleSetMetadata) error { + require.Equal(t, "dynamic-set", tag) + if metadata.ContainsDNSQueryTypeRule { + return E.New("dns conflict") + } + return nil + }, + }) + ruleSet := &RemoteRuleSet{ + ctx: ctx, + options: option.RuleSet{ + Tag: "dynamic-set", + Format: C.RuleSetFormatSource, + }, + callbacks: list.List[adapter.RuleSetUpdateCallback]{}, + } + _ = ruleSet.callbacks.PushBack(func(adapter.RuleSet) { + callbackCount.Add(1) + }) + + err := ruleSet.loadBytes([]byte(`{"version":4,"rules":[{"domain":["example.com"]}]}`)) + require.NoError(t, err) + require.Equal(t, int32(1), callbackCount.Load()) + require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule) + require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"})) + + err = ruleSet.loadBytes([]byte(`{"version":4,"rules":[{"query_type":["A"]}]}`)) + require.ErrorContains(t, err, "dns conflict") + require.Equal(t, int32(1), callbackCount.Load()) + require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule) + require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"})) +}