From 5117618d8d1b2270770a35b15f121ef9a97033ca Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Wed, 17 Apr 2024 13:59:08 +0800 Subject: [PATCH] [XPU] add bfloat16 supports for compare_kernel and add reduce_all_kernel --- paddle/phi/backends/xpu/xpu2_op_list.cc | 1 + paddle/phi/backends/xpu/xpu3_op_list.cc | 7 +++ paddle/phi/kernels/reduce_all_kernel.cc | 4 ++ paddle/phi/kernels/xpu/compare_kernel.cc | 4 +- paddle/phi/kernels/xpu/reduce_all_kernel.cc | 51 +++++++++++++++++++++ test/xpu/test_reduce_all_op_xpu.py | 30 ++++++++---- 6 files changed, 86 insertions(+), 11 deletions(-) create mode 100644 paddle/phi/kernels/xpu/reduce_all_kernel.cc diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 9698544b3738fd..167dcee1f88cb2 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -741,6 +741,7 @@ XPUOpMap& get_kl2_ops() { {"reciprocal", XPUKernelSet({phi::DataType::FLOAT32})}, {"reciprocal_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"reduce_all", XPUKernelSet({phi::DataType::BOOL})}, {"reduce_any", XPUKernelSet({phi::DataType::BOOL})}, {"reduce_max_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_max", diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index 779f35a483bc71..35f9f8c359bc41 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -349,6 +349,7 @@ XPUOpMap& get_kl3_ops() { XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32, phi::DataType::BOOL})}, {"exp_grad", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -517,11 +518,13 @@ XPUOpMap& get_kl3_ops() { XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"greater_than", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"grid_sampler_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"group_norm_silu_xpu", @@ -576,11 +579,13 @@ XPUOpMap& get_kl3_ops() { XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"less_than", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"load", XPUKernelSet({phi::DataType::FLOAT32})}, {"load_combine", @@ -669,6 +674,7 @@ XPUOpMap& get_kl3_ops() { XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"one_hot", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})}, {"one_hot_v2", @@ -716,6 +722,7 @@ XPUOpMap& get_kl3_ops() { {"reciprocal", XPUKernelSet({phi::DataType::FLOAT32})}, {"reciprocal_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"reduce_all", XPUKernelSet({phi::DataType::BOOL})}, {"reduce_any", XPUKernelSet({phi::DataType::BOOL})}, {"reduce_max_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_max", diff --git a/paddle/phi/kernels/reduce_all_kernel.cc b/paddle/phi/kernels/reduce_all_kernel.cc index d6f88a596af3ac..92bc5e97cc0211 100644 --- a/paddle/phi/kernels/reduce_all_kernel.cc +++ b/paddle/phi/kernels/reduce_all_kernel.cc @@ -53,3 +53,7 @@ PD_REGISTER_KERNEL( #if defined(PADDLE_WITH_XPU_KP) PD_REGISTER_KERNEL(all, KPS, ALL_LAYOUT, phi::AllKernel, bool) {} #endif + +#if defined(PADDLE_WITH_XPU) +PD_REGISTER_KERNEL(all, XPU, ALL_LAYOUT, phi::AllKernel, bool) {} +#endif diff --git a/paddle/phi/kernels/xpu/compare_kernel.cc b/paddle/phi/kernels/xpu/compare_kernel.cc index 2732823fd94282..d0878e6749711e 100644 --- a/paddle/phi/kernels/xpu/compare_kernel.cc +++ b/paddle/phi/kernels/xpu/compare_kernel.cc @@ -88,7 +88,8 @@ PD_REGISTER_KERNEL(less_than, int, int64_t, float, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } @@ -101,6 +102,7 @@ PD_REGISTER_KERNEL(less_than, int64_t, \ float, \ phi::dtype::float16, \ + phi::dtype::bfloat16, \ bool) { \ kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ } diff --git a/paddle/phi/kernels/xpu/reduce_all_kernel.cc b/paddle/phi/kernels/xpu/reduce_all_kernel.cc new file mode 100644 index 00000000000000..e9731db88c7a03 --- /dev/null +++ b/paddle/phi/kernels/xpu/reduce_all_kernel.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_all_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/xpu/reduce.h" + +namespace phi { + +template +void AllRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims); + using XPUType = typename XPUTypeTrait::Type; + auto f = [](xpu::Context* ctx, + const T* x, + T* y, + const std::vector& xdims, + const std::vector& reduce_dims) { + return xpu::reduce_all(ctx, + reinterpret_cast(x), + reinterpret_cast(y), + xdims, + reduce_dims); + }; + + int r = XPUReduce(dev_ctx, x, dims, keep_dim, reduce_all, out, f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_all"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(all_raw, XPU, ALL_LAYOUT, phi::AllRawKernel, bool) {} diff --git a/test/xpu/test_reduce_all_op_xpu.py b/test/xpu/test_reduce_all_op_xpu.py index 313d8297a17054..2d11d04ad63db3 100644 --- a/test/xpu/test_reduce_all_op_xpu.py +++ b/test/xpu/test_reduce_all_op_xpu.py @@ -40,8 +40,8 @@ def set_case(self): self.op_type = 'reduce_all' self.attrs = { 'use_xpu': True, - 'reduce_all': True, - 'keep_dim': True, + 'reduce_all': False, + 'keep_dim': False, 'dim': (3, 5, 4), } self.inputs = { @@ -49,7 +49,11 @@ def set_case(self): "bool" ) } - self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])} + self.outputs = { + 'Out': self.inputs['X'].all( + axis=self.attrs['dim'], keepdims=self.attrs['keep_dim'] + ) + } def test_check_output(self): self.check_output_with_place(self.place) @@ -63,7 +67,7 @@ def set_case(self): self.attrs = { 'use_xpu': True, 'reduce_all': True, - 'keep_dim': True, + 'keep_dim': False, 'dim': [1], } self.inputs = { @@ -76,8 +80,8 @@ def set_case(self): self.op_type = 'reduce_all' self.attrs = { 'use_xpu': True, - 'reduce_all': True, - 'keep_dim': False, + 'reduce_all': False, + 'keep_dim': True, 'dim': (3, 6), } self.inputs = { @@ -85,22 +89,28 @@ def set_case(self): "bool" ) } - self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])} + self.outputs = { + 'Out': self.inputs['X'].all( + axis=self.attrs['dim'], keepdims=self.attrs['keep_dim'] + ) + } class XPUTestReduceAllCase3(XPUTestReduceAllBase): def set_case(self): self.op_type = 'reduce_all' self.attrs = { 'use_xpu': True, + 'reduce_all': True, 'keep_dim': True, - 'dim': [1] - # 'reduce_all': True, + 'dim': [1], } self.inputs = { 'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool") } self.outputs = { - 'Out': np.expand_dims(self.inputs['X'].all(axis=1), axis=1) + 'Out': self.inputs['X'].all( + axis=(0, 1, 2), keepdims=self.attrs['keep_dim'] + ) }