diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index 87e6f9af430755..ef5cfc90727ff5 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -24,6 +24,7 @@ PHI_DECLARE_bool(use_stride_kernel); #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/kernel_factory.h" namespace paddle { namespace experimental { @@ -416,6 +417,32 @@ void TransStride(phi::DeviceContext* dev_ctx, delete from; return; } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto* custom_ctx = dynamic_cast(dev_ctx); + if (custom_ctx) { + const phi::KernelKey& kernel_key = {phi::TransToPhiBackend(to->place()), + phi::DataLayout::ALL_LAYOUT, + to->dtype()}; + using kernel_signature = void (*)(const phi::DeviceContext&, + const phi::DenseTensor&, + const std::vector&, + const std::vector&, + int64_t, + phi::DenseTensor*); + PD_VISIT_KERNEL("strided_copy", + kernel_key, + kernel_signature, + false, + *custom_ctx, + *from, + common::vectorize(to->dims()), + common::vectorize(to->strides()), + to->offset(), + to); + delete from; + return; + } #endif } } @@ -466,6 +493,31 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx, })); return; } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto* custom_ctx = dynamic_cast(dev_ctx); + if (custom_ctx) { + const phi::KernelKey& kernel_key = {phi::TransToPhiBackend(to->place()), + phi::DataLayout::ALL_LAYOUT, + to->dtype()}; + using kernel_signature = void (*)(const phi::DeviceContext&, + const phi::DenseTensor&, + const std::vector&, + const std::vector&, + int64_t, + phi::DenseTensor*); + PD_VISIT_KERNEL("strided_copy", + kernel_key, + kernel_signature, + false, + *custom_ctx, + *from, + common::vectorize(to->dims()), + common::vectorize(to->strides()), + to->offset(), + to); + return; + } #endif } } @@ -520,6 +572,33 @@ void TransStride(phi::DeviceContext* dev_ctx, delete from[i]; continue; } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto* custom_ctx = dynamic_cast(dev_ctx); + if (custom_ctx) { + const phi::KernelKey& kernel_key = { + phi::TransToPhiBackend(to[i]->place()), + phi::DataLayout::ALL_LAYOUT, + to[i]->dtype()}; + using kernel_signature = void (*)(const phi::DeviceContext&, + const phi::DenseTensor&, + const std::vector&, + const std::vector&, + int64_t, + phi::DenseTensor*); + PD_VISIT_KERNEL("strided_copy", + kernel_key, + kernel_signature, + false, + *custom_ctx, + *from[i], + common::vectorize(to[i]->dims()), + common::vectorize(to[i]->strides()), + to[i]->offset(), + to[i]); + delete from[i]; + return; + } #endif } } diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 80bb9f44475734..d310d43f4b7e08 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -255,6 +255,27 @@ phi::DenseTensor Trans2Contiguous(const phi::DenseTensor& tensor) { } else if (tensor.place().GetType() == phi::AllocationType::XPU) { auto* dev_ctx = static_cast(pool.Get(tensor.place())); return TensorContiguous(*dev_ctx, tensor); +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + } else if (tensor.place().GetType() == phi::AllocationType::CUSTOM) { + auto* dev_ctx = static_cast(pool.Get(tensor.place())); + phi::DenseTensor dense_out; + phi::MetaTensor meta_input(tensor); + phi::MetaTensor meta_out(&dense_out); + UnchangedInferMeta(meta_input, &meta_out); + const phi::KernelKey& kernel_key = {phi::TransToPhiBackend(tensor.place()), + phi::DataLayout::ALL_LAYOUT, + tensor.dtype()}; + using kernel_signature = void (*)( + const phi::DeviceContext&, const phi::DenseTensor&, phi::DenseTensor*); + PD_VISIT_KERNEL("contiguous", + kernel_key, + kernel_signature, + false, + *dev_ctx, + tensor, + &dense_out); + return dense_out; #endif } else { PADDLE_THROW(phi::errors::Unimplemented( diff --git a/paddle/phi/backends/device_ext.h b/paddle/phi/backends/device_ext.h index bd3f5f687f29b1..3e7497b2209dcd 100644 --- a/paddle/phi/backends/device_ext.h +++ b/paddle/phi/backends/device_ext.h @@ -50,6 +50,7 @@ typedef enum { NCHW, NCDHW, NDHWC, + STRIDED, NUM_DATA_LAYOUTS, ALL_LAYOUT = ANY, } C_DataLayout; diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index f04c1b2c880bd0..32644cfe8bf631 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -249,6 +249,17 @@ KernelResult KernelFactory::SelectKernelOrThrowError( if (stride_kernel_iter != iter->second.end()) { return {stride_kernel_iter->second, false, true}; } +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (stride_kernel_iter == iter->second.end() && + const_kernel_key.backend() > phi::Backend::NUM_BACKENDS) { + stride_kernel_iter = iter->second.find({phi::Backend::CUSTOM, + phi::DataLayout::STRIDED, + const_kernel_key.dtype()}); + if (stride_kernel_iter != iter->second.end()) { + return {stride_kernel_iter->second, false, true}; + } + } +#endif } KernelKey kernel_key = KernelKey(const_kernel_key.backend(), diff --git a/paddle/phi/core/visit_type.h b/paddle/phi/core/visit_type.h index 7ee12e26d7d0ef..ad30da4ddcd6f0 100644 --- a/paddle/phi/core/visit_type.h +++ b/paddle/phi/core/visit_type.h @@ -471,4 +471,20 @@ namespace phi { } \ }() +#define PD_VISIT_KERNEL( \ + kernel_name, kernel_key, kernel_signature, use_strided_kernel, ...) \ + [&] { \ + auto kernel_result = \ + phi::KernelFactory::Instance().SelectKernelOrThrowError( \ + kernel_name, kernel_key, use_strided_kernel); \ + const auto& kernel = kernel_result.kernel; \ + auto* kernel_fn = kernel.GetVariadicKernelFn(); \ + if (kernel_result.has_fallback_cpu) { \ + PADDLE_THROW(phi::errors::NotFound( \ + "The kernel with key %s of kernel `%s` is not registered.", \ + kernel_key, \ + kernel_name)); \ + } \ + (*kernel_fn)(__VA_ARGS__); \ + }() } // namespace phi diff --git a/paddle/phi/kernels/stride/slice_grad_kernel.cc b/paddle/phi/kernels/stride/slice_grad_kernel.cc index 171c20b3b83acd..4504c9a1fda6fb 100644 --- a/paddle/phi/kernels/stride/slice_grad_kernel.cc +++ b/paddle/phi/kernels/stride/slice_grad_kernel.cc @@ -14,10 +14,11 @@ #include "paddle/phi/kernels/slice_grad_kernel.h" #include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/fill_kernel.h" #include "paddle/phi/kernels/slice_kernel.h" -#include "paddle/phi/kernels/strided_copy_kernel.h" +#include "paddle/phi/kernels/stride_funcs.h" namespace phi { @@ -33,10 +34,12 @@ void SliceGradStridedKernel(const Context& dev_ctx, DenseTensor* input_grad) { dev_ctx.Alloc(input_grad, input_grad->dtype()); input_grad->set_strides(DenseTensorMeta::calc_strides(input_grad->dims())); - PD_VISIT_ALL_TYPES(input.dtype(), "SliceGradStridedKernel", ([&] { - phi::FillKernel( - dev_ctx, *input_grad, 0, input_grad); - })); + phi::StridedTensorFill(input.dtype(), + "SliceGradStridedKernel", + dev_ctx, + *input_grad, + 0, + input_grad); DenseTensor tmp; tmp.set_meta(out_grad.meta()); SliceStridedKernel(dev_ctx, @@ -47,17 +50,22 @@ void SliceGradStridedKernel(const Context& dev_ctx, infer_flags, decrease_axis, &tmp); - PD_VISIT_ALL_TYPES(input.dtype(), "SliceGradStridedKernel", ([&] { - phi::StridedCopyKernel( - dev_ctx, - out_grad, - common::vectorize(tmp.dims()), - common::vectorize(tmp.strides()), - tmp.offset(), - &tmp); - })); + phi::StridedTensorCopy(input.dtype(), + "SliceGradStridedKernel", + dev_ctx, + out_grad, + common::vectorize(tmp.dims()), + common::vectorize(tmp.strides()), + tmp.offset(), + &tmp); } - } // namespace phi + +#ifndef PADDLE_WITH_CUSTOM_DEVICE PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( slice_grad, STRIDED, phi::SliceGradStridedKernel) {} +#else +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(slice_grad, + STRIDED, + phi::SliceGradStridedKernel) {} +#endif diff --git a/paddle/phi/kernels/stride/slice_kernel.cc b/paddle/phi/kernels/stride/slice_kernel.cc index 132fb30c314aa7..8961ee039b9828 100644 --- a/paddle/phi/kernels/stride/slice_kernel.cc +++ b/paddle/phi/kernels/stride/slice_kernel.cc @@ -95,5 +95,6 @@ void SliceStridedKernel(const Context& ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - slice, STRIDED, phi::SliceStridedKernel) {} +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(slice, + STRIDED, + phi::SliceStridedKernel) {} diff --git a/paddle/phi/kernels/stride_funcs.h b/paddle/phi/kernels/stride_funcs.h new file mode 100644 index 00000000000000..a8654428adb7e2 --- /dev/null +++ b/paddle/phi/kernels/stride_funcs.h @@ -0,0 +1,88 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_factory.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/fill_kernel.h" +#include "paddle/phi/kernels/strided_copy_kernel.h" + +namespace phi { + +template +inline void StridedTensorCopy(const phi::DataType input_dtype, + std::string kernel_name, + const Context& dev_ctx, + const phi::DenseTensor& input, + const std::vector& dims, + const std::vector& out_stride, + int64_t offset, + phi::DenseTensor* out) { +#ifndef PADDLE_WITH_CUSTOM_DEVICE + PD_VISIT_ALL_TYPES(input_dtype, kernel_name, ([&] { + phi::StridedCopyKernel( + dev_ctx, input, dims, out_stride, offset, out); + })); +#else + (void)kernel_name; + const phi::KernelKey& strided_copy_key = { + phi::TransToPhiBackend(dev_ctx.GetPlace()), + phi::DataLayout::ALL_LAYOUT, + input_dtype}; + using strided_copy_signature = void (*)(const phi::DeviceContext&, + const phi::DenseTensor&, + const std::vector&, + const std::vector&, + int64_t, + phi::DenseTensor*); + PD_VISIT_KERNEL("strided_copy", + strided_copy_key, + strided_copy_signature, + false, + dev_ctx, + input, + dims, + out_stride, + offset, + out); +#endif +} + +template +inline void StridedTensorFill(const phi::DataType input_dtype, + std::string kernel_name, + const Context& dev_ctx, + const phi::DenseTensor& x, + const phi::Scalar& value, + phi::DenseTensor* out) { +#ifndef PADDLE_WITH_CUSTOM_DEVICE + PD_VISIT_ALL_TYPES(input_dtype, kernel_name, ([&] { + phi::FillKernel(dev_ctx, x, value, out); + })); +#else + (void)kernel_name; + const phi::KernelKey& fill_key = {phi::TransToPhiBackend(dev_ctx.GetPlace()), + phi::DataLayout::ALL_LAYOUT, + input_dtype}; + using fill_signature = void (*)(const phi::DeviceContext&, + const phi::DenseTensor&, + const phi::Scalar&, + phi::DenseTensor*); + + PD_VISIT_KERNEL( + "fill", fill_key, fill_signature, false, dev_ctx, x, value, out); +#endif +} +} // namespace phi