Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
6 changes: 1 addition & 5 deletions example/quantization/imagenet_gen_qsym_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
rgb_mean = '123.68,116.779,103.939'
rgb_std = '58.393, 57.12, 57.375'
calib_layer = lambda name: name.endswith('_output')
excluded_sym_names += ['squeezenet0_flatten0_flatten0',
'squeezenet0_pool0_fwd',
'squeezenet0_pool1_fwd',
'squeezenet0_pool2_fwd',
'squeezenet0_pool3_fwd']
excluded_sym_names += ['squeezenet0_flatten0_flatten0']
if exclude_first_conv:
excluded_sym_names += ['squeezenet0_conv0_fwd']
elif args.model == 'mobilenet1.0':
Expand Down
23 changes: 18 additions & 5 deletions src/operator/quantization/quantized_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,30 @@ bool QuantizedPoolingShape(const nnvm::NodeAttrs& attrs,
<< "kernel size (" << param.kernel[1]
<< ") exceeds input (" << dshape[W]
<< " padded to " << (dshape[W] + 2*param.pad[1]) << ")";
// only support valid convention

oshape[N] = dshape[N];
oshape[C] = dshape[C];
if (param.global_pool) {
oshape[H] = 1;
oshape[W] = 1;
} else {
oshape[H] = 1 + (dshape[H] + 2 * param.pad[0] - param.kernel[0]) /
param.stride[0];
oshape[W] = 1 + (dshape[W] + 2 * param.pad[1] - param.kernel[1]) /
param.stride[1];
if (param.pooling_convention == pool_enum::kValid) {
oshape[H] = 1 +
(dshape[H] + 2 * param.pad[0] - param.kernel[0]) /
param.stride[0];
oshape[W] = 1 +
(dshape[W] + 2 * param.pad[1] - param.kernel[1]) /
param.stride[1];
} else {
oshape[H] = 1 + static_cast<int>(std::ceil(
static_cast<float>(dshape[H] + 2 * param.pad[0] -
param.kernel[0]) /
param.stride[0]));
oshape[W] = 1 + static_cast<int>(std::ceil(
static_cast<float>(dshape[W] + 2 * param.pad[1] -
param.kernel[1]) /
param.stride[1]));
}
}

SHAPE_ASSIGN_CHECK(*in_shape, 1, TShape{1});
Expand Down
23 changes: 16 additions & 7 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def check_quantized_conv(data_shape, kernel, num_filter, pad, stride, no_bias, q

@with_seed()
def test_quantized_pooling():
def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_pool, qdtype):
def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_pool, qdtype, convention):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we keep this test to test the default behavior without passing in convention. Add another test to pass in convention as an extra argument.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having another check function like check_quantized_pooling_with_convention may need much redundant code here. Do you think we can give the argument convention a default value here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's okay to only keep this test method. But As @yuxihu mentioned, you may want to test the backward compatibility of the operator without passing this argument.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the default convention before this PR? By adding a new argument, you are testing "valid" and "full". Has the default case been covered with your change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default convention for mxnet pooling is "valid".
https://github.com/apache/incubator-mxnet/blob/master/src/operator/nn/pooling-inl.h#L74

I will set the default value for argument "convention" to "valid" here and revert the change for L264-L267. So the behavior of these 4 test cases will be as before.

if is_test_for_native_cpu():
print('skipped testing quantized_pooling for native cpu since it is not supported yet')
return
Expand All @@ -224,7 +224,8 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p

data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
pooling_fp32 = mx.sym.Pooling(data=data, kernel=kernel, pad=pad, stride=stride,
pool_type=pool_type, global_pool=global_pool, cudnn_off=False)
pool_type=pool_type, global_pool=global_pool, cudnn_off=False,
pooling_convention=convention)
arg_shapes, _, _ = pooling_fp32.infer_shape(data=data_shape)
arg_names = pooling_fp32.list_arguments()
pooling_fp32_exe = pooling_fp32.simple_bind(ctx=mx.current_context(), grad_req='null')
Expand All @@ -244,7 +245,8 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p
quantized_pooling = mx.sym.contrib.quantized_pooling(data=qdata, min_data=min_data,
max_data=max_data, kernel=kernel,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be also good to fix the alignment from L246-L249.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

pad=pad, stride=stride, pool_type=pool_type,
global_pool=global_pool)
global_pool=global_pool,
pooling_convention=convention)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question here: what was the default value for pooling_convention? Make sure the current case is covered.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pooling_int8_exe = quantized_pooling.simple_bind(ctx=mx.current_context(), grad_req='null')
qarg_names = quantized_pooling.list_arguments()
pooling_int8_exe.arg_dict[qarg_names[0]][:] = pooling_fp32_exe.arg_dict[arg_names[0]].astype(qdtype)
Expand All @@ -261,10 +263,17 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p
assert cond == 0

for qdtype in ['int8', 'uint8']:
check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), False, qdtype)
check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), True, qdtype)
check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), False, qdtype)
check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), True, qdtype)
check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), False, qdtype, 'valid')
check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), True, qdtype , 'valid')
check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), False, qdtype, 'valid')
check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), True, qdtype, 'valid')

check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), False, qdtype, 'full')
check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), True, qdtype, 'full')
check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), False, qdtype, 'full')
check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), True, qdtype, 'full')



@with_seed()
def test_quantized_fc():
Expand Down