@@ -18,6 +18,11 @@ namespace paddle {
1818namespace operators {
1919namespace math {
2020
21+ /*
22+ * All tensors are in NCHW format.
23+ * Ksize, strides, paddings are two elements. These two elements represent
24+ * height and width, respectively.
25+ */
2126template <typename PoolProcess, typename T>
2227class Pool2dFunctor <platform::CPUPlace, PoolProcess, T> {
2328 public:
@@ -73,6 +78,11 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
7378 }
7479};
7580
81+ /*
82+ * All tensors are in NCHW format.
83+ * Ksize, strides, paddings are two elements. These two elements represent height
84+ * and width, respectively.
85+ */
7686template <typename PoolProcess, class T >
7787class Pool2dGradFunctor <platform::CPUPlace, PoolProcess, T> {
7888 public:
@@ -135,6 +145,11 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
135145 }
136146};
137147
148+ /*
149+ * All tensors are in NCHW format.
150+ * Ksize, strides, paddings are two elements. These two elements represent
151+ * height and width, respectively.
152+ */
138153template <class T >
139154class MaxPool2dGradFunctor <platform::CPUPlace, T> {
140155 public:
@@ -197,7 +212,7 @@ class MaxPool2dGradFunctor<platform::CPUPlace, T> {
197212};
198213
199214template class MaxPool2dGradFunctor <platform::CPUPlace, float >;
200- // template class MaxPool2dGradFunctor<platform::CPUPlace, double>;
215+ template class MaxPool2dGradFunctor <platform::CPUPlace, double >;
201216
202217template class Pool2dFunctor <platform::CPUPlace,
203218 paddle::operators::math::MaxPool<float >, float >;
@@ -216,6 +231,11 @@ template class Pool2dGradFunctor<
216231template class Pool2dGradFunctor <
217232 platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double >, double >;
218233
234+ /*
235+ * All tensors are in NCDHW format.
236+ * Ksize, strides, paddings are three elements. These three elements represent
237+ * depth, height and width, respectively.
238+ */
219239template <typename PoolProcess, class T >
220240class Pool3dFunctor <platform::CPUPlace, PoolProcess, T> {
221241 public:
@@ -286,6 +306,11 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
286306 }
287307};
288308
309+ /*
310+ * All tensors are in NCDHW format.
311+ * Ksize, strides, paddings are three elements. These three elements represent
312+ * depth, height and width, respectively.
313+ */
289314template <typename PoolProcess, class T >
290315class Pool3dGradFunctor <platform::CPUPlace, PoolProcess, T> {
291316 public:
@@ -364,6 +389,11 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
364389 }
365390};
366391
392+ /*
393+ * All tensors are in NCDHW format.
394+ * Ksize, strides, paddings are three elements. These three elements represent
395+ * depth, height and width, respectively.
396+ */
367397template <class T >
368398class MaxPool3dGradFunctor <platform::CPUPlace, T> {
369399 public:
@@ -440,7 +470,7 @@ class MaxPool3dGradFunctor<platform::CPUPlace, T> {
440470};
441471
442472template class MaxPool3dGradFunctor <platform::CPUPlace, float >;
443- // template class MaxPool3dGradFunctor<platform::CPUPlace, double>;
473+ template class MaxPool3dGradFunctor <platform::CPUPlace, double >;
444474
445475template class Pool3dFunctor <platform::CPUPlace,
446476 paddle::operators::math::MaxPool<float >, float >;
@@ -459,6 +489,11 @@ template class Pool3dGradFunctor<
459489template class Pool3dGradFunctor <
460490 platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double >, double >;
461491
492+ /*
493+ * All tensors are in NCHW format.
494+ * Ksize, strides, paddings are two elements. These two elements represent
495+ * height and width, respectively.
496+ */
462497template <typename T>
463498class MaxPool2dWithIndexFunctor <platform::CPUPlace, T> {
464499 public:
@@ -519,6 +554,11 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
519554 }
520555};
521556
557+ /*
558+ * All tensors are in NCHW format.
559+ * Ksize, strides, paddings are two elements. These two elements represent
560+ * height and width, respectively.
561+ */
522562template <typename T>
523563class MaxPool2dWithIndexGradFunctor <platform::CPUPlace, T> {
524564 public:
@@ -563,6 +603,11 @@ template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float>;
563603template class MaxPool2dWithIndexFunctor <platform::CPUPlace, double >;
564604template class MaxPool2dWithIndexGradFunctor <platform::CPUPlace, double >;
565605
606+ /*
607+ * All tensors are in NCDHW format.
608+ * Ksize, strides, paddings are three elements. These three elements represent
609+ * depth, height and width, respectively.
610+ */
566611template <typename T>
567612class MaxPool3dWithIndexFunctor <platform::CPUPlace, T> {
568613 public:
@@ -637,6 +682,11 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
637682 }
638683};
639684
685+ /*
686+ * All tensors are in NCDHW format.
687+ * Ksize, strides, paddings are three elements. These three elements represent
688+ * depth, height and width, respectively.
689+ */
640690template <typename T>
641691class MaxPool3dWithIndexGradFunctor <platform::CPUPlace, T> {
642692 public:
0 commit comments