@@ -12,25 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License. */
1414
15+ #include < thrust/device_vector.h>
16+ #include < thrust/host_vector.h>
1517#include < thrust/random.h>
1618#include < thrust/transform.h>
1719#include < limits>
1820#include " paddle/fluid/framework/generator.h"
1921#include " paddle/fluid/framework/op_registry.h"
2022#include " paddle/fluid/framework/operator.h"
23+ #include " paddle/fluid/operators/truncated_gaussian_random_op.h"
2124
2225namespace paddle {
2326namespace operators {
2427
2528template <typename T>
26- struct TruncatedNormal {
29+ struct GPUTruncatedNormal {
2730 T mean, std;
2831 T a_normal_cdf;
2932 T b_normal_cdf;
3033 unsigned int seed;
3134 T numeric_min;
3235
33- __host__ __device__ TruncatedNormal (T mean, T std, T numeric_min, int seed)
36+ __host__ __device__ GPUTruncatedNormal (T mean, T std, T numeric_min, int seed)
3437 : mean(mean), std(std), seed(seed), numeric_min(numeric_min) {
3538 a_normal_cdf = (1.0 + erff (-2.0 / sqrtf (2.0 ))) / 2.0 ;
3639 b_normal_cdf = (1.0 + erff (2.0 / sqrtf (2.0 ))) / 2.0 ;
@@ -110,10 +113,10 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
110113 TruncatedNormalOffset<T>(mean, std, std::numeric_limits<T>::min (),
111114 seed_offset.first , gen_offset));
112115 } else {
113- thrust::transform (
114- index_sequence_begin, index_sequence_begin + size ,
115- thrust::device_ptr <T>(data),
116- TruncatedNormal<T>( mean, std, std::numeric_limits<T>::min (), seed));
116+ thrust::transform (index_sequence_begin, index_sequence_begin + size,
117+ thrust::device_ptr<T>(data) ,
118+ GPUTruncatedNormal <T>(
119+ mean, std, std::numeric_limits<T>::min (), seed));
117120 }
118121 }
119122};
0 commit comments