Fix DNS match_response response address handling

This commit is contained in:
世界
2026-03-24 20:10:49 +08:00
parent 9b3415f7fc
commit 85fa414474
7 changed files with 266 additions and 26 deletions

View File

@@ -81,15 +81,16 @@ type InboundContext struct {
FallbackNetworkType []C.InterfaceType
FallbackDelay time.Duration
DestinationAddresses []netip.Addr
DNSResponse *dns.Msg
SourceGeoIPCode string
GeoIPCode string
ProcessInfo *ConnectionOwner
SourceMACAddress net.HardwareAddr
SourceHostname string
QueryType uint16
FakeIP bool
DestinationAddresses []netip.Addr
DNSResponse *dns.Msg
DestinationAddressMatchFromResponse bool
SourceGeoIPCode string
GeoIPCode string
ProcessInfo *ConnectionOwner
SourceMACAddress net.HardwareAddr
SourceHostname string
QueryType uint16
FakeIP bool
// rule cache
@@ -118,6 +119,29 @@ func (c *InboundContext) ResetRuleMatchCache() {
c.DidMatch = false
}
func (c *InboundContext) DestinationAddressesForMatch() []netip.Addr {
if c.DestinationAddressMatchFromResponse {
return DNSResponseAddresses(c.DNSResponse)
}
return c.DestinationAddresses
}
func DNSResponseAddresses(response *dns.Msg) []netip.Addr {
if response == nil || response.Rcode != dns.RcodeSuccess {
return nil
}
var addresses []netip.Addr
for _, rawRecord := range response.Answer {
switch record := rawRecord.(type) {
case *dns.A:
addresses = append(addresses, M.AddrFromIP(record.A))
case *dns.AAAA:
addresses = append(addresses, M.AddrFromIP(record.AAAA))
}
}
return addresses
}
type inboundContextKey struct{}
func WithContext(ctx context.Context, inboundContext *InboundContext) context.Context {

View File

@@ -285,7 +285,6 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio
for currentRuleIndex, currentRule := range r.rules {
metadata.ResetRuleCache()
metadata.DNSResponse = savedResponse
metadata.DestinationAddresses = MessageToAddresses(savedResponse)
if !currentRule.Match(metadata) {
continue
}
@@ -303,6 +302,7 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio
if transport == nil {
r.logger.ErrorContext(ctx, "transport not found: ", action.Server)
}
savedResponse = nil
continue
}
if queryOptions.Strategy == C.DomainStrategyAsIS {

View File

@@ -62,7 +62,8 @@ func (m *fakeDNSTransportManager) Create(context.Context, log.ContextLogger, str
}
type fakeDNSClient struct {
exchange func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error)
beforeExchange func(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg)
exchange func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error)
}
type fakeDeprecatedManager struct {
@@ -75,7 +76,10 @@ func (m *fakeDeprecatedManager) ReportDeprecated(feature deprecated.Note) {
func (c *fakeDNSClient) Start() {}
func (c *fakeDNSClient) Exchange(_ context.Context, transport adapter.DNSTransport, message *mDNS.Msg, _ adapter.DNSQueryOptions, _ func([]netip.Addr) bool) (*mDNS.Msg, error) {
func (c *fakeDNSClient) Exchange(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg, _ adapter.DNSQueryOptions, _ func([]netip.Addr) bool) (*mDNS.Msg, error) {
if c.beforeExchange != nil {
c.beforeExchange(ctx, transport, message)
}
return c.exchange(transport, message)
}
@@ -255,6 +259,150 @@ func TestExchangeNewModeEvaluateMatchResponseRouteIgnoresTTL(t *testing.T) {
require.Equal(t, []netip.Addr{netip.MustParseAddr("8.8.8.8")}, MessageToAddresses(response))
}
func TestExchangeNewModeEvaluateDoesNotLeakAddressesToNextQuery(t *testing.T) {
t.Parallel()
transportManager := &fakeDNSTransportManager{
defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
transports: map[string]adapter.DNSTransport{
"upstream": &fakeDNSTransport{tag: "upstream", transportType: C.DNSTypeUDP},
"selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP},
"default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
},
}
var inspectedSelected bool
client := &fakeDNSClient{
beforeExchange: func(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg) {
if transport.Tag() != "selected" {
return
}
inspectedSelected = true
metadata := adapter.ContextFrom(ctx)
require.NotNil(t, metadata)
require.Empty(t, metadata.DestinationAddresses)
require.NotNil(t, metadata.DNSResponse)
},
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
switch transport.Tag() {
case "upstream":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil
case "selected":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil
default:
return nil, errors.New("unexpected transport")
}
},
}
rules := []option.DNSRule{
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
Domain: badoption.Listable[string]{"example.com"},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeEvaluate,
RouteOptions: option.DNSRouteActionOptions{Server: "upstream"},
},
},
},
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
MatchResponse: true,
ResponseAnswer: badoption.Listable[option.DNSRecordOptions]{mustRecord(t, "example.com. IN A 1.1.1.1")},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
},
},
},
}
router := newTestRouter(t, rules, transportManager, client)
response, err := router.Exchange(context.Background(), &mDNS.Msg{
Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)},
}, adapter.DNSQueryOptions{})
require.NoError(t, err)
require.True(t, inspectedSelected)
require.Equal(t, []netip.Addr{netip.MustParseAddr("8.8.8.8")}, MessageToAddresses(response))
}
func TestExchangeNewModeEvaluateRouteResolutionFailureClearsResponse(t *testing.T) {
t.Parallel()
transportManager := &fakeDNSTransportManager{
defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
transports: map[string]adapter.DNSTransport{
"upstream": &fakeDNSTransport{tag: "upstream", transportType: C.DNSTypeUDP},
"selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP},
"default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
},
}
client := &fakeDNSClient{
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
switch transport.Tag() {
case "upstream":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil
case "selected":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil
case "default":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("4.4.4.4")}, 60), nil
default:
return nil, errors.New("unexpected transport")
}
},
}
rules := []option.DNSRule{
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
Domain: badoption.Listable[string]{"example.com"},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeEvaluate,
RouteOptions: option.DNSRouteActionOptions{Server: "upstream"},
},
},
},
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
Domain: badoption.Listable[string]{"example.com"},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeEvaluate,
RouteOptions: option.DNSRouteActionOptions{Server: "missing"},
},
},
},
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
MatchResponse: true,
ResponseAnswer: badoption.Listable[option.DNSRecordOptions]{mustRecord(t, "example.com. IN A 1.1.1.1")},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
},
},
},
}
router := newTestRouter(t, rules, transportManager, client)
response, err := router.Exchange(context.Background(), &mDNS.Msg{
Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)},
}, adapter.DNSQueryOptions{})
require.NoError(t, err)
require.Equal(t, []netip.Addr{netip.MustParseAddr("4.4.4.4")}, MessageToAddresses(response))
}
func TestLookupNewModeAllowsPartialSuccess(t *testing.T) {
t.Parallel()

View File

@@ -358,7 +358,9 @@ func (r *DefaultDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) r
if metadata.DNSResponse == nil {
return 0
}
return r.abstractDefaultRule.matchStates(metadata)
matchMetadata := *metadata
matchMetadata.DestinationAddressMatchFromResponse = true
return r.abstractDefaultRule.matchStates(&matchMetadata)
}
matchMetadata := *metadata
matchMetadata.IgnoreDestinationIPCIDRMatch = true

View File

@@ -76,11 +76,20 @@ func (r *IPCIDRItem) Match(metadata *adapter.InboundContext) bool {
if r.isSource || metadata.IPCIDRMatchSource {
return r.ipSet.Contains(metadata.Source.Addr)
}
if metadata.DestinationAddressMatchFromResponse {
for _, address := range metadata.DestinationAddressesForMatch() {
if r.ipSet.Contains(address) {
return true
}
}
return metadata.IPCIDRAcceptEmpty
}
if metadata.Destination.IsIP() {
return r.ipSet.Contains(metadata.Destination.Addr)
}
if len(metadata.DestinationAddresses) > 0 {
for _, address := range metadata.DestinationAddresses {
addresses := metadata.DestinationAddressesForMatch()
if len(addresses) > 0 {
for _, address := range addresses {
if r.ipSet.Contains(address) {
return true
}

View File

@@ -1,8 +1,6 @@
package rule
import (
"net/netip"
"github.com/sagernet/sing-box/adapter"
N "github.com/sagernet/sing/common/network"
)
@@ -18,21 +16,24 @@ func NewIPIsPrivateItem(isSource bool) *IPIsPrivateItem {
}
func (r *IPIsPrivateItem) Match(metadata *adapter.InboundContext) bool {
var destination netip.Addr
if r.isSource {
destination = metadata.Source.Addr
} else {
destination = metadata.Destination.Addr
return !N.IsPublicAddr(metadata.Source.Addr)
}
if destination.IsValid() {
return !N.IsPublicAddr(destination)
}
if !r.isSource {
for _, destinationAddress := range metadata.DestinationAddresses {
if metadata.DestinationAddressMatchFromResponse {
for _, destinationAddress := range metadata.DestinationAddressesForMatch() {
if !N.IsPublicAddr(destinationAddress) {
return true
}
}
return false
}
if metadata.Destination.Addr.IsValid() {
return !N.IsPublicAddr(metadata.Destination.Addr)
}
for _, destinationAddress := range metadata.DestinationAddressesForMatch() {
if !N.IsPublicAddr(destinationAddress) {
return true
}
}
return false
}

View File

@@ -2,6 +2,7 @@ package rule
import (
"context"
"net"
"net/netip"
"strings"
"testing"
@@ -14,6 +15,7 @@ import (
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
mDNS "github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
@@ -616,6 +618,27 @@ func TestDNSRuleSetSemantics(t *testing.T) {
})
}
func TestDNSMatchResponseRuleSetDestinationCIDRUsesDNSResponse(t *testing.T) {
t.Parallel()
ruleSet := newLocalRuleSetForTest("dns-response-ipcidr", headlessDefaultRule(t, func(rule *abstractDefaultRule) {
addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"})
}))
rule := dnsRuleForTest(func(rule *abstractDefaultRule) {
addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}})
})
rule.matchResponse = true
matchedMetadata := testMetadata("lookup.example")
matchedMetadata.DNSResponse = dnsResponseForTest(netip.MustParseAddr("203.0.113.1"))
require.True(t, rule.Match(&matchedMetadata))
require.Empty(t, matchedMetadata.DestinationAddresses)
unmatchedMetadata := testMetadata("lookup.example")
unmatchedMetadata.DNSResponse = dnsResponseForTest(netip.MustParseAddr("8.8.8.8"))
require.False(t, rule.Match(&unmatchedMetadata))
}
func TestDNSInvertAddressLimitPreLookupRegression(t *testing.T) {
t.Parallel()
testCases := []struct {
@@ -763,6 +786,39 @@ func testMetadata(domain string) adapter.InboundContext {
}
}
func dnsResponseForTest(addresses ...netip.Addr) *mDNS.Msg {
response := &mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
Response: true,
Rcode: mDNS.RcodeSuccess,
},
}
for _, address := range addresses {
if address.Is4() {
response.Answer = append(response.Answer, &mDNS.A{
Hdr: mDNS.RR_Header{
Name: mDNS.Fqdn("lookup.example"),
Rrtype: mDNS.TypeA,
Class: mDNS.ClassINET,
Ttl: 60,
},
A: net.IP(append([]byte(nil), address.AsSlice()...)),
})
} else {
response.Answer = append(response.Answer, &mDNS.AAAA{
Hdr: mDNS.RR_Header{
Name: mDNS.Fqdn("lookup.example"),
Rrtype: mDNS.TypeAAAA,
Class: mDNS.ClassINET,
Ttl: 60,
},
AAAA: net.IP(append([]byte(nil), address.AsSlice()...)),
})
}
}
return response
}
func addRuleSetItem(rule *abstractDefaultRule, item *RuleSetItem) {
rule.ruleSetItem = item
rule.allItems = append(rule.allItems, item)