@@ -26,14 +26,14 @@ namespace paddle {
2626namespace operators {
2727
2828template <typename T>
29- struct TruncatedNormal {
29+ struct GPUTruncatedNormal {
3030 T mean, std;
3131 T a_normal_cdf;
3232 T b_normal_cdf;
3333 unsigned int seed;
3434 T numeric_min;
3535
36- __host__ __device__ TruncatedNormal (T mean, T std, T numeric_min, int seed)
36+ __host__ __device__ GPUTruncatedNormal (T mean, T std, T numeric_min, int seed)
3737 : mean(mean), std(std), seed(seed), numeric_min(numeric_min) {
3838 a_normal_cdf = (1.0 + erff (-2.0 / sqrtf (2.0 ))) / 2.0 ;
3939 b_normal_cdf = (1.0 + erff (2.0 / sqrtf (2.0 ))) / 2.0 ;
@@ -113,10 +113,10 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
113113 TruncatedNormalOffset<T>(mean, std, std::numeric_limits<T>::min (),
114114 seed_offset.first , gen_offset));
115115 } else {
116- thrust::transform (
117- index_sequence_begin, index_sequence_begin + size ,
118- thrust::device_ptr <T>(data),
119- TruncatedNormal<T>( mean, std, std::numeric_limits<T>::min (), seed));
116+ thrust::transform (index_sequence_begin, index_sequence_begin + size,
117+ thrust::device_ptr<T>(data) ,
118+ GPUTruncatedNormal <T>(
119+ mean, std, std::numeric_limits<T>::min (), seed));
120120 }
121121 }
122122};
0 commit comments