Skip to content

Commit cad2f55

Browse files
authored
feat(hugr-passes)!: normalize_cfgs inlines entry DFG (#2649)
Breaking followup to #2633. * normalize_cfgs moves the entry block outside the CFG *even if there are outgoing `Dom` edges* by inlining the DFG (so the entry block children become siblings of the CFG), as this makes said edges into valid `Ext` edges. Sadly, I don't see a good/corresponding treatment for the exit dfg; it doesn't help that the source block of the `Dom` edges must dominate the exit (and hence we could break the CFG just before that source block and start another CFG), because even then there could be backedges to the source (new CFG entry) block, so we can't necessarily lift it outside. Options thus seem to be: 1. thread the values from the source block through the whole CFG 2. add extra outputs to the CFG, and thus to each predecessor of the new exit block - but that requires either finding the corresponding Tag, or also adding new inputs to any other successors of those predecessors (and any other predecessors of those, and so on) ...and I've not tried either of those here. BREAKING CHANGE: NormalizeCFGResult specifies entry_nodes_moved not entry_dfg (as no DFG is inserted).
1 parent 164fef0 commit cad2f55

File tree

1 file changed

+67
-86
lines changed

1 file changed

+67
-86
lines changed

hugr-passes/src/normalize_cfgs.rs

Lines changed: 67 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,9 @@ pub enum NormalizeCFGResult<N = Node> {
8484
CFGToDFG,
8585
/// The CFG was preserved, but the entry or exit blocks may have changed.
8686
CFGPreserved {
87-
/// If `Some`, the new [DFG] containing what was previously in the entry block
88-
entry_dfg: Option<N>,
87+
/// Nodes that were in the entry block but have been moved to be siblings of the CFG.
88+
/// (Either empty, or all the nodes in the entry block except [Input]/[Output].)
89+
entry_nodes_moved: Vec<N>,
8990
/// If `Some`, the new [DFG] of what was previously in the last block before the exit
9091
exit_dfg: Option<N>,
9192
/// The number of basic blocks merged together.
@@ -186,7 +187,7 @@ pub fn normalize_cfg<H: HugrMut>(
186187
// However, we only do this if the Entry block has just one successor (i.e. we can remove
187188
// the entry block altogether) - an extension would be to do this in other cases, preserving
188189
// the Entry block as an empty branch.
189-
let mut entry_dfg = None;
190+
let mut entry_nodes_moved = Vec::new();
190191
if let Some(succ) = h
191192
.output_neighbours(entry)
192193
.exactly_one()
@@ -217,51 +218,59 @@ pub fn normalize_cfg<H: HugrMut>(
217218
unpack_before_output(h, h.get_io(cfg_node).unwrap()[1], result_tys);
218219
return Ok(NormalizeCFGResult::CFGToDFG);
219220
}
220-
// 1b. Move entry block outside/before the CFG into a DFG; its successor becomes the entry block.
221+
// 1b. Move entry block outside/before the CFG; its successor becomes the entry block.
221222
let new_cfg_inputs = entry_blk.successor_input(0).unwrap();
222223
// Look for nonlocal `Dom` edges from the entry block. (Ignore `Ext` edges.)
223-
let dests = h.children(entry).flat_map(|n| h.output_neighbours(n));
224-
let has_dom_outs = dests.dedup().any(|succ| {
225-
ancestor_block(h, succ).expect("Dom edges within entry, Ext within CFG") != entry
226-
});
227-
if !has_dom_outs {
228-
// Move entry block contents into DFG.
229-
let dfg = h.add_node_with_parent(
230-
cfg_parent,
231-
DFG {
232-
signature: Signature::new(entry_blk.inputs.clone(), new_cfg_inputs.clone()),
233-
},
234-
);
235-
let [_, entry_output] = h.get_io(entry).unwrap();
236-
while let Some(n) = h.first_child(entry) {
237-
h.set_parent(n, dfg);
238-
}
239-
h.move_before_sibling(succ, entry);
240-
h.remove_node(entry);
224+
let nonlocal_srcs = h
225+
.children(entry)
226+
.filter(|n| {
227+
h.output_neighbours(*n).any(|succ| {
228+
ancestor_block(h, succ).expect("Dom edges within entry, Ext within CFG")
229+
!= entry
230+
})
231+
})
232+
.collect::<Vec<_>>();
233+
// Move entry block contents into DFG.
234+
let dfg = h.add_node_with_parent(
235+
cfg_parent,
236+
DFG {
237+
signature: Signature::new(entry_blk.inputs.clone(), new_cfg_inputs.clone()),
238+
},
239+
);
240+
let [_, entry_output] = h.get_io(entry).unwrap();
241+
while let Some(n) = h.first_child(entry) {
242+
h.set_parent(n, dfg);
243+
}
244+
h.move_before_sibling(succ, entry);
245+
h.remove_node(entry);
241246

242-
unpack_before_output(h, entry_output, new_cfg_inputs.clone());
247+
unpack_before_output(h, entry_output, new_cfg_inputs.clone());
243248

244-
// Inputs to CFG go directly to DFG
245-
for inp in h.node_inputs(cfg_node).collect::<Vec<_>>() {
246-
for src in h.linked_outputs(cfg_node, inp).collect::<Vec<_>>() {
247-
h.connect(src.0, src.1, dfg, inp.index());
248-
}
249-
h.disconnect(cfg_node, inp);
249+
// Inputs to CFG go directly to DFG
250+
for inp in h.node_inputs(cfg_node).collect::<Vec<_>>() {
251+
for src in h.linked_outputs(cfg_node, inp).collect::<Vec<_>>() {
252+
h.connect(src.0, src.1, dfg, inp.index());
250253
}
254+
h.disconnect(cfg_node, inp);
255+
}
251256

252-
// Update input ports
253-
let cfg_ty = cfg_ty_mut(h, cfg_node);
254-
let inputs_to_add =
255-
new_cfg_inputs.len() as isize - cfg_ty.signature.input.len() as isize;
256-
cfg_ty.signature.input = new_cfg_inputs;
257-
h.add_ports(cfg_node, Direction::Incoming, inputs_to_add);
257+
// Update input ports
258+
let cfg_ty = cfg_ty_mut(h, cfg_node);
259+
let inputs_to_add = new_cfg_inputs.len() as isize - cfg_ty.signature.input.len() as isize;
260+
cfg_ty.signature.input = new_cfg_inputs;
261+
h.add_ports(cfg_node, Direction::Incoming, inputs_to_add);
258262

259-
// Wire outputs of DFG directly to CFG
260-
for src in h.node_outputs(dfg).collect::<Vec<_>>() {
261-
h.connect(dfg, src, cfg_node, src.index());
262-
}
263-
entry_dfg = Some(dfg);
263+
// Wire outputs of DFG directly to CFG
264+
for src in h.node_outputs(dfg).collect::<Vec<_>>() {
265+
h.connect(dfg, src, cfg_node, src.index());
266+
}
267+
// Inline DFG to ensure that any nonlocal (`Dom`) edges from it, become valid `Ext` edges
268+
for n in nonlocal_srcs {
269+
// With required Order edge. (Do this before inlining, in case n is Input.)
270+
h.add_other_edge(n, cfg_node);
264271
}
272+
entry_nodes_moved.extend(h.children(dfg).skip(2)); // Skip Input/Output nodes
273+
h.apply_patch(InlineDFG(dfg.into())).unwrap();
265274
}
266275
// 2. If the exit node has a single predecessor and that predecessor has no other successors...
267276
let mut exit_dfg = None;
@@ -323,7 +332,7 @@ pub fn normalize_cfg<H: HugrMut>(
323332
exit_dfg = Some(dfg);
324333
}
325334
Ok(NormalizeCFGResult::CFGPreserved {
326-
entry_dfg,
335+
entry_nodes_moved,
327336
exit_dfg,
328337
num_merged,
329338
})
@@ -776,13 +785,14 @@ mod test {
776785
let res = normalize_cfg(&mut h).unwrap();
777786
h.validate().unwrap();
778787
let NormalizeCFGResult::CFGPreserved {
779-
entry_dfg: Some(dfg),
788+
entry_nodes_moved,
780789
exit_dfg: None,
781790
num_merged: 0,
782791
} = res
783792
else {
784793
panic!("Unexpected result");
785794
};
795+
assert_eq!(entry_nodes_moved.len(), 4); // Noop, Const, LoadConstant, UnpackTuple
786796
assert_eq!(
787797
h.children(h.entrypoint())
788798
.map(|n| h.get_optype(n).tag())
@@ -793,20 +803,8 @@ mod test {
793803
let func_children = child_tags_ext_ids(&h, func);
794804
assert_eq!(
795805
func_children.into_iter().sorted().collect_vec(),
796-
["Cfg", "Dfg", "Input", "Output",]
797-
);
798-
assert_eq!(
799-
h.children(func)
800-
.filter(|n| h.get_optype(*n).is_dfg())
801-
.collect_vec(),
802-
[dfg]
803-
);
804-
assert_eq!(
805-
child_tags_ext_ids(&h, dfg)
806-
.into_iter()
807-
.sorted()
808-
.collect_vec(),
809806
[
807+
"Cfg",
810808
"Const",
811809
"Input",
812810
"LoadConst",
@@ -849,13 +847,14 @@ mod test {
849847
let res = normalize_cfg(&mut h).unwrap();
850848
h.validate().unwrap();
851849
let NormalizeCFGResult::CFGPreserved {
852-
entry_dfg: None,
850+
entry_nodes_moved,
853851
exit_dfg: Some(dfg),
854852
num_merged: 0,
855853
} = res
856854
else {
857855
panic!("Unexpected result");
858856
};
857+
assert_eq!(entry_nodes_moved, []);
859858
assert_eq!(
860859
h.children(h.entrypoint())
861860
.map(|n| h.get_optype(n).tag())
@@ -963,65 +962,47 @@ mod test {
963962
assert_eq!(h.get_parent(tail_pred.node()), Some(tail_b.node()));
964963

965964
let mut res = NormalizeCFGPass::default().run(&mut h).unwrap();
965+
966966
h.validate().unwrap();
967967
assert_eq!(
968968
res.remove(&inner.node()),
969969
Some(NormalizeCFGResult::CFGToDFG)
970970
);
971971
let Some(NormalizeCFGResult::CFGPreserved {
972-
entry_dfg,
972+
entry_nodes_moved,
973973
exit_dfg: Some(tail_dfg),
974974
num_merged: 0,
975975
}) = res.remove(&h.entrypoint())
976976
else {
977977
panic!("Unexpected result")
978978
};
979-
980979
assert!(res.is_empty());
981-
980+
assert_eq!(entry_nodes_moved.len(), 3);
981+
// Now contains only one CFG with one BB (self-loop)
982982
assert_eq!(
983983
h.nodes()
984984
.filter(|n| h.get_optype(*n).is_cfg())
985-
.collect_vec(),
986-
vec![h.entrypoint()]
985+
.exactly_one()
986+
.ok(),
987+
Some(h.entrypoint())
987988
);
988-
let [loop_, exit] = if nonlocal {
989-
let [entry, exit, loop_] = h.children(h.entrypoint()).collect_array().unwrap();
990-
assert_eq!(h.get_parent(entry_pred.node()), Some(entry));
991-
[loop_, exit]
992-
} else {
993-
h.children(h.entrypoint()).collect_array().unwrap()
994-
};
995-
996-
assert_eq!(h.output_neighbours(loop_).collect_vec(), [loop_, exit]);
997-
989+
let [entry, exit] = h.children(h.entrypoint()).collect_array().unwrap();
990+
assert_eq!(h.output_neighbours(entry).collect_vec(), [entry, exit]);
998991
// Inner CFG is now a DFG (and still sibling of entry_pred)...
999992
assert_eq!(h.get_parent(inner_pred), Some(inner.node()));
1000993
assert_eq!(h.get_optype(inner.node()).tag(), OpTag::Dfg);
1001994
assert_eq!(h.get_parent(inner.node()), h.get_parent(entry_pred.node()));
995+
1002996
// Predicates lifted appropriately...
1003997
let func = h.get_parent(h.entrypoint()).unwrap();
998+
assert_eq!(h.get_parent(entry_pred.node()), Some(func));
1004999

10051000
assert_eq!(h.get_parent(tail_pred.node()), Some(tail_dfg));
10061001
assert_eq!(h.get_optype(tail_dfg).tag(), OpTag::Dfg);
10071002
assert_eq!(h.get_parent(tail_dfg), Some(func));
1008-
let lifted_preds = if nonlocal {
1009-
assert!(entry_dfg.is_none());
1010-
// entry_pred not lifted, still connected to output
1011-
let [output] = h
1012-
.output_neighbours(entry_pred.node())
1013-
.collect_array()
1014-
.unwrap();
1015-
assert_eq!(h.get_optype(output).tag(), OpTag::Output);
1016-
vec![inner_pred.node(), tail_pred.node()]
1017-
} else {
1018-
assert_eq!(h.get_parent(entry_dfg.unwrap()), Some(func));
1019-
assert_eq!(h.get_parent(entry_pred.node()), entry_dfg);
1020-
vec![inner_pred.node(), entry_pred.node(), tail_pred.node()]
1021-
};
10221003

10231004
// ...and followed by UnpackTuple's
1024-
for n in lifted_preds {
1005+
for n in [inner_pred, entry_pred.node(), tail_pred.node()] {
10251006
let [unpack] = h.output_neighbours(n).collect_array().unwrap();
10261007
assert!(
10271008
h.get_optype(unpack)

0 commit comments

Comments
 (0)