@@ -32,6 +32,11 @@ __global__ void GPUBCELossForward(const T* x_data, const T* label_data,
3232 T one = static_cast <T>(1 .);
3333 T neg_100 = static_cast <T>(-100 .);
3434
35+ PADDLE_ENFORCE (
36+ (x >= static_cast <T>(0 )) && (x <= one),
37+ " Input is expected to be within the interval [0, 1], but recieved %f." ,
38+ x);
39+
3540 T term1 = max (real_log (x), neg_100);
3641 T term2 = max (real_log (one - x), neg_100);
3742
@@ -64,29 +69,13 @@ class BCELossCUDAKernel : public framework::OpKernel<T> {
6469 auto * labels = ctx.Input <Tensor>(" Label" );
6570 auto * out = ctx.Output <Tensor>(" Out" );
6671
67- auto x_data = x->data <T>();
68- auto out_data = out->mutable_data <T>(ctx.GetPlace ());
72+ const auto * x_data = x->data <T>();
73+ auto * out_data = out->mutable_data <T>(ctx.GetPlace ());
6974 auto x_numel = x->numel ();
7075
71- platform::GpuLaunchConfig config =
72- platform::GetGpuLaunchConfig1D (ctx.cuda_device_context (), x_numel);
73-
74- Tensor x_cpu;
75- framework::TensorCopy (*x, platform::CPUPlace (), &x_cpu);
76- T* x_cpu_data = x_cpu.data <T>();
77-
78- for (int64_t i = 0 ; i < x_numel; ++i) {
79- PADDLE_ENFORCE_GE (
80- x_cpu_data[i], static_cast <T>(0 ),
81- platform::errors::InvalidArgument (
82- " Illegal input, input must be greater than or equal to 0" ));
83- PADDLE_ENFORCE_LE (
84- x_cpu_data[i], static_cast <T>(1 ),
85- platform::errors::InvalidArgument (
86- " Illegal input, input must be less than or equal to 1" ));
87- }
88-
8976 auto & dev_ctx = ctx.cuda_device_context ();
77+ platform::GpuLaunchConfig config =
78+ platform::GetGpuLaunchConfig1D (dev_ctx, x_numel);
9079
9180 GPUBCELossForward<T><<<config.block_per_grid, config.thread_per_block, 0 ,
9281 dev_ctx.stream()>>> (x_data, labels->data <T>(),
@@ -102,9 +91,10 @@ class BCELossGradCUDAKernel : public framework::OpKernel<T> {
10291 auto * labels = ctx.Input <Tensor>(" Label" );
10392 auto * dout = ctx.Input <Tensor>(framework::GradVarName (" Out" ));
10493 auto * dx = ctx.Output <Tensor>(framework::GradVarName (" X" ));
105- auto dx_data = dx->mutable_data <T>(ctx.GetPlace ());
10694
10795 int x_numel = x->numel ();
96+ auto * dx_data = dx->mutable_data <T>(ctx.GetPlace ());
97+
10898 auto & dev_ctx = ctx.cuda_device_context ();
10999 platform::GpuLaunchConfig config =
110100 platform::GetGpuLaunchConfig1D (dev_ctx, x_numel);
0 commit comments