Skip to content

Commit 6da6ff6

Browse files
authored
SimplifyWithBasicOpsPass (#33637)
* simplify_with_basic * fix * scale factor
1 parent 478ea78 commit 6da6ff6

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

paddle/fluid/framework/ir/simplify_with_basic_ops_pass.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,26 @@ namespace ir {
3434
*/
3535
class Graph;
3636

37+
SimplifyWithBasicOpsPass::SimplifyWithBasicOpsPass() {
38+
AddOpCompat(OpCompat("scale"))
39+
.AddInput("X")
40+
.IsTensor()
41+
.End()
42+
.AddOutput("Out")
43+
.IsTensor()
44+
.End()
45+
.AddAttr("scale")
46+
.IsNumGE(0.f)
47+
.IsNumLE(1.f)
48+
.End()
49+
.AddAttr("bias")
50+
.IsNumEQ(0.f)
51+
.End()
52+
.AddAttr("bias_after_scale")
53+
.IsNumEQ(true)
54+
.End();
55+
}
56+
3757
void SimplifyWithBasicOpsPass::ApplyImpl(Graph* graph) const {
3858
VLOG(3) << "Simplify the Graph with basic ops.";
3959
std::unordered_set<const Node*> del_node_set;
@@ -145,6 +165,11 @@ bool SimplifyWithBasicOpsPass::SimplifyDropout(
145165
new_op_desc.SetAttr("bias", static_cast<float>(0));
146166
new_op_desc.SetAttr("bias_after_scale", true);
147167

168+
if (!IsCompat(new_op_desc)) {
169+
LOG(WARNING) << "Basic ops pass in scale op compat failed.";
170+
return false;
171+
}
172+
148173
auto* scale_op_node = graph->CreateOpNode(&new_op_desc);
149174
IR_NODE_LINK_TO(dropout_x, scale_op_node);
150175
IR_NODE_LINK_TO(scale_op_node, dropout_out);

paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License. */
1717
#include <string>
1818
#include <unordered_set>
1919

20-
#include "paddle/fluid/framework/ir/pass.h"
20+
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
2121

2222
namespace paddle {
2323
namespace framework {
@@ -26,7 +26,10 @@ namespace ir {
2626
class Graph;
2727
class Node;
2828

29-
class SimplifyWithBasicOpsPass : public Pass {
29+
class SimplifyWithBasicOpsPass : public OpCompatSensiblePass {
30+
public:
31+
SimplifyWithBasicOpsPass();
32+
3033
protected:
3134
void ApplyImpl(Graph* graph) const override;
3235

0 commit comments

Comments
 (0)