Skip to content

Commit 66d7f98

Browse files
[Accuracy diff No.168] Fix accuracy (output type) diff for paddle.floor and paddle.ceil API (#74598)
* fix(activation_kernel.cc/cu): fix output type diff for floor/ceil kernel * fix(test_activation_op.py): add unit test * fix(full_kernel.cc/cu): add int8 support for full_like * fix(activation_functor.h): fix floor/ceil functor for int dtype input * fix(test_activation_op.py): add unit test
1 parent 0c0e353 commit 66d7f98

File tree

12 files changed

+212
-26
lines changed

12 files changed

+212
-26
lines changed

paddle/fluid/eager/auto_code_generator/generator/eager_gen.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,11 @@
152152
"asinh": ["x"],
153153
"atan": ["x"],
154154
"atanh": ["x"],
155-
"ceil": ["x"],
156155
"cos": ["x"],
157156
"cosh": ["x"],
158157
"digamma": ["x"],
159158
"erf": ["x"],
160159
"erfinv": ["x"],
161-
"floor": ["x"],
162160
"i0": ["x"],
163161
"i0e": ["x"],
164162
"i1": ["x"],
@@ -181,10 +179,7 @@
181179

182180
# ops support casting int tensor into float32 to do forward calculation,
183181
# and it is valid to cast float32 gradient back to int tensor.
184-
type_autocast_valid_grad_op_list = {
185-
"ceil",
186-
"floor",
187-
}
182+
type_autocast_valid_grad_op_list = {}
188183

189184
# dict of special api that forward api's output will affect backward api's output
190185
# backward api's output usually affected by backward api's input

paddle/fluid/pir/dialect/op_generator/api_gen.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,11 @@
8383
"asinh": ["x"],
8484
"atan": ["x"],
8585
"atanh": ["x"],
86-
"ceil": ["x"],
8786
"cos": ["x"],
8887
"cosh": ["x"],
8988
"digamma": ["x"],
9089
"erf": ["x"],
9190
"erfinv": ["x"],
92-
"floor": ["x"],
9391
"i0": ["x"],
9492
"i0e": ["x"],
9593
"i1": ["x"],
@@ -112,10 +110,7 @@
112110

113111
# ops support casting int tensor into float32 to do forward calculation,
114112
# and it is valid to cast float32 gradient back to int tensor.
115-
type_autocast_valid_grad_op_list = {
116-
"ceil",
117-
"floor",
118-
}
113+
type_autocast_valid_grad_op_list = {}
119114

120115
PD_MANUAL_API_LIST = {
121116
'embedding_grad',

paddle/phi/kernels/cpu/activation_grad_kernel.cc

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,8 +483,6 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(log_double_grad,
483483
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad,
484484
HardSwishGradKernel)
485485
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
486-
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
487-
PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel)
488486
PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_grad, CeluGradKernel)
489487
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(celu_double_grad,
490488
CeluDoubleGradKernel)
@@ -541,3 +539,31 @@ PD_REGISTER_KERNEL(pow_triple_grad,
541539
int64_t,
542540
phi::dtype::complex<float>,
543541
phi::dtype::complex<double>) {}
542+
543+
PD_REGISTER_KERNEL(ceil_grad,
544+
CPU,
545+
ALL_LAYOUT,
546+
phi::CeilGradKernel,
547+
float,
548+
double,
549+
uint8_t,
550+
int8_t,
551+
int16_t,
552+
int,
553+
int64_t,
554+
phi::dtype::float16,
555+
phi::dtype::bfloat16) {}
556+
557+
PD_REGISTER_KERNEL(floor_grad,
558+
CPU,
559+
ALL_LAYOUT,
560+
phi::FloorGradKernel,
561+
float,
562+
double,
563+
uint8_t,
564+
int8_t,
565+
int16_t,
566+
int,
567+
int64_t,
568+
phi::dtype::float16,
569+
phi::dtype::bfloat16) {}

paddle/phi/kernels/cpu/activation_kernel.cc

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,6 @@ PD_REGISTER_ACTIVATION_KERNEL(hardsigmoid, HardSigmoidKernel)
254254
PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
255255
PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel)
256256
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(hardswish, HardSwishKernel)
257-
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
258-
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
259257
PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
260258

261259
PD_REGISTER_KERNEL(
@@ -381,3 +379,31 @@ PD_REGISTER_KERNEL(pow,
381379
int64_t,
382380
phi::dtype::complex<float>,
383381
phi::dtype::complex<double>) {}
382+
383+
PD_REGISTER_KERNEL(ceil,
384+
CPU,
385+
ALL_LAYOUT,
386+
phi::CeilKernel,
387+
float,
388+
double,
389+
uint8_t,
390+
int8_t,
391+
int16_t,
392+
int,
393+
int64_t,
394+
phi::dtype::float16,
395+
phi::dtype::bfloat16) {}
396+
397+
PD_REGISTER_KERNEL(floor,
398+
CPU,
399+
ALL_LAYOUT,
400+
phi::FloorKernel,
401+
float,
402+
double,
403+
uint8_t,
404+
int8_t,
405+
int16_t,
406+
int,
407+
int64_t,
408+
phi::dtype::float16,
409+
phi::dtype::bfloat16) {}

paddle/phi/kernels/cpu/full_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ PD_REGISTER_KERNEL(full_like,
142142
float,
143143
double,
144144
uint8_t,
145+
int8_t,
145146
int16_t,
146147
int,
147148
int64_t,

paddle/phi/kernels/funcs/activation_functor.h

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3036,7 +3036,16 @@ template <typename T>
30363036
struct FloorFunctor : public BaseActivationFunctor<T> {
30373037
template <typename Device, typename X, typename Out>
30383038
void operator()(Device d, X x, Out out) const {
3039-
out.device(d) = x.floor();
3039+
if constexpr ((std::is_same<T, uint8_t>::value) ||
3040+
(std::is_same<T, int8_t>::value) ||
3041+
(std::is_same<T, uint16_t>::value) ||
3042+
(std::is_same<T, int16_t>::value) ||
3043+
(std::is_same<T, int>::value) ||
3044+
(std::is_same<T, int64_t>::value)) {
3045+
out.device(d) = x;
3046+
} else {
3047+
out.device(d) = x.floor();
3048+
}
30403049
}
30413050
};
30423051

@@ -3160,7 +3169,16 @@ template <typename T>
31603169
struct CeilFunctor : public BaseActivationFunctor<T> {
31613170
template <typename Device, typename X, typename Out>
31623171
void operator()(Device d, X x, Out out) const {
3163-
out.device(d) = x.ceil();
3172+
if constexpr ((std::is_same<T, uint8_t>::value) ||
3173+
(std::is_same<T, int8_t>::value) ||
3174+
(std::is_same<T, uint16_t>::value) ||
3175+
(std::is_same<T, int16_t>::value) ||
3176+
(std::is_same<T, int>::value) ||
3177+
(std::is_same<T, int64_t>::value)) {
3178+
out.device(d) = x;
3179+
} else {
3180+
out.device(d) = x.ceil();
3181+
}
31643182
}
31653183
};
31663184

@@ -5403,7 +5421,16 @@ struct CudaCeilFunctor : public BaseActivationFunctor<T> {
54035421
// ceil(x) = ceil(x)
54045422
__device__ __forceinline__ T operator()(const T arg_x) const {
54055423
MPType x = static_cast<MPType>(arg_x);
5406-
return static_cast<T>(ceil(x));
5424+
if constexpr ((std::is_same<T, uint8_t>::value) ||
5425+
(std::is_same<T, int8_t>::value) ||
5426+
(std::is_same<T, uint16_t>::value) ||
5427+
(std::is_same<T, int16_t>::value) ||
5428+
(std::is_same<T, int>::value) ||
5429+
(std::is_same<T, int64_t>::value)) {
5430+
return static_cast<T>(x);
5431+
} else {
5432+
return static_cast<T>(ceil(x));
5433+
}
54075434
}
54085435
};
54095436

@@ -5492,7 +5519,16 @@ struct CudaFloorFunctor : public BaseActivationFunctor<T> {
54925519
// floor(x) = floor(x)
54935520
__device__ __forceinline__ T operator()(const T arg_x) const {
54945521
MPType x = static_cast<MPType>(arg_x);
5495-
return static_cast<T>(floor(x));
5522+
if constexpr ((std::is_same<T, uint8_t>::value) ||
5523+
(std::is_same<T, int8_t>::value) ||
5524+
(std::is_same<T, uint16_t>::value) ||
5525+
(std::is_same<T, int16_t>::value) ||
5526+
(std::is_same<T, int>::value) ||
5527+
(std::is_same<T, int64_t>::value)) {
5528+
return static_cast<T>(x);
5529+
} else {
5530+
return static_cast<T>(floor(x));
5531+
}
54965532
}
54975533
};
54985534

paddle/phi/kernels/gpu/activation_grad_kernel.cu

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,8 +554,6 @@ PD_REGISTER_KERNEL(log_double_grad,
554554
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad,
555555
HardSwishGradKernel)
556556
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
557-
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
558-
PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel)
559557
PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_grad, CeluGradKernel)
560558
PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_double_grad, CeluDoubleGradKernel)
561559

@@ -617,3 +615,29 @@ PD_REGISTER_KERNEL(pow_triple_grad,
617615
phi::dtype::bfloat16,
618616
phi::dtype::complex<float>,
619617
phi::dtype::complex<double>) {}
618+
PD_REGISTER_KERNEL(ceil_grad,
619+
GPU,
620+
ALL_LAYOUT,
621+
phi::CeilGradKernel,
622+
float,
623+
double,
624+
uint8_t,
625+
int8_t,
626+
int16_t,
627+
int,
628+
int64_t,
629+
phi::dtype::float16,
630+
phi::dtype::bfloat16) {}
631+
PD_REGISTER_KERNEL(floor_grad,
632+
GPU,
633+
ALL_LAYOUT,
634+
phi::FloorGradKernel,
635+
float,
636+
double,
637+
uint8_t,
638+
int8_t,
639+
int16_t,
640+
int,
641+
int64_t,
642+
phi::dtype::float16,
643+
phi::dtype::bfloat16) {}

paddle/phi/kernels/gpu/activation_kernel.cu

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,6 @@ PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel)
347347
PD_REGISTER_ACTIVATION_KERNEL(hardsigmoid, HardSigmoidKernel)
348348
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(hardswish, HardSwishKernel)
349349
PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
350-
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
351-
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
352350
PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
353351
PD_REGISTER_ACTIVATION_KERNEL(selu, SeluKernel)
354352
PD_REGISTER_ACTIVATION_KERNEL(logit, LogitCUDAKernel)
@@ -435,3 +433,29 @@ PD_REGISTER_KERNEL(pow,
435433
phi::dtype::bfloat16,
436434
phi::dtype::complex<float>,
437435
phi::dtype::complex<double>) {}
436+
PD_REGISTER_KERNEL(ceil,
437+
GPU,
438+
ALL_LAYOUT,
439+
phi::CeilKernel,
440+
float,
441+
double,
442+
uint8_t,
443+
int8_t,
444+
int16_t,
445+
int,
446+
int64_t,
447+
phi::dtype::float16,
448+
phi::dtype::bfloat16) {}
449+
PD_REGISTER_KERNEL(floor,
450+
GPU,
451+
ALL_LAYOUT,
452+
phi::FloorKernel,
453+
float,
454+
double,
455+
uint8_t,
456+
int8_t,
457+
int16_t,
458+
int,
459+
int64_t,
460+
phi::dtype::float16,
461+
phi::dtype::bfloat16) {}

paddle/phi/kernels/xpu/activation_kernel.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,16 @@ PD_REGISTER_KERNEL(acos,
777777
#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
778778
PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {}
779779

780-
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
781780
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
782781
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
783782
PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
783+
784+
PD_REGISTER_KERNEL(floor,
785+
XPU,
786+
ALL_LAYOUT,
787+
phi::FloorKernel,
788+
float,
789+
int,
790+
int64_t,
791+
phi::dtype::float16,
792+
phi::dtype::bfloat16) {}

paddle/phi/kernels/xpu/full_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ PD_REGISTER_KERNEL(full_like,
174174
float,
175175
double,
176176
uint8_t,
177+
int8_t,
177178
int16_t,
178179
int,
179180
int64_t,

0 commit comments

Comments
 (0)