Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions paddle/phi/kernels/stride/bitwise_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,32 @@
COMMON_DECLARE_bool(use_stride_kernel);
COMMON_DECLARE_bool(use_stride_compute_kernel);
namespace phi {

template <typename T, typename Context, typename Functor>
void LaunchBinaryElementwiseStrideKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
Functor func,
int axis,
DenseTensor *out) {
std::vector<const DenseTensor *> inputs = {&x, &y};
std::vector<DenseTensor *> outputs = {out};
dev_ctx.template Alloc<T>(out);
BinaryStrideBroadcastKernel<T, Context>(
dev_ctx, inputs, &outputs, func, axis);
}

template <typename T, typename Context, typename Functor>
void LaunchUnaryElementwiseStrideKernel(const Context &dev_ctx,
const DenseTensor &x,
Functor func,
DenseTensor *out) {
std::vector<const DenseTensor *> inputs = {&x};
std::vector<DenseTensor *> outputs = {out};
dev_ctx.template Alloc<T>(out);
UnaryStrideElementwiseKernel<T, Context>(dev_ctx, inputs, &outputs, func);
}

#define DEFINE_CUDA_BINARY_ELEMENTWISE_STRIDE_OP(name) \
template <typename T, typename Context> \
void name##StrideKernel(const Context &dev_ctx, \
Expand Down Expand Up @@ -73,6 +99,118 @@ namespace phi {
DEFINE_CUDA_BINARY_ELEMENTWISE_STRIDE_OP(BitwiseAnd)
DEFINE_CUDA_BINARY_ELEMENTWISE_STRIDE_OP(BitwiseOr)
DEFINE_CUDA_BINARY_ELEMENTWISE_STRIDE_OP(BitwiseXor)

#define DEFINE_CUDA_BINARY_ELEMENTWISE_WITH_BOOL_STRIDE_OP(name) \
template <typename T, typename Context> \
void Bitwise##name##StrideKernel(const Context &dev_ctx, \
const DenseTensor &x, \
const DenseTensor &y, \
bool is_arithmetic, \
DenseTensor *out) { \
if (!FLAGS_use_stride_kernel) { \
PADDLE_THROW(common::errors::Fatal( \
"FLAGS_use_stride_kernel is closed. Strided kernel " \
"be called, something wrong has happened!")); \
} \
DenseTensor x_; \
DenseTensor y_; \
if (!FLAGS_use_stride_compute_kernel || x.offset() != 0 || \
y.offset() != 0) { \
if (!x.meta().is_contiguous() || x.offset() != 0) { \
x_ = Tensor2Contiguous<Context>(dev_ctx, x); \
} else { \
x_ = x; \
} \
if (!y.meta().is_contiguous() || y.offset() != 0) { \
y_ = Tensor2Contiguous<Context>(dev_ctx, y); \
} else { \
y_ = y; \
} \
} else { \
x_ = x; \
y_ = y; \
} \
if (x_.meta().is_contiguous() && y_.meta().is_contiguous()) { \
auto meta = out->meta(); \
meta.strides = meta.calc_strides(out->dims()); \
out->set_meta(meta); \
dev_ctx.template Alloc<T>(out); \
std::vector<const DenseTensor *> ins = {&x_, &y_}; \
std::vector<DenseTensor *> outs = {out}; \
if (is_arithmetic) { \
funcs::Bitwise##name##ArithmeticFunctor<T> func; \
funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, func); \
} else { \
funcs::Bitwise##name##LogicFunctor<T> func; \
funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, func); \
} \
return; \
} \
if (!FLAGS_use_stride_compute_kernel) { \
PADDLE_THROW( \
common::errors::Fatal("FLAGS_use_stride_compute_kernel is closed. " \
"Kernel using DenseTensorIterator " \
"be called, something wrong has happened!")); \
} \
if (is_arithmetic) { \
LaunchBinaryElementwiseStrideKernel<T, Context>( \
dev_ctx, \
x_, \
y_, \
funcs::Bitwise##name##ArithmeticFunctor<T>(), \
-1, \
out); \
} else { \
LaunchBinaryElementwiseStrideKernel<T, Context>( \
dev_ctx, x_, y_, funcs::Bitwise##name##LogicFunctor<T>(), -1, out); \
} \
}

DEFINE_CUDA_BINARY_ELEMENTWISE_WITH_BOOL_STRIDE_OP(LeftShift)
DEFINE_CUDA_BINARY_ELEMENTWISE_WITH_BOOL_STRIDE_OP(RightShift)
#undef DEFINE_CUDA_BINARY_ELEMENTWISE_WITH_BOOL_STRIDE_OP

template <typename T, typename Context>
void BitwiseNotStrideKernel(const Context &dev_ctx,
const DenseTensor &x,
DenseTensor *out) {
if (!FLAGS_use_stride_kernel) {
PADDLE_THROW(common::errors::Fatal(
"FLAGS_use_stride_kernel is closed. Strided kernel "
"be called, something wrong has happened!"));
}
DenseTensor x_;
if (!FLAGS_use_stride_compute_kernel || x.offset() != 0) {
if (!x.meta().is_contiguous() || x.offset() != 0) {
x_ = Tensor2Contiguous<Context>(dev_ctx, x);
} else {
x_ = x;
}
} else {
x_ = x;
}
if (x_.meta().is_contiguous()) {
auto meta = out->meta();
meta.strides = meta.calc_strides(out->dims());
out->set_meta(meta);
dev_ctx.template Alloc<T>(out);
std::vector<const DenseTensor *> ins = {&x_};
std::vector<DenseTensor *> outs = {out};
funcs::BitwiseNotFunctor<T> unary_func;
funcs::ElementwiseKernel<T, funcs::BitwiseNotFunctor<T>>(
dev_ctx, ins, &outs, unary_func);
return;
}
if (!FLAGS_use_stride_compute_kernel) {
PADDLE_THROW(
common::errors::Fatal("FLAGS_use_stride_compute_kernel is closed. "
"Kernel using DenseTensorIterator "
"be called, something wrong has happened!"));
}
LaunchUnaryElementwiseStrideKernel<T, Context>(
dev_ctx, x_, funcs::BitwiseNotFunctor<T>(), out);
}

} // namespace phi
using float16 = phi::dtype::float16;
using bfloat16 = phi::dtype::bfloat16;
Expand Down Expand Up @@ -108,4 +246,35 @@ PD_REGISTER_KERNEL(bitwise_xor,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_left_shift,
GPU,
STRIDED,
phi::BitwiseLeftShiftStrideKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_right_shift,
GPU,
STRIDED,
phi::BitwiseRightShiftStrideKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_not,
GPU,
STRIDED,
phi::BitwiseNotStrideKernel,
bool,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
#endif
Loading
Loading