Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use crate::{
},
};
use fxhash::FxHashMap;
use hugr_model::v0 as model;
use hugr_model::v0::table;
use hugr_model::v0::{self as model};
use itertools::{Either, Itertools};
use smol_str::{SmolStr, ToSmolStr};
use thiserror::Error;
Expand Down Expand Up @@ -433,7 +433,7 @@ impl<'a> Context<'a> {
let region_data = self.get_region(self.module.root)?;

for node in region_data.children {
self.import_node(*node, self.hugr.entrypoint())?;
self.import_node(*node, self.hugr.module_root())?;
}

for meta_item in region_data.meta {
Expand Down
15 changes: 12 additions & 3 deletions hugr-py/src/hugr/model/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def __init__(self, hugr: Hugr):
self.link_names: dict[InPort | OutPort, str] = {}
self.link_next = 0

# TODO: Store the hugr entrypoint

for a, b in self.hugr.links():
self.link_ports.union(a, b)

Expand Down Expand Up @@ -397,10 +395,18 @@ def export_json_meta(self, node: Node) -> list[model.Term]:

return meta

def export_entrypoint_meta(self, node: Node) -> list[model.Term]:
"""Export entrypoint metadata if the node is the module entrypoint."""
if self.hugr.entrypoint == node:
return [model.Apply("core.entrypoint")]
else:
return []

def export_region_module(self, node: Node) -> model.Region:
"""Export a module node as a module region."""
node_data = self.hugr[node]
meta = self.export_json_meta(node)
meta += self.export_entrypoint_meta(node)
children = []

for child in node_data.children:
Expand All @@ -419,7 +425,8 @@ def export_region_dfg(self, node: Node) -> model.Region:
target_types: model.Term = model.Wildcard()
sources = []
targets = []
meta = []

meta = self.export_entrypoint_meta(node)

for child in node_data.children:
child_data = self.hugr[child]
Expand Down Expand Up @@ -489,6 +496,7 @@ def export_region_cfg(self, node: Node) -> model.Region:
source_types: model.Term = model.Wildcard()
target_types: model.Term = model.Wildcard()
children = []
meta = self.export_entrypoint_meta(node)

for child in node_data.children:
child_data = self.hugr[child]
Expand Down Expand Up @@ -540,6 +548,7 @@ def export_region_cfg(self, node: Node) -> model.Region:
sources=[source],
signature=signature,
children=children,
meta=meta,
)

def export_symbol(
Expand Down
Loading