diff --git a/Cargo.lock b/Cargo.lock index b63e090a2..9c2b5fff4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1033,9 +1033,9 @@ dependencies = [ [[package]] name = "hugr" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc1bdd463fd409d55a1fb3a2e3aeded7567fe7852129d8b40bd09c5070661694" +checksum = "ccc992a69300fb5835dd9fd259620b21b726bc22ade4aea9fcd20fbdd580e653" dependencies = [ "hugr-core", "hugr-llvm", @@ -1045,9 +1045,9 @@ dependencies = [ [[package]] name = "hugr-cli" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2566c6379afde1b8648b4f9e7477c6c6fae2af15d8c662af23ec40c0201ba42" +checksum = "e095693e2c78868aebdc16919deba50fb3e6c458f2fba37c05efceb3fc926e4e" dependencies = [ "anyhow", "clap", @@ -1063,9 +1063,9 @@ dependencies = [ [[package]] name = "hugr-core" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2418f5d99fa346fe47943da854b8a4d5dde8550b4b311a1a9eb76d4fe78d5f23" +checksum = "8f7cae0bac05d0551d805552db1c6000262f2234897b174e467a394a85a058af" dependencies = [ "base64", "cgmath", @@ -1101,9 +1101,9 @@ dependencies = [ [[package]] name = "hugr-llvm" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e53465db105b2567727464d2cda16caf9fceec7539ad0d13fbfacfd901f971d" +checksum = "a986b3b9cdb4c61f0a0c5980eee0c0e89839edfb1bc361025eb5aab767d110b7" dependencies = [ "anyhow", "delegate 0.13.4", @@ -1121,9 +1121,9 @@ dependencies = [ [[package]] name = "hugr-model" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb50c9602d7c8423e89a643b69a24c9247c8d5f6a43547344c1968eec14a682c" +checksum = "64698533246838bfcd11f5ad5823871e03135e98fe858b8971c7da37df328c8a" dependencies = [ "base64", "bumpalo", @@ -1143,9 +1143,9 @@ dependencies = [ [[package]] name = "hugr-passes" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d3b96fcb729cea2c49172ecfc649d908dd3574938cce95e3096a3dc9132c73b" +checksum = "b4561dfe756a0bc7e25c905255d056d7c73bbd6ac88552b4c1d9778b6140cc87" dependencies = [ "ascent", "derive_more 1.0.0", diff --git a/Cargo.toml b/Cargo.toml index 69a8105cb..012152589 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,9 +47,9 @@ large_enum_variant = "allow" [workspace.dependencies] # Make sure to run `just recompile-eccs` if the hugr serialisation format changes. -hugr = "0.22.0" -hugr-core = "0.22.0" -hugr-cli = "0.22.0" +hugr = "0.22.1" +hugr-core = "0.22.1" +hugr-cli = "0.22.1" portgraph = "0.15.1" pyo3 = ">= 0.23.4, < 0.26" itertools = "0.14.0" diff --git a/tket-qsystem/src/lib.rs b/tket-qsystem/src/lib.rs index 7701f371a..6bb4f8fa3 100644 --- a/tket-qsystem/src/lib.rs +++ b/tket-qsystem/src/lib.rs @@ -13,6 +13,7 @@ use hugr::{ hugr::{hugrmut::HugrMut, HugrError}, Hugr, HugrView, Node, }; +use lower_drops::LowerDropsPass; use replace_bools::{ReplaceBoolPass, ReplaceBoolPassError}; use tket::TketOp; @@ -29,7 +30,7 @@ pub mod cli; pub mod extension; #[cfg(feature = "llvm")] pub mod llvm; - +mod lower_drops; pub mod replace_bools; /// Modify a [hugr::Hugr] into a form that is acceptable for ingress into a @@ -117,7 +118,12 @@ impl QSystemPass { if self.lazify { self.replace_bools().run(hugr)?; } + + // We expect any Hugr will have *either* drop ops, or ValueArrays (without drops), + // so only one of these passes will do anything; the order is thus immaterial. + self.lower_drops().run(hugr)?; self.linearize_arrays().run(hugr)?; + #[cfg(feature = "llvm")] { // TODO: Remove "llvm" feature gate once `inline_constant_functions` is moved to @@ -191,6 +197,10 @@ impl QSystemPass { MonomorphizePass } + fn lower_drops(&self) -> LowerDropsPass { + LowerDropsPass + } + fn linearize_arrays(&self) -> LinearizeArrayPass { LinearizeArrayPass::default() } diff --git a/tket-qsystem/src/lower_drops.rs b/tket-qsystem/src/lower_drops.rs new file mode 100644 index 000000000..fa09ce130 --- /dev/null +++ b/tket-qsystem/src/lower_drops.rs @@ -0,0 +1,72 @@ +/// Contains a pass to lower "drop" ops from the Guppy extension +use hugr::algorithms::replace_types::{NodeTemplate, ReplaceTypesError, ReplacementOptions}; +use hugr::algorithms::{ComposablePass, ReplaceTypes}; +use hugr::builder::{Container, DFGBuilder}; +use hugr::types::{Signature, Term}; +use hugr::{hugr::hugrmut::HugrMut, Node}; +use tket::extension::guppy::{DROP_OP_NAME, GUPPY_EXTENSION}; + +/// A pass that lowers "drop" ops from [GUPPY_EXTENSION] +#[derive(Default, Debug, Clone)] +pub struct LowerDropsPass; + +impl> ComposablePass for LowerDropsPass { + type Error = ReplaceTypesError; + + /// Returns whether any drops were lowered + type Result = bool; + + fn run(&self, hugr: &mut H) -> Result { + let mut rt = ReplaceTypes::default(); + rt.replace_parametrized_op_with( + GUPPY_EXTENSION.get_op(DROP_OP_NAME.as_str()).unwrap(), + |targs| { + let [Term::Runtime(ty)] = targs else { + panic!("Expected just one type") + }; + // The Hugr here is invalid, so we have to pull it out manually + let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap(); + let h = std::mem::take(dfb.hugr_mut()); + Some(NodeTemplate::CompoundOp(Box::new(h))) + }, + ReplacementOptions::default().with_linearization(true), + ); + rt.run(hugr) + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use hugr::builder::{inout_sig, Dataflow, DataflowHugr}; + use hugr::ops::ExtensionOp; + use hugr::{extension::prelude::usize_t, std_extensions::collections::array::array_type}; + use hugr::{Hugr, HugrView}; + + use super::*; + + #[test] + fn test_lower_drop() { + let arr_type = array_type(2, usize_t()); + let drop_op = GUPPY_EXTENSION.get_op(DROP_OP_NAME.as_str()).unwrap(); + let drop_node = ExtensionOp::new(drop_op.clone(), [arr_type.clone().into()]).unwrap(); + let mut b = DFGBuilder::new(inout_sig(arr_type, vec![])).unwrap(); + let inp = b.input_wires(); + b.add_dataflow_op(drop_node, inp).unwrap(); + let mut h = b.finish_hugr_with_outputs([]).unwrap(); + let count_drops = |h: &Hugr| { + h.nodes() + .filter(|n| { + h.get_optype(*n) + .as_extension_op() + .is_some_and(|e| Arc::ptr_eq(e.def_arc(), drop_op)) + }) + .count() + }; + assert_eq!(count_drops(&h), 1); + LowerDropsPass.run(&mut h).unwrap(); + h.validate().unwrap(); + assert_eq!(count_drops(&h), 0); + } +}