Skip to content

Commit 43547fb

Browse files
committed
refine
1 parent 849eb85 commit 43547fb

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

paddle/fluid/distributed/fleet_executor/runtime_graph.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
16-
#include <tuple>
1716
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
1817
#include "paddle/fluid/framework/op_registry.h"
1918
#include "paddle/fluid/framework/operator.h"
@@ -113,7 +112,7 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
113112
for (const auto& op_desc : program.Block(0).AllOps()) {
114113
ops_.emplace_back(OpRegistry::CreateOp(*op_desc));
115114
}
116-
std::unordered_map<OpRole, std::vector<OperatorBase*>> role_to_ops;
115+
std::unordered_map<int64_t, std::vector<OperatorBase*>> role_to_ops;
117116
for (const auto& op : ops_) {
118117
int64_t op_role = op->Attr<int64_t>("op_role");
119118
OpRole new_op_role;
@@ -130,22 +129,23 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
130129
"The op %s is None of LRSched, Forward, Backward or Optimize.",
131130
op->Type()));
132131
}
133-
if (role_to_ops.find(new_op_role) == role_to_ops.end()) {
134-
role_to_ops.insert({new_op_role, {}});
132+
int64_t new_op_role_id = static_cast<int64_t>(new_op_role);
133+
if (role_to_ops.find(new_op_role_id) == role_to_ops.end()) {
134+
role_to_ops.insert({new_op_role_id, {}});
135135
}
136-
role_to_ops.at(new_op_role).emplace_back(op.get());
136+
role_to_ops.at(new_op_role_id).emplace_back(op.get());
137137
}
138138
int64_t cur_rank = exe_desc_.cur_rank();
139139
int64_t task_id = cur_rank * functionality_order.size();
140140
for (std::size_t i = 0; i < functionality_order.size(); ++i) {
141141
OpRole role = functionality_order[i];
142142
int64_t role_id = static_cast<int64_t>(role);
143-
if (role_to_ops.find(role) == role_to_ops.end()) {
143+
if (role_to_ops.find(role_id) == role_to_ops.end()) {
144144
task_nodes_.emplace_back(
145145
TaskNode::CreateEmptyTaskNode(role_id, cur_rank, task_id));
146146
} else {
147147
task_nodes_.emplace_back(TaskNode::CreateTaskNode(
148-
role_id, role_to_ops.at(role), cur_rank, task_id));
148+
role_id, role_to_ops.at(role_id), cur_rank, task_id));
149149
}
150150
++task_id;
151151
}

0 commit comments

Comments
 (0)