Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions paddle/fluid/framework/device_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ class SectionWorker : public DeviceWorker {
~SectionWorker() override {}

void Initialize(const TrainerDesc& desc) override;
void PrepareUnusedVar();

void BindingDataFeedMemory() override {}
void CreateDeviceResource(const ProgramDesc& main_prog) override{};
Expand Down Expand Up @@ -581,7 +582,8 @@ class SectionWorker : public DeviceWorker {
void RunUpdate(
std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
void PrepareUnusedVar();
void RunFThenB(std::unique_ptr<GarbageCollector>&);
void Run1F1B(std::unique_ptr<GarbageCollector>&);

protected:
int section_id_;
Expand All @@ -591,9 +593,12 @@ class SectionWorker : public DeviceWorker {
int pipeline_stage_;
int schedule_mode_; // 0 for F-then-B and 1 for 1F1B
std::vector<Scope*> microbatch_scopes_;
std::vector<std::string> skip_vars_;
const Scope* minibatch_scope_;

// skip&backward vars are only used in 1F1B
std::vector<std::string> skip_vars_;
std::vector<std::string> backward_send_vars_;

std::vector<std::unique_ptr<OperatorBase>> ops_;
std::shared_ptr<framework::ProgramDesc> program_;
std::unordered_map<const OperatorBase*, std::vector<std::string>>
Expand Down
29 changes: 17 additions & 12 deletions paddle/fluid/framework/executor_gc_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,9 @@ GetUnusedVars(const BlockDesc &block,
return result;
}

void DeleteUnusedTensors(
const Scope &scope, const OperatorBase *op,
const std::unordered_map<const OperatorBase *, std::vector<std::string>>
&delete_vars_map,
GarbageCollector *gc) {
auto iter = delete_vars_map.find(op);
if (iter == delete_vars_map.end()) {
return;
}

auto &delete_vars = iter->second;

void DeleteUnusedTensors(const Scope &scope,
const std::vector<std::string> &delete_vars,
GarbageCollector *gc) {
std::deque<std::shared_ptr<memory::Allocation>> garbages;

for (auto &var_name : delete_vars) {
Expand Down Expand Up @@ -189,6 +180,20 @@ void DeleteUnusedTensors(
}
}

void DeleteUnusedTensors(
const Scope &scope, const OperatorBase *op,
const std::unordered_map<const OperatorBase *, std::vector<std::string>>
&delete_vars_map,
GarbageCollector *gc) {
auto iter = delete_vars_map.find(op);
if (iter == delete_vars_map.end()) {
return;
}

auto &delete_vars = iter->second;
DeleteUnusedTensors(scope, delete_vars, gc);
}

static std::vector<std::unique_ptr<OperatorBase>> CreateOpsFromBlock(
const BlockDesc &block) {
std::vector<std::unique_ptr<OperatorBase>> ops;
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/framework/executor_gc_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ GetUnusedVars(const BlockDesc &block,
const std::vector<std::unique_ptr<OperatorBase>> &ops,
const std::vector<std::string> &skip_vars);

// Collect unused tensors
void DeleteUnusedTensors(const Scope &scope,
const std::vector<std::string> &delete_vars,
GarbageCollector *gc);

// Collect unused tensors after op runs
void DeleteUnusedTensors(
const Scope &scope, const OperatorBase *op,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/pipeline_trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
auto this_worker =
std::dynamic_pointer_cast<paddle::framework::SectionWorker>(worker_);
this_worker->SetPlace(place_);
this_worker->Initialize(trainer_desc);
this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
this_worker->SetScheduleMode(schedule_mode_);
this_worker->Initialize(trainer_desc);
}

void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) {
Expand Down
157 changes: 106 additions & 51 deletions paddle/fluid/framework/section_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,36 @@ void SectionWorker::Initialize(const TrainerDesc &desc) {
for (auto &op_desc : program_->Block(0).AllOps()) {
ops_.push_back(OpRegistry::CreateOp(*op_desc));
}

// if not 1F1B scheduler
if (schedule_mode_ != 1) return;

bool is_first_stage = (pipeline_stage_ == 0);
int BACKWARD = static_cast<int>(OpRole::kBackward);
for (auto &op : ops_) {
int op_role = op->Attr<int>("op_role");
auto op_type = op->Type();

// pipeline backward send op
if (op_role != BACKWARD) continue;
if (op_type != "send_v2" && op_type != "partial_send") continue;

auto var_name = op->InputVars()[0];
VLOG(3) << "Pipeline backward send var " << var_name;
PADDLE_ENFORCE_NE(is_first_stage, true,
platform::errors::PreconditionNotMet(
"The first pipeline stage must do not have a "
"backward send var, please check var %s",
var_name));

backward_send_vars_.push_back(var_name);
skip_vars_.push_back(var_name);
}
}

void SectionWorker::PrepareUnusedVar() {
VLOG(5) << "begin prepare the unsed vars";
unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
}

void SectionWorker::RunForward(
Expand Down Expand Up @@ -96,9 +126,79 @@ void SectionWorker::RunUpdate(
}
}

void SectionWorker::PrepareUnusedVar() {
VLOG(5) << "begin prepare the unsed vars";
unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
void SectionWorker::RunFThenB(std::unique_ptr<GarbageCollector> &gc) {
// F-then-B scheduler which runs Forward phase for all microbatches,
// then runs Backward phase for all microbatches.
// step1: run forward
for (int i = 0; i < num_microbatches_; ++i) {
RunForward(i, gc, unused_vars_);
}
// step2: run backward
for (int i = 0; i < num_microbatches_; ++i) {
RunBackward(i, gc, unused_vars_);
}
// step3: run update
RunUpdate(gc, unused_vars_);
}

void SectionWorker::Run1F1B(std::unique_ptr<GarbageCollector> &gc) {
// 1F1B scheduler, which runs forward phase and backward phase altertively
// after startup phase. For a stage, the number of microbatches for
// startup is num_pipeline_stages_ - pipeline_stage_ - 1, where
// num_pipeline_stages_ is the total number of pipeline stages and
// pipeline_stage_ is the pipeline stage of the current device.
auto startup_steps = num_pipeline_stages_ - pipeline_stage_ - 1;
VLOG(3) << "startup_steps:" << startup_steps
<< ", num_stages: " << num_pipeline_stages_
<< ", stage:" << pipeline_stage_;
PADDLE_ENFORCE_GT(
num_microbatches_, startup_steps,
platform::errors::InvalidArgument(
"To use pipeline with 1F1B scheduler, please make sure number of "
"microbatches (%d) is than startup steps (%d).",
num_microbatches_, startup_steps));
int fw_step = 0;
int bw_step = 0;

// startup phase
while (fw_step < startup_steps) {
RunForward(fw_step, gc, unused_vars_);
fw_step += 1;
}

// 1f1b phase
while (fw_step < num_microbatches_) {
RunForward(fw_step, gc, unused_vars_);

// delete backward send var at step=(bw_step - 2)
if (gc && bw_step >= 2) {
DeleteUnusedTensors(*microbatch_scopes_[bw_step - 2], backward_send_vars_,
gc.get());
}

RunBackward(bw_step, gc, unused_vars_);

fw_step += 1;
bw_step += 1;
}

int reserve_bw_send_step = bw_step - 2;
// backward phase
while (bw_step < num_microbatches_) {
RunBackward(bw_step, gc, unused_vars_);
bw_step += 1;
}

RunUpdate(gc, unused_vars_);

if (gc) {
// NOTE(wangxi): program must add sync backward send comm at update
// delete backward send var
for (int i = reserve_bw_send_step; i < num_microbatches_; ++i) {
DeleteUnusedTensors(*microbatch_scopes_[i], backward_send_vars_,
gc.get());
}
}
}

void SectionWorker::TrainFiles() {
Expand Down Expand Up @@ -132,56 +232,11 @@ void SectionWorker::TrainFiles() {
} // max_memory_size >= 0

if (schedule_mode_ == 0) {
// F-then-B scheduler which runs Forward phase for all microbatches,
// then runs Backward phase for all microbatches.
// step1: run forward
for (int i = 0; i < num_microbatches_; ++i) {
RunForward(i, gc, unused_vars_);
}
// step2: run backward
for (int i = 0; i < num_microbatches_; ++i) {
RunBackward(i, gc, unused_vars_);
}
// step3: run update
RunUpdate(gc, unused_vars_);
RunFThenB(gc);
} else {
// 1F1B scheduler, which runs forward phase and backward phase altertively
// after startup phase. For a stage, the number of microbatches for
// startup is num_pipeline_stages_ - pipeline_stage_ - 1, where
// num_pipeline_stages_ is the total number of pipeline stages and
// pipeline_stage_ is the pipeline stage of the current device.
auto startup_steps = num_pipeline_stages_ - pipeline_stage_ - 1;
VLOG(3) << "startup_steps:" << startup_steps
<< ", num_stages: " << num_pipeline_stages_
<< ", stage:" << pipeline_stage_;
PADDLE_ENFORCE_GT(
num_microbatches_, startup_steps,
platform::errors::InvalidArgument(
"To use pipeline with 1F1B scheduler, please make sure number of "
"microbatches (%d) is than startup steps (%d).",
num_microbatches_, startup_steps));
int fw_step = 0;
int bw_step = 0;
// startup phase
while (fw_step < startup_steps) {
RunForward(fw_step, gc, unused_vars_);
fw_step += 1;
}

// 1f1b phase
while (fw_step < num_microbatches_) {
RunForward(fw_step, gc, unused_vars_);
fw_step += 1;
RunBackward(bw_step, gc, unused_vars_);
bw_step += 1;
}
// backward phase
while (bw_step < num_microbatches_) {
RunBackward(bw_step, gc, unused_vars_);
bw_step += 1;
}
RunUpdate(gc, unused_vars_);
Run1F1B(gc);
}

dev_ctx_->Wait();
++batch_id_;
}
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5277,6 +5277,7 @@ def _optimize_forward_send_sync(self, program):
backward_recv_index = index
break

# last pipeline stage
if backward_recv_index is None: return

offset = 0
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_dist_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,7 +1258,7 @@ def _get_required_envs(self, check_error_log=False, need_envs={}):
"fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10," \
"alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10,executor=10,operator=10," \
"sparse_all_reduce_op_handle=10,gen_nccl_id_op=10,gen_nccl_id_op_help=10,nccl_helper=10,grpc_client=10," \
"grpc_server=10,request_handler_impl=10"
"grpc_server=10,request_handler_impl=10,section_worker=10"
required_envs["GLOG_logtostderr"] = "1"

required_envs.update(need_envs)
Expand Down