Skip to content

Commit b751514

Browse files
committed
square, sin and cos support bfloat16 for xpu
1 parent 9952846 commit b751514

5 files changed

Lines changed: 94 additions & 82 deletions

File tree

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,9 @@ XPUOpMap& get_kl3_ops() {
293293
phi::DataType::INT32,
294294
phi::DataType::INT64})},
295295
{"elementwise_pow",
296-
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
296+
XPUKernelSet({phi::DataType::FLOAT32,
297+
phi::DataType::FLOAT16,
298+
phi::DataType::BFLOAT16})},
297299
{"elementwise_sub_grad",
298300
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
299301
{"elementwise_sub",
@@ -885,7 +887,9 @@ XPUOpMap& get_kl3_ops() {
885887
{"square_grad",
886888
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
887889
{"square",
888-
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
890+
XPUKernelSet({phi::DataType::FLOAT32,
891+
phi::DataType::FLOAT16,
892+
phi::DataType::BFLOAT16})},
889893
{"squared_l2_norm",
890894
XPUKernelSet({phi::DataType::FLOAT32,
891895
phi::DataType::FLOAT16,
@@ -1136,9 +1140,15 @@ XPUOpMap& get_kl3_ops() {
11361140
phi::DataType::FLOAT32,
11371141
phi::DataType::FLOAT16,
11381142
phi::DataType::BFLOAT16})},
1139-
{"sin", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
1143+
{"sin",
1144+
XPUKernelSet({phi::DataType::FLOAT32,
1145+
phi::DataType::FLOAT16,
1146+
phi::DataType::BFLOAT16})},
11401147
{"sin_grad", XPUKernelSet({phi::DataType::FLOAT32})},
1141-
{"cos", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
1148+
{"cos",
1149+
XPUKernelSet({phi::DataType::FLOAT32,
1150+
phi::DataType::FLOAT16,
1151+
phi::DataType::BFLOAT16})},
11421152
{"cos_grad", XPUKernelSet({phi::DataType::FLOAT32})},
11431153
{"linspace",
11441154
XPUKernelSet({phi::DataType::FLOAT32,

paddle/phi/kernels/xpu/activation_kernel.cc

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -624,19 +624,34 @@ PD_REGISTER_KERNEL(sqrt,
624624
PD_REGISTER_KERNEL(
625625
tanh, XPU, ALL_LAYOUT, phi::TanhKernel, float, phi::dtype::float16) {}
626626

627-
PD_REGISTER_KERNEL(
628-
square, XPU, ALL_LAYOUT, phi::SquareKernel, float, phi::dtype::float16) {}
627+
PD_REGISTER_KERNEL(square,
628+
XPU,
629+
ALL_LAYOUT,
630+
phi::SquareKernel,
631+
float,
632+
phi::dtype::float16,
633+
phi::dtype::bfloat16) {}
629634

630635
PD_REGISTER_KERNEL(
631636
log, XPU, ALL_LAYOUT, phi::LogKernel, float, phi::dtype::float16) {}
632637

633638
PD_REGISTER_KERNEL(
634639
relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float, phi::dtype::float16) {}
635640

636-
PD_REGISTER_KERNEL(
637-
sin, XPU, ALL_LAYOUT, phi::SinKernel, float, phi::dtype::float16) {}
638-
PD_REGISTER_KERNEL(
639-
cos, XPU, ALL_LAYOUT, phi::CosKernel, float, phi::dtype::float16) {}
641+
PD_REGISTER_KERNEL(sin,
642+
XPU,
643+
ALL_LAYOUT,
644+
phi::SinKernel,
645+
float,
646+
phi::dtype::float16,
647+
phi::dtype::bfloat16) {}
648+
PD_REGISTER_KERNEL(cos,
649+
XPU,
650+
ALL_LAYOUT,
651+
phi::CosKernel,
652+
float,
653+
phi::dtype::float16,
654+
phi::dtype::bfloat16) {}
640655

641656
#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
642657
PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {}

paddle/phi/kernels/xpu/elementwise_kernel.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,5 @@ PD_REGISTER_KERNEL(elementwise_pow,
114114
ALL_LAYOUT,
115115
phi::ElementwisePowKernel,
116116
float,
117-
phi::dtype::float16) {}
117+
phi::dtype::float16,
118+
phi::dtype::bfloat16) {}

test/xpu/test_activation_op_xpu.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -521,34 +521,39 @@ def set_case(self):
521521
self.op_type = "square"
522522
self.dtype = self.in_type
523523
self.init_config()
524+
if self.dtype == np.uint16:
525+
# bfloat16 actually
526+
self.x = convert_float_to_uint16(self.tmp_x)
527+
else:
528+
self.x = self.tmp_x.astype(self.dtype)
524529
out = np.square(self.x)
525530

526531
self.attrs = {'use_xpu': True}
527532
self.inputs = {'X': OpTest.np_dtype_to_base_dtype(self.x)}
528533
self.outputs = {'Out': out}
529534

530535
def init_config(self):
531-
self.x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
536+
self.tmp_x = np.random.uniform(-1, 1, [11, 17])
532537

533538
class XPUTestSquare_ZeroDim(XPUTestSquare):
534539
def init_config(self):
535-
self.x = np.random.uniform(-2, 2, []).astype(self.dtype)
540+
self.tmp_x = np.random.uniform(-2, 2, [])
536541

537542
class XPUTestSquare2(XPUTestSquare):
538543
def init_config(self):
539-
self.x = np.random.uniform(-2, 2, [100]).astype(self.dtype)
544+
self.tmp_x = np.random.uniform(-2, 2, [100])
540545

541546
class XPUTestSquare3(XPUTestSquare):
542547
def init_config(self):
543-
self.x = np.random.uniform(-2, 2, [1, 15, 19]).astype(self.dtype)
548+
self.tmp_x = np.random.uniform(-2, 2, [1, 15, 19])
544549

545550
class XPUTestSquare4(XPUTestSquare):
546551
def init_config(self):
547-
self.x = np.random.uniform(-2, 2, [100, 10]).astype(self.dtype)
552+
self.tmp_x = np.random.uniform(-2, 2, [100, 10])
548553

549554
class XPUTestSquare5(XPUTestSquare):
550555
def init_config(self):
551-
self.x = np.random.uniform(-2, 2, [1, 2, 5, 17]).astype(self.dtype)
556+
self.tmp_x = np.random.uniform(-2, 2, [1, 2, 5, 17])
552557

553558

554559
support_types = get_xpu_op_support_types('square')
@@ -1297,38 +1302,35 @@ def set_case(self):
12971302
self.dtype = self.in_type
12981303

12991304
self.init_config()
1305+
if self.dtype == np.uint16:
1306+
# bfloat16 actually
1307+
self.x = convert_float_to_uint16(self.tmp_x)
1308+
else:
1309+
self.x = self.tmp_x.astype(self.dtype)
13001310
out = np.sin(self.x)
13011311

13021312
self.inputs = {'X': self.x}
13031313
self.outputs = {'Out': out}
13041314
self.attrs = {'use_xpu': True}
13051315

13061316
def init_config(self):
1307-
self.x = np.random.uniform(-np.pi, np.pi, [11, 17]).astype(
1308-
self.dtype
1309-
)
1317+
self.tmp_x = np.random.uniform(-np.pi, np.pi, [11, 17])
13101318

13111319
class XPUTestSin_ZeroDim(XPUTestSinBase):
13121320
def init_config(self):
1313-
self.x = np.random.uniform(-np.pi, np.pi, []).astype(self.dtype)
1321+
self.tmp_x = np.random.uniform(-np.pi, np.pi, [])
13141322

13151323
class XPUTestSin2(XPUTestSinBase):
13161324
def init_config(self):
1317-
self.x = np.random.uniform(-np.pi, np.pi, [1024, 8]).astype(
1318-
self.dtype
1319-
)
1325+
self.tmp_x = np.random.uniform(-np.pi, np.pi, [1024, 8])
13201326

13211327
class XPUTestSin3(XPUTestSinBase):
13221328
def init_config(self):
1323-
self.x = np.random.uniform(-np.pi, np.pi, [4, 512, 15, 15]).astype(
1324-
self.dtype
1325-
)
1329+
self.tmp_x = np.random.uniform(-np.pi, np.pi, [4, 512, 15, 15])
13261330

13271331
class XPUTestSin4(XPUTestSinBase):
13281332
def init_config(self):
1329-
self.x = np.random.uniform(-np.pi, np.pi, [4, 256, 22, 22]).astype(
1330-
self.dtype
1331-
)
1333+
self.tmp_x = np.random.uniform(-np.pi, np.pi, [4, 256, 22, 22])
13321334

13331335

13341336
support_types = get_xpu_op_support_types('sin')
@@ -1347,38 +1349,35 @@ def set_case(self):
13471349
self.dtype = self.in_type
13481350

13491351
self.init_config()
1352+
if self.dtype == np.uint16:
1353+
# bfloat16 actually
1354+
self.x = convert_float_to_uint16(self.tmp_x)
1355+
else:
1356+
self.x = self.tmp_x.astype(self.dtype)
13501357
out = np.cos(self.x)
13511358

13521359
self.inputs = {'X': self.x}
13531360
self.outputs = {'Out': out}
13541361
self.attrs = {'use_xpu': True}
13551362

13561363
def init_config(self):
1357-
self.x = np.random.uniform(-np.pi, np.pi, [11, 17]).astype(
1358-
self.dtype
1359-
)
1364+
self.tmp_x = np.random.uniform(-np.pi, np.pi, [11, 17])
13601365

13611366
class XPUTestCos_ZeroDim(XPUTestCosBase):
13621367
def init_config(self):
1363-
self.x = np.random.uniform(-np.pi, np.pi, []).astype(self.dtype)
1368+
self.tmp_x = np.random.uniform(-np.pi, np.pi, [])
13641369

13651370
class XPUTestCos2(XPUTestCosBase):
13661371
def init_config(self):
1367-
self.x = np.random.uniform(-np.pi, np.pi, [1024, 8]).astype(
1368-
self.dtype
1369-
)
1372+
self.tmp_x = np.random.uniform(-np.pi, np.pi, [1024, 8])
13701373

13711374
class XPUTestCos3(XPUTestCosBase):
13721375
def init_config(self):
1373-
self.x = np.random.uniform(-np.pi, np.pi, [4, 512, 15, 15]).astype(
1374-
self.dtype
1375-
)
1376+
self.tmp_x = np.random.uniform(-np.pi, np.pi, [4, 512, 15, 15])
13761377

13771378
class XPUTestCos4(XPUTestCosBase):
13781379
def init_config(self):
1379-
self.x = np.random.uniform(-np.pi, np.pi, [4, 256, 22, 22]).astype(
1380-
self.dtype
1381-
)
1380+
self.tmp_x = np.random.uniform(-np.pi, np.pi, [4, 256, 22, 22])
13821381

13831382

13841383
support_types = get_xpu_op_support_types('cos')

test/xpu/test_elementwise_pow_op_xpu.py

Lines changed: 26 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
create_test_class,
2121
get_xpu_op_support_types,
2222
)
23-
from op_test import OpTest, skip_check_grad_ci
23+
from op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci
2424
from op_test_xpu import XPUOpTest
2525

2626
import paddle
@@ -40,73 +40,60 @@ def setUp(self):
4040
self.dtype = self.in_type
4141
self.__class__.no_need_check_grad = True
4242
self.compute_input_output()
43-
44-
def compute_input_output(self):
43+
if self.dtype == np.uint16:
44+
# bfloat16 actually
45+
self.x = convert_float_to_uint16(self.tmp_x)
46+
self.y = convert_float_to_uint16(self.tmp_y)
47+
else:
48+
self.x = self.tmp_x.astype(self.dtype)
49+
self.y = self.tmp_y.astype(self.dtype)
4550
self.inputs = {
46-
'X': np.random.uniform(1, 2, [20, 5]).astype(self.dtype),
47-
'Y': np.random.uniform(1, 2, [20, 5]).astype(self.dtype),
51+
'X': self.x,
52+
'Y': self.y,
4853
}
4954
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
5055

56+
def compute_input_output(self):
57+
self.tmp_x = np.random.uniform(1, 2, [20, 5])
58+
self.tmp_y = np.random.uniform(1, 2, [20, 5])
59+
5160
def test_check_output(self):
5261
if paddle.is_compiled_with_xpu():
5362
place = paddle.XPUPlace(0)
5463
self.check_output_with_place(place, check_dygraph=False)
5564

5665
class TestElementwisePowOp_big_shape_1(TestElementwisePowOp):
5766
def compute_input_output(self):
58-
self.inputs = {
59-
'X': np.random.uniform(1, 2, [10, 10]).astype(self.dtype),
60-
'Y': np.random.uniform(0.1, 1, [10, 10]).astype(self.dtype),
61-
}
62-
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
67+
self.tmp_x = np.random.uniform(1, 2, [10, 10])
68+
self.tmp_y = np.random.uniform(0.1, 1, [10, 10])
6369

6470
class TestElementwisePowOp_big_shape_2(TestElementwisePowOp):
6571
def compute_input_output(self):
66-
self.inputs = {
67-
'X': np.random.uniform(1, 2, [10, 10]).astype(self.dtype),
68-
'Y': np.random.uniform(0.2, 2, [10, 10]).astype(self.dtype),
69-
}
70-
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
72+
self.tmp_x = np.random.uniform(1, 2, [10, 10])
73+
self.tmp_y = np.random.uniform(0.2, 2, [10, 10])
7174

7275
@skip_check_grad_ci(
7376
reason="[skip shape check] Use y_shape(1) to test broadcast."
7477
)
7578
class TestElementwisePowOp_scalar(TestElementwisePowOp):
7679
def compute_input_output(self):
77-
self.inputs = {
78-
'X': np.random.uniform(0.1, 1, [3, 3, 4]).astype(self.dtype),
79-
'Y': np.random.uniform(0.1, 1, [1]).astype(self.dtype),
80-
}
81-
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
80+
self.tmp_x = np.random.uniform(0.1, 1, [3, 3, 4])
81+
self.tmp_y = np.random.uniform(0.1, 1, [1])
8282

8383
class TestElementwisePowOp_tensor(TestElementwisePowOp):
8484
def compute_input_output(self):
85-
self.inputs = {
86-
'X': np.random.uniform(0.1, 1, [100]).astype(self.dtype),
87-
'Y': np.random.uniform(1, 3, [100]).astype(self.dtype),
88-
}
89-
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
85+
self.tmp_x = np.random.uniform(0.1, 1, [100])
86+
self.tmp_y = np.random.uniform(1, 3, [100])
9087

9188
class TestElementwisePowOp_broadcast_0(TestElementwisePowOp):
9289
def compute_input_output(self):
93-
self.inputs = {
94-
'X': np.random.uniform(0.1, 1, [2, 1, 100]).astype(self.dtype),
95-
'Y': np.random.uniform(0.1, 1, [100]).astype(self.dtype),
96-
}
97-
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
90+
self.tmp_x = np.random.uniform(0.1, 1, [2, 1, 100])
91+
self.tmp_y = np.random.uniform(0.1, 1, [100])
9892

9993
class TestElementwisePowOp_broadcast_4(TestElementwisePowOp):
10094
def compute_input_output(self):
101-
self.inputs = {
102-
'X': np.random.uniform(0.1, 1, [2, 10, 3, 5]).astype(
103-
self.dtype
104-
),
105-
'Y': np.random.uniform(0.1, 1, [2, 10, 1, 5]).astype(
106-
self.dtype
107-
),
108-
}
109-
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
95+
self.tmp_x = np.random.uniform(0.1, 1, [2, 10, 3, 5])
96+
self.tmp_y = np.random.uniform(0.1, 1, [2, 10, 1, 5])
11097

11198
class TestElementwisePowOpInt(OpTest):
11299
def setUp(self):

0 commit comments

Comments
 (0)