From 499b4f6695e83e49e521657bb50ebaa16b964feb Mon Sep 17 00:00:00 2001 From: Ian Campbell <52475242+xv-ian-c@users.noreply.github.com> Date: Thu, 5 Sep 2024 11:01:24 +0100 Subject: [PATCH 1/2] lightway-core: io: Add a CoW abstraction for sending buffer Depending on the specific implementation of the trait they may or may not need an owned version of the buffer. Since we have one already in the core in some paths we can expose that through the stack and avoid some extra allocations. In the UDP send path we can easily and cheaply freeze the `BytesMut` buffer to get an owned `Bytes` (with a little more overhead while `aggressive_send` is enabled). However in the TCP path the use of `SendBuffer` makes this harder (if not impossible) to achieve (cheaply at least). So add `CowBytes` which can contain either a `Bytes` or a `&[u8]` slice. Note that right now every consumer still uses the byte slice. --- lightway-client/src/io/outside/tcp.rs | 6 ++--- lightway-client/src/io/outside/udp.rs | 6 ++--- lightway-core/src/connection/io_adapter.rs | 16 ++++++----- lightway-core/src/io.rs | 31 ++++++++++++++++++++-- lightway-core/src/lib.rs | 3 ++- lightway-core/tests/connection.rs | 7 ++--- lightway-server/src/io/outside/tcp.rs | 6 ++--- lightway-server/src/io/outside/udp.rs | 8 +++--- 8 files changed, 58 insertions(+), 25 deletions(-) diff --git a/lightway-client/src/io/outside/tcp.rs b/lightway-client/src/io/outside/tcp.rs index 3ab8133e..204de44b 100644 --- a/lightway-client/src/io/outside/tcp.rs +++ b/lightway-client/src/io/outside/tcp.rs @@ -4,7 +4,7 @@ use std::{net::SocketAddr, sync::Arc}; use tokio::net::TcpStream; use super::OutsideIO; -use lightway_core::{IOCallbackResult, OutsideIOSendCallback, OutsideIOSendCallbackArg}; +use lightway_core::{CowBytes, IOCallbackResult, OutsideIOSendCallback, OutsideIOSendCallbackArg}; pub struct Tcp(tokio::net::TcpStream, SocketAddr); @@ -58,8 +58,8 @@ impl OutsideIO for Tcp { } impl OutsideIOSendCallback for Tcp { - fn send(&self, buf: &[u8]) -> IOCallbackResult { - match self.0.try_write(buf) { + fn send(&self, buf: CowBytes) -> IOCallbackResult { + match self.0.try_write(buf.as_bytes()) { Ok(nr) => IOCallbackResult::Ok(nr), Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { IOCallbackResult::WouldBlock diff --git a/lightway-client/src/io/outside/udp.rs b/lightway-client/src/io/outside/udp.rs index 112059d9..69e34636 100644 --- a/lightway-client/src/io/outside/udp.rs +++ b/lightway-client/src/io/outside/udp.rs @@ -5,7 +5,7 @@ use tokio::net::UdpSocket; use super::OutsideIO; use lightway_app_utils::sockopt; -use lightway_core::{IOCallbackResult, OutsideIOSendCallback, OutsideIOSendCallbackArg}; +use lightway_core::{CowBytes, IOCallbackResult, OutsideIOSendCallback, OutsideIOSendCallbackArg}; pub struct Udp { sock: tokio::net::UdpSocket, @@ -67,8 +67,8 @@ impl OutsideIO for Udp { } impl OutsideIOSendCallback for Udp { - fn send(&self, buf: &[u8]) -> IOCallbackResult { - match self.sock.try_send_to(buf, self.peer_addr) { + fn send(&self, buf: CowBytes) -> IOCallbackResult { + match self.sock.try_send_to(buf.as_bytes(), self.peer_addr) { Ok(nr) => IOCallbackResult::Ok(nr), Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { IOCallbackResult::WouldBlock diff --git a/lightway-core/src/connection/io_adapter.rs b/lightway-core/src/connection/io_adapter.rs index 7b2e22ca..03fc4b32 100644 --- a/lightway-core/src/connection/io_adapter.rs +++ b/lightway-core/src/connection/io_adapter.rs @@ -5,7 +5,8 @@ use more_asserts::*; use wolfssl::IOCallbackResult; use crate::{ - plugin::PluginList, wire, ConnectionType, OutsideIOSendCallbackArg, PluginResult, Version, + plugin::PluginList, wire, ConnectionType, CowBytes, OutsideIOSendCallbackArg, PluginResult, + Version, }; pub(crate) struct SendBuffer { @@ -164,26 +165,28 @@ impl WolfSSLIOAdapter { } } + let b = b.freeze(); + // Send header + buf. If we are in aggressive mode we send it // a total of three times. On any send error we return // immediately without the remaining tries, otherwise we // return the result of the final attempt. if self.aggressive_send { - match self.io.send(&b[..]) { + match self.io.send(CowBytes::Owned(b.clone())) { IOCallbackResult::Ok(_) => {} wb @ IOCallbackResult::WouldBlock => return wb, err @ IOCallbackResult::Err(_) => return err, } - match self.io.send(&b[..]) { + match self.io.send(CowBytes::Owned(b.clone())) { IOCallbackResult::Ok(_) => {} wb @ IOCallbackResult::WouldBlock => return wb, err @ IOCallbackResult::Err(_) => return err, } } - match self.io.send(&b[..]) { + match self.io.send(CowBytes::Owned(b)) { IOCallbackResult::Ok(n) => { // We've sent `n` bytes successfully out of // `wire::Header::WIRE_SIZE` + `b.len()` that we @@ -250,7 +253,7 @@ impl WolfSSLIOAdapter { debug_assert_le!(send_buffer.original_len(), buf.len()); } - match self.io.send(send_buffer.as_bytes()) { + match self.io.send(CowBytes::Borrowed(send_buffer.as_bytes())) { IOCallbackResult::Ok(n) if n == send_buffer.actual_len() => { // We've now sent everything we were originally // asked to, so signal completion of that original @@ -335,7 +338,8 @@ mod tests { } impl OutsideIOSendCallback for FakeOutsideIOSend { - fn send(&self, buf: &[u8]) -> IOCallbackResult { + fn send(&self, buf: CowBytes) -> IOCallbackResult { + let buf = buf.as_bytes(); let (fakes, sent) = &mut *self.0.lock().unwrap(); match fakes.pop_front() { Some(IOCallbackResult::Ok(n)) => { diff --git a/lightway-core/src/io.rs b/lightway-core/src/io.rs index 747a4162..42d5a335 100644 --- a/lightway-core/src/io.rs +++ b/lightway-core/src/io.rs @@ -1,6 +1,6 @@ use std::{net::SocketAddr, sync::Arc}; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use wolfssl::IOCallbackResult; /// Application provided callback used to send inside data. @@ -20,6 +20,33 @@ pub trait InsideIOSendCallback { /// Convenience type to use as function arguments pub type InsideIOSendCallbackArg = Arc + Send + Sync>; +/// A byte buffer to be sent, may be owned or borrowed. +pub enum CowBytes<'a> { + /// An owned buffer + Owned(Bytes), + /// A borrowed buffer + Borrowed(&'a [u8]), +} + +impl CowBytes<'_> { + /// Convert this buffer into an owned `Bytes`. Cheap if this + /// instance if `::Owned`, but copied if not. + pub fn into_owned(self) -> Bytes { + match self { + CowBytes::Owned(b) => b, + CowBytes::Borrowed(b) => Bytes::copy_from_slice(b), + } + } + + /// Gain access to the underlying byte buffer. + pub fn as_bytes(&self) -> &[u8] { + match self { + CowBytes::Owned(b) => b.as_ref(), + CowBytes::Borrowed(b) => b, + } + } +} + /// Application provided callback used to send outside data. pub trait OutsideIOSendCallback { /// Called when Lightway wishes to send some outside data @@ -30,7 +57,7 @@ pub trait OutsideIOSendCallback { /// [`IOCallbackResult::WouldBlock`]. /// /// This is the same method as [`wolfssl::IOCallbacks::send`]. - fn send(&self, buf: &[u8]) -> IOCallbackResult; + fn send(&self, buf: CowBytes) -> IOCallbackResult; /// Get the peer's [`SocketAddr`] fn peer_addr(&self) -> SocketAddr; diff --git a/lightway-core/src/lib.rs b/lightway-core/src/lib.rs index 3930ef20..d89dfa69 100644 --- a/lightway-core/src/lib.rs +++ b/lightway-core/src/lib.rs @@ -38,7 +38,8 @@ pub use context::{ ServerAuthArg, ServerAuthHandle, ServerAuthResult, ServerContext, ServerContextBuilder, }; pub use io::{ - InsideIOSendCallback, InsideIOSendCallbackArg, OutsideIOSendCallback, OutsideIOSendCallbackArg, + CowBytes, InsideIOSendCallback, InsideIOSendCallbackArg, OutsideIOSendCallback, + OutsideIOSendCallbackArg, }; pub use packet::OutsidePacket; pub use plugin::{ diff --git a/lightway-core/tests/connection.rs b/lightway-core/tests/connection.rs index 50ee400c..aca1ab4e 100644 --- a/lightway-core/tests/connection.rs +++ b/lightway-core/tests/connection.rs @@ -106,7 +106,8 @@ impl TestSock for TestDatagramSock { } impl OutsideIOSendCallback for TestDatagramSock { - fn send(&self, buf: &[u8]) -> IOCallbackResult { + fn send(&self, buf: CowBytes) -> IOCallbackResult { + let buf = buf.as_bytes(); match self.0.try_send(buf) { Ok(nr) => IOCallbackResult::Ok(nr), Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { @@ -156,8 +157,8 @@ impl TestSock for TestStreamSock { } impl OutsideIOSendCallback for TestStreamSock { - fn send(&self, buf: &[u8]) -> IOCallbackResult { - match self.0.try_write(buf) { + fn send(&self, buf: CowBytes) -> IOCallbackResult { + match self.0.try_write(buf.as_bytes()) { Ok(nr) => IOCallbackResult::Ok(nr), Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { IOCallbackResult::WouldBlock diff --git a/lightway-server/src/io/outside/tcp.rs b/lightway-server/src/io/outside/tcp.rs index 9d6fdb5f..d51bb338 100644 --- a/lightway-server/src/io/outside/tcp.rs +++ b/lightway-server/src/io/outside/tcp.rs @@ -4,7 +4,7 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use bytes::BytesMut; use lightway_core::{ - ConnectionType, IOCallbackResult, OutsideIOSendCallback, OutsidePacket, Version, + ConnectionType, CowBytes, IOCallbackResult, OutsideIOSendCallback, OutsidePacket, Version, MAX_OUTSIDE_MTU, }; use socket2::SockRef; @@ -21,8 +21,8 @@ struct TcpStream { } impl OutsideIOSendCallback for TcpStream { - fn send(&self, buf: &[u8]) -> IOCallbackResult { - match self.sock.try_write(buf) { + fn send(&self, buf: CowBytes) -> IOCallbackResult { + match self.sock.try_write(buf.as_bytes()) { Ok(nr) => IOCallbackResult::Ok(nr), Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { IOCallbackResult::WouldBlock diff --git a/lightway-server/src/io/outside/udp.rs b/lightway-server/src/io/outside/udp.rs index 2554108f..d58f2f20 100644 --- a/lightway-server/src/io/outside/udp.rs +++ b/lightway-server/src/io/outside/udp.rs @@ -11,8 +11,8 @@ use bytes::BytesMut; use bytesize::ByteSize; use lightway_app_utils::sockopt::socket_enable_pktinfo; use lightway_core::{ - ConnectionType, Header, IOCallbackResult, OutsideIOSendCallback, OutsidePacket, SessionId, - Version, MAX_OUTSIDE_MTU, + ConnectionType, CowBytes, Header, IOCallbackResult, OutsideIOSendCallback, OutsidePacket, + SessionId, Version, MAX_OUTSIDE_MTU, }; use socket2::{MaybeUninitSlice, MsgHdr, MsgHdrMut, SockAddr, SockRef}; use tokio::io::Interest; @@ -87,9 +87,9 @@ struct UdpSocket { } impl OutsideIOSendCallback for UdpSocket { - fn send(&self, buf: &[u8]) -> IOCallbackResult { + fn send(&self, buf: CowBytes) -> IOCallbackResult { let peer_addr = self.peer_addr.read().unwrap(); - send_to_socket(&self.sock, buf, &peer_addr.1, self.reply_pktinfo) + send_to_socket(&self.sock, buf.as_bytes(), &peer_addr.1, self.reply_pktinfo) } fn peer_addr(&self) -> SocketAddr { From 2cb8c0ef6efd7e063a1268a6a16de995b33b8665 Mon Sep 17 00:00:00 2001 From: Ian Campbell <52475242+xv-ian-c@users.noreply.github.com> Date: Tue, 15 Oct 2024 16:20:18 +0100 Subject: [PATCH 2/2] lightway-server: Use i/o uring for all i/o, not just tun. This does not consistently improve performance but reduces CPU overheads (by around 50%-100% i.e. half to one core) under heavy traffic, which adding perhaps a few hundred Mbps to a speedtest.net download test and making negligible difference to the upload test. It also removes about 1ms from the latency in the same tests. Finally the STDEV across multiple test runs appears to be lower. This appears to be due to a combination of avoiding async runtime overheads, as well as removing various channels/queues in favour of a more direct model of interaction between the ring and the connections. As well as those benefits we are now able to reach the same level of performance with far fewer slots used for the TUN rx path, here we use 64 slots (by default) and reach the same performance as using 1024 previously. The way uring handles blocking vs async for tun devices seems to be non-optimal. In blocking mode things are very slow. In async mode more and more time is spent on bookkeeping and polling, as the number of slots is increased, plus a high level of EAGAIN results (due to a request timing out after multiple failed polls[^0]) which waste time requeueing. This is related to https://github.com/axboe/liburing/issues/886 and https://github.com/axboe/liburing/issues/239. For UDP/TCP sockets io uring behaves well with the socket in blocking mode which avoids processing lots of EAGAIN results. Tuning the slots for each I/O path is a bit of an art (more is definitely not always better) and the sweet spot varies depending on the I/O device, so provide various tunables instead of just splitting the ring evenly. With this there's no real reason to have a very large ring, it's the number of inflight requests which matters. This is specific to the server since it relies on kernel features and correctness(/lack of bugs) which may not be upheld on an arbitrary client system (while it is assumed that server operators have more control over what they run). It is also not portable to non-Linux systems. It is known to work with Linux 6.1 (as found in Debian 12 AKA bookworm). Note that this kernel version contains a bug which causes the `iou-sqp-*` kernel thread to get stuck (unkillable) if the tun is in blocking mode, therefore an option is provided. Enabling that option on a kernel which contains [the fix][] allows equivalent performance with fewer slots on the ring. [^0]: When data becomes available _all_ requests are woken but only one will find data, the rest will see EAGAIN and after a certain number of such events I/O uring will propagate this back to userspace. [the fix]: https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/commit/?id=438b406055cd21105aad77db7938ee4720b09bee --- Cargo.lock | 2 + Cargo.toml | 2 + README.md | 2 +- lightway-app-utils/Cargo.toml | 5 +- lightway-app-utils/src/lib.rs | 3 + lightway-app-utils/src/net.rs | 179 ++++++ lightway-server/Cargo.toml | 5 +- lightway-server/src/args.rs | 60 +- lightway-server/src/io.rs | 278 ++++++++ lightway-server/src/io/ffi.rs | 53 ++ lightway-server/src/io/inside.rs | 12 +- lightway-server/src/io/inside/tun.rs | 251 +++++++- lightway-server/src/io/outside.rs | 8 +- lightway-server/src/io/outside/tcp.rs | 703 ++++++++++++++++----- lightway-server/src/io/outside/udp.rs | 388 ++++++++---- lightway-server/src/io/outside/udp/cmsg.rs | 8 +- lightway-server/src/io/tx.rs | 170 +++++ lightway-server/src/lib.rs | 207 ++++-- lightway-server/src/main.rs | 6 +- tests/Earthfile | 10 +- 20 files changed, 1906 insertions(+), 446 deletions(-) create mode 100644 lightway-app-utils/src/net.rs create mode 100644 lightway-server/src/io/ffi.rs create mode 100644 lightway-server/src/io/tx.rs diff --git a/Cargo.lock b/Cargo.lock index ac98d773..7d973fc3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1330,6 +1330,7 @@ dependencies = [ "ctrlc", "delegate", "educe", + "io-uring", "ipnet", "jsonwebtoken", "libc", @@ -1354,6 +1355,7 @@ dependencies = [ "tracing", "tracing-log", "tracing-subscriber", + "tun", "twelf", ] diff --git a/Cargo.toml b/Cargo.toml index 6afc7e2d..38760725 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ clap = { version = "4.4.7", features = ["derive"] } ctrlc = { version = "3.4.2", features = ["termination"] } delegate = "0.12.0" educe = { version = "0.6.0", default-features = false, features = ["Debug"] } +io-uring = "0.7.0" ipnet = { version = "2.8.0", features = ["serde"]} libc = "0.2.152" lightway-app-utils = { path = "./lightway-app-utils" } @@ -52,3 +53,4 @@ tokio-util = "0.7.10" tracing = "0.1.37" tracing-subscriber = "0.3.17" twelf = { version = "0.15.0", default-features = false, features = ["env", "clap", "yaml"]} +tun = { version = "0.7.1" } diff --git a/README.md b/README.md index 8c43cb02..84f7c081 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ Protocol and design documentation can be found in the Lightway rust implementation currently supports Linux OS. Both x86_64 and arm64 platforms are supported and built as part of CI. -Support for other platforms will be added soon. +Support for other client platforms will be added soon. ## Development steps diff --git a/lightway-app-utils/Cargo.toml b/lightway-app-utils/Cargo.toml index 3a73a13d..c152beef 100644 --- a/lightway-app-utils/Cargo.toml +++ b/lightway-app-utils/Cargo.toml @@ -23,7 +23,7 @@ bytes.workspace = true clap.workspace = true fs-mistrust = { version = "0.8.0", default-features = false } humantime = "2.1.0" -io-uring = { version = "0.7.0", optional = true } +io-uring = { workspace = true, optional = true } ipnet.workspace = true libc.workspace = true lightway-core.workspace = true @@ -38,11 +38,12 @@ tokio-stream = { workspace = true, optional = true } tokio-util.workspace = true tracing.workspace = true tracing-subscriber = { workspace = true, features = ["json"] } -tun = { version = "0.7", features = ["async"] } +tun = { workspace = true, features = ["async"] } [[example]] name = "udprelay" path = "examples/udprelay.rs" +required-features = ["io-uring"] [dev-dependencies] async-trait.workspace = true diff --git a/lightway-app-utils/src/lib.rs b/lightway-app-utils/src/lib.rs index 4e48b6c1..5e5b0215 100644 --- a/lightway-app-utils/src/lib.rs +++ b/lightway-app-utils/src/lib.rs @@ -14,6 +14,9 @@ mod event_stream; mod iouring; mod tun; +mod net; +pub use net::{sockaddr_from_socket_addr, socket_addr_from_sockaddr}; + #[cfg(feature = "tokio")] pub use connection_ticker::{ connection_ticker_cb, ConnectionTicker, ConnectionTickerState, ConnectionTickerTask, Tickable, diff --git a/lightway-app-utils/src/net.rs b/lightway-app-utils/src/net.rs new file mode 100644 index 00000000..b306c6df --- /dev/null +++ b/lightway-app-utils/src/net.rs @@ -0,0 +1,179 @@ +use std::{io, net::SocketAddr}; + +/// Convert from `libc::sockaddr_storage` to `std::net::SocketAddr` +#[allow(unsafe_code)] +pub fn socket_addr_from_sockaddr( + storage: &libc::sockaddr_storage, + len: libc::socklen_t, +) -> io::Result { + match storage.ss_family as libc::c_int { + libc::AF_INET => { + if (len as usize) < std::mem::size_of::() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid argument (inet len)", + )); + } + + // SAFETY: Casting from sockaddr_storage to sockaddr_in is safe since we have validated the len. + let addr = + unsafe { &*(storage as *const libc::sockaddr_storage as *const libc::sockaddr_in) }; + + let ip = u32::from_be(addr.sin_addr.s_addr); + let ip = std::net::Ipv4Addr::from_bits(ip); + let port = u16::from_be(addr.sin_port); + + Ok((ip, port).into()) + } + libc::AF_INET6 => { + if (len as usize) < std::mem::size_of::() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid argument (inet6 len)", + )); + } + // SAFETY: Casting from sockaddr_storage to sockaddr_in6 is safe since we have validated the len. + let addr = unsafe { + &*(storage as *const libc::sockaddr_storage as *const libc::sockaddr_in6) + }; + + let ip = u128::from_be_bytes(addr.sin6_addr.s6_addr); + let ip = std::net::Ipv6Addr::from_bits(ip); + let port = u16::from_be(addr.sin6_port); + + Ok((ip, port).into()) + } + _ => Err(io::Error::new( + std::io::ErrorKind::InvalidInput, + "invalid argument (ss_family)", + )), + } +} + +/// Convert from `std::net::SocketAddr` to `libc::sockaddr_storage`+`libc::socklen_t` +#[allow(unsafe_code)] +pub fn sockaddr_from_socket_addr(addr: SocketAddr) -> (libc::sockaddr_storage, libc::socklen_t) { + // SAFETY: All zeroes is a valid sockaddr_storage + let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; + + let len = match addr { + SocketAddr::V4(v4) => { + let p = &mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in; + // SAFETY: sockaddr_storage is defined to be big enough for any sockaddr_*. + unsafe { + p.write(libc::sockaddr_in { + sin_family: libc::AF_INET as _, + sin_port: v4.port().to_be(), + sin_addr: libc::in_addr { + s_addr: v4.ip().to_bits().to_be(), + }, + sin_zero: Default::default(), + }) + }; + std::mem::size_of::() as libc::socklen_t + } + SocketAddr::V6(v6) => { + let p = &mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in6; + // SAFETY: sockaddr_storage is defined to be big enough for any sockaddr_*. + unsafe { + p.write(libc::sockaddr_in6 { + sin6_family: libc::AF_INET6 as _, + sin6_port: v6.port().to_be(), + sin6_flowinfo: v6.flowinfo().to_be(), + sin6_addr: libc::in6_addr { + s6_addr: v6.ip().to_bits().to_be_bytes(), + }, + sin6_scope_id: v6.scope_id().to_be(), + }) + }; + std::mem::size_of::() as libc::socklen_t + } + }; + + (storage, len) +} + +#[cfg(test)] +mod tests { + #![allow(unsafe_code, clippy::undocumented_unsafe_blocks)] + + use std::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + str::FromStr as _, + }; + + use super::*; + + use test_case::test_case; + + #[test] + fn socket_addr_from_sockaddr_unknown_af() { + // Test assumes these don't match the zero initialized + // libc::sockaddr_storage::ss_family. + assert_ne!(libc::AF_INET, 0); + assert_ne!(libc::AF_INET6, 0); + + let storage = unsafe { std::mem::zeroed() }; + let err = + socket_addr_from_sockaddr(&storage, std::mem::size_of::() as _) + .unwrap_err(); + + assert!(matches!(err.kind(), std::io::ErrorKind::InvalidInput)); + assert!(err.to_string().contains("invalid argument (ss_family)")); + } + + #[test] + fn socket_addr_from_sockaddr_unknown_af_inet_short() { + let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; + storage.ss_family = libc::AF_INET as libc::sa_family_t; + + let err = socket_addr_from_sockaddr( + &storage, + (std::mem::size_of::() - 1) as _, + ) + .unwrap_err(); + + assert!(matches!(err.kind(), std::io::ErrorKind::InvalidInput)); + assert!(err.to_string().contains("invalid argument (inet len)")); + } + + #[test] + fn socket_addr_from_sockaddr_unknown_af_inet6_short() { + let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; + storage.ss_family = libc::AF_INET6 as libc::sa_family_t; + + let err = socket_addr_from_sockaddr( + &storage, + (std::mem::size_of::() - 1) as _, + ) + .unwrap_err(); + + assert!(matches!(err.kind(), std::io::ErrorKind::InvalidInput)); + assert!(err.to_string().contains("invalid argument (inet6 len)")); + } + + #[test] + fn sockaddr_from_socket_addr_inet() { + let socket_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080); + let (storage, len) = sockaddr_from_socket_addr(socket_addr); + assert_eq!(storage.ss_family, libc::AF_INET as libc::sa_family_t); + assert_eq!(len as usize, std::mem::size_of::()); + } + + #[test] + fn sockaddr_from_socket_addr_inet6() { + let socket_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8080); + let (storage, len) = sockaddr_from_socket_addr(socket_addr); + assert_eq!(storage.ss_family, libc::AF_INET6 as libc::sa_family_t); + assert_eq!(len as usize, std::mem::size_of::()); + } + + #[test_case("127.0.0.1:443")] + #[test_case("[::1]:8888")] + fn round_trip(addr: &str) { + let orig = SocketAddr::from_str(addr).unwrap(); + let (storage, len) = sockaddr_from_socket_addr(orig); + let round_tripped = socket_addr_from_sockaddr(&storage, len).unwrap(); + assert_eq!(orig, round_tripped) + } +} diff --git a/lightway-server/Cargo.toml b/lightway-server/Cargo.toml index 02274171..18756df8 100644 --- a/lightway-server/Cargo.toml +++ b/lightway-server/Cargo.toml @@ -9,9 +9,8 @@ license = "AGPL-3.0-only" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["io-uring"] +default = [] debug = ["lightway-core/debug"] -io-uring = ["lightway-app-utils/io-uring"] [lints] workspace = true @@ -26,6 +25,7 @@ clap.workspace = true ctrlc.workspace = true delegate.workspace = true educe.workspace = true +io-uring.workspace = true ipnet.workspace = true jsonwebtoken = "9.3.0" libc.workspace = true @@ -48,6 +48,7 @@ tokio-stream = { workspace = true, features = ["time"] } tracing.workspace = true tracing-log = "0.2.0" tracing-subscriber = { workspace = true, features = ["json"] } +tun.workspace = true twelf.workspace = true [dev-dependencies] diff --git a/lightway-server/src/args.rs b/lightway-server/src/args.rs index f4b87ad7..23786993 100644 --- a/lightway-server/src/args.rs +++ b/lightway-server/src/args.rs @@ -71,13 +71,25 @@ pub struct Config { #[clap(long, default_value_t)] pub enable_pqc: bool, - /// Enable IO-uring interface for Tunnel - #[clap(long, default_value_t)] - pub enable_tun_iouring: bool, - - /// IO-uring submission queue count. Only applicable when - /// `enable_tun_iouring` is `true` - // Any value more than 1024 negatively impact the throughput + /// Total IO-uring submission queue count. + /// + /// Must be larger than the total of: + /// + /// UDP: + /// + /// iouring_tun_rx_count + iouring_udp_rx_count + + /// iouring_tx_count + 1 (cancellation request) + /// + /// TCP: + /// + /// iouring_tun_rx_count + iouring_tx_count + 1 (cancellation + /// request) + 2 * maximum number of connections. + /// + /// Each connection actually uses up to 3 slots, a persistent + /// recv request and on demand slots for TX and cancellation + /// (teardown). + /// + /// There is no downside to setting this much larger. #[clap(long, default_value_t = 1024)] pub iouring_entry_count: usize, @@ -87,6 +99,36 @@ pub struct Config { #[clap(long, default_value = "100ms")] pub iouring_sqpoll_idle_time: Duration, + /// Number of concurrent TUN device read requests to issue to + /// IO-uring. Setting this too large may negatively impact + /// performance. + #[clap(long, default_value_t = 64)] + pub iouring_tun_rx_count: u32, + + /// Configure TUN device in blocking mode. This can allow + /// equivalent performance with fewer `ìouring-tun-rx-count` + /// entries but can significantly harm performance on some kernels + /// where the kernel does not indicate that the tun device handles + /// `FMODE_NOWAIT`. + /// + /// If blocking mode is enabled then `iouring_tun_rx_count` may be + /// set much lower. + /// + /// This was fixed by + /// which was part of v6.4-rc1. + #[clap(long, default_value_t = false)] + pub iouring_tun_blocking: bool, + + /// Number of concurrent UDP socket recvmsg requests to issue to + /// IO-uring. + #[clap(long, default_value_t = 32)] + pub iouring_udp_rx_count: u32, + + /// Maximum number of concurrent UDP + TUN sendmsg/write requests + /// to issue to IO-uring. + #[clap(long, default_value_t = 512)] + pub iouring_tx_count: u32, + /// Log format #[clap(long, value_enum, default_value_t = LogFormat::Full)] pub log_format: LogFormat, @@ -111,6 +153,10 @@ pub struct Config { #[clap(long, default_value_t = ByteSize::mib(15))] pub udp_buffer_size: ByteSize, + /// Set UDP buffer size. Default value is 256 KiB. + #[clap(long, default_value_t = ByteSize::kib(256))] + pub tcp_buffer_size: ByteSize, + /// Enable WolfSSL debug logging #[cfg(feature = "debug")] #[clap(long)] diff --git a/lightway-server/src/io.rs b/lightway-server/src/io.rs index c32ee10e..36f6d332 100644 --- a/lightway-server/src/io.rs +++ b/lightway-server/src/io.rs @@ -1,2 +1,280 @@ pub(crate) mod inside; pub(crate) mod outside; + +mod ffi; +mod tx; + +use std::{ + os::fd::{AsRawFd, OwnedFd, RawFd}, + sync::{Arc, Mutex}, + time::Duration, +}; + +use anyhow::{anyhow, Context as _, Result}; +use io_uring::{ + cqueue::Entry as CEntry, + opcode, + squeue::Entry as SEntry, + types::{Fd, Fixed}, + Builder, IoUring, SubmissionQueue, Submitter, +}; + +use ffi::{iovec, msghdr}; +pub use tx::TxQueue; + +/// Convenience function to handle errors in a uring result codes +/// (which are negative errno codes). +fn io_uring_res(res: i32) -> std::io::Result { + if res < 0 { + Err(std::io::Error::from_raw_os_error(-res)) + } else { + Ok(res) + } +} + +/// An I/O source pushing requests to a uring instance +pub(crate) trait UringIoSource: Send { + /// Return the raw file descriptor. This will be registered as an + /// fd with the ring, allowing the use of io_uring::types::Fixed. + fn as_raw_fd(&self) -> RawFd; + + /// Push the initial set of requests to `sq`. + fn push_initial_ops(&mut self, sq: &mut io_uring::SubmissionQueue) -> Result<()>; + + /// Complete an rx request + fn complete_rx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()>; + + /// Complete a tx request + fn complete_tx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()>; +} + +pub(crate) enum OutsideIoSource { + Udp(outside::udp::UdpServer), + Tcp(outside::tcp::TcpServer), +} + +// Avoiding `dyn`amic dispatch is a small performance win. +impl UringIoSource for OutsideIoSource { + fn as_raw_fd(&self) -> RawFd { + match self { + OutsideIoSource::Udp(udp) => udp.as_raw_fd(), + OutsideIoSource::Tcp(tcp) => tcp.as_raw_fd(), + } + } + + fn push_initial_ops(&mut self, sq: &mut io_uring::SubmissionQueue) -> Result<()> { + match self { + OutsideIoSource::Udp(udp) => udp.push_initial_ops(sq), + OutsideIoSource::Tcp(tcp) => tcp.push_initial_ops(sq), + } + } + + fn complete_rx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()> { + match self { + OutsideIoSource::Udp(udp) => udp.complete_rx(sq, cqe, idx), + OutsideIoSource::Tcp(tcp) => tcp.complete_rx(sq, cqe, idx), + } + } + + fn complete_tx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()> { + match self { + OutsideIoSource::Udp(udp) => udp.complete_tx(sq, cqe, idx), + OutsideIoSource::Tcp(tcp) => tcp.complete_tx(sq, cqe, idx), + } + } +} + +pub(crate) struct Loop { + ring: IoUring, + + tx: Arc>, + + cancel_buf: u8, + + outside: OutsideIoSource, + inside: inside::tun::Tun, +} + +impl Loop { + /// Use for outside IO requests, `self.outside.as_raw_fd` will be registered in this slot. + const FIXED_OUTSIDE_FD: Fixed = Fixed(0); + /// Use for inside IO requests, `self.inside.as_raw_fd` will be registered in this slot. + const FIXED_INSIDE_FD: Fixed = Fixed(1); + + /// Masks the bits used by `*_USER_DATA_BASE` + const USER_DATA_TYPE_MASK: u64 = 0xe000_0000_0000_0000; + + /// Indexes in this range will result in a call to `self.outside.complete_rx` + const OUTSIDE_RX_USER_DATA_BASE: u64 = 0xc000_0000_0000_0000; + /// Indexes in this range will result in a call to `self.outside.complete_tx` + const OUTSIDE_TX_USER_DATA_BASE: u64 = 0x8000_0000_0000_0000; + + /// Indexes in this range will result in a call to `self.inside.complete_rx` + const INSIDE_RX_USER_DATA_BASE: u64 = 0x4000_0000_0000_0000; + /// Indexes in this range will result in a call to `self.inside.complete_tx` + const INSIDE_TX_USER_DATA_BASE: u64 = 0x2000_0000_0000_0000; + + /// Indexes in this range are used by `Loop` itself. + const CONTROL_USER_DATA_BASE: u64 = 0x0000_0000_0000_0000; + + /// A read request on the cancellation fd (used to exit the io loop) + const CANCEL_USER_DATA: u64 = Self::CONTROL_USER_DATA_BASE + 1; + + /// Return user data for a particular outside rx index. + fn outside_rx_user_data(idx: u32) -> u64 { + Self::OUTSIDE_RX_USER_DATA_BASE + (idx as u64) + } + + /// Return user data for a particular inside rx index. + fn inside_rx_user_data(idx: u32) -> u64 { + Self::INSIDE_RX_USER_DATA_BASE + (idx as u64) + } + + /// Return user data for a particular inside tx index. + fn inside_tx_user_data(idx: u32) -> u64 { + Self::INSIDE_TX_USER_DATA_BASE + (idx as u64) + } + + /// Return user data for a particular outside tx index. + fn outside_tx_user_data(idx: u32) -> u64 { + Self::OUTSIDE_TX_USER_DATA_BASE + (idx as u64) + } + + pub(crate) fn new( + ring_size: usize, + sqpoll_idle_time: Duration, + tx: Arc>, + outside: OutsideIoSource, + inside: inside::tun::Tun, + ) -> Result { + tracing::info!(ring_size, "creating IoUring"); + let mut builder: Builder = IoUring::builder(); + + builder.dontfork(); + + if sqpoll_idle_time.as_millis() > 0 { + let idle_time: u32 = sqpoll_idle_time + .as_millis() + .try_into() + .with_context(|| "invalid sqpoll idle time")?; + // This setting makes CPU go 100% when there is continuous traffic + builder.setup_sqpoll(idle_time); // Needs 5.13 + } + + let ring = builder + .build(ring_size as u32) + .inspect_err(|e| tracing::error!("iouring setup failed: {e}"))?; + + Ok(Self { + ring, + tx, + cancel_buf: 0, + outside, + inside, + }) + } + + pub(crate) fn run(mut self, cancel: OwnedFd) -> Result<()> { + let (submitter, mut sq, mut cq) = self.ring.split(); + + submitter.register_files(&[self.outside.as_raw_fd(), self.inside.as_raw_fd()])?; + + let sqe = opcode::Read::new( + Fd(cancel.as_raw_fd()), + &mut self.cancel_buf as *mut _, + std::mem::size_of_val(&self.cancel_buf) as _, + ) + .build() + .user_data(Self::CANCEL_USER_DATA); + + #[allow(unsafe_code)] + // SAFETY: The buffer is owned by `self.cancel_buf` and `self` is owned + unsafe { + sq.push(&sqe)? + }; + + self.outside.push_initial_ops(&mut sq)?; + self.inside.push_initial_ops(&mut sq)?; + sq.sync(); + + loop { + let _ = submitter.submit_and_wait(1)?; + + cq.sync(); + + for cqe in &mut cq { + let user_data = cqe.user_data(); + + match user_data & Self::USER_DATA_TYPE_MASK { + Self::CONTROL_USER_DATA_BASE => { + match user_data - Self::CONTROL_USER_DATA_BASE { + Self::CANCEL_USER_DATA => { + let res = cqe.result(); + tracing::debug!(?res, "Uring cancelled"); + return Ok(()); + } + idx => { + return Err(anyhow!( + "Unknown control data {user_data:016x} => {idx:016x}" + )) + } + } + } + Self::OUTSIDE_RX_USER_DATA_BASE => { + self.outside.complete_rx( + &mut sq, + cqe, + (user_data - Self::OUTSIDE_RX_USER_DATA_BASE) as u32, + )?; + } + Self::OUTSIDE_TX_USER_DATA_BASE => { + self.outside.complete_tx( + &mut sq, + cqe, + (user_data - Self::OUTSIDE_TX_USER_DATA_BASE) as u32, + )?; + } + + Self::INSIDE_RX_USER_DATA_BASE => { + self.inside.complete_rx( + &mut sq, + cqe, + (user_data - Self::INSIDE_RX_USER_DATA_BASE) as u32, + )?; + } + Self::INSIDE_TX_USER_DATA_BASE => { + self.inside.complete_tx( + &mut sq, + cqe, + (user_data - Self::INSIDE_TX_USER_DATA_BASE) as u32, + )?; + } + + _ => unreachable!(), + } + + self.tx.lock().unwrap().drain(&submitter, &mut sq)?; + } + } + } +} diff --git a/lightway-server/src/io/ffi.rs b/lightway-server/src/io/ffi.rs new file mode 100644 index 00000000..ddd0ad42 --- /dev/null +++ b/lightway-server/src/io/ffi.rs @@ -0,0 +1,53 @@ +#![allow(unsafe_code)] +#![allow(non_camel_case_types, reason = "Using POSIX/libc naming")] + +/// Marker for types which are usable with syscalls +/// +/// # Safety +/// +/// Implement only for types containing raw pointers which are +/// passed to syscalls where the concept of Sync/Send is orthogonal to +/// Rust's model. +pub(super) unsafe trait IsSyscallSafe {} + +// SAFETY: iovec is used with syscalls +unsafe impl IsSyscallSafe for libc::iovec {} +// SAFETY: msghdr is used with syscalls +unsafe impl IsSyscallSafe for libc::msghdr {} + +pub(super) struct SyscallSafe(T); + +impl SyscallSafe { + pub fn new(t: T) -> Self { + Self(t) + } + + pub fn as_mut_ptr(&mut self) -> *mut T { + &mut self.0 as *mut T + } +} + +impl std::ops::Deref for SyscallSafe { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for SyscallSafe { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +// SAFETY: T must be e.g. a libc type which contains raw pointers for syscall use. +// The `pub` aliases below all satisfy this. +unsafe impl Send for SyscallSafe {} + +// SAFETY: T must be e.g. a libc type which contains raw pointers for syscall use. +// The `pub` aliases below all satisfy this. +unsafe impl Sync for SyscallSafe {} + +pub type iovec = SyscallSafe; +pub type msghdr = SyscallSafe; diff --git a/lightway-server/src/io/inside.rs b/lightway-server/src/io/inside.rs index decf5c57..430cbcab 100644 --- a/lightway-server/src/io/inside.rs +++ b/lightway-server/src/io/inside.rs @@ -2,14 +2,4 @@ pub(crate) mod tun; pub(crate) use tun::Tun; -use crate::connection::ConnectionState; -use async_trait::async_trait; -use lightway_core::{IOCallbackResult, InsideIOSendCallbackArg}; -use std::sync::Arc; - -#[async_trait] -pub(crate) trait InsideIO: Sync + Send { - async fn recv_buf(&self) -> IOCallbackResult; - - fn into_io_send_callback(self: Arc) -> InsideIOSendCallbackArg; -} +use super::{io_uring_res, Loop, TxQueue, UringIoSource}; diff --git a/lightway-server/src/io/inside/tun.rs b/lightway-server/src/io/inside/tun.rs index 909b8355..ef96a7da 100644 --- a/lightway-server/src/io/inside/tun.rs +++ b/lightway-server/src/io/inside/tun.rs @@ -1,56 +1,213 @@ -use crate::{io::inside::InsideIO, metrics}; +//! Tun UringIoSource +//! +//! Uses uring indexes: +//! +//! Loop::inside_rx_user_data: +//! - 0..Tun::rx.len(): A set of recv requests +//! +//! Loop::inside_tx_user_data: +//! - Managed by TxQueue + +use crate::ip_manager::IpManager; +use crate::metrics; + +use super::{io_uring_res, Loop, TxQueue, UringIoSource}; use crate::connection::ConnectionState; -use anyhow::Result; -use async_trait::async_trait; + +use anyhow::{Context as _, Result}; use bytes::BytesMut; -use lightway_app_utils::{Tun as AppUtilsTun, TunConfig}; +use io_uring::opcode; use lightway_core::{ - ipv4_update_source, IOCallbackResult, InsideIOSendCallback, InsideIOSendCallbackArg, + ipv4_update_destination, ipv4_update_source, ConnectionError, IOCallbackResult, + InsideIOSendCallback, InsideIOSendCallbackArg, }; -use std::os::fd::{AsRawFd, RawFd}; -use std::sync::Arc; -use std::time::Duration; +use pnet::packet::ipv4::Ipv4Packet; +use std::net::Ipv4Addr; +use std::os::fd::{AsRawFd as _, RawFd}; +use std::sync::{Arc, Mutex}; +use tun::{AbstractDevice as _, Configuration as TunConfig, Device as TunDevice}; + +pub(crate) struct Tun { + tun: TunDevice, + lightway_client_ip: Ipv4Addr, + ip_manager: Arc, -pub(crate) struct Tun(AppUtilsTun); + tx_queue: Arc>, + + mtu: usize, + + rx: Vec, +} impl Tun { - pub async fn new(tun: TunConfig, iouring: Option<(usize, Duration)>) -> Result { - let tun = match iouring { - Some((ring_size, sqpoll_idle_time)) => { - AppUtilsTun::iouring(tun, ring_size, sqpoll_idle_time).await? - } - None => AppUtilsTun::direct(tun).await?, - }; - Ok(Tun(tun)) + pub fn new( + nr_slots: u32, + blocking: bool, + mut tun: TunConfig, + lightway_client_ip: Ipv4Addr, + ip_manager: Arc, + tx_queue: Arc>, + ) -> Result { + tracing::info!("Tun with {nr_slots} slots (blocking: {blocking})"); + + tun.platform_config(|cfg| { + cfg.napi(true); + }); + + let tun = tun::create(&tun)?; + if !blocking { + tun.set_nonblock()?; + } + + let mtu = tun.mtu()? as usize; + + let rx = (0..nr_slots).map(|_| BytesMut::new()).collect(); + + Ok(Tun { + tun, + lightway_client_ip, + ip_manager, + tx_queue, + mtu, + rx, + }) + } + + pub fn inside_io_sender(&self) -> InsideIOSendCallbackArg { + Arc::new(TunInsideIO::new(self.tx_queue.clone(), self)) + } + + fn push_rx(&mut self, sq: &mut io_uring::SubmissionQueue, idx: u32) -> Result<()> { + let buf = &mut self.rx[idx as usize]; + + // Recover full capacity + buf.clear(); + buf.reserve(self.mtu); + + let sqe = opcode::Read::new( + Loop::FIXED_INSIDE_FD, + buf.as_mut_ptr() as *mut _, + buf.capacity() as _, + ) + .build() + .user_data(Loop::inside_rx_user_data(idx)); + + #[allow(unsafe_code)] + // SAFETY: The buffer is owned by `self.rx` and `self` is owned by the `io::Loop` + unsafe { + sq.push(&sqe)?; + } + + sq.sync(); + + Ok(()) } } -impl AsRawFd for Tun { +impl UringIoSource for Tun { fn as_raw_fd(&self) -> RawFd { - self.0.as_raw_fd() + self.tun.as_raw_fd() } -} -#[async_trait] -impl InsideIO for Tun { - async fn recv_buf(&self) -> IOCallbackResult { - match self.0.recv_buf().await { - IOCallbackResult::Ok(buf) => { - metrics::tun_to_client(buf.len()); - IOCallbackResult::Ok(buf) + fn push_initial_ops(&mut self, sq: &mut io_uring::SubmissionQueue) -> Result<()> { + for idx in 0..self.rx.len() as u32 { + self.push_rx(sq, idx)? + } + Ok(()) + } + + fn complete_rx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()> { + let res = match io_uring_res(cqe.result()) { + Ok(res) => res, + Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { + self.push_rx(sq, idx)?; + return Ok(()); } - e => e, + Err(err) => return Err(err).with_context(|| "inside read completion"), + }; + + let buf = &mut self.rx[idx as usize]; + + metrics::tun_to_client(res as usize); + + #[allow(unsafe_code)] + // SAFETY: We rely on recv_from giving us the correct size + unsafe { + buf.set_len(res as usize); } + + // Find connection based on client ip (dest ip) and forward packet + let packet = Ipv4Packet::new(buf.as_ref()); + let Some(packet) = packet else { + eprintln!("Invalid inside packet size (less than Ipv4 header)!"); + // Queue another recv + self.push_rx(sq, idx)?; + return Ok(()); + }; + let conn = self.ip_manager.find_connection(packet.get_destination()); + + // Update destination IP address to client's ip + ipv4_update_destination(buf.as_mut(), self.lightway_client_ip); + + if let Some(conn) = conn { + match conn.inside_data_received(buf) { + Ok(()) => {} + Err(ConnectionError::InvalidState) => { + // Skip forwarding packet when offline + metrics::tun_rejected_packet_invalid_state(); + } + Err(ConnectionError::InvalidInsidePacket(_)) => { + // Skip processing invalid packet + metrics::tun_rejected_packet_invalid_inside_packet(); + } + Err(err) => { + let fatal = err.is_fatal(conn.connection_type()); + metrics::tun_rejected_packet_invalid_other(fatal); + if fatal { + conn.handle_end_of_stream(); + return Ok(()); + } + } + } + } else { + metrics::tun_rejected_packet_no_connection(); + }; + + // Queue another recv + self.push_rx(sq, idx)?; + + Ok(()) } - fn into_io_send_callback(self: Arc) -> InsideIOSendCallbackArg { - self + fn complete_tx( + &mut self, + _sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()> { + let _ = self.tx_queue.lock().unwrap().complete(cqe, idx); + Ok(()) } } -impl InsideIOSendCallback for Tun { +pub(crate) struct TunInsideIO(Arc>, usize); + +impl TunInsideIO { + pub(crate) fn new(queue: Arc>, tun: &Tun) -> Self { + Self(queue, tun.mtu) + } +} + +impl InsideIOSendCallback for TunInsideIO { fn send(&self, mut buf: BytesMut, state: &mut ConnectionState) -> IOCallbackResult { + let len = buf.len(); + let Some(client_ip) = state.internal_ip else { metrics::tun_rejected_packet_no_client_ip(); // Ip address not found, dropping the packet @@ -58,11 +215,37 @@ impl InsideIOSendCallback for Tun { }; ipv4_update_source(buf.as_mut(), client_ip); - metrics::tun_from_client(buf.len()); - self.0.try_send(buf) + metrics::tun_from_client(len); + + let buf = buf.freeze(); + + let mut tx_queue = self.0.lock().unwrap(); + + let Some((slot, state)) = tx_queue.take_slot() else { + return IOCallbackResult::WouldBlock; + }; + + let sqe = opcode::Write::new( + Loop::FIXED_INSIDE_FD, + buf.as_ptr() as *mut _, + buf.len() as _, + ) + .build(); + + state.buf = Some(buf); + + #[allow(unsafe_code)] + // SAFETY: + // - slot was optained from take_slot above + // - The buffer is owned by `state` and which is owned by the `TxRing` + unsafe { + tx_queue.push_inside_slot(slot, sqe) + }; + + IOCallbackResult::Ok(len) } fn mtu(&self) -> usize { - self.0.mtu() + self.1 } } diff --git a/lightway-server/src/io/outside.rs b/lightway-server/src/io/outside.rs index e233a80f..f4168d82 100644 --- a/lightway-server/src/io/outside.rs +++ b/lightway-server/src/io/outside.rs @@ -4,10 +4,4 @@ pub(crate) mod udp; pub(crate) use tcp::TcpServer; pub(crate) use udp::UdpServer; -use anyhow::Result; -use async_trait::async_trait; - -#[async_trait] -pub(crate) trait Server { - async fn run(&mut self) -> Result<()>; -} +use super::{io_uring_res, iovec, msghdr, Loop, TxQueue, UringIoSource}; diff --git a/lightway-server/src/io/outside/tcp.rs b/lightway-server/src/io/outside/tcp.rs index d51bb338..0994fea8 100644 --- a/lightway-server/src/io/outside/tcp.rs +++ b/lightway-server/src/io/outside/tcp.rs @@ -1,231 +1,596 @@ -use std::{net::SocketAddr, sync::Arc}; +//! TcpServer UringIoSource +//! +//! Uses uring indexes: +//! +//! Loop::outside_rx_user_data: +//! - TcpServer::ACCEPT_IDX: +//! The accept request. +//! - The fd for a connection (positive i32): +//! The RX request for that connection. +//! - The fd for a connection (positive i32) + TcpServer::RX_CANCEL_IDX_BIT: +//! A cancellation request for that connection +//! +//! Loop::outside_tx_user_data: +//! - The fd for a connection (positive i32): +//! The TX request for that connection. -use anyhow::{anyhow, Result}; -use async_trait::async_trait; +use std::{ + collections::HashMap, + net::{SocketAddr, TcpStream}, + os::fd::{AsRawFd, FromRawFd as _, RawFd}, + sync::{Arc, Mutex}, +}; + +use anyhow::{anyhow, Context as _, Result}; use bytes::BytesMut; +use bytesize::ByteSize; +use io_uring::{ + opcode, + types::{CancelBuilder, Fd}, +}; +use lightway_app_utils::socket_addr_from_sockaddr; use lightway_core::{ ConnectionType, CowBytes, IOCallbackResult, OutsideIOSendCallback, OutsidePacket, Version, - MAX_OUTSIDE_MTU, }; -use socket2::SockRef; -use tokio::io::AsyncReadExt as _; -use tracing::{debug, info, instrument, warn}; +use tracing::{debug, info, warn}; -use crate::{connection_manager::ConnectionManager, metrics}; +use crate::{connection::Connection, connection_manager::ConnectionManager, metrics}; -use super::Server; +use super::{io_uring_res, Loop, TxQueue, UringIoSource}; -struct TcpStream { - sock: Arc, - peer_addr: SocketAddr, +enum ConnectionPhase { + ProxyInitial { + local_addr: SocketAddr, + }, + Proxy { + local_addr: SocketAddr, + rest: usize, + }, + Connected { + conn: Arc, + buffer: Arc>, + }, } -impl OutsideIOSendCallback for TcpStream { - fn send(&self, buf: CowBytes) -> IOCallbackResult { - match self.sock.try_write(buf.as_bytes()) { - Ok(nr) => IOCallbackResult::Ok(nr), - Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { - IOCallbackResult::WouldBlock +struct ConnectionState { + sock: TcpStream, + rx_buf: BytesMut, + tx_buffer_size: usize, + phase: ConnectionPhase, +} + +impl ConnectionState { + const RX_BUFFER_SIZE: usize = 15 * 1024 * 1024; // 15M + + fn push_rx(&mut self, sq: &mut io_uring::SubmissionQueue) -> Result<()> { + use ConnectionPhase::*; + let (buf, len) = match &mut self.phase { + ProxyInitial { .. } => (self.rx_buf.as_mut_ptr(), 16), + Proxy { rest, .. } => (self.rx_buf[16..].as_mut_ptr(), *rest), + Connected { .. } => { + // Recover full capacity + self.rx_buf.clear(); + self.rx_buf.reserve(Self::RX_BUFFER_SIZE); + (self.rx_buf.as_mut_ptr(), self.rx_buf.capacity()) } - Err(err) => IOCallbackResult::Err(err), - } - } + }; + let fd = self.sock.as_raw_fd(); - fn peer_addr(&self) -> SocketAddr { - self.peer_addr + let sqe = opcode::Recv::new(Fd(fd), buf, len as _) + .build() + .user_data(Loop::outside_rx_user_data(fd as u32)); + + #[allow(unsafe_code)] + // SAFETY: The buffer is owned by `self` and `self` is owned by `TcpServer::fd_map` + unsafe { + sq.push(&sqe)? + }; + + sq.sync(); + + Ok(()) } -} -async fn handle_proxy_protocol(sock: &mut tokio::net::TcpStream) -> Result { - use ppp::v2::{Header, ParseError}; + fn push_cancel(&mut self, sq: &mut io_uring::SubmissionQueue) -> Result<()> { + let fd = self.sock.as_raw_fd(); + info!(fd, "Cancelling"); + let builder = CancelBuilder::fd(Fd(fd)).all(); + let sqe = opcode::AsyncCancel2::new(builder) + .build() + .user_data(Loop::outside_rx_user_data( + fd as u32 + TcpServer::RX_CANCEL_IDX_BIT, + )); - // https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt §2.2 - const MINIMUM_LENGTH: usize = 16; + #[allow(unsafe_code)] + // SAFETY: The cancel sqe is well formed above + unsafe { + sq.push(&sqe)? + }; - let mut header: Vec = [0; MINIMUM_LENGTH].into(); - if let Err(err) = sock.read_exact(&mut header[..MINIMUM_LENGTH]).await { - return Err(anyhow!(err).context("Failed to read initial PROXY header")); - }; - let rest = match Header::try_from(&header[..]) { - // Failure tells us exactly how many more bytes are required. - Err(ParseError::Partial(_, rest)) => rest, + sq.sync(); - Ok(_) => { - // The initial 16 bytes is never enough to actually succeed. - return Err(anyhow!("Unexpectedly parsed initial PROXY header")); - } - Err(err) => { - return Err(anyhow!(err).context("Failed to parse initial PROXY header")); + Ok(()) + } + + fn complete_tx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + ) -> Result<()> { + match &mut self.phase { + // Nothing to do for either of these cases. + ConnectionPhase::ProxyInitial { .. } | ConnectionPhase::Proxy { .. } => Ok(()), + ConnectionPhase::Connected { + conn: _conn, + buffer, + } => buffer.lock().unwrap().complete_tx(sq, cqe), } - }; + } - header.resize(MINIMUM_LENGTH + rest, 0); + fn complete_rx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + tx_queue: &Arc>, + conn_manager: &Arc, + ) -> Result<()> { + use ppp::v2::{Header, ParseError}; + use ConnectionPhase::*; - if let Err(err) = sock.read_exact(&mut header[MINIMUM_LENGTH..]).await { - return Err(anyhow!(err).context("Failed to read remainder of PROXY header")); - }; + let res = io_uring_res(cqe.result()).with_context(|| "outside recv completion")?; - let header = match Header::try_from(&header[..]) { - Ok(h) => h, - Err(err) => { - return Err(anyhow!(err).context("Failed to parse complete PROXY header")); - } - }; + match &mut self.phase { + ProxyInitial { local_addr } => { + assert!(16 == res); + #[allow(unsafe_code)] + // SAFETY: We rely on recv_from giving us the correct size + unsafe { + self.rx_buf.set_len(res as usize); + } - let addr = match header.addresses { - ppp::v2::Addresses::Unspecified => { - return Err(anyhow!("Unspecified PROXY connection")); - } - ppp::v2::Addresses::IPv4(addr) => { - SocketAddr::new(addr.source_address.into(), addr.source_port) - } - ppp::v2::Addresses::IPv6(_) => { - return Err(anyhow!("IPv6 PROXY connection")); - } - ppp::v2::Addresses::Unix(_) => { - return Err(anyhow!("Unix PROXY connection")); - } - }; - Ok(addr) -} + let rest = match Header::try_from(&self.rx_buf[..]) { + // Failure tells us exactly how many more bytes are required. + Err(ParseError::Partial(_, rest)) => rest, -#[instrument(level = "trace", skip_all)] -async fn handle_connection( - mut sock: tokio::net::TcpStream, - mut peer_addr: SocketAddr, - local_addr: SocketAddr, - conn_manager: Arc, - proxy_protocol: bool, -) { - if proxy_protocol { - peer_addr = match handle_proxy_protocol(&mut sock).await { - Ok(real_addr) => real_addr, - Err(err) => { - debug!(?err, "Failed to process PROXY header"); - metrics::connection_accept_proxy_header_failed(); - return; + Ok(_) => { + // The initial 16 bytes is never enough to actually succeed. + return Err(anyhow!("Unexpectedly parsed initial PROXY header")); + } + Err(err) => { + return Err(anyhow!(err).context("Failed to parse initial PROXY header")); + } + }; + + self.phase = Proxy { + local_addr: *local_addr, + rest, + } } - }; - } + Proxy { local_addr, rest } => { + assert!(*rest == res as usize); + #[allow(unsafe_code)] + // SAFETY: We rely on recv_from giving us the correct size + // We read 16 bytes in state ProxyInitial + unsafe { + self.rx_buf.set_len((res + 16) as usize); + } + let header = match Header::try_from(&self.rx_buf[..]) { + Ok(h) => h, + Err(err) => { + return Err(anyhow!(err).context("Failed to parse complete PROXY header")); + } + }; - let sock = Arc::new(sock); - - let outside_io = Arc::new(TcpStream { - sock: sock.clone(), - peer_addr, - }); - // TCP has no version indication, default to the minimum - // supported version. - let Ok(conn) = - conn_manager.create_streaming_connection(Version::MINIMUM, local_addr, outside_io) - else { - return; - }; - - // We no longer need to hold this reference. - drop(conn_manager); - - let mut buf = BytesMut::with_capacity(MAX_OUTSIDE_MTU); - let err: anyhow::Error = loop { - // Recover full capacity - buf.clear(); - buf.reserve(MAX_OUTSIDE_MTU); - if let Err(e) = sock.readable().await { - break anyhow!(e).context("Sock readable error"); - } + let peer_addr = match header.addresses { + ppp::v2::Addresses::Unspecified => { + return Err(anyhow!("Unspecified PROXY connection")); + } + ppp::v2::Addresses::IPv4(addr) => { + SocketAddr::new(addr.source_address.into(), addr.source_port) + } + ppp::v2::Addresses::IPv6(_) => { + return Err(anyhow!("IPv6 PROXY connection")); + } + ppp::v2::Addresses::Unix(_) => { + return Err(anyhow!("Unix PROXY connection")); + } + }; - match sock.try_read_buf(&mut buf) { - Ok(0) => { - // EOF - break anyhow!("End of stream"); + let buffer = + TcpSocketBuffer::new(tx_queue.clone(), self.tx_buffer_size, &self.sock); + let outside_io = Arc::new(TcpSocket { + buffer: buffer.clone(), + peer_addr, + }); + let conn = conn_manager.create_streaming_connection( + Version::MINIMUM, + *local_addr, + outside_io, + )?; + self.phase = ConnectionPhase::Connected { conn, buffer } } - Ok(_nr) => {} - Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { - // Spuriously failed to read, keep waiting - continue; + Connected { + conn, + buffer: _buffer, + } => { + if res == 0 { + // EOF + conn.handle_end_of_stream(); + return Err(anyhow!("End of stream")); + } + + #[allow(unsafe_code)] + // SAFETY: We rely on recv_from giving us the correct size + unsafe { + self.rx_buf.set_len(res as usize); + } + let pkt = OutsidePacket::Wire(&mut self.rx_buf, ConnectionType::Stream); + if let Err(err) = conn.outside_data_received(pkt) { + warn!("Failed to process outside data: {err}"); + if conn.handle_outside_data_error(&err).is_break() { + return Err(anyhow!(err).context("Outside data fatal error")); + } + } } - Err(err) => break anyhow!(err).context("TCP read error"), }; + self.push_rx(sq)?; + Ok(()) + } +} + +pub(in super::super) struct TcpSocketBuffer { + tx_queue: Arc>, + fd: Fd, + // We double buffer the tx. + tx_in_flight: BytesMut, + tx_buffer: BytesMut, + tx_buffer_size: usize, +} + +impl TcpSocketBuffer { + fn new( + tx_queue: Arc>, + tx_buffer_size: usize, + sock: &impl AsRawFd, + ) -> Arc> { + Arc::new(Mutex::new(TcpSocketBuffer { + tx_queue, + fd: Fd(sock.as_raw_fd()), + tx_in_flight: BytesMut::new(), + tx_buffer: BytesMut::new(), + tx_buffer_size, + })) + } + + fn push_tx(&mut self) { + let mut tx_queue = self.tx_queue.lock().unwrap(); + let len = self.tx_in_flight.len(); + + let sqe = opcode::Send::new(self.fd, self.tx_in_flight.as_ptr() as *const _, len as _) + .flags(libc::MSG_WAITALL) + .build() + .user_data(Loop::outside_tx_user_data(self.fd.0 as u32)); + + #[allow(unsafe_code)] + // SAFETY: + // - The buffer is owned by `self` and which is owned by the connection and ultimately by `TcpServer::fd_map` + unsafe { + tx_queue.push(sqe) + }; + } - let pkt = OutsidePacket::Wire(&mut buf, ConnectionType::Stream); - if let Err(err) = conn.outside_data_received(pkt) { - warn!("Failed to process outside data: {err}"); - if conn.handle_outside_data_error(&err).is_break() { - break anyhow!(err).context("Outside data fatal error"); + fn send(&mut self, buf: CowBytes) -> IOCallbackResult { + let bytes = buf.as_bytes(); + + if !self.tx_in_flight.is_empty() { + // tx_buffer_size is not a strict limit, but once we have + // exceeded it we stop adding more. + if self.tx_buffer.len() > self.tx_buffer_size { + return IOCallbackResult::WouldBlock; } + + self.tx_buffer.extend_from_slice(bytes); + return IOCallbackResult::Ok(bytes.len()); } - }; - conn.handle_end_of_stream(); + self.tx_in_flight.extend_from_slice(bytes); + self.push_tx(); + + IOCallbackResult::Ok(bytes.len()) + } + + pub fn complete_tx( + &mut self, + _sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + ) -> Result<()> { + let res = io_uring_res(cqe.result()).with_context(|| "outside send completion")? as usize; + + // We use MSG_WAITALL so this should not happen + assert!(res == self.tx_in_flight.len(), "Unexpected short send"); + + self.tx_in_flight.clear(); + + std::mem::swap(&mut self.tx_buffer, &mut self.tx_in_flight); + + if !self.tx_in_flight.is_empty() { + self.push_tx(); + } + + Ok(()) + } +} + +struct TcpSocket { + buffer: Arc>, + peer_addr: SocketAddr, +} + +impl OutsideIOSendCallback for TcpSocket { + fn send(&self, buf: CowBytes) -> IOCallbackResult { + self.buffer.lock().unwrap().send(buf) + } - info!("Connection closed: {:?}", err); + fn peer_addr(&self) -> SocketAddr { + self.peer_addr + } } pub(crate) struct TcpServer { conn_manager: Arc, - sock: Arc, + sock: Arc, + tx_queue: Arc>, + tx_buffer_size: usize, proxy_protocol: bool, + + // Buffers passed to opcode::Accept + accept_addr: Box<(libc::sockaddr_storage, libc::socklen_t)>, + // Map from accepted fds to connections + fd_map: HashMap, } impl TcpServer { + // idx reserved for the accept request. Cannot clash with indexes + // for connections since those are fd numbers which are positive + // i32 values. + const ACCEPT_IDX: u32 = 0x8000_0000; + + // Signals a cancelation request for a connection when added to + // the idx for a rx request (which is an fd number). Since fd is + // never 0 (that is stdin) cannot clash with ACCEPT_IDX. + // + // We must cancel any in flight requests before destroying the + // connection state since they may be reading from owned data or, + // worse, writing to it! + const RX_CANCEL_IDX_BIT: u32 = 0x8000_0000; + pub(crate) async fn new( conn_manager: Arc, + tx_queue: Arc>, bind_address: SocketAddr, proxy_protocol: bool, + tcp_buffer_size: ByteSize, ) -> Result { - let sock = Arc::new(tokio::net::TcpListener::bind(bind_address).await?); + eprintln!("Binding to {bind_address}"); + let sock = tokio::net::TcpListener::bind(bind_address).await?; + eprintln!("Bound to {bind_address}"); + + let sock = sock.into_std()?; + sock.set_nonblocking(false)?; + let sock = Arc::new(sock); + + let tx_buffer_size = tcp_buffer_size.as_u64().try_into()?; Ok(Self { conn_manager, sock, + tx_queue, + tx_buffer_size, proxy_protocol, + + #[allow(unsafe_code)] + // SAFETY: All zeroes is a valid sockaddr_storage + accept_addr: Box::new((unsafe { std::mem::zeroed() }, 0)), + + fd_map: Default::default(), }) } -} -#[async_trait] -impl Server for TcpServer { - async fn run(&mut self) -> Result<()> { + fn push_accept(&mut self, sq: &mut io_uring::SubmissionQueue) -> Result<()> { info!("Accepting traffic on {}", self.sock.local_addr()?); - loop { - let (sock, peer_addr) = match self.sock.accept().await { - Ok(r) => r, - Err(err) => { - // Some of the errors which accept(2) can return - // - // while never a good thing needn't necessarily be - // fatal to the entire server and prevent us from - // servicing existing connections or potentially - // new connections in the future. - warn!(?err, "Failed to accept a new connection"); - metrics::connection_accept_failed(); - continue; - } - }; - - sock.set_nodelay(true)?; - let local_addr = match SockRef::from(&sock).local_addr() { - Ok(local_addr) => local_addr, - Err(err) => { - // Since we have a bound socket this shouldn't happen. - debug!(?err, "Failed to get local addr"); - return Err(err.into()); - } - }; - let Some(local_addr) = local_addr.as_socket() else { - // Since we only bind to IP sockets this shouldn't happen. - debug!("Failed to convert local addr to socketaddr"); - return Err(anyhow!("Failed to convert local addr to socketaddr")); - }; - - tokio::spawn(handle_connection( - sock, + let (addr, len) = &mut *self.accept_addr; + *len = std::mem::size_of_val(addr) as _; + + let sqe = opcode::Accept::new( + Loop::FIXED_OUTSIDE_FD, + addr as *mut libc::sockaddr_storage as *mut _, + len as *mut libc::socklen_t as *mut _, + ) + .build() + .user_data(Loop::outside_rx_user_data(Self::ACCEPT_IDX)); + + #[allow(unsafe_code)] + // SAFETY: The address buffers are owned by `self` and`` self` is owned by the `io::Loop` + unsafe { + sq.push(&sqe)? + }; + + sq.sync(); + + Ok(()) + } + + fn complete_accept( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + ) -> Result<()> { + let res = io_uring_res(cqe.result()).with_context(|| "outside accept")?; + // Should be impossible since as a twos complement i32 it would be negative. + assert!(res as u32 != Self::ACCEPT_IDX); + + let peer_addr = socket_addr_from_sockaddr(&self.accept_addr.0, self.accept_addr.1)?; + + #[allow(unsafe_code)] + // SAFETY: We trust that on success `accept(2)` returns a + // valid socket fd. + let sock = unsafe { TcpStream::from_raw_fd(res) }; + sock.set_nodelay(true)?; + + let local_addr = match sock.local_addr() { + Ok(local_addr) => local_addr, + Err(err) => { + // Since we have a bound socket this shouldn't happen. + debug!(?err, "Failed to get local addr"); + return Err(err.into()); + } + }; + + let rx_buf = BytesMut::with_capacity(ConnectionState::RX_BUFFER_SIZE); + + let phase = if self.proxy_protocol { + ConnectionPhase::ProxyInitial { local_addr } + } else { + let buffer = TcpSocketBuffer::new(self.tx_queue.clone(), self.tx_buffer_size, &sock); + let outside_io = Arc::new(TcpSocket { + buffer: buffer.clone(), peer_addr, + }); + let conn = self.conn_manager.create_streaming_connection( + Version::MINIMUM, local_addr, - self.conn_manager.clone(), - self.proxy_protocol, - )); + outside_io, + )?; + ConnectionPhase::Connected { conn, buffer } + }; + + let mut state = ConnectionState { + sock, + rx_buf, + phase, + tx_buffer_size: self.tx_buffer_size, + }; + + // Before we add to the hash, due to insert taking ownership + // of state, but we cannot complete anything until we return + // so that's ok. + state.push_rx(sq)?; + + self.fd_map.insert(res as u32, state); + + Ok(()) + } +} + +impl UringIoSource for TcpServer { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } + + fn push_initial_ops(&mut self, sq: &mut io_uring::SubmissionQueue) -> Result<()> { + self.push_accept(sq) + } + + fn complete_rx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()> { + if idx == Self::ACCEPT_IDX { + if let Err(err) = self.complete_accept(sq, cqe) { + // Some of the errors which accept(2) can return + // + // while never a good thing needn't necessarily be + // fatal to the entire server and prevent us from + // servicing existing connections or potentially + // new connections in the future. + warn!(?err, "Failed to accept a new connection"); + metrics::connection_accept_failed(); + } + self.push_accept(sq)?; + return Ok(()); + } + + let (idx, cancelling) = if (idx & Self::RX_CANCEL_IDX_BIT) != 0 { + (idx - Self::RX_CANCEL_IDX_BIT, true) + } else { + (idx, false) + }; + + use std::collections::hash_map::Entry; + + match self.fd_map.entry(idx) { + Entry::Occupied(entry) if cancelling => { + let nr = io_uring_res(cqe.result()).with_context(|| "Cancelling")?; + info!(fd = idx, nr, "Cancelled"); + entry.remove_entry(); + Ok(()) + } + + Entry::Occupied(mut entry) => { + let state = entry.get_mut(); + match state.complete_rx(sq, cqe, &self.tx_queue, &self.conn_manager) { + Ok(()) => Ok(()), + Err(err) => { + if matches!( + state.phase, + ConnectionPhase::ProxyInitial { .. } | ConnectionPhase::Proxy { .. } + ) { + metrics::connection_accept_proxy_header_failed(); + } + info!("Connection closed: {:?}", err); + state.push_cancel(sq)?; + + if let ConnectionPhase::Connected { conn, .. } = &state.phase { + conn.handle_end_of_stream(); + } + + Ok(()) // Error is for the connection, not the process + } + } + } + + // Likely we raced with a cancellation request + Entry::Vacant(_) => { + match io_uring_res(cqe.result()) { + Err(err) => info!("complete unknown tcp rx {idx} with {err}"), + Ok(res) => info!("complete unknown tcp rx {idx} with {res}"), + }; + Ok(()) + } + } + } + + fn complete_tx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()> { + use std::collections::hash_map::Entry; + match self.fd_map.entry(idx) { + Entry::Occupied(mut entry) => { + let state = entry.get_mut(); + match state.complete_tx(sq, cqe) { + Ok(()) => Ok(()), + Err(err) => { + info!("Connection closed: {:?}", err); + state.push_cancel(sq)?; + Ok(()) // Error is for the connection, not the process + } + } + } + + // Likely we raced with a cancellation request + Entry::Vacant(_) => { + match io_uring_res(cqe.result()) { + Err(err) => info!("complete unknown tcp tx {idx} with {err}"), + Ok(res) => info!("complete unknown tcp tx {idx} with {res}"), + }; + Ok(()) + } } } } diff --git a/lightway-server/src/io/outside/udp.rs b/lightway-server/src/io/outside/udp.rs index d58f2f20..436c3d8d 100644 --- a/lightway-server/src/io/outside/udp.rs +++ b/lightway-server/src/io/outside/udp.rs @@ -1,26 +1,37 @@ -mod cmsg; +//! UdpServer UringIoSource +//! +//! Uses uring indexes: +//! +//! Loop::outside_rx_user_data: +//! - 0..UdpServer::rx.len(): A set of recv requests +//! +//! Loop::outside_tx_user_data: +//! - Managed by TxQueue + +pub(crate) mod cmsg; use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, - sync::{Arc, RwLock}, + os::fd::{AsRawFd as _, RawFd}, + sync::{Arc, Mutex, MutexGuard, RwLock}, }; -use anyhow::Result; -use async_trait::async_trait; -use bytes::BytesMut; +use anyhow::{Context as _, Result}; +use bytes::{Bytes, BytesMut}; use bytesize::ByteSize; -use lightway_app_utils::sockopt::socket_enable_pktinfo; +use io_uring::opcode; +use lightway_app_utils::{ + sockaddr_from_socket_addr, socket_addr_from_sockaddr, sockopt::socket_enable_pktinfo, +}; use lightway_core::{ ConnectionType, CowBytes, Header, IOCallbackResult, OutsideIOSendCallback, OutsidePacket, SessionId, Version, MAX_OUTSIDE_MTU, }; -use socket2::{MaybeUninitSlice, MsgHdr, MsgHdrMut, SockAddr, SockRef}; -use tokio::io::Interest; -use tracing::{info, warn}; +use tracing::warn; use crate::{connection_manager::ConnectionManager, metrics}; -use super::Server; +use super::{io_uring_res, iovec, msghdr, Loop, TxQueue, UringIoSource}; enum BindMode { UnspecifiedAddress { local_port: u16 }, @@ -44,52 +55,70 @@ impl std::fmt::Display for BindMode { } } -fn send_to_socket( - sock: &Arc, - buf: &[u8], - peer_addr: &SockAddr, +fn queue_tx( + mut tx_queue: MutexGuard, + buf: Bytes, + peer_addr: libc::sockaddr_storage, + peer_addr_len: libc::socklen_t, pktinfo: Option, ) -> IOCallbackResult { - let res = sock.try_io(Interest::WRITABLE, || { - let sock = SockRef::from(sock.as_ref()); - let bufs = [std::io::IoSlice::new(buf)]; - - let msghdr = MsgHdr::new().with_addr(peer_addr).with_buffers(&bufs); + let len = buf.len(); - const CMSG_SIZE: usize = cmsg::Message::space::(); - let mut cmsg = cmsg::BufferMut::::zeroed(); + let Some((slot, state)) = tx_queue.take_slot() else { + return IOCallbackResult::WouldBlock; + }; - let msghdr = if let Some(pktinfo) = pktinfo { - let mut builder = cmsg.builder(); - builder.fill_next(libc::SOL_IP, libc::IP_PKTINFO, pktinfo)?; + state.iov[0].iov_base = buf.as_ptr() as *mut _; + state.iov[0].iov_len = buf.len(); + state.addr = peer_addr; + state.addr_len = peer_addr_len; - msghdr.with_control(cmsg.as_ref()) - } else { - msghdr - }; + state.buf = Some(buf); - sock.sendmsg(&msghdr, 0) - }); + state.msghdr.msg_name = &mut state.addr as *mut libc::sockaddr_storage as *mut _; + state.msghdr.msg_namelen = state.addr_len; + state.msghdr.msg_iov = state.iov.as_mut_ptr() as *mut _; + state.msghdr.msg_iovlen = state.iov.len(); - match res { - Ok(nr) => IOCallbackResult::Ok(nr), - Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => { - IOCallbackResult::WouldBlock + if let Some(pktinfo) = pktinfo { + let mut builder = state.control.builder(); + if let Err(err) = builder.fill_next(libc::SOL_IP, libc::IP_PKTINFO, pktinfo) { + return IOCallbackResult::Err(err); } - Err(err) => IOCallbackResult::Err(err), + state.msghdr.msg_control = state.control.as_mut_ptr() as *mut _; + // Get from builder? + state.msghdr.msg_controllen = std::mem::size_of_val(&state.control) as _; + } else { + state.msghdr.msg_control = std::ptr::null_mut(); + state.msghdr.msg_controllen = 0; } + + let sqe = opcode::SendMsg::new(Loop::FIXED_OUTSIDE_FD, state.msghdr.as_mut_ptr()).build(); + + #[allow(unsafe_code)] + // SAFETY: + // - slot was optained from take_slot above + // - The buffer is owned by `state` and which is owned by the `TxRing` + unsafe { + tx_queue.push_outside_slot(slot, sqe) + }; + + IOCallbackResult::Ok(len) } struct UdpSocket { - sock: Arc, - peer_addr: RwLock<(SocketAddr, SockAddr)>, + tx_queue: Arc>, + peer_addr: RwLock<(SocketAddr, libc::sockaddr_storage, libc::socklen_t)>, reply_pktinfo: Option, } impl OutsideIOSendCallback for UdpSocket { fn send(&self, buf: CowBytes) -> IOCallbackResult { + let buf = buf.into_owned(); let peer_addr = self.peer_addr.read().unwrap(); - send_to_socket(&self.sock, buf.as_bytes(), &peer_addr.1, self.reply_pktinfo) + let tx_queue = self.tx_queue.lock().unwrap(); + + queue_tx(tx_queue, buf, peer_addr.1, peer_addr.2, self.reply_pktinfo) } fn peer_addr(&self) -> SocketAddr { @@ -99,24 +128,67 @@ impl OutsideIOSendCallback for UdpSocket { fn set_peer_addr(&self, addr: SocketAddr) -> SocketAddr { let mut peer_addr = self.peer_addr.write().unwrap(); let old_addr = peer_addr.0; - *peer_addr = (addr, addr.into()); + + let (raw_addr, raw_addr_len) = sockaddr_from_socket_addr(addr); + + *peer_addr = (addr, raw_addr, raw_addr_len); old_addr } } +struct RxState { + buf: BytesMut, + addr: libc::sockaddr_storage, + control: cmsg::Buffer<{ Self::CONTROL_SIZE }>, + iov: [iovec; 1], + msghdr: msghdr, +} + +impl RxState { + const CONTROL_SIZE: usize = cmsg::Message::space::(); + + fn new() -> Self { + let mut buf = BytesMut::with_capacity(MAX_OUTSIDE_MTU); + let iov = iovec::new(libc::iovec { + iov_base: buf.as_mut_ptr() as *mut _, + iov_len: buf.capacity(), + }); + #[allow(unsafe_code)] + Self { + buf, + // SAFETY: All zeroes is a valid sockaddr + addr: unsafe { std::mem::zeroed() }, + control: cmsg::Buffer::new(), + iov: [iov], + // SAFETY: All zeroes is a valid msghdr + msghdr: unsafe { std::mem::zeroed() }, + } + } +} pub(crate) struct UdpServer { conn_manager: Arc, - sock: Arc, + sock: Arc, bind_mode: BindMode, + tx_queue: Arc>, + // The contents are used for I/O syscalls, ensure they stay put. + rx: Vec, } impl UdpServer { pub(crate) async fn new( + nr_slots: u32, conn_manager: Arc, + tx_queue: Arc>, bind_address: SocketAddr, udp_buffer_size: ByteSize, ) -> Result { - let sock = Arc::new(tokio::net::UdpSocket::bind(bind_address).await?); + tracing::info!("UdpServer with {nr_slots} slots"); + + let sock = tokio::net::UdpSocket::bind(bind_address).await?; + + let sock = sock.into_std()?; + sock.set_nonblocking(false)?; + let sock = Arc::new(sock); let bind_mode = if bind_address.ip().is_unspecified() { BindMode::UnspecifiedAddress { @@ -137,20 +209,31 @@ impl UdpServer { socket_enable_pktinfo(&sock)?; } + let rx = (0..nr_slots).map(|_| RxState::new()).collect(); + + #[allow(unsafe_code)] Ok(Self { conn_manager, sock, bind_mode, + tx_queue, + rx, }) } - async fn data_received( + fn data_received( &mut self, peer_addr: SocketAddr, + raw_peer_addr: libc::sockaddr_storage, + raw_peer_addr_len: libc::socklen_t, local_addr: SocketAddr, reply_pktinfo: Option, - buf: &mut BytesMut, + idx: u32, ) { + #[allow(unsafe_code)] + // SAFETY: The caller must already have validated this. + let buf = &mut unsafe { self.rx.get_unchecked_mut(idx as usize) }.buf; + let pkt = OutsidePacket::Wire(buf, ConnectionType::Datagram); let pkt = match self.conn_manager.parse_raw_outside_packet(pkt) { Ok(hdr) => hdr, @@ -184,8 +267,8 @@ impl UdpServer { local_addr, || { Arc::new(UdpSocket { - sock: self.sock.clone(), - peer_addr: RwLock::new((peer_addr, peer_addr.into())), + tx_queue: self.tx_queue.clone(), + peer_addr: RwLock::new((peer_addr, raw_peer_addr, raw_peer_addr_len)), reply_pktinfo, }) }, @@ -194,7 +277,7 @@ impl UdpServer { match conn_result { Ok(conn) => conn, Err(_e) => { - self.send_reject(peer_addr.into(), reply_pktinfo).await; + self.send_reject(raw_peer_addr, raw_peer_addr_len, reply_pktinfo); return; } } @@ -233,7 +316,12 @@ impl UdpServer { } } - async fn send_reject(&self, peer_addr: SockAddr, reply_pktinfo: Option) { + fn send_reject( + &self, + peer_addr: libc::sockaddr_storage, + peer_addr_len: libc::socklen_t, + pktinfo: Option, + ) { metrics::udp_rejected_session(); let msg = Header { version: Version::MINIMUM, @@ -244,92 +332,98 @@ impl UdpServer { let mut buf = BytesMut::with_capacity(Header::WIRE_SIZE); msg.append_to_wire(&mut buf); + let tx_queue = self.tx_queue.lock().unwrap(); + // Ignore failure to send. - let _ = send_to_socket(&self.sock, &buf, &peer_addr, reply_pktinfo); + + let _ = queue_tx(tx_queue, buf.freeze(), peer_addr, peer_addr_len, pktinfo); + } + + fn push_rx(&mut self, sq: &mut io_uring::SubmissionQueue, idx: u32) -> Result<()> { + let rx = &mut self.rx[idx as usize]; + + // Recover full capacity in case this is a resubmit + rx.buf.clear(); + rx.buf.reserve(MAX_OUTSIDE_MTU); + + rx.msghdr = msghdr::new(libc::msghdr { + msg_name: &mut rx.addr as *mut libc::sockaddr_storage as *mut _, + msg_namelen: std::mem::size_of::() as _, + msg_iov: rx.iov.as_mut_ptr() as *mut libc::msghdr as *mut _, + msg_iovlen: rx.iov.len(), + msg_control: rx.control.as_mut_ptr() as *mut _, + msg_controllen: RxState::CONTROL_SIZE, + msg_flags: 0, + }); + let sqe = opcode::RecvMsg::new(Loop::FIXED_OUTSIDE_FD, rx.msghdr.as_mut_ptr()) + .build() + .user_data(Loop::outside_rx_user_data(idx)); + + #[allow(unsafe_code)] + // SAFETY: The buffer is owned by `self.rx` and `self` is owned by the `io::Loop` + unsafe { + sq.push(&sqe)? + }; + + sq.sync(); + + Ok(()) } } -#[async_trait] -impl Server for UdpServer { - async fn run(&mut self) -> Result<()> { - info!("Accepting traffic on {}", self.bind_mode); - let mut buf = BytesMut::with_capacity(MAX_OUTSIDE_MTU); - loop { - // Recover full capacity - buf.clear(); - buf.reserve(MAX_OUTSIDE_MTU); - - let (peer_addr, local_addr, reply_pktinfo) = self - .sock - .async_io(Interest::READABLE, || { - let sock = SockRef::from(self.sock.as_ref()); - let mut raw_buf = [MaybeUninitSlice::new(buf.spare_capacity_mut())]; - - #[allow(unsafe_code)] - let mut peer_sock_addr = { - // SAFETY: sockaddr_storage is defined - // () - // as being a suitable size and alignment for - // "all supported protocol-specific address - // structures" in the underlying OS APIs. - // - // All zeros is a valid representation, - // corresponding to the `ss_family` having a - // value of `AF_UNSPEC`. - let addr_storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; - let len = std::mem::size_of_val(&addr_storage) as libc::socklen_t; - // SAFETY: We initialized above as `AF_UNSPEC` - // so the storage is correct from that - // angle. The `recvmsg` call will change this - // which should be ok since `sockaddr_storage` - // is big enough. - unsafe { SockAddr::new(addr_storage, len) } - }; - - // We only need this control buffer if - // `self.bind_mode.needs_pktinfo()`. However the hit - // on reserving a fairly small on stack buffer - // should be small compared with the conditional - // logic and dynamically sized buffer needed to - // allow omitting it. - const SIZE: usize = cmsg::Message::space::(); - let mut control = cmsg::Buffer::::new(); - - let mut msg = MsgHdrMut::new() - .with_addr(&mut peer_sock_addr) - .with_buffers(&mut raw_buf) - .with_control(control.as_mut()); - - let len = sock.recvmsg(&mut msg, 0)?; - - if msg.flags().is_truncated() { - metrics::udp_recv_truncated(); - } +impl UringIoSource for UdpServer { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } + + fn push_initial_ops(&mut self, sq: &mut io_uring::SubmissionQueue) -> Result<()> { + for idx in 0..self.rx.len() as u32 { + self.push_rx(sq, idx)? + } + Ok(()) + } + + fn complete_rx( + &mut self, + sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()> { + let res = { + let res = io_uring_res(cqe.result()).with_context(|| "outside recvmsg completion")?; + + let rx = &mut self.rx[idx as usize]; + + #[allow(unsafe_code)] + // SAFETY: We rely on recv_from giving us the correct size + unsafe { + rx.buf.set_len(res as usize); + } - let control_len = msg.control_len(); - - // SAFETY: We rely on recv_from giving us the correct size - #[allow(unsafe_code)] - unsafe { - buf.set_len(len) - }; - - let Some(peer_addr) = peer_sock_addr.as_socket() else { - // Since we only bind to IP sockets this shouldn't happen. - metrics::udp_recv_invalid_addr(); - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "failed to convert local addr to socketaddr", - )); - }; - - #[allow(unsafe_code)] - let (local_addr, reply_pktinfo) = match self.bind_mode { - BindMode::UnspecifiedAddress { local_port } => { - let Some((local_addr, reply_pktinfo)) = + let raw_peer_addr = rx.addr; + let raw_peer_addr_len = rx.msghdr.msg_namelen; + + let peer_addr = match socket_addr_from_sockaddr(&raw_peer_addr, raw_peer_addr_len) { + Ok(a) => a, + Err(err) => { + metrics::udp_recv_invalid_addr(); + return Err(err.into()); + } + }; + + if (rx.msghdr.msg_flags & libc::MSG_TRUNC) != 0 { + metrics::udp_recv_truncated(); + } + + let control_len = rx.msghdr.msg_controllen; + + #[allow(unsafe_code)] + let (local_addr, reply_pktinfo) = match self.bind_mode { + BindMode::UnspecifiedAddress { local_port } => { + let Some((local_addr, reply_pktinfo)) = // SAFETY: The call to `recvmsg` above updated // the control buffer length field. - unsafe { control.iter(control_len) }.find_map(|cmsg| { + unsafe { rx.control.iter(control_len) }.find_map(|cmsg| { match cmsg { cmsg::Message::IpPktinfo(pi) => { // From https://pubs.opengroup.org/onlinepubs/009695399/basedefs/netinet/in.h.html @@ -355,22 +449,42 @@ impl Server for UdpServer { // and we have set IP_PKTINFO // sockopt this shouldn't happen. metrics::udp_recv_missing_pktinfo(); + println!("outside user data {:016x}, idx {:x} had no PKTINFO", cqe.user_data(),idx); return Err(std::io::Error::new( std::io::ErrorKind::Other, "recvmsg did not return IP_PKTINFO", - )); + ).into()); }; - (local_addr, Some(reply_pktinfo)) - } - BindMode::SpecificAddress { local_addr } => (local_addr, None), - }; + (local_addr, Some(reply_pktinfo)) + } + BindMode::SpecificAddress { local_addr } => (local_addr, None), + }; + + self.data_received( + peer_addr, + raw_peer_addr, + raw_peer_addr_len, + local_addr, + reply_pktinfo, + idx, + ); + + Ok(()) + }; - Ok((peer_addr, local_addr, reply_pktinfo)) - }) - .await?; + // Queue another recv + self.push_rx(sq, idx)?; - self.data_received(peer_addr, local_addr, reply_pktinfo, &mut buf) - .await; - } + res + } + + fn complete_tx( + &mut self, + _sq: &mut io_uring::SubmissionQueue, + cqe: io_uring::cqueue::Entry, + idx: u32, + ) -> Result<()> { + let _ = self.tx_queue.lock().unwrap().complete(cqe, idx); + Ok(()) } } diff --git a/lightway-server/src/io/outside/udp/cmsg.rs b/lightway-server/src/io/outside/udp/cmsg.rs index 0aa94c8c..9e553185 100644 --- a/lightway-server/src/io/outside/udp/cmsg.rs +++ b/lightway-server/src/io/outside/udp/cmsg.rs @@ -8,8 +8,8 @@ impl Buffer { Self([std::mem::MaybeUninit::::uninit(); N]) } - pub(crate) fn as_mut(&mut self) -> &mut [std::mem::MaybeUninit] { - &mut self.0 + pub(crate) fn as_mut_ptr(&mut self) -> *mut u8 { + self.0.as_mut_ptr() as *mut _ } /// # Safety @@ -137,6 +137,10 @@ impl BufferMut { _phantom: std::marker::PhantomData, } } + + pub(crate) fn as_mut_ptr(&mut self) -> *mut u8 { + self.0.as_mut_ptr() as *mut _ + } } impl AsRef<[u8]> for BufferMut { diff --git a/lightway-server/src/io/tx.rs b/lightway-server/src/io/tx.rs new file mode 100644 index 00000000..fa027313 --- /dev/null +++ b/lightway-server/src/io/tx.rs @@ -0,0 +1,170 @@ +//! TxQueue, helper/queue for UringIoSource tx implementations +//! +//! Uses uring indexes: +//! +//! Loop::outside_rx_user_data: +//! - None +//! +//! Loop::outside_tx_user_data: +//! - 0..TxQueue::state.len() + +use std::collections::VecDeque; + +use anyhow::{Context as _, Result}; +use bytes::Bytes; +use io_uring::squeue::Entry as SEntry; + +use super::{ + ffi::{iovec, msghdr}, + io_uring_res, + outside::udp::cmsg, + Loop, SubmissionQueue, Submitter, +}; + +pub(super) struct TxState { + pub buf: Option, + pub addr: libc::sockaddr_storage, + pub addr_len: libc::socklen_t, + pub control: cmsg::BufferMut<{ Self::CONTROL_SIZE }>, + pub iov: [iovec; 1], + pub msghdr: msghdr, +} + +impl TxState { + const CONTROL_SIZE: usize = cmsg::Message::space::(); + fn new() -> Self { + #[allow(unsafe_code)] + Self { + buf: None, + // SAFETY: All zeroes is a valid sockaddr + addr: unsafe { std::mem::zeroed() }, + addr_len: 0, + control: cmsg::BufferMut::zeroed(), + + // SAFETY: All zeroes is a valid iov + iov: [unsafe { std::mem::zeroed() }], + // SAFETY: All zeroes is a valid msghdr + msghdr: unsafe { std::mem::zeroed() }, + } + } +} + +pub struct TxQueue { + sqe_ring: VecDeque, + slots: Vec, + state: Vec, +} + +impl TxQueue { + pub fn new(nr_slots: u32) -> Self { + tracing::info!("TxQueue with {nr_slots} slots"); + let sqe_ring = VecDeque::with_capacity(nr_slots as usize); + let (slots, state) = (0..nr_slots).map(|nr| (nr, TxState::new())).unzip(); + + Self { + sqe_ring, + slots, + state, + } + } + + /// Reserve a slot, the returned value should be passed to + /// `push_*_slot` after setting up the state and constructing an + /// sqe. + pub(super) fn take_slot(&mut self) -> Option<(u32, &mut TxState)> { + let slot = self.slots.pop()?; + let state = &mut self.state[slot as usize]; + Some((slot, state)) + } + + #[allow(unsafe_code)] + /// Push an inside request entry to the tx queue. + /// + /// Callers are responsible for calling `::complete` when the + /// request completes to free the slot. + /// + /// # Safety: + /// + /// - idx must have been previously obtained from `take_slot` + /// - sqe must meet the safety requirements + /// + /// Any sqe userdata will be overwritten + pub(super) unsafe fn push_inside_slot(&mut self, idx: u32, sqe: SEntry) { + let sqe = sqe.user_data(Loop::inside_tx_user_data(idx)); + self.sqe_ring.push_back(sqe); + } + + #[allow(unsafe_code)] + /// Push an outside request entry to the tx queue. + /// + /// Callers are responsible for calling `::complete` when the + /// request completes to free the slot. + /// + /// # Safety: + /// + /// - idx must have been previously obtained from `take_slot` + /// - sqe must meet the safety requirements + /// + /// Any sqe userdata will be overwritten + pub(super) unsafe fn push_outside_slot(&mut self, idx: u32, sqe: SEntry) { + let sqe = sqe.user_data(Loop::outside_tx_user_data(idx)); + self.sqe_ring.push_back(sqe); + } + + #[allow(unsafe_code)] + /// Push an arbitrary entry to the tx queue. Does not consume a slot. + /// + /// Callers are responsible for completion and should not call + /// `::complete`. + /// + /// Use this for SQEs which do not require an entry in `::state` + /// to keep buffers live and/or for which the calling code wants + /// to manage the idx space itself. + /// + /// # Safety: + /// + /// - sqe must meet the safety requirements + pub(super) unsafe fn push(&mut self, sqe: SEntry) { + self.sqe_ring.push_back(sqe); + } + + /// Push all entries (added by `push_*_slot` or `push`) to the uring. + pub(super) fn drain(&mut self, submitter: &Submitter, sq: &mut SubmissionQueue) -> Result<()> { + while let Some(sqe) = self.sqe_ring.pop_front() { + if sq.is_full() { + match submitter.submit() { + Ok(_) => (), + Err(ref err) if err.raw_os_error() == Some(libc::EBUSY) => break, + Err(err) => return Err(err.into()), + } + sq.sync(); + } + + #[allow(unsafe_code)] + // SAFETY: Safe according to the safety requirements of `push_*_slot` or `push` + unsafe { + sq.push(&sqe)? + }; + + sq.sync() + } + + Ok(()) + } + + /// Complete an entry added with `push_*_slot`, intended to be + /// called from the IoUringSource's `complete_*` method. Note that + /// users of plain `push` are responsible for their own + /// completion. + pub(super) fn complete(&mut self, cqe: io_uring::cqueue::Entry, idx: u32) -> Result<()> { + let _res = io_uring_res(cqe.result()).with_context(|| "tx completion")?; + + let slot = &mut self.state[idx as usize]; + + slot.buf = None; + + self.slots.push(idx); + + Ok(()) + } +} diff --git a/lightway-server/src/lib.rs b/lightway-server/src/lib.rs index 0bcd6a3e..9eaeda8c 100644 --- a/lightway-server/src/lib.rs +++ b/lightway-server/src/lib.rs @@ -17,26 +17,19 @@ pub use lightway_core::{ use anyhow::{anyhow, Context, Result}; use ipnet::Ipv4Net; use lightway_app_utils::{connection_ticker_cb, TunConfig}; -use lightway_core::{ - ipv4_update_destination, AuthMethod, BuilderPredicates, ConnectionError, IOCallbackResult, - InsideIpConfig, Secret, ServerContextBuilder, -}; -use pnet::packet::ipv4::Ipv4Packet; +use lightway_core::{AuthMethod, BuilderPredicates, InsideIpConfig, Secret, ServerContextBuilder}; use std::{ collections::HashMap, net::{IpAddr, Ipv4Addr, SocketAddr}, path::PathBuf, - sync::Arc, + sync::{Arc, Mutex}, time::Duration, }; -use tokio::task::JoinHandle; use tracing::{info, warn}; -use crate::io::inside::InsideIO; use crate::ip_manager::IpManager; use connection_manager::ConnectionManager; -use io::outside::Server; fn debug_fmt_plugin_list( list: &PluginFactoryList, @@ -112,15 +105,28 @@ pub struct ServerConfig ServerAuth>> { /// Enable Post Quantum Crypto pub enable_pqc: bool, - /// Enable IO-uring interface for Tunnel - pub enable_tun_iouring: bool, - /// IO-uring submission queue count pub iouring_entry_count: usize, /// IO-uring sqpoll idle time. pub iouring_sqpoll_idle_time: Duration, + /// Number of concurrent TUN device read requests to issue to + /// IO-uring. Setting this too large may negatively impact + /// performance. + pub iouring_tun_rx_count: u32, + + /// Configure TUN in blocking mode. + pub iouring_tun_blocking: bool, + + /// Number of concurrent UDP socket recvmsg requests to issue to + /// IO-uring. + pub iouring_udp_rx_count: u32, + + /// Maximum number of concurrent UDP + TUN sendmsg/write requests + /// to issue to IO-uring. + pub iouring_tx_count: u32, + /// The key update interval for DTLS/TLS 1.3 connections pub key_update_interval: Duration, @@ -140,11 +146,40 @@ pub struct ServerConfig ServerAuth>> { /// UDP Buffer size for the server pub udp_buffer_size: ByteSize, + + /// TCP Buffer size for the server + pub tcp_buffer_size: ByteSize, +} + +impl ServerAuth> + Sync + Send + 'static> ServerConfig { + fn validate(&self) -> Result<()> { + let mut required_uring_slots = + self.iouring_tun_rx_count as usize + self.iouring_tx_count as usize + 1; // cancellation request + + required_uring_slots += match self.connection_type { + // this should be 2 * max connections, but max connections + // is unknown, assume at least 1. + ConnectionType::Stream => 2, + ConnectionType::Datagram => self.iouring_udp_rx_count as usize, + }; + + if self.iouring_entry_count < required_uring_slots { + return Err(anyhow!( + "iouring_entry_count too small {} < {}", + self.iouring_entry_count, + required_uring_slots + )); + } + + Ok(()) + } } pub async fn server ServerAuth> + Sync + Send + 'static>( config: ServerConfig, ) -> Result<()> { + config.validate()?; + let server_key = Secret::PemFile(&config.server_key); let server_cert = Secret::PemFile(&config.server_cert); @@ -175,12 +210,16 @@ pub async fn server ServerAuth> + Sync + Send + 'stati let connection_type = config.connection_type; let auth = Arc::new(AuthAdapter(config.auth)); - let iouring = if config.enable_tun_iouring { - Some((config.iouring_entry_count, config.iouring_sqpoll_idle_time)) - } else { - None - }; - let inside_io = Arc::new(io::inside::Tun::new(config.tun_config, iouring).await?); + let tx_queue = Arc::new(Mutex::new(io::TxQueue::new(config.iouring_tx_count))); + + let tun = io::inside::Tun::new( + config.iouring_tun_rx_count, + config.iouring_tun_blocking, + config.tun_config, + config.lightway_client_ip, + ip_manager.clone(), + tx_queue.clone(), + )?; let ctx = ServerContextBuilder::new( connection_type, @@ -188,7 +227,7 @@ pub async fn server ServerAuth> + Sync + Send + 'stati server_key, auth, ip_manager.clone(), - inside_io.clone().into_io_send_callback(), + tun.inside_io_sender(), )? .with_schedule_tick_cb(connection_ticker_cb) .with_key_update_interval(config.key_update_interval) @@ -201,69 +240,43 @@ pub async fn server ServerAuth> + Sync + Send + 'stati tokio::spawn(statistics::run(conn_manager.clone(), ip_manager.clone())); - let mut server: Box = match connection_type { - ConnectionType::Datagram => Box::new( + let server = match connection_type { + ConnectionType::Datagram => io::OutsideIoSource::Udp( io::outside::UdpServer::new( + config.iouring_udp_rx_count, conn_manager.clone(), + tx_queue.clone(), config.bind_address, config.udp_buffer_size, ) .await?, ), - ConnectionType::Stream => Box::new( + ConnectionType::Stream => io::OutsideIoSource::Tcp( io::outside::TcpServer::new( conn_manager.clone(), + tx_queue.clone(), config.bind_address, config.proxy_protocol, + config.tcp_buffer_size, ) .await?, ), }; - let inside_io_loop: JoinHandle> = tokio::spawn(async move { - loop { - let mut buf = match inside_io.recv_buf().await { - IOCallbackResult::Ok(buf) => buf, - IOCallbackResult::WouldBlock => continue, // Spuriously failed to read, keep waiting - IOCallbackResult::Err(err) => { - break Err(anyhow!(err).context("InsideIO recv buf error")); - } - }; - - // Find connection based on client ip (dest ip) and forward packet - let packet = Ipv4Packet::new(buf.as_ref()); - let Some(packet) = packet else { - eprintln!("Invalid inside packet size (less than Ipv4 header)!"); - continue; - }; - let conn = ip_manager.find_connection(packet.get_destination()); - - // Update destination IP address to client's ip - ipv4_update_destination(buf.as_mut(), config.lightway_client_ip); - - if let Some(conn) = conn { - match conn.inside_data_received(&mut buf) { - Ok(()) => {} - Err(ConnectionError::InvalidState) => { - // Skip forwarding packet when offline - metrics::tun_rejected_packet_invalid_state(); - } - Err(ConnectionError::InvalidInsidePacket(_)) => { - // Skip processing invalid packet - metrics::tun_rejected_packet_invalid_inside_packet(); - } - Err(err) => { - let fatal = err.is_fatal(conn.connection_type()); - metrics::tun_rejected_packet_invalid_other(fatal); - if fatal { - conn.handle_end_of_stream(); - } - } - } - } else { - metrics::tun_rejected_packet_no_connection(); - } - } + // On exit dropping _io_handle will cause EPIPE to be delivered to + // io_cancel. This causes the corresponding read request on the + // ring to complete and signal the loop should exit. + let (_io_handle, io_cancel) = tokio::net::unix::pipe::pipe()?; + let io_cancel = io_cancel.into_blocking_fd()?; + let io_task = tokio::task::spawn_blocking(move || { + let io_loop = io::Loop::new( + config.iouring_entry_count, + config.iouring_sqpoll_idle_time, + tx_queue, + server, + tun, + )?; + io_loop.run(io_cancel) }); let (ctrlc_tx, ctrlc_rx) = tokio::sync::oneshot::channel(); @@ -275,8 +288,7 @@ pub async fn server ServerAuth> + Sync + Send + 'stati })?; tokio::select! { - err = server.run() => err.context("Outside IO loop exited"), - io = inside_io_loop => io.map_err(|e| anyhow!(e).context("Inside IO loop panicked"))?.context("Inside IO loop exited"), + r = io_task => r?.context("IO task exited"), _ = ctrlc_rx => { info!("Sigterm or Sigint received"); conn_manager.close_all_connections(); @@ -284,3 +296,62 @@ pub async fn server ServerAuth> + Sync + Send + 'stati } } } + +#[cfg(test)] +mod tests { + use super::*; + + use test_case::test_case; + + struct Auth; + + impl ServerAuth> for Auth {} + + #[test_case(ConnectionType::Stream, 0, 0, 0, 0 => panics "iouring_entry_count too small")] + #[test_case(ConnectionType::Stream, 3, 0, 0, 0 => ())] + #[test_case(ConnectionType::Stream, 20, 5, 0, 13 => panics "iouring_entry_count too small")] + #[test_case(ConnectionType::Stream, 21, 5, 0, 13 => ())] + #[test_case(ConnectionType::Stream, 22, 5, 0, 13 => ())] + #[test_case(ConnectionType::Stream, 7, 1, 10_000, 3 => ())] // udp rx count irrelevant for stream + #[test_case(ConnectionType::Datagram, 0, 0, 0, 0 => panics "iouring_entry_count too small")] + #[test_case(ConnectionType::Datagram, 1, 0, 0, 0 => ())] + #[test_case(ConnectionType::Datagram, 25, 5, 7, 13 => panics "iouring_entry_count too small")] + #[test_case(ConnectionType::Datagram, 26, 5, 7, 13 => ())] + #[test_case(ConnectionType::Datagram, 27, 5, 7, 13 => ())] + fn validate_iouring_entry_count( + connection_type: ConnectionType, + iouring_entry_count: usize, + iouring_tun_rx_count: u32, + iouring_udp_rx_count: u32, + iouring_tx_count: u32, + ) { + let config = ServerConfig { + connection_type, + auth: Auth, + server_cert: "".into(), + server_key: "".into(), + tun_config: Default::default(), + ip_pool: "10.0.0.0/8".parse().unwrap(), + ip_map: Default::default(), + tun_ip: None, + lightway_server_ip: "1.1.1.1".parse().unwrap(), + lightway_client_ip: "2.2.2.2".parse().unwrap(), + lightway_dns_ip: "3.3.3.3".parse().unwrap(), + enable_pqc: false, + iouring_entry_count, + iouring_sqpoll_idle_time: Default::default(), + iouring_tun_rx_count, + iouring_tun_blocking: false, + iouring_udp_rx_count, + iouring_tx_count, + key_update_interval: Default::default(), + inside_plugins: Default::default(), + outside_plugins: Default::default(), + bind_address: "0.0.0.0:0".parse().unwrap(), + proxy_protocol: false, + udp_buffer_size: Default::default(), + tcp_buffer_size: Default::default(), + }; + config.validate().unwrap(); + } +} diff --git a/lightway-server/src/main.rs b/lightway-server/src/main.rs index 649f2ed3..7fde294c 100644 --- a/lightway-server/src/main.rs +++ b/lightway-server/src/main.rs @@ -130,15 +130,19 @@ async fn main() -> Result<()> { lightway_client_ip: config.lightway_client_ip, lightway_dns_ip: config.lightway_dns_ip, enable_pqc: config.enable_pqc, - enable_tun_iouring: config.enable_tun_iouring, iouring_entry_count: config.iouring_entry_count, iouring_sqpoll_idle_time: config.iouring_sqpoll_idle_time.into(), + iouring_tun_rx_count: config.iouring_tun_rx_count, + iouring_tun_blocking: config.iouring_tun_blocking, + iouring_udp_rx_count: config.iouring_udp_rx_count, + iouring_tx_count: config.iouring_tx_count, key_update_interval: config.key_update_interval.into(), inside_plugins: Default::default(), outside_plugins: Default::default(), bind_address: config.bind_address, proxy_protocol: config.proxy_protocol, udp_buffer_size: config.udp_buffer_size, + tcp_buffer_size: config.tcp_buffer_size, }; server(config).await diff --git a/tests/Earthfile b/tests/Earthfile index 19669363..fb01446d 100644 --- a/tests/Earthfile +++ b/tests/Earthfile @@ -83,13 +83,13 @@ run-udp-floating-ip-test: run-udp-pmtud-test: DO +TEST --MODE=udp --SERVER_PORT=27690 --CLIENT_EXTRA_ARGS="--enable-pmtud" -# run-udp-iouring-test runs e2e test using UDP and default cipher with io-uring enabled +# run-udp-iouring-test runs e2e test using UDP and default cipher with client io-uring enabled run-udp-iouring-test: - DO +TEST --MODE=udp --SERVER_PORT=27690 --SERVER_EXTRA_ARGS="--enable-tun-iouring" --CLIENT_EXTRA_ARGS="--enable-tun-iouring" + DO +TEST --MODE=udp --SERVER_PORT=27690 --CLIENT_EXTRA_ARGS="--enable-tun-iouring" -# run-tcp-iouring-test runs e2e test using TCP and default cipher with io-uring enabled +# run-tcp-iouring-test runs e2e test using TCP and default cipher with client io-uring enabled run-tcp-iouring-test: - DO +TEST --MODE=tcp --SERVER_PORT=27690 --SERVER_EXTRA_ARGS="--enable-tun-iouring" --CLIENT_EXTRA_ARGS="--enable-tun-iouring" + DO +TEST --MODE=tcp --SERVER_PORT=27690 --CLIENT_EXTRA_ARGS="--enable-tun-iouring" # run-udp-min-inside-mtu-test runs e2e test of UDP with client using smallest valid inside MTU run-udp-min-inside-mtu-test: @@ -121,7 +121,7 @@ run-tcp-keepalive-test: # run-udp-single-threaded-test runs e2e test of UDP with server and client using a single Tokio worker thread run-udp-single-threaded-test: - DO +TEST --MODE=udp --SERVER_PORT=27690 --SERVER_TOKIO_WORKER_THREADS=1 --SERVER_EXTRA_ARGS="--enable-tun-iouring" --CLIENT_TOKIO_WORKER_THREADS=1 --CLIENT_EXTRA_ARGS="--keepalive-interval=2s --keepalive-timeout=6s --enable-tun-iouring --enable-pmtud" + DO +TEST --MODE=udp --SERVER_PORT=27690 --SERVER_TOKIO_WORKER_THREADS=1 --CLIENT_TOKIO_WORKER_THREADS=1 --CLIENT_EXTRA_ARGS="--keepalive-interval=2s --keepalive-timeout=6s --enable-tun-iouring --enable-pmtud" # run-tcp-single-threaded-test runs e2e test of TCP with server and client using a single Tokio worker thread run-tcp-single-threaded-test: