Skip to content

Commit 0f15496

Browse files
authored
Reimplement the comparision binary ops using the new optimized CUDA function (#33064)
1 parent e8d6ff5 commit 0f15496

File tree

4 files changed

+120
-36
lines changed

4 files changed

+120
-36
lines changed

paddle/fluid/operators/controlflow/compare_op.cu

Lines changed: 81 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,85 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/controlflow/compare_op.h"
16+
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
1617

17-
REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor,
18-
paddle::operators::GreaterThanFunctor);
19-
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor,
20-
paddle::operators::GreaterEqualFunctor);
21-
REGISTER_COMPARE_KERNEL(greater_than, CUDA,
22-
paddle::operators::GreaterThanFunctor,
23-
paddle::operators::LessThanFunctor);
24-
REGISTER_COMPARE_KERNEL(greater_equal, CUDA,
25-
paddle::operators::GreaterEqualFunctor,
26-
paddle::operators::LessEqualFunctor);
27-
REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor,
28-
paddle::operators::EqualFunctor);
29-
REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor,
30-
paddle::operators::NotEqualFunctor);
18+
namespace ops = paddle::operators;
19+
namespace plat = paddle::platform;
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(Func, op) \
25+
template <typename T, typename Enable = void> \
26+
struct Func##Functor { \
27+
using ELEMENT_TYPE = T; \
28+
inline HOSTDEVICE bool operator()(const T* args) const { \
29+
return args[0] op args[1]; \
30+
} \
31+
};
32+
33+
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThan, <)
34+
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqual, <=)
35+
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThan, >)
36+
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqual, >=)
37+
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqual, ==)
38+
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqual, !=)
39+
#undef DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
40+
41+
template <typename T>
42+
struct CudaEqualFunctor<
43+
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
44+
using ELEMENT_TYPE = T;
45+
HOSTDEVICE bool operator()(const T* args) const {
46+
return fabs(static_cast<double>(args[0] - args[1])) < 1e-8;
47+
}
48+
};
49+
50+
template <typename T>
51+
struct CudaNotEqualFunctor<
52+
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
53+
using ELEMENT_TYPE = T;
54+
HOSTDEVICE bool operator()(const T* args) const {
55+
return fabs(static_cast<double>(args[0] - args[1])) > 1e-8;
56+
}
57+
};
58+
59+
template <typename Functor, typename InverseFunctor>
60+
class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
61+
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
62+
public:
63+
public:
64+
using InT = typename Functor::ELEMENT_TYPE;
65+
using OutT = bool;
66+
void Compute(const framework::ExecutionContext& ctx) const override {
67+
auto functor = Functor();
68+
std::vector<const framework::Tensor*> ins;
69+
std::vector<framework::Tensor*> outs;
70+
71+
PackTensorsIntoVector<OutT>(ctx, &ins, &outs);
72+
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutT>(
73+
ctx, ins, &outs, functor);
74+
}
75+
};
76+
77+
} // namespace operators
78+
} // namespace paddle
79+
80+
#define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \
81+
REGISTER_OP_CUDA_KERNEL( \
82+
op_type, ops::CompareOpKernel<plat::CUDADeviceContext, \
83+
ops::func##Functor<int>, void>, \
84+
ops::CompareOpKernel<plat::CUDADeviceContext, \
85+
ops::func##Functor<int64_t>, void>, \
86+
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func##Functor<float>, \
87+
void>, \
88+
ops::CompareOpKernel<plat::CUDADeviceContext, \
89+
ops::func##Functor<double>, void>);
90+
91+
REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqual)
92+
REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqual)
93+
REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThan)
94+
REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqual)
95+
REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThan)
96+
REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqual)
97+
#undef REGISTER_CUDA_COMPARE_KERNEL

paddle/fluid/operators/elementwise/elementwise_add_op.cu

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,11 @@ class ElementwiseAddKernel<platform::CUDADeviceContext, T>
4242
: public framework::OpKernel<T> {
4343
public:
4444
void Compute(const framework::ExecutionContext& ctx) const override {
45-
auto* x = ctx.Input<framework::LoDTensor>("X");
46-
auto* y = ctx.Input<framework::LoDTensor>("Y");
47-
auto* z = ctx.Output<framework::LoDTensor>("Out");
48-
z->mutable_data<T>(ctx.GetPlace());
49-
int axis = ctx.Attr<int>("axis");
50-
axis = axis == -1 ? std::abs(x->dims().size() - y->dims().size()) : axis;
51-
52-
std::vector<const framework::Tensor*> ins = {x, y};
53-
std::vector<framework::Tensor*> outs = {z};
54-
const auto& cuda_ctx =
55-
ctx.template device_context<platform::CUDADeviceContext>();
56-
45+
std::vector<const framework::Tensor*> ins;
46+
std::vector<framework::Tensor*> outs;
47+
PackTensorsIntoVector<T>(ctx, &ins, &outs);
5748
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
58-
cuda_ctx, ins, &outs, axis, CudaAddFunctor<T>());
49+
ctx, ins, &outs, CudaAddFunctor<T>());
5950
}
6051
};
6152

paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,6 @@ template <typename InT, typename OutT, typename BroadcastArgsWarpper,
343343
__global__ void ElementwiseBroadcastKernel(
344344
BroadcastArgsWarpper broadcast_warpper, int main_tid, int tail_tid) {
345345
int tid = threadIdx.x + blockIdx.x * blockDim.x;
346-
347346
// Vectorized calculation of major data whose length is the max multipler of
348347
// VecSize,
349348
// eg: Calcualting the front 1024-length data in total 1027 data once VecSize
@@ -501,23 +500,30 @@ void LaunchBroadcastElementwiseCudaKernel(
501500
}
502501
}
503502

504-
template <ElementwiseType ET, typename InT, typename OutType, typename Functor>
503+
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
505504
void LaunchElementwiseCudaKernel(
506-
const platform::CUDADeviceContext &cuda_ctx,
505+
const framework::ExecutionContext &ctx,
507506
const std::vector<const framework::Tensor *> &ins,
508-
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
507+
std::vector<framework::Tensor *> *outs, Functor func) {
508+
std::vector<int> dims_size;
509509
bool no_broadcast_flag = true;
510510
for (auto *in : ins) {
511511
no_broadcast_flag = ins[0]->dims() == in->dims();
512+
dims_size.emplace_back(in->dims().size());
512513
}
513-
514+
const auto &cuda_ctx =
515+
ctx.template device_context<platform::CUDADeviceContext>();
514516
if (no_broadcast_flag) {
515-
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutType>(
517+
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutT>(
516518
cuda_ctx, ins, outs, func);
517519
} else {
518-
LaunchBroadcastElementwiseCudaKernel<ElementwiseType::kBinary, InT,
519-
OutType>(cuda_ctx, ins, outs, axis,
520-
func);
520+
int axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
521+
axis = axis == -1
522+
? *std::max_element(dims_size.begin(), dims_size.end()) -
523+
*std::min_element(dims_size.begin(), dims_size.end())
524+
: axis;
525+
LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
526+
axis, func);
521527
}
522528
}
523529

paddle/fluid/operators/elementwise/elementwise_op_function.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,26 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
6060
namespace paddle {
6161
namespace operators {
6262

63+
/*
64+
* To pack the input and output tnesors into vector for
65+
* LaunchElementwiseCudaKernel
66+
*/
67+
template <typename T>
68+
void PackTensorsIntoVector(const framework::ExecutionContext &ctx,
69+
std::vector<const framework::Tensor *> *ins,
70+
std::vector<framework::Tensor *> *outs) {
71+
auto *x = ctx.Input<framework::LoDTensor>("X");
72+
auto *y = ctx.Input<framework::LoDTensor>("Y");
73+
auto *z = ctx.Output<framework::LoDTensor>("Out");
74+
z->mutable_data<T>(ctx.GetPlace());
75+
ins->emplace_back(x);
76+
outs->emplace_back(z);
77+
78+
if (y != nullptr) {
79+
ins->emplace_back(y);
80+
}
81+
}
82+
6383
/*
6484
* Out = X ⊙ Y
6585
* If Y's shape does not match X' shape, they will be reshaped.

0 commit comments

Comments
 (0)