diff --git a/hugr-core/src/hugr/rewrite.rs b/hugr-core/src/hugr/rewrite.rs index 1f2be7a9a1..7c4374b651 100644 --- a/hugr-core/src/hugr/rewrite.rs +++ b/hugr-core/src/hugr/rewrite.rs @@ -5,10 +5,12 @@ pub mod inline_call; pub mod inline_dfg; pub mod insert_identity; pub mod outline_cfg; +mod port_types; pub mod replace; pub mod simple_replace; use crate::{Hugr, HugrView, Node}; +pub use port_types::{BoundaryPort, HostPort, ReplacementPort}; pub use simple_replace::{SimpleReplacement, SimpleReplacementError}; use super::HugrMut; diff --git a/hugr-core/src/hugr/rewrite/port_types.rs b/hugr-core/src/hugr/rewrite/port_types.rs new file mode 100644 index 0000000000..3aeafa4aea --- /dev/null +++ b/hugr-core/src/hugr/rewrite/port_types.rs @@ -0,0 +1,74 @@ +//! Types to distinguish ports in the host and replacement graphs. + +use std::collections::HashMap; + +use crate::{IncomingPort, Node, OutgoingPort, Port}; + +use derive_more::From; + +/// A port in either the host or replacement graph. +/// +/// This is used to represent boundary edges that will be added between the host and +/// replacement graphs when applying a rewrite. +#[derive(Debug, Clone, Copy)] +pub enum BoundaryPort { + /// A port in the host graph. + Host(HostNode, P), + /// A port in the replacement graph. + Replacement(Node, P), +} + +/// A port in the host graph. +#[derive(Debug, Clone, Copy, From)] +pub struct HostPort(pub N, pub P); + +/// A port in the replacement graph. +#[derive(Debug, Clone, Copy, From)] +pub struct ReplacementPort

(pub Node, pub P); + +impl BoundaryPort { + /// Maps a boundary port according to the insertion mapping. + /// Host ports are unchanged, while Replacement ports are mapped according to the index_map. + pub fn map_replacement(self, index_map: &HashMap) -> (HostNode, P) { + match self { + BoundaryPort::Host(node, port) => (node, port), + BoundaryPort::Replacement(node, port) => (*index_map.get(&node).unwrap(), port), + } + } +} + +impl From> for BoundaryPort { + fn from(HostPort(node, port): HostPort) -> Self { + BoundaryPort::Host(node, port) + } +} + +impl From> for BoundaryPort { + fn from(ReplacementPort(node, port): ReplacementPort

) -> Self { + BoundaryPort::Replacement(node, port) + } +} + +impl From> for HostPort { + fn from(HostPort(node, port): HostPort) -> Self { + HostPort(node, port.into()) + } +} + +impl From> for HostPort { + fn from(HostPort(node, port): HostPort) -> Self { + HostPort(node, port.into()) + } +} + +impl From> for ReplacementPort { + fn from(ReplacementPort(node, port): ReplacementPort) -> Self { + ReplacementPort(node, port.into()) + } +} + +impl From> for ReplacementPort { + fn from(ReplacementPort(node, port): ReplacementPort) -> Self { + ReplacementPort(node, port.into()) + } +} diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index 5187b4fa1b..cf7f2922a1 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -8,11 +8,14 @@ pub use crate::hugr::internal::HugrMutInternals; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrMut, HugrView, Rewrite}; use crate::ops::{OpTag, OpTrait, OpType}; -use crate::{Hugr, IncomingPort, Node}; +use crate::{Hugr, IncomingPort, Node, OutgoingPort}; + +use itertools::Itertools; use thiserror::Error; use super::inline_dfg::InlineDFGError; +use super::{BoundaryPort, HostPort, ReplacementPort}; /// Specification of a simple replacement operation. /// @@ -25,11 +28,11 @@ pub struct SimpleReplacement { subgraph: SiblingSubgraph, /// A hugr with DFG root (consisting of replacement nodes). replacement: Hugr, - /// A map from (target ports of edges from the Input node of `replacement`) to (target ports of - /// edges from nodes not in `removal` to nodes in `removal`). + /// A map from (target ports of edges from the Input node of `replacement`) + /// to (target ports of edges from nodes not in `subgraph` to nodes in `subgraph`). nu_inp: HashMap<(Node, IncomingPort), (HostNode, IncomingPort)>, - /// A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to - /// (input ports of the Output node of `replacement`). + /// A map from (target ports of edges from nodes in `subgraph` to nodes not + /// in `subgraph`) to (input ports of the Output node of `replacement`). nu_out: HashMap<(HostNode, IncomingPort), IncomingPort>, } @@ -56,97 +59,252 @@ impl SimpleReplacement { &self.replacement } + /// Consume self and return the replacement hugr. + #[inline] + pub fn into_replacement(self) -> Hugr { + self.replacement + } + /// Subgraph to be replaced. #[inline] pub fn subgraph(&self) -> &SiblingSubgraph { &self.subgraph } -} -impl Rewrite for SimpleReplacement { - type Error = SimpleReplacementError; - type ApplyResult = Vec<(Node, OpType)>; - const UNCHANGED_ON_FAILURE: bool = true; - - fn verify(&self, _h: &impl HugrView) -> Result<(), SimpleReplacementError> { - unimplemented!() - } + /// Check if the replacement can be applied to the given hugr. + pub fn is_valid_rewrite( + &self, + h: &impl HugrView, + ) -> Result<(), SimpleReplacementError> { + let parent = self.subgraph.get_parent(h); - fn apply(self, h: &mut impl HugrMut) -> Result { - let Self { - subgraph, - replacement, - nu_inp, - nu_out, - } = self; - let parent = subgraph.get_parent(h); // 1. Check the parent node exists and is a DataflowParent. if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) { return Err(SimpleReplacementError::InvalidParentNode()); } + // 2. Check that all the to-be-removed nodes are children of it and are leaves. - for node in subgraph.nodes() { + for node in self.subgraph.nodes() { if h.get_parent(*node) != Some(parent) || h.children(*node).next().is_some() { return Err(SimpleReplacementError::InvalidRemovedNode()); } } - let replacement_output_node = replacement - .get_io(replacement.root()) - .expect("parent already checked.")[1]; + Ok(()) + } - // 3. Do the replacement. - // Now we proceed to connect the edges between the newly inserted - // replacement and the rest of the graph. - // - // Existing connections to the removed subgraph will be automatically - // removed when the nodes are removed. + /// Get the input and output nodes of the replacement hugr. + pub fn get_replacement_io(&self) -> Result<[Node; 2], SimpleReplacementError> { + self.replacement + .get_io(self.replacement.root()) + .ok_or(SimpleReplacementError::InvalidParentNode()) + } - // 3.1. For each p = self.nu_inp[q] such that q is not an Output port, add an edge from the - // predecessor of p to (the new copy of) q. - let nu_inp_connects: Vec<_> = nu_inp + /// Get all edges that the replacement would add from outgoing ports in + /// `host` to incoming ports in `self.replacement`. + /// + /// The incoming ports returned are always connected to outputs of + /// the [`OpTag::Input`] node of `self.replacement`. + /// + /// For each pair in the returned vector, the first element is a port in + /// `host` and the second is a port in `self.replacement`. + pub fn incoming_boundary<'a>( + &'a self, + host: &'a impl HugrView, + ) -> impl Iterator< + Item = ( + HostPort, + ReplacementPort, + ), + > + 'a { + // For each p = self.nu_inp[q] such that q is not an Output port, + // there will be an edge from the predecessor of p to (the new copy of) q. + self.nu_inp .iter() .filter(|&((rep_inp_node, _), _)| { - replacement.get_optype(*rep_inp_node).tag() != OpTag::Output + self.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output }) .map( - |((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port))| { + |(&(rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port))| { // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port) - let (rem_inp_pred_node, rem_inp_pred_port) = h + let (rem_inp_pred_node, rem_inp_pred_port) = host .single_linked_output(*rem_inp_node, *rem_inp_port) .unwrap(); ( - rem_inp_pred_node, - rem_inp_pred_port, - // the new input node will be updated after insertion - rep_inp_node, - rep_inp_port, + HostPort(rem_inp_pred_node, rem_inp_pred_port), + ReplacementPort(rep_inp_node, rep_inp_port), ) }, ) - .collect(); + } - // 3.2. For each q = self.nu_out[p] such that the predecessor of q is not an Input port, add an - // edge from (the new copy of) the predecessor of q to p. - let nu_out_connects: Vec<_> = nu_out + /// Get all edges that the replacement would add from outgoing ports in + /// `self.replacement` to incoming ports in `host`. + /// + /// The outgoing ports returned are always connected to inputs of + /// the [`OpTag::Output`] node of `self.replacement`. + /// + /// For each pair in the returned vector, the first element is a port in + /// `self.replacement` and the second is a port in `host`. + /// + /// This panics if self.replacement is not a DFG. + pub fn outgoing_boundary<'a>( + &'a self, + _host: &'a impl HugrView, + ) -> impl Iterator< + Item = ( + ReplacementPort, + HostPort, + ), + > + 'a { + let [_, replacement_output_node] = self.get_replacement_io().expect("replacement is a DFG"); + + // For each q = self.nu_out[p] such that the predecessor of q is not an Input port, + // there will be an edge from (the new copy of) the predecessor of q to p. + self.nu_out .iter() - .filter_map(|((rem_out_node, rem_out_port), rep_out_port)| { - let (rep_out_pred_node, rep_out_pred_port) = replacement + .filter_map(move |(&(rem_out_node, rem_out_port), rep_out_port)| { + let (rep_out_pred_node, rep_out_pred_port) = self + .replacement .single_linked_output(replacement_output_node, *rep_out_port) .unwrap(); - (replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input).then_some({ + (self.replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input).then_some({ ( // the new output node will be updated after insertion - rep_out_pred_node, - rep_out_pred_port, - rem_out_node, - rem_out_port, + ReplacementPort(rep_out_pred_node, rep_out_pred_port), + HostPort(rem_out_node, rem_out_port), ) }) }) - .collect(); + } + + /// Get all edges that the replacement would add between ports in `host`. + /// + /// These correspond to direct edges between the input and output nodes + /// in the replacement graph. + /// + /// For each pair in the returned vector, the both ports are in `host`. + /// + /// This panics if self.replacement is not a DFG. + pub fn host_to_host_boundary<'a>( + &'a self, + host: &'a impl HugrView, + ) -> impl Iterator< + Item = ( + HostPort, + HostPort, + ), + > + 'a { + let [_, replacement_output_node] = self.get_replacement_io().expect("replacement is a DFG"); + + // For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0 + // to p1. + self.nu_out + .iter() + .filter_map(move |(&(rem_out_node, rem_out_port), &rep_out_port)| { + self.nu_inp + .get(&(replacement_output_node, rep_out_port)) + .map(|&(rem_inp_node, rem_inp_port)| { + let (rem_inp_pred_node, rem_inp_pred_port) = host + .single_linked_output(rem_inp_node, rem_inp_port) + .unwrap(); + ( + HostPort(rem_inp_pred_node, rem_inp_pred_port), + HostPort(rem_out_node, rem_out_port), + ) + }) + }) + } - // 3.3. Insert the replacement as a whole. + /// Get the incoming port at the output node of `self.replacement` that + /// corresponds to the given host output port. + /// + /// This panics if self.replacement is not a DFG. + pub fn map_host_output( + &self, + port: impl Into>, + ) -> Option> { + let HostPort(node, port) = port.into(); + let [_, rep_output] = self.get_replacement_io().expect("replacement is a DFG"); + self.nu_out + .get(&(node, port)) + .map(|&rep_out_port| ReplacementPort(rep_output, rep_out_port)) + } + + /// Get the incoming port in `subgraph` that corresponds to the given + /// replacement input port. + /// + /// This panics if self.replacement is not a DFG. + pub fn map_replacement_input( + &self, + port: impl Into>, + ) -> Option> { + let ReplacementPort(node, port) = port.into(); + self.nu_inp.get(&(node, port)).copied().map(Into::into) + } + + /// Get all edges that the replacement would add between `host` and + /// `self.replacement`. + /// + /// This is equivalent to chaining the results of [`Self::incoming_boundary`], + /// [`Self::outgoing_boundary`], and [`Self::host_to_host_boundary`]. + /// + /// This panics if self.replacement is not a DFG. + pub fn all_boundary_edges<'a>( + &'a self, + host: &'a impl HugrView, + ) -> impl Iterator< + Item = ( + BoundaryPort, + BoundaryPort, + ), + > + 'a { + let incoming_boundary = self + .incoming_boundary(host) + .map(|(src, tgt)| (src.into(), tgt.into())); + let outgoing_boundary = self + .outgoing_boundary(host) + .map(|(src, tgt)| (src.into(), tgt.into())); + let host_to_host_boundary = self + .host_to_host_boundary(host) + .map(|(src, tgt)| (src.into(), tgt.into())); + + incoming_boundary + .chain(outgoing_boundary) + .chain(host_to_host_boundary) + } +} + +impl Rewrite for SimpleReplacement { + type Error = SimpleReplacementError; + type ApplyResult = Vec<(Node, OpType)>; + const UNCHANGED_ON_FAILURE: bool = true; + + fn verify(&self, h: &impl HugrView) -> Result<(), SimpleReplacementError> { + self.is_valid_rewrite(h) + } + + fn apply(self, h: &mut impl HugrMut) -> Result { + self.is_valid_rewrite(h)?; + + let parent = self.subgraph.get_parent(h); + + // We proceed to connect the edges between the newly inserted + // replacement and the rest of the graph. + // + // Existing connections to the removed subgraph will be automatically + // removed when the nodes are removed. + + // 1. Get the boundary edges + let boundary_edges = self.all_boundary_edges(h).collect_vec(); + + let Self { + replacement, + subgraph, + .. + } = self; + + // 2. Insert the replacement as a whole. let InsertionResult { new_root, node_map: index_map, @@ -164,46 +322,14 @@ impl Rewrite for SimpleReplacement { // remove the replacement root (which now has no children and no edges) h.remove_node(new_root); - // 3.4. Update replacement nodes according to insertion mapping and connect - for (src_node, src_port, tgt_node, tgt_port) in nu_inp_connects { - h.connect( - src_node, - src_port, - *index_map.get(tgt_node).unwrap(), - *tgt_port, - ) - } - - for (src_node, src_port, tgt_node, tgt_port) in nu_out_connects { - h.connect( - *index_map.get(&src_node).unwrap(), - src_port, - *tgt_node, - *tgt_port, - ) - } - // 3.5. For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0 - // to p1. - // - // i.e. the replacement graph has direct edges between the input and output nodes. - for ((rem_out_node, rem_out_port), &rep_out_port) in &nu_out { - let rem_inp_nodeport = nu_inp.get(&(replacement_output_node, rep_out_port)); - if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport { - // add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port): - let (rem_inp_pred_node, rem_inp_pred_port) = h - .single_linked_output(*rem_inp_node, *rem_inp_port) - .unwrap(); - - h.connect( - rem_inp_pred_node, - rem_inp_pred_port, - *rem_out_node, - *rem_out_port, - ); - } + // 3. Insert all boundary edges. + for (src, tgt) in boundary_edges { + let (src_node, src_port) = src.map_replacement(&index_map); + let (tgt_node, tgt_port) = tgt.map_replacement(&index_map); + h.connect(src_node, src_port, tgt_node, tgt_port); } - // 3.6. Remove all nodes in subgraph and edges between them. + // 4. Remove all nodes in subgraph and edges between them. Ok(subgraph .nodes() .iter()