33from contextlib import contextmanager
44from functools import singledispatch
55from textwrap import dedent
6- from typing import Union
6+ from typing import TYPE_CHECKING , Callable , Optional , Union , cast
77
88import numba
99import numba .np .unsafe .ndarray as numba_ndarray
2222from aesara .compile .ops import DeepCopyOp
2323from aesara .graph .basic import Apply , NoParams
2424from aesara .graph .fg import FunctionGraph
25+ from aesara .graph .op import Op
2526from aesara .graph .type import Type
2627from aesara .ifelse import IfElse
2728from aesara .link .utils import (
4849from aesara .tensor .type_other import MakeSlice , NoneConst
4950
5051
52+ if TYPE_CHECKING :
53+ from aesara .graph .op import StorageMapType
54+
55+
5156def numba_njit (* args , ** kwargs ):
5257
5358 if len (args ) > 0 and callable (args [0 ]):
@@ -335,9 +340,42 @@ def numba_const_convert(data, dtype=None, **kwargs):
335340 return data
336341
337342
343+ def numba_funcify (obj , node = None , storage_map = None , ** kwargs ) -> Callable :
344+ """Convert `obj` to a Numba-JITable object."""
345+ return _numba_funcify (obj , node = node , storage_map = storage_map , ** kwargs )
346+
347+
338348@singledispatch
339- def numba_funcify (op , node = None , storage_map = None , ** kwargs ):
340- """Create a Numba compatible function from an Aesara `Op`."""
349+ def _numba_funcify (
350+ obj ,
351+ node : Optional [Apply ] = None ,
352+ storage_map : Optional ["StorageMapType" ] = None ,
353+ ** kwargs ,
354+ ) -> Callable :
355+ r"""Dispatch on Aesara object types to perform Numba conversions.
356+
357+ Arguments
358+ ---------
359+ obj
360+ The object used to determine the appropriate conversion function based
361+ on its type. This is generally an `Op` instance, but `FunctionGraph`\s
362+ are also supported.
363+ node
364+ When `obj` is an `Op`, this value should be the corresponding `Apply` node.
365+ storage_map
366+ A storage map with, for example, the constant and `SharedVariable` values
367+ of the graph being converted.
368+
369+ Returns
370+ -------
371+ A `Callable` that can be JIT-compiled in Numba using `numba.jit`.
372+
373+ """
374+
375+
376+ @_numba_funcify .register (Op )
377+ def numba_funcify_perform (op , node , storage_map = None , ** kwargs ) -> Callable :
378+ """Create a Numba compatible function from an Aesara `Op.perform`."""
341379
342380 warnings .warn (
343381 f"Numba will use object mode to run { op } 's perform method" ,
@@ -388,10 +426,10 @@ def perform(*inputs):
388426 ret = py_perform_return (inputs )
389427 return ret
390428
391- return perform
429+ return cast ( Callable , perform )
392430
393431
394- @numba_funcify .register (OpFromGraph )
432+ @_numba_funcify .register (OpFromGraph )
395433def numba_funcify_OpFromGraph (op , node = None , ** kwargs ):
396434
397435 _ = kwargs .pop ("storage_map" , None )
@@ -413,7 +451,7 @@ def opfromgraph(*inputs):
413451 return opfromgraph
414452
415453
416- @numba_funcify .register (FunctionGraph )
454+ @_numba_funcify .register (FunctionGraph )
417455def numba_funcify_FunctionGraph (
418456 fgraph ,
419457 node = None ,
@@ -521,9 +559,9 @@ def {fn_name}({", ".join(input_names)}):
521559 return subtensor_def_src
522560
523561
524- @numba_funcify .register (Subtensor )
525- @numba_funcify .register (AdvancedSubtensor )
526- @numba_funcify .register (AdvancedSubtensor1 )
562+ @_numba_funcify .register (Subtensor )
563+ @_numba_funcify .register (AdvancedSubtensor )
564+ @_numba_funcify .register (AdvancedSubtensor1 )
527565def numba_funcify_Subtensor (op , node , ** kwargs ):
528566
529567 subtensor_def_src = create_index_func (
@@ -539,8 +577,8 @@ def numba_funcify_Subtensor(op, node, **kwargs):
539577 return numba_njit (subtensor_fn )
540578
541579
542- @numba_funcify .register (IncSubtensor )
543- @numba_funcify .register (AdvancedIncSubtensor )
580+ @_numba_funcify .register (IncSubtensor )
581+ @_numba_funcify .register (AdvancedIncSubtensor )
544582def numba_funcify_IncSubtensor (op , node , ** kwargs ):
545583
546584 incsubtensor_def_src = create_index_func (
@@ -556,7 +594,7 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
556594 return numba_njit (incsubtensor_fn )
557595
558596
559- @numba_funcify .register (AdvancedIncSubtensor1 )
597+ @_numba_funcify .register (AdvancedIncSubtensor1 )
560598def numba_funcify_AdvancedIncSubtensor1 (op , node , ** kwargs ):
561599 inplace = op .inplace
562600 set_instead_of_inc = op .set_instead_of_inc
@@ -589,7 +627,7 @@ def advancedincsubtensor1(x, vals, idxs):
589627 return advancedincsubtensor1
590628
591629
592- @numba_funcify .register (DeepCopyOp )
630+ @_numba_funcify .register (DeepCopyOp )
593631def numba_funcify_DeepCopyOp (op , node , ** kwargs ):
594632
595633 # Scalars are apparently returned as actual Python scalar types and not
@@ -611,26 +649,26 @@ def deepcopyop(x):
611649 return deepcopyop
612650
613651
614- @numba_funcify .register (MakeSlice )
615- def numba_funcify_MakeSlice (op , ** kwargs ):
652+ @_numba_funcify .register (MakeSlice )
653+ def numba_funcify_MakeSlice (op , node , ** kwargs ):
616654 @numba_njit
617655 def makeslice (* x ):
618656 return slice (* x )
619657
620658 return makeslice
621659
622660
623- @numba_funcify .register (Shape )
624- def numba_funcify_Shape (op , ** kwargs ):
661+ @_numba_funcify .register (Shape )
662+ def numba_funcify_Shape (op , node , ** kwargs ):
625663 @numba_njit (inline = "always" )
626664 def shape (x ):
627665 return np .asarray (np .shape (x ))
628666
629667 return shape
630668
631669
632- @numba_funcify .register (Shape_i )
633- def numba_funcify_Shape_i (op , ** kwargs ):
670+ @_numba_funcify .register (Shape_i )
671+ def numba_funcify_Shape_i (op , node , ** kwargs ):
634672 i = op .i
635673
636674 @numba_njit (inline = "always" )
@@ -660,8 +698,8 @@ def codegen(context, builder, signature, args):
660698 return sig , codegen
661699
662700
663- @numba_funcify .register (Reshape )
664- def numba_funcify_Reshape (op , ** kwargs ):
701+ @_numba_funcify .register (Reshape )
702+ def numba_funcify_Reshape (op , node , ** kwargs ):
665703 ndim = op .ndim
666704
667705 if ndim == 0 :
@@ -683,7 +721,7 @@ def reshape(x, shape):
683721 return reshape
684722
685723
686- @numba_funcify .register (SpecifyShape )
724+ @_numba_funcify .register (SpecifyShape )
687725def numba_funcify_SpecifyShape (op , node , ** kwargs ):
688726 shape_inputs = node .inputs [1 :]
689727 shape_input_names = ["shape_" + str (i ) for i in range (len (shape_inputs ))]
@@ -730,7 +768,7 @@ def inputs_cast(x):
730768 return inputs_cast
731769
732770
733- @numba_funcify .register (Dot )
771+ @_numba_funcify .register (Dot )
734772def numba_funcify_Dot (op , node , ** kwargs ):
735773 # Numba's `np.dot` does not support integer dtypes, so we need to cast to
736774 # float.
@@ -745,7 +783,7 @@ def dot(x, y):
745783 return dot
746784
747785
748- @numba_funcify .register (Softplus )
786+ @_numba_funcify .register (Softplus )
749787def numba_funcify_Softplus (op , node , ** kwargs ):
750788
751789 x_dtype = np .dtype (node .inputs [0 ].dtype )
@@ -764,7 +802,7 @@ def softplus(x):
764802 return softplus
765803
766804
767- @numba_funcify .register (Cholesky )
805+ @_numba_funcify .register (Cholesky )
768806def numba_funcify_Cholesky (op , node , ** kwargs ):
769807 lower = op .lower
770808
@@ -800,7 +838,7 @@ def cholesky(a):
800838 return cholesky
801839
802840
803- @numba_funcify .register (Solve )
841+ @_numba_funcify .register (Solve )
804842def numba_funcify_Solve (op , node , ** kwargs ):
805843
806844 assume_a = op .assume_a
@@ -847,7 +885,7 @@ def solve(a, b):
847885 return solve
848886
849887
850- @numba_funcify .register (BatchedDot )
888+ @_numba_funcify .register (BatchedDot )
851889def numba_funcify_BatchedDot (op , node , ** kwargs ):
852890 dtype = node .outputs [0 ].type .numpy_dtype
853891
@@ -868,7 +906,7 @@ def batched_dot(x, y):
868906# optimizations are apparently already performed by Numba
869907
870908
871- @numba_funcify .register (IfElse )
909+ @_numba_funcify .register (IfElse )
872910def numba_funcify_IfElse (op , ** kwargs ):
873911 n_outs = op .n_outs
874912
0 commit comments