Skip to content

Commit b103868

Browse files
bukejiyuMangodadada
authored andcommitted
update cpu inference (PaddlePaddle#8984)
1 parent ac7c17f commit b103868

17 files changed

Lines changed: 2282 additions & 918 deletions

csrc/cpu/0001-fp16_bf16.patch

Lines changed: 619 additions & 0 deletions
Large diffs are not rendered by default.

csrc/cpu/0001-fp32.patch

Lines changed: 647 additions & 0 deletions
Large diffs are not rendered by default.

csrc/cpu/0001-patch-fp16-and-bf16.patch

Lines changed: 0 additions & 280 deletions
This file was deleted.

csrc/cpu/0001-patch-fp32.patch

Lines changed: 0 additions & 302 deletions
This file was deleted.

csrc/cpu/README.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
# cpu-custom-ops
22

33
## 快速开始
4-
# 构建 cpu 自定义算子库
5-
```
6-
$ 前提条件:机器支持avx指令
7-
$ bash setup.sh
4+
5+
### 1.环境准备
6+
```shell
7+
# 查询机器是否支持 avx512指令
8+
lscpu | grep avx512*
89
```
10+
11+
### 2.安装 cpu 自定义算子和第三方库
12+
```shell
13+
#建议在 gcc 9.4.0 下安装第三方库
14+
bash setup.sh

csrc/cpu/setup.sh

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,35 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
#1. download XFT
15+
#0.环境准备:安装numactl
16+
# apt-get update
17+
# apt-get install numactl
18+
19+
# 1. download XFT
1620
if [ ! -d xFasterTransformer]; then
17-
git clone --branch v1.7.2 https://github.com/intel/xFasterTransformer.git
21+
git clone https://github.com/intel/xFasterTransformer.git
1822
fi
1923

2024
#2.cp patch
2125
cd xFasterTransformer
22-
git checkout .
26+
git reset --hard 420a493f5c3c74f5fdd786f5399aacd04e021df7
2327
cd ..
2428

2529
if lscpu | grep -q "avx512_bf16"; then
2630
echo "apply bf16 and fp16."
27-
if [ ! -f 0001-patch-fp16-and-bf16.patch ]; then
28-
echo "Error: 0001-patch-fp16-and-bf16.patch not exist."
31+
if [ ! -f 0001-fp16_bf16.patch ]; then
32+
echo "Error: 0001-fp16_bf16.patch not exist."
2933
exit 1
3034
fi
3135
# apply patch
32-
cp ./0001-patch-fp16-and-bf16.patch ./xFasterTransformer/paddle.patch
36+
cp ./0001-fp16_bf16.patch ./xFasterTransformer/paddle.patch
3337
else
3438
echo "apply fp32 "
35-
if [ ! -f 0001-patch-fp32.patch ]; then
36-
echo "Error: does 0001-patch-fp32.patch not exist."
39+
if [ ! -f 0001-fp32.patch ]; then
40+
echo "Error: does 0001-fp32.patch not exist."
3741
exit 1
3842
fi
39-
cp ./0001-patch-fp32.patch ./xFasterTransformer/paddle.patch
43+
cp ./0001-fp32.patch ./xFasterTransformer/paddle.patch
4044
fi
4145

4246
#3. apply patch

csrc/cpu/src/avx_weight_only.cc

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#include "dtype.h"
15+
#include "matmul_helper.h"
16+
#include "my_types.h"
17+
#include "paddle/extension.h"
18+
git adtemplate <typename T>
19+
void AvxCompute(const paddle::Tensor &x,
20+
const paddle::Tensor &weight,
21+
bool trans,
22+
const std::string alog,
23+
paddle::Tensor &out,
24+
xft::Matrix<T> &quantizedWeight,
25+
xft::Vector<float> &WeightScale,
26+
xft::Vector<float> &WeightZero,
27+
xft::Vector<float> &WeightSum,
28+
MMHelper *mmHelper) {
29+
auto out_data = out.data<float>();
30+
const float *x_data = reinterpret_cast<const float *>(x.data<float>());
31+
const float *bias_data = nullptr;
32+
int m = 1;
33+
for (int i = 0; i < x.shape().size() - 1; i++) {
34+
m = m * x.shape()[i];
35+
}
36+
int k = x.shape()[x.shape().size() - 1];
37+
int l = weight.shape()[1];
38+
int n = weight.shape()[1];
39+
40+
mmHelper->compute(false,
41+
m,
42+
n,
43+
k,
44+
1.0f,
45+
x_data,
46+
k,
47+
quantizedWeight.Data(),
48+
WeightScale.Data(),
49+
WeightZero.Data(),
50+
WeightSum.Data(),
51+
0.0,
52+
out_data,
53+
l);
54+
};
55+
template <typename T>
56+
void AvxWeightOnly(const paddle::Tensor &x,
57+
const paddle::Tensor &weight,
58+
bool trans,
59+
const std::string alog,
60+
paddle::Tensor &out) {
61+
static std::unordered_map<std::string,
62+
std::tuple<xft::Matrix<T> *,
63+
xft::Vector<float> *,
64+
xft::Vector<float> *,
65+
xft::Vector<float> *>>
66+
weight_only_hub;
67+
std::stringstream weights_addr;
68+
weights_addr << weight.data<float>() << alog;
69+
std::string weight_only_key = weights_addr.str();
70+
auto it_created = weight_only_hub.find(weight_only_key);
71+
static MMHelper *mmHelper;
72+
int rows = weight.shape()[0], cols = weight.shape()[1];
73+
xft::Vector<float> *WeightScale =
74+
new xft::Vector<float>(); // if weight is int8
75+
xft::Vector<float> *WeightZero =
76+
new xft::Vector<float>(); // if weight is int8
77+
xft::Vector<float> *WeightSum =
78+
new xft::Vector<float>(); // if weight is int8
79+
xft::Matrix<T> *quantizedWeight = new xft::Matrix<T>();
80+
if (it_created == weight_only_hub.end()) {
81+
auto weight_ptr = reinterpret_cast<const float *>(weight.data<float>());
82+
xft::Matrix<T> convertedWeight;
83+
mmHelper = new MMHelper(xft::DeviceKind::iCPU, 0);
84+
mmHelper->convertWeight(trans,
85+
rows,
86+
cols,
87+
weight_ptr,
88+
nullptr,
89+
nullptr,
90+
convertedWeight,
91+
*WeightScale,
92+
*WeightZero,
93+
*WeightSum);
94+
quantizedWeight->Resize(rows, cols);
95+
mmHelper->packWeight(trans, convertedWeight, *quantizedWeight);
96+
weight_only_hub[weight_only_key] =
97+
std::make_tuple(quantizedWeight, WeightScale, WeightZero, WeightSum);
98+
AvxCompute<T>(x,
99+
weight,
100+
trans,
101+
alog,
102+
out,
103+
*quantizedWeight,
104+
*WeightScale,
105+
*WeightZero,
106+
*WeightSum,
107+
mmHelper);
108+
} else {
109+
AvxCompute<T>(x,
110+
weight,
111+
trans,
112+
alog,
113+
out,
114+
*(std::get<0>(it_created->second)),
115+
*(std::get<1>(it_created->second)),
116+
*(std::get<2>(it_created->second)),
117+
*(std::get<3>(it_created->second)),
118+
mmHelper);
119+
}
120+
}
121+
std::vector<paddle::Tensor> InvokeAvxWeightOnly(const paddle::Tensor &x,
122+
const paddle::Tensor &weight,
123+
const std::string &alog,
124+
bool trans) {
125+
auto out_shape = x.shape();
126+
out_shape[out_shape.size() - 1] = weight.shape()[1];
127+
auto out = paddle::empty(out_shape, x.dtype(), paddle::CPUPlace());
128+
if (alog == "int8") {
129+
AvxWeightOnly<int8_t>(x, weight, trans, alog, out);
130+
} else if (alog == "fp16") {
131+
AvxWeightOnly<float16_t>(x, weight, trans, alog, out);
132+
} else {
133+
AvxWeightOnly<float16_t>(x, weight, trans, alog, out);
134+
}
135+
return {out};
136+
}
137+
138+
std::vector<std::vector<int64_t>> AvxWeightOnlyInferShape(
139+
std::vector<int64_t> x_shape,
140+
std::vector<int64_t> weigh_shape) {
141+
int m = 1;
142+
for (int i = 0; i < x_shape.size() - 1; i++) {
143+
m = m * x_shape[i];
144+
}
145+
return {std::vector<int64_t>{m, weigh_shape[1]}};
146+
}
147+
148+
std::vector<paddle::DataType> AvxWeightOnlyInferDtype(
149+
paddle::DataType x_dtype,
150+
paddle::DataType weight_dtype) {
151+
return {x_dtype};
152+
}
153+
154+
PD_BUILD_OP(avx_weight_only)
155+
.Inputs({"x", "weight"})
156+
.Outputs({"out"})
157+
.Attrs({"alog: std::string", "trans:bool"})
158+
.SetKernelFn(PD_KERNEL(InvokeAvxWeightOnly))
159+
.SetInferShapeFn(PD_INFER_SHAPE(AvxWeightOnlyInferShape))
160+
.SetInferDtypeFn(PD_INFER_DTYPE(AvxWeightOnlyInferDtype));

csrc/cpu/src/setup_cpu.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,28 @@ def check_avx512_bf16__support():
5353
return False
5454

5555

56-
# cc flags
5756
paddle_extra_compile_args = [
5857
"-std=c++17",
5958
"-shared",
6059
"-fPIC",
6160
"-Wno-parentheses",
6261
"-DPADDLE_WITH_CUSTOM_KERNEL",
62+
"-mavx512f",
63+
"-mavx512vl",
64+
"-fopenmp",
65+
"-mavx512bw",
66+
"-mno-mmx",
67+
"-Wall",
68+
"-march=skylake-avx512",
69+
"-O3",
70+
"-g",
6371
]
6472

6573
if check_avx512_bf16__support():
6674
paddle_extra_compile_args += [
6775
"-DAVX512_BF16_WEIGHT_ONLY_BF16=true",
68-
"-DAVX512_BF16_WEIGHT_ONLY_BF16=true",
76+
"-DAVX512_FP16_WEIGHT_ONLY_INT8=true",
77+
"-DAVX512_FP16_WEIGHT_ONLY_FP16=true",
6978
]
7079
else:
7180
paddle_extra_compile_args += [
@@ -81,15 +90,17 @@ def check_avx512_bf16__support():
8190

8291
# include path third_party
8392
paddle_custom_kernel_include += [
84-
os.path.join(XFT_INCLUDE_DIR, "include"), # glog
85-
os.path.join(XFT_INCLUDE_DIR, "src/common"), # src
86-
os.path.join(XFT_INCLUDE_DIR, "src/kernel"), # src
87-
os.path.join(XFT_INCLUDE_DIR, "src/layers"), # src
88-
os.path.join(XFT_INCLUDE_DIR, "src/models"), # src
89-
os.path.join(XFT_INCLUDE_DIR, "src/utils"), # src
90-
os.path.join(XFT_INCLUDE_DIR, "3rdparty/onednn/include"), # src
91-
os.path.join(XFT_INCLUDE_DIR, "3rdparty/onednn/build/include"), # src
92-
os.path.join(XFT_INCLUDE_DIR, "3rdparty/xdnn"), # src
93+
os.path.join(XFT_INCLUDE_DIR, "include"),
94+
os.path.join(XFT_INCLUDE_DIR, "src/common"),
95+
os.path.join(XFT_INCLUDE_DIR, "src/kernel"),
96+
os.path.join(XFT_INCLUDE_DIR, "src/layers"),
97+
os.path.join(XFT_INCLUDE_DIR, "src/models"),
98+
os.path.join(XFT_INCLUDE_DIR, "src/utils"),
99+
os.path.join(XFT_INCLUDE_DIR, "3rdparty/onednn/include"),
100+
os.path.join(XFT_INCLUDE_DIR, "3rdparty/onednn/build/include"),
101+
os.path.join(XFT_INCLUDE_DIR, "3rdparty/xdnn"),
102+
os.path.join(XFT_INCLUDE_DIR, "3rdparty"),
103+
os.path.join(XFT_INCLUDE_DIR, "3rdparty/mkl/include"),
93104
]
94105

95106
# libs path
@@ -101,11 +112,13 @@ def check_avx512_bf16__support():
101112

102113
custom_kernel_dot_module = CppExtension(
103114
sources=[
104-
"./src/xft_llama_layer.cc",
105115
"../generation/save_with_output.cc",
106116
"./src/token_penalty_multi_scores.cc",
107117
"./src/stop_generation_multi_ends.cc",
108118
"./src/set_value_by_flags.cc",
119+
"./src/xft_transformer.cc",
120+
"./src/avx_weight_only.cc",
121+
"./src/xft_greedy_search.cc",
109122
],
110123
include_dirs=paddle_custom_kernel_include,
111124
library_dirs=paddle_custom_kernel_library_dir,

csrc/cpu/src/token_penalty_multi_scores.cc

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,20 @@
11
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2-
//
2+
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
55
// You may obtain a copy of the License at
6-
//
6+
//
77
// http://www.apache.org/licenses/LICENSE-2.0
8-
//
8+
//
99
// Unless required by applicable law or agreed to in writing, software
1010
// distributed under the License is distributed on an "AS IS" BASIS,
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

1515
#include <paddle/extension.h>
16-
1716
#include <vector>
1817

19-
#include "paddle/phi/core/kernel_registry.h"
20-
2118
template <typename T>
2219
void min_length_logits_process(T* logits,
2320
const int64_t* cur_len,

0 commit comments

Comments
 (0)