1+ from functools import wraps
2+ from typing import List
13from triton .language import core
4+ from triton .language .math import _add_math_1arg_docstr , _add_math_2arg_docstr , _add_math_3arg_docstr
5+ from triton .language import semantic
6+
7+ T = core .TypeVar ('T' )
8+
9+
10+ def _check_dtype (dtypes : List [str ]) -> T :
11+ """
12+ We're following libdevice's convention to check accepted data types for math functions.
13+ It is not a good practice to support all data types as accelerators/GPUs don't support
14+ many float16 and bfloat16 math operations.
15+ We should let the users know that they are using and invoke explicit cast to convert
16+ the data type to the supported one.
17+ """
18+
19+ def wrapper (fn ):
20+
21+ @wraps (fn )
22+ def check (* args , ** kwargs ):
23+ # concatenate args and kwargs
24+ all_args = list (args ) + list (kwargs .values ())
25+ for arg in [a for a in all_args if isinstance (a , core .tensor )]:
26+ arg_type = arg .type .scalar .name
27+ if hasattr (arg , 'was_bool_to_int8' ) and arg .was_bool_to_int8 :
28+ # In Triton, int1 maps to the boolean type
29+ arg_type = 'int1'
30+ if arg_type not in dtypes :
31+ raise ValueError (f"Expected dtype { dtypes } but got { arg_type } " )
32+ return fn (* args , ** kwargs )
33+
34+ return check
35+
36+ return wrapper
37+
38+
39+ @core .extern
40+ @_check_dtype (dtypes = ["int32" , "uint32" ])
41+ @_add_math_2arg_docstr ("most significant N bits of the 2N-bit product" )
42+ def umulhi (x , y , _builder = None ):
43+ x = semantic .to_tensor (x , _builder )
44+ y = semantic .to_tensor (y , _builder )
45+ x , y = core .binary_op_type_legalization (x , y , _builder )
46+ return core .tensor (_builder .create_umulhi (x .handle , y .handle ), x .type )
47+
48+ @core .extern
49+ @_check_dtype (dtypes = ["bf16" , "fp16" , "fp32" ])
50+ @_add_math_1arg_docstr ("exponential" )
51+ @core ._tensor_member_fn
52+ def exp (x , _builder = None ):
53+ x = semantic .to_tensor (x , _builder )
54+ return core .tensor (_builder .create_exp (x .handle ), x .type )
55+
56+ @core .extern
57+ @_check_dtype (dtypes = ["bf16" , "fp16" , "fp32" ])
58+ @_add_math_1arg_docstr ("exponential (base 2)" )
59+ @core ._tensor_member_fn
60+ def exp2 (x , _builder = None ):
61+ x = semantic .to_tensor (x , _builder )
62+ return core .tensor (_builder .create_exp2 (x .handle ), x .type )
63+
64+ @core .extern
65+ @_check_dtype (dtypes = ["bf16" , "fp16" , "fp32" ])
66+ @_add_math_1arg_docstr ("natural logarithm" )
67+ @core ._tensor_member_fn
68+ def log (x , _builder = None ):
69+ x = semantic .to_tensor (x , _builder )
70+ return core .tensor (_builder .create_log (x .handle ), x .type )
71+
72+ @core .extern
73+ @_check_dtype (dtypes = ["bf16" , "fp16" , "fp32" ])
74+ @_add_math_1arg_docstr ("logarithm (base 2)" )
75+ @core ._tensor_member_fn
76+ def log2 (x , _builder = None ):
77+ x = semantic .to_tensor (x , _builder )
78+ return core .tensor (_builder .create_log2 (x .handle ), x .type )
79+
80+ @core .extern
81+ @_check_dtype (dtypes = ["bf16" , "fp16" , "fp32" ])
82+ @_add_math_1arg_docstr ("cosine" )
83+ @core ._tensor_member_fn
84+ def cos (x , _builder = None ):
85+ x = semantic .to_tensor (x , _builder )
86+ return core .tensor (_builder .create_cos (x .handle ), x .type )
87+
88+ @core .extern
89+ @_check_dtype (dtypes = ["bf16" , "fp16" , "fp32" ])
90+ @_add_math_1arg_docstr ("sine" )
91+ @core ._tensor_member_fn
92+ def sin (x , _builder = None ):
93+ x = semantic .to_tensor (x , _builder )
94+ return core .tensor (_builder .create_sin (x .handle ), x .type )
95+
96+ @core .extern
97+ @_check_dtype (dtypes = ["bf16" , "fp16" , "fp32" ])
98+ @_add_math_1arg_docstr ("fast square root" )
99+ @core ._tensor_member_fn
100+ def sqrt (x , _builder = None ):
101+ x = semantic .to_tensor (x , _builder )
102+ return core .tensor (_builder .create_sqrt (x .handle ), x .type )
103+
104+ @core .extern
105+ @_check_dtype (dtypes = ["bf16" , "fp16" , "fp32" ])
106+ @_add_math_1arg_docstr ("precise square root (rounding to nearest wrt the IEEE standard)" )
107+ @core ._tensor_member_fn
108+ def sqrt_rn (x , _builder = None ):
109+ x = semantic .to_tensor (x , _builder )
110+ return core .tensor (_builder .create_precise_sqrt (x .handle ), x .type )
111+
112+ @core .extern
113+ @_check_dtype (dtypes = ["bf16" , "fp16" , "fp32" ])
114+ @_add_math_1arg_docstr ("inverse square root" )
115+ @core ._tensor_member_fn
116+ def rsqrt (x , _builder = None ):
117+ x = semantic .to_tensor (x , _builder )
118+ return core .tensor (_builder .create_rsqrt (x .handle ), x .type )
119+
120+ @core .extern
121+ @_check_dtype (dtypes = ["bf16" , "fp16" , "fp32" ])
122+ @_add_math_2arg_docstr ("precise division (rounding to nearest wrt the IEEE standard)" )
123+ def div_rn (x , y , _builder = None ):
124+ x = semantic .to_tensor (x , _builder )
125+ y = semantic .to_tensor (y , _builder )
126+ x , y = core .binary_op_type_legalization (x , y , _builder )
127+ return core .tensor (_builder .create_precise_divf (x .handle , y .handle ), x .type )
128+
129+ @core .extern
130+ @_check_dtype (dtypes = ["bf16" , "fp16" , "fp32" ])
131+ @_add_math_1arg_docstr ("error function" )
132+ @core ._tensor_member_fn
133+ def erf (x , _builder = None ):
134+ x = semantic .to_tensor (x , _builder )
135+ return core .tensor (_builder .create_erf (x .handle ), x .type )
136+
137+ @core .extern
138+ @_check_dtype (dtypes = ["bf16" , "fp16" , "fp32" ])
139+ @_add_math_1arg_docstr ("error function" )
140+ @core ._tensor_member_fn
141+ def tanh (x , _builder = None ):
142+ x = semantic .to_tensor (x , _builder )
143+ return core .tensor (_builder .create_tanh (x .handle ), x .type )
144+
145+ @core .extern
146+ @_check_dtype (dtypes = ["bf16" , "fp16" , "fp32" ])
147+ @_add_math_1arg_docstr ("floor" )
148+ @core ._tensor_member_fn
149+ def floor (x , _builder = None ):
150+ x = semantic .to_tensor (x , _builder )
151+ return core .tensor (_builder .create_floor (x .handle ), x .type )
152+
153+
154+ @core .extern
155+ @_check_dtype (dtypes = ["bf16" , "fp16" , "fp32" ])
156+ @_add_math_1arg_docstr ("ceil" )
157+ @core ._tensor_member_fn
158+ def ceil (x , _builder = None ):
159+ x = semantic .to_tensor (x , _builder )
160+ return core .tensor (_builder .create_ceil (x .handle ), x .type )
161+
162+
163+ @core .extern
164+ @_check_dtype (dtypes = ["bf16" , "fp16" , "fp32" ])
165+ @_add_math_3arg_docstr ("fused multiply-add" )
166+ def fma (x , y , z , _builder = None ):
167+ x = semantic .to_tensor (x , _builder )
168+ y = semantic .to_tensor (y , _builder )
169+ z = semantic .to_tensor (z , _builder )
170+ x , y = core .binary_op_type_legalization (x , y , _builder )
171+ z , x = core .binary_op_type_legalization (z , x , _builder )
172+ z , y = core .binary_op_type_legalization (z , y , _builder )
173+ return core .tensor (_builder .create_fma (x .handle , y .handle , z .handle ), x .type )
174+
2175
3176@core .extern
4177def reciprocal (arg0 , _builder = None ):
@@ -151,5 +324,5 @@ def trunc(arg0, _builder=None):
151324def round (arg0 , _builder = None ):
152325 return core .extern_elementwise (
153326 "" , "" , [arg0 ], {
154- (core .dtype ("fp32" ), ): ("__hmf_roundf" , core .dtype ("fp32" )),
327+ (core .dtype ("fp32" ), ): ("__hmf_roundf" , core .dtype ("fp32" )),
155328 }, is_pure = True , _builder = _builder )
0 commit comments