Skip to content

Commit 85d4e88

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into variable-keeping-with-ensor
2 parents d88cec9 + 9bf00cd commit 85d4e88

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+2523
-94
lines changed

cmake/external/mkldnn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ SET(MKLDNN_SOURCE_DIR ${THIRD_PARTY_PATH}/mkldnn/src/extern_mkldnn)
2020
SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn)
2121
SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE)
2222
SET(MKLDNN_REPOSITORY ${GIT_URL}/oneapi-src/oneDNN.git)
23-
SET(MKLDNN_TAG 748528a2d3204b5f401c14a9aacdec16accd5ead)
23+
SET(MKLDNN_TAG bbaf5d24dde1b6760435d5034d6f48feae7a30b9)
2424

2525

2626
# Introduce variables:

paddle/fluid/framework/fleet/ps_gpu_wrapper.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,15 @@ class PSGPUWrapper {
209209
void EndPass() { HeterPs_->end_pass(); }
210210
void ShowOneTable(int index) { HeterPs_->show_one_table(index); }
211211

212+
void Finalize() {
213+
VLOG(3) << "PSGPUWrapper Begin Finalize.";
214+
if (s_instance_ == nullptr) {
215+
return;
216+
}
217+
s_instance_ = nullptr;
218+
VLOG(3) << "PSGPUWrapper Finalize Finished.";
219+
}
220+
212221
private:
213222
static std::shared_ptr<PSGPUWrapper> s_instance_;
214223
Dataset* dataset_;

paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,61 @@ namespace paddle {
2323
namespace framework {
2424
namespace ir {
2525

26+
AttentionLSTMFusePass::AttentionLSTMFusePass() {
27+
AddOpCompat(OpCompat("while"))
28+
.AddInput("X") // A set of variables, unconstrained
29+
.End()
30+
.AddInput("Condition") // An scalar
31+
.IsTensor()
32+
.End()
33+
.AddOutput("Out") // A set of variables, unconstrained
34+
.End()
35+
.AddOutput("StepScopes") // A vector of local scope, unconstrained
36+
.End()
37+
.AddAttr("sub_block")
38+
.IsType<framework::BlockDesc*>()
39+
.End();
40+
41+
AddOpCompat(OpCompat("fill_constant"))
42+
.AddInput("ValueTensor")
43+
.IsTensor()
44+
.IsOptional()
45+
.End()
46+
.AddInput("ShapeTensor")
47+
.IsTensor()
48+
.IsOptional()
49+
.End()
50+
.AddInput("ShapeTensorList") // vector<Tensor<int>>
51+
.IsOptional()
52+
.End()
53+
.AddOutput("Out")
54+
.IsTensor()
55+
.End()
56+
.AddAttr("dtype")
57+
.IsNumGE(0)
58+
.IsNumLE(25)
59+
.End()
60+
.AddAttr("shape")
61+
.IsType<std::vector<int>>()
62+
.End()
63+
.AddAttr("value")
64+
.IsType<float>()
65+
.End();
66+
67+
AddOpCompat(OpCompat("sequence_expand"))
68+
.AddInput("X")
69+
.IsTensor()
70+
.End()
71+
.AddInput("Y")
72+
.IsTensor()
73+
.End()
74+
.AddOutput("Out")
75+
.IsTensor()
76+
.End()
77+
.AddAttr("ref_level")
78+
.IsNumGE(-1)
79+
.End();
80+
}
2681
struct Param {
2782
std::string X = "concat_0.tmp_0";
2883
std::string C0 = "cell_init";
@@ -43,7 +98,7 @@ struct Param {
4398

4499
void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op);
45100

46-
void FindWhileOp(Graph* graph) {
101+
void AttentionLSTMFusePass::FindWhileOp(Graph* graph) const {
47102
GraphPatternDetector gpd;
48103
std::unordered_set<int> fused_external_ops(
49104
{35, 36, 37, 38, 43, 44, 49, 45, 46, 47, 41, 42, 53, 54, 48,
@@ -60,6 +115,10 @@ void FindWhileOp(Graph* graph) {
60115

61116
auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph,
62117
Graph* g) {
118+
if (!IsCompat(subgraph, g)) {
119+
LOG(WARNING) << "Pass in op compat failed.";
120+
return;
121+
}
63122
auto* while_pat_node = gpd.pattern().RetrieveNode("while");
64123
auto* while_node = subgraph.at(while_pat_node);
65124
marked_nodes.insert(while_node);

paddle/fluid/framework/ir/attention_lstm_fuse_pass.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,14 @@ namespace ir {
2323
class Graph;
2424

2525
class AttentionLSTMFusePass : public FusePassBase {
26+
public:
27+
AttentionLSTMFusePass();
28+
2629
protected:
2730
void ApplyImpl(ir::Graph* graph) const override;
31+
32+
private:
33+
void FindWhileOp(Graph* graph) const;
2834
};
2935

3036
} // namespace ir

paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,37 @@ namespace ir {
3232
GET_IR_NODE(quant_dequant_op_outscale); \
3333
GET_IR_NODE(any_op2);
3434

35+
DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() {
36+
AddOpCompat(OpCompat("fake_quantize_dequantize_abs_max"))
37+
.AddInput("X")
38+
.IsTensor()
39+
.End()
40+
.AddOutput("Out")
41+
.IsTensor()
42+
.End()
43+
.AddOutput("OutScale")
44+
.IsTensor()
45+
.End()
46+
.AddAttr("bit_length")
47+
.IsIntIn({8, 16})
48+
.End();
49+
AddOpCompat(OpCompat("fake_channel_wise_quantize_dequantize_abs_max"))
50+
.AddInput("X")
51+
.IsTensor()
52+
.End()
53+
.AddOutput("Out")
54+
.IsTensor()
55+
.End()
56+
.AddOutput("OutScale")
57+
.IsTensor()
58+
.End()
59+
.AddAttr("bit_length")
60+
.IsIntIn({8, 16})
61+
.End()
62+
.AddAttr("quant_axis")
63+
.IsIntIn({0, 1})
64+
.End();
65+
}
3566
// Delete quant_dequant_op, then quantize and dequantize weight
3667
void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
3768
const std::string pattern_name = "delete_quantdequant_filter_op_pattern";
@@ -50,6 +81,11 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
5081
Graph* g) {
5182
GET_NODES;
5283

84+
if (!IsCompat(*quant_dequant_op->Op())) {
85+
LOG(WARNING) << "quant_dequant_op in delete_quant_dequant_filter_op_pass "
86+
"compat check failed.";
87+
return;
88+
}
5389
std::unordered_set<const Node*> nodes2rm = {};
5490
int bit_length =
5591
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("bit_length"));

paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,14 @@
1616
#include <vector>
1717

1818
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
19-
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
2019

2120
namespace paddle {
2221
namespace framework {
2322
namespace ir {
2423

25-
class Graph;
26-
2724
class DeleteQuantDequantFilterOpPass : public FusePassBase {
2825
public:
26+
DeleteQuantDequantFilterOpPass();
2927
virtual ~DeleteQuantDequantFilterOpPass() {}
3028

3129
protected:

paddle/fluid/framework/ir/op_compat_sensible_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ bool OpCompatSensiblePass::IsCompat(
260260
auto op_type = node_pair.second->Op()->Type();
261261
if (!op_compat_judgers_.count(op_type)) {
262262
if (HasOpDef(op_type)) {
263-
LOG(WARNING) << op_type << "compat not registered!";
263+
LOG(WARNING) << op_type << " compat not registered!";
264264
return false;
265265
}
266266
continue;

paddle/fluid/framework/ir/op_compat_sensible_pass.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ class AttrCompat {
3131
AttrCompat(const std::string& attr_name, OpCompat* op_compat)
3232
: optional_(false), attr_name_(attr_name), op_compat_(op_compat) {}
3333

34+
//! Assert the attribute type is `T`.
35+
template <typename T>
36+
AttrCompat& IsType();
37+
3438
// @{ String-related methods
3539
//! Assert the attribute is an string in the `candidates` domain.
3640
AttrCompat& IsStringIn(const std::set<std::string>& candidates);
@@ -207,6 +211,13 @@ class OpCompatSensiblePass : public Pass {
207211
std::map<std::string, std::unique_ptr<OpCompat>> op_compat_judgers_;
208212
};
209213

214+
template <typename T>
215+
AttrCompat& AttrCompat::IsType() {
216+
conditions_.emplace_back(
217+
[](const Attribute& attr) -> bool { return attr.type() == typeid(T); });
218+
return *this;
219+
}
220+
210221
template <typename T>
211222
AttrCompat& AttrCompat::IsNumGT(T v) {
212223
conditions_.emplace_back([v](const Attribute& attr) -> bool {

paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.cc

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,52 @@ static void GetConcatNodes(ir::Graph* graph, std::vector<Node*>* concat_nodes) {
5252
}
5353
} // anonymous namespace
5454

55+
SeqPoolCVMConcatFusePass::SeqPoolCVMConcatFusePass() {
56+
AddOpCompat(OpCompat("sequence_pool"))
57+
.AddInput("X")
58+
.IsTensor()
59+
.End()
60+
.AddOutput("Out")
61+
.IsTensor()
62+
.End()
63+
.AddOutput("MaxIndex")
64+
.IsTensor()
65+
.IsOptional()
66+
.End()
67+
.AddAttr("pooltype")
68+
.IsStringIn({"AVERAGE", "SUM", "SQRT", "LAST", "FIRST", "MAX"})
69+
.End()
70+
.AddAttr("pad_value")
71+
.End();
72+
AddOpCompat(OpCompat("cvm"))
73+
.AddInput("X")
74+
.IsTensor()
75+
.End()
76+
.AddInput("CVM")
77+
.IsTensor()
78+
.End()
79+
.AddOutput("Y")
80+
.IsTensor()
81+
.End()
82+
.AddAttr("use_cvm")
83+
.IsBoolEQ(true)
84+
.End();
85+
AddOpCompat(OpCompat("concat"))
86+
.AddInput("X")
87+
.IsTensor()
88+
.End()
89+
.AddInput("AxisTensor")
90+
.IsTensor()
91+
.IsOptional()
92+
.End()
93+
.AddOutput("Out")
94+
.IsTensor()
95+
.End()
96+
.AddAttr("axis")
97+
.IsNumGE(1)
98+
.End();
99+
}
100+
55101
void SeqPoolCVMConcatFusePass::ApplyImpl(ir::Graph* graph) const {
56102
FusePassBase::Init("seqpool_cvm_concat_fuse", graph);
57103
std::vector<Node*> concat_nodes;

paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class Graph;
4444

4545
class SeqPoolCVMConcatFusePass : public FusePassBase {
4646
public:
47-
virtual ~SeqPoolCVMConcatFusePass() {}
47+
SeqPoolCVMConcatFusePass();
4848

4949
protected:
5050
void ApplyImpl(ir::Graph* graph) const override;

0 commit comments

Comments
 (0)