dns: validate rule-set updates before commit

This commit is contained in:
世界
2026-04-02 00:24:16 +08:00
parent ca43d71152
commit bdfb344955
8 changed files with 470 additions and 762 deletions

View File

@@ -9,6 +9,7 @@ import (
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/service"
"go4.org/netipx"
)
@@ -73,3 +74,20 @@ func isIPCIDRHeadlessRule(rule option.DefaultHeadlessRule) bool {
func isDNSQueryTypeHeadlessRule(rule option.DefaultHeadlessRule) bool {
return len(rule.QueryType) > 0
}
func buildRuleSetMetadata(headlessRules []option.HeadlessRule) adapter.RuleSetMetadata {
return adapter.RuleSetMetadata{
ContainsProcessRule: HasHeadlessRule(headlessRules, isProcessHeadlessRule),
ContainsWIFIRule: HasHeadlessRule(headlessRules, isWIFIHeadlessRule),
ContainsIPCIDRRule: HasHeadlessRule(headlessRules, isIPCIDRHeadlessRule),
ContainsDNSQueryTypeRule: HasHeadlessRule(headlessRules, isDNSQueryTypeHeadlessRule),
}
}
func validateRuleSetMetadataUpdate(ctx context.Context, tag string, metadata adapter.RuleSetMetadata) error {
validator := service.FromContext[adapter.DNSRuleSetUpdateValidator](ctx)
if validator == nil {
return nil
}
return validator.ValidateRuleSetMetadataUpdate(tag, metadata)
}

View File

@@ -137,11 +137,11 @@ func (s *LocalRuleSet) reloadRules(headlessRules []option.HeadlessRule) error {
return E.Cause(err, "parse rule_set.rules.[", i, "]")
}
}
var metadata adapter.RuleSetMetadata
metadata.ContainsProcessRule = HasHeadlessRule(headlessRules, isProcessHeadlessRule)
metadata.ContainsWIFIRule = HasHeadlessRule(headlessRules, isWIFIHeadlessRule)
metadata.ContainsIPCIDRRule = HasHeadlessRule(headlessRules, isIPCIDRHeadlessRule)
metadata.ContainsDNSQueryTypeRule = HasHeadlessRule(headlessRules, isDNSQueryTypeHeadlessRule)
metadata := buildRuleSetMetadata(headlessRules)
err = validateRuleSetMetadataUpdate(s.ctx, s.tag, metadata)
if err != nil {
return err
}
s.access.Lock()
s.rules = rules
s.metadata = metadata

View File

@@ -189,11 +189,13 @@ func (s *RemoteRuleSet) loadBytes(content []byte) error {
return E.Cause(err, "parse rule_set.rules.[", i, "]")
}
}
metadata := buildRuleSetMetadata(plainRuleSet.Rules)
err = validateRuleSetMetadataUpdate(s.ctx, s.options.Tag, metadata)
if err != nil {
return err
}
s.access.Lock()
s.metadata.ContainsProcessRule = HasHeadlessRule(plainRuleSet.Rules, isProcessHeadlessRule)
s.metadata.ContainsWIFIRule = HasHeadlessRule(plainRuleSet.Rules, isWIFIHeadlessRule)
s.metadata.ContainsIPCIDRRule = HasHeadlessRule(plainRuleSet.Rules, isIPCIDRHeadlessRule)
s.metadata.ContainsDNSQueryTypeRule = HasHeadlessRule(plainRuleSet.Rules, isDNSQueryTypeHeadlessRule)
s.metadata = metadata
s.rules = rules
callbacks := s.callbacks.Array()
s.access.Unlock()

View File

@@ -0,0 +1,110 @@
package rule
import (
"context"
"sync/atomic"
"testing"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json/badoption"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
"github.com/stretchr/testify/require"
)
type fakeDNSRuleSetUpdateValidator struct {
validate func(tag string, metadata adapter.RuleSetMetadata) error
}
func (v *fakeDNSRuleSetUpdateValidator) ValidateRuleSetMetadataUpdate(tag string, metadata adapter.RuleSetMetadata) error {
if v.validate == nil {
return nil
}
return v.validate(tag, metadata)
}
func TestLocalRuleSetReloadRulesRejectsInvalidUpdateBeforeCommit(t *testing.T) {
t.Parallel()
var callbackCount atomic.Int32
ctx := service.ContextWith[adapter.DNSRuleSetUpdateValidator](context.Background(), &fakeDNSRuleSetUpdateValidator{
validate: func(tag string, metadata adapter.RuleSetMetadata) error {
require.Equal(t, "dynamic-set", tag)
if metadata.ContainsDNSQueryTypeRule {
return E.New("dns conflict")
}
return nil
},
})
ruleSet := &LocalRuleSet{
ctx: ctx,
tag: "dynamic-set",
fileFormat: C.RuleSetFormatSource,
}
_ = ruleSet.callbacks.PushBack(func(adapter.RuleSet) {
callbackCount.Add(1)
})
err := ruleSet.reloadRules([]option.HeadlessRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultHeadlessRule{
Domain: badoption.Listable[string]{"example.com"},
},
}})
require.NoError(t, err)
require.Equal(t, int32(1), callbackCount.Load())
require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule)
require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"}))
err = ruleSet.reloadRules([]option.HeadlessRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultHeadlessRule{
QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(1)},
},
}})
require.ErrorContains(t, err, "dns conflict")
require.Equal(t, int32(1), callbackCount.Load())
require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule)
require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"}))
}
func TestRemoteRuleSetLoadBytesRejectsInvalidUpdateBeforeCommit(t *testing.T) {
t.Parallel()
var callbackCount atomic.Int32
ctx := service.ContextWith[adapter.DNSRuleSetUpdateValidator](context.Background(), &fakeDNSRuleSetUpdateValidator{
validate: func(tag string, metadata adapter.RuleSetMetadata) error {
require.Equal(t, "dynamic-set", tag)
if metadata.ContainsDNSQueryTypeRule {
return E.New("dns conflict")
}
return nil
},
})
ruleSet := &RemoteRuleSet{
ctx: ctx,
options: option.RuleSet{
Tag: "dynamic-set",
Format: C.RuleSetFormatSource,
},
callbacks: list.List[adapter.RuleSetUpdateCallback]{},
}
_ = ruleSet.callbacks.PushBack(func(adapter.RuleSet) {
callbackCount.Add(1)
})
err := ruleSet.loadBytes([]byte(`{"version":4,"rules":[{"domain":["example.com"]}]}`))
require.NoError(t, err)
require.Equal(t, int32(1), callbackCount.Load())
require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule)
require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"}))
err = ruleSet.loadBytes([]byte(`{"version":4,"rules":[{"query_type":["A"]}]}`))
require.ErrorContains(t, err, "dns conflict")
require.Equal(t, int32(1), callbackCount.Load())
require.False(t, ruleSet.metadata.ContainsDNSQueryTypeRule)
require.True(t, ruleSet.Match(&adapter.InboundContext{Domain: "example.com"}))
}