Skip to content

Support control flow in DataParallel#31625

Merged
ForFishes merged 42 commits intoPaddlePaddle:developfrom
ForFishes:support_control_flow
Apr 1, 2021
Merged

Support control flow in DataParallel#31625
ForFishes merged 42 commits intoPaddlePaddle:developfrom
ForFishes:support_control_flow

Conversation

@ForFishes
Copy link
Member

@ForFishes ForFishes commented Mar 15, 2021

PR types

New features

PR changes

Others

Describe

Support control flow in DataParallel
DataParallel exposes the find_unused_parameters interface, which is used to detect whether there are unused parameters in the network. This is a compatible upgrade.

This PR also did the following work:

  1. Improve the overall error message, and expose it earlier if it does not meet the execution requirements of DataParallel. For example: multiple forward calculations and one reverse calculation.
  2. Support control flow branching.
  • For the same step, the gradients of different cards are inconsistent.
  • Different steps and gradients of the same card are inconsistent.
  • In the case of gradient accumulation, the above two cases.
  1. Support sync parameters and buffers in reducer.

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

// TODO(liuyuhui) support XPU set constant
VLOG(3) << "XPU doesn't support set_constant";
}
if (platform::is_xpu_place(group_tensor.place())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

移到下面的else里面?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. ok

VLOG(3) << "Local used vars : "
<< string::join_strings(local_used_vars_, ',');
// TODO(liuyuhui): support bckl in using TensorToVector/TensorFromVector
#if defined(PADDLE_WITH_NCCL)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把RCCL给加上?

});
#elif defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)
FusedAllReduceSchedule(run_order, group);
FusedAllReduceSchedule(run_order, group, next_group_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BKCL的给漏掉了

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BKCL的代码在上面,这里只有NCCL。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BKCL的没有加next_group_ =.=

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

额,所以单测挂了。已修改。


if (find_unused_vars_) {
// TODO(liuyuhui) support xpu about Tensorcopy
#if defined(PADDLE_WITH_NCCL)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RCCL给加上?

strategy.find_unused_parameters = True
"""

return self.strategy.sync_batch_norm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sync_batch_norm了

last_comm_group_size_MB)
last_comm_group_size_MB,
find_unused_parameters=self._user_defined_strategy.
find_unused_parameters)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

重复了

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是自动排版了,参数并没有重复。

}

void BKCLParallelContext::SynchronizeCompute() {
// TODO(wangxi16): [Performance optimize] Maybe need to put Wait and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个TODO和上面那个TODO帮忙删一下0.0

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok。

virtual void WaitComm(int ring_id) = 0;

// synchorize compute stream
virtual void SynchronizeCompute() = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

亮哥,顺路把SynchronizeCompute在bkcl_context里面也加下哈~

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经添加。

}

if (find_unused_vars_) {
// TODO(liuyuhui) support xpu about Tensorcopy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorCopy不是已经支持XPU的place了吗~

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,不过由于前面那两个还不支持。所以还是下次xpu添加的时候顺便删一下吧。

class TestFleetDygraphControlFlowSame(TestDygraphControlFlowSame):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add new UT case about bkcl_mode for XPU?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前还不支持xpu。

self._dygraph = True

def test_mnist(self):
def test_net(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add new UT case about bkcl_mode for XPU?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,还不支持xpu。

auto dtype = var->DataType();
auto place = var->Place();
const auto dtype = var->DataType();
const auto place = var->Place();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use const&?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok , done

void Reducer::PrepareForBackward(
const std::vector<std::shared_ptr<imperative::VarBase>> &outputs) {
VLOG(3) << "start reseting count..";
VLOG(3) << "start forward and reset count.";
Copy link
Contributor

@hutuxian hutuxian Mar 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why start forward? This function is called after forward, right?

auto tensor =
var_warpper->MutableVar()->GetMutable<framework::LoDTensor>();
auto var_base = vars_[var_index]->GradVarBase();
auto tensor = var_base->MutableVar()->GetMutable<framework::LoDTensor>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference?


void Reducer::FusedAllReduceSchedule(int run_order, Group &group) {
void Reducer::FusedAllReduceSchedule(const int run_order, Group &group,
const int curr_group_index) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why add curr_group_index

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for print log

std::vector<bool> vars_marked_ready_;

// Following variables are to help control flow,
std::vector<int> local_used_vars_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we find better name, or add more comment to explain this vector.
Because after all-reduce, what it denotes is global_used_vars instead of local_used_vars

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

@hutuxian hutuxian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@jzhang533 jzhang533 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm for find_unused_parameters api exposed in python

@ForFishes ForFishes merged commit 8460698 into PaddlePaddle:develop Apr 1, 2021
@ForFishes ForFishes deleted the support_control_flow branch April 1, 2021 08:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants