diff --git a/hugr-core/src/core.rs b/hugr-core/src/core.rs index 06366822ae..300210223a 100644 --- a/hugr-core/src/core.rs +++ b/hugr-core/src/core.rs @@ -7,7 +7,7 @@ pub use itertools::Either; use derive_more::From; use itertools::Either::{Left, Right}; -use crate::hugr::HugrError; +use crate::{HugrView, hugr::HugrError}; /// A handle to a node in the HUGR. #[derive( @@ -219,17 +219,55 @@ impl Wire { Self(node, port.into()) } - /// The node that this wire is connected to. + /// Create a new wire from a node and a port that is connected to the wire. + /// + /// If `port` is an incoming port, the wire is traversed to find the unique + /// outgoing port that is connected to the wire. Otherwise, this is + /// equivalent to constructing a wire using [`Wire::new`]. + /// + /// ## Panics + /// + /// This will panic if the wire is not connected to a unique outgoing port. + #[inline] + pub fn from_connected_port( + node: N, + port: impl Into, + hugr: &impl HugrView, + ) -> Self { + let (node, outgoing) = match port.into().as_directed() { + Either::Left(incoming) => hugr + .single_linked_output(node, incoming) + .expect("invalid dfg port"), + Either::Right(outgoing) => (node, outgoing), + }; + Self::new(node, outgoing) + } + + /// The node of the unique outgoing port that the wire is connected to. #[inline] pub fn node(&self) -> N { self.0 } - /// The output port that this wire is connected to. + /// The unique outgoing port that the wire is connected to. #[inline] pub fn source(&self) -> OutgoingPort { self.1 } + + /// Get all ports connected to the wire. + /// + /// Return a chained iterator of the unique outgoing port, followed by all + /// incoming ports connected to the wire. + pub fn all_connected_ports<'h, H: HugrView>( + &self, + hugr: &'h H, + ) -> impl Iterator + use<'h, N, H> { + let node = self.node(); + let out_port = self.source(); + + std::iter::once((node, out_port.into())).chain(hugr.linked_ports(node, out_port)) + } } impl std::fmt::Display for Wire { diff --git a/hugr-core/src/hugr/views/root_checked/dfg.rs b/hugr-core/src/hugr/views/root_checked/dfg.rs index 1b485dcc7c..684f3f0037 100644 --- a/hugr-core/src/hugr/views/root_checked/dfg.rs +++ b/hugr-core/src/hugr/views/root_checked/dfg.rs @@ -28,7 +28,7 @@ macro_rules! impl_dataflow_parent_methods { .expect("valid DFG graph") } - /// Rewire the inputs and outputs of the DFG to modify its signature. + /// Rewire the inputs and outputs of the nested DFG to modify its signature. /// /// Reorder the outgoing resp. incoming wires at the input resp. output /// node of the DFG to modify the signature of the DFG HUGR. This will diff --git a/hugr-persistent/src/lib.rs b/hugr-persistent/src/lib.rs index 1a320c88fd..a91ddc97d2 100644 --- a/hugr-persistent/src/lib.rs +++ b/hugr-persistent/src/lib.rs @@ -72,11 +72,13 @@ mod resolver; pub mod state_space; mod trait_impls; pub mod walker; +mod wire; pub use persistent_hugr::{Commit, PersistentHugr}; pub use resolver::{PointerEqResolver, Resolver, SerdeHashResolver}; pub use state_space::{CommitId, CommitStateSpace, InvalidCommit, PatchNode}; -pub use walker::{PinnedWire, Walker}; +pub use walker::Walker; +pub use wire::PersistentWire; /// A replacement operation that can be applied to a [`PersistentHugr`]. pub type PersistentReplacement = hugr_core::SimpleReplacement; diff --git a/hugr-persistent/src/persistent_hugr.rs b/hugr-persistent/src/persistent_hugr.rs index 789bcba1c7..452cbff3df 100644 --- a/hugr-persistent/src/persistent_hugr.rs +++ b/hugr-persistent/src/persistent_hugr.rs @@ -1,12 +1,12 @@ use std::{ - collections::{BTreeSet, HashMap, VecDeque}, + collections::{BTreeSet, HashMap}, mem, vec, }; use delegate::delegate; use derive_more::derive::From; use hugr_core::{ - Hugr, HugrView, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement, + Direction, Hugr, HugrView, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement, hugr::patch::{Patch, simple_replace}, }; use itertools::{Either, Itertools}; @@ -394,68 +394,14 @@ impl PersistentHugr { /// /// Panics if `node` is not in `self` (in particular if it is deleted) or if /// `port` is not a value port in `node`. - pub(crate) fn get_single_outgoing_port( + pub(crate) fn single_outgoing_port( &self, node: PatchNode, port: impl Into, ) -> (PatchNode, OutgoingPort) { - let mut in_port = port.into(); - let PatchNode(commit_id, mut in_node) = node; - - assert!(self.is_value_port(node, in_port), "not a dataflow wire"); - assert!(self.contains_node(node), "node not in self"); - - let hugr = self.commit_hugr(commit_id); - let (mut out_node, mut out_port) = hugr - .single_linked_output(in_node, in_port) - .map(|(n, p)| (PatchNode(commit_id, n), p)) - .expect("invalid HUGR"); - - // invariant: (out_node, out_port) -> (in_node, in_port) is a boundary - // edge, i.e. it never is the case that both are deleted by the same - // child commit - loop { - let commit_id = out_node.0; - - if let Some(deleted_by) = self.find_deleting_commit(out_node) { - (out_node, out_port) = self - .state_space - .linked_child_output(PatchNode(commit_id, in_node), in_port, deleted_by) - .expect("valid boundary edge"); - // update (in_node, in_port) - (in_node, in_port) = { - let new_commit_id = out_node.0; - let hugr = self.commit_hugr(new_commit_id); - hugr.linked_inputs(out_node.1, out_port) - .find(|&(n, _)| { - self.find_deleting_commit(PatchNode(commit_id, n)).is_none() - }) - .expect("out_node is connected to output node (which is never deleted)") - }; - } else if self - .replacement(commit_id) - .is_some_and(|repl| repl.get_replacement_io()[0] == out_node.1) - { - // out_node is an input node - (out_node, out_port) = self - .as_state_space() - .linked_parent_input(PatchNode(commit_id, in_node), in_port); - // update (in_node, in_port) - (in_node, in_port) = { - let new_commit_id = out_node.0; - let hugr = self.commit_hugr(new_commit_id); - hugr.linked_inputs(out_node.1, out_port) - .find(|&(n, _)| { - self.find_deleting_commit(PatchNode(new_commit_id, n)) - == Some(commit_id) - }) - .expect("boundary edge must connect out_node to deleted node") - }; - } else { - // valid outgoing node! - return (out_node, out_port); - } - } + let w = self.get_wire(node, port.into()); + w.single_outgoing_port(self) + .expect("found invalid dfg wire") } /// All incoming ports that the given outgoing port is attached to. @@ -464,99 +410,14 @@ impl PersistentHugr { /// /// Panics if `out_node` is not in `self` (in particular if it is deleted) /// or if `out_port` is not a value port in `out_node`. - pub(crate) fn get_all_incoming_ports( + pub(crate) fn all_incoming_ports( &self, out_node: PatchNode, out_port: OutgoingPort, ) -> impl Iterator { - assert!( - self.is_value_port(out_node, out_port), - "not a dataflow wire" - ); - assert!(self.contains_node(out_node), "node not in self"); - - let mut visited = BTreeSet::new(); - // enqueue the outport and initialise the set of valid incoming ports - // to the valid incoming ports in this commit - let mut queue = VecDeque::from([(out_node, out_port)]); - let mut valid_incoming_ports = BTreeSet::from_iter( - self.commit_hugr(out_node.0) - .linked_inputs(out_node.1, out_port) - .map(|(in_node, in_port)| (PatchNode(out_node.0, in_node), in_port)) - .filter(|(in_node, _)| self.contains_node(*in_node)), - ); - - // A simple BFS across the commit history to find all equivalent incoming ports. - while let Some((out_node, out_port)) = queue.pop_front() { - if !visited.insert((out_node, out_port)) { - continue; - } - let commit_id = out_node.0; - let hugr = self.commit_hugr(commit_id); - let out_deleted_by = self.find_deleting_commit(out_node); - let curr_repl_out = { - let repl = self.replacement(commit_id); - repl.map(|r| r.get_replacement_io()[1]) - }; - // incoming ports are of interest to us if - // (i) they are connected to the output of a replacement (then there will be a - // linked port in a parent commit), or - // (ii) they are deleted by a child commit and are not equal to the out_node - // (then there will be a linked port in a child commit) - let is_linked_to_output = curr_repl_out.is_some_and(|curr_repl_out| { - hugr.linked_inputs(out_node.1, out_port) - .any(|(in_node, _)| in_node == curr_repl_out) - }); - - let deleted_by_child: BTreeSet<_> = hugr - .linked_inputs(out_node.1, out_port) - .filter(|(in_node, _)| Some(in_node) != curr_repl_out.as_ref()) - .filter_map(|(in_node, _)| { - self.find_deleting_commit(PatchNode(commit_id, in_node)) - .filter(|other_deleted_by| - // (out_node, out_port) -> (in_node, in_port) is a boundary edge - // into the child commit `other_deleted_by` - (Some(other_deleted_by) != out_deleted_by.as_ref())) - }) - .collect(); - - // Convert an incoming port to the unique outgoing port that it is linked to - let to_outgoing_port = |(PatchNode(commit_id, in_node), in_port)| { - let hugr = self.commit_hugr(commit_id); - let (out_node, out_port) = hugr - .single_linked_output(in_node, in_port) - .expect("valid dfg wire"); - (PatchNode(commit_id, out_node), out_port) - }; - - if is_linked_to_output { - // Traverse boundary to parent(s) - let new_ins = self - .as_state_space() - .linked_parent_outputs(out_node, out_port); - for (in_node, in_port) in new_ins { - if self.contains_node(in_node) { - valid_incoming_ports.insert((in_node, in_port)); - } - queue.push_back(to_outgoing_port((in_node, in_port))); - } - } - - for child in deleted_by_child { - // Traverse boundary to `child` - let new_ins = self - .as_state_space() - .linked_child_inputs(out_node, out_port, child); - for (in_node, in_port) in new_ins { - if self.contains_node(in_node) { - valid_incoming_ports.insert((in_node, in_port)); - } - queue.push_back(to_outgoing_port((in_node, in_port))); - } - } - } - - valid_incoming_ports.into_iter() + let w = self.get_wire(out_node, out_port); + w.into_all_ports(self, Direction::Incoming) + .map(|(node, port)| (node, port.as_incoming().unwrap())) } delegate! { @@ -578,7 +439,7 @@ impl PersistentHugr { /// All nodes will be PatchNodes with commit ID `commit_id`. pub fn inserted_nodes(&self, commit_id: CommitId) -> impl Iterator + '_; /// Get the replacement for `commit_id`. - fn replacement(&self, commit_id: CommitId) -> Option<&SimpleReplacement>; + pub(crate) fn replacement(&self, commit_id: CommitId) -> Option<&SimpleReplacement>; /// Get the Hugr inserted by `commit_id`. /// /// This is either the replacement Hugr of a [`CommitData::Replacement`] or @@ -628,7 +489,11 @@ impl PersistentHugr { .unique() } - fn find_deleting_commit(&self, node @ PatchNode(commit_id, _): PatchNode) -> Option { + /// Get the child commit that deletes `node`. + pub(crate) fn find_deleting_commit( + &self, + node @ PatchNode(commit_id, _): PatchNode, + ) -> Option { let mut children = self.state_space.children(commit_id); children.find(move |&child_id| { let child = self.get_commit(child_id); @@ -636,6 +501,12 @@ impl PersistentHugr { }) } + /// Convert a node ID specific to a commit HUGR into a patch node in the + /// [`PersistentHugr`]. + pub(crate) fn to_persistent_node(&self, node: Node, commit_id: CommitId) -> PatchNode { + PatchNode(commit_id, node) + } + /// Check if a patch node is in the PersistentHugr, that is, it belongs to /// a commit in the state space and is not deleted by any child commit. pub fn contains_node(&self, PatchNode(commit_id, node): PatchNode) -> bool { diff --git a/hugr-persistent/src/state_space.rs b/hugr-persistent/src/state_space.rs index 23bf1b1761..cdf5e734e8 100644 --- a/hugr-persistent/src/state_space.rs +++ b/hugr-persistent/src/state_space.rs @@ -9,7 +9,7 @@ use hugr_core::{ hugr::{self, internal::HugrInternals, patch::BoundaryPort}, ops::OpType, }; -use itertools::Itertools; +use itertools::{Either, Itertools}; use relrc::{HistoryGraph, RelRc}; use thiserror::Error; @@ -42,7 +42,8 @@ mod hidden { /// other commits apply), or a [`PersistentReplacement`] /// /// This is a "unnamable" type: we do not expose this struct publicly in our - /// API, but we can still use it in public trait bounds (see [`Resolver`](crate::resolver::Resolver)). + /// API, but we can still use it in public trait bounds (see + /// [`Resolver`](crate::resolver::Resolver)). #[derive(Debug, Clone, From)] pub enum CommitData { Base(Hugr), @@ -307,10 +308,12 @@ impl CommitStateSpace { /// Get the boundary inputs linked to `(node, port)` in `child`. /// + /// `child` should be a child commit of the owner of `node`. + /// /// ## Panics /// - /// Panics if `(node, port)` is not a boundary edge, or if `child` is not - /// a valid commit ID. + /// Panics if `(node, port)` is not a boundary edge, if `child` is not + /// a valid commit ID or if it is the base commit. pub(crate) fn linked_child_inputs( &self, node: PatchNode, @@ -335,6 +338,9 @@ impl CommitStateSpace { /// Get the single boundary output linked to `(node, port)` in `child`. /// + /// `child` should be a child commit of the owner of `node` (or `None` will + /// be returned). + /// /// ## Panics /// /// Panics if `child` is not a valid commit ID. @@ -345,7 +351,7 @@ impl CommitStateSpace { child: CommitId, ) -> Option<(PatchNode, OutgoingPort)> { let parent_hugrs = ParentsView::from_commit(child, self); - let repl = self.replacement(child).expect("valid child commit"); + let repl = self.replacement(child)?; match repl.linked_replacement_output((node, port), &parent_hugrs)? { BoundaryPort::Host(patch_node, port) => (patch_node, port), BoundaryPort::Replacement(node, port) => (PatchNode(child, node), port), @@ -353,6 +359,31 @@ impl CommitStateSpace { .into() } + /// Get the boundary ports linked to `(node, port)` in `child`. + /// + /// `child` should be a child commit of the owner of `node`. + /// + /// See [`Self::linked_child_inputs`] and [`Self::linked_child_output`] for + /// more details. + pub(crate) fn linked_child_ports( + &self, + node: PatchNode, + port: impl Into, + child: CommitId, + ) -> impl Iterator + '_ { + match port.into().as_directed() { + Either::Left(incoming) => Either::Left( + self.linked_child_output(node, incoming, child) + .into_iter() + .map(|(node, port)| (node, port.into())), + ), + Either::Right(outgoing) => Either::Right( + self.linked_child_inputs(node, outgoing, child) + .map(|(node, port)| (node, port.into())), + ), + } + } + /// Get the single output boundary port linked to `(node, port)` in a /// parent of the commit of `node`. /// @@ -399,6 +430,33 @@ impl CommitStateSpace { .into_iter() } + /// Get the ports linked to `(node, port)` in a parent of the commit of + /// `node`. + /// + /// See [`Self::linked_parent_input`] and [`Self::linked_parent_outputs`] + /// for more details. + /// + /// ## Panics + /// + /// Panics if `(node, port)` is not connected to an IO node in the commit + /// of `node`, or if the node is not valid. + pub fn linked_parent_ports( + &self, + node: PatchNode, + port: impl Into, + ) -> impl Iterator + '_ { + match port.into().as_directed() { + Either::Left(incoming) => { + let (node, port) = self.linked_parent_input(node, incoming); + Either::Left(std::iter::once((node, port.into()))) + } + Either::Right(outgoing) => Either::Right( + self.linked_parent_outputs(node, outgoing) + .map(|(node, port)| (node, port.into())), + ), + } + } + /// Get the replacement for `commit_id`. pub(crate) fn replacement(&self, commit_id: CommitId) -> Option<&SimpleReplacement> { let commit = self.get_commit(commit_id); diff --git a/hugr-persistent/src/trait_impls.rs b/hugr-persistent/src/trait_impls.rs index f4d4fac759..1d3331f269 100644 --- a/hugr-persistent/src/trait_impls.rs +++ b/hugr-persistent/src/trait_impls.rs @@ -175,11 +175,11 @@ impl HugrView for PersistentHugr { } else { match port.as_directed() { Either::Left(incoming) => { - let (out_node, out_port) = self.get_single_outgoing_port(node, incoming); + let (out_node, out_port) = self.single_outgoing_port(node, incoming); ret_ports.push((out_node, out_port.into())) } Either::Right(outgoing) => ret_ports.extend( - self.get_all_incoming_ports(node, outgoing) + self.all_incoming_ports(node, outgoing) .map(|(node, port)| (node, port.into())), ), } diff --git a/hugr-persistent/src/walker.rs b/hugr-persistent/src/walker.rs index 4365318ba5..f346921493 100644 --- a/hugr-persistent/src/walker.rs +++ b/hugr-persistent/src/walker.rs @@ -55,9 +55,6 @@ //! versions of the graph simultaneously, without having to materialize //! each version separately. -mod pinned; -pub use pinned::PinnedWire; - use std::{borrow::Cow, collections::BTreeSet}; use itertools::{Either, Itertools}; @@ -65,7 +62,7 @@ use thiserror::Error; use hugr_core::{Direction, HugrView, Port}; -use crate::{PointerEqResolver, resolver::Resolver}; +use crate::{PersistentWire, PointerEqResolver, resolver::Resolver}; use super::{CommitStateSpace, InvalidCommit, PatchNode, PersistentHugr, state_space::CommitId}; @@ -197,13 +194,13 @@ impl<'a, R: Resolver> Walker<'a, R> { /// true, then an empty iterator is returned. pub fn expand<'b>( &'b self, - wire: &'b PinnedWire, + wire: &'b PersistentWire, dir: impl Into>, ) -> impl Iterator> + 'b { let dir = dir.into(); // Find unpinned ports on the wire (satisfying the direction constraint) - let unpinned_ports = wire.unpinned_ports(dir); + let unpinned_ports = self.wire_unpinned_ports(wire, dir); // Obtain set of pinnable nodes by considering all ports (in descendant // commits) equivalent to currently unpinned ports. @@ -231,8 +228,9 @@ impl Walker<'_, R> { /// /// # Panics /// Panics if `node` is not already pinned in this Walker. - pub fn get_wire(&self, node: PatchNode, port: impl Into) -> PinnedWire { - PinnedWire::from_pinned_port(node, port, self) + pub fn get_wire(&self, node: PatchNode, port: impl Into) -> PersistentWire { + assert!(self.is_pinned(node), "node must be pinned"); + self.selected_commits.get_wire(node, port) } /// Materialise the [`PersistentHugr`] containing all the compatible commits @@ -290,7 +288,7 @@ impl Walker<'_, R> { all_ports } - fn is_pinned(&self, node: PatchNode) -> bool { + pub(crate) fn is_pinned(&self, node: PatchNode) -> bool { self.pinned_nodes.contains(&node) } } @@ -402,11 +400,11 @@ mod tests { for new_walker in out_walkers { // new wire is complete (and thus cannot be expanded) let in0 = new_walker.get_wire(base_and_node, IncomingPort::from(0)); - assert!(in0.is_complete(None)); + assert!(new_walker.is_complete(&in0, None)); assert!(new_walker.expand(&in0, None).next().is_none()); // all nodes on wire are pinned - let (not_node, _) = in0.pinned_outport().unwrap(); + let (not_node, _) = in0.single_outgoing_port(new_walker.as_hugr_view()).unwrap(); assert!(new_walker.is_pinned(base_and_node)); assert!(new_walker.is_pinned(not_node)); @@ -486,11 +484,15 @@ mod tests { // new wire is complete (and thus cannot be expanded) let not4_out = new_walker.get_wire(not4_node, OutgoingPort::from(0)); - assert!(not4_out.is_complete(None)); + assert!(new_walker.is_complete(¬4_out, None)); assert!(new_walker.expand(¬4_out, None).next().is_none()); // all nodes on wire are pinned - let (next_node, _) = not4_out.pinned_inports().exactly_one().ok().unwrap(); + let (next_node, _) = not4_out + .all_incoming_ports(new_walker.as_hugr_view()) + .exactly_one() + .ok() + .unwrap(); assert!(new_walker.is_pinned(not4_node)); assert!(new_walker.is_pinned(next_node)); @@ -529,7 +531,7 @@ mod tests { let hugr = state_space.try_extract_hugr([commit4]).unwrap(); let (second_not_node, out_port) = - hugr.get_single_outgoing_port(base_and_node, IncomingPort::from(1)); + hugr.single_outgoing_port(base_and_node, IncomingPort::from(1)); assert_eq!(second_not_node.0, commit4); assert_eq!(out_port, OutgoingPort::from(0)); @@ -537,7 +539,7 @@ mod tests { .try_extract_hugr([commit1, commit2, commit4]) .unwrap(); let (new_and_node, in_port) = hugr - .get_all_incoming_ports(second_not_node, out_port) + .all_incoming_ports(second_not_node, out_port) .exactly_one() .ok() .unwrap(); diff --git a/hugr-persistent/src/walker/pinned.rs b/hugr-persistent/src/walker/pinned.rs deleted file mode 100644 index d2111cc2be..0000000000 --- a/hugr-persistent/src/walker/pinned.rs +++ /dev/null @@ -1,169 +0,0 @@ -//! Utilities for pinned ports and pinned wires. -//! -//! Encapsulation: we only ever expose pinned values publicly. - -use itertools::Either; - -use crate::PatchNode; -use hugr_core::{Direction, IncomingPort, OutgoingPort, Port}; - -use super::Walker; - -/// A wire in the current HUGR of a [`Walker`] with some of its endpoints -/// pinned. -/// -/// Just like a normal HUGR [`Wire`](hugr_core::Wire), a [`PinnedWire`] has -/// endpoints: the ports that are linked together by the wire. A [`PinnedWire`] -/// however distinguishes itself in that each of its ports is specified either -/// as "pinned" or "unpinned". A port is pinned if and only if the node it is -/// attached to is pinned in the walker. -/// -/// A [`PinnedWire`] always has at least one pinned port. -/// -/// All pinned ports of a [`PinnedWire`] can be retrieved using -/// [`PinnedWire::pinned_inports`] and [`PinnedWire::pinned_outport`]. Unpinned -/// ports, on the other hand, represent undetermined connections, which may -/// still change as the walker is expanded (see [`Walker::expand`]). -/// -/// Whether all incoming or outgoing ports are pinned can be checked using -/// [`PinnedWire::is_complete`]. -#[derive(Debug, Clone)] -pub struct PinnedWire { - outgoing: MaybePinned, - incoming: Vec>, -} - -/// A private enum to track whether a port is pinned. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -enum MaybePinned

{ - Pinned(PatchNode, P), - Unpinned(PatchNode, P), -} - -impl

MaybePinned

{ - fn new(node: PatchNode, port: P, walker: &Walker) -> Self { - debug_assert!( - walker.selected_commits.contains_node(node), - "pinned node not in walker" - ); - if walker.is_pinned(node) { - MaybePinned::Pinned(node, port) - } else { - MaybePinned::Unpinned(node, port) - } - } - - fn is_pinned(&self) -> bool { - matches!(self, MaybePinned::Pinned(_, _)) - } - - fn into_unpinned>(self) -> Option<(PatchNode, PP)> { - match self { - MaybePinned::Pinned(_, _) => None, - MaybePinned::Unpinned(node, port) => Some((node, port.into())), - } - } - - fn into_pinned>(self) -> Option<(PatchNode, PP)> { - match self { - MaybePinned::Pinned(node, port) => Some((node, port.into())), - MaybePinned::Unpinned(_, _) => None, - } - } -} - -impl PinnedWire { - /// Create a new pinned wire in `walker` from a pinned node and a port. - /// - /// # Panics - /// Panics if `node` is not pinned in `walker`. - pub fn from_pinned_port( - node: PatchNode, - port: impl Into, - walker: &Walker, - ) -> Self { - assert!(walker.is_pinned(node), "node must be pinned"); - - let (outgoing_node, outgoing_port) = match port.into().as_directed() { - Either::Left(incoming) => walker - .selected_commits - .get_single_outgoing_port(node, incoming), - Either::Right(outgoing) => (node, outgoing), - }; - - let outgoing = MaybePinned::new(outgoing_node, outgoing_port, walker); - - let incoming = walker - .selected_commits - .get_all_incoming_ports(outgoing_node, outgoing_port) - .map(|(n, p)| MaybePinned::new(n, p, walker)) - .collect(); - - Self { outgoing, incoming } - } - - /// Check if all ports on the wire in the given direction are pinned. - /// - /// A wire is complete in a direction if and only if expanding the wire - /// in that direction would yield no new walkers. If no direction is - /// specified, checks if the wire is complete in both directions. - pub fn is_complete(&self, dir: impl Into>) -> bool { - match dir.into() { - Some(Direction::Outgoing) => self.outgoing.is_pinned(), - Some(Direction::Incoming) => self.incoming.iter().all(|p| p.is_pinned()), - None => self.outgoing.is_pinned() && self.incoming.iter().all(|p| p.is_pinned()), - } - } - - /// Get the outgoing port of the wire, if it is pinned. - /// - /// Returns `None` if the outgoing port is not pinned. - pub fn pinned_outport(&self) -> Option<(PatchNode, OutgoingPort)> { - self.outgoing.into_pinned() - } - - /// Get all pinned incoming ports of the wire. - /// - /// Returns an iterator over all pinned incoming ports. - pub fn pinned_inports(&self) -> impl Iterator + '_ { - self.incoming.iter().filter_map(|&p| p.into_pinned()) - } - - /// Get all pinned ports of the wire. - pub fn all_pinned_ports(&self) -> impl Iterator + '_ { - fn to_port((node, port): (PatchNode, impl Into)) -> (PatchNode, Port) { - (node, port.into()) - } - self.pinned_outport() - .into_iter() - .map(to_port) - .chain(self.pinned_inports().map(to_port)) - } - - /// Get all unpinned ports of the wire, optionally filtering to only those - /// in the given direction. - pub(crate) fn unpinned_ports( - &self, - dir: impl Into>, - ) -> impl Iterator + '_ { - let incoming = self - .incoming - .iter() - .filter_map(|p| p.into_unpinned::()); - let outgoing = self.outgoing.into_unpinned::(); - let dir = dir.into(); - mask_iter(incoming, dir != Some(Direction::Outgoing)) - .chain(mask_iter(outgoing, dir != Some(Direction::Incoming))) - } -} - -/// Return an iterator over the items in `iter` if `mask` is true, otherwise -/// return an empty iterator. -#[inline] -fn mask_iter(iter: impl IntoIterator, mask: bool) -> impl Iterator { - match mask { - true => Either::Left(iter.into_iter()), - false => Either::Right(std::iter::empty()), - } - .into_iter() -} diff --git a/hugr-persistent/src/wire.rs b/hugr-persistent/src/wire.rs new file mode 100644 index 0000000000..7342b81be6 --- /dev/null +++ b/hugr-persistent/src/wire.rs @@ -0,0 +1,285 @@ +use std::collections::{BTreeSet, VecDeque}; + +use hugr_core::{Direction, HugrView, IncomingPort, OutgoingPort, Port, Wire}; +use itertools::Itertools; + +use crate::{CommitId, PatchNode, PersistentHugr, Resolver, Walker}; + +/// A wire in a [`PersistentHugr`]. +/// +/// A wire may be composed of multiple wires in the underlying commits +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct PersistentWire { + wires: BTreeSet, +} + +/// A wire within a commit HUGR of a [`PersistentHugr`]. +/// +/// This is a `Wire` valid in a commit HUGR of a [`PersistentHugr`], along with +/// the ID of the commit that contains the wire; this is equivalent to storing a +/// wire of type `Wire`. +/// +/// Note that it does not correspond to a valid wire in a [`PersistentHugr`] +/// (see [`PersistentWire`]): some of its connected ports may be on deleted or +/// IO nodes that are not valid in the [`PersistentHugr`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +struct CommitWire(Wire); + +impl CommitWire { + fn from_connected_port( + PatchNode(commit_id, node): PatchNode, + port: impl Into, + hugr: &PersistentHugr, + ) -> Self { + let commit_hugr = hugr.commit_hugr(commit_id); + let wire = Wire::from_connected_port(node, port, commit_hugr); + Self(Wire::new(PatchNode(commit_id, wire.node()), wire.source())) + } + + fn all_connected_ports<'h, R>( + &self, + hugr: &'h PersistentHugr, + ) -> impl Iterator + use<'h, R> { + let wire = Wire::new(self.0.node().1, self.0.source()); + let commit_id = self.commit_id(); + wire.all_connected_ports(hugr.commit_hugr(commit_id)) + .map(move |(node, port)| (hugr.to_persistent_node(node, commit_id), port)) + } + + fn commit_id(&self) -> CommitId { + self.0.node().0 + } +} + +/// A node in a commit of a [`PersistentHugr`] is either a valid node of the +/// HUGR, a node deleted by a child commit in that [`PersistentHugr`], or an +/// input or output node in a replacement graph. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +enum NodeStatus { + /// A node deleted by a child commit in that [`PersistentHugr`]. + /// + /// The ID of the child commit is stored in the variant. + Deleted(CommitId), + /// An input or output node in the replacement graph of a Commit + ReplacementIO, + /// A valid node in the [`PersistentHugr`] + Valid, +} + +impl PersistentHugr { + pub fn get_wire(&self, node: PatchNode, port: impl Into) -> PersistentWire { + PersistentWire::from_port(node, port, self) + } + + /// Whether a node is valid in `self`, is deleted or is an IO node in a + /// replacement graph. + fn node_status(&self, per_node @ PatchNode(commit_id, node): PatchNode) -> NodeStatus { + debug_assert!(self.contains_id(commit_id), "unknown commit"); + if self + .replacement(commit_id) + .is_some_and(|repl| repl.get_replacement_io().contains(&node)) + { + NodeStatus::ReplacementIO + } else if let Some(commit_id) = self.find_deleting_commit(per_node) { + NodeStatus::Deleted(commit_id) + } else { + NodeStatus::Valid + } + } +} + +impl PersistentWire { + /// Get the wire connected to a specified port of a pinned node in `hugr`. + fn from_port(node: PatchNode, port: impl Into, per_hugr: &PersistentHugr) -> Self { + assert!(per_hugr.contains_node(node), "node not in hugr"); + + // Queue of wires within each commit HUGR, that combined will form the + // persistent wire. + let mut commit_wires = + BTreeSet::from_iter([CommitWire::from_connected_port(node, port, per_hugr)]); + let mut queue = VecDeque::from_iter(commit_wires.iter().copied()); + + while let Some(wire) = queue.pop_front() { + let commit_id = wire.commit_id(); + let commit_hugr = per_hugr.commit_hugr(commit_id); + let all_ports = wire.all_connected_ports(per_hugr); + + for (per_node @ PatchNode(_, node), port) in all_ports { + match per_hugr.node_status(per_node) { + NodeStatus::Deleted(deleted_by) => { + // If node is deleted, check if there are wires between + // ports on the opposite end of the wire and boundary + // ports in the child commit that deleted the node. + for (opp_node, opp_port) in commit_hugr.linked_ports(node, port) { + let opp_node = per_hugr.to_persistent_node(opp_node, commit_id); + for (child_node, child_port) in per_hugr + .as_state_space() + .linked_child_ports(opp_node, opp_port, deleted_by) + { + let w = CommitWire::from_connected_port( + child_node, child_port, per_hugr, + ); + if commit_wires.insert(w) { + queue.push_back(w); + } + } + } + } + NodeStatus::ReplacementIO => { + // If node is an input (resp. output) node in a replacement graph, there + // must be (at least) one wire between the incoming (resp. outgoing) + // boundary ports of the commit (i.e. the ports connected to + // the input resp. output) and ports in a parent commit. + for (opp_node, opp_port) in commit_hugr.linked_ports(node, port) { + let opp_node = per_hugr.to_persistent_node(opp_node, commit_id); + for (parent_node, parent_port) in per_hugr + .as_state_space() + .linked_parent_ports(opp_node, opp_port) + { + let w = CommitWire::from_connected_port( + parent_node, + parent_port, + per_hugr, + ); + if commit_wires.insert(w) { + queue.push_back(w); + } + } + } + } + NodeStatus::Valid => {} + } + } + } + + Self { + wires: commit_wires, + } + } + + /// Get all ports attached to a wire in `hugr`. + /// + /// All ports returned are on nodes that are contained in `hugr`. + pub fn all_ports( + &self, + hugr: &PersistentHugr, + dir: impl Into>, + ) -> impl Iterator { + all_ports_impl(self.wires.iter().copied(), dir.into(), hugr) + } + + /// Consume the wire and return all ports attached to a wire in `hugr`. + /// + /// All ports returned are on nodes that are contained in `hugr`. + pub fn into_all_ports( + self, + hugr: &PersistentHugr, + dir: impl Into>, + ) -> impl Iterator { + all_ports_impl(self.wires.into_iter(), dir.into(), hugr) + } + + pub fn single_outgoing_port( + &self, + hugr: &PersistentHugr, + ) -> Option<(PatchNode, OutgoingPort)> { + single_outgoing(self.all_ports(hugr, Direction::Outgoing)) + } + + pub fn all_incoming_ports( + &self, + hugr: &PersistentHugr, + ) -> impl Iterator { + self.all_ports(hugr, Direction::Incoming) + .map(|(node, port)| (node, port.as_incoming().unwrap())) + } +} + +impl Walker<'_, R> { + /// Get all ports on a wire that are not pinned in `self`. + pub(crate) fn wire_unpinned_ports( + &self, + wire: &PersistentWire, + dir: impl Into>, + ) -> impl Iterator { + let ports = wire.all_ports(self.as_hugr_view(), dir); + ports.filter(|(node, _)| !self.is_pinned(*node)) + } + + /// Get the ports of the wire that are on pinned nodes of `self`. + pub fn wire_pinned_ports( + &self, + wire: &PersistentWire, + dir: impl Into>, + ) -> impl Iterator { + let ports = wire.all_ports(self.as_hugr_view(), dir); + ports.filter(|(node, _)| self.is_pinned(*node)) + } + + /// Get the outgoing port of a wire if it is pinned in `walker`. + pub fn wire_pinned_outport(&self, wire: &PersistentWire) -> Option<(PatchNode, OutgoingPort)> { + single_outgoing(self.wire_pinned_ports(wire, Direction::Outgoing)) + } + + /// Get all pinned incoming ports of a wire. + pub fn wire_pinned_inports( + &self, + wire: &PersistentWire, + ) -> impl Iterator { + self.wire_pinned_ports(wire, Direction::Incoming) + .map(|(node, port)| (node, port.as_incoming().expect("incoming port"))) + } + + /// Whether a wire is complete in the specified direction, i.e. there are no + /// unpinned ports left. + pub fn is_complete(&self, wire: &PersistentWire, dir: impl Into>) -> bool { + self.wire_unpinned_ports(wire, dir).next().is_none() + } +} + +/// Implementation of the (shared) body of [`PersistentWire::all_ports`] and +/// [`PersistentWire::into_all_ports`]. +fn all_ports_impl( + wires: impl Iterator, + dir: Option, + per_hugr: &PersistentHugr, +) -> impl Iterator { + let all_ports = wires.flat_map(move |w| w.all_connected_ports(per_hugr)); + + // Filter out invalid and wrong direction ports + all_ports + .filter(move |(_, port)| dir.is_none_or(|dir| port.direction() == dir)) + .filter(|&(node, _)| per_hugr.node_status(node) == NodeStatus::Valid) +} + +fn single_outgoing(iter: impl Iterator) -> Option<(N, OutgoingPort)> { + let (node, port) = iter.exactly_one().ok()?; + Some((node, port.as_outgoing().ok()?)) +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeSet; + + use crate::{CommitId, CommitStateSpace, PatchNode, tests::test_state_space}; + use hugr_core::{HugrView, OutgoingPort}; + use itertools::Itertools; + use rstest::rstest; + + #[rstest] + fn test_all_ports(test_state_space: (CommitStateSpace, [CommitId; 4])) { + let (state_space, [_, _, cm3, cm4]) = test_state_space; + let hugr = state_space.try_extract_hugr([cm3, cm4]).unwrap(); + let cm4_not = { + let hugr4 = state_space.commit_hugr(cm4); + let out = state_space.replacement(cm4).unwrap().get_replacement_io()[1]; + let node = hugr4.input_neighbours(out).exactly_one().ok().unwrap(); + PatchNode(cm4, node) + }; + let w = hugr.get_wire(cm4_not, OutgoingPort::from(0)); + assert_eq!( + BTreeSet::from_iter(w.wires.iter().map(|w| w.0.node().0)), + BTreeSet::from_iter([cm3, cm4, state_space.base(),]) + ); + } +} diff --git a/hugr-persistent/tests/persistent_walker_example.rs b/hugr-persistent/tests/persistent_walker_example.rs index 1c2df6bd82..faf637df57 100644 --- a/hugr-persistent/tests/persistent_walker_example.rs +++ b/hugr-persistent/tests/persistent_walker_example.rs @@ -12,7 +12,7 @@ use hugr_core::{ types::EdgeKind, }; -use hugr_persistent::{CommitStateSpace, PersistentReplacement, PinnedWire, Walker}; +use hugr_persistent::{CommitStateSpace, PersistentReplacement, PersistentWire, Walker}; /// The maximum commit depth that we will consider in this example const MAX_COMMITS: usize = 2; @@ -133,7 +133,7 @@ fn two_cz_3qb_hugr() -> Hugr { /// Traverse all commits in state space, enqueueing all outgoing wires of /// CZ nodes fn enqueue_all( - queue: &mut VecDeque<(PinnedWire, Walker<'static>)>, + queue: &mut VecDeque<(PersistentWire, Walker<'static>)>, state_space: &CommitStateSpace, ) { for id in state_space.all_commit_ids() { @@ -169,10 +169,10 @@ fn build_state_space() -> CommitStateSpace { enqueue_all(&mut wire_queue, &state_space); while let Some((wire, walker)) = wire_queue.pop_front() { - if !wire.is_complete(None) { + if !walker.is_complete(&wire, None) { // expand the wire in all possible ways - let (pinned_node, pinned_port) = wire - .all_pinned_ports() + let (pinned_node, pinned_port) = walker + .wire_pinned_ports(&wire, None) .next() .expect("at least one port was already pinned"); assert!( @@ -190,7 +190,10 @@ fn build_state_space() -> CommitStateSpace { // we have a complete wire, so we can commute the CZ gates (or // cancel them out) - let patch_nodes: BTreeSet<_> = wire.all_pinned_ports().map(|(n, _)| n).collect(); + let patch_nodes: BTreeSet<_> = walker + .wire_pinned_ports(&wire, None) + .map(|(n, _)| n) + .collect(); // check that the patch applies to more than one commit (or the base), // otherwise we have infinite commutations back and forth let patch_owners: BTreeSet<_> = patch_nodes.iter().map(|n| n.0).collect(); @@ -230,14 +233,14 @@ fn build_state_space() -> CommitStateSpace { state_space } -fn create_replacement(wire: PinnedWire, walker: &Walker) -> Option { +fn create_replacement(wire: PersistentWire, walker: &Walker) -> Option { let hugr = walker.clone().into_persistent_hugr(); let (out_node, _) = wire - .pinned_outport() + .single_outgoing_port(&hugr) .expect("outgoing port was already pinned (and is unique)"); let (in_node, _) = wire - .pinned_inports() + .all_incoming_ports(&hugr) .exactly_one() .ok() .expect("all our wires have exactly one incoming port");