mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
dns: make rule path selection rule-set aware
This commit is contained in:
@@ -67,9 +67,10 @@ type RuleSet interface {
|
||||
type RuleSetUpdateCallback func(it RuleSet)
|
||||
|
||||
type RuleSetMetadata struct {
|
||||
ContainsProcessRule bool
|
||||
ContainsWIFIRule bool
|
||||
ContainsIPCIDRRule bool
|
||||
ContainsProcessRule bool
|
||||
ContainsWIFIRule bool
|
||||
ContainsIPCIDRRule bool
|
||||
ContainsDNSQueryTypeRule bool
|
||||
}
|
||||
type HTTPStartContext struct {
|
||||
ctx context.Context
|
||||
|
||||
2
box.go
2
box.go
@@ -486,7 +486,7 @@ func (s *Box) preStart() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = adapter.Start(s.logger, adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router)
|
||||
err = adapter.Start(s.logger, adapter.StartStateStart, s.outbound, s.dnsTransport, s.network, s.connection, s.router, s.dnsRouter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
307
dns/router.go
307
dns/router.go
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
@@ -21,6 +22,7 @@ import (
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
"github.com/sagernet/sing/contrab/freelru"
|
||||
"github.com/sagernet/sing/contrab/maphash"
|
||||
"github.com/sagernet/sing/service"
|
||||
@@ -30,17 +32,27 @@ import (
|
||||
|
||||
var _ adapter.DNSRouter = (*Router)(nil)
|
||||
|
||||
type dnsRuleSetCallback struct {
|
||||
ruleSet adapter.RuleSet
|
||||
element *list.Element[adapter.RuleSetUpdateCallback]
|
||||
}
|
||||
|
||||
type Router struct {
|
||||
ctx context.Context
|
||||
logger logger.ContextLogger
|
||||
transport adapter.DNSTransportManager
|
||||
outbound adapter.OutboundManager
|
||||
client adapter.DNSClient
|
||||
rawRules []option.DNSRule
|
||||
rules []adapter.DNSRule
|
||||
defaultDomainStrategy C.DomainStrategy
|
||||
dnsReverseMapping freelru.Cache[netip.Addr, string]
|
||||
platformInterface adapter.PlatformInterface
|
||||
legacyAddressFilterMode bool
|
||||
rulesAccess sync.RWMutex
|
||||
ruleSetCallbacks []dnsRuleSetCallback
|
||||
runtimeRuleError error
|
||||
deprecatedReported bool
|
||||
}
|
||||
|
||||
func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOptions) *Router {
|
||||
@@ -49,6 +61,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp
|
||||
logger: logFactory.NewLogger("dns"),
|
||||
transport: service.FromContext[adapter.DNSTransportManager](ctx),
|
||||
outbound: service.FromContext[adapter.OutboundManager](ctx),
|
||||
rawRules: make([]option.DNSRule, 0, len(options.Rules)),
|
||||
rules: make([]adapter.DNSRule, 0, len(options.Rules)),
|
||||
defaultDomainStrategy: C.DomainStrategy(options.Strategy),
|
||||
}
|
||||
@@ -77,20 +90,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp
|
||||
}
|
||||
|
||||
func (r *Router) Initialize(rules []option.DNSRule) error {
|
||||
r.legacyAddressFilterMode = hasLegacyAddressFilterItems(rules)
|
||||
if !r.legacyAddressFilterMode {
|
||||
err := validateNonLegacyAddressFilterRules(rules)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for i, ruleOptions := range rules {
|
||||
dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true, r.legacyAddressFilterMode)
|
||||
if err != nil {
|
||||
return E.Cause(err, "parse dns rule[", i, "]")
|
||||
}
|
||||
r.rules = append(r.rules, dnsRule)
|
||||
}
|
||||
r.rawRules = append(r.rawRules[:0], rules...)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -102,16 +102,17 @@ func (r *Router) Start(stage adapter.StartStage) error {
|
||||
r.client.Start()
|
||||
monitor.Finish()
|
||||
|
||||
for i, rule := range r.rules {
|
||||
monitor.Start("initialize DNS rule[", i, "]")
|
||||
err := rule.Start()
|
||||
monitor.Finish()
|
||||
if err != nil {
|
||||
return E.Cause(err, "initialize DNS rule[", i, "]")
|
||||
}
|
||||
monitor.Start("initialize DNS rules")
|
||||
err := r.rebuildRules(true)
|
||||
monitor.Finish()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.legacyAddressFilterMode && common.Any(r.rules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() }) {
|
||||
deprecated.Report(r.ctx, deprecated.OptionLegacyDNSAddressFilter)
|
||||
monitor.Start("register DNS rule-set callbacks")
|
||||
err = r.registerRuleSetCallbacks()
|
||||
monitor.Finish()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -119,8 +120,18 @@ func (r *Router) Start(stage adapter.StartStage) error {
|
||||
|
||||
func (r *Router) Close() error {
|
||||
monitor := taskmonitor.New(r.logger, C.StopTimeout)
|
||||
r.rulesAccess.Lock()
|
||||
callbacks := r.ruleSetCallbacks
|
||||
r.ruleSetCallbacks = nil
|
||||
runtimeRules := r.rules
|
||||
r.rules = nil
|
||||
r.runtimeRuleError = nil
|
||||
r.rulesAccess.Unlock()
|
||||
for _, callback := range callbacks {
|
||||
callback.ruleSet.UnregisterCallback(callback.element)
|
||||
}
|
||||
var err error
|
||||
for i, rule := range r.rules {
|
||||
for i, rule := range runtimeRules {
|
||||
monitor.Start("close dns rule[", i, "]")
|
||||
err = E.Append(err, rule.Close(), func(err error) error {
|
||||
return E.Cause(err, "close dns rule[", i, "]")
|
||||
@@ -130,6 +141,111 @@ func (r *Router) Close() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Router) rebuildRules(startRules bool) error {
|
||||
router := service.FromContext[adapter.Router](r.ctx)
|
||||
legacyAddressFilterMode, err := resolveLegacyAddressFilterMode(router, r.rawRules)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !legacyAddressFilterMode {
|
||||
err = validateNonLegacyAddressFilterRules(r.rawRules)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
newRules := make([]adapter.DNSRule, 0, len(r.rawRules))
|
||||
for i, ruleOptions := range r.rawRules {
|
||||
dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true, legacyAddressFilterMode)
|
||||
if err != nil {
|
||||
closeRules(newRules)
|
||||
return E.Cause(err, "parse dns rule[", i, "]")
|
||||
}
|
||||
newRules = append(newRules, dnsRule)
|
||||
}
|
||||
if startRules {
|
||||
for i, rule := range newRules {
|
||||
err := rule.Start()
|
||||
if err != nil {
|
||||
closeRules(newRules)
|
||||
return E.Cause(err, "initialize DNS rule[", i, "]")
|
||||
}
|
||||
}
|
||||
}
|
||||
r.rulesAccess.Lock()
|
||||
oldRules := r.rules
|
||||
r.rules = newRules
|
||||
r.legacyAddressFilterMode = legacyAddressFilterMode
|
||||
r.runtimeRuleError = nil
|
||||
shouldReportDeprecated := legacyAddressFilterMode &&
|
||||
!r.deprecatedReported &&
|
||||
common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() })
|
||||
if shouldReportDeprecated {
|
||||
r.deprecatedReported = true
|
||||
}
|
||||
r.rulesAccess.Unlock()
|
||||
closeRules(oldRules)
|
||||
if shouldReportDeprecated {
|
||||
deprecated.Report(r.ctx, deprecated.OptionLegacyDNSAddressFilter)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func closeRules(rules []adapter.DNSRule) {
|
||||
for _, rule := range rules {
|
||||
_ = rule.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Router) registerRuleSetCallbacks() error {
|
||||
tags := referencedDNSRuleSetTags(r.rawRules)
|
||||
if len(tags) == 0 {
|
||||
return nil
|
||||
}
|
||||
r.rulesAccess.RLock()
|
||||
if len(r.ruleSetCallbacks) > 0 {
|
||||
r.rulesAccess.RUnlock()
|
||||
return nil
|
||||
}
|
||||
r.rulesAccess.RUnlock()
|
||||
router := service.FromContext[adapter.Router](r.ctx)
|
||||
if router == nil {
|
||||
return E.New("router service not found")
|
||||
}
|
||||
callbacks := make([]dnsRuleSetCallback, 0, len(tags))
|
||||
for _, tag := range tags {
|
||||
ruleSet, loaded := router.RuleSet(tag)
|
||||
if !loaded {
|
||||
for _, callback := range callbacks {
|
||||
callback.ruleSet.UnregisterCallback(callback.element)
|
||||
}
|
||||
return E.New("rule-set not found: ", tag)
|
||||
}
|
||||
element := ruleSet.RegisterCallback(func(adapter.RuleSet) {
|
||||
err := r.rebuildRules(true)
|
||||
if err != nil {
|
||||
r.rulesAccess.Lock()
|
||||
r.runtimeRuleError = err
|
||||
r.rulesAccess.Unlock()
|
||||
r.logger.Error(E.Cause(err, "rebuild DNS rules after rule-set update"))
|
||||
}
|
||||
})
|
||||
callbacks = append(callbacks, dnsRuleSetCallback{
|
||||
ruleSet: ruleSet,
|
||||
element: element,
|
||||
})
|
||||
}
|
||||
r.rulesAccess.Lock()
|
||||
if len(r.ruleSetCallbacks) == 0 {
|
||||
r.ruleSetCallbacks = callbacks
|
||||
callbacks = nil
|
||||
}
|
||||
r.rulesAccess.Unlock()
|
||||
for _, callback := range callbacks {
|
||||
callback.ruleSet.UnregisterCallback(callback.element)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) {
|
||||
metadata := adapter.ContextFrom(ctx)
|
||||
if metadata == nil {
|
||||
@@ -538,6 +654,11 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
|
||||
}
|
||||
return &responseMessage, nil
|
||||
}
|
||||
r.rulesAccess.RLock()
|
||||
defer r.rulesAccess.RUnlock()
|
||||
if r.runtimeRuleError != nil {
|
||||
return nil, r.runtimeRuleError
|
||||
}
|
||||
r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String()))
|
||||
var (
|
||||
response *mDNS.Msg
|
||||
@@ -639,6 +760,11 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
|
||||
}
|
||||
|
||||
func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
|
||||
r.rulesAccess.RLock()
|
||||
defer r.rulesAccess.RUnlock()
|
||||
if r.runtimeRuleError != nil {
|
||||
return nil, r.runtimeRuleError
|
||||
}
|
||||
var (
|
||||
responseAddrs []netip.Addr
|
||||
err error
|
||||
@@ -769,28 +895,124 @@ func (r *Router) ResetNetwork() {
|
||||
}
|
||||
}
|
||||
|
||||
func hasLegacyAddressFilterItems(rules []option.DNSRule) bool {
|
||||
return common.Any(rules, hasLegacyAddressFilterItemsInRule)
|
||||
}
|
||||
|
||||
func hasLegacyAddressFilterItemsInRule(rule option.DNSRule) bool {
|
||||
switch rule.Type {
|
||||
case "", C.RuleTypeDefault:
|
||||
return hasLegacyAddressFilterItemsInDefaultRule(rule.DefaultOptions)
|
||||
case C.RuleTypeLogical:
|
||||
return common.Any(rule.LogicalOptions.Rules, hasLegacyAddressFilterItemsInRule)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func hasLegacyAddressFilterItemsInDefaultRule(rule option.DefaultDNSRule) bool {
|
||||
func hasDirectLegacyAddressFilterItemsInDefaultRule(rule option.DefaultDNSRule) bool {
|
||||
if rule.IPAcceptAny || rule.RuleSetIPCIDRAcceptEmpty {
|
||||
return true
|
||||
}
|
||||
return !rule.MatchResponse && (len(rule.IPCIDR) > 0 || rule.IPIsPrivate)
|
||||
}
|
||||
|
||||
func hasResponseMatchFields(rule option.DefaultDNSRule) bool {
|
||||
return rule.ResponseRcode != nil ||
|
||||
len(rule.ResponseAnswer) > 0 ||
|
||||
len(rule.ResponseNs) > 0 ||
|
||||
len(rule.ResponseExtra) > 0
|
||||
}
|
||||
|
||||
func defaultRuleForcesNewDNSPath(rule option.DefaultDNSRule) bool {
|
||||
return rule.MatchResponse ||
|
||||
hasResponseMatchFields(rule) ||
|
||||
rule.Action == C.RuleActionTypeEvaluate ||
|
||||
rule.IPVersion > 0 ||
|
||||
len(rule.QueryType) > 0
|
||||
}
|
||||
|
||||
func resolveLegacyAddressFilterMode(router adapter.Router, rules []option.DNSRule) (bool, error) {
|
||||
forceNew, needsLegacy, err := dnsRuleModeRequirements(router, rules)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if forceNew {
|
||||
return false, nil
|
||||
}
|
||||
return needsLegacy, nil
|
||||
}
|
||||
|
||||
func dnsRuleModeRequirements(router adapter.Router, rules []option.DNSRule) (bool, bool, error) {
|
||||
var forceNew bool
|
||||
var needsLegacy bool
|
||||
for i, rule := range rules {
|
||||
ruleForceNew, ruleNeedsLegacy, err := dnsRuleModeRequirementsInRule(router, rule)
|
||||
if err != nil {
|
||||
return false, false, E.Cause(err, "dns rule[", i, "]")
|
||||
}
|
||||
forceNew = forceNew || ruleForceNew
|
||||
needsLegacy = needsLegacy || ruleNeedsLegacy
|
||||
}
|
||||
return forceNew, needsLegacy, nil
|
||||
}
|
||||
|
||||
func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule) (bool, bool, error) {
|
||||
switch rule.Type {
|
||||
case "", C.RuleTypeDefault:
|
||||
return dnsRuleModeRequirementsInDefaultRule(router, rule.DefaultOptions)
|
||||
case C.RuleTypeLogical:
|
||||
forceNew := dnsRuleActionType(rule) == C.RuleActionTypeEvaluate
|
||||
var needsLegacy bool
|
||||
for i, subRule := range rule.LogicalOptions.Rules {
|
||||
subForceNew, subNeedsLegacy, err := dnsRuleModeRequirementsInRule(router, subRule)
|
||||
if err != nil {
|
||||
return false, false, E.Cause(err, "sub rule[", i, "]")
|
||||
}
|
||||
forceNew = forceNew || subForceNew
|
||||
needsLegacy = needsLegacy || subNeedsLegacy
|
||||
}
|
||||
return forceNew, needsLegacy, nil
|
||||
default:
|
||||
return false, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.DefaultDNSRule) (bool, bool, error) {
|
||||
forceNew := defaultRuleForcesNewDNSPath(rule)
|
||||
needsLegacy := hasDirectLegacyAddressFilterItemsInDefaultRule(rule)
|
||||
if len(rule.RuleSet) == 0 {
|
||||
return forceNew, needsLegacy, nil
|
||||
}
|
||||
if router == nil {
|
||||
return false, false, E.New("router service not found")
|
||||
}
|
||||
for _, tag := range rule.RuleSet {
|
||||
ruleSet, loaded := router.RuleSet(tag)
|
||||
if !loaded {
|
||||
return false, false, E.New("rule-set not found: ", tag)
|
||||
}
|
||||
metadata := ruleSet.Metadata()
|
||||
forceNew = forceNew || metadata.ContainsDNSQueryTypeRule
|
||||
if !rule.RuleSetIPCIDRMatchSource && metadata.ContainsIPCIDRRule {
|
||||
needsLegacy = true
|
||||
}
|
||||
}
|
||||
return forceNew, needsLegacy, nil
|
||||
}
|
||||
|
||||
func referencedDNSRuleSetTags(rules []option.DNSRule) []string {
|
||||
tagMap := make(map[string]bool)
|
||||
var walkRule func(rule option.DNSRule)
|
||||
walkRule = func(rule option.DNSRule) {
|
||||
switch rule.Type {
|
||||
case "", C.RuleTypeDefault:
|
||||
for _, tag := range rule.DefaultOptions.RuleSet {
|
||||
tagMap[tag] = true
|
||||
}
|
||||
case C.RuleTypeLogical:
|
||||
for _, subRule := range rule.LogicalOptions.Rules {
|
||||
walkRule(subRule)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, rule := range rules {
|
||||
walkRule(rule)
|
||||
}
|
||||
tags := make([]string, 0, len(tagMap))
|
||||
for tag := range tagMap {
|
||||
if tag != "" {
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
}
|
||||
return tags
|
||||
}
|
||||
|
||||
func validateNonLegacyAddressFilterRules(rules []option.DNSRule) error {
|
||||
var seenEvaluate bool
|
||||
for i, rule := range rules {
|
||||
@@ -832,10 +1054,7 @@ func validateNonLegacyAddressFilterRuleTree(rule option.DNSRule) (bool, error) {
|
||||
}
|
||||
|
||||
func validateNonLegacyAddressFilterDefaultRule(rule option.DefaultDNSRule) (bool, error) {
|
||||
hasResponseRecords := rule.ResponseRcode != nil ||
|
||||
len(rule.ResponseAnswer) > 0 ||
|
||||
len(rule.ResponseNs) > 0 ||
|
||||
len(rule.ResponseExtra) > 0
|
||||
hasResponseRecords := hasResponseMatchFields(rule)
|
||||
if hasResponseRecords && !rule.MatchResponse {
|
||||
return false, E.New("response_* items require match_response")
|
||||
}
|
||||
|
||||
@@ -6,17 +6,23 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/experimental/deprecated"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
rulepkg "github.com/sagernet/sing-box/route/rule"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
"github.com/sagernet/sing/service"
|
||||
|
||||
mDNS "github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go4.org/netipx"
|
||||
)
|
||||
|
||||
type fakeDNSTransport struct {
|
||||
@@ -72,6 +78,69 @@ type fakeDeprecatedManager struct {
|
||||
features []deprecated.Note
|
||||
}
|
||||
|
||||
type fakeRouter struct {
|
||||
ruleSets map[string]adapter.RuleSet
|
||||
}
|
||||
|
||||
func (r *fakeRouter) Start(adapter.StartStage) error { return nil }
|
||||
func (r *fakeRouter) Close() error { return nil }
|
||||
func (r *fakeRouter) PreMatch(metadata adapter.InboundContext, _ tun.DirectRouteContext, _ time.Duration, _ bool) (tun.DirectRouteDestination, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeRouter) RouteConnection(context.Context, net.Conn, adapter.InboundContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeRouter) RoutePacketConnection(context.Context, N.PacketConn, adapter.InboundContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeRouter) RouteConnectionEx(context.Context, net.Conn, adapter.InboundContext, N.CloseHandlerFunc) {
|
||||
}
|
||||
|
||||
func (r *fakeRouter) RoutePacketConnectionEx(context.Context, N.PacketConn, adapter.InboundContext, N.CloseHandlerFunc) {
|
||||
}
|
||||
|
||||
func (r *fakeRouter) RuleSet(tag string) (adapter.RuleSet, bool) {
|
||||
ruleSet, loaded := r.ruleSets[tag]
|
||||
return ruleSet, loaded
|
||||
}
|
||||
func (r *fakeRouter) Rules() []adapter.Rule { return nil }
|
||||
func (r *fakeRouter) NeedFindProcess() bool { return false }
|
||||
func (r *fakeRouter) NeedFindNeighbor() bool { return false }
|
||||
func (r *fakeRouter) NeighborResolver() adapter.NeighborResolver { return nil }
|
||||
func (r *fakeRouter) AppendTracker(adapter.ConnectionTracker) {}
|
||||
func (r *fakeRouter) ResetNetwork() {}
|
||||
|
||||
type fakeRuleSet struct {
|
||||
metadata adapter.RuleSetMetadata
|
||||
callbacks []adapter.RuleSetUpdateCallback
|
||||
}
|
||||
|
||||
func (s *fakeRuleSet) Name() string { return "fake-rule-set" }
|
||||
func (s *fakeRuleSet) StartContext(context.Context, *adapter.HTTPStartContext) error { return nil }
|
||||
func (s *fakeRuleSet) PostStart() error { return nil }
|
||||
func (s *fakeRuleSet) Metadata() adapter.RuleSetMetadata { return s.metadata }
|
||||
func (s *fakeRuleSet) ExtractIPSet() []*netipx.IPSet { return nil }
|
||||
func (s *fakeRuleSet) IncRef() {}
|
||||
func (s *fakeRuleSet) DecRef() {}
|
||||
func (s *fakeRuleSet) Cleanup() {}
|
||||
func (s *fakeRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback) *list.Element[adapter.RuleSetUpdateCallback] {
|
||||
s.callbacks = append(s.callbacks, callback)
|
||||
return nil
|
||||
}
|
||||
func (s *fakeRuleSet) UnregisterCallback(*list.Element[adapter.RuleSetUpdateCallback]) {}
|
||||
func (s *fakeRuleSet) Close() error { return nil }
|
||||
func (s *fakeRuleSet) Match(*adapter.InboundContext) bool { return true }
|
||||
func (s *fakeRuleSet) String() string { return "fake-rule-set" }
|
||||
func (s *fakeRuleSet) updateMetadata(metadata adapter.RuleSetMetadata) {
|
||||
s.metadata = metadata
|
||||
for _, callback := range s.callbacks {
|
||||
callback(s)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *fakeDeprecatedManager) ReportDeprecated(feature deprecated.Note) {
|
||||
m.features = append(m.features, feature)
|
||||
}
|
||||
@@ -108,18 +177,25 @@ func (c *fakeDNSClient) Lookup(_ context.Context, transport adapter.DNSTransport
|
||||
func (c *fakeDNSClient) ClearCache() {}
|
||||
|
||||
func newTestRouter(t *testing.T, rules []option.DNSRule, transportManager *fakeDNSTransportManager, client *fakeDNSClient) *Router {
|
||||
return newTestRouterWithContext(t, context.Background(), rules, transportManager, client)
|
||||
}
|
||||
|
||||
func newTestRouterWithContext(t *testing.T, ctx context.Context, rules []option.DNSRule, transportManager *fakeDNSTransportManager, client *fakeDNSClient) *Router {
|
||||
t.Helper()
|
||||
router := &Router{
|
||||
ctx: context.Background(),
|
||||
ctx: ctx,
|
||||
logger: log.NewNOPFactory().NewLogger("dns"),
|
||||
transport: transportManager,
|
||||
client: client,
|
||||
rawRules: make([]option.DNSRule, 0, len(rules)),
|
||||
rules: make([]adapter.DNSRule, 0, len(rules)),
|
||||
defaultDomainStrategy: C.DomainStrategyAsIS,
|
||||
}
|
||||
if rules != nil {
|
||||
err := router.Initialize(rules)
|
||||
require.NoError(t, err)
|
||||
err = router.Start(adapter.StartStateStart)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
return router
|
||||
}
|
||||
@@ -202,6 +278,187 @@ func TestValidateNewDNSRules_RequireMatchResponseForDirectIPCIDR(t *testing.T) {
|
||||
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
|
||||
}
|
||||
|
||||
func TestStartNewModeRejectsDirectLegacyRuleWhenRuleSetForcesNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ruleSet, err := rulepkg.NewRuleSet(ctx, log.NewNOPFactory().NewLogger("router"), option.RuleSet{
|
||||
Type: C.RuleSetTypeInline,
|
||||
Tag: "query-set",
|
||||
InlineOptions: option.PlainRuleSet{
|
||||
Rules: []option.HeadlessRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultHeadlessRule{
|
||||
QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(mDNS.TypeA)},
|
||||
},
|
||||
}},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
ctx = service.ContextWith[adapter.Router](ctx, &fakeRouter{
|
||||
ruleSets: map[string]adapter.RuleSet{
|
||||
"query-set": ruleSet,
|
||||
},
|
||||
})
|
||||
|
||||
router := &Router{
|
||||
ctx: ctx,
|
||||
logger: log.NewNOPFactory().NewLogger("dns"),
|
||||
transport: &fakeDNSTransportManager{},
|
||||
client: &fakeDNSClient{},
|
||||
rawRules: make([]option.DNSRule, 0, 2),
|
||||
rules: make([]adapter.DNSRule, 0, 2),
|
||||
defaultDomainStrategy: C.DomainStrategyAsIS,
|
||||
}
|
||||
err = router.Initialize([]option.DNSRule{
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
RuleSet: badoption.Listable[string]{"query-set"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
IPIsPrivate: true,
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "private"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = router.Start(adapter.StartStateStart)
|
||||
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
|
||||
}
|
||||
|
||||
func TestLookupLegacyModeDefersRuleSetDestinationIPMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ruleSet, err := rulepkg.NewRuleSet(ctx, log.NewNOPFactory().NewLogger("router"), option.RuleSet{
|
||||
Type: C.RuleSetTypeInline,
|
||||
Tag: "legacy-ipcidr-set",
|
||||
InlineOptions: option.PlainRuleSet{
|
||||
Rules: []option.HeadlessRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultHeadlessRule{
|
||||
IPCIDR: badoption.Listable[string]{"10.0.0.0/8"},
|
||||
},
|
||||
}},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
ctx = service.ContextWith[adapter.Router](ctx, &fakeRouter{
|
||||
ruleSets: map[string]adapter.RuleSet{
|
||||
"legacy-ipcidr-set": ruleSet,
|
||||
},
|
||||
})
|
||||
|
||||
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
||||
privateTransport := &fakeDNSTransport{tag: "private", transportType: C.DNSTypeUDP}
|
||||
router := newTestRouterWithContext(t, ctx, []option.DNSRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
RuleSet: badoption.Listable[string]{"legacy-ipcidr-set"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "private"},
|
||||
},
|
||||
},
|
||||
}}, &fakeDNSTransportManager{
|
||||
defaultTransport: defaultTransport,
|
||||
transports: map[string]adapter.DNSTransport{
|
||||
"default": defaultTransport,
|
||||
"private": privateTransport,
|
||||
},
|
||||
}, &fakeDNSClient{
|
||||
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
|
||||
require.Equal(t, "example.com", domain)
|
||||
require.Equal(t, "private", transport.Tag())
|
||||
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60)
|
||||
return MessageToAddresses(response), response, nil
|
||||
},
|
||||
})
|
||||
|
||||
require.True(t, router.legacyAddressFilterMode)
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
|
||||
LookupStrategy: C.DomainStrategyIPv4Only,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
|
||||
}
|
||||
|
||||
func TestRuleSetUpdateSetsRuntimeErrorWhenRebuildFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fakeSet := &fakeRuleSet{}
|
||||
ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{
|
||||
ruleSets: map[string]adapter.RuleSet{
|
||||
"dynamic-set": fakeSet,
|
||||
},
|
||||
})
|
||||
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
||||
router := newTestRouterWithContext(t, ctx, []option.DNSRule{
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
RuleSet: badoption.Listable[string]{"dynamic-set"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
IPIsPrivate: true,
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, &fakeDNSTransportManager{
|
||||
defaultTransport: defaultTransport,
|
||||
transports: map[string]adapter.DNSTransport{
|
||||
"default": defaultTransport,
|
||||
},
|
||||
}, &fakeDNSClient{
|
||||
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
|
||||
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60)
|
||||
return MessageToAddresses(response), response, nil
|
||||
},
|
||||
})
|
||||
|
||||
require.True(t, router.legacyAddressFilterMode)
|
||||
|
||||
fakeSet.updateMetadata(adapter.RuleSetMetadata{
|
||||
ContainsDNSQueryTypeRule: true,
|
||||
})
|
||||
|
||||
_, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
|
||||
}
|
||||
|
||||
func TestLookupLegacyModeDefersDirectDestinationIPMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -47,7 +47,12 @@ type legacyResponseConstraint struct {
|
||||
forbiddenSet *netipx.IPSet
|
||||
}
|
||||
|
||||
type legacyRuleMatchStateSet [16]legacyResponseFormula
|
||||
const (
|
||||
legacyRuleMatchDeferredDestinationAddress ruleMatchState = 1 << 4
|
||||
legacyRuleMatchStateCount = 32
|
||||
)
|
||||
|
||||
type legacyRuleMatchStateSet [legacyRuleMatchStateCount]legacyResponseFormula
|
||||
|
||||
var (
|
||||
legacyAllIPSet = func() *netipx.IPSet {
|
||||
@@ -350,7 +355,7 @@ func (s legacyRuleMatchStateSet) isEmpty() bool {
|
||||
|
||||
func (s legacyRuleMatchStateSet) merge(other legacyRuleMatchStateSet) legacyRuleMatchStateSet {
|
||||
var merged legacyRuleMatchStateSet
|
||||
for state := ruleMatchState(0); state < 16; state++ {
|
||||
for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ {
|
||||
merged[state] = s[state].or(other[state])
|
||||
}
|
||||
return merged
|
||||
@@ -361,11 +366,11 @@ func (s legacyRuleMatchStateSet) combine(other legacyRuleMatchStateSet) legacyRu
|
||||
return legacyRuleMatchStateSet{}
|
||||
}
|
||||
var combined legacyRuleMatchStateSet
|
||||
for left := ruleMatchState(0); left < 16; left++ {
|
||||
for left := ruleMatchState(0); left < legacyRuleMatchStateCount; left++ {
|
||||
if s[left].isFalse() {
|
||||
continue
|
||||
}
|
||||
for right := ruleMatchState(0); right < 16; right++ {
|
||||
for right := ruleMatchState(0); right < legacyRuleMatchStateCount; right++ {
|
||||
if other[right].isFalse() {
|
||||
continue
|
||||
}
|
||||
@@ -380,7 +385,7 @@ func (s legacyRuleMatchStateSet) withBase(base ruleMatchState) legacyRuleMatchSt
|
||||
return legacyRuleMatchStateSet{}
|
||||
}
|
||||
var withBase legacyRuleMatchStateSet
|
||||
for state := ruleMatchState(0); state < 16; state++ {
|
||||
for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ {
|
||||
if s[state].isFalse() {
|
||||
continue
|
||||
}
|
||||
@@ -391,7 +396,7 @@ func (s legacyRuleMatchStateSet) withBase(base ruleMatchState) legacyRuleMatchSt
|
||||
|
||||
func (s legacyRuleMatchStateSet) filter(allowed func(ruleMatchState) bool) legacyRuleMatchStateSet {
|
||||
var filtered legacyRuleMatchStateSet
|
||||
for state := ruleMatchState(0); state < 16; state++ {
|
||||
for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ {
|
||||
if s[state].isFalse() {
|
||||
continue
|
||||
}
|
||||
@@ -404,7 +409,7 @@ func (s legacyRuleMatchStateSet) filter(allowed func(ruleMatchState) bool) legac
|
||||
|
||||
func (s legacyRuleMatchStateSet) addBit(bit ruleMatchState) legacyRuleMatchStateSet {
|
||||
var withBit legacyRuleMatchStateSet
|
||||
for state := ruleMatchState(0); state < 16; state++ {
|
||||
for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ {
|
||||
if s[state].isFalse() {
|
||||
continue
|
||||
}
|
||||
@@ -422,7 +427,7 @@ func (s legacyRuleMatchStateSet) branchOnBit(bit ruleMatchState, condition legac
|
||||
}
|
||||
var branched legacyRuleMatchStateSet
|
||||
conditionFalse := condition.not()
|
||||
for state := ruleMatchState(0); state < 16; state++ {
|
||||
for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ {
|
||||
if s[state].isFalse() {
|
||||
continue
|
||||
}
|
||||
@@ -444,7 +449,7 @@ func (s legacyRuleMatchStateSet) andFormula(formula legacyResponseFormula) legac
|
||||
return s
|
||||
}
|
||||
var result legacyRuleMatchStateSet
|
||||
for state := ruleMatchState(0); state < 16; state++ {
|
||||
for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ {
|
||||
if s[state].isFalse() {
|
||||
continue
|
||||
}
|
||||
@@ -588,7 +593,7 @@ func (r *abstractDefaultRule) legacyMatchStatesWithBase(metadata *adapter.Inboun
|
||||
}
|
||||
if r.legacyDestinationIPCIDRMatchesDestination(metadata) {
|
||||
metadata.DidMatch = true
|
||||
stateSet = stateSet.branchOnBit(ruleMatchDestinationAddress, legacyDestinationIPFormula(r.destinationIPCIDRItems, metadata))
|
||||
stateSet = stateSet.branchOnBit(legacyRuleMatchDeferredDestinationAddress, legacyDestinationIPFormula(r.destinationIPCIDRItems, metadata))
|
||||
}
|
||||
if len(r.destinationPortItems) > 0 {
|
||||
metadata.DidMatch = true
|
||||
@@ -608,7 +613,7 @@ func (r *abstractDefaultRule) legacyMatchStatesWithBase(metadata *adapter.Inboun
|
||||
if r.ruleSetItem != nil {
|
||||
metadata.DidMatch = true
|
||||
var merged legacyRuleMatchStateSet
|
||||
for state := ruleMatchState(0); state < 16; state++ {
|
||||
for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ {
|
||||
if stateSet[state].isFalse() {
|
||||
continue
|
||||
}
|
||||
@@ -627,6 +632,9 @@ func (r *abstractDefaultRule) legacyMatchStatesWithBase(metadata *adapter.Inboun
|
||||
if r.legacyRequiresDestinationAddressMatch(metadata) && !state.has(ruleMatchDestinationAddress) {
|
||||
return false
|
||||
}
|
||||
if r.legacyRequiresDeferredDestinationAddressMatch(metadata) && !state.has(legacyRuleMatchDeferredDestinationAddress) {
|
||||
return false
|
||||
}
|
||||
if len(r.destinationPortItems) > 0 && !state.has(ruleMatchDestinationPort) {
|
||||
return false
|
||||
}
|
||||
@@ -647,7 +655,11 @@ func (r *abstractDefaultRule) legacyDestinationIPCIDRMatchesDestination(metadata
|
||||
}
|
||||
|
||||
func (r *abstractDefaultRule) legacyRequiresDestinationAddressMatch(metadata *adapter.InboundContext) bool {
|
||||
return len(r.destinationAddressItems) > 0 || r.legacyDestinationIPCIDRMatchesDestination(metadata)
|
||||
return len(r.destinationAddressItems) > 0
|
||||
}
|
||||
|
||||
func (r *abstractDefaultRule) legacyRequiresDeferredDestinationAddressMatch(metadata *adapter.InboundContext) bool {
|
||||
return r.legacyDestinationIPCIDRMatchesDestination(metadata)
|
||||
}
|
||||
|
||||
func (r *abstractLogicalRule) legacyMatchStates(metadata *adapter.InboundContext) legacyRuleMatchStateSet {
|
||||
|
||||
@@ -69,3 +69,7 @@ func isWIFIHeadlessRule(rule option.DefaultHeadlessRule) bool {
|
||||
func isIPCIDRHeadlessRule(rule option.DefaultHeadlessRule) bool {
|
||||
return len(rule.IPCIDR) > 0 || rule.IPSet != nil
|
||||
}
|
||||
|
||||
func isDNSQueryTypeHeadlessRule(rule option.DefaultHeadlessRule) bool {
|
||||
return len(rule.QueryType) > 0
|
||||
}
|
||||
|
||||
@@ -141,6 +141,7 @@ func (s *LocalRuleSet) reloadRules(headlessRules []option.HeadlessRule) error {
|
||||
metadata.ContainsProcessRule = HasHeadlessRule(headlessRules, isProcessHeadlessRule)
|
||||
metadata.ContainsWIFIRule = HasHeadlessRule(headlessRules, isWIFIHeadlessRule)
|
||||
metadata.ContainsIPCIDRRule = HasHeadlessRule(headlessRules, isIPCIDRHeadlessRule)
|
||||
metadata.ContainsDNSQueryTypeRule = HasHeadlessRule(headlessRules, isDNSQueryTypeHeadlessRule)
|
||||
s.access.Lock()
|
||||
s.rules = rules
|
||||
s.metadata = metadata
|
||||
|
||||
@@ -193,6 +193,7 @@ func (s *RemoteRuleSet) loadBytes(content []byte) error {
|
||||
s.metadata.ContainsProcessRule = HasHeadlessRule(plainRuleSet.Rules, isProcessHeadlessRule)
|
||||
s.metadata.ContainsWIFIRule = HasHeadlessRule(plainRuleSet.Rules, isWIFIHeadlessRule)
|
||||
s.metadata.ContainsIPCIDRRule = HasHeadlessRule(plainRuleSet.Rules, isIPCIDRHeadlessRule)
|
||||
s.metadata.ContainsDNSQueryTypeRule = HasHeadlessRule(plainRuleSet.Rules, isDNSQueryTypeHeadlessRule)
|
||||
s.rules = rules
|
||||
callbacks := s.callbacks.Array()
|
||||
s.access.Unlock()
|
||||
|
||||
Reference in New Issue
Block a user