Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions hugr-passes/src/composable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ mod test {

use crate::const_fold::{ConstFoldError, ConstantFoldPass};
use crate::dead_code::DeadCodeElimError;
use crate::untuple::{UntupleRecursive, UntupleResult};
use crate::{DeadCodeElimPass, ReplaceTypes, UntuplePass};
use crate::untuple::UntupleResult;
use crate::{DeadCodeElimPass, PassScope, ReplaceTypes, UntuplePass};

use super::{ComposablePass, IfThen, ValidatePassError, ValidatingPass, validate_if_test};

Expand Down Expand Up @@ -416,7 +416,7 @@ mod test {
fb.finish_hugr_with_outputs(untup.outputs()).unwrap()
};

let untup = UntuplePass::new(UntupleRecursive::Recursive);
let untup = UntuplePass::new_scoped(PassScope::EntrypointRecursive);
{
// Change usize_t to INT_TYPES[6], and if that did anything (it will!), then Untuple
let mut repl = ReplaceTypes::default();
Expand Down
199 changes: 158 additions & 41 deletions hugr-passes/src/dead_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
use std::collections::HashSet;

use hugr_core::{
HugrView, Node,
HugrView, Node, Visibility,
hugr::hugrmut::HugrMut,
module_graph::{ModuleGraph, StaticNode},
ops::{OpTag, OpTrait},
};
use itertools::Either;
use petgraph::visit::{Dfs, Walker};

use crate::{
ComposablePass,
ComposablePass, PassScope,
composable::{ValidatePassError, validate_if_test},
};

Expand Down Expand Up @@ -48,41 +49,93 @@ fn reachable_funcs<'a, H: HugrView>(
})
}

#[derive(Debug, Clone, Default)]
#[derive(Debug, Clone)]
/// A configuration for the Dead Function Removal pass.
pub struct RemoveDeadFuncsPass {
entry_points: Vec<Node>,
entry_points: Either<Vec<Node>, PassScope>,
}

impl Default for RemoveDeadFuncsPass {
fn default() -> Self {
Self {
entry_points: Either::Left(Vec::new()),
}
}
}

impl RemoveDeadFuncsPass {
#[deprecated(note = "Use RemoveDeadFuncsPass::with_scope")]
/// Adds new entry points - these must be [`FuncDefn`] nodes
/// that are children of the [`Module`] at the root of the Hugr.
///
/// Overrides any [PassScope] set by a call to [Self::with_scope].
///
/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn
/// [`Module`]: hugr_core::ops::OpType::Module
pub fn with_module_entry_points(
mut self,
entry_points: impl IntoIterator<Item = Node>,
) -> Self {
self.entry_points.extend(entry_points);
let v = match self.entry_points {
Either::Left(ref mut v) => v,
Either::Right(_) => {
self.entry_points = Either::Left(Vec::new());
self.entry_points.as_mut().unwrap_left()
}
};
v.extend(entry_points);
self
}
}

impl<H: HugrMut<Node = Node>> ComposablePass<H> for RemoveDeadFuncsPass {
type Error = RemoveDeadFuncsError;
type Result = ();

/// Overrides any entrypoints set by a call to [Self::with_module_entry_points].
fn with_scope(mut self, scope: &PassScope) -> Self {
self.entry_points = Either::Right(scope.clone());
self
}

fn run(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> {
let mut entry_points = Vec::new();
for &n in self.entry_points.iter() {
if !hugr.get_optype(n).is_func_defn() {
return Err(RemoveDeadFuncsError::InvalidEntryPoint { node: n });
match &self.entry_points {
Either::Left(ep) => {
for &n in ep {
if !hugr.get_optype(n).is_func_defn() {
return Err(RemoveDeadFuncsError::InvalidEntryPoint { node: n });
}
debug_assert_eq!(hugr.get_parent(n), Some(hugr.module_root()));
entry_points.push(n);
}
if hugr.entrypoint() != hugr.module_root() {
entry_points.push(hugr.entrypoint())
}
}
Either::Right(PassScope::EntrypointFlat | PassScope::EntrypointRecursive) => {
// If the entrypoint is the module root, not allowed to touch anything.
// Otherwise, we must keep the entrypoint (and can touch only inside it).
return Ok(());
}
Either::Right(PassScope::PreserveAll) => return Ok(()), // Optimize whole Hugr but keep all functions
Either::Right(PassScope::PreservePublic) => {
for n in hugr.children(hugr.module_root()) {
if let Some(fd) = hugr.get_optype(n).as_func_defn()
&& fd.visibility() == &Visibility::Public {
entry_points.push(n);
}
}
if hugr.entrypoint() != hugr.module_root() {
entry_points.push(hugr.entrypoint());
}
}
Either::Right(PassScope::PreserveEntrypoint) => {
if hugr.entrypoint() == hugr.module_root() {
return Ok(());
};
entry_points.push(hugr.entrypoint())
}
debug_assert_eq!(hugr.get_parent(n), Some(hugr.module_root()));
entry_points.push(n);
}
if hugr.entrypoint() != hugr.module_root() {
entry_points.push(hugr.entrypoint())
}

let mut reachable =
Expand Down Expand Up @@ -124,6 +177,8 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for RemoveDeadFuncsPass {
/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn
/// [`LoadFunction`]: hugr_core::ops::OpType::LoadFunction
/// [`Module`]: hugr_core::ops::OpType::Module
#[deprecated]
#[expect(deprecated)]
pub fn remove_dead_funcs(
h: &mut impl HugrMut<Node = Node>,
entry_points: impl IntoIterator<Item = Node>,
Expand All @@ -134,56 +189,90 @@ pub fn remove_dead_funcs(
)
}

/// Deletes from the Hugr any functions that are not used by either [`Call`] or
/// [`LoadFunction`] nodes in reachable parts, according to the specified [PassScope].
// TODO: after removing the deprecated `remove_dead_funcs`, rename this over it
pub fn remove_dead_funcs_scoped<H: HugrMut<Node = Node>>(
h: &mut H,
scope: &PassScope,
) -> Result<(), ValidatePassError<Node, RemoveDeadFuncsError>> {
validate_if_test(
<RemoveDeadFuncsPass as ComposablePass<H>>::with_scope(
RemoveDeadFuncsPass::default(),
scope,
),
h,
)
}

#[cfg(test)]
mod test {
use std::collections::HashMap;

use hugr_core::ops::handle::NodeHandle;
use hugr_core::{Hugr, Visibility};
use itertools::Itertools;
use rstest::rstest;

use hugr_core::builder::{Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder};
use hugr_core::hugr::hugrmut::HugrMut;
use hugr_core::{HugrView, extension::prelude::usize_t, types::Signature};

use super::remove_dead_funcs;
use crate::PassScope;
use crate::dead_funcs::remove_dead_funcs_scoped;

fn hugr(use_entrypoint: bool) -> Hugr {
let mut hb = ModuleBuilder::new();
let o2 = hb
.define_function("from_pub", Signature::new_endo(usize_t()))
.unwrap();
let o2inp = o2.input_wires();
let o2 = o2.finish_with_outputs(o2inp).unwrap();
let mut o1 = hb
.define_function_vis(
"pubfunc",
Signature::new_endo(usize_t()),
Visibility::Public,
)
.unwrap();

let o1c = o1.call(o2.handle(), &[], o1.input_wires()).unwrap();
o1.finish_with_outputs(o1c.outputs()).unwrap();

let fm = hb
.define_function("from_main", Signature::new_endo(usize_t()))
.unwrap();
let f_inp = fm.input_wires();
let fm = fm.finish_with_outputs(f_inp).unwrap();
let mut m = hb
.define_function("main", Signature::new_endo(usize_t()))
.unwrap();
let m_in = m.input_wires();
let mut dfb = m.dfg_builder(Signature::new_endo(usize_t()), m_in).unwrap();
let c = dfb.call(fm.handle(), &[], dfb.input_wires()).unwrap();
let dfg = dfb.finish_with_outputs(c.outputs()).unwrap();
m.finish_with_outputs(dfg.outputs()).unwrap();
let mut h = hb.finish_hugr().unwrap();
if use_entrypoint {
h.set_entrypoint(dfg.node());
}
h
}

#[rstest]
#[case(false, [], vec![])] // No entry_points removes everything!
#[case(true, [], vec!["from_main", "main"])]
#[case(false, ["main"], vec!["from_main", "main"])]
#[case(false, ["from_main"], vec!["from_main"])]
#[case(false, ["other1"], vec!["other1", "other2"])]
#[case(true, ["other2"], vec!["from_main", "main", "other2"])]
#[case(false, ["other1", "other2"], vec!["other1", "other2"])]
#[case(false, ["pubfunc"], vec!["from_pub", "pubfunc"])]
#[case(true, ["from_pub"], vec!["from_main", "from_pub", "main"])]
#[case(false, ["from_pub", "pubfunc"], vec!["from_pub", "pubfunc"])]
fn remove_dead_funcs_entry_points(
#[case] use_hugr_entrypoint: bool,
#[case] entry_points: impl IntoIterator<Item = &'static str>,
#[case] retained_funcs: Vec<&'static str>,
) -> Result<(), Box<dyn std::error::Error>> {
let mut hb = ModuleBuilder::new();
let o2 = hb.define_function("other2", Signature::new_endo(usize_t()))?;
let o2inp = o2.input_wires();
let o2 = o2.finish_with_outputs(o2inp)?;
let mut o1 = hb.define_function("other1", Signature::new_endo(usize_t()))?;

let o1c = o1.call(o2.handle(), &[], o1.input_wires())?;
o1.finish_with_outputs(o1c.outputs())?;

let fm = hb.define_function("from_main", Signature::new_endo(usize_t()))?;
let f_inp = fm.input_wires();
let fm = fm.finish_with_outputs(f_inp)?;
let mut m = hb.define_function("main", Signature::new_endo(usize_t()))?;
let m_in = m.input_wires();
let mut dfg = m.dfg_builder(Signature::new_endo(usize_t()), m_in)?;
let c = dfg.call(fm.handle(), &[], dfg.input_wires())?;
let dfg = dfg.finish_with_outputs(c.outputs()).unwrap();
m.finish_with_outputs(dfg.outputs())?;

let mut hugr = hb.finish_hugr()?;
if use_hugr_entrypoint {
hugr.set_entrypoint(dfg.node());
}
let mut hugr = hugr(use_hugr_entrypoint);

let avail_funcs = hugr
.children(hugr.module_root())
Expand All @@ -194,7 +283,8 @@ mod test {
})
.collect::<HashMap<_, _>>();

remove_dead_funcs(
#[expect(deprecated)]
super::remove_dead_funcs(
&mut hugr,
entry_points
.into_iter()
Expand All @@ -215,4 +305,31 @@ mod test {
assert_eq!(remaining_funcs, retained_funcs);
Ok(())
}

#[rstest]
#[case(PassScope::PreserveAll, false, vec!["from_main", "from_pub", "main", "pubfunc"])]
#[case(PassScope::EntrypointFlat, true, vec!["from_main", "from_pub", "main", "pubfunc"])]
#[case(PassScope::EntrypointRecursive, false, vec!["from_main", "from_pub", "main", "pubfunc"])]
#[case(PassScope::PreservePublic, true, vec!["from_main", "from_pub", "main", "pubfunc"])]
#[case(PassScope::PreservePublic, false, vec!["from_pub", "pubfunc"])]
#[case(PassScope::PreserveEntrypoint, true, vec!["from_main", "main"])]
fn remove_dead_funcs_scope(
#[case] scope: PassScope,
#[case] use_entrypoint: bool,
#[case] retained_funcs: Vec<&'static str>,
) {
let mut hugr = hugr(use_entrypoint);
remove_dead_funcs_scoped(&mut hugr, &scope).unwrap();

let remaining_funcs = hugr
.nodes()
.filter_map(|n| {
hugr.get_optype(n)
.as_func_defn()
.map(|fd| fd.func_name().as_str())
})
.sorted()
.collect_vec();
assert_eq!(remaining_funcs, retained_funcs);
}
}
5 changes: 4 additions & 1 deletion hugr-passes/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ pub use composable::{ComposablePass, InScope, PassScope};

// Pass re-exports
pub use dead_code::DeadCodeElimPass;
pub use dead_funcs::{RemoveDeadFuncsError, RemoveDeadFuncsPass, remove_dead_funcs};
#[deprecated(note = "Use RemoveDeadFuncsPass instead")]
#[expect(deprecated)] // Remove together
pub use dead_funcs::remove_dead_funcs;
pub use dead_funcs::{RemoveDeadFuncsError, RemoveDeadFuncsPass, remove_dead_funcs_scoped};
pub use force_order::{force_order, force_order_by_key};
pub use inline_funcs::inline_acyclic;
pub use lower::{lower_ops, replace_many_ops};
Expand Down
21 changes: 13 additions & 8 deletions hugr-passes/src/monomorphize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,13 +281,14 @@ mod test {
HugrBuilder, ModuleBuilder,
};
use hugr_core::extension::prelude::{ConstUsize, UnpackTuple, UnwrapBuilder, usize_t};
use hugr_core::ops::handle::{FuncID, NodeHandle};
use hugr_core::ops::handle::FuncID;
use hugr_core::ops::{CallIndirect, DataflowOpTrait as _, FuncDefn, Tag};
use hugr_core::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound, TypeEnum};
use hugr_core::{Hugr, HugrView, Node};
use hugr_core::{Hugr, HugrView, Node, Visibility};
use rstest::rstest;

use crate::{monomorphize, remove_dead_funcs};
use crate::dead_funcs::remove_dead_funcs_scoped;
use crate::{PassScope, monomorphize};

use super::{is_polymorphic, mangle_name};

Expand Down Expand Up @@ -349,17 +350,21 @@ mod test {
let trip = fb.add_dataflow_op(tag, [elem1, elem2, elem])?;
fb.finish_with_outputs(trip.outputs())?
};
let mn = {
{
let outs = vec![triple_type(usize_t()), triple_type(pair_type(usize_t()))];
let mut fb = mb.define_function("main", Signature::new(usize_t(), outs))?;
let mut fb = mb.define_function_vis(
"main",
Signature::new(usize_t(), outs),
Visibility::Public,
)?;
let [elem] = fb.input_wires_arr();
let [res1] = fb
.call(tr.handle(), &[usize_t().into()], [elem])?
.outputs_arr();
let pair = fb.call(db.handle(), &[usize_t().into()], [elem])?;
let pty = pair_type(usize_t()).into();
let [res2] = fb.call(tr.handle(), &[pty], pair.outputs())?.outputs_arr();
fb.finish_with_outputs([res1, res2])?
fb.finish_with_outputs([res1, res2])?;
};
let mut hugr = mb.finish_hugr()?;
assert_eq!(
Expand Down Expand Up @@ -394,7 +399,7 @@ mod test {
assert_eq!(mono2, mono); // Idempotent

let mut nopoly = mono;
remove_dead_funcs(&mut nopoly, [mn.node()])?;
remove_dead_funcs_scoped(&mut nopoly, &PassScope::PreservePublic)?;
let mut funcs = list_funcs(&nopoly);

assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd)));
Expand Down Expand Up @@ -621,7 +626,7 @@ mod test {
};

monomorphize(&mut hugr).unwrap();
remove_dead_funcs(&mut hugr, []).unwrap();
remove_dead_funcs_scoped(&mut hugr, &PassScope::PreservePublic).unwrap();

let funcs = list_funcs(&hugr);
assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd)));
Expand Down
Loading
Loading