File tree Expand file tree Collapse file tree 2 files changed +30
-2
lines changed
paddle/fluid/framework/ir Expand file tree Collapse file tree 2 files changed +30
-2
lines changed Original file line number Diff line number Diff line change @@ -34,6 +34,26 @@ namespace ir {
3434 */
3535class Graph ;
3636
37+ SimplifyWithBasicOpsPass::SimplifyWithBasicOpsPass () {
38+ AddOpCompat (OpCompat (" scale" ))
39+ .AddInput (" X" )
40+ .IsTensor ()
41+ .End ()
42+ .AddOutput (" Out" )
43+ .IsTensor ()
44+ .End ()
45+ .AddAttr (" scale" )
46+ .IsNumGE (0 .f )
47+ .IsNumLE (1 .f )
48+ .End ()
49+ .AddAttr (" bias" )
50+ .IsNumEQ (0 .f )
51+ .End ()
52+ .AddAttr (" bias_after_scale" )
53+ .IsNumEQ (true )
54+ .End ();
55+ }
56+
3757void SimplifyWithBasicOpsPass::ApplyImpl (Graph* graph) const {
3858 VLOG (3 ) << " Simplify the Graph with basic ops." ;
3959 std::unordered_set<const Node*> del_node_set;
@@ -145,6 +165,11 @@ bool SimplifyWithBasicOpsPass::SimplifyDropout(
145165 new_op_desc.SetAttr (" bias" , static_cast <float >(0 ));
146166 new_op_desc.SetAttr (" bias_after_scale" , true );
147167
168+ if (!IsCompat (new_op_desc)) {
169+ LOG (WARNING) << " Basic ops pass in scale op compat failed." ;
170+ return false ;
171+ }
172+
148173 auto * scale_op_node = graph->CreateOpNode (&new_op_desc);
149174 IR_NODE_LINK_TO (dropout_x, scale_op_node);
150175 IR_NODE_LINK_TO (scale_op_node, dropout_out);
Original file line number Diff line number Diff line change @@ -17,7 +17,7 @@ limitations under the License. */
1717#include < string>
1818#include < unordered_set>
1919
20- #include " paddle/fluid/framework/ir/pass .h"
20+ #include " paddle/fluid/framework/ir/op_compat_sensible_pass .h"
2121
2222namespace paddle {
2323namespace framework {
@@ -26,7 +26,10 @@ namespace ir {
2626class Graph ;
2727class Node ;
2828
29- class SimplifyWithBasicOpsPass : public Pass {
29+ class SimplifyWithBasicOpsPass : public OpCompatSensiblePass {
30+ public:
31+ SimplifyWithBasicOpsPass ();
32+
3033 protected:
3134 void ApplyImpl (Graph* graph) const override ;
3235
You can’t perform that action at this time.
0 commit comments