diff --git a/Cargo.lock b/Cargo.lock index 25787d63cd..faa509b828 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1267,6 +1267,7 @@ dependencies = [ "derive_more 2.0.1", "hugr-core", "hugr-llvm", + "hugr-passes", "inkwell", "insta", "itertools 0.14.0", diff --git a/hugr-llvm/Cargo.toml b/hugr-llvm/Cargo.toml index 9c0f6d8788..0282b9a6a6 100644 --- a/hugr-llvm/Cargo.toml +++ b/hugr-llvm/Cargo.toml @@ -40,6 +40,7 @@ derive_more = { workspace = true, features = ["debug"] } [dev-dependencies] hugr-llvm = { "path" = ".", features = ["test-utils"] } +hugr-passes = { path = "../hugr-passes" } [build-dependencies] cc = "1.2.41" diff --git a/hugr-llvm/src/extension/collections/borrow_array.rs b/hugr-llvm/src/extension/collections/borrow_array.rs index f69bbee247..df2e18aefa 100644 --- a/hugr-llvm/src/extension/collections/borrow_array.rs +++ b/hugr-llvm/src/extension/collections/borrow_array.rs @@ -3039,4 +3039,48 @@ mod test { }); assert_eq!(1, exec_ctx.exec_hugr_u64(hugr, "main")); } + + #[rstest] + fn exec_discard_part_borrowed(mut exec_ctx: TestContext) { + use hugr_passes::replace_types::{DelegatingLinearizer, Linearizer}; + // Builds a HUGR that: + // - Creates a borrow array [1,2,3] + // - Borrows index 1 + // - Discards the borrow array using the ReplaceTypes linearizer + // And then runs this, i.e. to check that it does not panic. + let inn_arr_ty = borrow_array_type(2, usize_t()); + let arr_ty = borrow_array_type(3, inn_arr_ty.clone()); + let hugr = SimpleHugrConfig::new() + .with_outs(int_type(6)) + .with_extensions(exec_registry()) + .finish(|mut builder| { + let inner_arrays = [1, 3, 5].map(|i| { + let elems = [i, i + 1].map(|v| builder.add_load_value(ConstUsize::new(v))); + builder.add_new_borrow_array(usize_t(), elems).unwrap() + }); + let outer = builder + .add_new_borrow_array(inn_arr_ty.clone(), inner_arrays) + .unwrap(); + let idx = builder.add_load_value(ConstUsize::new(0)); + let (outer, inner) = builder + .add_borrow_array_borrow(inn_arr_ty, 3, outer, idx) + .unwrap(); + builder + .add_borrow_array_discard(usize_t(), 2, inner) + .unwrap(); + let dl = DelegatingLinearizer::default(); + let nt = dl.copy_discard_op(&arr_ty, 0).unwrap(); + nt.add(&mut builder, [outer]).unwrap(); + let res = builder.add_load_value(ConstInt::new_u(6, 17).unwrap()); + builder.finish_hugr_with_outputs([res]).unwrap() + }); + exec_ctx.add_extensions(|cge| { + cge.add_default_prelude_extensions() + .add_logic_extensions() + .add_conversion_extensions() + .add_default_borrow_array_extensions(DefaultPreludeCodegen) + .add_default_int_extensions() + }); + assert_eq!(17, exec_ctx.exec_hugr_u64(hugr, "main")); + } } diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index 327c78d28c..b11c49e90a 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -1,7 +1,9 @@ //! Callbacks for use with [`ReplaceTypes::replace_consts_parametrized`] //! and [`DelegatingLinearizer::register_callback`](super::DelegatingLinearizer::register_callback) -use hugr_core::builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig, inout_sig}; +use hugr_core::builder::{ + DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer, endo_sig, inout_sig, +}; use hugr_core::extension::prelude::{UnwrapBuilder, option_type}; use hugr_core::ops::constant::CustomConst; use hugr_core::ops::{OpTrait, OpType, Tag}; @@ -389,8 +391,119 @@ pub fn copy_discard_borrow_array( dfb.finish_hugr_with_outputs(outs).unwrap() }))) } + } else if num_outports == 0 { + // Override "generic" array discard to only discard non-borrowed elements. + let elem_discard = lin.copy_discard_op(ty, 0)?; + let array_ty = || borrow_array_type(*n, ty.clone()); + let i64_t = || INT_TYPES[6].clone(); + let mut dfb = DFGBuilder::new(inout_sig(array_ty(), type_row![])).unwrap(); + let [in_array] = dfb.input_wires_arr(); + let zero = dfb.add_load_value(ConstInt::new_u(6, 0).unwrap()); + let one = dfb.add_load_value(ConstInt::new_u(6, 1).unwrap()); + let len = dfb.add_load_value(ConstInt::new_u(6, *n).unwrap()); + + // Loop through the elements, discarding as necessary + let mut tl = dfb + .tail_loop_builder([(i64_t(), zero), (array_ty(), in_array)], [], type_row![]) + .unwrap(); + let [idx, arr] = tl.input_wires_arr(); + let [in_range] = tl + .add_dataflow_op(IntOpDef::ilt_u.with_log_width(6), [idx, len]) + .unwrap() + .outputs_arr(); + let loop_variants = vec![vec![i64_t(), array_ty()].into(), type_row![]]; + let mut cond = tl + .conditional_builder( + (vec![type_row![]; 2], in_range), + [(array_ty(), arr)], + Type::new_sum(loop_variants.clone()).into(), + ) + .unwrap(); + { + // reached end of the array - discard_all_borrowed and exit loop + let mut out_range = cond.case_builder(0).unwrap(); + let [arr] = out_range.input_wires_arr(); + let () = out_range + .add_discard_all_borrowed(ty.clone(), *n, arr) + .unwrap(); + let res = out_range + .add_dataflow_op(Tag::new(1, loop_variants.clone()), []) + .unwrap(); + out_range.finish_with_outputs(res.outputs()).unwrap(); + } + { + // Valid index - check if borrowed + let mut in_range = cond.case_builder(1).unwrap(); + let [arr] = in_range.input_wires_arr(); + let [idx_u] = in_range + .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx]) + .unwrap() + .outputs_arr(); + let (arr, is_borrowed) = in_range + .add_is_borrowed(ty.clone(), *n, arr, idx_u) + .unwrap(); + let mut cond2 = in_range + .conditional_builder( + (vec![type_row![]; 2], is_borrowed), + [(array_ty(), arr)], + array_ty().into(), + ) + .unwrap(); + { + // borrowed - do nothing + let borrowed_case = cond2.case_builder(1).unwrap(); + let [arr] = borrowed_case.input_wires_arr(); + borrowed_case.finish_with_outputs([arr]).unwrap(); + } + { + // not borrowed - discard element + let mut not_borrowed_case = cond2.case_builder(0).unwrap(); + let [arr] = not_borrowed_case.input_wires_arr(); + let (arr, elem) = not_borrowed_case + .add_borrow_array_borrow(ty.clone(), *n, arr, idx_u) + .unwrap(); + elem_discard.add(&mut not_borrowed_case, [elem]).unwrap(); + not_borrowed_case.finish_with_outputs([arr]).unwrap(); + } + let [arr_out] = cond2.finish_sub_container().unwrap().outputs_arr(); + let [idx_out] = in_range + .add_dataflow_op(IntOpDef::iadd.with_log_width(6), [idx, one]) + .unwrap() + .outputs_arr(); + let res = in_range + .add_dataflow_op(Tag::new(0, loop_variants), [idx_out, arr_out]) + .unwrap(); + in_range.finish_with_outputs(res.outputs()).unwrap(); + } + let [loop_pred] = cond.finish_sub_container().unwrap().outputs_arr(); + let [] = tl.finish_with_outputs(loop_pred, []).unwrap().outputs_arr(); + let h = dfb.finish_hugr_with_outputs([]).unwrap(); + Ok(NodeTemplate::CompoundOp(Box::new(h))) } else { // For linear elements we have to fall back to the generic linearization implementation linearize_generic_array::(args, num_outports, lin) } } + +#[cfg(test)] +mod test { + use hugr_core::builder::{DFGBuilder, Dataflow, DataflowHugr}; + use hugr_core::{ + extension::prelude::usize_t, std_extensions::collections::borrow_array::borrow_array_type, + type_row, types::Signature, + }; + + use crate::replace_types::{DelegatingLinearizer, Linearizer}; + + #[test] + fn test_borrow_array_discard() { + let arr_ty = borrow_array_type(5, borrow_array_type(7, usize_t())); + let dl = DelegatingLinearizer::default(); + let mut dfb = DFGBuilder::new(Signature::new(arr_ty.clone(), type_row![])).unwrap(); + let nt = dl.copy_discard_op(&arr_ty, 0).unwrap(); + let ins = dfb.input_wires(); + nt.add(&mut dfb, ins).unwrap(); + + dfb.finish_hugr_with_outputs([]).unwrap(); + } +} diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index edc6128813..d7d2fab0c3 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -934,14 +934,14 @@ mod test { lowerer.run(&mut h).unwrap(); h.validate().unwrap(); let mut exts = h.nodes().filter_map(|n| h.get_optype(n).as_extension_op()); - assert!(exts.any(|eo| eo.qualified_id() == "collections.borrow_arr.discard")); + assert!(exts.any(|eo| eo.qualified_id() == "collections.borrow_arr.discard_all_borrowed")); // We can drop a borrow array of usize let mut h = build_hugr(borrow_array_type(4, usize_t())); lowerer.run(&mut h).unwrap(); h.validate().unwrap(); let mut exts = h.nodes().filter_map(|n| h.get_optype(n).as_extension_op()); - assert!(exts.any(|eo| eo.qualified_id() == "collections.borrow_arr.discard")); + assert!(exts.any(|eo| eo.qualified_id() == "collections.borrow_arr.discard_all_borrowed")); // We cannot drop a qubit let mut h = build_hugr(qb_t());