Skip to content

Commit d80e05d

Browse files
committed
Import input and output order hint keys.
1 parent eaae72d commit d80e05d

1 file changed

Lines changed: 51 additions & 23 deletions

File tree

hugr-core/src/import.rs

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ impl From<ExtensionError> for ImportError {
114114
enum OrderHintError {
115115
/// Duplicate order hint key in the same region.
116116
#[error("duplicate order hint key {0}")]
117-
DuplicateKey(table::NodeId, u64),
117+
DuplicateKey(table::RegionId, u64),
118118
/// Order hint including a key not defined in the region.
119119
#[error("order hint with unknown key {0}")]
120120
UnknownKey(u64),
@@ -608,7 +608,7 @@ impl<'a> Context<'a> {
608608
self.import_node(*child, node)?;
609609
}
610610

611-
self.create_order_edges(region)?;
611+
self.create_order_edges(region, input, output)?;
612612

613613
for meta_item in region_data.meta {
614614
self.import_node_metadata(node, *meta_item)?;
@@ -622,13 +622,18 @@ impl<'a> Context<'a> {
622622
/// Create order edges between nodes of a dataflow region based on order hint metadata.
623623
///
624624
/// This method assumes that the nodes for the children of the region have already been imported.
625-
fn create_order_edges(&mut self, region_id: table::RegionId) -> Result<(), ImportError> {
625+
fn create_order_edges(
626+
&mut self,
627+
region_id: table::RegionId,
628+
input: Node,
629+
output: Node,
630+
) -> Result<(), ImportError> {
626631
let region_data = self.get_region(region_id)?;
627632
debug_assert_eq!(region_data.kind, model::RegionKind::DataFlow);
628633

629634
// Collect order hint keys
630635
// PERFORMANCE: It might be worthwhile to reuse the map to avoid allocations.
631-
let mut order_keys = FxHashMap::<u64, table::NodeId>::default();
636+
let mut order_keys = FxHashMap::<u64, Node>::default();
632637

633638
for child_id in region_data.children {
634639
let child_data = self.get_node(*child_id)?;
@@ -642,8 +647,42 @@ impl<'a> Context<'a> {
642647
continue;
643648
};
644649

645-
if order_keys.insert(*key, *child_id).is_some() {
646-
return Err(OrderHintError::DuplicateKey(*child_id, *key).into());
650+
// NOTE: The lookups here are expected to succeed since we only
651+
// process the order metadata after we have imported the nodes.
652+
let child_node = self.nodes[child_id];
653+
let child_optype = self.hugr.get_optype(child_node);
654+
655+
// Check that the node has order ports.
656+
// NOTE: This assumes that a node has an input order port iff it has an output one.
657+
if !child_optype.other_output_port().is_some() {
658+
return Err(OrderHintError::NoOrderPort(*child_id).into());
659+
}
660+
661+
if order_keys.insert(*key, child_node).is_some() {
662+
return Err(OrderHintError::DuplicateKey(region_id, *key).into());
663+
}
664+
}
665+
}
666+
667+
// Collect the order hint keys for the input and output nodes
668+
for meta_id in region_data.meta {
669+
if let Some([key]) = self.match_symbol(*meta_id, model::ORDER_HINT_INPUT_KEY)? {
670+
let table::Term::Literal(model::Literal::Nat(key)) = self.get_term(key)? else {
671+
continue;
672+
};
673+
674+
if order_keys.insert(*key, input).is_some() {
675+
return Err(OrderHintError::DuplicateKey(region_id, *key).into());
676+
}
677+
}
678+
679+
if let Some([key]) = self.match_symbol(*meta_id, model::ORDER_HINT_OUTPUT_KEY)? {
680+
let table::Term::Literal(model::Literal::Nat(key)) = self.get_term(key)? else {
681+
continue;
682+
};
683+
684+
if order_keys.insert(*key, output).is_some() {
685+
return Err(OrderHintError::DuplicateKey(region_id, *key).into());
647686
}
648687
}
649688
}
@@ -665,24 +704,13 @@ impl<'a> Context<'a> {
665704
let a = order_keys.get(a).ok_or(OrderHintError::UnknownKey(*a))?;
666705
let b = order_keys.get(b).ok_or(OrderHintError::UnknownKey(*b))?;
667706

668-
// NOTE: The lookups here are expected to succeed since we only
669-
// process the order metadata after we have imported the nodes.
670-
let a_node = self.nodes[a];
671-
let b_node = self.nodes[b];
672-
673-
let a_port = self
674-
.hugr
675-
.get_optype(a_node)
676-
.other_output_port()
677-
.ok_or(OrderHintError::NoOrderPort(*a))?;
678-
679-
let b_port = self
680-
.hugr
681-
.get_optype(b_node)
682-
.other_input_port()
683-
.ok_or(OrderHintError::NoOrderPort(*b))?;
707+
// NOTE: The unwrap here must succeed:
708+
// - For all ordinary nodes we checked that they have an order port.
709+
// - Input and output nodes always have an order port.
710+
let a_port = self.hugr.get_optype(*a).other_output_port().unwrap();
711+
let b_port = self.hugr.get_optype(*b).other_input_port().unwrap();
684712

685-
self.hugr.connect(a_node, a_port, b_node, b_port);
713+
self.hugr.connect(*a, a_port, *b, b_port);
686714
}
687715

688716
Ok(())

0 commit comments

Comments
 (0)