Add hysteria and acme TLS certificate issuer (#18)
* Add hysteria client/server * Add acme TLS certificate issuer
This commit is contained in:
149
transport/hysteria/brutal.go
Normal file
149
transport/hysteria/brutal.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package hysteria
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
)
|
||||
|
||||
const (
|
||||
initMaxDatagramSize = 1252
|
||||
|
||||
pktInfoSlotCount = 4
|
||||
minSampleCount = 50
|
||||
minAckRate = 0.8
|
||||
)
|
||||
|
||||
type BrutalSender struct {
|
||||
rttStats congestion.RTTStatsProvider
|
||||
bps congestion.ByteCount
|
||||
maxDatagramSize congestion.ByteCount
|
||||
pacer *pacer
|
||||
|
||||
pktInfoSlots [pktInfoSlotCount]pktInfo
|
||||
ackRate float64
|
||||
}
|
||||
|
||||
type pktInfo struct {
|
||||
Timestamp int64
|
||||
AckCount uint64
|
||||
LossCount uint64
|
||||
}
|
||||
|
||||
func NewBrutalSender(bps congestion.ByteCount) *BrutalSender {
|
||||
bs := &BrutalSender{
|
||||
bps: bps,
|
||||
maxDatagramSize: initMaxDatagramSize,
|
||||
ackRate: 1,
|
||||
}
|
||||
bs.pacer = newPacer(func() congestion.ByteCount {
|
||||
return congestion.ByteCount(float64(bs.bps) / bs.ackRate)
|
||||
})
|
||||
return bs
|
||||
}
|
||||
|
||||
func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) {
|
||||
b.rttStats = rttStats
|
||||
}
|
||||
|
||||
func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time {
|
||||
return b.pacer.TimeUntilSend()
|
||||
}
|
||||
|
||||
func (b *BrutalSender) HasPacingBudget() bool {
|
||||
return b.pacer.Budget(time.Now()) >= b.maxDatagramSize
|
||||
}
|
||||
|
||||
func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool {
|
||||
return bytesInFlight < b.GetCongestionWindow()
|
||||
}
|
||||
|
||||
func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount {
|
||||
rtt := maxDuration(b.rttStats.LatestRTT(), b.rttStats.SmoothedRTT())
|
||||
if rtt <= 0 {
|
||||
return 10240
|
||||
}
|
||||
return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * 1.5 / b.ackRate)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount,
|
||||
packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool,
|
||||
) {
|
||||
b.pacer.SentPacket(sentTime, bytes)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount,
|
||||
priorInFlight congestion.ByteCount, eventTime time.Time,
|
||||
) {
|
||||
currentTimestamp := eventTime.Unix()
|
||||
slot := currentTimestamp % pktInfoSlotCount
|
||||
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
|
||||
b.pktInfoSlots[slot].AckCount++
|
||||
} else {
|
||||
// uninitialized slot or too old, reset
|
||||
b.pktInfoSlots[slot].Timestamp = currentTimestamp
|
||||
b.pktInfoSlots[slot].AckCount = 1
|
||||
b.pktInfoSlots[slot].LossCount = 0
|
||||
}
|
||||
b.updateAckRate(currentTimestamp)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) OnPacketLost(number congestion.PacketNumber, lostBytes congestion.ByteCount,
|
||||
priorInFlight congestion.ByteCount,
|
||||
) {
|
||||
currentTimestamp := time.Now().Unix()
|
||||
slot := currentTimestamp % pktInfoSlotCount
|
||||
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
|
||||
b.pktInfoSlots[slot].LossCount++
|
||||
} else {
|
||||
// uninitialized slot or too old, reset
|
||||
b.pktInfoSlots[slot].Timestamp = currentTimestamp
|
||||
b.pktInfoSlots[slot].AckCount = 0
|
||||
b.pktInfoSlots[slot].LossCount = 1
|
||||
}
|
||||
b.updateAckRate(currentTimestamp)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) {
|
||||
b.maxDatagramSize = size
|
||||
b.pacer.SetMaxDatagramSize(size)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
|
||||
minTimestamp := currentTimestamp - pktInfoSlotCount
|
||||
var ackCount, lossCount uint64
|
||||
for _, info := range b.pktInfoSlots {
|
||||
if info.Timestamp < minTimestamp {
|
||||
continue
|
||||
}
|
||||
ackCount += info.AckCount
|
||||
lossCount += info.LossCount
|
||||
}
|
||||
if ackCount+lossCount < minSampleCount {
|
||||
b.ackRate = 1
|
||||
}
|
||||
rate := float64(ackCount) / float64(ackCount+lossCount)
|
||||
if rate < minAckRate {
|
||||
b.ackRate = minAckRate
|
||||
}
|
||||
b.ackRate = rate
|
||||
}
|
||||
|
||||
func (b *BrutalSender) InSlowStart() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *BrutalSender) InRecovery() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *BrutalSender) MaybeExitSlowStart() {}
|
||||
|
||||
func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {}
|
||||
|
||||
func maxDuration(a, b time.Duration) time.Duration {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
65
transport/hysteria/frag.go
Normal file
65
transport/hysteria/frag.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package hysteria
|
||||
|
||||
func FragUDPMessage(m UDPMessage, maxSize int) []UDPMessage {
|
||||
if m.Size() <= maxSize {
|
||||
return []UDPMessage{m}
|
||||
}
|
||||
fullPayload := m.Data
|
||||
maxPayloadSize := maxSize - m.HeaderSize()
|
||||
off := 0
|
||||
fragID := uint8(0)
|
||||
fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up
|
||||
var frags []UDPMessage
|
||||
for off < len(fullPayload) {
|
||||
payloadSize := len(fullPayload) - off
|
||||
if payloadSize > maxPayloadSize {
|
||||
payloadSize = maxPayloadSize
|
||||
}
|
||||
frag := m
|
||||
frag.FragID = fragID
|
||||
frag.FragCount = fragCount
|
||||
frag.Data = fullPayload[off : off+payloadSize]
|
||||
frags = append(frags, frag)
|
||||
off += payloadSize
|
||||
fragID++
|
||||
}
|
||||
return frags
|
||||
}
|
||||
|
||||
type Defragger struct {
|
||||
msgID uint16
|
||||
frags []*UDPMessage
|
||||
count uint8
|
||||
}
|
||||
|
||||
func (d *Defragger) Feed(m UDPMessage) *UDPMessage {
|
||||
if m.FragCount <= 1 {
|
||||
return &m
|
||||
}
|
||||
if m.FragID >= m.FragCount {
|
||||
// wtf is this?
|
||||
return nil
|
||||
}
|
||||
if m.MsgID != d.msgID {
|
||||
// new message, clear previous state
|
||||
d.msgID = m.MsgID
|
||||
d.frags = make([]*UDPMessage, m.FragCount)
|
||||
d.count = 1
|
||||
d.frags[m.FragID] = &m
|
||||
} else if d.frags[m.FragID] == nil {
|
||||
d.frags[m.FragID] = &m
|
||||
d.count++
|
||||
if int(d.count) == len(d.frags) {
|
||||
// all fragments received, assemble
|
||||
var data []byte
|
||||
for _, frag := range d.frags {
|
||||
data = append(data, frag.Data...)
|
||||
}
|
||||
m.Data = data
|
||||
m.FragID = 0
|
||||
m.FragCount = 1
|
||||
return &m
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
86
transport/hysteria/pacer.go
Normal file
86
transport/hysteria/pacer.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package hysteria
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
)
|
||||
|
||||
const (
|
||||
maxBurstPackets = 10
|
||||
minPacingDelay = time.Millisecond
|
||||
)
|
||||
|
||||
// The pacer implements a token bucket pacing algorithm.
|
||||
type pacer struct {
|
||||
budgetAtLastSent congestion.ByteCount
|
||||
maxDatagramSize congestion.ByteCount
|
||||
lastSentTime time.Time
|
||||
getBandwidth func() congestion.ByteCount // in bytes/s
|
||||
}
|
||||
|
||||
func newPacer(getBandwidth func() congestion.ByteCount) *pacer {
|
||||
p := &pacer{
|
||||
budgetAtLastSent: maxBurstPackets * initMaxDatagramSize,
|
||||
maxDatagramSize: initMaxDatagramSize,
|
||||
getBandwidth: getBandwidth,
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) {
|
||||
budget := p.Budget(sendTime)
|
||||
if size > budget {
|
||||
p.budgetAtLastSent = 0
|
||||
} else {
|
||||
p.budgetAtLastSent = budget - size
|
||||
}
|
||||
p.lastSentTime = sendTime
|
||||
}
|
||||
|
||||
func (p *pacer) Budget(now time.Time) congestion.ByteCount {
|
||||
if p.lastSentTime.IsZero() {
|
||||
return p.maxBurstSize()
|
||||
}
|
||||
budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
|
||||
return minByteCount(p.maxBurstSize(), budget)
|
||||
}
|
||||
|
||||
func (p *pacer) maxBurstSize() congestion.ByteCount {
|
||||
return maxByteCount(
|
||||
congestion.ByteCount((minPacingDelay+time.Millisecond).Nanoseconds())*p.getBandwidth()/1e9,
|
||||
maxBurstPackets*p.maxDatagramSize,
|
||||
)
|
||||
}
|
||||
|
||||
// TimeUntilSend returns when the next packet should be sent.
|
||||
// It returns the zero value of time.Time if a packet can be sent immediately.
|
||||
func (p *pacer) TimeUntilSend() time.Time {
|
||||
if p.budgetAtLastSent >= p.maxDatagramSize {
|
||||
return time.Time{}
|
||||
}
|
||||
return p.lastSentTime.Add(maxDuration(
|
||||
minPacingDelay,
|
||||
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/
|
||||
float64(p.getBandwidth())))*time.Nanosecond,
|
||||
))
|
||||
}
|
||||
|
||||
func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) {
|
||||
p.maxDatagramSize = s
|
||||
}
|
||||
|
||||
func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount {
|
||||
if a < b {
|
||||
return b
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func minByteCount(a, b congestion.ByteCount) congestion.ByteCount {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
510
transport/hysteria/protocol.go
Normal file
510
transport/hysteria/protocol.go
Normal file
@@ -0,0 +1,510 @@
|
||||
package hysteria
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
MbpsToBps = 125000
|
||||
MinSpeedBPS = 16384
|
||||
DefaultStreamReceiveWindow = 15728640 // 15 MB/s
|
||||
DefaultConnectionReceiveWindow = 67108864 // 64 MB/s
|
||||
DefaultMaxIncomingStreams = 1024
|
||||
DefaultALPN = "hysteria"
|
||||
KeepAlivePeriod = 10 * time.Second
|
||||
)
|
||||
|
||||
const Version = 3
|
||||
|
||||
type ClientHello struct {
|
||||
SendBPS uint64
|
||||
RecvBPS uint64
|
||||
Auth []byte
|
||||
}
|
||||
|
||||
func WriteClientHello(stream io.Writer, hello ClientHello) error {
|
||||
var requestLen int
|
||||
requestLen += 1 // version
|
||||
requestLen += 8 // sendBPS
|
||||
requestLen += 8 // recvBPS
|
||||
requestLen += 2 // auth len
|
||||
requestLen += len(hello.Auth)
|
||||
_request := buf.StackNewSize(requestLen)
|
||||
defer common.KeepAlive(_request)
|
||||
request := common.Dup(_request)
|
||||
defer request.Release()
|
||||
common.Must(
|
||||
request.WriteByte(Version),
|
||||
binary.Write(request, binary.BigEndian, hello.SendBPS),
|
||||
binary.Write(request, binary.BigEndian, hello.RecvBPS),
|
||||
binary.Write(request, binary.BigEndian, uint16(len(hello.Auth))),
|
||||
common.Error(request.Write(hello.Auth)),
|
||||
)
|
||||
return common.Error(stream.Write(request.Bytes()))
|
||||
}
|
||||
|
||||
func ReadClientHello(reader io.Reader) (*ClientHello, error) {
|
||||
var version uint8
|
||||
err := binary.Read(reader, binary.BigEndian, &version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if version != Version {
|
||||
return nil, E.New("unsupported client version: ", version)
|
||||
}
|
||||
var clientHello ClientHello
|
||||
err = binary.Read(reader, binary.BigEndian, &clientHello.SendBPS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &clientHello.RecvBPS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var authLen uint16
|
||||
err = binary.Read(reader, binary.BigEndian, &authLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientHello.Auth = make([]byte, authLen)
|
||||
_, err = io.ReadFull(reader, clientHello.Auth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &clientHello, nil
|
||||
}
|
||||
|
||||
type ServerHello struct {
|
||||
OK bool
|
||||
SendBPS uint64
|
||||
RecvBPS uint64
|
||||
Message string
|
||||
}
|
||||
|
||||
func ReadServerHello(stream io.Reader) (*ServerHello, error) {
|
||||
var responseLen int
|
||||
responseLen += 1 // ok
|
||||
responseLen += 8 // sendBPS
|
||||
responseLen += 8 // recvBPS
|
||||
responseLen += 2 // message len
|
||||
_response := buf.StackNewSize(responseLen)
|
||||
defer common.KeepAlive(_response)
|
||||
response := common.Dup(_response)
|
||||
defer response.Release()
|
||||
_, err := response.ReadFullFrom(stream, responseLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var serverHello ServerHello
|
||||
serverHello.OK = response.Byte(0) == 1
|
||||
serverHello.SendBPS = binary.BigEndian.Uint64(response.Range(1, 9))
|
||||
serverHello.RecvBPS = binary.BigEndian.Uint64(response.Range(9, 17))
|
||||
messageLen := binary.BigEndian.Uint16(response.Range(17, 19))
|
||||
if messageLen == 0 {
|
||||
return &serverHello, nil
|
||||
}
|
||||
message := make([]byte, messageLen)
|
||||
_, err = io.ReadFull(stream, message)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serverHello.Message = string(message)
|
||||
return &serverHello, nil
|
||||
}
|
||||
|
||||
func WriteServerHello(stream io.Writer, hello ServerHello) error {
|
||||
var responseLen int
|
||||
responseLen += 1 // ok
|
||||
responseLen += 8 // sendBPS
|
||||
responseLen += 8 // recvBPS
|
||||
responseLen += 2 // message len
|
||||
responseLen += len(hello.Message)
|
||||
_response := buf.StackNewSize(responseLen)
|
||||
defer common.KeepAlive(_response)
|
||||
response := common.Dup(_response)
|
||||
defer response.Release()
|
||||
if hello.OK {
|
||||
common.Must(response.WriteByte(1))
|
||||
} else {
|
||||
common.Must(response.WriteByte(0))
|
||||
}
|
||||
common.Must(
|
||||
binary.Write(response, binary.BigEndian, hello.SendBPS),
|
||||
binary.Write(response, binary.BigEndian, hello.RecvBPS),
|
||||
binary.Write(response, binary.BigEndian, uint16(len(hello.Message))),
|
||||
common.Error(response.WriteString(hello.Message)),
|
||||
)
|
||||
return common.Error(stream.Write(response.Bytes()))
|
||||
}
|
||||
|
||||
type ClientRequest struct {
|
||||
UDP bool
|
||||
Host string
|
||||
Port uint16
|
||||
}
|
||||
|
||||
func ReadClientRequest(stream io.Reader) (*ClientRequest, error) {
|
||||
var clientRequest ClientRequest
|
||||
err := binary.Read(stream, binary.BigEndian, &clientRequest.UDP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var hostLen uint16
|
||||
err = binary.Read(stream, binary.BigEndian, &hostLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
host := make([]byte, hostLen)
|
||||
_, err = io.ReadFull(stream, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientRequest.Host = string(host)
|
||||
err = binary.Read(stream, binary.BigEndian, &clientRequest.Port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &clientRequest, nil
|
||||
}
|
||||
|
||||
func WriteClientRequest(stream io.Writer, request ClientRequest) error {
|
||||
var requestLen int
|
||||
requestLen += 1 // udp
|
||||
requestLen += 2 // host len
|
||||
requestLen += len(request.Host)
|
||||
requestLen += 2 // port
|
||||
_buffer := buf.StackNewSize(requestLen)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
if request.UDP {
|
||||
common.Must(buffer.WriteByte(1))
|
||||
} else {
|
||||
common.Must(buffer.WriteByte(0))
|
||||
}
|
||||
common.Must(
|
||||
binary.Write(buffer, binary.BigEndian, uint16(len(request.Host))),
|
||||
common.Error(buffer.WriteString(request.Host)),
|
||||
binary.Write(buffer, binary.BigEndian, request.Port),
|
||||
)
|
||||
return common.Error(stream.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
type ServerResponse struct {
|
||||
OK bool
|
||||
UDPSessionID uint32
|
||||
Message string
|
||||
}
|
||||
|
||||
func ReadServerResponse(stream io.Reader) (*ServerResponse, error) {
|
||||
var responseLen int
|
||||
responseLen += 1 // ok
|
||||
responseLen += 4 // udp session id
|
||||
responseLen += 2 // message len
|
||||
_response := buf.StackNewSize(responseLen)
|
||||
defer common.KeepAlive(_response)
|
||||
response := common.Dup(_response)
|
||||
defer response.Release()
|
||||
_, err := response.ReadFullFrom(stream, responseLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var serverResponse ServerResponse
|
||||
serverResponse.OK = response.Byte(0) == 1
|
||||
serverResponse.UDPSessionID = binary.BigEndian.Uint32(response.Range(1, 5))
|
||||
messageLen := binary.BigEndian.Uint16(response.Range(5, 7))
|
||||
if messageLen == 0 {
|
||||
return &serverResponse, nil
|
||||
}
|
||||
message := make([]byte, messageLen)
|
||||
_, err = io.ReadFull(stream, message)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serverResponse.Message = string(message)
|
||||
return &serverResponse, nil
|
||||
}
|
||||
|
||||
func WriteServerResponse(stream io.Writer, response ServerResponse) error {
|
||||
var responseLen int
|
||||
responseLen += 1 // ok
|
||||
responseLen += 4 // udp session id
|
||||
responseLen += 2 // message len
|
||||
responseLen += len(response.Message)
|
||||
_buffer := buf.StackNewSize(responseLen)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
if response.OK {
|
||||
common.Must(buffer.WriteByte(1))
|
||||
} else {
|
||||
common.Must(buffer.WriteByte(0))
|
||||
}
|
||||
common.Must(
|
||||
binary.Write(buffer, binary.BigEndian, response.UDPSessionID),
|
||||
binary.Write(buffer, binary.BigEndian, uint16(len(response.Message))),
|
||||
common.Error(buffer.WriteString(response.Message)),
|
||||
)
|
||||
return common.Error(stream.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
type UDPMessage struct {
|
||||
SessionID uint32
|
||||
Host string
|
||||
Port uint16
|
||||
MsgID uint16 // doesn't matter when not fragmented, but must not be 0 when fragmented
|
||||
FragID uint8 // doesn't matter when not fragmented, starts at 0 when fragmented
|
||||
FragCount uint8 // must be 1 when not fragmented
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (m UDPMessage) HeaderSize() int {
|
||||
return 4 + 2 + len(m.Host) + 2 + 2 + 1 + 1 + 2
|
||||
}
|
||||
|
||||
func (m UDPMessage) Size() int {
|
||||
return m.HeaderSize() + len(m.Data)
|
||||
}
|
||||
|
||||
func ParseUDPMessage(packet []byte) (message UDPMessage, err error) {
|
||||
reader := bytes.NewReader(packet)
|
||||
err = binary.Read(reader, binary.BigEndian, &message.SessionID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var hostLen uint16
|
||||
err = binary.Read(reader, binary.BigEndian, &hostLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = reader.Seek(int64(hostLen), io.SeekCurrent)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
message.Host = string(packet[6 : 6+hostLen])
|
||||
err = binary.Read(reader, binary.BigEndian, &message.Port)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &message.MsgID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &message.FragID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &message.FragCount)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var dataLen uint16
|
||||
err = binary.Read(reader, binary.BigEndian, &dataLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if reader.Len() != int(dataLen) {
|
||||
err = E.New("invalid data length")
|
||||
}
|
||||
dataOffset := int(reader.Size()) - reader.Len()
|
||||
message.Data = packet[dataOffset:]
|
||||
return
|
||||
}
|
||||
|
||||
func WriteUDPMessage(conn quic.Connection, message UDPMessage) error {
|
||||
var messageLen int
|
||||
messageLen += 4 // session id
|
||||
messageLen += 2 // host len
|
||||
messageLen += len(message.Host)
|
||||
messageLen += 2 // port
|
||||
messageLen += 2 // msg id
|
||||
messageLen += 1 // frag id
|
||||
messageLen += 1 // frag count
|
||||
messageLen += 2 // data len
|
||||
messageLen += len(message.Data)
|
||||
_buffer := buf.StackNewSize(messageLen)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
err := writeUDPMessage(conn, message, buffer)
|
||||
if errSize, ok := err.(quic.ErrMessageToLarge); ok {
|
||||
// need to frag
|
||||
message.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
|
||||
fragMsgs := FragUDPMessage(message, int(errSize))
|
||||
for _, fragMsg := range fragMsgs {
|
||||
buffer.FullReset()
|
||||
err = writeUDPMessage(conn, fragMsg, buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func writeUDPMessage(conn quic.Connection, message UDPMessage, buffer *buf.Buffer) error {
|
||||
common.Must(
|
||||
binary.Write(buffer, binary.BigEndian, message.SessionID),
|
||||
binary.Write(buffer, binary.BigEndian, uint16(len(message.Host))),
|
||||
common.Error(buffer.WriteString(message.Host)),
|
||||
binary.Write(buffer, binary.BigEndian, message.Port),
|
||||
binary.Write(buffer, binary.BigEndian, message.MsgID),
|
||||
binary.Write(buffer, binary.BigEndian, message.FragID),
|
||||
binary.Write(buffer, binary.BigEndian, message.FragCount),
|
||||
binary.Write(buffer, binary.BigEndian, uint16(len(message.Data))),
|
||||
common.Error(buffer.Write(message.Data)),
|
||||
)
|
||||
return conn.SendMessage(buffer.Bytes())
|
||||
}
|
||||
|
||||
var _ net.Conn = (*Conn)(nil)
|
||||
|
||||
type Conn struct {
|
||||
quic.Stream
|
||||
destination M.Socksaddr
|
||||
responseWritten bool
|
||||
}
|
||||
|
||||
func NewConn(stream quic.Stream, destination M.Socksaddr) *Conn {
|
||||
return &Conn{
|
||||
Stream: stream,
|
||||
destination: destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.destination.TCPAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) ReaderReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Conn) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Conn) Upstream() any {
|
||||
return c.Stream
|
||||
}
|
||||
|
||||
type PacketConn struct {
|
||||
session quic.Connection
|
||||
stream quic.Stream
|
||||
sessionId uint32
|
||||
destination M.Socksaddr
|
||||
msgCh <-chan *UDPMessage
|
||||
closer io.Closer
|
||||
}
|
||||
|
||||
func NewPacketConn(session quic.Connection, stream quic.Stream, sessionId uint32, destination M.Socksaddr, msgCh <-chan *UDPMessage, closer io.Closer) *PacketConn {
|
||||
return &PacketConn{
|
||||
session: session,
|
||||
stream: stream,
|
||||
sessionId: sessionId,
|
||||
destination: destination,
|
||||
msgCh: msgCh,
|
||||
closer: closer,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PacketConn) Hold() {
|
||||
// Hold the stream until it's closed
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
_, err := c.stream.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
msg := <-c.msgCh
|
||||
if msg == nil {
|
||||
err = net.ErrClosed
|
||||
return
|
||||
}
|
||||
err = common.Error(buffer.Write(msg.Data))
|
||||
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
msg := <-c.msgCh
|
||||
if msg == nil {
|
||||
err = net.ErrClosed
|
||||
return
|
||||
}
|
||||
buffer = buf.As(msg.Data)
|
||||
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
return WriteUDPMessage(c.session, UDPMessage{
|
||||
SessionID: c.sessionId,
|
||||
Host: destination.Unwrap().AddrString(),
|
||||
Port: destination.Port,
|
||||
FragCount: 1,
|
||||
Data: buffer.Bytes(),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *PacketConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PacketConn) RemoteAddr() net.Addr {
|
||||
return c.destination.UDPAddr()
|
||||
}
|
||||
|
||||
func (c *PacketConn) SetDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *PacketConn) SetReadDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *PacketConn) SetWriteDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
return 0, nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *PacketConn) Read(b []byte) (n int, err error) {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *PacketConn) Write(b []byte) (n int, err error) {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *PacketConn) Close() error {
|
||||
return common.Close(c.stream, c.closer)
|
||||
}
|
||||
36
transport/hysteria/speed.go
Normal file
36
transport/hysteria/speed.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package hysteria
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func StringToBps(s string) uint64 {
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
m := regexp.MustCompile(`^(\d+)\s*([KMGT]?)([Bb])ps$`).FindStringSubmatch(s)
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
var n uint64
|
||||
switch m[2] {
|
||||
case "K":
|
||||
n = 1 << 10
|
||||
case "M":
|
||||
n = 1 << 20
|
||||
case "G":
|
||||
n = 1 << 30
|
||||
case "T":
|
||||
n = 1 << 40
|
||||
default:
|
||||
n = 1
|
||||
}
|
||||
v, _ := strconv.ParseUint(m[1], 10, 64)
|
||||
n = v * n
|
||||
if m[3] == "b" {
|
||||
// Bits, need to convert to bytes
|
||||
n = n >> 3
|
||||
}
|
||||
return n
|
||||
}
|
||||
56
transport/hysteria/wrap.go
Normal file
56
transport/hysteria/wrap.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package hysteria
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
type PacketConnWrapper struct {
|
||||
net.PacketConn
|
||||
}
|
||||
|
||||
func (c *PacketConnWrapper) SetReadBuffer(bytes int) error {
|
||||
return common.MustCast[*net.UDPConn](c.PacketConn).SetReadBuffer(bytes)
|
||||
}
|
||||
|
||||
func (c *PacketConnWrapper) SetWriteBuffer(bytes int) error {
|
||||
return common.MustCast[*net.UDPConn](c.PacketConn).SetWriteBuffer(bytes)
|
||||
}
|
||||
|
||||
func (c *PacketConnWrapper) SyscallConn() (syscall.RawConn, error) {
|
||||
return common.MustCast[*net.UDPConn](c.PacketConn).SyscallConn()
|
||||
}
|
||||
|
||||
func (c *PacketConnWrapper) File() (f *os.File, err error) {
|
||||
return common.MustCast[*net.UDPConn](c.PacketConn).File()
|
||||
}
|
||||
|
||||
func (c *PacketConnWrapper) Upstream() any {
|
||||
return c.PacketConn
|
||||
}
|
||||
|
||||
type StreamWrapper struct {
|
||||
quic.Stream
|
||||
}
|
||||
|
||||
func (s *StreamWrapper) Upstream() any {
|
||||
return s.Stream
|
||||
}
|
||||
|
||||
func (s *StreamWrapper) ReaderReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *StreamWrapper) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *StreamWrapper) Close() error {
|
||||
s.CancelRead(0)
|
||||
s.Stream.Close()
|
||||
return nil
|
||||
}
|
||||
119
transport/hysteria/xplus.go
Normal file
119
transport/hysteria/xplus.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package hysteria
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
const xplusSaltLen = 16
|
||||
|
||||
var errInalidPacket = E.New("invalid packet")
|
||||
|
||||
func NewXPlusPacketConn(conn net.PacketConn, key []byte) net.PacketConn {
|
||||
vectorisedWriter, isVectorised := bufio.CreateVectorisedPacketWriter(conn)
|
||||
if isVectorised {
|
||||
return &VectorisedXPlusConn{
|
||||
XPlusPacketConn: XPlusPacketConn{
|
||||
PacketConn: conn,
|
||||
key: key,
|
||||
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
},
|
||||
writer: vectorisedWriter,
|
||||
}
|
||||
} else {
|
||||
return &XPlusPacketConn{
|
||||
PacketConn: conn,
|
||||
key: key,
|
||||
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type XPlusPacketConn struct {
|
||||
net.PacketConn
|
||||
key []byte
|
||||
randAccess sync.Mutex
|
||||
rand *rand.Rand
|
||||
}
|
||||
|
||||
func (c *XPlusPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, addr, err = c.PacketConn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return
|
||||
} else if n < xplusSaltLen {
|
||||
return 0, nil, errInalidPacket
|
||||
}
|
||||
key := sha256.Sum256(append(c.key, p[:xplusSaltLen]...))
|
||||
for i := range p[xplusSaltLen:] {
|
||||
p[i] = p[xplusSaltLen+i] ^ key[i%sha256.Size]
|
||||
}
|
||||
n -= xplusSaltLen
|
||||
return
|
||||
}
|
||||
|
||||
func (c *XPlusPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
// can't use unsafe buffer on WriteTo
|
||||
buffer := buf.NewSize(len(p) + xplusSaltLen)
|
||||
defer buffer.Release()
|
||||
salt := buffer.Extend(xplusSaltLen)
|
||||
c.randAccess.Lock()
|
||||
_, _ = c.rand.Read(salt)
|
||||
c.randAccess.Unlock()
|
||||
key := sha256.Sum256(append(c.key, salt...))
|
||||
for i := range p {
|
||||
common.Must(buffer.WriteByte(p[i] ^ key[i%sha256.Size]))
|
||||
}
|
||||
return c.PacketConn.WriteTo(buffer.Bytes(), addr)
|
||||
}
|
||||
|
||||
func (c *XPlusPacketConn) Upstream() any {
|
||||
return c.PacketConn
|
||||
}
|
||||
|
||||
type VectorisedXPlusConn struct {
|
||||
XPlusPacketConn
|
||||
writer N.VectorisedPacketWriter
|
||||
}
|
||||
|
||||
func (c *VectorisedXPlusConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
header := buf.NewSize(xplusSaltLen)
|
||||
defer header.Release()
|
||||
salt := header.Extend(xplusSaltLen)
|
||||
c.randAccess.Lock()
|
||||
_, _ = c.rand.Read(salt)
|
||||
c.randAccess.Unlock()
|
||||
key := sha256.Sum256(append(c.key, salt...))
|
||||
for i := range p {
|
||||
p[i] ^= key[i%sha256.Size]
|
||||
}
|
||||
return bufio.WriteVectorisedPacket(c.writer, [][]byte{header.Bytes(), p}, M.SocksaddrFromNet(addr))
|
||||
}
|
||||
|
||||
func (c *VectorisedXPlusConn) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
|
||||
header := buf.NewSize(xplusSaltLen)
|
||||
salt := header.Extend(xplusSaltLen)
|
||||
c.randAccess.Lock()
|
||||
_, _ = c.rand.Read(salt)
|
||||
c.randAccess.Unlock()
|
||||
key := sha256.Sum256(append(c.key, salt...))
|
||||
var index int
|
||||
for _, buffer := range buffers {
|
||||
data := buffer.Bytes()
|
||||
for i := range data {
|
||||
data[i] ^= key[index%sha256.Size]
|
||||
index++
|
||||
}
|
||||
}
|
||||
buffers = append([]*buf.Buffer{header}, buffers...)
|
||||
return c.writer.WriteVectorisedPacket(buffers, destination)
|
||||
}
|
||||
Reference in New Issue
Block a user