Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
d3e3e31
customdevice support stride slice & slice_grad
Tongkaio Mar 7, 2024
cbbed37
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Tongkaio Mar 7, 2024
74db2fd
refine & remove redundant code
Tongkaio Mar 7, 2024
ec846be
refine
Tongkaio Mar 7, 2024
d23aa3f
remove VLOG
Tongkaio Mar 7, 2024
28375d8
replace input.place() with input_grad->place()& include kernel_factory.h
Tongkaio Mar 7, 2024
6b3a137
include kernel_factory.h
Tongkaio Mar 7, 2024
e6a45b1
set different kernel_keys for different kernels
Tongkaio Mar 8, 2024
4e49c71
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Tongkaio Mar 8, 2024
7f52d27
fix syntax error
Tongkaio Mar 8, 2024
6532edd
if missing kernel, throw error
Tongkaio Mar 8, 2024
c005dbb
use dev_ctx.GetPlace() to get place
Tongkaio Mar 8, 2024
91c9065
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Tongkaio Mar 8, 2024
eaa1d65
modify error message
Tongkaio Mar 8, 2024
5cca2a8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Tongkaio Mar 8, 2024
c24b010
add stride_func.h
Tongkaio Mar 9, 2024
fec9ace
move stride_funcs.h to kernels dir
Tongkaio Mar 9, 2024
438337c
set FLAGS_use_stride_kernel false for customdevice
Tongkaio Mar 9, 2024
ef214c1
reset use_stride_kernel for customdevice
Tongkaio Mar 10, 2024
c35113f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Tongkaio Mar 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ PHI_DECLARE_bool(use_stride_kernel);
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/kernel_factory.h"

namespace paddle {
namespace experimental {
Expand Down Expand Up @@ -416,6 +417,32 @@ void TransStride(phi::DeviceContext* dev_ctx,
delete from;
return;
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto* custom_ctx = dynamic_cast<phi::CustomContext*>(dev_ctx);
if (custom_ctx) {
const phi::KernelKey& kernel_key = {phi::TransToPhiBackend(to->place()),
phi::DataLayout::ALL_LAYOUT,
to->dtype()};
using kernel_signature = void (*)(const phi::DeviceContext&,
const phi::DenseTensor&,
const std::vector<int64_t>&,
const std::vector<int64_t>&,
int64_t,
phi::DenseTensor*);
PD_VISIT_KERNEL("strided_copy",
kernel_key,
kernel_signature,
false,
*custom_ctx,
*from,
common::vectorize<int64_t>(to->dims()),
common::vectorize<int64_t>(to->strides()),
to->offset(),
to);
delete from;
return;
}
#endif
}
}
Expand Down Expand Up @@ -466,6 +493,31 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx,
}));
return;
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto* custom_ctx = dynamic_cast<phi::CustomContext*>(dev_ctx);
if (custom_ctx) {
const phi::KernelKey& kernel_key = {phi::TransToPhiBackend(to->place()),
phi::DataLayout::ALL_LAYOUT,
to->dtype()};
using kernel_signature = void (*)(const phi::DeviceContext&,
const phi::DenseTensor&,
const std::vector<int64_t>&,
const std::vector<int64_t>&,
int64_t,
phi::DenseTensor*);
PD_VISIT_KERNEL("strided_copy",
kernel_key,
kernel_signature,
false,
*custom_ctx,
*from,
common::vectorize<int64_t>(to->dims()),
common::vectorize<int64_t>(to->strides()),
to->offset(),
to);
return;
}
#endif
}
}
Expand Down Expand Up @@ -520,6 +572,33 @@ void TransStride(phi::DeviceContext* dev_ctx,
delete from[i];
continue;
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto* custom_ctx = dynamic_cast<phi::CustomContext*>(dev_ctx);
if (custom_ctx) {
const phi::KernelKey& kernel_key = {
phi::TransToPhiBackend(to[i]->place()),
phi::DataLayout::ALL_LAYOUT,
to[i]->dtype()};
using kernel_signature = void (*)(const phi::DeviceContext&,
const phi::DenseTensor&,
const std::vector<int64_t>&,
const std::vector<int64_t>&,
int64_t,
phi::DenseTensor*);
PD_VISIT_KERNEL("strided_copy",
kernel_key,
kernel_signature,
false,
*custom_ctx,
*from[i],
common::vectorize<int64_t>(to[i]->dims()),
common::vectorize<int64_t>(to[i]->strides()),
to[i]->offset(),
to[i]);
delete from[i];
return;
}
#endif
}
}
Expand Down
21 changes: 21 additions & 0 deletions paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,27 @@ phi::DenseTensor Trans2Contiguous(const phi::DenseTensor& tensor) {
} else if (tensor.place().GetType() == phi::AllocationType::XPU) {
auto* dev_ctx = static_cast<phi::XPUContext*>(pool.Get(tensor.place()));
return TensorContiguous<phi::XPUContext>(*dev_ctx, tensor);
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
} else if (tensor.place().GetType() == phi::AllocationType::CUSTOM) {
auto* dev_ctx = static_cast<phi::CustomContext*>(pool.Get(tensor.place()));
phi::DenseTensor dense_out;
phi::MetaTensor meta_input(tensor);
phi::MetaTensor meta_out(&dense_out);
UnchangedInferMeta(meta_input, &meta_out);
const phi::KernelKey& kernel_key = {phi::TransToPhiBackend(tensor.place()),
phi::DataLayout::ALL_LAYOUT,
tensor.dtype()};
using kernel_signature = void (*)(
const phi::DeviceContext&, const phi::DenseTensor&, phi::DenseTensor*);
PD_VISIT_KERNEL("contiguous",
kernel_key,
kernel_signature,
false,
*dev_ctx,
tensor,
&dense_out);
return dense_out;
#endif
} else {
PADDLE_THROW(phi::errors::Unimplemented(
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/backends/device_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ typedef enum {
NCHW,
NCDHW,
NDHWC,
STRIDED,
NUM_DATA_LAYOUTS,
ALL_LAYOUT = ANY,
} C_DataLayout;
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/core/kernel_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,17 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
if (stride_kernel_iter != iter->second.end()) {
return {stride_kernel_iter->second, false, true};
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (stride_kernel_iter == iter->second.end() &&
const_kernel_key.backend() > phi::Backend::NUM_BACKENDS) {
stride_kernel_iter = iter->second.find({phi::Backend::CUSTOM,
phi::DataLayout::STRIDED,
const_kernel_key.dtype()});
if (stride_kernel_iter != iter->second.end()) {
return {stride_kernel_iter->second, false, true};
}
}
#endif
}

KernelKey kernel_key = KernelKey(const_kernel_key.backend(),
Expand Down
16 changes: 16 additions & 0 deletions paddle/phi/core/visit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -471,4 +471,20 @@ namespace phi {
} \
}()

#define PD_VISIT_KERNEL( \
kernel_name, kernel_key, kernel_signature, use_strided_kernel, ...) \
[&] { \
auto kernel_result = \
phi::KernelFactory::Instance().SelectKernelOrThrowError( \
kernel_name, kernel_key, use_strided_kernel); \
const auto& kernel = kernel_result.kernel; \
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>(); \
if (kernel_result.has_fallback_cpu) { \
PADDLE_THROW(phi::errors::NotFound( \
"The kernel with key %s of kernel `%s` is not registered.", \
kernel_key, \
kernel_name)); \
} \
(*kernel_fn)(__VA_ARGS__); \
}()
} // namespace phi
38 changes: 23 additions & 15 deletions paddle/phi/kernels/stride/slice_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

#include "paddle/phi/kernels/slice_grad_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/fill_kernel.h"
#include "paddle/phi/kernels/slice_kernel.h"
#include "paddle/phi/kernels/strided_copy_kernel.h"
#include "paddle/phi/kernels/stride_funcs.h"

namespace phi {

Expand All @@ -33,10 +34,12 @@ void SliceGradStridedKernel(const Context& dev_ctx,
DenseTensor* input_grad) {
dev_ctx.Alloc(input_grad, input_grad->dtype());
input_grad->set_strides(DenseTensorMeta::calc_strides(input_grad->dims()));
PD_VISIT_ALL_TYPES(input.dtype(), "SliceGradStridedKernel", ([&] {
phi::FillKernel<data_t, Context>(
dev_ctx, *input_grad, 0, input_grad);
}));
phi::StridedTensorFill<Context>(input.dtype(),
"SliceGradStridedKernel",
dev_ctx,
*input_grad,
0,
input_grad);
DenseTensor tmp;
tmp.set_meta(out_grad.meta());
SliceStridedKernel<Context>(dev_ctx,
Expand All @@ -47,17 +50,22 @@ void SliceGradStridedKernel(const Context& dev_ctx,
infer_flags,
decrease_axis,
&tmp);
PD_VISIT_ALL_TYPES(input.dtype(), "SliceGradStridedKernel", ([&] {
phi::StridedCopyKernel<data_t, Context>(
dev_ctx,
out_grad,
common::vectorize<int64_t>(tmp.dims()),
common::vectorize<int64_t>(tmp.strides()),
tmp.offset(),
&tmp);
}));
phi::StridedTensorCopy<Context>(input.dtype(),
"SliceGradStridedKernel",
dev_ctx,
out_grad,
common::vectorize<int64_t>(tmp.dims()),
common::vectorize<int64_t>(tmp.strides()),
tmp.offset(),
&tmp);
}

} // namespace phi

#ifndef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM(
slice_grad, STRIDED, phi::SliceGradStridedKernel) {}
#else
PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(slice_grad,
STRIDED,
phi::SliceGradStridedKernel) {}
#endif
5 changes: 3 additions & 2 deletions paddle/phi/kernels/stride/slice_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,6 @@ void SliceStridedKernel(const Context& ctx,
}

} // namespace phi
PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM(
slice, STRIDED, phi::SliceStridedKernel) {}
PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(slice,
STRIDED,
phi::SliceStridedKernel) {}
88 changes: 88 additions & 0 deletions paddle/phi/kernels/stride_funcs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/fill_kernel.h"
#include "paddle/phi/kernels/strided_copy_kernel.h"

namespace phi {

template <typename Context>
inline void StridedTensorCopy(const phi::DataType input_dtype,
std::string kernel_name,
const Context& dev_ctx,
const phi::DenseTensor& input,
const std::vector<int64_t>& dims,
const std::vector<int64_t>& out_stride,
int64_t offset,
phi::DenseTensor* out) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
PD_VISIT_ALL_TYPES(input_dtype, kernel_name, ([&] {
phi::StridedCopyKernel<data_t, Context>(
dev_ctx, input, dims, out_stride, offset, out);
}));
#else
(void)kernel_name;
const phi::KernelKey& strided_copy_key = {
phi::TransToPhiBackend(dev_ctx.GetPlace()),
phi::DataLayout::ALL_LAYOUT,
input_dtype};
using strided_copy_signature = void (*)(const phi::DeviceContext&,
const phi::DenseTensor&,
const std::vector<int64_t>&,
const std::vector<int64_t>&,
int64_t,
phi::DenseTensor*);
PD_VISIT_KERNEL("strided_copy",
strided_copy_key,
strided_copy_signature,
false,
dev_ctx,
input,
dims,
out_stride,
offset,
out);
#endif
}

template <typename Context>
inline void StridedTensorFill(const phi::DataType input_dtype,
std::string kernel_name,
const Context& dev_ctx,
const phi::DenseTensor& x,
const phi::Scalar& value,
phi::DenseTensor* out) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
PD_VISIT_ALL_TYPES(input_dtype, kernel_name, ([&] {
phi::FillKernel<data_t, Context>(dev_ctx, x, value, out);
}));
#else
(void)kernel_name;
const phi::KernelKey& fill_key = {phi::TransToPhiBackend(dev_ctx.GetPlace()),
phi::DataLayout::ALL_LAYOUT,
input_dtype};
using fill_signature = void (*)(const phi::DeviceContext&,
const phi::DenseTensor&,
const phi::Scalar&,
phi::DenseTensor*);

PD_VISIT_KERNEL(
"fill", fill_key, fill_signature, false, dev_ctx, x, value, out);
#endif
}
} // namespace phi