Skip to content

Commit a593bf2

Browse files
[Accuracy diff No.43-44] Accuracy grid sample (#74555)
* fix accuracy for grid_sample * fix grid_sample accuracy * fix grid_sample test
1 parent 7f0baf6 commit a593bf2

File tree

5 files changed

+179
-68
lines changed

5 files changed

+179
-68
lines changed

paddle/phi/kernels/cpu/grid_sample_grad_kernel.cc

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -248,13 +248,15 @@ static void GatherOutputGradToInputGrad(const DenseTensor& output_grad,
248248
for (int i = 0; i < n; i++) {
249249
for (int k = 0; k < out_h; k++) {
250250
for (int l = 0; l < out_w; l++) {
251-
if (IsInBound(
252-
x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), (T)(in_h - 1))) {
251+
if (IsInBound<int>(static_cast<int>(x_t(i, k, l)),
252+
static_cast<int>(y_t(i, k, l)),
253+
(in_w - 1),
254+
(in_h - 1))) {
253255
for (int j = 0; j < c; j++) {
254256
input_grad_t(i,
255257
j,
256-
static_cast<int>(round(y_t(i, k, l))),
257-
static_cast<int>(round(x_t(i, k, l)))) +=
258+
static_cast<int>(y_t(i, k, l)),
259+
static_cast<int>(x_t(i, k, l))) +=
258260
output_grad_t(i, j, k, l) * d1_t(i, k, l) * d2_t(i, k, l);
259261
}
260262
}
@@ -293,18 +295,18 @@ static void Gather3DOutputGradToInputGrad(const DenseTensor& output_grad,
293295
for (int m = 0; m < out_d; m++) {
294296
for (int k = 0; k < out_h; k++) {
295297
for (int l = 0; l < out_w; l++) {
296-
if (IsInBound3D(x_t(i, m, k, l),
297-
y_t(i, m, k, l),
298-
z_t(i, m, k, l),
299-
(T)(in_w - 1),
300-
(T)(in_h - 1),
301-
(T)(in_d - 1))) {
298+
if (IsInBound3D<int>(static_cast<int>(x_t(i, m, k, l)),
299+
static_cast<int>(y_t(i, m, k, l)),
300+
static_cast<int>(z_t(i, m, k, l)),
301+
(in_w - 1),
302+
(in_h - 1),
303+
(in_d - 1))) {
302304
for (int j = 0; j < c; j++) {
303305
input_grad_t(i,
304306
j,
305-
static_cast<int>(round(z_t(i, m, k, l))),
306-
static_cast<int>(round(y_t(i, m, k, l))),
307-
static_cast<int>(round(x_t(i, m, k, l)))) +=
307+
static_cast<int>(z_t(i, m, k, l)),
308+
static_cast<int>(y_t(i, m, k, l)),
309+
static_cast<int>(x_t(i, m, k, l))) +=
308310
output_grad_t(i, j, m, k, l) * d1_t(i, m, k, l) *
309311
d2_t(i, m, k, l) * d3_t(i, m, k, l);
310312
}
@@ -590,13 +592,15 @@ static void GatherOutputGradToInputGrad(const DenseTensor& output_grad,
590592
for (int i = 0; i < n; i++) {
591593
for (int k = 0; k < out_h; k++) {
592594
for (int l = 0; l < out_w; l++) {
593-
if (IsInBound(
594-
x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), (T)(in_h - 1))) {
595+
if (IsInBound<int>(static_cast<int>(std::nearbyint(x_t(i, k, l))),
596+
static_cast<int>(std::nearbyint(y_t(i, k, l))),
597+
(in_w - 1),
598+
(in_h - 1))) {
595599
for (int j = 0; j < c; j++) {
596600
input_grad_t(i,
597601
j,
598-
static_cast<int>(round(y_t(i, k, l))),
599-
static_cast<int>(round(x_t(i, k, l)))) +=
602+
static_cast<int>(std::nearbyint(y_t(i, k, l))),
603+
static_cast<int>(std::nearbyint(x_t(i, k, l)))) +=
600604
output_grad_t(i, j, k, l);
601605
}
602606
}
@@ -628,18 +632,19 @@ static void Gather3DOutputGradToInputGrad(const DenseTensor& output_grad,
628632
for (int m = 0; m < out_d; m++) {
629633
for (int k = 0; k < out_h; k++) {
630634
for (int l = 0; l < out_w; l++) {
631-
if (IsInBound3D(x_t(i, m, k, l),
632-
y_t(i, m, k, l),
633-
z_t(i, m, k, l),
634-
(T)(in_w - 1),
635-
(T)(in_h - 1),
636-
(T)(in_d - 1))) {
635+
if (IsInBound3D<int>(
636+
static_cast<int>(std::nearbyint(x_t(i, m, k, l))),
637+
static_cast<int>(std::nearbyint(y_t(i, m, k, l))),
638+
static_cast<int>(std::nearbyint(z_t(i, m, k, l))),
639+
(in_w - 1),
640+
(in_h - 1),
641+
(in_d - 1))) {
637642
for (int j = 0; j < c; j++) {
638643
input_grad_t(i,
639644
j,
640-
static_cast<int>(round(z_t(i, m, k, l))),
641-
static_cast<int>(round(y_t(i, m, k, l))),
642-
static_cast<int>(round(x_t(i, m, k, l)))) +=
645+
static_cast<int>(std::nearbyint(z_t(i, m, k, l))),
646+
static_cast<int>(std::nearbyint(y_t(i, m, k, l))),
647+
static_cast<int>(std::nearbyint(x_t(i, m, k, l)))) +=
643648
output_grad_t(i, j, m, k, l);
644649
}
645650
}
@@ -673,6 +678,13 @@ void GridSampleGradKernel(const Context& dev_ctx,
673678
return;
674679
}
675680

681+
std::string enum_mode;
682+
if (mode == "nearest") {
683+
enum_mode = "nearest";
684+
} else {
685+
enum_mode = "bilinear";
686+
}
687+
676688
if (x.dims().size() == 4) {
677689
const int n = static_cast<int>(grid.dims()[0]);
678690
const int out_h = static_cast<int>(grid.dims()[1]);
@@ -704,7 +716,10 @@ void GridSampleGradKernel(const Context& dev_ctx,
704716
&grid_y,
705717
&grid_x_scale,
706718
&grid_y_scale);
707-
if (mode == "bilinear") {
719+
if (enum_mode == "nearest") {
720+
GatherOutputGradToInputGrad<T>(out_grad, x_grad, grid_x, grid_y);
721+
722+
} else if (enum_mode == "bilinear") {
708723
GatherBilinearGrad<T>(dev_ctx,
709724
x,
710725
out_grad,
@@ -714,12 +729,6 @@ void GridSampleGradKernel(const Context& dev_ctx,
714729
&grid_y_scale,
715730
x_grad,
716731
grid_grad);
717-
} else {
718-
auto grid_x_t = EigenTensor<T, 3>::From(grid_x);
719-
auto grid_y_t = EigenTensor<T, 3>::From(grid_y);
720-
grid_x_t = grid_x_t.round();
721-
grid_y_t = grid_y_t.round();
722-
GatherOutputGradToInputGrad<T>(out_grad, x_grad, grid_x, grid_y);
723732
}
724733
} else {
725734
const int n = static_cast<int>(grid.dims()[0]);
@@ -757,7 +766,11 @@ void GridSampleGradKernel(const Context& dev_ctx,
757766
&grid_x_scale,
758767
&grid_y_scale,
759768
&grid_z_scale);
760-
if (mode == "bilinear") {
769+
if (enum_mode == "nearest") {
770+
Gather3DOutputGradToInputGrad<T>(
771+
out_grad, x_grad, grid_x, grid_y, grid_z);
772+
773+
} else if (enum_mode == "bilinear") {
761774
Gather3DBilinearGrad<T>(dev_ctx,
762775
x,
763776
out_grad,
@@ -769,9 +782,6 @@ void GridSampleGradKernel(const Context& dev_ctx,
769782
&grid_z_scale,
770783
x_grad,
771784
grid_grad);
772-
} else {
773-
Gather3DOutputGradToInputGrad<T>(
774-
out_grad, x_grad, grid_x, grid_y, grid_z);
775785
}
776786
}
777787
}

paddle/phi/kernels/cpu/grid_sample_kernel.cc

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,14 @@ void GridSampleKernel(const Context& dev_ctx,
316316
dev_ctx.template Alloc<T>(out);
317317
return;
318318
}
319+
320+
std::string enum_mode;
321+
if (mode == "nearest") {
322+
enum_mode = "nearest";
323+
} else {
324+
enum_mode = "bilinear";
325+
}
326+
319327
if (x.dims().size() == 4) {
320328
const int n = static_cast<int>(grid.dims()[0]);
321329
const int out_h = static_cast<int>(grid.dims()[1]);
@@ -338,14 +346,10 @@ void GridSampleKernel(const Context& dev_ctx,
338346
&grid_x,
339347
&grid_y);
340348

341-
if (mode == "bilinear") {
349+
if (enum_mode == "bilinear") {
342350
BilinearInter<T>(dev_ctx, x, &grid_x, &grid_y, out);
343-
} else if (mode == "nearest") {
344-
auto grid_x_t = EigenTensor<T, 3>::From(grid_x);
345-
auto grid_y_t = EigenTensor<T, 3>::From(grid_y);
346-
grid_x_t = grid_x_t.round();
347-
grid_y_t = grid_y_t.round();
348-
GetGridPointValue<T>(x, out, grid_x, grid_y);
351+
} else if (enum_mode == "nearest") {
352+
GetGridPointValue_nearest<T>(x, out, grid_x, grid_y);
349353
}
350354
} else {
351355
const int n = static_cast<int>(grid.dims()[0]);
@@ -372,10 +376,10 @@ void GridSampleKernel(const Context& dev_ctx,
372376
&grid_x,
373377
&grid_y,
374378
&grid_z);
375-
if (mode == "bilinear") {
379+
if (enum_mode == "bilinear") {
376380
Bilinear3DInter<T>(dev_ctx, x, &grid_x, &grid_y, &grid_z, out);
377-
} else if (mode == "nearest") {
378-
Get3DGridPointValue<T>(x, out, grid_x, grid_y, grid_z);
381+
} else if (enum_mode == "nearest") {
382+
Get3DGridPointValue_nearest<T>(x, out, grid_x, grid_y, grid_z);
379383
}
380384
}
381385
}

paddle/phi/kernels/cpu/grid_sample_utils.h

Lines changed: 101 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ void Unnormalize(const CPUContext& dev_ctx,
2626
auto& place = *dev_ctx.eigen_device();
2727
auto grid_slice_t = EigenTensor<T, 3>::From(*grid_slice);
2828

29-
if (!align_corners) {
29+
if (align_corners) {
30+
auto factor = static_cast<T>(max_val * 0.5);
31+
grid_slice_t.device(place) = (grid_slice_t + static_cast<T>(1)) * factor;
32+
} else {
3033
auto factor = static_cast<T>((max_val + 1) * 0.5);
3134
grid_slice_t.device(place) =
3235
(grid_slice_t + static_cast<T>(1)) * factor - static_cast<T>(0.5);
33-
} else {
34-
auto factor = static_cast<T>(max_val * 0.5);
35-
grid_slice_t.device(place) = (grid_slice_t + static_cast<T>(1)) * factor;
3636
}
3737
}
3838

@@ -89,14 +89,51 @@ void GetGridPointValue(const DenseTensor& input,
8989
for (int i = 0; i < n; i++) {
9090
for (int k = 0; k < out_h; k++) {
9191
for (int l = 0; l < out_w; l++) {
92-
if (IsInBound(
93-
x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), (T)(in_h - 1))) {
92+
if (IsInBound<int>(static_cast<int>(x_t(i, k, l)),
93+
static_cast<int>(y_t(i, k, l)),
94+
(in_w - 1),
95+
(in_h - 1))) {
96+
for (int j = 0; j < c; j++) {
97+
output_t(i, j, k, l) = input_t(i,
98+
j,
99+
static_cast<int>(y_t(i, k, l)),
100+
static_cast<int>(x_t(i, k, l)));
101+
}
102+
}
103+
}
104+
}
105+
}
106+
}
107+
108+
template <typename T>
109+
void GetGridPointValue_nearest(const DenseTensor& input,
110+
DenseTensor* output,
111+
const DenseTensor& x,
112+
const DenseTensor& y) {
113+
const int n = input.dims()[0];
114+
const int c = input.dims()[1];
115+
const int in_h = input.dims()[2];
116+
const int in_w = input.dims()[3];
117+
const int out_h = x.dims()[1];
118+
const int out_w = x.dims()[2];
119+
auto x_t = EigenTensor<T, 3>::From(x);
120+
auto y_t = EigenTensor<T, 3>::From(y);
121+
auto output_t = EigenTensor<T, 4>::From(*output).setConstant((T)0);
122+
auto input_t = EigenTensor<T, 4>::From(input);
123+
124+
for (int i = 0; i < n; i++) {
125+
for (int k = 0; k < out_h; k++) {
126+
for (int l = 0; l < out_w; l++) {
127+
if (IsInBound<int>(static_cast<int>(std::nearbyint(x_t(i, k, l))),
128+
static_cast<int>(std::nearbyint(y_t(i, k, l))),
129+
(in_w - 1),
130+
(in_h - 1))) {
94131
for (int j = 0; j < c; j++) {
95132
output_t(i, j, k, l) =
96133
input_t(i,
97134
j,
98-
static_cast<int>(round(y_t(i, k, l))),
99-
static_cast<int>(round(x_t(i, k, l))));
135+
static_cast<int>(std::nearbyint(y_t(i, k, l))),
136+
static_cast<int>(std::nearbyint(x_t(i, k, l))));
100137
}
101138
}
102139
}
@@ -207,19 +244,66 @@ void Get3DGridPointValue(const DenseTensor& input,
207244
for (int m = 0; m < out_d; m++) {
208245
for (int k = 0; k < out_h; k++) {
209246
for (int l = 0; l < out_w; l++) {
210-
if (IsInBound3D(x_t(i, m, k, l),
211-
y_t(i, m, k, l),
212-
z_t(i, m, k, l),
213-
(T)(in_w - 1),
214-
(T)(in_h - 1),
215-
(T)(in_d - 1))) {
247+
if (IsInBound3D<int>(static_cast<int>(x_t(i, m, k, l)),
248+
static_cast<int>(y_t(i, m, k, l)),
249+
static_cast<int>(z_t(i, m, k, l)),
250+
(in_w - 1),
251+
(in_h - 1),
252+
(in_d - 1))) {
253+
for (int j = 0; j < c; j++) {
254+
output_t(i, j, m, k, l) =
255+
input_t(i,
256+
j,
257+
static_cast<int>(z_t(i, m, k, l)),
258+
static_cast<int>(y_t(i, m, k, l)),
259+
static_cast<int>(x_t(i, m, k, l)));
260+
}
261+
}
262+
}
263+
}
264+
}
265+
}
266+
}
267+
268+
template <typename T>
269+
void Get3DGridPointValue_nearest(const DenseTensor& input,
270+
DenseTensor* output,
271+
const DenseTensor& x,
272+
const DenseTensor& y,
273+
const DenseTensor& z) {
274+
const int n = input.dims()[0];
275+
const int c = input.dims()[1];
276+
const int in_d = input.dims()[2];
277+
const int in_h = input.dims()[3];
278+
const int in_w = input.dims()[4];
279+
const int out_d = x.dims()[1];
280+
const int out_h = x.dims()[2];
281+
const int out_w = x.dims()[3];
282+
auto x_t = EigenTensor<T, 4>::From(x);
283+
auto y_t = EigenTensor<T, 4>::From(y);
284+
auto z_t = EigenTensor<T, 4>::From(z);
285+
auto output_t =
286+
EigenTensor<T, 5>::From(*output).setConstant(static_cast<T>(0.0));
287+
auto input_t = EigenTensor<T, 5>::From(input);
288+
289+
for (int i = 0; i < n; i++) {
290+
for (int m = 0; m < out_d; m++) {
291+
for (int k = 0; k < out_h; k++) {
292+
for (int l = 0; l < out_w; l++) {
293+
if (IsInBound3D<int>(
294+
static_cast<int>(std::nearbyint(x_t(i, m, k, l))),
295+
static_cast<int>(std::nearbyint(y_t(i, m, k, l))),
296+
static_cast<int>(std::nearbyint(z_t(i, m, k, l))),
297+
(in_w - 1),
298+
(in_h - 1),
299+
(in_d - 1))) {
216300
for (int j = 0; j < c; j++) {
217301
output_t(i, j, m, k, l) =
218302
input_t(i,
219303
j,
220-
static_cast<int>(round(z_t(i, m, k, l))),
221-
static_cast<int>(round(y_t(i, m, k, l))),
222-
static_cast<int>(round(x_t(i, m, k, l))));
304+
static_cast<int>(std::nearbyint(z_t(i, m, k, l))),
305+
static_cast<int>(std::nearbyint(y_t(i, m, k, l))),
306+
static_cast<int>(std::nearbyint(x_t(i, m, k, l))));
223307
}
224308
}
225309
}

paddle/phi/kernels/gpu/grid_sample_kernel.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,9 @@ __global__ void GridSample3DCudaKernel(const IndexT nthreads,
293293
}
294294
}
295295
} else if (interpolation_mode == Mode::nearest) {
296-
IndexT ix_nearest = static_cast<IndexT>(std::round(ix));
297-
IndexT iy_nearest = static_cast<IndexT>(std::round(iy));
298-
IndexT iz_nearest = static_cast<IndexT>(std::round(iz));
296+
IndexT ix_nearest = static_cast<IndexT>(std::nearbyint(ix));
297+
IndexT iy_nearest = static_cast<IndexT>(std::nearbyint(iy));
298+
IndexT iz_nearest = static_cast<IndexT>(std::nearbyint(iz));
299299

300300
// assign nearest neighbor pixel value to output pixel
301301
const T* inp_ptr_NC = input + n * inp_sN;

test/legacy_test/test_grid_sampler_op.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,16 +379,29 @@ def setUp(self):
379379
}
380380

381381
def test_check_output(self):
382+
self.check_output_with_place(core.CPUPlace(), check_pir=True)
383+
if core.is_compiled_with_cuda():
384+
self.check_output_with_place(core.CUDAPlace(0), check_pir=True)
382385
self.check_output(check_pir=True)
383386

384387
def test_check_grad_normal(self):
385-
self.check_grad(
388+
self.check_grad_with_place(
389+
core.CPUPlace(),
386390
['X', 'Grid'],
387391
'Output',
388392
max_relative_error=0.01,
389393
numeric_grad_delta=self.numeric_grad_delta,
390394
check_pir=True,
391395
)
396+
if core.is_compiled_with_cuda():
397+
self.check_grad_with_place(
398+
core.CUDAPlace(0),
399+
['X', 'Grid'],
400+
'Output',
401+
max_relative_error=0.01,
402+
numeric_grad_delta=self.numeric_grad_delta,
403+
check_pir=True,
404+
)
392405

393406
def initTestCase(self):
394407
self.x_shape = (2, 3, 8, 8)

0 commit comments

Comments
 (0)