diff --git a/hugr-passes/src/normalize_cfgs.rs b/hugr-passes/src/normalize_cfgs.rs index d26c969b8..20d2494c0 100644 --- a/hugr-passes/src/normalize_cfgs.rs +++ b/hugr-passes/src/normalize_cfgs.rs @@ -84,8 +84,9 @@ pub enum NormalizeCFGResult { CFGToDFG, /// The CFG was preserved, but the entry or exit blocks may have changed. CFGPreserved { - /// If `Some`, the new [DFG] containing what was previously in the entry block - entry_dfg: Option, + /// Nodes that were in the entry block but have been moved to be siblings of the CFG. + /// (Either empty, or all the nodes in the entry block except [Input]/[Output].) + entry_nodes_moved: Vec, /// If `Some`, the new [DFG] of what was previously in the last block before the exit exit_dfg: Option, /// The number of basic blocks merged together. @@ -186,7 +187,7 @@ pub fn normalize_cfg( // However, we only do this if the Entry block has just one successor (i.e. we can remove // the entry block altogether) - an extension would be to do this in other cases, preserving // the Entry block as an empty branch. - let mut entry_dfg = None; + let mut entry_nodes_moved = Vec::new(); if let Some(succ) = h .output_neighbours(entry) .exactly_one() @@ -217,51 +218,59 @@ pub fn normalize_cfg( unpack_before_output(h, h.get_io(cfg_node).unwrap()[1], result_tys); return Ok(NormalizeCFGResult::CFGToDFG); } - // 1b. Move entry block outside/before the CFG into a DFG; its successor becomes the entry block. + // 1b. Move entry block outside/before the CFG; its successor becomes the entry block. let new_cfg_inputs = entry_blk.successor_input(0).unwrap(); // Look for nonlocal `Dom` edges from the entry block. (Ignore `Ext` edges.) - let dests = h.children(entry).flat_map(|n| h.output_neighbours(n)); - let has_dom_outs = dests.dedup().any(|succ| { - ancestor_block(h, succ).expect("Dom edges within entry, Ext within CFG") != entry - }); - if !has_dom_outs { - // Move entry block contents into DFG. - let dfg = h.add_node_with_parent( - cfg_parent, - DFG { - signature: Signature::new(entry_blk.inputs.clone(), new_cfg_inputs.clone()), - }, - ); - let [_, entry_output] = h.get_io(entry).unwrap(); - while let Some(n) = h.first_child(entry) { - h.set_parent(n, dfg); - } - h.move_before_sibling(succ, entry); - h.remove_node(entry); + let nonlocal_srcs = h + .children(entry) + .filter(|n| { + h.output_neighbours(*n).any(|succ| { + ancestor_block(h, succ).expect("Dom edges within entry, Ext within CFG") + != entry + }) + }) + .collect::>(); + // Move entry block contents into DFG. + let dfg = h.add_node_with_parent( + cfg_parent, + DFG { + signature: Signature::new(entry_blk.inputs.clone(), new_cfg_inputs.clone()), + }, + ); + let [_, entry_output] = h.get_io(entry).unwrap(); + while let Some(n) = h.first_child(entry) { + h.set_parent(n, dfg); + } + h.move_before_sibling(succ, entry); + h.remove_node(entry); - unpack_before_output(h, entry_output, new_cfg_inputs.clone()); + unpack_before_output(h, entry_output, new_cfg_inputs.clone()); - // Inputs to CFG go directly to DFG - for inp in h.node_inputs(cfg_node).collect::>() { - for src in h.linked_outputs(cfg_node, inp).collect::>() { - h.connect(src.0, src.1, dfg, inp.index()); - } - h.disconnect(cfg_node, inp); + // Inputs to CFG go directly to DFG + for inp in h.node_inputs(cfg_node).collect::>() { + for src in h.linked_outputs(cfg_node, inp).collect::>() { + h.connect(src.0, src.1, dfg, inp.index()); } + h.disconnect(cfg_node, inp); + } - // Update input ports - let cfg_ty = cfg_ty_mut(h, cfg_node); - let inputs_to_add = - new_cfg_inputs.len() as isize - cfg_ty.signature.input.len() as isize; - cfg_ty.signature.input = new_cfg_inputs; - h.add_ports(cfg_node, Direction::Incoming, inputs_to_add); + // Update input ports + let cfg_ty = cfg_ty_mut(h, cfg_node); + let inputs_to_add = new_cfg_inputs.len() as isize - cfg_ty.signature.input.len() as isize; + cfg_ty.signature.input = new_cfg_inputs; + h.add_ports(cfg_node, Direction::Incoming, inputs_to_add); - // Wire outputs of DFG directly to CFG - for src in h.node_outputs(dfg).collect::>() { - h.connect(dfg, src, cfg_node, src.index()); - } - entry_dfg = Some(dfg); + // Wire outputs of DFG directly to CFG + for src in h.node_outputs(dfg).collect::>() { + h.connect(dfg, src, cfg_node, src.index()); + } + // Inline DFG to ensure that any nonlocal (`Dom`) edges from it, become valid `Ext` edges + for n in nonlocal_srcs { + // With required Order edge. (Do this before inlining, in case n is Input.) + h.add_other_edge(n, cfg_node); } + entry_nodes_moved.extend(h.children(dfg).skip(2)); // Skip Input/Output nodes + h.apply_patch(InlineDFG(dfg.into())).unwrap(); } // 2. If the exit node has a single predecessor and that predecessor has no other successors... let mut exit_dfg = None; @@ -323,7 +332,7 @@ pub fn normalize_cfg( exit_dfg = Some(dfg); } Ok(NormalizeCFGResult::CFGPreserved { - entry_dfg, + entry_nodes_moved, exit_dfg, num_merged, }) @@ -776,13 +785,14 @@ mod test { let res = normalize_cfg(&mut h).unwrap(); h.validate().unwrap(); let NormalizeCFGResult::CFGPreserved { - entry_dfg: Some(dfg), + entry_nodes_moved, exit_dfg: None, num_merged: 0, } = res else { panic!("Unexpected result"); }; + assert_eq!(entry_nodes_moved.len(), 4); // Noop, Const, LoadConstant, UnpackTuple assert_eq!( h.children(h.entrypoint()) .map(|n| h.get_optype(n).tag()) @@ -793,20 +803,8 @@ mod test { let func_children = child_tags_ext_ids(&h, func); assert_eq!( func_children.into_iter().sorted().collect_vec(), - ["Cfg", "Dfg", "Input", "Output",] - ); - assert_eq!( - h.children(func) - .filter(|n| h.get_optype(*n).is_dfg()) - .collect_vec(), - [dfg] - ); - assert_eq!( - child_tags_ext_ids(&h, dfg) - .into_iter() - .sorted() - .collect_vec(), [ + "Cfg", "Const", "Input", "LoadConst", @@ -849,13 +847,14 @@ mod test { let res = normalize_cfg(&mut h).unwrap(); h.validate().unwrap(); let NormalizeCFGResult::CFGPreserved { - entry_dfg: None, + entry_nodes_moved, exit_dfg: Some(dfg), num_merged: 0, } = res else { panic!("Unexpected result"); }; + assert_eq!(entry_nodes_moved, []); assert_eq!( h.children(h.entrypoint()) .map(|n| h.get_optype(n).tag()) @@ -963,65 +962,47 @@ mod test { assert_eq!(h.get_parent(tail_pred.node()), Some(tail_b.node())); let mut res = NormalizeCFGPass::default().run(&mut h).unwrap(); + h.validate().unwrap(); assert_eq!( res.remove(&inner.node()), Some(NormalizeCFGResult::CFGToDFG) ); let Some(NormalizeCFGResult::CFGPreserved { - entry_dfg, + entry_nodes_moved, exit_dfg: Some(tail_dfg), num_merged: 0, }) = res.remove(&h.entrypoint()) else { panic!("Unexpected result") }; - assert!(res.is_empty()); - + assert_eq!(entry_nodes_moved.len(), 3); + // Now contains only one CFG with one BB (self-loop) assert_eq!( h.nodes() .filter(|n| h.get_optype(*n).is_cfg()) - .collect_vec(), - vec![h.entrypoint()] + .exactly_one() + .ok(), + Some(h.entrypoint()) ); - let [loop_, exit] = if nonlocal { - let [entry, exit, loop_] = h.children(h.entrypoint()).collect_array().unwrap(); - assert_eq!(h.get_parent(entry_pred.node()), Some(entry)); - [loop_, exit] - } else { - h.children(h.entrypoint()).collect_array().unwrap() - }; - - assert_eq!(h.output_neighbours(loop_).collect_vec(), [loop_, exit]); - + let [entry, exit] = h.children(h.entrypoint()).collect_array().unwrap(); + assert_eq!(h.output_neighbours(entry).collect_vec(), [entry, exit]); // Inner CFG is now a DFG (and still sibling of entry_pred)... assert_eq!(h.get_parent(inner_pred), Some(inner.node())); assert_eq!(h.get_optype(inner.node()).tag(), OpTag::Dfg); assert_eq!(h.get_parent(inner.node()), h.get_parent(entry_pred.node())); + // Predicates lifted appropriately... let func = h.get_parent(h.entrypoint()).unwrap(); + assert_eq!(h.get_parent(entry_pred.node()), Some(func)); assert_eq!(h.get_parent(tail_pred.node()), Some(tail_dfg)); assert_eq!(h.get_optype(tail_dfg).tag(), OpTag::Dfg); assert_eq!(h.get_parent(tail_dfg), Some(func)); - let lifted_preds = if nonlocal { - assert!(entry_dfg.is_none()); - // entry_pred not lifted, still connected to output - let [output] = h - .output_neighbours(entry_pred.node()) - .collect_array() - .unwrap(); - assert_eq!(h.get_optype(output).tag(), OpTag::Output); - vec![inner_pred.node(), tail_pred.node()] - } else { - assert_eq!(h.get_parent(entry_dfg.unwrap()), Some(func)); - assert_eq!(h.get_parent(entry_pred.node()), entry_dfg); - vec![inner_pred.node(), entry_pred.node(), tail_pred.node()] - }; // ...and followed by UnpackTuple's - for n in lifted_preds { + for n in [inner_pred, entry_pred.node(), tail_pred.node()] { let [unpack] = h.output_neighbours(n).collect_array().unwrap(); assert!( h.get_optype(unpack)