mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-11 17:47:20 +10:00
Add evaluate DNS rule action and related rule items
This commit is contained in:
272
option/dns.go
272
option/dns.go
@@ -3,19 +3,14 @@ package option
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/experimental/deprecated"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
"github.com/sagernet/sing/common/json/badjson"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/service"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type RawDNSOptions struct {
|
||||
@@ -26,80 +21,29 @@ type RawDNSOptions struct {
|
||||
DNSClientOptions
|
||||
}
|
||||
|
||||
type LegacyDNSOptions struct {
|
||||
FakeIP *LegacyDNSFakeIPOptions `json:"fakeip,omitempty"`
|
||||
}
|
||||
|
||||
type DNSOptions struct {
|
||||
RawDNSOptions
|
||||
LegacyDNSOptions
|
||||
}
|
||||
|
||||
type contextKeyDontUpgrade struct{}
|
||||
const (
|
||||
legacyDNSFakeIPRemovedMessage = "legacy DNS fakeip options are deprecated in sing-box 1.12.0 and removed in sing-box 1.14.0, checkout migration: https://sing-box.sagernet.org/migration/#migrate-to-new-dns-server-formats"
|
||||
legacyDNSServerRemovedMessage = "legacy DNS server formats are deprecated in sing-box 1.12.0 and removed in sing-box 1.14.0, checkout migration: https://sing-box.sagernet.org/migration/#migrate-to-new-dns-server-formats"
|
||||
)
|
||||
|
||||
func ContextWithDontUpgrade(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, (*contextKeyDontUpgrade)(nil), true)
|
||||
}
|
||||
|
||||
func dontUpgradeFromContext(ctx context.Context) bool {
|
||||
return ctx.Value((*contextKeyDontUpgrade)(nil)) == true
|
||||
type removedLegacyDNSOptions struct {
|
||||
FakeIP json.RawMessage `json:"fakeip,omitempty"`
|
||||
}
|
||||
|
||||
func (o *DNSOptions) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||
err := json.UnmarshalContext(ctx, content, &o.LegacyDNSOptions)
|
||||
var legacyOptions removedLegacyDNSOptions
|
||||
err := json.UnmarshalContext(ctx, content, &legacyOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dontUpgrade := dontUpgradeFromContext(ctx)
|
||||
legacyOptions := o.LegacyDNSOptions
|
||||
if !dontUpgrade {
|
||||
if o.FakeIP != nil && o.FakeIP.Enabled {
|
||||
deprecated.Report(ctx, deprecated.OptionLegacyDNSFakeIPOptions)
|
||||
ctx = context.WithValue(ctx, (*LegacyDNSFakeIPOptions)(nil), o.FakeIP)
|
||||
}
|
||||
o.LegacyDNSOptions = LegacyDNSOptions{}
|
||||
if len(legacyOptions.FakeIP) != 0 {
|
||||
return E.New(legacyDNSFakeIPRemovedMessage)
|
||||
}
|
||||
err = badjson.UnmarshallExcludedContext(ctx, content, legacyOptions, &o.RawDNSOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !dontUpgrade {
|
||||
rcodeMap := make(map[string]int)
|
||||
o.Servers = common.Filter(o.Servers, func(it DNSServerOptions) bool {
|
||||
if it.Type == C.DNSTypeLegacyRcode {
|
||||
rcodeMap[it.Tag] = it.Options.(int)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if len(rcodeMap) > 0 {
|
||||
for i := 0; i < len(o.Rules); i++ {
|
||||
rewriteRcode(rcodeMap, &o.Rules[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func rewriteRcode(rcodeMap map[string]int, rule *DNSRule) {
|
||||
switch rule.Type {
|
||||
case C.RuleTypeDefault:
|
||||
rewriteRcodeAction(rcodeMap, &rule.DefaultOptions.DNSRuleAction)
|
||||
case C.RuleTypeLogical:
|
||||
rewriteRcodeAction(rcodeMap, &rule.LogicalOptions.DNSRuleAction)
|
||||
}
|
||||
}
|
||||
|
||||
func rewriteRcodeAction(rcodeMap map[string]int, ruleAction *DNSRuleAction) {
|
||||
if ruleAction.Action != C.RuleActionTypeRoute {
|
||||
return
|
||||
}
|
||||
rcode, loaded := rcodeMap[ruleAction.RouteOptions.Server]
|
||||
if !loaded {
|
||||
return
|
||||
}
|
||||
ruleAction.Action = C.RuleActionTypePredefined
|
||||
ruleAction.PredefinedOptions.Rcode = common.Ptr(DNSRCode(rcode))
|
||||
return badjson.UnmarshallExcludedContext(ctx, content, legacyOptions, &o.RawDNSOptions)
|
||||
}
|
||||
|
||||
type DNSClientOptions struct {
|
||||
@@ -111,12 +55,6 @@ type DNSClientOptions struct {
|
||||
ClientSubnet *badoption.Prefixable `json:"client_subnet,omitempty"`
|
||||
}
|
||||
|
||||
type LegacyDNSFakeIPOptions struct {
|
||||
Enabled bool `json:"enabled,omitempty"`
|
||||
Inet4Range *badoption.Prefix `json:"inet4_range,omitempty"`
|
||||
Inet6Range *badoption.Prefix `json:"inet6_range,omitempty"`
|
||||
}
|
||||
|
||||
type DNSTransportOptionsRegistry interface {
|
||||
CreateOptions(transportType string) (any, bool)
|
||||
}
|
||||
@@ -129,10 +67,6 @@ type _DNSServerOptions struct {
|
||||
type DNSServerOptions _DNSServerOptions
|
||||
|
||||
func (o *DNSServerOptions) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||
switch o.Type {
|
||||
case C.DNSTypeLegacy:
|
||||
o.Type = ""
|
||||
}
|
||||
return badjson.MarshallObjectsContext(ctx, (*_DNSServerOptions)(o), o.Options)
|
||||
}
|
||||
|
||||
@@ -148,9 +82,7 @@ func (o *DNSServerOptions) UnmarshalJSONContext(ctx context.Context, content []b
|
||||
var options any
|
||||
switch o.Type {
|
||||
case "", C.DNSTypeLegacy:
|
||||
o.Type = C.DNSTypeLegacy
|
||||
options = new(LegacyDNSServerOptions)
|
||||
deprecated.Report(ctx, deprecated.OptionLegacyDNSTransport)
|
||||
return E.New(legacyDNSServerRemovedMessage)
|
||||
default:
|
||||
var loaded bool
|
||||
options, loaded = registry.CreateOptions(o.Type)
|
||||
@@ -163,169 +95,6 @@ func (o *DNSServerOptions) UnmarshalJSONContext(ctx context.Context, content []b
|
||||
return err
|
||||
}
|
||||
o.Options = options
|
||||
if o.Type == C.DNSTypeLegacy && !dontUpgradeFromContext(ctx) {
|
||||
err = o.Upgrade(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *DNSServerOptions) Upgrade(ctx context.Context) error {
|
||||
if o.Type != C.DNSTypeLegacy {
|
||||
return nil
|
||||
}
|
||||
options := o.Options.(*LegacyDNSServerOptions)
|
||||
serverURL, _ := url.Parse(options.Address)
|
||||
var serverType string
|
||||
if serverURL != nil && serverURL.Scheme != "" {
|
||||
serverType = serverURL.Scheme
|
||||
} else {
|
||||
switch options.Address {
|
||||
case "local", "fakeip":
|
||||
serverType = options.Address
|
||||
default:
|
||||
serverType = C.DNSTypeUDP
|
||||
}
|
||||
}
|
||||
remoteOptions := RemoteDNSServerOptions{
|
||||
RawLocalDNSServerOptions: RawLocalDNSServerOptions{
|
||||
DialerOptions: DialerOptions{
|
||||
Detour: options.Detour,
|
||||
DomainResolver: &DomainResolveOptions{
|
||||
Server: options.AddressResolver,
|
||||
Strategy: options.AddressStrategy,
|
||||
},
|
||||
FallbackDelay: options.AddressFallbackDelay,
|
||||
},
|
||||
Legacy: true,
|
||||
LegacyStrategy: options.Strategy,
|
||||
LegacyDefaultDialer: options.Detour == "",
|
||||
LegacyClientSubnet: options.ClientSubnet.Build(netip.Prefix{}),
|
||||
},
|
||||
LegacyAddressResolver: options.AddressResolver,
|
||||
LegacyAddressStrategy: options.AddressStrategy,
|
||||
LegacyAddressFallbackDelay: options.AddressFallbackDelay,
|
||||
}
|
||||
switch serverType {
|
||||
case C.DNSTypeLocal:
|
||||
o.Type = C.DNSTypeLocal
|
||||
o.Options = &LocalDNSServerOptions{
|
||||
RawLocalDNSServerOptions: remoteOptions.RawLocalDNSServerOptions,
|
||||
}
|
||||
case C.DNSTypeUDP:
|
||||
o.Type = C.DNSTypeUDP
|
||||
o.Options = &remoteOptions
|
||||
var serverAddr M.Socksaddr
|
||||
if serverURL == nil || serverURL.Scheme == "" {
|
||||
serverAddr = M.ParseSocksaddr(options.Address)
|
||||
} else {
|
||||
serverAddr = M.ParseSocksaddr(serverURL.Host)
|
||||
}
|
||||
if !serverAddr.IsValid() {
|
||||
return E.New("invalid server address")
|
||||
}
|
||||
remoteOptions.Server = serverAddr.AddrString()
|
||||
if serverAddr.Port != 0 && serverAddr.Port != 53 {
|
||||
remoteOptions.ServerPort = serverAddr.Port
|
||||
}
|
||||
case C.DNSTypeTCP:
|
||||
o.Type = C.DNSTypeTCP
|
||||
o.Options = &remoteOptions
|
||||
if serverURL == nil {
|
||||
return E.New("invalid server address")
|
||||
}
|
||||
serverAddr := M.ParseSocksaddr(serverURL.Host)
|
||||
if !serverAddr.IsValid() {
|
||||
return E.New("invalid server address")
|
||||
}
|
||||
remoteOptions.Server = serverAddr.AddrString()
|
||||
if serverAddr.Port != 0 && serverAddr.Port != 53 {
|
||||
remoteOptions.ServerPort = serverAddr.Port
|
||||
}
|
||||
case C.DNSTypeTLS, C.DNSTypeQUIC:
|
||||
o.Type = serverType
|
||||
if serverURL == nil {
|
||||
return E.New("invalid server address")
|
||||
}
|
||||
serverAddr := M.ParseSocksaddr(serverURL.Host)
|
||||
if !serverAddr.IsValid() {
|
||||
return E.New("invalid server address")
|
||||
}
|
||||
remoteOptions.Server = serverAddr.AddrString()
|
||||
if serverAddr.Port != 0 && serverAddr.Port != 853 {
|
||||
remoteOptions.ServerPort = serverAddr.Port
|
||||
}
|
||||
o.Options = &RemoteTLSDNSServerOptions{
|
||||
RemoteDNSServerOptions: remoteOptions,
|
||||
}
|
||||
case C.DNSTypeHTTPS, C.DNSTypeHTTP3:
|
||||
o.Type = serverType
|
||||
httpsOptions := RemoteHTTPSDNSServerOptions{
|
||||
RemoteTLSDNSServerOptions: RemoteTLSDNSServerOptions{
|
||||
RemoteDNSServerOptions: remoteOptions,
|
||||
},
|
||||
}
|
||||
o.Options = &httpsOptions
|
||||
if serverURL == nil {
|
||||
return E.New("invalid server address")
|
||||
}
|
||||
serverAddr := M.ParseSocksaddr(serverURL.Host)
|
||||
if !serverAddr.IsValid() {
|
||||
return E.New("invalid server address")
|
||||
}
|
||||
httpsOptions.Server = serverAddr.AddrString()
|
||||
if serverAddr.Port != 0 && serverAddr.Port != 443 {
|
||||
httpsOptions.ServerPort = serverAddr.Port
|
||||
}
|
||||
if serverURL.Path != "/dns-query" {
|
||||
httpsOptions.Path = serverURL.Path
|
||||
}
|
||||
case "rcode":
|
||||
var rcode int
|
||||
if serverURL == nil {
|
||||
return E.New("invalid server address")
|
||||
}
|
||||
switch serverURL.Host {
|
||||
case "success":
|
||||
rcode = dns.RcodeSuccess
|
||||
case "format_error":
|
||||
rcode = dns.RcodeFormatError
|
||||
case "server_failure":
|
||||
rcode = dns.RcodeServerFailure
|
||||
case "name_error":
|
||||
rcode = dns.RcodeNameError
|
||||
case "not_implemented":
|
||||
rcode = dns.RcodeNotImplemented
|
||||
case "refused":
|
||||
rcode = dns.RcodeRefused
|
||||
default:
|
||||
return E.New("unknown rcode: ", serverURL.Host)
|
||||
}
|
||||
o.Type = C.DNSTypeLegacyRcode
|
||||
o.Options = rcode
|
||||
case C.DNSTypeDHCP:
|
||||
o.Type = C.DNSTypeDHCP
|
||||
dhcpOptions := DHCPDNSServerOptions{}
|
||||
if serverURL == nil {
|
||||
return E.New("invalid server address")
|
||||
}
|
||||
if serverURL.Host != "" && serverURL.Host != "auto" {
|
||||
dhcpOptions.Interface = serverURL.Host
|
||||
}
|
||||
o.Options = &dhcpOptions
|
||||
case C.DNSTypeFakeIP:
|
||||
o.Type = C.DNSTypeFakeIP
|
||||
fakeipOptions := FakeIPDNSServerOptions{}
|
||||
if legacyOptions, loaded := ctx.Value((*LegacyDNSFakeIPOptions)(nil)).(*LegacyDNSFakeIPOptions); loaded {
|
||||
fakeipOptions.Inet4Range = legacyOptions.Inet4Range
|
||||
fakeipOptions.Inet6Range = legacyOptions.Inet6Range
|
||||
}
|
||||
o.Options = &fakeipOptions
|
||||
default:
|
||||
return E.New("unsupported DNS server scheme: ", serverType)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -350,16 +119,6 @@ func (o *DNSServerAddressOptions) ReplaceServerOptions(options ServerOptions) {
|
||||
*o = DNSServerAddressOptions(options)
|
||||
}
|
||||
|
||||
type LegacyDNSServerOptions struct {
|
||||
Address string `json:"address"`
|
||||
AddressResolver string `json:"address_resolver,omitempty"`
|
||||
AddressStrategy DomainStrategy `json:"address_strategy,omitempty"`
|
||||
AddressFallbackDelay badoption.Duration `json:"address_fallback_delay,omitempty"`
|
||||
Strategy DomainStrategy `json:"strategy,omitempty"`
|
||||
Detour string `json:"detour,omitempty"`
|
||||
ClientSubnet *badoption.Prefixable `json:"client_subnet,omitempty"`
|
||||
}
|
||||
|
||||
type HostsDNSServerOptions struct {
|
||||
Path badoption.Listable[string] `json:"path,omitempty"`
|
||||
Predefined *badjson.TypedMap[string, badoption.Listable[netip.Addr]] `json:"predefined,omitempty"`
|
||||
@@ -367,10 +126,6 @@ type HostsDNSServerOptions struct {
|
||||
|
||||
type RawLocalDNSServerOptions struct {
|
||||
DialerOptions
|
||||
Legacy bool `json:"-"`
|
||||
LegacyStrategy DomainStrategy `json:"-"`
|
||||
LegacyDefaultDialer bool `json:"-"`
|
||||
LegacyClientSubnet netip.Prefix `json:"-"`
|
||||
}
|
||||
|
||||
type LocalDNSServerOptions struct {
|
||||
@@ -381,9 +136,6 @@ type LocalDNSServerOptions struct {
|
||||
type RemoteDNSServerOptions struct {
|
||||
RawLocalDNSServerOptions
|
||||
DNSServerAddressOptions
|
||||
LegacyAddressResolver string `json:"-"`
|
||||
LegacyAddressStrategy DomainStrategy `json:"-"`
|
||||
LegacyAddressFallbackDelay badoption.Duration `json:"-"`
|
||||
}
|
||||
|
||||
type RemoteTLSDNSServerOptions struct {
|
||||
|
||||
@@ -2,6 +2,7 @@ package option
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
@@ -11,6 +12,8 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const defaultDNSRecordTTL uint32 = 3600
|
||||
|
||||
type DNSRCode int
|
||||
|
||||
func (r DNSRCode) MarshalJSON() ([]byte, error) {
|
||||
@@ -76,10 +79,13 @@ func (o *DNSRecordOptions) UnmarshalJSON(data []byte) error {
|
||||
if err == nil {
|
||||
return o.unmarshalBase64(binary)
|
||||
}
|
||||
record, err := dns.NewRR(stringValue)
|
||||
record, err := parseDNSRecord(stringValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if record == nil {
|
||||
return E.New("empty DNS record")
|
||||
}
|
||||
if a, isA := record.(*dns.A); isA {
|
||||
a.A = M.AddrFromIP(a.A).Unmap().AsSlice()
|
||||
}
|
||||
@@ -87,6 +93,16 @@ func (o *DNSRecordOptions) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseDNSRecord(stringValue string) (dns.RR, error) {
|
||||
if len(stringValue) > 0 && stringValue[len(stringValue)-1] != '\n' {
|
||||
stringValue += "\n"
|
||||
}
|
||||
parser := dns.NewZoneParser(strings.NewReader(stringValue), "", "")
|
||||
parser.SetDefaultTTL(defaultDNSRecordTTL)
|
||||
record, _ := parser.Next()
|
||||
return record, parser.Err()
|
||||
}
|
||||
|
||||
func (o *DNSRecordOptions) unmarshalBase64(binary []byte) error {
|
||||
record, _, err := dns.UnpackRR(binary, 0)
|
||||
if err != nil {
|
||||
@@ -100,3 +116,10 @@ func (o *DNSRecordOptions) unmarshalBase64(binary []byte) error {
|
||||
func (o DNSRecordOptions) Build() dns.RR {
|
||||
return o.RR
|
||||
}
|
||||
|
||||
func (o DNSRecordOptions) Match(record dns.RR) bool {
|
||||
if o.RR == nil || record == nil {
|
||||
return false
|
||||
}
|
||||
return dns.IsDuplicate(o.RR, record)
|
||||
}
|
||||
|
||||
40
option/dns_record_test.go
Normal file
40
option/dns_record_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package option
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func mustRecordOptions(t *testing.T, record string) DNSRecordOptions {
|
||||
t.Helper()
|
||||
var value DNSRecordOptions
|
||||
require.NoError(t, value.UnmarshalJSON([]byte(`"`+record+`"`)))
|
||||
return value
|
||||
}
|
||||
|
||||
func TestDNSRecordOptionsUnmarshalJSONRejectsRelativeNames(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, record := range []string{
|
||||
"@ IN A 1.1.1.1",
|
||||
"www IN CNAME example.com.",
|
||||
"example.com. IN CNAME @",
|
||||
"example.com. IN CNAME www",
|
||||
} {
|
||||
var value DNSRecordOptions
|
||||
err := value.UnmarshalJSON([]byte(`"` + record + `"`))
|
||||
require.Error(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSRecordOptionsMatchIgnoresTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expected := mustRecordOptions(t, "example.com. 600 IN A 1.1.1.1")
|
||||
record, err := dns.NewRR("example.com. 60 IN A 1.1.1.1")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, expected.Match(record))
|
||||
}
|
||||
54
option/dns_test.go
Normal file
54
option/dns_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package option
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
"github.com/sagernet/sing/service"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type stubDNSTransportOptionsRegistry struct{}
|
||||
|
||||
func (stubDNSTransportOptionsRegistry) CreateOptions(transportType string) (any, bool) {
|
||||
switch transportType {
|
||||
case C.DNSTypeUDP:
|
||||
return new(RemoteDNSServerOptions), true
|
||||
case C.DNSTypeFakeIP:
|
||||
return new(FakeIPDNSServerOptions), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOptionsRejectsLegacyFakeIPOptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := service.ContextWith[DNSTransportOptionsRegistry](context.Background(), stubDNSTransportOptionsRegistry{})
|
||||
var options DNSOptions
|
||||
err := json.UnmarshalContext(ctx, []byte(`{
|
||||
"fakeip": {
|
||||
"enabled": true,
|
||||
"inet4_range": "198.18.0.0/15"
|
||||
}
|
||||
}`), &options)
|
||||
require.EqualError(t, err, legacyDNSFakeIPRemovedMessage)
|
||||
}
|
||||
|
||||
func TestDNSServerOptionsRejectsLegacyFormats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := service.ContextWith[DNSTransportOptionsRegistry](context.Background(), stubDNSTransportOptionsRegistry{})
|
||||
testCases := []string{
|
||||
`{"address":"1.1.1.1"}`,
|
||||
`{"type":"legacy","address":"1.1.1.1"}`,
|
||||
}
|
||||
for _, content := range testCases {
|
||||
var options DNSServerOptions
|
||||
err := json.UnmarshalContext(ctx, []byte(content), &options)
|
||||
require.EqualError(t, err, legacyDNSServerRemovedMessage)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package option
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
@@ -33,26 +34,24 @@ func (r Rule) MarshalJSON() ([]byte, error) {
|
||||
return badjson.MarshallObjects((_Rule)(r), v)
|
||||
}
|
||||
|
||||
func (r *Rule) UnmarshalJSON(bytes []byte) error {
|
||||
err := json.Unmarshal(bytes, (*_Rule)(r))
|
||||
func (r *Rule) UnmarshalJSONContext(ctx context.Context, bytes []byte) error {
|
||||
err := json.UnmarshalContext(ctx, bytes, (*_Rule)(r))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload, err := rulePayloadWithoutType(ctx, bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var v any
|
||||
switch r.Type {
|
||||
case "", C.RuleTypeDefault:
|
||||
r.Type = C.RuleTypeDefault
|
||||
v = &r.DefaultOptions
|
||||
return unmarshalDefaultRuleContext(ctx, payload, &r.DefaultOptions)
|
||||
case C.RuleTypeLogical:
|
||||
v = &r.LogicalOptions
|
||||
return unmarshalLogicalRuleContext(ctx, payload, &r.LogicalOptions)
|
||||
default:
|
||||
return E.New("unknown rule type: " + r.Type)
|
||||
}
|
||||
err = badjson.UnmarshallExcluded(bytes, (*_Rule)(r), v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r Rule) IsValid() bool {
|
||||
@@ -160,6 +159,64 @@ func (r *LogicalRule) UnmarshalJSON(data []byte) error {
|
||||
return badjson.UnmarshallExcluded(data, &r.RawLogicalRule, &r.RuleAction)
|
||||
}
|
||||
|
||||
func rulePayloadWithoutType(ctx context.Context, data []byte) ([]byte, error) {
|
||||
var content badjson.JSONObject
|
||||
err := content.UnmarshalJSONContext(ctx, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content.Remove("type")
|
||||
return content.MarshalJSONContext(ctx)
|
||||
}
|
||||
|
||||
func unmarshalDefaultRuleContext(ctx context.Context, data []byte, rule *DefaultRule) error {
|
||||
rawAction, routeOptions, err := inspectRouteRuleAction(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = rejectNestedRouteRuleAction(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
depth := nestedRuleDepth(ctx)
|
||||
err = json.UnmarshalContext(ctx, data, &rule.RawDefaultRule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = badjson.UnmarshallExcludedContext(ctx, data, &rule.RawDefaultRule, &rule.RuleAction)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if depth > 0 && rawAction == "" && routeOptions == (RouteActionOptions{}) {
|
||||
rule.RuleAction = RuleAction{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func unmarshalLogicalRuleContext(ctx context.Context, data []byte, rule *LogicalRule) error {
|
||||
rawAction, routeOptions, err := inspectRouteRuleAction(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = rejectNestedRouteRuleAction(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
depth := nestedRuleDepth(ctx)
|
||||
err = json.UnmarshalContext(nestedRuleChildContext(ctx), data, &rule.RawLogicalRule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = badjson.UnmarshallExcludedContext(ctx, data, &rule.RawLogicalRule, &rule.RuleAction)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if depth > 0 && rawAction == "" && routeOptions == (RouteActionOptions{}) {
|
||||
rule.RuleAction = RuleAction{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *LogicalRule) IsValid() bool {
|
||||
return len(r.Rules) > 0 && common.All(r.Rules, Rule.IsValid)
|
||||
}
|
||||
|
||||
@@ -115,6 +115,10 @@ func (r DNSRuleAction) MarshalJSON() ([]byte, error) {
|
||||
case C.RuleActionTypeRoute:
|
||||
r.Action = ""
|
||||
v = r.RouteOptions
|
||||
case C.RuleActionTypeEvaluate:
|
||||
v = r.RouteOptions
|
||||
case C.RuleActionTypeRespond:
|
||||
v = nil
|
||||
case C.RuleActionTypeRouteOptions:
|
||||
v = r.RouteOptionsOptions
|
||||
case C.RuleActionTypeReject:
|
||||
@@ -124,6 +128,9 @@ func (r DNSRuleAction) MarshalJSON() ([]byte, error) {
|
||||
default:
|
||||
return nil, E.New("unknown DNS rule action: " + r.Action)
|
||||
}
|
||||
if v == nil {
|
||||
return badjson.MarshallObjects((_DNSRuleAction)(r))
|
||||
}
|
||||
return badjson.MarshallObjects((_DNSRuleAction)(r), v)
|
||||
}
|
||||
|
||||
@@ -137,6 +144,10 @@ func (r *DNSRuleAction) UnmarshalJSONContext(ctx context.Context, data []byte) e
|
||||
case "", C.RuleActionTypeRoute:
|
||||
r.Action = C.RuleActionTypeRoute
|
||||
v = &r.RouteOptions
|
||||
case C.RuleActionTypeEvaluate:
|
||||
v = &r.RouteOptions
|
||||
case C.RuleActionTypeRespond:
|
||||
v = nil
|
||||
case C.RuleActionTypeRouteOptions:
|
||||
v = &r.RouteOptionsOptions
|
||||
case C.RuleActionTypeReject:
|
||||
@@ -146,6 +157,9 @@ func (r *DNSRuleAction) UnmarshalJSONContext(ctx context.Context, data []byte) e
|
||||
default:
|
||||
return E.New("unknown DNS rule action: " + r.Action)
|
||||
}
|
||||
if v == nil {
|
||||
return json.UnmarshalDisallowUnknownFields(data, &_DNSRuleAction{})
|
||||
}
|
||||
return badjson.UnmarshallExcludedContext(ctx, data, (*_DNSRuleAction)(r), v)
|
||||
}
|
||||
|
||||
|
||||
29
option/rule_action_test.go
Normal file
29
option/rule_action_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package option
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDNSRuleActionRespondUnmarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var action DNSRuleAction
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{"action":"respond"}`), &action)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, C.RuleActionTypeRespond, action.Action)
|
||||
require.Equal(t, DNSRouteActionOptions{}, action.RouteOptions)
|
||||
}
|
||||
|
||||
func TestDNSRuleActionRespondRejectsUnknownFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var action DNSRuleAction
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{"action":"respond","disable_cache":true}`), &action)
|
||||
require.ErrorContains(t, err, "unknown field")
|
||||
}
|
||||
@@ -35,7 +35,7 @@ func (r DNSRule) MarshalJSON() ([]byte, error) {
|
||||
}
|
||||
|
||||
func (r *DNSRule) UnmarshalJSONContext(ctx context.Context, bytes []byte) error {
|
||||
err := json.Unmarshal(bytes, (*_DNSRule)(r))
|
||||
err := json.UnmarshalContext(ctx, bytes, (*_DNSRule)(r))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -78,12 +78,6 @@ type RawDefaultDNSRule struct {
|
||||
DomainSuffix badoption.Listable[string] `json:"domain_suffix,omitempty"`
|
||||
DomainKeyword badoption.Listable[string] `json:"domain_keyword,omitempty"`
|
||||
DomainRegex badoption.Listable[string] `json:"domain_regex,omitempty"`
|
||||
Geosite badoption.Listable[string] `json:"geosite,omitempty"`
|
||||
SourceGeoIP badoption.Listable[string] `json:"source_geoip,omitempty"`
|
||||
GeoIP badoption.Listable[string] `json:"geoip,omitempty"`
|
||||
IPCIDR badoption.Listable[string] `json:"ip_cidr,omitempty"`
|
||||
IPIsPrivate bool `json:"ip_is_private,omitempty"`
|
||||
IPAcceptAny bool `json:"ip_accept_any,omitempty"`
|
||||
SourceIPCIDR badoption.Listable[string] `json:"source_ip_cidr,omitempty"`
|
||||
SourceIPIsPrivate bool `json:"source_ip_is_private,omitempty"`
|
||||
SourcePort badoption.Listable[uint16] `json:"source_port,omitempty"`
|
||||
@@ -110,9 +104,23 @@ type RawDefaultDNSRule struct {
|
||||
SourceHostname badoption.Listable[string] `json:"source_hostname,omitempty"`
|
||||
RuleSet badoption.Listable[string] `json:"rule_set,omitempty"`
|
||||
RuleSetIPCIDRMatchSource bool `json:"rule_set_ip_cidr_match_source,omitempty"`
|
||||
RuleSetIPCIDRAcceptEmpty bool `json:"rule_set_ip_cidr_accept_empty,omitempty"`
|
||||
MatchResponse bool `json:"match_response,omitempty"`
|
||||
IPCIDR badoption.Listable[string] `json:"ip_cidr,omitempty"`
|
||||
IPIsPrivate bool `json:"ip_is_private,omitempty"`
|
||||
ResponseRcode *DNSRCode `json:"response_rcode,omitempty"`
|
||||
ResponseAnswer badoption.Listable[DNSRecordOptions] `json:"response_answer,omitempty"`
|
||||
ResponseNs badoption.Listable[DNSRecordOptions] `json:"response_ns,omitempty"`
|
||||
ResponseExtra badoption.Listable[DNSRecordOptions] `json:"response_extra,omitempty"`
|
||||
Invert bool `json:"invert,omitempty"`
|
||||
|
||||
// Deprecated: removed in sing-box 1.12.0
|
||||
Geosite badoption.Listable[string] `json:"geosite,omitempty"`
|
||||
SourceGeoIP badoption.Listable[string] `json:"source_geoip,omitempty"`
|
||||
GeoIP badoption.Listable[string] `json:"geoip,omitempty"`
|
||||
// Deprecated: use match_response with response items
|
||||
IPAcceptAny bool `json:"ip_accept_any,omitempty"`
|
||||
// Deprecated: removed in sing-box 1.11.0
|
||||
RuleSetIPCIDRAcceptEmpty bool `json:"rule_set_ip_cidr_accept_empty,omitempty"`
|
||||
// Deprecated: renamed to rule_set_ip_cidr_match_source
|
||||
Deprecated_RulesetIPCIDRMatchSource bool `json:"rule_set_ipcidr_match_source,omitempty"`
|
||||
}
|
||||
@@ -127,11 +135,27 @@ func (r DefaultDNSRule) MarshalJSON() ([]byte, error) {
|
||||
}
|
||||
|
||||
func (r *DefaultDNSRule) UnmarshalJSONContext(ctx context.Context, data []byte) error {
|
||||
err := json.UnmarshalContext(ctx, data, &r.RawDefaultDNSRule)
|
||||
rawAction, routeOptions, err := inspectDNSRuleAction(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return badjson.UnmarshallExcludedContext(ctx, data, &r.RawDefaultDNSRule, &r.DNSRuleAction)
|
||||
err = rejectNestedDNSRuleAction(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
depth := nestedRuleDepth(ctx)
|
||||
err = json.UnmarshalContext(ctx, data, &r.RawDefaultDNSRule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = badjson.UnmarshallExcludedContext(ctx, data, &r.RawDefaultDNSRule, &r.DNSRuleAction)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if depth > 0 && rawAction == "" && routeOptions == (DNSRouteActionOptions{}) {
|
||||
r.DNSRuleAction = DNSRuleAction{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r DefaultDNSRule) IsValid() bool {
|
||||
@@ -156,11 +180,27 @@ func (r LogicalDNSRule) MarshalJSON() ([]byte, error) {
|
||||
}
|
||||
|
||||
func (r *LogicalDNSRule) UnmarshalJSONContext(ctx context.Context, data []byte) error {
|
||||
err := json.Unmarshal(data, &r.RawLogicalDNSRule)
|
||||
rawAction, routeOptions, err := inspectDNSRuleAction(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return badjson.UnmarshallExcludedContext(ctx, data, &r.RawLogicalDNSRule, &r.DNSRuleAction)
|
||||
err = rejectNestedDNSRuleAction(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
depth := nestedRuleDepth(ctx)
|
||||
err = json.UnmarshalContext(nestedRuleChildContext(ctx), data, &r.RawLogicalDNSRule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = badjson.UnmarshallExcludedContext(ctx, data, &r.RawLogicalDNSRule, &r.DNSRuleAction)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if depth > 0 && rawAction == "" && routeOptions == (DNSRouteActionOptions{}) {
|
||||
r.DNSRuleAction = DNSRuleAction{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *LogicalDNSRule) IsValid() bool {
|
||||
|
||||
133
option/rule_nested.go
Normal file
133
option/rule_nested.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package option
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
"github.com/sagernet/sing/common/json/badjson"
|
||||
)
|
||||
|
||||
type nestedRuleDepthContextKey struct{}
|
||||
|
||||
const (
|
||||
RouteRuleActionNestedUnsupportedMessage = "rule action is not supported in nested rules"
|
||||
DNSRuleActionNestedUnsupportedMessage = "DNS rule action is not supported in nested rules"
|
||||
)
|
||||
|
||||
var (
|
||||
routeRuleActionKeys = jsonFieldNames(reflect.TypeFor[_RuleAction](), reflect.TypeFor[RouteActionOptions]())
|
||||
dnsRuleActionKeys = jsonFieldNames(reflect.TypeFor[_DNSRuleAction](), reflect.TypeFor[DNSRouteActionOptions]())
|
||||
)
|
||||
|
||||
func nestedRuleChildContext(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, nestedRuleDepthContextKey{}, nestedRuleDepth(ctx)+1)
|
||||
}
|
||||
|
||||
func rejectNestedRouteRuleAction(ctx context.Context, content []byte) error {
|
||||
return rejectNestedRuleAction(ctx, content, routeRuleActionKeys, RouteRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func rejectNestedDNSRuleAction(ctx context.Context, content []byte) error {
|
||||
return rejectNestedRuleAction(ctx, content, dnsRuleActionKeys, DNSRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func nestedRuleDepth(ctx context.Context) int {
|
||||
depth, _ := ctx.Value(nestedRuleDepthContextKey{}).(int)
|
||||
return depth
|
||||
}
|
||||
|
||||
func rejectNestedRuleAction(ctx context.Context, content []byte, keys []string, message string) error {
|
||||
if nestedRuleDepth(ctx) == 0 {
|
||||
return nil
|
||||
}
|
||||
hasActionKey, err := hasAnyJSONKey(ctx, content, keys...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if hasActionKey {
|
||||
return E.New(message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasAnyJSONKey(ctx context.Context, content []byte, keys ...string) (bool, error) {
|
||||
var object badjson.JSONObject
|
||||
err := object.UnmarshalJSONContext(ctx, content)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, key := range keys {
|
||||
if object.ContainsKey(key) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func inspectRouteRuleAction(ctx context.Context, content []byte) (string, RouteActionOptions, error) {
|
||||
var rawAction _RuleAction
|
||||
err := json.UnmarshalContext(ctx, content, &rawAction)
|
||||
if err != nil {
|
||||
return "", RouteActionOptions{}, err
|
||||
}
|
||||
var routeOptions RouteActionOptions
|
||||
err = json.UnmarshalContext(ctx, content, &routeOptions)
|
||||
if err != nil {
|
||||
return "", RouteActionOptions{}, err
|
||||
}
|
||||
return rawAction.Action, routeOptions, nil
|
||||
}
|
||||
|
||||
func inspectDNSRuleAction(ctx context.Context, content []byte) (string, DNSRouteActionOptions, error) {
|
||||
var rawAction _DNSRuleAction
|
||||
err := json.UnmarshalContext(ctx, content, &rawAction)
|
||||
if err != nil {
|
||||
return "", DNSRouteActionOptions{}, err
|
||||
}
|
||||
var routeOptions DNSRouteActionOptions
|
||||
err = json.UnmarshalContext(ctx, content, &routeOptions)
|
||||
if err != nil {
|
||||
return "", DNSRouteActionOptions{}, err
|
||||
}
|
||||
return rawAction.Action, routeOptions, nil
|
||||
}
|
||||
|
||||
func jsonFieldNames(types ...reflect.Type) []string {
|
||||
fieldMap := make(map[string]struct{})
|
||||
for _, fieldType := range types {
|
||||
appendJSONFieldNames(fieldMap, fieldType)
|
||||
}
|
||||
fieldNames := make([]string, 0, len(fieldMap))
|
||||
for fieldName := range fieldMap {
|
||||
fieldNames = append(fieldNames, fieldName)
|
||||
}
|
||||
return fieldNames
|
||||
}
|
||||
|
||||
func appendJSONFieldNames(fieldMap map[string]struct{}, fieldType reflect.Type) {
|
||||
for fieldType.Kind() == reflect.Pointer {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
if fieldType.Kind() != reflect.Struct {
|
||||
return
|
||||
}
|
||||
for i := range fieldType.NumField() {
|
||||
field := fieldType.Field(i)
|
||||
tagValue := field.Tag.Get("json")
|
||||
tagName, _, _ := strings.Cut(tagValue, ",")
|
||||
if tagName == "-" {
|
||||
continue
|
||||
}
|
||||
if field.Anonymous && tagName == "" {
|
||||
appendJSONFieldNames(fieldMap, field.Type)
|
||||
continue
|
||||
}
|
||||
if tagName == "" {
|
||||
tagName = field.Name
|
||||
}
|
||||
fieldMap[tagName] = struct{}{}
|
||||
}
|
||||
}
|
||||
68
option/rule_nested_test.go
Normal file
68
option/rule_nested_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package option
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common/json"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRuleRejectsNestedDefaultRuleAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule Rule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{"domain": "example.com", "outbound": "direct"}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, RouteRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func TestRuleLeavesUnknownNestedKeysToNormalValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule Rule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{"domain": "example.com", "foo": "bar"}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, "unknown field")
|
||||
require.NotContains(t, err.Error(), RouteRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func TestDNSRuleRejectsNestedDefaultRuleAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule DNSRule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{"domain": "example.com", "server": "default"}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, DNSRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
|
||||
func TestDNSRuleLeavesUnknownNestedKeysToNormalValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var rule DNSRule
|
||||
err := json.UnmarshalContext(context.Background(), []byte(`{
|
||||
"type": "logical",
|
||||
"mode": "and",
|
||||
"rules": [
|
||||
{"domain": "example.com", "foo": "bar"}
|
||||
]
|
||||
}`), &rule)
|
||||
require.ErrorContains(t, err, "unknown field")
|
||||
require.NotContains(t, err.Error(), DNSRuleActionNestedUnsupportedMessage)
|
||||
}
|
||||
Reference in New Issue
Block a user