Browse Source

Initial version of sticky sockets for Linux

sticky-sockets
Mathias Hall-Andersen 6 years ago
parent
commit
1e26a0bef4
  1. 192
      src/platform/linux/udp.rs

192
src/platform/linux/udp.rs

@ -21,16 +21,16 @@ fn errno() -> libc::c_int {
}
}
#[repr(C)]
#[repr(C, align(1))]
struct ControlHeaderV4 {
hdr: libc::cmsghdr,
info: libc::in_pktinfo,
}
#[repr(C)]
#[repr(C, align(1))]
struct ControlHeaderV6 {
hdr: libc::cmsghdr,
body: libc::in6_pktinfo,
info: libc::in6_pktinfo,
}
pub struct EndpointV4 {
@ -159,6 +159,7 @@ fn setsockopt<V: Sized>(
}
}
#[inline(always)]
fn setsockopt_int(
fd: RawFd,
level: libc::c_int,
@ -168,6 +169,21 @@ fn setsockopt_int(
setsockopt(fd, level, name, &value)
}
#[allow(non_snake_case)]
const fn CMSG_ALIGN(len: usize) -> usize {
(((len) + mem::size_of::<u32>() - 1) & !(mem::size_of::<u32>() - 1))
}
#[allow(non_snake_case)]
const fn CMSG_LEN(len: usize) -> usize {
CMSG_ALIGN(len + mem::size_of::<libc::cmsghdr>())
}
#[inline(always)]
fn safe_cast<T, D>(v: &mut T) -> *mut D {
(v as *mut T) as *mut D
}
impl LinuxUDPReader {
fn read6(fd: RawFd, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), io::Error> {
log::trace!(
@ -176,43 +192,41 @@ impl LinuxUDPReader {
buf.len()
);
// this memory is mutated by the recvmsg call
#[allow(unused_mut)]
let mut control: ControlHeaderV6 = unsafe { mem::MaybeUninit::uninit().assume_init() };
let iovs: [libc::iovec; 1] = [libc::iovec {
let mut iovs: [libc::iovec; 1] = [libc::iovec {
iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void,
iov_len: buf.len(),
}];
let src: libc::sockaddr_in6 = unsafe { mem::MaybeUninit::uninit().assume_init() };
let mut hdr = unsafe {
libc::msghdr {
msg_name: mem::transmute(&src),
msg_namelen: mem::size_of_val(&src).try_into().unwrap(),
msg_iov: mem::transmute(&iovs[0]),
msg_iovlen: iovs.len(),
msg_control: mem::transmute(&control),
msg_controllen: mem::size_of_val(&control),
msg_flags: 0, // ignored
}
let mut src: libc::sockaddr_in6 = unsafe { mem::MaybeUninit::uninit().assume_init() };
let mut control: ControlHeaderV6 = unsafe { mem::MaybeUninit::uninit().assume_init() };
let mut hdr = libc::msghdr {
msg_name: safe_cast(&mut src),
msg_namelen: mem::size_of::<libc::sockaddr_in6> as u32,
msg_iov: iovs.as_mut_ptr(),
msg_iovlen: iovs.len(),
msg_control: safe_cast(&mut control),
msg_controllen: mem::size_of::<ControlHeaderV6>(),
msg_flags: 0,
};
debug_assert!(
hdr.msg_controllen
>= mem::size_of::<libc::cmsghdr>() + mem::size_of::<libc::in6_pktinfo>(),
);
let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) };
if len < 0 {
log::trace!("failed to receive IPv6 packet (errno = {})", errno());
return Err(io::Error::new(
io::ErrorKind::NotConnected,
"failed to receive",
));
}
log::trace!("received IPv6 packet ({} fd, {} bytes)", fd, len);
Ok((
len.try_into().unwrap(),
LinuxEndpoint::V6(EndpointV6 {
info: control.body,
dst: src,
info: control.info, // save pktinfo (sticky source)
dst: src, // our future destination is the source address
}),
))
}
@ -224,52 +238,40 @@ impl LinuxUDPReader {
buf.len()
);
let iovs: [libc::iovec; 1] = [libc::iovec {
let mut iovs: [libc::iovec; 1] = [libc::iovec {
iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void,
iov_len: buf.len(),
}];
let src: libc::sockaddr_in = unsafe { mem::MaybeUninit::uninit().assume_init() };
// this memory is mutated by the recvmsg call
#[allow(unused_mut)]
let mut src: libc::sockaddr_in = unsafe { mem::MaybeUninit::uninit().assume_init() };
let mut control: ControlHeaderV4 = unsafe { mem::MaybeUninit::uninit().assume_init() };
let mut hdr = unsafe {
libc::msghdr {
msg_name: mem::transmute(&src),
msg_namelen: mem::size_of_val(&src).try_into().unwrap(), // constant
msg_iov: mem::transmute(&iovs[0]),
msg_iovlen: iovs.len(), // constant
msg_control: mem::transmute(&control),
msg_controllen: mem::size_of_val(&control), // constant
msg_flags: 0, // ignored
}
let mut hdr = libc::msghdr {
msg_name: safe_cast(&mut src),
msg_namelen: mem::size_of::<libc::sockaddr_in> as u32,
msg_iov: iovs.as_mut_ptr(),
msg_iovlen: iovs.len(),
msg_control: safe_cast(&mut control),
msg_controllen: mem::size_of::<ControlHeaderV4>(),
msg_flags: 0,
};
debug_assert!(
hdr.msg_controllen
>= mem::size_of::<libc::cmsghdr>() + mem::size_of::<libc::in_pktinfo>(),
);
let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) };
if len < 0 {
log::trace!("failed to receive IPv4 packet (errno = {})", errno());
return Err(io::Error::new(
io::ErrorKind::NotConnected,
"failed to receive",
));
}
log::trace!("read4, len: {}", len);
log::trace!(
"control: {{ hdr : {{ cmsg_level: {}, cmsg_type: {}, cmsg_len: {} }} }}",
control.hdr.cmsg_level,
control.hdr.cmsg_type,
control.hdr.cmsg_len
);
log::trace!("received IPv4 packet ({} fd, {} bytes)", fd, len);
Ok((
len.try_into().unwrap(),
LinuxEndpoint::V4(EndpointV4 {
info: control.info, // save pkinfo (sticky source)
info: control.info, // save pktinfo (sticky source)
dst: src, // our future destination is the source address
}),
))
@ -288,23 +290,82 @@ impl Reader<LinuxEndpoint> for LinuxUDPReader {
}
impl LinuxUDPWriter {
fn write6(fd: RawFd, buf: &[u8], dst: &EndpointV6) -> Result<(), io::Error> {
fn write6(fd: RawFd, buf: &[u8], dst: &mut EndpointV6) -> Result<(), io::Error> {
log::debug!("sending IPv6 packet ({} fd, {} bytes)", fd, buf.len());
unimplemented!()
let mut iovs: [libc::iovec; 1] = [libc::iovec {
iov_base: buf.as_ptr() as *mut core::ffi::c_void,
iov_len: buf.len(),
}];
let mut control = ControlHeaderV6 {
hdr: libc::cmsghdr {
cmsg_len: CMSG_LEN(mem::size_of::<libc::in6_pktinfo>()),
cmsg_level: libc::IPPROTO_IPV6,
cmsg_type: libc::IPV6_PKTINFO,
},
info: dst.info,
};
debug_assert_eq!(
control.hdr.cmsg_len % mem::size_of::<u32>(),
0,
"cmsg_len must be aligned to a long"
);
debug_assert_eq!(
dst.dst.sin6_family,
libc::AF_INET6 as libc::sa_family_t,
"this method only handles IPv6 destinations"
);
let mut hdr = libc::msghdr {
msg_name: safe_cast(&mut dst.dst),
msg_namelen: mem::size_of_val(&dst.dst).try_into().unwrap(),
msg_iov: iovs.as_mut_ptr(),
msg_iovlen: iovs.len(),
msg_control: safe_cast(&mut control),
msg_controllen: mem::size_of_val(&control),
msg_flags: 0,
};
let ret = unsafe { libc::sendmsg(fd, &hdr, 0) };
if ret < 0 {
if errno() == libc::EINVAL {
log::trace!("clear source and retry");
hdr.msg_control = ptr::null_mut();
hdr.msg_controllen = 0;
dst.info = unsafe { mem::zeroed() };
if unsafe { libc::sendmsg(fd, &hdr, 0) } < 0 {
return Err(io::Error::new(
io::ErrorKind::NotConnected,
"failed to send IPv6 packet",
));
} else {
return Ok(());
}
}
return Err(io::Error::new(
io::ErrorKind::NotConnected,
"failed to send IPv6 packet",
));
}
Ok(())
}
fn write4(fd: RawFd, buf: &[u8], dst: &mut EndpointV4) -> Result<(), io::Error> {
log::debug!("sending IPv4 packet ({} fd, {} bytes)", fd, buf.len());
let iovs: [libc::iovec; 1] = [libc::iovec {
let mut iovs: [libc::iovec; 1] = [libc::iovec {
iov_base: buf.as_ptr() as *mut core::ffi::c_void,
iov_len: buf.len(),
}];
let mut control = ControlHeaderV4 {
hdr: libc::cmsghdr {
cmsg_len: mem::size_of::<ControlHeaderV4>(),
cmsg_len: CMSG_LEN(mem::size_of::<libc::in_pktinfo>()),
cmsg_level: libc::IPPROTO_IP,
cmsg_type: libc::IP_PKTINFO,
},
@ -312,18 +373,23 @@ impl LinuxUDPWriter {
};
debug_assert_eq!(
control.hdr.cmsg_len % mem::size_of::<usize>(),
control.hdr.cmsg_len % mem::size_of::<u32>(),
0,
"cmsg_len must be aligned to a word"
"cmsg_len must be aligned to a long"
);
debug_assert_eq!(
dst.dst.sin_family,
libc::AF_INET as libc::sa_family_t,
"this method only handles IPv4 destinations"
);
debug_assert_eq!(dst.dst.sin_family, libc::AF_INET as libc::sa_family_t);
let mut hdr = libc::msghdr {
msg_name: unsafe { mem::transmute(&dst.dst as *const libc::sockaddr_in) },
msg_name: safe_cast(&mut dst.dst),
msg_namelen: mem::size_of_val(&dst.dst).try_into().unwrap(),
msg_iov: iovs.as_ptr() as *mut libc::iovec,
msg_iov: iovs.as_mut_ptr(),
msg_iovlen: iovs.len(),
msg_control: unsafe { mem::transmute(&control as *const ControlHeaderV4) },
msg_control: safe_cast(&mut control),
msg_controllen: mem::size_of_val(&control),
msg_flags: 0,
};
@ -361,7 +427,7 @@ impl Writer<LinuxEndpoint> for LinuxUDPWriter {
fn write(&self, buf: &[u8], dst: &mut LinuxEndpoint) -> Result<(), Self::Error> {
match dst {
LinuxEndpoint::V4(ref mut end) => Self::write4(self.sock4, buf, end),
LinuxEndpoint::V6(ref end) => Self::write6(self.sock6, buf, end),
LinuxEndpoint::V6(ref mut end) => Self::write6(self.sock6, buf, end),
}
}
}

Loading…
Cancel
Save