@@ -17,11 +17,8 @@ limitations under the License. */
1717#include < gflags/gflags.h>
1818#include < algorithm>
1919
20- #include " paddle/fluid/framework/ir/graph_helper.h"
2120#include " paddle/fluid/framework/op_proto_maker.h"
2221
23- DECLARE_bool (convert_all_blocks);
24-
2522namespace paddle {
2623namespace framework {
2724class ProgramDesc ;
@@ -33,116 +30,12 @@ namespace framework {
3330namespace ir {
3431
3532void GraphToProgramPass::ApplyImpl (ir::Graph* graph) const {
36- PADDLE_ENFORCE_EQ (graph->IsMainGraph (), true ,
37- platform::errors::InvalidArgument (
38- " This graph is a sub_graph, "
39- " and can't convert to program individually" ));
40-
41- ProgramDesc& program = Get<ProgramDesc>(" program" );
42-
43- std::unique_ptr<proto::ProgramDesc> program_pb (
44- new proto::ProgramDesc (*program.Proto ()));
45-
46- auto block = program_pb->mutable_blocks (kRootBlockIndex );
47- block->set_idx (kRootBlockIndex );
48-
49- if (FLAGS_convert_all_blocks) {
50- GraphToBlock (graph->GetSubGraph (kRootBlockIndex ), block);
51-
52- VLOG (3 ) << " Graph to program need convert " << graph->SubGraphsSize ()
53- << " sub graph" ;
54- for (size_t idx = 0 ; idx < graph->SubGraphsSize (); ++idx) {
55- // avoid kRootBlockIndex not 0
56- if (idx == kRootBlockIndex ) continue ;
57-
58- block = program_pb->add_blocks ();
59- block->set_idx (idx);
60- GraphToBlock (graph->GetSubGraph (idx), block);
61- }
62- } else {
63- GraphToBlock (graph, block);
64- }
65-
66- program.CopyFrom (*program_pb);
67- }
68-
69- OpDesc* ReplaceScaleLossGradOp (ir::Node* node, OpDesc* desc) {
70- desc->SetType (" fill_constant" );
71- desc->SetAttr (
72- OpProtoAndCheckerMaker::OpRoleAttrName (),
73- (static_cast <int >(OpRole::kBackward ) | static_cast <int >(OpRole::kLoss )));
74- desc->SetAttr (" value" , 1 .0f );
75- std::vector<std::string> output_names;
76- for (auto out : node->outputs ) {
77- output_names.emplace_back (out->Name ());
78- }
79- desc->SetOutput (" Out" , output_names);
80- return desc;
81- }
82-
83- std::vector<OpDesc>* GetGraphOpDesc (const std::vector<ir::Node*>& nodes,
84- std::vector<OpDesc>* ops) {
85- for (ir::Node* n : nodes) {
86- // if node is not Op, skip
87- if (!n->IsOp ()) continue ;
88-
89- // create fill_constant op
90- if (n->Name () == " scale_loss_grad" ) {
91- ops->emplace_back ();
92- auto & desc = ops->back ();
93- ReplaceScaleLossGradOp (n, &desc);
94- } else if (n->Op ()) {
95- ops->emplace_back (*n->Op ());
96- } else {
97- // delete no OpDesc op
98- }
99- }
100- return ops;
101- }
102-
103- void GraphToProgramPass::GraphToBlock (const Graph* graph,
104- proto::BlockDesc* block) const {
105- // Remove the unneeded variables after memory optimization.
106- std::unordered_set<std::string> vars2remove;
107- if (graph->Has (kGraphToProgramVarsToRemove )) {
108- vars2remove = graph->Get <std::unordered_set<std::string>>(
109- kGraphToProgramVarsToRemove );
110- VLOG (2 ) << " graph (id: " << block->idx () << " ) to program remove "
111- << vars2remove.size () << " nodes" ;
112- }
113-
114- block->clear_vars ();
115- std::unordered_set<std::string> visited_vars;
116- for (ir::Node* n : graph->Nodes ()) {
117- if (n->IsVar ()) {
118- if (n->Var () && visited_vars.count (n->Var ()->Name ()) == 0 &&
119- !vars2remove.count (n->Var ()->Name ()) &&
120- n->GetVarNodeBlockId () == graph->GetBlockId ()) {
121- visited_vars.insert (n->Var ()->Name ());
122- block->add_vars ()->MergeFrom (*n->Var ()->Proto ());
123- }
124- }
125- }
126- block->clear_ops ();
127-
128- std::vector<ir::Node*> nodes;
33+ auto & program = Get<ProgramDesc>(" program" );
12934 if (Has (kGraphToProgramSortKind )) {
130- // Inference Memory Optimize relays on this branch.
131- int sort_kind = Get<int >(kGraphToProgramSortKind );
132- nodes = TopologyVarientSort (
133- *graph, static_cast <framework::ir::SortKind>(sort_kind));
35+ auto sort_kind = static_cast <SortKind>(Get<int >(kGraphToProgramSortKind ));
36+ GraphToProgram (*graph, &program, &sort_kind);
13437 } else {
135- if (FLAGS_convert_all_blocks) {
136- nodes = TopologySortGraphByDescOrder (*graph);
137- } else {
138- nodes = TopologySortOperations (*graph);
139- }
140- }
141-
142- std::vector<OpDesc> ops;
143- GetGraphOpDesc (nodes, &ops);
144- for (auto & op : ops) {
145- block->add_ops ()->MergeFrom (*op.Proto ());
38+ GraphToProgram (*graph, &program, nullptr );
14639 }
14740}
14841
0 commit comments