mirror of https://git.zx2c4.com/wireguard-rs
6 changed files with 395 additions and 185 deletions
@ -0,0 +1,228 @@
|
||||
use std::collections::HashMap; |
||||
|
||||
use std::net::{Ipv4Addr, Ipv6Addr}; |
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; |
||||
use std::sync::mpsc::sync_channel; |
||||
use std::sync::mpsc::SyncSender; |
||||
use std::sync::Arc; |
||||
use std::thread; |
||||
use std::time::Instant; |
||||
|
||||
use log::debug; |
||||
use spin::{Mutex, RwLock}; |
||||
use treebitmap::IpLookupTable; |
||||
use zerocopy::LayoutVerified; |
||||
|
||||
use super::anti_replay::AntiReplay; |
||||
use super::constants::*; |
||||
|
||||
use super::messages::{TransportHeader, TYPE_TRANSPORT}; |
||||
use super::peer::{new_peer, Peer, PeerInner}; |
||||
use super::types::{Callbacks, RouterError}; |
||||
use super::workers::{worker_parallel, JobParallel}; |
||||
use super::SIZE_MESSAGE_PREFIX; |
||||
|
||||
use super::route::get_route; |
||||
|
||||
use super::super::{bind, tun, Endpoint, KeyPair}; |
||||
|
||||
pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { |
||||
// inbound writer (TUN)
|
||||
pub inbound: T, |
||||
|
||||
// outbound writer (Bind)
|
||||
pub outbound: RwLock<(bool, Option<B>)>, |
||||
|
||||
// routing
|
||||
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state
|
||||
pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv4 cryptkey routing
|
||||
pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv6 cryptkey routing
|
||||
|
||||
// work queues
|
||||
pub queue_next: AtomicUsize, // next round-robin index
|
||||
pub queues: Mutex<Vec<SyncSender<JobParallel>>>, // work queues (1 per thread)
|
||||
} |
||||
|
||||
pub struct EncryptionState { |
||||
pub keypair: Arc<KeyPair>, // keypair
|
||||
pub nonce: u64, // next available nonce
|
||||
pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout)
|
||||
} |
||||
|
||||
pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { |
||||
pub keypair: Arc<KeyPair>, |
||||
pub confirmed: AtomicBool, |
||||
pub protector: Mutex<AntiReplay>, |
||||
pub peer: Arc<PeerInner<E, C, T, B>>, |
||||
pub death: Instant, // time when the key can no longer be used for decryption
|
||||
} |
||||
|
||||
pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { |
||||
state: Arc<DeviceInner<E, C, T, B>>, // reference to device state
|
||||
handles: Vec<thread::JoinHandle<()>>, // join handles for workers
|
||||
} |
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Device<E, C, T, B> { |
||||
fn drop(&mut self) { |
||||
debug!("router: dropping device"); |
||||
|
||||
// drop all queues
|
||||
{ |
||||
let mut queues = self.state.queues.lock(); |
||||
while queues.pop().is_some() {} |
||||
} |
||||
|
||||
// join all worker threads
|
||||
while match self.handles.pop() { |
||||
Some(handle) => { |
||||
handle.thread().unpark(); |
||||
handle.join().unwrap(); |
||||
true |
||||
} |
||||
_ => false, |
||||
} {} |
||||
|
||||
debug!("router: device dropped"); |
||||
} |
||||
} |
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> { |
||||
pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> { |
||||
// allocate shared device state
|
||||
let inner = DeviceInner { |
||||
inbound: tun, |
||||
outbound: RwLock::new((true, None)), |
||||
queues: Mutex::new(Vec::with_capacity(num_workers)), |
||||
queue_next: AtomicUsize::new(0), |
||||
recv: RwLock::new(HashMap::new()), |
||||
ipv4: RwLock::new(IpLookupTable::new()), |
||||
ipv6: RwLock::new(IpLookupTable::new()), |
||||
}; |
||||
|
||||
// start worker threads
|
||||
let mut threads = Vec::with_capacity(num_workers); |
||||
for _ in 0..num_workers { |
||||
let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE); |
||||
inner.queues.lock().push(tx); |
||||
threads.push(thread::spawn(move || worker_parallel(rx))); |
||||
} |
||||
|
||||
// return exported device handle
|
||||
Device { |
||||
state: Arc::new(inner), |
||||
handles: threads, |
||||
} |
||||
} |
||||
|
||||
/// Brings the router down.
|
||||
/// When the router is brought down it:
|
||||
/// - Prevents transmission of outbound messages.
|
||||
pub fn down(&self) { |
||||
self.state.outbound.write().0 = false; |
||||
} |
||||
|
||||
/// Brints the router up
|
||||
/// When the router is brought up it enables the transmission of outbound messages.
|
||||
pub fn up(&self) { |
||||
self.state.outbound.write().0 = true; |
||||
} |
||||
|
||||
/// A new secret key has been set for the device.
|
||||
/// According to WireGuard semantics, this should cause all "sending" keys to be discarded.
|
||||
pub fn new_sk(&self) {} |
||||
|
||||
/// Adds a new peer to the device
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A atomic ref. counted peer (with liftime matching the device)
|
||||
pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> { |
||||
new_peer(self.state.clone(), opaque) |
||||
} |
||||
|
||||
/// Cryptkey routes and sends a plaintext message (IP packet)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - msg: IP packet to crypt-key route
|
||||
///
|
||||
pub fn send(&self, msg: Vec<u8>) -> Result<(), RouterError> { |
||||
debug_assert!(msg.len() > SIZE_MESSAGE_PREFIX); |
||||
log::trace!( |
||||
"Router, outbound packet = {}", |
||||
hex::encode(&msg[SIZE_MESSAGE_PREFIX..]) |
||||
); |
||||
|
||||
// ignore header prefix (for in-place transport message construction)
|
||||
let packet = &msg[SIZE_MESSAGE_PREFIX..]; |
||||
|
||||
// lookup peer based on IP packet destination address
|
||||
let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptoKeyRoute)?; |
||||
|
||||
// schedule for encryption and transmission to peer
|
||||
if let Some(job) = peer.send_job(msg, true) { |
||||
// add job to worker queue
|
||||
let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); |
||||
let queues = self.state.queues.lock(); |
||||
queues[idx % queues.len()].send(job).unwrap(); |
||||
} |
||||
|
||||
Ok(()) |
||||
} |
||||
|
||||
/// Receive an encrypted transport message
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - src: Source address of the packet
|
||||
/// - msg: Encrypted transport message
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
///
|
||||
pub fn recv(&self, src: E, msg: Vec<u8>) -> Result<(), RouterError> { |
||||
// parse / cast
|
||||
let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) { |
||||
Some(v) => v, |
||||
None => { |
||||
return Err(RouterError::MalformedTransportMessage); |
||||
} |
||||
}; |
||||
|
||||
let header: LayoutVerified<&[u8], TransportHeader> = header; |
||||
|
||||
debug_assert!( |
||||
header.f_type.get() == TYPE_TRANSPORT as u32, |
||||
"this should be checked by the message type multiplexer" |
||||
); |
||||
|
||||
log::trace!( |
||||
"Router, handle transport message: (receiver = {}, counter = {})", |
||||
header.f_receiver, |
||||
header.f_counter |
||||
); |
||||
|
||||
// lookup peer based on receiver id
|
||||
let dec = self.state.recv.read(); |
||||
let dec = dec |
||||
.get(&header.f_receiver.get()) |
||||
.ok_or(RouterError::UnknownReceiverId)?; |
||||
|
||||
// schedule for decryption and TUN write
|
||||
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) { |
||||
// add job to worker queue
|
||||
let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); |
||||
let queues = self.state.queues.lock(); |
||||
queues[idx % queues.len()].send(job).unwrap(); |
||||
} |
||||
|
||||
Ok(()) |
||||
} |
||||
|
||||
/// Set outbound writer
|
||||
///
|
||||
///
|
||||
pub fn set_outbound_writer(&self, new: B) { |
||||
self.state.outbound.write().1 = Some(new); |
||||
} |
||||
} |
||||
@ -1,113 +1,163 @@
|
||||
use super::super::{bind, tun, Endpoint}; |
||||
use super::device::DeviceInner; |
||||
use super::ip::*; |
||||
use super::peer::PeerInner; |
||||
use super::types::Callbacks; |
||||
|
||||
use log::trace; |
||||
use zerocopy::LayoutVerified; |
||||
|
||||
use std::mem; |
||||
use std::net::{Ipv4Addr, Ipv6Addr}; |
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; |
||||
use std::sync::Arc; |
||||
|
||||
#[inline(always)] |
||||
pub fn get_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( |
||||
device: &Arc<DeviceInner<E, C, T, B>>, |
||||
packet: &[u8], |
||||
) -> Option<Arc<PeerInner<E, C, T, B>>> { |
||||
match packet.get(0)? >> 4 { |
||||
VERSION_IP4 => { |
||||
// check length and cast to IPv4 header
|
||||
let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = |
||||
LayoutVerified::new_from_prefix(packet)?; |
||||
|
||||
log::trace!( |
||||
"Router, get route for IPv4 destination: {:?}", |
||||
Ipv4Addr::from(header.f_destination) |
||||
); |
||||
|
||||
// check IPv4 source address
|
||||
device |
||||
.ipv4 |
||||
.read() |
||||
.longest_match(Ipv4Addr::from(header.f_destination)) |
||||
.and_then(|(_, _, p)| Some(p.clone())) |
||||
use spin::RwLock; |
||||
use treebitmap::address::Address; |
||||
use treebitmap::IpLookupTable; |
||||
|
||||
/* Functions for obtaining and validating "cryptokey" routes */ |
||||
|
||||
pub struct RoutingTable<T> { |
||||
ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<T>>>, |
||||
ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<T>>>, |
||||
} |
||||
|
||||
impl<T> RoutingTable<T> { |
||||
pub fn new() -> Self { |
||||
RoutingTable { |
||||
ipv4: RwLock::new(IpLookupTable::new()), |
||||
ipv6: RwLock::new(IpLookupTable::new()), |
||||
} |
||||
VERSION_IP6 => { |
||||
// check length and cast to IPv6 header
|
||||
let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = |
||||
LayoutVerified::new_from_prefix(packet)?; |
||||
|
||||
log::trace!( |
||||
"Router, get route for IPv6 destination: {:?}", |
||||
Ipv6Addr::from(header.f_destination) |
||||
); |
||||
|
||||
// check IPv6 source address
|
||||
device |
||||
.ipv6 |
||||
.read() |
||||
.longest_match(Ipv6Addr::from(header.f_destination)) |
||||
.and_then(|(_, _, p)| Some(p.clone())) |
||||
} |
||||
|
||||
fn collect<A>(table: &IpLookupTable<A, Arc<T>>, value: &Arc<T>) -> Vec<(A, u32)> |
||||
where |
||||
A: Address, |
||||
{ |
||||
let mut res = Vec::new(); |
||||
for (ip, cidr, v) in table.iter() { |
||||
if Arc::ptr_eq(v, value) { |
||||
res.push((ip, cidr)) |
||||
} |
||||
} |
||||
_ => None, |
||||
res |
||||
} |
||||
} |
||||
|
||||
#[inline(always)] |
||||
pub fn check_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( |
||||
device: &Arc<DeviceInner<E, C, T, B>>, |
||||
peer: &Arc<PeerInner<E, C, T, B>>, |
||||
packet: &[u8], |
||||
) -> Option<usize> { |
||||
match packet.get(0)? >> 4 { |
||||
VERSION_IP4 => { |
||||
// check length and cast to IPv4 header
|
||||
let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = |
||||
LayoutVerified::new_from_prefix(packet)?; |
||||
|
||||
log::trace!( |
||||
"Router, check route for IPv4 source: {:?}", |
||||
Ipv4Addr::from(header.f_source) |
||||
); |
||||
|
||||
// check IPv4 source address
|
||||
device |
||||
.ipv4 |
||||
.read() |
||||
.longest_match(Ipv4Addr::from(header.f_source)) |
||||
.and_then(|(_, _, p)| { |
||||
if Arc::ptr_eq(p, peer) { |
||||
Some(header.f_total_len.get() as usize) |
||||
} else { |
||||
None |
||||
} |
||||
}) |
||||
pub fn list(&self, value: &Arc<T>) -> Vec<(IpAddr, u32)> { |
||||
let mut res = vec![]; |
||||
res.extend( |
||||
Self::collect(&*self.ipv4.read(), value) |
||||
.into_iter() |
||||
.map(|(ip, cidr)| (IpAddr::V4(ip), cidr)), |
||||
); |
||||
res.extend( |
||||
Self::collect(&*self.ipv6.read(), value) |
||||
.into_iter() |
||||
.map(|(ip, cidr)| (IpAddr::V6(ip), cidr)), |
||||
); |
||||
res |
||||
} |
||||
|
||||
pub fn remove(&self, value: &Arc<T>) { |
||||
let mut v4 = self.ipv4.write(); |
||||
let mut v6 = self.ipv6.write(); |
||||
for (ip, cidr) in Self::collect(&*v4, value) { |
||||
v4.remove(ip, cidr); |
||||
} |
||||
VERSION_IP6 => { |
||||
// check length and cast to IPv6 header
|
||||
let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = |
||||
LayoutVerified::new_from_prefix(packet)?; |
||||
|
||||
log::trace!( |
||||
"Router, check route for IPv6 source: {:?}", |
||||
Ipv6Addr::from(header.f_source) |
||||
); |
||||
|
||||
// check IPv6 source address
|
||||
device |
||||
.ipv6 |
||||
.read() |
||||
.longest_match(Ipv6Addr::from(header.f_source)) |
||||
.and_then(|(_, _, p)| { |
||||
if Arc::ptr_eq(p, peer) { |
||||
Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>()) |
||||
} else { |
||||
None |
||||
} |
||||
}) |
||||
for (ip, cidr) in Self::collect(&*v6, value) { |
||||
v6.remove(ip, cidr); |
||||
} |
||||
_ => None, |
||||
} |
||||
|
||||
#[inline(always)] |
||||
pub fn get_route(&self, packet: &[u8]) -> Option<Arc<T>> { |
||||
match packet.get(0)? >> 4 { |
||||
VERSION_IP4 => { |
||||
// check length and cast to IPv4 header
|
||||
let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = |
||||
LayoutVerified::new_from_prefix(packet)?; |
||||
|
||||
log::trace!( |
||||
"Router, get route for IPv4 destination: {:?}", |
||||
Ipv4Addr::from(header.f_destination) |
||||
); |
||||
|
||||
// check IPv4 source address
|
||||
self.ipv4 |
||||
.read() |
||||
.longest_match(Ipv4Addr::from(header.f_destination)) |
||||
.and_then(|(_, _, p)| Some(p.clone())) |
||||
} |
||||
VERSION_IP6 => { |
||||
// check length and cast to IPv6 header
|
||||
let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = |
||||
LayoutVerified::new_from_prefix(packet)?; |
||||
|
||||
log::trace!( |
||||
"Router, get route for IPv6 destination: {:?}", |
||||
Ipv6Addr::from(header.f_destination) |
||||
); |
||||
|
||||
// check IPv6 source address
|
||||
self.ipv6 |
||||
.read() |
||||
.longest_match(Ipv6Addr::from(header.f_destination)) |
||||
.and_then(|(_, _, p)| Some(p.clone())) |
||||
} |
||||
_ => None, |
||||
} |
||||
} |
||||
|
||||
#[inline(always)] |
||||
pub fn check_route(&self, peer: &Arc<T>, packet: &[u8]) -> Option<usize> { |
||||
match packet.get(0)? >> 4 { |
||||
VERSION_IP4 => { |
||||
// check length and cast to IPv4 header
|
||||
let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = |
||||
LayoutVerified::new_from_prefix(packet)?; |
||||
|
||||
log::trace!( |
||||
"Router, check route for IPv4 source: {:?}", |
||||
Ipv4Addr::from(header.f_source) |
||||
); |
||||
|
||||
// check IPv4 source address
|
||||
self.ipv4 |
||||
.read() |
||||
.longest_match(Ipv4Addr::from(header.f_source)) |
||||
.and_then(|(_, _, p)| { |
||||
if Arc::ptr_eq(p, peer) { |
||||
Some(header.f_total_len.get() as usize) |
||||
} else { |
||||
None |
||||
} |
||||
}) |
||||
} |
||||
VERSION_IP6 => { |
||||
// check length and cast to IPv6 header
|
||||
let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = |
||||
LayoutVerified::new_from_prefix(packet)?; |
||||
|
||||
log::trace!( |
||||
"Router, check route for IPv6 source: {:?}", |
||||
Ipv6Addr::from(header.f_source) |
||||
); |
||||
|
||||
// check IPv6 source address
|
||||
self.ipv6 |
||||
.read() |
||||
.longest_match(Ipv6Addr::from(header.f_source)) |
||||
.and_then(|(_, _, p)| { |
||||
if Arc::ptr_eq(p, peer) { |
||||
Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>()) |
||||
} else { |
||||
None |
||||
} |
||||
}) |
||||
} |
||||
_ => None, |
||||
} |
||||
} |
||||
|
||||
pub fn insert(&self, ip: IpAddr, cidr: u32, value: Arc<T>) { |
||||
match ip { |
||||
IpAddr::V4(v4) => self.ipv4.write().insert(v4.mask(cidr), cidr, value), |
||||
IpAddr::V6(v6) => self.ipv6.write().insert(v6.mask(cidr), cidr, value), |
||||
}; |
||||
} |
||||
} |
||||
|
||||
Loading…
Reference in new issue