mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-11 17:47:20 +10:00
300 lines
7.4 KiB
Go
300 lines
7.4 KiB
Go
package cachefile
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"time"
|
|
|
|
"github.com/sagernet/bbolt"
|
|
"github.com/sagernet/sing/common/buf"
|
|
"github.com/sagernet/sing/common/logger"
|
|
)
|
|
|
|
var bucketDNSCache = []byte("dns_cache")
|
|
|
|
func (c *CacheFile) StoreDNS() bool {
|
|
return c.storeDNS
|
|
}
|
|
|
|
func (c *CacheFile) LoadDNSCache(transportName string, qName string, qType uint16) (rawMessage []byte, expireAt time.Time, loaded bool) {
|
|
c.saveDNSCacheAccess.RLock()
|
|
entry, cached := c.saveDNSCache[saveCacheKey{transportName, qName, qType}]
|
|
c.saveDNSCacheAccess.RUnlock()
|
|
if cached {
|
|
return entry.rawMessage, entry.expireAt, true
|
|
}
|
|
key := buf.Get(2 + len(qName))
|
|
binary.BigEndian.PutUint16(key, qType)
|
|
copy(key[2:], qName)
|
|
defer buf.Put(key)
|
|
err := c.view(func(tx *bbolt.Tx) error {
|
|
bucket := c.bucket(tx, bucketDNSCache)
|
|
if bucket == nil {
|
|
return nil
|
|
}
|
|
bucket = bucket.Bucket([]byte(transportName))
|
|
if bucket == nil {
|
|
return nil
|
|
}
|
|
content := bucket.Get(key)
|
|
if len(content) < 8 {
|
|
return nil
|
|
}
|
|
expireAt = time.Unix(int64(binary.BigEndian.Uint64(content[:8])), 0)
|
|
rawMessage = make([]byte, len(content)-8)
|
|
copy(rawMessage, content[8:])
|
|
loaded = true
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, time.Time{}, false
|
|
}
|
|
return
|
|
}
|
|
|
|
func (c *CacheFile) SaveDNSCache(transportName string, qName string, qType uint16, rawMessage []byte, expireAt time.Time) error {
|
|
return c.batch(func(tx *bbolt.Tx) error {
|
|
bucket, err := c.createBucket(tx, bucketDNSCache)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
bucket, err = bucket.CreateBucketIfNotExists([]byte(transportName))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
key := buf.Get(2 + len(qName))
|
|
binary.BigEndian.PutUint16(key, qType)
|
|
copy(key[2:], qName)
|
|
defer buf.Put(key)
|
|
value := buf.Get(8 + len(rawMessage))
|
|
defer buf.Put(value)
|
|
binary.BigEndian.PutUint64(value[:8], uint64(expireAt.Unix()))
|
|
copy(value[8:], rawMessage)
|
|
return bucket.Put(key, value)
|
|
})
|
|
}
|
|
|
|
func (c *CacheFile) SaveDNSCacheAsync(transportName string, qName string, qType uint16, rawMessage []byte, expireAt time.Time, logger logger.Logger) {
|
|
saveKey := saveCacheKey{transportName, qName, qType}
|
|
if !c.queueDNSCacheSave(saveKey, rawMessage, expireAt) {
|
|
return
|
|
}
|
|
go c.flushPendingDNSCache(saveKey, logger)
|
|
}
|
|
|
|
func (c *CacheFile) queueDNSCacheSave(saveKey saveCacheKey, rawMessage []byte, expireAt time.Time) bool {
|
|
c.saveDNSCacheAccess.Lock()
|
|
defer c.saveDNSCacheAccess.Unlock()
|
|
entry := c.saveDNSCache[saveKey]
|
|
entry.rawMessage = append([]byte(nil), rawMessage...)
|
|
entry.expireAt = expireAt
|
|
entry.sequence++
|
|
startFlush := !entry.saving
|
|
entry.saving = true
|
|
c.saveDNSCache[saveKey] = entry
|
|
return startFlush
|
|
}
|
|
|
|
func (c *CacheFile) flushPendingDNSCache(saveKey saveCacheKey, logger logger.Logger) {
|
|
c.flushPendingDNSCacheWith(saveKey, logger, func(entry saveDNSCacheEntry) error {
|
|
return c.SaveDNSCache(saveKey.TransportName, saveKey.QuestionName, saveKey.QType, entry.rawMessage, entry.expireAt)
|
|
})
|
|
}
|
|
|
|
func (c *CacheFile) flushPendingDNSCacheWith(saveKey saveCacheKey, logger logger.Logger, save func(saveDNSCacheEntry) error) {
|
|
for {
|
|
c.saveDNSCacheAccess.RLock()
|
|
entry, loaded := c.saveDNSCache[saveKey]
|
|
c.saveDNSCacheAccess.RUnlock()
|
|
if !loaded {
|
|
return
|
|
}
|
|
err := save(entry)
|
|
if err != nil {
|
|
logger.Warn("save DNS cache: ", err)
|
|
}
|
|
c.saveDNSCacheAccess.Lock()
|
|
currentEntry, loaded := c.saveDNSCache[saveKey]
|
|
if !loaded {
|
|
c.saveDNSCacheAccess.Unlock()
|
|
return
|
|
}
|
|
if currentEntry.sequence != entry.sequence {
|
|
c.saveDNSCacheAccess.Unlock()
|
|
continue
|
|
}
|
|
delete(c.saveDNSCache, saveKey)
|
|
c.saveDNSCacheAccess.Unlock()
|
|
return
|
|
}
|
|
}
|
|
|
|
func (c *CacheFile) ClearDNSCache() error {
|
|
c.saveDNSCacheAccess.Lock()
|
|
clear(c.saveDNSCache)
|
|
c.saveDNSCacheAccess.Unlock()
|
|
return c.batch(func(tx *bbolt.Tx) error {
|
|
if c.cacheID == nil {
|
|
bucket := tx.Bucket(bucketDNSCache)
|
|
if bucket == nil {
|
|
return nil
|
|
}
|
|
return tx.DeleteBucket(bucketDNSCache)
|
|
}
|
|
bucket := tx.Bucket(c.cacheID)
|
|
if bucket == nil || bucket.Bucket(bucketDNSCache) == nil {
|
|
return nil
|
|
}
|
|
return bucket.DeleteBucket(bucketDNSCache)
|
|
})
|
|
}
|
|
|
|
func (c *CacheFile) loopCacheCleanup(interval time.Duration, cleanupFunc func()) {
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-c.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
cleanupFunc()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *CacheFile) cleanupDNSCache() {
|
|
now := time.Now()
|
|
err := c.batch(func(tx *bbolt.Tx) error {
|
|
bucket := c.bucket(tx, bucketDNSCache)
|
|
if bucket == nil {
|
|
return nil
|
|
}
|
|
var emptyTransports [][]byte
|
|
err := bucket.ForEachBucket(func(transportName []byte) error {
|
|
transportBucket := bucket.Bucket(transportName)
|
|
if transportBucket == nil {
|
|
return nil
|
|
}
|
|
var expiredKeys [][]byte
|
|
err := transportBucket.ForEach(func(key, value []byte) error {
|
|
if len(value) < 8 {
|
|
expiredKeys = append(expiredKeys, append([]byte(nil), key...))
|
|
return nil
|
|
}
|
|
if c.disableExpire {
|
|
return nil
|
|
}
|
|
expireAt := time.Unix(int64(binary.BigEndian.Uint64(value[:8])), 0)
|
|
if now.After(expireAt.Add(c.optimisticTimeout)) {
|
|
expiredKeys = append(expiredKeys, append([]byte(nil), key...))
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, key := range expiredKeys {
|
|
err = transportBucket.Delete(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
first, _ := transportBucket.Cursor().First()
|
|
if first == nil {
|
|
emptyTransports = append(emptyTransports, append([]byte(nil), transportName...))
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, name := range emptyTransports {
|
|
err = bucket.DeleteBucket(name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
c.logger.Warn("cleanup DNS cache: ", err)
|
|
}
|
|
}
|
|
|
|
func (c *CacheFile) clearRDRC() {
|
|
c.saveRDRCAccess.Lock()
|
|
clear(c.saveRDRC)
|
|
c.saveRDRCAccess.Unlock()
|
|
err := c.batch(func(tx *bbolt.Tx) error {
|
|
if c.cacheID == nil {
|
|
if tx.Bucket(bucketRDRC) == nil {
|
|
return nil
|
|
}
|
|
return tx.DeleteBucket(bucketRDRC)
|
|
}
|
|
bucket := tx.Bucket(c.cacheID)
|
|
if bucket == nil || bucket.Bucket(bucketRDRC) == nil {
|
|
return nil
|
|
}
|
|
return bucket.DeleteBucket(bucketRDRC)
|
|
})
|
|
if err != nil {
|
|
c.logger.Warn("clear RDRC: ", err)
|
|
}
|
|
}
|
|
|
|
func (c *CacheFile) cleanupRDRC() {
|
|
now := time.Now()
|
|
err := c.batch(func(tx *bbolt.Tx) error {
|
|
bucket := c.bucket(tx, bucketRDRC)
|
|
if bucket == nil {
|
|
return nil
|
|
}
|
|
var emptyTransports [][]byte
|
|
err := bucket.ForEachBucket(func(transportName []byte) error {
|
|
transportBucket := bucket.Bucket(transportName)
|
|
if transportBucket == nil {
|
|
return nil
|
|
}
|
|
var expiredKeys [][]byte
|
|
err := transportBucket.ForEach(func(key, value []byte) error {
|
|
if len(value) < 8 {
|
|
expiredKeys = append(expiredKeys, append([]byte(nil), key...))
|
|
return nil
|
|
}
|
|
expiresAt := time.Unix(int64(binary.BigEndian.Uint64(value)), 0)
|
|
if now.After(expiresAt) {
|
|
expiredKeys = append(expiredKeys, append([]byte(nil), key...))
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, key := range expiredKeys {
|
|
err = transportBucket.Delete(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
first, _ := transportBucket.Cursor().First()
|
|
if first == nil {
|
|
emptyTransports = append(emptyTransports, append([]byte(nil), transportName...))
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, name := range emptyTransports {
|
|
err = bucket.DeleteBucket(name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
c.logger.Warn("cleanup RDRC: ", err)
|
|
}
|
|
}
|