Skip to content

Commit 1dfd857

Browse files
author
joanna.wozna.intel
authored
Fix format in requantize mkldnn op (#34137)
1 parent 9bc5967 commit 1dfd857

2 files changed

Lines changed: 21 additions & 6 deletions

File tree

paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,10 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
8484
auto src_dt = framework::ToMKLDNNDataType(input->type());
8585
auto dst_dt = with_shift ? framework::MKLDNNDataType::u8 : src_dt;
8686

87-
auto src_md =
88-
platform::MKLDNNMemDesc({src_tz}, src_dt, MKLDNNMemoryFormat::nhwc);
87+
auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, input->format());
8988
src_memory = std::make_shared<dnnl::memory>(src_md, engine,
9089
to_void_cast<T>(input_data));
91-
auto dst_md =
92-
platform::MKLDNNMemDesc({dst_tz}, dst_dt, MKLDNNMemoryFormat::nhwc);
90+
auto dst_md = platform::MKLDNNMemDesc({dst_tz}, dst_dt, input->format());
9391

9492
dnnl::primitive_attr attri;
9593
int mask = 0;

python/paddle/fluid/tests/unittests/mkldnn/test_requantize_mkldnn_op.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,18 @@
2323

2424

2525
class TestReQuantizeOp(OpTest):
26+
def set_input_size(self):
27+
self.input_size = [1, 1, 10, 10]
28+
self.format_reorder = format_reorder
29+
2630
def setUp(self):
2731
self.op_type = 'requantize'
2832
self.scale_in = 127.0
2933
self.shift_in = 0.0
3034
self.scale_out = 100.0
3135
self.shift_out = 0.0
32-
self.input_size = [1, 1, 10, 10]
3336
self.input_data_type = 'int8'
37+
self.set_input_size()
3438
self.set_scales()
3539
self.set_shifts()
3640
self.set_input_data_type()
@@ -76,7 +80,7 @@ def prepare_output(self):
7680
np.rint(self.input.astype('float32') * scale_ratio + new_shift),
7781
type_min, type_max).astype(dst_type)
7882

79-
self.output = format_reorder(output_tmp, self.input_size)
83+
self.output = self.format_reorder(output_tmp, self.input_size)
8084
self.outputs = {'Output': self.output}
8185

8286
def test_check_output(self):
@@ -266,6 +270,18 @@ def set_shifts(self):
266270
self.shift_out = 128.0
267271

268272

273+
# ---------------test non-four dimentional formats--------------------------
274+
275+
276+
class TestReQuantizeOp_2DimFormat(TestReQuantizeOp):
277+
def format_reorder_2Dim(self, out, size):
278+
return out
279+
280+
def set_input_size(self):
281+
self.input_size = [10, 20]
282+
self.format_reorder = self.format_reorder_2Dim
283+
284+
269285
# ---------------test reused requantize op, no shift------------------------
270286

271287

@@ -274,6 +290,7 @@ def setUp(self):
274290
# self.input_size = [1, 1, 10, 10]
275291
self.input_size = [1, 1, 2, 2]
276292
self.input_data_type = 'int8'
293+
self.format_reorder = format_reorder
277294
self.set_scales()
278295
self.set_shifts()
279296
self.set_input_data_type()

0 commit comments

Comments
 (0)