Skip to content

Make fuse_optimizer_op_pass also work when the model contains sparse gradients.#18664

Merged
chengduoZH merged 11 commits intoPaddlePaddle:developfrom
chengduoZH:support_sparse_gradient
Jul 23, 2019
Merged

Make fuse_optimizer_op_pass also work when the model contains sparse gradients.#18664
chengduoZH merged 11 commits intoPaddlePaddle:developfrom
chengduoZH:support_sparse_gradient

Conversation

@chengduoZH
Copy link
Contributor

@chengduoZH chengduoZH commented Jul 17, 2019

  1. Make fuse_optimizer_op_pass also work when the model contains sparse gradients.
  2. Polish the code which are used by fuse_optimizer_op_pass support.

chengduozh added 2 commits July 17, 2019 13:13
test=develop
@chengduoZH chengduoZH requested review from Xreki and gongweibao July 17, 2019 08:22
chengduozh added 3 commits July 17, 2019 17:14
@chengduoZH chengduoZH force-pushed the support_sparse_gradient branch 2 times, most recently from 7bdea4e to 4a73988 Compare July 17, 2019 12:51
@chengduoZH chengduoZH force-pushed the support_sparse_gradient branch from 4a73988 to 419f342 Compare July 18, 2019 05:22
@chengduoZH chengduoZH changed the title Support sparse gradients for fuse_optimizer_op_pass Make fuse_optimizer_op_pass also work when the model contains sparse gradients. Jul 18, 2019
result.Get<details::ParamsAndGrads>(details::kParamsAndSparseGrads);

for (auto &param_grad : params_grads) {
if (IsSupportedVarType(GetTypeOfVar(vars_info, param_grad.second))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsLodTensorVartype or IsDenseGradVarType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if (node->Op()->Type() == fuse_op_type) {
auto grad_name = node->Op()->Input(kGrad);
PADDLE_ENFORCE_EQ(grad_name.size(), static_cast<size_t>(1));
if (GettypeOfVar(vars_info, grad_name[0]) == proto::VarType::LOD_TENSOR) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsDenseGradVarType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

const std::string prefix(details::kFusedVarNamePrefix);
// NOTE: the fused_var_name should be unique.
for (auto &var_name : aux_var_names) {
// NOTE: the fused_var_name should be unique.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to guarantee it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 81 is used to check this.

test=develop
const std::string prefix(details::kFusedVarNamePrefix);
// NOTE: the fused_var_name should be unique.
for (auto &var_name : aux_var_names) {
// NOTE: the fused_var_name should be unique.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 81 is used to check this.

"The VarDescs of persistable variable are not consistency.");
PADDLE_ENFORCE(graph == native_graph,
"Pass::Apply() cannot delete the passed graph and shouldn't "
"return a new graph.(For the need of pybind11)");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is unnecessary.

result.Get<details::ParamsAndGrads>(details::kParamsAndSparseGrads);

for (auto &param_grad : params_grads) {
if (IsSupportedVarType(GetTypeOfVar(vars_info, param_grad.second))) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if (node->Op()->Type() == fuse_op_type) {
auto grad_name = node->Op()->Input(kGrad);
PADDLE_ENFORCE_EQ(grad_name.size(), static_cast<size_t>(1));
if (GettypeOfVar(vars_info, grad_name[0]) == proto::VarType::LOD_TENSOR) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if (result.Has(details::kParamsAndGrads)) {
auto &params_grads =
result.Get<details::ParamsAndGrads>(details::kParamsAndGrads);
if (result.Has(details::kParamsAndDenseGrads)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This nested if is too long.

@chengduoZH chengduoZH force-pushed the support_sparse_gradient branch from 864393f to 3d011e7 Compare July 18, 2019 09:45
gongweibao
gongweibao previously approved these changes Jul 18, 2019
Copy link
Contributor

@gongweibao gongweibao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

test=develop
@chengduoZH chengduoZH force-pushed the support_sparse_gradient branch from 464b882 to 126d0a0 Compare July 21, 2019 07:42
}
}
}
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Define the fused variables in the local execution scope.
Because for some model, there may be more than one program, and those programs may share some parameters, for the previous strategy, the gradients of the shared parameters of those programs are also shared, But this is somewhat problematic, so we should define those fused variables of gradients in the local execution scope.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy these line to Comments may be better
And which is the unit test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gongweibao
gongweibao previously approved these changes Jul 22, 2019
Copy link
Contributor

@gongweibao gongweibao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

test=develop
Copy link
Contributor

@gongweibao gongweibao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chengduoZH chengduoZH merged commit fd3aad6 into PaddlePaddle:develop Jul 23, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants