diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 41fe7ba45b..25316db7f7 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -57,8 +57,8 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { self.validate_node(node)?; } - // Hierarchy and children. No type variables declared outside the root. - self.validate_subtree(self.hugr.entrypoint(), &[])?; + // Hierarchy and children. No type variables declared by the module root. + self.validate_subtree(self.hugr.module_root(), &[])?; self.validate_linkage()?; // In tests we take the opportunity to verify that the hugr @@ -600,13 +600,9 @@ impl<'a, H: HugrView> ValidationContext<'a, H> { } // Check port connections. - // - // Root nodes are ignored, as they cannot have connected edges. - if node != self.hugr.entrypoint() { - for dir in Direction::BOTH { - for port in self.hugr.node_ports(node, dir) { - self.validate_port(node, port, op_type, var_decls)?; - } + for dir in Direction::BOTH { + for port in self.hugr.node_ports(node, dir) { + self.validate_port(node, port, op_type, var_decls)?; } } diff --git a/hugr-passes/src/linearize_array.rs b/hugr-passes/src/linearize_array.rs index 07fbc6e958..dbec6b8c03 100644 --- a/hugr-passes/src/linearize_array.rs +++ b/hugr-passes/src/linearize_array.rs @@ -141,7 +141,7 @@ impl LinearizeArrayPass { #[cfg(test)] mod test { - use hugr_core::builder::ModuleBuilder; + use hugr_core::builder::{FunctionBuilder, ModuleBuilder}; use hugr_core::extension::prelude::{ConstUsize, Noop}; use hugr_core::ops::handle::NodeHandle; use hugr_core::ops::{Const, OpType}; @@ -287,7 +287,7 @@ mod test { ), }; let sig = Signature::new(src, tgt); - let mut builder = DFGBuilder::new(sig).unwrap(); + let mut builder = FunctionBuilder::new("main", sig).unwrap(); let [arr] = builder.input_wires_arr(); let op: OpType = match dir { INTO => VArrayToArray::new(elem_ty.clone(), size).into(), @@ -313,7 +313,7 @@ mod test { #[case(value_array_type(2, Type::new_tuple(vec![usize_t(), value_array_type(4, usize_t())])))] fn implicit_clone(#[case] array_ty: Type) { let sig = Signature::new(array_ty.clone(), vec![array_ty; 2]); - let mut builder = DFGBuilder::new(sig).unwrap(); + let mut builder = FunctionBuilder::new("main", sig).unwrap(); let [arr] = builder.input_wires_arr(); builder.set_outputs(vec![arr, arr]).unwrap(); @@ -329,7 +329,7 @@ mod test { #[case(value_array_type(2, Type::new_tuple(vec![usize_t(), value_array_type(4, usize_t())])))] fn implicit_discard(#[case] array_ty: Type) { let sig = Signature::new(array_ty, Type::EMPTY_TYPEROW); - let mut builder = DFGBuilder::new(sig).unwrap(); + let mut builder = FunctionBuilder::new("main", sig).unwrap(); builder.set_outputs(vec![]).unwrap(); let mut hugr = builder.finish_hugr().unwrap(); diff --git a/hugr-passes/src/normalize_cfgs.rs b/hugr-passes/src/normalize_cfgs.rs index 7599f6928b..d26c969b81 100644 --- a/hugr-passes/src/normalize_cfgs.rs +++ b/hugr-passes/src/normalize_cfgs.rs @@ -168,6 +168,16 @@ pub fn normalize_cfg( _ => unreachable!(), // Checked at entry to normalize_cfg } } + let ancestor_block = |h: &H, mut n: H::Node| { + while let Some(p) = h.get_parent(n) { + if p == cfg_node { + return Some(n); + } + n = p; + } + None + }; + // Further normalizations with effects outside the CFG let [entry, exit] = h.children(cfg_node).take(2).collect_array().unwrap(); let entry_blk = h.get_optype(entry).as_dataflow_block().unwrap(); @@ -209,49 +219,60 @@ pub fn normalize_cfg( } // 1b. Move entry block outside/before the CFG into a DFG; its successor becomes the entry block. let new_cfg_inputs = entry_blk.successor_input(0).unwrap(); - let dfg = h.add_node_with_parent( - cfg_parent, - DFG { - signature: Signature::new(entry_blk.inputs.clone(), new_cfg_inputs.clone()), - }, - ); - let [_, entry_output] = h.get_io(entry).unwrap(); - while let Some(n) = h.first_child(entry) { - h.set_parent(n, dfg); - } - h.move_before_sibling(succ, entry); - h.remove_node(entry); + // Look for nonlocal `Dom` edges from the entry block. (Ignore `Ext` edges.) + let dests = h.children(entry).flat_map(|n| h.output_neighbours(n)); + let has_dom_outs = dests.dedup().any(|succ| { + ancestor_block(h, succ).expect("Dom edges within entry, Ext within CFG") != entry + }); + if !has_dom_outs { + // Move entry block contents into DFG. + let dfg = h.add_node_with_parent( + cfg_parent, + DFG { + signature: Signature::new(entry_blk.inputs.clone(), new_cfg_inputs.clone()), + }, + ); + let [_, entry_output] = h.get_io(entry).unwrap(); + while let Some(n) = h.first_child(entry) { + h.set_parent(n, dfg); + } + h.move_before_sibling(succ, entry); + h.remove_node(entry); - unpack_before_output(h, entry_output, new_cfg_inputs.clone()); + unpack_before_output(h, entry_output, new_cfg_inputs.clone()); - // Inputs to CFG go directly to DFG - for inp in h.node_inputs(cfg_node).collect::>() { - for src in h.linked_outputs(cfg_node, inp).collect::>() { - h.connect(src.0, src.1, dfg, inp.index()); + // Inputs to CFG go directly to DFG + for inp in h.node_inputs(cfg_node).collect::>() { + for src in h.linked_outputs(cfg_node, inp).collect::>() { + h.connect(src.0, src.1, dfg, inp.index()); + } + h.disconnect(cfg_node, inp); } - h.disconnect(cfg_node, inp); - } - // Update input ports - let cfg_ty = cfg_ty_mut(h, cfg_node); - let inputs_to_add = new_cfg_inputs.len() as isize - cfg_ty.signature.input.len() as isize; - cfg_ty.signature.input = new_cfg_inputs; - h.add_ports(cfg_node, Direction::Incoming, inputs_to_add); + // Update input ports + let cfg_ty = cfg_ty_mut(h, cfg_node); + let inputs_to_add = + new_cfg_inputs.len() as isize - cfg_ty.signature.input.len() as isize; + cfg_ty.signature.input = new_cfg_inputs; + h.add_ports(cfg_node, Direction::Incoming, inputs_to_add); - // Wire outputs of DFG directly to CFG - for src in h.node_outputs(dfg).collect::>() { - h.connect(dfg, src, cfg_node, src.index()); + // Wire outputs of DFG directly to CFG + for src in h.node_outputs(dfg).collect::>() { + h.connect(dfg, src, cfg_node, src.index()); + } + entry_dfg = Some(dfg); } - entry_dfg = Some(dfg); } // 2. If the exit node has a single predecessor and that predecessor has no other successors... let mut exit_dfg = None; - if let Some(pred) = h - .input_neighbours(exit) - .exactly_one() - .ok() - .filter(|pred| h.output_neighbours(*pred).count() == 1) - { + if let Some(pred) = h.input_neighbours(exit).exactly_one().ok().filter(|pred| { + // Allow only if there are no `Dom` edges into `pred`. (Ignore `Ext` edges.) + let src_nodes = h.children(*pred).flat_map(|ch| h.input_neighbours(ch)); + h.output_neighbours(*pred).count() == 1 + && src_nodes.dedup().all(|src| { + ancestor_block(h, src).is_none_or(|src| src == *pred) // Nones are `Ext` edges. + }) + }) { // Code in that predecessor can be moved outside (into a new DFG after the CFG), // and the predecessor deleted let [_, output] = h.get_io(pred).unwrap(); @@ -866,17 +887,19 @@ mod test { Ok(()) } - #[test] - fn nested_cfgs_pass() { + #[rstest] + fn nested_cfgs_pass(#[values(true, false)] nonlocal: bool) { // --> Entry --> Loop --> Tail --> EXIT // | / \ // (E->X) \<-/ let e = extension(); let tst_op = e.instantiate_extension_op("Test", []).unwrap(); - let qqu = vec![qb_t(), qb_t(), usize_t()]; + let qqu = TypeRow::from(vec![qb_t(), qb_t(), usize_t()]); let qq = TypeRow::from(vec![qb_t(); 2]); let mut outer = CFGBuilder::new(inout_sig(qqu.clone(), vec![usize_t(), qb_t()])).unwrap(); - let mut entry = outer.entry_builder(vec![qq.clone()], type_row![]).unwrap(); + let mut entry = outer + .entry_builder(vec![qq.clone()], usize_t().into()) + .unwrap(); let [q1, q2, u] = entry.input_wires_arr(); let (inner, inner_pred) = { let mut inner = entry @@ -897,31 +920,33 @@ mod test { .add_dataflow_op(Tag::new(0, vec![qq.clone()]), [q1, q2]) .unwrap() .outputs_arr(); - let entry = entry.finish_with_outputs(entry_pred, []).unwrap(); + let entry = entry.finish_with_outputs(entry_pred, [u]).unwrap(); let loop_b = { + let qu = [qb_t(), usize_t()]; let mut loop_b = outer - .block_builder(qq.clone(), [qb_t().into(), usize_t().into()], qb_t().into()) + .block_builder(qqu, qu.clone().map(TypeRow::from), Vec::from(qu).into()) .unwrap(); - let [q1, q2] = loop_b.input_wires_arr(); + let [q1, q2, u_local] = loop_b.input_wires_arr(); // u here is `dom` edge from entry block let [pred] = loop_b - .add_dataflow_op(tst_op, [q1, u]) + .add_dataflow_op(tst_op, [q1, if nonlocal { u } else { u_local }]) .unwrap() .outputs_arr(); - loop_b.finish_with_outputs(pred, [q2]).unwrap() + loop_b.finish_with_outputs(pred, [q2, u_local]).unwrap() }; outer.branch(&entry, 0, &loop_b).unwrap(); outer.branch(&loop_b, 0, &loop_b).unwrap(); let (tail_b, tail_pred) = { let uq = TypeRow::from(vec![usize_t(), qb_t()]); + let uqu = vec![usize_t(), qb_t(), usize_t()].into(); let mut tail_b = outer - .block_builder(uq.clone(), vec![uq.clone()], type_row![]) + .block_builder(uqu, vec![uq.clone()], type_row![]) .unwrap(); - let [u, q] = tail_b.input_wires_arr(); + let [u, q, _] = tail_b.input_wires_arr(); let [br] = tail_b - .add_dataflow_op(Tag::new(0, vec![uq.clone()]), [u, q]) + .add_dataflow_op(Tag::new(0, vec![uq]), [u, q]) .unwrap() .outputs_arr(); (tail_b.finish_with_outputs(br, []).unwrap(), br.node()) @@ -929,6 +954,7 @@ mod test { outer.branch(&loop_b, 1, &tail_b).unwrap(); outer.branch(&tail_b, 0, &outer.exit_block()).unwrap(); let mut h = outer.finish_hugr().unwrap(); + // Sanity checks: assert_eq!( h.get_parent(h.get_parent(inner_pred).unwrap()), Some(inner.node()) @@ -943,36 +969,59 @@ mod test { Some(NormalizeCFGResult::CFGToDFG) ); let Some(NormalizeCFGResult::CFGPreserved { - entry_dfg: Some(entry_dfg), + entry_dfg, exit_dfg: Some(tail_dfg), num_merged: 0, }) = res.remove(&h.entrypoint()) else { panic!("Unexpected result") }; + assert!(res.is_empty()); - // Now contains only one CFG with one BB (self-loop) + assert_eq!( h.nodes() .filter(|n| h.get_optype(*n).is_cfg()) - .exactly_one() - .ok(), - Some(h.entrypoint()) + .collect_vec(), + vec![h.entrypoint()] ); - let [entry, exit] = h.children(h.entrypoint()).collect_array().unwrap(); - assert_eq!(h.output_neighbours(entry).collect_vec(), [entry, exit]); + let [loop_, exit] = if nonlocal { + let [entry, exit, loop_] = h.children(h.entrypoint()).collect_array().unwrap(); + assert_eq!(h.get_parent(entry_pred.node()), Some(entry)); + [loop_, exit] + } else { + h.children(h.entrypoint()).collect_array().unwrap() + }; + + assert_eq!(h.output_neighbours(loop_).collect_vec(), [loop_, exit]); + // Inner CFG is now a DFG (and still sibling of entry_pred)... assert_eq!(h.get_parent(inner_pred), Some(inner.node())); assert_eq!(h.get_optype(inner.node()).tag(), OpTag::Dfg); assert_eq!(h.get_parent(inner.node()), h.get_parent(entry_pred.node())); // Predicates lifted appropriately... - for (n, parent) in [(entry_pred.node(), entry_dfg), (tail_pred.node(), tail_dfg)] { - assert_eq!(h.get_parent(n), Some(parent)); - assert_eq!(h.get_optype(parent).tag(), OpTag::Dfg); - assert_eq!(h.get_parent(parent), h.get_parent(h.entrypoint())); - } + let func = h.get_parent(h.entrypoint()).unwrap(); + + assert_eq!(h.get_parent(tail_pred.node()), Some(tail_dfg)); + assert_eq!(h.get_optype(tail_dfg).tag(), OpTag::Dfg); + assert_eq!(h.get_parent(tail_dfg), Some(func)); + let lifted_preds = if nonlocal { + assert!(entry_dfg.is_none()); + // entry_pred not lifted, still connected to output + let [output] = h + .output_neighbours(entry_pred.node()) + .collect_array() + .unwrap(); + assert_eq!(h.get_optype(output).tag(), OpTag::Output); + vec![inner_pred.node(), tail_pred.node()] + } else { + assert_eq!(h.get_parent(entry_dfg.unwrap()), Some(func)); + assert_eq!(h.get_parent(entry_pred.node()), entry_dfg); + vec![inner_pred.node(), entry_pred.node(), tail_pred.node()] + }; + // ...and followed by UnpackTuple's - for n in [inner_pred, entry_pred.node(), tail_pred.node()] { + for n in lifted_preds { let [unpack] = h.output_neighbours(n).collect_array().unwrap(); assert!( h.get_optype(unpack) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index ee03830bd4..e6c8d63a5e 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -1117,8 +1117,8 @@ mod test { where GenericArrayValue: CustomConst, { - let mut dfb = - DFGBuilder::new(inout_sig(type_row![], AK::ty(vals.len() as _, usize_t()))).unwrap(); + let sig = inout_sig(type_row![], AK::ty(vals.len() as _, usize_t())); + let mut dfb = FunctionBuilder::new("main", sig).unwrap(); let c = dfb.add_load_value(GenericArrayValue::::new( usize_t(), vals.iter().map(|u| ConstUsize::new(*u).into()), diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index edc6128813..6bf78f0383 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -375,7 +375,7 @@ mod test { use hugr_core::builder::{ BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, inout_sig, + FunctionBuilder, HugrBuilder, inout_sig, }; use hugr_core::extension::prelude::{option_type, qb_t, usize_t}; @@ -912,7 +912,7 @@ mod test { ); let build_hugr = |ty: Type| { - let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap(); + let mut dfb = FunctionBuilder::new("main", Signature::new(ty.clone(), vec![])).unwrap(); let [inp] = dfb.input_wires_arr(); let drop_op = drop_ext .instantiate_extension_op("drop", [ty.into()])