Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,9 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w,
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b,
Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out) {
Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out,
Node* softmax_qk, Node* eltadd0, Node* eltadd1, Node* eltadd2,
Node* matmul_qk) {
auto scale_attr = BOOST_GET_CONST(float, scale->Op()->GetAttr("scale"));

// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
Expand Down Expand Up @@ -876,19 +878,35 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
weight_max = std::max(weight_max, weight_scale2);
multihead_op_desc.SetAttr("weight_scale", weight_max);

if (mul0_op_desc->HasAttr("out_threshold")) {
auto* add0_op_desc = eltadd0->Op();
auto* add1_op_desc = eltadd1->Op();
auto* add2_op_desc = eltadd2->Op();
if (add0_op_desc->HasAttr("out_threshold")) {
auto out_scale0 =
BOOST_GET_CONST(float, mul0_op_desc->GetAttr("out_threshold"));
BOOST_GET_CONST(float, add0_op_desc->GetAttr("out_threshold"));
auto out_scale1 =
BOOST_GET_CONST(float, mul1_op_desc->GetAttr("out_threshold"));
BOOST_GET_CONST(float, add1_op_desc->GetAttr("out_threshold"));
auto out_scale2 =
BOOST_GET_CONST(float, mul2_op_desc->GetAttr("out_threshold"));
BOOST_GET_CONST(float, add2_op_desc->GetAttr("out_threshold"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out_scale0out_scale1out_scale2分别是q,k,v的scale,将其最大值设置为 multihead_op的"out_threshold", multihead_op中包含FC和QKVToContextPlugin,这个"out_threhold"是用作FC的output scale.
问题是:能給"out_threhold"换个名字么?原因如下:

  1. 当前命名不太直观;容易被误以为multihead_op的output scale;
  2. "out_threhold"用于fusion pass向TRT op convert传递信息,所以这个命名不用受限于量化模型的存储格式。另外,也不受OP Maker中对attr的约束。

auto out_scale_max = std::max(out_scale0, out_scale1);
out_scale_max = std::max(out_scale_max, out_scale2);
multihead_op_desc.SetAttr("out_threshold", out_scale_max);
multihead_op_desc.SetAttr("fc_out_threshold", out_scale_max);
}
}

auto* softmax_qk_op_desc = softmax_qk->Op();
auto* matmul_qk_op_desc = matmul_qk->Op();
if (matmul_qk_op_desc->HasAttr("X_scale")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果这个判断为False,line897~902还有必要么?

multihead_op_desc.SetAttr("qkv2context_plugin_int8", true);
if (softmax_qk_op_desc->HasAttr("out_threshold")) {
auto qkv_plugin_scale = BOOST_GET_CONST(
float, softmax_qk_op_desc->GetAttr("out_threshold"));
multihead_op_desc.SetAttr("dp_probs", qkv_plugin_scale);
}
} else {
multihead_op_desc.SetAttr("qkv2context_plugin_int8", false);
}

auto* multihead = graph->CreateOpNode(&multihead_op_desc);

IR_NODE_LINK_TO(input0, multihead);
Expand Down Expand Up @@ -990,7 +1008,8 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
}
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
reshape2_0, reshape2_qkv_out, scale, scale_out);
reshape2_0, reshape2_qkv_out, scale, scale_out, softmax_qk,
eltadd0, eltadd1, eltadd2, matmul_qk);

std::unordered_set<const Node*> marked_nodes({eltadd0,
eltadd1,
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/inference/tensorrt/convert/fc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,9 @@ class FcOpConverter : public OpConverter {
auto* fc_layer_int8 =
TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output,
nv_ksize, weight.get(), bias.get());
engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0), out_scale);
auto* fc_after_reshape_int8 = reshape_after_fc(
fc_layer_int8->getOutput(0), x_dim, x_num_col_dims);
engine_->SetTensorDynamicRange(fc_after_reshape_int8->getOutput(0),
out_scale);
if (activation_type == "relu") {
nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *(fc_after_reshape_int8->getOutput(0)),
Expand Down
26 changes: 20 additions & 6 deletions paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class MultiheadMatMulOpConverter : public OpConverter {

float* weight_data = nullptr;
bool enable_int8 = op_desc.HasAttr("enable_int8");
bool qkv2context_plugin_int8 =
BOOST_GET_CONST(bool, op_desc.GetAttr("qkv2context_plugin_int8"));
float in_scale = 0.;

if (enable_int8) {
Expand Down Expand Up @@ -147,13 +149,16 @@ class MultiheadMatMulOpConverter : public OpConverter {

if (enable_int8) {
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("out_threshold"), true,
op_desc.HasAttr("fc_out_threshold"), true,
platform::errors::InvalidArgument(
"must have out threshold in multihead layers in int8 mode"));
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
BOOST_GET_CONST(float, op_desc.GetAttr("fc_out_threshold"));
engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale);
dp_probs = out_scale / 127.0;
if (qkv2context_plugin_int8) {
dp_probs =
BOOST_GET_CONST(float, op_desc.GetAttr("dp_probs")) / 127.0;
}
}

auto mask_tensor = engine_->GetITensor("qkv_plugin_mask");
Expand All @@ -166,16 +171,25 @@ class MultiheadMatMulOpConverter : public OpConverter {
: nvinfer1::DataType::kFLOAT);
if (enable_int8) {
type = static_cast<int>(nvinfer1::DataType::kHALF);
if (qkv2context_plugin_int8) {
type = static_cast<int>(nvinfer1::DataType::kINT8);
}
}
bool has_mask = true;
int var_seqlen = 1;
const std::vector<nvinfer1::PluginField> fields{
std::vector<nvinfer1::PluginField> fields{
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
{"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, 1},
{"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1},
{"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1},
{"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1},
{ "dq_probs", &dp_probs, nvinfer1::PluginFieldType::kFLOAT32, 1 }};
{ "var_seqlen",
&var_seqlen,
nvinfer1::PluginFieldType::kINT32,
1 }};
if (qkv2context_plugin_int8) {
fields.push_back(
{"dq_probs", &dp_probs, nvinfer1::PluginFieldType::kFLOAT32, 1});
}
nvinfer1::PluginFieldCollection* plugin_collection =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_collection) +
Expand Down