Skip to content

ParallelDyGraph with GPU collective mode#16827

Merged
Yancey0623 merged 11 commits intoPaddlePaddle:developfrom
Yancey0623:prallel_dygraph_gpu
Apr 25, 2019
Merged

ParallelDyGraph with GPU collective mode#16827
Yancey0623 merged 11 commits intoPaddlePaddle:developfrom
Yancey0623:prallel_dygraph_gpu

Conversation

@Yancey0623
Copy link
Contributor

@Yancey0623 Yancey0623 commented Apr 12, 2019

Implement dygraph.parallel.DataParallel to hook collective ops in bk progress.

TODO:

  • add bcast parameters interface to bcast params from node0.
  • Refine the DataParallel API to make it simpler.
  • Implement varbase hook so that we can hook op's output grad filtered by parameter type.

@Yancey0623 Yancey0623 changed the title [WIP] ParallelDyGraph with GPU collective mode ParallelDyGraph with GPU collective mode Apr 18, 2019
@velconia
Copy link
Collaborator

Cool job~

junjun315
junjun315 previously approved these changes Apr 24, 2019
Copy link
Contributor

@junjun315 junjun315 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some comment

@junjun315 junjun315 self-requested a review April 24, 2019 02:27
}

void OpBase::RegisterBackwardHooks(const py::object& callable) {
void OpBase::RegisterBackwardHooks(const py::object& callable, bool front) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need front?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trace.py append the release_op as the first grad hook, and we can get the op object in other hooks, maybe we can store the hooks using stack instead of vector ?

.def("_grad_ivar",
[](const imperative::VarBase &self) { return self.grads_; },
py::return_value_policy::reference)
.def("_set_grad_ivar", [](imperative::VarBase &self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused method

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do that, thx.



def prepare_context(parallel_strategy, place):
def prepare_context(parallel_strategy, place=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need place here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary so far, can remove it.

collective._allreduce(g_var, g_var, sync_mode=True)

outs = self._layers(*inputs, **kwargs)
for _, op in six.iteritems(_dygraph_tracer()._ops):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should all op's output be collectived?

Copy link
Contributor Author

@Yancey0623 Yancey0623 Apr 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe not? The trainable parameters are only the inputs?

name=ivar._grad_name(),
stop_gradient=True,
ivar=g)
collective._allreduce(g_var, g_var, sync_mode=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the output be g_var?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will collective the grad var inplace.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool

Copy link
Collaborator

@velconia velconia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with little comments

@velconia
Copy link
Collaborator

  1. backward hooks should be supported in parameters
  2. op's output gradient should be filtered with parameter type?

@Yancey0623
Copy link
Contributor Author

backward hooks should be supported in parameters
op's output gradient should be filtered with parameter type?

Added in the TODO list.

@Yancey0623 Yancey0623 merged commit 0b07eef into PaddlePaddle:develop Apr 25, 2019
@Yancey0623 Yancey0623 deleted the prallel_dygraph_gpu branch April 25, 2019 09:41
sneaxiy pushed a commit to sneaxiy/Paddle that referenced this pull request Apr 28, 2019
# The first commit's message is:
remove ut test_dist_word2vec in mac ci, will fix it in private, test=develop (PaddlePaddle#17066)

# This is the 2nd commit message:

Fleet unify distributed training (PaddlePaddle#16791)

* implement distributed transpiler with fleet
# This is the 3rd commit message:

ParallelDyGraph with GPU collective mode (PaddlePaddle#16827)

implement dygraph.parallel.DataParallel to hook reduce op.

# This is the 4th commit message:

Init mixed precision training interface (PaddlePaddle#16856)

* Init mixed precision training interface

* Add fp16 test script

test=develop

* All initializers support float16

test=develop

* Code cleanup & add more code annotations

test=develop

* Update API spec

test=develop

* Add usage example in doc

test=develop

# This is the 5th commit message:

fix reference_count_pass,test=develop (PaddlePaddle#17060)

test=develop
# This is the 6th commit message:

Speedup roi_perspective_transform op by caching the information of linear interpolation in forward (PaddlePaddle#17090)

* Cache the information of linear interpolation in forward and use it in backward.
test=develop

* Fix cuda kernel.
test=develop

# This is the 7th commit message:

remove unnecessary prepare_data (PaddlePaddle#17080)

test=develop
# This is the 8th commit message:

fix interpolate cu. test=develop (PaddlePaddle#17101)

# This is the 9th commit message:

test=develop, double backward leaky_relu (PaddlePaddle#17067)

backward of backward: leaky_relu
# This is the 10th commit message:

fix fuse optimizer ops (PaddlePaddle#17102)

test=develop
# This is the 11th commit message:

truncated_gaussian_random supported in distributed training, test=develop (PaddlePaddle#17091)

# This is the 12th commit message:

 Detailed coordinate description for yolov3 loss (PaddlePaddle#17007)

* Detailed coordinate description for yolov3 loss

test=develop

* modified api.spec

test=develop

* modified loss name

* fix api.spec

test=develop

* polish description

test=develop

* modified api.spec

test=develop

# This is the 13th commit message:

fix test_weight_decay (PaddlePaddle#17109)

test=develop
# This is the 14th commit message:

Path flag (PaddlePaddle#17105)

* fix python/paddle/fluid/__init__.py detecting problems
sneaxiy added a commit that referenced this pull request Apr 28, 2019
* refine_dropout_mem,test=develop

* # This is a combination of 14 commits.
# The first commit's message is:
remove ut test_dist_word2vec in mac ci, will fix it in private, test=develop (#17066)

# This is the 2nd commit message:

Fleet unify distributed training (#16791)

* implement distributed transpiler with fleet
# This is the 3rd commit message:

ParallelDyGraph with GPU collective mode (#16827)

implement dygraph.parallel.DataParallel to hook reduce op.

# This is the 4th commit message:

Init mixed precision training interface (#16856)

* Init mixed precision training interface

* Add fp16 test script

test=develop

* All initializers support float16

test=develop

* Code cleanup & add more code annotations

test=develop

* Update API spec

test=develop

* Add usage example in doc

test=develop

# This is the 5th commit message:

fix reference_count_pass,test=develop (#17060)

test=develop
# This is the 6th commit message:

Speedup roi_perspective_transform op by caching the information of linear interpolation in forward (#17090)

* Cache the information of linear interpolation in forward and use it in backward.
test=develop

* Fix cuda kernel.
test=develop

# This is the 7th commit message:

remove unnecessary prepare_data (#17080)

test=develop
# This is the 8th commit message:

fix interpolate cu. test=develop (#17101)

# This is the 9th commit message:

test=develop, double backward leaky_relu (#17067)

backward of backward: leaky_relu
# This is the 10th commit message:

fix fuse optimizer ops (#17102)

test=develop
# This is the 11th commit message:

truncated_gaussian_random supported in distributed training, test=develop (#17091)

# This is the 12th commit message:

 Detailed coordinate description for yolov3 loss (#17007)

* Detailed coordinate description for yolov3 loss

test=develop

* modified api.spec

test=develop

* modified loss name

* fix api.spec

test=develop

* polish description

test=develop

* modified api.spec

test=develop

# This is the 13th commit message:

fix test_weight_decay (#17109)

test=develop
# This is the 14th commit message:

Path flag (#17105)

* fix python/paddle/fluid/__init__.py detecting problems
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants