Fix DNS record parsing and shutdown race

This commit is contained in:
世界
2026-03-26 13:01:13 +08:00
parent 5fd49b3752
commit 9a4e9c0379
4 changed files with 143 additions and 14 deletions

View File

@@ -50,6 +50,7 @@ type Router struct {
platformInterface adapter.PlatformInterface
legacyAddressFilterMode bool
rulesAccess sync.RWMutex
closing bool
ruleSetCallbacks []dnsRuleSetCallback
runtimeRuleError error
deprecatedReported bool
@@ -126,15 +127,16 @@ func (r *Router) Start(stage adapter.StartStage) error {
func (r *Router) Close() error {
monitor := taskmonitor.New(r.logger, C.StopTimeout)
r.rulesAccess.Lock()
r.closing = true
callbacks := r.ruleSetCallbacks
r.ruleSetCallbacks = nil
runtimeRules := r.rules
r.rules = nil
r.runtimeRuleError = nil
r.rulesAccess.Unlock()
for _, callback := range callbacks {
callback.ruleSet.UnregisterCallback(callback.element)
}
r.rulesAccess.Unlock()
var err error
for i, rule := range runtimeRules {
monitor.Start("close dns rule[", i, "]")
@@ -147,8 +149,14 @@ func (r *Router) Close() error {
}
func (r *Router) rebuildRules(startRules bool) error {
if r.isClosing() {
return nil
}
newRules, legacyAddressFilterMode, err := r.buildRules(startRules)
if err != nil {
if r.isClosing() {
return nil
}
return err
}
shouldReportDeprecated := startRules &&
@@ -156,6 +164,11 @@ func (r *Router) rebuildRules(startRules bool) error {
!r.deprecatedReported &&
common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() })
r.rulesAccess.Lock()
if r.closing {
r.rulesAccess.Unlock()
closeRules(newRules)
return nil
}
oldRules := r.rules
r.rules = newRules
r.legacyAddressFilterMode = legacyAddressFilterMode
@@ -171,6 +184,12 @@ func (r *Router) rebuildRules(startRules bool) error {
return nil
}
func (r *Router) isClosing() bool {
r.rulesAccess.RLock()
defer r.rulesAccess.RUnlock()
return r.closing
}
func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) {
router := service.FromContext[adapter.Router](r.ctx)
legacyAddressFilterMode, err := resolveLegacyAddressFilterMode(router, r.rawRules)

View File

@@ -5,6 +5,7 @@ import (
"errors"
"net"
"net/netip"
"sync"
"testing"
"time"
@@ -114,8 +115,9 @@ func (r *fakeRouter) AppendTracker(adapter.ConnectionTracker) {}
func (r *fakeRouter) ResetNetwork() {}
type fakeRuleSet struct {
access sync.Mutex
metadata adapter.RuleSetMetadata
callbacks []adapter.RuleSetUpdateCallback
callbacks list.List[adapter.RuleSetUpdateCallback]
}
func (s *fakeRuleSet) Name() string { return "fake-rule-set" }
@@ -127,20 +129,34 @@ func (s *fakeRuleSet) IncRef()
func (s *fakeRuleSet) DecRef() {}
func (s *fakeRuleSet) Cleanup() {}
func (s *fakeRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback) *list.Element[adapter.RuleSetUpdateCallback] {
s.callbacks = append(s.callbacks, callback)
return nil
s.access.Lock()
defer s.access.Unlock()
return s.callbacks.PushBack(callback)
}
func (s *fakeRuleSet) UnregisterCallback(*list.Element[adapter.RuleSetUpdateCallback]) {}
func (s *fakeRuleSet) Close() error { return nil }
func (s *fakeRuleSet) Match(*adapter.InboundContext) bool { return true }
func (s *fakeRuleSet) String() string { return "fake-rule-set" }
func (s *fakeRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSetUpdateCallback]) {
s.access.Lock()
defer s.access.Unlock()
s.callbacks.Remove(element)
}
func (s *fakeRuleSet) Close() error { return nil }
func (s *fakeRuleSet) Match(*adapter.InboundContext) bool { return true }
func (s *fakeRuleSet) String() string { return "fake-rule-set" }
func (s *fakeRuleSet) updateMetadata(metadata adapter.RuleSetMetadata) {
s.access.Lock()
s.metadata = metadata
for _, callback := range s.callbacks {
callbacks := s.callbacks.Array()
s.access.Unlock()
for _, callback := range callbacks {
callback(s)
}
}
func (s *fakeRuleSet) snapshotCallbacks() []adapter.RuleSetUpdateCallback {
s.access.Lock()
defer s.access.Unlock()
return s.callbacks.Array()
}
func (m *fakeDeprecatedManager) ReportDeprecated(feature deprecated.Note) {
m.features = append(m.features, feature)
}
@@ -483,6 +499,72 @@ func TestRuleSetUpdateSetsRuntimeErrorWhenRebuildFails(t *testing.T) {
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
}
func TestCloseIgnoresSnapshottedRuleSetCallback(t *testing.T) {
t.Parallel()
fakeSet := &fakeRuleSet{}
ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{
ruleSets: map[string]adapter.RuleSet{
"dynamic-set": fakeSet,
},
})
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
router := newTestRouterWithContext(t, ctx, []option.DNSRule{
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
RuleSet: badoption.Listable[string]{"dynamic-set"},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
},
},
},
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
IPIsPrivate: true,
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
},
},
},
}, &fakeDNSTransportManager{
defaultTransport: defaultTransport,
transports: map[string]adapter.DNSTransport{
"default": defaultTransport,
},
}, &fakeDNSClient{
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60)
return MessageToAddresses(response), response, nil
},
})
callbacks := fakeSet.snapshotCallbacks()
require.Len(t, callbacks, 1)
require.NoError(t, router.Close())
require.Empty(t, fakeSet.snapshotCallbacks())
fakeSet.metadata = adapter.RuleSetMetadata{
ContainsDNSQueryTypeRule: true,
}
callbacks[0](fakeSet)
router.rulesAccess.RLock()
defer router.rulesAccess.RUnlock()
require.True(t, router.closing)
require.Nil(t, router.rules)
require.Empty(t, router.ruleSetCallbacks)
require.NoError(t, router.runtimeRuleError)
}
func TestLookupLegacyModeDefersDirectDestinationIPMatch(t *testing.T) {
t.Parallel()

View File

@@ -2,6 +2,7 @@ package option
import (
"encoding/base64"
"strings"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
@@ -11,6 +12,8 @@ import (
"github.com/miekg/dns"
)
const defaultDNSRecordTTL uint32 = 3600
type DNSRCode int
func (r DNSRCode) MarshalJSON() ([]byte, error) {
@@ -76,7 +79,7 @@ func (o *DNSRecordOptions) UnmarshalJSON(data []byte) error {
if err == nil {
return o.unmarshalBase64(binary)
}
record, err := dns.NewRR(stringValue)
record, err := parseDNSRecord(stringValue)
if err != nil {
return err
}
@@ -90,6 +93,17 @@ func (o *DNSRecordOptions) UnmarshalJSON(data []byte) error {
return nil
}
func parseDNSRecord(stringValue string) (dns.RR, error) {
if len(stringValue) > 0 && stringValue[len(stringValue)-1] != '\n' {
stringValue += "\n"
}
parser := dns.NewZoneParser(strings.NewReader(stringValue), "", "")
parser.SetDefaultTTL(defaultDNSRecordTTL)
parser.SetIncludeAllowed(true)
record, _ := parser.Next()
return record, parser.Err()
}
func (o *DNSRecordOptions) unmarshalBase64(binary []byte) error {
record, _, err := dns.UnpackRR(binary, 0)
if err != nil {

View File

@@ -14,19 +14,33 @@ func mustRecordOptions(t *testing.T, record string) DNSRecordOptions {
return value
}
func TestDNSRecordOptionsUnmarshalJSONAcceptsRelativeOwnerNames(t *testing.T) {
func TestDNSRecordOptionsUnmarshalJSONAcceptsFullyQualifiedNames(t *testing.T) {
t.Parallel()
for _, record := range []string{
"example.com A 1.1.1.1",
"@ IN A 1.1.1.1",
"www IN CNAME @",
"example.com. A 1.1.1.1",
"www.example.com. IN CNAME example.com.",
} {
value := mustRecordOptions(t, record)
require.NotNil(t, value.RR)
}
}
func TestDNSRecordOptionsUnmarshalJSONRejectsRelativeNames(t *testing.T) {
t.Parallel()
for _, record := range []string{
"@ IN A 1.1.1.1",
"www IN CNAME example.com.",
"example.com. IN CNAME @",
"example.com. IN CNAME www",
} {
var value DNSRecordOptions
err := value.UnmarshalJSON([]byte(`"` + record + `"`))
require.Error(t, err)
}
}
func TestDNSRecordOptionsMatchIgnoresTTL(t *testing.T) {
t.Parallel()