diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 8bae0f1cdc..9a392dfaaf 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -14,7 +14,7 @@ use crate::{ }, types::{ CustomType, EdgeKind, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, - TypeArg, TypeBase, TypeBound, TypeEnum, TypeRow, + TypeArg, TypeBase, TypeBound, TypeEnum, type_param::{TypeArgVariable, TypeParam}, type_row::TypeRowBase, }, @@ -578,7 +578,6 @@ impl<'a> Context<'a> { pub fn export_block_signature(&mut self, block: &DataflowBlock) -> table::TermId { let inputs = { let inputs = self.export_type_row(&block.inputs); - let inputs = self.make_term_apply(model::CORE_CTRL, &[inputs]); self.make_term(table::Term::List( self.bump.alloc_slice_copy(&[table::SeqPart::Item(inputs)]), )) @@ -590,13 +589,12 @@ impl<'a> Context<'a> { let mut outputs = BumpVec::with_capacity_in(block.sum_rows.len(), self.bump); for sum_row in &block.sum_rows { let variant = self.export_type_row_with_tail(sum_row, Some(tail)); - let control = self.make_term_apply(model::CORE_CTRL, &[variant]); - outputs.push(table::SeqPart::Item(control)); + outputs.push(table::SeqPart::Item(variant)); } self.make_term(table::Term::List(outputs.into_bump_slice())) }; - self.make_term_apply(model::CORE_FN, &[inputs, outputs]) + self.make_term_apply(model::CORE_CTRL, &[inputs, outputs]) } /// Creates a data flow region from the given node's children. @@ -740,18 +738,21 @@ impl<'a> Context<'a> { let signature = { let node_signature = self.hugr.signature(node).unwrap(); - let mut wrap_ctrl = |types: &TypeRow| { - let types = self.export_type_row(types); - let types_ctrl = self.make_term_apply(model::CORE_CTRL, &[types]); + let inputs = { + let types = self.export_type_row(node_signature.input()); self.make_term(table::Term::List( - self.bump - .alloc_slice_copy(&[table::SeqPart::Item(types_ctrl)]), + self.bump.alloc_slice_copy(&[table::SeqPart::Item(types)]), )) }; - let inputs = wrap_ctrl(node_signature.input()); - let outputs = wrap_ctrl(node_signature.output()); - Some(self.make_term_apply(model::CORE_FN, &[inputs, outputs])) + let outputs = { + let types = self.export_type_row(node_signature.output()); + self.make_term(table::Term::List( + self.bump.alloc_slice_copy(&[table::SeqPart::Item(types)]), + )) + }; + + Some(self.make_term_apply(model::CORE_CTRL, &[inputs, outputs])) }; let scope = match closure { diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 430c819218..e9c614c590 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -826,7 +826,7 @@ impl<'a> Context<'a> { } let region_target_types = (|| { - let [_, region_targets] = self.get_func_type( + let [_, region_targets] = self.get_ctrl_type( region_data .signature .ok_or_else(|| error_uninferred!("region signature"))?, @@ -860,25 +860,18 @@ impl<'a> Context<'a> { }; // The entry node in core control flow regions is identified by being - // the first child node of the CFG node. We therefore import the entry - // node first and follow it up by every other node. + // the first child node of the CFG node. We therefore import the entry node first. self.import_node(entry_node, node)?; - for child in region_data.children { - if *child != entry_node { - self.import_node(*child, node)?; - } - } - - // Create the exit node for the control flow region. + // Create the exit node for the control flow region. This always needs + // to be second in the node list. { let cfg_outputs = { - let [ctrl_type] = region_target_types.as_slice() else { + let [target_types] = region_target_types.as_slice() else { return Err(error_invalid!("cfg region expects a single target")); }; - let [types] = self.expect_symbol(*ctrl_type, model::CORE_CTRL)?; - self.import_type_row(types)? + self.import_type_row(*target_types)? }; let exit = self @@ -887,6 +880,13 @@ impl<'a> Context<'a> { self.record_links(exit, Direction::Incoming, region_data.targets); } + // Finally we import all other nodes. + for child in region_data.children { + if *child != entry_node { + self.import_node(*child, node)?; + } + } + for meta_item in region_data.meta { self.import_node_metadata(node, *meta_item) .map_err(|err| error_context!(err, "node metadata"))?; @@ -1245,13 +1245,6 @@ impl<'a> Context<'a> { return Err(error_unsupported!("`{}` as `TypeParam`", model::CORE_CONST)); } - if let Some([]) = self.match_symbol(term_id, model::CORE_CTRL_TYPE)? { - return Err(error_unsupported!( - "`{}` as `TypeParam`", - model::CORE_CTRL_TYPE - )); - } - if let Some([item_type]) = self.match_symbol(term_id, model::CORE_LIST_TYPE)? { // At present `hugr-model` has no way to express that the item // type of a list must be copyable. Therefore we import it as `Any`. @@ -1339,13 +1332,6 @@ impl<'a> Context<'a> { return Err(error_unsupported!("`{}` as `TypeArg`", model::CORE_STATIC)); } - if let Some([]) = self.match_symbol(term_id, model::CORE_CTRL_TYPE)? { - return Err(error_unsupported!( - "`{}` as `TypeArg`", - model::CORE_CTRL_TYPE - )); - } - if let Some([]) = self.match_symbol(term_id, model::CORE_CONST)? { return Err(error_unsupported!("`{}` as `TypeArg`", model::CORE_CONST)); } @@ -1510,6 +1496,11 @@ impl<'a> Context<'a> { .ok_or(error_invalid!("expected a function type")) } + fn get_ctrl_type(&mut self, term_id: table::TermId) -> Result<[table::TermId; 2], ImportError> { + self.match_symbol(term_id, model::CORE_CTRL)? + .ok_or(error_invalid!("expected a control type")) + } + fn import_func_type( &mut self, term_id: table::TermId, diff --git a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap index 7a6136bdb2..f3f272fd21 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg.edn\"))" +expression: ast --- (hugr 0) @@ -22,10 +22,9 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg. (cfg [%0] [%1] (signature (core.fn [?0] [?0])) (cfg [%2] [%3] - (signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])])) + (signature (core.ctrl [[?0]] [[?0]])) (block [%2] [%3 %2] - (signature - (core.fn [(core.ctrl [?0])] [(core.ctrl [?0]) (core.ctrl [?0])])) + (signature (core.ctrl [[?0]] [[?0] [?0]])) (dfg [%4] [%5] (signature (core.fn [?0] [(core.adt [[?0] [?0]])])) ((core.make_adt 0) [%4] [%5] @@ -37,15 +36,15 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg. (cfg [%0] [%1] (signature (core.fn [?0] [?0])) (cfg [%2] [%3] - (signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])])) + (signature (core.ctrl [[?0]] [[?0]])) (block [%2] [%6] - (signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])])) + (signature (core.ctrl [[?0]] [[?0]])) (dfg [%4] [%5] (signature (core.fn [?0] [(core.adt [[?0]])])) ((core.make_adt 0) [%4] [%5] (signature (core.fn [?0] [(core.adt [[?0]])]))))) (block [%6] [%3] - (signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])])) + (signature (core.ctrl [[?0]] [[?0]])) (dfg [%7] [%8] (signature (core.fn [?0] [(core.adt [[?0]])])) ((core.make_adt 0) [%7] [%8] diff --git a/hugr-core/tests/snapshots/model__roundtrip_entrypoint.snap b/hugr-core/tests/snapshots/model__roundtrip_entrypoint.snap index 1db0b9d1cd..5313b7257b 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_entrypoint.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_entrypoint.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-entrypoint.edn\"))" +expression: ast --- (hugr 0) @@ -40,10 +40,10 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-entr (cfg (signature (core.fn [] [])) (cfg [%0] [%1] - (signature (core.fn [(core.ctrl [])] [(core.ctrl [])])) + (signature (core.ctrl [[]] [[]])) (meta core.entrypoint) (block [%0] [%1] - (signature (core.fn [(core.ctrl [])] [(core.ctrl [])])) + (signature (core.ctrl [[]] [[]])) (dfg [] [%2] (signature (core.fn [] [(core.adt [[]])])) ((core.make_adt 0) [] [%2] diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index 15e29f3bde..27e6605ae7 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -163,17 +163,13 @@ pub const CORE_BYTES_TYPE: &str = "core.bytes"; /// - **Result:** `core.static` pub const CORE_FLOAT_TYPE: &str = "core.float"; -/// Type of a control flow edge. +/// Type of control flow regions. /// -/// - **Parameter:** `?types : (core.list core.type)` -/// - **Result:** `core.ctrl_type` +/// - **Parameter:** `?inputs : (core.list (core.list core.type))` +/// - **Parameter:** `?outputs : (core.list (core.list core.type))` +/// - **Result:** `core.type` pub const CORE_CTRL: &str = "core.ctrl"; -/// The type of the types for control flow edges. -/// -/// - **Result:** `?type : core.static` -pub const CORE_CTRL_TYPE: &str = "core.ctrl_type"; - /// The type for runtime constants. /// /// - **Parameter:** `?type : core.type` diff --git a/hugr-model/tests/fixtures/model-cfg.edn b/hugr-model/tests/fixtures/model-cfg.edn index e2a760f5c0..b105eb7d91 100644 --- a/hugr-model/tests/fixtures/model-cfg.edn +++ b/hugr-model/tests/fixtures/model-cfg.edn @@ -6,17 +6,17 @@ (param ?a core.type) (core.fn [?a] [?a]) (dfg [%0] [%1] - (signature (core.fn [?a] [?a])) - (cfg [%0] [%1] - (signature (core.fn [?a] [?a])) - (cfg [%2] [%4] - (signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a])])) - (block [%2] [%4 %2] - (signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a]) (core.ctrl [?a])])) - (dfg [%5] [%6] - (signature (core.fn [?a] [(core.adt [[?a] [?a]])])) - ((core.make_adt 0) [%5] [%6] - (signature (core.fn [?a] [(core.adt [[?a] [?a]])]))))))))) + (signature (core.fn [?a] [?a])) + (cfg [%0] [%1] + (signature (core.fn [?a] [?a])) + (cfg [%2] [%3] + (signature (core.ctrl [[?a]] [[?a]])) + (block [%2] [%3 %2] + (signature (core.ctrl [[?a]] [[?a] [?a]])) + (dfg [%4] [%5] + (signature (core.fn [?a] [(core.adt [[?a] [?a]])])) + ((core.make_adt 0) [%4] [%5] + (signature (core.fn [?a] [(core.adt [[?a] [?a]])]))))))))) (define-func example.cfg_order (param ?a core.type) @@ -26,15 +26,15 @@ (cfg [%0] [%1] (signature (core.fn [?a] [?a])) (cfg [%2] [%4] - (signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a])])) + (signature (core.ctrl [[?a]] [[?a]])) (block [%3] [%4] - (signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a])])) + (signature (core.ctrl [[?a]] [[?a]])) (dfg [%5] [%6] (signature (core.fn [?a] [(core.adt [[?a]])])) ((core.make_adt _ _ 0) [%5] [%6] (signature (core.fn [?a] [(core.adt [[?a]])]))))) (block [%2] [%3] - (signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a])])) + (signature (core.ctrl [[?a]] [[?a]])) (dfg [%7] [%8] (signature (core.fn [?a] [(core.adt [[?a]])])) ((core.make_adt _ _ 0) [%7] [%8] diff --git a/hugr-model/tests/fixtures/model-entrypoint.edn b/hugr-model/tests/fixtures/model-entrypoint.edn index 10cab9173b..b95cb06992 100644 --- a/hugr-model/tests/fixtures/model-entrypoint.edn +++ b/hugr-model/tests/fixtures/model-entrypoint.edn @@ -25,10 +25,10 @@ (cfg [] [] (signature (core.fn [] [])) (cfg [%entry] [%exit] - (signature (core.fn [(core.ctrl [])] [(core.ctrl [])])) + (signature (core.ctrl [[]] [[]])) (meta core.entrypoint) (block [%entry] [%exit] - (signature (core.fn [(core.ctrl [])] [(core.ctrl [])])) + (signature (core.ctrl [[]] [[]])) (dfg [] [%value] (signature (core.fn [] [(core.adt [[]])])) ((core.make_adt _ _ 0) [] [%value] diff --git a/hugr-py/src/hugr/model/export.py b/hugr-py/src/hugr/model/export.py index d1216d8370..950b7c776d 100644 --- a/hugr-py/src/hugr/model/export.py +++ b/hugr-py/src/hugr/model/export.py @@ -39,6 +39,7 @@ def __init__(self, hugr: Hugr): self.hugr = hugr self.link_ports: _UnionFind[InPort | OutPort] = _UnionFind() self.link_names: dict[InPort | OutPort, str] = {} + self.link_next = 0 # TODO: Store the hugr entrypoint @@ -52,15 +53,20 @@ def link_name(self, port: InPort | OutPort) -> str: if root in self.link_names: return self.link_names[root] else: - index = str(len(self.link_names)) + index = str(self.link_next) + self.link_next += 1 self.link_names[root] = index return index - def export_node(self, node: Node) -> model.Node | None: + def export_node( + self, node: Node, virtual_input_links: Sequence[str] = [] + ) -> model.Node | None: """Export the node with the given node id.""" node_data = self.hugr[node] inputs = [self.link_name(InPort(node, i)) for i in range(node_data._num_inps)] + inputs = [*inputs, *virtual_input_links] + outputs = [self.link_name(OutPort(node, i)) for i in range(node_data._num_outs)] meta = [] @@ -308,31 +314,21 @@ def export_node(self, node: Node) -> model.Node | None: case DataflowBlock() as op: region = self.export_region_dfg(node) - input_types = [ - model.Apply( - "core.ctrl", - [model.List([type.to_model() for type in op.inputs])], - ) - ] + input_types = [model.List([type.to_model() for type in op.inputs])] other_output_types = [type.to_model() for type in op.other_outputs] output_types = [ - model.Apply( - "core.ctrl", + model.List( [ - model.List( - [ - *[type.to_model() for type in row], - *other_output_types, - ] - ) - ], + *[type.to_model() for type in row], + *other_output_types, + ] ) for row in op.sum_ty.variant_rows ] signature = model.Apply( - "core.fn", + "core.ctrl", [model.List(input_types), model.List(output_types)], ) @@ -469,9 +465,14 @@ def export_region_cfg(self, node: Node) -> model.Region: source_types = model.List( [type.to_model() for type in op.inputs] ) - source = self.link_name(OutPort(child, 0)) + source = str(self.link_next) + self.link_next += 1 - child_node = self.export_node(child) + child_node = self.export_node( + child, virtual_input_links=[source] + ) + else: + child_node = self.export_node(child) if child_node is not None: children.append(child_node) @@ -483,7 +484,13 @@ def export_region_cfg(self, node: Node) -> model.Region: error = f"CFG {node} has no entry block." raise ValueError(error) - signature = model.Apply("core.fn", [source_types, target_types]) + signature = model.Apply( + "core.ctrl", + [ + model.List([source_types]), + model.List([target_types]), + ], + ) return model.Region( kind=model.RegionKind.CONTROL_FLOW,