@@ -49,11 +49,28 @@ bool IsTarget(const proto::OpDesc& op_desc) {
4949 return false ;
5050}
5151
52- void prune_impl (const proto::ProgramDesc& input, proto::ProgramDesc* output,
53- int block_id) {
54- // TODO(tonyyang-svail):
55- // - will change to use multiple blocks for RNN op and Cond Op
52+ int GetSubBlockIndex (const proto::OpDesc& op_desc) {
53+ for (auto & attr : op_desc.attrs ()) {
54+ if (attr.type () == proto::AttrType::BLOCK) {
55+ PADDLE_ENFORCE (attr.has_block_idx ());
56+ return attr.block_idx ();
57+ }
58+ }
59+ return -1 ;
60+ }
61+
62+ bool HasSubBlock (const proto::OpDesc& op_desc) {
63+ return GetSubBlockIndex (op_desc) > 0 ;
64+ }
5665
66+ // block_id is the idx of the current block in the input desc
67+ // parent_block_id is the idx of the parent of the current block
68+ // in the output desc, -1 means the current block is global block
69+ // dependent_vars is passed recursively from the parent block to
70+ // the child block to help pruning
71+ void prune_impl (const proto::ProgramDesc& input, proto::ProgramDesc* output,
72+ int block_id, int parent_block_id,
73+ std::set<std::string>& dependent_vars) {
5774 auto & block = input.blocks (block_id);
5875 auto & ops = block.ops ();
5976
@@ -72,19 +89,16 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
7289 expect_fetch = (op_desc.type () == kFetchOpType );
7390 }
7491
75- std::set<std::string> dependent_vars;
7692 std::vector<bool > should_run;
7793 for (auto op_iter = ops.rbegin (); op_iter != ops.rend (); ++op_iter) {
7894 auto & op_desc = *op_iter;
79-
8095 if (IsTarget (op_desc) || HasDependentVar (op_desc, dependent_vars)) {
8196 // insert its input to the dependency graph
8297 for (auto & var : op_desc.inputs ()) {
8398 for (auto & argu : var.arguments ()) {
8499 dependent_vars.insert (argu);
85100 }
86101 }
87-
88102 should_run.push_back (true );
89103 } else {
90104 should_run.push_back (false );
@@ -95,45 +109,81 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
95109 // we reverse the should_run vector
96110 std::reverse (should_run.begin (), should_run.end ());
97111
98- *output = input;
99- auto * op_field = output->mutable_blocks (block_id)->mutable_ops ();
112+ // copy the current block from input to output
113+ auto * block_field = output->mutable_blocks ();
114+ *block_field->Add () = input.blocks (block_id);
115+
116+ int output_block_id = output->blocks_size () - 1 ;
117+ auto * output_block = output->mutable_blocks (output_block_id);
118+ output_block->set_idx (output_block_id);
119+ output_block->set_parent_idx (parent_block_id);
120+
121+ auto * op_field = output_block->mutable_ops ();
100122 op_field->Clear ();
101123 for (size_t i = 0 ; i < should_run.size (); ++i) {
102124 if (should_run[i]) {
103- *op_field->Add () = input.blocks (block_id).ops (i);
125+ auto * op = op_field->Add ();
126+ *op = input.blocks (block_id).ops (i);
127+ if (HasSubBlock (*op)) {
128+ // create sub_block_dependent_vars here to help prune the sub block
129+ std::set<std::string> sub_block_dependent_vars;
130+ for (auto & var : op->inputs ()) {
131+ for (auto & argu : var.arguments ()) {
132+ sub_block_dependent_vars.insert (argu);
133+ }
134+ }
135+ for (auto & var : op->outputs ()) {
136+ for (auto & argu : var.arguments ()) {
137+ sub_block_dependent_vars.insert (argu);
138+ }
139+ }
140+ // GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
141+ // output_block_id is the idx of the current block in the output desc
142+ prune_impl (input, output, GetSubBlockIndex (*op), output_block_id,
143+ sub_block_dependent_vars);
144+ }
104145 }
105146 }
106147
107148 // remove the VarDescs in BlockDesc that are not referenced in
108149 // the pruned OpDescs
109150 std::unordered_map<std::string, proto::VarDesc> var_map;
110- auto * var_field = output->mutable_blocks (block_id )->mutable_vars ();
151+ auto * var_field = output->mutable_blocks (output_block_id )->mutable_vars ();
111152 for (const auto & var : *var_field) {
112153 var_map[var.name ()] = var;
113154 }
114155
115- var_field-> Clear () ;
156+ std::set<std::string> var_names ;
116157 for (const auto & op : *op_field) {
117- // add VarDescs of all input arguments for each OpDesc
118158 auto & input_field = op.inputs ();
119159 for (auto & input_var : input_field) {
120160 for (auto & arg : input_var.arguments ()) {
121- *var_field->Add () = var_map[arg];
161+ if (var_map.count (arg) != 0 ) {
162+ var_names.insert (arg);
163+ }
122164 }
123165 }
124- // add VarDescs of all output arguments for each OpDesc
125166 auto & output_field = op.outputs ();
126167 for (auto & output_var : output_field) {
127168 for (auto & arg : output_var.arguments ()) {
128- *var_field->Add () = var_map[arg];
169+ if (var_map.count (arg) != 0 ) {
170+ var_names.insert (arg);
171+ }
129172 }
130173 }
131174 }
175+
176+ var_field->Clear ();
177+ for (const auto & name : var_names) {
178+ *var_field->Add () = var_map[name];
179+ }
132180}
133181
134182// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
135183void Prune (const proto::ProgramDesc& input, proto::ProgramDesc* output) {
136- prune_impl (input, output, 0 );
184+ std::set<std::string> dependent_vars;
185+ output->clear_blocks ();
186+ prune_impl (input, output, 0 , -1 , dependent_vars);
137187}
138188
139189void inference_optimize_impl (const proto::ProgramDesc& input,
0 commit comments