diff --git a/dns/router.go b/dns/router.go index fe8783766..870498a84 100644 --- a/dns/router.go +++ b/dns/router.go @@ -50,6 +50,7 @@ type Router struct { platformInterface adapter.PlatformInterface legacyAddressFilterMode bool rulesAccess sync.RWMutex + closing bool ruleSetCallbacks []dnsRuleSetCallback runtimeRuleError error deprecatedReported bool @@ -126,15 +127,16 @@ func (r *Router) Start(stage adapter.StartStage) error { func (r *Router) Close() error { monitor := taskmonitor.New(r.logger, C.StopTimeout) r.rulesAccess.Lock() + r.closing = true callbacks := r.ruleSetCallbacks r.ruleSetCallbacks = nil runtimeRules := r.rules r.rules = nil r.runtimeRuleError = nil - r.rulesAccess.Unlock() for _, callback := range callbacks { callback.ruleSet.UnregisterCallback(callback.element) } + r.rulesAccess.Unlock() var err error for i, rule := range runtimeRules { monitor.Start("close dns rule[", i, "]") @@ -147,8 +149,14 @@ func (r *Router) Close() error { } func (r *Router) rebuildRules(startRules bool) error { + if r.isClosing() { + return nil + } newRules, legacyAddressFilterMode, err := r.buildRules(startRules) if err != nil { + if r.isClosing() { + return nil + } return err } shouldReportDeprecated := startRules && @@ -156,6 +164,11 @@ func (r *Router) rebuildRules(startRules bool) error { !r.deprecatedReported && common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() }) r.rulesAccess.Lock() + if r.closing { + r.rulesAccess.Unlock() + closeRules(newRules) + return nil + } oldRules := r.rules r.rules = newRules r.legacyAddressFilterMode = legacyAddressFilterMode @@ -171,6 +184,12 @@ func (r *Router) rebuildRules(startRules bool) error { return nil } +func (r *Router) isClosing() bool { + r.rulesAccess.RLock() + defer r.rulesAccess.RUnlock() + return r.closing +} + func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) { router := service.FromContext[adapter.Router](r.ctx) legacyAddressFilterMode, err := resolveLegacyAddressFilterMode(router, r.rawRules) diff --git a/dns/router_test.go b/dns/router_test.go index 1f66f1757..ccc377ea8 100644 --- a/dns/router_test.go +++ b/dns/router_test.go @@ -5,6 +5,7 @@ import ( "errors" "net" "net/netip" + "sync" "testing" "time" @@ -114,8 +115,9 @@ func (r *fakeRouter) AppendTracker(adapter.ConnectionTracker) {} func (r *fakeRouter) ResetNetwork() {} type fakeRuleSet struct { + access sync.Mutex metadata adapter.RuleSetMetadata - callbacks []adapter.RuleSetUpdateCallback + callbacks list.List[adapter.RuleSetUpdateCallback] } func (s *fakeRuleSet) Name() string { return "fake-rule-set" } @@ -127,20 +129,34 @@ func (s *fakeRuleSet) IncRef() func (s *fakeRuleSet) DecRef() {} func (s *fakeRuleSet) Cleanup() {} func (s *fakeRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback) *list.Element[adapter.RuleSetUpdateCallback] { - s.callbacks = append(s.callbacks, callback) - return nil + s.access.Lock() + defer s.access.Unlock() + return s.callbacks.PushBack(callback) } -func (s *fakeRuleSet) UnregisterCallback(*list.Element[adapter.RuleSetUpdateCallback]) {} -func (s *fakeRuleSet) Close() error { return nil } -func (s *fakeRuleSet) Match(*adapter.InboundContext) bool { return true } -func (s *fakeRuleSet) String() string { return "fake-rule-set" } +func (s *fakeRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSetUpdateCallback]) { + s.access.Lock() + defer s.access.Unlock() + s.callbacks.Remove(element) +} +func (s *fakeRuleSet) Close() error { return nil } +func (s *fakeRuleSet) Match(*adapter.InboundContext) bool { return true } +func (s *fakeRuleSet) String() string { return "fake-rule-set" } func (s *fakeRuleSet) updateMetadata(metadata adapter.RuleSetMetadata) { + s.access.Lock() s.metadata = metadata - for _, callback := range s.callbacks { + callbacks := s.callbacks.Array() + s.access.Unlock() + for _, callback := range callbacks { callback(s) } } +func (s *fakeRuleSet) snapshotCallbacks() []adapter.RuleSetUpdateCallback { + s.access.Lock() + defer s.access.Unlock() + return s.callbacks.Array() +} + func (m *fakeDeprecatedManager) ReportDeprecated(feature deprecated.Note) { m.features = append(m.features, feature) } @@ -483,6 +499,72 @@ func TestRuleSetUpdateSetsRuntimeErrorWhenRebuildFails(t *testing.T) { require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response") } +func TestCloseIgnoresSnapshottedRuleSetCallback(t *testing.T) { + t.Parallel() + + fakeSet := &fakeRuleSet{} + ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{ + ruleSets: map[string]adapter.RuleSet{ + "dynamic-set": fakeSet, + }, + }) + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + router := newTestRouterWithContext(t, ctx, []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + RuleSet: badoption.Listable[string]{"dynamic-set"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "default"}, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + IPIsPrivate: true, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "default"}, + }, + }, + }, + }, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + }, + }, &fakeDNSClient{ + lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { + response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60) + return MessageToAddresses(response), response, nil + }, + }) + + callbacks := fakeSet.snapshotCallbacks() + require.Len(t, callbacks, 1) + + require.NoError(t, router.Close()) + require.Empty(t, fakeSet.snapshotCallbacks()) + + fakeSet.metadata = adapter.RuleSetMetadata{ + ContainsDNSQueryTypeRule: true, + } + callbacks[0](fakeSet) + + router.rulesAccess.RLock() + defer router.rulesAccess.RUnlock() + require.True(t, router.closing) + require.Nil(t, router.rules) + require.Empty(t, router.ruleSetCallbacks) + require.NoError(t, router.runtimeRuleError) +} + func TestLookupLegacyModeDefersDirectDestinationIPMatch(t *testing.T) { t.Parallel() diff --git a/option/dns_record.go b/option/dns_record.go index b2d73fa00..2d4fb7888 100644 --- a/option/dns_record.go +++ b/option/dns_record.go @@ -2,6 +2,7 @@ package option import ( "encoding/base64" + "strings" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" @@ -11,6 +12,8 @@ import ( "github.com/miekg/dns" ) +const defaultDNSRecordTTL uint32 = 3600 + type DNSRCode int func (r DNSRCode) MarshalJSON() ([]byte, error) { @@ -76,7 +79,7 @@ func (o *DNSRecordOptions) UnmarshalJSON(data []byte) error { if err == nil { return o.unmarshalBase64(binary) } - record, err := dns.NewRR(stringValue) + record, err := parseDNSRecord(stringValue) if err != nil { return err } @@ -90,6 +93,17 @@ func (o *DNSRecordOptions) UnmarshalJSON(data []byte) error { return nil } +func parseDNSRecord(stringValue string) (dns.RR, error) { + if len(stringValue) > 0 && stringValue[len(stringValue)-1] != '\n' { + stringValue += "\n" + } + parser := dns.NewZoneParser(strings.NewReader(stringValue), "", "") + parser.SetDefaultTTL(defaultDNSRecordTTL) + parser.SetIncludeAllowed(true) + record, _ := parser.Next() + return record, parser.Err() +} + func (o *DNSRecordOptions) unmarshalBase64(binary []byte) error { record, _, err := dns.UnpackRR(binary, 0) if err != nil { diff --git a/option/dns_record_test.go b/option/dns_record_test.go index f30f6a682..cb26f9b01 100644 --- a/option/dns_record_test.go +++ b/option/dns_record_test.go @@ -14,19 +14,33 @@ func mustRecordOptions(t *testing.T, record string) DNSRecordOptions { return value } -func TestDNSRecordOptionsUnmarshalJSONAcceptsRelativeOwnerNames(t *testing.T) { +func TestDNSRecordOptionsUnmarshalJSONAcceptsFullyQualifiedNames(t *testing.T) { t.Parallel() for _, record := range []string{ - "example.com A 1.1.1.1", - "@ IN A 1.1.1.1", - "www IN CNAME @", + "example.com. A 1.1.1.1", + "www.example.com. IN CNAME example.com.", } { value := mustRecordOptions(t, record) require.NotNil(t, value.RR) } } +func TestDNSRecordOptionsUnmarshalJSONRejectsRelativeNames(t *testing.T) { + t.Parallel() + + for _, record := range []string{ + "@ IN A 1.1.1.1", + "www IN CNAME example.com.", + "example.com. IN CNAME @", + "example.com. IN CNAME www", + } { + var value DNSRecordOptions + err := value.UnmarshalJSON([]byte(`"` + record + `"`)) + require.Error(t, err) + } +} + func TestDNSRecordOptionsMatchIgnoresTTL(t *testing.T) { t.Parallel()