From 2d08d34e0bf0bf0f9101d0f4e2388115b5b865a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 15:23:50 +0800 Subject: [PATCH] Add evaluate DNS rule action and related rule items --- adapter/inbound.go | 3 + constant/rule.go | 1 + dns/router.go | 361 ++++++++++++++++++++++-- dns/router_test.go | 338 ++++++++++++++++++++++ experimental/deprecated/constants.go | 27 ++ option/dns_record.go | 25 +- option/rule_action.go | 4 + option/rule_dns.go | 5 + route/rule/rule_action.go | 41 ++- route/rule/rule_dns.go | 63 ++++- route/rule/rule_item_response_rcode.go | 24 ++ route/rule/rule_item_response_record.go | 63 +++++ 12 files changed, 913 insertions(+), 42 deletions(-) create mode 100644 dns/router_test.go create mode 100644 route/rule/rule_item_response_rcode.go create mode 100644 route/rule/rule_item_response_record.go diff --git a/adapter/inbound.go b/adapter/inbound.go index 52af336e5..903ab9f9b 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -10,6 +10,8 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" M "github.com/sagernet/sing/common/metadata" + + "github.com/miekg/dns" ) type Inbound interface { @@ -80,6 +82,7 @@ type InboundContext struct { FallbackDelay time.Duration DestinationAddresses []netip.Addr + DNSResponse *dns.Msg SourceGeoIPCode string GeoIPCode string ProcessInfo *ConnectionOwner diff --git a/constant/rule.go b/constant/rule.go index 55cad2e13..2a5aaefda 100644 --- a/constant/rule.go +++ b/constant/rule.go @@ -29,6 +29,7 @@ const ( const ( RuleActionTypeRoute = "route" RuleActionTypeRouteOptions = "route-options" + RuleActionTypeEvaluate = "evaluate" RuleActionTypeDirect = "direct" RuleActionTypeBypass = "bypass" RuleActionTypeReject = "reject" diff --git a/dns/router.go b/dns/router.go index 4f18959b7..8e552bd52 100644 --- a/dns/router.go +++ b/dns/router.go @@ -10,6 +10,7 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/taskmonitor" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/experimental/deprecated" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" R "github.com/sagernet/sing-box/route/rule" @@ -19,6 +20,7 @@ import ( F "github.com/sagernet/sing/common/format" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/task" "github.com/sagernet/sing/contrab/freelru" "github.com/sagernet/sing/contrab/maphash" "github.com/sagernet/sing/service" @@ -29,15 +31,16 @@ import ( var _ adapter.DNSRouter = (*Router)(nil) type Router struct { - ctx context.Context - logger logger.ContextLogger - transport adapter.DNSTransportManager - outbound adapter.OutboundManager - client adapter.DNSClient - rules []adapter.DNSRule - defaultDomainStrategy C.DomainStrategy - dnsReverseMapping freelru.Cache[netip.Addr, string] - platformInterface adapter.PlatformInterface + ctx context.Context + logger logger.ContextLogger + transport adapter.DNSTransportManager + outbound adapter.OutboundManager + client adapter.DNSClient + rules []adapter.DNSRule + defaultDomainStrategy C.DomainStrategy + dnsReverseMapping freelru.Cache[netip.Addr, string] + platformInterface adapter.PlatformInterface + legacyAddressFilterMode bool } func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOptions) *Router { @@ -74,8 +77,15 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp } func (r *Router) Initialize(rules []option.DNSRule) error { + r.legacyAddressFilterMode = !hasNonLegacyAddressFilterItems(rules) + if !r.legacyAddressFilterMode { + err := validateNonLegacyAddressFilterRules(rules) + if err != nil { + return err + } + } for i, ruleOptions := range rules { - dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true) + dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true, r.legacyAddressFilterMode) if err != nil { return E.Cause(err, "parse dns rule[", i, "]") } @@ -100,6 +110,9 @@ func (r *Router) Start(stage adapter.StartStage) error { return E.Cause(err, "initialize DNS rule[", i, "]") } } + if r.legacyAddressFilterMode && common.Any(r.rules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() }) { + deprecated.Report(r.ctx, deprecated.OptionLegacyDNSAddressFilter) + } } return nil } @@ -207,6 +220,209 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, return transport, nil, -1 } +func (r *Router) applyTransportDefaults(transport adapter.DNSTransport, options *adapter.DNSQueryOptions) { + if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy { + if options.Strategy == C.DomainStrategyAsIS { + options.Strategy = legacyTransport.LegacyStrategy() + } + if !options.ClientSubnet.IsValid() { + options.ClientSubnet = legacyTransport.LegacyClientSubnet() + } + } +} + +func (r *Router) applyDNSRouteOptions(options *adapter.DNSQueryOptions, routeOptions R.RuleActionDNSRouteOptions) { + if routeOptions.Strategy != C.DomainStrategyAsIS { + options.Strategy = routeOptions.Strategy + } + if routeOptions.DisableCache { + options.DisableCache = true + } + if routeOptions.RewriteTTL != nil { + options.RewriteTTL = routeOptions.RewriteTTL + } + if routeOptions.ClientSubnet.IsValid() { + options.ClientSubnet = routeOptions.ClientSubnet + } +} + +func (r *Router) resolveDNSRoute(action *R.RuleActionDNSRoute, allowFakeIP bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, bool) { + transport, loaded := r.transport.Transport(action.Server) + if !loaded { + return nil, false + } + isFakeIP := transport.Type() == C.DNSTypeFakeIP + if isFakeIP && !allowFakeIP { + return transport, false + } + r.applyDNSRouteOptions(options, action.RuleActionDNSRouteOptions) + if isFakeIP { + options.DisableCache = true + } + r.applyTransportDefaults(transport, options) + return transport, true +} + +func (r *Router) logRuleMatch(ctx context.Context, ruleIndex int, currentRule adapter.DNSRule) { + displayRuleIndex := ruleIndex + if displayRuleIndex != -1 { + displayRuleIndex += displayRuleIndex + 1 + } + if ruleDescription := currentRule.String(); ruleDescription != "" { + r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] ", currentRule, " => ", currentRule.Action()) + } else { + r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action()) + } +} + +func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions, allowFakeIP bool) (*mDNS.Msg, adapter.DNSTransport, error) { + metadata := adapter.ContextFrom(ctx) + if metadata == nil { + panic("no context") + } + effectiveOptions := options + var savedResponse *mDNS.Msg + for currentRuleIndex, currentRule := range r.rules { + metadata.ResetRuleCache() + metadata.DNSResponse = savedResponse + metadata.DestinationAddresses = MessageToAddresses(savedResponse) + if !currentRule.Match(metadata) { + continue + } + r.logRuleMatch(ctx, currentRuleIndex, currentRule) + switch action := currentRule.Action().(type) { + case *R.RuleActionDNSRouteOptions: + r.applyDNSRouteOptions(&effectiveOptions, *action) + case *R.RuleActionEvaluate: + queryOptions := effectiveOptions + transport, loaded := r.resolveDNSRoute(&R.RuleActionDNSRoute{ + Server: action.Server, + RuleActionDNSRouteOptions: action.RuleActionDNSRouteOptions, + }, allowFakeIP, &queryOptions) + if !loaded { + if transport == nil { + r.logger.ErrorContext(ctx, "transport not found: ", action.Server) + } + continue + } + if queryOptions.Strategy == C.DomainStrategyAsIS { + queryOptions.Strategy = r.defaultDomainStrategy + } + response, err := r.client.Exchange(adapter.OverrideContext(ctx), transport, message, queryOptions, nil) + if err != nil { + r.logger.ErrorContext(ctx, E.Cause(err, "exchange failed for ", FormatQuestion(message.Question[0].String()))) + savedResponse = nil + continue + } + savedResponse = response + case *R.RuleActionDNSRoute: + queryOptions := effectiveOptions + transport, loaded := r.resolveDNSRoute(action, allowFakeIP, &queryOptions) + if !loaded { + if transport == nil { + r.logger.ErrorContext(ctx, "transport not found: ", action.Server) + } + continue + } + if queryOptions.Strategy == C.DomainStrategyAsIS { + queryOptions.Strategy = r.defaultDomainStrategy + } + response, err := r.client.Exchange(adapter.OverrideContext(ctx), transport, message, queryOptions, nil) + return response, transport, err + case *R.RuleActionReject: + switch action.Method { + case C.RuleActionRejectMethodDefault: + return &mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Id: message.Id, + Rcode: mDNS.RcodeRefused, + Response: true, + }, + Question: []mDNS.Question{message.Question[0]}, + }, nil, nil + case C.RuleActionRejectMethodDrop: + return nil, nil, tun.ErrDrop + } + case *R.RuleActionPredefined: + return action.Response(message), nil, nil + } + } + queryOptions := effectiveOptions + transport := r.transport.Default() + r.applyTransportDefaults(transport, &queryOptions) + if queryOptions.Strategy == C.DomainStrategyAsIS { + queryOptions.Strategy = r.defaultDomainStrategy + } + response, err := r.client.Exchange(adapter.OverrideContext(ctx), transport, message, queryOptions, nil) + return response, transport, err +} + +func (r *Router) lookupWithRules(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) { + var strategy C.DomainStrategy + if options.LookupStrategy != C.DomainStrategyAsIS { + strategy = options.LookupStrategy + } else { + strategy = options.Strategy + } + lookupOptions := options + if options.LookupStrategy != C.DomainStrategyAsIS { + lookupOptions.Strategy = strategy + } + if strategy == C.DomainStrategyIPv4Only { + return r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions) + } + if strategy == C.DomainStrategyIPv6Only { + return r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions) + } + var ( + response4 []netip.Addr + response6 []netip.Addr + ) + var group task.Group + group.Append("exchange4", func(ctx context.Context) error { + response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions) + if err != nil { + return err + } + response4 = response + return nil + }) + group.Append("exchange6", func(ctx context.Context) error { + response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions) + if err != nil { + return err + } + response6 = response + return nil + }) + err := group.Run(ctx) + if len(response4) == 0 && len(response6) == 0 { + return nil, err + } + return sortAddresses(response4, response6, strategy), nil +} + +func (r *Router) lookupWithRulesType(ctx context.Context, domain string, qType uint16, options adapter.DNSQueryOptions) ([]netip.Addr, error) { + request := &mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + RecursionDesired: true, + }, + Question: []mDNS.Question{{ + Name: mDNS.Fqdn(FqdnToDomain(domain)), + Qtype: qType, + Qclass: mDNS.ClassINET, + }}, + } + response, _, err := r.exchangeWithRules(adapter.OverrideContext(ctx), request, options, false) + if err != nil { + return nil, err + } + if response.Rcode != mDNS.RcodeSuccess { + return nil, RcodeError(response.Rcode) + } + return MessageToAddresses(response), nil +} + func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions) (*mDNS.Msg, error) { if len(message.Question) != 1 { r.logger.WarnContext(ctx, "bad question size: ", len(message.Question)) @@ -239,18 +455,13 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte metadata.Domain = FqdnToDomain(message.Question[0].Name) if options.Transport != nil { transport = options.Transport - if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy { - if options.Strategy == C.DomainStrategyAsIS { - options.Strategy = legacyTransport.LegacyStrategy() - } - if !options.ClientSubnet.IsValid() { - options.ClientSubnet = legacyTransport.LegacyClientSubnet() - } - } + r.applyTransportDefaults(transport, &options) if options.Strategy == C.DomainStrategyAsIS { options.Strategy = r.defaultDomainStrategy } response, err = r.client.Exchange(ctx, transport, message, options, nil) + } else if !r.legacyAddressFilterMode { + response, transport, err = r.exchangeWithRules(ctx, message, options, true) } else { var ( rule adapter.DNSRule @@ -352,18 +563,13 @@ func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQ metadata.Domain = FqdnToDomain(domain) if options.Transport != nil { transport := options.Transport - if legacyTransport, isLegacy := transport.(adapter.LegacyDNSTransport); isLegacy { - if options.Strategy == C.DomainStrategyAsIS { - options.Strategy = legacyTransport.LegacyStrategy() - } - if !options.ClientSubnet.IsValid() { - options.ClientSubnet = legacyTransport.LegacyClientSubnet() - } - } + r.applyTransportDefaults(transport, &options) if options.Strategy == C.DomainStrategyAsIS { options.Strategy = r.defaultDomainStrategy } responseAddrs, err = r.client.Lookup(ctx, transport, domain, options, nil) + } else if !r.legacyAddressFilterMode { + responseAddrs, err = r.lookupWithRules(ctx, domain, options) } else { var ( transport adapter.DNSTransport @@ -458,3 +664,106 @@ func (r *Router) ResetNetwork() { transport.Reset() } } + +func hasNonLegacyAddressFilterItems(rules []option.DNSRule) bool { + return common.Any(rules, hasNonLegacyAddressFilterItemsInRule) +} + +func hasNonLegacyAddressFilterItemsInRule(rule option.DNSRule) bool { + switch rule.Type { + case "", C.RuleTypeDefault: + return hasNonLegacyAddressFilterItemsInDefaultRule(rule.DefaultOptions) + case C.RuleTypeLogical: + action := rule.LogicalOptions.Action + return action == C.RuleActionTypeEvaluate || common.Any(rule.LogicalOptions.Rules, hasNonLegacyAddressFilterItemsInRule) + default: + return false + } +} + +func hasNonLegacyAddressFilterItemsInDefaultRule(rule option.DefaultDNSRule) bool { + action := rule.Action + return action == C.RuleActionTypeEvaluate || + rule.MatchResponse || + rule.ResponseRcode != nil || + len(rule.ResponseAnswer) > 0 || + len(rule.ResponseNs) > 0 || + len(rule.ResponseExtra) > 0 +} + +func validateNonLegacyAddressFilterRules(rules []option.DNSRule) error { + var seenEvaluate bool + for i, rule := range rules { + consumesResponse, err := validateNonLegacyAddressFilterRuleTree(rule) + if err != nil { + return E.Cause(err, "validate dns rule[", i, "]") + } + action := dnsRuleActionType(rule) + if action == C.RuleActionTypeEvaluate && consumesResponse { + return E.New("dns rule[", i, "]: evaluate rule cannot consume response state") + } + if consumesResponse && !seenEvaluate { + return E.New("dns rule[", i, "]: response matching requires a preceding top-level evaluate rule") + } + if action == C.RuleActionTypeEvaluate { + seenEvaluate = true + } + } + return nil +} + +func validateNonLegacyAddressFilterRuleTree(rule option.DNSRule) (bool, error) { + switch rule.Type { + case "", C.RuleTypeDefault: + return validateNonLegacyAddressFilterDefaultRule(rule.DefaultOptions) + case C.RuleTypeLogical: + var consumesResponse bool + for i, subRule := range rule.LogicalOptions.Rules { + subConsumesResponse, err := validateNonLegacyAddressFilterRuleTree(subRule) + if err != nil { + return false, E.Cause(err, "sub rule[", i, "]") + } + consumesResponse = consumesResponse || subConsumesResponse + } + return consumesResponse, nil + default: + return false, nil + } +} + +func validateNonLegacyAddressFilterDefaultRule(rule option.DefaultDNSRule) (bool, error) { + hasResponseRecords := rule.ResponseRcode != nil || + len(rule.ResponseAnswer) > 0 || + len(rule.ResponseNs) > 0 || + len(rule.ResponseExtra) > 0 + if hasResponseRecords && !rule.MatchResponse { + return false, E.New("response_* items require match_response") + } + if (len(rule.IPCIDR) > 0 || rule.IPIsPrivate) && !rule.MatchResponse { + return false, E.New("ip_cidr and ip_is_private require match_response in DNS evaluate mode") + } + if rule.IPAcceptAny { + return false, E.New("ip_accept_any is removed in DNS evaluate mode, use ip_cidr with match_response") + } + if rule.RuleSetIPCIDRAcceptEmpty { + return false, E.New("rule_set_ip_cidr_accept_empty is removed in DNS evaluate mode") + } + return rule.MatchResponse, nil +} + +func dnsRuleActionType(rule option.DNSRule) string { + switch rule.Type { + case "", C.RuleTypeDefault: + if rule.DefaultOptions.Action == "" { + return C.RuleActionTypeRoute + } + return rule.DefaultOptions.Action + case C.RuleTypeLogical: + if rule.LogicalOptions.Action == "" { + return C.RuleActionTypeRoute + } + return rule.LogicalOptions.Action + default: + return "" + } +} diff --git a/dns/router_test.go b/dns/router_test.go new file mode 100644 index 000000000..d7deb848f --- /dev/null +++ b/dns/router_test.go @@ -0,0 +1,338 @@ +package dns + +import ( + "context" + "errors" + "net/netip" + "testing" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/experimental/deprecated" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/json/badoption" + "github.com/sagernet/sing/service" + + mDNS "github.com/miekg/dns" + "github.com/stretchr/testify/require" +) + +type fakeDNSTransport struct { + tag string + transportType string +} + +func (t *fakeDNSTransport) Start(adapter.StartStage) error { return nil } +func (t *fakeDNSTransport) Close() error { return nil } +func (t *fakeDNSTransport) Type() string { return t.transportType } +func (t *fakeDNSTransport) Tag() string { return t.tag } +func (t *fakeDNSTransport) Dependencies() []string { return nil } +func (t *fakeDNSTransport) Reset() {} +func (t *fakeDNSTransport) Exchange(context.Context, *mDNS.Msg) (*mDNS.Msg, error) { + return nil, errors.New("unused transport exchange") +} + +type fakeDNSTransportManager struct { + defaultTransport adapter.DNSTransport + transports map[string]adapter.DNSTransport +} + +func (m *fakeDNSTransportManager) Start(adapter.StartStage) error { return nil } +func (m *fakeDNSTransportManager) Close() error { return nil } +func (m *fakeDNSTransportManager) Transports() []adapter.DNSTransport { + transports := make([]adapter.DNSTransport, 0, len(m.transports)) + for _, transport := range m.transports { + transports = append(transports, transport) + } + return transports +} + +func (m *fakeDNSTransportManager) Transport(tag string) (adapter.DNSTransport, bool) { + transport, loaded := m.transports[tag] + return transport, loaded +} +func (m *fakeDNSTransportManager) Default() adapter.DNSTransport { return m.defaultTransport } +func (m *fakeDNSTransportManager) FakeIP() adapter.FakeIPTransport { + return nil +} +func (m *fakeDNSTransportManager) Remove(string) error { return nil } +func (m *fakeDNSTransportManager) Create(context.Context, log.ContextLogger, string, string, any) error { + return errors.New("unsupported") +} + +type fakeDNSClient struct { + exchange func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) +} + +type fakeDeprecatedManager struct { + features []deprecated.Note +} + +func (m *fakeDeprecatedManager) ReportDeprecated(feature deprecated.Note) { + m.features = append(m.features, feature) +} + +func (c *fakeDNSClient) Start() {} + +func (c *fakeDNSClient) Exchange(_ context.Context, transport adapter.DNSTransport, message *mDNS.Msg, _ adapter.DNSQueryOptions, _ func([]netip.Addr) bool) (*mDNS.Msg, error) { + return c.exchange(transport, message) +} + +func (c *fakeDNSClient) Lookup(context.Context, adapter.DNSTransport, string, adapter.DNSQueryOptions, func([]netip.Addr) bool) ([]netip.Addr, error) { + return nil, errors.New("unused client lookup") +} + +func (c *fakeDNSClient) ClearCache() {} + +func newTestRouter(t *testing.T, rules []option.DNSRule, transportManager *fakeDNSTransportManager, client *fakeDNSClient) *Router { + t.Helper() + router := &Router{ + ctx: context.Background(), + logger: log.NewNOPFactory().NewLogger("dns"), + transport: transportManager, + client: client, + rules: make([]adapter.DNSRule, 0, len(rules)), + defaultDomainStrategy: C.DomainStrategyAsIS, + } + if rules != nil { + err := router.Initialize(rules) + require.NoError(t, err) + } + return router +} + +func fixedQuestion(name string, qType uint16) mDNS.Question { + return mDNS.Question{ + Name: mDNS.Fqdn(name), + Qtype: qType, + Qclass: mDNS.ClassINET, + } +} + +func mustRecord(t *testing.T, record string) option.DNSRecordOptions { + t.Helper() + var value option.DNSRecordOptions + require.NoError(t, value.UnmarshalJSON([]byte(`"`+record+`"`))) + return value +} + +func TestValidateNewDNSRules_RequireMatchResponseForDirectIPCIDR(t *testing.T) { + t.Parallel() + + err := validateNonLegacyAddressFilterRules([]option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + IPCIDR: badoption.Listable[string]{"1.1.1.0/24"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{ + Server: "default", + }, + }, + }, + }}) + require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response") +} + +func TestExchangeNewModeEvaluateMatchResponseRoute(t *testing.T) { + t.Parallel() + + transportManager := &fakeDNSTransportManager{ + defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + transports: map[string]adapter.DNSTransport{ + "upstream": &fakeDNSTransport{tag: "upstream", transportType: C.DNSTypeUDP}, + "selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP}, + "default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + }, + } + client := &fakeDNSClient{ + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { + switch transport.Tag() { + case "upstream": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil + case "selected": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil + default: + return nil, errors.New("unexpected transport") + } + }, + } + rules := []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeEvaluate, + RouteOptions: option.DNSRouteActionOptions{Server: "upstream"}, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + MatchResponse: true, + ResponseAnswer: badoption.Listable[option.DNSRecordOptions]{mustRecord(t, "example.com. IN A 1.1.1.1")}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, + }, + }, + }, + } + router := newTestRouter(t, rules, transportManager, client) + + response, err := router.Exchange(context.Background(), &mDNS.Msg{ + Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)}, + }, adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("8.8.8.8")}, MessageToAddresses(response)) +} + +func TestLookupNewModeAllowsPartialSuccess(t *testing.T) { + t.Parallel() + + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + router := newTestRouter(t, nil, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + }, + }, &fakeDNSClient{ + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { + require.Equal(t, "default", transport.Tag()) + switch message.Question[0].Qtype { + case mDNS.TypeA: + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil + case mDNS.TypeAAAA: + return nil, errors.New("ipv6 failed") + default: + return nil, errors.New("unexpected qtype") + } + }, + }) + router.legacyAddressFilterMode = false + + addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("1.1.1.1")}, addresses) +} + +func TestLookupNewModeSkipsFakeIPRule(t *testing.T) { + t.Parallel() + + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + router := newTestRouter(t, []option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "fake"}, + }, + }, + }}, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + "fake": &fakeDNSTransport{tag: "fake", transportType: C.DNSTypeFakeIP}, + }, + }, &fakeDNSClient{ + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { + require.Equal(t, "default", transport.Tag()) + if message.Question[0].Qtype == mDNS.TypeA { + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2.2.2.2")}, 60), nil + } + return FixedResponse(0, message.Question[0], nil, 60), nil + }, + }) + router.legacyAddressFilterMode = false + + addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("2.2.2.2")}, addresses) +} + +func TestLookupNewModeDoesNotUseQueryTypeRule(t *testing.T) { + t.Parallel() + + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + router := newTestRouter(t, []option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(mDNS.TypeA)}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "only-a"}, + }, + }, + }}, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + "only-a": &fakeDNSTransport{tag: "only-a", transportType: C.DNSTypeUDP}, + }, + }, &fakeDNSClient{ + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { + switch transport.Tag() { + case "default": + if message.Question[0].Qtype == mDNS.TypeA { + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("3.3.3.3")}, 60), nil + } + return FixedResponse(0, message.Question[0], nil, 60), nil + case "only-a": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60), nil + default: + return nil, errors.New("unexpected transport") + } + }, + }) + router.legacyAddressFilterMode = false + + addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("3.3.3.3")}, addresses) +} + +func TestOldModeReportsLegacyAddressFilterDeprecation(t *testing.T) { + t.Parallel() + + manager := &fakeDeprecatedManager{} + ctx := service.ContextWith[deprecated.Manager](context.Background(), manager) + router := &Router{ + ctx: ctx, + logger: log.NewNOPFactory().NewLogger("dns"), + client: &fakeDNSClient{}, + rules: make([]adapter.DNSRule, 0, 1), + defaultDomainStrategy: C.DomainStrategyAsIS, + } + err := router.Initialize([]option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + IPCIDR: badoption.Listable[string]{"1.1.1.0/24"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "default"}, + }, + }, + }}) + require.NoError(t, err) + + err = router.Start(adapter.StartStateStart) + require.NoError(t, err) + require.Len(t, manager.features, 1) + require.Equal(t, deprecated.OptionLegacyDNSAddressFilter.Name, manager.features[0].Name) +} diff --git a/experimental/deprecated/constants.go b/experimental/deprecated/constants.go index 3526cda83..d4d12cc72 100644 --- a/experimental/deprecated/constants.go +++ b/experimental/deprecated/constants.go @@ -111,6 +111,30 @@ var OptionInlineACME = Note{ MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-inline-acme-to-certificate-provider", } +var OptionIPAcceptAny = Note{ + Name: "dns-rule-ip-accept-any", + Description: "`ip_accept_any` in DNS rules", + DeprecatedVersion: "1.14.0", + ScheduledVersion: "1.16.0", + MigrationLink: "https://sing-box.sagernet.org/configuration/dns/rule/", +} + +var OptionRuleSetIPCIDRAcceptEmpty = Note{ + Name: "dns-rule-rule-set-ip-cidr-accept-empty", + Description: "`rule_set_ip_cidr_accept_empty` in DNS rules", + DeprecatedVersion: "1.14.0", + ScheduledVersion: "1.16.0", + MigrationLink: "https://sing-box.sagernet.org/configuration/dns/rule/", +} + +var OptionLegacyDNSAddressFilter = Note{ + Name: "legacy-dns-address-filter", + Description: "legacy address filter DNS rule items", + DeprecatedVersion: "1.14.0", + ScheduledVersion: "1.16.0", + MigrationLink: "https://sing-box.sagernet.org/configuration/dns/rule/", +} + var Options = []Note{ OptionLegacyDNSTransport, OptionLegacyDNSFakeIPOptions, @@ -118,4 +142,7 @@ var Options = []Note{ OptionMissingDomainResolver, OptionLegacyDomainStrategyOptions, OptionInlineACME, + OptionIPAcceptAny, + OptionRuleSetIPCIDRAcceptEmpty, + OptionLegacyDNSAddressFilter, } diff --git a/option/dns_record.go b/option/dns_record.go index fa72b61b7..c51341f1d 100644 --- a/option/dns_record.go +++ b/option/dns_record.go @@ -2,6 +2,7 @@ package option import ( "encoding/base64" + "strings" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" @@ -51,6 +52,7 @@ func (r *DNSRCode) Build() int { type DNSRecordOptions struct { dns.RR fromBase64 bool + hasTTL bool } func (o DNSRecordOptions) MarshalJSON() ([]byte, error) { @@ -76,7 +78,16 @@ func (o *DNSRecordOptions) UnmarshalJSON(data []byte) error { if err == nil { return o.unmarshalBase64(binary) } - record, err := dns.NewRR(stringValue) + parser := dns.NewZoneParser(strings.NewReader(stringValue+"\n"), "", "") + record, ok := parser.Next() + if !ok { + err = parser.Err() + if err == nil { + err = E.New("empty DNS record") + } + return err + } + err = parser.Err() if err != nil { return err } @@ -84,6 +95,7 @@ func (o *DNSRecordOptions) UnmarshalJSON(data []byte) error { a.A = M.AddrFromIP(a.A).Unmap().AsSlice() } o.RR = record + o.hasTTL = record.Header().Ttl != 0 return nil } @@ -94,9 +106,20 @@ func (o *DNSRecordOptions) unmarshalBase64(binary []byte) error { } o.RR = record o.fromBase64 = true + o.hasTTL = true return nil } func (o DNSRecordOptions) Build() dns.RR { return o.RR } + +func (o DNSRecordOptions) Match(record dns.RR) bool { + if o.RR == nil || record == nil { + return false + } + if o.hasTTL { + return o.RR.String() == record.String() + } + return dns.IsDuplicate(o.RR, record) +} diff --git a/option/rule_action.go b/option/rule_action.go index 431082552..027e80076 100644 --- a/option/rule_action.go +++ b/option/rule_action.go @@ -115,6 +115,8 @@ func (r DNSRuleAction) MarshalJSON() ([]byte, error) { case C.RuleActionTypeRoute: r.Action = "" v = r.RouteOptions + case C.RuleActionTypeEvaluate: + v = r.RouteOptions case C.RuleActionTypeRouteOptions: v = r.RouteOptionsOptions case C.RuleActionTypeReject: @@ -137,6 +139,8 @@ func (r *DNSRuleAction) UnmarshalJSONContext(ctx context.Context, data []byte) e case "", C.RuleActionTypeRoute: r.Action = C.RuleActionTypeRoute v = &r.RouteOptions + case C.RuleActionTypeEvaluate: + v = &r.RouteOptions case C.RuleActionTypeRouteOptions: v = &r.RouteOptionsOptions case C.RuleActionTypeReject: diff --git a/option/rule_dns.go b/option/rule_dns.go index 880b96ac5..bba6a0db6 100644 --- a/option/rule_dns.go +++ b/option/rule_dns.go @@ -84,6 +84,11 @@ type RawDefaultDNSRule struct { IPCIDR badoption.Listable[string] `json:"ip_cidr,omitempty"` IPIsPrivate bool `json:"ip_is_private,omitempty"` IPAcceptAny bool `json:"ip_accept_any,omitempty"` + MatchResponse bool `json:"match_response,omitempty"` + ResponseRcode *DNSRCode `json:"response_rcode,omitempty"` + ResponseAnswer badoption.Listable[DNSRecordOptions] `json:"response_answer,omitempty"` + ResponseNs badoption.Listable[DNSRecordOptions] `json:"response_ns,omitempty"` + ResponseExtra badoption.Listable[DNSRecordOptions] `json:"response_extra,omitempty"` SourceIPCIDR badoption.Listable[string] `json:"source_ip_cidr,omitempty"` SourceIPIsPrivate bool `json:"source_ip_is_private,omitempty"` SourcePort badoption.Listable[uint16] `json:"source_port,omitempty"` diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go index cac814e76..9f8ef945e 100644 --- a/route/rule/rule_action.go +++ b/route/rule/rule_action.go @@ -132,6 +132,16 @@ func NewDNSRuleAction(logger logger.ContextLogger, action option.DNSRuleAction) ClientSubnet: netip.Prefix(common.PtrValueOrDefault(action.RouteOptions.ClientSubnet)), }, } + case C.RuleActionTypeEvaluate: + return &RuleActionEvaluate{ + Server: action.RouteOptions.Server, + RuleActionDNSRouteOptions: RuleActionDNSRouteOptions{ + Strategy: C.DomainStrategy(action.RouteOptions.Strategy), + DisableCache: action.RouteOptions.DisableCache, + RewriteTTL: action.RouteOptions.RewriteTTL, + ClientSubnet: netip.Prefix(common.PtrValueOrDefault(action.RouteOptions.ClientSubnet)), + }, + } case C.RuleActionTypeRouteOptions: return &RuleActionDNSRouteOptions{ Strategy: C.DomainStrategy(action.RouteOptionsOptions.Strategy), @@ -266,18 +276,35 @@ func (r *RuleActionDNSRoute) Type() string { } func (r *RuleActionDNSRoute) String() string { + return formatDNSRouteAction("route", r.Server, r.RuleActionDNSRouteOptions) +} + +type RuleActionEvaluate struct { + Server string + RuleActionDNSRouteOptions +} + +func (r *RuleActionEvaluate) Type() string { + return C.RuleActionTypeEvaluate +} + +func (r *RuleActionEvaluate) String() string { + return formatDNSRouteAction("evaluate", r.Server, r.RuleActionDNSRouteOptions) +} + +func formatDNSRouteAction(action string, server string, options RuleActionDNSRouteOptions) string { var descriptions []string - descriptions = append(descriptions, r.Server) - if r.DisableCache { + descriptions = append(descriptions, server) + if options.DisableCache { descriptions = append(descriptions, "disable-cache") } - if r.RewriteTTL != nil { - descriptions = append(descriptions, F.ToString("rewrite-ttl=", *r.RewriteTTL)) + if options.RewriteTTL != nil { + descriptions = append(descriptions, F.ToString("rewrite-ttl=", *options.RewriteTTL)) } - if r.ClientSubnet.IsValid() { - descriptions = append(descriptions, F.ToString("client-subnet=", r.ClientSubnet)) + if options.ClientSubnet.IsValid() { + descriptions = append(descriptions, F.ToString("client-subnet=", options.ClientSubnet)) } - return F.ToString("route(", strings.Join(descriptions, ","), ")") + return F.ToString(action, "(", strings.Join(descriptions, ","), ")") } type RuleActionDNSRouteOptions struct { diff --git a/route/rule/rule_dns.go b/route/rule/rule_dns.go index f33d6096a..1bb42cbaf 100644 --- a/route/rule/rule_dns.go +++ b/route/rule/rule_dns.go @@ -5,6 +5,7 @@ import ( "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/experimental/deprecated" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" @@ -12,30 +13,36 @@ import ( "github.com/sagernet/sing/service" ) -func NewDNSRule(ctx context.Context, logger log.ContextLogger, options option.DNSRule, checkServer bool) (adapter.DNSRule, error) { +func NewDNSRule(ctx context.Context, logger log.ContextLogger, options option.DNSRule, checkServer bool, legacyAddressFilter bool) (adapter.DNSRule, error) { switch options.Type { case "", C.RuleTypeDefault: if !options.DefaultOptions.IsValid() { return nil, E.New("missing conditions") } + if !checkServer && options.DefaultOptions.Action == C.RuleActionTypeEvaluate { + return nil, E.New(options.DefaultOptions.Action, " is only allowed on top-level DNS rules") + } switch options.DefaultOptions.Action { - case "", C.RuleActionTypeRoute: + case "", C.RuleActionTypeRoute, C.RuleActionTypeEvaluate: if options.DefaultOptions.RouteOptions.Server == "" && checkServer { return nil, E.New("missing server field") } } - return NewDefaultDNSRule(ctx, logger, options.DefaultOptions) + return NewDefaultDNSRule(ctx, logger, options.DefaultOptions, legacyAddressFilter) case C.RuleTypeLogical: if !options.LogicalOptions.IsValid() { return nil, E.New("missing conditions") } + if !checkServer && options.LogicalOptions.Action == C.RuleActionTypeEvaluate { + return nil, E.New(options.LogicalOptions.Action, " is only allowed on top-level DNS rules") + } switch options.LogicalOptions.Action { - case "", C.RuleActionTypeRoute: + case "", C.RuleActionTypeRoute, C.RuleActionTypeEvaluate: if options.LogicalOptions.RouteOptions.Server == "" && checkServer { return nil, E.New("missing server field") } } - return NewLogicalDNSRule(ctx, logger, options.LogicalOptions) + return NewLogicalDNSRule(ctx, logger, options.LogicalOptions, legacyAddressFilter) default: return nil, E.New("unknown rule type: ", options.Type) } @@ -45,18 +52,20 @@ var _ adapter.DNSRule = (*DefaultDNSRule)(nil) type DefaultDNSRule struct { abstractDefaultRule + matchResponse bool } func (r *DefaultDNSRule) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet { return r.abstractDefaultRule.matchStates(metadata) } -func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) { +func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options option.DefaultDNSRule, legacyAddressFilter bool) (*DefaultDNSRule, error) { rule := &DefaultDNSRule{ abstractDefaultRule: abstractDefaultRule{ invert: options.Invert, action: NewDNSRuleAction(logger, options.DNSRuleAction), }, + matchResponse: options.MatchResponse, } if len(options.Inbound) > 0 { item := NewInboundRule(options.Inbound) @@ -152,10 +161,35 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op rule.allItems = append(rule.allItems, item) } if options.IPAcceptAny { + if legacyAddressFilter { + deprecated.Report(ctx, deprecated.OptionIPAcceptAny) + } else { + return nil, E.New("ip_accept_any is removed in DNS evaluate mode, use ip_cidr with match_response") + } item := NewIPAcceptAnyItem() rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item) rule.allItems = append(rule.allItems, item) } + if options.ResponseRcode != nil { + item := NewDNSResponseRCodeItem(int(*options.ResponseRcode)) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.ResponseAnswer) > 0 { + item := NewDNSResponseRecordItem("response_answer", options.ResponseAnswer, dnsResponseAnswers) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.ResponseNs) > 0 { + item := NewDNSResponseRecordItem("response_ns", options.ResponseNs, dnsResponseNS) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.ResponseExtra) > 0 { + item := NewDNSResponseRecordItem("response_extra", options.ResponseExtra, dnsResponseExtra) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } if len(options.SourcePort) > 0 { item := NewPortItem(true, options.SourcePort) rule.sourcePortItems = append(rule.sourcePortItems, item) @@ -284,6 +318,13 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op if options.RuleSetIPCIDRMatchSource { matchSource = true } + if options.RuleSetIPCIDRAcceptEmpty { + if legacyAddressFilter { + deprecated.Report(ctx, deprecated.OptionRuleSetIPCIDRAcceptEmpty) + } else { + return nil, E.New("rule_set_ip_cidr_accept_empty is removed in DNS evaluate mode") + } + } item := NewRuleSetItem(router, options.RuleSet, matchSource, options.RuleSetIPCIDRAcceptEmpty) rule.ruleSetItem = item rule.allItems = append(rule.allItems, item) @@ -309,6 +350,12 @@ func (r *DefaultDNSRule) WithAddressLimit() bool { } func (r *DefaultDNSRule) Match(metadata *adapter.InboundContext) bool { + if r.matchResponse { + if metadata.DNSResponse == nil { + return false + } + return r.abstractDefaultRule.Match(metadata) + } metadata.IgnoreDestinationIPCIDRMatch = true defer func() { metadata.IgnoreDestinationIPCIDRMatch = false @@ -330,7 +377,7 @@ func (r *LogicalDNSRule) matchStates(metadata *adapter.InboundContext) ruleMatch return r.abstractLogicalRule.matchStates(metadata) } -func NewLogicalDNSRule(ctx context.Context, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) { +func NewLogicalDNSRule(ctx context.Context, logger log.ContextLogger, options option.LogicalDNSRule, legacyAddressFilter bool) (*LogicalDNSRule, error) { r := &LogicalDNSRule{ abstractLogicalRule: abstractLogicalRule{ rules: make([]adapter.HeadlessRule, len(options.Rules)), @@ -347,7 +394,7 @@ func NewLogicalDNSRule(ctx context.Context, logger log.ContextLogger, options op return nil, E.New("unknown logical mode: ", options.Mode) } for i, subRule := range options.Rules { - rule, err := NewDNSRule(ctx, logger, subRule, false) + rule, err := NewDNSRule(ctx, logger, subRule, false, legacyAddressFilter) if err != nil { return nil, E.Cause(err, "sub rule[", i, "]") } diff --git a/route/rule/rule_item_response_rcode.go b/route/rule/rule_item_response_rcode.go new file mode 100644 index 000000000..ae2c62248 --- /dev/null +++ b/route/rule/rule_item_response_rcode.go @@ -0,0 +1,24 @@ +package rule + +import ( + "github.com/sagernet/sing-box/adapter" + F "github.com/sagernet/sing/common/format" +) + +var _ RuleItem = (*DNSResponseRCodeItem)(nil) + +type DNSResponseRCodeItem struct { + rcode int +} + +func NewDNSResponseRCodeItem(rcode int) *DNSResponseRCodeItem { + return &DNSResponseRCodeItem{rcode: rcode} +} + +func (r *DNSResponseRCodeItem) Match(metadata *adapter.InboundContext) bool { + return metadata.DNSResponse != nil && metadata.DNSResponse.Rcode == r.rcode +} + +func (r *DNSResponseRCodeItem) String() string { + return F.ToString("response_rcode=", r.rcode) +} diff --git a/route/rule/rule_item_response_record.go b/route/rule/rule_item_response_record.go new file mode 100644 index 000000000..3a2c889be --- /dev/null +++ b/route/rule/rule_item_response_record.go @@ -0,0 +1,63 @@ +package rule + +import ( + "strings" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/option" + + "github.com/miekg/dns" +) + +var _ RuleItem = (*DNSResponseRecordItem)(nil) + +type DNSResponseRecordItem struct { + field string + records []option.DNSRecordOptions + selector func(*dns.Msg) []dns.RR +} + +func NewDNSResponseRecordItem(field string, records []option.DNSRecordOptions, selector func(*dns.Msg) []dns.RR) *DNSResponseRecordItem { + return &DNSResponseRecordItem{ + field: field, + records: records, + selector: selector, + } +} + +func (r *DNSResponseRecordItem) Match(metadata *adapter.InboundContext) bool { + if metadata.DNSResponse == nil { + return false + } + records := r.selector(metadata.DNSResponse) + for _, expected := range r.records { + for _, record := range records { + if expected.Match(record) { + return true + } + } + } + return false +} + +func (r *DNSResponseRecordItem) String() string { + descriptions := make([]string, 0, len(r.records)) + for _, record := range r.records { + if record.RR != nil { + descriptions = append(descriptions, record.RR.String()) + } + } + return r.field + "=[" + strings.Join(descriptions, " ") + "]" +} + +func dnsResponseAnswers(message *dns.Msg) []dns.RR { + return message.Answer +} + +func dnsResponseNS(message *dns.Msg) []dns.RR { + return message.Ns +} + +func dnsResponseExtra(message *dns.Msg) []dns.RR { + return message.Extra +}