@@ -572,6 +572,8 @@ enum MaskCheck {
572572 CheckNotBorrowed ,
573573 /// Check the element is not borrowed, panicking if it is; then mark as borrowed.
574574 Borrow ,
575+ /// Check if the element is borrowed, returning a boolean. (Do not change the bit.)
576+ IsBorrowed ,
575577}
576578
577579impl MaskCheck {
@@ -580,9 +582,17 @@ impl MaskCheck {
580582 MaskCheck :: Return => "__barray_mask_return" ,
581583 MaskCheck :: CheckNotBorrowed => "__barray_mask_check_not_borrowed" ,
582584 MaskCheck :: Borrow => "__barray_mask_borrow" ,
585+ MaskCheck :: IsBorrowed => "__barray_mask_is_borrowed" ,
583586 }
584587 }
585588
589+ fn return_type < ' c , H : HugrView < Node = Node > > (
590+ & self ,
591+ ctx : & EmitFuncContext < ' c , ' _ , H > ,
592+ ) -> Option < BasicTypeEnum < ' c > > {
593+ ( * self == MaskCheck :: IsBorrowed ) . then ( || ctx. iw_context ( ) . bool_type ( ) . into ( ) )
594+ }
595+
586596 /// Generate code to perform the check on the specified bit of the mask.
587597 /// (Does not check the index is in bounds.)
588598 fn emit < ' c , H : HugrView < Node = Node > > (
@@ -591,12 +601,12 @@ impl MaskCheck {
591601 ctx : & mut EmitFuncContext < ' c , ' _ , H > ,
592602 mask_ptr : PointerValue < ' c > ,
593603 idx : IntValue < ' c > ,
594- ) -> Result < ( ) > {
604+ ) -> Result < Option < BasicValueEnum < ' c > > > {
595605 get_or_make_function (
596606 ctx,
597607 self . func_name ( ) ,
598608 [ mask_ptr. into ( ) , idx. into ( ) ] ,
599- None ,
609+ self . return_type ( ctx ) ,
600610 |ctx, [ mask_ptr, idx] | {
601611 // Compute mask bitarray block index via `idx // BLOCK_SIZE`
602612 let mask_ptr = mask_ptr. into_pointer_value ( ) ;
@@ -613,10 +623,15 @@ impl MaskCheck {
613623 let block_shifted = builder. build_right_shift ( block, idx_in_block, false , "" ) ?;
614624 let bit =
615625 builder. build_int_truncate ( block_shifted, ctx. iw_context ( ) . bool_type ( ) , "" ) ?;
626+ if * self == MaskCheck :: IsBorrowed {
627+ // Just return the bit
628+ return Ok ( Some ( bit. as_basic_value_enum ( ) ) ) ;
629+ }
616630 let panic_bb = ctx. build_positioned_new_block ( "panic" , None , |ctx, panic_bb| {
617631 let err: & ConstError = match self {
618632 MaskCheck :: CheckNotBorrowed | MaskCheck :: Borrow => & ERR_ALREADY_BORROWED ,
619633 MaskCheck :: Return => & ERR_NOT_BORROWED ,
634+ MaskCheck :: IsBorrowed => unreachable ! ( "handled above" ) ,
620635 } ;
621636 let err_val = ctx. emit_custom_const ( err) . unwrap ( ) ;
622637 ccg. emit_panic ( ctx, err_val) ?;
@@ -641,13 +656,13 @@ impl MaskCheck {
641656 let ( if_borrowed, if_present) = match self {
642657 MaskCheck :: CheckNotBorrowed | MaskCheck :: Borrow => ( panic_bb, ok_bb) ,
643658 MaskCheck :: Return => ( ok_bb, panic_bb) ,
659+ MaskCheck :: IsBorrowed => unreachable ! ( "handled above" ) ,
644660 } ;
645661 ctx. builder ( )
646662 . build_conditional_branch ( bit, if_borrowed, if_present) ?;
647663 Ok ( None )
648664 } ,
649- ) ?;
650- Ok ( ( ) )
665+ )
651666 }
652667}
653668
@@ -1570,6 +1585,22 @@ pub fn emit_barray_unsafe_op<'c, H: HugrView<Node = Node>>(
15701585 let ( _, array_v) = build_barray_alloc ( ctx, ccg, elem_ty, size, true ) ?;
15711586 outputs. finish ( ctx. builder ( ) , [ array_v. into ( ) ] )
15721587 }
1588+ BArrayUnsafeOpDef :: is_borrowed => {
1589+ let [ array_v, index_v] = inputs
1590+ . try_into ( )
1591+ . map_err ( |_| anyhow ! ( "BArrayUnsafeOpDef::is_borrowed expects two arguments" ) ) ?;
1592+ let BArrayFatPtrComponents {
1593+ mask_ptr, offset, ..
1594+ } = decompose_barray_fat_pointer ( builder, array_v) ?;
1595+ let index_v = index_v. into_int_value ( ) ;
1596+ build_bounds_check ( ccg, ctx, size, index_v) ?;
1597+ let offset_index_v = ctx. builder ( ) . build_int_add ( index_v, offset, "" ) ?;
1598+ // let bit = build_is_borrowed_check(ctx, mask_ptr, offset_index_v)?;
1599+ let bit = MaskCheck :: IsBorrowed
1600+ . emit ( ccg, ctx, mask_ptr, offset_index_v) ?
1601+ . expect ( "IsBorrowed always returns a value" ) ;
1602+ outputs. finish ( ctx. builder ( ) , [ bit. into ( ) , array_v] )
1603+ }
15731604 _ => todo ! ( ) ,
15741605 }
15751606}
@@ -2634,7 +2665,7 @@ mod test {
26342665 // - Pops specified numbers from the left to introduce an offset
26352666 // - Converts it into a regular array
26362667 // - Converts it back into a borrow array
2637- // - Borrows alls elements, sums them up, and returns the sum
2668+ // - Borrows all elements, sums them up, and returns the sum
26382669
26392670 let int_ty = int_type ( 6 ) ;
26402671 let hugr = SimpleHugrConfig :: new ( )
@@ -2908,4 +2939,77 @@ mod test {
29082939 let msg = "Some array elements have been borrowed" ;
29092940 assert_eq ! ( & exec_ctx. exec_hugr_panicking( hugr, "main" ) , msg) ;
29102941 }
2942+
2943+ #[ rstest]
2944+ fn exec_is_borrowed_basic ( mut exec_ctx : TestContext ) {
2945+ // We build a HUGR that:
2946+ // - Creates a borrow array [1,2,3]
2947+ // - Borrows index 1
2948+ // - Checks is_borrowed for indices 0, 1
2949+ // - Returns 1 if [false, true], else 0
2950+ let int_ty = int_type ( 6 ) ;
2951+ let size = 3 ;
2952+ let hugr = SimpleHugrConfig :: new ( )
2953+ . with_outs ( int_ty. clone ( ) )
2954+ . with_extensions ( exec_registry ( ) )
2955+ . finish ( |mut builder| {
2956+ let barray = borrow_array:: BArrayValue :: new (
2957+ int_ty. clone ( ) ,
2958+ ( 1 ..=3 )
2959+ . map ( |i| ConstInt :: new_u ( 6 , i) . unwrap ( ) . into ( ) )
2960+ . collect_vec ( ) ,
2961+ ) ;
2962+ let barray = builder. add_load_value ( barray) ;
2963+ let idx1 = builder. add_load_value ( ConstUsize :: new ( 1 ) ) ;
2964+ let ( _, barray) = builder
2965+ . add_borrow_array_borrow ( int_ty. clone ( ) , size, barray, idx1)
2966+ . unwrap ( ) ;
2967+
2968+ let idx0 = builder. add_load_value ( ConstUsize :: new ( 0 ) ) ;
2969+ let ( arr, b0_bools) =
2970+ [ idx0, idx1]
2971+ . iter ( )
2972+ . fold ( ( barray, Vec :: new ( ) ) , |( arr, mut bools) , idx| {
2973+ let ( b, arr) = builder
2974+ . add_is_borrowed ( int_ty. clone ( ) , size, arr, * idx)
2975+ . unwrap ( ) ;
2976+ bools. push ( b) ;
2977+ ( arr, bools)
2978+ } ) ;
2979+ let [ b0, b1] = b0_bools. try_into ( ) . unwrap ( ) ;
2980+
2981+ let b0 = builder. add_not ( b0) . unwrap ( ) ; // flip b0 to true
2982+ let and01 = builder. add_and ( b0, b1) . unwrap ( ) ;
2983+ let one = builder. add_load_value ( ConstInt :: new_u ( 6 , 1 ) . unwrap ( ) ) ;
2984+ let zero = builder. add_load_value ( ConstInt :: new_u ( 6 , 0 ) . unwrap ( ) ) ;
2985+ let mut cond = builder
2986+ . conditional_builder (
2987+ ( [ type_row ! [ ] , type_row ! [ ] ] , and01) ,
2988+ [ ] ,
2989+ int_ty. clone ( ) . into ( ) ,
2990+ )
2991+ . unwrap ( ) ;
2992+ cond. case_builder ( 0 )
2993+ . unwrap ( )
2994+ . finish_with_outputs ( [ zero] )
2995+ . unwrap ( ) ;
2996+ cond. case_builder ( 1 )
2997+ . unwrap ( )
2998+ . finish_with_outputs ( [ one] )
2999+ . unwrap ( ) ;
3000+ let out = cond. finish_sub_container ( ) . unwrap ( ) . out_wire ( 0 ) ;
3001+ builder
3002+ . add_borrow_array_discard ( int_ty. clone ( ) , size, arr)
3003+ . unwrap ( ) ;
3004+ builder. finish_hugr_with_outputs ( [ out] ) . unwrap ( )
3005+ } ) ;
3006+
3007+ exec_ctx. add_extensions ( |cge| {
3008+ cge. add_default_prelude_extensions ( )
3009+ . add_logic_extensions ( )
3010+ . add_default_borrow_array_extensions ( DefaultPreludeCodegen )
3011+ . add_default_int_extensions ( )
3012+ } ) ;
3013+ assert_eq ! ( 1 , exec_ctx. exec_hugr_u64( hugr, "main" ) ) ;
3014+ }
29113015}
0 commit comments