diff --git a/hugr-core/src/builder.rs b/hugr-core/src/builder.rs index aa2d949056..10527a29a4 100644 --- a/hugr-core/src/builder.rs +++ b/hugr-core/src/builder.rs @@ -261,8 +261,8 @@ pub(crate) mod test { use super::handle::BuildHandle; use super::{ - BuildError, CFGBuilder, Container, DFGBuilder, Dataflow, DataflowHugr, FuncID, - FunctionBuilder, ModuleBuilder, + BuildError, CFGBuilder, DFGBuilder, Dataflow, DataflowHugr, FuncID, FunctionBuilder, + ModuleBuilder, }; use super::{DataflowSubContainer, HugrBuilder}; diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index 03731bb7da..05a81bb198 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -9,7 +9,7 @@ use crate::{Extension, IncomingPort, Node, OutgoingPort}; use std::iter; use std::sync::Arc; -use super::{BuilderWiringError, FunctionBuilder}; +use super::{BuilderWiringError, ModuleBuilder}; use super::{ CircuitBuilder, handle::{BuildHandle, Outputs}, @@ -21,7 +21,7 @@ use crate::{ }; use crate::extension::ExtensionRegistry; -use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeRow}; +use crate::types::{Signature, Type, TypeArg, TypeRow}; use itertools::Itertools; @@ -82,33 +82,6 @@ pub trait Container { self.add_child_node(constant.into()).into() } - /// Add a [`ops::FuncDefn`] node and returns a builder to define the function - /// body graph. - /// - /// # Errors - /// - /// This function will return an error if there is an error in adding the - /// [`ops::FuncDefn`] node. - fn define_function( - &mut self, - name: impl Into, - signature: impl Into, - ) -> Result, BuildError> { - let signature: PolyFuncType = signature.into(); - let body = signature.body().clone(); - let f_node = self.add_child_node(ops::FuncDefn::new(name, signature)); - - // Add the extensions used by the function types. - self.use_extensions( - body.used_extensions().unwrap_or_else(|e| { - panic!("Build-time signatures should have valid extensions. {e}") - }), - ); - - let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?; - Ok(FunctionBuilder::from_dfg_builder(db)) - } - /// Insert a HUGR as a child of the container. fn add_hugr(&mut self, child: Hugr) -> InsertionResult { let parent = self.container_node(); @@ -155,8 +128,19 @@ pub trait Container { } /// Types implementing this trait can be used to build complete HUGRs -/// (with varying root node types) +/// (with varying entrypoint node types) pub trait HugrBuilder: Container { + /// Allows adding definitions to the module root of which + /// this builder is building a part + fn module_root_builder(&mut self) -> ModuleBuilder<&mut Hugr> { + debug_assert!( + self.hugr() + .get_optype(self.hugr().module_root()) + .is_module() + ); + ModuleBuilder(self.hugr_mut()) + } + /// Finish building the HUGR, perform any validation checks and return it. fn finish_hugr(self) -> Result>; } diff --git a/hugr-core/src/builder/circuit.rs b/hugr-core/src/builder/circuit.rs index 58f388f439..28c5c336f7 100644 --- a/hugr-core/src/builder/circuit.rs +++ b/hugr-core/src/builder/circuit.rs @@ -245,7 +245,7 @@ mod test { use cool_asserts::assert_matches; use crate::Extension; - use crate::builder::{Container, HugrBuilder, ModuleBuilder}; + use crate::builder::{HugrBuilder, ModuleBuilder}; use crate::extension::ExtensionId; use crate::extension::prelude::{qb_t, usize_t}; use crate::std_extensions::arithmetic::float_types::ConstF64; diff --git a/hugr-core/src/builder/module.rs b/hugr-core/src/builder/module.rs index 47bba818d2..b57d25c144 100644 --- a/hugr-core/src/builder/module.rs +++ b/hugr-core/src/builder/module.rs @@ -16,13 +16,13 @@ use crate::{Hugr, Node}; use smol_str::SmolStr; /// Builder for a HUGR module. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Default, Clone, PartialEq)] pub struct ModuleBuilder(pub(super) T); impl + AsRef> Container for ModuleBuilder { #[inline] fn container_node(&self) -> Node { - self.0.as_ref().entrypoint() + self.0.as_ref().module_root() } #[inline] @@ -39,13 +39,7 @@ impl ModuleBuilder { /// Begin building a new module. #[must_use] pub fn new() -> Self { - Self(Default::default()) - } -} - -impl Default for ModuleBuilder { - fn default() -> Self { - Self::new() + Self::default() } } @@ -112,6 +106,33 @@ impl + AsRef> ModuleBuilder { Ok(declare_n.into()) } + /// Add a [`ops::FuncDefn`] node and returns a builder to define the function + /// body graph. + /// + /// # Errors + /// + /// This function will return an error if there is an error in adding the + /// [`ops::FuncDefn`] node. + pub fn define_function( + &mut self, + name: impl Into, + signature: impl Into, + ) -> Result, BuildError> { + let signature: PolyFuncType = signature.into(); + let body = signature.body().clone(); + let f_node = self.add_child_node(ops::FuncDefn::new(name, signature)); + + // Add the extensions used by the function types. + self.use_extensions( + body.used_extensions().unwrap_or_else(|e| { + panic!("Build-time signatures should have valid extensions. {e}") + }), + ); + + let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?; + Ok(FunctionBuilder::from_dfg_builder(db)) + } + /// Add a [`crate::ops::OpType::AliasDefn`] node and return a handle to the Alias. /// /// # Errors @@ -208,29 +229,4 @@ mod test { assert_matches!(build_result, Ok(_)); Ok(()) } - - #[test] - fn local_def() -> Result<(), BuildError> { - let build_result = { - let mut module_builder = ModuleBuilder::new(); - - let mut f_build = module_builder.define_function( - "main", - Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]), - )?; - let local_build = f_build.define_function( - "local", - Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]), - )?; - let [wire] = local_build.input_wires_arr(); - let f_id = local_build.finish_with_outputs([wire, wire])?; - - let call = f_build.call(f_id.handle(), &[], f_build.input_wires())?; - - f_build.finish_with_outputs(call.outputs())?; - module_builder.finish_hugr() - }; - assert_matches!(build_result, Ok(_)); - Ok(()) - } } diff --git a/hugr-core/src/hugr/patch/inline_call.rs b/hugr-core/src/hugr/patch/inline_call.rs index b6beb8d459..058b2c2b5a 100644 --- a/hugr-core/src/hugr/patch/inline_call.rs +++ b/hugr-core/src/hugr/patch/inline_call.rs @@ -291,29 +291,32 @@ mod test { fn test_polymorphic() -> Result<(), Box> { let tuple_ty = Type::new_tuple(vec![usize_t(); 2]); let mut fb = FunctionBuilder::new("mkpair", Signature::new(usize_t(), tuple_ty.clone()))?; - let inner = fb.define_function( - "id", - PolyFuncType::new( - [TypeBound::Copyable.into()], - Signature::new_endo(Type::new_var_use(0, TypeBound::Copyable)), - ), - )?; - let inps = inner.input_wires(); - let inner = inner.finish_with_outputs(inps)?; - let call1 = fb.call(inner.handle(), &[usize_t().into()], fb.input_wires())?; + let helper = { + let mut mb = fb.module_root_builder(); + let fb2 = mb.define_function( + "id", + PolyFuncType::new( + [TypeBound::Copyable.into()], + Signature::new_endo(Type::new_var_use(0, TypeBound::Copyable)), + ), + )?; + let inps = fb2.input_wires(); + fb2.finish_with_outputs(inps)? + }; + let call1 = fb.call(helper.handle(), &[usize_t().into()], fb.input_wires())?; let [call1_out] = call1.outputs_arr(); let tup = fb.make_tuple([call1_out, call1_out])?; - let call2 = fb.call(inner.handle(), &[tuple_ty.into()], [tup])?; + let call2 = fb.call(helper.handle(), &[tuple_ty.into()], [tup])?; let mut hugr = fb.finish_hugr_with_outputs(call2.outputs()).unwrap(); assert_eq!( - hugr.output_neighbours(inner.node()).collect::>(), + hugr.output_neighbours(helper.node()).collect::>(), [call1.node(), call2.node()] ); hugr.apply_patch(InlineCall::new(call1.node()))?; assert_eq!( - hugr.output_neighbours(inner.node()).collect::>(), + hugr.output_neighbours(helper.node()).collect::>(), [call2.node()] ); assert!(hugr.get_optype(call1.node()).is_dfg()); diff --git a/hugr-core/src/hugr/patch/simple_replace.rs b/hugr-core/src/hugr/patch/simple_replace.rs index ee47e95f8b..309a60dbf7 100644 --- a/hugr-core/src/hugr/patch/simple_replace.rs +++ b/hugr-core/src/hugr/patch/simple_replace.rs @@ -641,8 +641,8 @@ pub(in crate::hugr::patch) mod test { use crate::builder::test::n_identity; use crate::builder::{ - BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, ModuleBuilder, endo_sig, inout_sig, + BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, + ModuleBuilder, endo_sig, inout_sig, }; use crate::extension::prelude::{bool_t, qb_t}; use crate::hugr::patch::simple_replace::Outcome; diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index e07df4ed12..7348055a7f 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -443,28 +443,12 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { // // This search could be sped-up with a pre-computed LCA structure, but // for valid Hugrs this search should be very short. - // - // For Value edges only, we record any FuncDefn we went through; if there is - // any such, then that is an error, but we report that only if the dom/ext - // relation was otherwise ok (an error about an edge "entering" some ancestor - // node could be misleading if the source isn't where it's expected) - let mut err_entered_func = None; let from_parent_parent = self.hugr.get_parent(from_parent); for (ancestor, ancestor_parent) in iter::successors(to_parent, |&p| self.hugr.get_parent(p)).tuple_windows() { - if !is_static && self.hugr.get_optype(ancestor).is_func_defn() { - err_entered_func.get_or_insert(InterGraphEdgeError::ValueEdgeIntoFunc { - to, - to_offset, - from, - from_offset, - func: ancestor, - }); - } if ancestor_parent == from_parent { // External edge. - err_entered_func.map_or(Ok(()), Err)?; if !is_static { // Must have an order edge. self.hugr @@ -491,7 +475,7 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { ancestor_parent_op: ancestor_parent_op.clone(), }); } - err_entered_func.map_or(Ok(()), Err)?; + // Check domination let (dominator_tree, node_map) = if let Some(tree) = self.dominators.get(&ancestor_parent) { @@ -758,17 +742,6 @@ pub enum InterGraphEdgeError { to_offset: Port, ty: EdgeKind, }, - /// Inter-Graph edges may not enter into `FuncDefns` unless they are static - #[error( - "Inter-graph Value edges cannot enter into FuncDefns. Inter-graph edge from {from} ({from_offset}) to {to} ({to_offset} enters FuncDefn {func}" - )] - ValueEdgeIntoFunc { - from: N, - from_offset: Port, - to: N, - to_offset: Port, - func: N, - }, /// The grandparent of a dominator inter-graph edge must be a CFG container. #[error( "The grandparent of a dominator inter-graph edge must be a CFG container. Found operation {ancestor_parent_op}. In a dominator inter-graph edge from {from} ({from_offset}) to {to} ({to_offset})." diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 8ee95cde61..e3f837caca 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -8,7 +8,7 @@ use super::*; use crate::builder::test::closed_dfg_root_hugr; use crate::builder::{ BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, inout_sig, + FunctionBuilder, HugrBuilder, ModuleBuilder, inout_sig, }; use crate::extension::prelude::Noop; use crate::extension::prelude::{bool_t, qb_t, usize_t}; @@ -225,35 +225,6 @@ fn test_ext_edge() { h.validate().unwrap(); } -#[test] -fn no_ext_edge_into_func() -> Result<(), Box> { - let b2b = Signature::new_endo(bool_t()); - let mut h = DFGBuilder::new(Signature::new(bool_t(), Type::new_function(b2b.clone())))?; - let [input] = h.input_wires_arr(); - - let mut dfg = h.dfg_builder(Signature::new(vec![], Type::new_function(b2b.clone())), [])?; - let mut func = dfg.define_function("AndWithOuter", b2b.clone())?; - let [fn_input] = func.input_wires_arr(); - let and_op = func.add_dataflow_op(and_op(), [fn_input, input])?; // 'ext' edge - let func = func.finish_with_outputs(and_op.outputs())?; - let loadfn = dfg.load_func(func.handle(), &[])?; - let dfg = dfg.finish_with_outputs([loadfn])?; - let res = h.finish_hugr_with_outputs(dfg.outputs()); - assert_eq!( - res, - Err(BuildError::InvalidHUGR( - ValidationError::InterGraphEdgeError(InterGraphEdgeError::ValueEdgeIntoFunc { - from: input.node(), - from_offset: input.source().into(), - to: and_op.node(), - to_offset: IncomingPort::from(1).into(), - func: func.node() - }) - )) - ); - Ok(()) -} - #[test] fn test_local_const() { let mut h = closed_dfg_root_hugr(Signature::new_endo(bool_t())); @@ -456,42 +427,31 @@ fn typevars_declared() -> Result<(), Box> { Ok(()) } -/// Test that nested `FuncDefns` cannot use Type Variables declared by enclosing `FuncDefns` +/// Test that `FuncDefns` cannot be nested. #[test] -fn nested_typevars() -> Result<(), Box> { - const OUTER_BOUND: TypeBound = TypeBound::Any; - const INNER_BOUND: TypeBound = TypeBound::Copyable; - fn build(t: Type) -> Result { - let mut outer = FunctionBuilder::new( - "outer", - PolyFuncType::new( - [OUTER_BOUND.into()], - Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), - ), - )?; - let inner = outer.define_function( - "inner", - PolyFuncType::new([INNER_BOUND.into()], Signature::new_endo(vec![t])), - )?; - let [w] = inner.input_wires_arr(); - inner.finish_with_outputs([w])?; - let [w] = outer.input_wires_arr(); - outer.finish_hugr_with_outputs([w]) - } - assert!(build(Type::new_var_use(0, INNER_BOUND)).is_ok()); - assert_matches!( - build(Type::new_var_use(1, OUTER_BOUND)).unwrap_err(), - BuildError::InvalidHUGR(ValidationError::SignatureError { - cause: SignatureError::FreeTypeVar { - idx: 1, - num_decls: 1 - }, - .. +fn no_nested_funcdefns() -> Result<(), Box> { + let mut outer = FunctionBuilder::new("outer", Signature::new_endo(usize_t()))?; + let inner = outer + .add_hugr({ + let inner = FunctionBuilder::new("inner", Signature::new_endo(bool_t()))?; + let [w] = inner.input_wires_arr(); + inner.finish_hugr_with_outputs([w])? }) + .inserted_entrypoint; + let [w] = outer.input_wires_arr(); + let outer_node = outer.container_node(); + let hugr = outer.finish_hugr_with_outputs([w]); + assert_matches!( + hugr.unwrap_err(), + BuildError::InvalidHUGR(ValidationError::InvalidParentOp { + child_optype: OpType::FuncDefn(_), + allowed_children: OpTag::DataflowChild, + parent_optype: OpType::FuncDefn(_), + child, parent + }) => {assert_eq!(child, inner); + assert_eq!(parent, outer_node); + } ); - assert_matches!(build(Type::new_var_use(0, OUTER_BOUND)).unwrap_err(), - BuildError::InvalidHUGR(ValidationError::SignatureError { cause: SignatureError::TypeVarDoesNotMatchDeclaration { actual, cached }, .. }) => - {assert_eq!(actual, INNER_BOUND.into()); assert_eq!(cached, OUTER_BOUND.into())}); Ok(()) } @@ -610,10 +570,11 @@ fn row_variables() -> Result<(), Box> { // All the wires here are carrying higher-order Function values let [func_arg] = fb.input_wires_arr(); let id_usz = { - let bldr = fb.define_function("id_usz", Signature::new_endo(usize_t()))?; + let mut mb = fb.module_root_builder(); + let bldr = mb.define_function("id_usz", Signature::new_endo(usize_t()))?; let vals = bldr.input_wires(); - let inner_def = bldr.finish_with_outputs(vals)?; - fb.load_func(inner_def.handle(), &[])? + let helper_def = bldr.finish_with_outputs(vals)?; + fb.load_func(helper_def.handle(), &[])? }; let par = e.instantiate_extension_op( "parallel", @@ -624,80 +585,6 @@ fn row_variables() -> Result<(), Box> { Ok(()) } -#[test] -fn test_polymorphic_call() -> Result<(), Box> { - // TODO: This tests a function call that is polymorphic in an extension set. - // Should this be rewritten to be polymorphic in something else or removed? - - let e = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { - let params: Vec = vec![TypeBound::Any.into(), TypeBound::Any.into()]; - let evaled_fn = Type::new_function(Signature::new( - Type::new_var_use(0, TypeBound::Any), - Type::new_var_use(1, TypeBound::Any), - )); - // Single-input/output version of the higher-order "eval" operation, with extension param. - // Note the extension-delta of the eval node includes that of the input function. - ext.add_op( - "eval".into(), - String::new(), - PolyFuncTypeRV::new( - params.clone(), - Signature::new( - vec![evaled_fn, Type::new_var_use(0, TypeBound::Any)], - Type::new_var_use(1, TypeBound::Any), - ), - ), - extension_ref, - )?; - - Ok(()) - })?; - - fn utou() -> Type { - Type::new_function(Signature::new_endo(usize_t())) - } - - let int_pair = Type::new_tuple(vec![usize_t(); 2]); - // Root DFG: applies a function int-->int to each element of a pair of two ints - let mut d = DFGBuilder::new(inout_sig( - vec![utou(), int_pair.clone()], - vec![int_pair.clone()], - ))?; - // ....by calling a function (int-->int, int_pair) -> int_pair - let f = { - let mut f = d.define_function( - "two_ints", - PolyFuncType::new( - vec![], - Signature::new(vec![utou(), int_pair.clone()], int_pair.clone()), - ), - )?; - let [func, tup] = f.input_wires_arr(); - let mut c = f.conditional_builder( - (vec![vec![usize_t(); 2].into()], tup), - vec![], - vec![usize_t(); 2].into(), - )?; - let mut cc = c.case_builder(0)?; - let [i1, i2] = cc.input_wires_arr(); - let op = e.instantiate_extension_op("eval", vec![usize_t().into(), usize_t().into()])?; - let [f1] = cc.add_dataflow_op(op.clone(), [func, i1])?.outputs_arr(); - let [f2] = cc.add_dataflow_op(op, [func, i2])?.outputs_arr(); - cc.finish_with_outputs([f1, f2])?; - let res = c.finish_sub_container()?.outputs(); - let tup = f.make_tuple(res)?; - f.finish_with_outputs([tup])? - }; - - let [func, tup] = d.input_wires_arr(); - let call = d.call(f.handle(), &[], [func, tup])?; - let h = d.finish_hugr_with_outputs(call.outputs())?; - let call_ty = h.get_optype(call.node()).dataflow_signature().unwrap(); - let exp_fun_ty = Signature::new(vec![utou(), int_pair.clone()], int_pair); - assert_eq!(call_ty.as_ref(), &exp_fun_ty); - Ok(()) -} - #[test] fn test_polymorphic_load() -> Result<(), Box> { let mut m = ModuleBuilder::new(); diff --git a/hugr-core/src/hugr/views/root_checked/dfg.rs b/hugr-core/src/hugr/views/root_checked/dfg.rs index f9681c9ca7..fbff077266 100644 --- a/hugr-core/src/hugr/views/root_checked/dfg.rs +++ b/hugr-core/src/hugr/views/root_checked/dfg.rs @@ -268,7 +268,7 @@ mod test { use super::*; use crate::builder::{ - Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, endo_sig, + DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, endo_sig, }; use crate::extension::prelude::{bool_t, qb_t}; use crate::hugr::views::root_checked::RootChecked; diff --git a/hugr-core/src/hugr/views/tests.rs b/hugr-core/src/hugr/views/tests.rs index 3a83b5b9d8..28d304d2ee 100644 --- a/hugr-core/src/hugr/views/tests.rs +++ b/hugr-core/src/hugr/views/tests.rs @@ -4,7 +4,8 @@ use rstest::{fixture, rstest}; use crate::{ Hugr, HugrView, builder::{ - BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr, endo_sig, inout_sig, + BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr, HugrBuilder, + endo_sig, inout_sig, }, extension::prelude::qb_t, ops::{ @@ -183,8 +184,9 @@ fn test_dataflow_ports_only() { let mut dfg = DFGBuilder::new(endo_sig(bool_t())).unwrap(); let local_and = { - let local_and = dfg - .define_function("and", Signature::new(vec![bool_t(); 2], vec![bool_t()])) + let mut mb = dfg.module_root_builder(); + let local_and = mb + .define_function("and", Signature::new(vec![bool_t(); 2], bool_t())) .unwrap(); let first_input = local_and.input().out_wire(0); local_and.finish_with_outputs([first_input]).unwrap() diff --git a/hugr-core/src/ops/tag.rs b/hugr-core/src/ops/tag.rs index bed7e47370..2834cd94eb 100644 --- a/hugr-core/src/ops/tag.rs +++ b/hugr-core/src/ops/tag.rs @@ -57,6 +57,8 @@ pub enum OpTag { /// A function load operation. LoadFunc, /// A definition that could be at module level or inside a DSG. + /// Note that this means only Constants, as all other defn/decls + /// must be at Module level. ScopedDefn, /// A tail-recursive loop. TailLoop, @@ -112,8 +114,8 @@ impl OpTag { OpTag::Input => &[OpTag::DataflowChild], OpTag::Output => &[OpTag::DataflowChild], OpTag::Function => &[OpTag::ModuleOp, OpTag::StaticOutput], - OpTag::Alias => &[OpTag::ScopedDefn], - OpTag::FuncDefn => &[OpTag::Function, OpTag::ScopedDefn, OpTag::DataflowParent], + OpTag::Alias => &[OpTag::ModuleOp], + OpTag::FuncDefn => &[OpTag::Function, OpTag::DataflowParent], OpTag::DataflowBlock => &[OpTag::ControlFlowChild, OpTag::DataflowParent], OpTag::BasicBlockExit => &[OpTag::ControlFlowChild], OpTag::Case => &[OpTag::Any, OpTag::DataflowParent], diff --git a/hugr-llvm/src/emit/ops/cfg.rs b/hugr-llvm/src/emit/ops/cfg.rs index 744427a1c6..02f42e3d74 100644 --- a/hugr-llvm/src/emit/ops/cfg.rs +++ b/hugr-llvm/src/emit/ops/cfg.rs @@ -217,7 +217,7 @@ impl<'c, 'hugr, H: HugrView> CfgEmitter<'c, 'hugr, H> { #[cfg(test)] mod test { - use hugr_core::builder::{Dataflow, DataflowSubContainer, SubContainer}; + use hugr_core::builder::{Dataflow, DataflowHugr, SubContainer}; use hugr_core::extension::ExtensionRegistry; use hugr_core::extension::prelude::{self, bool_t}; use hugr_core::ops::Value; @@ -279,7 +279,7 @@ mod test { cfg_builder.branch(&b1, 1, &exit_block).unwrap(); let cfg = cfg_builder.finish_sub_container().unwrap(); let [cfg_out] = cfg.outputs_arr(); - builder.finish_with_outputs([cfg_out]).unwrap() + builder.finish_hugr_with_outputs([cfg_out]).unwrap() }); llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); check_emission!(hugr, llvm_ctx); @@ -395,7 +395,7 @@ mod test { .unwrap() .outputs_arr() }; - builder.finish_with_outputs([outer_cfg_out]).unwrap() + builder.finish_hugr_with_outputs([outer_cfg_out]).unwrap() }); check_emission!(hugr, llvm_ctx); } diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@llvm14.snap index 9ea0d09e8d..d673a4b73e 100644 --- a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@llvm14.snap +++ b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@llvm14.snap @@ -13,21 +13,12 @@ entry_block: ; preds = %alloca_block br label %0 0: ; preds = %entry_block - %1 = call i1 @_hl.scoped_func.7() - switch i1 false, label %2 [ + switch i1 false, label %1 [ ] -2: ; preds = %0 - br label %3 +1: ; preds = %0 + br label %2 -3: ; preds = %2 - ret i1 %1 -} - -define i1 @_hl.scoped_func.7() { -alloca_block: - br label %entry_block - -entry_block: ; preds = %alloca_block +2: ; preds = %1 ret i1 false } diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@pre-mem2reg@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@pre-mem2reg@llvm14.snap index c38ac33f4d..025b85a9ac 100644 --- a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_cfg_children@pre-mem2reg@llvm14.snap @@ -10,31 +10,30 @@ alloca_block: %"0" = alloca i1, align 1 %"4_0" = alloca i1, align 1 %"01" = alloca i1, align 1 - %"15_0" = alloca {}, align 8 - %"16_0" = alloca i1, align 1 + %"11_0" = alloca {}, align 8 + %"12_0" = alloca i1, align 1 br label %entry_block entry_block: ; preds = %alloca_block br label %0 0: ; preds = %entry_block - %1 = call i1 @_hl.scoped_func.7() - store i1 %1, i1* %"16_0", align 1 - store {} undef, {}* %"15_0", align 1 - %"15_02" = load {}, {}* %"15_0", align 1 - %"16_03" = load i1, i1* %"16_0", align 1 - store {} %"15_02", {}* %"15_0", align 1 - store i1 %"16_03", i1* %"16_0", align 1 - %"15_04" = load {}, {}* %"15_0", align 1 - %"16_05" = load i1, i1* %"16_0", align 1 - switch i1 false, label %2 [ + store i1 false, i1* %"12_0", align 1 + store {} undef, {}* %"11_0", align 1 + %"11_02" = load {}, {}* %"11_0", align 1 + %"12_03" = load i1, i1* %"12_0", align 1 + store {} %"11_02", {}* %"11_0", align 1 + store i1 %"12_03", i1* %"12_0", align 1 + %"11_04" = load {}, {}* %"11_0", align 1 + %"12_05" = load i1, i1* %"12_0", align 1 + switch i1 false, label %1 [ ] -2: ; preds = %0 - store i1 %"16_05", i1* %"01", align 1 - br label %3 +1: ; preds = %0 + store i1 %"12_05", i1* %"01", align 1 + br label %2 -3: ; preds = %2 +2: ; preds = %1 %"06" = load i1, i1* %"01", align 1 store i1 %"06", i1* %"4_0", align 1 %"4_07" = load i1, i1* %"4_0", align 1 @@ -42,17 +41,3 @@ entry_block: ; preds = %alloca_block %"08" = load i1, i1* %"0", align 1 ret i1 %"08" } - -define i1 @_hl.scoped_func.7() { -alloca_block: - %"0" = alloca i1, align 1 - %"10_0" = alloca i1, align 1 - br label %entry_block - -entry_block: ; preds = %alloca_block - store i1 false, i1* %"10_0", align 1 - %"10_01" = load i1, i1* %"10_0", align 1 - store i1 %"10_01", i1* %"0", align 1 - %"02" = load i1, i1* %"0", align 1 - ret i1 %"02" -} diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@llvm14.snap deleted file mode 100644 index ea9074b87b..0000000000 --- a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@llvm14.snap +++ /dev/null @@ -1,23 +0,0 @@ ---- -source: hugr-llvm/src/emit/test.rs -expression: mod_str ---- -; ModuleID = 'test_context' -source_filename = "test_context" - -define i1 @_hl.main.1() { -alloca_block: - br label %entry_block - -entry_block: ; preds = %alloca_block - %0 = call i1 @_hl.scoped_func.8() - ret i1 %0 -} - -define i1 @_hl.scoped_func.8() { -alloca_block: - br label %entry_block - -entry_block: ; preds = %alloca_block - ret i1 false -} diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@pre-mem2reg@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@pre-mem2reg@llvm14.snap deleted file mode 100644 index f990db641b..0000000000 --- a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__diverse_dfg_children@pre-mem2reg@llvm14.snap +++ /dev/null @@ -1,38 +0,0 @@ ---- -source: hugr-llvm/src/emit/test.rs -expression: mod_str ---- -; ModuleID = 'test_context' -source_filename = "test_context" - -define i1 @_hl.main.1() { -alloca_block: - %"0" = alloca i1, align 1 - %"4_0" = alloca i1, align 1 - %"12_0" = alloca i1, align 1 - br label %entry_block - -entry_block: ; preds = %alloca_block - %0 = call i1 @_hl.scoped_func.8() - store i1 %0, i1* %"12_0", align 1 - %"12_01" = load i1, i1* %"12_0", align 1 - store i1 %"12_01", i1* %"4_0", align 1 - %"4_02" = load i1, i1* %"4_0", align 1 - store i1 %"4_02", i1* %"0", align 1 - %"03" = load i1, i1* %"0", align 1 - ret i1 %"03" -} - -define i1 @_hl.scoped_func.8() { -alloca_block: - %"0" = alloca i1, align 1 - %"11_0" = alloca i1, align 1 - br label %entry_block - -entry_block: ; preds = %alloca_block - store i1 false, i1* %"11_0", align 1 - %"11_01" = load i1, i1* %"11_0", align 1 - store i1 %"11_01", i1* %"0", align 1 - %"02" = load i1, i1* %"0", align 1 - ret i1 %"02" -} diff --git a/hugr-llvm/src/emit/test.rs b/hugr-llvm/src/emit/test.rs index d5194bf47b..d79ac361cd 100644 --- a/hugr-llvm/src/emit/test.rs +++ b/hugr-llvm/src/emit/test.rs @@ -1,9 +1,7 @@ use crate::types::HugrFuncType; use crate::utils::fat::FatNode; use anyhow::{Result, anyhow}; -use hugr_core::builder::{ - BuildHandle, Container, DFGWrapper, HugrBuilder, ModuleBuilder, SubContainer, -}; +use hugr_core::builder::{BuildHandle, DFGWrapper, FunctionBuilder}; use hugr_core::extension::ExtensionRegistry; use hugr_core::ops::handle::FuncID; use hugr_core::types::TypeRow; @@ -15,7 +13,7 @@ use inkwell::values::GenericValue; use super::EmitHugr; #[allow(clippy::upper_case_acronyms)] -pub type DFGW<'a> = DFGWrapper<&'a mut Hugr, BuildHandle>>; +pub type DFGW = DFGWrapper>>; pub struct SimpleHugrConfig { ins: TypeRow, @@ -131,31 +129,13 @@ impl SimpleHugrConfig { self } - pub fn finish( - self, - make: impl for<'a> FnOnce(DFGW<'a>) -> as SubContainer>::ContainerHandle, - ) -> Hugr { + pub fn finish(self, make: impl FnOnce(DFGW) -> Hugr) -> Hugr { self.finish_with_exts(|builder, _| make(builder)) } - pub fn finish_with_exts( - self, - make: impl for<'a> FnOnce( - DFGW<'a>, - &ExtensionRegistry, - ) -> as SubContainer>::ContainerHandle, - ) -> Hugr { - let mut mod_b = ModuleBuilder::new(); - let func_b = mod_b - .define_function("main", HugrFuncType::new(self.ins, self.outs)) - .unwrap(); - make(func_b, &self.extensions); - - // Intentionally left as a debugging aid. If the HUGR you construct - // fails validation, uncomment the following line to print it out - // unvalidated. - // println!("{}", mod_b.hugr().mermaid_string()); - mod_b.finish_hugr().unwrap_or_else(|e| panic!("{e}")) + pub fn finish_with_exts(self, make: impl FnOnce(DFGW, &ExtensionRegistry) -> Hugr) -> Hugr { + let func_b = FunctionBuilder::new("main", HugrFuncType::new(self.ins, self.outs)).unwrap(); + make(func_b, &self.extensions) } } @@ -187,11 +167,7 @@ pub use insta; macro_rules! check_emission { // Call the macro with a snapshot name. ($snapshot_name:expr, $hugr: ident, $test_ctx:ident) => {{ - let root = - $crate::utils::fat::FatExt::fat_root::<$crate::emit::test::hugr_core::ops::Module>( - &$hugr, - ) - .unwrap(); + let root = $crate::utils::fat::FatExt::fat_root(&$hugr).unwrap(); let emission = $crate::emit::test::Emission::emit_hugr(root, $test_ctx.get_emit_hugr()).unwrap(); @@ -237,8 +213,8 @@ mod test_fns { use crate::custom::CodegenExtsBuilder; use crate::types::{HugrFuncType, HugrSumType}; - use hugr_core::builder::DataflowSubContainer; use hugr_core::builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer}; + use hugr_core::builder::{DataflowHugr, DataflowSubContainer}; use hugr_core::extension::PRELUDE_REGISTRY; use hugr_core::extension::prelude::{ConstUsize, bool_t, usize_t}; use hugr_core::ops::constant::CustomConst; @@ -266,7 +242,7 @@ mod test_fns { builder.input_wires(), ) .unwrap(); - builder.finish_with_outputs(tag.outputs()).unwrap() + builder.finish_hugr_with_outputs(tag.outputs()).unwrap() }); let _ = check_emission!(hugr, llvm_ctx); } @@ -284,7 +260,7 @@ mod test_fns { let w = b.input_wires(); b.finish_with_outputs(w).unwrap() }; - builder.finish_with_outputs(dfg.outputs()).unwrap() + builder.finish_hugr_with_outputs(dfg.outputs()).unwrap() }); check_emission!(hugr, llvm_ctx); } @@ -329,7 +305,7 @@ mod test_fns { cond_b.finish_sub_container().unwrap() }; let [o1, o2] = cond.outputs_arr(); - builder.finish_with_outputs([o1, o2]).unwrap() + builder.finish_hugr_with_outputs([o1, o2]).unwrap() }) }; check_emission!(hugr, llvm_ctx); @@ -349,7 +325,7 @@ mod test_fns { .with_extensions(STD_REG.to_owned()) .finish(|mut builder: DFGW| { let konst = builder.add_load_value(v); - builder.finish_with_outputs([konst]).unwrap() + builder.finish_hugr_with_outputs([konst]).unwrap() }); check_emission!(hugr, llvm_ctx); } @@ -411,7 +387,7 @@ mod test_fns { .instantiate_extension_op("iadd", [4.into()]) .unwrap(); let add = builder.add_dataflow_op(ext_op, [k1, k2]).unwrap(); - builder.finish_with_outputs(add.outputs()).unwrap() + builder.finish_hugr_with_outputs(add.outputs()).unwrap() }); check_emission!(hugr, llvm_ctx); } @@ -453,34 +429,6 @@ mod test_fns { check_emission!(hugr, llvm_ctx); } - #[rstest] - fn diverse_dfg_children(llvm_ctx: TestContext) { - let hugr = SimpleHugrConfig::new() - .with_outs(bool_t()) - .finish(|mut builder: DFGW| { - let [r] = { - let mut builder = builder - .dfg_builder(HugrFuncType::new(type_row![], bool_t()), []) - .unwrap(); - let konst = builder.add_constant(Value::false_val()); - let func = { - let mut builder = builder - .define_function( - "scoped_func", - HugrFuncType::new(type_row![], bool_t()), - ) - .unwrap(); - let w = builder.load_const(&konst); - builder.finish_with_outputs([w]).unwrap() - }; - let [r] = builder.call(func.handle(), &[], []).unwrap().outputs_arr(); - builder.finish_with_outputs([r]).unwrap().outputs_arr() - }; - builder.finish_with_outputs([r]).unwrap() - }); - check_emission!(hugr, llvm_ctx); - } - #[rstest] fn diverse_cfg_children(llvm_ctx: TestContext) { let hugr = SimpleHugrConfig::new() @@ -489,29 +437,19 @@ mod test_fns { let [r] = { let mut builder = builder.cfg_builder([], vec![bool_t()].into()).unwrap(); let konst = builder.add_constant(Value::false_val()); - let func = { - let mut builder = builder - .define_function( - "scoped_func", - HugrFuncType::new(type_row![], bool_t()), - ) - .unwrap(); - let w = builder.load_const(&konst); - builder.finish_with_outputs([w]).unwrap() - }; let entry = { let mut builder = builder .entry_builder([type_row![]], vec![bool_t()].into()) .unwrap(); let control = builder.add_load_value(Value::unary_unit_sum()); - let [r] = builder.call(func.handle(), &[], []).unwrap().outputs_arr(); + let r = builder.load_const(&konst); builder.finish_with_outputs(control, [r]).unwrap() }; let exit = builder.exit_block(); builder.branch(&entry, 0, &exit).unwrap(); builder.finish_sub_container().unwrap().outputs_arr() }; - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); check_emission!(hugr, llvm_ctx); } @@ -575,7 +513,7 @@ mod test_fns { .finish_with_outputs(sum_inp_w, []) .unwrap() .outputs_arr(); - builder.finish_with_outputs(outs).unwrap() + builder.finish_hugr_with_outputs(outs).unwrap() }) }; llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); @@ -696,7 +634,7 @@ mod test_fns { }; let [out_int] = tail_l.outputs_arr(); builder - .finish_with_outputs([out_int]) + .finish_hugr_with_outputs([out_int]) .unwrap_or_else(|e| panic!("{e}")) }) } @@ -731,7 +669,7 @@ mod test_fns { .with_extensions(PRELUDE_REGISTRY.to_owned()) .finish(|mut builder: DFGW| { let konst = builder.add_load_value(ConstUsize::new(42)); - builder.finish_with_outputs([konst]).unwrap() + builder.finish_hugr_with_outputs([konst]).unwrap() }); exec_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); assert_eq!(42, exec_ctx.exec_hugr_u64(hugr, "main")); diff --git a/hugr-llvm/src/extension/collections/array.rs b/hugr-llvm/src/extension/collections/array.rs index 2a602e291c..5f1d2fcd56 100644 --- a/hugr-llvm/src/extension/collections/array.rs +++ b/hugr-llvm/src/extension/collections/array.rs @@ -890,7 +890,7 @@ pub fn emit_scan_op<'c, H: HugrView>( #[cfg(test)] mod test { - use hugr_core::builder::Container as _; + use hugr_core::builder::{DataflowHugr, HugrBuilder}; use hugr_core::extension::prelude::either_type; use hugr_core::ops::Tag; use hugr_core::std_extensions::STD_REG; @@ -934,7 +934,7 @@ mod test { build_all_array_ops(builder.dfg_builder_endo([]).unwrap()) .finish_sub_container() .unwrap(); - builder.finish_sub_container().unwrap() + builder.finish_hugr().unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -953,7 +953,7 @@ mod test { let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); let (_, arr) = builder.add_array_get(usize_t(), 2, arr, us1).unwrap(); builder.add_array_discard(usize_t(), 2, arr).unwrap(); - builder.finish_with_outputs([]).unwrap() + builder.finish_hugr_with_outputs([]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -973,7 +973,7 @@ mod test { let (arr1, arr2) = builder.add_array_clone(usize_t(), 2, arr).unwrap(); builder.add_array_discard(usize_t(), 2, arr1).unwrap(); builder.add_array_discard(usize_t(), 2, arr2).unwrap(); - builder.finish_with_outputs([]).unwrap() + builder.finish_hugr_with_outputs([]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -990,7 +990,7 @@ mod test { .finish(|mut builder| { let vs = vec![ConstUsize::new(1).into(), ConstUsize::new(2).into()]; let arr = builder.add_load_value(array::ArrayValue::new(usize_t(), vs)); - builder.finish_with_outputs([arr]).unwrap() + builder.finish_hugr_with_outputs([arr]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1049,7 +1049,7 @@ mod test { } builder.finish_sub_container().unwrap().out_wire(0) }; - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1154,7 +1154,7 @@ mod test { .unwrap(); conditional.finish_sub_container().unwrap().out_wire(0) }; - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1261,7 +1261,7 @@ mod test { conditional.finish_sub_container().unwrap().out_wire(0) }; builder.add_array_discard(int_ty.clone(), 2, arr).unwrap(); - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1318,7 +1318,7 @@ mod test { builder .add_array_discard(int_ty.clone(), 2, arr_clone) .unwrap(); - builder.finish_with_outputs([elem]).unwrap() + builder.finish_hugr_with_outputs([elem]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1388,7 +1388,7 @@ mod test { arr, ) .unwrap(); - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1420,7 +1420,8 @@ mod test { .with_outs(int_ty.clone()) .with_extensions(exec_registry()) .finish(|mut builder| { - let mut func = builder + let mut mb = builder.module_root_builder(); + let mut func = mb .define_function("foo", Signature::new(vec![], vec![int_ty.clone()])) .unwrap(); let v = func.add_load_value(ConstInt::new_u(6, value).unwrap()); @@ -1441,7 +1442,7 @@ mod test { builder .add_array_discard(int_ty.clone(), size, arr) .unwrap(); - builder.finish_with_outputs([elem]).unwrap() + builder.finish_hugr_with_outputs([elem]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1474,7 +1475,8 @@ mod test { .add_new_array(int_ty.clone(), new_array_args) .unwrap(); - let mut func = builder + let mut mb = builder.module_root_builder(); + let mut func = mb .define_function( "foo", Signature::new(vec![int_ty.clone()], vec![int_ty.clone()]), @@ -1512,7 +1514,7 @@ mod test { builder .add_array_discard_empty(int_ty.clone(), arr) .unwrap(); - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1544,7 +1546,8 @@ mod test { .add_new_array(int_ty.clone(), new_array_args) .unwrap(); - let mut func = builder + let mut mb = builder.module_root_builder(); + let mut func = mb .define_function( "foo", Signature::new( @@ -1568,7 +1571,7 @@ mod test { .unwrap() .outputs_arr(); builder.add_array_discard(Type::UNIT, size, arr).unwrap(); - builder.finish_with_outputs([sum]).unwrap() + builder.finish_hugr_with_outputs([sum]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() diff --git a/hugr-llvm/src/extension/collections/list.rs b/hugr-llvm/src/extension/collections/list.rs index e1ff76e2a8..2bddff687e 100644 --- a/hugr-llvm/src/extension/collections/list.rs +++ b/hugr-llvm/src/extension/collections/list.rs @@ -366,7 +366,7 @@ fn build_load_i8_ptr<'c, H: HugrView>( #[cfg(test)] mod test { use hugr_core::{ - builder::{Dataflow, DataflowSubContainer}, + builder::{Dataflow, DataflowHugr}, extension::{ ExtensionRegistry, prelude::{self, ConstUsize, qb_t, usize_t}, @@ -407,7 +407,7 @@ mod test { .add_dataflow_op(ext_op, hugr_builder.input_wires()) .unwrap() .outputs(); - hugr_builder.finish_with_outputs(outputs).unwrap() + hugr_builder.finish_hugr_with_outputs(outputs).unwrap() }); llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_list_extensions); @@ -427,7 +427,7 @@ mod test { .with_extensions(es) .finish(|mut hugr_builder| { let list = hugr_builder.add_load_value(ListValue::new(elem_ty, contents)); - hugr_builder.finish_with_outputs(vec![list]).unwrap() + hugr_builder.finish_hugr_with_outputs(vec![list]).unwrap() }); llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); diff --git a/hugr-llvm/src/extension/collections/stack_array.rs b/hugr-llvm/src/extension/collections/stack_array.rs index 50410d9630..0d6abbdbae 100644 --- a/hugr-llvm/src/extension/collections/stack_array.rs +++ b/hugr-llvm/src/extension/collections/stack_array.rs @@ -714,7 +714,7 @@ fn emit_scan_op<'c, H: HugrView>( #[cfg(test)] mod test { - use hugr_core::builder::Container as _; + use hugr_core::builder::{DataflowHugr as _, HugrBuilder}; use hugr_core::extension::prelude::either_type; use hugr_core::ops::Tag; use hugr_core::std_extensions::STD_REG; @@ -758,7 +758,7 @@ mod test { build_all_array_ops(builder.dfg_builder_endo([]).unwrap()) .finish_sub_container() .unwrap(); - builder.finish_sub_container().unwrap() + builder.finish_hugr().unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -777,7 +777,7 @@ mod test { let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); let (_, arr) = builder.add_array_get(usize_t(), 2, arr, us1).unwrap(); builder.add_array_discard(usize_t(), 2, arr).unwrap(); - builder.finish_with_outputs([]).unwrap() + builder.finish_hugr_with_outputs([]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -797,7 +797,7 @@ mod test { let (arr1, arr2) = builder.add_array_clone(usize_t(), 2, arr).unwrap(); builder.add_array_discard(usize_t(), 2, arr1).unwrap(); builder.add_array_discard(usize_t(), 2, arr2).unwrap(); - builder.finish_with_outputs([]).unwrap() + builder.finish_hugr_with_outputs([]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -814,7 +814,7 @@ mod test { .finish(|mut builder| { let vs = vec![ConstUsize::new(1).into(), ConstUsize::new(2).into()]; let arr = builder.add_load_value(array::ArrayValue::new(usize_t(), vs)); - builder.finish_with_outputs([arr]).unwrap() + builder.finish_hugr_with_outputs([arr]).unwrap() }); llvm_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -873,7 +873,7 @@ mod test { } builder.finish_sub_container().unwrap().out_wire(0) }; - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -978,7 +978,7 @@ mod test { .unwrap(); conditional.finish_sub_container().unwrap().out_wire(0) }; - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1085,7 +1085,7 @@ mod test { conditional.finish_sub_container().unwrap().out_wire(0) }; builder.add_array_discard(int_ty.clone(), 2, arr).unwrap(); - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1142,7 +1142,7 @@ mod test { builder .add_array_discard(int_ty.clone(), 2, arr_clone) .unwrap(); - builder.finish_with_outputs([elem]).unwrap() + builder.finish_hugr_with_outputs([elem]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1212,7 +1212,7 @@ mod test { arr, ) .unwrap(); - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1244,7 +1244,8 @@ mod test { .with_outs(int_ty.clone()) .with_extensions(exec_registry()) .finish(|mut builder| { - let mut func = builder + let mut mb = builder.module_root_builder(); + let mut func = mb .define_function("foo", Signature::new(vec![], vec![int_ty.clone()])) .unwrap(); let v = func.add_load_value(ConstInt::new_u(6, value).unwrap()); @@ -1265,7 +1266,7 @@ mod test { builder .add_array_discard(int_ty.clone(), size, arr) .unwrap(); - builder.finish_with_outputs([elem]).unwrap() + builder.finish_hugr_with_outputs([elem]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1298,7 +1299,8 @@ mod test { .add_new_array(int_ty.clone(), new_array_args) .unwrap(); - let mut func = builder + let mut mb = builder.module_root_builder(); + let mut func = mb .define_function( "foo", Signature::new(vec![int_ty.clone()], vec![int_ty.clone()]), @@ -1336,7 +1338,7 @@ mod test { builder .add_array_discard_empty(int_ty.clone(), arr) .unwrap(); - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() @@ -1355,7 +1357,6 @@ mod test { // We build a HUGR that: // - Creates an array [1, 2, 3, ..., size] // - Sums up the elements of the array using a scan and returns that sum - let int_ty = int_type(6); let hugr = SimpleHugrConfig::new() .with_outs(int_ty.clone()) @@ -1368,7 +1369,8 @@ mod test { .add_new_array(int_ty.clone(), new_array_args) .unwrap(); - let mut func = builder + let mut mb = builder.module_root_builder(); + let mut func = mb .define_function( "foo", Signature::new( @@ -1392,7 +1394,7 @@ mod test { .unwrap() .outputs_arr(); builder.add_array_discard(Type::UNIT, size, arr).unwrap(); - builder.finish_with_outputs([sum]).unwrap() + builder.finish_hugr_with_outputs([sum]).unwrap() }); exec_ctx.add_extensions(|cge| { cge.add_default_prelude_extensions() diff --git a/hugr-llvm/src/extension/collections/static_array.rs b/hugr-llvm/src/extension/collections/static_array.rs index e9df520ee3..1d9bfd8147 100644 --- a/hugr-llvm/src/extension/collections/static_array.rs +++ b/hugr-llvm/src/extension/collections/static_array.rs @@ -394,6 +394,7 @@ impl CodegenExtension for StaticArrayCodegenE mod test { use super::*; use float_types::float64_type; + use hugr_core::builder::DataflowHugr; use hugr_core::extension::prelude::ConstUsize; use hugr_core::ops::OpType; use hugr_core::ops::Value; @@ -459,7 +460,7 @@ mod test { ])) .finish(|mut builder| { let a = builder.add_load_value(value); - builder.finish_with_outputs([a]).unwrap() + builder.finish_hugr_with_outputs([a]).unwrap() }); check_emission!(hugr, llvm_ctx); } @@ -512,7 +513,7 @@ mod test { } cond.finish_sub_container().unwrap().outputs_arr() }; - builder.finish_with_outputs([out]).unwrap() + builder.finish_hugr_with_outputs([out]).unwrap() }); exec_ctx.add_extensions(|ceb| { @@ -534,7 +535,7 @@ mod test { let arr = builder .add_load_value(StaticArrayValue::try_new("empty", usize_t(), vec![]).unwrap()); let len = builder.add_static_array_len(usize_t(), arr).unwrap(); - builder.finish_with_outputs([len]).unwrap() + builder.finish_hugr_with_outputs([len]).unwrap() }); exec_ctx.add_extensions(|ceb| { @@ -574,7 +575,7 @@ mod test { let len = builder .add_static_array_len(inner_arr_ty, outer_arr) .unwrap(); - builder.finish_with_outputs([len]).unwrap() + builder.finish_hugr_with_outputs([len]).unwrap() }); check_emission!(hugr, llvm_ctx); } diff --git a/hugr-llvm/src/extension/conversions.rs b/hugr-llvm/src/extension/conversions.rs index cbc036719b..0ed8ec88c2 100644 --- a/hugr-llvm/src/extension/conversions.rs +++ b/hugr-llvm/src/extension/conversions.rs @@ -275,7 +275,7 @@ mod test { use crate::check_emission; use crate::emit::test::{DFGW, SimpleHugrConfig}; use crate::test::{TestContext, exec_ctx, llvm_ctx}; - use hugr_core::builder::SubContainer; + use hugr_core::builder::{DataflowHugr, SubContainer}; use hugr_core::std_extensions::STD_REG; use hugr_core::std_extensions::arithmetic::float_types::ConstF64; use hugr_core::std_extensions::arithmetic::int_types::ConstInt; @@ -311,7 +311,7 @@ mod test { .add_dataflow_op(ext_op, [in1]) .unwrap() .outputs(); - hugr_builder.finish_with_outputs(outputs).unwrap() + hugr_builder.finish_hugr_with_outputs(outputs).unwrap() }) } @@ -381,7 +381,7 @@ mod test { .add_dataflow_op(ext_op, [in1]) .unwrap() .outputs_arr(); - hugr_builder.finish_with_outputs([out1]).unwrap() + hugr_builder.finish_hugr_with_outputs([out1]).unwrap() }); check_emission!(op_name, hugr, llvm_ctx); } @@ -393,7 +393,7 @@ mod test { .with_extensions(PRELUDE_REGISTRY.to_owned()) .finish(|mut builder: DFGW| { let konst = builder.add_load_value(ConstUsize::new(42)); - builder.finish_with_outputs([konst]).unwrap() + builder.finish_hugr_with_outputs([konst]).unwrap() }); exec_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); assert_eq!(42, exec_ctx.exec_hugr_u64(hugr, "main")); @@ -417,7 +417,7 @@ mod test { .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [int]) .unwrap() .outputs_arr(); - builder.finish_with_outputs([usize_]).unwrap() + builder.finish_hugr_with_outputs([usize_]).unwrap() }); exec_ctx.add_extensions(|builder| { builder @@ -481,7 +481,7 @@ mod test { .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [cond_result]) .unwrap() .outputs_arr(); - builder.finish_with_outputs([usize_]).unwrap() + builder.finish_hugr_with_outputs([usize_]).unwrap() }) } @@ -613,7 +613,7 @@ mod test { let true_result = case_true.add_load_value(ConstUsize::new(6)); case_true.finish_with_outputs([true_result]).unwrap(); let res = cond.finish_sub_container().unwrap(); - builder.finish_with_outputs(res.outputs()).unwrap() + builder.finish_hugr_with_outputs(res.outputs()).unwrap() }); exec_ctx.add_extensions(|builder| { builder @@ -635,7 +635,7 @@ mod test { let [b] = builder.add_dataflow_op(i2b, [i]).unwrap().outputs_arr(); let b2i = EXTENSION.instantiate_extension_op("ifrombool", []).unwrap(); let [i] = builder.add_dataflow_op(b2i, [b]).unwrap().outputs_arr(); - builder.finish_with_outputs([i]).unwrap() + builder.finish_hugr_with_outputs([i]).unwrap() }); exec_ctx.add_extensions(|builder| { builder @@ -663,7 +663,7 @@ mod test { .instantiate_extension_op("bytecast_int64_to_float64", []) .unwrap(); let [f] = builder.add_dataflow_op(i2f, [i]).unwrap().outputs_arr(); - builder.finish_with_outputs([f]).unwrap() + builder.finish_hugr_with_outputs([f]).unwrap() }); exec_ctx.add_extensions(|builder| { builder @@ -690,7 +690,7 @@ mod test { .instantiate_extension_op("bytecast_float64_to_int64", []) .unwrap(); let [i] = builder.add_dataflow_op(f2i, [f]).unwrap().outputs_arr(); - builder.finish_with_outputs([i]).unwrap() + builder.finish_hugr_with_outputs([i]).unwrap() }); exec_ctx.add_extensions(|builder| { builder diff --git a/hugr-llvm/src/extension/float.rs b/hugr-llvm/src/extension/float.rs index b95a698b18..968ae3f585 100644 --- a/hugr-llvm/src/extension/float.rs +++ b/hugr-llvm/src/extension/float.rs @@ -149,13 +149,14 @@ impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { #[cfg(test)] mod test { use hugr_core::Hugr; + use hugr_core::builder::DataflowHugr; use hugr_core::extension::SignatureFunc; use hugr_core::extension::simple_op::MakeOpDef; use hugr_core::std_extensions::STD_REG; use hugr_core::std_extensions::arithmetic::float_ops::FloatOps; use hugr_core::types::TypeRow; use hugr_core::{ - builder::{Dataflow, DataflowSubContainer}, + builder::Dataflow, std_extensions::arithmetic::float_types::{ConstF64, float64_type}, }; use rstest::rstest; @@ -184,7 +185,7 @@ mod test { .add_dataflow_op(op, builder.input_wires()) .unwrap() .outputs(); - builder.finish_with_outputs(outputs).unwrap() + builder.finish_hugr_with_outputs(outputs).unwrap() }) } @@ -196,7 +197,7 @@ mod test { .with_extensions(STD_REG.to_owned()) .finish(|mut builder| { let c = builder.add_load_value(ConstF64::new(3.12)); - builder.finish_with_outputs([c]).unwrap() + builder.finish_hugr_with_outputs([c]).unwrap() }); check_emission!(hugr, llvm_ctx); } diff --git a/hugr-llvm/src/extension/int.rs b/hugr-llvm/src/extension/int.rs index 315c7c7296..7f8932f00d 100644 --- a/hugr-llvm/src/extension/int.rs +++ b/hugr-llvm/src/extension/int.rs @@ -1141,6 +1141,7 @@ impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { #[cfg(test)] mod test { use anyhow::Result; + use hugr_core::builder::DataflowHugr; use hugr_core::extension::prelude::{ConstError, UnwrapBuilder, error_type}; use hugr_core::std_extensions::STD_REG; use hugr_core::{ @@ -1242,7 +1243,9 @@ mod test { .unwrap() .outputs(); let processed_outputs = process(&mut hugr_builder, outputs).unwrap(); - hugr_builder.finish_with_outputs(processed_outputs).unwrap() + hugr_builder + .finish_hugr_with_outputs(processed_outputs) + .unwrap() }) } @@ -1578,7 +1581,7 @@ mod test { .add_dataflow_op(iu_to_s, [unsigned]) .unwrap() .outputs_arr(); - hugr_builder.finish_with_outputs([signed]).unwrap() + hugr_builder.finish_hugr_with_outputs([signed]).unwrap() }); let act = int_exec_ctx.exec_hugr_i64(hugr, "main"); assert_eq!(act, val as i64); @@ -1605,7 +1608,7 @@ mod test { .add_dataflow_op(make_int_op("iadd", log_width), [unsigned, num]) .unwrap() .outputs_arr(); - hugr_builder.finish_with_outputs([res]).unwrap() + hugr_builder.finish_hugr_with_outputs([res]).unwrap() }); let act = int_exec_ctx.exec_hugr_u64(hugr, "main"); assert_eq!(act, (val as u64) + 42); diff --git a/hugr-llvm/src/extension/logic.rs b/hugr-llvm/src/extension/logic.rs index 50dd2bd17c..b382a21408 100644 --- a/hugr-llvm/src/extension/logic.rs +++ b/hugr-llvm/src/extension/logic.rs @@ -76,7 +76,7 @@ impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { mod test { use hugr_core::{ Hugr, - builder::{Dataflow, DataflowSubContainer}, + builder::{Dataflow, DataflowHugr}, extension::{ExtensionRegistry, prelude::bool_t}, std_extensions::logic::{self, LogicOp}, }; @@ -99,7 +99,7 @@ mod test { .add_dataflow_op(op, builder.input_wires()) .unwrap() .outputs(); - builder.finish_with_outputs(outputs).unwrap() + builder.finish_hugr_with_outputs(outputs).unwrap() }) } diff --git a/hugr-llvm/src/extension/prelude.rs b/hugr-llvm/src/extension/prelude.rs index 62a00527c8..973f192f5e 100644 --- a/hugr-llvm/src/extension/prelude.rs +++ b/hugr-llvm/src/extension/prelude.rs @@ -405,7 +405,7 @@ pub fn add_prelude_extensions<'a, H: HugrView + 'a>( #[cfg(test)] mod test { - use hugr_core::builder::{Dataflow, DataflowSubContainer}; + use hugr_core::builder::{Dataflow, DataflowHugr}; use hugr_core::extension::PRELUDE; use hugr_core::extension::prelude::{EXIT_OP_ID, Noop}; use hugr_core::types::{Type, TypeArg}; @@ -479,7 +479,7 @@ mod test { .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) .finish(|mut builder| { let k = builder.add_load_value(ConstUsize::new(17)); - builder.finish_with_outputs([k]).unwrap() + builder.finish_hugr_with_outputs([k]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -502,7 +502,7 @@ mod test { .finish(|mut builder| { let k1 = builder.add_load_value(konst1); let k2 = builder.add_load_value(konst2); - builder.finish_with_outputs([k1, k2]).unwrap() + builder.finish_hugr_with_outputs([k1, k2]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -519,7 +519,7 @@ mod test { .add_dataflow_op(Noop::new(usize_t()), in_wires) .unwrap() .outputs(); - builder.finish_with_outputs(r).unwrap() + builder.finish_hugr_with_outputs(r).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -533,7 +533,7 @@ mod test { .finish(|mut builder| { let in_wires = builder.input_wires(); let r = builder.make_tuple(in_wires).unwrap(); - builder.finish_with_outputs([r]).unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -551,7 +551,7 @@ mod test { builder.input_wires(), ) .unwrap(); - builder.finish_with_outputs(unpack.outputs()).unwrap() + builder.finish_hugr_with_outputs(unpack.outputs()).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -578,7 +578,7 @@ mod test { .add_dataflow_op(panic_op, [err, q0, q1]) .unwrap() .outputs_arr(); - builder.finish_with_outputs([q0, q1]).unwrap() + builder.finish_hugr_with_outputs([q0, q1]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); @@ -606,7 +606,7 @@ mod test { .add_dataflow_op(exit_op, [err, q0, q1]) .unwrap() .outputs_arr(); - builder.finish_with_outputs([q0, q1]).unwrap() + builder.finish_hugr_with_outputs([q0, q1]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); @@ -622,7 +622,7 @@ mod test { .finish(|mut builder| { let greeting_out = builder.add_load_value(greeting); builder.add_dataflow_op(print_op, [greeting_out]).unwrap(); - builder.finish_with_outputs([]).unwrap() + builder.finish_hugr_with_outputs([]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); @@ -638,7 +638,7 @@ mod test { .add_dataflow_op(LoadNat::new(TypeArg::BoundedNat { n: 42 }), vec![]) .unwrap() .out_wire(0); - builder.finish_with_outputs([v]).unwrap() + builder.finish_hugr_with_outputs([v]).unwrap() }); check_emission!(hugr, prelude_llvm_ctx); } @@ -651,7 +651,7 @@ mod test { .finish(|mut builder| { let i = builder.add_load_value(ConstUsize::new(42)); let [w1, _w2] = builder.add_barrier([i, i]).unwrap().outputs_arr(); - builder.finish_with_outputs([w1]).unwrap() + builder.finish_hugr_with_outputs([w1]).unwrap() }) } diff --git a/hugr-llvm/src/test.rs b/hugr-llvm/src/test.rs index 9864ae12e1..59919baad4 100644 --- a/hugr-llvm/src/test.rs +++ b/hugr-llvm/src/test.rs @@ -2,7 +2,7 @@ use std::rc::Rc; use hugr_core::{ Hugr, - builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, + builder::{Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, ops::{OpTrait, OpType}, types::PolyFuncType, }; diff --git a/hugr-llvm/src/utils/fat.rs b/hugr-llvm/src/utils/fat.rs index 1b046ddf02..1476bcb484 100644 --- a/hugr-llvm/src/utils/fat.rs +++ b/hugr-llvm/src/utils/fat.rs @@ -8,7 +8,7 @@ use hugr_core::hugr::views::Rerooted; use hugr_core::{ Hugr, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, core::HugrNode, - ops::{CFG, DataflowBlock, ExitBlock, Input, OpType, Output}, + ops::{CFG, DataflowBlock, ExitBlock, Input, Module, OpType, Output}, types::Type, }; use itertools::Itertools as _; @@ -373,7 +373,12 @@ pub trait FatExt: HugrView { } /// Try to create a specific [`FatNode`] for the root of a [`HugrView`]. - fn fat_root(&self) -> Option> + fn fat_root(&self) -> Option> { + self.try_fat(self.module_root()) + } + + /// Try to create a specific [`FatNode`] for the entrypoint of a [`HugrView`]. + fn fat_entrypoint(&self) -> Option> where for<'a> &'a OpType: TryInto<&'a OT>, { diff --git a/hugr-passes/src/composable.rs b/hugr-passes/src/composable.rs index d7f44fcebb..cd6591b0e8 100644 --- a/hugr-passes/src/composable.rs +++ b/hugr-passes/src/composable.rs @@ -223,8 +223,7 @@ mod test { use std::convert::Infallible; use hugr_core::builder::{ - Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, - ModuleBuilder, + Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, }; use hugr_core::extension::prelude::{ConstUsize, MakeTuple, UnpackTuple, bool_t, usize_t}; use hugr_core::hugr::hugrmut::HugrMut; diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 205d9ba4fa..915f6f3425 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -2,7 +2,7 @@ use std::convert::Infallible; use ascent::{Lattice, lattice::BoundedLattice}; -use hugr_core::builder::{CFGBuilder, Container, DataflowHugr, ModuleBuilder, inout_sig}; +use hugr_core::builder::{CFGBuilder, DataflowHugr, ModuleBuilder, inout_sig}; use hugr_core::ops::{CallIndirect, TailLoop}; use hugr_core::types::{ConstTypeError, TypeRow}; use hugr_core::{Hugr, Node, Wire}; @@ -409,11 +409,14 @@ fn test_call( #[case] out: PartialValue, ) { let mut builder = DFGBuilder::new(Signature::new_endo(vec![bool_t(); 2])).unwrap(); - let func_bldr = builder - .define_function("id", Signature::new_endo(bool_t())) - .unwrap(); - let [v] = func_bldr.input_wires_arr(); - let func_defn = func_bldr.finish_with_outputs([v]).unwrap(); + let func_defn = { + let mut mb = builder.module_root_builder(); + let func_bldr = mb + .define_function("id", Signature::new_endo(bool_t())) + .unwrap(); + let [v] = func_bldr.input_wires_arr(); + func_bldr.finish_with_outputs([v]).unwrap() + }; let [a, b] = builder.input_wires_arr(); let [a2] = builder .call(func_defn.handle(), &[], [a]) @@ -554,7 +557,8 @@ fn call_indirect(#[case] inp1: PartialValue, #[case] inp2: PartialValue Node { - let for_func = cache.entry(poly_func).or_insert_with(|| { - // First time we've instantiated poly_func. Lift any nested FuncDefn's out to the same level. - let outer_name = h - .get_optype(poly_func) - .as_func_defn() - .unwrap() - .func_name() - .clone(); - let mut to_scan = Vec::from_iter(h.children(poly_func)); - while let Some(n) = to_scan.pop() { - if let OpType::FuncDefn(fd) = h.optype_mut(n) { - *fd.func_name_mut() = mangle_inner_func(&outer_name, fd.func_name()); - h.move_after_sibling(n, poly_func); - } else { - to_scan.extend(h.children(n)); - } - } - HashMap::new() - }); + let for_func = cache.entry(poly_func).or_default(); let ve = match for_func.entry(type_args.clone()) { Entry::Occupied(n) => return *n.get(), @@ -278,16 +260,13 @@ pub fn mangle_name(name: &str, type_args: impl AsRef<[TypeArg]>) -> String { format!("${name}${}", TypeArgsList(type_args.as_ref())) } -fn mangle_inner_func(outer_name: &str, inner_name: &str) -> String { - format!("${outer_name}${inner_name}") -} - #[cfg(test)] mod test { use std::collections::HashMap; use std::iter; use hugr_core::extension::simple_op::MakeRegisteredOp as _; + use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; use hugr_core::std_extensions::collections; use hugr_core::std_extensions::collections::array::ArrayKind; @@ -308,7 +287,7 @@ mod test { use crate::{monomorphize, remove_dead_funcs}; - use super::{is_polymorphic, mangle_inner_func, mangle_name}; + use super::{is_polymorphic, mangle_name}; fn pair_type(ty: Type) -> Type { Type::new_tuple(vec![ty.clone(), ty]) @@ -425,13 +404,12 @@ mod test { } #[test] - fn test_flattening_multiargs_nats() { + fn test_multiargs_nats() { //pf1 contains pf2 contains mono_func -> pf1 and pf1 share pf2's and they share mono_func let tv = |i| Type::new_var_use(i, TypeBound::Copyable); let sv = |i| TypeArg::new_var_use(i, TypeParam::max_nat()); let sa = |n| TypeArg::BoundedNat { n }; - let n: u64 = 5; let mut outer = FunctionBuilder::new( "mainish", @@ -447,32 +425,23 @@ mod test { .unwrap(); let arr2u = || ValueArray::ty_parametric(sa(2), usize_t()).unwrap(); - let pf1t = PolyFuncType::new( - [TypeParam::max_nat()], - Signature::new( - ValueArray::ty_parametric(sv(0), arr2u()).unwrap(), - usize_t(), - ), - ); - let mut pf1 = outer.define_function("pf1", pf1t).unwrap(); - let pf2t = PolyFuncType::new( - [TypeParam::max_nat(), TypeBound::Copyable.into()], - Signature::new( - vec![ValueArray::ty_parametric(sv(0), tv(1)).unwrap()], - tv(1), - ), - ); - let mut pf2 = pf1.define_function("pf2", pf2t).unwrap(); + let mut mb = outer.module_root_builder(); let mono_func = { - let mut fb = pf2 + let mut fb = mb .define_function("get_usz", Signature::new(vec![], usize_t())) .unwrap(); let cst0 = fb.add_load_value(ConstUsize::new(1)); fb.finish_with_outputs([cst0]).unwrap() }; + let pf2 = { + let pf2t = PolyFuncType::new( + [TypeParam::max_nat(), TypeBound::Copyable.into()], + Signature::new(ValueArray::ty_parametric(sv(0), tv(1)).unwrap(), tv(1)), + ); + let mut pf2 = mb.define_function("pf2", pf2t).unwrap(); let [inw] = pf2.input_wires_arr(); let [idx] = pf2.call(mono_func.handle(), &[], []).unwrap().outputs_arr(); let op_def = collections::value_array::EXTENSION.get_op("get").unwrap(); @@ -484,6 +453,16 @@ mod test { .unwrap(); pf2.finish_with_outputs([got]).unwrap() }; + + let pf1t = PolyFuncType::new( + [TypeParam::max_nat()], + Signature::new( + ValueArray::ty_parametric(sv(0), arr2u()).unwrap(), + usize_t(), + ), + ); + let mut pf1 = mb.define_function("pf1", pf1t).unwrap(); + // pf1: Two calls to pf2, one depending on pf1's TypeArg, the other not let inner = pf1 .call(pf2.handle(), &[sv(0), arr2u().into()], pf1.input_wires()) @@ -496,6 +475,7 @@ mod test { ) .unwrap(); let pf1 = pf1.finish_with_outputs(elem.outputs()).unwrap(); + // Outer: two calls to pf1 with different TypeArgs let [e1] = outer .call(pf1.handle(), &[sa(n)], outer.input_wires()) @@ -516,23 +496,24 @@ mod test { .call(pf1.handle(), &[sa(n - 1)], [ar2_unwrapped]) .unwrap() .outputs_arr(); + let outer_func = outer.container_node(); let mut hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap(); + hugr.set_entrypoint(hugr.module_root()); // We want to act on everything, not just `main` monomorphize(&mut hugr).unwrap(); let mono_hugr = hugr; mono_hugr.validate().unwrap(); let funcs = list_funcs(&mono_hugr); - let pf2_name = mangle_inner_func("pf1", "pf2"); assert_eq!( funcs.keys().copied().sorted().collect_vec(), vec![ &mangle_name("pf1", &[TypeArg::BoundedNat { n: 5 }]), &mangle_name("pf1", &[TypeArg::BoundedNat { n: 4 }]), - &mangle_name(&pf2_name, &[TypeArg::BoundedNat { n: 5 }, arr2u().into()]), // from pf1<5> - &mangle_name(&pf2_name, &[TypeArg::BoundedNat { n: 4 }, arr2u().into()]), // from pf1<4> - &mangle_name(&pf2_name, &[TypeArg::BoundedNat { n: 2 }, usize_t().into()]), // from both pf1<4> and <5> - &mangle_inner_func(&pf2_name, "get_usz"), - &pf2_name, + &mangle_name("pf2", &[TypeArg::BoundedNat { n: 5 }, arr2u().into()]), // from pf1<5> + &mangle_name("pf2", &[TypeArg::BoundedNat { n: 4 }, arr2u().into()]), // from pf1<4> + &mangle_name("pf2", &[TypeArg::BoundedNat { n: 2 }, usize_t().into()]), // from both pf1<4> and <5> + "get_usz", + "pf2", "mainish", "pf1" ] @@ -540,13 +521,10 @@ mod test { .sorted() .collect_vec() ); - for (n, fd) in funcs.into_values() { - if n == mono_hugr.entrypoint() { - assert_eq!(fd.func_name(), "mainish"); - } else { - assert_ne!(fd.func_name(), "mainish"); - } - } + #[allow(clippy::unnecessary_to_owned)] // it is necessary + let (n, fd) = *funcs.get(&"mainish".to_string()).unwrap(); + assert_eq!(n, outer_func); + assert_eq!(fd.func_name(), "mainish"); // just a sanity check on list_funcs } fn list_funcs(h: &Hugr) -> HashMap<&String, (Node, &FuncDefn)> { @@ -559,50 +537,6 @@ mod test { .collect::>() } - #[test] - fn test_no_flatten_out_of_mono_func() -> Result<(), Box> { - let ity = || INT_TYPES[4].clone(); - let sig = Signature::new_endo(vec![usize_t(), ity()]); - let mut dfg = DFGBuilder::new(sig.clone()).unwrap(); - let mut mono = dfg.define_function("id2", sig).unwrap(); - let pf = mono - .define_function( - "id", - PolyFuncType::new( - [TypeBound::Any.into()], - Signature::new_endo(Type::new_var_use(0, TypeBound::Any)), - ), - ) - .unwrap(); - let outs = pf.input_wires(); - let pf = pf.finish_with_outputs(outs).unwrap(); - let [a, b] = mono.input_wires_arr(); - let [a] = mono - .call(pf.handle(), &[usize_t().into()], [a]) - .unwrap() - .outputs_arr(); - let [b] = mono - .call(pf.handle(), &[ity().into()], [b]) - .unwrap() - .outputs_arr(); - let mono = mono.finish_with_outputs([a, b]).unwrap(); - let c = dfg.call(mono.handle(), &[], dfg.input_wires()).unwrap(); - let mut hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap(); - monomorphize(&mut hugr)?; - let mono_hugr = hugr; - - let mut funcs = list_funcs(&mono_hugr); - #[allow(clippy::unnecessary_to_owned)] // It is necessary - let (m, _) = funcs.remove(&"id2".to_string()).unwrap(); - assert_eq!(m, mono.handle().node()); - assert_eq!(mono_hugr.get_parent(m), Some(mono_hugr.entrypoint())); - for t in [usize_t(), ity()] { - let (n, _) = funcs.remove(&mangle_name("id", &[t.into()])).unwrap(); - assert_eq!(mono_hugr.get_parent(n), Some(m)); // Not lifted to top - } - Ok(()) - } - #[test] fn load_function() { let mut hugr = { diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index bc12e730bd..a682d754bb 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -367,8 +367,8 @@ mod test { use std::sync::Arc; use hugr_core::builder::{ - BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, inout_sig, + BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, + inout_sig, }; use hugr_core::extension::prelude::{option_type, usize_t}; @@ -800,7 +800,8 @@ mod test { // A simple Hugr that discards a usize_t, with a "drop" function let mut dfb = DFGBuilder::new(inout_sig(usize_t(), type_row![])).unwrap(); let discard_fn = { - let mut fb = dfb + let mut mb = dfb.module_root_builder(); + let mut fb = mb .define_function("drop", Signature::new(lin_t.clone(), type_row![])) .unwrap(); let ins = fb.input_wires(); @@ -815,12 +816,11 @@ mod test { let backup = dfb.finish_hugr().unwrap(); let mut lower_discard_to_call = ReplaceTypes::default(); - // The `copy_fn` here will break completely, but we don't use it lower_discard_to_call .linearizer() .register_simple( lin_ct.clone(), - NodeTemplate::Call(backup.entrypoint(), vec![]), + NodeTemplate::Call(backup.entrypoint(), vec![]), // Arbitrary, unused NodeTemplate::Call(discard_fn, vec![]), ) .unwrap(); @@ -834,18 +834,20 @@ mod test { assert_eq!(h.output_neighbours(discard_fn).count(), 1); } - // But if we lower usize_t to array, the call will fail + // But if we lower usize_t to array, the call will fail. lower_discard_to_call.replace_type( usize_t().as_extension().unwrap().clone(), value_array_type(4, lin_ct.into()), ); let r = lower_discard_to_call.run(&mut backup.clone()); + // Note the error (or success) can be quite fragile, according to what the `discard_fn` + // Node points at in the (hidden here) inner Hugr built by the array linearization helper. assert!(matches!( r, Err(ReplaceTypesError::LinearizeError( LinearizeError::NestedTemplateError( nested_t, - BuildError::NodeNotFound { node } + BuildError::NodeNotFound { node } // Note `..` would be somewhat less fragile ) )) if nested_t == lin_t && node == discard_fn )); diff --git a/hugr-py/src/hugr/build/dfg.py b/hugr-py/src/hugr/build/dfg.py index 5e5bd2e0cb..440d16f456 100644 --- a/hugr-py/src/hugr/build/dfg.py +++ b/hugr-py/src/hugr/build/dfg.py @@ -21,8 +21,9 @@ if TYPE_CHECKING: from collections.abc import Iterable, Sequence + from hugr.build.function import Module from hugr.hugr.node_port import Node, OutPort, PortOffset, ToNode, Wire - from hugr.tys import Type, TypeParam, TypeRow + from hugr.tys import TypeParam, TypeRow from .cfg import Cfg from .cond_loop import Conditional, If, TailLoop @@ -36,40 +37,21 @@ class DataflowError(Exception): @dataclass() class DefinitionBuilder(Generic[OpVar]): - """Base class for builders that can define functions, constants, and aliases. + """Base class for builders that can define constants, and allow access + to the `Module` for declaring/defining functions and aliases. As this class may be a root node, it does not extend `ParentBuilder`. """ hugr: Hugr[OpVar] - def define_function( - self, - name: str, - input_types: TypeRow, - output_types: TypeRow | None = None, - type_params: list[TypeParam] | None = None, - parent: ToNode | None = None, - ) -> Function: - """Start building a function definition in the graph. - - Args: - name: The name of the function. - input_types: The input types for the function. - output_types: The output types for the function. - If not provided, it will be inferred after the function is built. - type_params: The type parameters for the function, if polymorphic. - parent: The parent node of the constant. Defaults to the entrypoint node. - - Returns: - The new function builder. + def module_root_builder(self) -> Module: + """Allows access to the `Module` at the root of the Hugr + (outside the scope of this builder, perhaps outside the entrypoint). """ - parent_node = parent or self.hugr.entrypoint - parent_op = ops.FuncDefn(name, input_types, type_params or []) - func = Function.new_nested(parent_op, self.hugr, parent_node) - if output_types is not None: - func.declare_outputs(output_types) - return func + from hugr.build.function import Module # Avoid circular import + + return Module(self.hugr) def add_const(self, value: val.Value, parent: ToNode | None = None) -> Node: """Add a static constant to the graph. @@ -90,11 +72,6 @@ def add_const(self, value: val.Value, parent: ToNode | None = None) -> Node: parent_node = parent or self.hugr.entrypoint return self.hugr.add_node(ops.Const(value), parent_node) - def add_alias_defn(self, name: str, ty: Type, parent: ToNode | None = None) -> Node: - """Add a type alias definition.""" - parent_node = parent or self.hugr.entrypoint - return self.hugr.add_node(ops.AliasDefn(name, ty), parent_node) - DP = TypeVar("DP", bound=ops.DfParentOp) diff --git a/hugr-py/src/hugr/build/function.py b/hugr-py/src/hugr/build/function.py index b5d8b8c1ff..ecf0fd18ef 100644 --- a/hugr-py/src/hugr/build/function.py +++ b/hugr-py/src/hugr/build/function.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from hugr.hugr.node_port import Node - from hugr.tys import PolyFuncType, TypeBound, TypeRow + from hugr.tys import PolyFuncType, Type, TypeBound, TypeParam, TypeRow __all__ = ["Function", "Module"] @@ -28,13 +28,39 @@ class Module(DefinitionBuilder[ops.Module]): hugr: Hugr[ops.Module] - def __init__(self) -> None: - self.hugr = Hugr(ops.Module()) + def __init__(self, hugr: Hugr | None = None) -> None: + self.hugr = Hugr(ops.Module()) if hugr is None else hugr def define_main(self, input_types: TypeRow) -> Function: """Define the 'main' function in the module. See :meth:`define_function`.""" return self.define_function("main", input_types) + def define_function( + self, + name: str, + input_types: TypeRow, + output_types: TypeRow | None = None, + type_params: list[TypeParam] | None = None, + ) -> Function: + """Start building a function definition in the graph. + + Args: + name: The name of the function. + input_types: The input types for the function. + output_types: The output types for the function. + If not provided, it will be inferred after the function is built. + type_params: The type parameters for the function, if polymorphic. + parent: The parent node of the constant. Defaults to the entrypoint node. + + Returns: + The new function builder. + """ + parent_op = ops.FuncDefn(name, input_types, type_params or []) + func = Function.new_nested(parent_op, self.hugr, self.hugr.module_root) + if output_types is not None: + func.declare_outputs(output_types) + return func + def declare_function(self, name: str, signature: PolyFuncType) -> Node: """Add a function declaration to the module. @@ -54,9 +80,13 @@ def declare_function(self, name: str, signature: PolyFuncType) -> Node: """ return self.hugr.add_node(ops.FuncDecl(name, signature), self.hugr.entrypoint) + def add_alias_defn(self, name: str, ty: Type) -> Node: + """Add a type alias definition.""" + return self.hugr.add_node(ops.AliasDefn(name, ty), self.hugr.module_root) + def add_alias_decl(self, name: str, bound: TypeBound) -> Node: """Add a type alias declaration.""" - return self.hugr.add_node(ops.AliasDecl(name, bound), self.hugr.entrypoint) + return self.hugr.add_node(ops.AliasDecl(name, bound), self.hugr.module_root) @property def metadata(self) -> dict[str, object]: diff --git a/hugr-py/src/hugr/hugr/base.py b/hugr-py/src/hugr/hugr/base.py index 75fc8d8d93..3d7d657105 100644 --- a/hugr-py/src/hugr/hugr/base.py +++ b/hugr-py/src/hugr/hugr/base.py @@ -98,7 +98,8 @@ class Hugr(Mapping[Node, NodeData], Generic[OpVarCov]): """The core HUGR datastructure. Args: - root_op: The operation for the root node. Defaults to a Module. + entrypoint_op: The operation for the entrypoint node. Defaults to a Module + (which will then also be the root). Examples: >>> h = Hugr() diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 74c8018f9e..9452f3672a 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -278,6 +278,19 @@ def test_mono_function(direct_call: bool) -> None: validate(mod.hugr) +def test_function_dfg() -> None: + d = Dfg(tys.Qubit) + + f_id = d.module_root_builder().define_function("id", [tys.Qubit]) + f_id.set_outputs(f_id.input_node[0]) + + (q,) = d.inputs() + call = d.call(f_id, q) + d.set_outputs(call) + + validate(d.hugr) + + def test_recursive_function(snapshot) -> None: mod = Module() diff --git a/hugr/benches/benchmarks/hugr/examples.rs b/hugr/benches/benchmarks/hugr/examples.rs index c74b66e53c..8f0b673156 100644 --- a/hugr/benches/benchmarks/hugr/examples.rs +++ b/hugr/benches/benchmarks/hugr/examples.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use hugr::builder::{ - BuildError, CFGBuilder, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, ModuleBuilder, + BuildError, CFGBuilder, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, + ModuleBuilder, }; use hugr::extension::prelude::{bool_t, qb_t, usize_t}; use hugr::ops::OpName; diff --git a/specification/hugr.md b/specification/hugr.md index 3bd22b8efd..01c21bbe78 100644 --- a/specification/hugr.md +++ b/specification/hugr.md @@ -248,6 +248,11 @@ edges. The following operations are *only* valid as immediate children of a - `AliasDecl`: an external type alias declaration. At link time this can be replaced with the definition. An alias declared with `AliasDecl` is equivalent to a named opaque type. +- `FuncDefn` : a function definition. Like `FuncDecl` but with a function body. + The function body is defined by the sibling graph formed by its children. + At link time `FuncDecl` nodes are replaced by `FuncDefn`. +- `AliasDefn`: type alias definition. At link time `AliasDecl` can be replaced with + `AliasDefn`. There may also be other [scoped definitions](#scoped-definitions). @@ -258,11 +263,6 @@ regions and control-flow regions: - `Const` : a static constant value of type T stored in the node weight. Like `FuncDecl` and `FuncDefn` this has one `Const` out-edge per use. -- `FuncDefn` : a function definition. Like `FuncDecl` but with a function body. - The function body is defined by the sibling graph formed by its children. - At link time `FuncDecl` nodes are replaced by `FuncDefn`. -- `AliasDefn`: type alias definition. At link time `AliasDecl` can be replaced with - `AliasDefn`. A **loadable HUGR** is a module HUGR where all input ports are connected and there are no `FuncDecl/AliasDecl` nodes. @@ -552,11 +552,8 @@ parent(n2) when the edge's locality is: Each of these localities have additional constraints as follows: 1. For Ext edges, we require parent(n1) == - parenti(n2) for some i\>1, *and* for Value edges only: - * there must be a order edge from n1 to - parenti-1(n2). - * None of the parentj(n2), for i\>j\>=1, - may be a FuncDefn node + parenti(n2) for some i\>1, *and* for Value edges only there must be a order edge from n1 to + parenti-1(n2). The order edge records the ordering requirement that results, i.e. it must be possible to @@ -569,9 +566,6 @@ Each of these localities have additional constraints as follows: For Static edges this order edge is not required since the source is guaranteed to causally precede the target. - The FuncDefn restriction means that FuncDefn really are static, - and do not capture runtime values from their environment. - 2. For Dom edges, we must have that parent2(n1) == parenti(n2) is a CFG-node, for some i\>1, **and** parent(n1) strictly dominates @@ -580,8 +574,6 @@ Each of these localities have additional constraints as follows: i\>1 allows the node to target an arbitrarily-deep descendant of the dominated block, similar to an Ext edge.) - The same FuncDefn restriction also applies here, on the parent(j)(n2) for i\>j\>=1 (of course j=i is the CFG and j=i-1 is the basic block). - Specifically, these rules allow for edges where in a given execution of the HUGR the source of the edge executes once, but the target may execute \>=0 times.