Skip to content

Commit fce77e3

Browse files
authored
correct adamw bf16 unit test and the way to get data type (#60565)
1 parent 25a7b2b commit fce77e3

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

test/xpu/op_test_xpu.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,13 @@ def check_output_with_place(
106106
if not core.is_float16_supported(place):
107107
return
108108

109-
if self.dtype == np.float16:
109+
if self.dtype == np.uint16:
110+
if not core.is_bfloat16_supported(place):
111+
return
112+
113+
if self.dtype == np.float16 or self.dtype == np.uint16:
110114
atol = 0.1
115+
111116
return super().check_output_with_place(
112117
place,
113118
atol,
@@ -183,6 +188,10 @@ def check_grad_with_place(
183188
if not core.is_float16_supported(place):
184189
return
185190

191+
if self.dtype == np.uint16:
192+
if not core.is_bfloat16_supported(place):
193+
return
194+
186195
if self.dtype == np.float16 or self.dtype == np.uint16:
187196
max_relative_error = 0.1
188197
return super().check_grad_with_place(

test/xpu/test_adamw_op_xpu.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
create_test_class,
2222
get_xpu_op_support_types,
2323
)
24+
from op_test import convert_float_to_uint16
2425
from op_test_xpu import XPUOpTest
2526

2627
import paddle
@@ -85,8 +86,8 @@ def setUp(self):
8586
self.op_type = "adamw"
8687
self.init_shape()
8788
self.dtype = self.in_type
88-
param = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
89-
grad = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
89+
param = np.random.uniform(-1, 1, self.shape)
90+
grad = np.random.uniform(-1, 1, self.shape)
9091
moment1 = np.random.uniform(-1, 1, self.shape).astype("float32")
9192
# The second moment is positive
9293
moment2 = np.random.random(self.shape).astype("float32")
@@ -97,7 +98,9 @@ def setUp(self):
9798
epsilon = 1e-4
9899
beta1_pow = beta1**10
99100
beta2_pow = beta2**10
100-
101+
if self.dtype != np.uint16:
102+
param = param.astype(self.dtype)
103+
grad = grad.astype(self.dtype)
101104
self.inputs = {
102105
'Param': param,
103106
'Grad': grad,
@@ -128,13 +131,26 @@ def setUp(self):
128131
'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2,
129132
}
130133

134+
if self.dtype == np.uint16:
135+
self.inputs['Param'] = convert_float_to_uint16(
136+
self.inputs['Param']
137+
)
138+
self.inputs['Grad'] = convert_float_to_uint16(
139+
self.inputs['Grad']
140+
)
141+
self.outputs['ParamOut'] = convert_float_to_uint16(param_out)
142+
131143
def init_shape(self):
132144
self.shape = [102, 105]
133145

134146
def test_check_output(self):
135147
paddle.enable_static()
136148
self.check_output_with_place(place=paddle.XPUPlace(0))
137149

150+
def infer_dtype_from_inputs_outputs(self, inputs, outputs):
151+
self.__class__.dtype = self.dtype
152+
self.output_dtype = self.dtype
153+
138154
class TestAdamW2(TestAdamW):
139155
def init_shape(self):
140156
self.shape = [

test/xpu/test_elementwise_sub_op_xpu.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,17 @@ def init_input_output(self):
6666
if self.dtype == np.uint16:
6767
tmp_x = self.reshape_data(self.x, self.y)
6868
tmp_y = self.reshape_data(self.y, self.x)
69-
self.outputs = {'Out': tmp_x - tmp_y}
69+
tmp_out = tmp_x - tmp_y
70+
self.outputs = {'Out': convert_float_to_uint16(tmp_out)}
7071
self.x = convert_float_to_uint16(self.x)
7172
self.y = convert_float_to_uint16(self.y)
7273
else:
7374
tmp_x = self.reshape_data(self.x, self.y).astype(self.dtype)
7475
tmp_y = self.reshape_data(self.y, self.x).astype(self.dtype)
7576
self.outputs = {'Out': tmp_x - tmp_y}
7677
self.inputs = {
77-
'X': self.x,
78-
'Y': self.y,
78+
'X': self.x.astype(self.dtype),
79+
'Y': self.y.astype(self.dtype),
7980
}
8081

8182
def init_shape(self):

0 commit comments

Comments
 (0)