From d3e3e31c61f7e534ba009472d30d71a27dd6684c Mon Sep 17 00:00:00 2001 From: tongkai <1092019531@qq.com> Date: Thu, 7 Mar 2024 19:47:26 +0800 Subject: [PATCH 01/13] customdevice support stride slice & slice_grad --- paddle/phi/api/lib/api_gen_utils.cc | 75 +++++++++++++++++++ paddle/phi/api/lib/data_transform.cc | 20 +++++ paddle/phi/backends/device_ext.h | 1 + paddle/phi/core/kernel_factory.cc | 11 +++ paddle/phi/core/visit_type.h | 15 ++++ .../phi/kernels/stride/slice_grad_kernel.cc | 58 ++++++++++---- paddle/phi/kernels/stride/slice_kernel.cc | 5 +- 7 files changed, 168 insertions(+), 17 deletions(-) diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index 87e6f9af430755..331ceec3c53c70 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -416,6 +416,31 @@ void TransStride(phi::DeviceContext* dev_ctx, delete from; return; } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto* custom_ctx = dynamic_cast(dev_ctx); + if (custom_ctx) { + const phi::KernelKey& kernel_key = {phi::TransToPhiBackend(to->place()), + phi::DataLayout::ALL_LAYOUT, + to->dtype()}; + using kernel_signature = void (*)(const phi::DeviceContext&, + const phi::DenseTensor&, + const std::vector&, + const std::vector&, + int64_t, + phi::DenseTensor*); + PD_VISIT_KERNEL("strided_copy", + kernel_key, + kernel_signature, + *custom_ctx, + *from, + common::vectorize(to->dims()), + common::vectorize(to->strides()), + to->offset(), + to); + delete from; + return; + } #endif } } @@ -466,6 +491,30 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx, })); return; } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto* custom_ctx = dynamic_cast(dev_ctx); + if (custom_ctx) { + const phi::KernelKey& kernel_key = {phi::TransToPhiBackend(to->place()), + phi::DataLayout::ALL_LAYOUT, + to->dtype()}; + using kernel_signature = void (*)(const phi::DeviceContext&, + const phi::DenseTensor&, + const std::vector&, + const std::vector&, + int64_t, + phi::DenseTensor*); + PD_VISIT_KERNEL("strided_copy", + kernel_key, + kernel_signature, + *custom_ctx, + *from, + common::vectorize(to->dims()), + common::vectorize(to->strides()), + to->offset(), + to); + return; + } #endif } } @@ -520,6 +569,32 @@ void TransStride(phi::DeviceContext* dev_ctx, delete from[i]; continue; } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto* custom_ctx = dynamic_cast(dev_ctx); + if (custom_ctx) { + const phi::KernelKey& kernel_key = { + phi::TransToPhiBackend(to[i]->place()), + phi::DataLayout::ALL_LAYOUT, + to[i]->dtype()}; + using kernel_signature = void (*)(const phi::DeviceContext&, + const phi::DenseTensor&, + const std::vector&, + const std::vector&, + int64_t, + phi::DenseTensor*); + PD_VISIT_KERNEL("strided_copy", + kernel_key, + kernel_signature, + *custom_ctx, + *from[i], + common::vectorize(to[i]->dims()), + common::vectorize(to[i]->strides()), + to[i]->offset(), + to[i]); + delete from[i]; + return; + } #endif } } diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 80bb9f44475734..98f7543cd3a7bf 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -255,6 +255,26 @@ phi::DenseTensor Trans2Contiguous(const phi::DenseTensor& tensor) { } else if (tensor.place().GetType() == phi::AllocationType::XPU) { auto* dev_ctx = static_cast(pool.Get(tensor.place())); return TensorContiguous(*dev_ctx, tensor); +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + } else if (tensor.place().GetType() == phi::AllocationType::CUSTOM) { + auto* dev_ctx = static_cast(pool.Get(tensor.place())); + phi::DenseTensor dense_out; + phi::MetaTensor meta_input(tensor); + phi::MetaTensor meta_out(&dense_out); + UnchangedInferMeta(meta_input, &meta_out); + const phi::KernelKey& kernel_key = {phi::TransToPhiBackend(tensor.place()), + phi::DataLayout::ALL_LAYOUT, + tensor.dtype()}; + using kernel_signature = void (*)( + const phi::DeviceContext&, const phi::DenseTensor&, phi::DenseTensor*); + PD_VISIT_KERNEL("contiguous", + kernel_key, + kernel_signature, + *dev_ctx, + tensor, + &dense_out); + return dense_out; #endif } else { PADDLE_THROW(phi::errors::Unimplemented( diff --git a/paddle/phi/backends/device_ext.h b/paddle/phi/backends/device_ext.h index bd3f5f687f29b1..3e7497b2209dcd 100644 --- a/paddle/phi/backends/device_ext.h +++ b/paddle/phi/backends/device_ext.h @@ -50,6 +50,7 @@ typedef enum { NCHW, NCDHW, NDHWC, + STRIDED, NUM_DATA_LAYOUTS, ALL_LAYOUT = ANY, } C_DataLayout; diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index f04c1b2c880bd0..32644cfe8bf631 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -249,6 +249,17 @@ KernelResult KernelFactory::SelectKernelOrThrowError( if (stride_kernel_iter != iter->second.end()) { return {stride_kernel_iter->second, false, true}; } +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (stride_kernel_iter == iter->second.end() && + const_kernel_key.backend() > phi::Backend::NUM_BACKENDS) { + stride_kernel_iter = iter->second.find({phi::Backend::CUSTOM, + phi::DataLayout::STRIDED, + const_kernel_key.dtype()}); + if (stride_kernel_iter != iter->second.end()) { + return {stride_kernel_iter->second, false, true}; + } + } +#endif } KernelKey kernel_key = KernelKey(const_kernel_key.backend(), diff --git a/paddle/phi/core/visit_type.h b/paddle/phi/core/visit_type.h index 7ee12e26d7d0ef..53afe73af64438 100644 --- a/paddle/phi/core/visit_type.h +++ b/paddle/phi/core/visit_type.h @@ -14,8 +14,10 @@ limitations under the License. */ #pragma once +#include #include "paddle/common/exception.h" #include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_factory.h" namespace phi { @@ -471,4 +473,17 @@ namespace phi { } \ }() +#define PD_VISIT_KERNEL(kernel_name, kernel_key, kernel_signature, ...) \ + [&] { \ + auto kernel_result = \ + phi::KernelFactory::Instance().SelectKernelOrThrowError(kernel_name, \ + kernel_key); \ + const auto& kernel = kernel_result.kernel; \ + if (kernel_result.has_fallback_cpu) { \ + VLOG(6) << "missing kernel:" << kernel_name; \ + } \ + VLOG(6) << kernel_name << "kernel: " << kernel; \ + auto* kernel_fn = kernel.GetVariadicKernelFn(); \ + (*kernel_fn)(__VA_ARGS__); \ + }() } // namespace phi diff --git a/paddle/phi/kernels/stride/slice_grad_kernel.cc b/paddle/phi/kernels/stride/slice_grad_kernel.cc index 171c20b3b83acd..00a66f03809c60 100644 --- a/paddle/phi/kernels/stride/slice_grad_kernel.cc +++ b/paddle/phi/kernels/stride/slice_grad_kernel.cc @@ -33,10 +33,22 @@ void SliceGradStridedKernel(const Context& dev_ctx, DenseTensor* input_grad) { dev_ctx.Alloc(input_grad, input_grad->dtype()); input_grad->set_strides(DenseTensorMeta::calc_strides(input_grad->dims())); - PD_VISIT_ALL_TYPES(input.dtype(), "SliceGradStridedKernel", ([&] { - phi::FillKernel( - dev_ctx, *input_grad, 0, input_grad); - })); + const phi::KernelKey& kernel_key = {phi::TransToPhiBackend(input.place()), + phi::DataLayout::ALL_LAYOUT, + input.dtype()}; + using kernel_signature_fill = void (*)( + const DeviceContext&, const DenseTensor&, const Scalar&, DenseTensor*); + PD_VISIT_KERNEL("fill", + kernel_key, + kernel_signature_fill, + dev_ctx, + *input_grad, + 0, + input_grad); + // PD_VISIT_ALL_TYPES(input.dtype(), "SliceGradStridedKernel", ([&] { + // phi::FillKernel( + // dev_ctx, *input_grad, 0, input_grad); + // })); DenseTensor tmp; tmp.set_meta(out_grad.meta()); SliceStridedKernel(dev_ctx, @@ -47,17 +59,33 @@ void SliceGradStridedKernel(const Context& dev_ctx, infer_flags, decrease_axis, &tmp); - PD_VISIT_ALL_TYPES(input.dtype(), "SliceGradStridedKernel", ([&] { - phi::StridedCopyKernel( - dev_ctx, - out_grad, - common::vectorize(tmp.dims()), - common::vectorize(tmp.strides()), - tmp.offset(), - &tmp); - })); + using kernel_signature_strided_copy = void (*)(const DeviceContext&, + const DenseTensor&, + const std::vector&, + const std::vector&, + int64_t, + DenseTensor*); + PD_VISIT_KERNEL("strided_copy", + kernel_key, + kernel_signature_strided_copy, + dev_ctx, + out_grad, + common::vectorize(tmp.dims()), + common::vectorize(tmp.strides()), + tmp.offset(), + &tmp); + // PD_VISIT_ALL_TYPES(input.dtype(), "SliceGradStridedKernel", ([&] { + // phi::StridedCopyKernel( + // dev_ctx, + // out_grad, + // common::vectorize(tmp.dims()), + // common::vectorize(tmp.strides()), + // tmp.offset(), + // &tmp); + // })); } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - slice_grad, STRIDED, phi::SliceGradStridedKernel) {} +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(slice_grad, + STRIDED, + phi::SliceGradStridedKernel) {} diff --git a/paddle/phi/kernels/stride/slice_kernel.cc b/paddle/phi/kernels/stride/slice_kernel.cc index 132fb30c314aa7..8961ee039b9828 100644 --- a/paddle/phi/kernels/stride/slice_kernel.cc +++ b/paddle/phi/kernels/stride/slice_kernel.cc @@ -95,5 +95,6 @@ void SliceStridedKernel(const Context& ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - slice, STRIDED, phi::SliceStridedKernel) {} +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(slice, + STRIDED, + phi::SliceStridedKernel) {} From 74db2fdff8c026c83e24c40a33e4da9191d36317 Mon Sep 17 00:00:00 2001 From: tongkai <1092019531@qq.com> Date: Thu, 7 Mar 2024 20:16:14 +0800 Subject: [PATCH 02/13] refine & remove redundant code --- paddle/phi/core/visit_type.h | 2 +- paddle/phi/kernels/stride/slice_grad_kernel.cc | 14 -------------- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/paddle/phi/core/visit_type.h b/paddle/phi/core/visit_type.h index 53afe73af64438..5f643681dafe24 100644 --- a/paddle/phi/core/visit_type.h +++ b/paddle/phi/core/visit_type.h @@ -480,7 +480,7 @@ namespace phi { kernel_key); \ const auto& kernel = kernel_result.kernel; \ if (kernel_result.has_fallback_cpu) { \ - VLOG(6) << "missing kernel:" << kernel_name; \ + VLOG(6) << "missing kernel: " << kernel_name; \ } \ VLOG(6) << kernel_name << "kernel: " << kernel; \ auto* kernel_fn = kernel.GetVariadicKernelFn(); \ diff --git a/paddle/phi/kernels/stride/slice_grad_kernel.cc b/paddle/phi/kernels/stride/slice_grad_kernel.cc index 00a66f03809c60..21ec151230bfad 100644 --- a/paddle/phi/kernels/stride/slice_grad_kernel.cc +++ b/paddle/phi/kernels/stride/slice_grad_kernel.cc @@ -45,10 +45,6 @@ void SliceGradStridedKernel(const Context& dev_ctx, *input_grad, 0, input_grad); - // PD_VISIT_ALL_TYPES(input.dtype(), "SliceGradStridedKernel", ([&] { - // phi::FillKernel( - // dev_ctx, *input_grad, 0, input_grad); - // })); DenseTensor tmp; tmp.set_meta(out_grad.meta()); SliceStridedKernel(dev_ctx, @@ -74,17 +70,7 @@ void SliceGradStridedKernel(const Context& dev_ctx, common::vectorize(tmp.strides()), tmp.offset(), &tmp); - // PD_VISIT_ALL_TYPES(input.dtype(), "SliceGradStridedKernel", ([&] { - // phi::StridedCopyKernel( - // dev_ctx, - // out_grad, - // common::vectorize(tmp.dims()), - // common::vectorize(tmp.strides()), - // tmp.offset(), - // &tmp); - // })); } - } // namespace phi PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(slice_grad, STRIDED, From ec846beb154f4e645370145cc9be534ff575f342 Mon Sep 17 00:00:00 2001 From: tongkai <1092019531@qq.com> Date: Thu, 7 Mar 2024 20:20:19 +0800 Subject: [PATCH 03/13] refine --- paddle/phi/core/visit_type.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/core/visit_type.h b/paddle/phi/core/visit_type.h index 5f643681dafe24..d8bf8850c3594d 100644 --- a/paddle/phi/core/visit_type.h +++ b/paddle/phi/core/visit_type.h @@ -482,7 +482,7 @@ namespace phi { if (kernel_result.has_fallback_cpu) { \ VLOG(6) << "missing kernel: " << kernel_name; \ } \ - VLOG(6) << kernel_name << "kernel: " << kernel; \ + VLOG(6) << kernel_name << " kernel: " << kernel; \ auto* kernel_fn = kernel.GetVariadicKernelFn(); \ (*kernel_fn)(__VA_ARGS__); \ }() From d23aa3f2f169a70dc09873b73f2e50ab521deb4f Mon Sep 17 00:00:00 2001 From: tongkai <1092019531@qq.com> Date: Thu, 7 Mar 2024 21:49:41 +0800 Subject: [PATCH 04/13] remove VLOG --- paddle/phi/core/visit_type.h | 5 ----- 1 file changed, 5 deletions(-) diff --git a/paddle/phi/core/visit_type.h b/paddle/phi/core/visit_type.h index d8bf8850c3594d..3e10d2d651aa04 100644 --- a/paddle/phi/core/visit_type.h +++ b/paddle/phi/core/visit_type.h @@ -14,7 +14,6 @@ limitations under the License. */ #pragma once -#include #include "paddle/common/exception.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_factory.h" @@ -479,10 +478,6 @@ namespace phi { phi::KernelFactory::Instance().SelectKernelOrThrowError(kernel_name, \ kernel_key); \ const auto& kernel = kernel_result.kernel; \ - if (kernel_result.has_fallback_cpu) { \ - VLOG(6) << "missing kernel: " << kernel_name; \ - } \ - VLOG(6) << kernel_name << " kernel: " << kernel; \ auto* kernel_fn = kernel.GetVariadicKernelFn(); \ (*kernel_fn)(__VA_ARGS__); \ }() From 28375d8a98dc668fe5475f80180cd34c13aea370 Mon Sep 17 00:00:00 2001 From: tongkai <1092019531@qq.com> Date: Thu, 7 Mar 2024 22:56:05 +0800 Subject: [PATCH 05/13] replace input.place() with input_grad->place()& include kernel_factory.h --- paddle/phi/kernels/stride/slice_grad_kernel.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/stride/slice_grad_kernel.cc b/paddle/phi/kernels/stride/slice_grad_kernel.cc index 21ec151230bfad..4bfb56551cc308 100644 --- a/paddle/phi/kernels/stride/slice_grad_kernel.cc +++ b/paddle/phi/kernels/stride/slice_grad_kernel.cc @@ -33,9 +33,10 @@ void SliceGradStridedKernel(const Context& dev_ctx, DenseTensor* input_grad) { dev_ctx.Alloc(input_grad, input_grad->dtype()); input_grad->set_strides(DenseTensorMeta::calc_strides(input_grad->dims())); - const phi::KernelKey& kernel_key = {phi::TransToPhiBackend(input.place()), - phi::DataLayout::ALL_LAYOUT, - input.dtype()}; + const phi::KernelKey& kernel_key = { + phi::TransToPhiBackend(input_grad->place()), + phi::DataLayout::ALL_LAYOUT, + input.dtype()}; using kernel_signature_fill = void (*)( const DeviceContext&, const DenseTensor&, const Scalar&, DenseTensor*); PD_VISIT_KERNEL("fill", From 6b3a137170b3291b59913169b74abe196146b989 Mon Sep 17 00:00:00 2001 From: tongkai <1092019531@qq.com> Date: Thu, 7 Mar 2024 23:02:53 +0800 Subject: [PATCH 06/13] include kernel_factory.h --- paddle/phi/api/lib/api_gen_utils.cc | 1 + paddle/phi/core/visit_type.h | 1 - paddle/phi/kernels/stride/slice_grad_kernel.cc | 1 + 3 files changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index 331ceec3c53c70..e3520223669470 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -24,6 +24,7 @@ PHI_DECLARE_bool(use_stride_kernel); #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/kernel_factory.h" namespace paddle { namespace experimental { diff --git a/paddle/phi/core/visit_type.h b/paddle/phi/core/visit_type.h index 3e10d2d651aa04..298cd859676cfd 100644 --- a/paddle/phi/core/visit_type.h +++ b/paddle/phi/core/visit_type.h @@ -16,7 +16,6 @@ limitations under the License. */ #include "paddle/common/exception.h" #include "paddle/phi/common/data_type.h" -#include "paddle/phi/core/kernel_factory.h" namespace phi { diff --git a/paddle/phi/kernels/stride/slice_grad_kernel.cc b/paddle/phi/kernels/stride/slice_grad_kernel.cc index 4bfb56551cc308..32066fa297f85f 100644 --- a/paddle/phi/kernels/stride/slice_grad_kernel.cc +++ b/paddle/phi/kernels/stride/slice_grad_kernel.cc @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/slice_grad_kernel.h" #include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/fill_kernel.h" #include "paddle/phi/kernels/slice_kernel.h" From e6a45b1f70c58d367a7e516e1cb605b1e80712dc Mon Sep 17 00:00:00 2001 From: tongkai <1092019531@qq.com> Date: Fri, 8 Mar 2024 09:44:40 +0800 Subject: [PATCH 07/13] set different kernel_keys for different kernels --- .../phi/kernels/stride/slice_grad_kernel.cc | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/paddle/phi/kernels/stride/slice_grad_kernel.cc b/paddle/phi/kernels/stride/slice_grad_kernel.cc index 32066fa297f85f..d14cd2ac63ab08 100644 --- a/paddle/phi/kernels/stride/slice_grad_kernel.cc +++ b/paddle/phi/kernels/stride/slice_grad_kernel.cc @@ -34,19 +34,13 @@ void SliceGradStridedKernel(const Context& dev_ctx, DenseTensor* input_grad) { dev_ctx.Alloc(input_grad, input_grad->dtype()); input_grad->set_strides(DenseTensorMeta::calc_strides(input_grad->dims())); - const phi::KernelKey& kernel_key = { - phi::TransToPhiBackend(input_grad->place()), - phi::DataLayout::ALL_LAYOUT, - input.dtype()}; - using kernel_signature_fill = void (*)( + const phi::KernelKey& fill_key = {phi::TransToPhiBackend(input_grad->place()), + phi::DataLayout::ALL_LAYOUT, + input.dtype()}; + using fill_signature = void (*)( const DeviceContext&, const DenseTensor&, const Scalar&, DenseTensor*); - PD_VISIT_KERNEL("fill", - kernel_key, - kernel_signature_fill, - dev_ctx, - *input_grad, - 0, - input_grad); + PD_VISIT_KERNEL( + "fill", fill_key, fill_signature, dev_ctx, *input_grad, 0, input_grad); DenseTensor tmp; tmp.set_meta(out_grad.meta()); SliceStridedKernel(dev_ctx, @@ -57,15 +51,19 @@ void SliceGradStridedKernel(const Context& dev_ctx, infer_flags, decrease_axis, &tmp); - using kernel_signature_strided_copy = void (*)(const DeviceContext&, - const DenseTensor&, - const std::vector&, - const std::vector&, - int64_t, - DenseTensor*); + const phi::KernelKey& strided_copy_key = { + phi::TransToPhiBackend(out_grad->place()), + phi::DataLayout::ALL_LAYOUT, + input.dtype()}; + using strided_copy_signature = void (*)(const DeviceContext&, + const DenseTensor&, + const std::vector&, + const std::vector&, + int64_t, + DenseTensor*); PD_VISIT_KERNEL("strided_copy", - kernel_key, - kernel_signature_strided_copy, + strided_copy_key, + strided_copy_signature, dev_ctx, out_grad, common::vectorize(tmp.dims()), From 7f52d27bbbc3e571d3eeab3b32bfb14716e589bd Mon Sep 17 00:00:00 2001 From: tongkai <1092019531@qq.com> Date: Fri, 8 Mar 2024 11:00:35 +0800 Subject: [PATCH 08/13] fix syntax error --- paddle/phi/kernels/stride/slice_grad_kernel.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/stride/slice_grad_kernel.cc b/paddle/phi/kernels/stride/slice_grad_kernel.cc index d14cd2ac63ab08..fbd1c16733bbf4 100644 --- a/paddle/phi/kernels/stride/slice_grad_kernel.cc +++ b/paddle/phi/kernels/stride/slice_grad_kernel.cc @@ -52,7 +52,7 @@ void SliceGradStridedKernel(const Context& dev_ctx, decrease_axis, &tmp); const phi::KernelKey& strided_copy_key = { - phi::TransToPhiBackend(out_grad->place()), + phi::TransToPhiBackend(out_grad.place()), phi::DataLayout::ALL_LAYOUT, input.dtype()}; using strided_copy_signature = void (*)(const DeviceContext&, From 6532edd55424f7586cf86196ad60f53e17a69403 Mon Sep 17 00:00:00 2001 From: tongkai <1092019531@qq.com> Date: Fri, 8 Mar 2024 11:30:18 +0800 Subject: [PATCH 09/13] if missing kernel, throw error --- paddle/phi/core/visit_type.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddle/phi/core/visit_type.h b/paddle/phi/core/visit_type.h index 298cd859676cfd..c8a3a4ef0dbb3b 100644 --- a/paddle/phi/core/visit_type.h +++ b/paddle/phi/core/visit_type.h @@ -478,6 +478,10 @@ namespace phi { kernel_key); \ const auto& kernel = kernel_result.kernel; \ auto* kernel_fn = kernel.GetVariadicKernelFn(); \ + if (kernel_result.has_fallback_cpu) { \ + PADDLE_THROW( \ + phi::errors::Unimplemented("Missing kernel: %s", kernel_name)); \ + } \ (*kernel_fn)(__VA_ARGS__); \ }() } // namespace phi From c005dbb3e53f5375f0a44afbac7ca4ce966e4935 Mon Sep 17 00:00:00 2001 From: tongkai <1092019531@qq.com> Date: Fri, 8 Mar 2024 12:56:12 +0800 Subject: [PATCH 10/13] use dev_ctx.GetPlace() to get place --- paddle/phi/kernels/stride/slice_grad_kernel.cc | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/paddle/phi/kernels/stride/slice_grad_kernel.cc b/paddle/phi/kernels/stride/slice_grad_kernel.cc index fbd1c16733bbf4..ba3f048c1634aa 100644 --- a/paddle/phi/kernels/stride/slice_grad_kernel.cc +++ b/paddle/phi/kernels/stride/slice_grad_kernel.cc @@ -34,13 +34,14 @@ void SliceGradStridedKernel(const Context& dev_ctx, DenseTensor* input_grad) { dev_ctx.Alloc(input_grad, input_grad->dtype()); input_grad->set_strides(DenseTensorMeta::calc_strides(input_grad->dims())); - const phi::KernelKey& fill_key = {phi::TransToPhiBackend(input_grad->place()), - phi::DataLayout::ALL_LAYOUT, - input.dtype()}; + const phi::KernelKey& kernel_key = { + phi::TransToPhiBackend(dev_ctx.GetPlace()), + phi::DataLayout::ALL_LAYOUT, + input.dtype()}; using fill_signature = void (*)( const DeviceContext&, const DenseTensor&, const Scalar&, DenseTensor*); PD_VISIT_KERNEL( - "fill", fill_key, fill_signature, dev_ctx, *input_grad, 0, input_grad); + "fill", kernel_key, fill_signature, dev_ctx, *input_grad, 0, input_grad); DenseTensor tmp; tmp.set_meta(out_grad.meta()); SliceStridedKernel(dev_ctx, @@ -51,10 +52,6 @@ void SliceGradStridedKernel(const Context& dev_ctx, infer_flags, decrease_axis, &tmp); - const phi::KernelKey& strided_copy_key = { - phi::TransToPhiBackend(out_grad.place()), - phi::DataLayout::ALL_LAYOUT, - input.dtype()}; using strided_copy_signature = void (*)(const DeviceContext&, const DenseTensor&, const std::vector&, @@ -62,7 +59,7 @@ void SliceGradStridedKernel(const Context& dev_ctx, int64_t, DenseTensor*); PD_VISIT_KERNEL("strided_copy", - strided_copy_key, + kernel_key, strided_copy_signature, dev_ctx, out_grad, From eaa1d6501b61b6d3dc056be4b96ed6b9f01d0315 Mon Sep 17 00:00:00 2001 From: tongkai <1092019531@qq.com> Date: Fri, 8 Mar 2024 16:23:36 +0800 Subject: [PATCH 11/13] modify error message --- paddle/phi/api/lib/api_gen_utils.cc | 3 +++ paddle/phi/api/lib/data_transform.cc | 1 + paddle/phi/core/visit_type.h | 27 ++++++++++--------- .../phi/kernels/stride/slice_grad_kernel.cc | 11 ++++++-- 4 files changed, 28 insertions(+), 14 deletions(-) diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index e3520223669470..ef5cfc90727ff5 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -433,6 +433,7 @@ void TransStride(phi::DeviceContext* dev_ctx, PD_VISIT_KERNEL("strided_copy", kernel_key, kernel_signature, + false, *custom_ctx, *from, common::vectorize(to->dims()), @@ -508,6 +509,7 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx, PD_VISIT_KERNEL("strided_copy", kernel_key, kernel_signature, + false, *custom_ctx, *from, common::vectorize(to->dims()), @@ -587,6 +589,7 @@ void TransStride(phi::DeviceContext* dev_ctx, PD_VISIT_KERNEL("strided_copy", kernel_key, kernel_signature, + false, *custom_ctx, *from[i], common::vectorize(to[i]->dims()), diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 98f7543cd3a7bf..d310d43f4b7e08 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -271,6 +271,7 @@ phi::DenseTensor Trans2Contiguous(const phi::DenseTensor& tensor) { PD_VISIT_KERNEL("contiguous", kernel_key, kernel_signature, + false, *dev_ctx, tensor, &dense_out); diff --git a/paddle/phi/core/visit_type.h b/paddle/phi/core/visit_type.h index c8a3a4ef0dbb3b..ad30da4ddcd6f0 100644 --- a/paddle/phi/core/visit_type.h +++ b/paddle/phi/core/visit_type.h @@ -471,17 +471,20 @@ namespace phi { } \ }() -#define PD_VISIT_KERNEL(kernel_name, kernel_key, kernel_signature, ...) \ - [&] { \ - auto kernel_result = \ - phi::KernelFactory::Instance().SelectKernelOrThrowError(kernel_name, \ - kernel_key); \ - const auto& kernel = kernel_result.kernel; \ - auto* kernel_fn = kernel.GetVariadicKernelFn(); \ - if (kernel_result.has_fallback_cpu) { \ - PADDLE_THROW( \ - phi::errors::Unimplemented("Missing kernel: %s", kernel_name)); \ - } \ - (*kernel_fn)(__VA_ARGS__); \ +#define PD_VISIT_KERNEL( \ + kernel_name, kernel_key, kernel_signature, use_strided_kernel, ...) \ + [&] { \ + auto kernel_result = \ + phi::KernelFactory::Instance().SelectKernelOrThrowError( \ + kernel_name, kernel_key, use_strided_kernel); \ + const auto& kernel = kernel_result.kernel; \ + auto* kernel_fn = kernel.GetVariadicKernelFn(); \ + if (kernel_result.has_fallback_cpu) { \ + PADDLE_THROW(phi::errors::NotFound( \ + "The kernel with key %s of kernel `%s` is not registered.", \ + kernel_key, \ + kernel_name)); \ + } \ + (*kernel_fn)(__VA_ARGS__); \ }() } // namespace phi diff --git a/paddle/phi/kernels/stride/slice_grad_kernel.cc b/paddle/phi/kernels/stride/slice_grad_kernel.cc index ba3f048c1634aa..a61826ffb69b71 100644 --- a/paddle/phi/kernels/stride/slice_grad_kernel.cc +++ b/paddle/phi/kernels/stride/slice_grad_kernel.cc @@ -40,8 +40,14 @@ void SliceGradStridedKernel(const Context& dev_ctx, input.dtype()}; using fill_signature = void (*)( const DeviceContext&, const DenseTensor&, const Scalar&, DenseTensor*); - PD_VISIT_KERNEL( - "fill", kernel_key, fill_signature, dev_ctx, *input_grad, 0, input_grad); + PD_VISIT_KERNEL("fill", + kernel_key, + fill_signature, + false, + dev_ctx, + *input_grad, + 0, + input_grad); DenseTensor tmp; tmp.set_meta(out_grad.meta()); SliceStridedKernel(dev_ctx, @@ -61,6 +67,7 @@ void SliceGradStridedKernel(const Context& dev_ctx, PD_VISIT_KERNEL("strided_copy", kernel_key, strided_copy_signature, + false, dev_ctx, out_grad, common::vectorize(tmp.dims()), From fec9ace3407449ce39a95809ff5bd182ccde2bde Mon Sep 17 00:00:00 2001 From: tongkai <1092019531@qq.com> Date: Sat, 9 Mar 2024 16:07:10 +0800 Subject: [PATCH 12/13] move stride_funcs.h to kernels dir --- paddle/phi/kernels/stride/slice_grad_kernel.cc | 2 +- paddle/phi/kernels/{stride/stride_func.h => stride_funcs.h} | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) rename paddle/phi/kernels/{stride/stride_func.h => stride_funcs.h} (97%) diff --git a/paddle/phi/kernels/stride/slice_grad_kernel.cc b/paddle/phi/kernels/stride/slice_grad_kernel.cc index 76ce30611671b1..4504c9a1fda6fb 100644 --- a/paddle/phi/kernels/stride/slice_grad_kernel.cc +++ b/paddle/phi/kernels/stride/slice_grad_kernel.cc @@ -18,7 +18,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/fill_kernel.h" #include "paddle/phi/kernels/slice_kernel.h" -#include "paddle/phi/kernels/stride/stride_funcs.h" +#include "paddle/phi/kernels/stride_funcs.h" namespace phi { diff --git a/paddle/phi/kernels/stride/stride_func.h b/paddle/phi/kernels/stride_funcs.h similarity index 97% rename from paddle/phi/kernels/stride/stride_func.h rename to paddle/phi/kernels/stride_funcs.h index 87fc5cfb880d2c..a8654428adb7e2 100644 --- a/paddle/phi/kernels/stride/stride_func.h +++ b/paddle/phi/kernels/stride_funcs.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/fill_kernel.h" #include "paddle/phi/kernels/strided_copy_kernel.h" From ef214c183320ead4b3c952b9bdaa083e5766c18b Mon Sep 17 00:00:00 2001 From: tongkai <1092019531@qq.com> Date: Sun, 10 Mar 2024 23:52:05 +0800 Subject: [PATCH 13/13] reset use_stride_kernel for customdevice --- paddle/phi/core/kernel_factory.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 8457ff2030ddf7..32644cfe8bf631 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -28,15 +28,9 @@ #include "paddle/phi/core/compat/op_utils.h" #include "paddle/utils/string/string_helper.h" -#if defined(PADDLE_WITH_CUSTOM_DEVICE) -PHI_DEFINE_EXPORTED_bool(use_stride_kernel, - false, - "Whether to use stride kernel if op support stride."); -#else PHI_DEFINE_EXPORTED_bool(use_stride_kernel, true, "Whether to use stride kernel if op support stride."); -#endif COMMON_DECLARE_int32(low_precision_op_list); COMMON_DECLARE_bool(enable_api_kernel_fallback);