diff --git a/dns/router.go b/dns/router.go index a3a874a81..ba29281f8 100644 --- a/dns/router.go +++ b/dns/router.go @@ -1097,7 +1097,6 @@ func referencedDNSRuleSetTags(rules []option.DNSRule) []string { } func validateNonLegacyAddressFilterRules(rules []option.DNSRule) error { - var seenEvaluate bool for i, rule := range rules { consumesResponse, err := validateNonLegacyAddressFilterRuleTree(rule) if err != nil { @@ -1107,12 +1106,6 @@ func validateNonLegacyAddressFilterRules(rules []option.DNSRule) error { if action == C.RuleActionTypeEvaluate && consumesResponse { return E.New("dns rule[", i, "]: evaluate rule cannot consume response state") } - if consumesResponse && !seenEvaluate { - return E.New("dns rule[", i, "]: response matching requires a preceding top-level evaluate rule") - } - if action == C.RuleActionTypeEvaluate { - seenEvaluate = true - } } return nil } diff --git a/dns/router_test.go b/dns/router_test.go index d28718c23..fb6dad0e4 100644 --- a/dns/router_test.go +++ b/dns/router_test.go @@ -118,16 +118,32 @@ type fakeRuleSet struct { access sync.Mutex metadata adapter.RuleSetMetadata callbacks list.List[adapter.RuleSetUpdateCallback] + refs int } func (s *fakeRuleSet) Name() string { return "fake-rule-set" } func (s *fakeRuleSet) StartContext(context.Context, *adapter.HTTPStartContext) error { return nil } func (s *fakeRuleSet) PostStart() error { return nil } -func (s *fakeRuleSet) Metadata() adapter.RuleSetMetadata { return s.metadata } -func (s *fakeRuleSet) ExtractIPSet() []*netipx.IPSet { return nil } -func (s *fakeRuleSet) IncRef() {} -func (s *fakeRuleSet) DecRef() {} -func (s *fakeRuleSet) Cleanup() {} +func (s *fakeRuleSet) Metadata() adapter.RuleSetMetadata { + s.access.Lock() + defer s.access.Unlock() + return s.metadata +} +func (s *fakeRuleSet) ExtractIPSet() []*netipx.IPSet { return nil } +func (s *fakeRuleSet) IncRef() { + s.access.Lock() + defer s.access.Unlock() + s.refs++ +} +func (s *fakeRuleSet) DecRef() { + s.access.Lock() + defer s.access.Unlock() + s.refs-- + if s.refs < 0 { + panic("rule-set: negative refs") + } +} +func (s *fakeRuleSet) Cleanup() {} func (s *fakeRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback) *list.Element[adapter.RuleSetUpdateCallback] { s.access.Lock() defer s.access.Unlock() @@ -157,6 +173,12 @@ func (s *fakeRuleSet) snapshotCallbacks() []adapter.RuleSetUpdateCallback { return s.callbacks.Array() } +func (s *fakeRuleSet) refCount() int { + s.access.Lock() + defer s.access.Unlock() + return s.refs +} + func (m *fakeDeprecatedManager) ReportDeprecated(feature deprecated.Note) { m.features = append(m.features, feature) } @@ -294,6 +316,21 @@ func TestValidateNewDNSRules_RequireMatchResponseForDirectIPCIDR(t *testing.T) { require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response") } +func TestValidateNewDNSRules_AllowMatchResponseWithoutEvaluate(t *testing.T) { + t.Parallel() + + err := validateNonLegacyAddressFilterRules([]option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + MatchResponse: true, + IPCIDR: badoption.Listable[string]{"1.1.1.0/24"}, + }, + }, + }}) + require.NoError(t, err) +} + func TestInitializeRejectsInvalidDNSRuleParseError(t *testing.T) { t.Parallel() @@ -442,6 +479,46 @@ func TestLookupLegacyModeDefersRuleSetDestinationIPMatch(t *testing.T) { require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses) } +func TestRuleSetUpdateReleasesOldRuleSetRefs(t *testing.T) { + t.Parallel() + + fakeSet := &fakeRuleSet{} + ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{ + ruleSets: map[string]adapter.RuleSet{ + "dynamic-set": fakeSet, + }, + }) + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + router := newTestRouterWithContext(t, ctx, []option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + RuleSet: badoption.Listable[string]{"dynamic-set"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "default"}, + }, + }, + }}, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + }, + }, &fakeDNSClient{}) + + require.Equal(t, 1, fakeSet.refCount()) + + fakeSet.updateMetadata(adapter.RuleSetMetadata{}) + require.Equal(t, 1, fakeSet.refCount()) + + fakeSet.updateMetadata(adapter.RuleSetMetadata{}) + require.Equal(t, 1, fakeSet.refCount()) + + require.NoError(t, router.Close()) + require.Zero(t, fakeSet.refCount()) +} + func TestRuleSetUpdateSetsRuntimeErrorWhenRebuildFails(t *testing.T) { t.Parallel() diff --git a/route/rule/rule_item_rule_set.go b/route/rule/rule_item_rule_set.go index 3467843ba..013649435 100644 --- a/route/rule/rule_item_rule_set.go +++ b/route/rule/rule_item_rule_set.go @@ -29,9 +29,11 @@ func NewRuleSetItem(router adapter.Router, tagList []string, ipCIDRMatchSource b } func (r *RuleSetItem) Start() error { + _ = r.Close() for _, tag := range r.tagList { ruleSet, loaded := r.router.RuleSet(tag) if !loaded { + _ = r.Close() return E.New("rule-set not found: ", tag) } ruleSet.IncRef() @@ -40,6 +42,15 @@ func (r *RuleSetItem) Start() error { return nil } +func (r *RuleSetItem) Close() error { + for _, ruleSet := range r.setList { + ruleSet.DecRef() + } + clear(r.setList) + r.setList = nil + return nil +} + func (r *RuleSetItem) Match(metadata *adapter.InboundContext) bool { return !r.matchStates(metadata).isEmpty() } diff --git a/route/rule/rule_item_rule_set_test.go b/route/rule/rule_item_rule_set_test.go new file mode 100644 index 000000000..3d5959a39 --- /dev/null +++ b/route/rule/rule_item_rule_set_test.go @@ -0,0 +1,134 @@ +package rule + +import ( + "context" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-tun" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/x/list" + + "github.com/stretchr/testify/require" + "go4.org/netipx" +) + +type ruleSetItemTestRouter struct { + ruleSets map[string]adapter.RuleSet +} + +func (r *ruleSetItemTestRouter) Start(adapter.StartStage) error { return nil } +func (r *ruleSetItemTestRouter) Close() error { return nil } +func (r *ruleSetItemTestRouter) PreMatch(adapter.InboundContext, tun.DirectRouteContext, time.Duration, bool) (tun.DirectRouteDestination, error) { + return nil, nil +} +func (r *ruleSetItemTestRouter) RouteConnection(context.Context, net.Conn, adapter.InboundContext) error { + return nil +} +func (r *ruleSetItemTestRouter) RoutePacketConnection(context.Context, N.PacketConn, adapter.InboundContext) error { + return nil +} +func (r *ruleSetItemTestRouter) RouteConnectionEx(context.Context, net.Conn, adapter.InboundContext, N.CloseHandlerFunc) { +} +func (r *ruleSetItemTestRouter) RoutePacketConnectionEx(context.Context, N.PacketConn, adapter.InboundContext, N.CloseHandlerFunc) { +} +func (r *ruleSetItemTestRouter) RuleSet(tag string) (adapter.RuleSet, bool) { + ruleSet, loaded := r.ruleSets[tag] + return ruleSet, loaded +} +func (r *ruleSetItemTestRouter) Rules() []adapter.Rule { return nil } +func (r *ruleSetItemTestRouter) NeedFindProcess() bool { return false } +func (r *ruleSetItemTestRouter) NeedFindNeighbor() bool { return false } +func (r *ruleSetItemTestRouter) NeighborResolver() adapter.NeighborResolver { return nil } +func (r *ruleSetItemTestRouter) AppendTracker(adapter.ConnectionTracker) {} +func (r *ruleSetItemTestRouter) ResetNetwork() {} + +type countingRuleSet struct { + name string + refs atomic.Int32 +} + +func (s *countingRuleSet) Name() string { return s.name } +func (s *countingRuleSet) StartContext(context.Context, *adapter.HTTPStartContext) error { return nil } +func (s *countingRuleSet) PostStart() error { return nil } +func (s *countingRuleSet) Metadata() adapter.RuleSetMetadata { return adapter.RuleSetMetadata{} } +func (s *countingRuleSet) ExtractIPSet() []*netipx.IPSet { return nil } +func (s *countingRuleSet) IncRef() { s.refs.Add(1) } +func (s *countingRuleSet) DecRef() { + if s.refs.Add(-1) < 0 { + panic("rule-set: negative refs") + } +} +func (s *countingRuleSet) Cleanup() {} +func (s *countingRuleSet) RegisterCallback(adapter.RuleSetUpdateCallback) *list.Element[adapter.RuleSetUpdateCallback] { + return nil +} +func (s *countingRuleSet) UnregisterCallback(*list.Element[adapter.RuleSetUpdateCallback]) {} +func (s *countingRuleSet) Close() error { return nil } +func (s *countingRuleSet) Match(*adapter.InboundContext) bool { return true } +func (s *countingRuleSet) String() string { return s.name } +func (s *countingRuleSet) RefCount() int32 { return s.refs.Load() } + +func TestRuleSetItemCloseReleasesRefs(t *testing.T) { + t.Parallel() + + firstSet := &countingRuleSet{name: "first"} + secondSet := &countingRuleSet{name: "second"} + item := NewRuleSetItem(&ruleSetItemTestRouter{ + ruleSets: map[string]adapter.RuleSet{ + "first": firstSet, + "second": secondSet, + }, + }, []string{"first", "second"}, false, false) + + require.NoError(t, item.Start()) + require.EqualValues(t, 1, firstSet.RefCount()) + require.EqualValues(t, 1, secondSet.RefCount()) + + require.NoError(t, item.Close()) + require.Zero(t, firstSet.RefCount()) + require.Zero(t, secondSet.RefCount()) + + require.NoError(t, item.Close()) + require.Zero(t, firstSet.RefCount()) + require.Zero(t, secondSet.RefCount()) +} + +func TestRuleSetItemStartRollbackOnFailure(t *testing.T) { + t.Parallel() + + firstSet := &countingRuleSet{name: "first"} + item := NewRuleSetItem(&ruleSetItemTestRouter{ + ruleSets: map[string]adapter.RuleSet{ + "first": firstSet, + }, + }, []string{"first", "missing"}, false, false) + + err := item.Start() + require.ErrorContains(t, err, "rule-set not found: missing") + require.Zero(t, firstSet.RefCount()) + require.Empty(t, item.setList) +} + +func TestRuleSetItemRestartKeepsBalancedRefs(t *testing.T) { + t.Parallel() + + firstSet := &countingRuleSet{name: "first"} + item := NewRuleSetItem(&ruleSetItemTestRouter{ + ruleSets: map[string]adapter.RuleSet{ + "first": firstSet, + }, + }, []string{"first"}, false, false) + + require.NoError(t, item.Start()) + require.EqualValues(t, 1, firstSet.RefCount()) + + require.NoError(t, item.Start()) + require.EqualValues(t, 1, firstSet.RefCount()) + + require.NoError(t, item.Close()) + require.Zero(t, firstSet.RefCount()) +}