Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions paddle/fluid/operators/group_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ class GroupNormOp : public framework::OperatorWithKernel {
"GroupNorm");

auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(
x_dim.size(), 2,
platform::errors::InvalidArgument(
"The Input(X)'s dimension of Op(group_norm) must be "
"greater than 1. But received: %u-D Tensor, which shape is [%s].",
x_dim.size(), x_dim));

const std::string data_layout_str =
ctx->Attrs().Get<std::string>("data_layout");
const framework::DataLayout data_layout =
Expand Down
25 changes: 20 additions & 5 deletions paddle/fluid/operators/group_norm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,16 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();

int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3]
: x_dims[1] * x_dims[2]);

int imsize = 1;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
#ifdef __HIPCC__
int block_size = std::max(std::min(256, imsize), 64);
#else
Expand Down Expand Up @@ -349,8 +356,16 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();

int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3]
: x_dims[1] * x_dims[2]);
int imsize = 1;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}

#ifdef __HIPCC__
int block_size = std::max(std::min(256, imsize), 64);
Expand Down
25 changes: 20 additions & 5 deletions paddle/fluid/operators/group_norm_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,16 @@ class GroupNormKernel : public framework::OpKernel<T> {
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();

int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3]
: x_dims[1] * x_dims[2]);

int imsize = 1;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
auto* iter_x_data = x_data;
auto* iter_y_data = y_data;
for (int bid = 0; bid < x_dims[0]; bid++) {
Expand Down Expand Up @@ -257,8 +264,16 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();

int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3]
: x_dims[1] * x_dims[2]);
int imsize = 1;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
auto* iter_x_data = x_data;
auto* iter_d_x_data = d_x_data;
auto* iter_y_data = y_data;
Expand Down
78 changes: 63 additions & 15 deletions python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,29 @@
import paddle


def group_norm_naive_for_general_dimension(x, scale, bias, epsilon, groups):
# original version group norm only support 4-D tensor
# this function generalizes to support differnt dimensions tensor (>= 2-D)
input_shape = x.shape
N, C = x.shape[0], x.shape[1]
G = groups
x = x.reshape((N * G, -1))
mean = np.mean(x, axis=1, keepdims=True)
var = np.var(x, axis=1, keepdims=True)
output = (x - mean) / np.sqrt(var + epsilon)
output = output.reshape(input_shape) * scale.reshape(
(-1, 1, 1)) + bias.reshape((-1, 1, 1))
return output


class TestDygraphGroupNormv2(unittest.TestCase):
def test_dygraph(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"):
places.append(fluid.CUDAPlace(0))
shapes = [[2, 2, 2, 2], [2, 2, 4], [4, 2], [4, 2, 6, 6, 2],
[2, 2, 2, 2, 2, 2]]
for p in places:
shape = [2, 2, 2, 2]

def compute_v1(x):
with fluid.dygraph.guard(p):
Expand Down Expand Up @@ -62,23 +78,26 @@ def attr_data_format():

self.assertRaises(ValueError, attr_data_format)

x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x)
y2 = compute_v2(x)
result = np.allclose(y1, y2, atol=1e-5)
if not result:
print("y1:", y1, "\ty2:", y2)
self.assertTrue(result)
test_weight_bias_false()
test_nn_exception()
for shape in shapes:
x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x)
y2 = compute_v2(x)
result = np.allclose(y1, y2, atol=1e-5)
if not result:
print("y1:", y1, "\ty2:", y2)
self.assertTrue(result)
test_weight_bias_false()
test_nn_exception()

def test_static(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"):
places.append(fluid.CUDAPlace(0))
shapes = [[2, 6, 2, 2], [2, 6, 4], [4, 6], [4, 6, 6, 6, 2],
[4, 6, 2, 2, 2, 2]]
for p in places:
exe = fluid.Executor(p)
shape = [2, 6, 2, 2]

def compute_v1(x_np):
with program_guard(Program(), Program()):
Expand All @@ -98,10 +117,39 @@ def compute_v2(x_np):
r = exe.run(feed={'x': x_np}, fetch_list=[y])[0]
return r

x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x)
y2 = compute_v2(x)
self.assertTrue(np.allclose(y1, y2, atol=1e-5))
for shape in shapes:
x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x)
y2 = compute_v2(x)
self.assertTrue(np.allclose(y1, y2, atol=1e-5))


class TestGroupNormAPIV2_With_General_Dimensions(unittest.TestCase):
def test_numerical_accuracy(self):
paddle.disable_static()
shapes = [(2, 6), (2, 6, 4), (2, 6, 4, 4), (2, 6, 6, 6, 2), (2, 6, 6, 6,
2, 3)]
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"):
places.append(fluid.CUDAPlace(0))

for place in places:
for shape in shapes:
scale = np.array([1]).astype("float32")
bias = np.array([0]).astype("float32")
data = np.random.random(shape).astype("float32")
expect_res1 = group_norm_naive_for_general_dimension(
data, scale, bias, epsilon=1e-5, groups=6)
expect_res2 = group_norm_naive_for_general_dimension(
data, scale, bias, epsilon=1e-5, groups=2)

gn1 = paddle.nn.GroupNorm(num_channels=6, num_groups=6)
gn2 = paddle.nn.GroupNorm(num_channels=6, num_groups=2)
data_pd = paddle.to_tensor(data)
result1 = gn1(data_pd).numpy()
result2 = gn2(data_pd).numpy()
self.assertTrue(np.allclose(result1, expect_res1, atol=1e-5))
self.assertTrue(np.allclose(result2, expect_res2, atol=1e-5))


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/layer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,8 @@ class GroupNorm(Layer):
name(str, optional): Name for the GroupNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..

Shape:
- x: 4-D tensor with shape: (batch, num_features, height, weight).
- output: 4-D tensor with same shape as input x.
- x: Tensor with shape: (batch, num_features, *).
- output: The same shape as input x.

Returns:
None
Expand Down