Skip to content

Commit 978d289

Browse files
committed
[xpu] multi_encoder_xpu supoort smooth quant, skip quant and local quant
1 parent 4f1bffe commit 978d289

12 files changed

Lines changed: 510 additions & 168 deletions

File tree

paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc

Lines changed: 267 additions & 82 deletions
Large diffs are not rendered by default.

paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ struct PatternParam {
128128
bool norm_before;
129129
bool with_q_scale;
130130
bool with_mask;
131+
bool is_smooth_quant;
131132
};
132133

133134
class MultiEncoderXPUFusePass : public FusePassBase {
@@ -142,7 +143,8 @@ class MultiEncoderXPUFusePass : public FusePassBase {
142143
const std::string& matmul_type_2,
143144
bool norm_before,
144145
bool with_q_scale,
145-
bool with_mask) const;
146+
bool with_mask,
147+
bool is_smooth_qunat) const;
146148

147149
bool ApplyMultiEncoderXPUFuse(ir::Graph* graph) const;
148150

@@ -152,7 +154,7 @@ class MultiEncoderXPUFusePass : public FusePassBase {
152154
// 1. Transpose q_w, k_w, v_w
153155
// 2. Concat q_w, k_w, v_w
154156
// 3. Generate qkv_w_max tensor
155-
// 4. Quant qkv_w to int16
157+
// 4. Quant qkv_w to int16/int8 or cast to float16 (local quant)
156158
void PrepareQKVWeight(
157159
Graph* graph,
158160
Scope* scope,
@@ -161,6 +163,7 @@ class MultiEncoderXPUFusePass : public FusePassBase {
161163
Node* k_w,
162164
Node* v_w,
163165
bool enable_int8,
166+
bool local_quant,
164167
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
165168
Node** qkv_w,
166169
Node** qkv_w_max,
@@ -171,7 +174,9 @@ class MultiEncoderXPUFusePass : public FusePassBase {
171174
BlockDesc* block,
172175
std::unordered_map<std::string, std::vector<Node*>>* node_maps,
173176
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
174-
std::vector<Node*>* input_max_nodes) const;
177+
std::vector<Node*>* input_max_nodes,
178+
std::vector<std::string>* quant_types,
179+
const std::string* act_type) const;
175180

176181
// 1. Cast bias to fp32
177182
// 2. Concat q/k/v bias

paddle/fluid/framework/ir/xpu/pass_utils.cc

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
1616
#include "paddle/fluid/platform/enforce.h"
17+
#include "paddle/phi/kernels/cast_kernel.h"
1718
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
1819

1920
namespace paddle {
@@ -123,11 +124,68 @@ template size_t HashTensor<int16_t>(const phi::DenseTensor& in);
123124
template size_t HashTensor<float>(const phi::DenseTensor& in);
124125
template size_t HashTensor<int8_t>(const phi::DenseTensor& in);
125126

127+
template <>
128+
size_t HashTensor<float16>(const phi::DenseTensor& in) {
129+
phi::DenseTensor dst_tensor;
130+
auto* cpu_ctx = static_cast<phi::CPUContext*>(
131+
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
132+
dst_tensor.Resize(in.dims());
133+
dst_tensor.set_type(phi::DataType::FLOAT32);
134+
dst_tensor.set_layout(in.layout());
135+
phi::CastKernel<float16>(*cpu_ctx, in, phi::DataType::FLOAT32, &dst_tensor);
136+
return HashTensor<float>(dst_tensor);
137+
}
138+
126139
std::string GetPrefixWithoutHash(const std::string& name) {
127140
std::size_t found = name.find("_#");
128141
return found == std::string::npos ? name : name.substr(0, found);
129142
}
130143

144+
void ConvertFromFp32ToFp16(phi::DenseTensor* weight,
145+
phi::DenseTensor* weight_max,
146+
bool transpose) {
147+
// Convert fp16 to fp32
148+
phi::DenseTensor weight_fp32;
149+
CastToFp32(weight, &weight_fp32);
150+
151+
if (transpose) { // (k, n) -> (n, k)
152+
Transpose2D(&weight_fp32);
153+
}
154+
155+
auto FindMaxAbs = [](const float* data, int len) {
156+
float max_f = 0.0f;
157+
for (int i = 0; i < len; ++i) {
158+
float max = std::abs(data[i]);
159+
if (max > max_f) {
160+
max_f = max;
161+
}
162+
}
163+
return max_f;
164+
};
165+
166+
auto* cpu_ctx = static_cast<phi::CPUContext*>(
167+
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
168+
// Convert to fp16
169+
phi::DenseTensor weight_fp16;
170+
CastToFp16(&weight_fp32, &weight_fp16);
171+
// Find max
172+
int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1);
173+
int size = weight_fp32.numel();
174+
float max_val = FindMaxAbs(weight_fp32.data<float>(), size);
175+
std::vector<float> max_vec(max_ptr_size, max_val);
176+
weight_max->set_type(phi::DataType::FLOAT32);
177+
weight_max->Resize({max_ptr_size});
178+
memcpy(cpu_ctx->Alloc<float>(weight_max),
179+
max_vec.data(),
180+
max_ptr_size * sizeof(float));
181+
weight->clear();
182+
weight->set_type(phi::DataType::FLOAT16);
183+
weight->Resize({size});
184+
memcpy(cpu_ctx->Alloc<float16>(weight),
185+
weight_fp16.data<float16>(),
186+
size * sizeof(float16));
187+
}
188+
131189
template <typename Tcpu, typename Txpu>
132190
void PrepareWeight(Graph* graph,
133191
Scope* scope,
@@ -268,6 +326,18 @@ template void PrepareWeight<float, float>(
268326
const std::vector<float>& weight_scales,
269327
bool per_channel_quant = false);
270328

329+
template void PrepareWeight<float, float16>(
330+
Graph* graph,
331+
Scope* scope,
332+
BlockDesc* block,
333+
Node* weight,
334+
Node** dst_weight,
335+
Node** dst_weight_max,
336+
Node** dst_scale_max,
337+
bool transpose,
338+
const std::vector<float>& weight_scales,
339+
bool per_channel_quant = false);
340+
271341
template void PrepareWeight<float, int16_t>(
272342
Graph* graph,
273343
Scope* scope,

paddle/fluid/framework/ir/xpu/pass_utils.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ std::vector<Node*> FindOpNodeByInputName(Graph* graph,
5757
template <typename T>
5858
size_t HashTensor(const phi::DenseTensor& in);
5959

60+
void ConvertFromFp32ToFp16(phi::DenseTensor* weight,
61+
phi::DenseTensor* weight_max,
62+
bool transpose);
63+
6064
template <typename Tcpu,
6165
typename Txpu,
6266
typename std::enable_if<!std::is_same<Tcpu, Txpu>::value, Tcpu>::type*
@@ -67,8 +71,12 @@ void ConvertWeightWrapper(phi::DenseTensor* weight,
6771
bool transpose,
6872
const std::vector<float>& weight_scales,
6973
bool per_channel_quant) {
70-
ConvertWithQuant<Tcpu, Txpu>(
71-
weight, weight_max, scale_max, transpose, per_channel_quant);
74+
if (std::is_same<Tcpu, float>::value && std::is_same<Txpu, float16>::value) {
75+
ConvertFromFp32ToFp16(weight, weight_max, transpose);
76+
} else {
77+
ConvertWithQuant<Tcpu, Txpu>(
78+
weight, weight_max, scale_max, transpose, per_channel_quant);
79+
}
7280
}
7381

7482
template <typename Tcpu,

paddle/fluid/framework/ir/xpu/quant_dequant_xpu_pass.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ void QuantDequantXPUPass::CollectInputScalesFromQuantize(
191191
if (out->Name() == out_var_name) {
192192
for (auto* var : out->outputs) {
193193
auto op_desc = var->Op();
194-
std::string quantized_op_type = op_desc->Type();
195194
op_desc->SetAttr("enable_int8", true);
196195
op_desc->Flush();
197196
}

paddle/fluid/framework/ir/xpu/quant_utils.cc

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -115,41 +115,47 @@ void CastToInt32(phi::DenseTensor* in, phi::DenseTensor* out) {
115115
Assign(*out_ptr, in);
116116
}
117117
}
118-
119-
void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out) {
118+
void CastTo(phi::DenseTensor* in, phi::DenseTensor* out, DataType out_dtype) {
120119
auto* cpu_ctx = static_cast<phi::CPUContext*>(
121120
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
122121

123-
paddle::experimental::CheckAndTrans2Contiguous(in);
122+
if (in->dtype() != phi::DataType::FLOAT16 &&
123+
in->dtype() != phi::DataType::FLOAT32) {
124+
PADDLE_THROW(platform::errors::InvalidArgument(
125+
"Only support fp16 and fp32, but received dtype is %s.",
126+
phi::DataTypeToString(in->dtype())));
127+
}
124128

125-
phi::DenseTensor fp32_tensor;
126-
phi::DenseTensor* out_ptr = out == nullptr ? &fp32_tensor : out;
129+
paddle::experimental::CheckAndTrans2Contiguous(in);
130+
phi::DenseTensor ori_tensor;
131+
phi::DenseTensor* out_ptr = out == nullptr ? &ori_tensor : out;
127132
out_ptr->Resize(in->dims());
128-
out_ptr->set_type(phi::DataType::FLOAT32);
133+
out_ptr->set_type(out_dtype);
129134
out_ptr->set_layout(in->layout());
130-
131-
switch (in->dtype()) {
132-
case phi::DataType::FLOAT16:
133-
phi::CastKernel<phi::dtype::float16>(
134-
*cpu_ctx, *in, phi::DataType::FLOAT32, out_ptr);
135-
break;
136-
case phi::DataType::FLOAT32:
137-
if (out == nullptr) {
138-
return;
139-
} else {
140-
phi::AssignKernel(*cpu_ctx, *in, out_ptr);
141-
}
142-
break;
143-
default:
144-
PADDLE_THROW(platform::errors::InvalidArgument(
145-
"Only support fp16 and fp32, but received dtype is %s.",
146-
phi::DataTypeToString(in->dtype())));
147-
break;
135+
if (in->dtype() == out_dtype) {
136+
if (out == nullptr) {
137+
return;
138+
} else {
139+
phi::AssignKernel(*cpu_ctx, *in, out_ptr);
140+
}
141+
} else {
142+
if (in->dtype() == phi::DataType::FLOAT16) {
143+
phi::CastKernel<float16>(*cpu_ctx, *in, out_dtype, out_ptr);
144+
} else {
145+
phi::CastKernel<float>(*cpu_ctx, *in, out_dtype, out_ptr);
146+
}
147+
if (out == nullptr) {
148+
Assign(*out_ptr, in);
149+
}
148150
}
151+
}
149152

150-
if (out == nullptr) {
151-
Assign(*out_ptr, in);
152-
}
153+
void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out) {
154+
CastTo(in, out, phi::DataType::FLOAT32);
155+
}
156+
157+
void CastToFp16(phi::DenseTensor* in, phi::DenseTensor* out) {
158+
CastTo(in, out, phi::DataType::FLOAT16);
153159
}
154160

155161
static float FindMaxAbs(const float* data, int len) {

paddle/fluid/framework/ir/xpu/quant_utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,12 @@ void Assign(const phi::DenseTensor& in, phi::DenseTensor* out);
2323

2424
void Transpose2D(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
2525

26+
void CastTo(phi::DenseTensor* in, phi::DenseTensor* out, DataType dtype);
27+
2628
void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
2729

30+
void CastToFp16(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
31+
2832
void CastToInt32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
2933

3034
template <typename T>

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ void CpuPassStrategy::EraseFcMkldnnPasses() {
521521

522522
XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
523523
passes_.assign({
524-
"quant_dequant_xpu_pass",
524+
// "quant_dequant_xpu_pass", open this pass when use old int8 model
525525
"delete_quant_dequant_linear_op_pass",
526526
"delete_weight_dequant_linear_op_pass",
527527
"delete_assign_op_pass",

paddle/phi/api/yaml/fused_ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@
399399
backward : max_pool2d_v2_grad
400400

401401
- op : multi_encoder_xpu
402-
args : (Tensor x, Tensor[] fc_input_max, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor mask, Tensor seq_lod, Tensor max_seq_len, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx, bool is_per_channel)
402+
args : (Tensor x, Tensor[] fc_input_max, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] smooth_scale_weight, Tensor mask, Tensor seq_lod, Tensor max_seq_len, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx, bool is_per_channel, float[] softmax_max_value, str[] quant_types)
403403
output : Tensor(out), Tensor(x_fp16), Tensor(out_fp16)
404404
infer_meta :
405405
func : MultiEncoderXPUInferMeta

paddle/phi/infermeta/fusion.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,6 +1446,7 @@ void MultiEncoderXPUInferMeta(
14461446
const std::vector<const MetaTensor*>& fc_bias,
14471447
const std::vector<const MetaTensor*>& ln_scale,
14481448
const std::vector<const MetaTensor*>& ln_bias,
1449+
const std::vector<const MetaTensor*>& smooth_scale_weight,
14491450
const MetaTensor& mask,
14501451
const MetaTensor& seq_lod,
14511452
const MetaTensor& max_seq_len,
@@ -1459,6 +1460,8 @@ void MultiEncoderXPUInferMeta(
14591460
int relative_type,
14601461
int slice_idx,
14611462
bool is_per_channel,
1463+
const std::vector<float>& softmax_max_value,
1464+
const std::vector<std::string>& quant_types,
14621465
MetaTensor* out,
14631466
MetaTensor* x_fp16,
14641467
MetaTensor* out_fp16) {

0 commit comments

Comments
 (0)