diff --git a/protocol/tailscale/endpoint.go b/protocol/tailscale/endpoint.go index 5ec89feeb..c25588364 100644 --- a/protocol/tailscale/endpoint.go +++ b/protocol/tailscale/endpoint.go @@ -341,26 +341,42 @@ func (t *Endpoint) DialContext(ctx context.Context, network string, destination } return N.DialSerial(ctx, t, network, destination, destinationAddresses) } - addr := tcpip.FullAddress{ + addr4, addr6 := t.server.TailscaleIPs() + remoteAddr := tcpip.FullAddress{ NIC: 1, Port: destination.Port, Addr: addressFromAddr(destination.Addr), } + var localAddr tcpip.FullAddress var networkProtocol tcpip.NetworkProtocolNumber if destination.IsIPv4() { + if !addr4.IsValid() { + return nil, E.New("missing Tailscale IPv4 address") + } networkProtocol = header.IPv4ProtocolNumber + localAddr = tcpip.FullAddress{ + NIC: 1, + Addr: addressFromAddr(addr4), + } } else { + if !addr6.IsValid() { + return nil, E.New("missing Tailscale IPv6 address") + } networkProtocol = header.IPv6ProtocolNumber + localAddr = tcpip.FullAddress{ + NIC: 1, + Addr: addressFromAddr(addr6), + } } switch N.NetworkName(network) { case N.NetworkTCP: - tcpConn, err := gonet.DialContextTCP(ctx, t.stack, addr, networkProtocol) + tcpConn, err := gonet.DialTCPWithBind(ctx, t.stack, localAddr, remoteAddr, networkProtocol) if err != nil { return nil, err } return tcpConn, nil case N.NetworkUDP: - udpConn, err := gonet.DialUDP(t.stack, nil, &addr, networkProtocol) + udpConn, err := gonet.DialUDP(t.stack, &localAddr, &remoteAddr, networkProtocol) if err != nil { return nil, err }