Skip to content

Commit 18096c7

Browse files
committed
[Decoupling] Decouple language.math extension for ascend
1 parent aad18af commit 18096c7

File tree

2 files changed

+206
-87
lines changed

2 files changed

+206
-87
lines changed

python/triton/testing.py

Lines changed: 32 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -523,25 +523,6 @@ def get_max_simd_tflops(dtype, clock_rate, device=None):
523523
tensor_descriptor_type,
524524
)
525525
from .language.standard_ext import flip, sigmoid, softmax, isfinited, finitef, rint, atan2
526-
from .language.math_ext import (
527-
umulhi,
528-
exp,
529-
exp2,
530-
log,
531-
log2,
532-
cos,
533-
sin,
534-
sqrt,
535-
sqrt_rn,
536-
rsqrt,
537-
div_rn,
538-
erf,
539-
tanh,
540-
floor,
541-
ceil,
542-
_check_dtype,
543-
fma,
544-
)
545526
from . import language
546527

547528
language.flip = flip
@@ -550,75 +531,40 @@ def get_max_simd_tflops(dtype, clock_rate, device=None):
550531
language.tensor_descriptor = tensor_descriptor
551532
language.tensor_descriptor_type = tensor_descriptor_type
552533

553-
language.umulhi = umulhi
554-
language.exp = exp
555-
language.exp2 = exp2
556-
language.log = log
557-
language.log2 = log2
558-
language.cos = cos
559-
language.sin = sin
560-
language.sqrt = sqrt
561-
language.sqrt_rn = sqrt_rn
562-
language.rsqrt = rsqrt
563-
language.div_rn = div_rn
564-
language.erf = erf
565-
language.tanh = tanh
566-
language.floor = floor
567-
language.ceil = ceil
568-
language.fma = fma
569-
language.math.umulhi = umulhi
570-
language.math.exp = exp
571-
language.math.exp2 = exp2
572-
language.math.log = log
573-
language.math.log2 = log2
574-
language.math.cos = cos
575-
language.math.sin = sin
576-
language.math.sqrt = sqrt
577-
language.math.sqrt_rn = sqrt_rn
578-
language.math.rsqrt = rsqrt
579-
language.math.div_rn = div_rn
580-
language.math.erf = erf
581-
language.math.tanh = tanh
582-
language.math.floor = floor
583-
language.math.ceil = ceil
584-
language.math._check_dtype = _check_dtype
585-
language.math.fma = fma
586-
language.math.isnan = language.extra.ascend.libdevice.isnan
587-
language.math.isinf = language.extra.ascend.libdevice.isinf
588-
language.math.reciprocal = language.extra.ascend.libdevice.reciprocal
589-
language.math.log1p = language.extra.ascend.libdevice.log1p
590-
language.math.relu = language.extra.ascend.libdevice.relu
591-
language.math.tan = language.extra.ascend.libdevice.tan
592-
language.math.atan = language.extra.ascend.libdevice.atan
534+
language.umulhi = language.extra.ascend.libdevice.umulhi
535+
language.exp = language.extra.ascend.libdevice.exp
536+
language.exp2 = language.extra.ascend.libdevice.exp2
537+
language.log = language.extra.ascend.libdevice.log
538+
language.log2 = language.extra.ascend.libdevice.log2
539+
language.cos = language.extra.ascend.libdevice.cos
540+
language.sin = language.extra.ascend.libdevice.sin
541+
language.sqrt = language.extra.ascend.libdevice.sqrt
542+
language.sqrt_rn = language.extra.ascend.libdevice.sqrt_rn
543+
language.rsqrt = language.extra.ascend.libdevice.rsqrt
544+
language.div_rn = language.extra.ascend.libdevice.div_rn
545+
language.erf = language.extra.ascend.libdevice.erf
546+
language.tanh = language.extra.ascend.libdevice.tanh
547+
language.floor = language.extra.ascend.libdevice.floor
548+
language.ceil = language.extra.ascend.libdevice.ceil
549+
language.fma = language.extra.ascend.libdevice.fma
550+
language.math.umulhi = language.extra.ascend.libdevice.umulhi
551+
language.math.exp = language.extra.ascend.libdevice.exp
552+
language.math.exp2 = language.extra.ascend.libdevice.exp2
553+
language.math.log = language.extra.ascend.libdevice.log
554+
language.math.log2 = language.extra.ascend.libdevice.log2
555+
language.math.cos = language.extra.ascend.libdevice.cos
556+
language.math.sin = language.extra.ascend.libdevice.sin
557+
language.math.sqrt = language.extra.ascend.libdevice.sqrt
558+
language.math.sqrt_rn = language.extra.ascend.libdevice.sqrt_rn
559+
language.math.rsqrt = language.extra.ascend.libdevice.rsqrt
560+
language.math.div_rn = language.extra.ascend.libdevice.div_rn
561+
language.math.erf = language.extra.ascend.libdevice.erf
593562
language.math.tanh = language.extra.ascend.libdevice.tanh
594-
language.math.ilogb = language.extra.ascend.libdevice.ilogb
595-
language.math.ldexp = language.extra.ascend.libdevice.ldexp
596-
language.math.pow = language.extra.ascend.libdevice.pow
597-
language.math.flip = language.extra.ascend.libdevice.flip
598-
language.math.atan2 = language.extra.ascend.libdevice.atan2
599-
language.math.div_rz = language.extra.ascend.libdevice.div_rz
600-
language.math.fmod = language.extra.ascend.libdevice.fmod
601-
language.math.trunc = language.extra.ascend.libdevice.trunc
602-
language.math.round = language.extra.ascend.libdevice.round
563+
language.math.floor = language.extra.ascend.libdevice.floor
564+
language.math.ceil = language.extra.ascend.libdevice.ceil
565+
language.math._check_dtype = language.extra.ascend.libdevice._check_dtype
566+
language.math.fma = language.extra.ascend.libdevice.fma
603567
language.math.finitef = finitef
604568
language.math.isfinited = isfinited
605569
language.math.rint = rint
606570
language.math.atan2 = atan2
607-
language.extra.ascend.libdevice.umulhi = language.math.umulhi
608-
language.extra.ascend.libdevice.exp = language.math.exp
609-
language.extra.ascend.libdevice.exp2 = language.math.exp2
610-
language.extra.ascend.libdevice.log = language.math.log
611-
language.extra.ascend.libdevice.log2 = language.math.log2
612-
language.extra.ascend.libdevice.cos = language.math.cos
613-
language.extra.ascend.libdevice.sin = language.math.sin
614-
language.extra.ascend.libdevice.sqrt = language.math.sqrt
615-
language.extra.ascend.libdevice.sqrt_rn = language.math.sqrt_rn
616-
language.extra.ascend.libdevice.rsqrt = language.math.rsqrt
617-
language.extra.ascend.libdevice.div_rn = language.math.div_rn
618-
language.extra.ascend.libdevice.erf = language.math.erf
619-
language.extra.ascend.libdevice.tanh = language.math.tanh
620-
language.extra.ascend.libdevice.floor = language.math.floor
621-
language.extra.ascend.libdevice.ceil = language.math.ceil
622-
language.extra.ascend.libdevice.fdiv = language.math.fdiv
623-
language.extra.ascend.libdevice.fma = language.math.fma
624-
language.extra.ascend.libdevice.abs = language.math.abs

third_party/ascend/language/ascend/libdevice.py

Lines changed: 174 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,177 @@
1+
from functools import wraps
2+
from typing import List
13
from triton.language import core
4+
from triton.language.math import _add_math_1arg_docstr, _add_math_2arg_docstr, _add_math_3arg_docstr
5+
from triton.language import semantic
6+
7+
T = core.TypeVar('T')
8+
9+
10+
def _check_dtype(dtypes: List[str]) -> T:
11+
"""
12+
We're following libdevice's convention to check accepted data types for math functions.
13+
It is not a good practice to support all data types as accelerators/GPUs don't support
14+
many float16 and bfloat16 math operations.
15+
We should let the users know that they are using and invoke explicit cast to convert
16+
the data type to the supported one.
17+
"""
18+
19+
def wrapper(fn):
20+
21+
@wraps(fn)
22+
def check(*args, **kwargs):
23+
# concatenate args and kwargs
24+
all_args = list(args) + list(kwargs.values())
25+
for arg in [a for a in all_args if isinstance(a, core.tensor)]:
26+
arg_type = arg.type.scalar.name
27+
if hasattr(arg, 'was_bool_to_int8') and arg.was_bool_to_int8:
28+
# In Triton, int1 maps to the boolean type
29+
arg_type = 'int1'
30+
if arg_type not in dtypes:
31+
raise ValueError(f"Expected dtype {dtypes} but got {arg_type}")
32+
return fn(*args, **kwargs)
33+
34+
return check
35+
36+
return wrapper
37+
38+
39+
@core.extern
40+
@_check_dtype(dtypes=["int32", "uint32"])
41+
@_add_math_2arg_docstr("most significant N bits of the 2N-bit product")
42+
def umulhi(x, y, _builder=None):
43+
x = semantic.to_tensor(x, _builder)
44+
y = semantic.to_tensor(y, _builder)
45+
x, y = core.binary_op_type_legalization(x, y, _builder)
46+
return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type)
47+
48+
@core.extern
49+
@_check_dtype(dtypes=["bf16", "fp16", "fp32"])
50+
@_add_math_1arg_docstr("exponential")
51+
@core._tensor_member_fn
52+
def exp(x, _builder=None):
53+
x = semantic.to_tensor(x, _builder)
54+
return core.tensor(_builder.create_exp(x.handle), x.type)
55+
56+
@core.extern
57+
@_check_dtype(dtypes=["bf16", "fp16", "fp32"])
58+
@_add_math_1arg_docstr("exponential (base 2)")
59+
@core._tensor_member_fn
60+
def exp2(x, _builder=None):
61+
x = semantic.to_tensor(x, _builder)
62+
return core.tensor(_builder.create_exp2(x.handle), x.type)
63+
64+
@core.extern
65+
@_check_dtype(dtypes=["bf16", "fp16", "fp32"])
66+
@_add_math_1arg_docstr("natural logarithm")
67+
@core._tensor_member_fn
68+
def log(x, _builder=None):
69+
x = semantic.to_tensor(x, _builder)
70+
return core.tensor(_builder.create_log(x.handle), x.type)
71+
72+
@core.extern
73+
@_check_dtype(dtypes=["bf16", "fp16", "fp32"])
74+
@_add_math_1arg_docstr("logarithm (base 2)")
75+
@core._tensor_member_fn
76+
def log2(x, _builder=None):
77+
x = semantic.to_tensor(x, _builder)
78+
return core.tensor(_builder.create_log2(x.handle), x.type)
79+
80+
@core.extern
81+
@_check_dtype(dtypes=["bf16", "fp16", "fp32"])
82+
@_add_math_1arg_docstr("cosine")
83+
@core._tensor_member_fn
84+
def cos(x, _builder=None):
85+
x = semantic.to_tensor(x, _builder)
86+
return core.tensor(_builder.create_cos(x.handle), x.type)
87+
88+
@core.extern
89+
@_check_dtype(dtypes=["bf16", "fp16", "fp32"])
90+
@_add_math_1arg_docstr("sine")
91+
@core._tensor_member_fn
92+
def sin(x, _builder=None):
93+
x = semantic.to_tensor(x, _builder)
94+
return core.tensor(_builder.create_sin(x.handle), x.type)
95+
96+
@core.extern
97+
@_check_dtype(dtypes=["bf16", "fp16", "fp32"])
98+
@_add_math_1arg_docstr("fast square root")
99+
@core._tensor_member_fn
100+
def sqrt(x, _builder=None):
101+
x = semantic.to_tensor(x, _builder)
102+
return core.tensor(_builder.create_sqrt(x.handle), x.type)
103+
104+
@core.extern
105+
@_check_dtype(dtypes=["bf16", "fp16", "fp32"])
106+
@_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)")
107+
@core._tensor_member_fn
108+
def sqrt_rn(x, _builder=None):
109+
x = semantic.to_tensor(x, _builder)
110+
return core.tensor(_builder.create_precise_sqrt(x.handle), x.type)
111+
112+
@core.extern
113+
@_check_dtype(dtypes=["bf16", "fp16", "fp32"])
114+
@_add_math_1arg_docstr("inverse square root")
115+
@core._tensor_member_fn
116+
def rsqrt(x, _builder=None):
117+
x = semantic.to_tensor(x, _builder)
118+
return core.tensor(_builder.create_rsqrt(x.handle), x.type)
119+
120+
@core.extern
121+
@_check_dtype(dtypes=["bf16", "fp16", "fp32"])
122+
@_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)")
123+
def div_rn(x, y, _builder=None):
124+
x = semantic.to_tensor(x, _builder)
125+
y = semantic.to_tensor(y, _builder)
126+
x, y = core.binary_op_type_legalization(x, y, _builder)
127+
return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type)
128+
129+
@core.extern
130+
@_check_dtype(dtypes=["bf16", "fp16", "fp32"])
131+
@_add_math_1arg_docstr("error function")
132+
@core._tensor_member_fn
133+
def erf(x, _builder=None):
134+
x = semantic.to_tensor(x, _builder)
135+
return core.tensor(_builder.create_erf(x.handle), x.type)
136+
137+
@core.extern
138+
@_check_dtype(dtypes=["bf16", "fp16", "fp32"])
139+
@_add_math_1arg_docstr("error function")
140+
@core._tensor_member_fn
141+
def tanh(x, _builder=None):
142+
x = semantic.to_tensor(x, _builder)
143+
return core.tensor(_builder.create_tanh(x.handle), x.type)
144+
145+
@core.extern
146+
@_check_dtype(dtypes=["bf16", "fp16", "fp32"])
147+
@_add_math_1arg_docstr("floor")
148+
@core._tensor_member_fn
149+
def floor(x, _builder=None):
150+
x = semantic.to_tensor(x, _builder)
151+
return core.tensor(_builder.create_floor(x.handle), x.type)
152+
153+
154+
@core.extern
155+
@_check_dtype(dtypes=["bf16", "fp16", "fp32"])
156+
@_add_math_1arg_docstr("ceil")
157+
@core._tensor_member_fn
158+
def ceil(x, _builder=None):
159+
x = semantic.to_tensor(x, _builder)
160+
return core.tensor(_builder.create_ceil(x.handle), x.type)
161+
162+
163+
@core.extern
164+
@_check_dtype(dtypes=["bf16", "fp16", "fp32"])
165+
@_add_math_3arg_docstr("fused multiply-add")
166+
def fma(x, y, z, _builder=None):
167+
x = semantic.to_tensor(x, _builder)
168+
y = semantic.to_tensor(y, _builder)
169+
z = semantic.to_tensor(z, _builder)
170+
x, y = core.binary_op_type_legalization(x, y, _builder)
171+
z, x = core.binary_op_type_legalization(z, x, _builder)
172+
z, y = core.binary_op_type_legalization(z, y, _builder)
173+
return core.tensor(_builder.create_fma(x.handle, y.handle, z.handle), x.type)
174+
2175

3176
@core.extern
4177
def reciprocal(arg0, _builder=None):
@@ -151,5 +324,5 @@ def trunc(arg0, _builder=None):
151324
def round(arg0, _builder=None):
152325
return core.extern_elementwise(
153326
"", "", [arg0], {
154-
(core.dtype("fp32"), ): ("__hmf_roundf", core.dtype("fp32")),
327+
(core.dtype("fp32"), ): ("__hmf_roundf", core.dtype("fp32")),
155328
}, is_pure=True, _builder=_builder)

0 commit comments

Comments
 (0)