Skip to content

Commit 7fa3e2d

Browse files
authored
【Hackathon 6th No.52】move quantize、dequantize op to phi (#64494)
1 parent 3375cfb commit 7fa3e2d

File tree

6 files changed

+163
-138
lines changed

6 files changed

+163
-138
lines changed

paddle/fluid/operators/onednn/quantize_onednn_op.cc

Lines changed: 0 additions & 129 deletions
This file was deleted.

paddle/fluid/pir/dialect/operator/utils/utils.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ const std::unordered_set<std::string> LegacyOpList = {
8484
#ifdef PADDLE_WITH_DNNL
8585
paddle::onednn::dialect::LrnOp::name(),
8686
paddle::onednn::dialect::LrnGradOp::name(),
87-
paddle::onednn::dialect::QuantizeOp::name(),
8887
paddle::onednn::dialect::MultiGruOp::name(),
8988
paddle::onednn::dialect::FusionLstmOp::name(),
9089
#endif
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/phi/kernels/quantize_kernel.h"
13+
#include "paddle/phi/backends/onednn/onednn_reuse.h"
14+
#include "paddle/phi/core/compat/convert_utils.h"
15+
#include "paddle/phi/core/enforce.h"
16+
#include "paddle/phi/core/expect.h"
17+
#include "paddle/phi/core/kernel_registry.h"
18+
#include "paddle/phi/core/utils/data_type.h"
19+
20+
namespace phi {
21+
22+
using dnnl::memory;
23+
24+
template <typename T, typename Context>
25+
void QuantOpKernel(const Context& dev_ctx,
26+
const DenseTensor& input,
27+
bool is_negative_input,
28+
const float scale,
29+
const float shift,
30+
const std::string& output_format,
31+
bool bfloat16,
32+
DenseTensor* output) {
33+
const auto quantization_shift = static_cast<int32_t>(shift);
34+
const bool with_scale = scale != 1.0f;
35+
const bool with_shift = quantization_shift != 0.0f;
36+
37+
PADDLE_ENFORCE_NE(scale,
38+
0.0f,
39+
phi::errors::InvalidArgument(
40+
"Quantization scale must be different than 0.0f"));
41+
PADDLE_ENFORCE(quantization_shift <= 255 && quantization_shift >= 0,
42+
phi::errors::InvalidArgument(
43+
"Quantization shift must be lower or equal to ",
44+
"255 and greater or equal to 0, but got %f",
45+
quantization_shift));
46+
47+
auto x_tz = common::vectorize<int64_t>(input.dims());
48+
dnnl::primitive_attr attrs;
49+
static constexpr int32_t mask = 0;
50+
51+
if (with_scale) {
52+
attrs.set_scales_mask(DNNL_ARG_SRC, mask);
53+
}
54+
55+
if (with_shift) {
56+
attrs.set_zero_points_mask(DNNL_ARG_DST, mask);
57+
}
58+
59+
auto x_type = phi::funcs::ToOneDNNDataType(input.dtype());
60+
DataType out_dtype;
61+
62+
if (bfloat16) {
63+
out_dtype = DataType::BFLOAT16;
64+
} else if (is_negative_input && !with_shift) {
65+
out_dtype = DataType::INT8;
66+
} else {
67+
out_dtype = DataType::UINT8;
68+
}
69+
70+
auto out_type = phi::funcs::ToOneDNNDataType(out_dtype);
71+
72+
phi::funcs::ReorderOneDNNHandler reorder_handler(
73+
x_tz, input.dtype(), x_type, out_dtype, out_type, dev_ctx.GetEngine());
74+
75+
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
76+
input.mem_desc(), phi::funcs::to_void_cast(input.data<T>()));
77+
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
78+
output, input.mem_desc(), dev_ctx.GetPlace());
79+
80+
auto reorder_p = reorder_handler.AcquireReorder(
81+
reorder_dst_memory_p, reorder_src_memory_p, attrs);
82+
83+
auto& astream = phi::OneDNNContext::tls().get_stream();
84+
85+
auto scales_md = dnnl::memory::desc(
86+
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
87+
auto scales_mem = dnnl::memory(
88+
scales_md, dev_ctx.GetEngine(), phi::funcs::to_void_cast<float>(&scale));
89+
auto zero_points_md = dnnl::memory::desc(
90+
{1}, dnnl::memory::data_type::s32, dnnl::memory::format_tag::x);
91+
auto zero_points_mem =
92+
dnnl::memory(zero_points_md,
93+
dev_ctx.GetEngine(),
94+
phi::funcs::to_void_cast<int32_t>(&quantization_shift));
95+
96+
std::unordered_map<int, dnnl::memory> reorder_args;
97+
reorder_args.insert({DNNL_ARG_SRC, *reorder_src_memory_p});
98+
reorder_args.insert({DNNL_ARG_DST, *reorder_dst_memory_p});
99+
if (with_scale) {
100+
reorder_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, scales_mem});
101+
}
102+
if (with_shift) {
103+
reorder_args.insert(
104+
{DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, zero_points_mem});
105+
}
106+
107+
reorder_p->execute(astream, reorder_args);
108+
astream.wait();
109+
110+
output->set_mem_desc(reorder_dst_memory_p->get_desc());
111+
}
112+
} // namespace phi
113+
114+
PD_REGISTER_KERNEL(quantize, OneDNN, ONEDNN, phi::QuantOpKernel, float) {}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/dense_tensor.h"
18+
#include "paddle/phi/core/device_context.h"
19+
#include "paddle/phi/core/kernel_registry.h"
20+
21+
namespace phi {
22+
23+
template <typename T, typename Context>
24+
void QuantOpKernel(const Context& dev_ctx,
25+
const DenseTensor& input,
26+
bool is_negative_input,
27+
const float scale,
28+
const float shift,
29+
const std::string& output_format,
30+
bool bfloat16,
31+
DenseTensor* output);
32+
33+
} // namespace phi

test/mkldnn/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ if(WITH_ONEDNN AND NOT WIN32)
2020
py_test_modules(
2121
test_dequantize_mkldnn_op_static_build MODULES test_dequantize_mkldnn_op
2222
ENVS FLAGS_new_executor_static_build=true)
23+
py_test_modules(
24+
test_quantize_mkldnn_op_static_build MODULES test_quantize_mkldnn_op ENVS
25+
FLAGS_new_executor_static_build=true)
2326
endif()
2427

2528
set_tests_properties(test_concat_mkldnn_op PROPERTIES TIMEOUT 120)

test/mkldnn/test_quantize_mkldnn_op.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ def setUp(self):
2626
self.scale = 255.0
2727
self.shift = 0.0
2828
self.input_size = [1, 1, 5, 5] # Naive nChw16c
29-
self.is_negative = False
29+
self.is_negative_input = False
3030
self.output_format = 'NCHW'
31+
self.bfloat16 = False
3132
self.set_scale()
3233
self.set_shift()
3334
self.set_is_negative()
@@ -37,7 +38,7 @@ def setUp(self):
3738
self.prepare_output()
3839

3940
def prepare_input(self):
40-
if self.is_negative:
41+
if self.is_negative_input:
4142
# input data values are from interval [-1.0, 1.0)
4243
self.input = (
4344
2 * np.random.random_sample(self.input_size) - 1
@@ -50,14 +51,18 @@ def prepare_input(self):
5051

5152
self.inputs = {'Input': OpTest.np_dtype_to_base_dtype(self.input)}
5253
self.attrs = {
54+
'is_negative_input': self.is_negative_input,
5355
'Scale': self.scale,
5456
'Shift': self.shift,
55-
'is_negative_input': self.is_negative,
5657
'output_format': self.output_format,
58+
'bfloat16': self.bfloat16,
5759
}
5860

5961
def prepare_output(self):
60-
input_data_type = 'int8' if self.is_negative else 'uint8'
62+
if self.is_negative_input and self.shift == 0.0:
63+
input_data_type = 'int8'
64+
else:
65+
input_data_type = 'uint8'
6166
output = np.rint(self.input * self.scale + self.shift).astype(
6267
input_data_type
6368
)
@@ -97,15 +102,15 @@ def set_scale(self):
97102
self.scale = 127.0
98103

99104
def set_is_negative(self):
100-
self.is_nagative = True
105+
self.is_negative_input = True
101106

102107

103108
class TestQuantizeOp2(TestQuantizeOp):
104109
def set_scale(self):
105110
self.scale = 255.0
106111

107112
def set_is_negative(self):
108-
self.is_nagative = False
113+
self.is_negative_input = False
109114

110115

111116
# 2-dim input
@@ -115,7 +120,7 @@ def set_output_format(self):
115120
self.output_format = 'NCHW'
116121

117122
def set_is_negative(self):
118-
self.is_nagative = False
123+
self.is_negative_input = False
119124

120125
def set_scale(self):
121126
self.scale = 255.0
@@ -131,7 +136,7 @@ def set_input_size(self):
131136
# N - negative input
132137
class TestQuantizeOpShift_NCHW_2_N(TestQuantizeOpShift_NCHW_2_P):
133138
def set_is_negative(self):
134-
self.is_nagative = True
139+
self.is_negative_input = True
135140

136141
def set_scale(self):
137142
self.scale = 127.0

0 commit comments

Comments
 (0)