Skip to content

Commit 2d1c8c3

Browse files
committed
Add codegen for conversion ops
1 parent 330f3ba commit 2d1c8c3

File tree

1 file changed

+214
-17
lines changed

1 file changed

+214
-17
lines changed

hugr-llvm/src/extension/collections/borrow_array.rs

Lines changed: 214 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@ use std::sync::LazyLock;
2323

2424
use anyhow::{Ok, Result, anyhow};
2525
use hugr_core::extension::prelude::{ConstError, option_type, usize_t};
26-
use hugr_core::extension::simple_op::{MakeExtensionOp, MakeRegisteredOp};
26+
use hugr_core::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp};
2727
use hugr_core::ops::DataflowOpTrait;
2828
use hugr_core::std_extensions::collections::array;
2929
use hugr_core::std_extensions::collections::borrow_array::{
30-
self, BArrayClone, BArrayDiscard, BArrayOp, BArrayOpDef, BArrayRepeat, BArrayScan,
31-
BArrayUnsafeOp, BArrayUnsafeOpDef, borrow_array_type,
30+
self, BArrayClone, BArrayDiscard, BArrayFromArray, BArrayFromArrayDef, BArrayOp, BArrayOpDef,
31+
BArrayRepeat, BArrayScan, BArrayToArray, BArrayToArrayDef, BArrayUnsafeOp, BArrayUnsafeOpDef,
32+
borrow_array_type,
3233
};
3334
use hugr_core::types::{TypeArg, TypeEnum};
3435
use hugr_core::{HugrView, Node};
@@ -45,6 +46,7 @@ use crate::emit::emit_value;
4546
use crate::emit::func::get_or_make_function;
4647
use crate::emit::libc::{emit_libc_free, emit_libc_malloc};
4748
use crate::extension::PreludeCodegen;
49+
use crate::extension::collections::array::{build_array_fat_pointer, decompose_array_fat_pointer};
4850
use crate::{CodegenExtension, CodegenExtsBuilder};
4951
use crate::{
5052
emit::{EmitFuncContext, RowPromise, deaggregate_call_result},
@@ -225,6 +227,26 @@ pub trait BorrowArrayCodegen: Clone {
225227
initial_accs,
226228
)
227229
}
230+
231+
/// Emit a [`hugr_core::std_extensions::collections::borrow_array::BArrayToArray`].
232+
fn emit_to_array_op<'c, H: HugrView<Node = Node>>(
233+
&self,
234+
ctx: &mut EmitFuncContext<'c, '_, H>,
235+
op: BArrayToArray,
236+
barray_v: BasicValueEnum<'c>,
237+
) -> Result<BasicValueEnum<'c>> {
238+
emit_to_array_op(self, ctx, op, barray_v)
239+
}
240+
241+
/// Emit a [`hugr_core::std_extensions::collections::borrow_array::BArrayFromArray`].
242+
fn emit_from_array_op<'c, H: HugrView<Node = Node>>(
243+
&self,
244+
ctx: &mut EmitFuncContext<'c, '_, H>,
245+
op: BArrayFromArray,
246+
array_v: BasicValueEnum<'c>,
247+
) -> Result<BasicValueEnum<'c>> {
248+
emit_from_array_op(self, ctx, op, array_v)
249+
}
228250
}
229251

230252
/// A trivial implementation of [`BorrowArrayCodegen`] which passes all methods
@@ -349,6 +371,32 @@ impl<CCG: BorrowArrayCodegen> CodegenExtension for BorrowArrayCodegenExtension<C
349371
.finish(context.builder(), iter::once(tgt_array).chain(final_accs))
350372
}
351373
})
374+
.extension_op(
375+
borrow_array::EXTENSION_ID,
376+
BArrayToArrayDef::new().opdef_id(),
377+
{
378+
let ccg = self.0.clone();
379+
move |context, args| {
380+
let barray = args.inputs[0];
381+
let op = BArrayToArray::from_extension_op(args.node().as_ref())?;
382+
let array = ccg.emit_to_array_op(context, op, barray)?;
383+
args.outputs.finish(context.builder(), [array])
384+
}
385+
},
386+
)
387+
.extension_op(
388+
borrow_array::EXTENSION_ID,
389+
BArrayFromArrayDef::new().opdef_id(),
390+
{
391+
let ccg = self.0.clone();
392+
move |context, args| {
393+
let array = args.inputs[0];
394+
let op = BArrayFromArray::from_extension_op(args.node().as_ref())?;
395+
let barray = ccg.emit_from_array_op(context, op, array)?;
396+
args.outputs.finish(context.builder(), [barray])
397+
}
398+
},
399+
)
352400
}
353401
}
354402

@@ -452,34 +500,39 @@ pub fn build_barray_alloc<'c, H: HugrView<Node = Node>>(
452500
.builder()
453501
.build_bit_cast(mask_ptr, usize_t.ptr_type(AddressSpace::default()), "")?
454502
.into_pointer_value();
455-
// Initialise mask using memset
503+
fill_mask(ctx, mask_ptr, mask_size_value, set_borrowed)?;
504+
505+
let offset = usize_t.const_zero();
506+
let array_v = build_barray_fat_pointer(ctx, elem_ptr, mask_ptr, offset)?;
507+
Ok((elem_ptr, mask_ptr, array_v))
508+
}
509+
510+
/// Emits instructions to fill the entire mask with a bit value.
511+
fn fill_mask<H: HugrView<Node = Node>>(
512+
ctx: &mut EmitFuncContext<H>,
513+
mask_ptr: PointerValue,
514+
size: IntValue,
515+
value: bool,
516+
) -> Result<()> {
456517
let memset_intrinsic = Intrinsic::find("llvm.memset").unwrap();
457518
let memset = memset_intrinsic
458519
.get_declaration(
459520
ctx.get_current_module(),
460-
&[mask_ptr.get_type().into(), usize_t.into()],
521+
&[mask_ptr.get_type().into(), size.get_type().into()],
461522
)
462523
.unwrap();
463-
let val = if set_borrowed {
524+
let val = if value {
464525
ctx.iw_context().i8_type().const_all_ones()
465526
} else {
466527
ctx.iw_context().i8_type().const_zero()
467528
};
468529
let volatile = ctx.iw_context().bool_type().const_zero().into();
469530
ctx.builder().build_call(
470531
memset,
471-
&[
472-
mask_ptr.into(),
473-
val.into(),
474-
mask_size_value.into(),
475-
volatile,
476-
],
532+
&[mask_ptr.into(), val.into(), size.into(), volatile],
477533
"",
478534
)?;
479-
480-
let offset = usize_t.const_zero();
481-
let array_v = build_barray_fat_pointer(ctx, elem_ptr, mask_ptr, offset)?;
482-
Ok((elem_ptr, mask_ptr, array_v))
535+
Ok(())
483536
}
484537

485538
fn inspect_mask_idx<'c, H: HugrView<Node = Node>>(
@@ -1489,16 +1542,56 @@ pub fn emit_barray_unsafe_op<'c, H: HugrView<Node = Node>>(
14891542
}
14901543
}
14911544

1545+
/// Emits an [`BArrayToArray`] op.
1546+
pub fn emit_to_array_op<'c, H: HugrView<Node = Node>>(
1547+
ccg: &impl BorrowArrayCodegen,
1548+
ctx: &mut EmitFuncContext<'c, '_, H>,
1549+
op: BArrayToArray,
1550+
barray_v: BasicValueEnum<'c>,
1551+
) -> Result<BasicValueEnum<'c>> {
1552+
let (ptr, mask_ptr, offset) = decompose_barray_fat_pointer(ctx.builder(), barray_v)?;
1553+
build_none_borrowed_check(ccg, ctx, mask_ptr, offset, op.size)?;
1554+
Ok(build_array_fat_pointer(ctx, ptr, offset)?.into())
1555+
}
1556+
1557+
/// Emits an [`BArrayFromArray`] op.
1558+
pub fn emit_from_array_op<'c, H: HugrView<Node = Node>>(
1559+
ccg: &impl BorrowArrayCodegen,
1560+
ctx: &mut EmitFuncContext<'c, '_, H>,
1561+
op: BArrayFromArray,
1562+
array_v: BasicValueEnum<'c>,
1563+
) -> Result<BasicValueEnum<'c>> {
1564+
// We reuse the allocation from the array but we have to allocate a fresh mask.
1565+
// Note that the mask must have size at least `size + offset` so the offsets match up.
1566+
let usize_t = usize_ty(&ctx.typing_session());
1567+
let builder = ctx.builder();
1568+
let (ptr, offset) = decompose_array_fat_pointer(builder, array_v)?;
1569+
let size = usize_t.const_int(op.size, false);
1570+
let mask_bits = builder.build_int_add(size, offset, "")?;
1571+
let mask_blocks = builder.build_int_unsigned_div(mask_bits, usize_t.size_of(), "")?;
1572+
// Increment by one to account for potential rounding down
1573+
let mask_blocks = builder.build_int_add(mask_blocks, usize_t.const_int(1, false), "")?;
1574+
let mask_size = builder.build_int_mul(mask_blocks, usize_t.size_of(), "")?;
1575+
let mask_ptr = ccg.emit_allocate_array(ctx, mask_size)?;
1576+
let mask_ptr = ctx
1577+
.builder()
1578+
.build_bit_cast(mask_ptr, usize_t.ptr_type(AddressSpace::default()), "")?
1579+
.into_pointer_value();
1580+
fill_mask(ctx, mask_ptr, mask_size, false)?;
1581+
Ok(build_barray_fat_pointer(ctx, ptr, mask_ptr, offset)?.into())
1582+
}
1583+
14921584
#[cfg(test)]
14931585
mod test {
14941586
use hugr_core::Wire;
14951587
use hugr_core::builder::{DataflowHugr, HugrBuilder};
14961588
use hugr_core::extension::prelude::either_type;
14971589
use hugr_core::ops::Tag;
14981590
use hugr_core::std_extensions::STD_REG;
1591+
use hugr_core::std_extensions::collections::array::ArrayOpBuilder;
14991592
use hugr_core::std_extensions::collections::array::op_builder::build_all_borrow_array_ops;
15001593
use hugr_core::std_extensions::collections::borrow_array::{
1501-
self, BArrayOpBuilder, BArrayRepeat, BArrayScan, borrow_array_type,
1594+
self, BArrayOpBuilder, BArrayRepeat, BArrayScan, BArrayToArray, borrow_array_type,
15021595
};
15031596
use hugr_core::types::Type;
15041597
use hugr_core::{
@@ -2446,6 +2539,72 @@ mod test {
24462539
assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
24472540
}
24482541

2542+
#[rstest]
2543+
#[case::basic(32, 0)]
2544+
#[case::boundary(65, 0)]
2545+
#[case::pop1(65, 10)]
2546+
#[case::pop2(200, 32)]
2547+
fn exec_conversion_roundtrip(
2548+
mut exec_ctx: TestContext,
2549+
#[case] mut size: u64,
2550+
#[case] num_pops: u64,
2551+
) {
2552+
// We build a HUGR that:
2553+
// - Loads a borrow array filled with 0..size
2554+
// - Pops specified numbers from the left to introduce an offset
2555+
// - Converts it into a regular array
2556+
// - Converts it back into a borrow array
2557+
// - Borrows alls elements, sums them up, and returns the sum
2558+
2559+
let int_ty = int_type(6);
2560+
let hugr = SimpleHugrConfig::new()
2561+
.with_outs(int_ty.clone())
2562+
.with_extensions(exec_registry())
2563+
.finish(|mut builder| {
2564+
use hugr_core::std_extensions::collections::borrow_array::BArrayFromArray;
2565+
2566+
let barray = borrow_array::BArrayValue::new(
2567+
int_ty.clone(),
2568+
(0..size)
2569+
.map(|i| ConstInt::new_u(6, i).unwrap().into())
2570+
.collect_vec(),
2571+
);
2572+
let barray = builder.add_load_value(barray);
2573+
let barray = build_pops(&mut builder, int_ty.clone(), size, barray, num_pops);
2574+
size -= num_pops;
2575+
let array = builder
2576+
.add_dataflow_op(BArrayToArray::new(int_ty.clone(), size), [barray])
2577+
.unwrap()
2578+
.out_wire(0);
2579+
let mut barray = builder
2580+
.add_dataflow_op(BArrayFromArray::new(int_ty.clone(), size), [array])
2581+
.unwrap()
2582+
.out_wire(0);
2583+
2584+
let mut r = builder.add_load_value(ConstInt::new_u(6, 0).unwrap());
2585+
for i in 0..size {
2586+
let i = builder.add_load_value(ConstUsize::new(i));
2587+
let (val, arr) = builder
2588+
.add_borrow_array_borrow(int_ty.clone(), size, barray, i)
2589+
.unwrap();
2590+
r = builder.add_iadd(6, r, val).unwrap();
2591+
barray = arr;
2592+
}
2593+
builder
2594+
.add_discard_all_borrowed(int_ty.clone(), size, barray)
2595+
.unwrap();
2596+
builder.finish_hugr_with_outputs([r]).unwrap()
2597+
});
2598+
exec_ctx.add_extensions(|cge| {
2599+
cge.add_default_prelude_extensions()
2600+
.add_default_borrow_array_extensions(DefaultPreludeCodegen)
2601+
.add_default_array_extensions()
2602+
.add_default_int_extensions()
2603+
});
2604+
let expected: u64 = (num_pops..size + num_pops).sum();
2605+
assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
2606+
}
2607+
24492608
#[rstest]
24502609
#[case::oob(1, 0, [0, 1], "Index out of bounds")]
24512610
#[case::double_borrow(32, 0, [0, 0], "Array element is already borrowed")]
@@ -2580,4 +2739,42 @@ mod test {
25802739
let msg = "Array contains non-borrowed elements and cannot be discarded";
25812740
assert_eq!(&exec_ctx.exec_hugr_panicking(hugr, "main"), msg);
25822741
}
2742+
2743+
#[rstest]
2744+
fn exec_to_array_panic(mut exec_ctx: TestContext) {
2745+
let int_ty = int_type(6);
2746+
let size = 10;
2747+
let hugr = SimpleHugrConfig::new()
2748+
.with_extensions(exec_registry())
2749+
.finish(|mut builder| {
2750+
let barray = borrow_array::BArrayValue::new(
2751+
int_ty.clone(),
2752+
(0..size)
2753+
.map(|i| ConstInt::new_u(6, i).unwrap().into())
2754+
.collect_vec(),
2755+
);
2756+
let barray = builder.add_load_value(barray);
2757+
let idx = builder.add_load_value(ConstUsize::new(0));
2758+
let (_, barray) = builder
2759+
.add_borrow_array_borrow(int_ty.clone(), size, barray, idx)
2760+
.unwrap();
2761+
let array = builder
2762+
.add_dataflow_op(BArrayToArray::new(int_ty.clone(), size), [barray])
2763+
.unwrap()
2764+
.out_wire(0);
2765+
builder
2766+
.add_array_discard(int_ty.clone(), size, array)
2767+
.unwrap();
2768+
builder.finish_hugr_with_outputs([]).unwrap()
2769+
});
2770+
2771+
exec_ctx.add_extensions(|cge| {
2772+
cge.add_prelude_extensions(PanicTestPreludeCodegen)
2773+
.add_default_borrow_array_extensions(PanicTestPreludeCodegen)
2774+
.add_default_array_extensions()
2775+
.add_default_int_extensions()
2776+
});
2777+
let msg = "Some array elements have been borrowed";
2778+
assert_eq!(&exec_ctx.exec_hugr_panicking(hugr, "main"), msg);
2779+
}
25832780
}

0 commit comments

Comments
 (0)