|
29 | 29 |
|
30 | 30 | namespace paddle { |
31 | 31 | namespace platform { |
| 32 | + |
32 | 33 | // Transform on host or device. It provides the same API in std library. |
33 | | -template <typename InputIter, typename OutputIter, typename UnaryOperation> |
34 | | -void Transform(const DeviceContext& context, InputIter first, InputIter last, |
35 | | - OutputIter result, UnaryOperation op) { |
36 | | - auto place = context.GetPlace(); |
37 | | - if (is_cpu_place(place)) { |
| 34 | +template <typename Place> |
| 35 | +struct Transform { |
| 36 | + template <typename InputIter, typename OutputIter, typename UnaryOperation> |
| 37 | + void operator()(const DeviceContext& context, InputIter first, InputIter last, |
| 38 | + OutputIter result, UnaryOperation op); |
| 39 | + |
| 40 | + template <typename InputIter1, typename InputIter2, typename OutputIter, |
| 41 | + typename BinaryOperation> |
| 42 | + void operator()(const DeviceContext& context, InputIter1 first1, |
| 43 | + InputIter1 last1, InputIter2 first2, OutputIter result, |
| 44 | + BinaryOperation op); |
| 45 | +}; |
| 46 | + |
| 47 | +template <> |
| 48 | +struct Transform<platform::CPUPlace> { |
| 49 | + template <typename InputIter, typename OutputIter, typename UnaryOperation> |
| 50 | + void operator()(const DeviceContext& context, InputIter first, InputIter last, |
| 51 | + OutputIter result, UnaryOperation op) { |
| 52 | + auto place = context.GetPlace(); |
| 53 | + PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place."); |
38 | 54 | std::transform(first, last, result, op); |
39 | | - } else { |
40 | | -#ifdef __NVCC__ |
41 | | - auto& ctx = reinterpret_cast<const CUDADeviceContext&>(context); |
42 | | - using namespace details; |
43 | | - thrust::transform(thrust::cuda::par.on(ctx.stream()), DevPtrCast(first), |
44 | | - DevPtrCast(last), DevPtrCast(result), op); |
45 | | -#else |
46 | | - PADDLE_THROW("Do not invoke `Transform<GPUPlace>` in .cc file"); |
47 | | -#endif |
48 | 55 | } |
49 | | -} |
50 | 56 |
|
51 | | -template <typename InputIter1, typename InputIter2, typename OutputIter, |
52 | | - typename BinaryOperation> |
53 | | -void Transform(const DeviceContext& context, InputIter1 first1, |
54 | | - InputIter1 last1, InputIter2 first2, OutputIter result, |
55 | | - BinaryOperation op) { |
56 | | - auto place = context.GetPlace(); |
57 | | - if (is_cpu_place(place)) { |
| 57 | + template <typename InputIter1, typename InputIter2, typename OutputIter, |
| 58 | + typename BinaryOperation> |
| 59 | + void operator()(const DeviceContext& context, InputIter1 first1, |
| 60 | + InputIter1 last1, InputIter2 first2, OutputIter result, |
| 61 | + BinaryOperation op) { |
| 62 | + auto place = context.GetPlace(); |
| 63 | + PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place."); |
58 | 64 | std::transform(first1, last1, first2, result, op); |
59 | | - } else { |
| 65 | + } |
| 66 | +}; |
| 67 | + |
60 | 68 | #ifdef __NVCC__ |
| 69 | +template <> |
| 70 | +struct Transform<platform::GPUPlace> { |
| 71 | + template <typename InputIter, typename OutputIter, typename UnaryOperation> |
| 72 | + void operator()(const DeviceContext& context, InputIter first, InputIter last, |
| 73 | + OutputIter result, UnaryOperation op) { |
| 74 | + auto place = context.GetPlace(); |
| 75 | + PADDLE_ENFORCE(is_gpu_place(place), "It must use GPU place."); |
61 | 76 | auto& ctx = reinterpret_cast<const CUDADeviceContext&>(context); |
62 | | - using namespace details; |
63 | | - thrust::transform(thrust::cuda::par.on(ctx.stream()), DevPtrCast(first1), |
64 | | - DevPtrCast(last1), DevPtrCast(first2), DevPtrCast(result), |
| 77 | + thrust::transform(thrust::cuda::par.on(ctx.stream()), |
| 78 | + details::DevPtrCast(first), details::DevPtrCast(last), |
| 79 | + details::DevPtrCast(result), op); |
| 80 | + } |
| 81 | + |
| 82 | + template <typename InputIter1, typename InputIter2, typename OutputIter, |
| 83 | + typename BinaryOperation> |
| 84 | + void operator()(const DeviceContext& context, InputIter1 first1, |
| 85 | + InputIter1 last1, InputIter2 first2, OutputIter result, |
| 86 | + BinaryOperation op) { |
| 87 | + auto place = context.GetPlace(); |
| 88 | + PADDLE_ENFORCE(is_gpu_place(place), "It must use GPU place."); |
| 89 | + auto& ctx = reinterpret_cast<const CUDADeviceContext&>(context); |
| 90 | + thrust::transform(thrust::cuda::par.on(ctx.stream()), |
| 91 | + details::DevPtrCast(first1), details::DevPtrCast(last1), |
| 92 | + details::DevPtrCast(first2), details::DevPtrCast(result), |
65 | 93 | op); |
66 | | -#else |
67 | | - PADDLE_THROW("Do not invoke `Transform<GPUPlace>` in .cc file"); |
68 | | -#endif |
69 | 94 | } |
70 | 95 | }; |
| 96 | +#endif |
71 | 97 |
|
72 | 98 | } // namespace platform |
73 | 99 | } // namespace paddle |
0 commit comments