2121#include < algorithm>
2222#include < type_traits>
2323#ifdef __NVCC__
24- #include < thrust/device_ptr.h>
2524#include < thrust/transform.h>
25+ #include " paddle/platform/details/device_ptr_cast.h"
2626#endif
2727
2828namespace paddle {
2929namespace platform {
30-
31- #ifdef __NVCC__
32- template <typename T, bool is_ptr>
33- struct DevicePtrCast ;
34-
35- template <typename T>
36- struct DevicePtrCast <T, true > {
37- using ELEM = typename std::remove_pointer<T>::type;
38- using RTYPE = thrust::device_ptr<ELEM>;
39-
40- inline thrust::device_ptr<ELEM> operator ()(ELEM* ele) const {
41- return thrust::device_pointer_cast (ele);
42- }
43- };
44-
45- template <typename T>
46- struct DevicePtrCast <T, false > {
47- using RTYPE = T;
48- inline RTYPE operator ()(RTYPE it) const { return it; }
49- };
50-
51- template <typename T>
52- auto DevCast (T t) ->
53- typename DevicePtrCast<T, std::is_pointer<T>::value>::RTYPE {
54- DevicePtrCast<T, std::is_pointer<T>::value> cast;
55- return cast (t);
56- }
57- #endif
58-
5930// Transform on host or device. It provides the same API in std library.
6031template <typename Place, typename InputIter, typename OutputIter,
6132 typename UnaryOperation>
@@ -65,7 +36,9 @@ void Transform(Place place, InputIter first, InputIter last, OutputIter result,
6536 std::transform (first, last, result, op);
6637 } else {
6738#ifdef __NVCC__
68- thrust::transform (DevCast (first), DevCast (last), DevCast (result), op);
39+ using namespace details ;
40+ thrust::transform (DevPtrCast (first), DevPtrCast (last), DevPtrCast (result),
41+ op);
6942#else
7043 PADDLE_THROW (" Do not invoke `Transform<GPUPlace>` in .cc file" );
7144#endif
@@ -80,8 +53,9 @@ void Transform(Place place, InputIter1 first1, InputIter1 last1,
8053 std::transform (first1, last1, first2, result, op);
8154 } else {
8255#ifdef __NVCC__
83- thrust::transform (DevCast (first1), DevCast (last1), DevCast (first2),
84- DevCast (result), op);
56+ using namespace details ;
57+ thrust::transform (DevPtrCast (first1), DevPtrCast (last1), DevPtrCast (first2),
58+ DevPtrCast (result), op);
8559#else
8660 PADDLE_THROW (" Do not invoke `Transform<GPUPlace>` in .cc file" );
8761#endif
0 commit comments