1515#include < string>
1616
1717#include " glog/logging.h"
18+ #include " paddle/fluid/framework/executor_gc_helper.h"
1819#include " paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h"
1920#include " paddle/fluid/framework/ir/pass.h"
2021#include " paddle/fluid/platform/enforce.h"
@@ -30,6 +31,9 @@ class BufferSharedInplaceOpPass : public MemoryReusePass {
3031 std::string ReuseType () const override { return " inplace" ; }
3132
3233 void Run (Graph *graph) const override ;
34+
35+ void ApplyImpl (ProgramDesc *main_program,
36+ ProgramDesc *startup_program) const override ;
3337};
3438
3539void BufferSharedInplaceOpPass::Run (Graph *graph) const {
@@ -149,6 +153,141 @@ void BufferSharedInplaceOpPass::Run(Graph *graph) const {
149153 }
150154}
151155
156+ static std::string GetFirstVarName (const OpDesc &op, const std::string &slot,
157+ bool is_input) {
158+ const auto &name_map = is_input ? op.Inputs () : op.Outputs ();
159+ auto iter = name_map.find (slot);
160+ if (iter != name_map.end () && !iter->second .empty ()) {
161+ return iter->second [0 ];
162+ }
163+ return kEmptyVarName ;
164+ }
165+
166+ static std::vector<std::vector<std::pair<std::string, std::string>>>
167+ GetInplaceVars (const BlockDesc &block, bool use_cuda,
168+ const std::vector<std::string> &skip_vars) {
169+ PADDLE_ENFORCE_EQ (block.ID (), 0 , platform::errors::Unimplemented (
170+ " Inplace can only perform in block 0." ));
171+ // only take block 0 gc_vars
172+ const auto op_gc_vars =
173+ GetEagerDeletionCleanVars (*block.Program (), skip_vars)[0 ];
174+ const auto all_ops = block.AllOps ();
175+ PADDLE_ENFORCE_EQ (op_gc_vars.size (), all_ops.size (),
176+ platform::errors::PermissionDenied (
177+ " GC analysis error: op number not match." ));
178+ size_t n = all_ops.size ();
179+ std::unordered_set<std::string> visited_vars;
180+ std::unordered_set<std::string> reused_in_vars (skip_vars.begin (),
181+ skip_vars.end ());
182+ std::unordered_set<std::string> reused_out_vars (skip_vars.begin (),
183+ skip_vars.end ());
184+ for (const auto *op : all_ops) {
185+ if (op->Type () == " share_buffer" || op->Type () == " share_data" ) {
186+ const auto &inputs = op->Input (" X" );
187+ const auto &outputs = op->Output (" Out" );
188+ reused_in_vars.insert (inputs.begin (), inputs.end ());
189+ reused_out_vars.insert (outputs.begin (), outputs.end ());
190+ }
191+ }
192+
193+ std::vector<std::vector<std::pair<std::string, std::string>>> result (n);
194+ for (size_t i = 0 ; i < n; ++i) {
195+ const auto &op = *all_ops[i];
196+ const auto &gc_vars = op_gc_vars[i];
197+ const auto inputs = op.InputArgumentNames ();
198+ const auto outputs = op.OutputArgumentNames ();
199+ visited_vars.insert (inputs.begin (), inputs.end ());
200+
201+ auto &infer_inplace = OpInfoMap::Instance ().Get (op.Type ()).infer_inplace_ ;
202+ if (gc_vars.empty () || !infer_inplace) {
203+ visited_vars.insert (outputs.begin (), outputs.end ());
204+ continue ;
205+ }
206+
207+ const auto var_pair = infer_inplace (use_cuda);
208+ std::unordered_multiset<std::string> input_set (inputs.begin (),
209+ inputs.end ());
210+ std::unordered_multiset<std::string> output_set (outputs.begin (),
211+ outputs.end ());
212+ std::unordered_set<std::string> valid_vars;
213+ for (const auto &var : gc_vars) {
214+ if (var != kEmptyVarName && input_set.count (var) == 1 &&
215+ output_set.count (var) == 0 &&
216+ block.FindVar (var)->GetType () == proto::VarType::LOD_TENSOR) {
217+ valid_vars.insert (var);
218+ }
219+ }
220+
221+ if (valid_vars.empty ()) {
222+ visited_vars.insert (outputs.begin (), outputs.end ());
223+ continue ;
224+ }
225+
226+ for (const auto &pair : var_pair) {
227+ const auto &input_slot = pair.first ;
228+ const auto &output_slot = pair.second ;
229+ auto input_var = GetFirstVarName (op, input_slot, /* is_input=*/ true );
230+ if (input_var == kEmptyVarName || valid_vars.count (input_var) == 0 ) {
231+ continue ;
232+ }
233+ auto output_var = GetFirstVarName (op, output_slot, /* is_input=*/ false );
234+ if (output_var == kEmptyVarName || visited_vars.count (output_var) > 0 ) {
235+ continue ;
236+ }
237+ auto output_var_desc = block.FindVar (output_var);
238+ if (output_var_desc == nullptr || output_var_desc->Persistable () ||
239+ output_var_desc->GetType () != proto::VarType::LOD_TENSOR) {
240+ continue ;
241+ }
242+
243+ if (reused_in_vars.count (input_var) > 0 ||
244+ reused_out_vars.count (output_var) > 0 ) {
245+ continue ;
246+ }
247+
248+ // input_var -> output_var is reusable
249+ VLOG (10 ) << " inplace occurs at op " << i << " " << op.Type () << " : "
250+ << input_var << " -> " << output_var;
251+ result[i].emplace_back (input_var, output_var);
252+ reused_in_vars.insert (input_var);
253+ reused_out_vars.insert (output_var);
254+ }
255+ visited_vars.insert (outputs.begin (), outputs.end ());
256+ std::sort (result[i].begin (), result[i].end ());
257+ }
258+ return result;
259+ }
260+
261+ void BufferSharedInplaceOpPass::ApplyImpl (ProgramDesc *main_program,
262+ ProgramDesc *startup_program) const {
263+ bool use_cuda = Get<bool >(kUseCuda );
264+ auto skip_vars = Get<std::vector<std::string>>(" mem_opt_skip_vars" );
265+
266+ auto *block = main_program->MutableBlock (0 );
267+ auto inplace_vars = GetInplaceVars (*block, use_cuda, skip_vars);
268+ PADDLE_ENFORCE_EQ (inplace_vars.size (), block->OpSize (),
269+ platform::errors::PermissionDenied (
270+ " Inplace analysis error: op number not match." ));
271+ int64_t n = static_cast <int64_t >(inplace_vars.size ());
272+ for (int64_t i = n - 1 ; i >= 0 ; --i) {
273+ if (inplace_vars[i].empty ()) continue ;
274+ auto *op = block->InsertOp (i);
275+ std::vector<std::string> inputs, outputs;
276+ inputs.reserve (inplace_vars[i].size ());
277+ outputs.reserve (inplace_vars[i].size ());
278+ for (const auto &pair : inplace_vars[i]) {
279+ inputs.push_back (pair.first );
280+ outputs.push_back (pair.second );
281+ }
282+ op->SetType (" share_buffer" );
283+ op->SetInput (" X" , inputs);
284+ op->SetOutput (" Out" , outputs);
285+ op->SetOutput (" XOut" , inputs); // add necessary dependency
286+ op->SetAttr (" share_dims" , std::vector<bool >(inputs.size (), false ));
287+ }
288+ block->Flush ();
289+ }
290+
152291} // namespace ir
153292} // namespace framework
154293} // namespace paddle
0 commit comments