diff --git a/dns/router_test.go b/dns/router_test.go index 81ab358b2..566cef455 100644 --- a/dns/router_test.go +++ b/dns/router_test.go @@ -78,6 +78,11 @@ type fakeDNSClient struct { lookup func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) } +type recordingExchangeDNSClient struct { + beforeExchange func(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg, options adapter.DNSQueryOptions) + exchange func(transport adapter.DNSTransport, message *mDNS.Msg, options adapter.DNSQueryOptions) (*mDNS.Msg, error) +} + type fakeDeprecatedManager struct { features []deprecated.Note } @@ -256,10 +261,54 @@ func (c *fakeDNSClient) Lookup(_ context.Context, transport adapter.DNSTransport return MessageToAddresses(response), nil } +func (c *recordingExchangeDNSClient) Start() {} + +func (c *recordingExchangeDNSClient) Exchange(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg, options adapter.DNSQueryOptions, _ func(*mDNS.Msg) bool) (*mDNS.Msg, error) { + if c.beforeExchange != nil { + c.beforeExchange(ctx, transport, message, options) + } + if c.exchange == nil { + return nil, E.New("unused client exchange") + } + return c.exchange(transport, message, options) +} + +func (c *recordingExchangeDNSClient) Lookup(context.Context, adapter.DNSTransport, string, adapter.DNSQueryOptions, func(*mDNS.Msg) bool) ([]netip.Addr, error) { + return nil, E.New("unused client lookup") +} + func (c *fakeDNSClient) ClearCache() {} +func (c *recordingExchangeDNSClient) ClearCache() {} + func newTestRouter(t *testing.T, rules []option.DNSRule, transportManager *fakeDNSTransportManager, client *fakeDNSClient) *Router { - return newTestRouterWithContext(t, context.Background(), rules, transportManager, client) + router := newTestRouterWithContext(t, context.Background(), rules, transportManager, client) + t.Cleanup(func() { + router.Close() + }) + return router +} + +func newTestRouterWithDNSClient(t *testing.T, rules []option.DNSRule, transportManager *fakeDNSTransportManager, client adapter.DNSClient) *Router { + router := &Router{ + ctx: context.Background(), + logger: log.NewNOPFactory().NewLogger("dns"), + transport: transportManager, + client: client, + rawRules: make([]option.DNSRule, 0, len(rules)), + defaultDomainStrategy: C.DomainStrategyAsIS, + } + router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, len(rules)), false)) + if rules != nil { + err := router.Initialize(rules) + require.NoError(t, err) + err = router.Start(adapter.StartStateStart) + require.NoError(t, err) + } + t.Cleanup(func() { + router.Close() + }) + return router } func newTestRouterWithContext(t *testing.T, ctx context.Context, rules []option.DNSRule, transportManager *fakeDNSTransportManager, client *fakeDNSClient) *Router { @@ -1903,6 +1952,98 @@ func TestExchangeLegacyDNSModeDisabledEvaluateRouteResolutionFailureClearsRespon require.Equal(t, []netip.Addr{netip.MustParseAddr("4.4.4.4")}, MessageToAddresses(response)) } +func TestExchangeLegacyDNSModeDisabledSecondEvaluateOverwritesFirstResponse(t *testing.T) { + t.Parallel() + + transportManager := &fakeDNSTransportManager{ + defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + transports: map[string]adapter.DNSTransport{ + "default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + "first-upstream": &fakeDNSTransport{tag: "first-upstream", transportType: C.DNSTypeUDP}, + "second-upstream": &fakeDNSTransport{tag: "second-upstream", transportType: C.DNSTypeUDP}, + "first-match": &fakeDNSTransport{tag: "first-match", transportType: C.DNSTypeUDP}, + "second-match": &fakeDNSTransport{tag: "second-match", transportType: C.DNSTypeUDP}, + }, + } + client := &fakeDNSClient{ + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { + switch transport.Tag() { + case "first-upstream": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil + case "second-upstream": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2.2.2.2")}, 60), nil + case "first-match": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("7.7.7.7")}, 60), nil + case "second-match": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil + case "default": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("4.4.4.4")}, 60), nil + default: + return nil, E.New("unexpected transport") + } + }, + } + rules := []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeEvaluate, + RouteOptions: option.DNSRouteActionOptions{Server: "first-upstream"}, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeEvaluate, + RouteOptions: option.DNSRouteActionOptions{Server: "second-upstream"}, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + MatchResponse: true, + ResponseAnswer: badoption.Listable[option.DNSRecordOptions]{mustRecord(t, "example.com. IN A 1.1.1.1")}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "first-match"}, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + MatchResponse: true, + ResponseAnswer: badoption.Listable[option.DNSRecordOptions]{mustRecord(t, "example.com. IN A 2.2.2.2")}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "second-match"}, + }, + }, + }, + } + router := newTestRouter(t, rules, transportManager, client) + + response, err := router.Exchange(context.Background(), &mDNS.Msg{ + Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)}, + }, adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("8.8.8.8")}, MessageToAddresses(response)) +} + func TestExchangeLegacyDNSModeDisabledEvaluateExchangeFailureUsesMatchResponseBooleanSemantics(t *testing.T) { t.Parallel() @@ -1922,7 +2063,6 @@ func TestExchangeLegacyDNSModeDisabledEvaluateExchangeFailureUsesMatchResponseBo }, } for _, testCase := range testCases { - testCase := testCase t.Run(testCase.name, func(t *testing.T) { t.Parallel() @@ -2349,6 +2489,68 @@ func TestInitializeRejectsDNSRuleStrategyWhenLegacyDNSModeIsDisabledByMatchRespo require.ErrorContains(t, err, "legacyDNSMode") } +func TestExchangeLegacyDNSModeDisabledRouteOptionsApplyQueryOptions(t *testing.T) { + t.Parallel() + + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + rewriteTTL := uint32(30) + var capturedOptions adapter.DNSQueryOptions + router := newTestRouterWithDNSClient(t, []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(mDNS.TypeA)}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRouteOptions, + RouteOptionsOptions: option.DNSRouteOptionsActionOptions{ + DisableCache: true, + RewriteTTL: &rewriteTTL, + }, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, + }, + }, + }, + }, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + "selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP}, + }, + }, &recordingExchangeDNSClient{ + beforeExchange: func(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg, options adapter.DNSQueryOptions) { + require.Equal(t, "selected", transport.Tag()) + require.Equal(t, []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)}, message.Question) + capturedOptions = options + }, + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg, options adapter.DNSQueryOptions) (*mDNS.Msg, error) { + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2.2.2.2")}, 60), nil + }, + }) + require.False(t, router.currentRules.Load().legacyDNSMode) + + response, err := router.Exchange(context.Background(), &mDNS.Msg{ + Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)}, + }, adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("2.2.2.2")}, MessageToAddresses(response)) + require.True(t, capturedOptions.DisableCache) + require.NotNil(t, capturedOptions.RewriteTTL) + require.Equal(t, rewriteTTL, *capturedOptions.RewriteTTL) +} + func TestLookupLegacyDNSModeUsesRouteStrategy(t *testing.T) { t.Parallel() @@ -2457,6 +2659,40 @@ func TestExchangeLegacyDNSModeDisabledReturnsRefusedResponseForRejectAction(t *t require.Equal(t, []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)}, response.Question) } +func TestExchangeLegacyDNSModeDisabledReturnsDropErrorForRejectDropAction(t *testing.T) { + t.Parallel() + + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + router := newTestRouter(t, []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeReject, + RejectOptions: option.RejectActionOptions{ + Method: C.RuleActionRejectMethodDrop, + }, + }, + }, + }, + }, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + }, + }, &fakeDNSClient{}) + require.False(t, router.currentRules.Load().legacyDNSMode) + + response, err := router.Exchange(context.Background(), &mDNS.Msg{ + Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)}, + }, adapter.DNSQueryOptions{}) + require.Nil(t, response) + require.ErrorIs(t, err, tun.ErrDrop) +} + func TestLookupLegacyDNSModeDisabledFiltersPerQueryTypeAddressesBeforeMerging(t *testing.T) { t.Parallel()