Skip to content

Commit d865b04

Browse files
authored
Merge pull request #4201 from qingqing01/fix_prelu
Refine platform::Transform function and fix prelu_op testing.
2 parents a9202e8 + 2aa4d32 commit d865b04

File tree

4 files changed

+74
-40
lines changed

4 files changed

+74
-40
lines changed

paddle/operators/prelu_op.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,9 @@ class PReluKernel : public framework::OpKernel {
5454

5555
int numel = x->numel();
5656

57-
Transform(context.device_context(), x_ptr, x_ptr + numel, o_ptr,
58-
PReluFunctor<T>(alpha_ptr));
57+
Transform<Place> trans;
58+
trans(context.device_context(), x_ptr, x_ptr + numel, o_ptr,
59+
PReluFunctor<T>(alpha_ptr));
5960
}
6061
};
6162

@@ -91,8 +92,9 @@ class PReluGradKernel : public framework::OpKernel {
9192
const T* out_ptr = out->data<T>();
9293
int numel = dx->numel();
9394

94-
Transform(context.device_context(), out_ptr, out_ptr + numel, dout_ptr,
95-
dx_ptr, PReluGradFunctor<T>(alpha_ptr));
95+
Transform<Place> trans;
96+
trans(context.device_context(), out_ptr, out_ptr + numel, dout_ptr, dx_ptr,
97+
PReluGradFunctor<T>(alpha_ptr));
9698

9799
// TODO (Zhuoyuan): add dalpha upgrade when GPU kernels ready
98100
}

paddle/platform/transform.h

Lines changed: 55 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,45 +29,71 @@
2929

3030
namespace paddle {
3131
namespace platform {
32+
3233
// Transform on host or device. It provides the same API in std library.
33-
template <typename InputIter, typename OutputIter, typename UnaryOperation>
34-
void Transform(const DeviceContext& context, InputIter first, InputIter last,
35-
OutputIter result, UnaryOperation op) {
36-
auto place = context.GetPlace();
37-
if (is_cpu_place(place)) {
34+
template <typename Place>
35+
struct Transform {
36+
template <typename InputIter, typename OutputIter, typename UnaryOperation>
37+
void operator()(const DeviceContext& context, InputIter first, InputIter last,
38+
OutputIter result, UnaryOperation op);
39+
40+
template <typename InputIter1, typename InputIter2, typename OutputIter,
41+
typename BinaryOperation>
42+
void operator()(const DeviceContext& context, InputIter1 first1,
43+
InputIter1 last1, InputIter2 first2, OutputIter result,
44+
BinaryOperation op);
45+
};
46+
47+
template <>
48+
struct Transform<platform::CPUPlace> {
49+
template <typename InputIter, typename OutputIter, typename UnaryOperation>
50+
void operator()(const DeviceContext& context, InputIter first, InputIter last,
51+
OutputIter result, UnaryOperation op) {
52+
auto place = context.GetPlace();
53+
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
3854
std::transform(first, last, result, op);
39-
} else {
40-
#ifdef __NVCC__
41-
auto& ctx = reinterpret_cast<const CUDADeviceContext&>(context);
42-
using namespace details;
43-
thrust::transform(thrust::cuda::par.on(ctx.stream()), DevPtrCast(first),
44-
DevPtrCast(last), DevPtrCast(result), op);
45-
#else
46-
PADDLE_THROW("Do not invoke `Transform<GPUPlace>` in .cc file");
47-
#endif
4855
}
49-
}
5056

51-
template <typename InputIter1, typename InputIter2, typename OutputIter,
52-
typename BinaryOperation>
53-
void Transform(const DeviceContext& context, InputIter1 first1,
54-
InputIter1 last1, InputIter2 first2, OutputIter result,
55-
BinaryOperation op) {
56-
auto place = context.GetPlace();
57-
if (is_cpu_place(place)) {
57+
template <typename InputIter1, typename InputIter2, typename OutputIter,
58+
typename BinaryOperation>
59+
void operator()(const DeviceContext& context, InputIter1 first1,
60+
InputIter1 last1, InputIter2 first2, OutputIter result,
61+
BinaryOperation op) {
62+
auto place = context.GetPlace();
63+
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
5864
std::transform(first1, last1, first2, result, op);
59-
} else {
65+
}
66+
};
67+
6068
#ifdef __NVCC__
69+
template <>
70+
struct Transform<platform::GPUPlace> {
71+
template <typename InputIter, typename OutputIter, typename UnaryOperation>
72+
void operator()(const DeviceContext& context, InputIter first, InputIter last,
73+
OutputIter result, UnaryOperation op) {
74+
auto place = context.GetPlace();
75+
PADDLE_ENFORCE(is_gpu_place(place), "It must use GPU place.");
6176
auto& ctx = reinterpret_cast<const CUDADeviceContext&>(context);
62-
using namespace details;
63-
thrust::transform(thrust::cuda::par.on(ctx.stream()), DevPtrCast(first1),
64-
DevPtrCast(last1), DevPtrCast(first2), DevPtrCast(result),
77+
thrust::transform(thrust::cuda::par.on(ctx.stream()),
78+
details::DevPtrCast(first), details::DevPtrCast(last),
79+
details::DevPtrCast(result), op);
80+
}
81+
82+
template <typename InputIter1, typename InputIter2, typename OutputIter,
83+
typename BinaryOperation>
84+
void operator()(const DeviceContext& context, InputIter1 first1,
85+
InputIter1 last1, InputIter2 first2, OutputIter result,
86+
BinaryOperation op) {
87+
auto place = context.GetPlace();
88+
PADDLE_ENFORCE(is_gpu_place(place), "It must use GPU place.");
89+
auto& ctx = reinterpret_cast<const CUDADeviceContext&>(context);
90+
thrust::transform(thrust::cuda::par.on(ctx.stream()),
91+
details::DevPtrCast(first1), details::DevPtrCast(last1),
92+
details::DevPtrCast(first2), details::DevPtrCast(result),
6593
op);
66-
#else
67-
PADDLE_THROW("Do not invoke `Transform<GPUPlace>` in .cc file");
68-
#endif
6994
}
7095
};
96+
#endif
7197

7298
} // namespace platform
7399
} // namespace paddle

paddle/platform/transform_test.cu

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <gtest/gtest.h>
1616
#include "paddle/memory/memcpy.h"
1717
#include "paddle/memory/memory.h"
18+
#include "paddle/platform/hostdevice.h"
1819
#include "paddle/platform/transform.h"
1920

2021
template <typename T>
@@ -38,7 +39,8 @@ TEST(Transform, CPUUnary) {
3839
using namespace paddle::platform;
3940
CPUDeviceContext ctx;
4041
float buf[4] = {0.1, 0.2, 0.3, 0.4};
41-
Transform(ctx, buf, buf + 4, buf, Scale<float>(10));
42+
Transform<paddle::platform::CPUPlace> trans;
43+
trans(ctx, buf, buf + 4, buf, Scale<float>(10));
4244
for (int i = 0; i < 4; ++i) {
4345
ASSERT_NEAR(buf[i], static_cast<float>(i + 1), 1e-5);
4446
}
@@ -52,7 +54,8 @@ TEST(Transform, GPUUnary) {
5254
float cpu_buf[4] = {0.1, 0.2, 0.3, 0.4};
5355
float* gpu_buf = static_cast<float*>(Alloc(gpu0, sizeof(float) * 4));
5456
Copy(gpu0, gpu_buf, CPUPlace(), cpu_buf, sizeof(cpu_buf));
55-
Transform(ctx, gpu_buf, gpu_buf + 4, gpu_buf, Scale<float>(10));
57+
Transform<paddle::platform::GPUPlace> trans;
58+
trans(ctx, gpu_buf, gpu_buf + 4, gpu_buf, Scale<float>(10));
5659
ctx.Wait();
5760
Copy(CPUPlace(), cpu_buf, gpu0, gpu_buf, sizeof(cpu_buf));
5861
Free(gpu0, gpu_buf);
@@ -65,7 +68,9 @@ TEST(Transform, CPUBinary) {
6568
using namespace paddle::platform;
6669
using namespace paddle::memory;
6770
int buf[4] = {1, 2, 3, 4};
68-
Transform(CPUDeviceContext(), buf, buf + 4, buf, buf, Multiply<int>());
71+
Transform<paddle::platform::CPUPlace> trans;
72+
CPUDeviceContext ctx;
73+
trans(ctx, buf, buf + 4, buf, buf, Multiply<int>());
6974
for (int i = 0; i < 4; ++i) {
7075
ASSERT_EQ((i + 1) * (i + 1), buf[i]);
7176
}
@@ -79,11 +84,12 @@ TEST(Transform, GPUBinary) {
7984
CUDADeviceContext ctx(gpu0);
8085
int* gpu_buf = static_cast<int*>(Alloc(gpu0, sizeof(buf)));
8186
Copy(gpu0, gpu_buf, CPUPlace(), buf, sizeof(buf));
82-
Transform(ctx, gpu_buf, gpu_buf + 4, gpu_buf, gpu_buf, Multiply<int>());
87+
Transform<paddle::platform::GPUPlace> trans;
88+
trans(ctx, gpu_buf, gpu_buf + 4, gpu_buf, gpu_buf, Multiply<int>());
8389
ctx.Wait();
8490
Copy(CPUPlace(), buf, gpu0, gpu_buf, sizeof(buf));
8591
Free(gpu0, gpu_buf);
8692
for (int i = 0; i < 4; ++i) {
8793
ASSERT_EQ((i + 1) * (i + 1), buf[i]);
8894
}
89-
}
95+
}

python/paddle/v2/framework/tests/test_prelu_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ def setUp(self):
1717
assert out_np is not self.inputs['X']
1818
self.outputs = {'Out': out_np}
1919

20-
def not_test_check_output(self):
20+
def test_check_output(self):
2121
self.check_output()
2222

23-
def not_test_check_grad(self):
23+
def test_check_grad(self):
2424
self.check_grad(['X'], 'Out')
2525

2626

0 commit comments

Comments
 (0)