Skip to content
14 changes: 5 additions & 9 deletions hugr-core/src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ impl<'a, H: HugrView> ValidationContext<'a, H> {
self.validate_node(node)?;
}

// Hierarchy and children. No type variables declared outside the root.
self.validate_subtree(self.hugr.entrypoint(), &[])?;
// Hierarchy and children. No type variables declared by the module root.
self.validate_subtree(self.hugr.module_root(), &[])?;

self.validate_linkage()?;
// In tests we take the opportunity to verify that the hugr
Expand Down Expand Up @@ -600,13 +600,9 @@ impl<'a, H: HugrView> ValidationContext<'a, H> {
}

// Check port connections.
//
// Root nodes are ignored, as they cannot have connected edges.
if node != self.hugr.entrypoint() {
for dir in Direction::BOTH {
for port in self.hugr.node_ports(node, dir) {
self.validate_port(node, port, op_type, var_decls)?;
}
for dir in Direction::BOTH {
for port in self.hugr.node_ports(node, dir) {
self.validate_port(node, port, op_type, var_decls)?;
}
}

Expand Down
8 changes: 4 additions & 4 deletions hugr-passes/src/linearize_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ impl LinearizeArrayPass {

#[cfg(test)]
mod test {
use hugr_core::builder::ModuleBuilder;
use hugr_core::builder::{FunctionBuilder, ModuleBuilder};
use hugr_core::extension::prelude::{ConstUsize, Noop};
use hugr_core::ops::handle::NodeHandle;
use hugr_core::ops::{Const, OpType};
Expand Down Expand Up @@ -287,7 +287,7 @@ mod test {
),
};
let sig = Signature::new(src, tgt);
let mut builder = DFGBuilder::new(sig).unwrap();
let mut builder = FunctionBuilder::new("main", sig).unwrap();
let [arr] = builder.input_wires_arr();
let op: OpType = match dir {
INTO => VArrayToArray::new(elem_ty.clone(), size).into(),
Expand All @@ -313,7 +313,7 @@ mod test {
#[case(value_array_type(2, Type::new_tuple(vec![usize_t(), value_array_type(4, usize_t())])))]
fn implicit_clone(#[case] array_ty: Type) {
let sig = Signature::new(array_ty.clone(), vec![array_ty; 2]);
let mut builder = DFGBuilder::new(sig).unwrap();
let mut builder = FunctionBuilder::new("main", sig).unwrap();
let [arr] = builder.input_wires_arr();
builder.set_outputs(vec![arr, arr]).unwrap();

Expand All @@ -329,7 +329,7 @@ mod test {
#[case(value_array_type(2, Type::new_tuple(vec![usize_t(), value_array_type(4, usize_t())])))]
fn implicit_discard(#[case] array_ty: Type) {
let sig = Signature::new(array_ty, Type::EMPTY_TYPEROW);
let mut builder = DFGBuilder::new(sig).unwrap();
let mut builder = FunctionBuilder::new("main", sig).unwrap();
builder.set_outputs(vec![]).unwrap();

let mut hugr = builder.finish_hugr().unwrap();
Expand Down
165 changes: 106 additions & 59 deletions hugr-passes/src/normalize_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,16 @@ pub fn normalize_cfg<H: HugrMut>(
_ => unreachable!(), // Checked at entry to normalize_cfg
}
}
let ancestor_block = |h: &H, mut n: H::Node| {
while let Some(p) = h.get_parent(n) {
if p == cfg_node {
return Some(n);
}
n = p;
}
None
};

// Further normalizations with effects outside the CFG
let [entry, exit] = h.children(cfg_node).take(2).collect_array().unwrap();
let entry_blk = h.get_optype(entry).as_dataflow_block().unwrap();
Expand Down Expand Up @@ -209,49 +219,58 @@ pub fn normalize_cfg<H: HugrMut>(
}
// 1b. Move entry block outside/before the CFG into a DFG; its successor becomes the entry block.
let new_cfg_inputs = entry_blk.successor_input(0).unwrap();
let dfg = h.add_node_with_parent(
cfg_parent,
DFG {
signature: Signature::new(entry_blk.inputs.clone(), new_cfg_inputs.clone()),
},
);
let [_, entry_output] = h.get_io(entry).unwrap();
while let Some(n) = h.first_child(entry) {
h.set_parent(n, dfg);
}
h.move_before_sibling(succ, entry);
h.remove_node(entry);
// Look for nonlocal `Dom` edges from the entry block.
let has_nonlocals = h
.children(entry)
.flat_map(|n| h.output_neighbours(n))
.any(|succ| ancestor_block(h, succ).unwrap() != entry);
if !has_nonlocals {
// Move entry block contents into DFG.
let dfg = h.add_node_with_parent(
cfg_parent,
DFG {
signature: Signature::new(entry_blk.inputs.clone(), new_cfg_inputs.clone()),
},
);
let [_, entry_output] = h.get_io(entry).unwrap();
while let Some(n) = h.first_child(entry) {
h.set_parent(n, dfg);
}
h.move_before_sibling(succ, entry);
h.remove_node(entry);

unpack_before_output(h, entry_output, new_cfg_inputs.clone());
unpack_before_output(h, entry_output, new_cfg_inputs.clone());

// Inputs to CFG go directly to DFG
for inp in h.node_inputs(cfg_node).collect::<Vec<_>>() {
for src in h.linked_outputs(cfg_node, inp).collect::<Vec<_>>() {
h.connect(src.0, src.1, dfg, inp.index());
// Inputs to CFG go directly to DFG
for inp in h.node_inputs(cfg_node).collect::<Vec<_>>() {
for src in h.linked_outputs(cfg_node, inp).collect::<Vec<_>>() {
h.connect(src.0, src.1, dfg, inp.index());
}
h.disconnect(cfg_node, inp);
}
h.disconnect(cfg_node, inp);
}

// Update input ports
let cfg_ty = cfg_ty_mut(h, cfg_node);
let inputs_to_add = new_cfg_inputs.len() as isize - cfg_ty.signature.input.len() as isize;
cfg_ty.signature.input = new_cfg_inputs;
h.add_ports(cfg_node, Direction::Incoming, inputs_to_add);
// Update input ports
let cfg_ty = cfg_ty_mut(h, cfg_node);
let inputs_to_add =
new_cfg_inputs.len() as isize - cfg_ty.signature.input.len() as isize;
cfg_ty.signature.input = new_cfg_inputs;
h.add_ports(cfg_node, Direction::Incoming, inputs_to_add);

// Wire outputs of DFG directly to CFG
for src in h.node_outputs(dfg).collect::<Vec<_>>() {
h.connect(dfg, src, cfg_node, src.index());
// Wire outputs of DFG directly to CFG
for src in h.node_outputs(dfg).collect::<Vec<_>>() {
h.connect(dfg, src, cfg_node, src.index());
}
entry_dfg = Some(dfg);
}
entry_dfg = Some(dfg);
}
// 2. If the exit node has a single predecessor and that predecessor has no other successors...
let mut exit_dfg = None;
if let Some(pred) = h
.input_neighbours(exit)
.exactly_one()
.ok()
.filter(|pred| h.output_neighbours(*pred).count() == 1)
{
if let Some(pred) = h.input_neighbours(exit).exactly_one().ok().filter(|pred| {
h.output_neighbours(*pred).count() == 1
&& // Allow only if no node in `pred` has nonlocal inputs
h.children(*pred)
.all(|ch| h.input_neighbours(ch).all(|n| ancestor_block(h, n).is_none_or(|src| src == *pred)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this quite an involved one-liner, could be good to split up/share code with the has_nonlocals check above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

split up a bit (and added dedup), but not really possible to share (much) - the two examine opposite directions and this makes Ext edges look very different

}) {
// Code in that predecessor can be moved outside (into a new DFG after the CFG),
// and the predecessor deleted
let [_, output] = h.get_io(pred).unwrap();
Expand Down Expand Up @@ -866,17 +885,19 @@ mod test {
Ok(())
}

#[test]
fn nested_cfgs_pass() {
#[rstest]
fn nested_cfgs_pass(#[values(true, false)] nonlocal: bool) {
// --> Entry --> Loop --> Tail --> EXIT
// | / \
// (E->X) \<-/
let e = extension();
let tst_op = e.instantiate_extension_op("Test", []).unwrap();
let qqu = vec![qb_t(), qb_t(), usize_t()];
let qqu = TypeRow::from(vec![qb_t(), qb_t(), usize_t()]);
let qq = TypeRow::from(vec![qb_t(); 2]);
let mut outer = CFGBuilder::new(inout_sig(qqu.clone(), vec![usize_t(), qb_t()])).unwrap();
let mut entry = outer.entry_builder(vec![qq.clone()], type_row![]).unwrap();
let mut entry = outer
.entry_builder(vec![qq.clone()], usize_t().into())
.unwrap();
let [q1, q2, u] = entry.input_wires_arr();
let (inner, inner_pred) = {
let mut inner = entry
Expand All @@ -897,38 +918,41 @@ mod test {
.add_dataflow_op(Tag::new(0, vec![qq.clone()]), [q1, q2])
.unwrap()
.outputs_arr();
let entry = entry.finish_with_outputs(entry_pred, []).unwrap();
let entry = entry.finish_with_outputs(entry_pred, [u]).unwrap();

let loop_b = {
let qu = [qb_t(), usize_t()];
let mut loop_b = outer
.block_builder(qq.clone(), [qb_t().into(), usize_t().into()], qb_t().into())
.block_builder(qqu, qu.clone().map(TypeRow::from), Vec::from(qu).into())
.unwrap();
let [q1, q2] = loop_b.input_wires_arr();
let [q1, q2, u_local] = loop_b.input_wires_arr();
// u here is `dom` edge from entry block
let [pred] = loop_b
.add_dataflow_op(tst_op, [q1, u])
.add_dataflow_op(tst_op, [q1, if nonlocal { u } else { u_local }])
.unwrap()
.outputs_arr();
loop_b.finish_with_outputs(pred, [q2]).unwrap()
loop_b.finish_with_outputs(pred, [q2, u_local]).unwrap()
};
outer.branch(&entry, 0, &loop_b).unwrap();
outer.branch(&loop_b, 0, &loop_b).unwrap();

let (tail_b, tail_pred) = {
let uq = TypeRow::from(vec![usize_t(), qb_t()]);
let uqu = vec![usize_t(), qb_t(), usize_t()].into();
let mut tail_b = outer
.block_builder(uq.clone(), vec![uq.clone()], type_row![])
.block_builder(uqu, vec![uq.clone()], type_row![])
.unwrap();
let [u, q] = tail_b.input_wires_arr();
let [u, q, _] = tail_b.input_wires_arr();
let [br] = tail_b
.add_dataflow_op(Tag::new(0, vec![uq.clone()]), [u, q])
.add_dataflow_op(Tag::new(0, vec![uq]), [u, q])
.unwrap()
.outputs_arr();
(tail_b.finish_with_outputs(br, []).unwrap(), br.node())
};
outer.branch(&loop_b, 1, &tail_b).unwrap();
outer.branch(&tail_b, 0, &outer.exit_block()).unwrap();
let mut h = outer.finish_hugr().unwrap();
// Sanity checks:
assert_eq!(
h.get_parent(h.get_parent(inner_pred).unwrap()),
Some(inner.node())
Expand All @@ -943,36 +967,59 @@ mod test {
Some(NormalizeCFGResult::CFGToDFG)
);
let Some(NormalizeCFGResult::CFGPreserved {
entry_dfg: Some(entry_dfg),
entry_dfg,
exit_dfg: Some(tail_dfg),
num_merged: 0,
}) = res.remove(&h.entrypoint())
else {
panic!("Unexpected result")
};

assert!(res.is_empty());
// Now contains only one CFG with one BB (self-loop)

assert_eq!(
h.nodes()
.filter(|n| h.get_optype(*n).is_cfg())
.exactly_one()
.ok(),
Some(h.entrypoint())
.collect_vec(),
vec![h.entrypoint()]
);
let [entry, exit] = h.children(h.entrypoint()).collect_array().unwrap();
assert_eq!(h.output_neighbours(entry).collect_vec(), [entry, exit]);
let [loop_, exit] = if nonlocal {
let [entry, exit, loop_] = h.children(h.entrypoint()).collect_array().unwrap();
assert_eq!(h.get_parent(entry_pred.node()), Some(entry));
[loop_, exit]
} else {
h.children(h.entrypoint()).collect_array().unwrap()
};

assert_eq!(h.output_neighbours(loop_).collect_vec(), [loop_, exit]);

// Inner CFG is now a DFG (and still sibling of entry_pred)...
assert_eq!(h.get_parent(inner_pred), Some(inner.node()));
assert_eq!(h.get_optype(inner.node()).tag(), OpTag::Dfg);
assert_eq!(h.get_parent(inner.node()), h.get_parent(entry_pred.node()));
// Predicates lifted appropriately...
for (n, parent) in [(entry_pred.node(), entry_dfg), (tail_pred.node(), tail_dfg)] {
assert_eq!(h.get_parent(n), Some(parent));
assert_eq!(h.get_optype(parent).tag(), OpTag::Dfg);
assert_eq!(h.get_parent(parent), h.get_parent(h.entrypoint()));
}
let func = h.get_parent(h.entrypoint()).unwrap();

assert_eq!(h.get_parent(tail_pred.node()), Some(tail_dfg));
assert_eq!(h.get_optype(tail_dfg).tag(), OpTag::Dfg);
assert_eq!(h.get_parent(tail_dfg), Some(func));
let lifted_preds = if nonlocal {
assert!(entry_dfg.is_none());
// entry_pred not lifted, still connected to output
let [output] = h
.output_neighbours(entry_pred.node())
.collect_array()
.unwrap();
assert_eq!(h.get_optype(output).tag(), OpTag::Output);
vec![inner_pred.node(), tail_pred.node()]
} else {
assert_eq!(h.get_parent(entry_dfg.unwrap()), Some(func));
assert_eq!(h.get_parent(entry_pred.node()), entry_dfg);
vec![inner_pred.node(), entry_pred.node(), tail_pred.node()]
};

// ...and followed by UnpackTuple's
for n in [inner_pred, entry_pred.node(), tail_pred.node()] {
for n in lifted_preds {
let [unpack] = h.output_neighbours(n).collect_array().unwrap();
assert!(
h.get_optype(unpack)
Expand Down
4 changes: 2 additions & 2 deletions hugr-passes/src/replace_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1117,8 +1117,8 @@ mod test {
where
GenericArrayValue<AK>: CustomConst,
{
let mut dfb =
DFGBuilder::new(inout_sig(type_row![], AK::ty(vals.len() as _, usize_t()))).unwrap();
let sig = inout_sig(type_row![], AK::ty(vals.len() as _, usize_t()));
let mut dfb = FunctionBuilder::new("main", sig).unwrap();
let c = dfb.add_load_value(GenericArrayValue::<AK>::new(
usize_t(),
vals.iter().map(|u| ConstUsize::new(*u).into()),
Expand Down
4 changes: 2 additions & 2 deletions hugr-passes/src/replace_types/linearize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ mod test {

use hugr_core::builder::{
BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
HugrBuilder, inout_sig,
FunctionBuilder, HugrBuilder, inout_sig,
};

use hugr_core::extension::prelude::{option_type, qb_t, usize_t};
Expand Down Expand Up @@ -912,7 +912,7 @@ mod test {
);

let build_hugr = |ty: Type| {
let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap();
let mut dfb = FunctionBuilder::new("main", Signature::new(ty.clone(), vec![])).unwrap();
let [inp] = dfb.input_wires_arr();
let drop_op = drop_ext
.instantiate_extension_op("drop", [ty.into()])
Expand Down
Loading