From 773432098d0c393d5bf7564fc3066acb0aa3253e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 16 Apr 2025 16:36:40 +0100 Subject: [PATCH 01/12] WIP add NodeTemplate::Call --- hugr-passes/src/replace_types.rs | 48 ++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 3ed7337a98..54b3062098 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -15,16 +15,16 @@ use hugr_core::builder::{BuildError, BuildHandle, Dataflow}; use hugr_core::extension::{ExtensionId, OpDef, SignatureError, TypeDef}; use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::ops::constant::{OpaqueValue, Sum}; -use hugr_core::ops::handle::DataflowOpID; +use hugr_core::ops::handle::{DataflowOpID, FuncID}; use hugr_core::ops::{ AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock, ExitBlock, ExtensionOp, FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop, Value, CFG, DFG, }; use hugr_core::types::{ - ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer, + ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer }; -use hugr_core::{Hugr, HugrView, Node, Wire}; +use hugr_core::{Direction, Hugr, HugrView, Node, PortIndex, Wire}; use crate::validation::{ValidatePassError, ValidationLevel}; @@ -48,18 +48,33 @@ pub enum NodeTemplate { // TODO: store also a vec, and update Hugr::validate to take &[TypeParam]s // (defaulting to empty list) - see https://github.com/CQCL/hugr/issues/709 CompoundOp(Box), - // TODO allow also Call to a Node in the existing Hugr - // (can't see any other way to achieve multiple calls to the same decl. - // So client should add the functions before replacement, then remove unused ones afterwards.) + /// A Call to a function (already) existing in the Hugr. + Call(Node, Vec) } impl NodeTemplate { /// Adds this instance to the specified [HugrMut] as a new node or subtree under a /// given parent, returning the unique new child (of that parent) thus created + /// + /// # Panics + /// + /// * If `parent` is not in the `hugr` + /// * If `self` is a [Self::Call] and the target Node either + /// * is neither a [FuncDefn] nor a [FuncDecl] + /// * has a [`signature`] which the type-args of the [Self::Call] do not match + /// + /// [`signature`]: hugr_core::types::PolyFuncType pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Node { match self { NodeTemplate::SingleOp(op_type) => hugr.add_node_with_parent(parent, op_type), NodeTemplate::CompoundOp(new_h) => hugr.insert_hugr(parent, *new_h).new_root, + NodeTemplate::Call(target, type_args) => { + let c = call(hugr, target, type_args); + let tgt_port = c.called_function_port(); + let n = hugr.add_node_with_parent(parent, c); + hugr.connect(target, 0, n, tgt_port); + n + } } } @@ -72,6 +87,9 @@ impl NodeTemplate { match self { NodeTemplate::SingleOp(opty) => dfb.add_dataflow_op(opty, inputs), NodeTemplate::CompoundOp(h) => dfb.add_hugr_with_wires(*h, inputs), + // Really we should check whether func points at a FuncDecl or FuncDefn and create + // the appropriate variety of FuncID but it doesn't matter for the purpose of making a Call. + NodeTemplate::Call(func, type_args) => dfb.call(&FuncID::::from(func) , &type_args, inputs) } } @@ -88,6 +106,15 @@ impl NodeTemplate { } root_opty } + NodeTemplate::Call(func, type_args) => { + let c = call(hugr, func, type_args); + let static_inport = c.called_function_port(); + // insert an input for the Call static input + hugr.insert_ports(n, Direction::Incoming, static_inport.index(), 1); + // connect the function to (what will be) the call + hugr.connect(func, 0, n, static_inport); + c.into() + } }; *hugr.optype_mut(n) = new_optype; } @@ -101,6 +128,15 @@ impl NodeTemplate { } } +fn call(h: &H, func: H::Node, type_args: Vec) -> Call { + let func_sig = match h.get_optype(func) { + OpType::FuncDecl(fd) => fd.signature.clone(), + OpType::FuncDefn(fd) => fd.signature.clone(), + o => panic!("Node {func}: expected FuncDecl or FuncDefn, got {o:?}") + }; + Call::try_new(func_sig, type_args).unwrap() +} + /// A configuration of what types, ops, and constants should be replaced with what. /// May be applied to a Hugr via [Self::run]. /// From 572f3709f8a6eca82d6beef926e3a68d35237fc4 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 16 Apr 2025 17:46:15 +0100 Subject: [PATCH 02/12] Move sig-checking into NodeTemplate --- hugr-passes/src/replace_types.rs | 33 +++++++++++++++------- hugr-passes/src/replace_types/linearize.rs | 19 +++++-------- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 54b3062098..8fde640c1a 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -22,7 +22,8 @@ use hugr_core::ops::{ Value, CFG, DFG, }; use hugr_core::types::{ - ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer + ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeRow, + TypeTransformer, }; use hugr_core::{Direction, Hugr, HugrView, Node, PortIndex, Wire}; @@ -49,20 +50,20 @@ pub enum NodeTemplate { // (defaulting to empty list) - see https://github.com/CQCL/hugr/issues/709 CompoundOp(Box), /// A Call to a function (already) existing in the Hugr. - Call(Node, Vec) + Call(Node, Vec), } impl NodeTemplate { /// Adds this instance to the specified [HugrMut] as a new node or subtree under a /// given parent, returning the unique new child (of that parent) thus created - /// + /// /// # Panics - /// + /// /// * If `parent` is not in the `hugr` /// * If `self` is a [Self::Call] and the target Node either /// * is neither a [FuncDefn] nor a [FuncDecl] /// * has a [`signature`] which the type-args of the [Self::Call] do not match - /// + /// /// [`signature`]: hugr_core::types::PolyFuncType pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Node { match self { @@ -89,7 +90,9 @@ impl NodeTemplate { NodeTemplate::CompoundOp(h) => dfb.add_hugr_with_wires(*h, inputs), // Really we should check whether func points at a FuncDecl or FuncDefn and create // the appropriate variety of FuncID but it doesn't matter for the purpose of making a Call. - NodeTemplate::Call(func, type_args) => dfb.call(&FuncID::::from(func) , &type_args, inputs) + NodeTemplate::Call(func, type_args) => { + dfb.call(&FuncID::::from(func), &type_args, inputs) + } } } @@ -119,12 +122,22 @@ impl NodeTemplate { *hugr.optype_mut(n) = new_optype; } - fn signature(&self) -> Option> { - match self { + fn check_signature( + &self, + inputs: &TypeRow, + outputs: &TypeRow, + ) -> Result<(), Option> { + let sig = match self { NodeTemplate::SingleOp(op_type) => op_type, NodeTemplate::CompoundOp(hugr) => hugr.root_type(), + NodeTemplate::Call(_, _) => return Ok(()), // no way to tell + } + .dataflow_signature(); + if sig.as_deref().map(Signature::io) == Some((inputs, outputs)) { + Ok(()) + } else { + Err(sig.map(Cow::into_owned)) } - .dataflow_signature() } } @@ -132,7 +145,7 @@ fn call(h: &H, func: H::Node, type_args: Vec) -> Call { let func_sig = match h.get_optype(func) { OpType::FuncDecl(fd) => fd.signature.clone(), OpType::FuncDefn(fd) => fd.signature.clone(), - o => panic!("Node {func}: expected FuncDecl or FuncDefn, got {o:?}") + o => panic!("Node {func}: expected FuncDecl or FuncDefn, got {o:?}"), }; Call::try_new(func_sig, type_args).unwrap() } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 5b4da71846..a80f30aee2 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -1,4 +1,3 @@ -use std::borrow::Cow; use std::iter::repeat; use std::{collections::HashMap, sync::Arc}; @@ -185,8 +184,10 @@ impl DelegatingLinearizer { /// /// * [LinearizeError::CopyableType] If `typ` is /// [Copyable](hugr_core::types::TypeBound::Copyable) - /// * [LinearizeError::WrongSignature] if `copy` or `discard` do not have the - /// expected inputs or outputs + /// * [LinearizeError::WrongSignature] if `copy` or `discard` do not have the expected + /// inputs or outputs (for [NodeTemplate::SingleOp] and [NodeTemplate::CompoundOp] + /// only: the signature for a [NodeTemplate::Call] cannot be checked until it is used + /// in a Hugr). pub fn register_simple( &mut self, cty: CustomType, @@ -230,18 +231,12 @@ impl DelegatingLinearizer { } fn check_sig(tmpl: &NodeTemplate, typ: &Type, num_outports: usize) -> Result<(), LinearizeError> { - let sig = tmpl.signature(); - if sig.as_ref().is_some_and(|sig| { - sig.io() == (&typ.clone().into(), &vec![typ.clone(); num_outports].into()) - }) { - Ok(()) - } else { - Err(LinearizeError::WrongSignature { + tmpl.check_signature(&typ.clone().into(), &vec![typ.clone(); num_outports].into()) + .map_err(|sig| LinearizeError::WrongSignature { typ: typ.clone(), num_outports, - sig: sig.map(Cow::into_owned), + sig, }) - } } impl Linearizer for DelegatingLinearizer { From 013dbcf9d5dbccef97d9829b1146328af116ec18 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 16 Apr 2025 22:25:27 +0100 Subject: [PATCH 03/12] [refactor test] move lowered_read out, return Hugr --- hugr-passes/src/replace_types.rs | 50 +++++++++++++++++--------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 8fde640c1a..1c8eb8447c 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -597,6 +597,7 @@ mod test { use hugr_core::hugr::ValidationError; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; + use hugr_core::Hugr; use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; use itertools::Itertools; use rstest::rstest; @@ -664,30 +665,29 @@ mod test { ) } - fn lowerer(ext: &Arc) -> ReplaceTypes { - fn lowered_read(args: &[TypeArg]) -> Option { - let ty = just_elem_type(args); - let mut dfb = DFGBuilder::new(inout_sig( - vec![array_type(64, ty.clone()), i64_t()], - ty.clone(), - )) + fn lowered_read(args: &[TypeArg]) -> Hugr { + let ty = just_elem_type(args); + let mut dfb = DFGBuilder::new(inout_sig( + vec![array_type(64, ty.clone()), i64_t()], + ty.clone(), + )) + .unwrap(); + let [val, idx] = dfb.input_wires_arr(); + let [idx] = dfb + .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx]) + .unwrap() + .outputs_arr(); + let [opt] = dfb + .add_dataflow_op(ArrayOpDef::get.to_concrete(ty.clone(), 64), [val, idx]) + .unwrap() + .outputs_arr(); + let [res] = dfb + .build_unwrap_sum(1, option_type(Type::from(ty.clone())), opt) .unwrap(); - let [val, idx] = dfb.input_wires_arr(); - let [idx] = dfb - .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx]) - .unwrap() - .outputs_arr(); - let [opt] = dfb - .add_dataflow_op(ArrayOpDef::get.to_concrete(ty.clone(), 64), [val, idx]) - .unwrap() - .outputs_arr(); - let [res] = dfb - .build_unwrap_sum(1, option_type(Type::from(ty.clone())), opt) - .unwrap(); - Some(NodeTemplate::CompoundOp(Box::new( - dfb.finish_hugr_with_outputs([res]).unwrap(), - ))) - } + dfb.finish_hugr_with_outputs([res]).unwrap() + } + + fn lowerer(ext: &Arc) -> ReplaceTypes { let pv = ext.get_type(PACKED_VEC).unwrap(); let mut lw = ReplaceTypes::default(); lw.replace_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t()); @@ -703,7 +703,9 @@ mod test { .into(), ), ); - lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), Box::new(lowered_read)); + lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), |type_args| { + Some(NodeTemplate::CompoundOp(Box::new(lowered_read(type_args)))) + }); lw } From 633b502efdb3127d4012cbd0fe5a6a8f342518af Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 16 Apr 2025 22:39:16 +0100 Subject: [PATCH 04/12] [refactor test] lowered_read_bool takes Type, parametrized by builder::container --- hugr-passes/src/replace_types.rs | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 1c8eb8447c..1cc69ad665 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -575,7 +575,7 @@ mod test { use std::sync::Arc; use hugr_core::builder::{ - inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, }; use hugr_core::extension::prelude::{ @@ -595,10 +595,9 @@ mod test { list_type, list_type_def, ListOp, ListValue, }; - use hugr_core::hugr::ValidationError; + use hugr_core::hugr::{IdentList, ValidationError}; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; - use hugr_core::Hugr; - use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; + use hugr_core::{type_row, Extension, HugrView}; use itertools::Itertools; use rstest::rstest; @@ -665,11 +664,13 @@ mod test { ) } - fn lowered_read(args: &[TypeArg]) -> Hugr { - let ty = just_elem_type(args); - let mut dfb = DFGBuilder::new(inout_sig( - vec![array_type(64, ty.clone()), i64_t()], - ty.clone(), + fn lowered_read( + elem_ty: Type, + new: impl Fn(Signature) -> Result, + ) -> T { + let mut dfb = new(inout_sig( + vec![array_type(64, elem_ty.clone()), i64_t()], + elem_ty.clone(), )) .unwrap(); let [val, idx] = dfb.input_wires_arr(); @@ -678,13 +679,14 @@ mod test { .unwrap() .outputs_arr(); let [opt] = dfb - .add_dataflow_op(ArrayOpDef::get.to_concrete(ty.clone(), 64), [val, idx]) + .add_dataflow_op(ArrayOpDef::get.to_concrete(elem_ty.clone(), 64), [val, idx]) .unwrap() .outputs_arr(); let [res] = dfb - .build_unwrap_sum(1, option_type(Type::from(ty.clone())), opt) + .build_unwrap_sum(1, option_type(Type::from(elem_ty)), opt) .unwrap(); - dfb.finish_hugr_with_outputs([res]).unwrap() + dfb.set_outputs([res]).unwrap(); + dfb } fn lowerer(ext: &Arc) -> ReplaceTypes { @@ -704,7 +706,11 @@ mod test { ), ); lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), |type_args| { - Some(NodeTemplate::CompoundOp(Box::new(lowered_read(type_args)))) + Some(NodeTemplate::CompoundOp(Box::new( + lowered_read(just_elem_type(type_args).clone(), DFGBuilder::new) + .finish_hugr() + .unwrap(), + ))) }); lw } From e3748fb48ba4004897d90e56ddbe5ccb420045cc Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 16 Apr 2025 22:55:22 +0100 Subject: [PATCH 05/12] new test adding a polymorphic Function --- hugr-passes/src/replace_types.rs | 51 +++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 1cc69ad665..9d09384561 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -576,7 +576,7 @@ mod test { use hugr_core::builder::{ inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, + FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, }; use hugr_core::extension::prelude::{ bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, @@ -584,6 +584,7 @@ mod test { use hugr_core::extension::simple_op::MakeExtensionOp; use hugr_core::extension::{TypeDefBound, Version}; + use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value}; use hugr_core::std_extensions::arithmetic::int_types::ConstInt; @@ -1045,4 +1046,52 @@ mod test { let mut h = backup; repl.run(&mut h).unwrap(); // Includes validation } + + #[test] + fn op_to_call() { + let e = ext(); + let pv = e.get_type(PACKED_VEC).unwrap(); + let inner = pv.instantiate([usize_t().into()]).unwrap(); + let outer = pv + .instantiate([Type::new_extension(inner.clone()).into()]) + .unwrap(); + let mut dfb = DFGBuilder::new(inout_sig(vec![outer.into(), i64_t()], usize_t())).unwrap(); + let [outer, idx] = dfb.input_wires_arr(); + let [inner] = dfb + .add_dataflow_op(read_op(&e, inner.clone().into()), [outer, idx]) + .unwrap() + .outputs_arr(); + let res = dfb + .add_dataflow_op(read_op(&e, usize_t()), [inner, idx]) + .unwrap(); + let mut h = dfb.finish_hugr_with_outputs(res.outputs()).unwrap(); + let read_func = h + .insert_hugr( + h.root(), + lowered_read(Type::new_var_use(0, TypeBound::Copyable), |sig| { + FunctionBuilder::new( + "lowered_read", + PolyFuncType::new([TypeBound::Copyable.into()], sig), + ) + }) + .finish_hugr() + .unwrap(), + ) + .new_root; + + let mut lw = lowerer(&e); + lw.replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args| { + Some(NodeTemplate::Call(read_func, args.to_owned())) + }); + lw.run(&mut h).unwrap(); + + assert_eq!(h.output_neighbours(read_func).count(), 2); + let ext_op_names = h + .nodes() + .filter_map(|n| h.get_optype(n).as_extension_op()) + .map(|e| e.def().name()) + .sorted() + .collect_vec(); + assert_eq!(ext_op_names, ["get", "itousize", "panic",]); + } } From 6911b0b63b89a556b742f0ab7b1006f033e00d0b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 17 Apr 2025 08:43:09 +0100 Subject: [PATCH 06/12] fix all-features --- hugr-passes/src/replace_types.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 9d09384561..ebee74c538 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -579,18 +579,17 @@ mod test { FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, }; use hugr_core::extension::prelude::{ - bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, + bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, PRELUDE_ID, }; - use hugr_core::extension::simple_op::MakeExtensionOp; - use hugr_core::extension::{TypeDefBound, Version}; + use hugr_core::extension::{simple_op::MakeExtensionOp, ExtensionSet, TypeDefBound, Version}; use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value}; - use hugr_core::std_extensions::arithmetic::int_types::ConstInt; - use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; + use hugr_core::std_extensions::arithmetic::conversions::{self, ConvertOpDef}; + use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; use hugr_core::std_extensions::collections::array::{ - array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue, + self, array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue, }; use hugr_core::std_extensions::collections::list::{ list_type, list_type_def, ListOp, ListValue, @@ -669,10 +668,15 @@ mod test { elem_ty: Type, new: impl Fn(Signature) -> Result, ) -> T { - let mut dfb = new(inout_sig( + let mut dfb = new(Signature::new( vec![array_type(64, elem_ty.clone()), i64_t()], elem_ty.clone(), - )) + ) + .with_extension_delta(ExtensionSet::from_iter([ + PRELUDE_ID, + array::EXTENSION_ID, + conversions::EXTENSION_ID, + ]))) .unwrap(); let [val, idx] = dfb.input_wires_arr(); let [idx] = dfb From 5daa87b2c56aa0ac209989717ac2ae8e66da614e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 21 Apr 2025 12:19:45 +0100 Subject: [PATCH 07/12] Add AddTemplateError, wrap in LinearizeError - can for ReplaceTypesError later --- hugr-passes/src/replace_types.rs | 50 ++++++++++++++++------ hugr-passes/src/replace_types/linearize.rs | 9 +++- 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index ebee74c538..b5dace5bfc 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -65,16 +65,16 @@ impl NodeTemplate { /// * has a [`signature`] which the type-args of the [Self::Call] do not match /// /// [`signature`]: hugr_core::types::PolyFuncType - pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Node { + pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Result { match self { - NodeTemplate::SingleOp(op_type) => hugr.add_node_with_parent(parent, op_type), - NodeTemplate::CompoundOp(new_h) => hugr.insert_hugr(parent, *new_h).new_root, + NodeTemplate::SingleOp(op_type) => Ok(hugr.add_node_with_parent(parent, op_type)), + NodeTemplate::CompoundOp(new_h) => Ok(hugr.insert_hugr(parent, *new_h).new_root), NodeTemplate::Call(target, type_args) => { - let c = call(hugr, target, type_args); + let c = call(hugr, target, type_args)?; let tgt_port = c.called_function_port(); let n = hugr.add_node_with_parent(parent, c); hugr.connect(target, 0, n, tgt_port); - n + Ok(n) } } } @@ -96,7 +96,7 @@ impl NodeTemplate { } } - fn replace(&self, hugr: &mut impl HugrMut, n: Node) { + fn replace(&self, hugr: &mut impl HugrMut, n: Node) -> Result<(), AddTemplateError> { assert_eq!(hugr.children(n).count(), 0); let new_optype = match self.clone() { NodeTemplate::SingleOp(op_type) => op_type, @@ -110,7 +110,7 @@ impl NodeTemplate { root_opty } NodeTemplate::Call(func, type_args) => { - let c = call(hugr, func, type_args); + let c = call(hugr, func, type_args)?; let static_inport = c.called_function_port(); // insert an input for the Call static input hugr.insert_ports(n, Direction::Incoming, static_inport.index(), 1); @@ -120,6 +120,7 @@ impl NodeTemplate { } }; *hugr.optype_mut(n) = new_optype; + Ok(()) } fn check_signature( @@ -141,13 +142,31 @@ impl NodeTemplate { } } -fn call(h: &H, func: H::Node, type_args: Vec) -> Call { +fn call>( + h: &H, + func: Node, + type_args: Vec, +) -> Result { + if !h.contains_node(func) { + return Err(AddTemplateError::NotFunction(func, "absent".to_string())); + } let func_sig = match h.get_optype(func) { OpType::FuncDecl(fd) => fd.signature.clone(), OpType::FuncDefn(fd) => fd.signature.clone(), - o => panic!("Node {func}: expected FuncDecl or FuncDefn, got {o:?}"), + o => return Err(AddTemplateError::NotFunction(func, o.to_string())), }; - Call::try_new(func_sig, type_args).unwrap() + Ok(Call::try_new(func_sig, type_args)?) +} + +#[derive(Clone, Debug, PartialEq, Eq, Error)] +#[non_exhaustive] +#[allow(missing_docs)] +/// An error from [NodeTemplate::add_hugr], currently only from [NodeTemplate::Call]s +pub enum AddTemplateError { + #[error("Target {0} of call was not a function but was {1}")] + NotFunction(Node, String), + #[error(transparent)] + SignatureError(#[from] SignatureError), } /// A configuration of what types, ops, and constants should be replaced with what. @@ -238,6 +257,8 @@ pub enum ReplaceTypesError { ConstError(#[from] ConstTypeError), #[error(transparent)] LinearizeError(#[from] LinearizeError), + #[error("Replacement op for {0} could not be added because {1}")] + AddTemplateError(Node, AddTemplateError), } impl ReplaceTypes { @@ -459,8 +480,11 @@ impl ReplaceTypes { OpType::Const(Const { value, .. }) => self.change_value(value), OpType::ExtensionOp(ext_op) => Ok( + // Copy/discard insertion done by caller if let Some(replacement) = self.op_map.get(&OpHashWrapper::from(&*ext_op)) { - replacement.replace(hugr, n); // Copy/discard insertion done by caller + replacement + .replace(hugr, n) + .map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?; true } else { let def = ext_op.def_arc(); @@ -471,7 +495,9 @@ impl ReplaceTypes { .get(&def.as_ref().into()) .and_then(|rep_fn| rep_fn(&args)) { - replacement.replace(hugr, n); + replacement + .replace(hugr, n) + .map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?; true } else { if ch { diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index a80f30aee2..56bfb8562c 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -11,6 +11,7 @@ use hugr_core::types::{CustomType, Signature, Type, TypeArg, TypeEnum, TypeRow}; use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, HugrView, IncomingPort, Node, Wire}; use itertools::Itertools; +use super::AddTemplateError; use super::{handlers::linearize_array, NodeTemplate, ParametricType}; /// Trait for things that know how to wire up linear outports to other than one @@ -75,9 +76,11 @@ pub trait Linearizer { tgt_parent, }); } + let typ = typ.clone(); // Stop borrowing hugr in order to add_hugr to it let copy_discard_op = self - .copy_discard_op(typ, targets.len())? - .add_hugr(hugr, src_parent); + .copy_discard_op(&typ, targets.len())? + .add_hugr(hugr, src_parent) + .map_err(|e| LinearizeError::NestedTemplateError(typ, e))?; for (n, (tgt_node, tgt_port)) in targets.iter().enumerate() { hugr.connect(copy_discard_op, n, *tgt_node, *tgt_port); } @@ -162,6 +165,8 @@ pub enum LinearizeError { /// Neither does linearization make sense for copyable types #[error("Type {_0} is copyable")] CopyableType(Type), + #[error("Could not add operation for contained type {0} because {1}")] + NestedTemplateError(Type, AddTemplateError), } impl DelegatingLinearizer { From d1af455483ce3fa3470f592bd7b542188ea7baed Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 21 Apr 2025 12:57:53 +0100 Subject: [PATCH 08/12] New test - but SUT actually breaks on add(Builder).unwrap() --- hugr-passes/src/replace_types/linearize.rs | 66 +++++++++++++++++++++- 1 file changed, 64 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 56bfb8562c..c70bf7ffa1 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -353,7 +353,9 @@ mod test { use std::iter::successors; use std::sync::Arc; - use hugr_core::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}; + use hugr_core::builder::{ + inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, + }; use hugr_core::extension::prelude::{option_type, usize_t}; use hugr_core::extension::simple_op::MakeExtensionOp; @@ -376,7 +378,7 @@ mod test { use rstest::rstest; use crate::replace_types::handlers::linearize_array; - use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError}; + use crate::replace_types::{AddTemplateError, LinearizeError, NodeTemplate, ReplaceTypesError}; use crate::ReplaceTypes; const LIN_T: &str = "Lin"; @@ -768,4 +770,64 @@ mod test { )); assert_eq!(copy_sig.input[2..], copy_sig.output[1..]); } + + #[test] + fn call_ok_except_in_array() { + let (e, _) = ext_lowerer(); + let lin_ct = e.get_type(LIN_T).unwrap().instantiate([]).unwrap(); + let lin_t: Type = lin_ct.clone().into(); + + // A simple Hugr that discards a usize_t, with a "drop" function + let mut dfb = DFGBuilder::new(inout_sig(usize_t(), type_row![])).unwrap(); + let discard_fn = { + let mut fb = dfb + .define_function("drop", inout_sig(lin_t.clone(), type_row![])) + .unwrap(); + let ins = fb.input_wires(); + fb.add_dataflow_op( + ExtensionOp::new(e.get_op("discard").unwrap().clone(), []).unwrap(), + ins, + ) + .unwrap(); + fb.finish_with_outputs([]).unwrap() + } + .node(); + let backup = dfb.finish_hugr().unwrap(); + + let mut lower_discard_to_call = ReplaceTypes::default(); + // The `copy_fn` here will break completely, but we don't use it + lower_discard_to_call + .linearizer() + .register_simple( + lin_ct.clone(), + NodeTemplate::Call(backup.root(), vec![]), + NodeTemplate::Call(discard_fn, vec![]), + ) + .unwrap(); + + // Ok to lower usize_t to lin_t and call that function + { + let mut lowerer = lower_discard_to_call.clone(); + lowerer.replace_type(usize_t().as_extension().unwrap().clone(), lin_t.clone()); + let mut h = backup.clone(); + lowerer.run(&mut h).unwrap(); + assert_eq!(h.output_neighbours(discard_fn).count(), 1); + } + + // But if we lower usize_t to array, the call will fail + lower_discard_to_call.replace_type( + usize_t().as_extension().unwrap().clone(), + array_type(4, lin_ct.into()), + ); + let r = lower_discard_to_call.run(&mut backup.clone()); + assert!(matches!( + r, + Err(ReplaceTypesError::LinearizeError( + LinearizeError::NestedTemplateError( + nested_t, + AddTemplateError::NotFunction(tgt, _) + ) + )) if nested_t == lin_t && tgt == discard_fn + )); + } } From 4f77d2d9bba5690c80ddb1902fe91a7922db541d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 21 Apr 2025 13:09:05 +0100 Subject: [PATCH 09/12] Standardize on BuildError so can plumb add(Builder) errors too, handler does so --- hugr-passes/src/replace_types.rs | 29 ++++++++-------------- hugr-passes/src/replace_types/handlers.rs | 4 +-- hugr-passes/src/replace_types/linearize.rs | 22 ++++++++-------- 3 files changed, 24 insertions(+), 31 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index b5dace5bfc..fe6a4388f5 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -65,7 +65,7 @@ impl NodeTemplate { /// * has a [`signature`] which the type-args of the [Self::Call] do not match /// /// [`signature`]: hugr_core::types::PolyFuncType - pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Result { + pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Result { match self { NodeTemplate::SingleOp(op_type) => Ok(hugr.add_node_with_parent(parent, op_type)), NodeTemplate::CompoundOp(new_h) => Ok(hugr.insert_hugr(parent, *new_h).new_root), @@ -96,7 +96,7 @@ impl NodeTemplate { } } - fn replace(&self, hugr: &mut impl HugrMut, n: Node) -> Result<(), AddTemplateError> { + fn replace(&self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { assert_eq!(hugr.children(n).count(), 0); let new_optype = match self.clone() { NodeTemplate::SingleOp(op_type) => op_type, @@ -146,29 +146,20 @@ fn call>( h: &H, func: Node, type_args: Vec, -) -> Result { - if !h.contains_node(func) { - return Err(AddTemplateError::NotFunction(func, "absent".to_string())); - } +) -> Result { let func_sig = match h.get_optype(func) { OpType::FuncDecl(fd) => fd.signature.clone(), OpType::FuncDefn(fd) => fd.signature.clone(), - o => return Err(AddTemplateError::NotFunction(func, o.to_string())), + _ => { + return Err(BuildError::UnexpectedType { + node: func, + op_desc: "func defn/decl", + }) + } }; Ok(Call::try_new(func_sig, type_args)?) } -#[derive(Clone, Debug, PartialEq, Eq, Error)] -#[non_exhaustive] -#[allow(missing_docs)] -/// An error from [NodeTemplate::add_hugr], currently only from [NodeTemplate::Call]s -pub enum AddTemplateError { - #[error("Target {0} of call was not a function but was {1}")] - NotFunction(Node, String), - #[error(transparent)] - SignatureError(#[from] SignatureError), -} - /// A configuration of what types, ops, and constants should be replaced with what. /// May be applied to a Hugr via [Self::run]. /// @@ -258,7 +249,7 @@ pub enum ReplaceTypesError { #[error(transparent)] LinearizeError(#[from] LinearizeError), #[error("Replacement op for {0} could not be added because {1}")] - AddTemplateError(Node, AddTemplateError), + AddTemplateError(Node, BuildError), } impl ReplaceTypes { diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index e835a2d9b8..b6e6e67809 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -92,7 +92,7 @@ pub fn linearize_array( let [to_discard] = dfb.input_wires_arr(); lin.copy_discard_op(ty, 0)? .add(&mut dfb, [to_discard]) - .unwrap(); + .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))?; let ret = dfb.add_load_value(Value::unary_unit_sum()); dfb.finish_hugr_with_outputs([ret]).unwrap() }; @@ -162,7 +162,7 @@ pub fn linearize_array( let mut copies = lin .copy_discard_op(ty, num_outports)? .add(&mut dfb, [elem]) - .unwrap() + .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))? .outputs(); let copy0 = copies.next().unwrap(); // We'll return this directly diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index c70bf7ffa1..230aaa6684 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -2,8 +2,8 @@ use std::iter::repeat; use std::{collections::HashMap, sync::Arc}; use hugr_core::builder::{ - inout_sig, ConditionalBuilder, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, + inout_sig, BuildError, ConditionalBuilder, DFGBuilder, Dataflow, DataflowHugr, + DataflowSubContainer, HugrBuilder, }; use hugr_core::extension::{SignatureError, TypeDef}; use hugr_core::std_extensions::collections::array::array_type_def; @@ -11,7 +11,6 @@ use hugr_core::types::{CustomType, Signature, Type, TypeArg, TypeEnum, TypeRow}; use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, HugrView, IncomingPort, Node, Wire}; use itertools::Itertools; -use super::AddTemplateError; use super::{handlers::linearize_array, NodeTemplate, ParametricType}; /// Trait for things that know how to wire up linear outports to other than one @@ -135,7 +134,7 @@ impl Default for DelegatingLinearizer { // rather than passing a &DelegatingLinearizer directly) pub struct CallbackHandler<'a>(#[allow(dead_code)] &'a DelegatingLinearizer); -#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] +#[derive(Clone, Debug, thiserror::Error, PartialEq)] #[allow(missing_docs)] #[non_exhaustive] pub enum LinearizeError { @@ -165,8 +164,10 @@ pub enum LinearizeError { /// Neither does linearization make sense for copyable types #[error("Type {_0} is copyable")] CopyableType(Type), - #[error("Could not add operation for contained type {0} because {1}")] - NestedTemplateError(Type, AddTemplateError), + /// Error may be returned by a callback for e.g. a container because it could + /// not generate a [NodeTemplate] because of a problem with an element + #[error("Could not generate NodeTemplate for contained type {0} because {1}")] + NestedTemplateError(Type, BuildError), } impl DelegatingLinearizer { @@ -354,7 +355,8 @@ mod test { use std::sync::Arc; use hugr_core::builder::{ - inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, + inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + HugrBuilder, }; use hugr_core::extension::prelude::{option_type, usize_t}; @@ -378,7 +380,7 @@ mod test { use rstest::rstest; use crate::replace_types::handlers::linearize_array; - use crate::replace_types::{AddTemplateError, LinearizeError, NodeTemplate, ReplaceTypesError}; + use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError}; use crate::ReplaceTypes; const LIN_T: &str = "Lin"; @@ -825,9 +827,9 @@ mod test { Err(ReplaceTypesError::LinearizeError( LinearizeError::NestedTemplateError( nested_t, - AddTemplateError::NotFunction(tgt, _) + BuildError::UnexpectedType { node, .. } ) - )) if nested_t == lin_t && tgt == discard_fn + )) if nested_t == lin_t && node == discard_fn )); } } From de0e9022ba263013d7587d12326d07387a83aab1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 21 Apr 2025 13:14:33 +0100 Subject: [PATCH 10/12] all-features --- hugr-passes/src/replace_types/linearize.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 230aaa6684..8a48c3df4e 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -783,7 +783,11 @@ mod test { let mut dfb = DFGBuilder::new(inout_sig(usize_t(), type_row![])).unwrap(); let discard_fn = { let mut fb = dfb - .define_function("drop", inout_sig(lin_t.clone(), type_row![])) + .define_function( + "drop", + Signature::new(lin_t.clone(), type_row![]) + .with_extension_delta(e.name().clone()), + ) .unwrap(); let ins = fb.input_wires(); fb.add_dataflow_op( From ea8fb2cc9630a443df64acf6c31dd7f64d3b5a2b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 21 Apr 2025 13:19:37 +0100 Subject: [PATCH 11/12] Doc add_hugr returns Errors --- hugr-passes/src/replace_types.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index fe6a4388f5..e0580425ae 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -60,6 +60,9 @@ impl NodeTemplate { /// # Panics /// /// * If `parent` is not in the `hugr` + /// + /// # Errors + /// /// * If `self` is a [Self::Call] and the target Node either /// * is neither a [FuncDefn] nor a [FuncDecl] /// * has a [`signature`] which the type-args of the [Self::Call] do not match From 5b29b4ceb128886b57637376524b0303cd4641cd Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Apr 2025 10:52:41 +0100 Subject: [PATCH 12/12] Comments - remove the TODO that is done --- hugr-passes/src/replace_types.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 143b43419d..df4c140758 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -46,10 +46,8 @@ pub enum NodeTemplate { /// Note this will be of limited use before [monomorphization](super::monomorphize()) /// because the new subtree will not be able to use type variables present in the /// parent Hugr or previous op. - // TODO: store also a vec, and update Hugr::validate to take &[TypeParam]s - // (defaulting to empty list) - see https://github.com/CQCL/hugr/issues/709 CompoundOp(Box), - /// A Call to a function (already) existing in the Hugr. + /// A Call to an existing function. Call(Node, Vec), }