Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 177 additions & 121 deletions hugr-core/src/hugr/views/root_checked/dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,138 +8,149 @@ use thiserror::Error;
use crate::{
IncomingPort, OutgoingPort, PortIndex,
hugr::HugrMut,
ops::{DFG, FuncDefn, Input, OpTrait, OpType, Output, dataflow::IOTrait, handle::DfgID},
ops::{
DFG, FuncDefn, Input, OpTrait, OpType, Output,
dataflow::IOTrait,
handle::{DataflowParentID, DfgID},
},
types::{NoRV, Signature, TypeBase},
};

use super::RootChecked;

impl<H: HugrMut> RootChecked<H, DfgID<H::Node>> {
/// Get the input and output nodes of the DFG at the entrypoint node.
pub fn get_io(&self) -> [H::Node; 2] {
self.hugr()
.get_io(self.hugr().entrypoint())
.expect("valid DFG graph")
}

/// Rewire the inputs and outputs of the DFG to modify its signature.
///
/// Reorder the outgoing resp. incoming wires at the input resp. output
/// node of the DFG to modify the signature of the DFG HUGR. This will
/// recursively update the signatures of all ancestors of the entrypoint.
///
/// ### Arguments
///
/// * `new_inputs`: The new input signature. After the map, the i-th input
/// wire will be connected to the ports connected to the
/// `new_inputs[i]`-th input of the old DFG.
/// * `new_outputs`: The new output signature. After the map, the i-th
/// output wire will be connected to the ports connected to the
/// `new_outputs[i]`-th output of the old DFG.
///
/// Returns an `InvalidSignature` error if the new_inputs and new_outputs
/// map are not valid signatures.
///
/// ### Panics
///
/// Panics if the DFG is not trivially nested, i.e. if there is an ancestor
/// DFG of the entrypoint that has more than one inner DFG.
pub fn map_function_type(
&mut self,
new_inputs: &[usize],
new_outputs: &[usize],
) -> Result<(), InvalidSignature> {
let [inp, out] = self.get_io();
let Self(hugr, _) = self;

// Record the old connections from and to the input and output nodes
let old_inputs_incoming = hugr
.node_outputs(inp)
.map(|p| hugr.linked_inputs(inp, p).collect_vec())
.collect_vec();
let old_outputs_outgoing = hugr
.node_inputs(out)
.map(|p| hugr.linked_outputs(out, p).collect_vec())
.collect_vec();

// The old signature types
let old_inp_sig = hugr
.get_optype(inp)
.dataflow_signature()
.expect("input has signature");
let old_inp_sig = old_inp_sig.output_types();
let old_out_sig = hugr
.get_optype(out)
.dataflow_signature()
.expect("output has signature");
let old_out_sig = old_out_sig.input_types();

// Check if the signature map is valid
check_valid_inputs(&old_inputs_incoming, old_inp_sig, new_inputs)?;
check_valid_outputs(old_out_sig, new_outputs)?;

// The new signature types
let new_inp_sig = new_inputs
.iter()
.map(|&i| old_inp_sig[i].clone())
.collect_vec();
let new_out_sig = new_outputs
.iter()
.map(|&i| old_out_sig[i].clone())
.collect_vec();
let new_sig = Signature::new(new_inp_sig, new_out_sig);

// Remove all edges of the input and output nodes
disconnect_all(hugr, inp);
disconnect_all(hugr, out);

// Update the signatures of the IO and their ancestors
let mut is_ancestor = false;
let mut node = hugr.entrypoint();
while matches!(hugr.get_optype(node), OpType::FuncDefn(_) | OpType::DFG(_)) {
let [inner_inp, inner_out] = hugr.get_io(node).expect("valid DFG graph");
for node in [node, inner_inp, inner_out] {
update_signature(hugr, node, &new_sig);
macro_rules! impl_dataflow_parent_methods {
($handle_type:ident) => {
impl<H: HugrMut> RootChecked<H, $handle_type<H::Node>> {
/// Get the input and output nodes of the DFG at the entrypoint node.
pub fn get_io(&self) -> [H::Node; 2] {
self.hugr()
.get_io(self.hugr().entrypoint())
.expect("valid DFG graph")
}
if is_ancestor {
update_inner_dfg_links(hugr, node);
}
if let Some(parent) = hugr.get_parent(node) {
node = parent;
is_ancestor = true;
} else {
break;
}
}

// Insert the new edges at the input
let mut old_output_to_new_input = BTreeMap::<IncomingPort, OutgoingPort>::new();
for (inp_pos, &old_pos) in new_inputs.iter().enumerate() {
for &(node, port) in &old_inputs_incoming[old_pos] {
if node != out {
hugr.connect(inp, inp_pos, node, port);
} else {
old_output_to_new_input.insert(port, inp_pos.into());
/// Rewire the inputs and outputs of the DFG to modify its signature.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this docstring be updated to be more generic?

///
/// Reorder the outgoing resp. incoming wires at the input resp. output
/// node of the DFG to modify the signature of the DFG HUGR. This will
/// recursively update the signatures of all ancestors of the entrypoint.
///
/// ### Arguments
///
/// * `new_inputs`: The new input signature. After the map, the i-th input
/// wire will be connected to the ports connected to the
/// `new_inputs[i]`-th input of the old DFG.
/// * `new_outputs`: The new output signature. After the map, the i-th
/// output wire will be connected to the ports connected to the
/// `new_outputs[i]`-th output of the old DFG.
///
/// Returns an `InvalidSignature` error if the new_inputs and new_outputs
/// map are not valid signatures.
///
/// ### Panics
///
/// Panics if the DFG is not trivially nested, i.e. if there is an ancestor
/// DFG of the entrypoint that has more than one inner DFG.
pub fn map_function_type(
&mut self,
new_inputs: &[usize],
new_outputs: &[usize],
) -> Result<(), InvalidSignature> {
let [inp, out] = self.get_io();
let Self(hugr, _) = self;

// Record the old connections from and to the input and output nodes
let old_inputs_incoming = hugr
.node_outputs(inp)
.map(|p| hugr.linked_inputs(inp, p).collect_vec())
.collect_vec();
let old_outputs_outgoing = hugr
.node_inputs(out)
.map(|p| hugr.linked_outputs(out, p).collect_vec())
.collect_vec();

// The old signature types
let old_inp_sig = hugr
.get_optype(inp)
.dataflow_signature()
.expect("input has signature");
let old_inp_sig = old_inp_sig.output_types();
let old_out_sig = hugr
.get_optype(out)
.dataflow_signature()
.expect("output has signature");
let old_out_sig = old_out_sig.input_types();

// Check if the signature map is valid
check_valid_inputs(&old_inputs_incoming, old_inp_sig, new_inputs)?;
check_valid_outputs(old_out_sig, new_outputs)?;

// The new signature types
let new_inp_sig = new_inputs
.iter()
.map(|&i| old_inp_sig[i].clone())
.collect_vec();
let new_out_sig = new_outputs
.iter()
.map(|&i| old_out_sig[i].clone())
.collect_vec();
let new_sig = Signature::new(new_inp_sig, new_out_sig);

// Remove all edges of the input and output nodes
disconnect_all(hugr, inp);
disconnect_all(hugr, out);

// Update the signatures of the IO and their ancestors
let mut is_ancestor = false;
let mut node = hugr.entrypoint();
while matches!(hugr.get_optype(node), OpType::FuncDefn(_) | OpType::DFG(_)) {
let [inner_inp, inner_out] = hugr.get_io(node).expect("valid DFG graph");
for node in [node, inner_inp, inner_out] {
update_signature(hugr, node, &new_sig);
}
if is_ancestor {
update_inner_dfg_links(hugr, node);
}
if let Some(parent) = hugr.get_parent(node) {
node = parent;
is_ancestor = true;
} else {
break;
}
}

// Insert the new edges at the input
let mut old_output_to_new_input = BTreeMap::<IncomingPort, OutgoingPort>::new();
for (inp_pos, &old_pos) in new_inputs.iter().enumerate() {
for &(node, port) in &old_inputs_incoming[old_pos] {
if node != out {
hugr.connect(inp, inp_pos, node, port);
} else {
old_output_to_new_input.insert(port, inp_pos.into());
}
}
}
}
}

// Insert the new edges at the output
for (out_pos, &old_pos) in new_outputs.iter().enumerate() {
for &(node, port) in &old_outputs_outgoing[old_pos] {
if node != inp {
hugr.connect(node, port, out, out_pos);
} else {
let &inp_pos = old_output_to_new_input.get(&old_pos.into()).unwrap();
hugr.connect(inp, inp_pos, out, out_pos);
// Insert the new edges at the output
for (out_pos, &old_pos) in new_outputs.iter().enumerate() {
for &(node, port) in &old_outputs_outgoing[old_pos] {
if node != inp {
hugr.connect(node, port, out, out_pos);
} else {
let &inp_pos = old_output_to_new_input.get(&old_pos.into()).unwrap();
hugr.connect(inp, inp_pos, out, out_pos);
}
}
}

Ok(())
}
}

Ok(())
}
};
}

impl_dataflow_parent_methods!(DataflowParentID);
impl_dataflow_parent_methods!(DfgID);

/// Panics if the DFG within `node` is not a single inner DFG.
fn update_inner_dfg_links<H: HugrMut>(hugr: &mut H, node: H::Node) {
// connect all edges of the inner DFG to the input and output nodes
Expand Down Expand Up @@ -272,7 +283,7 @@ mod test {
};
use crate::extension::prelude::{bool_t, qb_t};
use crate::hugr::views::root_checked::RootChecked;
use crate::ops::handle::{DfgID, NodeHandle};
use crate::ops::handle::NodeHandle;
use crate::ops::{NamedOp, OpParent};
use crate::types::Signature;
use crate::utils::test_quantum_extension::cx_gate;
Expand All @@ -290,6 +301,51 @@ mod test {
let sig = Signature::new_endo(vec![qb_t(), qb_t()]);
let mut hugr = new_empty_dfg(sig);

// Wrap in RootChecked
let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();

// Test mapping inputs: [0,1] -> [1,0]
let input_map = vec![1, 0];
let output_map = vec![0, 1];

// Map the I/O
dfg_view.map_function_type(&input_map, &output_map).unwrap();

// Verify the new signature
let dfg_hugr = dfg_view.hugr();
let new_sig = dfg_hugr
.get_optype(dfg_hugr.entrypoint())
.dataflow_signature()
.unwrap();
assert_eq!(new_sig.input_count(), 2);
assert_eq!(new_sig.output_count(), 2);

// Test invalid mapping - missing input
let invalid_input_map = vec![0, 0];
let err = dfg_view.map_function_type(&invalid_input_map, &output_map);
assert!(matches!(err, Err(InvalidSignature::MissingIO(1, "input"))));

// Test invalid mapping - duplicate input
let invalid_input_map = vec![0, 0, 1];
assert!(matches!(
dfg_view.map_function_type(&invalid_input_map, &output_map),
Err(InvalidSignature::DuplicateInput(0))
));

// Test invalid mapping - unknown output
let invalid_output_map = vec![0, 2];
assert!(matches!(
dfg_view.map_function_type(&input_map, &invalid_output_map),
Err(InvalidSignature::UnknownIO(2, "output"))
));
}

#[test]
fn test_map_io_dfg_id() {
// Create a DFG with 2 inputs and 2 outputs
let sig = Signature::new_endo(vec![qb_t(), qb_t()]);
let mut hugr = new_empty_dfg(sig);

// Wrap in RootChecked
let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();

Expand Down Expand Up @@ -337,7 +393,7 @@ mod test {
let mut hugr = new_empty_dfg(sig);

// Wrap in RootChecked
let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();
let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();

// Test mapping outputs: [0] -> [0,0] (duplicating the output)
let input_map = vec![0];
Expand Down Expand Up @@ -377,7 +433,7 @@ mod test {
.unwrap();

// Wrap in RootChecked
let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();
let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();

// Test mapping inputs: [0,1] -> [1,0] (swapping inputs)
let input_map = vec![1, 0];
Expand Down
Loading