-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Support full convention in quantized pooling #13260
Changes from 6 commits
9aebcec
d2c8787
5198063
867dbfa
fccd6c2
16899fb
fb10846
4c673e6
2abd701
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -214,7 +214,7 @@ def check_quantized_conv(data_shape, kernel, num_filter, pad, stride, no_bias, q | |
|
|
||
| @with_seed() | ||
| def test_quantized_pooling(): | ||
| def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_pool, qdtype): | ||
| def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_pool, qdtype, convention): | ||
|
||
| if is_test_for_native_cpu(): | ||
| print('skipped testing quantized_pooling for native cpu since it is not supported yet') | ||
| return | ||
|
|
@@ -224,7 +224,8 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p | |
|
|
||
| data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') | ||
| pooling_fp32 = mx.sym.Pooling(data=data, kernel=kernel, pad=pad, stride=stride, | ||
| pool_type=pool_type, global_pool=global_pool, cudnn_off=False) | ||
| pool_type=pool_type, global_pool=global_pool, cudnn_off=False, | ||
| pooling_convention=convention) | ||
| arg_shapes, _, _ = pooling_fp32.infer_shape(data=data_shape) | ||
| arg_names = pooling_fp32.list_arguments() | ||
| pooling_fp32_exe = pooling_fp32.simple_bind(ctx=mx.current_context(), grad_req='null') | ||
|
|
@@ -244,7 +245,8 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p | |
| quantized_pooling = mx.sym.contrib.quantized_pooling(data=qdata, min_data=min_data, | ||
| max_data=max_data, kernel=kernel, | ||
|
||
| pad=pad, stride=stride, pool_type=pool_type, | ||
| global_pool=global_pool) | ||
| global_pool=global_pool, | ||
| pooling_convention=convention) | ||
|
||
| pooling_int8_exe = quantized_pooling.simple_bind(ctx=mx.current_context(), grad_req='null') | ||
| qarg_names = quantized_pooling.list_arguments() | ||
| pooling_int8_exe.arg_dict[qarg_names[0]][:] = pooling_fp32_exe.arg_dict[arg_names[0]].astype(qdtype) | ||
|
|
@@ -261,10 +263,17 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p | |
| assert cond == 0 | ||
|
|
||
| for qdtype in ['int8', 'uint8']: | ||
| check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), False, qdtype) | ||
| check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), True, qdtype) | ||
| check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), False, qdtype) | ||
| check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), True, qdtype) | ||
| check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), False, qdtype, 'valid') | ||
| check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), True, qdtype , 'valid') | ||
| check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), False, qdtype, 'valid') | ||
| check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), True, qdtype, 'valid') | ||
|
|
||
| check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), False, qdtype, 'full') | ||
| check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), True, qdtype, 'full') | ||
| check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), False, qdtype, 'full') | ||
| check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), True, qdtype, 'full') | ||
|
|
||
|
|
||
|
|
||
| @with_seed() | ||
| def test_quantized_fc(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest we keep this test to test the default behavior without passing in convention. Add another test to pass in convention as an extra argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having another check function like
check_quantized_pooling_with_conventionmay need much redundant code here. Do you think we can give the argumentconventiona default value here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's okay to only keep this test method. But As @yuxihu mentioned, you may want to test the backward compatibility of the operator without passing this argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.