diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index 5030adc6d92..9b3f46f172b 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -488,6 +488,24 @@ impl Handle { } } +/// TEST PURPOSE RELATED TO PR #7773 +#[cfg(feature = "full")] +impl Handle { + /// Returns the number of pending registrations (test-only, not part of public API) + #[doc(hidden)] + #[allow(unreachable_pub)] + pub fn io_pending_registration_count(&self) -> usize { + self.inner.driver().io().pending_registration_count() + } + + /// Returns the total number of registrations in the main list (test-only, not part of public API) + #[doc(hidden)] + #[allow(unreachable_pub)] + pub fn io_total_registration_count(&self) -> usize { + self.inner.driver().io().total_registration_count() + } +} + impl std::panic::UnwindSafe for Handle {} impl std::panic::RefUnwindSafe for Handle {} diff --git a/tokio/src/runtime/io/driver.rs b/tokio/src/runtime/io/driver.rs index 04540cf2b13..f16164381c4 100644 --- a/tokio/src/runtime/io/driver.rs +++ b/tokio/src/runtime/io/driver.rs @@ -296,7 +296,8 @@ impl Handle { source: &mut impl Source, ) -> io::Result<()> { // Deregister the source with the OS poller **first** - self.registry.deregister(source)?; + // Cleanup ALWAYS happens + let os_result = self.registry.deregister(source); if self .registrations @@ -307,6 +308,8 @@ impl Handle { self.metrics.dec_fd_count(); + os_result?; // Return error after cleanup + Ok(()) } @@ -317,6 +320,24 @@ impl Handle { } } +/// TEST PURPOSE RELATED TO PR #7773 +#[cfg(feature = "full")] +impl Handle { + /// Returns the number of pending registrations (test-only, not part of public API) + #[doc(hidden)] + #[allow(unreachable_pub)] + pub fn pending_registration_count(&self) -> usize { + self.registrations.pending_release_count() + } + /// Returns the total number of registrations in the main list (test-only) + #[doc(hidden)] + #[allow(unreachable_pub)] + pub fn total_registration_count(&self) -> usize { + self.registrations + .total_registration_count(&mut self.synced.lock()) + } +} + impl fmt::Debug for Handle { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Handle") diff --git a/tokio/src/runtime/io/registration_set.rs b/tokio/src/runtime/io/registration_set.rs index 2796796de93..a44efd226ec 100644 --- a/tokio/src/runtime/io/registration_set.rs +++ b/tokio/src/runtime/io/registration_set.rs @@ -53,6 +53,33 @@ impl RegistrationSet { self.num_pending_release.load(Acquire) != 0 } + /// TEST PURPOSE RELATED TO PR #7773 + #[cfg(feature = "full")] + pub(super) fn pending_release_count(&self) -> usize { + self.num_pending_release.load(Acquire) + } + /// TEST PURPOSE RELATED TO PR #7773 + #[cfg(feature = "full")] + pub(super) fn total_registration_count(&self, synced: &mut Synced) -> usize { + // Count by temporarily draining the list, then restoring it + // This is safe for test purposes + let mut items = Vec::new(); + + // Drain all items + while let Some(item) = synced.registrations.pop_back() { + items.push(item); + } + + let count = items.len(); + + // Restore items in reverse order (since we popped from back) + for item in items.into_iter().rev() { + synced.registrations.push_front(item); + } + + count + } + pub(super) fn allocate(&self, synced: &mut Synced) -> io::Result> { if synced.is_shutdown { return Err(io::Error::new( diff --git a/tokio/tests/io_async_fd.rs b/tokio/tests/io_async_fd.rs index c9a7302d3fc..46f42424eff 100644 --- a/tokio/tests/io_async_fd.rs +++ b/tokio/tests/io_async_fd.rs @@ -947,12 +947,76 @@ async fn try_new() { assert!(Arc::ptr_eq(&original, &returned)); } +/// Regression test for issue #7563 +/// +/// Reproduces the bug where closing fd before dropping AsyncFd causes +/// OS deregister to fail, preventing cleanup and leaking ScheduledIo objects. #[tokio::test] -async fn try_with_interest() { - let original = Arc::new(InvalidSource); +async fn memory_leak_when_fd_closed_before_drop() { + use std::os::unix::io::{AsRawFd, RawFd}; + use std::sync::Arc; + use tokio::io::unix::AsyncFd; + use tokio::runtime::Handle; - let error = AsyncFd::try_with_interest(original.clone(), Interest::READABLE).unwrap_err(); - let (returned, _cause) = error.into_parts(); + use nix::sys::socket::{self, AddressFamily, SockFlag, SockType}; - assert!(Arc::ptr_eq(&original, &returned)); + struct RawFdWrapper { + fd: RawFd, + } + + impl AsRawFd for RawFdWrapper { + fn as_raw_fd(&self) -> RawFd { + self.fd + } + } + + let rt_handle = Handle::current(); + tokio::task::yield_now().await; + let initial_count = rt_handle.io_total_registration_count(); + + const ITERATIONS: usize = 30; + let mut max_count_seen = initial_count; + + for _ in 0..ITERATIONS { + let (fd_a, _fd_b) = socket::socketpair( + AddressFamily::Unix, + SockType::Stream, + None, + SockFlag::empty(), + ) + .expect("socketpair"); + let raw_fd = fd_a.as_raw_fd(); + set_nonblocking(raw_fd); + std::mem::forget(fd_a); + + let afd = Arc::new(RawFdWrapper { fd: raw_fd }); + let async_fd = AsyncFd::new(ArcFd(afd.clone())).unwrap(); + + unsafe { + libc::close(raw_fd); + } + + drop(async_fd); + tokio::task::yield_now().await; + + let current_count = rt_handle.io_total_registration_count(); + max_count_seen = max_count_seen.max(current_count); + } + + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(100)).await; + + let final_count = rt_handle.io_total_registration_count(); + max_count_seen = max_count_seen.max(final_count); + + assert!( + final_count <= initial_count + 2 && max_count_seen <= initial_count + 2, + "Memory leak detected: final count {} (initial: {}), max seen: {}. \ + With bug, count would be ~{} ({} leaked objects).", + final_count, + initial_count, + max_count_seen, + initial_count + ITERATIONS, + max_count_seen.saturating_sub(initial_count) + ); }