Skip to content

Commit e5f4294

Browse files
committed
feat: Make NodeHandle generic
1 parent 195f30c commit e5f4294

File tree

2 files changed

+54
-35
lines changed

2 files changed

+54
-35
lines changed

hugr-core/src/ops.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pub mod module;
99
pub mod sum;
1010
pub mod tag;
1111
pub mod validate;
12+
use crate::core::HugrNode;
1213
use crate::extension::resolution::{
1314
collect_op_extension, collect_op_types_extensions, ExtensionCollectionError,
1415
};
@@ -20,6 +21,7 @@ use crate::types::{EdgeKind, Signature, Substitution};
2021
use crate::{Direction, OutgoingPort, Port};
2122
use crate::{IncomingPort, PortIndex};
2223
use derive_more::Display;
24+
use handle::NodeHandle;
2325
use paste::paste;
2426
use portgraph::NodeIndex;
2527

@@ -41,7 +43,6 @@ pub use tag::OpTag;
4143
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
4244
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
4345
/// The concrete operation types for a node in the HUGR.
44-
// TODO: Link the NodeHandles to the OpType.
4546
#[non_exhaustive]
4647
#[allow(missing_docs)]
4748
#[serde(tag = "op")]
@@ -377,6 +378,19 @@ pub trait OpTrait: Sized + Clone {
377378
/// Tag identifying the operation.
378379
fn tag(&self) -> OpTag;
379380

381+
/// Tries to create a specific [`NodeHandle`] for a node with this operation
382+
/// type.
383+
///
384+
/// Fails if the operation's [`OpTrait::tag`] does not match the
385+
/// [`NodeHandle::TAG`] of the requested handle.
386+
fn try_node_handle<N, H>(&self, node: N) -> Option<H>
387+
where
388+
N: HugrNode,
389+
H: NodeHandle<N> + From<N>,
390+
{
391+
H::TAG.is_superset(self.tag()).then(|| node.into())
392+
}
393+
380394
/// The signature of the operation.
381395
///
382396
/// Only dataflow operations have a signature, otherwise returns None.

hugr-core/src/ops/handle.rs

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
//! Handles to nodes in HUGR.
2+
use crate::core::HugrNode;
23
use crate::types::{Type, TypeBound};
34
use crate::Node;
45

@@ -9,12 +10,12 @@ use super::{AliasDecl, OpTag};
910

1011
/// Common trait for handles to a node.
1112
/// Typically wrappers around [`Node`].
12-
pub trait NodeHandle: Clone {
13+
pub trait NodeHandle<N = Node>: Clone {
1314
/// The most specific operation tag associated with the handle.
1415
const TAG: OpTag;
1516

1617
/// Index of underlying node.
17-
fn node(&self) -> Node;
18+
fn node(&self) -> N;
1819

1920
/// Operation tag for the handle.
2021
#[inline]
@@ -23,7 +24,7 @@ pub trait NodeHandle: Clone {
2324
}
2425

2526
/// Cast the handle to a different more general tag.
26-
fn try_cast<T: NodeHandle + From<Node>>(&self) -> Option<T> {
27+
fn try_cast<T: NodeHandle<N> + From<N>>(&self) -> Option<T> {
2728
T::TAG.is_superset(Self::TAG).then(|| self.node().into())
2829
}
2930

@@ -36,54 +37,54 @@ pub trait NodeHandle: Clone {
3637
/// Trait for handles that contain children.
3738
///
3839
/// The allowed children handles are defined by the associated type.
39-
pub trait ContainerHandle: NodeHandle {
40+
pub trait ContainerHandle<N = Node>: NodeHandle<N> {
4041
/// Handle type for the children of this node.
41-
type ChildrenHandle: NodeHandle;
42+
type ChildrenHandle: NodeHandle<N>;
4243
}
4344

4445
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
4546
/// Handle to a [DataflowOp](crate::ops::dataflow).
46-
pub struct DataflowOpID(Node);
47+
pub struct DataflowOpID<N = Node>(N);
4748

4849
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
4950
/// Handle to a [DFG](crate::ops::DFG) node.
50-
pub struct DfgID(Node);
51+
pub struct DfgID<N = Node>(N);
5152

5253
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
5354
/// Handle to a [CFG](crate::ops::CFG) node.
54-
pub struct CfgID(Node);
55+
pub struct CfgID<N = Node>(N);
5556

5657
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
5758
/// Handle to a module [Module](crate::ops::Module) node.
58-
pub struct ModuleRootID(Node);
59+
pub struct ModuleRootID<N = Node>(N);
5960

6061
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
6162
/// Handle to a [module op](crate::ops::module) node.
62-
pub struct ModuleID(Node);
63+
pub struct ModuleID<N = Node>(N);
6364

6465
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
6566
/// Handle to a [def](crate::ops::OpType::FuncDefn)
6667
/// or [declare](crate::ops::OpType::FuncDecl) node.
6768
///
6869
/// The `DEF` const generic is used to indicate whether the function is
6970
/// defined or just declared.
70-
pub struct FuncID<const DEF: bool>(Node);
71+
pub struct FuncID<const DEF: bool, N = Node>(N);
7172

7273
#[derive(Debug, Clone, PartialEq, Eq)]
7374
/// Handle to an [AliasDefn](crate::ops::OpType::AliasDefn)
7475
/// or [AliasDecl](crate::ops::OpType::AliasDecl) node.
7576
///
7677
/// The `DEF` const generic is used to indicate whether the function is
7778
/// defined or just declared.
78-
pub struct AliasID<const DEF: bool> {
79-
node: Node,
79+
pub struct AliasID<const DEF: bool, N = Node> {
80+
node: N,
8081
name: SmolStr,
8182
bound: TypeBound,
8283
}
8384

84-
impl<const DEF: bool> AliasID<DEF> {
85+
impl<const DEF: bool, N> AliasID<DEF, N> {
8586
/// Construct new AliasID
86-
pub fn new(node: Node, name: SmolStr, bound: TypeBound) -> Self {
87+
pub fn new(node: N, name: SmolStr, bound: TypeBound) -> Self {
8788
Self { node, name, bound }
8889
}
8990

@@ -99,27 +100,27 @@ impl<const DEF: bool> AliasID<DEF> {
99100

100101
#[derive(DerFrom, Debug, Clone, PartialEq, Eq)]
101102
/// Handle to a [Const](crate::ops::OpType::Const) node.
102-
pub struct ConstID(Node);
103+
pub struct ConstID<N = Node>(N);
103104

104105
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
105106
/// Handle to a [DataflowBlock](crate::ops::DataflowBlock) or [Exit](crate::ops::ExitBlock) node.
106-
pub struct BasicBlockID(Node);
107+
pub struct BasicBlockID<N = Node>(N);
107108

108109
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
109110
/// Handle to a [Case](crate::ops::Case) node.
110-
pub struct CaseID(Node);
111+
pub struct CaseID<N = Node>(N);
111112

112113
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
113114
/// Handle to a [TailLoop](crate::ops::TailLoop) node.
114-
pub struct TailLoopID(Node);
115+
pub struct TailLoopID<N = Node>(N);
115116

116117
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
117118
/// Handle to a [Conditional](crate::ops::Conditional) node.
118-
pub struct ConditionalID(Node);
119+
pub struct ConditionalID<N = Node>(N);
119120

120121
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
121122
/// Handle to a dataflow container node.
122-
pub struct DataflowParentID(Node);
123+
pub struct DataflowParentID<N = Node>(N);
123124

124125
/// Implements the `NodeHandle` trait for a tuple struct that contains just a
125126
/// NodeIndex. Takes the name of the struct, and the corresponding OpTag.
@@ -131,11 +132,11 @@ macro_rules! impl_nodehandle {
131132
impl_nodehandle!($name, $tag, 0);
132133
};
133134
($name:ident, $tag:expr, $node_attr:tt) => {
134-
impl NodeHandle for $name {
135+
impl<N: HugrNode> NodeHandle<N> for $name<N> {
135136
const TAG: OpTag = $tag;
136137

137138
#[inline]
138-
fn node(&self) -> Node {
139+
fn node(&self) -> N {
139140
self.$node_attr
140141
}
141142
}
@@ -156,35 +157,35 @@ impl_nodehandle!(ConstID, OpTag::Const);
156157

157158
impl_nodehandle!(BasicBlockID, OpTag::DataflowBlock);
158159

159-
impl<const DEF: bool> NodeHandle for FuncID<DEF> {
160+
impl<const DEF: bool, N: HugrNode> NodeHandle<N> for FuncID<DEF, N> {
160161
const TAG: OpTag = OpTag::Function;
161162
#[inline]
162-
fn node(&self) -> Node {
163+
fn node(&self) -> N {
163164
self.0
164165
}
165166
}
166167

167-
impl<const DEF: bool> NodeHandle for AliasID<DEF> {
168+
impl<const DEF: bool, N: HugrNode> NodeHandle<N> for AliasID<DEF, N> {
168169
const TAG: OpTag = OpTag::Alias;
169170
#[inline]
170-
fn node(&self) -> Node {
171+
fn node(&self) -> N {
171172
self.node
172173
}
173174
}
174175

175-
impl NodeHandle for Node {
176+
impl<N: HugrNode> NodeHandle<N> for N {
176177
const TAG: OpTag = OpTag::Any;
177178
#[inline]
178-
fn node(&self) -> Node {
179+
fn node(&self) -> N {
179180
*self
180181
}
181182
}
182183

183184
/// Implements the `ContainerHandle` trait, with the given child handle type.
184185
macro_rules! impl_containerHandle {
185-
($name:path, $children:ident) => {
186-
impl ContainerHandle for $name {
187-
type ChildrenHandle = $children;
186+
($name:ident, $children:ident) => {
187+
impl<N: HugrNode> ContainerHandle<N> for $name<N> {
188+
type ChildrenHandle = $children<N>;
188189
}
189190
};
190191
}
@@ -197,5 +198,9 @@ impl_containerHandle!(CaseID, DataflowOpID);
197198
impl_containerHandle!(ModuleRootID, ModuleID);
198199
impl_containerHandle!(CfgID, BasicBlockID);
199200
impl_containerHandle!(BasicBlockID, DataflowOpID);
200-
impl_containerHandle!(FuncID<true>, DataflowOpID);
201-
impl_containerHandle!(AliasID<true>, DataflowOpID);
201+
impl<N: HugrNode> ContainerHandle<N> for FuncID<true, N> {
202+
type ChildrenHandle = DataflowOpID<N>;
203+
}
204+
impl<N: HugrNode> ContainerHandle<N> for AliasID<true, N> {
205+
type ChildrenHandle = DataflowOpID<N>;
206+
}

0 commit comments

Comments
 (0)