Skip to content
49 changes: 46 additions & 3 deletions hugr-core/src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ impl Hugr {
let mut ordered = Vec::with_capacity(self.num_nodes());
let root = self.module_root();
let mut new_entrypoint = self.entrypoint;
ordered.extend(self.as_mut().canonical_order(root));
ordered.extend(self.canonical_order(root));

// Permute the nodes in the graph to match the order.
//
Expand All @@ -332,8 +332,10 @@ impl Hugr {
self.hierarchy.swap_nodes(pg_target, pg_source);
rekey(source, target);

if source.into_portgraph() == self.entrypoint {
if source.into_portgraph() == new_entrypoint {
new_entrypoint = target.into_portgraph();
} else if target.into_portgraph() == new_entrypoint {
new_entrypoint = source.into_portgraph();
}
}
}
Expand Down Expand Up @@ -492,9 +494,13 @@ pub(crate) mod test {

use super::*;

use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder};
use crate::envelope::{EnvelopeError, PackageEncodingError};
use crate::extension::prelude::bool_t;
use crate::ops::OpaqueOp;
use crate::ops::handle::{FuncID, NodeHandle};
use crate::test_file;
use crate::types::Signature;
use cool_asserts::assert_matches;
use portgraph::LinkView;

Expand Down Expand Up @@ -557,7 +563,6 @@ pub(crate) mod test {
#[test]
fn io_node() {
use crate::builder::test::simple_dfg_hugr;
use cool_asserts::assert_matches;

let hugr = simple_dfg_hugr();
assert_matches!(hugr.get_io(hugr.entrypoint()), Some(_));
Expand Down Expand Up @@ -612,4 +617,42 @@ pub(crate) mod test {
);
assert_matches!(&hugr, Ok(_));
}

#[test]
fn canonicalize_nodes() {
let sig = Signature::new(vec![bool_t(); 2], bool_t());
let mut mb = ModuleBuilder::new();
let mut fa = mb.define_function("a", sig.clone()).unwrap();
// Recursive call requires getting the handle from builder
let f_id = FuncID::<true>::from(fa.container_node());
let mut dfg = fa.dfg_builder(sig.clone(), fa.input_wires()).unwrap();
let call = dfg.call(&f_id, &[], dfg.input_wires()).unwrap();
let dfg = dfg.finish_with_outputs(call.outputs()).unwrap();
let fa = fa.finish_with_outputs(dfg.outputs()).unwrap();
let fb = mb.define_function("b", sig).unwrap();
let [fst, _] = fb.input_wires_arr();
let fb = fb.finish_with_outputs([fst]).unwrap();
let mut h = mb.finish_hugr().unwrap();
h.set_entrypoint(dfg.node());
let static_in = h.get_optype(call.node()).static_input_port().unwrap();
let static_out = h.get_optype(fb.node()).static_output_port().unwrap();
assert_eq!(
h.single_linked_output(call.node(), static_in)
.map(|(n, _p)| n),
Some(fa.node())
);
h.disconnect(call.node(), static_in);
h.connect(fb.node(), static_out, call.node(), static_in);

fn find_dfgs(h: &Hugr) -> Vec<Node> {
h.nodes().filter(|n| h.get_optype(*n).is_dfg()).collect()
}
assert_eq!(find_dfgs(&h), [dfg.node()]);
assert_eq!(h.entrypoint(), dfg.node());

h.canonicalize_nodes(|_, _| ());
let [dfg] = find_dfgs(&h).try_into().unwrap();
// This was failing, https://github.com/CQCL/hugr/issues/2262:
assert_eq!(h.entrypoint(), dfg);
}
}
Loading