@@ -11,11 +11,11 @@ use inkwell::{
1111 basic_block:: BasicBlock ,
1212 builder:: Builder ,
1313 context:: Context ,
14- module:: Module ,
14+ module:: { Linkage , Module } ,
1515 types:: { BasicType , BasicTypeEnum , FunctionType } ,
16- values:: { BasicValueEnum , FunctionValue , GlobalValue , IntValue } ,
16+ values:: { BasicValue , BasicValueEnum , FunctionValue , GlobalValue , IntValue } ,
1717} ;
18- use itertools:: zip_eq;
18+ use itertools:: { Itertools , zip_eq} ;
1919
2020use crate :: types:: { HugrFuncType , HugrSumType , HugrType , TypingSession } ;
2121use crate :: { custom:: CodegenExtsMap , types:: LLVMSumType , utils:: fat:: FatNode } ;
@@ -357,6 +357,70 @@ pub fn build_ok_or_else<'c, H: HugrView<Node = Node>>(
357357 let either = builder. build_select ( is_ok, right, left, "" ) ?;
358358 Ok ( either)
359359}
360+ /// Helper to outline LLVM IR into a function call instead of inlining it every time.
361+ ///
362+ /// The first time this helper is called with a given function name, a function is built
363+ /// using the provided closure. Future invocations with the same name will just emit calls
364+ /// to this function.
365+ ///
366+ /// The return type is specified by `ret_type`, and if `Some` then the closure must return
367+ /// a value of that type, which will be returned from the function. Otherwise, the function
368+ /// will return void.
369+ pub fn get_or_make_function < ' c , H : HugrView < Node = Node > , const N : usize > (
370+ ctx : & mut EmitFuncContext < ' c , ' _ , H > ,
371+ func_name : & str ,
372+ args : [ BasicValueEnum < ' c > ; N ] ,
373+ ret_type : Option < BasicTypeEnum < ' c > > ,
374+ go : impl FnOnce (
375+ & mut EmitFuncContext < ' c , ' _ , H > ,
376+ [ BasicValueEnum < ' c > ; N ] ,
377+ ) -> Result < Option < BasicValueEnum < ' c > > > ,
378+ ) -> Result < Option < BasicValueEnum < ' c > > > {
379+ let func = match ctx. get_current_module ( ) . get_function ( func_name) {
380+ Some ( func) => func,
381+ None => {
382+ let arg_tys = args. iter ( ) . map ( |v| v. get_type ( ) . into ( ) ) . collect_vec ( ) ;
383+ let sig = match ret_type {
384+ Some ( ret_ty) => ret_ty. fn_type ( & arg_tys, false ) ,
385+ None => ctx. iw_context ( ) . void_type ( ) . fn_type ( & arg_tys, false ) ,
386+ } ;
387+ let func =
388+ ctx. get_current_module ( )
389+ . add_function ( func_name, sig, Some ( Linkage :: Internal ) ) ;
390+ let bb = ctx. iw_context ( ) . append_basic_block ( func, "" ) ;
391+ let args = ( 0 ..N )
392+ . map ( |i| func. get_nth_param ( i as u32 ) . unwrap ( ) )
393+ . collect_array ( )
394+ . unwrap ( ) ;
395+
396+ let curr_bb = ctx. builder ( ) . get_insert_block ( ) . unwrap ( ) ;
397+ let curr_func = ctx. func ;
398+
399+ ctx. builder ( ) . position_at_end ( bb) ;
400+ ctx. func = func;
401+ let ret_val = go ( ctx, args) ?;
402+ if ctx
403+ . builder ( )
404+ . get_insert_block ( )
405+ . unwrap ( )
406+ . get_terminator ( )
407+ . is_none ( )
408+ {
409+ ctx. builder ( )
410+ . build_return ( ret_val. as_ref ( ) . map :: < & dyn BasicValue , _ > ( |v| v) ) ?;
411+ }
412+
413+ ctx. builder ( ) . position_at_end ( curr_bb) ;
414+ ctx. func = curr_func;
415+ func
416+ }
417+ } ;
418+ let call_site =
419+ ctx. builder ( )
420+ . build_call ( func, & args. iter ( ) . map ( |& a| a. into ( ) ) . collect_vec ( ) , "" ) ?;
421+ let result = call_site. try_as_basic_value ( ) . left ( ) ;
422+ Ok ( result)
423+ }
360424
361425#[ cfg( test) ]
362426mod tests {
0 commit comments