diff --git a/dns/router_test.go b/dns/router_test.go index fb6dad0e4..cfd8dea3d 100644 --- a/dns/router_test.go +++ b/dns/router_test.go @@ -1152,6 +1152,89 @@ func TestExchangeNewModeEvaluateRouteResolutionFailureClearsResponse(t *testing. require.Equal(t, []netip.Addr{netip.MustParseAddr("4.4.4.4")}, MessageToAddresses(response)) } +func TestExchangeNewModeEvaluateExchangeFailureUsesMatchResponseBooleanSemantics(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + invert bool + expectedAddr netip.Addr + }{ + { + name: "plain match_response rule stays false", + expectedAddr: netip.MustParseAddr("4.4.4.4"), + }, + { + name: "invert match_response rule becomes true", + invert: true, + expectedAddr: netip.MustParseAddr("8.8.8.8"), + }, + } + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + transportManager := &fakeDNSTransportManager{ + defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + transports: map[string]adapter.DNSTransport{ + "upstream": &fakeDNSTransport{tag: "upstream", transportType: C.DNSTypeUDP}, + "selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP}, + "default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + }, + } + client := &fakeDNSClient{ + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { + switch transport.Tag() { + case "upstream": + return nil, errors.New("upstream exchange failed") + case "selected": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil + case "default": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("4.4.4.4")}, 60), nil + default: + return nil, errors.New("unexpected transport") + } + }, + } + rules := []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeEvaluate, + RouteOptions: option.DNSRouteActionOptions{Server: "upstream"}, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + MatchResponse: true, + Invert: testCase.invert, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, + }, + }, + }, + } + router := newTestRouter(t, rules, transportManager, client) + + response, err := router.Exchange(context.Background(), &mDNS.Msg{ + Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)}, + }, adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{testCase.expectedAddr}, MessageToAddresses(response)) + }) + } +} + func TestLookupNewModeAllowsPartialSuccess(t *testing.T) { t.Parallel() diff --git a/route/rule/rule_dns.go b/route/rule/rule_dns.go index cceaec8e9..6a84a4e77 100644 --- a/route/rule/rule_dns.go +++ b/route/rule/rule_dns.go @@ -357,14 +357,25 @@ func (r *DefaultDNSRule) Match(metadata *adapter.InboundContext) bool { func (r *DefaultDNSRule) LegacyPreMatch(metadata *adapter.InboundContext) bool { if r.matchResponse { - return !r.matchStatesForMatch(metadata).isEmpty() + return !r.legacyMatchStatesForMatch(metadata).isEmpty() } return !r.abstractDefaultRule.legacyMatchStates(metadata).isEmpty() } func (r *DefaultDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) ruleMatchStateSet { + return r.matchStatesForMatchWithMissingResponse(metadata, true) +} + +func (r *DefaultDNSRule) legacyMatchStatesForMatch(metadata *adapter.InboundContext) ruleMatchStateSet { + return r.matchStatesForMatchWithMissingResponse(metadata, false) +} + +func (r *DefaultDNSRule) matchStatesForMatchWithMissingResponse(metadata *adapter.InboundContext, ordinaryFailure bool) ruleMatchStateSet { if r.matchResponse { if metadata.DNSResponse == nil { + if ordinaryFailure { + return r.abstractDefaultRule.invertedFailure(0) + } return 0 } matchMetadata := *metadata diff --git a/route/rule/rule_set_semantics_test.go b/route/rule/rule_set_semantics_test.go index 8c695ecc7..bb26cc169 100644 --- a/route/rule/rule_set_semantics_test.go +++ b/route/rule/rule_set_semantics_test.go @@ -688,6 +688,63 @@ func TestDNSMatchResponseRuleSetDestinationCIDRUsesDNSResponse(t *testing.T) { require.False(t, rule.Match(&unmatchedMetadata)) } +func TestDNSMatchResponseMissingResponseUsesBooleanSemantics(t *testing.T) { + t.Parallel() + + t.Run("plain rule remains false", func(t *testing.T) { + t.Parallel() + + rule := dnsRuleForTest(func(rule *abstractDefaultRule) {}) + rule.matchResponse = true + + metadata := testMetadata("lookup.example") + require.False(t, rule.Match(&metadata)) + }) + + t.Run("invert rule becomes true", func(t *testing.T) { + t.Parallel() + + rule := dnsRuleForTest(func(rule *abstractDefaultRule) { + rule.invert = true + }) + rule.matchResponse = true + + metadata := testMetadata("lookup.example") + require.True(t, rule.Match(&metadata)) + }) + + t.Run("logical wrapper respects inverted child", func(t *testing.T) { + t.Parallel() + + nestedRule := dnsRuleForTest(func(rule *abstractDefaultRule) { + rule.invert = true + }) + nestedRule.matchResponse = true + + logicalRule := &LogicalDNSRule{ + abstractLogicalRule: abstractLogicalRule{ + rules: []adapter.HeadlessRule{nestedRule}, + mode: C.LogicalTypeAnd, + }, + } + + metadata := testMetadata("lookup.example") + require.True(t, logicalRule.Match(&metadata)) + }) +} + +func TestDNSLegacyMatchResponseMissingResponseStillFailsClosed(t *testing.T) { + t.Parallel() + + rule := dnsRuleForTest(func(rule *abstractDefaultRule) { + rule.invert = true + }) + rule.matchResponse = true + + metadata := testMetadata("lookup.example") + require.False(t, rule.LegacyPreMatch(&metadata)) +} + func TestDNSAddressLimitIgnoresDestinationAddresses(t *testing.T) { t.Parallel()