mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
Fix DNS record parsing and shutdown race
This commit is contained in:
@@ -50,6 +50,7 @@ type Router struct {
|
||||
platformInterface adapter.PlatformInterface
|
||||
legacyAddressFilterMode bool
|
||||
rulesAccess sync.RWMutex
|
||||
closing bool
|
||||
ruleSetCallbacks []dnsRuleSetCallback
|
||||
runtimeRuleError error
|
||||
deprecatedReported bool
|
||||
@@ -126,15 +127,16 @@ func (r *Router) Start(stage adapter.StartStage) error {
|
||||
func (r *Router) Close() error {
|
||||
monitor := taskmonitor.New(r.logger, C.StopTimeout)
|
||||
r.rulesAccess.Lock()
|
||||
r.closing = true
|
||||
callbacks := r.ruleSetCallbacks
|
||||
r.ruleSetCallbacks = nil
|
||||
runtimeRules := r.rules
|
||||
r.rules = nil
|
||||
r.runtimeRuleError = nil
|
||||
r.rulesAccess.Unlock()
|
||||
for _, callback := range callbacks {
|
||||
callback.ruleSet.UnregisterCallback(callback.element)
|
||||
}
|
||||
r.rulesAccess.Unlock()
|
||||
var err error
|
||||
for i, rule := range runtimeRules {
|
||||
monitor.Start("close dns rule[", i, "]")
|
||||
@@ -147,8 +149,14 @@ func (r *Router) Close() error {
|
||||
}
|
||||
|
||||
func (r *Router) rebuildRules(startRules bool) error {
|
||||
if r.isClosing() {
|
||||
return nil
|
||||
}
|
||||
newRules, legacyAddressFilterMode, err := r.buildRules(startRules)
|
||||
if err != nil {
|
||||
if r.isClosing() {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
shouldReportDeprecated := startRules &&
|
||||
@@ -156,6 +164,11 @@ func (r *Router) rebuildRules(startRules bool) error {
|
||||
!r.deprecatedReported &&
|
||||
common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() })
|
||||
r.rulesAccess.Lock()
|
||||
if r.closing {
|
||||
r.rulesAccess.Unlock()
|
||||
closeRules(newRules)
|
||||
return nil
|
||||
}
|
||||
oldRules := r.rules
|
||||
r.rules = newRules
|
||||
r.legacyAddressFilterMode = legacyAddressFilterMode
|
||||
@@ -171,6 +184,12 @@ func (r *Router) rebuildRules(startRules bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) isClosing() bool {
|
||||
r.rulesAccess.RLock()
|
||||
defer r.rulesAccess.RUnlock()
|
||||
return r.closing
|
||||
}
|
||||
|
||||
func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) {
|
||||
router := service.FromContext[adapter.Router](r.ctx)
|
||||
legacyAddressFilterMode, err := resolveLegacyAddressFilterMode(router, r.rawRules)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -114,8 +115,9 @@ func (r *fakeRouter) AppendTracker(adapter.ConnectionTracker) {}
|
||||
func (r *fakeRouter) ResetNetwork() {}
|
||||
|
||||
type fakeRuleSet struct {
|
||||
access sync.Mutex
|
||||
metadata adapter.RuleSetMetadata
|
||||
callbacks []adapter.RuleSetUpdateCallback
|
||||
callbacks list.List[adapter.RuleSetUpdateCallback]
|
||||
}
|
||||
|
||||
func (s *fakeRuleSet) Name() string { return "fake-rule-set" }
|
||||
@@ -127,20 +129,34 @@ func (s *fakeRuleSet) IncRef()
|
||||
func (s *fakeRuleSet) DecRef() {}
|
||||
func (s *fakeRuleSet) Cleanup() {}
|
||||
func (s *fakeRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback) *list.Element[adapter.RuleSetUpdateCallback] {
|
||||
s.callbacks = append(s.callbacks, callback)
|
||||
return nil
|
||||
s.access.Lock()
|
||||
defer s.access.Unlock()
|
||||
return s.callbacks.PushBack(callback)
|
||||
}
|
||||
func (s *fakeRuleSet) UnregisterCallback(*list.Element[adapter.RuleSetUpdateCallback]) {}
|
||||
func (s *fakeRuleSet) Close() error { return nil }
|
||||
func (s *fakeRuleSet) Match(*adapter.InboundContext) bool { return true }
|
||||
func (s *fakeRuleSet) String() string { return "fake-rule-set" }
|
||||
func (s *fakeRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSetUpdateCallback]) {
|
||||
s.access.Lock()
|
||||
defer s.access.Unlock()
|
||||
s.callbacks.Remove(element)
|
||||
}
|
||||
func (s *fakeRuleSet) Close() error { return nil }
|
||||
func (s *fakeRuleSet) Match(*adapter.InboundContext) bool { return true }
|
||||
func (s *fakeRuleSet) String() string { return "fake-rule-set" }
|
||||
func (s *fakeRuleSet) updateMetadata(metadata adapter.RuleSetMetadata) {
|
||||
s.access.Lock()
|
||||
s.metadata = metadata
|
||||
for _, callback := range s.callbacks {
|
||||
callbacks := s.callbacks.Array()
|
||||
s.access.Unlock()
|
||||
for _, callback := range callbacks {
|
||||
callback(s)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *fakeRuleSet) snapshotCallbacks() []adapter.RuleSetUpdateCallback {
|
||||
s.access.Lock()
|
||||
defer s.access.Unlock()
|
||||
return s.callbacks.Array()
|
||||
}
|
||||
|
||||
func (m *fakeDeprecatedManager) ReportDeprecated(feature deprecated.Note) {
|
||||
m.features = append(m.features, feature)
|
||||
}
|
||||
@@ -483,6 +499,72 @@ func TestRuleSetUpdateSetsRuntimeErrorWhenRebuildFails(t *testing.T) {
|
||||
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
|
||||
}
|
||||
|
||||
func TestCloseIgnoresSnapshottedRuleSetCallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fakeSet := &fakeRuleSet{}
|
||||
ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{
|
||||
ruleSets: map[string]adapter.RuleSet{
|
||||
"dynamic-set": fakeSet,
|
||||
},
|
||||
})
|
||||
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
||||
router := newTestRouterWithContext(t, ctx, []option.DNSRule{
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
RuleSet: badoption.Listable[string]{"dynamic-set"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
IPIsPrivate: true,
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, &fakeDNSTransportManager{
|
||||
defaultTransport: defaultTransport,
|
||||
transports: map[string]adapter.DNSTransport{
|
||||
"default": defaultTransport,
|
||||
},
|
||||
}, &fakeDNSClient{
|
||||
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
|
||||
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60)
|
||||
return MessageToAddresses(response), response, nil
|
||||
},
|
||||
})
|
||||
|
||||
callbacks := fakeSet.snapshotCallbacks()
|
||||
require.Len(t, callbacks, 1)
|
||||
|
||||
require.NoError(t, router.Close())
|
||||
require.Empty(t, fakeSet.snapshotCallbacks())
|
||||
|
||||
fakeSet.metadata = adapter.RuleSetMetadata{
|
||||
ContainsDNSQueryTypeRule: true,
|
||||
}
|
||||
callbacks[0](fakeSet)
|
||||
|
||||
router.rulesAccess.RLock()
|
||||
defer router.rulesAccess.RUnlock()
|
||||
require.True(t, router.closing)
|
||||
require.Nil(t, router.rules)
|
||||
require.Empty(t, router.ruleSetCallbacks)
|
||||
require.NoError(t, router.runtimeRuleError)
|
||||
}
|
||||
|
||||
func TestLookupLegacyModeDefersDirectDestinationIPMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package option
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
@@ -11,6 +12,8 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const defaultDNSRecordTTL uint32 = 3600
|
||||
|
||||
type DNSRCode int
|
||||
|
||||
func (r DNSRCode) MarshalJSON() ([]byte, error) {
|
||||
@@ -76,7 +79,7 @@ func (o *DNSRecordOptions) UnmarshalJSON(data []byte) error {
|
||||
if err == nil {
|
||||
return o.unmarshalBase64(binary)
|
||||
}
|
||||
record, err := dns.NewRR(stringValue)
|
||||
record, err := parseDNSRecord(stringValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -90,6 +93,17 @@ func (o *DNSRecordOptions) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseDNSRecord(stringValue string) (dns.RR, error) {
|
||||
if len(stringValue) > 0 && stringValue[len(stringValue)-1] != '\n' {
|
||||
stringValue += "\n"
|
||||
}
|
||||
parser := dns.NewZoneParser(strings.NewReader(stringValue), "", "")
|
||||
parser.SetDefaultTTL(defaultDNSRecordTTL)
|
||||
parser.SetIncludeAllowed(true)
|
||||
record, _ := parser.Next()
|
||||
return record, parser.Err()
|
||||
}
|
||||
|
||||
func (o *DNSRecordOptions) unmarshalBase64(binary []byte) error {
|
||||
record, _, err := dns.UnpackRR(binary, 0)
|
||||
if err != nil {
|
||||
|
||||
@@ -14,19 +14,33 @@ func mustRecordOptions(t *testing.T, record string) DNSRecordOptions {
|
||||
return value
|
||||
}
|
||||
|
||||
func TestDNSRecordOptionsUnmarshalJSONAcceptsRelativeOwnerNames(t *testing.T) {
|
||||
func TestDNSRecordOptionsUnmarshalJSONAcceptsFullyQualifiedNames(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, record := range []string{
|
||||
"example.com A 1.1.1.1",
|
||||
"@ IN A 1.1.1.1",
|
||||
"www IN CNAME @",
|
||||
"example.com. A 1.1.1.1",
|
||||
"www.example.com. IN CNAME example.com.",
|
||||
} {
|
||||
value := mustRecordOptions(t, record)
|
||||
require.NotNil(t, value.RR)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSRecordOptionsUnmarshalJSONRejectsRelativeNames(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, record := range []string{
|
||||
"@ IN A 1.1.1.1",
|
||||
"www IN CNAME example.com.",
|
||||
"example.com. IN CNAME @",
|
||||
"example.com. IN CNAME www",
|
||||
} {
|
||||
var value DNSRecordOptions
|
||||
err := value.UnmarshalJSON([]byte(`"` + record + `"`))
|
||||
require.Error(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSRecordOptionsMatchIgnoresTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user