Skip to content

Commit ab6f874

Browse files
authored
remove thrust include files (#32395)
* remove thrust includes, test=develop * fix compilation error, test=develop * fix compilation of truncated_gaussian_random_op, test=develop
1 parent 2194ad1 commit ab6f874

File tree

7 files changed

+20
-13
lines changed

7 files changed

+20
-13
lines changed

paddle/fluid/framework/lod_tensor.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,11 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <glog/logging.h>
1718
#include <memory>
1819
#include <string>
1920
#include <utility>
2021
#include <vector>
21-
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
22-
#include <thrust/device_vector.h>
23-
#include <thrust/host_vector.h>
24-
#endif
25-
26-
#include <glog/logging.h>
2722

2823
#include "paddle/fluid/framework/ddim.h"
2924
#include "paddle/fluid/framework/mixed_vector.h"

paddle/fluid/operators/diag_embed_op.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include <thrust/device_vector.h>
16+
#include <thrust/host_vector.h>
1517
#include "paddle/fluid/framework/op_registry.h"
1618
#include "paddle/fluid/operators/diag_embed_op.h"
1719

paddle/fluid/operators/gaussian_random_op.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14+
#include <thrust/device_vector.h>
15+
#include <thrust/host_vector.h>
1416
#include <thrust/random.h>
1517
#include <thrust/transform.h>
1618
#include "paddle/fluid/framework/generator.h"

paddle/fluid/operators/modified_huber_loss_op.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14+
#include <thrust/device_vector.h>
1415
#include <thrust/for_each.h>
16+
#include <thrust/host_vector.h>
1517
#include <thrust/tuple.h>
1618
#include "paddle/fluid/framework/op_registry.h"
1719
#include "paddle/fluid/operators/modified_huber_loss_op.h"

paddle/fluid/operators/trace_op.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include <thrust/device_vector.h>
16+
#include <thrust/host_vector.h>
1517
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
1618
#include "paddle/fluid/operators/trace_op.h"
1719

paddle/fluid/operators/truncated_gaussian_random_op.cu

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include <thrust/device_vector.h>
16+
#include <thrust/host_vector.h>
1517
#include <thrust/random.h>
1618
#include <thrust/transform.h>
1719
#include <limits>
1820
#include "paddle/fluid/framework/generator.h"
1921
#include "paddle/fluid/framework/op_registry.h"
2022
#include "paddle/fluid/framework/operator.h"
23+
#include "paddle/fluid/operators/truncated_gaussian_random_op.h"
2124

2225
namespace paddle {
2326
namespace operators {
2427

2528
template <typename T>
26-
struct TruncatedNormal {
29+
struct GPUTruncatedNormal {
2730
T mean, std;
2831
T a_normal_cdf;
2932
T b_normal_cdf;
3033
unsigned int seed;
3134
T numeric_min;
3235

33-
__host__ __device__ TruncatedNormal(T mean, T std, T numeric_min, int seed)
36+
__host__ __device__ GPUTruncatedNormal(T mean, T std, T numeric_min, int seed)
3437
: mean(mean), std(std), seed(seed), numeric_min(numeric_min) {
3538
a_normal_cdf = (1.0 + erff(-2.0 / sqrtf(2.0))) / 2.0;
3639
b_normal_cdf = (1.0 + erff(2.0 / sqrtf(2.0))) / 2.0;
@@ -110,10 +113,10 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
110113
TruncatedNormalOffset<T>(mean, std, std::numeric_limits<T>::min(),
111114
seed_offset.first, gen_offset));
112115
} else {
113-
thrust::transform(
114-
index_sequence_begin, index_sequence_begin + size,
115-
thrust::device_ptr<T>(data),
116-
TruncatedNormal<T>(mean, std, std::numeric_limits<T>::min(), seed));
116+
thrust::transform(index_sequence_begin, index_sequence_begin + size,
117+
thrust::device_ptr<T>(data),
118+
GPUTruncatedNormal<T>(
119+
mean, std, std::numeric_limits<T>::min(), seed));
117120
}
118121
}
119122
};

paddle/fluid/operators/uniform_random_op.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14+
#include <thrust/device_vector.h>
15+
#include <thrust/host_vector.h>
1416
#include <thrust/random.h>
1517
#include <thrust/transform.h>
16-
1718
#include "paddle/fluid/framework/generator.h"
1819
#include "paddle/fluid/framework/op_registry.h"
1920
#include "paddle/fluid/framework/operator.h"

0 commit comments

Comments
 (0)