mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
Fix DNS rule-set ref handling
This commit is contained in:
@@ -1097,7 +1097,6 @@ func referencedDNSRuleSetTags(rules []option.DNSRule) []string {
|
||||
}
|
||||
|
||||
func validateNonLegacyAddressFilterRules(rules []option.DNSRule) error {
|
||||
var seenEvaluate bool
|
||||
for i, rule := range rules {
|
||||
consumesResponse, err := validateNonLegacyAddressFilterRuleTree(rule)
|
||||
if err != nil {
|
||||
@@ -1107,12 +1106,6 @@ func validateNonLegacyAddressFilterRules(rules []option.DNSRule) error {
|
||||
if action == C.RuleActionTypeEvaluate && consumesResponse {
|
||||
return E.New("dns rule[", i, "]: evaluate rule cannot consume response state")
|
||||
}
|
||||
if consumesResponse && !seenEvaluate {
|
||||
return E.New("dns rule[", i, "]: response matching requires a preceding top-level evaluate rule")
|
||||
}
|
||||
if action == C.RuleActionTypeEvaluate {
|
||||
seenEvaluate = true
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -118,16 +118,32 @@ type fakeRuleSet struct {
|
||||
access sync.Mutex
|
||||
metadata adapter.RuleSetMetadata
|
||||
callbacks list.List[adapter.RuleSetUpdateCallback]
|
||||
refs int
|
||||
}
|
||||
|
||||
func (s *fakeRuleSet) Name() string { return "fake-rule-set" }
|
||||
func (s *fakeRuleSet) StartContext(context.Context, *adapter.HTTPStartContext) error { return nil }
|
||||
func (s *fakeRuleSet) PostStart() error { return nil }
|
||||
func (s *fakeRuleSet) Metadata() adapter.RuleSetMetadata { return s.metadata }
|
||||
func (s *fakeRuleSet) ExtractIPSet() []*netipx.IPSet { return nil }
|
||||
func (s *fakeRuleSet) IncRef() {}
|
||||
func (s *fakeRuleSet) DecRef() {}
|
||||
func (s *fakeRuleSet) Cleanup() {}
|
||||
func (s *fakeRuleSet) Metadata() adapter.RuleSetMetadata {
|
||||
s.access.Lock()
|
||||
defer s.access.Unlock()
|
||||
return s.metadata
|
||||
}
|
||||
func (s *fakeRuleSet) ExtractIPSet() []*netipx.IPSet { return nil }
|
||||
func (s *fakeRuleSet) IncRef() {
|
||||
s.access.Lock()
|
||||
defer s.access.Unlock()
|
||||
s.refs++
|
||||
}
|
||||
func (s *fakeRuleSet) DecRef() {
|
||||
s.access.Lock()
|
||||
defer s.access.Unlock()
|
||||
s.refs--
|
||||
if s.refs < 0 {
|
||||
panic("rule-set: negative refs")
|
||||
}
|
||||
}
|
||||
func (s *fakeRuleSet) Cleanup() {}
|
||||
func (s *fakeRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback) *list.Element[adapter.RuleSetUpdateCallback] {
|
||||
s.access.Lock()
|
||||
defer s.access.Unlock()
|
||||
@@ -157,6 +173,12 @@ func (s *fakeRuleSet) snapshotCallbacks() []adapter.RuleSetUpdateCallback {
|
||||
return s.callbacks.Array()
|
||||
}
|
||||
|
||||
func (s *fakeRuleSet) refCount() int {
|
||||
s.access.Lock()
|
||||
defer s.access.Unlock()
|
||||
return s.refs
|
||||
}
|
||||
|
||||
func (m *fakeDeprecatedManager) ReportDeprecated(feature deprecated.Note) {
|
||||
m.features = append(m.features, feature)
|
||||
}
|
||||
@@ -294,6 +316,21 @@ func TestValidateNewDNSRules_RequireMatchResponseForDirectIPCIDR(t *testing.T) {
|
||||
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
|
||||
}
|
||||
|
||||
func TestValidateNewDNSRules_AllowMatchResponseWithoutEvaluate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := validateNonLegacyAddressFilterRules([]option.DNSRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
MatchResponse: true,
|
||||
IPCIDR: badoption.Listable[string]{"1.1.1.0/24"},
|
||||
},
|
||||
},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInitializeRejectsInvalidDNSRuleParseError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -442,6 +479,46 @@ func TestLookupLegacyModeDefersRuleSetDestinationIPMatch(t *testing.T) {
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
|
||||
}
|
||||
|
||||
func TestRuleSetUpdateReleasesOldRuleSetRefs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fakeSet := &fakeRuleSet{}
|
||||
ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{
|
||||
ruleSets: map[string]adapter.RuleSet{
|
||||
"dynamic-set": fakeSet,
|
||||
},
|
||||
})
|
||||
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
||||
router := newTestRouterWithContext(t, ctx, []option.DNSRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
RuleSet: badoption.Listable[string]{"dynamic-set"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
|
||||
},
|
||||
},
|
||||
}}, &fakeDNSTransportManager{
|
||||
defaultTransport: defaultTransport,
|
||||
transports: map[string]adapter.DNSTransport{
|
||||
"default": defaultTransport,
|
||||
},
|
||||
}, &fakeDNSClient{})
|
||||
|
||||
require.Equal(t, 1, fakeSet.refCount())
|
||||
|
||||
fakeSet.updateMetadata(adapter.RuleSetMetadata{})
|
||||
require.Equal(t, 1, fakeSet.refCount())
|
||||
|
||||
fakeSet.updateMetadata(adapter.RuleSetMetadata{})
|
||||
require.Equal(t, 1, fakeSet.refCount())
|
||||
|
||||
require.NoError(t, router.Close())
|
||||
require.Zero(t, fakeSet.refCount())
|
||||
}
|
||||
|
||||
func TestRuleSetUpdateSetsRuntimeErrorWhenRebuildFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -29,9 +29,11 @@ func NewRuleSetItem(router adapter.Router, tagList []string, ipCIDRMatchSource b
|
||||
}
|
||||
|
||||
func (r *RuleSetItem) Start() error {
|
||||
_ = r.Close()
|
||||
for _, tag := range r.tagList {
|
||||
ruleSet, loaded := r.router.RuleSet(tag)
|
||||
if !loaded {
|
||||
_ = r.Close()
|
||||
return E.New("rule-set not found: ", tag)
|
||||
}
|
||||
ruleSet.IncRef()
|
||||
@@ -40,6 +42,15 @@ func (r *RuleSetItem) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RuleSetItem) Close() error {
|
||||
for _, ruleSet := range r.setList {
|
||||
ruleSet.DecRef()
|
||||
}
|
||||
clear(r.setList)
|
||||
r.setList = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RuleSetItem) Match(metadata *adapter.InboundContext) bool {
|
||||
return !r.matchStates(metadata).isEmpty()
|
||||
}
|
||||
|
||||
134
route/rule/rule_item_rule_set_test.go
Normal file
134
route/rule/rule_item_rule_set_test.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-tun"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go4.org/netipx"
|
||||
)
|
||||
|
||||
type ruleSetItemTestRouter struct {
|
||||
ruleSets map[string]adapter.RuleSet
|
||||
}
|
||||
|
||||
func (r *ruleSetItemTestRouter) Start(adapter.StartStage) error { return nil }
|
||||
func (r *ruleSetItemTestRouter) Close() error { return nil }
|
||||
func (r *ruleSetItemTestRouter) PreMatch(adapter.InboundContext, tun.DirectRouteContext, time.Duration, bool) (tun.DirectRouteDestination, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *ruleSetItemTestRouter) RouteConnection(context.Context, net.Conn, adapter.InboundContext) error {
|
||||
return nil
|
||||
}
|
||||
func (r *ruleSetItemTestRouter) RoutePacketConnection(context.Context, N.PacketConn, adapter.InboundContext) error {
|
||||
return nil
|
||||
}
|
||||
func (r *ruleSetItemTestRouter) RouteConnectionEx(context.Context, net.Conn, adapter.InboundContext, N.CloseHandlerFunc) {
|
||||
}
|
||||
func (r *ruleSetItemTestRouter) RoutePacketConnectionEx(context.Context, N.PacketConn, adapter.InboundContext, N.CloseHandlerFunc) {
|
||||
}
|
||||
func (r *ruleSetItemTestRouter) RuleSet(tag string) (adapter.RuleSet, bool) {
|
||||
ruleSet, loaded := r.ruleSets[tag]
|
||||
return ruleSet, loaded
|
||||
}
|
||||
func (r *ruleSetItemTestRouter) Rules() []adapter.Rule { return nil }
|
||||
func (r *ruleSetItemTestRouter) NeedFindProcess() bool { return false }
|
||||
func (r *ruleSetItemTestRouter) NeedFindNeighbor() bool { return false }
|
||||
func (r *ruleSetItemTestRouter) NeighborResolver() adapter.NeighborResolver { return nil }
|
||||
func (r *ruleSetItemTestRouter) AppendTracker(adapter.ConnectionTracker) {}
|
||||
func (r *ruleSetItemTestRouter) ResetNetwork() {}
|
||||
|
||||
type countingRuleSet struct {
|
||||
name string
|
||||
refs atomic.Int32
|
||||
}
|
||||
|
||||
func (s *countingRuleSet) Name() string { return s.name }
|
||||
func (s *countingRuleSet) StartContext(context.Context, *adapter.HTTPStartContext) error { return nil }
|
||||
func (s *countingRuleSet) PostStart() error { return nil }
|
||||
func (s *countingRuleSet) Metadata() adapter.RuleSetMetadata { return adapter.RuleSetMetadata{} }
|
||||
func (s *countingRuleSet) ExtractIPSet() []*netipx.IPSet { return nil }
|
||||
func (s *countingRuleSet) IncRef() { s.refs.Add(1) }
|
||||
func (s *countingRuleSet) DecRef() {
|
||||
if s.refs.Add(-1) < 0 {
|
||||
panic("rule-set: negative refs")
|
||||
}
|
||||
}
|
||||
func (s *countingRuleSet) Cleanup() {}
|
||||
func (s *countingRuleSet) RegisterCallback(adapter.RuleSetUpdateCallback) *list.Element[adapter.RuleSetUpdateCallback] {
|
||||
return nil
|
||||
}
|
||||
func (s *countingRuleSet) UnregisterCallback(*list.Element[adapter.RuleSetUpdateCallback]) {}
|
||||
func (s *countingRuleSet) Close() error { return nil }
|
||||
func (s *countingRuleSet) Match(*adapter.InboundContext) bool { return true }
|
||||
func (s *countingRuleSet) String() string { return s.name }
|
||||
func (s *countingRuleSet) RefCount() int32 { return s.refs.Load() }
|
||||
|
||||
func TestRuleSetItemCloseReleasesRefs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
firstSet := &countingRuleSet{name: "first"}
|
||||
secondSet := &countingRuleSet{name: "second"}
|
||||
item := NewRuleSetItem(&ruleSetItemTestRouter{
|
||||
ruleSets: map[string]adapter.RuleSet{
|
||||
"first": firstSet,
|
||||
"second": secondSet,
|
||||
},
|
||||
}, []string{"first", "second"}, false, false)
|
||||
|
||||
require.NoError(t, item.Start())
|
||||
require.EqualValues(t, 1, firstSet.RefCount())
|
||||
require.EqualValues(t, 1, secondSet.RefCount())
|
||||
|
||||
require.NoError(t, item.Close())
|
||||
require.Zero(t, firstSet.RefCount())
|
||||
require.Zero(t, secondSet.RefCount())
|
||||
|
||||
require.NoError(t, item.Close())
|
||||
require.Zero(t, firstSet.RefCount())
|
||||
require.Zero(t, secondSet.RefCount())
|
||||
}
|
||||
|
||||
func TestRuleSetItemStartRollbackOnFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
firstSet := &countingRuleSet{name: "first"}
|
||||
item := NewRuleSetItem(&ruleSetItemTestRouter{
|
||||
ruleSets: map[string]adapter.RuleSet{
|
||||
"first": firstSet,
|
||||
},
|
||||
}, []string{"first", "missing"}, false, false)
|
||||
|
||||
err := item.Start()
|
||||
require.ErrorContains(t, err, "rule-set not found: missing")
|
||||
require.Zero(t, firstSet.RefCount())
|
||||
require.Empty(t, item.setList)
|
||||
}
|
||||
|
||||
func TestRuleSetItemRestartKeepsBalancedRefs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
firstSet := &countingRuleSet{name: "first"}
|
||||
item := NewRuleSetItem(&ruleSetItemTestRouter{
|
||||
ruleSets: map[string]adapter.RuleSet{
|
||||
"first": firstSet,
|
||||
},
|
||||
}, []string{"first"}, false, false)
|
||||
|
||||
require.NoError(t, item.Start())
|
||||
require.EqualValues(t, 1, firstSet.RefCount())
|
||||
|
||||
require.NoError(t, item.Start())
|
||||
require.EqualValues(t, 1, firstSet.RefCount())
|
||||
|
||||
require.NoError(t, item.Close())
|
||||
require.Zero(t, firstSet.RefCount())
|
||||
}
|
||||
Reference in New Issue
Block a user