Browse Source

Work on router optimizations

router
Mathias Hall-Andersen 6 years ago
parent
commit
106c5e8b5c
  1. 3
      src/wireguard/queue.rs
  2. 4
      src/wireguard/router/constants.rs
  3. 100
      src/wireguard/router/device.rs
  4. 190
      src/wireguard/router/inbound.rs
  5. 5
      src/wireguard/router/mod.rs
  6. 110
      src/wireguard/router/outbound.rs
  7. 192
      src/wireguard/router/peer.rs
  8. 164
      src/wireguard/router/pool.rs
  9. 92
      src/wireguard/router/queue.rs
  10. 184
      src/wireguard/router/receive.rs
  11. 164
      src/wireguard/router/runq.rs
  12. 95
      src/wireguard/router/send.rs
  13. 30
      src/wireguard/router/worker.rs
  14. 257
      src/wireguard/router/workers.rs
  15. 2
      src/wireguard/wireguard.rs

3
src/wireguard/queue.rs

@ -1,6 +1,7 @@
use crossbeam_channel::{bounded, Receiver, Sender};
use std::sync::Mutex;
use crossbeam_channel::{bounded, Receiver, Sender};
pub struct ParallelQueue<T> {
queue: Mutex<Option<Sender<T>>>,
}

4
src/wireguard/router/constants.rs

@ -4,6 +4,6 @@ pub const MAX_QUEUED_PACKETS: usize = 1024;
// performance constants
pub const PARALLEL_QUEUE_SIZE: usize = MAX_QUEUED_PACKETS;
pub const PARALLEL_QUEUE_SIZE: usize = 4 * MAX_QUEUED_PACKETS;
pub const INORDER_QUEUE_SIZE: usize = MAX_QUEUED_PACKETS;
pub const MAX_INORDER_CONSUME: usize = INORDER_QUEUE_SIZE;

100
src/wireguard/router/device.rs

@ -10,19 +10,16 @@ use spin::{Mutex, RwLock};
use zerocopy::LayoutVerified;
use super::anti_replay::AntiReplay;
use super::pool::Job;
use super::constants::PARALLEL_QUEUE_SIZE;
use super::inbound;
use super::outbound;
use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::{new_peer, Peer, PeerHandle};
use super::types::{Callbacks, RouterError};
use super::SIZE_MESSAGE_PREFIX;
use super::receive::ReceiveJob;
use super::route::RoutingTable;
use super::runq::RunQueue;
use super::worker::{worker, JobUnion};
use super::super::{tun, udp, Endpoint, KeyPair};
use super::ParallelQueue;
@ -38,13 +35,8 @@ pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state
pub table: RoutingTable<Peer<E, C, T, B>>,
// work queues
pub queue_outbound: ParallelQueue<Job<Peer<E, C, T, B>, outbound::Outbound>>,
pub queue_inbound: ParallelQueue<Job<Peer<E, C, T, B>, inbound::Inbound<E, C, T, B>>>,
// run queues
pub run_inbound: RunQueue<Peer<E, C, T, B>>,
pub run_outbound: RunQueue<Peer<E, C, T, B>>,
// work queue
pub work: ParallelQueue<JobUnion<E, C, T, B>>,
}
pub struct EncryptionState {
@ -101,13 +93,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop
fn drop(&mut self) {
debug!("router: dropping device");
// close worker queues
self.state.queue_outbound.close();
self.state.queue_inbound.close();
// close run queues
self.state.run_outbound.close();
self.state.run_inbound.close();
// close worker queue
self.state.work.close();
// join all worker threads
while match self.handles.pop() {
@ -118,24 +105,17 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop
}
_ => false,
} {}
debug!("router: device dropped");
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<E, C, T, B> {
pub fn new(num_workers: usize, tun: T) -> DeviceHandle<E, C, T, B> {
// allocate shared device state
let (queue_outbound, mut outrx) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE);
let (queue_inbound, mut inrx) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE);
let (work, mut consumers) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE);
let device = Device {
inner: Arc::new(DeviceInner {
work,
inbound: tun,
queue_inbound,
outbound: RwLock::new((true, None)),
queue_outbound,
run_inbound: RunQueue::new(),
run_outbound: RunQueue::new(),
recv: RwLock::new(HashMap::new()),
table: RoutingTable::new(),
}),
@ -143,52 +123,10 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
// start worker threads
let mut threads = Vec::with_capacity(num_workers);
// inbound/decryption workers
for _ in 0..num_workers {
// parallel workers (parallel processing)
{
let device = device.clone();
let rx = inrx.pop().unwrap();
threads.push(thread::spawn(move || {
log::debug!("inbound parallel router worker started");
inbound::parallel(device, rx)
}));
}
// sequential workers (in-order processing)
{
let device = device.clone();
threads.push(thread::spawn(move || {
log::debug!("inbound sequential router worker started");
inbound::sequential(device)
}));
}
while let Some(rx) = consumers.pop() {
threads.push(thread::spawn(move || worker(rx)));
}
// outbound/encryption workers
for _ in 0..num_workers {
// parallel workers (parallel processing)
{
let device = device.clone();
let rx = outrx.pop().unwrap();
threads.push(thread::spawn(move || {
log::debug!("outbound parallel router worker started");
outbound::parallel(device, rx)
}));
}
// sequential workers (in-order processing)
{
let device = device.clone();
threads.push(thread::spawn(move || {
log::debug!("outbound sequential router worker started");
outbound::sequential(device)
}));
}
}
debug_assert_eq!(threads.len(), num_workers * 4);
debug_assert_eq!(threads.len(), num_workers);
// return exported device handle
DeviceHandle {
@ -250,10 +188,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
.ok_or(RouterError::NoCryptoKeyRoute)?;
// schedule for encryption and transmission to peer
if let Some(job) = peer.send_job(msg, true) {
self.state.queue_outbound.send(job);
}
peer.send(msg, true);
Ok(())
}
@ -297,10 +232,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
.get(&header.f_receiver.get())
.ok_or(RouterError::UnknownReceiverId)?;
// schedule for decryption and TUN write
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
log::trace!("schedule decryption of transport message");
self.state.queue_inbound.send(job);
// create inbound job
let job = ReceiveJob::new(msg, dec.clone(), src);
// 1. add to sequential queue (drop if full)
// 2. then add to parallel work queue (wait if full)
if dec.peer.inbound.push(job.clone()) {
self.state.work.send(JobUnion::Inbound(job));
}
Ok(())
}

190
src/wireguard/router/inbound.rs

@ -1,190 +0,0 @@
use std::mem;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use crossbeam_channel::Receiver;
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
use zerocopy::{AsBytes, LayoutVerified};
use super::constants::MAX_INORDER_CONSUME;
use super::device::DecryptionState;
use super::device::Device;
use super::messages::TransportHeader;
use super::peer::Peer;
use super::pool::*;
use super::types::Callbacks;
use super::{tun, udp, Endpoint};
use super::{REJECT_AFTER_MESSAGES, SIZE_TAG};
pub struct Inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
msg: Vec<u8>,
failed: bool,
state: Arc<DecryptionState<E, C, T, B>>,
endpoint: Option<E>,
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Inbound<E, C, T, B> {
pub fn new(
msg: Vec<u8>,
state: Arc<DecryptionState<E, C, T, B>>,
endpoint: E,
) -> Inbound<E, C, T, B> {
Inbound {
msg,
state,
failed: false,
endpoint: Some(endpoint),
}
}
}
#[inline(always)]
pub fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
device: Device<E, C, T, B>,
receiver: Receiver<Job<Peer<E, C, T, B>, Inbound<E, C, T, B>>>,
) {
// parallel work to apply
#[inline(always)]
fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
peer: &Peer<E, C, T, B>,
body: &mut Inbound<E, C, T, B>,
) {
log::trace!("worker, parallel section, obtained job");
// cast to header followed by payload
let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
match LayoutVerified::new_from_prefix(&mut body.msg[..]) {
Some(v) => v,
None => {
log::debug!("inbound worker: failed to parse message");
return;
}
};
// authenticate and decrypt payload
{
// create nonce object
let mut nonce = [0u8; 12];
debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
nonce[4..].copy_from_slice(header.f_counter.as_bytes());
let nonce = Nonce::assume_unique_for_key(nonce);
// do the weird ring AEAD dance
let key = LessSafeKey::new(
UnboundKey::new(&CHACHA20_POLY1305, &body.state.keypair.recv.key[..]).unwrap(),
);
// attempt to open (and authenticate) the body
match key.open_in_place(nonce, Aad::empty(), packet) {
Ok(_) => (),
Err(_) => {
// fault and return early
log::trace!("inbound worker: authentication failure");
body.failed = true;
return;
}
}
}
// check that counter not after reject
if header.f_counter.get() >= REJECT_AFTER_MESSAGES {
body.failed = true;
return;
}
// cryptokey route and strip padding
let inner_len = {
let length = packet.len() - SIZE_TAG;
if length > 0 {
peer.device.table.check_route(&peer, &packet[..length])
} else {
Some(0)
}
};
// truncate to remove tag
match inner_len {
None => {
log::trace!("inbound worker: cryptokey routing failed");
body.failed = true;
}
Some(len) => {
log::trace!(
"inbound worker: good route, length = {} {}",
len,
if len == 0 { "(keepalive)" } else { "" }
);
body.msg.truncate(mem::size_of::<TransportHeader>() + len);
}
}
}
worker_parallel(device, |dev| &dev.run_inbound, receiver, work)
}
#[inline(always)]
pub fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
device: Device<E, C, T, B>,
) {
// sequential work to apply
fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
peer: &Peer<E, C, T, B>,
body: &mut Inbound<E, C, T, B>,
) {
log::trace!("worker, sequential section, obtained job");
// decryption failed, return early
if body.failed {
log::trace!("job faulted, remove from queue and ignore");
return;
}
// cast transport header
let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
match LayoutVerified::new_from_prefix(&body.msg[..]) {
Some(v) => v,
None => {
log::debug!("inbound worker: failed to parse message");
return;
}
};
// check for replay
if !body.state.protector.lock().update(header.f_counter.get()) {
log::debug!("inbound worker: replay detected");
return;
}
// check for confirms key
if !body.state.confirmed.swap(true, Ordering::SeqCst) {
log::debug!("inbound worker: message confirms key");
peer.confirm_key(&body.state.keypair);
}
// update endpoint
*peer.endpoint.lock() = body.endpoint.take();
// check if should be written to TUN
let mut sent = false;
if packet.len() > 0 {
sent = match peer.device.inbound.write(&packet[..]) {
Err(e) => {
log::debug!("failed to write inbound packet to TUN: {:?}", e);
false
}
Ok(_) => true,
}
} else {
log::debug!("inbound worker: received keepalive")
}
// trigger callback
C::recv(&peer.opaque, body.msg.len(), sent, &body.state.keypair);
}
// handle message from the peers inbound queue
device.run_inbound.run(|peer| {
peer.inbound
.handle(|body| work(&peer, body), MAX_INORDER_CONSUME)
});
}

5
src/wireguard/router/mod.rs

@ -1,14 +1,10 @@
mod anti_replay;
mod constants;
mod device;
mod inbound;
mod ip;
mod messages;
mod outbound;
mod peer;
mod pool;
mod route;
mod runq;
mod types;
mod queue;
@ -25,7 +21,6 @@ use std::mem;
use super::constants::REJECT_AFTER_MESSAGES;
use super::queue::ParallelQueue;
use super::types::*;
use super::{tun, udp, Endpoint};
pub const SIZE_TAG: usize = 16;
pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>();

110
src/wireguard/router/outbound.rs

@ -1,110 +0,0 @@
use std::sync::Arc;
use crossbeam_channel::Receiver;
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
use zerocopy::{AsBytes, LayoutVerified};
use super::constants::MAX_INORDER_CONSUME;
use super::device::Device;
use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::Peer;
use super::pool::*;
use super::types::Callbacks;
use super::KeyPair;
use super::{tun, udp, Endpoint};
use super::{REJECT_AFTER_MESSAGES, SIZE_TAG};
pub struct Outbound {
msg: Vec<u8>,
keypair: Arc<KeyPair>,
counter: u64,
}
impl Outbound {
pub fn new(msg: Vec<u8>, keypair: Arc<KeyPair>, counter: u64) -> Outbound {
Outbound {
msg,
keypair,
counter,
}
}
}
#[inline(always)]
pub fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
device: Device<E, C, T, B>,
receiver: Receiver<Job<Peer<E, C, T, B>, Outbound>>,
) {
#[inline(always)]
fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
_peer: &Peer<E, C, T, B>,
body: &mut Outbound,
) {
log::trace!("worker, parallel section, obtained job");
// make space for the tag
body.msg.extend([0u8; SIZE_TAG].iter());
// cast to header (should never fail)
let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
LayoutVerified::new_from_prefix(&mut body.msg[..])
.expect("earlier code should ensure that there is ample space");
// set header fields
debug_assert!(
body.counter < REJECT_AFTER_MESSAGES,
"should be checked when assigning counters"
);
header.f_type.set(TYPE_TRANSPORT);
header.f_receiver.set(body.keypair.send.id);
header.f_counter.set(body.counter);
// create a nonce object
let mut nonce = [0u8; 12];
debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
nonce[4..].copy_from_slice(header.f_counter.as_bytes());
let nonce = Nonce::assume_unique_for_key(nonce);
// do the weird ring AEAD dance
let key = LessSafeKey::new(
UnboundKey::new(&CHACHA20_POLY1305, &body.keypair.send.key[..]).unwrap(),
);
// encrypt content of transport message in-place
let end = packet.len() - SIZE_TAG;
let tag = key
.seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end])
.unwrap();
// append tag
packet[end..].copy_from_slice(tag.as_ref());
}
worker_parallel(device, |dev| &dev.run_outbound, receiver, work);
}
#[inline(always)]
pub fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
device: Device<E, C, T, B>,
) {
device.run_outbound.run(|peer| {
peer.outbound.handle(
|body| {
log::trace!("worker, sequential section, obtained job");
// send to peer
let xmit = peer.send(&body.msg[..]).is_ok();
// trigger callback
C::send(
&peer.opaque,
body.msg.len(),
xmit,
&body.keypair,
body.counter,
);
},
MAX_INORDER_CONSUME,
)
});
}

192
src/wireguard/router/peer.rs

@ -1,13 +1,3 @@
use std::mem;
use std::net::{IpAddr, SocketAddr};
use std::ops::Deref;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use arraydeque::{ArrayDeque, Wrapping};
use log::debug;
use spin::Mutex;
use super::super::constants::*;
use super::super::{tun, udp, Endpoint, KeyPair};
@ -15,20 +5,25 @@ use super::anti_replay::AntiReplay;
use super::device::DecryptionState;
use super::device::Device;
use super::device::EncryptionState;
use super::messages::TransportHeader;
use super::constants::*;
use super::runq::ToKey;
use super::types::{Callbacks, RouterError};
use super::SIZE_MESSAGE_PREFIX;
// worker pool related
use super::inbound::Inbound;
use super::outbound::Outbound;
use super::queue::Queue;
use super::send::SendJob;
use super::receive::ReceiveJob;
use super::send::SendJob;
use super::worker::JobUnion;
use std::mem;
use std::net::{IpAddr, SocketAddr};
use std::ops::Deref;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use arraydeque::{ArrayDeque, Wrapping};
use log::debug;
use spin::Mutex;
pub struct KeyWheel {
next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
@ -44,7 +39,7 @@ pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E
pub inbound: Queue<ReceiveJob<E, C, T, B>>,
pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_QUEUED_PACKETS], Wrapping>>,
pub keys: Mutex<KeyWheel>,
pub ekey: Mutex<Option<EncryptionState>>,
pub enc_key: Mutex<Option<EncryptionState>>,
pub endpoint: Mutex<Option<E>>,
}
@ -69,13 +64,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq for
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ToKey for Peer<E, C, T, B> {
type Key = usize;
fn to_key(&self) -> usize {
Arc::downgrade(&self.inner).into_raw() as usize
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Peer<E, C, T, B> {}
/* A peer is transparently dereferenced to the inner type
@ -157,7 +145,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Peer
keys.current = None;
keys.previous = None;
*peer.ekey.lock() = None;
*peer.enc_key.lock() = None;
*peer.endpoint.lock() = None;
debug!("peer dropped & removed from device");
@ -175,9 +163,9 @@ pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
inner: Arc::new(PeerInner {
opaque,
device,
inbound: InorderQueue::new(),
outbound: InorderQueue::new(),
ekey: spin::Mutex::new(None),
inbound: Queue::new(),
outbound: Queue::new(),
enc_key: spin::Mutex::new(None),
endpoint: spin::Mutex::new(None),
keys: spin::Mutex::new(KeyWheel {
next: None,
@ -203,7 +191,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E,
/// # Returns
///
/// Unit if packet was sent, or an error indicating why sending failed
pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> {
pub fn send_raw(&self, msg: &[u8]) -> Result<(), RouterError> {
debug!("peer.send");
// send to endpoint (if known)
@ -226,6 +214,57 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E,
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, B> {
/// Encrypt and send a message to the peer
///
/// Arguments:
///
/// - `msg` : A padded vector holding the message (allows in-place construction of the transport header)
/// - `stage`: Should the message be staged if no key is available
///
pub(super) fn send(&self, msg: Vec<u8>, stage: bool) {
// check if key available
let (job, need_key) = {
let mut enc_key = self.enc_key.lock();
match enc_key.as_mut() {
None => {
if stage {
self.staged_packets.lock().push_back(msg);
};
(None, true)
}
Some(mut state) => {
// avoid integer overflow in nonce
if state.nonce >= REJECT_AFTER_MESSAGES - 1 {
*enc_key = None;
if stage {
self.staged_packets.lock().push_back(msg);
}
(None, true)
} else {
debug!("encryption state available, nonce = {}", state.nonce);
let job =
SendJob::new(msg, state.nonce, state.keypair.clone(), self.clone());
if self.outbound.push(job.clone()) {
state.nonce += 1;
(Some(job), false)
} else {
(None, false)
}
}
}
}
};
if need_key {
debug_assert!(job.is_none());
C::need_key(&self.opaque);
};
if let Some(job) = job {
self.device.work.send(JobUnion::Outbound(job))
}
}
// Transmit all staged packets
fn send_staged(&self) -> bool {
debug!("peer.send_staged");
@ -235,28 +274,14 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
match staged.pop_front() {
Some(msg) => {
sent = true;
self.send_raw(msg);
self.send(msg, false);
}
None => break sent,
}
}
}
// Treat the msg as the payload of a transport message
// Unlike device.send, peer.send_raw does not buffer messages when a key is not available.
fn send_raw(&self, msg: Vec<u8>) -> bool {
log::debug!("peer.send_raw");
match self.send_job(msg, false) {
Some(job) => {
self.device.queue_outbound.send(job);
debug!("send_raw: got obtained send_job");
true
}
None => false,
}
}
pub fn confirm_key(&self, keypair: &Arc<KeyPair>) {
pub(super) fn confirm_key(&self, keypair: &Arc<KeyPair>) {
debug!("peer.confirm_key");
{
// take lock and check keypair = keys.next
@ -284,68 +309,12 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
C::key_confirmed(&self.opaque);
// set new key for encryption
*self.ekey.lock() = ekey;
*self.enc_key.lock() = ekey;
}
// start transmission of staged packets
self.send_staged();
}
pub fn send_job(&self, msg: Vec<u8>, stage: bool) -> Option<SendJob<E, C, T, B>> {
debug!("peer.send_job");
debug_assert!(
msg.len() >= mem::size_of::<TransportHeader>(),
"received TUN message with size: {:}",
msg.len()
);
// check if has key
let (keypair, counter) = {
let keypair = {
// TODO: consider using atomic ptr for ekey state
let mut ekey = self.ekey.lock();
match ekey.as_mut() {
None => None,
Some(mut state) => {
// avoid integer overflow in nonce
if state.nonce >= REJECT_AFTER_MESSAGES - 1 {
*ekey = None;
None
} else {
debug!("encryption state available, nonce = {}", state.nonce);
let counter = state.nonce;
state.nonce += 1;
SendJob::new(
msg,
state.nonce,
state.keypair.clone(),
self.clone()
);
Some((state.keypair.clone(), counter))
}
}
}
};
// If not suitable key was found:
// 1. Stage packet for later transmission
// 2. Request new key
if keypair.is_none() && stage {
self.staged_packets.lock().push_back(msg);
C::need_key(&self.opaque);
return None;
};
keypair
}?;
// add job to in-order queue and return sender to device for inclusion in worker pool
let job = Job::new(self.clone(), Outbound::new(msg, keypair, counter));
self.outbound.send(job.clone());
Some(job)
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, C, T, B> {
@ -397,7 +366,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E,
}
// clear encryption state
*self.peer.ekey.lock() = None;
*self.peer.enc_key.lock() = None;
}
pub fn down(&self) {
@ -434,7 +403,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E,
// update key-wheel
if new.initiator {
// start using key for encryption
*self.peer.ekey.lock() = Some(EncryptionState::new(&new));
*self.peer.enc_key.lock() = Some(EncryptionState::new(&new));
// move current into previous
keys.previous = keys.current.as_ref().map(|v| v.clone());
@ -468,16 +437,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E,
// schedule confirmation
if initiator {
debug_assert!(self.peer.ekey.lock().is_some());
debug_assert!(self.peer.enc_key.lock().is_some());
debug!("peer.add_keypair: is initiator, must confirm the key");
// attempt to confirm using staged packets
if !self.peer.send_staged() {
// fall back to keepalive packet
let ok = self.send_keepalive();
debug!(
"peer.add_keypair: keepalive for confirmation, sent = {}",
ok
);
self.send_keepalive();
debug!("peer.add_keypair: keepalive for confirmation",);
}
debug!("peer.add_keypair: key attempted confirmed");
}
@ -489,9 +455,9 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E,
release
}
pub fn send_keepalive(&self) -> bool {
pub fn send_keepalive(&self) {
debug!("peer.send_keepalive");
self.peer.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX])
self.peer.send(vec![0u8; SIZE_MESSAGE_PREFIX], false)
}
/// Map a subnet to the peer

164
src/wireguard/router/pool.rs

@ -1,164 +0,0 @@
use std::mem;
use std::sync::Arc;
use arraydeque::ArrayDeque;
use crossbeam_channel::Receiver;
use spin::{Mutex, MutexGuard};
use super::constants::INORDER_QUEUE_SIZE;
use super::runq::{RunQueue, ToKey};
pub struct InnerJob<P, B> {
// peer (used by worker to schedule/handle inorder queue),
// when the peer is None, the job is complete
peer: Option<P>,
pub body: B,
}
pub struct Job<P, B> {
inner: Arc<Mutex<InnerJob<P, B>>>,
}
impl<P, B> Clone for Job<P, B> {
fn clone(&self) -> Job<P, B> {
Job {
inner: self.inner.clone(),
}
}
}
impl<P, B> Job<P, B> {
pub fn new(peer: P, body: B) -> Job<P, B> {
Job {
inner: Arc::new(Mutex::new(InnerJob {
peer: Some(peer),
body,
})),
}
}
}
impl<P, B> Job<P, B> {
/// Returns a mutex guard to the inner job if complete
pub fn complete(&self) -> Option<MutexGuard<InnerJob<P, B>>> {
self.inner
.try_lock()
.and_then(|m| if m.peer.is_none() { Some(m) } else { None })
}
}
pub struct InorderQueue<P, B> {
queue: Mutex<ArrayDeque<[Job<P, B>; INORDER_QUEUE_SIZE]>>,
}
impl<P, B> InorderQueue<P, B> {
pub fn new() -> InorderQueue<P, B> {
InorderQueue {
queue: Mutex::new(ArrayDeque::new()),
}
}
/// Add a new job to the in-order queue
///
/// # Arguments
///
/// - `job`: The job added to the back of the queue
///
/// # Returns
///
/// True if the element was added,
/// false to indicate that the queue is full.
pub fn send(&self, job: Job<P, B>) -> bool {
self.queue.lock().push_back(job).is_ok()
}
/// Consume completed jobs from the in-order queue
///
/// # Arguments
///
/// - `f`: function to apply to the body of each jobof each job.
/// - `limit`: maximum number of jobs to handle before returning
///
/// # Returns
///
/// A boolean indicating if the limit was reached:
/// true indicating that the limit was reached,
/// while false implies that the queue is empty or an uncompleted job was reached.
#[inline(always)]
pub fn handle<F: Fn(&mut B)>(&self, f: F, mut limit: usize) -> bool {
// take the mutex
let mut queue = self.queue.lock();
while limit > 0 {
// attempt to extract front element
let front = queue.pop_front();
let elem = match front {
Some(elem) => elem,
_ => {
return false;
}
};
// apply function if job complete
let ret = if let Some(mut guard) = elem.complete() {
mem::drop(queue);
f(&mut guard.body);
queue = self.queue.lock();
false
} else {
true
};
// job not complete yet, return job to front
if ret {
queue.push_front(elem).unwrap();
return false;
}
limit -= 1;
}
// did not complete all jobs
true
}
}
/// Allows easy construction of a parallel worker.
/// Applicable for both decryption and encryption workers.
#[inline(always)]
pub fn worker_parallel<
P: ToKey, // represents a peer (atomic reference counted pointer)
B, // inner body type (message buffer, key material, ...)
D, // device
W: Fn(&P, &mut B),
Q: Fn(&D) -> &RunQueue<P>,
>(
device: D,
queue: Q,
receiver: Receiver<Job<P, B>>,
work: W,
) {
log::trace!("router worker started");
loop {
// handle new job
let peer = {
// get next job
let job = match receiver.recv() {
Ok(job) => job,
_ => return,
};
// lock the job
let mut job = job.inner.lock();
// take the peer from the job
let peer = job.peer.take().unwrap();
// process job
work(&peer, &mut job.body);
peer
};
// process inorder jobs for peer
queue(&device).insert(peer);
}
}

92
src/wireguard/router/queue.rs

@ -4,29 +4,36 @@ use spin::Mutex;
use std::mem;
use std::sync::atomic::{AtomicUsize, Ordering};
const QUEUE_SIZE: usize = 1024;
pub trait Job: Sized {
fn queue(&self) -> &Queue<Self>;
use super::constants::INORDER_QUEUE_SIZE;
pub trait SequentialJob {
fn is_ready(&self) -> bool;
fn parallel_work(&self);
fn sequential_work(self);
}
pub trait ParallelJob: Sized + SequentialJob {
fn queue(&self) -> &Queue<Self>;
fn parallel_work(&self);
}
pub struct Queue<J: Job> {
pub struct Queue<J: SequentialJob> {
contenders: AtomicUsize,
queue: Mutex<ArrayDeque<[J; QUEUE_SIZE]>>,
queue: Mutex<ArrayDeque<[J; INORDER_QUEUE_SIZE]>>,
#[cfg(debug)]
_flag: Mutex<()>,
}
impl<J: Job> Queue<J> {
impl<J: SequentialJob> Queue<J> {
pub fn new() -> Queue<J> {
Queue {
contenders: AtomicUsize::new(0),
queue: Mutex::new(ArrayDeque::new()),
#[cfg(debug)]
_flag: Mutex::new(()),
}
}
@ -36,14 +43,22 @@ impl<J: Job> Queue<J> {
pub fn consume(&self) {
// check if we are the first contender
let pos = self.contenders.fetch_add(1, Ordering::Acquire);
let pos = self.contenders.fetch_add(1, Ordering::SeqCst);
if pos > 0 {
assert!(pos < usize::max_value(), "contenders overflow");
assert!(usize::max_value() > pos, "contenders overflow");
return;
}
// enter the critical section
let mut contenders = 1; // myself
while contenders > 0 {
// check soundness in debug builds
#[cfg(debug)]
let _flag = self
._flag
.try_lock()
.expect("contenders should ensure mutual exclusion");
// handle every ready element
loop {
let mut queue = self.queue.lock();
@ -69,8 +84,61 @@ impl<J: Job> Queue<J> {
job.sequential_work();
}
#[cfg(debug)]
mem::drop(_flag);
// decrease contenders
contenders = self.contenders.fetch_sub(contenders, Ordering::Acquire) - contenders;
contenders = self.contenders.fetch_sub(contenders, Ordering::SeqCst) - contenders;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
use rand::thread_rng;
use rand::Rng;
struct TestJob {}
impl SequentialJob for TestJob {
fn is_ready(&self) -> bool {
true
}
fn sequential_work(self) {}
}
/* Fuzz the Queue */
#[test]
fn test_queue() {
fn hammer(queue: &Arc<Queue<TestJob>>) {
let mut rng = thread_rng();
for _ in 0..1_000_000 {
if rng.gen() {
queue.push(TestJob {});
} else {
queue.consume();
}
}
}
let queue = Arc::new(Queue::new());
// repeatedly apply operations randomly from concurrent threads
let other = {
let queue = queue.clone();
thread::spawn(move || hammer(&queue))
};
hammer(&queue);
// wait, consume and check empty
other.join().unwrap();
queue.consume();
assert_eq!(queue.queue.lock().len(), 0);
}
}

184
src/wireguard/router/receive.rs

@ -1,21 +1,18 @@
use super::queue::{Job, Queue};
use super::KeyPair;
use super::device::DecryptionState;
use super::messages::TransportHeader;
use super::queue::{ParallelJob, Queue, SequentialJob};
use super::types::Callbacks;
use super::peer::Peer;
use super::{REJECT_AFTER_MESSAGES, SIZE_TAG};
use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::device::DecryptionState;
use super::super::{tun, udp, Endpoint};
use std::mem;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::mem;
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
use zerocopy::{AsBytes, LayoutVerified};
use spin::Mutex;
use zerocopy::{AsBytes, LayoutVerified};
struct Inner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
ready: AtomicBool,
@ -23,49 +20,49 @@ struct Inner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
state: Arc<DecryptionState<E, C, T, B>>, // decryption state (keys and replay protector)
}
pub struct ReceiveJob<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
inner: Arc<Inner<E, C, T, B>>,
pub struct ReceiveJob<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
Arc<Inner<E, C, T, B>>,
);
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone
for ReceiveJob<E, C, T, B>
{
fn clone(&self) -> ReceiveJob<E, C, T, B> {
ReceiveJob(self.0.clone())
}
}
impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ReceiveJob<E, C, T, B> {
fn new(buffer: Vec<u8>, state: Arc<DecryptionState<E, C, T, B>>, endpoint: E) -> Option<ReceiveJob<E, C, T, B>> {
// create job
let inner = Arc::new(Inner{
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ReceiveJob<E, C, T, B> {
pub fn new(
buffer: Vec<u8>,
state: Arc<DecryptionState<E, C, T, B>>,
endpoint: E,
) -> ReceiveJob<E, C, T, B> {
ReceiveJob(Arc::new(Inner {
ready: AtomicBool::new(false),
buffer: Mutex::new((Some(endpoint), buffer)),
state
});
// attempt to add to queue
if state.peer.inbound.push(ReceiveJob{ inner: inner.clone()}) {
Some(ReceiveJob{inner})
} else {
None
}
state,
}))
}
}
impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for ReceiveJob<E, C, T, B> {
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ParallelJob
for ReceiveJob<E, C, T, B>
{
fn queue(&self) -> &Queue<Self> {
&self.inner.state.peer.inbound
}
fn is_ready(&self) -> bool {
self.inner.ready.load(Ordering::Acquire)
&self.0.state.peer.inbound
}
fn parallel_work(&self) {
// TODO: refactor
// decrypt
{
let job = &self.inner;
let job = &self.0;
let peer = &job.state.peer;
let mut msg = job.buffer.lock();
let failed = || {
// cast to header followed by payload
let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
// cast to header followed by payload
let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
match LayoutVerified::new_from_prefix(&mut msg.1[..]) {
Some(v) => v,
None => {
@ -74,73 +71,81 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for Rece
}
};
// authenticate and decrypt payload
{
// create nonce object
let mut nonce = [0u8; 12];
debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
nonce[4..].copy_from_slice(header.f_counter.as_bytes());
let nonce = Nonce::assume_unique_for_key(nonce);
// do the weird ring AEAD dance
let key = LessSafeKey::new(
UnboundKey::new(&CHACHA20_POLY1305, &job.state.keypair.recv.key[..]).unwrap(),
);
// attempt to open (and authenticate) the body
match key.open_in_place(nonce, Aad::empty(), packet) {
Ok(_) => (),
Err(_) => {
// fault and return early
log::trace!("inbound worker: authentication failure");
msg.1.truncate(0);
return;
}
// authenticate and decrypt payload
{
// create nonce object
let mut nonce = [0u8; 12];
debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
nonce[4..].copy_from_slice(header.f_counter.as_bytes());
let nonce = Nonce::assume_unique_for_key(nonce);
// do the weird ring AEAD dance
let key = LessSafeKey::new(
UnboundKey::new(&CHACHA20_POLY1305, &job.state.keypair.recv.key[..]).unwrap(),
);
// attempt to open (and authenticate) the body
match key.open_in_place(nonce, Aad::empty(), packet) {
Ok(_) => (),
Err(_) => {
// fault and return early
log::trace!("inbound worker: authentication failure");
msg.1.truncate(0);
return;
}
}
}
// check that counter not after reject
if header.f_counter.get() >= REJECT_AFTER_MESSAGES {
msg.1.truncate(0);
return;
}
// cryptokey route and strip padding
let inner_len = {
let length = packet.len() - SIZE_TAG;
if length > 0 {
peer.device.table.check_route(&peer, &packet[..length])
} else {
Some(0)
}
};
// check that counter not after reject
if header.f_counter.get() >= REJECT_AFTER_MESSAGES {
msg.1.truncate(0);
return;
}
// truncate to remove tag
match inner_len {
None => {
log::trace!("inbound worker: cryptokey routing failed");
msg.1.truncate(0);
}
Some(len) => {
log::trace!(
"inbound worker: good route, length = {} {}",
len,
if len == 0 { "(keepalive)" } else { "" }
);
msg.1.truncate(mem::size_of::<TransportHeader>() + len);
}
// cryptokey route and strip padding
let inner_len = {
let length = packet.len() - SIZE_TAG;
if length > 0 {
peer.device.table.check_route(&peer, &packet[..length])
} else {
Some(0)
}
};
// truncate to remove tag
match inner_len {
None => {
log::trace!("inbound worker: cryptokey routing failed");
msg.1.truncate(0);
}
Some(len) => {
log::trace!(
"inbound worker: good route, length = {} {}",
len,
if len == 0 { "(keepalive)" } else { "" }
);
msg.1.truncate(mem::size_of::<TransportHeader>() + len);
}
}
}
// mark ready
self.inner.ready.store(true, Ordering::Release);
self.0.ready.store(true, Ordering::Release);
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SequentialJob
for ReceiveJob<E, C, T, B>
{
fn is_ready(&self) -> bool {
self.0.ready.load(Ordering::Acquire)
}
fn sequential_work(self) {
let job = &self.inner;
let job = &self.0;
let peer = &job.state.peer;
let mut msg = job.buffer.lock();
let endpoint = msg.0.take();
// cast transport header
let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
@ -165,7 +170,7 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for Rece
}
// update endpoint
*peer.endpoint.lock() = msg.0.take();
*peer.endpoint.lock() = endpoint;
// check if should be written to TUN
let mut sent = false;
@ -184,5 +189,4 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for Rece
// trigger callback
C::recv(&peer.opaque, msg.1.len(), sent, &job.state.keypair);
}
}

164
src/wireguard/router/runq.rs

@ -1,164 +0,0 @@
use std::hash::Hash;
use std::mem;
use std::sync::{Condvar, Mutex};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::collections::VecDeque;
pub trait ToKey {
type Key: Hash + Eq;
fn to_key(&self) -> Self::Key;
}
pub struct RunQueue<T: ToKey> {
cvar: Condvar,
inner: Mutex<Inner<T>>,
}
struct Inner<T: ToKey> {
stop: bool,
queue: VecDeque<T>,
members: HashMap<T::Key, usize>,
}
impl<T: ToKey> RunQueue<T> {
pub fn close(&self) {
let mut inner = self.inner.lock().unwrap();
inner.stop = true;
self.cvar.notify_all();
}
pub fn new() -> RunQueue<T> {
RunQueue {
cvar: Condvar::new(),
inner: Mutex::new(Inner {
stop: false,
queue: VecDeque::new(),
members: HashMap::new(),
}),
}
}
pub fn insert(&self, v: T) {
let key = v.to_key();
let mut inner = self.inner.lock().unwrap();
match inner.members.entry(key) {
Entry::Occupied(mut elem) => {
*elem.get_mut() += 1;
}
Entry::Vacant(spot) => {
// add entry to back of queue
spot.insert(0);
inner.queue.push_back(v);
// wake a thread
self.cvar.notify_one();
}
}
}
/// Run (consume from) the run queue using the provided function.
/// The function should return wheter the given element should be rescheduled.
///
/// # Arguments
///
/// - `f` : function to apply to every element
///
/// # Note
///
/// The function f may be called again even when the element was not inserted back in to the
/// queue since the last applciation and no rescheduling was requested.
///
/// This happens then the function handles all work for T,
/// but T is added to the run queue while the function is running.
pub fn run<F: Fn(&T) -> bool>(&self, f: F) {
let mut inner = self.inner.lock().unwrap();
loop {
// fetch next element
let elem = loop {
// run-queue closed
if inner.stop {
return;
}
// try to pop from queue
match inner.queue.pop_front() {
Some(elem) => {
break elem;
}
None => (),
};
// wait for an element to be inserted
inner = self.cvar.wait(inner).unwrap();
};
// fetch current request number
let key = elem.to_key();
let old_n = *inner.members.get(&key).unwrap();
mem::drop(inner); // drop guard
// handle element
let rerun = f(&elem);
// if the function requested a re-run add the element to the back of the queue
inner = self.inner.lock().unwrap();
if rerun {
inner.queue.push_back(elem);
continue;
}
// otherwise check if new requests have come in since we ran the function
match inner.members.entry(key) {
Entry::Occupied(occ) => {
if *occ.get() == old_n {
// no new requests since last, remove entry.
occ.remove();
} else {
// new requests, reschedule.
inner.queue.push_back(elem);
}
}
Entry::Vacant(_) => {
unreachable!();
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
/*
#[test]
fn test_wait() {
let queue: Arc<RunQueue<usize>> = Arc::new(RunQueue::new());
{
let queue = queue.clone();
thread::spawn(move || {
queue.run(|e| {
println!("t0 {}", e);
thread::sleep(Duration::from_millis(100));
})
});
}
{
let queue = queue.clone();
thread::spawn(move || {
queue.run(|e| {
println!("t1 {}", e);
thread::sleep(Duration::from_millis(100));
})
});
}
}
*/
}

95
src/wireguard/router/send.rs

@ -1,4 +1,4 @@
use super::queue::{Job, Queue};
use super::queue::{SequentialJob, ParallelJob, Queue};
use super::KeyPair;
use super::types::Callbacks;
use super::peer::Peer;
@ -22,8 +22,14 @@ struct Inner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
peer: Peer<E, C, T, B>,
}
pub struct SendJob<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
inner: Arc<Inner<E, C, T, B>>,
pub struct SendJob<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> (
Arc<Inner<E, C, T, B>>
);
impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for SendJob<E, C, T, B> {
fn clone(&self) -> SendJob<E, C, T, B> {
SendJob(self.0.clone())
}
}
impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SendJob<E, C, T, B> {
@ -32,32 +38,53 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SendJob<E, C
counter: u64,
keypair: Arc<KeyPair>,
peer: Peer<E, C, T, B>
) -> Option<SendJob<E, C, T, B>> {
// create job
let inner = Arc::new(Inner{
) -> SendJob<E, C, T, B> {
SendJob(Arc::new(Inner{
buffer: Mutex::new(buffer),
counter,
keypair,
peer,
ready: AtomicBool::new(false)
});
// attempt to add to queue
if peer.outbound.push(SendJob{ inner: inner.clone()}) {
Some(SendJob{inner})
} else {
None
}
}))
}
}
impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for SendJob<E, C, T, B> {
fn queue(&self) -> &Queue<Self> {
&self.inner.peer.outbound
impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SequentialJob for SendJob<E, C, T, B> {
fn is_ready(&self) -> bool {
self.0.ready.load(Ordering::Acquire)
}
fn is_ready(&self) -> bool {
self.inner.ready.load(Ordering::Acquire)
fn sequential_work(self) {
debug_assert_eq!(
self.is_ready(),
true,
"doing sequential work
on an incomplete job"
);
log::trace!("processing sequential send job");
// send to peer
let job = &self.0;
let msg = job.buffer.lock();
let xmit = job.peer.send_raw(&msg[..]).is_ok();
// trigger callback (for timers)
C::send(
&job.peer.opaque,
msg.len(),
xmit,
&job.keypair,
job.counter,
);
}
}
impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ParallelJob for SendJob<E, C, T, B> {
fn queue(&self) -> &Queue<Self> {
&self.0.peer.outbound
}
fn parallel_work(&self) {
@ -71,7 +98,7 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for Send
// encrypt body
{
// make space for the tag
let job = &*self.inner;
let job = &*self.0;
let mut msg = job.buffer.lock();
msg.extend([0u8; SIZE_TAG].iter());
@ -111,30 +138,6 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for Send
}
// mark ready
self.inner.ready.store(true, Ordering::Release);
}
fn sequential_work(self) {
debug_assert_eq!(
self.is_ready(),
true,
"doing sequential work
on an incomplete job"
);
log::trace!("processing sequential send job");
// send to peer
let job = &self.inner;
let msg = job.buffer.lock();
let xmit = job.peer.send(&msg[..]).is_ok();
// trigger callback (for timers)
C::send(
&job.peer.opaque,
msg.len(),
xmit,
&job.keypair,
job.counter,
);
self.0.ready.store(true, Ordering::Release);
}
}
}

30
src/wireguard/router/worker.rs

@ -1,13 +1,31 @@
use super::Device;
use super::super::{tun, udp, Endpoint};
use super::types::Callbacks;
use super::receive::ReceieveJob;
use super::queue::ParallelJob;
use super::receive::ReceiveJob;
use super::send::SendJob;
fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
device: Device<E, C, T, B>,
use crossbeam_channel::Receiver;
pub enum JobUnion<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
Outbound(SendJob<E, C, T, B>),
Inbound(ReceiveJob<E, C, T, B>),
}
pub fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
receiver: Receiver<JobUnion<E, C, T, B>>,
) {
// fetch job
loop {
match receiver.recv() {
Err(_) => break,
Ok(JobUnion::Inbound(job)) => {
job.parallel_work();
job.queue().consume();
}
Ok(JobUnion::Outbound(job)) => {
job.parallel_work();
job.queue().consume();
}
}
}
}

257
src/wireguard/router/workers.rs

@ -1,257 +0,0 @@
use std::sync::Arc;
use log::{debug, trace};
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
use crossbeam_channel::Receiver;
use std::sync::atomic::Ordering;
use zerocopy::{AsBytes, LayoutVerified};
use super::device::{DecryptionState, DeviceInner};
use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::PeerInner;
use super::types::Callbacks;
use super::REJECT_AFTER_MESSAGES;
use super::super::types::KeyPair;
use super::super::{tun, udp, Endpoint};
pub const SIZE_TAG: usize = 16;
pub struct JobEncryption {
pub msg: Vec<u8>,
pub keypair: Arc<KeyPair>,
pub counter: u64,
}
pub struct JobDecryption {
pub msg: Vec<u8>,
pub keypair: Arc<KeyPair>,
}
pub enum JobParallel {
Encryption(oneshot::Sender<JobEncryption>, JobEncryption),
Decryption(oneshot::Sender<Option<JobDecryption>>, JobDecryption),
}
#[allow(type_alias_bounds)]
pub type JobInbound<E, C, T, B: udp::Writer<E>> = (
Arc<DecryptionState<E, C, T, B>>,
E,
oneshot::Receiver<Option<JobDecryption>>,
);
pub type JobOutbound = oneshot::Receiver<JobEncryption>;
/* TODO: Replace with run-queue
*/
pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
device: Arc<DeviceInner<E, C, T, B>>, // related device
peer: Arc<PeerInner<E, C, T, B>>, // related peer
receiver: Receiver<JobInbound<E, C, T, B>>,
) {
loop {
// fetch job
let (state, endpoint, rx) = match receiver.recv() {
Ok(v) => v,
_ => {
return;
}
};
debug!("inbound worker: obtained job");
// wait for job to complete
let _ = rx
.map(|buf| {
debug!("inbound worker: job complete");
if let Some(buf) = buf {
// cast transport header
let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
match LayoutVerified::new_from_prefix(&buf.msg[..]) {
Some(v) => v,
None => {
debug!("inbound worker: failed to parse message");
return;
}
};
debug_assert!(
packet.len() >= CHACHA20_POLY1305.tag_len(),
"this should be checked earlier in the pipeline (decryption should fail)"
);
// check for replay
if !state.protector.lock().update(header.f_counter.get()) {
debug!("inbound worker: replay detected");
return;
}
// check for confirms key
if !state.confirmed.swap(true, Ordering::SeqCst) {
debug!("inbound worker: message confirms key");
peer.confirm_key(&state.keypair);
}
// update endpoint
*peer.endpoint.lock() = Some(endpoint);
// calculate length of IP packet + padding
let length = packet.len() - SIZE_TAG;
debug!("inbound worker: plaintext length = {}", length);
// check if should be written to TUN
let mut sent = false;
if length > 0 {
if let Some(inner_len) = device.table.check_route(&peer, &packet[..length])
{
// TODO: Consider moving the cryptkey route check to parallel decryption worker
debug_assert!(inner_len <= length, "should be validated earlier");
if inner_len <= length {
sent = match device.inbound.write(&packet[..inner_len]) {
Err(e) => {
debug!("failed to write inbound packet to TUN: {:?}", e);
false
}
Ok(_) => true,
}
}
}
} else {
debug!("inbound worker: received keepalive")
}
// trigger callback
C::recv(&peer.opaque, buf.msg.len(), sent, &buf.keypair);
} else {
debug!("inbound worker: authentication failure")
}
})
.wait();
}
}
pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
peer: Arc<PeerInner<E, C, T, B>>,
receiver: Receiver<JobOutbound>,
) {
loop {
// fetch job
let rx = match receiver.recv() {
Ok(v) => v,
_ => {
return;
}
};
debug!("outbound worker: obtained job");
// wait for job to complete
let _ = rx
.map(|buf| {
debug!("outbound worker: job complete");
// send to peer
let xmit = peer.send(&buf.msg[..]).is_ok();
// trigger callback
C::send(&peer.opaque, buf.msg.len(), xmit, &buf.keypair, buf.counter);
})
.wait();
}
}
pub fn worker_parallel(receiver: Receiver<JobParallel>) {
loop {
// fetch next job
let job = match receiver.recv() {
Err(_) => {
return;
}
Ok(val) => val,
};
trace!("parallel worker: obtained job");
// handle job
match job {
JobParallel::Encryption(tx, mut job) => {
job.msg.extend([0u8; SIZE_TAG].iter());
// cast to header (should never fail)
let (mut header, body): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
LayoutVerified::new_from_prefix(&mut job.msg[..])
.expect("earlier code should ensure that there is ample space");
// set header fields
debug_assert!(
job.counter < REJECT_AFTER_MESSAGES,
"should be checked when assigning counters"
);
header.f_type.set(TYPE_TRANSPORT);
header.f_receiver.set(job.keypair.send.id);
header.f_counter.set(job.counter);
// create a nonce object
let mut nonce = [0u8; 12];
debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
nonce[4..].copy_from_slice(header.f_counter.as_bytes());
let nonce = Nonce::assume_unique_for_key(nonce);
// do the weird ring AEAD dance
let key = LessSafeKey::new(
UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.send.key[..]).unwrap(),
);
// encrypt content of transport message in-place
let end = body.len() - SIZE_TAG;
let tag = key
.seal_in_place_separate_tag(nonce, Aad::empty(), &mut body[..end])
.unwrap();
// append tag
body[end..].copy_from_slice(tag.as_ref());
// pass ownership
let _ = tx.send(job);
}
JobParallel::Decryption(tx, mut job) => {
// cast to header (could fail)
let layout: Option<(LayoutVerified<&mut [u8], TransportHeader>, &mut [u8])> =
LayoutVerified::new_from_prefix(&mut job.msg[..]);
let _ = tx.send(match layout {
Some((header, body)) => {
debug_assert_eq!(
header.f_type.get(),
TYPE_TRANSPORT,
"type and reserved bits should be checked by message de-multiplexer"
);
if header.f_counter.get() < REJECT_AFTER_MESSAGES {
// create a nonce object
let mut nonce = [0u8; 12];
debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
nonce[4..].copy_from_slice(header.f_counter.as_bytes());
let nonce = Nonce::assume_unique_for_key(nonce);
// do the weird ring AEAD dance
let key = LessSafeKey::new(
UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.recv.key[..])
.unwrap(),
);
// attempt to open (and authenticate) the body
match key.open_in_place(nonce, Aad::empty(), body) {
Ok(_) => Some(job),
Err(_) => None,
}
} else {
None
}
}
None => None,
});
}
}
}
}

2
src/wireguard/wireguard.rs

@ -603,7 +603,7 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
);
let device = wg.handshake.read();
let _ = device.begin(&mut rng, &peer.pk).map(|msg| {
let _ = peer.router.send(&msg[..]).map_err(|e| {
let _ = peer.router.send_raw(&msg[..]).map_err(|e| {
debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e)
});
peer.state.sent_handshake_initiation();

Loading…
Cancel
Save