diff --git a/tket2-exts/src/tket2_exts/__init__.py b/tket2-exts/src/tket2_exts/__init__.py index b332cedf5..1c86a3344 100644 --- a/tket2-exts/src/tket2_exts/__init__.py +++ b/tket2-exts/src/tket2_exts/__init__.py @@ -16,6 +16,11 @@ def opaque_bool() -> Extension: return load_extension("tket2.bool") +@functools.cache +def debug() -> Extension: + return load_extension("tket2.debug") + + @functools.cache def rotation() -> Extension: return load_extension("tket2.rotation") diff --git a/tket2-exts/src/tket2_exts/data/tket2/debug.json b/tket2-exts/src/tket2_exts/data/tket2/debug.json new file mode 100644 index 000000000..fdfa89ec5 --- /dev/null +++ b/tket2-exts/src/tket2_exts/data/tket2/debug.json @@ -0,0 +1,77 @@ +{ + "version": "0.1.0", + "name": "tket2.debug", + "runtime_reqs": [], + "types": {}, + "values": {}, + "operations": { + "StateResult": { + "extension": "tket2.debug", + "name": "StateResult", + "description": "Report the state of given qubits in the given order.", + "signature": { + "params": [ + { + "tp": "String" + }, + { + "tp": "BoundedNat", + "bound": null + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.array", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 1, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "Q" + } + } + ], + "bound": "A" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "collections.array", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 1, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "Q" + } + } + ], + "bound": "A" + } + ], + "runtime_reqs": [] + } + }, + "binary": false + } + } +} diff --git a/tket2-hseries/src/bin/tket2-hseries.rs b/tket2-hseries/src/bin/tket2-hseries.rs index 99388d29a..6ccee9588 100644 --- a/tket2-hseries/src/bin/tket2-hseries.rs +++ b/tket2-hseries/src/bin/tket2-hseries.rs @@ -11,6 +11,7 @@ fn main() { tket2::extension::TKET2_EXTENSION.to_owned(), tket2::extension::rotation::ROTATION_EXTENSION.to_owned(), tket2::extension::bool::BOOL_EXTENSION.to_owned(), + tket2::extension::debug::DEBUG_EXTENSION.to_owned(), tket2_hseries::extension::qsystem::EXTENSION.to_owned(), tket2_hseries::extension::futures::EXTENSION.to_owned(), tket2_hseries::extension::random::EXTENSION.to_owned(), diff --git a/tket2/src/extension.rs b/tket2/src/extension.rs index 3ed72ec09..43724d867 100644 --- a/tket2/src/extension.rs +++ b/tket2/src/extension.rs @@ -20,6 +20,8 @@ use smol_str::SmolStr; /// Definition for bool type and ops. pub mod bool; +/// Definition for debug ops. +pub mod debug; /// Definition for Angle ops and types. pub mod rotation; pub mod sympy; @@ -61,6 +63,7 @@ pub(crate) static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::new( TKET1_EXTENSION.to_owned(), TKET2_EXTENSION.to_owned(), bool::BOOL_EXTENSION.to_owned(), + debug::DEBUG_EXTENSION.to_owned(), rotation::ROTATION_EXTENSION.to_owned() ])); diff --git a/tket2/src/extension/debug.rs b/tket2/src/extension/debug.rs new file mode 100644 index 000000000..4e39b5902 --- /dev/null +++ b/tket2/src/extension/debug.rs @@ -0,0 +1,194 @@ +//! This module defines a Hugr extension for operations to be used by users debugging +//! with a simulator. +use std::sync::{Arc, Weak}; + +use hugr::{ + extension::{ + prelude::qb_t, + simple_op::{ + try_from_name, HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, + OpLoadError, + }, + ExtensionId, SignatureError, SignatureFunc, Version, + }, + ops::{NamedOp, OpName}, + std_extensions::collections::array::array_type_parametric, + types::{type_param::TypeParam, FuncValueType, PolyFuncTypeRV, TypeArg}, + Extension, +}; +use lazy_static::lazy_static; + +/// The ID of the `tket2.debug` extension. +pub const DEBUG_EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("tket2.debug"); +/// The "tket2.debug" extension version +pub const DEBUG_EXTENSION_VERSION: Version = Version::new(0, 1, 0); + +lazy_static! { + /// The "tket2.bool" extension. + pub static ref DEBUG_EXTENSION: Arc = { + Extension::new_arc(DEBUG_EXTENSION_ID, DEBUG_EXTENSION_VERSION, |ext, ext_ref| { + StateResultDef.add_to_extension(ext, ext_ref).unwrap(); + }) + }; +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +/// A `tket2.StateResult` operation definition. +pub struct StateResultDef; + +/// Name of the `tket2.StateResult` operation. +pub const STATE_RESULT_OP_ID: OpName = OpName::new_inline("StateResult"); +impl NamedOp for StateResultDef { + fn name(&self) -> OpName { + STATE_RESULT_OP_ID + } +} + +impl std::str::FromStr for StateResultDef { + type Err = (); + + fn from_str(s: &str) -> Result { + if s == StateResultDef.name() { + Ok(Self) + } else { + Err(()) + } + } +} + +impl MakeOpDef for StateResultDef { + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { + PolyFuncTypeRV::new( + vec![TypeParam::String, TypeParam::max_nat()], + FuncValueType::new( + vec![ + array_type_parametric(TypeArg::new_var_use(1, TypeParam::max_nat()), qb_t()) + .unwrap(), + ], + vec![ + array_type_parametric(TypeArg::new_var_use(1, TypeParam::max_nat()), qb_t()) + .unwrap(), + ], + ), + ) + .into() + } + + fn from_def( + op_def: &hugr::extension::OpDef, + ) -> Result { + try_from_name(op_def.name(), op_def.extension_id()) + } + + fn extension(&self) -> ExtensionId { + DEBUG_EXTENSION_ID + } + + fn description(&self) -> String { + "Report the state of given qubits in the given order.".to_string() + } + + fn extension_ref(&self) -> Weak { + Arc::downgrade(&DEBUG_EXTENSION) + } +} + +#[derive(Debug, Clone, PartialEq)] +/// A debug operation for requesting the state of some qubits to be recorded if the +/// program is executed on a simulator. +pub struct StateResult { + /// Static string tag for the result. + pub tag: String, + /// The number of qubits in the result. + pub num_qubits: u64, +} + +impl StateResult { + /// Create a new `StateResult` operation. + pub fn new(tag: String, num_qubits: u64) -> Self { + StateResult { tag, num_qubits } + } +} + +impl NamedOp for StateResult { + fn name(&self) -> OpName { + STATE_RESULT_OP_ID + } +} + +impl MakeExtensionOp for StateResult { + fn from_extension_op(ext_op: &hugr::ops::ExtensionOp) -> Result + where + Self: Sized, + { + let def = StateResultDef::from_def(ext_op.def())?; + def.instantiate(ext_op.args()) + } + + fn type_args(&self) -> Vec { + vec![ + TypeArg::String { + arg: self.tag.clone(), + }, + TypeArg::BoundedNat { n: self.num_qubits }, + ] + } +} + +impl MakeRegisteredOp for StateResult { + fn extension_id(&self) -> ExtensionId { + DEBUG_EXTENSION_ID + } + + fn extension_ref(&self) -> Weak { + Arc::downgrade(&DEBUG_EXTENSION) + } +} + +impl HasDef for StateResult { + type Def = StateResultDef; +} + +impl HasConcrete for StateResultDef { + type Concrete = StateResult; + + fn instantiate(&self, type_args: &[TypeArg]) -> Result { + let [TypeArg::String { arg }, TypeArg::BoundedNat { n }] = type_args else { + return Err(SignatureError::InvalidTypeArgs)?; + }; + Ok(StateResult { + tag: arg.to_string(), + num_qubits: *n, + }) + } +} + +#[cfg(test)] +pub(crate) mod test { + use hugr::{ + builder::{DFGBuilder, Dataflow, DataflowHugr}, + ops::OpType, + std_extensions::collections::array::array_type, + types::Signature, + }; + + use super::*; + + #[test] + fn test_state_result() { + let op = StateResult::new("test".into(), 22); + let optype: OpType = op.clone().into(); + let new_op = StateResult::from_extension_op(optype.as_extension_op().unwrap()).unwrap(); + assert_eq!(new_op, op); + + let qb_array_type = array_type(22, qb_t()); + let hugr = { + let mut builder = + DFGBuilder::new(Signature::new(qb_array_type.clone(), qb_array_type)).unwrap(); + let inputs: [hugr::Wire; 1] = builder.input_wires_arr(); + let output = builder.add_dataflow_op(op, inputs).unwrap(); + builder.finish_hugr_with_outputs(output.outputs()).unwrap() + }; + hugr.validate().unwrap(); + } +}