From 6a831b24ac8854a8642f9e9405000bb227e0beef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A3=E5=9C=A8=E5=AD=A6=E4=B9=A0?= <1181749441@qq.com> Date: Mon, 11 Aug 2025 11:12:32 +0800 Subject: [PATCH 1/4] fix norm max grad --- paddle/phi/kernels/gpu/p_norm_grad_kernel.cu | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu index fdfed25b3dda8f..71b5aa918d5108 100644 --- a/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu @@ -35,8 +35,12 @@ struct AbsMaxAndMinGradFunctor { DY* dy, const Dim& dim, int size) { - dx->device(place) = dy->broadcast(dim) * (*x).sign() * - ((*x).abs() == y->broadcast(dim)).template cast(); + auto abs_x = x->abs(); + auto y_bc = y->broadcast(dim); + auto mask = (abs_x == y_bc).template cast(); + auto count = mask.sum(dim); + + dx->device(place) = dy->broadcast(dim) * x->sign() * mask / count; } }; From 48e15b6401e08ac325223d79d48c4607e410eef1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A3=E5=9C=A8=E5=AD=A6=E4=B9=A0?= <1181749441@qq.com> Date: Mon, 11 Aug 2025 11:16:21 +0800 Subject: [PATCH 2/4] fix --- paddle/phi/kernels/gpu/p_norm_grad_kernel.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu index 71b5aa918d5108..21dd1c2da3a1a7 100644 --- a/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu @@ -40,7 +40,7 @@ struct AbsMaxAndMinGradFunctor { auto mask = (abs_x == y_bc).template cast(); auto count = mask.sum(dim); - dx->device(place) = dy->broadcast(dim) * x->sign() * mask / count; + dx->device(place) = y_bc * x->sign() * mask / count; } }; From af295308159c195a172baf7be7e0b956b43b06f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A3=E5=9C=A8=E5=AD=A6=E4=B9=A0?= <1181749441@qq.com> Date: Mon, 11 Aug 2025 15:26:08 +0800 Subject: [PATCH 3/4] fix --- paddle/phi/kernels/gpu/p_norm_grad_kernel.cu | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu index 21dd1c2da3a1a7..9dc279b4332a73 100644 --- a/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu @@ -35,12 +35,20 @@ struct AbsMaxAndMinGradFunctor { DY* dy, const Dim& dim, int size) { - auto abs_x = x->abs(); + auto abs_x = (*x).abs(); auto y_bc = y->broadcast(dim); + auto dy_bc = dy->broadcast(dim); auto mask = (abs_x == y_bc).template cast(); - auto count = mask.sum(dim); - dx->device(place) = y_bc * x->sign() * mask / count; + Eigen::array reduce_dim = {static_cast(dim.size() - 1)}; + auto shape1 = (*x).dimensions(); + shape1[shape1.size() - 1] = 1; + auto shape2 = (*x).dimensions(); + for (size_t i = 0; i < shape2.size() - 1; i++) shape2[i] = 1; + + auto count = mask.sum(reduce_dim).reshape(shape1).broadcast(shape2); + + dx->device(place) = dy_bc * (*x).sign() * mask / count; } }; From 74d52c9a80e92802d7e50ac5788b71ce1ba3ad07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A3=E5=9C=A8=E5=AD=A6=E4=B9=A0?= <1181749441@qq.com> Date: Mon, 11 Aug 2025 23:50:08 +0800 Subject: [PATCH 4/4] fix --- paddle/phi/kernels/gpu/p_norm_grad_kernel.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu index 9dc279b4332a73..9606dc157483da 100644 --- a/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu @@ -35,20 +35,20 @@ struct AbsMaxAndMinGradFunctor { DY* dy, const Dim& dim, int size) { - auto abs_x = (*x).abs(); - auto y_bc = y->broadcast(dim); - auto dy_bc = dy->broadcast(dim); - auto mask = (abs_x == y_bc).template cast(); + using MT = typename phi::dtype::MPTypeTrait::Type; + auto abs_x = (*x).abs().template cast(); + auto y_bc = y->broadcast(dim).template cast(); + auto dy_bc = dy->broadcast(dim).template cast(); + auto mask = (abs_x == y_bc).template cast(); Eigen::array reduce_dim = {static_cast(dim.size() - 1)}; auto shape1 = (*x).dimensions(); shape1[shape1.size() - 1] = 1; auto shape2 = (*x).dimensions(); for (size_t i = 0; i < shape2.size() - 1; i++) shape2[i] = 1; - auto count = mask.sum(reduce_dim).reshape(shape1).broadcast(shape2); - dx->device(place) = dy_bc * (*x).sign() * mask / count; + dx->device(place) = (dy_bc * (*x).sign() * mask / count).template cast(); } };