diff --git a/.cargo/config.toml b/.cargo/config.toml index bda566bd..04cacc9b 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,3 +1,3 @@ [target.aarch64-unknown-linux-gnu] linker = "aarch64-linux-gnu-gcc" -runner = ["qemu-aarch64-static"] # use qemu user emulation for cargo run and test +runner = ["qemu-aarch64-static"] # use qemu user emulation for cargo run and test \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 7584deae..1a4db3be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1218,6 +1218,7 @@ dependencies = [ "libc", "lightway-core", "metrics", + "parking_lot", "pnet", "serde", "serde_with", @@ -1276,6 +1277,7 @@ dependencies = [ "more-asserts", "num_enum", "once_cell", + "parking_lot", "pnet", "rand", "rand_core", @@ -1309,6 +1311,7 @@ dependencies = [ "ctrlc", "delegate", "educe", + "hashbrown 0.15.2", "ipnet", "jsonwebtoken", "libc", @@ -1317,6 +1320,7 @@ dependencies = [ "metrics", "metrics-util", "more-asserts", + "parking_lot", "pnet", "ppp", "pwhash", @@ -1342,6 +1346,16 @@ version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.25" @@ -1579,6 +1593,29 @@ version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + [[package]] name = "paste" version = "1.0.15" @@ -1867,6 +1904,15 @@ dependencies = [ "bitflags", ] +[[package]] +name = "redox_syscall" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" +dependencies = [ + "bitflags", +] + [[package]] name = "regex" version = "1.11.1" @@ -1935,6 +1981,12 @@ version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "serde" version = "1.0.217" diff --git a/lightway-app-utils/Cargo.toml b/lightway-app-utils/Cargo.toml index df8da5a2..9d677679 100644 --- a/lightway-app-utils/Cargo.toml +++ b/lightway-app-utils/Cargo.toml @@ -11,7 +11,7 @@ readme = "README.md" [features] default = [ "tokio" ] -io-uring = [ "dep:io-uring", "dep:tokio", "dep:tokio-eventfd" ] +io-uring = [ "dep:io-uring", "dep:tokio", "dep:tokio-eventfd", "dep:parking_lot" ] tokio = [ "dep:tokio", "dep:tokio-stream" ] [lints] @@ -28,6 +28,7 @@ ipnet.workspace = true libc.workspace = true lightway-core.workspace = true metrics.workspace = true +parking_lot = { version = "0.12.3", optional = true } serde.workspace = true serde_with = "3.4.0" serde_yaml = "0.9.34" diff --git a/lightway-app-utils/examples/udprelay.rs b/lightway-app-utils/examples/udprelay.rs index bd89bed8..957acbdc 100644 --- a/lightway-app-utils/examples/udprelay.rs +++ b/lightway-app-utils/examples/udprelay.rs @@ -206,14 +206,18 @@ struct TunIOUring { impl TunIOUring { async fn new(tun: Tun, ring_size: usize, channel_size: usize) -> Result { - let tun_iouring = IOUring::new( + let tun_iouring = match IOUring::new( Arc::new(WrappedTun(tun)), ring_size, channel_size, TUN_MTU, Duration::from_millis(100), ) - .await?; + .await + { + Ok(it) => it, + Err(err) => return Err(err), + }; Ok(Self { tun_iouring }) } @@ -231,7 +235,13 @@ impl TunAdapter for TunIOUring { } async fn recv_from_tun(&self) -> Result { - self.tun_iouring.recv().await.map_err(anyhow::Error::msg) + match self.tun_iouring.recv().await { + IOCallbackResult::Ok(pkt) => Ok(pkt), + IOCallbackResult::WouldBlock => { + Err(std::io::Error::from(std::io::ErrorKind::WouldBlock).into()) + } + IOCallbackResult::Err(err) => Err(err.into()), + } } } diff --git a/lightway-app-utils/src/iouring.rs b/lightway-app-utils/src/iouring.rs index d96108d0..a5a92ebc 100644 --- a/lightway-app-utils/src/iouring.rs +++ b/lightway-app-utils/src/iouring.rs @@ -1,464 +1,760 @@ -use anyhow::{Context, Result, anyhow}; -use bytes::{BufMut, Bytes, BytesMut}; -use lightway_core::IOCallbackResult; -use thiserror::Error; - use crate::metrics; -use io_uring::{ - Builder, IoUring, SubmissionQueue, Submitter, cqueue::Entry as CEntry, opcode, - squeue::Entry as SEntry, types::Fixed, -}; +use anyhow::{Context, Result}; +use bytes::BytesMut; +use io_uring::{IoUring, opcode, squeue::PushError, types}; +use libc::iovec; +use lightway_core::IOCallbackResult; +use parking_lot::Mutex; use std::{ - os::fd::{AsRawFd, RawFd}, - sync::Arc, + alloc::{Layout, alloc_zeroed, dealloc}, + os::fd::AsRawFd, + sync::{ + Arc, + atomic::{AtomicBool, AtomicUsize, Ordering}, + }, thread, time::Duration, }; -use tokio::{ - io::AsyncReadExt, - sync::{Mutex, mpsc}, -}; -use tokio_eventfd::EventFd; +use tokio::sync::Notify; + +// ------------------------------------------------------------- +// - IMPLEMENT read-multishot and RUNTIME variations - +// ------------------------------------------------------------- + +// NOTE: temp until this is merged: https://github.com/tokio-rs/io-uring/pull/317 + +use io_uring::squeue::Entry; + +pub const IORING_OP_READ_MULTISHOT: u8 = 49; + +#[repr(C)] +pub struct CustomSQE { + pub opcode: u8, + pub flags: u8, + pub ioprio: u16, + pub fd: i32, + pub off_or_addr2: Union1, + pub addr_or_splice_off_in: Union2, + pub len: u32, + pub msg_flags: Union3, + pub user_data: u64, + pub buf_index: PackedU16, // Note: this is packed! + pub personality: u16, + pub splice_fd: Union5, + pub __pad2: [u64; 2], // The final union covers 16 bytes +} -const REGISTERED_FD_INDEX: u32 = 0; +#[repr(C)] +pub union Union1 { + pub off: u64, + pub addr2: u64, + pub cmd_op: std::mem::ManuallyDrop, +} -/// IO-uring Struct -pub struct IOUring { - /// Any struct corresponds to a file descriptor - owned_fd: Arc, +#[repr(C)] +pub struct CmdOp { + pub cmd_op: u32, + pub __pad1: u32, +} - tx_queue: mpsc::Sender, - rx_queue: Mutex>, +#[repr(C)] +pub union Union2 { + pub addr: u64, + pub splice_off_in: u64, + pub level_optname: std::mem::ManuallyDrop, } -/// An error from read/write operation -#[derive(Debug, Error)] -pub enum IOUringError { - /// A recv error occurred - #[error("Recv Error")] - RecvError, +#[repr(C)] +pub struct SockLevel { + pub level: u32, + pub optname: u32, +} - /// A send error occurred - #[error("Send Error")] - SendError, +#[repr(C)] +pub union Union3 { + pub rw_flags: i32, + pub fsync_flags: u32, + pub poll_events: u16, + pub poll32_events: u32, + pub sync_range_flags: u32, + pub msg_flags: u32, + pub timeout_flags: u32, + pub accept_flags: u32, + pub cancel_flags: u32, + pub open_flags: u32, + pub statx_flags: u32, + pub fadvise_advice: u32, + pub splice_flags: u32, + pub rename_flags: u32, + pub unlink_flags: u32, + pub hardlink_flags: u32, + pub xattr_flags: u32, + pub msg_ring_flags: u32, + pub uring_cmd_flags: u32, + pub waitid_flags: u32, + pub futex_flags: u32, + pub install_fd_flags: u32, + pub nop_flags: u32, } -pub type IOUringResult = std::result::Result; +#[repr(C, packed)] +pub struct PackedU16 { + pub buf_index: u16, +} -impl IOUring { - /// Create `IOUring` struct - pub async fn new( - owned_fd: Arc, - ring_size: usize, - channel_size: usize, - mtu: usize, - sqpoll_idle_time: Duration, - ) -> Result { - let fd = owned_fd.as_raw_fd(); +#[repr(C)] +pub union Union5 { + pub splice_fd_in: i32, + pub file_index: u32, + pub optlen: u32, + pub addr_len_stuff: std::mem::ManuallyDrop, +} - let (tx_queue_sender, tx_queue_receiver) = mpsc::channel(channel_size); - let (rx_queue_sender, rx_queue_receiver) = mpsc::channel(channel_size); - thread::Builder::new() - .name("io_uring-main".to_string()) - .spawn(move || { - tokio::runtime::Builder::new_current_thread() - .enable_io() - .build() - .expect("Failed building Tokio Runtime") - .block_on(iouring_task( - fd, - ring_size, - mtu, - sqpoll_idle_time, - tx_queue_receiver, - rx_queue_sender, - )) - .inspect_err(|err| { - tracing::error!("i/o uring task stopped: {:?}", err); - }) - })?; +#[repr(C)] +pub struct AddrLenPad { + pub addr_len: u16, + pub __pad3: [u16; 1], +} - Ok(Self { - owned_fd, - tx_queue: tx_queue_sender, - rx_queue: Mutex::new(rx_queue_receiver), - }) +impl Default for CustomSQE { + fn default() -> Self { + // Safety: memzero is ok + #[allow(unsafe_code)] + unsafe { + std::mem::zeroed() + } } +} - /// Retrieve a reference to the underlying device - pub fn owned_fd(&self) -> &T { - &self.owned_fd +pub struct ReadMulti { + fd: i32, + buf_group: u16, + flags: i32, +} + +impl ReadMulti { + #[inline] + pub fn new(fd: i32, buf_group: u16) -> Self { + ReadMulti { + fd, + buf_group, + flags: 0, + } } - /// Receive packet from Tun device - pub async fn recv(&self) -> IOUringResult { - self.rx_queue - .lock() - .await - .recv() - .await - .ok_or(IOUringError::RecvError) + #[inline] + pub fn build(self) -> Entry { + let sqe = CustomSQE { + opcode: IORING_OP_READ_MULTISHOT as _, + flags: io_uring::squeue::Flags::BUFFER_SELECT.bits(), + fd: self.fd, + buf_index: PackedU16 { + buf_index: self.buf_group, + }, + msg_flags: Union3 { + msg_flags: self.flags as _, + }, + ..Default::default() + }; + + // Safety: CustomSQE has identical memory layout to io_uring_sqe + #[allow(unsafe_code)] + unsafe { + std::mem::transmute(sqe) + } } +} - /// Try Send packet to Tun device - pub fn try_send(&self, buf: BytesMut) -> IOCallbackResult { - let buf_len = buf.len(); - let try_send_res = self.tx_queue.try_send(buf.freeze()); - match try_send_res { - Ok(()) => IOCallbackResult::Ok(buf_len), - Err(mpsc::error::TrySendError::Full(_)) => IOCallbackResult::WouldBlock, - Err(_) => { - use std::io::{Error, ErrorKind}; - IOCallbackResult::Err(Error::new(ErrorKind::Other, IOUringError::SendError)) +// Static for one-time initialization +static INITIALIZED: AtomicBool = AtomicBool::new(false); +static SUPPORTED: AtomicBool = AtomicBool::new(false); + +#[cold] +fn initialize_kernel_check() -> bool { + let supported = std::fs::read_to_string("/proc/sys/kernel/osrelease") + .ok() + .and_then(|v| { + let version_numbers = v.split('-').next()?; + let parts: Vec<_> = version_numbers.split('.').collect(); + if parts.len() >= 2 { + Some((parts[0].parse::().ok()?, parts[1].parse::().ok()?)) + } else { + None } + }) + .is_some_and(|(major, minor)| major > 6 || (major == 6 && minor >= 7)); + + SUPPORTED.store(supported, Ordering::Release); + INITIALIZED.store(true, Ordering::Release); + supported +} + +#[inline(always)] +pub fn kernel_supports_multishot() -> bool { + // Fast path - just load if initialized + if INITIALIZED.load(Ordering::Acquire) { + SUPPORTED.load(Ordering::Acquire) + } else { + // Slow path - do initialization + initialize_kernel_check() + } +} + +// Safety: SQE operations are always unsafe +/// Inline operation to ensure we queue reads without impacting runtime (multi-kernel) +#[inline(always)] +#[allow(unsafe_code)] +pub unsafe fn queue_reads( + sq: &mut io_uring::SubmissionQueue<'_>, + fd: i32, + n_entries: usize, + buf_group: u16, + user_data: u64, +) -> Result<(), PushError> { + if kernel_supports_multishot() { + tracing::debug!("Kernel supports - adding MULTISHOT_READ"); + // Safety: Ring is initialized and file descriptor is valid + unsafe { + let op = ReadMulti::new(fd, buf_group).build().user_data(user_data); + sq.push(&op) + } + } else { + tracing::debug!("NO Kernel support - adding {} READ", n_entries); + let mut ops = Vec::with_capacity(n_entries); + for _ in 0..n_entries { + let op = opcode::Read::new(types::Fd(fd), std::ptr::null_mut(), 0) + .buf_group(buf_group) + .build() + .flags(io_uring::squeue::Flags::BUFFER_SELECT) + .user_data(user_data); + ops.push(op); } + // Safety: Ring is initialized and file descriptor is valid + unsafe { sq.push_multiple(&ops) } } } -#[derive(Debug)] -enum SlotIdx { - Tx(isize), - Rx(isize), +// ------------------------------------------------------------- + +#[repr(u64)] +enum IOUringActionID { + RecycleBuffers = 0x10001000, + ReceivedBuffer = 0xfeedfeed, } +const RX_BUFFER_GROUP: u16 = 0xdead; + +// Required 32MB for io-uring to function properly +const REQUIRED_RLIMIT_MEMLOCK_MAX: u64 = 32 * 1024 * 1024; -impl SlotIdx { - fn from_user_data(u: u64) -> Self { - let u = u as isize; - if u < 0 { Self::Rx(!u) } else { Self::Tx(u) } +/// A wrapper around a raw pointer that guarantees thread safety through Arc ownership +struct BufferPtr(*mut u8); + +#[allow(unsafe_code)] +// Safety: The pointer is owned by Arc which ensures exclusive access +unsafe impl Send for BufferPtr {} +#[allow(unsafe_code)] +// Safety: The pointer is owned by Arc which ensures synchronized access +unsafe impl Sync for BufferPtr {} + +impl BufferPtr { + fn as_ptr(&self) -> *mut u8 { + self.0 } +} + +struct PageAlignedBuffer { + ptr: *mut u8, + layout: Layout, + entry_size: usize, + num_entries: usize, +} + +#[allow(unsafe_code)] +// Safety: The pointer is owned by Arc which ensures exclusive access +unsafe impl Send for PageAlignedBuffer {} +#[allow(unsafe_code)] +// Safety: The pointer is owned by Arc which ensures synchronized access +unsafe impl Sync for PageAlignedBuffer {} + +impl PageAlignedBuffer { + fn new(entry_size: usize, num_entries: usize) -> Self { + #[allow(unsafe_code)] + // Safety: libc is not safe, variable is fine + let page_size = unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as usize; + + // Round up entry_size to 16-byte alignment first + let aligned_entry_size = (entry_size + 15) & !15; - fn idx(&self) -> usize { - match *self { - SlotIdx::Tx(idx) => idx as usize, - SlotIdx::Rx(idx) => idx as usize, + // Calculate how many entries fit in one page + let entries_per_page = page_size / aligned_entry_size; + + // Calculate total pages needed + let pages_needed = num_entries.div_ceil(entries_per_page); + let total_size = pages_needed * page_size; + + let layout = Layout::from_size_align(total_size, page_size).expect("Invalid layout"); + + // Safety: allocate per layout selected (no aligned-allocator in rust) + #[allow(unsafe_code)] + let ptr = unsafe { alloc_zeroed(layout) }; + if ptr.is_null() { + std::alloc::handle_alloc_error(layout); + } + + Self { + ptr, + layout, + entry_size: aligned_entry_size, + num_entries, } } - fn user_data(&self) -> u64 { - match *self { - SlotIdx::Tx(idx) => idx as u64, - SlotIdx::Rx(idx) => (!idx) as u64, + fn get_ptr(&self, idx: usize) -> *mut u8 { + assert!(idx < self.num_entries); + // Safety: asserted size within boundry before + #[allow(unsafe_code)] + unsafe { + self.ptr.add(idx * self.entry_size) + } + } + + fn as_ptr(&self) -> *mut u8 { + self.ptr + } +} + +impl Drop for PageAlignedBuffer { + fn drop(&mut self) { + // Safety: we know what layout we allocated (saved) + #[allow(unsafe_code)] + unsafe { + dealloc(self.ptr, self.layout); } } } -struct RxState { - sender: Option>, - buf: BytesMut, +/// A pool of buffers with an underlying contiguous memory block +struct BufferPool { + data: PageAlignedBuffer, + lengths: Vec, + states: Vec, // 0 (false) = free, 1 (true) = in-use + usage_idx: AtomicUsize, } -fn push_one_tx_event_to( - buf: Bytes, - sq: &mut SubmissionQueue, - bufs: &mut [Option], - slot: SlotIdx, -) -> std::result::Result<(), SlotIdx> { - let sqe = opcode::Write::new(Fixed(REGISTERED_FD_INDEX), buf.as_ptr(), buf.len() as _) - .build() - .user_data(slot.user_data()); - - #[allow(unsafe_code)] - // SAFETY: sqe points to a buffer on the heap, owned - // by a `Bytes` in `bufs[slot]`, we will not reuse - // `bufs[slot]` until `slot` is returned to the slots vector. - if unsafe { sq.push(&sqe) }.is_err() { - return Err(slot); +impl BufferPool { + fn new(entry_size: usize, pool_size: usize) -> Self { + Self { + data: PageAlignedBuffer::new(entry_size, pool_size), + lengths: (0..pool_size).map(|_| AtomicUsize::new(0)).collect(), + states: (0..pool_size).map(|_| AtomicBool::new(false)).collect(), + usage_idx: AtomicUsize::new(0), + } } - // SAFETY: By construction instances of SlotIdx are always in bounds. - #[allow(unsafe_code)] - unsafe { - *bufs.get_unchecked_mut(slot.idx()) = Some(buf) - }; + fn get_buffer(&self, idx: usize) -> (BufferPtr, &AtomicUsize, &AtomicBool) { + ( + BufferPtr(self.data.get_ptr(idx)), + &self.lengths[idx], + &self.states[idx], + ) + } +} - Ok(()) +/// IO-uring Struct +pub struct IOUring { + owned_fd: Arc, + rx_pool: Arc, + tx_pool: Arc, + rx_notify: Arc, + ring: Arc, + submission_lock: Arc>, } -fn push_tx_events_to( - sbmt: &Submitter, - sq: &mut SubmissionQueue, - txq: &mut mpsc::Receiver, - slots: &mut Vec, - bufs: &mut [Option], -) -> Result<()> { - while !slots.is_empty() { - if sq.is_full() { - match sbmt.submit() { - Ok(_) => (), - Err(ref err) if err.raw_os_error() == Some(libc::EBUSY) => break, - Err(err) => { - return Err(anyhow!(err)).context("Push TX events failed for sq submit"); - } - } +// Safety: IOUring implementation does direct memory manipulations for performence benifits +#[allow(unsafe_code)] +impl IOUring { + /// Create `IOUring` struct + pub async fn new( + owned_fd: Arc, + ring_size: usize, + _channel_size: usize, + mtu: usize, + sqpoll_idle_time: Duration, + ) -> Result { + // NOTE: it's probably a good idea for now to allocate rx/tx/ring at the same size + // this is because the VPN use-case usually has MTU-sized buffers going in-and-out + + tracing::debug!( + "INIT io-uring, estimated memory (user | kernel): {}Mb | {}Mb", + (2 * (size_of::() + (mtu * ring_size))) / 1024 / 1024, + (ring_size * 2 * (16 + (2 * 64)) + 8192) / 1024 / 1024, + ); + + let rx_pool = Arc::new(BufferPool::new(mtu, ring_size)); + let tx_pool = Arc::new(BufferPool::new(mtu, ring_size)); + + let ring = Arc::new( + IoUring::builder() + .setup_sqpoll(sqpoll_idle_time.as_millis() as u32) + .build((ring_size * 2) as u32)?, + ); + + let rx_notify = Arc::new(Notify::new()); + + // NOTE: for now this ensures we only create 1 kthread per tunnel, and not 2 (rx/tx) + // we can opt to change this going forward, or redo the structure to not need a lock + let submission_lock = Arc::new(Mutex::new(())); + + // We can provide the buffers without a lock, as we still havn't shared the ownership + let fd = owned_fd.as_raw_fd(); + + // Scope submission-queue operations to avoid borrowing ring + { + // Safety: Ring submission can be used without locks at this point + let mut sq = unsafe { ring.submission_shared() }; + + tracing::debug!("Sending PROVIDE_BUFFERS"); + // Safety: Buffer memory is owned by rx_pool and outlives the usage + unsafe { + sq.push( + &opcode::ProvideBuffers::new( + rx_pool.data.as_ptr(), + mtu as i32, + ring_size as u16, + RX_BUFFER_GROUP, + 0, + ) + .build() + .user_data(IOUringActionID::RecycleBuffers as u64), + )? + }; + + // Safety: Ring is initialized and file descriptor is valid + unsafe { + queue_reads( + &mut sq, + fd, + ring_size, + RX_BUFFER_GROUP, + IOUringActionID::ReceivedBuffer as _, + )? + }; } - sq.sync(); - - match txq.try_recv() { - Ok(buf) => { - let slot = slots.pop().expect("no tx slots left"); // we are inside `!slots.is_empty()`. - if let Err(slot) = push_one_tx_event_to(buf, sq, bufs, slot) { - slots.push(slot); - break; + + let tx_iovecs: Vec<_> = (0..ring_size) + .map(|idx| { + let (ptr, _, _) = tx_pool.get_buffer(idx); + iovec { + iov_base: ptr.as_ptr() as *mut libc::c_void, + iov_len: mtu, } - } - Err(mpsc::error::TryRecvError::Empty) => { - break; - } - Err(err) => { - return Err(anyhow!(err)).context("Push TX events failed for try_recv"); + }) + .collect(); + + // Safety: memory for libc calls + let mut rlim: libc::rlimit = unsafe { std::mem::zeroed() }; + // Safety: fetch memory limitations defined + unsafe { + libc::getrlimit(libc::RLIMIT_MEMLOCK, &mut rlim); + } + + // Check memory usage needed + if rlim.rlim_max < REQUIRED_RLIMIT_MEMLOCK_MAX { + tracing::info!("RLIMIT too low ({}), adjusting", rlim.rlim_max); + rlim.rlim_max = REQUIRED_RLIMIT_MEMLOCK_MAX; + // Safety: rlimit API requires unsafe block + if unsafe { libc::setrlimit(libc::RLIMIT_MEMLOCK, &rlim) } != 0 { + tracing::warn!( + "Failed to set RLIMIT_MEMLOCK: {}", + std::io::Error::last_os_error() + ); } } + + // Safety: tx_iovecs point to valid memory owned by tx_pool + unsafe { ring.submitter().register_buffers(&tx_iovecs)? }; + ring.submitter() + .register_files(&[fd]) + .expect("io-uring support"); + + let config = IOUringTaskConfig { + rx_pool: rx_pool.clone(), + tx_pool: tx_pool.clone(), + rx_notify: rx_notify.clone(), + ring: ring.clone(), + }; + + // NOTE: currently we don't implement any Drop for class, it will require changes + // so until then, we can also ignore the need to close the FDs in rx_eventfd and owned_fd + thread::Builder::new() + .name("io_uring-main".to_string()) + .spawn(move || { + tokio::runtime::Builder::new_current_thread() + .enable_io() + .build() + .expect("Failed building Tokio Runtime") + .block_on(iouring_task(config)) + }) + .context("io_uring-task")?; + + Ok(Self { + owned_fd, + rx_pool, + tx_pool, + rx_notify, + ring, + submission_lock, + }) } - Ok(()) -} -fn push_rx_events_to( - sbmt: &Submitter, - sq: &mut SubmissionQueue, - slots: &mut Vec, - state: &mut [RxState], -) -> Result<()> { - loop { - if sq.is_full() { - match sbmt.submit() { - Ok(_) => (), - Err(ref err) if err.raw_os_error() == Some(libc::EBUSY) => break, - Err(err) => { - return Err(anyhow!(err)).context("Push RX events failed for sq submit"); - } - } + /// Retrieve a reference to the underlying device + pub fn owned_fd(&self) -> &T { + &self.owned_fd + } + + /// Send packet on Tun device (push to RING and submit) + pub fn try_send(&self, buf: BytesMut) -> IOCallbackResult { + tracing::debug!("try_send {} bytes", buf.len()); + // For semantics, see recv() function below + let idx = self + .tx_pool + .usage_idx + .fetch_update(Ordering::AcqRel, Ordering::Acquire, |idx| { + Some((idx + 1) % self.rx_pool.states.len()) + }) + .unwrap(); + let (buffer, length, state) = self.tx_pool.get_buffer(idx); + + let len = buf.len(); + if len > self.tx_pool.data.entry_size { + tracing::warn!( + "We dont support buffer-splitting for now (max: {}, got: {})", + self.tx_pool.data.entry_size, + len + ); + return IOCallbackResult::WouldBlock; + } + + // Check if buffer is free (state = 0) + if state + .compare_exchange( + false, + true, // free -> in-use + Ordering::AcqRel, + Ordering::Relaxed, + ) + .is_err() + { + // Out of buffers, need kernel to work faster + // consider a bigger queue if we see this counter + metrics::tun_iouring_tx_err(); + return IOCallbackResult::WouldBlock; } - sq.sync(); - - match slots.pop() { - Some(slot) => { - // SAFETY: By construction instances of SlotIdx are always in bounds. - #[allow(unsafe_code)] - let state = unsafe { state.get_unchecked_mut(slot.idx()) }; - - // queue a new rx - let sqe = opcode::Read::new( - Fixed(REGISTERED_FD_INDEX), - state.buf.as_mut_ptr(), - state.buf.capacity() as _, - ) + + // Safety: Buffer is allocated with sufficient size and ownership is checked via state + unsafe { std::slice::from_raw_parts_mut(buffer.as_ptr(), len).copy_from_slice(&buf) }; + length.store(len, Ordering::Release); + + // NOTE: IOUringActionID values have to be bigger then the ring-size + // this is because we use here as data for send_fixed operations + let write_op = + opcode::WriteFixed::new(types::Fixed(0), buffer.as_ptr(), len as _, idx as _) .build() - .user_data(slot.user_data()); - #[allow(unsafe_code)] - // SAFETY: sqe points to a buffer on the heap, owned - // by a `BytesMut` in `rx_bufs[slot]`, we will not reuse - // `rx_bufs[slot]` until `slot` is returned to the slots vector. - if unsafe { sq.push(&sqe) }.is_err() { - slots.push(slot); - break; + .user_data(idx as u64); + + tracing::debug!("queuing WRITE_FIXED on buf-id {}", idx); + + // Safely queue submission + { + let _guard = self.submission_lock.lock(); + // Safety: protected by lock above + let mut sq = unsafe { self.ring.submission_shared() }; + // Safety: entry uses buffers from tx_pool which outlive task using them + unsafe { + // let res = libc::write( + // self.owned_fd.as_raw_fd(), + // buffer.as_ptr() as *const libc::c_void, + // len, + // ); + // tracing::debug!("write (sync) results: {}", res); + // if res > 0 { + // return IOCallbackResult::Ok(res as usize); + // } + + // let err = std::io::Error::last_os_error(); + // tracing::error!("write faild: {}", err); + // IOCallbackResult::Err(err) + + match sq.push(&write_op) { + Ok(_) => IOCallbackResult::Ok(len), + Err(_) => { + tracing::warn!("Failed to queue send"); + metrics::tun_iouring_tx_err(); + IOCallbackResult::WouldBlock + } } } - None => break, } } - Ok(()) -} + /// Receive packet from Tun device + pub async fn recv(&self) -> IOCallbackResult { + // NOTE: Explanation on why these semantics were used: + // Flow: + // 1. The current value is loaded + // 2. Our closure is called with that value + // 3. A compare-and-swap (CAS) operation attempts to update with our new value + // + // The calculation of (X+1 % len) happens INSIDE closure, after the load but before the CAS. + // So if multiple threads are running concurrently: + // - Thread A loads value X + // - Thread B loads value X (before A's CAS completes) + // - Both calculate X+1 % len + // - We need AqcRel to ensure threads don't set values on top of each-other. + // - First thread's CAS should succeed as no value changed + // - Second thread's CAS should fail because the value changed + // - Second thread would retry, so we need Acquire on fetch to see Thread A's value + let idx = self + .rx_pool + .usage_idx + .fetch_update(Ordering::AcqRel, Ordering::Acquire, |idx| { + Some((idx + 1) % self.rx_pool.states.len()) + }) + .unwrap(); + let (buffer, length, state) = self.rx_pool.get_buffer(idx); + + tracing::debug!("recv blocking until buf-id {} is available", idx); + loop { + // NOTE: unlike the above case, here we can use Relaxed ordering for better performance. + // This is because we don't use the value in a closure, so we don't care for ensuring it's current value + if state + .compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed) + .is_ok() + { + // Last buffer - need to reload + // NOTE: this is why io_uring is not really practical in a lot of use-cases... + if idx + 1 == self.rx_pool.data.num_entries { + let _guard = self.submission_lock.lock(); + // Safety: protected by lock above + let mut sq = unsafe { self.ring.submission_shared() }; + let rx_ring_size = self.rx_pool.states.len(); + // Safety: buffers are mapped from rx_pool which outlives this task + unsafe { + sq.push( + &opcode::ProvideBuffers::new( + self.rx_pool.data.as_ptr(), + rx_ring_size as i32, + self.rx_pool.states.len() as u16, + RX_BUFFER_GROUP, + 0, + ) + .build() + .user_data(IOUringActionID::RecycleBuffers as u64), + ) + .expect("iouring queue should work") + }; + // Safety: buffer-group originates from rx_pool which outlives this task + unsafe { + queue_reads( + &mut sq, + self.owned_fd.as_raw_fd(), + rx_ring_size, + RX_BUFFER_GROUP, + IOUringActionID::ReceivedBuffer as _, + ) + .expect("iouring queue should work") + }; + } -async fn iouring_task( - fd: RawFd, - ring_size: usize, - mtu: usize, - sqpoll_idle_time: Duration, - mut tx_queue: mpsc::Receiver, - rx_queue: mpsc::Sender, -) -> Result<()> { - let mut event_fd: EventFd = EventFd::new(0, false)?; - let mut builder: Builder = IoUring::builder(); - - if sqpoll_idle_time.as_millis() > 0 { - let idle_time: u32 = sqpoll_idle_time - .as_millis() - .try_into() - .with_context(|| "invalid sqpoll idle time")?; - // This setting makes CPU go 100% when there is continuous traffic - builder.setup_sqpoll(idle_time); // Needs 5.13 - } - let mut ring = builder - .build(ring_size as u32) - .inspect_err(|e| tracing::error!("iouring setup failed: {e}"))?; - - let (sbmt, mut sq, mut cq) = ring.split(); - - // Register event-fd to cqe entries - sbmt.register_eventfd(event_fd.as_raw_fd())?; - sbmt.register_files(&[fd])?; - - // Using half of total io-uring size for rx and half for tx - let nr_tx_rx_slots = (ring_size / 2) as isize; - tracing::info!( - ring_size, - nr_tx_rx_slots, - ?sqpoll_idle_time, - "uring main task" - ); - - let mut rx_slots: Vec<_> = (0..nr_tx_rx_slots).map(SlotIdx::Rx).collect(); - let mut rx_state: Vec<_> = rx_slots - .iter() - .map(|_| RxState { - sender: None, - buf: BytesMut::with_capacity(mtu), - }) - .collect(); - for state in rx_state.iter_mut() { - state.sender = Some(rx_queue.clone().reserve_owned().await?) - } + let len = length.load(Ordering::Acquire); + let mut new_buf = BytesMut::with_capacity(len); - let rx_sq_entries: Vec<_> = rx_slots - .drain(..) - .map(|slot| { - let state = &mut rx_state[slot.idx()]; - opcode::Read::new( - Fixed(REGISTERED_FD_INDEX), - state.buf.as_mut_ptr(), - state.buf.capacity() as _, - ) - .build() - .user_data(slot.user_data()) - }) - .collect(); + tracing::debug!("recv, got {} bytes", len); - // SAFETY: sqe points to a buffer on the heap, owned - // by a `BytesMut` in `rx_bufs[slot]`, we will not reuse - // `rx_bufs[slot]` until `slot` is returned to the slots vector. - #[allow(unsafe_code)] - unsafe { - let entries = rx_sq_entries; - // This call should not fail since the SubmissionQueue should be empty now - sq.push_multiple(&entries)? - }; - sq.sync(); + // Safety: Buffer is allocated with sufficient size and ownership is checked via state + unsafe { + new_buf.extend_from_slice(std::slice::from_raw_parts(buffer.as_ptr(), len)) + }; + return IOCallbackResult::Ok(new_buf); + } + // IO-Bound wait for available buffers + self.rx_notify.notified().await; + } + } +} - let mut tx_slots: Vec<_> = (0..nr_tx_rx_slots).map(SlotIdx::Tx).collect(); - let mut tx_bufs = vec![None; tx_slots.len()]; +/// Task variables +struct IOUringTaskConfig { + rx_pool: Arc, + tx_pool: Arc, + rx_notify: Arc, + ring: Arc, +} - tracing::info!("Entering i/o uring loop"); +// Safety: To manage ring completion and results effeciantly requires direct memory manipulations +#[allow(unsafe_code)] +async fn iouring_task(config: IOUringTaskConfig) -> Result<()> { + tracing::debug!("Started iouring_task"); loop { - let _ = sbmt.submit()?; + // Work once we have at least 1 task to perform + config.ring.submit_and_wait(1)?; - cq.sync(); + tracing::debug!("iotask woke up"); - if cq.is_empty() && tx_queue.is_empty() { - let mut completed_number: [u8; 8] = [0; 8]; - tokio::select! { - // There is no "wait until the queue contains - // something" method so we have to actually receive - // and treat that as a special case. - Some(buf) = tx_queue.recv(), if !tx_slots.is_empty() && !sq.is_full() => { + // Safety: only task is using the completion-queue (concept should not change) + for cqe in unsafe { config.ring.completion_shared() } { + match cqe.user_data() { + x if x == IOUringActionID::RecycleBuffers as u64 => { + // Buffer provision completed + tracing::debug!("Buffer provision completed"); + } - let slot = tx_slots.pop().expect("no tx slots left"); // we are inside `!slots.is_empty()` guard. - if let Err(slot) = push_one_tx_event_to(buf, &mut sq, &mut tx_bufs, slot) { - tx_slots.push(slot); + x if x == IOUringActionID::ReceivedBuffer as u64 => { + let result = cqe.result(); + if result < 0 { + tracing::error!( + "Receive failed: {}", + std::io::Error::from_raw_os_error(-result) + ); + metrics::tun_iouring_rx_err(); + continue; } - } - Ok(a) = event_fd.read(&mut completed_number) => { - assert_eq!(a, 8); - }, + let buf_id = io_uring::cqueue::buffer_select(cqe.flags()).unwrap(); + let (_, length, state) = config.rx_pool.get_buffer(buf_id as _); - }; - cq.sync(); - } + tracing::debug!("recv {} bytes, saving to buf-id {}", result, buf_id); - // fill tx slots - push_tx_events_to(&sbmt, &mut sq, &mut tx_queue, &mut tx_slots, &mut tx_bufs)?; - - // refill rx slots - push_rx_events_to(&sbmt, &mut sq, &mut rx_slots, &mut rx_state)?; - - sq.sync(); - - for cqe in &mut cq { - let res = cqe.result(); - let slot = SlotIdx::from_user_data(cqe.user_data()); - - match slot { - SlotIdx::Rx(_) => { - if res > 0 { - // SAFETY: By construction instances of SlotIdx are always in bounds. - #[allow(unsafe_code)] - let RxState { - sender: maybe_sender, - buf, - } = unsafe { rx_state.get_unchecked_mut(slot.idx()) }; - - let mut buf = std::mem::replace(buf, BytesMut::with_capacity(mtu)); - - // SAFETY: We trust that the read operation - // returns the correct number of bytes received. - #[allow(unsafe_code)] - unsafe { - buf.advance_mut(res as _); - } - - if let Some(sender) = maybe_sender.take() { - let sender = sender.send(buf); - maybe_sender.replace(sender.reserve_owned().await?); - } else { - panic!("inflight rx state with no sender!"); - }; - } else if res != -libc::EAGAIN { - metrics::tun_iouring_rx_err(); - }; + length.store(result as usize, Ordering::Release); + state.store(true, Ordering::Release); // Mark as ready-for-user + config.rx_notify.notify_waiters(); - rx_slots.push(slot); + // TODO: consider below implementation in the future + // issue with this is that we have to gurentee no in-flight buffers ! + // see the comment under `recv` function, we can consider a buffer migration. + // NOTE: Here if we use new kernels we can auto-opt for multishot via: + // if !io_uring::cqueue::more(cqe.flags()) { + // let opt = ReadMulti::new(fd, buf_group).build().user_data(IOUringActionID::ReceivedBuffer); + // unsafe { sq.push(&opt) }; + // } } - SlotIdx::Tx(_) => { - if res <= 0 { - tracing::info!("rx slot {slot:?} completed with {res}"); + + idx => { + // TX completion + let result = cqe.result(); + if result < 0 { + tracing::error!( + "Send failed: {}", + std::io::Error::from_raw_os_error(-result) + ); + metrics::tun_iouring_tx_err(); } - // handle tx complete, we just need to drop the buffer - // SAFETY: By construction instances of SlotIdx are always in bounds. - #[allow(unsafe_code)] - unsafe { - *tx_bufs.get_unchecked_mut(slot.idx()) = None - }; - tx_slots.push(slot); + tracing::debug!("sent {} bytes from buf-id {}", result, idx); + let (_, _, state) = config.tx_pool.get_buffer(idx as _); + state.store(false, Ordering::Release); // mark as available for send } } } } } - -#[cfg(test)] -mod tests { - use super::*; - use test_case::test_case; - - #[test_case(SlotIdx::Tx(0) => 0x0000_0000_0000_0000)] - #[test_case(SlotIdx::Tx(10) => 0x0000_0000_0000_000a)] - #[test_case(SlotIdx::Tx(isize::MAX) => 0x7fff_ffff_ffff_ffff)] - #[test_case(SlotIdx::Rx(0) => 0x0000_0000_0000_0000)] - #[test_case(SlotIdx::Rx(10) => 0x0000_0000_0000_000a)] - #[test_case(SlotIdx::Rx(isize::MAX) => 0x7fff_ffff_ffff_ffff)] - fn slotid_idx(id: SlotIdx) -> usize { - id.idx() - } - - #[test_case(SlotIdx::Tx(0) => 0x0000_0000_0000_0000)] - #[test_case(SlotIdx::Tx(10) => 0x0000_0000_0000_000a)] - #[test_case(SlotIdx::Tx(isize::MAX) => 0x7fff_ffff_ffff_ffff)] - #[test_case(SlotIdx::Rx(0) => 0xffff_ffff_ffff_ffff)] - #[test_case(SlotIdx::Rx(10) => 0xffff_ffff_ffff_fff5)] - #[test_case(SlotIdx::Rx(isize::MAX) => 0x8000_0000_0000_0000)] - fn slotid_user_data(id: SlotIdx) -> u64 { - id.user_data() - } - - #[test_case(0x0000_0000_0000_0000 => matches SlotIdx::Tx(0))] - #[test_case(0x0000_0000_0000_000a => matches SlotIdx::Tx(10))] - #[test_case(0x7fff_ffff_ffff_ffff => matches SlotIdx::Tx(isize::MAX))] - #[test_case(0xffff_ffff_ffff_ffff => matches SlotIdx::Rx(0))] - #[test_case(0xffff_ffff_ffff_fff5 => matches SlotIdx::Rx(10))] - #[test_case(0x8000_0000_0000_0000 => matches SlotIdx::Rx(isize::MAX))] - fn slotid_from(u: u64) -> SlotIdx { - SlotIdx::from_user_data(u) - } -} diff --git a/lightway-app-utils/src/metrics.rs b/lightway-app-utils/src/metrics.rs index ce336daa..299d6f9c 100644 --- a/lightway-app-utils/src/metrics.rs +++ b/lightway-app-utils/src/metrics.rs @@ -1,9 +1,16 @@ use metrics::{Counter, counter}; use std::sync::LazyLock; +static METRIC_TUN_IOURING_TX_ERR: LazyLock = + LazyLock::new(|| counter!("tun_iouring_tx_err")); static METRIC_TUN_IOURING_RX_ERR: LazyLock = LazyLock::new(|| counter!("tun_iouring_rx_err")); +/// Count iouring TX entries which complete with an error +pub(crate) fn tun_iouring_tx_err() { + METRIC_TUN_IOURING_TX_ERR.increment(1) +} + /// Count iouring RX entries which complete with an error pub(crate) fn tun_iouring_rx_err() { METRIC_TUN_IOURING_RX_ERR.increment(1) diff --git a/lightway-app-utils/src/tun.rs b/lightway-app-utils/src/tun.rs index 2b57db8c..df54ce9f 100644 --- a/lightway-app-utils/src/tun.rs +++ b/lightway-app-utils/src/tun.rs @@ -168,8 +168,9 @@ impl TunIoUring { /// Recv from Tun pub async fn recv_buf(&self) -> IOCallbackResult { match self.tun_io_uring.recv().await { - Ok(pkt) => IOCallbackResult::Ok(pkt), - Err(e) => { + IOCallbackResult::Ok(pkt) => IOCallbackResult::Ok(pkt), + IOCallbackResult::WouldBlock => IOCallbackResult::WouldBlock, + IOCallbackResult::Err(e) => { use std::io::{Error, ErrorKind}; IOCallbackResult::Err(Error::new(ErrorKind::Other, e)) } diff --git a/lightway-core/Cargo.toml b/lightway-core/Cargo.toml index d4037be0..4c8cebee 100644 --- a/lightway-core/Cargo.toml +++ b/lightway-core/Cargo.toml @@ -29,6 +29,7 @@ metrics.workspace = true more-asserts.workspace = true num_enum = "0.7.0" once_cell = "1.19.0" +parking_lot = "0.12" pnet.workspace = true rand.workspace = true rand_core = "0.6.4" diff --git a/lightway-core/src/connection.rs b/lightway-core/src/connection.rs index e8c1d558..d84f2968 100644 --- a/lightway-core/src/connection.rs +++ b/lightway-core/src/connection.rs @@ -6,13 +6,14 @@ mod io_adapter; mod key_update; use bytes::{Bytes, BytesMut}; +use parking_lot::Mutex; use rand::Rng; use std::borrow::Cow; use std::net::AddrParseError; use std::num::{NonZeroU16, Wrapping}; use std::{ net::SocketAddr, - sync::{Arc, Mutex}, + sync::Arc, time::{Duration, Instant}, }; use thiserror::Error; @@ -936,7 +937,7 @@ impl Connection { ref mut pending_session_id, .. } => { - let new_session_id = rng.lock().unwrap().r#gen(); + let new_session_id = rng.lock().r#gen(); self.session.io_cb_mut().set_session_id(new_session_id); diff --git a/lightway-core/src/connection/builders.rs b/lightway-core/src/connection/builders.rs index 7a55338d..9f44d11d 100644 --- a/lightway-core/src/connection/builders.rs +++ b/lightway-core/src/connection/builders.rs @@ -282,7 +282,7 @@ impl<'a, AppState: Send + 'static> ServerConnectionBuilder<'a, AppState> { let auth = ctx.auth.clone(); let ip_pool = ctx.ip_pool.clone(); - let session_id = ctx.rng.lock().unwrap().r#gen(); + let session_id = ctx.rng.lock().r#gen(); let outside_mtu = MAX_OUTSIDE_MTU; let outside_plugins = ctx.outside_plugins.build()?; diff --git a/lightway-core/src/context.rs b/lightway-core/src/context.rs index 6dc21580..9b648227 100644 --- a/lightway-core/src/context.rs +++ b/lightway-core/src/context.rs @@ -1,8 +1,9 @@ pub mod ip_pool; mod server_auth; +use parking_lot::Mutex; use rand::SeedableRng; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use thiserror::Error; use crate::{ diff --git a/lightway-core/src/utils.rs b/lightway-core/src/utils.rs index 3215572f..7b4ada34 100644 --- a/lightway-core/src/utils.rs +++ b/lightway-core/src/utils.rs @@ -9,17 +9,58 @@ use std::net::Ipv4Addr; use std::ops; use tracing::warn; +// #[cfg(target_arch = "x86_64")] +// use std::arch::x86_64::*; + +// // Check if AVX2 is available on the current CPU (unused until we support IPv6) +// #[inline(always)] +// fn has_avx2() -> bool { +// #[cfg(target_arch = "x86_64")] +// { +// is_x86_feature_detected!("avx2") +// } +// #[cfg(not(target_arch = "x86_64"))] +// { +// false +// } +// } + +// HOT/COLD path implementation until RUST adds +// https://github.com/rust-lang/rust/issues/26179 + +#[inline] +#[cold] +fn cold() {} + +#[inline] +pub(crate) fn likely(b: bool) -> bool { + if !b { + cold() + } + b +} + +#[inline] +pub(crate) fn unlikely(b: bool) -> bool { + if b { + cold() + } + b +} + +/// Validate if a buffer contains a valid IPv4 packet pub(crate) fn ipv4_is_valid_packet(buf: &[u8]) -> bool { - if buf.is_empty() { + if buf.len() < 20 { + // IPv4 header is at least 20 bytes return false; } let first_byte = buf[0]; let ip_version = first_byte >> 4; - ip_version == 4 } -// Structure to calculate incremental checksum +/// Structure to calculate incremental checksum +#[derive(Clone, Copy)] struct Checksum(u16); impl ops::Deref for Checksum { @@ -33,122 +74,220 @@ impl ops::Sub for Checksum { type Output = Checksum; fn sub(self, rhs: u16) -> Checksum { let (n, of) = self.0.overflowing_sub(rhs); - Checksum(match of { - true => n - 1, - false => n, - }) + Checksum(if of { n.wrapping_sub(1) } else { n }) } } +/// Structure to handle checksum updates when modifying IP addresses +struct ChecksumUpdate(Vec<(u16, u16)>); + impl Checksum { - // Based on RFC-1624 [Eqn. 4] + /// Update checksum when replacing one word with another + /// Based on RFC-1624 [Eqn. 4] fn update_word(self, old_word: u16, new_word: u16) -> Self { self - !old_word - new_word } + /// Apply multiple checksum updates fn update(self, updates: &ChecksumUpdate) -> Self { - updates.0.iter().fold(self, |c, x| c.update_word(x.0, x.1)) + updates + .0 + .iter() + .fold(self, |c, &(old, new)| c.update_word(old, new)) } -} -struct ChecksumUpdate(Vec<(u16, u16)>); + // AVX2-accelerated checksum update (unused until we support IPv6) + // #[allow(unsafe_code)] + // #[cfg(target_arch = "x86_64")] + // #[target_feature(enable = "avx2")] + // unsafe fn update_avx2(self, updates: &ChecksumUpdate) -> Self { + // let mut sum = u32::from(self.0); + + // // Process 8 words at a time using AVX2 + // for chunk in updates.0.chunks(8) { + // // Pre-allocate with known size + // let mut old_words = Vec::with_capacity(8); + // let mut new_words = Vec::with_capacity(8); + + // // Fill vectors with data or zeros + // for i in 0..8 { + // if let Some(&(old, new)) = chunk.get(i) { + // old_words.push(i32::from(old)); + // new_words.push(i32::from(new)); + // } else { + // old_words.push(0); + // new_words.push(0); + // } + // } + + // // SAFETY: Vectors are guaranteed to have exactly 8 elements + // unsafe { + // // Load data into AVX2 registers + // let old_vec = _mm256_set_epi32( + // old_words[7], + // old_words[6], + // old_words[5], + // old_words[4], + // old_words[3], + // old_words[2], + // old_words[1], + // old_words[0], + // ); + // let new_vec = _mm256_set_epi32( + // new_words[7], + // new_words[6], + // new_words[5], + // new_words[4], + // new_words[3], + // new_words[2], + // new_words[1], + // new_words[0], + // ); + + // // Compute NOT(old) + new using AVX2 + // let not_old = _mm256_xor_si256(old_vec, _mm256_set1_epi32(-1)); + // let sum_vec = _mm256_add_epi32(not_old, new_vec); + + // // Horizontal sum + // let hadd = _mm256_hadd_epi32(sum_vec, sum_vec); + // let hadd = _mm256_hadd_epi32(hadd, hadd); + + // sum = sum.wrapping_add(_mm256_extract_epi32(hadd, 0) as u32); + // } + // } + + // // Fold 32-bit sum to 16 bits + // while sum > 0xFFFF { + // sum = (sum & 0xFFFF) + (sum >> 16); + // } + + // Checksum(sum as u16) + // } +} impl ChecksumUpdate { + /// Create checksum update data from IP address change fn from_ipv4_address(old: Ipv4Addr, new: Ipv4Addr) -> Self { - let mut result = vec![]; - let old: [u8; 4] = old.octets(); - let new: [u8; 4] = new.octets(); - for i in 0..2 { - let old_word = u16::from_be_bytes([old[i * 2], old[i * 2 + 1]]); - let new_word = u16::from_be_bytes([new[i * 2], new[i * 2 + 1]]); - result.push((old_word, new_word)); - } - Self(result) + let old_bytes = old.octets(); + let new_bytes = new.octets(); + + // Convert to u16 pairs for checksum calculation + let old_words = [ + u16::from_be_bytes([old_bytes[0], old_bytes[1]]), + u16::from_be_bytes([old_bytes[2], old_bytes[3]]), + ]; + let new_words = [ + u16::from_be_bytes([new_bytes[0], new_bytes[1]]), + u16::from_be_bytes([new_bytes[2], new_bytes[3]]), + ]; + + Self(vec![ + (old_words[0], new_words[0]), + (old_words[1], new_words[1]), + ]) } } -fn tcp_adjust_packet_checksum(mut packet: MutableIpv4Packet, updates: ChecksumUpdate) { - let packet = MutableTcpPacket::new(packet.payload_mut()); - let Some(mut packet) = packet else { - warn!("Invalid packet size (less than Tcp header)!"); - return; - }; - - let checksum = Checksum(packet.get_checksum()); - let checksum = checksum.update(&updates); - packet.set_checksum(*checksum); -} - -fn udp_adjust_packet_checksum(mut packet: MutableIpv4Packet, updates: ChecksumUpdate) { - let packet = MutableUdpPacket::new(packet.payload_mut()); - let Some(mut packet) = packet else { - warn!("Invalid packet size (less than Udp header)!"); +/// Update transport protocol checksums after IP address changes +fn update_transport_checksums(packet: &mut MutableIpv4Packet, updates: ChecksumUpdate) { + // Skip if this is not the first fragment + if packet.get_fragment_offset() != 0 { return; - }; - - let checksum = Checksum(packet.get_checksum()); + } - // UDP checksums are optional, and we should respect that when doing NAT - if *checksum != 0 { - let checksum = checksum.update(&updates); - packet.set_checksum(checksum.0); + match packet.get_next_level_protocol() { + IpNextHeaderProtocols::Tcp => update_tcp_checksum(packet, updates), + IpNextHeaderProtocols::Udp => update_udp_checksum(packet, updates), + IpNextHeaderProtocols::Icmp => {} // ICMP doesn't need checksum update for IP changes + protocol => { + if unlikely(true) { + warn!(protocol = ?protocol, "Unknown protocol, skipping checksum update"); + } + } } } -fn ipv4_adjust_packet_checksum(mut packet: MutableIpv4Packet, updates: ChecksumUpdate) { - let checksum = Checksum(packet.get_checksum()); - let checksum = checksum.update(&updates); - packet.set_checksum(*checksum); - - // In case of fragmented packets, TCP/UDP header will be present only in the first fragment. - // So skip updating the checksum, if it is not the first fragment (i.e frag_offset != 0) - if 0 != packet.get_fragment_offset() { - return; +fn update_tcp_checksum(packet: &mut MutableIpv4Packet, updates: ChecksumUpdate) { + if likely(MutableTcpPacket::new(packet.payload_mut()).is_some()) { + let mut tcp_packet = MutableTcpPacket::new(packet.payload_mut()).unwrap(); + let checksum = tcp_packet.get_checksum(); + // Only update if checksum is present (not 0) + if checksum != 0 { + let checksum = Checksum(checksum).update(&updates); + tcp_packet.set_checksum(*checksum); + } + } else { + warn!("Invalid packet size (less than TCP header)!"); } +} - let transport_protocol = packet.get_next_level_protocol(); - match transport_protocol { - IpNextHeaderProtocols::Tcp => tcp_adjust_packet_checksum(packet, updates), - IpNextHeaderProtocols::Udp => udp_adjust_packet_checksum(packet, updates), - IpNextHeaderProtocols::Icmp => {} - protocol => { - warn!(protocol = ?protocol, "Unknown protocol, skipping checksum adjust") +fn update_udp_checksum(packet: &mut MutableIpv4Packet, updates: ChecksumUpdate) { + if likely(MutableUdpPacket::new(packet.payload_mut()).is_some()) { + let mut udp_packet = MutableUdpPacket::new(packet.payload_mut()).unwrap(); + let checksum = udp_packet.get_checksum(); + // Only update if checksum is present (not 0) + if checksum != 0 { + let checksum = Checksum(checksum).update(&updates); + udp_packet.set_checksum(*checksum); } + } else { + warn!("Invalid packet size (less than UDP header)!"); } } -/// Utility function to update source ip address in ipv4 packet buffer -/// Nop if buf is not a valid IPv4 packet -pub fn ipv4_update_source(buf: &mut [u8], ip: Ipv4Addr) { - let packet = MutableIpv4Packet::new(buf); - let Some(mut packet) = packet else { - warn!("Invalid packet size (less than Ipv4 header)!"); +#[derive(Clone, Copy)] +enum IpField { + Source, + Destination, +} + +// NOTE: the field is compile-time known, so gets optimized, this is for better maintanance +#[inline(always)] +fn ipv4_update_field(buf: &mut [u8], new_ip: Ipv4Addr, field: IpField) { + let Some(mut packet) = MutableIpv4Packet::new(buf) else { + if unlikely(true) { + warn!("Failed to create IPv4 packet!"); + } return; }; - let old = packet.get_source(); - // Set new source only after getting old source ip address - packet.set_source(ip); - - ipv4_adjust_packet_checksum(packet, ChecksumUpdate::from_ipv4_address(old, ip)); -} + // Get old IP before updating + let old_ip = match field { + IpField::Source => packet.get_source(), + IpField::Destination => packet.get_destination(), + }; -/// Utility function to update destination ip address in ipv4 packet buffer -/// Nop if buf is not a valid IPv4 packet -pub fn ipv4_update_destination(buf: &mut [u8], ip: Ipv4Addr) { - let packet = MutableIpv4Packet::new(buf); - let Some(mut packet) = packet else { - warn!("Invalid packet size (less than Ipv4 header)!"); - return; + // Update IP field + match field { + IpField::Source => packet.set_source(new_ip), + IpField::Destination => packet.set_destination(new_ip), }; - let old = packet.get_destination(); - // Set new destination only after getting old destination ip address - packet.set_destination(ip); + // Update checksums + let updates = ChecksumUpdate::from_ipv4_address(old_ip, new_ip); + let checksum = packet.get_checksum(); + if checksum != 0 { + let checksum = Checksum(checksum).update(&updates); + packet.set_checksum(*checksum); + } + + // Update transport protocol checksums + update_transport_checksums(&mut packet, updates); +} + +/// Update source IP address in an IPv4 packet +#[inline] +pub fn ipv4_update_source(buf: &mut [u8], new_ip: Ipv4Addr) { + ipv4_update_field(buf, new_ip, IpField::Source) +} - ipv4_adjust_packet_checksum(packet, ChecksumUpdate::from_ipv4_address(old, ip)); +/// Update destination IP address in an IPv4 packet +#[inline] +pub fn ipv4_update_destination(buf: &mut [u8], new_ip: Ipv4Addr) { + ipv4_update_field(buf, new_ip, IpField::Destination) } +/// Clamp TCP MSS option if present in a TCP SYN packet pub fn tcp_clamp_mss(pkt: &mut [u8], mss: u16) -> Option { let mut ipv4_packet = MutableIpv4Packet::new(pkt)?; @@ -177,7 +316,7 @@ pub fn tcp_clamp_mss(pkt: &mut [u8], mss: u16) -> Option { } [bytes[0], bytes[1]] = mss.to_be_bytes(); - tcp_adjust_packet_checksum(ipv4_packet, ChecksumUpdate(vec![(existing_mss, mss)])); + update_tcp_checksum(&mut ipv4_packet, ChecksumUpdate(vec![(existing_mss, mss)])); return Some(existing_mss); } let start = std::cmp::min(option.packet_size(), option_raw.len()); @@ -251,8 +390,9 @@ mod tests { ]; #[test_case(&[] => false; "empty")] - #[test_case(&[0x40] => true; "v4")] - #[test_case(&[0x60] => false; "v6")] + #[test_case(&[0x40; 19] => false; "buffer too small")] + #[test_case(&[0x45; 20] => true; "minimum valid v4")] + #[test_case(&[0x60; 20] => false; "v6 header")] #[test_case(SOURCE_1_DEST_1 => true; "SOURCE_1_TO_DEST_1")] #[test_case(SOURCE_1_DEST_2 => true; "SOURCE_1_TO_DEST_2")] #[test_case(SOURCE_2_DEST_1 => true; "SOURCE_2_TO_DEST_1")] diff --git a/lightway-core/tests/connection.rs b/lightway-core/tests/connection.rs index d67845ab..54616674 100644 --- a/lightway-core/tests/connection.rs +++ b/lightway-core/tests/connection.rs @@ -377,7 +377,7 @@ async fn client( assert!(matches!(client.state(), State::Online)); assert!(message_sent); - assert_eq!(&buf[..], b"\x40Hello World!"); + assert_eq!(&buf[..], b"\x40Hello World!But Bigger"); let curve = client.current_curve().unwrap(); assert_eq!(curve, pqc.expected_curve()); @@ -429,7 +429,9 @@ async fn client( // work as the first byte too, but be more // explicit to avoid a confusing surprise for some // future developer). - let mut buf: BytesMut = BytesMut::from(&b"\x40Hello World!"[..]); + // Hi, future developer here, this is an invalid ipv4 size, + // now the code sends 20 bytes so we don't fail any validations + let mut buf: BytesMut = BytesMut::from(&b"\x40Hello World!But Bigger"[..]); eprintln!("Sending message: {buf:?}"); client.inside_data_received(&mut buf).expect("Send my message"); message_sent = true; diff --git a/lightway-server/Cargo.toml b/lightway-server/Cargo.toml index 8b5bec3f..5b545d39 100644 --- a/lightway-server/Cargo.toml +++ b/lightway-server/Cargo.toml @@ -26,6 +26,7 @@ clap.workspace = true ctrlc.workspace = true delegate.workspace = true educe.workspace = true +hashbrown = "0.15.2" ipnet.workspace = true jsonwebtoken = "9.3.0" libc.workspace = true @@ -33,6 +34,7 @@ lightway-app-utils.workspace = true lightway-core = { workspace = true, features = ["postquantum"] } metrics.workspace = true metrics-util = "0.18.0" +parking_lot = "0.12.3" pnet.workspace = true ppp = "2.2.0" pwhash = "1.0.0" diff --git a/lightway-server/src/connection.rs b/lightway-server/src/connection.rs index efe3b430..302dc9ee 100644 --- a/lightway-server/src/connection.rs +++ b/lightway-server/src/connection.rs @@ -1,8 +1,9 @@ use bytes::BytesMut; use delegate::delegate; +use parking_lot::Mutex; use std::{ net::{Ipv4Addr, SocketAddr}, - sync::{Arc, Mutex, Weak}, + sync::{Arc, Weak}, }; use tracing::{trace, warn}; @@ -80,7 +81,6 @@ impl Connection { conn.lw_conn .lock() - .unwrap() .app_state_mut() .conn .set(Arc::downgrade(&conn)) @@ -93,7 +93,7 @@ impl Connection { } delegate! { - to self.lw_conn.lock().unwrap() { + to self.lw_conn.lock() { pub fn tls_protocol_version(&self) -> ProtocolVersion; pub fn connection_type(&self) -> ConnectionType; pub fn session_id(&self) -> SessionId; @@ -145,7 +145,7 @@ impl Connection { } pub fn begin_session_id_rotation(self: &Arc) { - let mut conn = self.lw_conn.lock().unwrap(); + let mut conn = self.lw_conn.lock(); // A rotation is already in flight, nothing to be done this // time. @@ -171,13 +171,13 @@ impl Connection { // Use this only during shutdown, after clearing all connections from // connection_manager pub fn lw_disconnect(self: Arc) -> ConnectionResult<()> { - self.lw_conn.lock().unwrap().disconnect() + self.lw_conn.lock().disconnect() } pub fn disconnect(&self) -> ConnectionResult<()> { metrics::connection_closed(); self.manager.remove_connection(self); - self.lw_conn.lock().unwrap().disconnect() + self.lw_conn.lock().disconnect() } } diff --git a/lightway-server/src/connection_manager.rs b/lightway-server/src/connection_manager.rs index 9682b77c..30203baf 100644 --- a/lightway-server/src/connection_manager.rs +++ b/lightway-server/src/connection_manager.rs @@ -1,11 +1,12 @@ mod connection_map; use delegate::delegate; +use hashbrown::HashMap; +use parking_lot::Mutex; use std::{ - collections::HashMap, net::SocketAddr, sync::{ - Arc, Mutex, Weak, + Arc, Weak, atomic::{AtomicUsize, Ordering}, }, }; @@ -252,7 +253,7 @@ impl ConnectionManager { } pub(crate) fn pending_session_id_rotations_count(&self) -> usize { - self.pending_session_id_rotations.lock().unwrap().len() + self.pending_session_id_rotations.lock().len() } pub(crate) fn create_streaming_connection( @@ -269,7 +270,7 @@ impl ConnectionManager { outside_io, )?; // TODO: what if addr was already present? - self.connections.lock().unwrap().insert(&conn)?; + self.connections.lock().insert(&conn)?; Ok(conn) } @@ -302,7 +303,7 @@ impl ConnectionManager { where F: FnOnce() -> OutsideIOSendCallbackArg, { - match self.connections.lock().unwrap().lookup(addr, session_id) { + match self.connections.lock().lookup(addr, session_id) { connection_map::Entry::Occupied(c) => { if session_id == SessionId::EMPTY || c.session_id() == session_id { let update_peer_address = addr != c.peer_addr(); @@ -328,12 +329,7 @@ impl ConnectionManager { } connection_map::Entry::Vacant(_e) => { // Maybe this is a pending session rotation - if let Some(c) = self - .pending_session_id_rotations - .lock() - .unwrap() - .get(&session_id) - { + if let Some(c) = self.pending_session_id_rotations.lock().get(&session_id) { let update_peer_address = addr != c.peer_addr(); return Ok((c.clone(), update_peer_address)); @@ -349,19 +345,18 @@ impl ConnectionManager { self: &Arc, addr: SocketAddr, ) -> Option> { - self.connections.lock().unwrap().find_by(addr) + self.connections.lock().find_by(addr) } pub(crate) fn set_peer_addr(&self, conn: &Arc, new_addr: SocketAddr) { let old_addr = conn.set_peer_addr(new_addr); self.connections .lock() - .unwrap() .update_socketaddr_for_connection(old_addr, new_addr); } pub(crate) fn remove_connection(&self, conn: &Connection) { - self.connections.lock().unwrap().remove(conn) + self.connections.lock().remove(conn) } pub(crate) fn begin_session_id_rotation( @@ -371,7 +366,6 @@ impl ConnectionManager { ) { self.pending_session_id_rotations .lock() - .unwrap() .insert(new_session_id, conn.clone()); metrics::udp_session_rotation_begin(); @@ -383,13 +377,9 @@ impl ConnectionManager { old: SessionId, new: SessionId, ) { - self.pending_session_id_rotations - .lock() - .unwrap() - .remove(&new); + self.pending_session_id_rotations.lock().remove(&new); self.connections .lock() - .unwrap() .update_session_id_for_connection(old, new); metrics::udp_session_rotation_finalized(); @@ -398,7 +388,6 @@ impl ConnectionManager { pub(crate) fn online_connection_activity(&self) -> Vec { self.connections .lock() - .unwrap() .iter_connections() .filter_map(|c| match c.state() { State::Online => Some(c.activity()), @@ -411,7 +400,7 @@ impl ConnectionManager { fn evict_idle_connections(&self) { tracing::trace!("Aging connections"); - for conn in self.connections.lock().unwrap().iter_connections() { + for conn in self.connections.lock().iter_connections() { let age = conn.activity().last_outside_data_received.elapsed(); if age > CONNECTION_MAX_IDLE_AGE { tracing::info!(session = ?conn.session_id(), age = ?age, "Disconnecting idle connection"); @@ -431,7 +420,7 @@ impl ConnectionManager { fn evict_expired_connections(&self) { tracing::trace!("Expiring connections"); - for conn in self.connections.lock().unwrap().iter_connections() { + for conn in self.connections.lock().iter_connections() { let Ok(expired) = conn.authentication_expired() else { continue; }; @@ -449,7 +438,7 @@ impl ConnectionManager { } pub(crate) fn close_all_connections(&self) { - let connections = self.connections.lock().unwrap().remove_connections(); + let connections = self.connections.lock().remove_connections(); for conn in connections { let _ = conn.lw_disconnect(); } diff --git a/tests/e2e/docker-compose.yml b/tests/e2e/docker-compose.yml index 453a8349..7ad94165 100644 --- a/tests/e2e/docker-compose.yml +++ b/tests/e2e/docker-compose.yml @@ -24,6 +24,10 @@ services: stop_grace_period: 10s cap_add: - NET_ADMIN + ulimits: + memlock: + soft: -1 + hard: -1 devices: - "/dev/net/tun:/dev/net/tun" networks: @@ -54,6 +58,10 @@ services: - net.ipv4.conf.all.promote_secondaries=1 cap_add: - NET_ADMIN + ulimits: + memlock: + soft: -1 + hard: -1 devices: - "/dev/net/tun:/dev/net/tun" depends_on: