Skip to content

Commit 0205316

Browse files
committed
add fp16 unittests for kl2
1 parent 19b02d9 commit 0205316

File tree

9 files changed

+490
-307
lines changed

9 files changed

+490
-307
lines changed

paddle/fluid/operators/scale_op_xpu.cc

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ namespace paddle {
2222
namespace operators {
2323
template <typename DeviceContext, typename T>
2424
class ScaleXPUKernel : public framework::OpKernel<T> {
25+
using XPUType = typename XPUTypeTrait<T>::Type;
26+
2527
public:
2628
virtual void Compute(const framework::ExecutionContext& ctx) const {
2729
auto* in_var = ctx.InputVar("X");
2830
auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var);
29-
auto scale = static_cast<T>(ctx.Attr<float>("scale"));
30-
auto bias = static_cast<T>(ctx.Attr<float>("bias"));
31+
auto scale = static_cast<float>(ctx.Attr<float>("scale"));
32+
auto bias = static_cast<float>(ctx.Attr<float>("bias"));
3133
auto bias_after_scale = ctx.Attr<bool>("bias_after_scale");
3234
auto* out_var = ctx.OutputVar("Out");
3335
if (in_var->IsType<framework::SelectedRows>() && in_var != out_var) {
@@ -46,9 +48,10 @@ class ScaleXPUKernel : public framework::OpKernel<T> {
4648
in->dims().to_str().c_str(),
4749
out->dims().to_str().c_str()));
4850
auto& dev_ctx = ctx.template device_context<DeviceContext>();
49-
int r =
50-
xpu::scale(dev_ctx.x_context(), in->data<float>(), out->data<float>(),
51-
in->numel(), bias_after_scale, scale, bias);
51+
int r = xpu::scale(dev_ctx.x_context(),
52+
reinterpret_cast<const XPUType*>(in->data<T>()),
53+
reinterpret_cast<XPUType*>(out->data<T>()), in->numel(),
54+
bias_after_scale, scale, bias);
5255
PADDLE_ENFORCE_EQ(
5356
r, XPU_SUCCESS,
5457
platform::errors::External("XPU scale kernel return wrong value[%d %s]",
@@ -60,7 +63,11 @@ class ScaleXPUKernel : public framework::OpKernel<T> {
6063
} // namespace paddle
6164

6265
namespace ops = paddle::operators;
66+
6367
REGISTER_OP_XPU_KERNEL(
64-
scale, ops::ScaleXPUKernel<paddle::platform::XPUDeviceContext, float>);
68+
scale, ops::ScaleXPUKernel<paddle::platform::XPUDeviceContext, float>,
69+
ops::ScaleXPUKernel<paddle::platform::XPUDeviceContext,
70+
paddle::platform::float16>,
71+
ops::ScaleXPUKernel<paddle::platform::XPUDeviceContext, int64_t>);
6572

6673
#endif

paddle/fluid/platform/xpu/xpu2_op_list.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ XPUOpMap& get_kl2_ops() {
184184
pOpKernelType(vartype::INT8, XPUPlace()),
185185
pOpKernelType(vartype::FP16, XPUPlace()),
186186
pOpKernelType(vartype::FP32, XPUPlace())})},
187+
{"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
188+
pOpKernelType(vartype::FP16, XPUPlace()),
189+
pOpKernelType(vartype::INT64, XPUPlace())})}
187190
// AddMore
188191
};
189192

paddle/fluid/pybind/pybind.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1709,6 +1709,14 @@ All parameter, weight, gradient are variables in Paddle.
17091709
m.def("get_xpu_device_count", platform::GetXPUDeviceCount);
17101710
m.def("get_xpu_device_version",
17111711
[](int device_id) { return platform::get_xpu_version(device_id); });
1712+
m.def("is_float16_supported", [](const platform::XPUPlace &place) -> bool {
1713+
// XPUs with Compute Capability > xpu2 support float16 and bfloat16
1714+
return platform::get_xpu_version(place.device) > platform::XPUVersion::XPU1;
1715+
});
1716+
m.def("is_bfloat16_supported", [](const platform::XPUPlace &place) -> bool {
1717+
// XPUs with Compute Capability > xpu2 support float16 and bfloat16
1718+
return platform::get_xpu_version(place.device) > platform::XPUVersion::XPU1;
1719+
});
17121720
#endif
17131721

17141722
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace", R"DOC(

python/paddle/fluid/tests/unittests/op_test_xpu.py

Lines changed: 45 additions & 230 deletions
Original file line numberDiff line numberDiff line change
@@ -44,86 +44,33 @@ class XPUOpTest(OpTest):
4444
@classmethod
4545
def setUpClass(cls):
4646
'''Fix random seeds to remove randomness from tests'''
47-
cls._np_rand_state = np.random.get_state()
48-
cls._py_rand_state = random.getstate()
49-
cls.call_once = False
50-
cls.dtype = np.float32
51-
cls.outputs = {}
52-
cls.input_shape_is_large = True
53-
54-
np.random.seed(123)
55-
random.seed(124)
56-
57-
cls._use_system_allocator = _set_use_system_allocator(True)
47+
cls.use_xpu = True
48+
cls.use_mkldnn = False
49+
super().setUpClass()
5850

5951
@classmethod
6052
def tearDownClass(cls):
6153
"""Restore random seeds"""
62-
np.random.set_state(cls._np_rand_state)
63-
random.setstate(cls._py_rand_state)
64-
65-
_set_use_system_allocator(cls._use_system_allocator)
6654

6755
def is_empty_grad_op(op_type):
6856
all_op_kernels = core._get_all_register_op_kernels()
6957
grad_op = op_type + '_grad'
7058
if grad_op in all_op_kernels.keys():
71-
if is_mkldnn_op_test():
72-
grad_op_kernels = all_op_kernels[grad_op]
73-
for grad_op_kernel in grad_op_kernels:
74-
if 'MKLDNN' in grad_op_kernel:
75-
return False
76-
else:
77-
return False
59+
grad_op_kernels = all_op_kernels[grad_op]
60+
for grad_op_kernel in grad_op_kernels:
61+
if 'XPU' in grad_op_kernel:
62+
return False
7863
return True
7964

80-
def is_xpu_op_test():
81-
return True
82-
83-
def is_mkldnn_op_test():
84-
return False
85-
86-
if not hasattr(cls, "op_type"):
87-
raise AssertionError(
88-
"This test do not have op_type in class attrs, "
89-
"please set self.__class__.op_type=the_real_op_type manually.")
65+
if cls.dtype == np.float16:
66+
place = paddle.XPUPlace(0)
67+
if core.is_float16_supported(place) == False:
68+
return
69+
super().tearDownClass()
9070

91-
# case in NO_FP64_CHECK_GRAD_CASES and op in NO_FP64_CHECK_GRAD_OP_LIST should be fixed
92-
if not hasattr(cls, "no_need_check_grad") \
93-
and not is_empty_grad_op(cls.op_type):
94-
if cls.dtype is None or \
95-
(cls.dtype == np.float16 \
96-
and cls.op_type not in op_accuracy_white_list.NO_FP16_CHECK_GRAD_OP_LIST \
97-
and not hasattr(cls, "exist_check_grad")):
98-
raise AssertionError("This test of %s op needs check_grad." %
99-
cls.op_type)
100-
101-
# check for op test with fp64 precision, but not check mkldnn op test for now
102-
if cls.dtype in [np.float32, np.float64] \
103-
and cls.op_type not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST \
104-
and not hasattr(cls, 'exist_fp64_check_grad') \
105-
and not is_xpu_op_test() \
106-
and not is_mkldnn_op_test() \
107-
and not is_rocm_op_test() \
108-
and not is_npu_op_test():
109-
raise AssertionError(
110-
"This test of %s op needs check_grad with fp64 precision." %
111-
cls.op_type)
112-
113-
if not cls.input_shape_is_large \
114-
and cls.op_type not in check_shape_white_list.NEED_TO_FIX_OP_LIST:
115-
raise AssertionError(
116-
"Input's shape should be large than or equal to 100 for " +
117-
cls.op_type + " Op.")
118-
119-
def try_call_once(self, data_type):
120-
if not self.call_once:
121-
self.call_once = True
122-
if data_type is not None and \
123-
data_type != np.float32:
124-
raise AssertionError("Unsupport data type %s in xpu" %
125-
data_type)
126-
self.dtype = data_type
71+
def _get_places(self):
72+
places = [fluid.XPUPlace(0)]
73+
return places
12774

12875
def check_output_with_place(self,
12976
place,
@@ -133,166 +80,17 @@ def check_output_with_place(self,
13380
check_dygraph=True,
13481
inplace_atol=None):
13582
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
136-
if self.dtype == np.float64 and \
137-
self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST:
138-
atol = 0
139-
140-
if self.is_bfloat16_op():
141-
check_dygraph = False
142-
if hasattr(self, 'force_fp32_output') and getattr(
143-
self, 'force_fp32_output'):
144-
atol = 1e-2
145-
else:
146-
atol = 2
147-
148-
if no_check_set is not None:
149-
if self.op_type not in no_check_set_white_list.no_check_set_white_list:
150-
raise AssertionError(
151-
"no_check_set of op %s must be set to None." % self.op_type)
152-
153-
if check_dygraph:
154-
dygraph_outs = self._calc_dygraph_output(
155-
place, no_check_set=no_check_set)
156-
outs, fetch_list = self._calc_output(place, no_check_set=no_check_set)
157-
for out_name, out_dup in Operator.get_op_outputs(self.op_type):
158-
if out_name not in self.outputs:
159-
continue
160-
if no_check_set is not None and out_name in no_check_set:
161-
continue
162-
163-
def find_imperative_actual(target_name, dygraph_outs, place):
164-
with fluid.dygraph.base.guard(place=place):
165-
for name in dygraph_outs:
166-
if name == target_name:
167-
return dygraph_outs[name][0]
168-
var_list = dygraph_outs[name]
169-
for i, var in enumerate(var_list):
170-
if var.name == target_name:
171-
return dygraph_outs[name][i]
172-
self.assertTrue(False, "Found failed {} {}".format(
173-
dygraph_outs.keys(), target_name))
174-
175-
def find_actual(target_name, fetch_list):
176-
found = [
177-
i for i, var_name in enumerate(fetch_list)
178-
if var_name == target_name
179-
]
180-
self.assertTrue(
181-
len(found) == 1, "Found {} {}".format(
182-
len(found), target_name))
183-
return found[0]
184-
185-
if out_dup:
186-
sub_out = self.outputs[out_name]
187-
if not isinstance(sub_out, list):
188-
raise AssertionError("sub_out type %s is not list",
189-
type(sub_out))
190-
for item in sub_out:
191-
sub_out_name, expect = item[0], item[1]
192-
if check_dygraph:
193-
imperative_actual = find_imperative_actual(
194-
sub_out_name, dygraph_outs, place)
195-
imperative_actual_t = np.array(imperative_actual.value()
196-
.get_tensor())
197-
idx = find_actual(sub_out_name, fetch_list)
198-
actual = outs[idx]
199-
actual_t = np.array(actual)
200-
expect_t = expect[0] \
201-
if isinstance(expect, tuple) else expect
202-
self.assertTrue(
203-
np.allclose(
204-
actual_t, expect_t, atol=atol, equal_nan=equal_nan),
205-
"Output (" + sub_out_name + ") has diff at " +
206-
str(place))
207-
if check_dygraph:
208-
self.assertTrue(
209-
np.allclose(
210-
imperative_actual_t,
211-
expect_t,
212-
atol=atol,
213-
equal_nan=equal_nan),
214-
"Output (" + sub_out_name + ") has diff at " +
215-
str(place) + " in dygraph mode")
216-
if isinstance(expect, tuple):
217-
self.assertListEqual(
218-
actual.recursive_sequence_lengths(), expect[1],
219-
"Output (" + sub_out_name +
220-
") has different lod at " + str(place))
221-
if check_dygraph:
222-
self.assertListEqual(
223-
imperative_actual.value().get_tensor()
224-
.recursive_sequence_lengths(), expect[1],
225-
"Output (" + out_name +
226-
") has different lod at " + str(place) +
227-
" in dygraph mode")
228-
else:
229-
if check_dygraph:
230-
imperative_actual = find_imperative_actual(
231-
out_name, dygraph_outs, place)
232-
imperative_actual_t = np.array(imperative_actual.value()
233-
.get_tensor())
234-
idx = find_actual(out_name, fetch_list)
235-
actual = outs[idx]
236-
actual_t = np.array(actual)
237-
expect = self.outputs[out_name]
238-
expect_t = expect[0] if isinstance(expect, tuple) else expect
239-
self.assertTrue(
240-
np.allclose(
241-
actual_t, expect_t, atol=atol, equal_nan=equal_nan),
242-
"Output (" + out_name + ") has diff at " + str(place) +
243-
"\nExpect " + str(expect_t) + "\n" + "But Got" +
244-
str(actual_t) + " in class " + self.__class__.__name__ + " "
245-
+ str(atol) + " " + str(expect_t - actual_t))
246-
if check_dygraph:
247-
if six.moves.reduce(
248-
lambda x, y: x * y, imperative_actual_t.shape,
249-
1) == 0 and six.moves.reduce(
250-
lambda x, y: x * y, expect_t.shape, 1) == 0:
251-
pass
252-
else:
253-
self.assertTrue(
254-
np.allclose(
255-
imperative_actual_t,
256-
expect_t,
257-
atol=atol,
258-
equal_nan=equal_nan),
259-
"Output (" + out_name + ") has diff at " +
260-
str(place) + "\nExpect " + str(expect_t) + "\n" +
261-
"But Got" + str(imperative_actual_t) + " in class "
262-
+ self.__class__.__name__)
263-
if isinstance(expect, tuple):
264-
self.assertListEqual(actual.recursive_sequence_lengths(),
265-
expect[1], "Output (" + out_name +
266-
") has different lod at " + str(place))
267-
if check_dygraph:
268-
self.assertListEqual(
269-
imperative_actual.value().get_tensor()
270-
.recursive_sequence_lengths(), expect[1],
271-
"Output (" + out_name + ") has different lod at " +
272-
str(place) + " in dygraph mode")
273-
274-
# Note(zhiqiu): inplace_atol should be only set when op doesn't ensure
275-
# computational consistency.
276-
# For example, group_norm uses AtomicAdd on CUDAPlace, which do not ensure
277-
# computation order when multiple threads write the same address. So the
278-
# result of group_norm is non-deterministic when datatype is float.
279-
# When inplace_atol is not None, the inplace check uses numpy.allclose
280-
# to check inplace result instead of numpy.array_equal.
281-
if inplace_atol is not None:
282-
warnings.warn(
283-
"inplace_atol should only be set when op doesn't ensure computational consistency, please check it!"
284-
)
285-
# Check inplace for given op, its grad op, its grad_grad op, etc.
286-
# No effect on original OpTest
287-
# Currently not support ParallelExecutor on XPUPlace.
288-
if not paddle.is_compiled_with_xpu():
289-
self.check_inplace_output_with_place(
290-
place, no_check_set=no_check_set, inplace_atol=inplace_atol)
291-
292-
if check_dygraph:
293-
return outs
294-
else:
295-
return outs
83+
#xpu not support float64
84+
if self.dtype == np.float64:
85+
return
86+
if place == None:
87+
place = paddle.XPUPlace(0)
88+
89+
if self.dtype == np.float16:
90+
if core.is_float16_supported(place) == False:
91+
return
92+
return super().check_output_with_place(
93+
place, atol, no_check_set, equal_nan, check_dygraph, inplace_atol)
29694

29795
def check_grad_with_place(self,
29896
place,
@@ -303,8 +101,25 @@ def check_grad_with_place(self,
303101
in_place=False,
304102
max_relative_error=0.005,
305103
user_defined_grads=None,
306-
check_dygraph=True):
307-
place = paddle.XPUPlace(0)
104+
user_defined_grad_outputs=None,
105+
check_dygraph=True,
106+
numeric_place=None):
107+
if place == None:
108+
place = paddle.XPUPlace(0)
109+
110+
if self.dtype == np.float64:
111+
return
112+
113+
if self.dtype == np.float16:
114+
if core.is_float16_supported(place) == False:
115+
return
116+
117+
if self.dtype == np.float16:
118+
return super().check_grad_with_place(
119+
place, inputs_to_check, output_names, no_grad_set,
120+
numeric_grad_delta, in_place, max_relative_error,
121+
user_defined_grads, user_defined_grads, check_dygraph)
122+
308123
a1 = self.get_grad_with_place(
309124
place, inputs_to_check, output_names, no_grad_set=no_grad_set)
310125
a2 = self.get_grad_with_place(

0 commit comments

Comments
 (0)