|
|
|
|
@ -2,11 +2,13 @@ use crate::constants::*;
|
|
|
|
|
use crate::handshake; |
|
|
|
|
use crate::router; |
|
|
|
|
use crate::timers::{Events, Timers}; |
|
|
|
|
use crate::types::{Bind, Endpoint, Tun}; |
|
|
|
|
|
|
|
|
|
use crate::types::Endpoint; |
|
|
|
|
use crate::types::tun::{Tun, Reader, MTU}; |
|
|
|
|
use crate::types::bind::{Bind, Writer}; |
|
|
|
|
|
|
|
|
|
use hjul::Runner; |
|
|
|
|
|
|
|
|
|
use std::cmp; |
|
|
|
|
use std::ops::Deref; |
|
|
|
|
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; |
|
|
|
|
use std::sync::Arc; |
|
|
|
|
@ -27,12 +29,20 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128;
|
|
|
|
|
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4; |
|
|
|
|
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000); |
|
|
|
|
|
|
|
|
|
#[derive(Clone)] |
|
|
|
|
pub struct Peer<T: Tun, B: Bind> { |
|
|
|
|
pub router: Arc<router::Peer<Events<T, B>, T, B>>, |
|
|
|
|
pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>, |
|
|
|
|
pub state: Arc<PeerInner<B>>, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl <T : Tun, B : Bind> Clone for Peer<T, B > { |
|
|
|
|
fn clone(&self) -> Peer<T, B> { |
|
|
|
|
Peer{ |
|
|
|
|
router: self.router.clone(), |
|
|
|
|
state: self.state.clone() |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub struct PeerInner<B: Bind> { |
|
|
|
|
pub keepalive: AtomicUsize, // keepalive interval
|
|
|
|
|
pub rx_bytes: AtomicU64, |
|
|
|
|
@ -66,20 +76,22 @@ pub enum HandshakeJob<E> {
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
struct WireguardInner<T: Tun, B: Bind> { |
|
|
|
|
// provides access to the MTU value of the tun device
|
|
|
|
|
// (otherwise owned solely by the router and a dedicated read IO thread)
|
|
|
|
|
mtu: T::MTU, |
|
|
|
|
send: RwLock<Option<B::Writer>>, |
|
|
|
|
|
|
|
|
|
// identify and configuration map
|
|
|
|
|
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>, |
|
|
|
|
|
|
|
|
|
// cryptkey router
|
|
|
|
|
router: router::Device<Events<T, B>, T, B>, |
|
|
|
|
router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>, |
|
|
|
|
|
|
|
|
|
// handshake related state
|
|
|
|
|
handshake: RwLock<Handshake>, |
|
|
|
|
under_load: AtomicBool, |
|
|
|
|
pending: AtomicUsize, // num of pending handshake packets in queue
|
|
|
|
|
queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, |
|
|
|
|
|
|
|
|
|
// IO
|
|
|
|
|
bind: B, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub struct Wireguard<T: Tun, B: Bind> { |
|
|
|
|
@ -87,6 +99,17 @@ pub struct Wireguard<T: Tun, B: Bind> {
|
|
|
|
|
state: Arc<WireguardInner<T, B>>, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/* Returns the padded length of a message:
|
|
|
|
|
* |
|
|
|
|
* # Arguments |
|
|
|
|
* |
|
|
|
|
* - `size` : Size of unpadded message |
|
|
|
|
* - `mtu` : Maximum transmission unit of the device |
|
|
|
|
* |
|
|
|
|
* # Returns |
|
|
|
|
* |
|
|
|
|
* The padded length (always less than or equal to the MTU) |
|
|
|
|
*/ |
|
|
|
|
#[inline(always)] |
|
|
|
|
const fn padding(size: usize, mtu: usize) -> usize { |
|
|
|
|
#[inline(always)] |
|
|
|
|
@ -114,6 +137,15 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub fn get_sk(&self) -> Option<StaticSecret> { |
|
|
|
|
let mut handshake = self.state.handshake.read(); |
|
|
|
|
if handshake.active { |
|
|
|
|
Some(handshake.device.get_sk()) |
|
|
|
|
} else { |
|
|
|
|
None |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub fn new_peer(&self, pk: PublicKey) -> Peer<T, B> { |
|
|
|
|
let state = Arc::new(PeerInner { |
|
|
|
|
pk, |
|
|
|
|
@ -137,20 +169,92 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|
|
|
|
peer |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub fn new(tun: T, bind: B) -> Wireguard<T, B> { |
|
|
|
|
pub fn new_bind( |
|
|
|
|
reader: B::Reader, |
|
|
|
|
writer: B::Writer, |
|
|
|
|
closer: B::Closer |
|
|
|
|
) { |
|
|
|
|
|
|
|
|
|
// drop existing closer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// swap IO thread for new reader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// start UDP read IO thread
|
|
|
|
|
|
|
|
|
|
/* |
|
|
|
|
{ |
|
|
|
|
let wg = wg.clone(); |
|
|
|
|
let mtu = mtu.clone(); |
|
|
|
|
thread::spawn(move || { |
|
|
|
|
let mut last_under_load = |
|
|
|
|
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); |
|
|
|
|
|
|
|
|
|
loop { |
|
|
|
|
// create vector big enough for any message given current MTU
|
|
|
|
|
let size = mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE; |
|
|
|
|
let mut msg: Vec<u8> = Vec::with_capacity(size); |
|
|
|
|
msg.resize(size, 0); |
|
|
|
|
|
|
|
|
|
// read UDP packet into vector
|
|
|
|
|
let (size, src) = reader.read(&mut msg).unwrap(); // TODO handle error
|
|
|
|
|
msg.truncate(size); |
|
|
|
|
|
|
|
|
|
// message type de-multiplexer
|
|
|
|
|
if msg.len() < std::mem::size_of::<u32>() { |
|
|
|
|
continue; |
|
|
|
|
} |
|
|
|
|
match LittleEndian::read_u32(&msg[..]) { |
|
|
|
|
handshake::TYPE_COOKIE_REPLY |
|
|
|
|
| handshake::TYPE_INITIATION |
|
|
|
|
| handshake::TYPE_RESPONSE => { |
|
|
|
|
// update under_load flag
|
|
|
|
|
if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD { |
|
|
|
|
last_under_load = Instant::now(); |
|
|
|
|
wg.under_load.store(true, Ordering::SeqCst); |
|
|
|
|
} else if last_under_load.elapsed() > DURATION_UNDER_LOAD { |
|
|
|
|
wg.under_load.store(false, Ordering::SeqCst); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
wg.queue |
|
|
|
|
.lock() |
|
|
|
|
.send(HandshakeJob::Message(msg, src)) |
|
|
|
|
.unwrap(); |
|
|
|
|
} |
|
|
|
|
router::TYPE_TRANSPORT => { |
|
|
|
|
// transport message
|
|
|
|
|
let _ = wg.router.recv(src, msg); |
|
|
|
|
} |
|
|
|
|
_ => (), |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
}); |
|
|
|
|
} |
|
|
|
|
*/ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub fn new( |
|
|
|
|
reader: T::Reader,
|
|
|
|
|
writer: T::Writer,
|
|
|
|
|
mtu: T::MTU, |
|
|
|
|
) -> Wireguard<T, B> { |
|
|
|
|
// create device state
|
|
|
|
|
let mut rng = OsRng::new().unwrap(); |
|
|
|
|
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE); |
|
|
|
|
let wg = Arc::new(WireguardInner { |
|
|
|
|
mtu: mtu.clone(), |
|
|
|
|
peers: RwLock::new(HashMap::new()), |
|
|
|
|
router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()), |
|
|
|
|
send: RwLock::new(None), |
|
|
|
|
router: router::Device::new(num_cpus::get(), writer), // router owns the writing half
|
|
|
|
|
pending: AtomicUsize::new(0), |
|
|
|
|
handshake: RwLock::new(Handshake { |
|
|
|
|
device: handshake::Device::new(StaticSecret::new(&mut rng)), |
|
|
|
|
active: false, |
|
|
|
|
}), |
|
|
|
|
under_load: AtomicBool::new(false), |
|
|
|
|
bind: bind.clone(), |
|
|
|
|
queue: Mutex::new(tx), |
|
|
|
|
}); |
|
|
|
|
|
|
|
|
|
@ -158,7 +262,6 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|
|
|
|
for _ in 0..num_cpus::get() { |
|
|
|
|
let wg = wg.clone(); |
|
|
|
|
let rx = rx.clone(); |
|
|
|
|
let bind = bind.clone(); |
|
|
|
|
thread::spawn(move || { |
|
|
|
|
// prepare OsRng instance for this thread
|
|
|
|
|
let mut rng = OsRng::new().unwrap(); |
|
|
|
|
@ -189,19 +292,22 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|
|
|
|
Ok((pk, msg, keypair)) => { |
|
|
|
|
// send response
|
|
|
|
|
if let Some(msg) = msg { |
|
|
|
|
let _ = bind.send(&msg[..], &src).map_err(|e| { |
|
|
|
|
debug!( |
|
|
|
|
"handshake worker, failed to send response, error = {:?}", |
|
|
|
|
e |
|
|
|
|
) |
|
|
|
|
}); |
|
|
|
|
let send : &Option<B::Writer> = &*wg.send.read(); |
|
|
|
|
if let Some(writer) = send.as_ref() { |
|
|
|
|
let _ = writer.write(&msg[..], &src).map_err(|e| { |
|
|
|
|
debug!( |
|
|
|
|
"handshake worker, failed to send response, error = {:?}", |
|
|
|
|
e |
|
|
|
|
) |
|
|
|
|
}); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// update timers
|
|
|
|
|
if let Some(pk) = pk { |
|
|
|
|
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { |
|
|
|
|
// update endpoint (DISCUSS: right semantics?)
|
|
|
|
|
peer.router.set_endpoint(src_validate); |
|
|
|
|
// update endpoint
|
|
|
|
|
peer.router.set_endpoint(src); |
|
|
|
|
|
|
|
|
|
// add keypair to peer and free any unused ids
|
|
|
|
|
if let Some(keypair) = keypair { |
|
|
|
|
@ -227,68 +333,18 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|
|
|
|
}); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// start UDP read IO thread
|
|
|
|
|
{ |
|
|
|
|
let wg = wg.clone(); |
|
|
|
|
let tun = tun.clone(); |
|
|
|
|
let bind = bind.clone(); |
|
|
|
|
thread::spawn(move || { |
|
|
|
|
let mut last_under_load = |
|
|
|
|
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); |
|
|
|
|
|
|
|
|
|
loop { |
|
|
|
|
// create vector big enough for any message given current MTU
|
|
|
|
|
let size = tun.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE; |
|
|
|
|
let mut msg: Vec<u8> = Vec::with_capacity(size); |
|
|
|
|
msg.resize(size, 0); |
|
|
|
|
|
|
|
|
|
// read UDP packet into vector
|
|
|
|
|
let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
|
|
|
|
|
msg.truncate(size); |
|
|
|
|
|
|
|
|
|
// message type de-multiplexer
|
|
|
|
|
if msg.len() < std::mem::size_of::<u32>() { |
|
|
|
|
continue; |
|
|
|
|
} |
|
|
|
|
match LittleEndian::read_u32(&msg[..]) { |
|
|
|
|
handshake::TYPE_COOKIE_REPLY |
|
|
|
|
| handshake::TYPE_INITIATION |
|
|
|
|
| handshake::TYPE_RESPONSE => { |
|
|
|
|
// update under_load flag
|
|
|
|
|
if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD { |
|
|
|
|
last_under_load = Instant::now(); |
|
|
|
|
wg.under_load.store(true, Ordering::SeqCst); |
|
|
|
|
} else if last_under_load.elapsed() > DURATION_UNDER_LOAD { |
|
|
|
|
wg.under_load.store(false, Ordering::SeqCst); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
wg.queue |
|
|
|
|
.lock() |
|
|
|
|
.send(HandshakeJob::Message(msg, src)) |
|
|
|
|
.unwrap(); |
|
|
|
|
} |
|
|
|
|
router::TYPE_TRANSPORT => { |
|
|
|
|
// transport message
|
|
|
|
|
let _ = wg.router.recv(src, msg); |
|
|
|
|
} |
|
|
|
|
_ => (), |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
}); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// start TUN read IO thread
|
|
|
|
|
{ |
|
|
|
|
let wg = wg.clone(); |
|
|
|
|
thread::spawn(move || loop { |
|
|
|
|
// create vector big enough for any transport message (based on MTU)
|
|
|
|
|
let mtu = tun.mtu(); |
|
|
|
|
let mtu = mtu.mtu(); |
|
|
|
|
let size = mtu + router::SIZE_MESSAGE_PREFIX; |
|
|
|
|
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); |
|
|
|
|
msg.resize(size, 0); |
|
|
|
|
|
|
|
|
|
// read a new IP packet
|
|
|
|
|
let payload = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap(); |
|
|
|
|
let payload = reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap(); |
|
|
|
|
debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu); |
|
|
|
|
|
|
|
|
|
// truncate padding
|
|
|
|
|
|