Skip to content

Commit 054f8d1

Browse files
committed
Import and export order edges via metadata.
1 parent 2e69e9a commit 054f8d1

File tree

7 files changed

+359
-51
lines changed

7 files changed

+359
-51
lines changed

hugr-core/src/export.rs

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Exporting HUGR graphs to their `hugr-model` representation.
22
use crate::{
33
extension::{ExtensionId, OpDef, SignatureFunc},
4-
hugr::{IdentList, NodeMetadataMap},
4+
hugr::IdentList,
55
ops::{
66
constant::CustomSerialized, DataflowBlock, DataflowOpTrait, OpName, OpTrait, OpType, Value,
77
},
@@ -12,10 +12,10 @@ use crate::{
1212
types::{
1313
type_param::{TypeArgVariable, TypeParam},
1414
type_row::TypeRowBase,
15-
CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg,
16-
TypeBase, TypeBound, TypeEnum, TypeRow,
15+
CustomType, EdgeKind, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType,
16+
TypeArg, TypeBase, TypeBound, TypeEnum, TypeRow,
1717
},
18-
Direction, Hugr, HugrView, IncomingPort, Node, Port,
18+
Direction, Hugr, HugrView, IncomingPort, Node, NodeIndex as _, Port,
1919
};
2020

2121
use fxhash::{FxBuildHasher, FxHashMap};
@@ -467,10 +467,7 @@ impl<'a> Context<'a> {
467467
let inputs = self.make_ports(node, Direction::Incoming, num_inputs);
468468
let outputs = self.make_ports(node, Direction::Outgoing, num_outputs);
469469

470-
let meta = match self.hugr.get_node_metadata(node) {
471-
Some(metadata_map) => self.export_node_metadata(metadata_map),
472-
None => &[],
473-
};
470+
let meta = self.export_node_metadata(node);
474471

475472
self.module.nodes[node_id.index()] = table::Node {
476473
operation,
@@ -577,6 +574,7 @@ impl<'a> Context<'a> {
577574
let mut targets: &[_] = &[];
578575
let mut input_types = None;
579576
let mut output_types = None;
577+
let mut meta = Vec::new();
580578

581579
let children = self.hugr.children(node);
582580
let mut region_children = BumpVec::with_capacity_in(children.size_hint().0 - 2, self.bump);
@@ -591,9 +589,25 @@ impl<'a> Context<'a> {
591589
targets = self.make_ports(child, Direction::Incoming, output.types.len());
592590
output_types = Some(&output.types);
593591
}
594-
_ => {
592+
child_optype => {
595593
if let Some(child_id) = self.export_node_shallow(child) {
596594
region_children.push(child_id);
595+
596+
// Record all order edges that originate from this node in metadata.
597+
let successors = child_optype
598+
.other_output_port()
599+
.into_iter()
600+
.flat_map(|port| self.hugr.linked_inputs(child, port))
601+
.map(|(successor, _)| successor)
602+
.filter(|successor| !self.hugr.get_optype(*successor).is_output());
603+
604+
for successor in successors {
605+
let a =
606+
self.make_term(model::Literal::Nat(child.index() as u64).into());
607+
let b = self
608+
.make_term(model::Literal::Nat(successor.index() as u64).into());
609+
meta.push(self.make_term_apply(model::ORDER_HINT_ORDER, &[a, b]));
610+
}
597611
}
598612
}
599613
}
@@ -623,7 +637,7 @@ impl<'a> Context<'a> {
623637
sources,
624638
targets,
625639
children: region_children.into_bump_slice(),
626-
meta: &[], // TODO: Export metadata
640+
meta: self.bump.alloc_slice_copy(&meta),
627641
signature,
628642
scope,
629643
};
@@ -1002,11 +1016,37 @@ impl<'a> Context<'a> {
10021016
}
10031017
}
10041018

1005-
pub fn export_node_metadata(&mut self, metadata_map: &NodeMetadataMap) -> &'a [table::TermId] {
1006-
let mut meta = BumpVec::with_capacity_in(metadata_map.len(), self.bump);
1019+
pub fn export_node_metadata(&mut self, node: Node) -> &'a [table::TermId] {
1020+
let metadata_map = self.hugr.get_node_metadata(node);
1021+
1022+
let has_order_edges = {
1023+
fn is_relevant_node(hugr: &Hugr, node: Node) -> bool {
1024+
let optype = hugr.get_optype(node);
1025+
!optype.is_input() && !optype.is_output()
1026+
}
1027+
1028+
let optype = self.hugr.get_optype(node);
1029+
1030+
Direction::BOTH
1031+
.iter()
1032+
.filter(|dir| optype.other_port_kind(**dir) == Some(EdgeKind::StateOrder))
1033+
.filter_map(|dir| optype.other_port(*dir))
1034+
.flat_map(|port| self.hugr.linked_ports(node, port))
1035+
.any(|(other, _)| is_relevant_node(self.hugr, other))
1036+
};
1037+
1038+
let meta_capacity = metadata_map.map_or(0, |map| map.len()) + has_order_edges as usize;
1039+
let mut meta = BumpVec::with_capacity_in(meta_capacity, self.bump);
1040+
1041+
if let Some(metadata_map) = metadata_map {
1042+
for (name, value) in metadata_map {
1043+
meta.push(self.export_json_meta(name, value));
1044+
}
1045+
}
10071046

1008-
for (name, value) in metadata_map {
1009-
meta.push(self.export_json_meta(name, value));
1047+
if has_order_edges {
1048+
let key = self.make_term(model::Literal::Nat(node.index() as u64).into());
1049+
meta.push(self.make_term_apply(model::ORDER_HINT_KEY, &[key]));
10101050
}
10111051

10121052
meta.into_bump_slice()

hugr-core/src/import.rs

Lines changed: 94 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::sync::Arc;
77

88
use crate::{
99
extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError},
10-
hugr::HugrMut,
10+
hugr::{HugrMut, IdentList, NodeMetadata},
1111
ops::{
1212
constant::{CustomConst, CustomSerialized, OpaqueValue},
1313
AliasDecl, AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock,
@@ -180,11 +180,27 @@ impl<'a> Context<'a> {
180180
self.record_links(node, Direction::Incoming, node_data.inputs);
181181
self.record_links(node, Direction::Outgoing, node_data.outputs);
182182

183+
// Import the JSON metadata
183184
for meta_item in node_data.meta {
184-
// TODO: For now we expect all metadata to be JSON since this is how
185-
// it is handled in `hugr-core`.
186-
let (name, value) = self.import_json_meta(*meta_item)?;
187-
self.hugr.set_metadata(node, name, value);
185+
let Some([name_arg, json_arg]) =
186+
self.match_symbol(*meta_item, model::COMPAT_META_JSON)?
187+
else {
188+
continue;
189+
};
190+
191+
let table::Term::Literal(model::Literal::Str(name)) = self.get_term(name_arg)? else {
192+
return Err(table::ModelError::TypeError(*meta_item).into());
193+
};
194+
195+
let table::Term::Literal(model::Literal::Str(json_str)) = self.get_term(json_arg)?
196+
else {
197+
return Err(table::ModelError::TypeError(*meta_item).into());
198+
};
199+
200+
let json_value: NodeMetadata = serde_json::from_str(json_str)
201+
.map_err(|_| table::ModelError::TypeError(*meta_item))?;
202+
203+
self.hugr.set_metadata(node, name, json_value);
188204
}
189205

190206
Ok(node)
@@ -600,11 +616,84 @@ impl<'a> Context<'a> {
600616
self.import_node(*child, node)?;
601617
}
602618

619+
self.create_order_edges(region)?;
620+
603621
self.region_scope = prev_region;
604622

605623
Ok(())
606624
}
607625

626+
/// Create order edges between nodes of a dataflow region based on order hint metadata.
627+
///
628+
/// This method assumes that the nodes for the children of the region have already been imported.
629+
fn create_order_edges(&mut self, region_id: table::RegionId) -> Result<(), ImportError> {
630+
let region_data = self.get_region(region_id)?;
631+
debug_assert_eq!(region_data.kind, model::RegionKind::DataFlow);
632+
633+
// Collect order hint keys
634+
// PERFORMANCE: It might be worthwhile to reuse the map to avoid allocations.
635+
let mut order_keys = FxHashMap::<u64, table::NodeId>::default();
636+
637+
for child_id in region_data.children {
638+
let child_data = self.get_node(*child_id)?;
639+
640+
for meta_id in child_data.meta {
641+
let Some([key]) = self.match_symbol(*meta_id, model::ORDER_HINT_KEY)? else {
642+
continue;
643+
};
644+
645+
let table::Term::Literal(model::Literal::Nat(key)) = self.get_term(key)? else {
646+
continue;
647+
};
648+
649+
// TODO: Error on duplicate key
650+
order_keys.insert(*key, *child_id);
651+
}
652+
}
653+
654+
// Insert order edges
655+
for meta_id in region_data.meta {
656+
let Some([a, b]) = self.match_symbol(*meta_id, model::ORDER_HINT_ORDER)? else {
657+
continue;
658+
};
659+
660+
let table::Term::Literal(model::Literal::Nat(a)) = self.get_term(a)? else {
661+
continue;
662+
};
663+
664+
let table::Term::Literal(model::Literal::Nat(b)) = self.get_term(b)? else {
665+
continue;
666+
};
667+
668+
// TODO: Proper error on non-existing key
669+
let a = order_keys.get(a).expect("unknown order key");
670+
let b = order_keys.get(b).expect("unknown order key");
671+
672+
// NOTE: The lookups here are expected to succeed since we only
673+
// process the order metadata after we have imported the nodes.
674+
let a_node = self.nodes[a];
675+
let b_node = self.nodes[b];
676+
677+
// Find the order ports
678+
// TODO: Proper error on non-existing order port
679+
let a_port = self
680+
.hugr
681+
.get_optype(a_node)
682+
.other_output_port()
683+
.expect("order hint on node without order port");
684+
685+
let b_port = self
686+
.hugr
687+
.get_optype(b_node)
688+
.other_input_port()
689+
.expect("order hint on node without order port");
690+
691+
self.hugr.connect(a_node, a_port, b_node, b_port);
692+
}
693+
694+
Ok(())
695+
}
696+
608697
fn import_adt_and_rest(
609698
&mut self,
610699
node_id: table::NodeId,
@@ -1341,28 +1430,6 @@ impl<'a> Context<'a> {
13411430
}
13421431
}
13431432

1344-
fn import_json_meta(
1345-
&mut self,
1346-
term_id: table::TermId,
1347-
) -> Result<(&'a str, serde_json::Value), ImportError> {
1348-
let [name_arg, json_arg] = self
1349-
.match_symbol(term_id, model::COMPAT_META_JSON)?
1350-
.ok_or(table::ModelError::TypeError(term_id))?;
1351-
1352-
let table::Term::Literal(model::Literal::Str(name)) = self.get_term(name_arg)? else {
1353-
return Err(table::ModelError::TypeError(term_id).into());
1354-
};
1355-
1356-
let table::Term::Literal(model::Literal::Str(json_str)) = self.get_term(json_arg)? else {
1357-
return Err(table::ModelError::TypeError(term_id).into());
1358-
};
1359-
1360-
let json_value =
1361-
serde_json::from_str(json_str).map_err(|_| table::ModelError::TypeError(term_id))?;
1362-
1363-
Ok((name, json_value))
1364-
}
1365-
13661433
fn import_value(
13671434
&mut self,
13681435
term_id: table::TermId,

hugr-core/tests/model.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,10 @@ pub fn test_roundtrip_const() {
7777
"../../hugr-model/tests/fixtures/model-const.edn"
7878
)));
7979
}
80+
81+
#[test]
82+
pub fn test_roundtrip_order() {
83+
insta::assert_snapshot!(roundtrip(include_str!(
84+
"../../hugr-model/tests/fixtures/model-order.edn"
85+
)));
86+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
---
2+
source: hugr-core/tests/model.rs
3+
expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-order.edn\"))"
4+
---
5+
(hugr 0)
6+
7+
(import order_hint.order)
8+
9+
(import order_hint.key)
10+
11+
(import core.fn)
12+
13+
(import arithmetic.int.types.int)
14+
15+
(import arithmetic.int.ineg)
16+
17+
(define-func
18+
main
19+
(core.fn
20+
[arithmetic.int.types.int
21+
arithmetic.int.types.int
22+
arithmetic.int.types.int
23+
arithmetic.int.types.int]
24+
[arithmetic.int.types.int
25+
arithmetic.int.types.int
26+
arithmetic.int.types.int
27+
arithmetic.int.types.int]
28+
(ext))
29+
(dfg [%0 %1 %2 %3] [%4 %5 %6 %7]
30+
(signature
31+
(core.fn
32+
[arithmetic.int.types.int
33+
arithmetic.int.types.int
34+
arithmetic.int.types.int
35+
arithmetic.int.types.int]
36+
[arithmetic.int.types.int
37+
arithmetic.int.types.int
38+
arithmetic.int.types.int
39+
arithmetic.int.types.int]
40+
(ext)))
41+
(meta (order_hint.order 4 7))
42+
(meta (order_hint.order 5 6))
43+
(meta (order_hint.order 5 4))
44+
(meta (order_hint.order 6 7))
45+
(arithmetic.int.ineg [%0] [%4]
46+
(signature
47+
(core.fn [arithmetic.int.types.int] [arithmetic.int.types.int] (ext)))
48+
(meta (order_hint.key 4)))
49+
(arithmetic.int.ineg [%1] [%5]
50+
(signature
51+
(core.fn [arithmetic.int.types.int] [arithmetic.int.types.int] (ext)))
52+
(meta (order_hint.key 5)))
53+
(arithmetic.int.ineg [%2] [%6]
54+
(signature
55+
(core.fn [arithmetic.int.types.int] [arithmetic.int.types.int] (ext)))
56+
(meta (order_hint.key 6)))
57+
(arithmetic.int.ineg [%3] [%7]
58+
(signature
59+
(core.fn [arithmetic.int.types.int] [arithmetic.int.types.int] (ext)))
60+
(meta (order_hint.key 7)))))

hugr-model/src/v0/mod.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,33 @@ pub const COMPAT_META_JSON: &str = "compat.meta_json";
265265
/// - **Result:** `(core.const ?type)`
266266
pub const COMPAT_CONST_JSON: &str = "compat.const_json";
267267

268+
/// Metadata constructor for order hint keys.
269+
///
270+
/// Nodes in a dataflow region can be annotated with a key. Each node may have
271+
/// at most one key and the key must be unique among all nodes in the same
272+
/// dataflow region. The parent dataflow graph can then use the
273+
/// `order_hint.order` metadata to imply a desired ordering relation, referring
274+
/// to the nodes by their key.
275+
///
276+
/// - **Parameter:** `?key : core.nat`
277+
/// - **Result:** `core.meta`
278+
pub const ORDER_HINT_KEY: &str = "order_hint.key";
279+
280+
/// Metadata constructor for order hints.
281+
///
282+
/// When this metadata is attached to a dataflow region, it can indicate a
283+
/// preferred ordering relation between child nodes. Code generation must take
284+
/// this into account when deciding on an execution order. The child nodes are
285+
/// identified by a key, using the `order_hint.key` metadata.
286+
///
287+
/// The graph consisting of both value dependencies between nodes and order
288+
/// hints must be directed acyclic.
289+
///
290+
/// - **Parameter:** `?before : core.nat`
291+
/// - **Parameter:** `?after : core.nat`
292+
/// - **Result:** `core.meta`
293+
pub const ORDER_HINT_ORDER: &str = "order_hint.order";
294+
268295
pub mod ast;
269296
pub mod binary;
270297
pub mod scope;

0 commit comments

Comments
 (0)