Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions paddle/fluid/operators/instance_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,22 @@ class InstanceNormKernel<platform::CPUDeviceContext, T>
auto &dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
auto *place = dev_ctx.eigen_device();

Eigen::DSizes<int, 2> shape(NxC, sample_size);
// Once eigen on Windows is updated, the if branch can be removed.
#ifndef EIGEN_HAS_INDEX_LIST
Eigen::DSizes<int, 2> bcast(1, sample_size);
Eigen::DSizes<int, 2> C_shape(C, 1);
Eigen::DSizes<int, 2> NxC_shape(NxC, 1);
Eigen::DSizes<int, 2> shape(NxC, sample_size);
Eigen::DSizes<int, 1> rdims(1);
#else
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以直接替换成IndexList么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前由于windows上eigen版本未升级,因此直接替换为IndexList会导致编译错误。待windows也升级了eigen后,可以直接使用IndexList

Eigen::IndexList<Eigen::type2index<1>, int> bcast;
bcast.set(1, sample_size);
Eigen::IndexList<int, Eigen::type2index<1>> C_shape;
C_shape.set(0, C);
Eigen::IndexList<int, Eigen::type2index<1>> NxC_shape;
NxC_shape.set(0, NxC);
Eigen::IndexList<Eigen::type2index<1>> rdims;
#endif

math::SetConstant<platform::CPUDeviceContext, T> set_constant;

Expand All @@ -201,8 +213,6 @@ class InstanceNormKernel<platform::CPUDeviceContext, T>
auto x_e = framework::EigenVector<T>::Flatten(*x);
auto x_arr = x_e.reshape(shape);

Eigen::DSizes<int, 1> rdims(1);

saved_mean_e.device(*place) = x_arr.mean(rdims);
auto saved_variance_arr =
(x_arr - saved_mean_e.broadcast(bcast)).square().mean(rdims) + epsilon;
Expand Down Expand Up @@ -316,14 +326,25 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
auto &dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
auto *place = dev_ctx.eigen_device();

Eigen::DSizes<int, 2> rshape(NxC, sample_size);
Eigen::DSizes<int, 2> param_shape(N, C);
Eigen::DSizes<int, 2> shape(NxC, sample_size);
#ifndef EIGEN_HAS_INDEX_LIST
Eigen::DSizes<int, 1> rdims(0);
Eigen::DSizes<int, 1> mean_rdims(1);
Eigen::DSizes<int, 2> rshape(NxC, sample_size);
Eigen::DSizes<int, 2> bcast(1, sample_size);
Eigen::DSizes<int, 2> C_shape(C, 1);
Eigen::DSizes<int, 2> NxC_shape(NxC, 1);
Eigen::DSizes<int, 2> param_shape(N, C);
Eigen::DSizes<int, 2> shape(NxC, sample_size);
#else
Eigen::IndexList<Eigen::type2index<0>> rdims;
Eigen::IndexList<Eigen::type2index<1>> mean_rdims;
Eigen::IndexList<Eigen::type2index<1>, int> bcast;
bcast.set(1, sample_size);
Eigen::IndexList<int, Eigen::type2index<1>> C_shape;
C_shape.set(0, C);
Eigen::IndexList<int, Eigen::type2index<1>> NxC_shape;
NxC_shape.set(0, NxC);
#endif

math::SetConstant<platform::CPUDeviceContext, T> set_constant;

Expand Down