Skip to content

Commit 572f972

Browse files
author
zhangting2020
committed
use IndexList only when EIGEN_HAS_INDEX_LIST is true
1 parent 1636a11 commit 572f972

1 file changed

Lines changed: 17 additions & 1 deletion

File tree

paddle/fluid/operators/instance_norm_op.cc

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,22 @@ class InstanceNormKernel<platform::CPUDeviceContext, T>
181181
auto &dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
182182
auto *place = dev_ctx.eigen_device();
183183

184+
Eigen::DSizes<int, 2> shape(NxC, sample_size);
185+
// Once eigen on Windows is updated, the if branch can be removed.
186+
#ifndef EIGEN_HAS_INDEX_LIST
187+
Eigen::DSizes<int, 2> bcast(1, sample_size);
188+
Eigen::DSizes<int, 2> C_shape(C, 1);
189+
Eigen::DSizes<int, 2> NxC_shape(NxC, 1);
190+
Eigen::DSizes<int, 1> rdims(1);
191+
#else
184192
Eigen::IndexList<Eigen::type2index<1>, int> bcast;
185193
bcast.set(1, sample_size);
186194
Eigen::IndexList<int, Eigen::type2index<1>> C_shape;
187195
C_shape.set(0, C);
188196
Eigen::IndexList<int, Eigen::type2index<1>> NxC_shape;
189197
NxC_shape.set(0, NxC);
190198
Eigen::IndexList<Eigen::type2index<1>> rdims;
199+
#endif
191200

192201
math::SetConstant<platform::CPUDeviceContext, T> set_constant;
193202

@@ -201,7 +210,6 @@ class InstanceNormKernel<platform::CPUDeviceContext, T>
201210
auto saved_variance_a = framework::EigenVector<T>::Flatten(*saved_variance);
202211
auto saved_variance_e = saved_variance_a.reshape(NxC_shape);
203212

204-
Eigen::DSizes<int, 2> shape(NxC, sample_size);
205213
auto x_e = framework::EigenVector<T>::Flatten(*x);
206214
auto x_arr = x_e.reshape(shape);
207215

@@ -321,6 +329,13 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
321329
Eigen::DSizes<int, 2> rshape(NxC, sample_size);
322330
Eigen::DSizes<int, 2> param_shape(N, C);
323331
Eigen::DSizes<int, 2> shape(NxC, sample_size);
332+
#ifndef EIGEN_HAS_INDEX_LIST
333+
Eigen::DSizes<int, 1> rdims(0);
334+
Eigen::DSizes<int, 1> mean_rdims(1);
335+
Eigen::DSizes<int, 2> bcast(1, sample_size);
336+
Eigen::DSizes<int, 2> C_shape(C, 1);
337+
Eigen::DSizes<int, 2> NxC_shape(NxC, 1);
338+
#else
324339
Eigen::IndexList<Eigen::type2index<0>> rdims;
325340
Eigen::IndexList<Eigen::type2index<1>> mean_rdims;
326341
Eigen::IndexList<Eigen::type2index<1>, int> bcast;
@@ -329,6 +344,7 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
329344
C_shape.set(0, C);
330345
Eigen::IndexList<int, Eigen::type2index<1>> NxC_shape;
331346
NxC_shape.set(0, NxC);
347+
#endif
332348

333349
math::SetConstant<platform::CPUDeviceContext, T> set_constant;
334350

0 commit comments

Comments
 (0)