Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ impl<'a> Context<'a> {
node,
model::ScopeClosure::Open,
false,
false,
)]);
table::Operation::Dfg
}
Expand All @@ -333,6 +334,7 @@ impl<'a> Context<'a> {
node,
model::ScopeClosure::Open,
false,
false,
)]);
table::Operation::Block
}
Expand All @@ -349,6 +351,7 @@ impl<'a> Context<'a> {
node,
model::ScopeClosure::Closed,
false,
false,
)]);
table::Operation::DefineFunc(symbol)
}),
Expand Down Expand Up @@ -462,6 +465,7 @@ impl<'a> Context<'a> {
node,
model::ScopeClosure::Open,
false,
false,
)]);
table::Operation::TailLoop
}
Expand Down Expand Up @@ -515,6 +519,7 @@ impl<'a> Context<'a> {

self.export_node_json_metadata(node, &mut meta);
self.export_node_order_metadata(node, &mut meta);
self.export_node_entrypoint_metadata(node, &mut meta);
let meta = self.bump.alloc_slice_copy(&meta);

self.module.nodes[node_id.index()] = table::Node {
Expand Down Expand Up @@ -613,6 +618,7 @@ impl<'a> Context<'a> {
node: Node,
closure: model::ScopeClosure,
export_json_meta: bool,
export_entrypoint_meta: bool,
) -> table::RegionId {
let region = self.module.insert_region(table::Region::default());

Expand All @@ -631,7 +637,9 @@ impl<'a> Context<'a> {
if export_json_meta {
self.export_node_json_metadata(node, &mut meta);
}
self.export_node_entrypoint_metadata(node, &mut meta);
if export_entrypoint_meta {
self.export_node_entrypoint_metadata(node, &mut meta);
}

let children = self.hugr.children(node);
let mut region_children = BumpVec::with_capacity_in(children.size_hint().0 - 2, self.bump);
Expand Down Expand Up @@ -801,7 +809,7 @@ impl<'a> Context<'a> {
panic!("expected a `Case` node as a child of a `Conditional` node");
};

regions.push(self.export_dfg(child, model::ScopeClosure::Open, true));
regions.push(self.export_dfg(child, model::ScopeClosure::Open, true, true));
}

regions.into_bump_slice()
Expand Down Expand Up @@ -1076,7 +1084,7 @@ impl<'a> Context<'a> {

let region = match hugr.entrypoint_optype() {
OpType::DFG(_) => {
self.export_dfg(hugr.entrypoint(), model::ScopeClosure::Closed, true)
self.export_dfg(hugr.entrypoint(), model::ScopeClosure::Closed, true, true)
}
_ => panic!("Value::Function root must be a DFG"),
};
Expand Down
7 changes: 5 additions & 2 deletions hugr-core/tests/snapshots/model__roundtrip_entrypoint.snap
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ expression: ast
(import core.entrypoint)

(define-func public main (core.fn [] [])
(dfg (signature (core.fn [] [])) (meta core.entrypoint)))
(meta core.entrypoint)
(dfg (signature (core.fn [] []))))

(mod)

Expand All @@ -20,7 +21,8 @@ expression: ast
(import core.entrypoint)

(define-func public wrapper_dfg (core.fn [] [])
(dfg (signature (core.fn [] [])) (meta core.entrypoint)))
(meta core.entrypoint)
(dfg (signature (core.fn [] []))))

(mod)

Expand All @@ -39,6 +41,7 @@ expression: ast
(signature (core.fn [] []))
(cfg
(signature (core.fn [] []))
(meta core.entrypoint)
(cfg [%0] [%1]
(signature (core.ctrl [[]] [[]]))
(meta core.entrypoint)
Expand Down
10 changes: 7 additions & 3 deletions hugr-py/src/hugr/model/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def export_node(

outputs = [self.link_name(OutPort(node, i)) for i in range(node_data._num_outs)]
meta = self.export_json_meta(node)
meta += self.export_entrypoint_meta(node)

# Add an order hint key to the node if necessary
if _has_order_links(self.hugr, node):
Expand Down Expand Up @@ -123,7 +124,8 @@ def export_node(

case Conditional() as op:
regions = [
self.export_region_dfg(child) for child in node_data.children
self.export_region_dfg(child, entrypoint_meta=True)
for child in node_data.children
]

signature = op.outer_signature().to_model()
Expand Down Expand Up @@ -424,16 +426,18 @@ def export_region_module(self, node: Node) -> model.Region:

return model.Region(kind=model.RegionKind.MODULE, children=children, meta=meta)

def export_region_dfg(self, node: Node) -> model.Region:
def export_region_dfg(self, node: Node, entrypoint_meta=False) -> model.Region:
"""Export the children of a node as a dataflow region."""
node_data = self.hugr[node]
children: list[model.Node] = []
source_types: model.Term = model.Wildcard()
target_types: model.Term = model.Wildcard()
sources = []
targets = []
meta = []

meta = self.export_entrypoint_meta(node)
if entrypoint_meta:
meta += self.export_entrypoint_meta(node)

for child in node_data.children:
child_data = self.hugr[child]
Expand Down
Loading