diff --git a/adapter/inbound.go b/adapter/inbound.go index 903ab9f9b..c8f9224ee 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -81,15 +81,16 @@ type InboundContext struct { FallbackNetworkType []C.InterfaceType FallbackDelay time.Duration - DestinationAddresses []netip.Addr - DNSResponse *dns.Msg - SourceGeoIPCode string - GeoIPCode string - ProcessInfo *ConnectionOwner - SourceMACAddress net.HardwareAddr - SourceHostname string - QueryType uint16 - FakeIP bool + DestinationAddresses []netip.Addr + DNSResponse *dns.Msg + DestinationAddressMatchFromResponse bool + SourceGeoIPCode string + GeoIPCode string + ProcessInfo *ConnectionOwner + SourceMACAddress net.HardwareAddr + SourceHostname string + QueryType uint16 + FakeIP bool // rule cache @@ -118,6 +119,29 @@ func (c *InboundContext) ResetRuleMatchCache() { c.DidMatch = false } +func (c *InboundContext) DestinationAddressesForMatch() []netip.Addr { + if c.DestinationAddressMatchFromResponse { + return DNSResponseAddresses(c.DNSResponse) + } + return c.DestinationAddresses +} + +func DNSResponseAddresses(response *dns.Msg) []netip.Addr { + if response == nil || response.Rcode != dns.RcodeSuccess { + return nil + } + var addresses []netip.Addr + for _, rawRecord := range response.Answer { + switch record := rawRecord.(type) { + case *dns.A: + addresses = append(addresses, M.AddrFromIP(record.A)) + case *dns.AAAA: + addresses = append(addresses, M.AddrFromIP(record.AAAA)) + } + } + return addresses +} + type inboundContextKey struct{} func WithContext(ctx context.Context, inboundContext *InboundContext) context.Context { diff --git a/dns/router.go b/dns/router.go index 19232f288..f07d897c2 100644 --- a/dns/router.go +++ b/dns/router.go @@ -285,7 +285,6 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio for currentRuleIndex, currentRule := range r.rules { metadata.ResetRuleCache() metadata.DNSResponse = savedResponse - metadata.DestinationAddresses = MessageToAddresses(savedResponse) if !currentRule.Match(metadata) { continue } @@ -303,6 +302,7 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio if transport == nil { r.logger.ErrorContext(ctx, "transport not found: ", action.Server) } + savedResponse = nil continue } if queryOptions.Strategy == C.DomainStrategyAsIS { diff --git a/dns/router_test.go b/dns/router_test.go index c0def7bf2..1fbb4c977 100644 --- a/dns/router_test.go +++ b/dns/router_test.go @@ -62,7 +62,8 @@ func (m *fakeDNSTransportManager) Create(context.Context, log.ContextLogger, str } type fakeDNSClient struct { - exchange func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) + beforeExchange func(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg) + exchange func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) } type fakeDeprecatedManager struct { @@ -75,7 +76,10 @@ func (m *fakeDeprecatedManager) ReportDeprecated(feature deprecated.Note) { func (c *fakeDNSClient) Start() {} -func (c *fakeDNSClient) Exchange(_ context.Context, transport adapter.DNSTransport, message *mDNS.Msg, _ adapter.DNSQueryOptions, _ func([]netip.Addr) bool) (*mDNS.Msg, error) { +func (c *fakeDNSClient) Exchange(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg, _ adapter.DNSQueryOptions, _ func([]netip.Addr) bool) (*mDNS.Msg, error) { + if c.beforeExchange != nil { + c.beforeExchange(ctx, transport, message) + } return c.exchange(transport, message) } @@ -255,6 +259,150 @@ func TestExchangeNewModeEvaluateMatchResponseRouteIgnoresTTL(t *testing.T) { require.Equal(t, []netip.Addr{netip.MustParseAddr("8.8.8.8")}, MessageToAddresses(response)) } +func TestExchangeNewModeEvaluateDoesNotLeakAddressesToNextQuery(t *testing.T) { + t.Parallel() + + transportManager := &fakeDNSTransportManager{ + defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + transports: map[string]adapter.DNSTransport{ + "upstream": &fakeDNSTransport{tag: "upstream", transportType: C.DNSTypeUDP}, + "selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP}, + "default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + }, + } + var inspectedSelected bool + client := &fakeDNSClient{ + beforeExchange: func(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg) { + if transport.Tag() != "selected" { + return + } + inspectedSelected = true + metadata := adapter.ContextFrom(ctx) + require.NotNil(t, metadata) + require.Empty(t, metadata.DestinationAddresses) + require.NotNil(t, metadata.DNSResponse) + }, + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { + switch transport.Tag() { + case "upstream": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil + case "selected": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil + default: + return nil, errors.New("unexpected transport") + } + }, + } + rules := []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeEvaluate, + RouteOptions: option.DNSRouteActionOptions{Server: "upstream"}, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + MatchResponse: true, + ResponseAnswer: badoption.Listable[option.DNSRecordOptions]{mustRecord(t, "example.com. IN A 1.1.1.1")}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, + }, + }, + }, + } + router := newTestRouter(t, rules, transportManager, client) + + response, err := router.Exchange(context.Background(), &mDNS.Msg{ + Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)}, + }, adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.True(t, inspectedSelected) + require.Equal(t, []netip.Addr{netip.MustParseAddr("8.8.8.8")}, MessageToAddresses(response)) +} + +func TestExchangeNewModeEvaluateRouteResolutionFailureClearsResponse(t *testing.T) { + t.Parallel() + + transportManager := &fakeDNSTransportManager{ + defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + transports: map[string]adapter.DNSTransport{ + "upstream": &fakeDNSTransport{tag: "upstream", transportType: C.DNSTypeUDP}, + "selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP}, + "default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + }, + } + client := &fakeDNSClient{ + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { + switch transport.Tag() { + case "upstream": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil + case "selected": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil + case "default": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("4.4.4.4")}, 60), nil + default: + return nil, errors.New("unexpected transport") + } + }, + } + rules := []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeEvaluate, + RouteOptions: option.DNSRouteActionOptions{Server: "upstream"}, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeEvaluate, + RouteOptions: option.DNSRouteActionOptions{Server: "missing"}, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + MatchResponse: true, + ResponseAnswer: badoption.Listable[option.DNSRecordOptions]{mustRecord(t, "example.com. IN A 1.1.1.1")}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, + }, + }, + }, + } + router := newTestRouter(t, rules, transportManager, client) + + response, err := router.Exchange(context.Background(), &mDNS.Msg{ + Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)}, + }, adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("4.4.4.4")}, MessageToAddresses(response)) +} + func TestLookupNewModeAllowsPartialSuccess(t *testing.T) { t.Parallel() diff --git a/route/rule/rule_dns.go b/route/rule/rule_dns.go index 4d6636dc5..f5914f89e 100644 --- a/route/rule/rule_dns.go +++ b/route/rule/rule_dns.go @@ -358,7 +358,9 @@ func (r *DefaultDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) r if metadata.DNSResponse == nil { return 0 } - return r.abstractDefaultRule.matchStates(metadata) + matchMetadata := *metadata + matchMetadata.DestinationAddressMatchFromResponse = true + return r.abstractDefaultRule.matchStates(&matchMetadata) } matchMetadata := *metadata matchMetadata.IgnoreDestinationIPCIDRMatch = true diff --git a/route/rule/rule_item_cidr.go b/route/rule/rule_item_cidr.go index c823dcf30..321b6e5d2 100644 --- a/route/rule/rule_item_cidr.go +++ b/route/rule/rule_item_cidr.go @@ -76,11 +76,20 @@ func (r *IPCIDRItem) Match(metadata *adapter.InboundContext) bool { if r.isSource || metadata.IPCIDRMatchSource { return r.ipSet.Contains(metadata.Source.Addr) } + if metadata.DestinationAddressMatchFromResponse { + for _, address := range metadata.DestinationAddressesForMatch() { + if r.ipSet.Contains(address) { + return true + } + } + return metadata.IPCIDRAcceptEmpty + } if metadata.Destination.IsIP() { return r.ipSet.Contains(metadata.Destination.Addr) } - if len(metadata.DestinationAddresses) > 0 { - for _, address := range metadata.DestinationAddresses { + addresses := metadata.DestinationAddressesForMatch() + if len(addresses) > 0 { + for _, address := range addresses { if r.ipSet.Contains(address) { return true } diff --git a/route/rule/rule_item_ip_is_private.go b/route/rule/rule_item_ip_is_private.go index e185db1db..b71255bae 100644 --- a/route/rule/rule_item_ip_is_private.go +++ b/route/rule/rule_item_ip_is_private.go @@ -1,8 +1,6 @@ package rule import ( - "net/netip" - "github.com/sagernet/sing-box/adapter" N "github.com/sagernet/sing/common/network" ) @@ -18,21 +16,24 @@ func NewIPIsPrivateItem(isSource bool) *IPIsPrivateItem { } func (r *IPIsPrivateItem) Match(metadata *adapter.InboundContext) bool { - var destination netip.Addr if r.isSource { - destination = metadata.Source.Addr - } else { - destination = metadata.Destination.Addr + return !N.IsPublicAddr(metadata.Source.Addr) } - if destination.IsValid() { - return !N.IsPublicAddr(destination) - } - if !r.isSource { - for _, destinationAddress := range metadata.DestinationAddresses { + if metadata.DestinationAddressMatchFromResponse { + for _, destinationAddress := range metadata.DestinationAddressesForMatch() { if !N.IsPublicAddr(destinationAddress) { return true } } + return false + } + if metadata.Destination.Addr.IsValid() { + return !N.IsPublicAddr(metadata.Destination.Addr) + } + for _, destinationAddress := range metadata.DestinationAddressesForMatch() { + if !N.IsPublicAddr(destinationAddress) { + return true + } } return false } diff --git a/route/rule/rule_set_semantics_test.go b/route/rule/rule_set_semantics_test.go index a01defe6e..17a7da9d5 100644 --- a/route/rule/rule_set_semantics_test.go +++ b/route/rule/rule_set_semantics_test.go @@ -2,6 +2,7 @@ package rule import ( "context" + "net" "net/netip" "strings" "testing" @@ -14,6 +15,7 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + mDNS "github.com/miekg/dns" "github.com/stretchr/testify/require" ) @@ -616,6 +618,27 @@ func TestDNSRuleSetSemantics(t *testing.T) { }) } +func TestDNSMatchResponseRuleSetDestinationCIDRUsesDNSResponse(t *testing.T) { + t.Parallel() + + ruleSet := newLocalRuleSetForTest("dns-response-ipcidr", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + })) + rule := dnsRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + }) + rule.matchResponse = true + + matchedMetadata := testMetadata("lookup.example") + matchedMetadata.DNSResponse = dnsResponseForTest(netip.MustParseAddr("203.0.113.1")) + require.True(t, rule.Match(&matchedMetadata)) + require.Empty(t, matchedMetadata.DestinationAddresses) + + unmatchedMetadata := testMetadata("lookup.example") + unmatchedMetadata.DNSResponse = dnsResponseForTest(netip.MustParseAddr("8.8.8.8")) + require.False(t, rule.Match(&unmatchedMetadata)) +} + func TestDNSInvertAddressLimitPreLookupRegression(t *testing.T) { t.Parallel() testCases := []struct { @@ -763,6 +786,39 @@ func testMetadata(domain string) adapter.InboundContext { } } +func dnsResponseForTest(addresses ...netip.Addr) *mDNS.Msg { + response := &mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Response: true, + Rcode: mDNS.RcodeSuccess, + }, + } + for _, address := range addresses { + if address.Is4() { + response.Answer = append(response.Answer, &mDNS.A{ + Hdr: mDNS.RR_Header{ + Name: mDNS.Fqdn("lookup.example"), + Rrtype: mDNS.TypeA, + Class: mDNS.ClassINET, + Ttl: 60, + }, + A: net.IP(append([]byte(nil), address.AsSlice()...)), + }) + } else { + response.Answer = append(response.Answer, &mDNS.AAAA{ + Hdr: mDNS.RR_Header{ + Name: mDNS.Fqdn("lookup.example"), + Rrtype: mDNS.TypeAAAA, + Class: mDNS.ClassINET, + Ttl: 60, + }, + AAAA: net.IP(append([]byte(nil), address.AsSlice()...)), + }) + } + } + return response +} + func addRuleSetItem(rule *abstractDefaultRule, item *RuleSetItem) { rule.ruleSetItem = item rule.allItems = append(rule.allItems, item)