diff --git a/dns/router.go b/dns/router.go index 778ad84c0..15761055b 100644 --- a/dns/router.go +++ b/dns/router.go @@ -399,7 +399,16 @@ func (r *Router) logRuleMatch(ctx context.Context, ruleIndex int, currentRule ad } } -func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions, allowFakeIP bool) (*mDNS.Msg, adapter.DNSTransport, adapter.DNSQueryOptions, bool, error) { +type exchangeWithRulesResult struct { + response *mDNS.Msg + transport adapter.DNSTransport + queryOptions adapter.DNSQueryOptions + strategyOverridden bool + rejectAction *R.RuleActionReject + err error +} + +func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions, allowFakeIP bool) exchangeWithRulesResult { metadata := adapter.ContextFrom(ctx) if metadata == nil { panic("no context") @@ -458,23 +467,43 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio exchangeOptions.Strategy = r.defaultDomainStrategy } response, err := r.client.Exchange(adapter.OverrideContext(ctx), transport, message, exchangeOptions, nil) - return response, transport, queryOptions, effectiveStrategyOverridden || strategyOverridden, err + return exchangeWithRulesResult{ + response: response, + transport: transport, + queryOptions: queryOptions, + strategyOverridden: effectiveStrategyOverridden || strategyOverridden, + err: err, + } case *R.RuleActionReject: switch action.Method { case C.RuleActionRejectMethodDefault: - return &mDNS.Msg{ - MsgHdr: mDNS.MsgHdr{ - Id: message.Id, - Rcode: mDNS.RcodeRefused, - Response: true, + return exchangeWithRulesResult{ + response: &mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Id: message.Id, + Rcode: mDNS.RcodeRefused, + Response: true, + }, + Question: []mDNS.Question{message.Question[0]}, }, - Question: []mDNS.Question{message.Question[0]}, - }, nil, effectiveOptions, effectiveStrategyOverridden, nil + queryOptions: effectiveOptions, + strategyOverridden: effectiveStrategyOverridden, + rejectAction: action, + } case C.RuleActionRejectMethodDrop: - return nil, nil, effectiveOptions, effectiveStrategyOverridden, tun.ErrDrop + return exchangeWithRulesResult{ + queryOptions: effectiveOptions, + strategyOverridden: effectiveStrategyOverridden, + rejectAction: action, + err: tun.ErrDrop, + } } case *R.RuleActionPredefined: - return action.Response(message), nil, effectiveOptions, effectiveStrategyOverridden, nil + return exchangeWithRulesResult{ + response: action.Response(message), + queryOptions: effectiveOptions, + strategyOverridden: effectiveStrategyOverridden, + } } } queryOptions := effectiveOptions @@ -484,7 +513,13 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio exchangeOptions.Strategy = r.defaultDomainStrategy } response, err := r.client.Exchange(adapter.OverrideContext(ctx), transport, message, exchangeOptions, nil) - return response, transport, queryOptions, effectiveStrategyOverridden, err + return exchangeWithRulesResult{ + response: response, + transport: transport, + queryOptions: queryOptions, + strategyOverridden: effectiveStrategyOverridden, + err: err, + } } type lookupWithRulesResponse struct { @@ -593,6 +628,21 @@ func withLookupQueryMetadata(ctx context.Context, qType uint16) context.Context return ctx } +func filterAddressesByQueryType(addresses []netip.Addr, qType uint16) []netip.Addr { + switch qType { + case mDNS.TypeA: + return common.Filter(addresses, func(address netip.Addr) bool { + return address.Is4() + }) + case mDNS.TypeAAAA: + return common.Filter(addresses, func(address netip.Addr) bool { + return address.Is6() + }) + default: + return addresses + } +} + func (r *Router) lookupWithRules(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) { lookupOptions := options if options.LookupStrategy != C.DomainStrategyAsIS { @@ -646,22 +696,25 @@ func (r *Router) lookupWithRulesType(ctx context.Context, domain string, qType u Qclass: mDNS.ClassINET, }}, } - response, _, queryOptions, strategyOverridden, err := r.exchangeWithRules(withLookupQueryMetadata(ctx, qType), request, options, false) - explicitStrategy := lookupStrategyOverride(queryOptions, strategyOverridden) + exchangeResult := r.exchangeWithRules(withLookupQueryMetadata(ctx, qType), request, options, false) + explicitStrategy := lookupStrategyOverride(exchangeResult.queryOptions, exchangeResult.strategyOverridden) result := lookupWithRulesResponse{ strategy: r.resolveLookupStrategy(options, explicitStrategy), explicitStrategy: explicitStrategy, } - if err != nil { - return result, err + if exchangeResult.rejectAction != nil { + return result, exchangeResult.rejectAction.Error(ctx) } - if response.Rcode != mDNS.RcodeSuccess { - return result, RcodeError(response.Rcode) + if exchangeResult.err != nil { + return result, exchangeResult.err + } + if exchangeResult.response.Rcode != mDNS.RcodeSuccess { + return result, RcodeError(exchangeResult.response.Rcode) } if !lookupStrategyAllowsQueryType(result.strategy, qType) { return result, nil } - result.addresses = MessageToAddresses(response) + result.addresses = filterAddressesByQueryType(MessageToAddresses(exchangeResult.response), qType) return result, nil } @@ -709,7 +762,8 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte } response, err = r.client.Exchange(ctx, transport, message, options, nil) } else if !r.legacyAddressFilterMode { - response, transport, _, _, err = r.exchangeWithRules(ctx, message, options, true) + exchangeResult := r.exchangeWithRules(ctx, message, options, true) + response, transport, err = exchangeResult.response, exchangeResult.transport, exchangeResult.err } else { var ( rule adapter.DNSRule @@ -802,6 +856,8 @@ func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQ r.logger.DebugContext(ctx, "response rejected for ", domain, " (cached)") } else if errors.Is(err, ErrResponseRejected) { r.logger.DebugContext(ctx, "response rejected for ", domain) + } else if R.IsRejected(err) { + r.logger.DebugContext(ctx, "lookup rejected for ", domain) } else { r.logger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain)) } diff --git a/dns/router_test.go b/dns/router_test.go index ccc377ea8..babd6cf83 100644 --- a/dns/router_test.go +++ b/dns/router_test.go @@ -1356,6 +1356,112 @@ func TestLookupNewModeAppliesRouteStrategyAfterEvaluate(t *testing.T) { require.Equal(t, []netip.Addr{netip.MustParseAddr("2.2.2.2")}, addresses) } +func TestLookupNewModeReturnsRejectedErrorForRejectAction(t *testing.T) { + t.Parallel() + + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + router := newTestRouter(t, []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeReject, + RejectOptions: option.RejectActionOptions{ + Method: C.RuleActionRejectMethodDefault, + }, + }, + }, + }, + }, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + }, + }, &fakeDNSClient{}) + require.False(t, router.legacyAddressFilterMode) + + addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) + require.Nil(t, addresses) + require.Error(t, err) + require.True(t, rulepkg.IsRejected(err)) +} + +func TestExchangeNewModeReturnsRefusedResponseForRejectAction(t *testing.T) { + t.Parallel() + + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + router := newTestRouter(t, []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeReject, + RejectOptions: option.RejectActionOptions{ + Method: C.RuleActionRejectMethodDefault, + }, + }, + }, + }, + }, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + }, + }, &fakeDNSClient{}) + require.False(t, router.legacyAddressFilterMode) + + response, err := router.Exchange(context.Background(), &mDNS.Msg{ + Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)}, + }, adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, mDNS.RcodeRefused, response.Rcode) + require.Equal(t, []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)}, response.Question) +} + +func TestLookupNewModeFiltersPerQueryTypeAddressesBeforeMerging(t *testing.T) { + t.Parallel() + + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + router := newTestRouter(t, []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypePredefined, + PredefinedOptions: option.DNSRouteActionPredefined{ + Answer: badoption.Listable[option.DNSRecordOptions]{ + mustRecord(t, "example.com. IN A 1.1.1.1"), + mustRecord(t, "example.com. IN AAAA 2001:db8::1"), + }, + }, + }, + }, + }, + }, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + }, + }, &fakeDNSClient{}) + require.False(t, router.legacyAddressFilterMode) + + addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{ + netip.MustParseAddr("1.1.1.1"), + netip.MustParseAddr("2001:db8::1"), + }, addresses) +} + func TestLookupNewModePrefersExplicitBranchStrategyOverDefault(t *testing.T) { t.Parallel()