From 0926405b944a88080726a246c96ab06cf665f63e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 2 Apr 2026 01:43:31 +0800 Subject: [PATCH] dns: hard-fail lookup split rule misuse --- dns/router.go | 88 ++++++++++++++++++++++++++---- dns/router_test.go | 132 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 209 insertions(+), 11 deletions(-) diff --git a/dns/router.go b/dns/router.go index a485e599c..d5c80f67d 100644 --- a/dns/router.go +++ b/dns/router.go @@ -380,7 +380,31 @@ type exchangeWithRulesResult struct { const dnsRespondMissingResponseMessage = "respond action requires an evaluated response from a preceding evaluate action" -func (r *Router) exchangeWithRules(ctx context.Context, rules []adapter.DNSRule, message *mDNS.Msg, options adapter.DNSQueryOptions, allowFakeIP bool) exchangeWithRulesResult { +type lookupSplitHardError struct { + cause error +} + +func (e *lookupSplitHardError) Error() string { + return e.cause.Error() +} + +func (e *lookupSplitHardError) Unwrap() error { + return e.cause +} + +func newLookupSplitHardError(err error) error { + if err == nil { + return nil + } + return &lookupSplitHardError{cause: err} +} + +func isLookupSplitHardError(err error) bool { + var target *lookupSplitHardError + return errors.As(err, &target) +} + +func (r *Router) exchangeWithRules(ctx context.Context, rules []adapter.DNSRule, message *mDNS.Msg, options adapter.DNSQueryOptions, allowFakeIP bool, hardFailMissingTransport bool) exchangeWithRulesResult { metadata := adapter.ContextFrom(ctx) if metadata == nil { panic("no context") @@ -404,7 +428,11 @@ func (r *Router) exchangeWithRules(ctx context.Context, rules []adapter.DNSRule, transport, status := r.resolveDNSRoute(action.Server, action.RuleActionDNSRouteOptions, allowFakeIP, &queryOptions) switch status { case dnsRouteStatusMissing: - r.logger.ErrorContext(ctx, "transport not found: ", action.Server) + err := E.New("transport not found: ", action.Server) + if hardFailMissingTransport { + return exchangeWithRulesResult{err: newLookupSplitHardError(err)} + } + r.logger.ErrorContext(ctx, err) evaluatedResponse = nil evaluatedTransport = nil continue @@ -430,7 +458,7 @@ func (r *Router) exchangeWithRules(ctx context.Context, rules []adapter.DNSRule, case *R.RuleActionRespond: if evaluatedResponse == nil { return exchangeWithRulesResult{ - err: E.New(dnsRespondMissingResponseMessage), + err: newLookupSplitHardError(E.New(dnsRespondMissingResponseMessage)), } } return exchangeWithRulesResult{ @@ -442,7 +470,11 @@ func (r *Router) exchangeWithRules(ctx context.Context, rules []adapter.DNSRule, transport, status := r.resolveDNSRoute(action.Server, action.RuleActionDNSRouteOptions, allowFakeIP, &queryOptions) switch status { case dnsRouteStatusMissing: - r.logger.ErrorContext(ctx, "transport not found: ", action.Server) + err := E.New("transport not found: ", action.Server) + if hardFailMissingTransport { + return exchangeWithRulesResult{err: newLookupSplitHardError(err)} + } + r.logger.ErrorContext(ctx, err) continue case dnsRouteStatusSkipped: continue @@ -547,22 +579,58 @@ func (r *Router) lookupWithRules(ctx context.Context, rules []adapter.DNSRule, d return r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeAAAA, lookupOptions) } var ( - response4 []netip.Addr - response6 []netip.Addr + response4 []netip.Addr + response6 []netip.Addr + ordinaryErr4 error + ordinaryErr6 error + hardErr4 error + hardErr6 error ) var group task.Group group.Append("exchange4", func(ctx context.Context) error { result, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeA, lookupOptions) response4 = result - return err + if err == nil { + return nil + } + if E.IsClosedOrCanceled(err) { + return err + } + if isLookupSplitHardError(err) { + hardErr4 = err + return nil + } + ordinaryErr4 = err + return nil }) group.Append("exchange6", func(ctx context.Context) error { result, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeAAAA, lookupOptions) response6 = result - return err + if err == nil { + return nil + } + if E.IsClosedOrCanceled(err) { + return err + } + if isLookupSplitHardError(err) { + hardErr6 = err + return nil + } + ordinaryErr6 = err + return nil }) err := group.Run(ctx) + if err != nil { + return nil, err + } + err = E.Errors(hardErr4, hardErr6) if len(response4) == 0 && len(response6) == 0 { + if err != nil { + return nil, err + } + return nil, E.Errors(ordinaryErr4, ordinaryErr6) + } + if err != nil { return nil, err } return sortAddresses(response4, response6, strategy), nil @@ -579,7 +647,7 @@ func (r *Router) lookupWithRulesType(ctx context.Context, rules []adapter.DNSRul Qclass: mDNS.ClassINET, }}, } - exchangeResult := r.exchangeWithRules(withLookupQueryMetadata(ctx, qType), rules, request, options, false) + exchangeResult := r.exchangeWithRules(withLookupQueryMetadata(ctx, qType), rules, request, options, false, true) if exchangeResult.rejectAction != nil { return nil, exchangeResult.rejectAction.Error(ctx) } @@ -638,7 +706,7 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte } response, err = r.client.Exchange(ctx, transport, message, options, nil) } else if !legacyDNSMode { - exchangeResult := r.exchangeWithRules(ctx, rules, message, options, true) + exchangeResult := r.exchangeWithRules(ctx, rules, message, options, true, false) response, transport, err = exchangeResult.response, exchangeResult.transport, exchangeResult.err } else { var ( diff --git a/dns/router_test.go b/dns/router_test.go index a6f71d877..c5f47a844 100644 --- a/dns/router_test.go +++ b/dns/router_test.go @@ -1775,7 +1775,7 @@ func TestExchangeLegacyDNSModeDisabledRespondWithoutEvaluatedResponseReturnsErro require.ErrorContains(t, err, dnsRespondMissingResponseMessage) } -func TestLookupLegacyDNSModeDisabledAllowsPartialSuccess(t *testing.T) { +func TestLookupLegacyDNSModeDisabledAllowsPartialSuccessForExchangeFailure(t *testing.T) { t.Parallel() defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} @@ -1804,6 +1804,136 @@ func TestLookupLegacyDNSModeDisabledAllowsPartialSuccess(t *testing.T) { require.Equal(t, []netip.Addr{netip.MustParseAddr("1.1.1.1")}, addresses) } +func TestLookupLegacyDNSModeDisabledRespondWithoutEvaluatedResponseIsHardError(t *testing.T) { + t.Parallel() + + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + router := newTestRouter(t, []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeEvaluate, + RouteOptions: option.DNSRouteActionOptions{Server: "upstream"}, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRespond, + }, + }, + }, + }, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + "upstream": &fakeDNSTransport{tag: "upstream", transportType: C.DNSTypeUDP}, + }, + }, &fakeDNSClient{ + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { + require.Equal(t, "upstream", transport.Tag()) + switch message.Question[0].Qtype { + case mDNS.TypeA: + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil + case mDNS.TypeAAAA: + return nil, E.New("upstream exchange failed") + default: + return nil, E.New("unexpected qtype") + } + }, + }) + router.legacyDNSMode = false + + addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) + require.Nil(t, addresses) + require.ErrorContains(t, err, dnsRespondMissingResponseMessage) +} + +func TestLookupLegacyDNSModeDisabledTransportNotFoundIsHardError(t *testing.T) { + t.Parallel() + + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + router := newTestRouter(t, []option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(mDNS.TypeAAAA)}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "missing"}, + }, + }, + }}, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + }, + }, &fakeDNSClient{ + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { + require.Equal(t, "default", transport.Tag()) + switch message.Question[0].Qtype { + case mDNS.TypeA: + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil + case mDNS.TypeAAAA: + return FixedResponse(0, message.Question[0], nil, 60), nil + default: + return nil, E.New("unexpected qtype") + } + }, + }) + router.legacyDNSMode = false + + addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) + require.Nil(t, addresses) + require.ErrorContains(t, err, "transport not found: missing") +} + +func TestLookupLegacyDNSModeDisabledAllowsPartialSuccessForRcodeError(t *testing.T) { + t.Parallel() + + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + router := newTestRouter(t, nil, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + }, + }, &fakeDNSClient{ + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { + require.Equal(t, "default", transport.Tag()) + switch message.Question[0].Qtype { + case mDNS.TypeA: + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil + case mDNS.TypeAAAA: + return &mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Response: true, + Rcode: mDNS.RcodeNameError, + }, + Question: []mDNS.Question{message.Question[0]}, + }, nil + default: + return nil, E.New("unexpected qtype") + } + }, + }) + router.legacyDNSMode = false + + addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("1.1.1.1")}, addresses) +} + func TestLookupLegacyDNSModeDisabledSkipsFakeIPRule(t *testing.T) { t.Parallel()