Skip to content

Commit 3508bd2

Browse files
authored
Add the op def for elementwise_mul and enhance layer_norm_fuse_pass (#33560)
1 parent 11f5a40 commit 3508bd2

File tree

3 files changed

+158
-18
lines changed

3 files changed

+158
-18
lines changed

paddle/fluid/framework/ir/layer_norm_fuse_pass.cc

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,122 @@ void addIntermediateOut(Node* op_node, const std::string& out_name,
9999

100100
} // namespace
101101

102+
LayerNormFusePass::LayerNormFusePass() {
103+
AddOpCompat(OpCompat("layer_norm"))
104+
.AddInput("X")
105+
.IsTensor()
106+
.End()
107+
.AddInput("Scale")
108+
.IsTensor()
109+
.End()
110+
.AddInput("Bias")
111+
.IsTensor()
112+
.End()
113+
.AddOutput("Y")
114+
.IsTensor()
115+
.End()
116+
.AddOutput("Mean")
117+
.IsTensor()
118+
.IsOptional()
119+
.End()
120+
.AddOutput("Variance")
121+
.IsTensor()
122+
.IsOptional()
123+
.End()
124+
.AddAttr("epsilon")
125+
.IsNumGE(0.0f)
126+
.IsNumLE(0.001f)
127+
.End()
128+
.AddAttr("begin_norm_axis")
129+
.IsNumGT(0)
130+
.End();
131+
AddOpCompat(OpCompat("reduce_mean"))
132+
.AddInput("X")
133+
.IsTensor()
134+
.End()
135+
.AddOutput("Out")
136+
.IsTensor()
137+
.End()
138+
.AddAttr("dim")
139+
.IsType<std::vector<int>>()
140+
.End()
141+
.AddAttr("keep_dim")
142+
.IsBoolEQ(true)
143+
.End();
144+
AddOpCompat(OpCompat("sqrt"))
145+
.AddInput("X")
146+
.IsTensor()
147+
.End()
148+
.AddOutput("Out")
149+
.IsTensor()
150+
.End();
151+
AddOpCompat(OpCompat("elementwise_sub"))
152+
.AddInput("X")
153+
.IsTensor()
154+
.End()
155+
.AddInput("Y")
156+
.IsTensor()
157+
.End()
158+
.AddOutput("Out")
159+
.IsTensor()
160+
.End()
161+
.AddAttr("axis")
162+
.IsNumEQ(1)
163+
.End();
164+
AddOpCompat(OpCompat("elementwise_pow"))
165+
.AddInput("X")
166+
.IsTensor()
167+
.End()
168+
.AddInput("Y")
169+
.IsTensor()
170+
.End()
171+
.AddOutput("Out")
172+
.IsTensor()
173+
.End()
174+
.AddAttr("axis")
175+
.IsNumEQ(1)
176+
.End();
177+
AddOpCompat(OpCompat("elementwise_add"))
178+
.AddInput("X")
179+
.IsTensor()
180+
.End()
181+
.AddInput("Y")
182+
.IsTensor()
183+
.End()
184+
.AddOutput("Out")
185+
.IsTensor()
186+
.End()
187+
.AddAttr("axis")
188+
.IsNumEQ(1)
189+
.End();
190+
AddOpCompat(OpCompat("elementwise_div"))
191+
.AddInput("X")
192+
.IsTensor()
193+
.End()
194+
.AddInput("Y")
195+
.IsTensor()
196+
.End()
197+
.AddOutput("Out")
198+
.IsTensor()
199+
.End()
200+
.AddAttr("axis")
201+
.IsNumEQ(1)
202+
.End();
203+
AddOpCompat(OpCompat("elementwise_mul"))
204+
.AddInput("X")
205+
.IsTensor()
206+
.End()
207+
.AddInput("Y")
208+
.IsTensor()
209+
.End()
210+
.AddOutput("Out")
211+
.IsTensor()
212+
.End()
213+
.AddAttr("axis")
214+
.IsNumEQ(1)
215+
.End();
216+
}
217+
102218
void LayerNormFusePass::ApplyImpl(Graph* graph) const {
103219
PADDLE_ENFORCE_NOT_NULL(graph,
104220
platform::errors::InvalidArgument(
@@ -117,6 +233,10 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const {
117233
int found_layer_norm_count = 0;
118234
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
119235
Graph* g) {
236+
if (!IsCompat(subgraph, g)) {
237+
LOG(WARNING) << "Pass in op compat failed.";
238+
return;
239+
}
120240
VLOG(4) << "Fuse LayerNorm from subgraph.";
121241
GET_IR_NODE_FROM_SUBGRAPH(x, x, layer_norm_pattern);
122242
GET_IR_NODE_FROM_SUBGRAPH(x_mean, x_mean, layer_norm_pattern);
@@ -205,6 +325,12 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const {
205325
ln_op_desc.SetAttr("begin_norm_axis", static_cast<int>(x_shape.size() - 1));
206326
ln_op_desc.SetAttr("epsilon", *(eps_tensor->data<float>()));
207327
ln_op_desc.SetAttr("is_test", true);
328+
329+
if (!IsCompat(ln_op_desc)) {
330+
LOG(WARNING) << "layer norm pass in out layer_norm op compat failed.";
331+
return;
332+
}
333+
208334
Node* ln_op = g->CreateOpNode(&ln_op_desc);
209335

210336
addIntermediateOut(ln_op, "Mean", scope_name_, g);

paddle/fluid/framework/ir/layer_norm_fuse_pass.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ namespace ir {
7070
*/
7171
class LayerNormFusePass : public FusePassBase {
7272
public:
73+
LayerNormFusePass();
7374
virtual ~LayerNormFusePass() {}
7475

7576
protected:

paddle/fluid/framework/ir/layer_norm_fuse_pass_tester.cc

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -66,33 +66,46 @@ class LayerNormFuseTest {
6666
x_mean->SetAttr("keep_dim", true);
6767
x_mean->SetAttr("reduce_all", false);
6868

69-
test::CreateOp(&m_prog, "elementwise_sub",
70-
{{"X", "x"}, {"Y", "x_mean_out"}},
71-
{{"Out", "x_sub_mean_out"}}, false);
72-
test::CreateOp(&m_prog, "elementwise_pow",
73-
{{"X", "x_sub_mean_out"}, {"Y", "sqr_pow"}},
74-
{{"Out", "x_sub_mean_sqr_out"}}, false);
69+
auto* x_sub = test::CreateOp(&m_prog, "elementwise_sub",
70+
{{"X", "x"}, {"Y", "x_mean_out"}},
71+
{{"Out", "x_sub_mean_out"}}, false);
72+
x_sub->SetAttr("axis", 1);
73+
74+
auto* x_pow = test::CreateOp(&m_prog, "elementwise_pow",
75+
{{"X", "x_sub_mean_out"}, {"Y", "sqr_pow"}},
76+
{{"Out", "x_sub_mean_sqr_out"}}, false);
77+
x_pow->SetAttr("axis", 1);
78+
7579
auto* std_dev =
7680
test::CreateOp(&m_prog, "reduce_mean", {{"X", "x_sub_mean_sqr_out"}},
7781
{{"Out", "std_dev_out"}}, false);
7882
std_dev->SetAttr("dim", std::vector<int>{-1});
7983
std_dev->SetAttr("keep_dim", true);
8084
std_dev->SetAttr("reduce_all", false);
8185

82-
test::CreateOp(&m_prog, "elementwise_add",
83-
{{"X", "std_dev_out"}, {"Y", "eps"}},
84-
{{"Out", "std_dev_eps_out"}}, false);
86+
auto* x_add = test::CreateOp(&m_prog, "elementwise_add",
87+
{{"X", "std_dev_out"}, {"Y", "eps"}},
88+
{{"Out", "std_dev_eps_out"}}, false);
89+
x_add->SetAttr("axis", 1);
90+
8591
test::CreateOp(&m_prog, "sqrt", {{"X", "std_dev_eps_out"}},
8692
{{"Out", "std_dev_eps_sqrt_out"}}, false);
87-
test::CreateOp(&m_prog, "elementwise_div",
88-
{{"X", "x_sub_mean_out"}, {"Y", "std_dev_eps_sqrt_out"}},
89-
{{"Out", "division_out"}}, false);
90-
test::CreateOp(&m_prog, "elementwise_mul",
91-
{{"X", "division_out"}, {"Y", "gamma"}},
92-
{{"Out", "scale_out"}}, false);
93-
test::CreateOp(&m_prog, "elementwise_add",
94-
{{"X", "scale_out"}, {"Y", "beta"}}, {{"Out", "shift_out"}},
95-
false);
93+
94+
auto* x_div =
95+
test::CreateOp(&m_prog, "elementwise_div",
96+
{{"X", "x_sub_mean_out"}, {"Y", "std_dev_eps_sqrt_out"}},
97+
{{"Out", "division_out"}}, false);
98+
x_div->SetAttr("axis", 1);
99+
100+
auto* x_mul = test::CreateOp(&m_prog, "elementwise_mul",
101+
{{"X", "division_out"}, {"Y", "gamma"}},
102+
{{"Out", "scale_out"}}, false);
103+
x_mul->SetAttr("axis", 1);
104+
105+
auto* x_add_v1 = test::CreateOp(&m_prog, "elementwise_add",
106+
{{"X", "scale_out"}, {"Y", "beta"}},
107+
{{"Out", "shift_out"}}, false);
108+
x_add_v1->SetAttr("axis", 1);
96109
}
97110

98111
template <typename Func>

0 commit comments

Comments
 (0)