diff --git a/CHANGELOG.md b/CHANGELOG.md index 94d8d6f..25edb8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## Upcoming +-[[#245](https://github.com/rust-vmm/vmm-sys-util/pull/245)]: Make sock_ctrl_msg support unix. + ## v0.14.0 ### Changed diff --git a/Cargo.toml b/Cargo.toml index 9d5c6d0..f267d85 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ rustdoc-args = ["--cfg", "docsrs"] with-serde = ["serde", "serde_derive"] [dependencies] -libc = "0.2.39" +libc = "0.2.127" serde = { version = "1.0.27", optional = true } serde_derive = { version = "1.0.27", optional = true } diff --git a/src/linux/mod.rs b/src/linux/mod.rs index daf86f5..01aa301 100644 --- a/src/linux/mod.rs +++ b/src/linux/mod.rs @@ -10,6 +10,5 @@ pub mod fallocate; pub mod poll; pub mod seek_hole; pub mod signal; -pub mod sock_ctrl_msg; pub mod timerfd; pub mod write_zeroes; diff --git a/src/unix/mod.rs b/src/unix/mod.rs index 5c26a9c..99e43be 100644 --- a/src/unix/mod.rs +++ b/src/unix/mod.rs @@ -1,5 +1,6 @@ // Copyright 2022 rust-vmm Authors or its affiliates. All Rights Reserved. // SPDX-License-Identifier: BSD-3-Clause pub mod file_traits; +pub mod sock_ctrl_msg; pub mod tempdir; pub mod terminal; diff --git a/src/linux/sock_ctrl_msg.rs b/src/unix/sock_ctrl_msg.rs similarity index 77% rename from src/linux/sock_ctrl_msg.rs rename to src/unix/sock_ctrl_msg.rs index 80af2e8..0d769f8 100644 --- a/src/linux/sock_ctrl_msg.rs +++ b/src/unix/sock_ctrl_msg.rs @@ -16,58 +16,24 @@ use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned}; use crate::errno::{Error, Result}; use libc::{ - c_long, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, MSG_NOSIGNAL, SCM_RIGHTS, SOL_SOCKET, + c_uint, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, CMSG_DATA, CMSG_LEN, CMSG_NXTHDR, + CMSG_SPACE, MSG_NOSIGNAL, SCM_RIGHTS, SOL_SOCKET, }; use std::os::raw::c_int; -// Each of the following macros performs the same function as their C counterparts. They are each -// macros because they are used to size statically allocated arrays. - -macro_rules! CMSG_ALIGN { - ($len:expr) => { - (($len) as usize + size_of::() - 1) & !(size_of::() - 1) - }; -} - -macro_rules! CMSG_SPACE { - ($len:expr) => { - size_of::() + CMSG_ALIGN!($len) - }; -} - -// This function (macro in the C version) is not used in any compile time constant slots, so is just -// an ordinary function. The returned pointer is hard coded to be RawFd because that's all that this -// module supports. -#[allow(non_snake_case)] -#[inline(always)] -fn CMSG_DATA(cmsg_buffer: *mut cmsghdr) -> *mut RawFd { - // Essentially returns a pointer to just past the header. - cmsg_buffer.wrapping_offset(1) as *mut RawFd -} - -#[cfg(not(target_env = "musl"))] -macro_rules! CMSG_LEN { - ($len:expr) => { - size_of::() + ($len) - }; -} - -#[cfg(target_env = "musl")] -macro_rules! CMSG_LEN { - ($len:expr) => {{ - let sz = size_of::() + ($len); - assert!(sz <= (std::u32::MAX as usize)); - sz as u32 - }}; -} - #[cfg(not(target_env = "musl"))] fn new_msghdr(iovecs: &mut [iovec]) -> msghdr { msghdr { msg_name: null_mut(), msg_namelen: 0, msg_iov: iovecs.as_mut_ptr(), + #[cfg(any(target_os = "linux", target_os = "android"))] msg_iovlen: iovecs.len(), + #[cfg(not(any(target_os = "linux", target_os = "android")))] + msg_iovlen: iovecs + .len() + .try_into() + .expect("iovecs.len() exceeds i32 range"), msg_control: null_mut(), msg_controllen: 0, msg_flags: 0, @@ -85,34 +51,26 @@ fn new_msghdr(iovecs: &mut [iovec]) -> msghdr { msg } -#[cfg(not(target_env = "musl"))] +#[cfg(all( + not(target_env = "musl"), + any(target_os = "linux", target_os = "android") +))] fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: usize) { msg.msg_controllen = cmsg_capacity; } -#[cfg(target_env = "musl")] +#[cfg(any( + target_env = "musl", + not(any(target_os = "linux", target_os = "android")) +))] fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: usize) { assert!(cmsg_capacity <= (std::u32::MAX as usize)); msg.msg_controllen = cmsg_capacity as u32; } -// This function is like CMSG_NEXT, but safer because it reads only from references, although it -// does some pointer arithmetic on cmsg_ptr. -#[allow(clippy::cast_ptr_alignment, clippy::unnecessary_cast)] -fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut cmsghdr) -> *mut cmsghdr { - let next_cmsg = (cmsg_ptr as *mut u8).wrapping_add(CMSG_ALIGN!(cmsg.cmsg_len)) as *mut cmsghdr; - if next_cmsg - .wrapping_offset(1) - .wrapping_sub(msghdr.msg_control as usize) as usize - > msghdr.msg_controllen as usize - { - null_mut() - } else { - next_cmsg - } -} - -const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::() * 32); +// SAFETY: CMSG_SPACE is a pure calculation. The input value will not exceed the range of c_uint +const CMSG_BUFFER_INLINE_CAPACITY: usize = + unsafe { CMSG_SPACE(size_of::() as u32 * 32) as usize }; enum CmsgBuffer { Inline([u64; CMSG_BUFFER_INLINE_CAPACITY.div_ceil(8)]), @@ -151,7 +109,14 @@ impl CmsgBuffer { } fn raw_sendmsg(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Result { - let cmsg_capacity = CMSG_SPACE!(std::mem::size_of_val(out_fds)); + // SAFETY: CMSG_SPACE is a pure calculation. We ensure that the input value does not exceed the range of c_uint + let cmsg_capacity = unsafe { + CMSG_SPACE( + size_of_val(out_fds) + .try_into() + .map_err(|_| Error::new(libc::E2BIG))?, + ) as usize + }; let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); let mut iovecs = Vec::with_capacity(out_data.len()); @@ -165,8 +130,25 @@ fn raw_sendmsg(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Re let mut msg = new_msghdr(&mut iovecs); if !out_fds.is_empty() { + // SAFETY: We ensure that the input value does not exceed the range of c_uint. c_uint to usize is a safe conversion + let cmsg_len = unsafe { + CMSG_LEN( + size_of_val(out_fds) + .try_into() + .map_err(|_| Error::new(libc::E2BIG))?, + ) + }; let cmsg = cmsghdr { - cmsg_len: CMSG_LEN!(std::mem::size_of_val(out_fds)), + #[cfg(all( + any(target_os = "linux", target_os = "android"), + not(target_env = "musl") + ))] + cmsg_len: cmsg_len as usize, + #[cfg(any( + not(any(target_os = "linux", target_os = "android")), + target_env = "musl" + ))] + cmsg_len: cmsg_len, cmsg_level: SOL_SOCKET, cmsg_type: SCM_RIGHTS, #[cfg(all(target_env = "musl", target_pointer_width = "64"))] @@ -180,7 +162,7 @@ fn raw_sendmsg(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Re // file descriptors. copy_nonoverlapping( out_fds.as_ptr(), - CMSG_DATA(cmsg_buffer.as_mut_ptr()), + CMSG_DATA(cmsg_buffer.as_mut_ptr()).cast(), out_fds.len(), ); } @@ -206,7 +188,7 @@ unsafe fn raw_recvmsg( iovecs: &mut [iovec], in_fds: &mut [RawFd], ) -> Result<(usize, usize)> { - let cmsg_capacity = CMSG_SPACE!(std::mem::size_of_val(in_fds)); + let cmsg_capacity = CMSG_SPACE(size_of_val(in_fds) as c_uint) as usize; let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); let mut msg = new_msghdr(iovecs); @@ -242,7 +224,8 @@ unsafe fn raw_recvmsg( // read. let cmsg = (cmsg_ptr as *mut cmsghdr).read_unaligned(); if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS { - let fds_count: usize = ((cmsg.cmsg_len - CMSG_LEN!(0)) as usize) / size_of::(); + let cmsg_len: usize = std::cmp::min(cmsg.cmsg_len as usize, cmsg_capacity); + let fds_count: usize = (cmsg_len - CMSG_LEN(0) as usize) / size_of::(); // The sender can transmit more data than we can buffer. If a message is too long to // fit in the supplied buffer, excess bytes may be discarded depending on the type of // socket the message is received from. @@ -254,7 +237,7 @@ unsafe fn raw_recvmsg( // data must be dropped to insufficient buffer space for returning them to outer // scope. This might be a sign of incorrect protocol communication. for fd_offset in 0..fds_count { - let raw_fds_ptr = CMSG_DATA(cmsg_ptr); + let raw_fds_ptr: *mut RawFd = CMSG_DATA(cmsg_ptr).cast(); // The cmsg_ptr is valid here because is checked at the beginning of the // loop and it is assured to have `fds_count` fds available. let raw_fd = *(raw_fds_ptr.wrapping_add(fd_offset)) as c_int; @@ -264,7 +247,7 @@ unsafe fn raw_recvmsg( // Safe because `cmsg_ptr` is checked against null and we copy from `cmesg_buffer` to // `in_fds` according to their current capacity. copy_nonoverlapping( - CMSG_DATA(cmsg_ptr), + CMSG_DATA(cmsg_ptr).cast(), in_fds[copied_fds_count..(copied_fds_count + fds_to_be_copied_count)] .as_mut_ptr(), fds_to_be_copied_count, @@ -285,7 +268,7 @@ unsafe fn raw_recvmsg( return Err(Error::new(libc::ENOBUFS)); } - cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr); + cmsg_ptr = CMSG_NXTHDR(&msg, &cmsg); } Ok((total_read as usize, copied_fds_count)) @@ -459,52 +442,12 @@ unsafe impl IntoIovec for &[u8] { mod tests { #![allow(clippy::undocumented_unsafe_blocks)] use super::*; - use crate::eventfd::EventFd; + use std::io::{pipe, Read}; use std::io::Write; - use std::mem::size_of; - use std::os::raw::c_long; use std::os::unix::net::UnixDatagram; use std::slice::from_raw_parts; - use libc::cmsghdr; - - #[test] - fn buffer_len() { - assert_eq!(CMSG_SPACE!(0), size_of::()); - assert_eq!( - CMSG_SPACE!(size_of::()), - size_of::() + size_of::() - ); - if size_of::() == 4 { - assert_eq!( - CMSG_SPACE!(2 * size_of::()), - size_of::() + size_of::() - ); - assert_eq!( - CMSG_SPACE!(3 * size_of::()), - size_of::() + size_of::() * 2 - ); - assert_eq!( - CMSG_SPACE!(4 * size_of::()), - size_of::() + size_of::() * 2 - ); - } else if size_of::() == 8 { - assert_eq!( - CMSG_SPACE!(2 * size_of::()), - size_of::() + size_of::() * 2 - ); - assert_eq!( - CMSG_SPACE!(3 * size_of::()), - size_of::() + size_of::() * 3 - ); - assert_eq!( - CMSG_SPACE!(4 * size_of::()), - size_of::() + size_of::() * 4 - ); - } - } - #[test] fn send_recv_no_fd() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); @@ -535,9 +478,9 @@ mod tests { fn send_recv_only_fd() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let evt = EventFd::new(0).expect("failed to create eventfd"); + let (mut evt_consumer, evt_notifier) = pipe().expect("failed to create pipe"); let write_count = s1 - .send_with_fd([].as_ref(), evt.as_raw_fd()) + .send_with_fd([].as_ref(), evt_notifier.as_raw_fd()) .expect("failed to send fd"); assert_eq!(write_count, 0); @@ -550,21 +493,25 @@ mod tests { assert!(file.as_raw_fd() >= 0); assert_ne!(file.as_raw_fd(), s1.as_raw_fd()); assert_ne!(file.as_raw_fd(), s2.as_raw_fd()); - assert_ne!(file.as_raw_fd(), evt.as_raw_fd()); + assert_ne!(file.as_raw_fd(), evt_notifier.as_raw_fd()); file.write_all(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) .expect("failed to write to sent fd"); - assert_eq!(evt.read().expect("failed to read from eventfd"), 1203); + let mut buf = [0u8; std::mem::size_of::()]; + evt_consumer + .read_exact(buf.as_mut_slice()) + .expect("Failed to read from PipeReader"); + assert_eq!(u64::from_ne_bytes(buf), 1203); } #[test] fn send_recv_with_fd() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let evt = EventFd::new(0).expect("failed to create eventfd"); + let (mut evt_consumer, evt_notifier) = pipe().expect("failed to create pipe"); let write_count = s1 - .send_with_fds(&[[237].as_ref()], &[evt.as_raw_fd()]) + .send_with_fds(&[[237].as_ref()], &[evt_notifier.as_raw_fd()]) .expect("failed to send fd"); assert_eq!(write_count, 1); @@ -586,14 +533,18 @@ mod tests { assert!(files[0] >= 0); assert_ne!(files[0], s1.as_raw_fd()); assert_ne!(files[0], s2.as_raw_fd()); - assert_ne!(files[0], evt.as_raw_fd()); + assert_ne!(files[0], evt_notifier.as_raw_fd()); let mut file = unsafe { File::from_raw_fd(files[0]) }; file.write_all(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) .expect("failed to write to sent fd"); - assert_eq!(evt.read().expect("failed to read from eventfd"), 1203); + let mut buf = [0u8; std::mem::size_of::()]; + evt_consumer + .read_exact(buf.as_mut_slice()) + .expect("Failed to read from PipeReader"); + assert_eq!(u64::from_ne_bytes(buf), 1203); } #[test] @@ -602,18 +553,18 @@ mod tests { fn send_more_recv_less1() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let evt1 = EventFd::new(0).expect("failed to create eventfd"); - let evt2 = EventFd::new(0).expect("failed to create eventfd"); - let evt3 = EventFd::new(0).expect("failed to create eventfd"); - let evt4 = EventFd::new(0).expect("failed to create eventfd"); + let (_, evt_notifier1) = pipe().expect("failed to create pipe"); + let (_, evt_notifier2) = pipe().expect("failed to create pipe"); + let (_, evt_notifier3) = pipe().expect("failed to create pipe"); + let (_, evt_notifier4) = pipe().expect("failed to create pipe"); let write_count = s1 .send_with_fds( &[[237].as_ref()], &[ - evt1.as_raw_fd(), - evt2.as_raw_fd(), - evt3.as_raw_fd(), - evt4.as_raw_fd(), + evt_notifier1.as_raw_fd(), + evt_notifier2.as_raw_fd(), + evt_notifier3.as_raw_fd(), + evt_notifier4.as_raw_fd(), ], ) .expect("failed to send fd"); @@ -635,18 +586,18 @@ mod tests { fn send_more_recv_less2() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let evt1 = EventFd::new(0).expect("failed to create eventfd"); - let evt2 = EventFd::new(0).expect("failed to create eventfd"); - let evt3 = EventFd::new(0).expect("failed to create eventfd"); - let evt4 = EventFd::new(0).expect("failed to create eventfd"); + let (_, evt_notifier1) = pipe().expect("failed to create pipe"); + let (_, evt_notifier2) = pipe().expect("failed to create pipe"); + let (_, evt_notifier3) = pipe().expect("failed to create pipe"); + let (_, evt_notifier4) = pipe().expect("failed to create pipe"); let write_count = s1 .send_with_fds( &[[237].as_ref()], &[ - evt1.as_raw_fd(), - evt2.as_raw_fd(), - evt3.as_raw_fd(), - evt4.as_raw_fd(), + evt_notifier1.as_raw_fd(), + evt_notifier2.as_raw_fd(), + evt_notifier3.as_raw_fd(), + evt_notifier4.as_raw_fd(), ], ) .expect("failed to send fd");