Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion paddle/fluid/framework/ir/fc_gru_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include <string>

#include "paddle/fluid/framework/op_version_registry.h"

#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
class Scope;
Expand Down Expand Up @@ -335,6 +335,9 @@ void FCGRUFusePass::ApplyImpl(ir::Graph* graph) const {
graph, name_scope_, param_scope(), true /*with_fc_bias*/);

AddStatis(fusion_count);

string::PrettyLogDetail("--- fused %d pairs of fc gru patterns",
fusion_count);
}

} // namespace ir
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <string>

#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/string/pretty_log.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -348,6 +349,9 @@ void FCLstmFusePass::ApplyImpl(ir::Graph* graph) const {
BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);

AddStatis(fusion_count);

string::PrettyLogDetail("--- fused %d pairs of fc lstm patterns",
fusion_count);
}

} // namespace ir
Expand Down
85 changes: 85 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,19 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("WeightX", {inputs[2]});
op->SetInput("WeightH", {inputs[3]});
op->SetOutput("Hidden", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
op->SetAttr("Scale_data", 1.0f);
op->SetAttr("Shift_data", 0.0f);
op->SetAttr("Weight_scale", std::vector<float>{1.0f});
} else if (type == "fusion_lstm") {
op->SetInput("X", {inputs[0]});
op->SetInput("Bias", {inputs[1]});
op->SetInput("WeightX", {inputs[2]});
op->SetInput("WeightH", {inputs[3]});

op->SetOutput("Hidden", {outputs[0]});
op->SetOutput("Cell", {outputs[1]});

op->SetAttr("mkldnn_data_type", mkldnn_data_type);
op->SetAttr("Scale_data", 1.0f);
op->SetAttr("Shift_data", 0.0f);
Expand Down Expand Up @@ -418,6 +431,25 @@ ProgramDesc BuildProgramDescFusionGru() {
return prog;
}

static const std::initializer_list<std::string> variable_names_fusion_lstm = {
"x", "wx", "wh", "b", "h", "c"};

// (x, wx, wh, b)->Fusion_lstm_1->h
ProgramDesc BuildProgramDescFusionLSTM() {
ProgramDesc prog;
for (auto& v : variable_names_transpose) {
auto* var = prog.MutableBlock(0)->Var(v);
if (v.find("wx") == 0 || v.find("wh") || v.find("b")) {
var->SetPersistable(true);
}
}

SetOp(&prog, "fusion_lstm", "Fusion_lstm_1", {"x", "wx", "wh", "b"},
{"h", "c"}, true, "int8");

return prog;
}

void MainTestFusionGru(const ProgramDesc& prog, int gru_count, int quant_count,
int dequant_count, int added_nodes_count, float scale,
float shift) {
Expand Down Expand Up @@ -470,6 +502,59 @@ TEST(CpuQuantizePass, fusion_gru) {
dequant_count, added_nodes_count, 2. * 127, 128.);
}

void MainTestFusionLSTM(const ProgramDesc& prog, int expect_lstm_count,
int quant_count, int dequant_count,
int added_nodes_count, float scale, float shift) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num, current_nodes_num;
PreparePass(&graph, prog, variable_names_fusion_lstm, &original_nodes_num,
&current_nodes_num);

int quantize_nodes_count = 0;
int dequantize_nodes_count = 0;
int lstm_nodes_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->Type() == "fusion_lstm") {
lstm_nodes_count++;

auto op_name = BOOST_GET_CONST(std::string, op->GetAttr("name"));
EXPECT_EQ(BOOST_GET_CONST(float, op->GetAttr("Scale_data")), scale)
<< "Scale_data for node '" + op_name + "'.";
EXPECT_EQ(BOOST_GET_CONST(float, op->GetAttr("Shift_data")), shift)
<< "Shift_data for node '" + op_name + "'.";
EXPECT_EQ(BOOST_GET_CONST(std::vector<float>,
op->GetAttr("Scale_weights"))[0],
scale)
<< "Scale_weights for node '" + op_name + "'.";
EXPECT_EQ(BOOST_GET_CONST(bool, op->GetAttr("force_fp32_output")), true)
<< "force_fp32_output for node '" + op_name + "'.";
} else if (op->Type() == "quantize") {
quantize_nodes_count++;
} else if (op->Type() == "dequantize") {
dequantize_nodes_count++;
}
}
}
EXPECT_EQ(lstm_nodes_count, expect_lstm_count);
EXPECT_EQ(quantize_nodes_count, quant_count);
EXPECT_EQ(dequantize_nodes_count, dequant_count);
EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num);
}

TEST(CpuQuantizePass, fusion_lstm) {
// (x, wx, wh, b)->Fusion_lstm->h
int expect_lstm_count = 1;
int expect_quant_count = 1;
int dequant_count = 0;
// 1 Quant + 1 IN + 0 DeQuant + 0 OUT
int added_nodes_count = 1 + 1 + 0 + 0;
MainTestFusionLSTM(BuildProgramDescFusionLSTM(), expect_lstm_count,
expect_quant_count, dequant_count, added_nodes_count,
2. * 127, 128.);
}

const std::vector<std::string> churn_out_vars(ProgramDesc* prog,
const std::string& prefix,
int number) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self,
self._relu_ops = ['relu', 'relu6']
self._matmul_ops = ['matmul']
self._gru_ops = ['fusion_gru', 'multi_gru']
self._lstm_ops = ['fusion_lstm']
self._weight_thresholds = {}
# Collect the Input and Output sclaes from Fake quant models
self._var_quant_scales = {}
Expand Down Expand Up @@ -534,10 +535,38 @@ def _compute_gru_weight_scales(wx_name, wh_name):
self._var_quant_scales[wx_var_name] = (use_unsigned_int,
lod_tensor)

def _compute_single_lstm_weight_scales(wx_var_name, wh_var_name):
wx = np.array(self._load_param(self._scope, wx_var_name))
wh = np.array(self._load_param(self._scope, wh_var_name))

lstm_weights_scale = 1.0 / np.max(
np.abs(np.concatenate(
[wx[:, :], wh[:, :]], axis=0)), axis=0)
lstm_weights_scale = lstm_weights_scale.astype('float')

return self._convert_scale2tensor(lstm_weights_scale)

def _compute_lstm_weight_scales(wx_name, wh_name):
for op in graph.all_op_nodes():
if op.op().type() in self._lstm_ops:
assert len(op.input(wx_name)) == len(
op.input(wh_name)
), 'Mismatch in number of weights inputs ({} for WeightX vs. {} for WeightH).'.format(
len(op.input(wx_name)), len(op.input(wh_name)))
for i, wx_var_name in enumerate(op.input(wx_name)):
wh_var_name = op.input(wh_name)[i]
use_unsigned_int = False
lod_tensor = _compute_single_lstm_weight_scales(
wx_var_name, wh_var_name)
self._var_quant_scales[wx_var_name] = (use_unsigned_int,
lod_tensor)

_compute_var_scales(self._conv_ops, "Filter", axis=1)
_compute_var_scales(self._fc_ops, "W", axis=0)
_compute_var_scales(self._gru_ops, "WeightH", axis=0)
_compute_var_scales(self._lstm_ops, "WeightH", axis=0)
_compute_gru_weight_scales("WeightX", "WeightH")
_compute_lstm_weight_scales("WeightX", "WeightH")
return graph

def _find_avg_pooling_ids(self, graph):
Expand Down
9 changes: 9 additions & 0 deletions python/paddle/fluid/contrib/slim/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ if(LINUX AND WITH_MKLDNN)
download_quant_model(${QUANT2_GRU_MODEL_DIR} ${QUANT2_GRU_MODEL_ARCHIVE} cf207f8076dcfb8b74d8b6bdddf9090c)
set(QUANT2_GRU_OPS_TO_QUANTIZE "multi_gru")

# Quant2 LSTM
set(QUANT2_LSTM_MODEL_ARCHIVE "lstm_quant.tar.gz")
set(QUANT2_LSTM_MODEL_DIR "${QUANT_INSTALL_DIR}/lstm_quant_test")
download_quant_model(${QUANT2_LSTM_MODEL_DIR} ${QUANT2_LSTM_MODEL_ARCHIVE} 40a693803b12ee9e251258f32559abcb)
set(QUANT2_LSTM_OPS_TO_QUANTIZE "fusion_lstm")

### Save FP32 model or INT8 model from Quant model

set(QUANT2_INT8_RESNET50_SAVE_PATH "${QUANT_INSTALL_DIR}/ResNet50_quant2_int8")
Expand All @@ -276,6 +282,9 @@ if(LINUX AND WITH_MKLDNN)
set(QUANT2_INT8_GRU_SAVE_PATH "${QUANT_INSTALL_DIR}/GRU_quant2_int8")
save_quant_nlp_model_test(save_quant2_model_gru ${QUANT2_GRU_MODEL_DIR}/GRU_quant_acc ${QUANT2_INT8_GRU_SAVE_PATH} ${QUANT2_GRU_OPS_TO_QUANTIZE})

set(QUANT2_INT8_LSTM_SAVE_PATH "${QUANT_INSTALL_DIR}/lstm_quant2_int8")
save_quant_nlp_model_test(save_quant2_model_lstm ${QUANT2_LSTM_MODEL_DIR}/lstm_quant ${QUANT2_INT8_LSTM_SAVE_PATH} ${QUANT2_LSTM_OPS_TO_QUANTIZE})

# Convert Quant2 model to dot and pdf files
set(QUANT2_INT8_ERNIE_DOT_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8_dot_file")
convert_model2dot_test(convert_model2dot_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QUANT2_INT8_ERNIE_DOT_SAVE_PATH} "Ernie_quant2_int8")
Expand Down