Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 78 additions & 24 deletions hugr-core/src/hugr/patch/simple_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,16 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
/// of `self`.
///
/// The returned port will be in `replacement`, unless the wire in the
/// replacement is empty, in which case it will another `host` port.
/// replacement is empty and `return_invalid` is
/// [`BoundaryMode::SnapToHost`] (the default), in which case it
/// will be another `host` port. If [`BoundaryMode::IncludeIO`] is
/// passed, the returned port will always be in `replacement` even if it
/// is invalid (i.e. it is an IO node in the replacement).
pub fn linked_replacement_output(
&self,
port: impl Into<HostPort<HostNode, IncomingPort>>,
host: &impl HugrView<Node = HostNode>,
return_invalid: BoundaryMode,
) -> Option<BoundaryPort<HostNode, OutgoingPort>> {
let HostPort(node, port) = port.into();
let pos = self
Expand All @@ -139,7 +144,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
.iter()
.position(move |&(n, p)| host.linked_inputs(n, p).contains(&(node, port)))?;

Some(self.linked_replacement_output_by_position(pos, host))
Some(self.linked_replacement_output_by_position(pos, host, return_invalid))
}

/// The outgoing port linked to the i-th output boundary edge of `subgraph`.
Expand All @@ -150,6 +155,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
&self,
pos: usize,
host: &impl HugrView<Node = HostNode>,
return_invalid: BoundaryMode,
) -> BoundaryPort<HostNode, OutgoingPort> {
debug_assert!(pos < self.subgraph().signature(host).output_count());

Expand All @@ -160,7 +166,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
.single_linked_output(repl_out, pos)
.expect("valid dfg wire");

if out_node != repl_inp {
if out_node != repl_inp || return_invalid == BoundaryMode::IncludeIO {
BoundaryPort::Replacement(out_node, out_port)
} else {
let (in_node, in_port) = *self.subgraph.incoming_ports()[out_port.index()]
Expand Down Expand Up @@ -207,11 +213,17 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
/// of `self`.
///
/// The returned ports will be in `replacement`, unless the wires in the
/// replacement are empty, in which case they are other `host` ports.
/// replacement are empty and `return_invalid` is
/// [`BoundaryMode::SnapToHost`] (the default), in which case they
/// will be other `host` ports. If [`BoundaryMode::IncludeIO`] is
/// passed, the returned ports will always be in
/// `replacement` even if they are invalid (i.e. they are an IO node in
/// the replacement).
pub fn linked_replacement_inputs<'a>(
&'a self,
port: impl Into<HostPort<HostNode, OutgoingPort>>,
host: &'a impl HugrView<Node = HostNode>,
return_invalid: BoundaryMode,
) -> impl Iterator<Item = BoundaryPort<HostNode, IncomingPort>> + 'a {
let HostPort(node, port) = port.into();
let positions = self
Expand All @@ -223,26 +235,25 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
host.single_linked_output(n, p).expect("valid dfg wire") == (node, port)
});

positions.flat_map(|pos| self.linked_replacement_inputs_by_position(pos, host))
positions.flat_map(move |pos| {
self.linked_replacement_inputs_by_position(pos, host, return_invalid)
})
}

/// The incoming ports linked to the i-th input boundary edge of `subgraph`.
///
/// The ports will be in `replacement` for all endpoints of the i-th input
/// wire that are not the output node of `replacement` and be in `host`
/// otherwise.
fn linked_replacement_inputs_by_position(
&self,
pos: usize,
host: &impl HugrView<Node = HostNode>,
return_invalid: BoundaryMode,
) -> impl Iterator<Item = BoundaryPort<HostNode, IncomingPort>> {
debug_assert!(pos < self.subgraph().signature(host).input_count());

let [repl_inp, repl_out] = self.get_replacement_io();
self.replacement
.linked_inputs(repl_inp, pos)
.flat_map(move |(in_node, in_port)| {
if in_node != repl_out {
if in_node != repl_out || return_invalid == BoundaryMode::IncludeIO {
Either::Left(std::iter::once(BoundaryPort::Replacement(in_node, in_port)))
} else {
let (out_node, out_port) = self.subgraph.outgoing_ports()[in_port.index()];
Expand Down Expand Up @@ -316,7 +327,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
subgraph_outgoing_ports
.enumerate()
.flat_map(|(pos, subg_np)| {
self.linked_replacement_inputs_by_position(pos, host)
self.linked_replacement_inputs_by_position(pos, host, BoundaryMode::SnapToHost)
.filter_map(move |np| Some((np.as_replacement()?, subg_np)))
})
.map(|((repl_node, repl_port), (subgraph_node, subgraph_port))| {
Expand Down Expand Up @@ -359,7 +370,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
.enumerate()
.filter_map(|(pos, subg_all)| {
let np = self
.linked_replacement_output_by_position(pos, host)
.linked_replacement_output_by_position(pos, host, BoundaryMode::SnapToHost)
.as_replacement()?;
Some((np, subg_all))
})
Expand Down Expand Up @@ -406,7 +417,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
.enumerate()
.filter_map(|(pos, subg_all)| {
Some((
self.linked_replacement_output_by_position(pos, host)
self.linked_replacement_output_by_position(pos, host, BoundaryMode::SnapToHost)
.as_host()?,
subg_all,
))
Expand Down Expand Up @@ -517,7 +528,8 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
SimpleReplacement::try_new(subgraph, new_host, replacement.clone())
}

/// Allows to get the [Self::invalidated_nodes] without requiring a [HugrView].
/// Allows to get the [Self::invalidated_nodes] without requiring a
/// [HugrView].
pub fn invalidation_set(&self) -> impl Iterator<Item = HostNode> {
self.subgraph.nodes().iter().copied()
}
Expand All @@ -540,6 +552,24 @@ impl<HostNode: HugrNode> PatchVerification for SimpleReplacement<HostNode> {
}
}

/// In [`SimpleReplacement`], some nodes in the replacement may not be valid
/// after the replacement is applied.
///
/// This enum allows specifying whether these invalid nodes on the boundary
/// should be returned or should be resolved to valid nodes in the host.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum BoundaryMode {
/// Only consider nodes that are valid after the replacement is applied.
///
/// This means that nodes in hosts may be returned in places where nodes in
/// the replacement would be typically expected.
#[default]
SnapToHost,
/// Include all nodes, including potentially invalid ones (inputs and
/// outputs of replacements).
IncludeIO,
}

/// Result of applying a [`SimpleReplacement`].
pub struct Outcome<HostNode = Node> {
/// Map from Node in replacement to corresponding Node in the result Hugr
Expand Down Expand Up @@ -652,7 +682,7 @@ pub(in crate::hugr::patch) mod test {
ModuleBuilder, endo_sig, inout_sig,
};
use crate::extension::prelude::{bool_t, qb_t};
use crate::hugr::patch::simple_replace::Outcome;
use crate::hugr::patch::simple_replace::{BoundaryMode, Outcome};
use crate::hugr::patch::{BoundaryPort, HostPort, PatchVerification, ReplacementPort};
use crate::hugr::views::{HugrView, SiblingSubgraph};
use crate::hugr::{Hugr, HugrMut, Patch};
Expand Down Expand Up @@ -1145,7 +1175,11 @@ pub(in crate::hugr::patch) mod test {

// Test linked_replacement_inputs with empty replacement
let replacement_inputs: Vec<_> = repl
.linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr)
.linked_replacement_inputs(
(inp, OutgoingPort::from(0)),
&hugr,
BoundaryMode::SnapToHost,
)
.collect();

assert_eq!(
Expand All @@ -1158,8 +1192,12 @@ pub(in crate::hugr::patch) mod test {
// Test linked_replacement_output with empty replacement
let replacement_output = (0..4)
.map(|i| {
repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr)
.unwrap()
repl.linked_replacement_output(
(out, IncomingPort::from(i)),
&hugr,
BoundaryMode::SnapToHost,
)
.unwrap()
})
.collect_vec();

Expand Down Expand Up @@ -1191,7 +1229,11 @@ pub(in crate::hugr::patch) mod test {
};

let replacement_inputs: Vec<_> = repl
.linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr)
.linked_replacement_inputs(
(inp, OutgoingPort::from(0)),
&hugr,
BoundaryMode::SnapToHost,
)
.collect();

assert_eq!(
Expand All @@ -1203,8 +1245,12 @@ pub(in crate::hugr::patch) mod test {

let replacement_output = (0..4)
.map(|i| {
repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr)
.unwrap()
repl.linked_replacement_output(
(out, IncomingPort::from(i)),
&hugr,
BoundaryMode::SnapToHost,
)
.unwrap()
})
.collect_vec();

Expand Down Expand Up @@ -1241,7 +1287,11 @@ pub(in crate::hugr::patch) mod test {
};

let replacement_inputs: Vec<_> = repl
.linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr)
.linked_replacement_inputs(
(inp, OutgoingPort::from(0)),
&hugr,
BoundaryMode::SnapToHost,
)
.collect();

assert_eq!(
Expand All @@ -1257,8 +1307,12 @@ pub(in crate::hugr::patch) mod test {

let replacement_output = (0..4)
.map(|i| {
repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr)
.unwrap()
repl.linked_replacement_output(
(out, IncomingPort::from(i)),
&hugr,
BoundaryMode::SnapToHost,
)
.unwrap()
})
.collect_vec();

Expand Down
87 changes: 75 additions & 12 deletions hugr-persistent/src/persistent_hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ impl Commit {
/// Requires a reference to the commit state space that the nodes in
/// `replacement` refer to.
///
/// Use [`Self::try_new`] instead if the parents of the commit cannot be
/// inferred from the invalidation set of `replacement` alone.
///
/// The replacement must act on a non-empty subgraph, otherwise this
/// function will return an [`InvalidCommit::EmptyReplacement`] error.
///
Expand All @@ -47,20 +50,37 @@ impl Commit {
pub fn try_from_replacement<R>(
replacement: PersistentReplacement,
graph: &CommitStateSpace<R>,
) -> Result<Commit, InvalidCommit> {
Self::try_new(replacement, [], graph)
}

/// Create a new commit
///
/// Requires a reference to the commit state space that the nodes in
/// `replacement` refer to.
///
/// The returned commit will correspond to the application of `replacement`
/// and will be the child of the commits in `parents` as well as of all
/// the commits in the invalidation set of `replacement`.
///
/// The replacement must act on a non-empty subgraph, otherwise this
/// function will return an [`InvalidCommit::EmptyReplacement`] error.
/// If any of the parents of the replacement are not in the commit state
/// space, this function will return an [`InvalidCommit::UnknownParent`]
/// error.
pub fn try_new<R>(
replacement: PersistentReplacement,
parents: impl IntoIterator<Item = Commit>,
graph: &CommitStateSpace<R>,
) -> Result<Commit, InvalidCommit> {
if replacement.subgraph().nodes().is_empty() {
return Err(InvalidCommit::EmptyReplacement);
}
let parent_ids = replacement.invalidation_set().map(|n| n.0).unique();
let parents = parent_ids
.map(|id| {
if graph.contains_id(id) {
Ok(graph.get_commit(id).clone())
} else {
Err(InvalidCommit::UnknownParent(id))
}
})
.collect::<Result<Vec<_>, _>>()?;
let repl_parents = get_parent_commits(&replacement, graph)?;
let parents = parents
.into_iter()
.chain(repl_parents)
.unique_by(|p| p.as_ptr());
let rc = RelRc::with_parents(
replacement.into(),
parents.into_iter().map(|p| (p.into(), ())),
Expand Down Expand Up @@ -434,6 +454,8 @@ impl<R> PersistentHugr<R> {
pub fn base_commit(&self) -> &Commit;
/// Get the commit with ID `commit_id`.
pub fn get_commit(&self, commit_id: CommitId) -> &Commit;
/// Check whether `commit_id` exists and return it.
pub fn try_get_commit(&self, commit_id: CommitId) -> Option<&Commit>;
/// Get an iterator over all nodes inserted by `commit_id`.
///
/// All nodes will be PatchNodes with commit ID `commit_id`.
Expand Down Expand Up @@ -529,6 +551,32 @@ impl<R> PersistentHugr<R> {
.expect("invalid port")
.is_value()
}

pub(super) fn value_ports(
&self,
patch_node @ PatchNode(commit_id, node): PatchNode,
dir: Direction,
) -> impl Iterator<Item = (PatchNode, Port)> + '_ {
let hugr = self.commit_hugr(commit_id);
let ports = hugr.node_ports(node, dir);
ports.filter_map(move |p| self.is_value_port(patch_node, p).then_some((patch_node, p)))
}

pub(super) fn output_value_ports(
&self,
patch_node: PatchNode,
) -> impl Iterator<Item = (PatchNode, OutgoingPort)> + '_ {
self.value_ports(patch_node, Direction::Outgoing)
.map(|(n, p)| (n, p.as_outgoing().expect("unexpected port direction")))
}

pub(super) fn input_value_ports(
&self,
patch_node: PatchNode,
) -> impl Iterator<Item = (PatchNode, IncomingPort)> + '_ {
self.value_ports(patch_node, Direction::Incoming)
.map(|(n, p)| (n, p.as_incoming().expect("unexpected port direction")))
}
}

impl<R> IntoIterator for PersistentHugr<R> {
Expand All @@ -549,11 +597,11 @@ impl<R> IntoIterator for PersistentHugr<R> {
/// among `children`.
pub(crate) fn find_conflicting_node<'a>(
commit_id: CommitId,
mut children: impl Iterator<Item = &'a Commit>,
children: impl IntoIterator<Item = &'a Commit>,
) -> Option<Node> {
let mut all_invalidated = BTreeSet::new();

children.find_map(|child| {
children.into_iter().find_map(|child| {
let mut new_invalidated =
child
.invalidation_set()
Expand All @@ -567,3 +615,18 @@ pub(crate) fn find_conflicting_node<'a>(
new_invalidated.find(|&n| !all_invalidated.insert(n))
})
}

fn get_parent_commits<R>(
replacement: &PersistentReplacement,
graph: &CommitStateSpace<R>,
) -> Result<Vec<Commit>, InvalidCommit> {
let parent_ids = replacement.invalidation_set().map(|n| n.owner()).unique();
parent_ids
.map(|id| {
graph
.try_get_commit(id)
.cloned()
.ok_or(InvalidCommit::UnknownParent(id))
})
.collect()
}
Loading
Loading