Skip to content
92 changes: 75 additions & 17 deletions hugr-core/src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,38 +307,42 @@ impl Hugr {
/// preserve the indices.
pub fn canonicalize_nodes(&mut self, mut rekey: impl FnMut(Node, Node)) {
// Generate the ordered list of nodes
let mut ordered = Vec::with_capacity(self.num_nodes());
let root = self.module_root();
let mut new_entrypoint = self.entrypoint;
ordered.extend(self.as_mut().canonical_order(root));
let ordered = {
let mut v = Vec::with_capacity(self.num_nodes());
v.extend(self.canonical_order(self.module_root()));
v
};
let mut new_entrypoint = None;

// Permute the nodes in the graph to match the order.
//
// Invariant: All the elements before `position` are in the correct place.
for position in 0..ordered.len() {
// Find the element's location. If it originally came from a previous position
// then it has been swapped somewhere else, so we follow the permutation chain.
let pg_target = portgraph::NodeIndex::new(position);
let mut source: Node = ordered[position];

// The (old) entrypoint appears exactly once in `ordered`:
if source.into_portgraph() == self.entrypoint {
let old = new_entrypoint.replace(pg_target);
debug_assert!(old.is_none());
}

// Find the element's current location. If it originally came from an earlier
// position then it has been swapped somewhere else, so follow the permutation chain.
while position > source.index() {
source = ordered[source.index()];
}

let target: Node = portgraph::NodeIndex::new(position).into();
if target != source {
let pg_target = target.into_portgraph();
let pg_source = source.into_portgraph();
let pg_source = source.into_portgraph();
if pg_target != pg_source {
self.graph.swap_nodes(pg_target, pg_source);
self.op_types.swap(pg_target, pg_source);
self.hierarchy.swap_nodes(pg_target, pg_source);
rekey(source, target);

if source.into_portgraph() == self.entrypoint {
new_entrypoint = target.into_portgraph();
}
rekey(source, pg_target.into());
}
}
self.module_root = portgraph::NodeIndex::new(0);
self.entrypoint = new_entrypoint;
self.entrypoint = new_entrypoint.unwrap();

// Finish by compacting the copy nodes.
// The operation nodes will be left in place.
Expand Down Expand Up @@ -492,11 +496,17 @@ pub(crate) mod test {

use super::*;

use crate::builder::{Container, Dataflow, DataflowSubContainer, ModuleBuilder};
use crate::envelope::{EnvelopeError, PackageEncodingError};
use crate::extension::prelude::bool_t;
use crate::ops::OpaqueOp;
use crate::ops::handle::NodeHandle;
use crate::test_file;
use crate::types::Signature;
use cool_asserts::assert_matches;
use itertools::Either;
use portgraph::LinkView;
use rstest::rstest;

/// Check that two HUGRs are equivalent, up to node renumbering.
pub(crate) fn check_hugr_equality(lhs: &Hugr, rhs: &Hugr) {
Expand Down Expand Up @@ -557,7 +567,6 @@ pub(crate) mod test {
#[test]
fn io_node() {
use crate::builder::test::simple_dfg_hugr;
use cool_asserts::assert_matches;

let hugr = simple_dfg_hugr();
assert_matches!(hugr.get_io(hugr.entrypoint()), Some(_));
Expand Down Expand Up @@ -612,4 +621,53 @@ pub(crate) mod test {
);
assert_matches!(&hugr, Ok(_));
}

fn hugr_failing_2262() -> Hugr {
let sig = Signature::new(vec![bool_t(); 2], bool_t());
let mut mb = ModuleBuilder::new();
let mut fa = mb.define_function("a", sig.clone()).unwrap();
let mut dfg = fa.dfg_builder(sig.clone(), fa.input_wires()).unwrap();
// Add Call node - without a static target as we'll create that later
let call_op = ops::Call::try_new(sig.clone().into(), []).unwrap();
let call = dfg.add_dataflow_op(call_op, dfg.input_wires()).unwrap();
let dfg = dfg.finish_with_outputs(call.outputs()).unwrap();
fa.finish_with_outputs(dfg.outputs()).unwrap();
let fb = mb.define_function("b", sig).unwrap();
let [fst, _] = fb.input_wires_arr();
let fb = fb.finish_with_outputs([fst]).unwrap();
let mut h = mb.hugr().clone();

h.set_entrypoint(dfg.node()); // Entrypoint unused, but to highlight failing case
let static_in = h.get_optype(call.node()).static_input_port().unwrap();
let static_out = h.get_optype(fb.node()).static_output_port().unwrap();
assert_eq!(h.single_linked_output(call.node(), static_in), None);
h.disconnect(call.node(), static_in);
h.connect(fb.node(), static_out, call.node(), static_in);
h
}

#[rstest]
// Opening files is not supported in (isolated) miri
#[cfg_attr(not(miri), case(Either::Left(test_file!("hugr-1.hugr"))))]
#[cfg_attr(not(miri), case(Either::Left(test_file!("hugr-3.hugr"))))]
// Next was failing, https://github.com/CQCL/hugr/issues/2262:
#[case(Either::Right(hugr_failing_2262()))]
fn canonicalize_entrypoint(#[case] file_or_hugr: Either<&str, Hugr>) {
let hugr = match file_or_hugr {
Either::Left(file) => {
Hugr::load(BufReader::new(File::open(file).unwrap()), None).unwrap()
}
Either::Right(hugr) => hugr,
};
hugr.validate().unwrap();

for n in hugr.nodes() {
let mut h2 = hugr.clone();
h2.set_entrypoint(n);
if h2.validate().is_ok() {
h2.canonicalize_nodes(|_, _| {});
assert_eq!(hugr.get_optype(n), h2.entrypoint_optype());
}
}
}
}
Loading