Skip to content

Commit b432d02

Browse files
authored
Support Add Sub Mul Max Min Pow binary functors in elementwise system (#33050)
1 parent 9c52ade commit b432d02

12 files changed

+231
-112
lines changed

paddle/fluid/operators/controlflow/compare_op.cu

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,21 @@ namespace plat = paddle::platform;
2121
namespace paddle {
2222
namespace operators {
2323

24-
#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(Func, op) \
24+
#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(func, op) \
2525
template <typename T, typename Enable = void> \
26-
struct Func##Functor { \
26+
struct func { \
2727
using ELEMENT_TYPE = T; \
2828
inline HOSTDEVICE bool operator()(const T* args) const { \
2929
return args[0] op args[1]; \
3030
} \
3131
};
3232

33-
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThan, <)
34-
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqual, <=)
35-
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThan, >)
36-
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqual, >=)
37-
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqual, ==)
38-
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqual, !=)
33+
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThanFunctor, <)
34+
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqualFunctor, <=)
35+
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThanFunctor, >)
36+
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqualFunctor, >=)
37+
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqualFunctor, ==)
38+
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqualFunctor, !=)
3939
#undef DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
4040

4141
template <typename T>
@@ -67,10 +67,12 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
6767
auto functor = Functor();
6868
std::vector<const framework::Tensor*> ins;
6969
std::vector<framework::Tensor*> outs;
70+
const auto& cuda_ctx =
71+
ctx.template device_context<platform::CUDADeviceContext>();
7072

71-
PackTensorsIntoVector<OutT>(ctx, &ins, &outs);
73+
int axis = PackTensorsIntoVector<OutT>(ctx, &ins, &outs);
7274
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutT>(
73-
ctx, ins, &outs, functor);
75+
cuda_ctx, ins, &outs, axis, functor);
7476
}
7577
};
7678

@@ -79,19 +81,16 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
7981

8082
#define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \
8183
REGISTER_OP_CUDA_KERNEL( \
82-
op_type, ops::CompareOpKernel<plat::CUDADeviceContext, \
83-
ops::func##Functor<int>, void>, \
84-
ops::CompareOpKernel<plat::CUDADeviceContext, \
85-
ops::func##Functor<int64_t>, void>, \
86-
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func##Functor<float>, \
87-
void>, \
88-
ops::CompareOpKernel<plat::CUDADeviceContext, \
89-
ops::func##Functor<double>, void>);
84+
op_type, \
85+
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int>, void>, \
86+
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int64_t>, void>, \
87+
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<float>, void>, \
88+
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<double>, void>);
9089

91-
REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqual)
92-
REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqual)
93-
REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThan)
94-
REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqual)
95-
REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThan)
96-
REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqual)
90+
REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqualFunctor)
91+
REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqualFunctor)
92+
REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThanFunctor)
93+
REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqualFunctor)
94+
REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThanFunctor)
95+
REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqualFunctor)
9796
#undef REGISTER_CUDA_COMPARE_KERNEL

paddle/fluid/operators/elementwise/elementwise_add_op.cu

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ namespace operators {
2828
1. For Unary Op, the length of input array is 1,
2929
e.g. Relu: return args[0] > 0 ? args[0] : 0;
3030
2. For Binary Op, the length of input array is 2,
31-
e.g. Add: return args[0] + args[1];
31+
e.g. Add: return args[0] expr args[1];
3232
*/
3333
template <typename T>
3434
struct CudaAddFunctor {
35-
__device__ __forceinline__ T operator()(const T* args) const {
35+
inline HOSTDEVICE T operator()(const T* args) const {
3636
return args[0] + args[1];
3737
}
3838
};
@@ -44,9 +44,12 @@ class ElementwiseAddKernel<platform::CUDADeviceContext, T>
4444
void Compute(const framework::ExecutionContext& ctx) const override {
4545
std::vector<const framework::Tensor*> ins;
4646
std::vector<framework::Tensor*> outs;
47-
PackTensorsIntoVector<T>(ctx, &ins, &outs);
47+
const auto& cuda_ctx =
48+
ctx.template device_context<platform::CUDADeviceContext>();
49+
50+
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
4851
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
49-
ctx, ins, &outs, CudaAddFunctor<T>());
52+
cuda_ctx, ins, &outs, axis, CudaAddFunctor<T>());
5053
}
5154
};
5255

paddle/fluid/operators/elementwise/elementwise_add_op.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,10 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
7272
auto *z = ctx.Output<framework::LoDTensor>("Out");
7373
z->mutable_data<T>(ctx.GetPlace());
7474
if (x->dims() == y->dims()) {
75-
SameDimsElemwiseAdd<platform::CPUDeviceContext, T>
76-
LaunchElementwiseCpuKernel;
75+
SameDimsElemwiseAdd<DeviceContext, T> LaunchElementwiseCpuKernel;
7776
LaunchElementwiseCpuKernel(ctx, x, y, z);
7877
} else {
79-
LaunchBroadcastElementwiseCpuKernel<platform::CPUDeviceContext, T>(ctx, x,
80-
y, z);
78+
LaunchBroadcastElementwiseCpuKernel<DeviceContext, T>(ctx, x, y, z);
8179
}
8280
}
8381
};

paddle/fluid/operators/elementwise/elementwise_max_op.cu

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
15+
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
16+
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
1517

1618
namespace ops = paddle::operators;
1719

20+
namespace paddle {
21+
namespace operators {
22+
23+
template <typename T>
24+
struct CudaMaxFunctor {
25+
inline HOSTDEVICE T operator()(const T* args) const {
26+
return (args[0] > args[1] ? args[0] : args[1]);
27+
}
28+
};
29+
30+
template <typename T>
31+
class ElementwiseMaxKernel<platform::CUDADeviceContext, T>
32+
: public framework::OpKernel<T> {
33+
public:
34+
void Compute(const framework::ExecutionContext& ctx) const override {
35+
std::vector<const framework::Tensor*> ins;
36+
std::vector<framework::Tensor*> outs;
37+
const auto& cuda_ctx =
38+
ctx.template device_context<platform::CUDADeviceContext>();
39+
40+
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
41+
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
42+
cuda_ctx, ins, &outs, axis, CudaMaxFunctor<T>());
43+
}
44+
};
45+
46+
} // namespace operators
47+
} // namespace paddle
48+
1849
REGISTER_OP_CUDA_KERNEL(
1950
elementwise_max,
2051
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, float>,

paddle/fluid/operators/elementwise/elementwise_min_op.cu

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
15+
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
16+
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
1517

1618
namespace ops = paddle::operators;
1719

20+
namespace paddle {
21+
namespace operators {
22+
23+
template <typename T>
24+
struct CudaMinFunctor {
25+
inline HOSTDEVICE T operator()(const T* args) const {
26+
return (args[0] > args[1] ? args[1] : args[0]);
27+
}
28+
};
29+
30+
template <typename T>
31+
class ElementwiseMinKernel<platform::CUDADeviceContext, T>
32+
: public framework::OpKernel<T> {
33+
public:
34+
void Compute(const framework::ExecutionContext& ctx) const override {
35+
std::vector<const framework::Tensor*> ins;
36+
std::vector<framework::Tensor*> outs;
37+
const auto& cuda_ctx =
38+
ctx.template device_context<platform::CUDADeviceContext>();
39+
40+
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
41+
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
42+
cuda_ctx, ins, &outs, axis, CudaMinFunctor<T>());
43+
}
44+
};
45+
46+
} // namespace operators
47+
} // namespace paddle
48+
1849
REGISTER_OP_CUDA_KERNEL(
1950
elementwise_min,
2051
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, float>,

paddle/fluid/operators/elementwise/elementwise_mul_op.cu

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
16+
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
1617
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
1718
#include "paddle/fluid/platform/complex.h"
1819
#include "paddle/fluid/platform/float16.h"
@@ -24,37 +25,65 @@ namespace paddle {
2425
namespace operators {
2526

2627
template <typename T>
27-
struct SameDimsElemwiseMul<platform::CUDADeviceContext, T> {
28-
void operator()(const framework::ExecutionContext& ctx,
29-
const framework::Tensor* x, const framework::Tensor* y,
30-
framework::Tensor* z) {
31-
MulRangeFunctor<T> functor(x->data<T>(), y->data<T>(), z->data<T>());
32-
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
33-
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
34-
x->numel());
35-
for_range(functor);
28+
struct CudaMulFunctor {
29+
inline HOSTDEVICE T operator()(const T* args) const {
30+
return args[0] * args[1];
3631
}
3732
};
3833

39-
template <>
40-
struct SameDimsElemwiseMul<platform::CUDADeviceContext, platform::float16> {
41-
void operator()(const framework::ExecutionContext& ctx,
42-
const framework::Tensor* x, const framework::Tensor* y,
43-
framework::Tensor* z) {
44-
auto size = x->numel();
45-
dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) /
46-
PADDLE_CUDA_THREAD_SIZE,
47-
1);
48-
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
49-
const half* x2 =
50-
reinterpret_cast<const half*>(x->data<platform::float16>());
51-
const half* y2 =
52-
reinterpret_cast<const half*>(y->data<platform::float16>());
53-
half* z2 = reinterpret_cast<half*>(z->data<platform::float16>());
54-
SameDimsElemwiseMulCUDAKernel<<<
55-
grid_size, block_size, 0,
56-
ctx.template device_context<platform::CUDADeviceContext>().stream()>>>(
57-
x2, y2, z2, size);
34+
template <typename T>
35+
class ElementwiseMulKernel<platform::CUDADeviceContext, T>
36+
: public framework::OpKernel<T> {
37+
public:
38+
void Compute(const framework::ExecutionContext& ctx) const override {
39+
int axis = -1;
40+
auto x_var = ctx.InputVar("X");
41+
PADDLE_ENFORCE_NOT_NULL(
42+
x_var, platform::errors::InvalidArgument(
43+
"Cannot get input Variable X, Variable name = %s.",
44+
ctx.InputName("X")));
45+
auto* y = ctx.Input<framework::LoDTensor>("Y");
46+
47+
framework::Tensor x, *z;
48+
std::vector<const framework::Tensor*> ins;
49+
std::vector<framework::Tensor*> outs;
50+
const auto& cuda_ctx =
51+
ctx.template device_context<platform::CUDADeviceContext>();
52+
53+
if (x_var->IsType<framework::LoDTensor>()) {
54+
x = x_var->Get<framework::LoDTensor>();
55+
z = ctx.Output<framework::LoDTensor>("Out");
56+
axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
57+
} else if (x_var->IsType<framework::SelectedRows>()) {
58+
PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true,
59+
platform::errors::InvalidArgument(
60+
"For elementwise_op, if X is Sparse, Y must be "
61+
"scalar. But reveived the size of Y = %s.",
62+
y->dims().size()));
63+
auto& x_sele = x_var->Get<framework::SelectedRows>();
64+
auto out_sele = ctx.Output<framework::SelectedRows>("Out");
65+
x = x_sele.value();
66+
out_sele->set_rows(x_sele.rows());
67+
out_sele->set_height(x_sele.height());
68+
out_sele->mutable_value()->Resize(x_sele.value().dims());
69+
out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type());
70+
z = ctx.Output<framework::SelectedRows>("Out")->mutable_value();
71+
z->mutable_data<T>(ctx.GetPlace());
72+
outs.emplace_back(z);
73+
ins.emplace_back(&x);
74+
ins.emplace_back(y);
75+
76+
axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
77+
axis = axis == -1 ? std::abs(y->dims().size() - x.dims().size()) : axis;
78+
} else {
79+
PADDLE_THROW(platform::errors::InvalidArgument(
80+
"X's type[%s] is not supported by elementwise_op. X's type should be "
81+
"LoDTensor or SelectedRows.",
82+
framework::ToTypeName(x_var->Type())));
83+
}
84+
85+
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
86+
cuda_ctx, ins, &outs, axis, CudaMulFunctor<T>());
5887
}
5988
};
6089

paddle/fluid/operators/elementwise/elementwise_mul_op.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
126126
}
127127
}
128128
};
129-
130129
template <typename T>
131130
struct MulGradDX {
132131
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; }

paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,11 @@ void LaunchBroadcastElementwiseCudaKernel(
465465
const platform::CUDADeviceContext &ctx,
466466
const std::vector<const framework::Tensor *> &ins,
467467
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
468-
static_assert(ET == (ElementwiseType)2, "Only Support binary calculation.");
468+
PADDLE_ENFORCE_EQ(ET, ElementwiseType::kBinary,
469+
platform::errors::InvalidArgument(
470+
"Currently, only Support binary calculation, "
471+
"but received %d input tensors.\n",
472+
static_cast<int>(ET)));
469473
int in_vec_size = 4;
470474
framework::Tensor *out = (*outs)[0];
471475
for (auto *in : ins) {
@@ -502,26 +506,18 @@ void LaunchBroadcastElementwiseCudaKernel(
502506

503507
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
504508
void LaunchElementwiseCudaKernel(
505-
const framework::ExecutionContext &ctx,
509+
const platform::CUDADeviceContext &cuda_ctx,
506510
const std::vector<const framework::Tensor *> &ins,
507-
std::vector<framework::Tensor *> *outs, Functor func) {
508-
std::vector<int> dims_size;
511+
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
509512
bool no_broadcast_flag = true;
510513
for (auto *in : ins) {
511514
no_broadcast_flag = ins[0]->dims() == in->dims();
512-
dims_size.emplace_back(in->dims().size());
513515
}
514-
const auto &cuda_ctx =
515-
ctx.template device_context<platform::CUDADeviceContext>();
516+
516517
if (no_broadcast_flag) {
517-
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutT>(
518-
cuda_ctx, ins, outs, func);
518+
LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
519+
func);
519520
} else {
520-
int axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
521-
axis = axis == -1
522-
? *std::max_element(dims_size.begin(), dims_size.end()) -
523-
*std::min_element(dims_size.begin(), dims_size.end())
524-
: axis;
525521
LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
526522
axis, func);
527523
}

paddle/fluid/operators/elementwise/elementwise_op_function.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,24 @@ namespace operators {
6464
* To pack the input and output tnesors into vector for
6565
* LaunchElementwiseCudaKernel
6666
*/
67-
template <typename T>
68-
void PackTensorsIntoVector(const framework::ExecutionContext &ctx,
69-
std::vector<const framework::Tensor *> *ins,
70-
std::vector<framework::Tensor *> *outs) {
67+
template <typename OutT>
68+
int PackTensorsIntoVector(const framework::ExecutionContext &ctx,
69+
std::vector<const framework::Tensor *> *ins,
70+
std::vector<framework::Tensor *> *outs) {
71+
int axis = -1;
7172
auto *x = ctx.Input<framework::LoDTensor>("X");
7273
auto *y = ctx.Input<framework::LoDTensor>("Y");
7374
auto *z = ctx.Output<framework::LoDTensor>("Out");
74-
z->mutable_data<T>(ctx.GetPlace());
75-
ins->emplace_back(x);
75+
z->mutable_data<OutT>(ctx.GetPlace());
7676
outs->emplace_back(z);
77+
ins->emplace_back(x);
7778

7879
if (y != nullptr) {
7980
ins->emplace_back(y);
81+
axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
82+
axis = axis == -1 ? std::abs(y->dims().size() - x->dims().size()) : axis;
8083
}
84+
return axis;
8185
}
8286

8387
/*

0 commit comments

Comments
 (0)