-
Notifications
You must be signed in to change notification settings - Fork 13
fix: validation outside entrypoint, normalize_cfgs w/ nonlocal edges #2633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
2c647ca
15b92f7
d087d4d
0d048e2
485adbd
0908342
483b4ba
7cdd73c
51ea7cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -168,6 +168,16 @@ pub fn normalize_cfg<H: HugrMut>( | |
| _ => unreachable!(), // Checked at entry to normalize_cfg | ||
| } | ||
| } | ||
| let ancestor_block = |h: &H, mut n: H::Node| { | ||
| while let Some(p) = h.get_parent(n) { | ||
| if p == cfg_node { | ||
| return Some(n); | ||
| } | ||
| n = p; | ||
| } | ||
| None | ||
| }; | ||
|
|
||
| // Further normalizations with effects outside the CFG | ||
| let [entry, exit] = h.children(cfg_node).take(2).collect_array().unwrap(); | ||
| let entry_blk = h.get_optype(entry).as_dataflow_block().unwrap(); | ||
|
|
@@ -209,49 +219,58 @@ pub fn normalize_cfg<H: HugrMut>( | |
| } | ||
| // 1b. Move entry block outside/before the CFG into a DFG; its successor becomes the entry block. | ||
| let new_cfg_inputs = entry_blk.successor_input(0).unwrap(); | ||
| let dfg = h.add_node_with_parent( | ||
| cfg_parent, | ||
| DFG { | ||
| signature: Signature::new(entry_blk.inputs.clone(), new_cfg_inputs.clone()), | ||
| }, | ||
| ); | ||
| let [_, entry_output] = h.get_io(entry).unwrap(); | ||
| while let Some(n) = h.first_child(entry) { | ||
| h.set_parent(n, dfg); | ||
| } | ||
| h.move_before_sibling(succ, entry); | ||
| h.remove_node(entry); | ||
| // Look for nonlocal `Dom` edges from the entry block. | ||
| let has_nonlocals = h | ||
| .children(entry) | ||
| .flat_map(|n| h.output_neighbours(n)) | ||
| .any(|succ| ancestor_block(h, succ).unwrap() != entry); | ||
| if !has_nonlocals { | ||
| // Move entry block contents into DFG. | ||
| let dfg = h.add_node_with_parent( | ||
| cfg_parent, | ||
| DFG { | ||
| signature: Signature::new(entry_blk.inputs.clone(), new_cfg_inputs.clone()), | ||
| }, | ||
| ); | ||
| let [_, entry_output] = h.get_io(entry).unwrap(); | ||
| while let Some(n) = h.first_child(entry) { | ||
| h.set_parent(n, dfg); | ||
| } | ||
| h.move_before_sibling(succ, entry); | ||
| h.remove_node(entry); | ||
|
|
||
| unpack_before_output(h, entry_output, new_cfg_inputs.clone()); | ||
| unpack_before_output(h, entry_output, new_cfg_inputs.clone()); | ||
|
|
||
| // Inputs to CFG go directly to DFG | ||
| for inp in h.node_inputs(cfg_node).collect::<Vec<_>>() { | ||
| for src in h.linked_outputs(cfg_node, inp).collect::<Vec<_>>() { | ||
| h.connect(src.0, src.1, dfg, inp.index()); | ||
| // Inputs to CFG go directly to DFG | ||
| for inp in h.node_inputs(cfg_node).collect::<Vec<_>>() { | ||
| for src in h.linked_outputs(cfg_node, inp).collect::<Vec<_>>() { | ||
| h.connect(src.0, src.1, dfg, inp.index()); | ||
| } | ||
| h.disconnect(cfg_node, inp); | ||
| } | ||
| h.disconnect(cfg_node, inp); | ||
| } | ||
|
|
||
| // Update input ports | ||
| let cfg_ty = cfg_ty_mut(h, cfg_node); | ||
| let inputs_to_add = new_cfg_inputs.len() as isize - cfg_ty.signature.input.len() as isize; | ||
| cfg_ty.signature.input = new_cfg_inputs; | ||
| h.add_ports(cfg_node, Direction::Incoming, inputs_to_add); | ||
| // Update input ports | ||
| let cfg_ty = cfg_ty_mut(h, cfg_node); | ||
| let inputs_to_add = | ||
| new_cfg_inputs.len() as isize - cfg_ty.signature.input.len() as isize; | ||
| cfg_ty.signature.input = new_cfg_inputs; | ||
| h.add_ports(cfg_node, Direction::Incoming, inputs_to_add); | ||
|
|
||
| // Wire outputs of DFG directly to CFG | ||
| for src in h.node_outputs(dfg).collect::<Vec<_>>() { | ||
| h.connect(dfg, src, cfg_node, src.index()); | ||
| // Wire outputs of DFG directly to CFG | ||
| for src in h.node_outputs(dfg).collect::<Vec<_>>() { | ||
| h.connect(dfg, src, cfg_node, src.index()); | ||
| } | ||
| entry_dfg = Some(dfg); | ||
| } | ||
| entry_dfg = Some(dfg); | ||
| } | ||
| // 2. If the exit node has a single predecessor and that predecessor has no other successors... | ||
| let mut exit_dfg = None; | ||
| if let Some(pred) = h | ||
| .input_neighbours(exit) | ||
| .exactly_one() | ||
| .ok() | ||
| .filter(|pred| h.output_neighbours(*pred).count() == 1) | ||
| { | ||
| if let Some(pred) = h.input_neighbours(exit).exactly_one().ok().filter(|pred| { | ||
| h.output_neighbours(*pred).count() == 1 | ||
| && // Allow only if no node in `pred` has nonlocal inputs | ||
| h.children(*pred) | ||
| .all(|ch| h.input_neighbours(ch).all(|n| ancestor_block(h, n).is_none_or(|src| src == *pred))) | ||
|
||
| }) { | ||
| // Code in that predecessor can be moved outside (into a new DFG after the CFG), | ||
| // and the predecessor deleted | ||
| let [_, output] = h.get_io(pred).unwrap(); | ||
|
|
@@ -866,17 +885,19 @@ mod test { | |
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn nested_cfgs_pass() { | ||
| #[rstest] | ||
| fn nested_cfgs_pass(#[values(true, false)] nonlocal: bool) { | ||
| // --> Entry --> Loop --> Tail --> EXIT | ||
| // | / \ | ||
| // (E->X) \<-/ | ||
| let e = extension(); | ||
| let tst_op = e.instantiate_extension_op("Test", []).unwrap(); | ||
| let qqu = vec![qb_t(), qb_t(), usize_t()]; | ||
| let qqu = TypeRow::from(vec![qb_t(), qb_t(), usize_t()]); | ||
| let qq = TypeRow::from(vec![qb_t(); 2]); | ||
| let mut outer = CFGBuilder::new(inout_sig(qqu.clone(), vec![usize_t(), qb_t()])).unwrap(); | ||
| let mut entry = outer.entry_builder(vec![qq.clone()], type_row![]).unwrap(); | ||
| let mut entry = outer | ||
| .entry_builder(vec![qq.clone()], usize_t().into()) | ||
| .unwrap(); | ||
| let [q1, q2, u] = entry.input_wires_arr(); | ||
| let (inner, inner_pred) = { | ||
| let mut inner = entry | ||
|
|
@@ -897,38 +918,41 @@ mod test { | |
| .add_dataflow_op(Tag::new(0, vec![qq.clone()]), [q1, q2]) | ||
| .unwrap() | ||
| .outputs_arr(); | ||
| let entry = entry.finish_with_outputs(entry_pred, []).unwrap(); | ||
| let entry = entry.finish_with_outputs(entry_pred, [u]).unwrap(); | ||
|
|
||
| let loop_b = { | ||
| let qu = [qb_t(), usize_t()]; | ||
| let mut loop_b = outer | ||
| .block_builder(qq.clone(), [qb_t().into(), usize_t().into()], qb_t().into()) | ||
| .block_builder(qqu, qu.clone().map(TypeRow::from), Vec::from(qu).into()) | ||
| .unwrap(); | ||
| let [q1, q2] = loop_b.input_wires_arr(); | ||
| let [q1, q2, u_local] = loop_b.input_wires_arr(); | ||
| // u here is `dom` edge from entry block | ||
| let [pred] = loop_b | ||
| .add_dataflow_op(tst_op, [q1, u]) | ||
| .add_dataflow_op(tst_op, [q1, if nonlocal { u } else { u_local }]) | ||
| .unwrap() | ||
| .outputs_arr(); | ||
| loop_b.finish_with_outputs(pred, [q2]).unwrap() | ||
| loop_b.finish_with_outputs(pred, [q2, u_local]).unwrap() | ||
| }; | ||
| outer.branch(&entry, 0, &loop_b).unwrap(); | ||
| outer.branch(&loop_b, 0, &loop_b).unwrap(); | ||
|
|
||
| let (tail_b, tail_pred) = { | ||
| let uq = TypeRow::from(vec![usize_t(), qb_t()]); | ||
| let uqu = vec![usize_t(), qb_t(), usize_t()].into(); | ||
| let mut tail_b = outer | ||
| .block_builder(uq.clone(), vec![uq.clone()], type_row![]) | ||
| .block_builder(uqu, vec![uq.clone()], type_row![]) | ||
| .unwrap(); | ||
| let [u, q] = tail_b.input_wires_arr(); | ||
| let [u, q, _] = tail_b.input_wires_arr(); | ||
| let [br] = tail_b | ||
| .add_dataflow_op(Tag::new(0, vec![uq.clone()]), [u, q]) | ||
| .add_dataflow_op(Tag::new(0, vec![uq]), [u, q]) | ||
| .unwrap() | ||
| .outputs_arr(); | ||
| (tail_b.finish_with_outputs(br, []).unwrap(), br.node()) | ||
| }; | ||
| outer.branch(&loop_b, 1, &tail_b).unwrap(); | ||
| outer.branch(&tail_b, 0, &outer.exit_block()).unwrap(); | ||
| let mut h = outer.finish_hugr().unwrap(); | ||
| // Sanity checks: | ||
| assert_eq!( | ||
| h.get_parent(h.get_parent(inner_pred).unwrap()), | ||
| Some(inner.node()) | ||
|
|
@@ -943,36 +967,59 @@ mod test { | |
| Some(NormalizeCFGResult::CFGToDFG) | ||
| ); | ||
| let Some(NormalizeCFGResult::CFGPreserved { | ||
| entry_dfg: Some(entry_dfg), | ||
| entry_dfg, | ||
| exit_dfg: Some(tail_dfg), | ||
| num_merged: 0, | ||
| }) = res.remove(&h.entrypoint()) | ||
| else { | ||
| panic!("Unexpected result") | ||
| }; | ||
|
|
||
| assert!(res.is_empty()); | ||
| // Now contains only one CFG with one BB (self-loop) | ||
|
|
||
| assert_eq!( | ||
| h.nodes() | ||
| .filter(|n| h.get_optype(*n).is_cfg()) | ||
| .exactly_one() | ||
| .ok(), | ||
| Some(h.entrypoint()) | ||
| .collect_vec(), | ||
| vec![h.entrypoint()] | ||
| ); | ||
| let [entry, exit] = h.children(h.entrypoint()).collect_array().unwrap(); | ||
| assert_eq!(h.output_neighbours(entry).collect_vec(), [entry, exit]); | ||
| let [loop_, exit] = if nonlocal { | ||
| let [entry, exit, loop_] = h.children(h.entrypoint()).collect_array().unwrap(); | ||
| assert_eq!(h.get_parent(entry_pred.node()), Some(entry)); | ||
| [loop_, exit] | ||
| } else { | ||
| h.children(h.entrypoint()).collect_array().unwrap() | ||
| }; | ||
|
|
||
| assert_eq!(h.output_neighbours(loop_).collect_vec(), [loop_, exit]); | ||
|
|
||
| // Inner CFG is now a DFG (and still sibling of entry_pred)... | ||
| assert_eq!(h.get_parent(inner_pred), Some(inner.node())); | ||
| assert_eq!(h.get_optype(inner.node()).tag(), OpTag::Dfg); | ||
| assert_eq!(h.get_parent(inner.node()), h.get_parent(entry_pred.node())); | ||
| // Predicates lifted appropriately... | ||
| for (n, parent) in [(entry_pred.node(), entry_dfg), (tail_pred.node(), tail_dfg)] { | ||
| assert_eq!(h.get_parent(n), Some(parent)); | ||
| assert_eq!(h.get_optype(parent).tag(), OpTag::Dfg); | ||
| assert_eq!(h.get_parent(parent), h.get_parent(h.entrypoint())); | ||
| } | ||
| let func = h.get_parent(h.entrypoint()).unwrap(); | ||
|
|
||
| assert_eq!(h.get_parent(tail_pred.node()), Some(tail_dfg)); | ||
| assert_eq!(h.get_optype(tail_dfg).tag(), OpTag::Dfg); | ||
| assert_eq!(h.get_parent(tail_dfg), Some(func)); | ||
| let lifted_preds = if nonlocal { | ||
| assert!(entry_dfg.is_none()); | ||
| // entry_pred not lifted, still connected to output | ||
| let [output] = h | ||
| .output_neighbours(entry_pred.node()) | ||
| .collect_array() | ||
| .unwrap(); | ||
| assert_eq!(h.get_optype(output).tag(), OpTag::Output); | ||
| vec![inner_pred.node(), tail_pred.node()] | ||
| } else { | ||
| assert_eq!(h.get_parent(entry_dfg.unwrap()), Some(func)); | ||
| assert_eq!(h.get_parent(entry_pred.node()), entry_dfg); | ||
| vec![inner_pred.node(), entry_pred.node(), tail_pred.node()] | ||
| }; | ||
|
|
||
| // ...and followed by UnpackTuple's | ||
| for n in [inner_pred, entry_pred.node(), tail_pred.node()] { | ||
| for n in lifted_preds { | ||
| let [unpack] = h.output_neighbours(n).collect_array().unwrap(); | ||
| assert!( | ||
| h.get_optype(unpack) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.