339 lines
12 KiB
Go
339 lines
12 KiB
Go
package dns
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/netip"
|
|
"testing"
|
|
|
|
"github.com/sagernet/sing-box/adapter"
|
|
C "github.com/sagernet/sing-box/constant"
|
|
"github.com/sagernet/sing-box/experimental/deprecated"
|
|
"github.com/sagernet/sing-box/log"
|
|
"github.com/sagernet/sing-box/option"
|
|
"github.com/sagernet/sing/common/json/badoption"
|
|
"github.com/sagernet/sing/service"
|
|
|
|
mDNS "github.com/miekg/dns"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type fakeDNSTransport struct {
|
|
tag string
|
|
transportType string
|
|
}
|
|
|
|
func (t *fakeDNSTransport) Start(adapter.StartStage) error { return nil }
|
|
func (t *fakeDNSTransport) Close() error { return nil }
|
|
func (t *fakeDNSTransport) Type() string { return t.transportType }
|
|
func (t *fakeDNSTransport) Tag() string { return t.tag }
|
|
func (t *fakeDNSTransport) Dependencies() []string { return nil }
|
|
func (t *fakeDNSTransport) Reset() {}
|
|
func (t *fakeDNSTransport) Exchange(context.Context, *mDNS.Msg) (*mDNS.Msg, error) {
|
|
return nil, errors.New("unused transport exchange")
|
|
}
|
|
|
|
type fakeDNSTransportManager struct {
|
|
defaultTransport adapter.DNSTransport
|
|
transports map[string]adapter.DNSTransport
|
|
}
|
|
|
|
func (m *fakeDNSTransportManager) Start(adapter.StartStage) error { return nil }
|
|
func (m *fakeDNSTransportManager) Close() error { return nil }
|
|
func (m *fakeDNSTransportManager) Transports() []adapter.DNSTransport {
|
|
transports := make([]adapter.DNSTransport, 0, len(m.transports))
|
|
for _, transport := range m.transports {
|
|
transports = append(transports, transport)
|
|
}
|
|
return transports
|
|
}
|
|
|
|
func (m *fakeDNSTransportManager) Transport(tag string) (adapter.DNSTransport, bool) {
|
|
transport, loaded := m.transports[tag]
|
|
return transport, loaded
|
|
}
|
|
func (m *fakeDNSTransportManager) Default() adapter.DNSTransport { return m.defaultTransport }
|
|
func (m *fakeDNSTransportManager) FakeIP() adapter.FakeIPTransport {
|
|
return nil
|
|
}
|
|
func (m *fakeDNSTransportManager) Remove(string) error { return nil }
|
|
func (m *fakeDNSTransportManager) Create(context.Context, log.ContextLogger, string, string, any) error {
|
|
return errors.New("unsupported")
|
|
}
|
|
|
|
type fakeDNSClient struct {
|
|
exchange func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error)
|
|
}
|
|
|
|
type fakeDeprecatedManager struct {
|
|
features []deprecated.Note
|
|
}
|
|
|
|
func (m *fakeDeprecatedManager) ReportDeprecated(feature deprecated.Note) {
|
|
m.features = append(m.features, feature)
|
|
}
|
|
|
|
func (c *fakeDNSClient) Start() {}
|
|
|
|
func (c *fakeDNSClient) Exchange(_ context.Context, transport adapter.DNSTransport, message *mDNS.Msg, _ adapter.DNSQueryOptions, _ func([]netip.Addr) bool) (*mDNS.Msg, error) {
|
|
return c.exchange(transport, message)
|
|
}
|
|
|
|
func (c *fakeDNSClient) Lookup(context.Context, adapter.DNSTransport, string, adapter.DNSQueryOptions, func([]netip.Addr) bool) ([]netip.Addr, error) {
|
|
return nil, errors.New("unused client lookup")
|
|
}
|
|
|
|
func (c *fakeDNSClient) ClearCache() {}
|
|
|
|
func newTestRouter(t *testing.T, rules []option.DNSRule, transportManager *fakeDNSTransportManager, client *fakeDNSClient) *Router {
|
|
t.Helper()
|
|
router := &Router{
|
|
ctx: context.Background(),
|
|
logger: log.NewNOPFactory().NewLogger("dns"),
|
|
transport: transportManager,
|
|
client: client,
|
|
rules: make([]adapter.DNSRule, 0, len(rules)),
|
|
defaultDomainStrategy: C.DomainStrategyAsIS,
|
|
}
|
|
if rules != nil {
|
|
err := router.Initialize(rules)
|
|
require.NoError(t, err)
|
|
}
|
|
return router
|
|
}
|
|
|
|
func fixedQuestion(name string, qType uint16) mDNS.Question {
|
|
return mDNS.Question{
|
|
Name: mDNS.Fqdn(name),
|
|
Qtype: qType,
|
|
Qclass: mDNS.ClassINET,
|
|
}
|
|
}
|
|
|
|
func mustRecord(t *testing.T, record string) option.DNSRecordOptions {
|
|
t.Helper()
|
|
var value option.DNSRecordOptions
|
|
require.NoError(t, value.UnmarshalJSON([]byte(`"`+record+`"`)))
|
|
return value
|
|
}
|
|
|
|
func TestValidateNewDNSRules_RequireMatchResponseForDirectIPCIDR(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
err := validateNonLegacyAddressFilterRules([]option.DNSRule{{
|
|
Type: C.RuleTypeDefault,
|
|
DefaultOptions: option.DefaultDNSRule{
|
|
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
|
IPCIDR: badoption.Listable[string]{"1.1.1.0/24"},
|
|
},
|
|
DNSRuleAction: option.DNSRuleAction{
|
|
Action: C.RuleActionTypeRoute,
|
|
RouteOptions: option.DNSRouteActionOptions{
|
|
Server: "default",
|
|
},
|
|
},
|
|
},
|
|
}})
|
|
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
|
|
}
|
|
|
|
func TestExchangeNewModeEvaluateMatchResponseRoute(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
transportManager := &fakeDNSTransportManager{
|
|
defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
|
|
transports: map[string]adapter.DNSTransport{
|
|
"upstream": &fakeDNSTransport{tag: "upstream", transportType: C.DNSTypeUDP},
|
|
"selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP},
|
|
"default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
|
|
},
|
|
}
|
|
client := &fakeDNSClient{
|
|
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
|
|
switch transport.Tag() {
|
|
case "upstream":
|
|
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil
|
|
case "selected":
|
|
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil
|
|
default:
|
|
return nil, errors.New("unexpected transport")
|
|
}
|
|
},
|
|
}
|
|
rules := []option.DNSRule{
|
|
{
|
|
Type: C.RuleTypeDefault,
|
|
DefaultOptions: option.DefaultDNSRule{
|
|
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
|
Domain: badoption.Listable[string]{"example.com"},
|
|
},
|
|
DNSRuleAction: option.DNSRuleAction{
|
|
Action: C.RuleActionTypeEvaluate,
|
|
RouteOptions: option.DNSRouteActionOptions{Server: "upstream"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Type: C.RuleTypeDefault,
|
|
DefaultOptions: option.DefaultDNSRule{
|
|
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
|
MatchResponse: true,
|
|
ResponseAnswer: badoption.Listable[option.DNSRecordOptions]{mustRecord(t, "example.com. IN A 1.1.1.1")},
|
|
},
|
|
DNSRuleAction: option.DNSRuleAction{
|
|
Action: C.RuleActionTypeRoute,
|
|
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
router := newTestRouter(t, rules, transportManager, client)
|
|
|
|
response, err := router.Exchange(context.Background(), &mDNS.Msg{
|
|
Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)},
|
|
}, adapter.DNSQueryOptions{})
|
|
require.NoError(t, err)
|
|
require.Equal(t, []netip.Addr{netip.MustParseAddr("8.8.8.8")}, MessageToAddresses(response))
|
|
}
|
|
|
|
func TestLookupNewModeAllowsPartialSuccess(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
|
router := newTestRouter(t, nil, &fakeDNSTransportManager{
|
|
defaultTransport: defaultTransport,
|
|
transports: map[string]adapter.DNSTransport{
|
|
"default": defaultTransport,
|
|
},
|
|
}, &fakeDNSClient{
|
|
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
|
|
require.Equal(t, "default", transport.Tag())
|
|
switch message.Question[0].Qtype {
|
|
case mDNS.TypeA:
|
|
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil
|
|
case mDNS.TypeAAAA:
|
|
return nil, errors.New("ipv6 failed")
|
|
default:
|
|
return nil, errors.New("unexpected qtype")
|
|
}
|
|
},
|
|
})
|
|
router.legacyAddressFilterMode = false
|
|
|
|
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
|
require.NoError(t, err)
|
|
require.Equal(t, []netip.Addr{netip.MustParseAddr("1.1.1.1")}, addresses)
|
|
}
|
|
|
|
func TestLookupNewModeSkipsFakeIPRule(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
|
router := newTestRouter(t, []option.DNSRule{{
|
|
Type: C.RuleTypeDefault,
|
|
DefaultOptions: option.DefaultDNSRule{
|
|
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
|
Domain: badoption.Listable[string]{"example.com"},
|
|
},
|
|
DNSRuleAction: option.DNSRuleAction{
|
|
Action: C.RuleActionTypeRoute,
|
|
RouteOptions: option.DNSRouteActionOptions{Server: "fake"},
|
|
},
|
|
},
|
|
}}, &fakeDNSTransportManager{
|
|
defaultTransport: defaultTransport,
|
|
transports: map[string]adapter.DNSTransport{
|
|
"default": defaultTransport,
|
|
"fake": &fakeDNSTransport{tag: "fake", transportType: C.DNSTypeFakeIP},
|
|
},
|
|
}, &fakeDNSClient{
|
|
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
|
|
require.Equal(t, "default", transport.Tag())
|
|
if message.Question[0].Qtype == mDNS.TypeA {
|
|
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2.2.2.2")}, 60), nil
|
|
}
|
|
return FixedResponse(0, message.Question[0], nil, 60), nil
|
|
},
|
|
})
|
|
router.legacyAddressFilterMode = false
|
|
|
|
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
|
require.NoError(t, err)
|
|
require.Equal(t, []netip.Addr{netip.MustParseAddr("2.2.2.2")}, addresses)
|
|
}
|
|
|
|
func TestLookupNewModeDoesNotUseQueryTypeRule(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
|
router := newTestRouter(t, []option.DNSRule{{
|
|
Type: C.RuleTypeDefault,
|
|
DefaultOptions: option.DefaultDNSRule{
|
|
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
|
QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(mDNS.TypeA)},
|
|
},
|
|
DNSRuleAction: option.DNSRuleAction{
|
|
Action: C.RuleActionTypeRoute,
|
|
RouteOptions: option.DNSRouteActionOptions{Server: "only-a"},
|
|
},
|
|
},
|
|
}}, &fakeDNSTransportManager{
|
|
defaultTransport: defaultTransport,
|
|
transports: map[string]adapter.DNSTransport{
|
|
"default": defaultTransport,
|
|
"only-a": &fakeDNSTransport{tag: "only-a", transportType: C.DNSTypeUDP},
|
|
},
|
|
}, &fakeDNSClient{
|
|
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
|
|
switch transport.Tag() {
|
|
case "default":
|
|
if message.Question[0].Qtype == mDNS.TypeA {
|
|
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("3.3.3.3")}, 60), nil
|
|
}
|
|
return FixedResponse(0, message.Question[0], nil, 60), nil
|
|
case "only-a":
|
|
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60), nil
|
|
default:
|
|
return nil, errors.New("unexpected transport")
|
|
}
|
|
},
|
|
})
|
|
router.legacyAddressFilterMode = false
|
|
|
|
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
|
require.NoError(t, err)
|
|
require.Equal(t, []netip.Addr{netip.MustParseAddr("3.3.3.3")}, addresses)
|
|
}
|
|
|
|
func TestOldModeReportsLegacyAddressFilterDeprecation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
manager := &fakeDeprecatedManager{}
|
|
ctx := service.ContextWith[deprecated.Manager](context.Background(), manager)
|
|
router := &Router{
|
|
ctx: ctx,
|
|
logger: log.NewNOPFactory().NewLogger("dns"),
|
|
client: &fakeDNSClient{},
|
|
rules: make([]adapter.DNSRule, 0, 1),
|
|
defaultDomainStrategy: C.DomainStrategyAsIS,
|
|
}
|
|
err := router.Initialize([]option.DNSRule{{
|
|
Type: C.RuleTypeDefault,
|
|
DefaultOptions: option.DefaultDNSRule{
|
|
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
|
IPCIDR: badoption.Listable[string]{"1.1.1.0/24"},
|
|
},
|
|
DNSRuleAction: option.DNSRuleAction{
|
|
Action: C.RuleActionTypeRoute,
|
|
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
|
|
},
|
|
},
|
|
}})
|
|
require.NoError(t, err)
|
|
|
|
err = router.Start(adapter.StartStateStart)
|
|
require.NoError(t, err)
|
|
require.Len(t, manager.features, 1)
|
|
require.Equal(t, deprecated.OptionLegacyDNSAddressFilter.Name, manager.features[0].Name)
|
|
}
|