Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ windows = { version = "0.61.1", optional = true }
[dev-dependencies]
remoteprocess = "0.5.0"
tokio = { version = "1.38.2", features = ["io-util", "macros", "process", "rt", "rt-multi-thread", "time"] }
tracing-subscriber = "0.3.19"

[features]
default = ["creation-flags", "job-object", "kill-on-drop", "process-group", "process-session", "tracing"]
Expand Down
107 changes: 64 additions & 43 deletions src/generic_wrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ macro_rules! Wrap {
#[derive(Debug)]
pub struct $name {
command: $command,
wrappers: ::indexmap::IndexMap<::std::any::TypeId, Box<dyn $wrapper>>,
wrappers: ::indexmap::IndexMap<::std::any::TypeId, ::std::cell::RefCell<::std::option::Option<Box<dyn $wrapper>>>>,
}

impl $name {
Expand Down Expand Up @@ -62,70 +62,78 @@ macro_rules! Wrap {
/// Returns `&mut self` for chaining.
pub fn wrap<W: $wrapper + 'static>(&mut self, wrapper: W) -> &mut Self {
let typeid = ::std::any::TypeId::of::<W>();
let mut wrapper = Some(Box::new(wrapper));
let boxed: Box<(dyn $wrapper + 'static)> = Box::new(wrapper);
let mut wrapper = Some(::std::cell::RefCell::new(Some(boxed)));
let extant = self
.wrappers
.entry(typeid)
.or_insert_with(|| wrapper.take().unwrap());
.or_insert_with(|| {
#[cfg(feature = "tracing")]
::tracing::debug!(id=?typeid, "wrap");
wrapper.take().unwrap()
});
if let Some(wrapper) = wrapper {
extant.extend(wrapper);
#[cfg(feature = "tracing")]
::tracing::debug!(id=?typeid, "wrap extend");
// UNWRAPs: we've just created those so we know they're Somes
extant.get_mut().as_mut().unwrap().extend(wrapper.into_inner().unwrap());
}

self
}

// poor man's try..finally block
#[inline]
fn spawn_inner(
&self,
command: &mut $command,
wrappers: &mut ::indexmap::IndexMap<::std::any::TypeId, Box<dyn $wrapper>>,
) -> ::std::io::Result<Box<dyn $childer>> {
for (id, wrapper) in wrappers.iter_mut() {
/// Spawn the command, returning a `Child` that can be interacted with.
///
/// In order, this runs all the `pre_spawn` hooks, then spawns the command, then runs
/// all the `post_spawn` hooks, then stacks all the `wrap_child`s. As it returns a boxed
/// trait object, only the methods from the trait are available directly; however you
/// may downcast to the concrete type of the last applied wrapper if you need to.
pub fn spawn(&mut self) -> ::std::io::Result<Box<dyn $childer>> {
// for each loop, we extract the active wrapper from its cell
// so we can use it mutably independently from the self borrow
// then we re-insert it; this happens regardless of the result

for (id, cell) in &self.wrappers {
#[cfg(feature = "tracing")]
::tracing::debug!(?id, "pre_spawn");
wrapper.pre_spawn(command, self)?;
if let Some(mut wrapper) = cell.take() {
let mut command = ::std::mem::replace(&mut self.command, <$command>::new(""));
let ret = wrapper.pre_spawn(&mut command, self);
self.command = command;
cell.replace(Some(wrapper));
ret?;
}
}

let mut child = command.spawn()?;
for (id, wrapper) in wrappers.iter_mut() {
let mut child = self.command.spawn()?;
for (id, cell) in &self.wrappers {
#[cfg(feature = "tracing")]
::tracing::debug!(?id, "post_spawn");
wrapper.post_spawn(&mut child, self)?;
if let Some(mut wrapper) = cell.take() {
let ret = wrapper.post_spawn(&mut child, self);
cell.replace(Some(wrapper));
ret?;
}
}

let mut child = Box::new(
#[allow(clippy::redundant_closure_call)]
$first_child_wrapper(child),
) as Box<dyn $childer>;

for (id, wrapper) in wrappers.iter_mut() {
for (id, cell) in &self.wrappers {
#[cfg(feature = "tracing")]
::tracing::debug!(?id, "wrap_child");
child = wrapper.wrap_child(child, self)?;
if let Some(mut wrapper) = cell.take() {
let ret = wrapper.wrap_child(child, self);
cell.replace(Some(wrapper));
child = ret?;
}
}

Ok(child)
}

/// Spawn the command, returning a `Child` that can be interacted with.
///
/// In order, this runs all the `pre_spawn` hooks, then spawns the command, then runs
/// all the `post_spawn` hooks, then stacks all the `wrap_child`s. As it returns a boxed
/// trait object, only the methods from the trait are available directly; however you
/// may downcast to the concrete type of the last applied wrapper if you need to.
pub fn spawn(&mut self) -> ::std::io::Result<Box<dyn $childer>> {
let mut command = ::std::mem::replace(&mut self.command, <$command>::new(""));
let mut wrappers = ::std::mem::take(&mut self.wrappers);

let res = self.spawn_inner(&mut command, &mut wrappers);

self.command = command;
self.wrappers = wrappers;

res
}

/// Check if a wrapper of a given type is present.
pub fn has_wrap<W: $wrapper + 'static>(&self) -> bool {
let typeid = ::std::any::TypeId::of::<W>();
Expand All @@ -139,14 +147,27 @@ macro_rules! Wrap {
///
/// Returns `None` if the wrapper is not present. To merely check if a wrapper is
/// present, use `has_wrap` instead.
pub fn get_wrap<W: $wrapper + 'static>(&self) -> Option<&W> {
///
/// Note that calling `.get_wrap()` to retrieve the current wrapper while within that
/// wrapper's hooks will return `None`. As this is useless (you should trivially have
/// access to the current wrapper), this is not considered a bug.
pub fn get_wrap<W: $wrapper + 'static>(&self) -> Option<::std::cell::Ref<W>> {
let typeid = ::std::any::TypeId::of::<W>();
self.wrappers.get(&typeid).map(|w| {
let w_any = w as &dyn ::std::any::Any;
w_any
.downcast_ref()
.expect("downcasting is guaranteed to succeed due to wrap()'s internals")
})
#[cfg(feature = "tracing")]
::tracing::debug!(id=?typeid, "get wrap");
self.wrappers.get(&typeid)
.and_then(|cell| cell.try_borrow().ok())
.and_then(|borrow| ::std::cell::Ref::filter_map(
borrow, |opt| opt.as_ref().map(|w: &Box<dyn $wrapper>| {
#[cfg(feature = "tracing")]
::tracing::debug!(id=?typeid, "got wrap");
let w_any = w as &dyn ::std::any::Any;
w_any
.downcast_ref()
.expect("downcasting is guaranteed to succeed due to wrap()'s internals")
})
).ok())

}
}

Expand Down
2 changes: 1 addition & 1 deletion src/std/job_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl StdCommandWrapper for JobObject {
fn pre_spawn(&mut self, command: &mut Command, core: &StdCommandWrap) -> Result<()> {
let mut flags = CREATE_SUSPENDED;
#[cfg(feature = "creation-flags")]
if let Some(CreationFlags(user_flags)) = core.get_wrap::<CreationFlags>() {
if let Some(CreationFlags(user_flags)) = core.get_wrap::<CreationFlags>().as_deref() {
flags |= *user_flags;
}

Expand Down
2 changes: 1 addition & 1 deletion src/tokio/job_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl TokioCommandWrapper for JobObject {
fn pre_spawn(&mut self, command: &mut Command, core: &TokioCommandWrap) -> Result<()> {
let mut flags = CREATE_SUSPENDED;
#[cfg(feature = "creation-flags")]
if let Some(CreationFlags(user_flags)) = core.get_wrap::<CreationFlags>() {
if let Some(CreationFlags(user_flags)) = core.get_wrap::<CreationFlags>().as_deref() {
flags |= *user_flags;
}

Expand Down
9 changes: 8 additions & 1 deletion tests/std_windows/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod prelude {
pub use std::{
io::{Read, Result, Write},
process::Stdio,
process::{Command, Stdio},
thread::sleep,
time::Duration,
};
Expand All @@ -11,10 +11,17 @@ mod prelude {
pub const DIE_TIME: Duration = Duration::from_millis(1000);
}

fn init() {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.init();
}

mod id_same_as_inner;
mod inner_read_stdout;
mod into_inner_write_stdin;
mod kill_and_try_wait;
mod read_creation_flags;
mod try_wait_after_die;
mod wait_after_die;
mod wait_twice;
Expand Down
41 changes: 41 additions & 0 deletions tests/std_windows/read_creation_flags.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use std::sync::{
atomic::{AtomicU32, Ordering},
Arc,
};

use windows::Win32::System::Threading::CREATE_NO_WINDOW;

use super::prelude::*;

#[derive(Clone, Debug, Default)]
pub struct FlagSpy {
pub flags: Arc<AtomicU32>,
}

impl StdCommandWrapper for FlagSpy {
fn pre_spawn(&mut self, _command: &mut Command, core: &StdCommandWrap) -> Result<()> {
#[cfg(feature = "creation-flags")]
if let Some(CreationFlags(user_flags)) = core.get_wrap::<CreationFlags>().as_deref() {
self.flags.store(user_flags.0, Ordering::Relaxed);
}

Ok(())
}
}

#[test]
fn retrieve_flags() -> Result<()> {
super::init();

let spy = FlagSpy::default();
let _ = StdCommandWrap::with_new("powershell.exe", |command| {
command.arg("/C").arg("echo hello").stdout(Stdio::piped());
})
.wrap(CreationFlags(CREATE_NO_WINDOW))
.wrap(spy.clone())
.spawn()?;

assert_eq!(spy.flags.load(Ordering::Relaxed), CREATE_NO_WINDOW.0);

Ok(())
}
Loading