Skip to content

Commit f8c6792

Browse files
committed
Extract DevPtrCast to device_ptr_cast.h
1 parent 54d88d4 commit f8c6792

File tree

2 files changed

+63
-33
lines changed

2 files changed

+63
-33
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#ifndef __NVCC__
18+
#error device_ptr_cast must be include by .cu file
19+
#endif
20+
21+
#include <thrust/device_ptr.h>
22+
23+
namespace paddle {
24+
namespace platform {
25+
namespace details {
26+
template <typename T, bool is_ptr>
27+
struct DevicePtrCast;
28+
29+
template <typename T>
30+
struct DevicePtrCast<T, true> {
31+
using ELEM = typename std::remove_pointer<T>::type;
32+
using RTYPE = thrust::device_ptr<ELEM>;
33+
34+
inline thrust::device_ptr<ELEM> operator()(ELEM* ele) const {
35+
return thrust::device_pointer_cast(ele);
36+
}
37+
};
38+
39+
template <typename T>
40+
struct DevicePtrCast<T, false> {
41+
using RTYPE = T;
42+
inline RTYPE operator()(RTYPE it) const { return it; }
43+
};
44+
45+
// Cast T to thrust::device_ptr if T is a pointer.
46+
// Otherwise, e.g., T is a iterator, return T itself.
47+
template <typename T>
48+
auto DevPtrCast(T t) ->
49+
typename DevicePtrCast<T, std::is_pointer<T>::value>::RTYPE {
50+
DevicePtrCast<T, std::is_pointer<T>::value> cast;
51+
return cast(t);
52+
}
53+
54+
} // namespace details
55+
} // namespace platform
56+
} // namespace paddle

paddle/platform/transform.h

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,41 +21,12 @@
2121
#include <algorithm>
2222
#include <type_traits>
2323
#ifdef __NVCC__
24-
#include <thrust/device_ptr.h>
2524
#include <thrust/transform.h>
25+
#include "paddle/platform/details/device_ptr_cast.h"
2626
#endif
2727

2828
namespace paddle {
2929
namespace platform {
30-
31-
#ifdef __NVCC__
32-
template <typename T, bool is_ptr>
33-
struct DevicePtrCast;
34-
35-
template <typename T>
36-
struct DevicePtrCast<T, true> {
37-
using ELEM = typename std::remove_pointer<T>::type;
38-
using RTYPE = thrust::device_ptr<ELEM>;
39-
40-
inline thrust::device_ptr<ELEM> operator()(ELEM* ele) const {
41-
return thrust::device_pointer_cast(ele);
42-
}
43-
};
44-
45-
template <typename T>
46-
struct DevicePtrCast<T, false> {
47-
using RTYPE = T;
48-
inline RTYPE operator()(RTYPE it) const { return it; }
49-
};
50-
51-
template <typename T>
52-
auto DevCast(T t) ->
53-
typename DevicePtrCast<T, std::is_pointer<T>::value>::RTYPE {
54-
DevicePtrCast<T, std::is_pointer<T>::value> cast;
55-
return cast(t);
56-
}
57-
#endif
58-
5930
// Transform on host or device. It provides the same API in std library.
6031
template <typename Place, typename InputIter, typename OutputIter,
6132
typename UnaryOperation>
@@ -65,7 +36,9 @@ void Transform(Place place, InputIter first, InputIter last, OutputIter result,
6536
std::transform(first, last, result, op);
6637
} else {
6738
#ifdef __NVCC__
68-
thrust::transform(DevCast(first), DevCast(last), DevCast(result), op);
39+
using namespace details;
40+
thrust::transform(DevPtrCast(first), DevPtrCast(last), DevPtrCast(result),
41+
op);
6942
#else
7043
PADDLE_THROW("Do not invoke `Transform<GPUPlace>` in .cc file");
7144
#endif
@@ -80,8 +53,9 @@ void Transform(Place place, InputIter1 first1, InputIter1 last1,
8053
std::transform(first1, last1, first2, result, op);
8154
} else {
8255
#ifdef __NVCC__
83-
thrust::transform(DevCast(first1), DevCast(last1), DevCast(first2),
84-
DevCast(result), op);
56+
using namespace details;
57+
thrust::transform(DevPtrCast(first1), DevPtrCast(last1), DevPtrCast(first2),
58+
DevPtrCast(result), op);
8559
#else
8660
PADDLE_THROW("Do not invoke `Transform<GPUPlace>` in .cc file");
8761
#endif

0 commit comments

Comments
 (0)