Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,6 @@ impl<'a> Context<'a> {
pub fn export_block_signature(&mut self, block: &DataflowBlock) -> table::TermId {
let inputs = {
let inputs = self.export_type_row(&block.inputs);
let inputs = self.make_term_apply(model::CORE_CTRL, &[inputs]);
self.make_term(table::Term::List(
self.bump.alloc_slice_copy(&[table::SeqPart::Item(inputs)]),
))
Expand All @@ -590,13 +589,12 @@ impl<'a> Context<'a> {
let mut outputs = BumpVec::with_capacity_in(block.sum_rows.len(), self.bump);
for sum_row in &block.sum_rows {
let variant = self.export_type_row_with_tail(sum_row, Some(tail));
let control = self.make_term_apply(model::CORE_CTRL, &[variant]);
outputs.push(table::SeqPart::Item(control));
outputs.push(table::SeqPart::Item(variant));
}
self.make_term(table::Term::List(outputs.into_bump_slice()))
};

self.make_term_apply(model::CORE_FN, &[inputs, outputs])
self.make_term_apply(model::CORE_CTRL, &[inputs, outputs])
}

/// Creates a data flow region from the given node's children.
Expand Down Expand Up @@ -742,16 +740,14 @@ impl<'a> Context<'a> {

let mut wrap_ctrl = |types: &TypeRow| {
let types = self.export_type_row(types);
let types_ctrl = self.make_term_apply(model::CORE_CTRL, &[types]);
self.make_term(table::Term::List(
self.bump
.alloc_slice_copy(&[table::SeqPart::Item(types_ctrl)]),
self.bump.alloc_slice_copy(&[table::SeqPart::Item(types)]),
))
};

let inputs = wrap_ctrl(node_signature.input());
let outputs = wrap_ctrl(node_signature.output());
Some(self.make_term_apply(model::CORE_FN, &[inputs, outputs]))
Some(self.make_term_apply(model::CORE_CTRL, &[inputs, outputs]))
};

let scope = match closure {
Expand Down
45 changes: 18 additions & 27 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ impl<'a> Context<'a> {
}

let region_target_types = (|| {
let [_, region_targets] = self.get_func_type(
let [_, region_targets] = self.get_ctrl_type(
region_data
.signature
.ok_or_else(|| error_uninferred!("region signature"))?,
Expand Down Expand Up @@ -860,25 +860,18 @@ impl<'a> Context<'a> {
};

// The entry node in core control flow regions is identified by being
// the first child node of the CFG node. We therefore import the entry
// node first and follow it up by every other node.
// the first child node of the CFG node. We therefore import the entry node first.
self.import_node(entry_node, node)?;

for child in region_data.children {
if *child != entry_node {
self.import_node(*child, node)?;
}
}

// Create the exit node for the control flow region.
// Create the exit node for the control flow region. This always needs
// to be second in the node list.
{
let cfg_outputs = {
let [ctrl_type] = region_target_types.as_slice() else {
let [target_types] = region_target_types.as_slice() else {
return Err(error_invalid!("cfg region expects a single target"));
};

let [types] = self.expect_symbol(*ctrl_type, model::CORE_CTRL)?;
self.import_type_row(types)?
self.import_type_row(*target_types)?
};

let exit = self
Expand All @@ -887,6 +880,13 @@ impl<'a> Context<'a> {
self.record_links(exit, Direction::Incoming, region_data.targets);
}

// Finally we import all other nodes.
for child in region_data.children {
if *child != entry_node {
self.import_node(*child, node)?;
}
}

for meta_item in region_data.meta {
self.import_node_metadata(node, *meta_item)
.map_err(|err| error_context!(err, "node metadata"))?;
Expand Down Expand Up @@ -1245,13 +1245,6 @@ impl<'a> Context<'a> {
return Err(error_unsupported!("`{}` as `TypeParam`", model::CORE_CONST));
}

if let Some([]) = self.match_symbol(term_id, model::CORE_CTRL_TYPE)? {
return Err(error_unsupported!(
"`{}` as `TypeParam`",
model::CORE_CTRL_TYPE
));
}

if let Some([item_type]) = self.match_symbol(term_id, model::CORE_LIST_TYPE)? {
// At present `hugr-model` has no way to express that the item
// type of a list must be copyable. Therefore we import it as `Any`.
Expand Down Expand Up @@ -1339,13 +1332,6 @@ impl<'a> Context<'a> {
return Err(error_unsupported!("`{}` as `TypeArg`", model::CORE_STATIC));
}

if let Some([]) = self.match_symbol(term_id, model::CORE_CTRL_TYPE)? {
return Err(error_unsupported!(
"`{}` as `TypeArg`",
model::CORE_CTRL_TYPE
));
}

if let Some([]) = self.match_symbol(term_id, model::CORE_CONST)? {
return Err(error_unsupported!("`{}` as `TypeArg`", model::CORE_CONST));
}
Expand Down Expand Up @@ -1510,6 +1496,11 @@ impl<'a> Context<'a> {
.ok_or(error_invalid!("expected a function type"))
}

fn get_ctrl_type(&mut self, term_id: table::TermId) -> Result<[table::TermId; 2], ImportError> {
self.match_symbol(term_id, model::CORE_CTRL)?
.ok_or(error_invalid!("expected a control type"))
}

fn import_func_type<RV: MaybeRV>(
&mut self,
term_id: table::TermId,
Expand Down
13 changes: 6 additions & 7 deletions hugr-core/tests/snapshots/model__roundtrip_cfg.snap
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
source: hugr-core/tests/model.rs
expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg.edn\"))"
expression: ast
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what's going on here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was some refactoring in an earlier PR that changed how the snapshot test is called. The snapshot itself didn't change, and so the snapshot file was still considered valid. Now the snapshot changed, which also led to a change in the metadata that captures how the snapshot test is called.

---
(hugr 0)

Expand All @@ -22,10 +22,9 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg.
(cfg [%0] [%1]
(signature (core.fn [?0] [?0]))
(cfg [%2] [%3]
(signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])]))
(signature (core.ctrl [[?0]] [[?0]]))
(block [%2] [%3 %2]
(signature
(core.fn [(core.ctrl [?0])] [(core.ctrl [?0]) (core.ctrl [?0])]))
(signature (core.ctrl [[?0]] [[?0] [?0]]))
(dfg [%4] [%5]
(signature (core.fn [?0] [(core.adt [[?0] [?0]])]))
((core.make_adt 0) [%4] [%5]
Expand All @@ -37,15 +36,15 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg.
(cfg [%0] [%1]
(signature (core.fn [?0] [?0]))
(cfg [%2] [%3]
(signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])]))
(signature (core.ctrl [[?0]] [[?0]]))
(block [%2] [%6]
(signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])]))
(signature (core.ctrl [[?0]] [[?0]]))
(dfg [%4] [%5]
(signature (core.fn [?0] [(core.adt [[?0]])]))
((core.make_adt 0) [%4] [%5]
(signature (core.fn [?0] [(core.adt [[?0]])])))))
(block [%6] [%3]
(signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])]))
(signature (core.ctrl [[?0]] [[?0]]))
(dfg [%7] [%8]
(signature (core.fn [?0] [(core.adt [[?0]])]))
((core.make_adt 0) [%7] [%8]
Expand Down
6 changes: 3 additions & 3 deletions hugr-core/tests/snapshots/model__roundtrip_entrypoint.snap
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
source: hugr-core/tests/model.rs
expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-entrypoint.edn\"))"
expression: ast
---
(hugr 0)

Expand Down Expand Up @@ -40,10 +40,10 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-entr
(cfg
(signature (core.fn [] []))
(cfg [%0] [%1]
(signature (core.fn [(core.ctrl [])] [(core.ctrl [])]))
(signature (core.ctrl [[]] [[]]))
(meta core.entrypoint)
(block [%0] [%1]
(signature (core.fn [(core.ctrl [])] [(core.ctrl [])]))
(signature (core.ctrl [[]] [[]]))
(dfg [] [%2]
(signature (core.fn [] [(core.adt [[]])]))
((core.make_adt 0) [] [%2]
Expand Down
12 changes: 4 additions & 8 deletions hugr-model/src/v0/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,13 @@ pub const CORE_BYTES_TYPE: &str = "core.bytes";
/// - **Result:** `core.static`
pub const CORE_FLOAT_TYPE: &str = "core.float";

/// Type of a control flow edge.
/// Type of control flow regions.
///
/// - **Parameter:** `?types : (core.list core.type)`
/// - **Result:** `core.ctrl_type`
/// - **Parameter:** `?inputs : (core.list (core.list core.type))`
/// - **Parameter:** `?outputs : (core.list (core.list core.type))`
/// - **Result:** `core.type`
pub const CORE_CTRL: &str = "core.ctrl";

/// The type of the types for control flow edges.
///
/// - **Result:** `?type : core.static`
pub const CORE_CTRL_TYPE: &str = "core.ctrl_type";

/// The type for runtime constants.
///
/// - **Parameter:** `?type : core.type`
Expand Down
28 changes: 14 additions & 14 deletions hugr-model/tests/fixtures/model-cfg.edn
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
(param ?a core.type)
(core.fn [?a] [?a])
(dfg [%0] [%1]
(signature (core.fn [?a] [?a]))
(cfg [%0] [%1]
(signature (core.fn [?a] [?a]))
(cfg [%2] [%4]
(signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a])]))
(block [%2] [%4 %2]
(signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a]) (core.ctrl [?a])]))
(dfg [%5] [%6]
(signature (core.fn [?a] [(core.adt [[?a] [?a]])]))
((core.make_adt 0) [%5] [%6]
(signature (core.fn [?a] [(core.adt [[?a] [?a]])])))))))))
(signature (core.fn [?a] [?a]))
(cfg [%0] [%1]
(signature (core.fn [?a] [?a]))
(cfg [%2] [%3]
(signature (core.ctrl [[?a]] [[?a]]))
(block [%2] [%3 %2]
(signature (core.ctrl [[?a]] [[?a] [?a]]))
(dfg [%4] [%5]
(signature (core.fn [?a] [(core.adt [[?a] [?a]])]))
((core.make_adt 0) [%4] [%5]
(signature (core.fn [?a] [(core.adt [[?a] [?a]])])))))))))

(define-func example.cfg_order
(param ?a core.type)
Expand All @@ -26,15 +26,15 @@
(cfg [%0] [%1]
(signature (core.fn [?a] [?a]))
(cfg [%2] [%4]
(signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a])]))
(signature (core.ctrl [[?a]] [[?a]]))
(block [%3] [%4]
(signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a])]))
(signature (core.ctrl [[?a]] [[?a]]))
(dfg [%5] [%6]
(signature (core.fn [?a] [(core.adt [[?a]])]))
((core.make_adt _ _ 0) [%5] [%6]
(signature (core.fn [?a] [(core.adt [[?a]])])))))
(block [%2] [%3]
(signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a])]))
(signature (core.ctrl [[?a]] [[?a]]))
(dfg [%7] [%8]
(signature (core.fn [?a] [(core.adt [[?a]])]))
((core.make_adt _ _ 0) [%7] [%8]
Expand Down
4 changes: 2 additions & 2 deletions hugr-model/tests/fixtures/model-entrypoint.edn
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
(cfg [] []
(signature (core.fn [] []))
(cfg [%entry] [%exit]
(signature (core.fn [(core.ctrl [])] [(core.ctrl [])]))
(signature (core.ctrl [[]] [[]]))
(meta core.entrypoint)
(block [%entry] [%exit]
(signature (core.fn [(core.ctrl [])] [(core.ctrl [])]))
(signature (core.ctrl [[]] [[]]))
(dfg [] [%value]
(signature (core.fn [] [(core.adt [[]])]))
((core.make_adt _ _ 0) [] [%value]
Expand Down
49 changes: 28 additions & 21 deletions hugr-py/src/hugr/model/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
self.hugr = hugr
self.link_ports: _UnionFind[InPort | OutPort] = _UnionFind()
self.link_names: dict[InPort | OutPort, str] = {}
self.link_next = 0

Check warning on line 42 in hugr-py/src/hugr/model/export.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/model/export.py#L42

Added line #L42 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use itertools.count() for this -
Then call next(self.link_next) and the counter goes up automatically

(Parroting a comment Mark gave to me recently)


# TODO: Store the hugr entrypoint

Expand All @@ -52,15 +53,20 @@
if root in self.link_names:
return self.link_names[root]
else:
index = str(len(self.link_names))
index = str(self.link_next)
self.link_next += 1

Check warning on line 57 in hugr-py/src/hugr/model/export.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/model/export.py#L56-L57

Added lines #L56 - L57 were not covered by tests
self.link_names[root] = index
return index

def export_node(self, node: Node) -> model.Node | None:
def export_node(
self, node: Node, virtual_input_links: Sequence[str] = []
) -> model.Node | None:
"""Export the node with the given node id."""
node_data = self.hugr[node]

inputs = [self.link_name(InPort(node, i)) for i in range(node_data._num_inps)]
inputs = [*inputs, *virtual_input_links]

Check warning on line 68 in hugr-py/src/hugr/model/export.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/model/export.py#L68

Added line #L68 was not covered by tests

outputs = [self.link_name(OutPort(node, i)) for i in range(node_data._num_outs)]
meta = []

Expand Down Expand Up @@ -308,31 +314,21 @@
case DataflowBlock() as op:
region = self.export_region_dfg(node)

input_types = [
model.Apply(
"core.ctrl",
[model.List([type.to_model() for type in op.inputs])],
)
]
input_types = [model.List([type.to_model() for type in op.inputs])]

Check warning on line 317 in hugr-py/src/hugr/model/export.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/model/export.py#L317

Added line #L317 was not covered by tests

other_output_types = [type.to_model() for type in op.other_outputs]
output_types = [
model.Apply(
"core.ctrl",
model.List(
[
model.List(
[
*[type.to_model() for type in row],
*other_output_types,
]
)
],
*[type.to_model() for type in row],
*other_output_types,
]
)
for row in op.sum_ty.variant_rows
]

signature = model.Apply(
"core.fn",
"core.ctrl",
[model.List(input_types), model.List(output_types)],
)

Expand Down Expand Up @@ -469,9 +465,14 @@
source_types = model.List(
[type.to_model() for type in op.inputs]
)
source = self.link_name(OutPort(child, 0))
source = str(self.link_next)
self.link_next += 1

Check warning on line 469 in hugr-py/src/hugr/model/export.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/model/export.py#L468-L469

Added lines #L468 - L469 were not covered by tests

child_node = self.export_node(child)
child_node = self.export_node(

Check warning on line 471 in hugr-py/src/hugr/model/export.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/model/export.py#L471

Added line #L471 was not covered by tests
child, virtual_input_links=[source]
)
else:
child_node = self.export_node(child)

Check warning on line 475 in hugr-py/src/hugr/model/export.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/model/export.py#L475

Added line #L475 was not covered by tests

if child_node is not None:
children.append(child_node)
Expand All @@ -483,7 +484,13 @@
error = f"CFG {node} has no entry block."
raise ValueError(error)

signature = model.Apply("core.fn", [source_types, target_types])
signature = model.Apply(

Check warning on line 487 in hugr-py/src/hugr/model/export.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/model/export.py#L487

Added line #L487 was not covered by tests
"core.ctrl",
[
model.List([source_types]),
model.List([target_types]),
],
)

return model.Region(
kind=model.RegionKind.CONTROL_FLOW,
Expand Down
Loading