Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions paddle/fluid/operators/group_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ class GroupNormOp : public framework::OperatorWithKernel {
"The Attr(groups) of Op(group_norm) must be "
"greater than or equal to 1. But received: groups is [%s].",
groups));
PADDLE_ENFORCE_EQ(
channel_num % groups, 0,
platform::errors::InvalidArgument(
"Expected number of channels in input to be divisible by "
"num_groups, but got input channel is %d and num_groups is %d",
channel_num, groups));

if (ctx->HasInput("Scale")) {
PADDLE_ENFORCE_EQ(
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/operators/group_norm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = (C - 1) / groups + 1;
const int group_size = C / groups;

const int W =
(data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1]
: x_dims[x_dims.size() - 2]);
Expand Down Expand Up @@ -314,7 +315,7 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = (C - 1) / groups + 1;
const int group_size = C / groups;
const int W =
(data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1]
: x_dims[x_dims.size() - 2]);
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/group_norm_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class GroupNormKernel : public framework::OpKernel<T> {
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = (C - 1) / groups + 1;
const int group_size = C / groups;

y->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
Expand Down Expand Up @@ -100,7 +100,7 @@ class GroupNormKernel : public framework::OpKernel<T> {
int imid;
for (imid = 0; imid < imsize - (imsize % M);
imid += M, iter_x_data += M) {
// TODO(gaoxiang)Because AVX/AVX2/AVX512 can not directly used
// TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance
Expand Down Expand Up @@ -138,7 +138,7 @@ class GroupNormKernel : public framework::OpKernel<T> {
int imid;
for (imid = 0; imid < imsize - (imsize % M);
imid += M, iter_x_data += M * C) {
// TODO(gaoxiang)Because AVX/AVX2/AVX512 can not directly used
// TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance
Expand Down Expand Up @@ -236,7 +236,7 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = (C - 1) / groups + 1;
const int group_size = C / groups;

d_x->mutable_data<T>(ctx.GetPlace());
math::SetConstant<DeviceContext, T> set_zero;
Expand Down