diff --git a/hugr-core/src/hugr/patch/simple_replace.rs b/hugr-core/src/hugr/patch/simple_replace.rs index eeba47409e..46e5bde205 100644 --- a/hugr-core/src/hugr/patch/simple_replace.rs +++ b/hugr-core/src/hugr/patch/simple_replace.rs @@ -126,11 +126,16 @@ impl SimpleReplacement { /// of `self`. /// /// The returned port will be in `replacement`, unless the wire in the - /// replacement is empty, in which case it will another `host` port. + /// replacement is empty and `boundary` is [`BoundaryMode::SnapToHost`] (the + /// default), in which case it will be another `host` port. If + /// [`BoundaryMode::IncludeIO`] is passed, the returned port will always + /// be in `replacement` even if it is invalid (i.e. it is an IO node in + /// the replacement). pub fn linked_replacement_output( &self, port: impl Into>, host: &impl HugrView, + boundary: BoundaryMode, ) -> Option> { let HostPort(node, port) = port.into(); let pos = self @@ -139,7 +144,7 @@ impl SimpleReplacement { .iter() .position(move |&(n, p)| host.linked_inputs(n, p).contains(&(node, port)))?; - Some(self.linked_replacement_output_by_position(pos, host)) + Some(self.linked_replacement_output_by_position(pos, host, boundary)) } /// The outgoing port linked to the i-th output boundary edge of `subgraph`. @@ -150,6 +155,7 @@ impl SimpleReplacement { &self, pos: usize, host: &impl HugrView, + boundary: BoundaryMode, ) -> BoundaryPort { debug_assert!(pos < self.subgraph().signature(host).output_count()); @@ -160,7 +166,7 @@ impl SimpleReplacement { .single_linked_output(repl_out, pos) .expect("valid dfg wire"); - if out_node != repl_inp { + if out_node != repl_inp || boundary == BoundaryMode::IncludeIO { BoundaryPort::Replacement(out_node, out_port) } else { let (in_node, in_port) = *self.subgraph.incoming_ports()[out_port.index()] @@ -207,11 +213,16 @@ impl SimpleReplacement { /// of `self`. /// /// The returned ports will be in `replacement`, unless the wires in the - /// replacement are empty, in which case they are other `host` ports. + /// replacement are empty and `boundary` is [`BoundaryMode::SnapToHost`] + /// (the default), in which case they will be other `host` ports. If + /// [`BoundaryMode::IncludeIO`] is passed, the returned ports will + /// always be in `replacement` even if they are invalid (i.e. they are + /// an IO node in the replacement). pub fn linked_replacement_inputs<'a>( &'a self, port: impl Into>, host: &'a impl HugrView, + boundary: BoundaryMode, ) -> impl Iterator> + 'a { let HostPort(node, port) = port.into(); let positions = self @@ -223,18 +234,16 @@ impl SimpleReplacement { host.single_linked_output(n, p).expect("valid dfg wire") == (node, port) }); - positions.flat_map(|pos| self.linked_replacement_inputs_by_position(pos, host)) + positions + .flat_map(move |pos| self.linked_replacement_inputs_by_position(pos, host, boundary)) } /// The incoming ports linked to the i-th input boundary edge of `subgraph`. - /// - /// The ports will be in `replacement` for all endpoints of the i-th input - /// wire that are not the output node of `replacement` and be in `host` - /// otherwise. fn linked_replacement_inputs_by_position( &self, pos: usize, host: &impl HugrView, + boundary: BoundaryMode, ) -> impl Iterator> { debug_assert!(pos < self.subgraph().signature(host).input_count()); @@ -242,7 +251,7 @@ impl SimpleReplacement { self.replacement .linked_inputs(repl_inp, pos) .flat_map(move |(in_node, in_port)| { - if in_node != repl_out { + if in_node != repl_out || boundary == BoundaryMode::IncludeIO { Either::Left(std::iter::once(BoundaryPort::Replacement(in_node, in_port))) } else { let (out_node, out_port) = self.subgraph.outgoing_ports()[in_port.index()]; @@ -316,7 +325,7 @@ impl SimpleReplacement { subgraph_outgoing_ports .enumerate() .flat_map(|(pos, subg_np)| { - self.linked_replacement_inputs_by_position(pos, host) + self.linked_replacement_inputs_by_position(pos, host, BoundaryMode::SnapToHost) .filter_map(move |np| Some((np.as_replacement()?, subg_np))) }) .map(|((repl_node, repl_port), (subgraph_node, subgraph_port))| { @@ -359,7 +368,7 @@ impl SimpleReplacement { .enumerate() .filter_map(|(pos, subg_all)| { let np = self - .linked_replacement_output_by_position(pos, host) + .linked_replacement_output_by_position(pos, host, BoundaryMode::SnapToHost) .as_replacement()?; Some((np, subg_all)) }) @@ -406,7 +415,7 @@ impl SimpleReplacement { .enumerate() .filter_map(|(pos, subg_all)| { Some(( - self.linked_replacement_output_by_position(pos, host) + self.linked_replacement_output_by_position(pos, host, BoundaryMode::SnapToHost) .as_host()?, subg_all, )) @@ -517,7 +526,8 @@ impl SimpleReplacement { SimpleReplacement::try_new(subgraph, new_host, replacement.clone()) } - /// Allows to get the [Self::invalidated_nodes] without requiring a [HugrView]. + /// Allows to get the [Self::invalidated_nodes] without requiring a + /// [HugrView]. pub fn invalidation_set(&self) -> impl Iterator { self.subgraph.nodes().iter().copied() } @@ -540,6 +550,24 @@ impl PatchVerification for SimpleReplacement { } } +/// In [`SimpleReplacement::replacement`], IO nodes marking the boundary will +/// not be valid nodes in the host after the replacement is applied. +/// +/// This enum allows specifying whether these invalid nodes on the boundary +/// should be returned or should be resolved to valid nodes in the host. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +pub enum BoundaryMode { + /// Only consider nodes that are valid after the replacement is applied. + /// + /// This means that nodes in hosts may be returned in places where nodes in + /// the replacement would be typically expected. + #[default] + SnapToHost, + /// Include all nodes, including potentially invalid ones (inputs and + /// outputs of replacements). + IncludeIO, +} + /// Result of applying a [`SimpleReplacement`]. pub struct Outcome { /// Map from Node in replacement to corresponding Node in the result Hugr @@ -652,7 +680,7 @@ pub(in crate::hugr::patch) mod test { ModuleBuilder, endo_sig, inout_sig, }; use crate::extension::prelude::{bool_t, qb_t}; - use crate::hugr::patch::simple_replace::Outcome; + use crate::hugr::patch::simple_replace::{BoundaryMode, Outcome}; use crate::hugr::patch::{BoundaryPort, HostPort, PatchVerification, ReplacementPort}; use crate::hugr::views::{HugrView, SiblingSubgraph}; use crate::hugr::{Hugr, HugrMut, Patch}; @@ -1145,7 +1173,11 @@ pub(in crate::hugr::patch) mod test { // Test linked_replacement_inputs with empty replacement let replacement_inputs: Vec<_> = repl - .linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr) + .linked_replacement_inputs( + (inp, OutgoingPort::from(0)), + &hugr, + BoundaryMode::SnapToHost, + ) .collect(); assert_eq!( @@ -1158,8 +1190,12 @@ pub(in crate::hugr::patch) mod test { // Test linked_replacement_output with empty replacement let replacement_output = (0..4) .map(|i| { - repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr) - .unwrap() + repl.linked_replacement_output( + (out, IncomingPort::from(i)), + &hugr, + BoundaryMode::SnapToHost, + ) + .unwrap() }) .collect_vec(); @@ -1191,7 +1227,11 @@ pub(in crate::hugr::patch) mod test { }; let replacement_inputs: Vec<_> = repl - .linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr) + .linked_replacement_inputs( + (inp, OutgoingPort::from(0)), + &hugr, + BoundaryMode::SnapToHost, + ) .collect(); assert_eq!( @@ -1203,8 +1243,12 @@ pub(in crate::hugr::patch) mod test { let replacement_output = (0..4) .map(|i| { - repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr) - .unwrap() + repl.linked_replacement_output( + (out, IncomingPort::from(i)), + &hugr, + BoundaryMode::SnapToHost, + ) + .unwrap() }) .collect_vec(); @@ -1241,7 +1285,11 @@ pub(in crate::hugr::patch) mod test { }; let replacement_inputs: Vec<_> = repl - .linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr) + .linked_replacement_inputs( + (inp, OutgoingPort::from(0)), + &hugr, + BoundaryMode::SnapToHost, + ) .collect(); assert_eq!( @@ -1257,8 +1305,12 @@ pub(in crate::hugr::patch) mod test { let replacement_output = (0..4) .map(|i| { - repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr) - .unwrap() + repl.linked_replacement_output( + (out, IncomingPort::from(i)), + &hugr, + BoundaryMode::SnapToHost, + ) + .unwrap() }) .collect_vec(); diff --git a/hugr-persistent/src/persistent_hugr.rs b/hugr-persistent/src/persistent_hugr.rs index 452cbff3df..bdca32ec1f 100644 --- a/hugr-persistent/src/persistent_hugr.rs +++ b/hugr-persistent/src/persistent_hugr.rs @@ -38,6 +38,9 @@ impl Commit { /// Requires a reference to the commit state space that the nodes in /// `replacement` refer to. /// + /// Use [`Self::try_new`] instead if the parents of the commit cannot be + /// inferred from the invalidation set of `replacement` alone. + /// /// The replacement must act on a non-empty subgraph, otherwise this /// function will return an [`InvalidCommit::EmptyReplacement`] error. /// @@ -47,20 +50,37 @@ impl Commit { pub fn try_from_replacement( replacement: PersistentReplacement, graph: &CommitStateSpace, + ) -> Result { + Self::try_new(replacement, [], graph) + } + + /// Create a new commit + /// + /// Requires a reference to the commit state space that the nodes in + /// `replacement` refer to. + /// + /// The returned commit will correspond to the application of `replacement` + /// and will be the child of the commits in `parents` as well as of all + /// the commits in the invalidation set of `replacement`. + /// + /// The replacement must act on a non-empty subgraph, otherwise this + /// function will return an [`InvalidCommit::EmptyReplacement`] error. + /// If any of the parents of the replacement are not in the commit state + /// space, this function will return an [`InvalidCommit::UnknownParent`] + /// error. + pub fn try_new( + replacement: PersistentReplacement, + parents: impl IntoIterator, + graph: &CommitStateSpace, ) -> Result { if replacement.subgraph().nodes().is_empty() { return Err(InvalidCommit::EmptyReplacement); } - let parent_ids = replacement.invalidation_set().map(|n| n.0).unique(); - let parents = parent_ids - .map(|id| { - if graph.contains_id(id) { - Ok(graph.get_commit(id).clone()) - } else { - Err(InvalidCommit::UnknownParent(id)) - } - }) - .collect::, _>>()?; + let repl_parents = get_parent_commits(&replacement, graph)?; + let parents = parents + .into_iter() + .chain(repl_parents) + .unique_by(|p| p.as_ptr()); let rc = RelRc::with_parents( replacement.into(), parents.into_iter().map(|p| (p.into(), ())), @@ -434,6 +454,8 @@ impl PersistentHugr { pub fn base_commit(&self) -> &Commit; /// Get the commit with ID `commit_id`. pub fn get_commit(&self, commit_id: CommitId) -> &Commit; + /// Check whether `commit_id` exists and return it. + pub fn try_get_commit(&self, commit_id: CommitId) -> Option<&Commit>; /// Get an iterator over all nodes inserted by `commit_id`. /// /// All nodes will be PatchNodes with commit ID `commit_id`. @@ -529,6 +551,32 @@ impl PersistentHugr { .expect("invalid port") .is_value() } + + pub(super) fn value_ports( + &self, + patch_node @ PatchNode(commit_id, node): PatchNode, + dir: Direction, + ) -> impl Iterator + '_ { + let hugr = self.commit_hugr(commit_id); + let ports = hugr.node_ports(node, dir); + ports.filter_map(move |p| self.is_value_port(patch_node, p).then_some((patch_node, p))) + } + + pub(super) fn output_value_ports( + &self, + patch_node: PatchNode, + ) -> impl Iterator + '_ { + self.value_ports(patch_node, Direction::Outgoing) + .map(|(n, p)| (n, p.as_outgoing().expect("unexpected port direction"))) + } + + pub(super) fn input_value_ports( + &self, + patch_node: PatchNode, + ) -> impl Iterator + '_ { + self.value_ports(patch_node, Direction::Incoming) + .map(|(n, p)| (n, p.as_incoming().expect("unexpected port direction"))) + } } impl IntoIterator for PersistentHugr { @@ -549,11 +597,11 @@ impl IntoIterator for PersistentHugr { /// among `children`. pub(crate) fn find_conflicting_node<'a>( commit_id: CommitId, - mut children: impl Iterator, + children: impl IntoIterator, ) -> Option { let mut all_invalidated = BTreeSet::new(); - children.find_map(|child| { + children.into_iter().find_map(|child| { let mut new_invalidated = child .invalidation_set() @@ -567,3 +615,18 @@ pub(crate) fn find_conflicting_node<'a>( new_invalidated.find(|&n| !all_invalidated.insert(n)) }) } + +fn get_parent_commits( + replacement: &PersistentReplacement, + graph: &CommitStateSpace, +) -> Result, InvalidCommit> { + let parent_ids = replacement.invalidation_set().map(|n| n.owner()).unique(); + parent_ids + .map(|id| { + graph + .try_get_commit(id) + .cloned() + .ok_or(InvalidCommit::UnknownParent(id)) + }) + .collect() +} diff --git a/hugr-persistent/src/state_space.rs b/hugr-persistent/src/state_space.rs index cdf5e734e8..0fb6d61da7 100644 --- a/hugr-persistent/src/state_space.rs +++ b/hugr-persistent/src/state_space.rs @@ -6,7 +6,15 @@ use delegate::delegate; use derive_more::From; use hugr_core::{ Direction, Hugr, HugrView, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement, - hugr::{self, internal::HugrInternals, patch::BoundaryPort}, + hugr::{ + self, + internal::HugrInternals, + patch::{ + BoundaryPort, + simple_replace::{BoundaryMode, InvalidReplacement}, + }, + views::{InvalidSignature, sibling_subgraph::InvalidSubgraph}, + }, ops::OpType, }; use itertools::{Either, Itertools}; @@ -25,10 +33,24 @@ pub type CommitId = relrc::NodeId; /// A HUGR node within a commit of the commit state space #[derive( - Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug, Hash, serde::Serialize, serde::Deserialize, + Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize, )] pub struct PatchNode(pub CommitId, pub Node); +impl PatchNode { + /// Get the commit ID of the commit that owns this node. + pub fn owner(&self) -> CommitId { + self.0 + } +} + +// Print out PatchNodes as `Node(x)@commit_hex` +impl std::fmt::Debug for PatchNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}@{}", self.1, self.0) + } +} + impl std::fmt::Display for PatchNode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{self:?}") @@ -231,6 +253,12 @@ impl CommitStateSpace { self.graph.get_node(commit_id).into() } + /// Check whether `commit_id` exists and return it. + pub fn try_get_commit(&self, commit_id: CommitId) -> Option<&Commit> { + self.contains_id(commit_id) + .then(|| self.get_commit(commit_id)) + } + /// Get an iterator over all commit IDs in the state space. pub fn all_commit_ids(&self) -> impl Iterator + Clone + '_ { let vec = self.graph.all_node_ids().collect_vec(); @@ -308,6 +336,11 @@ impl CommitStateSpace { /// Get the boundary inputs linked to `(node, port)` in `child`. /// + /// The returned ports will be ports on successors of the input node in the + /// `child` commit, unless (node, port) is connected to a passthrough wire + /// in `child` (i.e. a wire from input node to output node), in which + /// case they will be in one of the parents of `child`. + /// /// `child` should be a child commit of the owner of `node`. /// /// ## Panics @@ -319,6 +352,7 @@ impl CommitStateSpace { node: PatchNode, port: OutgoingPort, child: CommitId, + return_invalid: BoundaryMode, ) -> impl Iterator + '_ { assert!( self.is_boundary_edge(node, port, child), @@ -327,7 +361,7 @@ impl CommitStateSpace { let parent_hugrs = ParentsView::from_commit(child, self); let repl = self.replacement(child).expect("valid child commit"); - repl.linked_replacement_inputs((node, port), &parent_hugrs) + repl.linked_replacement_inputs((node, port), &parent_hugrs, return_invalid) .collect_vec() .into_iter() .map(move |np| match np { @@ -338,6 +372,11 @@ impl CommitStateSpace { /// Get the single boundary output linked to `(node, port)` in `child`. /// + /// The returned port will be a port on a predecessor of the output node in + /// the `child` commit, unless (node, port) is connected to a passthrough + /// wire in `child` (i.e. a wire from input node to output node), in + /// which case it will be in one of the parents of `child`. + /// /// `child` should be a child commit of the owner of `node` (or `None` will /// be returned). /// @@ -349,10 +388,11 @@ impl CommitStateSpace { node: PatchNode, port: IncomingPort, child: CommitId, + return_invalid: BoundaryMode, ) -> Option<(PatchNode, OutgoingPort)> { let parent_hugrs = ParentsView::from_commit(child, self); let repl = self.replacement(child)?; - match repl.linked_replacement_output((node, port), &parent_hugrs)? { + match repl.linked_replacement_output((node, port), &parent_hugrs, return_invalid)? { BoundaryPort::Host(patch_node, port) => (patch_node, port), BoundaryPort::Replacement(node, port) => (PatchNode(child, node), port), } @@ -370,28 +410,32 @@ impl CommitStateSpace { node: PatchNode, port: impl Into, child: CommitId, + return_invalid: BoundaryMode, ) -> impl Iterator + '_ { match port.into().as_directed() { Either::Left(incoming) => Either::Left( - self.linked_child_output(node, incoming, child) + self.linked_child_output(node, incoming, child, return_invalid) .into_iter() .map(|(node, port)| (node, port.into())), ), Either::Right(outgoing) => Either::Right( - self.linked_child_inputs(node, outgoing, child) + self.linked_child_inputs(node, outgoing, child, return_invalid) .map(|(node, port)| (node, port.into())), ), } } - /// Get the single output boundary port linked to `(node, port)` in a - /// parent of the commit of `node`. + /// Get the single output port linked to `(node, port)` in a parent of the + /// commit of `node`. + /// + /// The returned port belongs to the input boundary of the subgraph in + /// parent. /// /// ## Panics /// /// Panics if `(node, port)` is not connected to the input node in the /// commit of `node`, or if the node is not valid. - pub(crate) fn linked_parent_input( + pub fn linked_parent_input( &self, PatchNode(commit_id, node): PatchNode, port: IncomingPort, @@ -409,7 +453,17 @@ impl CommitStateSpace { repl.linked_host_input((node, port), &parent_hugrs).into() } - pub(crate) fn linked_parent_outputs( + /// Get the input ports linked to `(node, port)` in a parent of the commit + /// of `node`. + /// + /// The returned ports belong to the output boundary of the subgraph in + /// parent. + /// + /// ## Panics + /// + /// Panics if `(node, port)` is not connected to the output node in the + /// commit of `node`, or if the node is not valid. + pub fn linked_parent_outputs( &self, PatchNode(commit_id, node): PatchNode, port: OutgoingPort, @@ -561,4 +615,20 @@ pub enum InvalidCommit { /// The commit is an empty replacement. #[error("Not allowed: empty replacement")] EmptyReplacement, + + #[error("Invalid subgraph: {0}")] + /// The subgraph of the replacement is not convex. + InvalidSubgraph(#[from] InvalidSubgraph), + + /// The replacement of the commit is invalid. + #[error("Invalid replacement: {0}")] + InvalidReplacement(#[from] InvalidReplacement), + + /// The signature of the replacement is invalid. + #[error("Invalid signature: {0}")] + InvalidSignature(#[from] InvalidSignature), + + /// A wire has an unpinned port. + #[error("Incomplete wire: {0} is unpinned")] + IncompleteWire(PatchNode, Port), } diff --git a/hugr-persistent/src/tests.rs b/hugr-persistent/src/tests.rs index 0a6578aaa5..77b26be8ac 100644 --- a/hugr-persistent/src/tests.rs +++ b/hugr-persistent/src/tests.rs @@ -3,7 +3,7 @@ use std::collections::{BTreeMap, HashMap}; use derive_more::derive::{From, Into}; use hugr_core::{ IncomingPort, Node, OutgoingPort, SimpleReplacement, - builder::{DFGBuilder, Dataflow, DataflowHugr, inout_sig}, + builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig, inout_sig}, extension::prelude::bool_t, hugr::{Hugr, HugrView, patch::Patch, views::SiblingSubgraph}, ops::handle::NodeHandle, @@ -11,7 +11,10 @@ use hugr_core::{ }; use rstest::*; -use crate::{Commit, CommitStateSpace, PatchNode, Resolver, state_space::CommitId}; +use crate::{ + Commit, CommitStateSpace, PatchNode, PersistentHugr, PersistentReplacement, Resolver, + state_space::CommitId, +}; /// Creates a simple test Hugr with a DFG that contains a small boolean circuit /// @@ -291,6 +294,44 @@ pub(crate) fn test_state_space() -> (CommitStateSpace, [CommitId (state_space, [commit1, commit2, commit3, commit4]) } +#[fixture] +pub(super) fn persistent_hugr_empty_child() -> (PersistentHugr, [CommitId; 2], [PatchNode; 3]) { + let (triple_not_hugr, not_nodes) = { + let mut dfg_builder = DFGBuilder::new(endo_sig(bool_t())).unwrap(); + let [mut w] = dfg_builder.input_wires_arr(); + let not_nodes = [(); 3].map(|()| { + let handle = dfg_builder.add_dataflow_op(LogicOp::Not, vec![w]).unwrap(); + [w] = handle.outputs_arr(); + handle.node() + }); + ( + dfg_builder.finish_hugr_with_outputs([w]).unwrap(), + not_nodes, + ) + }; + let mut hugr = PersistentHugr::with_base(triple_not_hugr); + let empty_hugr = { + let dfg_builder = DFGBuilder::new(endo_sig(bool_t())).unwrap(); + let inputs = dfg_builder.input_wires(); + dfg_builder.finish_hugr_with_outputs(inputs).unwrap() + }; + let subg_nodes = [PatchNode(hugr.base(), not_nodes[1])]; + let repl = PersistentReplacement::try_new( + SiblingSubgraph::try_from_nodes(subg_nodes, &hugr).unwrap(), + &hugr, + empty_hugr, + ) + .unwrap(); + + let empty_commit = hugr.try_add_replacement(repl).unwrap(); + let base_commit = hugr.base(); + ( + hugr, + [base_commit, empty_commit], + not_nodes.map(|n| PatchNode(base_commit, n)), + ) +} + #[rstest] fn test_successive_replacements(test_state_space: (CommitStateSpace, [CommitId; 4])) { let (state_space, [commit1, commit2, _commit3, _commit4]) = test_state_space; diff --git a/hugr-persistent/src/walker.rs b/hugr-persistent/src/walker.rs index f346921493..1915f2b292 100644 --- a/hugr-persistent/src/walker.rs +++ b/hugr-persistent/src/walker.rs @@ -55,12 +55,24 @@ //! versions of the graph simultaneously, without having to materialize //! each version separately. -use std::{borrow::Cow, collections::BTreeSet}; +use std::{ + borrow::Cow, + collections::{BTreeMap, BTreeSet}, +}; use itertools::{Either, Itertools}; use thiserror::Error; -use hugr_core::{Direction, HugrView, Port}; +use hugr_core::{ + Direction, Hugr, HugrView, Port, PortIndex, + hugr::{ + patch::simple_replace::BoundaryMode, + views::{RootCheckable, SiblingSubgraph}, + }, + ops::handle::DfgID, +}; + +use crate::{Commit, PersistentReplacement}; use crate::{PersistentWire, PointerEqResolver, resolver::Resolver}; @@ -146,23 +158,35 @@ impl<'a, R: Resolver> Walker<'a, R> { } } else { let commit = self.state_space.get_commit(commit_id).clone(); - // TODO/Optimize: we should be able to check for an AlreadyPinned error at - // the same time that we check the ancestors are compatible in - // `PersistentHugr`, with e.g. a callback, instead of storing a backup - let backup = self.selected_commits.clone(); - self.selected_commits.try_add_commit(commit)?; - if let Some(&pinned_node) = self - .pinned_nodes - .iter() - .find(|&&n| !self.selected_commits.contains_node(n)) - { - self.selected_commits = backup; - return Err(PinNodeError::AlreadyPinned(pinned_node)); - } + self.try_select_commit(commit)?; } Ok(self.pinned_nodes.insert(node)) } + /// Add a commit to the selected commits of the Walker. + /// + /// Return the ID of the added commit if it was added successfully, or the + /// existing ID of the commit if it was already selected. + /// + /// Return an error if the commit is not compatible with the current set of + /// selected commits, or if the commit deletes an already pinned node. + pub fn try_select_commit(&mut self, commit: Commit) -> Result { + // TODO: we should be able to check for an AlreadyPinned error at + // the same time that we check the ancestors are compatible in + // `PersistentHugr`, with e.g. a callback, instead of storing a backup + let backup = self.selected_commits.clone(); + let commit_id = self.selected_commits.try_add_commit(commit)?; + if let Some(&pinned_node) = self + .pinned_nodes + .iter() + .find(|&&n| !self.selected_commits.contains_node(n)) + { + self.selected_commits = backup; + return Err(PinNodeError::AlreadyPinned(pinned_node)); + } + Ok(commit_id) + } + /// Expand the Walker by pinning a node connected to the given wire. /// /// To understand how Walkers are expanded, it is useful to understand how @@ -179,10 +203,14 @@ impl<'a, R: Resolver> Walker<'a, R> { /// walkers, which together cover the same space of possible HUGRs, each /// having a different additional node pinned. /// - /// Return an iterator over all possible [`Walker`]s that can be created by - /// pinning exactly one additional node connected to `wire`. Each returned - /// [`Walker`] represents a different alternative Hugr in the exploration - /// space. + /// If the wire is not complete yet, return an iterator over all possible + /// [`Walker`]s that can be created by pinning exactly one additional + /// node (or one additonal commit with an empty wire) connected to + /// `wire`. Each returned [`Walker`] represents a different alternative + /// Hugr in the exploration space. + /// + /// If the wire is already complete, return an iterator containing one + /// walker: the current walker unchanged. /// /// Optionally, the expansion can be restricted to only ports with the given /// direction (incoming or outgoing). @@ -199,6 +227,10 @@ impl<'a, R: Resolver> Walker<'a, R> { ) -> impl Iterator> + 'b { let dir = dir.into(); + if self.is_complete(wire, dir) { + return Either::Left(std::iter::once(self.clone())); + } + // Find unpinned ports on the wire (satisfying the direction constraint) let unpinned_ports = self.wire_unpinned_ports(wire, dir); @@ -206,20 +238,135 @@ impl<'a, R: Resolver> Walker<'a, R> { // commits) equivalent to currently unpinned ports. let pinnable_nodes = unpinned_ports .flat_map(|(node, port)| self.equivalent_descendant_ports(node, port)) - .map(|(n, _)| n) + .map(|(n, _, commits)| (n, commits)) .unique(); - pinnable_nodes.filter_map(|pinnable_node| { + let new_walkers = pinnable_nodes.filter_map(|(pinnable_node, new_commits)| { + let contains_new_commit = || { + new_commits + .iter() + .any(|&cm| !self.selected_commits.contains_id(cm)) + }; debug_assert!( - !self.is_pinned(pinnable_node), - "trying to pin already pinned node" + !self.is_pinned(pinnable_node) || contains_new_commit(), + "trying to pin already pinned node and no new commit is selected" ); - // Construct a new walker by pinning `pinnable_node` (if possible). - let mut new_walker = self.clone(); + // Update the selected commits to include the new commits. + let new_selected_commits = self + .state_space + .try_extract_hugr(self.selected_commits.all_commit_ids().chain(new_commits)) + .ok()?; + + // Make sure that the pinned nodes are still valid after including the new + // selected commits. + if self + .pinned_nodes + .iter() + .any(|&pnode| !new_selected_commits.contains_node(pnode)) + { + return None; + } + + // Construct a new walker and pin `pinnable_node`. + let mut new_walker = Walker { + state_space: self.state_space.clone(), + selected_commits: new_selected_commits, + pinned_nodes: self.pinned_nodes.clone(), + }; new_walker.try_pin_node(pinnable_node).ok()?; Some(new_walker) - }) + }); + + Either::Right(new_walkers) + } + + /// Create a new commit from a set of complete pinned wires and a + /// replacement. + /// + /// The subgraph of the commit is the subgraph given by the set of edges + /// in `wires`. `map_boundary` must provide a map from the boundary ports + /// of the subgraph to the inputs/output ports in `repl`. The returned port + /// must be of the opposite direction as the port passed as argument: + /// - an incoming subgraph port must be mapped to an outgoing port of the + /// input node of `repl` + /// - an outgoing subgraph port must be mapped to an incoming port of the + /// output node of `repl` + /// + /// ## Panics + /// + /// This will panic if repl is not a DFG graph. + pub fn try_create_commit( + &self, + wires: impl IntoIterator, + repl: impl RootCheckable, + map_boundary: impl Fn(PatchNode, Port) -> Port, + ) -> Result { + let mut wire_ports_incoming = BTreeSet::new(); + let mut wire_ports_outgoing = BTreeSet::new(); + let mut additional_parents = BTreeMap::new(); + + for w in wires { + if let Some((n, p)) = self.wire_unpinned_ports(&w, None).next() { + return Err(InvalidCommit::IncompleteWire(n, p)); + } + wire_ports_incoming.extend(w.all_incoming_ports(self.as_hugr_view())); + wire_ports_outgoing.extend(w.single_outgoing_port(self.as_hugr_view())); + for id in w.owners() { + let commit = self + .state_space + .try_get_commit(id) + .ok_or(InvalidCommit::UnknownParent(id))? + .clone(); + additional_parents.insert(id, commit); + } + } + + let mut all_nodes = BTreeSet::new(); + all_nodes.extend(wire_ports_incoming.iter().map(|&(n, _)| n)); + all_nodes.extend(wire_ports_outgoing.iter().map(|&(n, _)| n)); + + // (in/out) boundary: all in/out ports on the nodes of the wire, minus ports + // that are part of the wires + let incoming = all_nodes + .iter() + .flat_map(|&n| self.as_hugr_view().input_value_ports(n)) + .filter(|node_port| !wire_ports_incoming.contains(node_port)) + .map(|np| vec![np]) + .collect_vec(); + let outgoing = all_nodes + .iter() + .flat_map(|&n| self.as_hugr_view().output_value_ports(n)) + .filter(|node_port| !wire_ports_outgoing.contains(node_port)) + .collect_vec(); + + let repl = { + let mut repl = repl.try_into_checked().expect("replacement is not DFG"); + let new_inputs = incoming + .iter() + .flatten() // because of singleton-vec wrapping above + .map(|&(n, p)| { + map_boundary(n, p.into()) + .as_outgoing() + .expect("unexpected port direction returned by map_boundary") + .index() + }) + .collect_vec(); + let new_outputs = outgoing + .iter() + .map(|&(n, p)| { + map_boundary(n, p.into()) + .as_incoming() + .expect("unexpected port direction returned by map_boundary") + .index() + }) + .collect_vec(); + repl.map_function_type(&new_inputs, &new_outputs)?; + let subgraph = SiblingSubgraph::try_new(incoming, outgoing, self.as_hugr_view())?; + PersistentReplacement::try_new(subgraph, self.as_hugr_view(), repl.into_hugr())? + }; + + Commit::try_new(repl, additional_parents.into_values(), &self.state_space) } } @@ -250,46 +397,73 @@ impl Walker<'_, R> { &self.selected_commits } + /// Check if a node is pinned in the [`Walker`]. + pub fn is_pinned(&self, node: PatchNode) -> bool { + self.pinned_nodes.contains(&node) + } + + /// Iterate over all pinned nodes in the [`Walker`]. + pub fn pinned_nodes(&self) -> impl Iterator + '_ { + self.pinned_nodes.iter().copied() + } + /// Get all equivalent ports among the commits that are descendants of the /// current commit. /// /// The ports in the returned iterator will be in the same direction as - /// `port`. - fn equivalent_descendant_ports(&self, node: PatchNode, port: Port) -> Vec<(PatchNode, Port)> { + /// `port`. For each equivalent port, also return the set of empty commits + /// that were visited to find it. + fn equivalent_descendant_ports( + &self, + node: PatchNode, + port: Port, + ) -> Vec<(PatchNode, Port, BTreeSet)> { // Now, perform a BFS to find all equivalent ports - let mut all_ports = vec![(node, port)]; + let mut all_ports = vec![(node, port, BTreeSet::new())]; let mut index = 0; while index < all_ports.len() { - let (node, port) = all_ports[index]; + let (node, port, empty_commits) = all_ports[index].clone(); index += 1; for (child_id, (opp_node, opp_port)) in self.state_space.children_at_boundary_port(node, port) { - match opp_port.as_directed() { - Either::Left(in_port) => { - if let Some((n, p)) = self - .state_space - .linked_child_output(opp_node, in_port, child_id) - { - all_ports.push((n, p.into())); - } - } - Either::Right(out_port) => { - all_ports.extend( - self.state_space - .linked_child_inputs(opp_node, out_port, child_id) - .map(|(n, p)| (n, p.into())), - ); + for (node, port) in self.state_space.linked_child_ports( + opp_node, + opp_port, + child_id, + BoundaryMode::SnapToHost, + ) { + let mut empty_commits = empty_commits.clone(); + if node.0 != child_id { + empty_commits.insert(child_id); } + all_ports.push((node, port, empty_commits)); } } } all_ports } +} - pub(crate) fn is_pinned(&self, node: PatchNode) -> bool { - self.pinned_nodes.contains(&node) +#[cfg(test)] +impl Walker<'_, R> { + // Check walker equality by comparing pointers to the state space and + // other fields. Only for testing purposes. + fn component_wise_ptr_eq(&self, other: &Self) -> bool { + std::ptr::eq(self.state_space.as_ref(), other.state_space.as_ref()) + && self.pinned_nodes == other.pinned_nodes + && BTreeSet::from_iter(self.selected_commits.all_commit_ids()) + == BTreeSet::from_iter(other.selected_commits.all_commit_ids()) + } + + /// Check if the Walker cannot be expanded further, i.e. expanding it + /// returns the same Walker. + fn no_more_expansion(&self, wire: &PersistentWire, dir: impl Into>) -> bool { + let Some([new_walker]) = self.expand(wire, dir).collect_array() else { + return false; + }; + new_walker.component_wise_ptr_eq(self) } } @@ -366,12 +540,23 @@ impl From> for Cow<'_, CommitStateSpace> { #[cfg(test)] mod tests { + use std::collections::BTreeSet; + + use hugr_core::{ + Direction, HugrView, IncomingPort, OutgoingPort, + builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig}, + extension::prelude::bool_t, + std_extensions::logic::LogicOp, + }; + use itertools::Itertools; use rstest::rstest; - use crate::{state_space::CommitId, tests::test_state_space}; - use hugr_core::{IncomingPort, OutgoingPort, std_extensions::logic::LogicOp}; - use super::*; + use crate::{ + PersistentHugr, Walker, + state_space::CommitId, + tests::{persistent_hugr_empty_child, test_state_space}, + }; #[rstest] fn test_walker_base_or_child_expansion(test_state_space: (CommitStateSpace, [CommitId; 4])) { @@ -392,7 +577,8 @@ mod tests { let in0 = walker.get_wire(base_and_node, IncomingPort::from(0)); // a single incoming port (already pinned) => no more expansion - assert!(walker.expand(&in0, Direction::Incoming).next().is_none()); + assert!(walker.no_more_expansion(&in0, Direction::Incoming)); + // commit 2 cannot be applied, because AND is pinned // => only base commit, or commit1 let out_walkers = walker.expand(&in0, Direction::Outgoing).collect_vec(); @@ -401,7 +587,7 @@ mod tests { // new wire is complete (and thus cannot be expanded) let in0 = new_walker.get_wire(base_and_node, IncomingPort::from(0)); assert!(new_walker.is_complete(&in0, None)); - assert!(new_walker.expand(&in0, None).next().is_none()); + assert!(new_walker.no_more_expansion(&in0, None)); // all nodes on wire are pinned let (not_node, _) = in0.single_outgoing_port(new_walker.as_hugr_view()).unwrap(); @@ -458,9 +644,8 @@ mod tests { assert!(walker.is_pinned(not4_node)); let not4_out = walker.get_wire(not4_node, OutgoingPort::from(0)); - let expanded_out = walker.expand(¬4_out, Direction::Outgoing).collect_vec(); // a single outgoing port (already pinned) => no more expansion - assert!(expanded_out.is_empty()); + assert!(walker.no_more_expansion(¬4_out, Direction::Outgoing)); // Three options: // - AND gate from base @@ -485,7 +670,7 @@ mod tests { // new wire is complete (and thus cannot be expanded) let not4_out = new_walker.get_wire(not4_node, OutgoingPort::from(0)); assert!(new_walker.is_complete(¬4_out, None)); - assert!(new_walker.expand(¬4_out, None).next().is_none()); + assert!(new_walker.no_more_expansion(¬4_out, None)); // all nodes on wire are pinned let (next_node, _) = not4_out @@ -546,4 +731,142 @@ mod tests { assert_eq!(new_and_node.0, commit2); assert_eq!(in_port, 1.into()); } + + /// Test that the walker handles empty replacements correctly. + /// + /// The base hugr is a sequence of 3 NOT gates, with a single input/output + /// boolean. A single replacement exists in the state space, which replaces + /// the middle NOT gate with nothing. + #[rstest] + fn test_walk_over_empty_repls( + persistent_hugr_empty_child: (PersistentHugr, [CommitId; 2], [PatchNode; 3]), + ) { + let (hugr, [base_commit, empty_commit], [not0, not1, not2]) = persistent_hugr_empty_child; + let walker = Walker::from_pinned_node(not0, hugr.as_state_space()); + + let not0_outwire = walker.get_wire(not0, OutgoingPort::from(0)); + let expanded_wires = walker + .expand(¬0_outwire, Direction::Incoming) + .collect_vec(); + + assert_eq!(expanded_wires.len(), 2); + + let connected_inports: BTreeSet<_> = expanded_wires + .iter() + .map(|new_walker| { + let wire = new_walker.get_wire(not0, OutgoingPort::from(0)); + wire.all_incoming_ports(new_walker.as_hugr_view()) + .exactly_one() + .ok() + .unwrap() + }) + .collect(); + + assert_eq!( + connected_inports, + BTreeSet::from_iter([(not1, IncomingPort::from(0)), (not2, IncomingPort::from(0))]) + ); + + let traversed_commits: BTreeSet> = expanded_wires + .iter() + .map(|new_walker| { + let wire = new_walker.get_wire(not0, OutgoingPort::from(0)); + wire.owners().collect() + }) + .collect(); + + assert_eq!( + traversed_commits, + BTreeSet::from_iter([ + BTreeSet::from_iter([base_commit]), + BTreeSet::from_iter([base_commit, empty_commit]) + ]) + ); + } + + #[rstest] + fn test_create_commit_over_empty( + persistent_hugr_empty_child: (PersistentHugr, [CommitId; 2], [PatchNode; 3]), + ) { + let (hugr, [base_commit, empty_commit], [not0, _not1, not2]) = persistent_hugr_empty_child; + let mut walker = Walker { + state_space: hugr.as_state_space().into(), + selected_commits: hugr.clone(), + pinned_nodes: BTreeSet::from_iter([not0]), + }; + + // wire: Not0 -> Not2 (bridging over Not1) + let wire = walker.get_wire(not0, OutgoingPort::from(0)); + walker = walker.expand(&wire, None).exactly_one().ok().unwrap(); + let wire = walker.get_wire(not0, OutgoingPort::from(0)); + assert!(walker.is_complete(&wire, None)); + + let empty_hugr = { + let dfg_builder = DFGBuilder::new(endo_sig(bool_t())).unwrap(); + let inputs = dfg_builder.input_wires(); + dfg_builder.finish_hugr_with_outputs(inputs).unwrap() + }; + let commit = walker + .try_create_commit(vec![wire], empty_hugr, |node, port| { + assert_eq!(port.index(), 0); + assert!([not0, not2].contains(&node)); + match port.direction() { + Direction::Incoming => OutgoingPort::from(0).into(), + Direction::Outgoing => IncomingPort::from(0).into(), + } + }) + .unwrap(); + + let mut new_state_space = hugr.as_state_space().to_owned(); + let commit_id = new_state_space.try_add_commit(commit.clone()).unwrap(); + assert_eq!( + new_state_space.parents(commit_id).collect::>(), + BTreeSet::from_iter([base_commit, empty_commit]) + ); + + let res_hugr: PersistentHugr = PersistentHugr::from_commit(commit); + assert!(res_hugr.validate().is_ok()); + + // should be an empty DFG hugr + // module root + function def + func I/O nodes + DFG entrypoint + I/O nodes + assert_eq!(res_hugr.num_nodes(), 1 + 1 + 2 + 1 + 2); + } + + /// Test that the walker handles empty replacements correctly. + /// + /// The base hugr is a sequence of 3 NOT gates, with a single input/output + /// boolean. A single replacement exists in the state space, which replaces + /// the middle NOT gate with nothing. + /// + /// In this test, we pin both the first and third NOT and see if the walker + /// suggests to possible wires as outgoing from the first NOT. This tests + /// the edge case in which a new wire already has all its ports pinned. + #[rstest] + fn test_walk_over_two_pinned_nodes( + persistent_hugr_empty_child: (PersistentHugr, [CommitId; 2], [PatchNode; 3]), + ) { + let (hugr, [base_commit, empty_commit], [not0, _not1, not2]) = persistent_hugr_empty_child; + let mut walker = Walker::from_pinned_node(not0, hugr.as_state_space()); + assert!(walker.try_pin_node(not2).unwrap()); + + let not0_outwire = walker.get_wire(not0, OutgoingPort::from(0)); + let expanded_walkers = walker.expand(¬0_outwire, Direction::Incoming); + + let expanded_wires: BTreeSet> = expanded_walkers + .map(|new_walker| { + new_walker + .get_wire(not0, OutgoingPort::from(0)) + .owners() + .collect() + }) + .collect(); + + assert_eq!( + expanded_wires, + BTreeSet::from_iter([ + BTreeSet::from_iter([base_commit]), + BTreeSet::from_iter([base_commit, empty_commit]) + ]) + ); + } } diff --git a/hugr-persistent/src/wire.rs b/hugr-persistent/src/wire.rs index a0ffb3e142..a84d4e6923 100644 --- a/hugr-persistent/src/wire.rs +++ b/hugr-persistent/src/wire.rs @@ -1,6 +1,9 @@ use std::collections::{BTreeSet, VecDeque}; -use hugr_core::{Direction, HugrView, IncomingPort, OutgoingPort, Port, Wire}; +use hugr_core::{ + Direction, HugrView, IncomingPort, OutgoingPort, Port, Wire, + hugr::patch::simple_replace::BoundaryMode, +}; use itertools::Itertools; use crate::{CommitId, PatchNode, PersistentHugr, Resolver, Walker}; @@ -48,6 +51,12 @@ impl CommitWire { fn commit_id(&self) -> CommitId { self.0.node().0 } + + delegate::delegate! { + to self.0 { + fn node(&self) -> PatchNode; + } + } } /// A node in a commit of a [`PersistentHugr`] is either a valid node of the @@ -111,10 +120,15 @@ impl PersistentWire { // ports in the child commit that deleted the node. for (opp_node, opp_port) in commit_hugr.linked_ports(node, port) { let opp_node = per_hugr.to_persistent_node(opp_node, commit_id); - for (child_node, child_port) in per_hugr - .as_state_space() - .linked_child_ports(opp_node, opp_port, deleted_by) + for (child_node, child_port) in + per_hugr.as_state_space().linked_child_ports( + opp_node, + opp_port, + deleted_by, + BoundaryMode::IncludeIO, + ) { + debug_assert_eq!(child_node.owner(), deleted_by); let w = CommitWire::from_connected_port( child_node, child_port, per_hugr, ); @@ -167,6 +181,11 @@ impl PersistentWire { all_ports_impl(self.wires.iter().copied(), dir.into(), hugr) } + /// All commit IDs that the wire traverses. + pub fn owners(&self) -> impl Iterator { + self.wires.iter().map(|w| w.node().owner()).unique() + } + /// Consume the wire and return all ports attached to a wire in `hugr`. /// /// All ports returned are on nodes that are contained in `hugr`. diff --git a/hugr-persistent/tests/persistent_walker_example.rs b/hugr-persistent/tests/persistent_walker_example.rs index faf637df57..dfed4c143a 100644 --- a/hugr-persistent/tests/persistent_walker_example.rs +++ b/hugr-persistent/tests/persistent_walker_example.rs @@ -2,24 +2,24 @@ use std::collections::{BTreeSet, VecDeque}; -use itertools::Itertools; +use itertools::{Either, Itertools}; use hugr_core::{ - Hugr, HugrView, PortIndex, SimpleReplacement, + Hugr, HugrView, IncomingPort, OutgoingPort, Port, PortIndex, builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig}, extension::prelude::qb_t, - hugr::views::SiblingSubgraph, + ops::OpType, types::EdgeKind, }; -use hugr_persistent::{CommitStateSpace, PersistentReplacement, PersistentWire, Walker}; +use hugr_persistent::{Commit, CommitStateSpace, PersistentWire, Walker}; /// The maximum commit depth that we will consider in this example -const MAX_COMMITS: usize = 2; +const MAX_COMMITS: usize = 4; // We define a HUGR extension within this file, with CZ and H gates. Normally, // you would use an existing extension (e.g. as provided by tket2). -use walker_example_extension::{cz_gate, h_gate}; +use walker_example_extension::cz_gate; mod walker_example_extension { use std::sync::Arc; @@ -33,10 +33,6 @@ mod walker_example_extension { use super::*; - fn one_qb_func() -> PolyFuncTypeRV { - FuncValueType::new_endo(qb_t()).into() - } - fn two_qb_func() -> PolyFuncTypeRV { FuncValueType::new_endo(vec![qb_t(), qb_t()]).into() } @@ -48,15 +44,6 @@ mod walker_example_extension { EXTENSION_ID, Version::new(0, 0, 0), |extension, extension_ref| { - extension - .add_op( - OpName::new_inline("H"), - "Hadamard".into(), - one_qb_func(), - extension_ref, - ) - .unwrap(); - extension .add_op( OpName::new_inline("CZ"), @@ -74,10 +61,6 @@ mod walker_example_extension { static ref EXTENSION: Arc = extension(); } - pub fn h_gate() -> ExtensionOp { - EXTENSION.instantiate_extension_op("H", []).unwrap() - } - pub fn cz_gate() -> ExtensionOp { EXTENSION.instantiate_extension_op("CZ", []).unwrap() } @@ -108,15 +91,12 @@ fn dfg_hugr() -> Hugr { builder.finish_hugr_with_outputs(vec![q0, q1, q2]).unwrap() } -// TODO: currently empty replacements are buggy, so we have temporarily added -// a single Hadamard gate on each qubit. -fn empty_2qb_hugr() -> Hugr { - let mut builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()])).unwrap(); - let [q0, q1] = builder.input_wires_arr(); - let h0 = builder.add_dataflow_op(h_gate(), vec![q0]).unwrap(); - let [q0] = h0.outputs_arr(); - let h1 = builder.add_dataflow_op(h_gate(), vec![q1]).unwrap(); - let [q1] = h1.outputs_arr(); +fn empty_2qb_hugr(flip_args: bool) -> Hugr { + let builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()])).unwrap(); + let [mut q0, mut q1] = builder.input_wires_arr(); + if flip_args { + (q0, q1) = (q1, q0); + } builder.finish_hugr_with_outputs(vec![q0, q1]).unwrap() } @@ -182,7 +162,7 @@ fn build_state_space() -> CommitStateSpace { for subwalker in walker.expand(&wire, None) { assert!( subwalker.as_hugr_view().contains_node(pinned_node), - "pinned node is deleted" + "pinned node {pinned_node:?} is deleted", ); wire_queue.push_back((subwalker.get_wire(pinned_node, pinned_port), subwalker)); } @@ -206,22 +186,16 @@ fn build_state_space() -> CommitStateSpace { continue; } - let Some(repl) = create_replacement(wire, &walker) else { + let Some(new_commit) = create_commit(wire, &walker) else { continue; }; assert_eq!( - repl.subgraph() - .nodes() - .iter() - .copied() - .collect::>(), + new_commit.deleted_nodes().collect::>(), patch_nodes ); - state_space - .try_add_replacement(repl) - .expect("repl acts on non-empty subgraph"); + state_space.try_add_commit(new_commit).unwrap(); // enqueue new wires added by the replacement // (this will also add a lot of already visited wires, but they will @@ -233,7 +207,7 @@ fn build_state_space() -> CommitStateSpace { state_space } -fn create_replacement(wire: PersistentWire, walker: &Walker) -> Option { +fn create_commit(wire: PersistentWire, walker: &Walker) -> Option { let hugr = walker.clone().into_persistent_hugr(); let (out_node, _) = wire .single_outgoing_port(&hugr) @@ -258,14 +232,27 @@ fn create_replacement(wire: PersistentWire, walker: &Walker) -> Option { // out_node and in_node act on the same qubits - // => cancel out the two CZ gates - ( - empty_2qb_hugr(), - SiblingSubgraph::try_from_nodes([out_node, in_node], &hugr).ok()?, - ) + // => replace the two CZ gates with the empty 2qb HUGR + + // If the two CZ gates have flipped port ordering, we need to insert + // a swap gate + let add_swap = all_edges[0][0].index() != all_edges[0][1].index(); + + // Get the wires between the two CZ gates + let wires = all_edges + .into_iter() + .map(|[out_port, _]| walker.get_wire(out_node, out_port)); + + // Create the commit + walker.try_create_commit(wires, empty_2qb_hugr(add_swap), |_, port| { + // the incoming/outgoing ports of the subgraph map trivially to the empty 2qb + // HUGR + let dir = port.direction(); + Port::new(dir.reverse(), port.index()) + }) } 1 => { // out_node and in_node share just one qubit @@ -275,32 +262,49 @@ fn create_replacement(wire: PersistentWire, walker: &Walker) -> Option establish which qubit is shared between the two CZ gates let [out_port, in_port] = all_edges.into_iter().exactly_one().unwrap(); - let shared_qb_on_out_node = out_port.index(); - let shared_qb_on_in_node = in_port.index(); - - let subgraph = SiblingSubgraph::try_new( - vec![ - vec![(out_node, shared_qb_on_out_node.into())], - vec![(out_node, (1 - shared_qb_on_out_node).into())], - vec![(in_node, (1 - shared_qb_on_in_node).into())], - ], - vec![ - (in_node, shared_qb_on_in_node.into()), - (out_node, (1 - shared_qb_on_out_node).into()), - (in_node, (1 - shared_qb_on_in_node).into()), - ], - &hugr, - ) - .ok()?; - - (repl_hugr, subgraph) + let shared_qb_out = out_port.index(); + let shared_qb_in = in_port.index(); + + walker.try_create_commit([wire], repl_hugr, |node, port| { + // map the incoming/outgoing ports of the subgraph to the replacement as + // follows: + // - the first qubit is the one that is shared between the two CZ gates + // - the second qubit only touches the first CZ (out_node) + // - the third qubit only touches the second CZ (in_node) + match port.as_directed() { + Either::Left(incoming) => { + let in_boundary: [(_, IncomingPort); 3] = [ + (out_node, shared_qb_out.into()), + (out_node, (1 - shared_qb_out).into()), + (in_node, (1 - shared_qb_in).into()), + ]; + let out_index = in_boundary + .iter() + .position(|&(n, p)| n == node && p == incoming) + .expect("invalid input port"); + OutgoingPort::from(out_index).into() + } + Either::Right(outgoing) => { + let out_boundary: [(_, OutgoingPort); 3] = [ + (in_node, shared_qb_in.into()), + (out_node, (1 - shared_qb_out).into()), + (in_node, (1 - shared_qb_in).into()), + ]; + let in_index = out_boundary + .iter() + .position(|&(n, p)| n == node && p == outgoing) + .expect("invalid output port"); + IncomingPort::from(in_index).into() + } + } + }) } _ => unreachable!(), - }; - - SimpleReplacement::try_new(subgraph, &hugr, repl_hugr).ok() + } + .ok() } +#[ignore = "takes 10s (todo: optimise)"] #[test] fn walker_example() { let state_space = build_state_space(); @@ -326,26 +330,16 @@ fn walker_example() { ); } - // assert_eq!(state_space.all_commit_ids().count(), 13); - let empty_commits = state_space .all_commit_ids() - // .filter(|&id| state_space.commit_hugr(id).num_nodes() == 3) - .filter(|&id| { - state_space - .inserted_nodes(id) - .filter(|&n| state_space.get_optype(n) == &h_gate().into()) - .count() - == 2 - }) + .filter(|&id| state_space.inserted_nodes(id).count() == 0) .collect_vec(); // there should be a combination of three empty commits that are compatible // and such that the resulting HUGR is empty let mut empty_hugr = None; - // for cs in empty_commits.iter().combinations(3) { - for cs in empty_commits.iter().combinations(2) { - let cs = cs.into_iter().copied().collect_vec(); + for cs in empty_commits.iter().combinations(3) { + let cs = cs.into_iter().copied(); if let Ok(hugr) = state_space.try_extract_hugr(cs) { empty_hugr = Some(hugr); } @@ -353,16 +347,23 @@ fn walker_example() { let empty_hugr = empty_hugr.unwrap().to_hugr(); - // assert_eq!(empty_hugr.num_nodes(), 3); - - let n_cz = empty_hugr - .nodes() - .filter(|&n| empty_hugr.get_optype(n) == &cz_gate().into()) - .count(); - let n_h = empty_hugr - .nodes() - .filter(|&n| empty_hugr.get_optype(n) == &h_gate().into()) - .count(); - assert_eq!(n_cz, 2); - assert_eq!(n_h, 4); + // The empty hugr should have 7 nodes: + // module root, funcdef, 2 func IO, DFG root, 2 DFG IO + assert_eq!(empty_hugr.num_nodes(), 7); + assert_eq!( + empty_hugr + .nodes() + .filter(|&n| { + !matches!( + empty_hugr.get_optype(n), + OpType::Input(_) + | OpType::Output(_) + | OpType::FuncDefn(_) + | OpType::Module(_) + | OpType::DFG(_) + ) + }) + .count(), + 0 + ); }