dns: make rule path selection rule-set aware

This commit is contained in:
世界
2026-03-26 11:30:17 +08:00
parent 259e67fca3
commit 40b9c64a0d
8 changed files with 556 additions and 61 deletions

View File

@@ -67,9 +67,10 @@ type RuleSet interface {
type RuleSetUpdateCallback func(it RuleSet)
type RuleSetMetadata struct {
ContainsProcessRule bool
ContainsWIFIRule bool
ContainsIPCIDRRule bool
ContainsProcessRule bool
ContainsWIFIRule bool
ContainsIPCIDRRule bool
ContainsDNSQueryTypeRule bool
}
type HTTPStartContext struct {
ctx context.Context

2
box.go
View File

@@ -486,7 +486,7 @@ func (s *Box) preStart() error {
if err != nil {
return err
}
err = adapter.Start(s.logger, adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router)
err = adapter.Start(s.logger, adapter.StartStateStart, s.outbound, s.dnsTransport, s.network, s.connection, s.router, s.dnsRouter)
if err != nil {
return err
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"net/netip"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
@@ -21,6 +22,7 @@ import (
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
"github.com/sagernet/sing/service"
@@ -30,17 +32,27 @@ import (
var _ adapter.DNSRouter = (*Router)(nil)
type dnsRuleSetCallback struct {
ruleSet adapter.RuleSet
element *list.Element[adapter.RuleSetUpdateCallback]
}
type Router struct {
ctx context.Context
logger logger.ContextLogger
transport adapter.DNSTransportManager
outbound adapter.OutboundManager
client adapter.DNSClient
rawRules []option.DNSRule
rules []adapter.DNSRule
defaultDomainStrategy C.DomainStrategy
dnsReverseMapping freelru.Cache[netip.Addr, string]
platformInterface adapter.PlatformInterface
legacyAddressFilterMode bool
rulesAccess sync.RWMutex
ruleSetCallbacks []dnsRuleSetCallback
runtimeRuleError error
deprecatedReported bool
}
func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOptions) *Router {
@@ -49,6 +61,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp
logger: logFactory.NewLogger("dns"),
transport: service.FromContext[adapter.DNSTransportManager](ctx),
outbound: service.FromContext[adapter.OutboundManager](ctx),
rawRules: make([]option.DNSRule, 0, len(options.Rules)),
rules: make([]adapter.DNSRule, 0, len(options.Rules)),
defaultDomainStrategy: C.DomainStrategy(options.Strategy),
}
@@ -77,20 +90,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp
}
func (r *Router) Initialize(rules []option.DNSRule) error {
r.legacyAddressFilterMode = hasLegacyAddressFilterItems(rules)
if !r.legacyAddressFilterMode {
err := validateNonLegacyAddressFilterRules(rules)
if err != nil {
return err
}
}
for i, ruleOptions := range rules {
dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true, r.legacyAddressFilterMode)
if err != nil {
return E.Cause(err, "parse dns rule[", i, "]")
}
r.rules = append(r.rules, dnsRule)
}
r.rawRules = append(r.rawRules[:0], rules...)
return nil
}
@@ -102,16 +102,17 @@ func (r *Router) Start(stage adapter.StartStage) error {
r.client.Start()
monitor.Finish()
for i, rule := range r.rules {
monitor.Start("initialize DNS rule[", i, "]")
err := rule.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize DNS rule[", i, "]")
}
monitor.Start("initialize DNS rules")
err := r.rebuildRules(true)
monitor.Finish()
if err != nil {
return err
}
if r.legacyAddressFilterMode && common.Any(r.rules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() }) {
deprecated.Report(r.ctx, deprecated.OptionLegacyDNSAddressFilter)
monitor.Start("register DNS rule-set callbacks")
err = r.registerRuleSetCallbacks()
monitor.Finish()
if err != nil {
return err
}
}
return nil
@@ -119,8 +120,18 @@ func (r *Router) Start(stage adapter.StartStage) error {
func (r *Router) Close() error {
monitor := taskmonitor.New(r.logger, C.StopTimeout)
r.rulesAccess.Lock()
callbacks := r.ruleSetCallbacks
r.ruleSetCallbacks = nil
runtimeRules := r.rules
r.rules = nil
r.runtimeRuleError = nil
r.rulesAccess.Unlock()
for _, callback := range callbacks {
callback.ruleSet.UnregisterCallback(callback.element)
}
var err error
for i, rule := range r.rules {
for i, rule := range runtimeRules {
monitor.Start("close dns rule[", i, "]")
err = E.Append(err, rule.Close(), func(err error) error {
return E.Cause(err, "close dns rule[", i, "]")
@@ -130,6 +141,111 @@ func (r *Router) Close() error {
return err
}
func (r *Router) rebuildRules(startRules bool) error {
router := service.FromContext[adapter.Router](r.ctx)
legacyAddressFilterMode, err := resolveLegacyAddressFilterMode(router, r.rawRules)
if err != nil {
return err
}
if !legacyAddressFilterMode {
err = validateNonLegacyAddressFilterRules(r.rawRules)
if err != nil {
return err
}
}
newRules := make([]adapter.DNSRule, 0, len(r.rawRules))
for i, ruleOptions := range r.rawRules {
dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true, legacyAddressFilterMode)
if err != nil {
closeRules(newRules)
return E.Cause(err, "parse dns rule[", i, "]")
}
newRules = append(newRules, dnsRule)
}
if startRules {
for i, rule := range newRules {
err := rule.Start()
if err != nil {
closeRules(newRules)
return E.Cause(err, "initialize DNS rule[", i, "]")
}
}
}
r.rulesAccess.Lock()
oldRules := r.rules
r.rules = newRules
r.legacyAddressFilterMode = legacyAddressFilterMode
r.runtimeRuleError = nil
shouldReportDeprecated := legacyAddressFilterMode &&
!r.deprecatedReported &&
common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() })
if shouldReportDeprecated {
r.deprecatedReported = true
}
r.rulesAccess.Unlock()
closeRules(oldRules)
if shouldReportDeprecated {
deprecated.Report(r.ctx, deprecated.OptionLegacyDNSAddressFilter)
}
return nil
}
func closeRules(rules []adapter.DNSRule) {
for _, rule := range rules {
_ = rule.Close()
}
}
func (r *Router) registerRuleSetCallbacks() error {
tags := referencedDNSRuleSetTags(r.rawRules)
if len(tags) == 0 {
return nil
}
r.rulesAccess.RLock()
if len(r.ruleSetCallbacks) > 0 {
r.rulesAccess.RUnlock()
return nil
}
r.rulesAccess.RUnlock()
router := service.FromContext[adapter.Router](r.ctx)
if router == nil {
return E.New("router service not found")
}
callbacks := make([]dnsRuleSetCallback, 0, len(tags))
for _, tag := range tags {
ruleSet, loaded := router.RuleSet(tag)
if !loaded {
for _, callback := range callbacks {
callback.ruleSet.UnregisterCallback(callback.element)
}
return E.New("rule-set not found: ", tag)
}
element := ruleSet.RegisterCallback(func(adapter.RuleSet) {
err := r.rebuildRules(true)
if err != nil {
r.rulesAccess.Lock()
r.runtimeRuleError = err
r.rulesAccess.Unlock()
r.logger.Error(E.Cause(err, "rebuild DNS rules after rule-set update"))
}
})
callbacks = append(callbacks, dnsRuleSetCallback{
ruleSet: ruleSet,
element: element,
})
}
r.rulesAccess.Lock()
if len(r.ruleSetCallbacks) == 0 {
r.ruleSetCallbacks = callbacks
callbacks = nil
}
r.rulesAccess.Unlock()
for _, callback := range callbacks {
callback.ruleSet.UnregisterCallback(callback.element)
}
return nil
}
func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) {
metadata := adapter.ContextFrom(ctx)
if metadata == nil {
@@ -538,6 +654,11 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
}
return &responseMessage, nil
}
r.rulesAccess.RLock()
defer r.rulesAccess.RUnlock()
if r.runtimeRuleError != nil {
return nil, r.runtimeRuleError
}
r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String()))
var (
response *mDNS.Msg
@@ -639,6 +760,11 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
}
func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
r.rulesAccess.RLock()
defer r.rulesAccess.RUnlock()
if r.runtimeRuleError != nil {
return nil, r.runtimeRuleError
}
var (
responseAddrs []netip.Addr
err error
@@ -769,28 +895,124 @@ func (r *Router) ResetNetwork() {
}
}
func hasLegacyAddressFilterItems(rules []option.DNSRule) bool {
return common.Any(rules, hasLegacyAddressFilterItemsInRule)
}
func hasLegacyAddressFilterItemsInRule(rule option.DNSRule) bool {
switch rule.Type {
case "", C.RuleTypeDefault:
return hasLegacyAddressFilterItemsInDefaultRule(rule.DefaultOptions)
case C.RuleTypeLogical:
return common.Any(rule.LogicalOptions.Rules, hasLegacyAddressFilterItemsInRule)
default:
return false
}
}
func hasLegacyAddressFilterItemsInDefaultRule(rule option.DefaultDNSRule) bool {
func hasDirectLegacyAddressFilterItemsInDefaultRule(rule option.DefaultDNSRule) bool {
if rule.IPAcceptAny || rule.RuleSetIPCIDRAcceptEmpty {
return true
}
return !rule.MatchResponse && (len(rule.IPCIDR) > 0 || rule.IPIsPrivate)
}
func hasResponseMatchFields(rule option.DefaultDNSRule) bool {
return rule.ResponseRcode != nil ||
len(rule.ResponseAnswer) > 0 ||
len(rule.ResponseNs) > 0 ||
len(rule.ResponseExtra) > 0
}
func defaultRuleForcesNewDNSPath(rule option.DefaultDNSRule) bool {
return rule.MatchResponse ||
hasResponseMatchFields(rule) ||
rule.Action == C.RuleActionTypeEvaluate ||
rule.IPVersion > 0 ||
len(rule.QueryType) > 0
}
func resolveLegacyAddressFilterMode(router adapter.Router, rules []option.DNSRule) (bool, error) {
forceNew, needsLegacy, err := dnsRuleModeRequirements(router, rules)
if err != nil {
return false, err
}
if forceNew {
return false, nil
}
return needsLegacy, nil
}
func dnsRuleModeRequirements(router adapter.Router, rules []option.DNSRule) (bool, bool, error) {
var forceNew bool
var needsLegacy bool
for i, rule := range rules {
ruleForceNew, ruleNeedsLegacy, err := dnsRuleModeRequirementsInRule(router, rule)
if err != nil {
return false, false, E.Cause(err, "dns rule[", i, "]")
}
forceNew = forceNew || ruleForceNew
needsLegacy = needsLegacy || ruleNeedsLegacy
}
return forceNew, needsLegacy, nil
}
func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule) (bool, bool, error) {
switch rule.Type {
case "", C.RuleTypeDefault:
return dnsRuleModeRequirementsInDefaultRule(router, rule.DefaultOptions)
case C.RuleTypeLogical:
forceNew := dnsRuleActionType(rule) == C.RuleActionTypeEvaluate
var needsLegacy bool
for i, subRule := range rule.LogicalOptions.Rules {
subForceNew, subNeedsLegacy, err := dnsRuleModeRequirementsInRule(router, subRule)
if err != nil {
return false, false, E.Cause(err, "sub rule[", i, "]")
}
forceNew = forceNew || subForceNew
needsLegacy = needsLegacy || subNeedsLegacy
}
return forceNew, needsLegacy, nil
default:
return false, false, nil
}
}
func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.DefaultDNSRule) (bool, bool, error) {
forceNew := defaultRuleForcesNewDNSPath(rule)
needsLegacy := hasDirectLegacyAddressFilterItemsInDefaultRule(rule)
if len(rule.RuleSet) == 0 {
return forceNew, needsLegacy, nil
}
if router == nil {
return false, false, E.New("router service not found")
}
for _, tag := range rule.RuleSet {
ruleSet, loaded := router.RuleSet(tag)
if !loaded {
return false, false, E.New("rule-set not found: ", tag)
}
metadata := ruleSet.Metadata()
forceNew = forceNew || metadata.ContainsDNSQueryTypeRule
if !rule.RuleSetIPCIDRMatchSource && metadata.ContainsIPCIDRRule {
needsLegacy = true
}
}
return forceNew, needsLegacy, nil
}
func referencedDNSRuleSetTags(rules []option.DNSRule) []string {
tagMap := make(map[string]bool)
var walkRule func(rule option.DNSRule)
walkRule = func(rule option.DNSRule) {
switch rule.Type {
case "", C.RuleTypeDefault:
for _, tag := range rule.DefaultOptions.RuleSet {
tagMap[tag] = true
}
case C.RuleTypeLogical:
for _, subRule := range rule.LogicalOptions.Rules {
walkRule(subRule)
}
}
}
for _, rule := range rules {
walkRule(rule)
}
tags := make([]string, 0, len(tagMap))
for tag := range tagMap {
if tag != "" {
tags = append(tags, tag)
}
}
return tags
}
func validateNonLegacyAddressFilterRules(rules []option.DNSRule) error {
var seenEvaluate bool
for i, rule := range rules {
@@ -832,10 +1054,7 @@ func validateNonLegacyAddressFilterRuleTree(rule option.DNSRule) (bool, error) {
}
func validateNonLegacyAddressFilterDefaultRule(rule option.DefaultDNSRule) (bool, error) {
hasResponseRecords := rule.ResponseRcode != nil ||
len(rule.ResponseAnswer) > 0 ||
len(rule.ResponseNs) > 0 ||
len(rule.ResponseExtra) > 0
hasResponseRecords := hasResponseMatchFields(rule)
if hasResponseRecords && !rule.MatchResponse {
return false, E.New("response_* items require match_response")
}

View File

@@ -6,17 +6,23 @@ import (
"net"
"net/netip"
"testing"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/deprecated"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
rulepkg "github.com/sagernet/sing-box/route/rule"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/json/badoption"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
mDNS "github.com/miekg/dns"
"github.com/stretchr/testify/require"
"go4.org/netipx"
)
type fakeDNSTransport struct {
@@ -72,6 +78,69 @@ type fakeDeprecatedManager struct {
features []deprecated.Note
}
type fakeRouter struct {
ruleSets map[string]adapter.RuleSet
}
func (r *fakeRouter) Start(adapter.StartStage) error { return nil }
func (r *fakeRouter) Close() error { return nil }
func (r *fakeRouter) PreMatch(metadata adapter.InboundContext, _ tun.DirectRouteContext, _ time.Duration, _ bool) (tun.DirectRouteDestination, error) {
return nil, nil
}
func (r *fakeRouter) RouteConnection(context.Context, net.Conn, adapter.InboundContext) error {
return nil
}
func (r *fakeRouter) RoutePacketConnection(context.Context, N.PacketConn, adapter.InboundContext) error {
return nil
}
func (r *fakeRouter) RouteConnectionEx(context.Context, net.Conn, adapter.InboundContext, N.CloseHandlerFunc) {
}
func (r *fakeRouter) RoutePacketConnectionEx(context.Context, N.PacketConn, adapter.InboundContext, N.CloseHandlerFunc) {
}
func (r *fakeRouter) RuleSet(tag string) (adapter.RuleSet, bool) {
ruleSet, loaded := r.ruleSets[tag]
return ruleSet, loaded
}
func (r *fakeRouter) Rules() []adapter.Rule { return nil }
func (r *fakeRouter) NeedFindProcess() bool { return false }
func (r *fakeRouter) NeedFindNeighbor() bool { return false }
func (r *fakeRouter) NeighborResolver() adapter.NeighborResolver { return nil }
func (r *fakeRouter) AppendTracker(adapter.ConnectionTracker) {}
func (r *fakeRouter) ResetNetwork() {}
type fakeRuleSet struct {
metadata adapter.RuleSetMetadata
callbacks []adapter.RuleSetUpdateCallback
}
func (s *fakeRuleSet) Name() string { return "fake-rule-set" }
func (s *fakeRuleSet) StartContext(context.Context, *adapter.HTTPStartContext) error { return nil }
func (s *fakeRuleSet) PostStart() error { return nil }
func (s *fakeRuleSet) Metadata() adapter.RuleSetMetadata { return s.metadata }
func (s *fakeRuleSet) ExtractIPSet() []*netipx.IPSet { return nil }
func (s *fakeRuleSet) IncRef() {}
func (s *fakeRuleSet) DecRef() {}
func (s *fakeRuleSet) Cleanup() {}
func (s *fakeRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback) *list.Element[adapter.RuleSetUpdateCallback] {
s.callbacks = append(s.callbacks, callback)
return nil
}
func (s *fakeRuleSet) UnregisterCallback(*list.Element[adapter.RuleSetUpdateCallback]) {}
func (s *fakeRuleSet) Close() error { return nil }
func (s *fakeRuleSet) Match(*adapter.InboundContext) bool { return true }
func (s *fakeRuleSet) String() string { return "fake-rule-set" }
func (s *fakeRuleSet) updateMetadata(metadata adapter.RuleSetMetadata) {
s.metadata = metadata
for _, callback := range s.callbacks {
callback(s)
}
}
func (m *fakeDeprecatedManager) ReportDeprecated(feature deprecated.Note) {
m.features = append(m.features, feature)
}
@@ -108,18 +177,25 @@ func (c *fakeDNSClient) Lookup(_ context.Context, transport adapter.DNSTransport
func (c *fakeDNSClient) ClearCache() {}
func newTestRouter(t *testing.T, rules []option.DNSRule, transportManager *fakeDNSTransportManager, client *fakeDNSClient) *Router {
return newTestRouterWithContext(t, context.Background(), rules, transportManager, client)
}
func newTestRouterWithContext(t *testing.T, ctx context.Context, rules []option.DNSRule, transportManager *fakeDNSTransportManager, client *fakeDNSClient) *Router {
t.Helper()
router := &Router{
ctx: context.Background(),
ctx: ctx,
logger: log.NewNOPFactory().NewLogger("dns"),
transport: transportManager,
client: client,
rawRules: make([]option.DNSRule, 0, len(rules)),
rules: make([]adapter.DNSRule, 0, len(rules)),
defaultDomainStrategy: C.DomainStrategyAsIS,
}
if rules != nil {
err := router.Initialize(rules)
require.NoError(t, err)
err = router.Start(adapter.StartStateStart)
require.NoError(t, err)
}
return router
}
@@ -202,6 +278,187 @@ func TestValidateNewDNSRules_RequireMatchResponseForDirectIPCIDR(t *testing.T) {
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
}
func TestStartNewModeRejectsDirectLegacyRuleWhenRuleSetForcesNew(t *testing.T) {
t.Parallel()
ctx := context.Background()
ruleSet, err := rulepkg.NewRuleSet(ctx, log.NewNOPFactory().NewLogger("router"), option.RuleSet{
Type: C.RuleSetTypeInline,
Tag: "query-set",
InlineOptions: option.PlainRuleSet{
Rules: []option.HeadlessRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultHeadlessRule{
QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(mDNS.TypeA)},
},
}},
},
})
require.NoError(t, err)
ctx = service.ContextWith[adapter.Router](ctx, &fakeRouter{
ruleSets: map[string]adapter.RuleSet{
"query-set": ruleSet,
},
})
router := &Router{
ctx: ctx,
logger: log.NewNOPFactory().NewLogger("dns"),
transport: &fakeDNSTransportManager{},
client: &fakeDNSClient{},
rawRules: make([]option.DNSRule, 0, 2),
rules: make([]adapter.DNSRule, 0, 2),
defaultDomainStrategy: C.DomainStrategyAsIS,
}
err = router.Initialize([]option.DNSRule{
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
RuleSet: badoption.Listable[string]{"query-set"},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
},
},
},
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
IPIsPrivate: true,
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "private"},
},
},
},
})
require.NoError(t, err)
err = router.Start(adapter.StartStateStart)
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
}
func TestLookupLegacyModeDefersRuleSetDestinationIPMatch(t *testing.T) {
t.Parallel()
ctx := context.Background()
ruleSet, err := rulepkg.NewRuleSet(ctx, log.NewNOPFactory().NewLogger("router"), option.RuleSet{
Type: C.RuleSetTypeInline,
Tag: "legacy-ipcidr-set",
InlineOptions: option.PlainRuleSet{
Rules: []option.HeadlessRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultHeadlessRule{
IPCIDR: badoption.Listable[string]{"10.0.0.0/8"},
},
}},
},
})
require.NoError(t, err)
ctx = service.ContextWith[adapter.Router](ctx, &fakeRouter{
ruleSets: map[string]adapter.RuleSet{
"legacy-ipcidr-set": ruleSet,
},
})
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
privateTransport := &fakeDNSTransport{tag: "private", transportType: C.DNSTypeUDP}
router := newTestRouterWithContext(t, ctx, []option.DNSRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
RuleSet: badoption.Listable[string]{"legacy-ipcidr-set"},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "private"},
},
},
}}, &fakeDNSTransportManager{
defaultTransport: defaultTransport,
transports: map[string]adapter.DNSTransport{
"default": defaultTransport,
"private": privateTransport,
},
}, &fakeDNSClient{
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
require.Equal(t, "example.com", domain)
require.Equal(t, "private", transport.Tag())
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60)
return MessageToAddresses(response), response, nil
},
})
require.True(t, router.legacyAddressFilterMode)
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
LookupStrategy: C.DomainStrategyIPv4Only,
})
require.NoError(t, err)
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
}
func TestRuleSetUpdateSetsRuntimeErrorWhenRebuildFails(t *testing.T) {
t.Parallel()
fakeSet := &fakeRuleSet{}
ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{
ruleSets: map[string]adapter.RuleSet{
"dynamic-set": fakeSet,
},
})
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
router := newTestRouterWithContext(t, ctx, []option.DNSRule{
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
RuleSet: badoption.Listable[string]{"dynamic-set"},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
},
},
},
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
IPIsPrivate: true,
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
},
},
},
}, &fakeDNSTransportManager{
defaultTransport: defaultTransport,
transports: map[string]adapter.DNSTransport{
"default": defaultTransport,
},
}, &fakeDNSClient{
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60)
return MessageToAddresses(response), response, nil
},
})
require.True(t, router.legacyAddressFilterMode)
fakeSet.updateMetadata(adapter.RuleSetMetadata{
ContainsDNSQueryTypeRule: true,
})
_, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
}
func TestLookupLegacyModeDefersDirectDestinationIPMatch(t *testing.T) {
t.Parallel()

View File

@@ -47,7 +47,12 @@ type legacyResponseConstraint struct {
forbiddenSet *netipx.IPSet
}
type legacyRuleMatchStateSet [16]legacyResponseFormula
const (
legacyRuleMatchDeferredDestinationAddress ruleMatchState = 1 << 4
legacyRuleMatchStateCount = 32
)
type legacyRuleMatchStateSet [legacyRuleMatchStateCount]legacyResponseFormula
var (
legacyAllIPSet = func() *netipx.IPSet {
@@ -350,7 +355,7 @@ func (s legacyRuleMatchStateSet) isEmpty() bool {
func (s legacyRuleMatchStateSet) merge(other legacyRuleMatchStateSet) legacyRuleMatchStateSet {
var merged legacyRuleMatchStateSet
for state := ruleMatchState(0); state < 16; state++ {
for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ {
merged[state] = s[state].or(other[state])
}
return merged
@@ -361,11 +366,11 @@ func (s legacyRuleMatchStateSet) combine(other legacyRuleMatchStateSet) legacyRu
return legacyRuleMatchStateSet{}
}
var combined legacyRuleMatchStateSet
for left := ruleMatchState(0); left < 16; left++ {
for left := ruleMatchState(0); left < legacyRuleMatchStateCount; left++ {
if s[left].isFalse() {
continue
}
for right := ruleMatchState(0); right < 16; right++ {
for right := ruleMatchState(0); right < legacyRuleMatchStateCount; right++ {
if other[right].isFalse() {
continue
}
@@ -380,7 +385,7 @@ func (s legacyRuleMatchStateSet) withBase(base ruleMatchState) legacyRuleMatchSt
return legacyRuleMatchStateSet{}
}
var withBase legacyRuleMatchStateSet
for state := ruleMatchState(0); state < 16; state++ {
for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ {
if s[state].isFalse() {
continue
}
@@ -391,7 +396,7 @@ func (s legacyRuleMatchStateSet) withBase(base ruleMatchState) legacyRuleMatchSt
func (s legacyRuleMatchStateSet) filter(allowed func(ruleMatchState) bool) legacyRuleMatchStateSet {
var filtered legacyRuleMatchStateSet
for state := ruleMatchState(0); state < 16; state++ {
for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ {
if s[state].isFalse() {
continue
}
@@ -404,7 +409,7 @@ func (s legacyRuleMatchStateSet) filter(allowed func(ruleMatchState) bool) legac
func (s legacyRuleMatchStateSet) addBit(bit ruleMatchState) legacyRuleMatchStateSet {
var withBit legacyRuleMatchStateSet
for state := ruleMatchState(0); state < 16; state++ {
for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ {
if s[state].isFalse() {
continue
}
@@ -422,7 +427,7 @@ func (s legacyRuleMatchStateSet) branchOnBit(bit ruleMatchState, condition legac
}
var branched legacyRuleMatchStateSet
conditionFalse := condition.not()
for state := ruleMatchState(0); state < 16; state++ {
for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ {
if s[state].isFalse() {
continue
}
@@ -444,7 +449,7 @@ func (s legacyRuleMatchStateSet) andFormula(formula legacyResponseFormula) legac
return s
}
var result legacyRuleMatchStateSet
for state := ruleMatchState(0); state < 16; state++ {
for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ {
if s[state].isFalse() {
continue
}
@@ -588,7 +593,7 @@ func (r *abstractDefaultRule) legacyMatchStatesWithBase(metadata *adapter.Inboun
}
if r.legacyDestinationIPCIDRMatchesDestination(metadata) {
metadata.DidMatch = true
stateSet = stateSet.branchOnBit(ruleMatchDestinationAddress, legacyDestinationIPFormula(r.destinationIPCIDRItems, metadata))
stateSet = stateSet.branchOnBit(legacyRuleMatchDeferredDestinationAddress, legacyDestinationIPFormula(r.destinationIPCIDRItems, metadata))
}
if len(r.destinationPortItems) > 0 {
metadata.DidMatch = true
@@ -608,7 +613,7 @@ func (r *abstractDefaultRule) legacyMatchStatesWithBase(metadata *adapter.Inboun
if r.ruleSetItem != nil {
metadata.DidMatch = true
var merged legacyRuleMatchStateSet
for state := ruleMatchState(0); state < 16; state++ {
for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ {
if stateSet[state].isFalse() {
continue
}
@@ -627,6 +632,9 @@ func (r *abstractDefaultRule) legacyMatchStatesWithBase(metadata *adapter.Inboun
if r.legacyRequiresDestinationAddressMatch(metadata) && !state.has(ruleMatchDestinationAddress) {
return false
}
if r.legacyRequiresDeferredDestinationAddressMatch(metadata) && !state.has(legacyRuleMatchDeferredDestinationAddress) {
return false
}
if len(r.destinationPortItems) > 0 && !state.has(ruleMatchDestinationPort) {
return false
}
@@ -647,7 +655,11 @@ func (r *abstractDefaultRule) legacyDestinationIPCIDRMatchesDestination(metadata
}
func (r *abstractDefaultRule) legacyRequiresDestinationAddressMatch(metadata *adapter.InboundContext) bool {
return len(r.destinationAddressItems) > 0 || r.legacyDestinationIPCIDRMatchesDestination(metadata)
return len(r.destinationAddressItems) > 0
}
func (r *abstractDefaultRule) legacyRequiresDeferredDestinationAddressMatch(metadata *adapter.InboundContext) bool {
return r.legacyDestinationIPCIDRMatchesDestination(metadata)
}
func (r *abstractLogicalRule) legacyMatchStates(metadata *adapter.InboundContext) legacyRuleMatchStateSet {

View File

@@ -69,3 +69,7 @@ func isWIFIHeadlessRule(rule option.DefaultHeadlessRule) bool {
func isIPCIDRHeadlessRule(rule option.DefaultHeadlessRule) bool {
return len(rule.IPCIDR) > 0 || rule.IPSet != nil
}
func isDNSQueryTypeHeadlessRule(rule option.DefaultHeadlessRule) bool {
return len(rule.QueryType) > 0
}

View File

@@ -141,6 +141,7 @@ func (s *LocalRuleSet) reloadRules(headlessRules []option.HeadlessRule) error {
metadata.ContainsProcessRule = HasHeadlessRule(headlessRules, isProcessHeadlessRule)
metadata.ContainsWIFIRule = HasHeadlessRule(headlessRules, isWIFIHeadlessRule)
metadata.ContainsIPCIDRRule = HasHeadlessRule(headlessRules, isIPCIDRHeadlessRule)
metadata.ContainsDNSQueryTypeRule = HasHeadlessRule(headlessRules, isDNSQueryTypeHeadlessRule)
s.access.Lock()
s.rules = rules
s.metadata = metadata

View File

@@ -193,6 +193,7 @@ func (s *RemoteRuleSet) loadBytes(content []byte) error {
s.metadata.ContainsProcessRule = HasHeadlessRule(plainRuleSet.Rules, isProcessHeadlessRule)
s.metadata.ContainsWIFIRule = HasHeadlessRule(plainRuleSet.Rules, isWIFIHeadlessRule)
s.metadata.ContainsIPCIDRRule = HasHeadlessRule(plainRuleSet.Rules, isIPCIDRHeadlessRule)
s.metadata.ContainsDNSQueryTypeRule = HasHeadlessRule(plainRuleSet.Rules, isDNSQueryTypeHeadlessRule)
s.rules = rules
callbacks := s.callbacks.Array()
s.access.Unlock()