diff --git a/tokio/src/process/unix/pidfd_reaper.rs b/tokio/src/process/unix/pidfd_reaper.rs index b425438f190..5eba0b0eb68 100644 --- a/tokio/src/process/unix/pidfd_reaper.rs +++ b/tokio/src/process/unix/pidfd_reaper.rs @@ -4,7 +4,6 @@ use crate::{ imp::{orphan::Wait, OrphanQueue}, kill::Kill, }, - util::error::RUNTIME_SHUTTING_DOWN_ERROR, }; use libc::{syscall, SYS_pidfd_open, ENOSYS, PIDFD_NONBLOCK}; @@ -95,45 +94,6 @@ where pidfd: PollEvented, } -fn display_eq(d: impl std::fmt::Display, s: &str) -> bool { - use std::fmt::Write; - - struct FormatEq<'r> { - remainder: &'r str, - unequal: bool, - } - - impl<'r> Write for FormatEq<'r> { - fn write_str(&mut self, s: &str) -> std::fmt::Result { - if !self.unequal { - if let Some(new_remainder) = self.remainder.strip_prefix(s) { - self.remainder = new_remainder; - } else { - self.unequal = true; - } - } - Ok(()) - } - } - - let mut fmt_eq = FormatEq { - remainder: s, - unequal: false, - }; - let _ = write!(fmt_eq, "{d}"); - fmt_eq.remainder.is_empty() && !fmt_eq.unequal -} - -fn is_rt_shutdown_err(err: &io::Error) -> bool { - if let Some(inner) = err.get_ref() { - err.kind() == io::ErrorKind::Other - && inner.source().is_none() - && display_eq(inner, RUNTIME_SHUTTING_DOWN_ERROR) - } else { - false - } -} - impl Future for PidfdReaperInner where W: Wait + Unpin, @@ -150,7 +110,7 @@ where } this.pidfd.registration().clear_readiness(evt); } - Poll::Ready(Err(err)) if is_rt_shutdown_err(&err) => {} + Poll::Ready(Err(err)) if crate::runtime::is_rt_shutdown_err(&err) => {} Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, }; diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 1a341b7b98a..e806c026eb3 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -569,7 +569,7 @@ cfg_rt! { pub use handle::{EnterGuard, Handle, TryCurrentError}; mod runtime; - pub use runtime::{Runtime, RuntimeFlavor}; + pub use runtime::{Runtime, RuntimeFlavor, is_rt_shutdown_err}; mod id; pub use id::Id; diff --git a/tokio/src/runtime/runtime.rs b/tokio/src/runtime/runtime.rs index af1711fad38..e967a7e99ab 100644 --- a/tokio/src/runtime/runtime.rs +++ b/tokio/src/runtime/runtime.rs @@ -3,9 +3,11 @@ use crate::runtime::blocking::BlockingPool; use crate::runtime::scheduler::CurrentThread; use crate::runtime::{context, EnterGuard, Handle}; use crate::task::JoinHandle; +use crate::util::error::RUNTIME_SHUTTING_DOWN_ERROR; use crate::util::trace::SpawnMeta; use std::future::Future; +use std::io; use std::mem; use std::time::Duration; @@ -513,3 +515,70 @@ impl Drop for Runtime { impl std::panic::UnwindSafe for Runtime {} impl std::panic::RefUnwindSafe for Runtime {} + +fn display_eq(d: impl std::fmt::Display, s: &str) -> bool { + use std::fmt::Write; + + struct FormatEq<'r> { + remainder: &'r str, + unequal: bool, + } + + impl<'r> Write for FormatEq<'r> { + fn write_str(&mut self, s: &str) -> std::fmt::Result { + if !self.unequal { + if let Some(new_remainder) = self.remainder.strip_prefix(s) { + self.remainder = new_remainder; + } else { + self.unequal = true; + } + } + Ok(()) + } + } + + let mut fmt_eq = FormatEq { + remainder: s, + unequal: false, + }; + let _ = write!(fmt_eq, "{d}"); + fmt_eq.remainder.is_empty() && !fmt_eq.unequal +} + +/// Checks whether the given error was emitted by Tokio when shutting down its runtime. +/// +/// # Examples +/// +/// ``` +/// # #[cfg(not(target_family = "wasm"))] +/// # { +/// use tokio::runtime::Runtime; +/// use tokio::net::TcpListener; +/// +/// fn main() { +/// let rt1 = Runtime::new().unwrap(); +/// let rt2 = Runtime::new().unwrap(); +/// +/// let listener = rt1.block_on(async { +/// TcpListener::bind("127.0.0.1:0").await.unwrap() +/// }); +/// +/// drop(rt1); +/// +/// rt2.block_on(async { +/// let res = listener.accept().await; +/// assert!(res.is_err()); +/// assert!(tokio::runtime::is_rt_shutdown_err(res.as_ref().unwrap_err())); +/// }); +/// } +/// # } +/// ``` +pub fn is_rt_shutdown_err(err: &io::Error) -> bool { + if let Some(inner) = err.get_ref() { + err.kind() == io::ErrorKind::Other + && inner.source().is_none() + && display_eq(inner, RUNTIME_SHUTTING_DOWN_ERROR) + } else { + false + } +} diff --git a/tokio/tests/rt_shutdown_err.rs b/tokio/tests/rt_shutdown_err.rs new file mode 100644 index 00000000000..b92d8eb24f1 --- /dev/null +++ b/tokio/tests/rt_shutdown_err.rs @@ -0,0 +1,82 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] +#![cfg(not(miri))] // No socket in miri. + +use std::io; +use tokio::net::TcpListener; +use tokio::runtime::Builder; + +fn rt() -> tokio::runtime::Runtime { + Builder::new_current_thread().enable_all().build().unwrap() +} + +#[test] +fn test_is_rt_shutdown_err() { + let rt1 = rt(); + let rt2 = rt(); + + let listener = rt1.block_on(async { TcpListener::bind("127.0.0.1:0").await.unwrap() }); + + drop(rt1); + + rt2.block_on(async { + let res = listener.accept().await; + assert!(res.is_err()); + let err = res.as_ref().unwrap_err(); + assert!(tokio::runtime::is_rt_shutdown_err(err)); + }); +} + +#[test] +fn test_is_not_rt_shutdown_err() { + let err = io::Error::new(io::ErrorKind::Other, "some other error"); + assert!(!tokio::runtime::is_rt_shutdown_err(&err)); + + let err = io::Error::new(io::ErrorKind::NotFound, "not found"); + assert!(!tokio::runtime::is_rt_shutdown_err(&err)); +} + +#[test] +#[cfg_attr(panic = "abort", ignore)] +fn test_join_error_panic() { + let rt = rt(); + let handle = rt.spawn(async { + panic!("oops"); + }); + + let join_err = rt.block_on(handle).unwrap_err(); + let io_err: io::Error = join_err.into(); + assert!(!tokio::runtime::is_rt_shutdown_err(&io_err)); +} + +#[test] +fn test_join_error_cancelled() { + let rt = rt(); + let handle = rt.spawn(async { + std::future::pending::<()>().await; + }); + handle.abort(); + let join_err = rt.block_on(handle).unwrap_err(); + let io_err: io::Error = join_err.into(); + assert!(!tokio::runtime::is_rt_shutdown_err(&io_err)); +} + +#[test] +fn test_other_error_kinds_and_strings() { + // TimedOut + let err = io::Error::new(io::ErrorKind::TimedOut, "timed out"); + assert!(!tokio::runtime::is_rt_shutdown_err(&err)); + + // Interrupted + let err = io::Error::from(io::ErrorKind::Interrupted); + assert!(!tokio::runtime::is_rt_shutdown_err(&err)); + + // String that contains the shutdown message but has a prefix/suffix + let msg = "A Tokio 1.x context was found, but it is being shutdown. (extra info)"; + let err = io::Error::new(io::ErrorKind::Other, msg); + assert!(!tokio::runtime::is_rt_shutdown_err(&err)); + + let msg = "Error: A Tokio 1.x context was found, but it is being shutdown."; + let err = io::Error::new(io::ErrorKind::Other, msg); + assert!(!tokio::runtime::is_rt_shutdown_err(&err)); +}