Skip to content

Commit 4379a0f

Browse files
mark-kochss2165acl-cqc
authored
feat: LLVM lowering for borrow arrays using bitmasks (#2574)
Closes #2551 Substantial duplication between existing hugr-llvm/src/extension/collections/array.rs and new ...borrow_array.rs. --------- Co-authored-by: Seyon Sivarajah <[email protected]> Co-authored-by: Alan Lawrence <[email protected]>
1 parent cd04cf8 commit 4379a0f

12 files changed

+4340
-4
lines changed

hugr-core/src/std_extensions/collections/borrow_array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ pub trait BArrayOpBuilder: GenericArrayOpBuilder {
622622
elem_ty: Type,
623623
input: Wire,
624624
) -> Result<(), BuildError> {
625-
self.add_generic_array_discard_empty::<Array>(elem_ty, input)
625+
self.add_generic_array_discard_empty::<BorrowArray>(elem_ty, input)
626626
}
627627

628628
/// Adds a borrow array borrow operation to the dataflow graph.

hugr-llvm/src/emit/func.rs

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ use inkwell::{
1111
basic_block::BasicBlock,
1212
builder::Builder,
1313
context::Context,
14-
module::Module,
14+
module::{Linkage, Module},
1515
types::{BasicType, BasicTypeEnum, FunctionType},
16-
values::{BasicValueEnum, FunctionValue, GlobalValue, IntValue},
16+
values::{BasicValue, BasicValueEnum, FunctionValue, GlobalValue, IntValue},
1717
};
18-
use itertools::zip_eq;
18+
use itertools::{Itertools, zip_eq};
1919

2020
use crate::types::{HugrFuncType, HugrSumType, HugrType, TypingSession};
2121
use crate::{custom::CodegenExtsMap, types::LLVMSumType, utils::fat::FatNode};
@@ -357,6 +357,70 @@ pub fn build_ok_or_else<'c, H: HugrView<Node = Node>>(
357357
let either = builder.build_select(is_ok, right, left, "")?;
358358
Ok(either)
359359
}
360+
/// Helper to outline LLVM IR into a function call instead of inlining it every time.
361+
///
362+
/// The first time this helper is called with a given function name, a function is built
363+
/// using the provided closure. Future invocations with the same name will just emit calls
364+
/// to this function.
365+
///
366+
/// The return type is specified by `ret_type`, and if `Some` then the closure must return
367+
/// a value of that type, which will be returned from the function. Otherwise, the function
368+
/// will return void.
369+
pub fn get_or_make_function<'c, H: HugrView<Node = Node>, const N: usize>(
370+
ctx: &mut EmitFuncContext<'c, '_, H>,
371+
func_name: &str,
372+
args: [BasicValueEnum<'c>; N],
373+
ret_type: Option<BasicTypeEnum<'c>>,
374+
go: impl FnOnce(
375+
&mut EmitFuncContext<'c, '_, H>,
376+
[BasicValueEnum<'c>; N],
377+
) -> Result<Option<BasicValueEnum<'c>>>,
378+
) -> Result<Option<BasicValueEnum<'c>>> {
379+
let func = match ctx.get_current_module().get_function(func_name) {
380+
Some(func) => func,
381+
None => {
382+
let arg_tys = args.iter().map(|v| v.get_type().into()).collect_vec();
383+
let sig = match ret_type {
384+
Some(ret_ty) => ret_ty.fn_type(&arg_tys, false),
385+
None => ctx.iw_context().void_type().fn_type(&arg_tys, false),
386+
};
387+
let func =
388+
ctx.get_current_module()
389+
.add_function(func_name, sig, Some(Linkage::Internal));
390+
let bb = ctx.iw_context().append_basic_block(func, "");
391+
let args = (0..N)
392+
.map(|i| func.get_nth_param(i as u32).unwrap())
393+
.collect_array()
394+
.unwrap();
395+
396+
let curr_bb = ctx.builder().get_insert_block().unwrap();
397+
let curr_func = ctx.func;
398+
399+
ctx.builder().position_at_end(bb);
400+
ctx.func = func;
401+
let ret_val = go(ctx, args)?;
402+
if ctx
403+
.builder()
404+
.get_insert_block()
405+
.unwrap()
406+
.get_terminator()
407+
.is_none()
408+
{
409+
ctx.builder()
410+
.build_return(ret_val.as_ref().map::<&dyn BasicValue, _>(|v| v))?;
411+
}
412+
413+
ctx.builder().position_at_end(curr_bb);
414+
ctx.func = curr_func;
415+
func
416+
}
417+
};
418+
let call_site =
419+
ctx.builder()
420+
.build_call(func, &args.iter().map(|&a| a.into()).collect_vec(), "")?;
421+
let result = call_site.try_as_basic_value().left();
422+
Ok(result)
423+
}
360424

361425
#[cfg(test)]
362426
mod tests {
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Emission logic for collections.
22
33
pub mod array;
4+
pub mod borrow_array;
45
pub mod list;
56
pub mod stack_array;
67
pub mod static_array;

0 commit comments

Comments
 (0)