mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
Fix DNS match_response response address handling
This commit is contained in:
@@ -81,15 +81,16 @@ type InboundContext struct {
|
||||
FallbackNetworkType []C.InterfaceType
|
||||
FallbackDelay time.Duration
|
||||
|
||||
DestinationAddresses []netip.Addr
|
||||
DNSResponse *dns.Msg
|
||||
SourceGeoIPCode string
|
||||
GeoIPCode string
|
||||
ProcessInfo *ConnectionOwner
|
||||
SourceMACAddress net.HardwareAddr
|
||||
SourceHostname string
|
||||
QueryType uint16
|
||||
FakeIP bool
|
||||
DestinationAddresses []netip.Addr
|
||||
DNSResponse *dns.Msg
|
||||
DestinationAddressMatchFromResponse bool
|
||||
SourceGeoIPCode string
|
||||
GeoIPCode string
|
||||
ProcessInfo *ConnectionOwner
|
||||
SourceMACAddress net.HardwareAddr
|
||||
SourceHostname string
|
||||
QueryType uint16
|
||||
FakeIP bool
|
||||
|
||||
// rule cache
|
||||
|
||||
@@ -118,6 +119,29 @@ func (c *InboundContext) ResetRuleMatchCache() {
|
||||
c.DidMatch = false
|
||||
}
|
||||
|
||||
func (c *InboundContext) DestinationAddressesForMatch() []netip.Addr {
|
||||
if c.DestinationAddressMatchFromResponse {
|
||||
return DNSResponseAddresses(c.DNSResponse)
|
||||
}
|
||||
return c.DestinationAddresses
|
||||
}
|
||||
|
||||
func DNSResponseAddresses(response *dns.Msg) []netip.Addr {
|
||||
if response == nil || response.Rcode != dns.RcodeSuccess {
|
||||
return nil
|
||||
}
|
||||
var addresses []netip.Addr
|
||||
for _, rawRecord := range response.Answer {
|
||||
switch record := rawRecord.(type) {
|
||||
case *dns.A:
|
||||
addresses = append(addresses, M.AddrFromIP(record.A))
|
||||
case *dns.AAAA:
|
||||
addresses = append(addresses, M.AddrFromIP(record.AAAA))
|
||||
}
|
||||
}
|
||||
return addresses
|
||||
}
|
||||
|
||||
type inboundContextKey struct{}
|
||||
|
||||
func WithContext(ctx context.Context, inboundContext *InboundContext) context.Context {
|
||||
|
||||
@@ -285,7 +285,6 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio
|
||||
for currentRuleIndex, currentRule := range r.rules {
|
||||
metadata.ResetRuleCache()
|
||||
metadata.DNSResponse = savedResponse
|
||||
metadata.DestinationAddresses = MessageToAddresses(savedResponse)
|
||||
if !currentRule.Match(metadata) {
|
||||
continue
|
||||
}
|
||||
@@ -303,6 +302,7 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio
|
||||
if transport == nil {
|
||||
r.logger.ErrorContext(ctx, "transport not found: ", action.Server)
|
||||
}
|
||||
savedResponse = nil
|
||||
continue
|
||||
}
|
||||
if queryOptions.Strategy == C.DomainStrategyAsIS {
|
||||
|
||||
@@ -62,7 +62,8 @@ func (m *fakeDNSTransportManager) Create(context.Context, log.ContextLogger, str
|
||||
}
|
||||
|
||||
type fakeDNSClient struct {
|
||||
exchange func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error)
|
||||
beforeExchange func(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg)
|
||||
exchange func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error)
|
||||
}
|
||||
|
||||
type fakeDeprecatedManager struct {
|
||||
@@ -75,7 +76,10 @@ func (m *fakeDeprecatedManager) ReportDeprecated(feature deprecated.Note) {
|
||||
|
||||
func (c *fakeDNSClient) Start() {}
|
||||
|
||||
func (c *fakeDNSClient) Exchange(_ context.Context, transport adapter.DNSTransport, message *mDNS.Msg, _ adapter.DNSQueryOptions, _ func([]netip.Addr) bool) (*mDNS.Msg, error) {
|
||||
func (c *fakeDNSClient) Exchange(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg, _ adapter.DNSQueryOptions, _ func([]netip.Addr) bool) (*mDNS.Msg, error) {
|
||||
if c.beforeExchange != nil {
|
||||
c.beforeExchange(ctx, transport, message)
|
||||
}
|
||||
return c.exchange(transport, message)
|
||||
}
|
||||
|
||||
@@ -255,6 +259,150 @@ func TestExchangeNewModeEvaluateMatchResponseRouteIgnoresTTL(t *testing.T) {
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("8.8.8.8")}, MessageToAddresses(response))
|
||||
}
|
||||
|
||||
func TestExchangeNewModeEvaluateDoesNotLeakAddressesToNextQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
transportManager := &fakeDNSTransportManager{
|
||||
defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
|
||||
transports: map[string]adapter.DNSTransport{
|
||||
"upstream": &fakeDNSTransport{tag: "upstream", transportType: C.DNSTypeUDP},
|
||||
"selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP},
|
||||
"default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
|
||||
},
|
||||
}
|
||||
var inspectedSelected bool
|
||||
client := &fakeDNSClient{
|
||||
beforeExchange: func(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg) {
|
||||
if transport.Tag() != "selected" {
|
||||
return
|
||||
}
|
||||
inspectedSelected = true
|
||||
metadata := adapter.ContextFrom(ctx)
|
||||
require.NotNil(t, metadata)
|
||||
require.Empty(t, metadata.DestinationAddresses)
|
||||
require.NotNil(t, metadata.DNSResponse)
|
||||
},
|
||||
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||
switch transport.Tag() {
|
||||
case "upstream":
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil
|
||||
case "selected":
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected transport")
|
||||
}
|
||||
},
|
||||
}
|
||||
rules := []option.DNSRule{
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
Domain: badoption.Listable[string]{"example.com"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeEvaluate,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "upstream"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
MatchResponse: true,
|
||||
ResponseAnswer: badoption.Listable[option.DNSRecordOptions]{mustRecord(t, "example.com. IN A 1.1.1.1")},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := newTestRouter(t, rules, transportManager, client)
|
||||
|
||||
response, err := router.Exchange(context.Background(), &mDNS.Msg{
|
||||
Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)},
|
||||
}, adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
require.True(t, inspectedSelected)
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("8.8.8.8")}, MessageToAddresses(response))
|
||||
}
|
||||
|
||||
func TestExchangeNewModeEvaluateRouteResolutionFailureClearsResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
transportManager := &fakeDNSTransportManager{
|
||||
defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
|
||||
transports: map[string]adapter.DNSTransport{
|
||||
"upstream": &fakeDNSTransport{tag: "upstream", transportType: C.DNSTypeUDP},
|
||||
"selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP},
|
||||
"default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
|
||||
},
|
||||
}
|
||||
client := &fakeDNSClient{
|
||||
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||
switch transport.Tag() {
|
||||
case "upstream":
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil
|
||||
case "selected":
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil
|
||||
case "default":
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("4.4.4.4")}, 60), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected transport")
|
||||
}
|
||||
},
|
||||
}
|
||||
rules := []option.DNSRule{
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
Domain: badoption.Listable[string]{"example.com"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeEvaluate,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "upstream"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
Domain: badoption.Listable[string]{"example.com"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeEvaluate,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "missing"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
MatchResponse: true,
|
||||
ResponseAnswer: badoption.Listable[option.DNSRecordOptions]{mustRecord(t, "example.com. IN A 1.1.1.1")},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := newTestRouter(t, rules, transportManager, client)
|
||||
|
||||
response, err := router.Exchange(context.Background(), &mDNS.Msg{
|
||||
Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)},
|
||||
}, adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("4.4.4.4")}, MessageToAddresses(response))
|
||||
}
|
||||
|
||||
func TestLookupNewModeAllowsPartialSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -358,7 +358,9 @@ func (r *DefaultDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) r
|
||||
if metadata.DNSResponse == nil {
|
||||
return 0
|
||||
}
|
||||
return r.abstractDefaultRule.matchStates(metadata)
|
||||
matchMetadata := *metadata
|
||||
matchMetadata.DestinationAddressMatchFromResponse = true
|
||||
return r.abstractDefaultRule.matchStates(&matchMetadata)
|
||||
}
|
||||
matchMetadata := *metadata
|
||||
matchMetadata.IgnoreDestinationIPCIDRMatch = true
|
||||
|
||||
@@ -76,11 +76,20 @@ func (r *IPCIDRItem) Match(metadata *adapter.InboundContext) bool {
|
||||
if r.isSource || metadata.IPCIDRMatchSource {
|
||||
return r.ipSet.Contains(metadata.Source.Addr)
|
||||
}
|
||||
if metadata.DestinationAddressMatchFromResponse {
|
||||
for _, address := range metadata.DestinationAddressesForMatch() {
|
||||
if r.ipSet.Contains(address) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return metadata.IPCIDRAcceptEmpty
|
||||
}
|
||||
if metadata.Destination.IsIP() {
|
||||
return r.ipSet.Contains(metadata.Destination.Addr)
|
||||
}
|
||||
if len(metadata.DestinationAddresses) > 0 {
|
||||
for _, address := range metadata.DestinationAddresses {
|
||||
addresses := metadata.DestinationAddressesForMatch()
|
||||
if len(addresses) > 0 {
|
||||
for _, address := range addresses {
|
||||
if r.ipSet.Contains(address) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
@@ -18,21 +16,24 @@ func NewIPIsPrivateItem(isSource bool) *IPIsPrivateItem {
|
||||
}
|
||||
|
||||
func (r *IPIsPrivateItem) Match(metadata *adapter.InboundContext) bool {
|
||||
var destination netip.Addr
|
||||
if r.isSource {
|
||||
destination = metadata.Source.Addr
|
||||
} else {
|
||||
destination = metadata.Destination.Addr
|
||||
return !N.IsPublicAddr(metadata.Source.Addr)
|
||||
}
|
||||
if destination.IsValid() {
|
||||
return !N.IsPublicAddr(destination)
|
||||
}
|
||||
if !r.isSource {
|
||||
for _, destinationAddress := range metadata.DestinationAddresses {
|
||||
if metadata.DestinationAddressMatchFromResponse {
|
||||
for _, destinationAddress := range metadata.DestinationAddressesForMatch() {
|
||||
if !N.IsPublicAddr(destinationAddress) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
if metadata.Destination.Addr.IsValid() {
|
||||
return !N.IsPublicAddr(metadata.Destination.Addr)
|
||||
}
|
||||
for _, destinationAddress := range metadata.DestinationAddressesForMatch() {
|
||||
if !N.IsPublicAddr(destinationAddress) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package rule
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
mDNS "github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -616,6 +618,27 @@ func TestDNSRuleSetSemantics(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSMatchResponseRuleSetDestinationCIDRUsesDNSResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ruleSet := newLocalRuleSetForTest("dns-response-ipcidr", headlessDefaultRule(t, func(rule *abstractDefaultRule) {
|
||||
addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"})
|
||||
}))
|
||||
rule := dnsRuleForTest(func(rule *abstractDefaultRule) {
|
||||
addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}})
|
||||
})
|
||||
rule.matchResponse = true
|
||||
|
||||
matchedMetadata := testMetadata("lookup.example")
|
||||
matchedMetadata.DNSResponse = dnsResponseForTest(netip.MustParseAddr("203.0.113.1"))
|
||||
require.True(t, rule.Match(&matchedMetadata))
|
||||
require.Empty(t, matchedMetadata.DestinationAddresses)
|
||||
|
||||
unmatchedMetadata := testMetadata("lookup.example")
|
||||
unmatchedMetadata.DNSResponse = dnsResponseForTest(netip.MustParseAddr("8.8.8.8"))
|
||||
require.False(t, rule.Match(&unmatchedMetadata))
|
||||
}
|
||||
|
||||
func TestDNSInvertAddressLimitPreLookupRegression(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCases := []struct {
|
||||
@@ -763,6 +786,39 @@ func testMetadata(domain string) adapter.InboundContext {
|
||||
}
|
||||
}
|
||||
|
||||
func dnsResponseForTest(addresses ...netip.Addr) *mDNS.Msg {
|
||||
response := &mDNS.Msg{
|
||||
MsgHdr: mDNS.MsgHdr{
|
||||
Response: true,
|
||||
Rcode: mDNS.RcodeSuccess,
|
||||
},
|
||||
}
|
||||
for _, address := range addresses {
|
||||
if address.Is4() {
|
||||
response.Answer = append(response.Answer, &mDNS.A{
|
||||
Hdr: mDNS.RR_Header{
|
||||
Name: mDNS.Fqdn("lookup.example"),
|
||||
Rrtype: mDNS.TypeA,
|
||||
Class: mDNS.ClassINET,
|
||||
Ttl: 60,
|
||||
},
|
||||
A: net.IP(append([]byte(nil), address.AsSlice()...)),
|
||||
})
|
||||
} else {
|
||||
response.Answer = append(response.Answer, &mDNS.AAAA{
|
||||
Hdr: mDNS.RR_Header{
|
||||
Name: mDNS.Fqdn("lookup.example"),
|
||||
Rrtype: mDNS.TypeAAAA,
|
||||
Class: mDNS.ClassINET,
|
||||
Ttl: 60,
|
||||
},
|
||||
AAAA: net.IP(append([]byte(nil), address.AsSlice()...)),
|
||||
})
|
||||
}
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
func addRuleSetItem(rule *abstractDefaultRule, item *RuleSetItem) {
|
||||
rule.ruleSetItem = item
|
||||
rule.allItems = append(rule.allItems, item)
|
||||
|
||||
Reference in New Issue
Block a user