Skip to content

Commit cd1fef9

Browse files
committed
Export order edges via metadata in Python.
1 parent e4363b6 commit cd1fef9

File tree

1 file changed

+67
-10
lines changed

1 file changed

+67
-10
lines changed

hugr-py/src/hugr/model/export.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Helpers to export hugr graphs from their python representation to hugr model."""
22

3+
import json
34
from collections.abc import Sequence
45
from typing import Generic, TypeVar, cast
56

@@ -59,6 +60,22 @@ def export_node(self, node: Node) -> model.Node | None:
5960

6061
inputs = [self.link_name(InPort(node, i)) for i in range(node_data._num_inps)]
6162
outputs = [self.link_name(OutPort(node, i)) for i in range(node_data._num_outs)]
63+
meta = []
64+
65+
# Export JSON metadata
66+
for meta_name, meta_value in node_data.metadata.items():
67+
# TODO: Is this the correct way to convert the metadata as JSON?
68+
meta_json = json.dumps(meta_value)
69+
meta.append(
70+
model.Apply(
71+
"compat.meta_json",
72+
[model.Literal(meta_name), model.Literal(meta_json)],
73+
)
74+
)
75+
76+
# Add an order hint key to the node if necessary
77+
if _needs_order_key(self.hugr, node):
78+
meta.append(model.Apply("order_hint.key", [model.Literal(node.idx)]))
6279

6380
match node_data.op:
6481
case DFG() as op:
@@ -70,6 +87,7 @@ def export_node(self, node: Node) -> model.Node | None:
7087
signature=op.outer_signature().to_model(),
7188
inputs=inputs,
7289
outputs=outputs,
90+
meta=meta,
7391
)
7492

7593
case Custom() as op:
@@ -82,6 +100,7 @@ def export_node(self, node: Node) -> model.Node | None:
82100
signature=signature,
83101
inputs=inputs,
84102
outputs=outputs,
103+
meta=meta,
85104
)
86105

87106
case AsExtOp() as op:
@@ -96,6 +115,7 @@ def export_node(self, node: Node) -> model.Node | None:
96115
signature=signature,
97116
inputs=inputs,
98117
outputs=outputs,
118+
meta=meta,
99119
)
100120

101121
case Conditional() as op:
@@ -111,6 +131,7 @@ def export_node(self, node: Node) -> model.Node | None:
111131
signature=signature,
112132
inputs=inputs,
113133
outputs=outputs,
134+
meta=meta,
114135
)
115136

116137
case TailLoop() as op:
@@ -122,6 +143,7 @@ def export_node(self, node: Node) -> model.Node | None:
122143
signature=signature,
123144
inputs=inputs,
124145
outputs=outputs,
146+
meta=meta,
125147
)
126148

127149
case FuncDefn() as op:
@@ -132,30 +154,29 @@ def export_node(self, node: Node) -> model.Node | None:
132154
region = self.export_region_dfg(node)
133155

134156
return model.Node(
135-
operation=model.DefineFunc(symbol),
136-
regions=[region],
157+
operation=model.DefineFunc(symbol), regions=[region], meta=meta
137158
)
138159

139160
case FuncDecl() as op:
140161
name = _mangle_name(node, op.f_name)
141162
symbol = self.export_symbol(
142163
name, op.signature.params, op.signature.body
143164
)
144-
return model.Node(
145-
operation=model.DeclareFunc(symbol),
146-
)
165+
return model.Node(operation=model.DeclareFunc(symbol), meta=meta)
147166

148167
case AliasDecl() as op:
149168
symbol = model.Symbol(name=op.alias, signature=model.Apply("core.type"))
150169

151-
return model.Node(operation=model.DeclareAlias(symbol))
170+
return model.Node(operation=model.DeclareAlias(symbol), meta=meta)
152171

153172
case AliasDefn() as op:
154173
symbol = model.Symbol(name=op.alias, signature=model.Apply("core.type"))
155174

156175
alias_value = cast(model.Term, op.definition.to_model())
157176

158-
return model.Node(operation=model.DefineAlias(symbol, alias_value))
177+
return model.Node(
178+
operation=model.DefineAlias(symbol, alias_value), meta=meta
179+
)
159180

160181
case Call() as op:
161182
input_types = [type.to_model() for type in op.instantiation.input]
@@ -187,6 +208,7 @@ def export_node(self, node: Node) -> model.Node | None:
187208
signature=signature,
188209
inputs=inputs,
189210
outputs=outputs,
211+
meta=meta,
190212
)
191213

192214
case LoadFunc() as op:
@@ -211,6 +233,7 @@ def export_node(self, node: Node) -> model.Node | None:
211233
signature=signature,
212234
inputs=inputs,
213235
outputs=outputs,
236+
meta=meta,
214237
)
215238

216239
case CallIndirect() as op:
@@ -245,6 +268,7 @@ def export_node(self, node: Node) -> model.Node | None:
245268
signature=signature,
246269
inputs=inputs,
247270
outputs=outputs,
271+
meta=meta,
248272
)
249273

250274
case LoadConst() as op:
@@ -264,6 +288,7 @@ def export_node(self, node: Node) -> model.Node | None:
264288
signature=signature,
265289
inputs=inputs,
266290
outputs=outputs,
291+
meta=meta,
267292
)
268293

269294
case Const() as op:
@@ -273,13 +298,13 @@ def export_node(self, node: Node) -> model.Node | None:
273298
signature = op.outer_signature().to_model()
274299
region = self.export_region_cfg(node)
275300

276-
# TODO: Export CFGs
277301
return model.Node(
278302
operation=model.Cfg(),
279303
signature=signature,
280304
inputs=inputs,
281305
outputs=outputs,
282306
regions=[region],
307+
meta=meta,
283308
)
284309

285310
case DataflowBlock() as op:
@@ -319,6 +344,7 @@ def export_node(self, node: Node) -> model.Node | None:
319344
outputs=outputs,
320345
regions=[region],
321346
signature=signature,
347+
meta=meta,
322348
)
323349

324350
case Tag() as op:
@@ -343,6 +369,7 @@ def export_node(self, node: Node) -> model.Node | None:
343369
inputs=inputs,
344370
outputs=outputs,
345371
signature=signature,
372+
meta=meta,
346373
)
347374

348375
case op:
@@ -370,6 +397,7 @@ def export_region_dfg(self, node: Node) -> model.Region:
370397
target_types: model.Term = model.Wildcard()
371398
sources = []
372399
targets = []
400+
meta = []
373401

374402
for child in node_data.children:
375403
child_data = self.hugr[child]
@@ -392,8 +420,19 @@ def export_region_dfg(self, node: Node) -> model.Region:
392420
case _:
393421
child_node = self.export_node(child)
394422

395-
if child_node is not None:
396-
children.append(child_node)
423+
if child_node is None:
424+
continue
425+
426+
children.append(child_node)
427+
428+
meta += [
429+
model.Apply(
430+
"order_hint.order",
431+
[model.Literal(child.idx), model.Literal(successor.idx)],
432+
)
433+
for successor in self.hugr.outgoing_order_links(child)
434+
if not isinstance(self.hugr[successor].op, Output)
435+
]
397436

398437
signature = model.Apply("core.fn", [source_types, target_types, model.ExtSet()])
399438

@@ -564,3 +603,21 @@ def union(self, a: T, b: T):
564603

565604
self.parents[b] = a
566605
self.sizes[a] += self.sizes[b]
606+
607+
608+
def _needs_order_key(hugr: Hugr, node: Node) -> bool:
609+
"""Checks whether the node has any order links for the purposes of
610+
exporting order hint metadata. Order links to `Input` or `Output`
611+
operations are ignored, since they are not present in the model format.
612+
"""
613+
for succ in hugr.outgoing_order_links(node):
614+
succ_op = hugr[succ].op
615+
if not isinstance(succ_op, Output):
616+
return True
617+
618+
for pred in hugr.incoming_order_links(node):
619+
pred_op = hugr[pred].op
620+
if not isinstance(pred_op, Input):
621+
return True
622+
623+
return False

0 commit comments

Comments
 (0)