Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c60c4e0
First_Commit.
JamesLim-sy May 21, 2021
8a002d8
Fixs
JamesLim-sy May 21, 2021
9e89caf
Fix bugs
JamesLim-sy May 23, 2021
a495b1a
Change the assert codes
JamesLim-sy May 23, 2021
629ceaa
Update the flood_div op
JamesLim-sy May 23, 2021
d42eeb0
Update the flood_div op
JamesLim-sy May 23, 2021
62b221a
Update the flood_div op
JamesLim-sy May 23, 2021
67b034b
Update the flood_div op, change back into HOSTDEVICE.
JamesLim-sy May 23, 2021
937a297
Merge branch 'Adding_div_functors' of https://github.com/JamesLim-sy/…
JamesLim-sy May 23, 2021
2d371bc
Fix written bugs.
JamesLim-sy May 23, 2021
bc2b805
Fix bugs
JamesLim-sy May 23, 2021
9d46543
Fisrt commit
JamesLim-sy May 25, 2021
74e4179
Trigger of rerun
JamesLim-sy May 25, 2021
656ac99
To avoid spartial specification bugs which happened in PR-CI-ROCM
JamesLim-sy May 26, 2021
585566f
Avoid kUnary instantiation of LaunchElementwiseCudaKernel at compile …
JamesLim-sy May 30, 2021
d9c70ec
refine warpper of broadcast and add cuda op
JamesLim-sy May 31, 2021
20f9e36
Merge branch 'Fix_bugs_in_elementwise_warpprer' into Adding_div_functors
JamesLim-sy Jun 1, 2021
fb49a21
Fix bugs
JamesLim-sy Jun 1, 2021
e2692ee
revert the changes in host exe codes in .h files
JamesLim-sy Jun 1, 2021
2f72406
Merge branch 'develop' into Adding_div_functors
JamesLim-sy Jun 4, 2021
b7a41e9
Only support div currently.
JamesLim-sy Jun 5, 2021
8f26886
refine the vectorized_load length function
JamesLim-sy Jun 8, 2021
7b9c9a3
refine the vectorized_load length function
JamesLim-sy Jun 8, 2021
4f5b5c1
refine vectorized load determination function
JamesLim-sy Jun 8, 2021
5aaf95d
refine codes
JamesLim-sy Jun 9, 2021
f7d35b7
refine codes
JamesLim-sy Jun 9, 2021
da1cd7f
refine codes
JamesLim-sy Jun 9, 2021
85d954c
refine threads config.
JamesLim-sy Jun 10, 2021
13698a6
fix multiple errors with inline syntax
JamesLim-sy Jun 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 28 additions & 30 deletions paddle/fluid/operators/elementwise/elementwise_div_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"

Expand All @@ -23,38 +22,37 @@ namespace plat = paddle::platform;
namespace paddle {
namespace operators {

template <typename T, typename Enable = void>
struct CudaDivFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] / args[1];
}
};

template <typename T>
struct SameDimsElemwiseDiv<platform::CUDADeviceContext, T> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
DivRangeFunctor<T> functor(x->data<T>(), y->data<T>(), z->data<T>());
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
x->numel());
for_range(functor);
struct CudaDivFunctor<T,
typename std::enable_if_t<std::is_integral<T>::value>> {
inline HOSTDEVICE T operator()(const T* args) const {
PADDLE_ENFORCE(args[1] != 0,
"Invalid Argument Error: Integer division by zero "
"encountered in divide. Please check the input value.");
return args[0] / args[1];
}
};

template <>
struct SameDimsElemwiseDiv<platform::CUDADeviceContext, platform::float16> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto size = x->numel();
dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) /
PADDLE_CUDA_THREAD_SIZE,
1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
const half* x2 =
reinterpret_cast<const half*>(x->data<platform::float16>());
const half* y2 =
reinterpret_cast<const half*>(y->data<platform::float16>());
half* z2 = reinterpret_cast<half*>(z->data<platform::float16>());
SameDimsElemwiseDivCUDAKernel<<<
grid_size, block_size, 0,
ctx.template device_context<platform::CUDADeviceContext>().stream()>>>(
x2, y2, z2, size);
template <typename T>
class ElementwiseDivKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaDivFunctor<T>());
}
};

Expand Down
34 changes: 33 additions & 1 deletion paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,43 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;

namespace paddle {
namespace operators {

template <typename T>
struct CudaFloorDivFunctor {
inline HOSTDEVICE T operator()(const T argv[]) const {
PADDLE_ENFORCE(argv[1] != 0,
"InvalidArgument: divide by zero "
"encountered in floor-divide ops, please check.\n");
return static_cast<T>(std::trunc(argv[0] / argv[1]));
}
};

template <typename T>
class ElementwiseFloorDivKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaFloorDivFunctor<T>());
}
};

} // namespace operators
} // namespace paddle

REGISTER_OP_CUDA_KERNEL(
elementwise_floordiv,
ops::ElementwiseFloorDivKernel<plat::CUDADeviceContext, int>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License. */

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"

namespace paddle {
Expand Down
67 changes: 53 additions & 14 deletions paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ limitations under the License. */
#pragma once

#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/fast_divmod.h"

#ifdef __HIPCC__
Expand All @@ -28,19 +28,62 @@ namespace operators {

enum ElementwiseType { kUnary = 1, kBinary = 2 };

/*
* According to NVIDIA, if number of threads per block is 64/128/256/512,
* cuda performs better. And number of blocks should be greater (at least
* 2x~4x) than number of SMs. Hence, SM count is took into account within
* this function to determine the right number of threads per block.
*/
inline int GetThreadsConfig(const platform::CUDADeviceContext &ctx,
int64_t numel, int vec_size) {
int threads = ELEMENTWISE_BLOCK_SIZE;
int sm_count = ctx.GetSMCount();
int active_threads_num = numel / vec_size;
if (active_threads_num / (sm_count << 1) < ELEMENTWISE_BLOCK_SIZE) {
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about twice of SM, to acquire better performance.
threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 1));
} else if (active_threads_num / (sm_count << 2) < ELEMENTWISE_BLOCK_SIZE) {
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about 4 times of SM, to acquire better performance.
threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 2));
}
// Number of threads per block shall be larger than 64.
return std::max(64, threads);
}

/*
* Only the address of input data is the multiplier of 1,2,4, vectorized load
* with corresponding multiplier-value is possible. Moreover, the maximum length
* of vectorized load is 128 bits once. Hence, valid length of vectorized load
* shall be determined under both former constraints.
*/
template <typename T>
int GetVectorizedSizeImpl(const T *pointer) {
constexpr int max_load_bits = 128;
int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T);
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec8 =
std::alignment_of<CudaAlignedVector<T, 8>>::value; // NOLINT
constexpr int vec4 =
std::alignment_of<CudaAlignedVector<T, 4>>::value; // NOLINT
constexpr int vec2 =
std::alignment_of<CudaAlignedVector<T, 2>>::value; // NOLINT
if (address % vec4 == 0) {
return 4;
if (address % vec8 == 0) {
/*
* Currently, decide to deal with no more than 4 data once while adopting
* vectorization load/store, if performance test shows that dealing with
* 8 data once in vectorization load/store does get optimized, return code
* below can be changed into " return std::min(8, valid_vec_size); " .
*/
return std::min(4, valid_vec_size);
} else if (address % vec4 == 0) {
return std::min(4, valid_vec_size);
} else if (address % vec2 == 0) {
return 2;
return std::min(2, valid_vec_size);
} else {
return 1;
}
return 1;
}

template <typename InT, typename OutT>
Expand Down Expand Up @@ -96,42 +139,38 @@ struct ElementwiseDataWrapper {

template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
typename Functor>
__device__ void VectorizedKernelImpl(
__device__ inline void VectorizedKernelImpl(
ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
int tid) {
using InVecType = CudaAlignedVector<InT, VecSize>;
using OutVecType = CudaAlignedVector<OutT, VecSize>;
InVecType ins_vec[ET];
OutVecType out_vec;
InT *ins_ptr[ET];
OutT *out_ptr;
InT ins[ET];
#pragma unroll
for (int i = 0; i < ET; ++i) {
ins_ptr[i] = reinterpret_cast<InT *>(&(ins_vec[i]));
}
out_ptr = reinterpret_cast<OutT *>(&out_vec);

// load
data.load_vector(ins_vec, tid);

// compute
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
InT ins[ET];
#pragma unroll
for (int j = 0; j < ET; ++j) {
ins[j] = ins_ptr[j][i];
}
out_ptr[i] = func(ins);
out_vec.val[i] = func(ins);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我们的AlignedVector或许可以改成继承Array类,定义AlignedArray

template <typename T, size_t N>
class Array {
public:
static constexpr size_t kSize = N;
HOSTDEVICE inline Array() {}
template <typename... Args>
HOSTDEVICE inline explicit Array(const T &val, Args... args) {
static_assert(N == sizeof...(Args) + 1, "Invalid argument");
UnrollVarArgsAssign<T>::Run(data_, val, args...);
}

}

// store
data.store_vector(out_vec, tid);
}

template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
typename Functor>
__device__ void ScalarKernelImpl(
__device__ inline void ScalarKernelImpl(
ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
int start, int remain) {
InT ins[ET];
Expand Down Expand Up @@ -182,7 +221,7 @@ void LaunchSameDimsElementwiseCudaKernel(
// calculate the max vec_size for all ins and outs
auto size = ins[0]->numel();
int vec_size = GetVectorizedSize<InT, OutT>(ins, *outs);
int block_size = ELEMENTWISE_BLOCK_SIZE;
int block_size = GetThreadsConfig(ctx, size, vec_size);
int grid_size =
((size + vec_size - 1) / vec_size + block_size - 1) / block_size;
const InT *in0 = ins[0]->data<InT>();
Expand Down