Skip to content

Commit fcd93b3

Browse files
authored
Support Div and FloorDiv functor in elementwise system (#33053)
1 parent cd95ea8 commit fcd93b3

File tree

4 files changed

+114
-46
lines changed

4 files changed

+114
-46
lines changed

paddle/fluid/operators/elementwise/elementwise_div_op.cu

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
15-
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
16-
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
15+
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
1716
#include "paddle/fluid/platform/complex.h"
1817
#include "paddle/fluid/platform/float16.h"
1918

@@ -23,38 +22,37 @@ namespace plat = paddle::platform;
2322
namespace paddle {
2423
namespace operators {
2524

25+
template <typename T, typename Enable = void>
26+
struct CudaDivFunctor {
27+
inline HOSTDEVICE T operator()(const T* args) const {
28+
return args[0] / args[1];
29+
}
30+
};
31+
2632
template <typename T>
27-
struct SameDimsElemwiseDiv<platform::CUDADeviceContext, T> {
28-
void operator()(const framework::ExecutionContext& ctx,
29-
const framework::Tensor* x, const framework::Tensor* y,
30-
framework::Tensor* z) {
31-
DivRangeFunctor<T> functor(x->data<T>(), y->data<T>(), z->data<T>());
32-
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
33-
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
34-
x->numel());
35-
for_range(functor);
33+
struct CudaDivFunctor<T,
34+
typename std::enable_if_t<std::is_integral<T>::value>> {
35+
inline HOSTDEVICE T operator()(const T* args) const {
36+
PADDLE_ENFORCE(args[1] != 0,
37+
"Invalid Argument Error: Integer division by zero "
38+
"encountered in divide. Please check the input value.");
39+
return args[0] / args[1];
3640
}
3741
};
3842

39-
template <>
40-
struct SameDimsElemwiseDiv<platform::CUDADeviceContext, platform::float16> {
41-
void operator()(const framework::ExecutionContext& ctx,
42-
const framework::Tensor* x, const framework::Tensor* y,
43-
framework::Tensor* z) {
44-
auto size = x->numel();
45-
dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) /
46-
PADDLE_CUDA_THREAD_SIZE,
47-
1);
48-
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
49-
const half* x2 =
50-
reinterpret_cast<const half*>(x->data<platform::float16>());
51-
const half* y2 =
52-
reinterpret_cast<const half*>(y->data<platform::float16>());
53-
half* z2 = reinterpret_cast<half*>(z->data<platform::float16>());
54-
SameDimsElemwiseDivCUDAKernel<<<
55-
grid_size, block_size, 0,
56-
ctx.template device_context<platform::CUDADeviceContext>().stream()>>>(
57-
x2, y2, z2, size);
43+
template <typename T>
44+
class ElementwiseDivKernel<platform::CUDADeviceContext, T>
45+
: public framework::OpKernel<T> {
46+
public:
47+
void Compute(const framework::ExecutionContext& ctx) const override {
48+
std::vector<const framework::Tensor*> ins;
49+
std::vector<framework::Tensor*> outs;
50+
const auto& cuda_ctx =
51+
ctx.template device_context<platform::CUDADeviceContext>();
52+
53+
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
54+
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
55+
cuda_ctx, ins, &outs, axis, CudaDivFunctor<T>());
5856
}
5957
};
6058

paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,43 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414
#include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h"
15-
#include "paddle/fluid/platform/float16.h"
15+
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
1616

1717
namespace ops = paddle::operators;
1818
namespace plat = paddle::platform;
1919

20+
namespace paddle {
21+
namespace operators {
22+
23+
template <typename T>
24+
struct CudaFloorDivFunctor {
25+
inline HOSTDEVICE T operator()(const T argv[]) const {
26+
PADDLE_ENFORCE(argv[1] != 0,
27+
"InvalidArgument: divide by zero "
28+
"encountered in floor-divide ops, please check.\n");
29+
return static_cast<T>(std::trunc(argv[0] / argv[1]));
30+
}
31+
};
32+
33+
template <typename T>
34+
class ElementwiseFloorDivKernel<platform::CUDADeviceContext, T>
35+
: public framework::OpKernel<T> {
36+
public:
37+
void Compute(const framework::ExecutionContext& ctx) const override {
38+
std::vector<const framework::Tensor*> ins;
39+
std::vector<framework::Tensor*> outs;
40+
const auto& cuda_ctx =
41+
ctx.template device_context<platform::CUDADeviceContext>();
42+
43+
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
44+
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
45+
cuda_ctx, ins, &outs, axis, CudaFloorDivFunctor<T>());
46+
}
47+
};
48+
49+
} // namespace operators
50+
} // namespace paddle
51+
2052
REGISTER_OP_CUDA_KERNEL(
2153
elementwise_floordiv,
2254
ops::ElementwiseFloorDivKernel<plat::CUDADeviceContext, int>,

paddle/fluid/operators/elementwise/elementwise_floordiv_op.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ limitations under the License. */
1616

1717
#include "paddle/fluid/framework/eigen.h"
1818
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
19-
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
2019
#include "paddle/fluid/operators/math/blas.h"
2120

2221
namespace paddle {

paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ limitations under the License. */
1414
#pragma once
1515

1616
#include "paddle/fluid/framework/tensor.h"
17-
#include "paddle/fluid/platform/device_context.h"
17+
#include "paddle/fluid/platform/cuda_device_function.h"
1818
#include "paddle/fluid/platform/fast_divmod.h"
1919

2020
#ifdef __HIPCC__
@@ -28,19 +28,62 @@ namespace operators {
2828

2929
enum ElementwiseType { kUnary = 1, kBinary = 2 };
3030

31+
/*
32+
* According to NVIDIA, if number of threads per block is 64/128/256/512,
33+
* cuda performs better. And number of blocks should be greater (at least
34+
* 2x~4x) than number of SMs. Hence, SM count is took into account within
35+
* this function to determine the right number of threads per block.
36+
*/
37+
inline int GetThreadsConfig(const platform::CUDADeviceContext &ctx,
38+
int64_t numel, int vec_size) {
39+
int threads = ELEMENTWISE_BLOCK_SIZE;
40+
int sm_count = ctx.GetSMCount();
41+
int active_threads_num = numel / vec_size;
42+
if (active_threads_num / (sm_count << 1) < ELEMENTWISE_BLOCK_SIZE) {
43+
// Round up threads number into an exponential multiple of 2, while number
44+
// of acitve blocks is about twice of SM, to acquire better performance.
45+
threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 1));
46+
} else if (active_threads_num / (sm_count << 2) < ELEMENTWISE_BLOCK_SIZE) {
47+
// Round up threads number into an exponential multiple of 2, while number
48+
// of acitve blocks is about 4 times of SM, to acquire better performance.
49+
threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 2));
50+
}
51+
// Number of threads per block shall be larger than 64.
52+
return std::max(64, threads);
53+
}
54+
55+
/*
56+
* Only the address of input data is the multiplier of 1,2,4, vectorized load
57+
* with corresponding multiplier-value is possible. Moreover, the maximum length
58+
* of vectorized load is 128 bits once. Hence, valid length of vectorized load
59+
* shall be determined under both former constraints.
60+
*/
3161
template <typename T>
3262
int GetVectorizedSizeImpl(const T *pointer) {
63+
constexpr int max_load_bits = 128;
64+
int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T);
3365
uint64_t address = reinterpret_cast<uint64_t>(pointer);
66+
constexpr int vec8 =
67+
std::alignment_of<CudaAlignedVector<T, 8>>::value; // NOLINT
3468
constexpr int vec4 =
3569
std::alignment_of<CudaAlignedVector<T, 4>>::value; // NOLINT
3670
constexpr int vec2 =
3771
std::alignment_of<CudaAlignedVector<T, 2>>::value; // NOLINT
38-
if (address % vec4 == 0) {
39-
return 4;
72+
if (address % vec8 == 0) {
73+
/*
74+
* Currently, decide to deal with no more than 4 data once while adopting
75+
* vectorization load/store, if performance test shows that dealing with
76+
* 8 data once in vectorization load/store does get optimized, return code
77+
* below can be changed into " return std::min(8, valid_vec_size); " .
78+
*/
79+
return std::min(4, valid_vec_size);
80+
} else if (address % vec4 == 0) {
81+
return std::min(4, valid_vec_size);
4082
} else if (address % vec2 == 0) {
41-
return 2;
83+
return std::min(2, valid_vec_size);
84+
} else {
85+
return 1;
4286
}
43-
return 1;
4487
}
4588

4689
template <typename InT, typename OutT>
@@ -96,42 +139,38 @@ struct ElementwiseDataWrapper {
96139

97140
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
98141
typename Functor>
99-
__device__ void VectorizedKernelImpl(
142+
__device__ inline void VectorizedKernelImpl(
100143
ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
101144
int tid) {
102145
using InVecType = CudaAlignedVector<InT, VecSize>;
103146
using OutVecType = CudaAlignedVector<OutT, VecSize>;
104147
InVecType ins_vec[ET];
105148
OutVecType out_vec;
106149
InT *ins_ptr[ET];
107-
OutT *out_ptr;
150+
InT ins[ET];
108151
#pragma unroll
109152
for (int i = 0; i < ET; ++i) {
110153
ins_ptr[i] = reinterpret_cast<InT *>(&(ins_vec[i]));
111154
}
112-
out_ptr = reinterpret_cast<OutT *>(&out_vec);
113-
114155
// load
115156
data.load_vector(ins_vec, tid);
116157

117158
// compute
118159
#pragma unroll
119160
for (int i = 0; i < VecSize; ++i) {
120-
InT ins[ET];
121161
#pragma unroll
122162
for (int j = 0; j < ET; ++j) {
123163
ins[j] = ins_ptr[j][i];
124164
}
125-
out_ptr[i] = func(ins);
165+
out_vec.val[i] = func(ins);
126166
}
127-
128167
// store
129168
data.store_vector(out_vec, tid);
130169
}
131170

132171
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
133172
typename Functor>
134-
__device__ void ScalarKernelImpl(
173+
__device__ inline void ScalarKernelImpl(
135174
ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
136175
int start, int remain) {
137176
InT ins[ET];
@@ -182,7 +221,7 @@ void LaunchSameDimsElementwiseCudaKernel(
182221
// calculate the max vec_size for all ins and outs
183222
auto size = ins[0]->numel();
184223
int vec_size = GetVectorizedSize<InT, OutT>(ins, *outs);
185-
int block_size = ELEMENTWISE_BLOCK_SIZE;
224+
int block_size = GetThreadsConfig(ctx, size, vec_size);
186225
int grid_size =
187226
((size + vec_size - 1) / vec_size + block_size - 1) / block_size;
188227
const InT *in0 = ins[0]->data<InT>();

0 commit comments

Comments
 (0)