Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 125 additions & 54 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"

#include <algorithm>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/distributed/fleet_executor/global.h"
Expand Down Expand Up @@ -53,40 +54,40 @@ FleetExecutor::~FleetExecutor() {
}
}

void FleetExecutor::Init(
const std::string& carrier_id,
const framework::ProgramDesc& program_desc,
framework::Scope* scope,
const platform::Place& place,
int64_t num_micro_batches,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank,
const std::vector<std::string>& inference_root_scope_vars,
const std::vector<framework::Scope*>& micro_scope_list) {
PADDLE_ENFORCE_GT(task_nodes.size(),
0,
platform::errors::InvalidArgument(
"Fleet executor is inited with empty task node"));
// TODO(fleet_exe devs): the unused_vars should be got from run time graph
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (const auto& desc : program_desc.Block(0).AllOps()) {
ops.emplace_back(framework::OpRegistry::CreateOp(*desc));
namespace {
void GetSubBlockTask(const std::vector<TaskNode*>& tasks,
TaskNode* cur_task,
std::set<TaskNode*>* sub_block_task) {
auto& downstream = cur_task->downstream();
auto& id_to_dep_type = cur_task->id_to_dep_type();
for (auto& down : downstream) {
int64_t task_id = down.first;
if (id_to_dep_type.at(task_id) == DependType::NORMAL) {
for (const auto& task : tasks) {
if (task->task_id() == task_id) {
sub_block_task->emplace(task);
GetSubBlockTask(tasks, task, sub_block_task);
}
}
}
}
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
}

// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result.
void PreventVarsDelete(
std::unordered_map<const framework::OperatorBase*,
std::vector<std::string>>* unused_vars,
const std::vector<std::string>& vars_not_gc) {
std::vector<const framework::OperatorBase*> changed_ops;
for (auto pair : unused_vars) {

for (const auto& pair : *unused_vars) {
const framework::OperatorBase* op = pair.first;
std::vector<std::string> unused = pair.second;
for (auto name : inference_root_scope_vars) {
auto iter = std::find(unused.begin(), unused.end(), name);
if (iter != unused.end()) {
std::vector<std::string> cur_unused = pair.second;
for (auto name : vars_not_gc) {
auto iter = std::find(cur_unused.begin(), cur_unused.end(), name);
if (iter != cur_unused.end()) {
VLOG(3) << "Removing var: [" << name
<< "] from the unused vars list of op: [" << op->Type() << "]";
unused.erase(iter);
cur_unused.erase(iter);
if (std::find(changed_ops.begin(), changed_ops.end(), op) ==
changed_ops.end()) {
// record the op whose unused vars have been updated
Expand All @@ -95,48 +96,118 @@ void FleetExecutor::Init(
}
}
// update the unused vars list in the map
unused_vars[op] = unused;
unused_vars->at(op) = cur_unused;
}
for (auto op : changed_ops) {
auto iter = unused_vars.find(op);
const auto& iter = unused_vars->find(op);
if (iter->second.empty()) {
// remove those ops in the map that have empty unused vars list
VLOG(3) << "Removing op: [" << op->Type() << "] from unused_vars map.";
unused_vars.erase(iter);
unused_vars->erase(iter);
}
}
runtime_graph_ = std::make_shared<RuntimeGraph>();
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
for (auto task_node : task_nodes) {
task_node->SetUnusedVars(unused_vars);
if (task_node->type() == "Cond") {
std::vector<std::string> while_block_vars;
VLOG(3) << "Vars in while sub block:";
for (auto& var : program_desc.Block(1).AllVars()) {
VLOG(3) << var->Name();
while_block_vars.emplace_back(var->Name());
}
for (const auto& pair : unused_vars) {
if (pair.first->Type() == "while") {
for (const auto& var_name : pair.second) {
while_block_vars.emplace_back(var_name);
}
}
}

std::vector<std::string> GetUnusedVarsAfterWhile(
const framework::ProgramDesc& program_desc,
TaskNode* cond_task,
const std::vector<std::string> vars_not_gc) {
std::vector<std::string> while_block_vars;
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (const auto& desc : program_desc.Block(0).AllOps()) {
ops.emplace_back(framework::OpRegistry::CreateOp(*desc));
}
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
PreventVarsDelete(&unused_vars, vars_not_gc);
for (const auto& pair : unused_vars) {
if (pair.first->Type() == "while") {
for (const auto& var_name : pair.second) {
while_block_vars.emplace_back(var_name);
}
VLOG(3) << "Vars below will be removed after while:";
for (const auto& name : while_block_vars) {
VLOG(3) << name;
}
}
return while_block_vars;
}

} // namespace

void FleetExecutor::Init(
const std::string& carrier_id,
const framework::ProgramDesc& program_desc,
framework::Scope* scope,
const platform::Place& place,
int64_t num_micro_batches,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank,
const std::vector<std::string>& inference_root_scope_vars,
const std::vector<framework::Scope*>& micro_scope_list) {
PADDLE_ENFORCE_GT(task_nodes.size(),
0,
platform::errors::InvalidArgument(
"Fleet executor is inited with empty task node"));
// Set the unused var after running while op
std::set<TaskNode*> sub_block_tasks;
std::vector<std::string> while_block_vars;
for (const auto& task_node : task_nodes) {
if (task_node->type() == "Cond") {
GetSubBlockTask(task_nodes, task_node, &sub_block_tasks);
while_block_vars = GetUnusedVarsAfterWhile(
program_desc, task_node, inference_root_scope_vars);
VLOG(3) << "Vars will be gced after while op";
for (auto var : while_block_vars) {
VLOG(3) << var;
}
task_node->SetWhileBlockVars(while_block_vars);
}
}
std::vector<framework::OperatorBase*> sub_block_ops;
for (const auto& task_node : sub_block_tasks) {
for (const auto& op : task_node->ops()) {
sub_block_ops.emplace_back(op);
}
}
// Analyse the unused vars in block 0. The operators in block 1
// should be passed in first for prevent vars been released but removed soon.
// Since the unused vars in block 1 need to analyse separately.
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (const auto& task_node : task_nodes) {
for (const auto& op : task_node->ops()) {
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
}
}
auto global_unused_vars =
framework::GetUnusedVars(program_desc.Block(0), ops, {});

// Analyse the unused vars in block 1.
std::unordered_map<const framework::OperatorBase*, std::vector<std::string>>
sub_unused_vars;
if (program_desc.Size() > 1) {
sub_unused_vars = framework::GetUnusedVars(program_desc.Block(1), ops, {});
PreventVarsDelete(&sub_unused_vars, while_block_vars);
}
for (auto& unique_op : ops) {
unique_op.release();
}

// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result.
PreventVarsDelete(&global_unused_vars, inference_root_scope_vars);

runtime_graph_ = std::make_shared<RuntimeGraph>();
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
for (auto task_node : task_nodes) {
if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) {
task_node->SetUnusedVars(global_unused_vars);
} else {
task_node->SetUnusedVars(sub_unused_vars);
}
int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node);
}
runtime_graph_->SetInterceptorIdToRank(task_id_to_rank);
runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task);
for (auto& unique_op : ops) {
unique_op.release();
}

VLOG(5) << runtime_graph_->DebugString();
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
Expand Down