Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions paddle/fluid/operators/scale_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class ScaleXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;

public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* in_var = ctx.InputVar("X");
auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var);
auto scale = static_cast<T>(ctx.Attr<float>("scale"));
auto bias = static_cast<T>(ctx.Attr<float>("bias"));
auto scale = static_cast<float>(ctx.Attr<float>("scale"));
auto bias = static_cast<float>(ctx.Attr<float>("bias"));
auto bias_after_scale = ctx.Attr<bool>("bias_after_scale");
auto* out_var = ctx.OutputVar("Out");
if (in_var->IsType<framework::SelectedRows>() && in_var != out_var) {
Expand All @@ -46,9 +48,10 @@ class ScaleXPUKernel : public framework::OpKernel<T> {
in->dims().to_str().c_str(),
out->dims().to_str().c_str()));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r =
xpu::scale(dev_ctx.x_context(), in->data<float>(), out->data<float>(),
in->numel(), bias_after_scale, scale, bias);
int r = xpu::scale(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(in->data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()), in->numel(),
bias_after_scale, scale, bias);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU scale kernel return wrong value[%d %s]",
Expand All @@ -60,7 +63,11 @@ class ScaleXPUKernel : public framework::OpKernel<T> {
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_XPU_KERNEL(
scale, ops::ScaleXPUKernel<paddle::platform::XPUDeviceContext, float>);
scale, ops::ScaleXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::ScaleXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>,
ops::ScaleXPUKernel<paddle::platform::XPUDeviceContext, int64_t>);

#endif
3 changes: 3 additions & 0 deletions paddle/fluid/platform/xpu/xpu2_op_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})}
// AddMore
};

Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1709,6 +1709,14 @@ All parameter, weight, gradient are variables in Paddle.
m.def("get_xpu_device_count", platform::GetXPUDeviceCount);
m.def("get_xpu_device_version",
[](int device_id) { return platform::get_xpu_version(device_id); });
m.def("is_float16_supported", [](const platform::XPUPlace &place) -> bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否可以在接口名里加上xpu信息来表明这个接口是专门给xpu设备用的?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是为了GPU一致,GPU也是有这个接口,一样的名字,我这里就沿用了。

// XPUs with Compute Capability > xpu2 support float16 and bfloat16
return platform::get_xpu_version(place.device) > platform::XPUVersion::XPU1;
});
m.def("is_bfloat16_supported", [](const platform::XPUPlace &place) -> bool {
// XPUs with Compute Capability > xpu2 support float16 and bfloat16
return platform::get_xpu_version(place.device) > platform::XPUVersion::XPU1;
});
#endif

py::class_<paddle::platform::CPUPlace>(m, "CPUPlace", R"DOC(
Expand Down
275 changes: 45 additions & 230 deletions python/paddle/fluid/tests/unittests/op_test_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,86 +44,33 @@ class XPUOpTest(OpTest):
@classmethod
def setUpClass(cls):
'''Fix random seeds to remove randomness from tests'''
cls._np_rand_state = np.random.get_state()
cls._py_rand_state = random.getstate()
cls.call_once = False
cls.dtype = np.float32
cls.outputs = {}
cls.input_shape_is_large = True

np.random.seed(123)
random.seed(124)

cls._use_system_allocator = _set_use_system_allocator(True)
cls.use_xpu = True
cls.use_mkldnn = False
super().setUpClass()

@classmethod
def tearDownClass(cls):
"""Restore random seeds"""
np.random.set_state(cls._np_rand_state)
random.setstate(cls._py_rand_state)

_set_use_system_allocator(cls._use_system_allocator)

def is_empty_grad_op(op_type):
all_op_kernels = core._get_all_register_op_kernels()
grad_op = op_type + '_grad'
if grad_op in all_op_kernels.keys():
if is_mkldnn_op_test():
grad_op_kernels = all_op_kernels[grad_op]
for grad_op_kernel in grad_op_kernels:
if 'MKLDNN' in grad_op_kernel:
return False
else:
return False
grad_op_kernels = all_op_kernels[grad_op]
for grad_op_kernel in grad_op_kernels:
if 'XPU' in grad_op_kernel:
return False
return True

def is_xpu_op_test():
return True

def is_mkldnn_op_test():
return False

if not hasattr(cls, "op_type"):
raise AssertionError(
"This test do not have op_type in class attrs, "
"please set self.__class__.op_type=the_real_op_type manually.")
if cls.dtype == np.float16:
place = paddle.XPUPlace(0)
if core.is_float16_supported(place) == False:
return
super().tearDownClass()

# case in NO_FP64_CHECK_GRAD_CASES and op in NO_FP64_CHECK_GRAD_OP_LIST should be fixed
if not hasattr(cls, "no_need_check_grad") \
and not is_empty_grad_op(cls.op_type):
if cls.dtype is None or \
(cls.dtype == np.float16 \
and cls.op_type not in op_accuracy_white_list.NO_FP16_CHECK_GRAD_OP_LIST \
and not hasattr(cls, "exist_check_grad")):
raise AssertionError("This test of %s op needs check_grad." %
cls.op_type)

# check for op test with fp64 precision, but not check mkldnn op test for now
if cls.dtype in [np.float32, np.float64] \
and cls.op_type not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST \
and not hasattr(cls, 'exist_fp64_check_grad') \
and not is_xpu_op_test() \
and not is_mkldnn_op_test() \
and not is_rocm_op_test() \
and not is_npu_op_test():
raise AssertionError(
"This test of %s op needs check_grad with fp64 precision." %
cls.op_type)

if not cls.input_shape_is_large \
and cls.op_type not in check_shape_white_list.NEED_TO_FIX_OP_LIST:
raise AssertionError(
"Input's shape should be large than or equal to 100 for " +
cls.op_type + " Op.")

def try_call_once(self, data_type):
if not self.call_once:
self.call_once = True
if data_type is not None and \
data_type != np.float32:
raise AssertionError("Unsupport data type %s in xpu" %
data_type)
self.dtype = data_type
def _get_places(self):
places = [fluid.XPUPlace(0)]
return places

def check_output_with_place(self,
place,
Expand All @@ -133,166 +80,17 @@ def check_output_with_place(self,
check_dygraph=True,
inplace_atol=None):
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
if self.dtype == np.float64 and \
self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST:
atol = 0

if self.is_bfloat16_op():
check_dygraph = False
if hasattr(self, 'force_fp32_output') and getattr(
self, 'force_fp32_output'):
atol = 1e-2
else:
atol = 2

if no_check_set is not None:
if self.op_type not in no_check_set_white_list.no_check_set_white_list:
raise AssertionError(
"no_check_set of op %s must be set to None." % self.op_type)

if check_dygraph:
dygraph_outs = self._calc_dygraph_output(
place, no_check_set=no_check_set)
outs, fetch_list = self._calc_output(place, no_check_set=no_check_set)
for out_name, out_dup in Operator.get_op_outputs(self.op_type):
if out_name not in self.outputs:
continue
if no_check_set is not None and out_name in no_check_set:
continue

def find_imperative_actual(target_name, dygraph_outs, place):
with fluid.dygraph.base.guard(place=place):
for name in dygraph_outs:
if name == target_name:
return dygraph_outs[name][0]
var_list = dygraph_outs[name]
for i, var in enumerate(var_list):
if var.name == target_name:
return dygraph_outs[name][i]
self.assertTrue(False, "Found failed {} {}".format(
dygraph_outs.keys(), target_name))

def find_actual(target_name, fetch_list):
found = [
i for i, var_name in enumerate(fetch_list)
if var_name == target_name
]
self.assertTrue(
len(found) == 1, "Found {} {}".format(
len(found), target_name))
return found[0]

if out_dup:
sub_out = self.outputs[out_name]
if not isinstance(sub_out, list):
raise AssertionError("sub_out type %s is not list",
type(sub_out))
for item in sub_out:
sub_out_name, expect = item[0], item[1]
if check_dygraph:
imperative_actual = find_imperative_actual(
sub_out_name, dygraph_outs, place)
imperative_actual_t = np.array(imperative_actual.value()
.get_tensor())
idx = find_actual(sub_out_name, fetch_list)
actual = outs[idx]
actual_t = np.array(actual)
expect_t = expect[0] \
if isinstance(expect, tuple) else expect
self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol, equal_nan=equal_nan),
"Output (" + sub_out_name + ") has diff at " +
str(place))
if check_dygraph:
self.assertTrue(
np.allclose(
imperative_actual_t,
expect_t,
atol=atol,
equal_nan=equal_nan),
"Output (" + sub_out_name + ") has diff at " +
str(place) + " in dygraph mode")
if isinstance(expect, tuple):
self.assertListEqual(
actual.recursive_sequence_lengths(), expect[1],
"Output (" + sub_out_name +
") has different lod at " + str(place))
if check_dygraph:
self.assertListEqual(
imperative_actual.value().get_tensor()
.recursive_sequence_lengths(), expect[1],
"Output (" + out_name +
") has different lod at " + str(place) +
" in dygraph mode")
else:
if check_dygraph:
imperative_actual = find_imperative_actual(
out_name, dygraph_outs, place)
imperative_actual_t = np.array(imperative_actual.value()
.get_tensor())
idx = find_actual(out_name, fetch_list)
actual = outs[idx]
actual_t = np.array(actual)
expect = self.outputs[out_name]
expect_t = expect[0] if isinstance(expect, tuple) else expect
self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol, equal_nan=equal_nan),
"Output (" + out_name + ") has diff at " + str(place) +
"\nExpect " + str(expect_t) + "\n" + "But Got" +
str(actual_t) + " in class " + self.__class__.__name__ + " "
+ str(atol) + " " + str(expect_t - actual_t))
if check_dygraph:
if six.moves.reduce(
lambda x, y: x * y, imperative_actual_t.shape,
1) == 0 and six.moves.reduce(
lambda x, y: x * y, expect_t.shape, 1) == 0:
pass
else:
self.assertTrue(
np.allclose(
imperative_actual_t,
expect_t,
atol=atol,
equal_nan=equal_nan),
"Output (" + out_name + ") has diff at " +
str(place) + "\nExpect " + str(expect_t) + "\n" +
"But Got" + str(imperative_actual_t) + " in class "
+ self.__class__.__name__)
if isinstance(expect, tuple):
self.assertListEqual(actual.recursive_sequence_lengths(),
expect[1], "Output (" + out_name +
") has different lod at " + str(place))
if check_dygraph:
self.assertListEqual(
imperative_actual.value().get_tensor()
.recursive_sequence_lengths(), expect[1],
"Output (" + out_name + ") has different lod at " +
str(place) + " in dygraph mode")

# Note(zhiqiu): inplace_atol should be only set when op doesn't ensure
# computational consistency.
# For example, group_norm uses AtomicAdd on CUDAPlace, which do not ensure
# computation order when multiple threads write the same address. So the
# result of group_norm is non-deterministic when datatype is float.
# When inplace_atol is not None, the inplace check uses numpy.allclose
# to check inplace result instead of numpy.array_equal.
if inplace_atol is not None:
warnings.warn(
"inplace_atol should only be set when op doesn't ensure computational consistency, please check it!"
)
# Check inplace for given op, its grad op, its grad_grad op, etc.
# No effect on original OpTest
# Currently not support ParallelExecutor on XPUPlace.
if not paddle.is_compiled_with_xpu():
self.check_inplace_output_with_place(
place, no_check_set=no_check_set, inplace_atol=inplace_atol)

if check_dygraph:
return outs
else:
return outs
#xpu not support float64
if self.dtype == np.float64:
return
if place == None:
place = paddle.XPUPlace(0)

if self.dtype == np.float16:
if core.is_float16_supported(place) == False:
return
return super().check_output_with_place(
place, atol, no_check_set, equal_nan, check_dygraph, inplace_atol)

def check_grad_with_place(self,
place,
Expand All @@ -303,8 +101,25 @@ def check_grad_with_place(self,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
check_dygraph=True):
place = paddle.XPUPlace(0)
user_defined_grad_outputs=None,
check_dygraph=True,
numeric_place=None):
if place == None:
place = paddle.XPUPlace(0)

if self.dtype == np.float64:
return

if self.dtype == np.float16:
if core.is_float16_supported(place) == False:
return

if self.dtype == np.float16:
return super().check_grad_with_place(
place, inputs_to_check, output_names, no_grad_set,
numeric_grad_delta, in_place, max_relative_error,
user_defined_grads, user_defined_grads, check_dygraph)

a1 = self.get_grad_with_place(
place, inputs_to_check, output_names, no_grad_set=no_grad_set)
a2 = self.get_grad_with_place(
Expand Down
Loading