Skip to content

Commit da39ea4

Browse files
committed
fix compilation of truncated_gaussian_random_op, test=develop
1 parent ffaf62e commit da39ea4

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

paddle/fluid/operators/truncated_gaussian_random_op.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ namespace paddle {
2626
namespace operators {
2727

2828
template <typename T>
29-
struct TruncatedNormal {
29+
struct GPUTruncatedNormal {
3030
T mean, std;
3131
T a_normal_cdf;
3232
T b_normal_cdf;
3333
unsigned int seed;
3434
T numeric_min;
3535

36-
__host__ __device__ TruncatedNormal(T mean, T std, T numeric_min, int seed)
36+
__host__ __device__ GPUTruncatedNormal(T mean, T std, T numeric_min, int seed)
3737
: mean(mean), std(std), seed(seed), numeric_min(numeric_min) {
3838
a_normal_cdf = (1.0 + erff(-2.0 / sqrtf(2.0))) / 2.0;
3939
b_normal_cdf = (1.0 + erff(2.0 / sqrtf(2.0))) / 2.0;
@@ -113,10 +113,10 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
113113
TruncatedNormalOffset<T>(mean, std, std::numeric_limits<T>::min(),
114114
seed_offset.first, gen_offset));
115115
} else {
116-
thrust::transform(
117-
index_sequence_begin, index_sequence_begin + size,
118-
thrust::device_ptr<T>(data),
119-
TruncatedNormal<T>(mean, std, std::numeric_limits<T>::min(), seed));
116+
thrust::transform(index_sequence_begin, index_sequence_begin + size,
117+
thrust::device_ptr<T>(data),
118+
GPUTruncatedNormal<T>(
119+
mean, std, std::numeric_limits<T>::min(), seed));
120120
}
121121
}
122122
};

0 commit comments

Comments
 (0)