diff --git a/common/dialer/dialer.go b/common/dialer/dialer.go index bfa8af215..2ba559f9e 100644 --- a/common/dialer/dialer.go +++ b/common/dialer/dialer.go @@ -145,3 +145,7 @@ type ParallelNetworkDialer interface { DialParallelNetwork(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) ListenSerialNetworkPacket(ctx context.Context, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) } + +type PacketDialerWithDestination interface { + ListenPacketWithDestination(ctx context.Context, destination M.Socksaddr) (net.PacketConn, netip.Addr, error) +} diff --git a/protocol/tailscale/endpoint.go b/protocol/tailscale/endpoint.go index 659277d96..ff82ef86e 100644 --- a/protocol/tailscale/endpoint.go +++ b/protocol/tailscale/endpoint.go @@ -63,6 +63,7 @@ import ( var ( _ adapter.OutboundWithPreferredRoutes = (*Endpoint)(nil) _ adapter.DirectRouteOutbound = (*Endpoint)(nil) + _ dialer.PacketDialerWithDestination = (*Endpoint)(nil) ) func init() { @@ -518,19 +519,7 @@ func (t *Endpoint) DialContext(ctx context.Context, network string, destination } } -func (t *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - t.logger.InfoContext(ctx, "outbound packet connection to ", destination) - if destination.IsFqdn() { - destinationAddresses, err := t.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{}) - if err != nil { - return nil, err - } - packetConn, _, err := N.ListenSerial(ctx, t, destination, destinationAddresses) - if err != nil { - return nil, err - } - return packetConn, err - } +func (t *Endpoint) listenPacketWithAddress(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { addr4, addr6 := t.server.TailscaleIPs() bind := tcpip.FullAddress{ NIC: 1, @@ -556,6 +545,44 @@ func (t *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (n return udpConn, nil } +func (t *Endpoint) ListenPacketWithDestination(ctx context.Context, destination M.Socksaddr) (net.PacketConn, netip.Addr, error) { + t.logger.InfoContext(ctx, "outbound packet connection to ", destination) + if destination.IsFqdn() { + destinationAddresses, err := t.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{}) + if err != nil { + return nil, netip.Addr{}, err + } + var errors []error + for _, address := range destinationAddresses { + packetConn, packetErr := t.listenPacketWithAddress(ctx, M.SocksaddrFrom(address, destination.Port)) + if packetErr == nil { + return packetConn, address, nil + } + errors = append(errors, packetErr) + } + return nil, netip.Addr{}, E.Errors(errors...) + } + packetConn, err := t.listenPacketWithAddress(ctx, destination) + if err != nil { + return nil, netip.Addr{}, err + } + if destination.IsIP() { + return packetConn, destination.Addr, nil + } + return packetConn, netip.Addr{}, nil +} + +func (t *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + packetConn, destinationAddress, err := t.ListenPacketWithDestination(ctx, destination) + if err != nil { + return nil, err + } + if destinationAddress.IsValid() && destination != M.SocksaddrFrom(destinationAddress, destination.Port) { + return bufio.NewNATPacketConn(bufio.NewPacketConn(packetConn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil + } + return packetConn, nil +} + func (t *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) { tsFilter := t.filter.Load() if tsFilter != nil { diff --git a/protocol/wireguard/endpoint.go b/protocol/wireguard/endpoint.go index 35ffd19e3..bcf2078ee 100644 --- a/protocol/wireguard/endpoint.go +++ b/protocol/wireguard/endpoint.go @@ -24,7 +24,10 @@ import ( "github.com/sagernet/sing/service" ) -var _ adapter.OutboundWithPreferredRoutes = (*Endpoint)(nil) +var ( + _ adapter.OutboundWithPreferredRoutes = (*Endpoint)(nil) + _ dialer.PacketDialerWithDestination = (*Endpoint)(nil) +) func RegisterEndpoint(registry *endpoint.Registry) { endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, NewEndpoint) @@ -219,20 +222,34 @@ func (w *Endpoint) DialContext(ctx context.Context, network string, destination return w.endpoint.DialContext(ctx, network, destination) } -func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { +func (w *Endpoint) ListenPacketWithDestination(ctx context.Context, destination M.Socksaddr) (net.PacketConn, netip.Addr, error) { w.logger.InfoContext(ctx, "outbound packet connection to ", destination) if destination.IsFqdn() { destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{}) if err != nil { - return nil, err + return nil, netip.Addr{}, err } - packetConn, _, err := N.ListenSerial(ctx, w.endpoint, destination, destinationAddresses) - if err != nil { - return nil, err - } - return packetConn, err + return N.ListenSerial(ctx, w.endpoint, destination, destinationAddresses) } - return w.endpoint.ListenPacket(ctx, destination) + packetConn, err := w.endpoint.ListenPacket(ctx, destination) + if err != nil { + return nil, netip.Addr{}, err + } + if destination.IsIP() { + return packetConn, destination.Addr, nil + } + return packetConn, netip.Addr{}, nil +} + +func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + packetConn, destinationAddress, err := w.ListenPacketWithDestination(ctx, destination) + if err != nil { + return nil, err + } + if destinationAddress.IsValid() && destination != M.SocksaddrFrom(destinationAddress, destination.Port) { + return bufio.NewNATPacketConn(bufio.NewPacketConn(packetConn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil + } + return packetConn, nil } func (w *Endpoint) PreferredDomain(domain string) bool { diff --git a/route/conn.go b/route/conn.go index 899d29391..59afe5394 100644 --- a/route/conn.go +++ b/route/conn.go @@ -188,6 +188,8 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial } else { if len(metadata.DestinationAddresses) > 0 { remotePacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) + } else if packetDialer, withDestination := this.(dialer.PacketDialerWithDestination); withDestination { + remotePacketConn, destinationAddress, err = packetDialer.ListenPacketWithDestination(ctx, metadata.Destination) } else { remotePacketConn, err = this.ListenPacket(ctx, metadata.Destination) } @@ -218,11 +220,16 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial } if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded { natConn.UpdateDestination(destinationAddress) - } else if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) { - if metadata.UDPDisableDomainUnmapping { - remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination) - } else { - remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination) + } else { + destination := M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) + if metadata.Destination != destination { + if metadata.UDPDisableDomainUnmapping { + remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), destination, originDestination) + } else { + remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), destination, originDestination) + } + } else if metadata.RouteOriginalDestination.IsValid() && metadata.RouteOriginalDestination != metadata.Destination { + remotePacketConn = bufio.NewDestinationNATPacketConn(bufio.NewPacketConn(remotePacketConn), metadata.Destination, metadata.RouteOriginalDestination) } } } else if metadata.RouteOriginalDestination.IsValid() && metadata.RouteOriginalDestination != metadata.Destination {