1313// limitations under the License.
1414
1515#include " paddle/fluid/distributed/fleet_executor/runtime_graph.h"
16- #include < tuple>
1716#include " paddle/fluid/distributed/fleet_executor/task_node.h"
1817#include " paddle/fluid/framework/op_registry.h"
1918#include " paddle/fluid/framework/operator.h"
@@ -113,7 +112,7 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
113112 for (const auto & op_desc : program.Block (0 ).AllOps ()) {
114113 ops_.emplace_back (OpRegistry::CreateOp (*op_desc));
115114 }
116- std::unordered_map<OpRole , std::vector<OperatorBase*>> role_to_ops;
115+ std::unordered_map<int64_t , std::vector<OperatorBase*>> role_to_ops;
117116 for (const auto & op : ops_) {
118117 int64_t op_role = op->Attr <int64_t >(" op_role" );
119118 OpRole new_op_role;
@@ -130,22 +129,23 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
130129 " The op %s is None of LRSched, Forward, Backward or Optimize." ,
131130 op->Type ()));
132131 }
133- if (role_to_ops.find (new_op_role) == role_to_ops.end ()) {
134- role_to_ops.insert ({new_op_role, {}});
132+ int64_t new_op_role_id = static_cast <int64_t >(new_op_role);
133+ if (role_to_ops.find (new_op_role_id) == role_to_ops.end ()) {
134+ role_to_ops.insert ({new_op_role_id, {}});
135135 }
136- role_to_ops.at (new_op_role ).emplace_back (op.get ());
136+ role_to_ops.at (new_op_role_id ).emplace_back (op.get ());
137137 }
138138 int64_t cur_rank = exe_desc_.cur_rank ();
139139 int64_t task_id = cur_rank * functionality_order.size ();
140140 for (std::size_t i = 0 ; i < functionality_order.size (); ++i) {
141141 OpRole role = functionality_order[i];
142142 int64_t role_id = static_cast <int64_t >(role);
143- if (role_to_ops.find (role ) == role_to_ops.end ()) {
143+ if (role_to_ops.find (role_id ) == role_to_ops.end ()) {
144144 task_nodes_.emplace_back (
145145 TaskNode::CreateEmptyTaskNode (role_id, cur_rank, task_id));
146146 } else {
147147 task_nodes_.emplace_back (TaskNode::CreateTaskNode (
148- role_id, role_to_ops.at (role ), cur_rank, task_id));
148+ role_id, role_to_ops.at (role_id ), cur_rank, task_id));
149149 }
150150 ++task_id;
151151 }
0 commit comments