Skip to content

Commit 097e8c5

Browse files
committed
feat: add llvm codegen for is_borrowed
1 parent a680c89 commit 097e8c5

1 file changed

Lines changed: 109 additions & 5 deletions

File tree

hugr-llvm/src/extension/collections/borrow_array.rs

Lines changed: 109 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,8 @@ enum MaskCheck {
572572
CheckNotBorrowed,
573573
/// Check the element is not borrowed, panicking if it is; then mark as borrowed.
574574
Borrow,
575+
/// Check if the element is borrowed, returning a boolean. (Do not change the bit.)
576+
IsBorrowed,
575577
}
576578

577579
impl MaskCheck {
@@ -580,9 +582,17 @@ impl MaskCheck {
580582
MaskCheck::Return => "__barray_mask_return",
581583
MaskCheck::CheckNotBorrowed => "__barray_mask_check_not_borrowed",
582584
MaskCheck::Borrow => "__barray_mask_borrow",
585+
MaskCheck::IsBorrowed => "__barray_mask_is_borrowed",
583586
}
584587
}
585588

589+
fn return_type<'c, H: HugrView<Node = Node>>(
590+
&self,
591+
ctx: &EmitFuncContext<'c, '_, H>,
592+
) -> Option<BasicTypeEnum<'c>> {
593+
(*self == MaskCheck::IsBorrowed).then(|| ctx.iw_context().bool_type().into())
594+
}
595+
586596
/// Generate code to perform the check on the specified bit of the mask.
587597
/// (Does not check the index is in bounds.)
588598
fn emit<'c, H: HugrView<Node = Node>>(
@@ -591,12 +601,12 @@ impl MaskCheck {
591601
ctx: &mut EmitFuncContext<'c, '_, H>,
592602
mask_ptr: PointerValue<'c>,
593603
idx: IntValue<'c>,
594-
) -> Result<()> {
604+
) -> Result<Option<BasicValueEnum<'c>>> {
595605
get_or_make_function(
596606
ctx,
597607
self.func_name(),
598608
[mask_ptr.into(), idx.into()],
599-
None,
609+
self.return_type(ctx),
600610
|ctx, [mask_ptr, idx]| {
601611
// Compute mask bitarray block index via `idx // BLOCK_SIZE`
602612
let mask_ptr = mask_ptr.into_pointer_value();
@@ -613,10 +623,15 @@ impl MaskCheck {
613623
let block_shifted = builder.build_right_shift(block, idx_in_block, false, "")?;
614624
let bit =
615625
builder.build_int_truncate(block_shifted, ctx.iw_context().bool_type(), "")?;
626+
if *self == MaskCheck::IsBorrowed {
627+
// Just return the bit
628+
return Ok(Some(bit.as_basic_value_enum()));
629+
}
616630
let panic_bb = ctx.build_positioned_new_block("panic", None, |ctx, panic_bb| {
617631
let err: &ConstError = match self {
618632
MaskCheck::CheckNotBorrowed | MaskCheck::Borrow => &ERR_ALREADY_BORROWED,
619633
MaskCheck::Return => &ERR_NOT_BORROWED,
634+
MaskCheck::IsBorrowed => unreachable!("handled above"),
620635
};
621636
let err_val = ctx.emit_custom_const(err).unwrap();
622637
ccg.emit_panic(ctx, err_val)?;
@@ -641,13 +656,13 @@ impl MaskCheck {
641656
let (if_borrowed, if_present) = match self {
642657
MaskCheck::CheckNotBorrowed | MaskCheck::Borrow => (panic_bb, ok_bb),
643658
MaskCheck::Return => (ok_bb, panic_bb),
659+
MaskCheck::IsBorrowed => unreachable!("handled above"),
644660
};
645661
ctx.builder()
646662
.build_conditional_branch(bit, if_borrowed, if_present)?;
647663
Ok(None)
648664
},
649-
)?;
650-
Ok(())
665+
)
651666
}
652667
}
653668

@@ -1570,6 +1585,22 @@ pub fn emit_barray_unsafe_op<'c, H: HugrView<Node = Node>>(
15701585
let (_, array_v) = build_barray_alloc(ctx, ccg, elem_ty, size, true)?;
15711586
outputs.finish(ctx.builder(), [array_v.into()])
15721587
}
1588+
BArrayUnsafeOpDef::is_borrowed => {
1589+
let [array_v, index_v] = inputs
1590+
.try_into()
1591+
.map_err(|_| anyhow!("BArrayUnsafeOpDef::is_borrowed expects two arguments"))?;
1592+
let BArrayFatPtrComponents {
1593+
mask_ptr, offset, ..
1594+
} = decompose_barray_fat_pointer(builder, array_v)?;
1595+
let index_v = index_v.into_int_value();
1596+
build_bounds_check(ccg, ctx, size, index_v)?;
1597+
let offset_index_v = ctx.builder().build_int_add(index_v, offset, "")?;
1598+
// let bit = build_is_borrowed_check(ctx, mask_ptr, offset_index_v)?;
1599+
let bit = MaskCheck::IsBorrowed
1600+
.emit(ccg, ctx, mask_ptr, offset_index_v)?
1601+
.expect("IsBorrowed always returns a value");
1602+
outputs.finish(ctx.builder(), [bit.into(), array_v])
1603+
}
15731604
_ => todo!(),
15741605
}
15751606
}
@@ -2634,7 +2665,7 @@ mod test {
26342665
// - Pops specified numbers from the left to introduce an offset
26352666
// - Converts it into a regular array
26362667
// - Converts it back into a borrow array
2637-
// - Borrows alls elements, sums them up, and returns the sum
2668+
// - Borrows all elements, sums them up, and returns the sum
26382669

26392670
let int_ty = int_type(6);
26402671
let hugr = SimpleHugrConfig::new()
@@ -2908,4 +2939,77 @@ mod test {
29082939
let msg = "Some array elements have been borrowed";
29092940
assert_eq!(&exec_ctx.exec_hugr_panicking(hugr, "main"), msg);
29102941
}
2942+
2943+
#[rstest]
2944+
fn exec_is_borrowed_basic(mut exec_ctx: TestContext) {
2945+
// We build a HUGR that:
2946+
// - Creates a borrow array [1,2,3]
2947+
// - Borrows index 1
2948+
// - Checks is_borrowed for indices 0, 1
2949+
// - Returns 1 if [false, true], else 0
2950+
let int_ty = int_type(6);
2951+
let size = 3;
2952+
let hugr = SimpleHugrConfig::new()
2953+
.with_outs(int_ty.clone())
2954+
.with_extensions(exec_registry())
2955+
.finish(|mut builder| {
2956+
let barray = borrow_array::BArrayValue::new(
2957+
int_ty.clone(),
2958+
(1..=3)
2959+
.map(|i| ConstInt::new_u(6, i).unwrap().into())
2960+
.collect_vec(),
2961+
);
2962+
let barray = builder.add_load_value(barray);
2963+
let idx1 = builder.add_load_value(ConstUsize::new(1));
2964+
let (_, barray) = builder
2965+
.add_borrow_array_borrow(int_ty.clone(), size, barray, idx1)
2966+
.unwrap();
2967+
2968+
let idx0 = builder.add_load_value(ConstUsize::new(0));
2969+
let (arr, b0_bools) =
2970+
[idx0, idx1]
2971+
.iter()
2972+
.fold((barray, Vec::new()), |(arr, mut bools), idx| {
2973+
let (b, arr) = builder
2974+
.add_is_borrowed(int_ty.clone(), size, arr, *idx)
2975+
.unwrap();
2976+
bools.push(b);
2977+
(arr, bools)
2978+
});
2979+
let [b0, b1] = b0_bools.try_into().unwrap();
2980+
2981+
let b0 = builder.add_not(b0).unwrap(); // flip b0 to true
2982+
let and01 = builder.add_and(b0, b1).unwrap();
2983+
let one = builder.add_load_value(ConstInt::new_u(6, 1).unwrap());
2984+
let zero = builder.add_load_value(ConstInt::new_u(6, 0).unwrap());
2985+
let mut cond = builder
2986+
.conditional_builder(
2987+
([type_row![], type_row![]], and01),
2988+
[],
2989+
int_ty.clone().into(),
2990+
)
2991+
.unwrap();
2992+
cond.case_builder(0)
2993+
.unwrap()
2994+
.finish_with_outputs([zero])
2995+
.unwrap();
2996+
cond.case_builder(1)
2997+
.unwrap()
2998+
.finish_with_outputs([one])
2999+
.unwrap();
3000+
let out = cond.finish_sub_container().unwrap().out_wire(0);
3001+
builder
3002+
.add_borrow_array_discard(int_ty.clone(), size, arr)
3003+
.unwrap();
3004+
builder.finish_hugr_with_outputs([out]).unwrap()
3005+
});
3006+
3007+
exec_ctx.add_extensions(|cge| {
3008+
cge.add_default_prelude_extensions()
3009+
.add_logic_extensions()
3010+
.add_default_borrow_array_extensions(DefaultPreludeCodegen)
3011+
.add_default_int_extensions()
3012+
});
3013+
assert_eq!(1, exec_ctx.exec_hugr_u64(hugr, "main"));
3014+
}
29113015
}

0 commit comments

Comments
 (0)