@@ -16,6 +16,7 @@ limitations under the License. */
1616
1717#include < algorithm>
1818#include " paddle/fluid/framework/ir/graph_helper.h"
19+ #include " paddle/fluid/framework/op_proto_maker.h"
1920
2021namespace paddle {
2122namespace framework {
@@ -72,19 +73,6 @@ Graph *Pass::Apply(Graph *graph) const {
7273 return graph;
7374}
7475
75- void Pass::Apply (ProgramDesc *main_program,
76- ProgramDesc *startup_program) const {
77- VLOG (10 ) << " apply pass " << Type () << " to program" ;
78- PADDLE_ENFORCE_NOT_NULL (main_program, platform::errors::InvalidArgument (
79- " main program must be provided" ));
80- PADDLE_ENFORCE_NOT_NULL (
81- startup_program,
82- platform::errors::InvalidArgument (" startup program must be provided" ));
83-
84- ApplyImpl (main_program, startup_program);
85- VLOG (10 ) << " finish to apply pass " << Type () << " to program" ;
86- }
87-
8876template <typename Container, typename Visitor>
8977static void VisitAllElements (Container &&container, Visitor &&visitor,
9078 bool reverse) {
@@ -95,8 +83,8 @@ static void VisitAllElements(Container &&container, Visitor &&visitor,
9583 }
9684}
9785
98- void Pass:: MergePrograms (ProgramDesc *dst, const details::ProgramDescs &srcs,
99- bool append) {
86+ static void MergePrograms (ProgramDesc *dst, const details::ProgramDescs &srcs,
87+ bool append) {
10088 PADDLE_ENFORCE_NOT_NULL (
10189 dst, platform::errors::InvalidArgument (" Dst program must be provided." ));
10290 bool reverse = !append;
@@ -137,27 +125,105 @@ void Pass::MergePrograms(ProgramDesc *dst, const details::ProgramDescs &srcs,
137125 VisitAllElements (srcs, create_op_visitor, reverse);
138126}
139127
128+ static void FillNotSpecifiedOpRole (const ProgramDesc &main_program) {
129+ for (size_t block_idx = 0 ; block_idx < main_program.Size (); ++block_idx) {
130+ auto ops = main_program.Block (block_idx).AllOps ();
131+ size_t n = ops.size ();
132+ std::vector<OpRole> roles;
133+ roles.reserve (n);
134+ auto op_role_attr = OpProtoAndCheckerMaker::OpRoleAttrName ();
135+ for (auto *op : ops) {
136+ OpRole role;
137+ if (op->HasAttr (op_role_attr)) {
138+ role = static_cast <OpRole>(op->GetAttrIfExists <int >(op_role_attr));
139+ } else {
140+ role = OpRole::kNotSpecified ;
141+ }
142+ roles.emplace_back (role);
143+ }
144+
145+ // NOTE: The following codes may be wrong in some cases.
146+ // But how can we get the right OpRole? The right way
147+ // is that all passes should deal with unspecified OpRole.
148+ auto prev_role = OpRole::kForward ;
149+ for (size_t i = 0 ; i < n; ++i) {
150+ if (roles[i] == OpRole::kNotSpecified ) {
151+ VLOG (10 ) << " Fill op role of " << ops[i]->Type () << " as "
152+ << static_cast <int >(prev_role);
153+ ops[i]->SetAttr (op_role_attr, static_cast <int >(prev_role));
154+ } else {
155+ prev_role = roles[i];
156+ }
157+ }
158+ }
159+ }
160+
161+ void Pass::ApplyPassesToProgram (const std::vector<const Pass *> &passes,
162+ ProgramDesc *main_program,
163+ ProgramDesc *startup_program) {
164+ VLOG (10 ) << " ApplyPassesToProgram is called" ;
165+ PADDLE_ENFORCE_NOT_NULL (
166+ main_program,
167+ platform::errors::InvalidArgument (" The main program must be provided." ));
168+
169+ PADDLE_ENFORCE_NOT_NULL (startup_program,
170+ platform::errors::InvalidArgument (
171+ " The startup program must be provided." ));
172+
173+ for (auto *p : passes) {
174+ PADDLE_ENFORCE_NOT_NULL (p, platform::errors::InvalidArgument (
175+ " The provided pass cannot be nullptr." ));
176+ VLOG (10 ) << " Pass " << p->Type ();
177+ if (passes.size () > 1 ) {
178+ PADDLE_ENFORCE_EQ (p->SupportApplyProgramViaGraph (), true ,
179+ platform::errors::PermissionDenied (
180+ " Each pass must support to be applied via Graph if "
181+ " multi-passes are applied." ));
182+ }
183+ }
184+
185+ if (passes.size () == 1 && !passes[0 ]->SupportApplyProgramViaGraph ()) {
186+ VLOG (10 ) << " apply pass " << passes[0 ]->Type () << " to program" ;
187+ passes[0 ]->ApplyImpl (main_program, startup_program);
188+ FillNotSpecifiedOpRole (*main_program);
189+ VLOG (10 ) << " finish to apply pass " << passes[0 ]->Type () << " to program" ;
190+ return ;
191+ }
192+
193+ Graph graph (*main_program);
194+ for (auto *p : passes) {
195+ p->Apply (&graph);
196+ }
197+ ConvertToPrograms (&graph, main_program, startup_program);
198+ FillNotSpecifiedOpRole (*main_program);
199+ }
200+
140201void Pass::ApplyImpl (ProgramDesc *main_program,
141202 ProgramDesc *startup_program) const {
142- Graph graph (*main_program);
143- Apply (&graph);
203+ PADDLE_THROW (platform::errors::Unimplemented (
204+ " The pass %s does not support to apply ProgramDesc directly" , Type ()));
205+ }
144206
207+ void Pass::ConvertToPrograms (Graph *graph, ProgramDesc *main_program,
208+ ProgramDesc *startup_program) {
145209 ProgramDesc new_main_program;
146- GraphToProgram (graph, &new_main_program);
210+ GraphToProgram (* graph, &new_main_program);
147211 main_program->CopyFrom (*new_main_program.Proto ());
148212
149- if (graph. Has (details::kStartupProgramDescs )) {
213+ if (graph-> Has (details::kStartupProgramDescs )) {
150214 const auto &startups =
151- graph. Get <details::ProgramDescs>(details::kStartupProgramDescs );
215+ graph-> Get <details::ProgramDescs>(details::kStartupProgramDescs );
152216 VLOG (10 ) << " Merge startup programs" ;
153217 MergePrograms (startup_program, startups, /* append=*/ true );
218+ graph->Erase (details::kStartupProgramDescs );
154219 }
155220
156- if (graph. Has (details::kProgramDescs )) {
221+ if (graph-> Has (details::kProgramDescs )) {
157222 const auto &mains =
158- graph. Get <details::ProgramDescs>(details::kProgramDescs );
223+ graph-> Get <details::ProgramDescs>(details::kProgramDescs );
159224 VLOG (10 ) << " Merge main programs" ;
160225 MergePrograms (main_program, mains, /* append=*/ false );
226+ graph->Erase (details::kProgramDescs );
161227 }
162228
163229 startup_program->Flush ();
0 commit comments