Skip to content

Commit d55f3b6

Browse files
authored
add compat precondition for attention_lstm_fuse_pass, test=develop (#33711)
1 parent 1017180 commit d55f3b6

File tree

5 files changed

+78
-3
lines changed

5 files changed

+78
-3
lines changed

paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,61 @@ namespace paddle {
2323
namespace framework {
2424
namespace ir {
2525

26+
AttentionLSTMFusePass::AttentionLSTMFusePass() {
27+
AddOpCompat(OpCompat("while"))
28+
.AddInput("X") // A set of variables, unconstrained
29+
.End()
30+
.AddInput("Condition") // An scalar
31+
.IsTensor()
32+
.End()
33+
.AddOutput("Out") // A set of variables, unconstrained
34+
.End()
35+
.AddOutput("StepScopes") // A vector of local scope, unconstrained
36+
.End()
37+
.AddAttr("sub_block")
38+
.IsType<framework::BlockDesc*>()
39+
.End();
40+
41+
AddOpCompat(OpCompat("fill_constant"))
42+
.AddInput("ValueTensor")
43+
.IsTensor()
44+
.IsOptional()
45+
.End()
46+
.AddInput("ShapeTensor")
47+
.IsTensor()
48+
.IsOptional()
49+
.End()
50+
.AddInput("ShapeTensorList") // vector<Tensor<int>>
51+
.IsOptional()
52+
.End()
53+
.AddOutput("Out")
54+
.IsTensor()
55+
.End()
56+
.AddAttr("dtype")
57+
.IsNumGE(0)
58+
.IsNumLE(25)
59+
.End()
60+
.AddAttr("shape")
61+
.IsType<std::vector<int>>()
62+
.End()
63+
.AddAttr("value")
64+
.IsType<float>()
65+
.End();
66+
67+
AddOpCompat(OpCompat("sequence_expand"))
68+
.AddInput("X")
69+
.IsTensor()
70+
.End()
71+
.AddInput("Y")
72+
.IsTensor()
73+
.End()
74+
.AddOutput("Out")
75+
.IsTensor()
76+
.End()
77+
.AddAttr("ref_level")
78+
.IsNumGE(-1)
79+
.End();
80+
}
2681
struct Param {
2782
std::string X = "concat_0.tmp_0";
2883
std::string C0 = "cell_init";
@@ -43,7 +98,7 @@ struct Param {
4398

4499
void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op);
45100

46-
void FindWhileOp(Graph* graph) {
101+
void AttentionLSTMFusePass::FindWhileOp(Graph* graph) const {
47102
GraphPatternDetector gpd;
48103
std::unordered_set<int> fused_external_ops(
49104
{35, 36, 37, 38, 43, 44, 49, 45, 46, 47, 41, 42, 53, 54, 48,
@@ -60,6 +115,10 @@ void FindWhileOp(Graph* graph) {
60115

61116
auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph,
62117
Graph* g) {
118+
if (!IsCompat(subgraph, g)) {
119+
LOG(WARNING) << "Pass in op compat failed.";
120+
return;
121+
}
63122
auto* while_pat_node = gpd.pattern().RetrieveNode("while");
64123
auto* while_node = subgraph.at(while_pat_node);
65124
marked_nodes.insert(while_node);

paddle/fluid/framework/ir/attention_lstm_fuse_pass.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,14 @@ namespace ir {
2323
class Graph;
2424

2525
class AttentionLSTMFusePass : public FusePassBase {
26+
public:
27+
AttentionLSTMFusePass();
28+
2629
protected:
2730
void ApplyImpl(ir::Graph* graph) const override;
31+
32+
private:
33+
void FindWhileOp(Graph* graph) const;
2834
};
2935

3036
} // namespace ir

paddle/fluid/framework/ir/op_compat_sensible_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ bool OpCompatSensiblePass::IsCompat(
260260
auto op_type = node_pair.second->Op()->Type();
261261
if (!op_compat_judgers_.count(op_type)) {
262262
if (HasOpDef(op_type)) {
263-
LOG(WARNING) << op_type << "compat not registered!";
263+
LOG(WARNING) << op_type << " compat not registered!";
264264
return false;
265265
}
266266
continue;

paddle/fluid/framework/ir/op_compat_sensible_pass.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ class AttrCompat {
3131
AttrCompat(const std::string& attr_name, OpCompat* op_compat)
3232
: optional_(false), attr_name_(attr_name), op_compat_(op_compat) {}
3333

34+
//! Assert the attribute type is `T`.
35+
template <typename T>
36+
AttrCompat& IsType();
37+
3438
// @{ String-related methods
3539
//! Assert the attribute is an string in the `candidates` domain.
3640
AttrCompat& IsStringIn(const std::set<std::string>& candidates);
@@ -207,6 +211,13 @@ class OpCompatSensiblePass : public Pass {
207211
std::map<std::string, std::unique_ptr<OpCompat>> op_compat_judgers_;
208212
};
209213

214+
template <typename T>
215+
AttrCompat& AttrCompat::IsType() {
216+
conditions_.emplace_back(
217+
[](const Attribute& attr) -> bool { return attr.type() == typeid(T); });
218+
return *this;
219+
}
220+
210221
template <typename T>
211222
AttrCompat& AttrCompat::IsNumGT(T v) {
212223
conditions_.emplace_back([v](const Attribute& attr) -> bool {

paddle/fluid/operators/compat/fill_constant.pbtxt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def {
2424
name: "value"
2525
type: FLOAT
2626
}
27-
2827
}
2928
extra {
3029
attrs {

0 commit comments

Comments
 (0)