-
Notifications
You must be signed in to change notification settings - Fork 6k
[Hybrid Parallel] Add Topology for hybrid communicate #32011
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,176 @@ | ||
| # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import paddle | ||
| import collections | ||
| import numpy as np | ||
| from itertools import product | ||
| from functools import reduce | ||
| __all__ = ['CommunicateTopology', 'HybridCommunicateGroup'] | ||
|
|
||
|
|
||
| class CommunicateTopology(object): | ||
| def __init__(self, hybrid_group_names, dims): | ||
| self._parallel_names = hybrid_group_names | ||
| self._dims = dims | ||
| self.coordinate = collections.namedtuple('Coordinate', | ||
| self._parallel_names) | ||
| self._world_size = reduce(lambda x, y: x * y, self._dims) | ||
|
|
||
| ranges = [range(d) for d in self._dims] | ||
| all_coordinate = [self.coordinate(*x) for x in product(*ranges)] | ||
|
|
||
| self._coord2rank = dict(zip(all_coordinate, range(len(all_coordinate)))) | ||
| self._rank2coord = dict( | ||
| zip(self._coord2rank.values(), self._coord2rank.keys())) | ||
|
|
||
| def get_hybrid_group_names(self): | ||
| return self._parallel_names | ||
|
|
||
| def get_dim(self, axis_name): | ||
| return self._dims[self._parallel_names.index(axis_name)] | ||
|
|
||
| def world_size(self): | ||
| return self._world_size | ||
|
|
||
| def get_rank(self, **args): | ||
| assert len(args) == len(self._dims) | ||
| key = self.coordinate(**args) | ||
| assert key in self._coord2rank.keys() | ||
| return self._coord2rank[key] | ||
|
|
||
| def get_coord(self, rank): | ||
| assert rank < self._world_size | ||
| assert rank in self._rank2coord.keys() | ||
| return self._rank2coord[rank] | ||
|
|
||
| def get_axis_list(self, axis_name, index): | ||
| axis = self._parallel_names.index(axis_name) | ||
| ranks = [ | ||
| self._coord2rank[coord] for coord in self._coord2rank.keys() | ||
| if coord[axis] == index | ||
| ] | ||
| ranks.sort() | ||
| return ranks | ||
|
|
||
| def get_dim_size(self, axis_name): | ||
| assert axis_name in self._parallel_names | ||
| return self._dims[self._parallel_names.index(axis_name)] | ||
|
|
||
| def get_comm_list(self, axis_name): | ||
| assert axis_name in self._parallel_names | ||
| other_axis_names = [ | ||
| name for name in self._parallel_names if name != axis_name | ||
| ] | ||
|
|
||
| ranges = [] | ||
| for name in other_axis_names: | ||
| dim_num = self.get_dim_size(name) | ||
| ranges.append(range(dim_num)) | ||
|
|
||
| all_result = [] | ||
| for x in product(*ranges): | ||
| key_coord = {} | ||
| for other_name in other_axis_names: | ||
| key_coord[other_name] = x[other_axis_names.index(other_name)] | ||
|
|
||
| result = [] | ||
| for i in range(0, self.get_dim_size(axis_name)): | ||
| key_coord[axis_name] = i | ||
| result.append(self._coord2rank[self.coordinate(**key_coord)]) | ||
| all_result.append(result) | ||
|
|
||
| return all_result | ||
|
|
||
|
|
||
| class HybridCommunicateGroup(object): | ||
| def __init__(self, topology): | ||
| self.nranks = paddle.distributed.get_world_size() | ||
| self.global_rank = paddle.distributed.get_rank() | ||
| self._topo = topology | ||
|
|
||
| self._num_data_parallel = self._topo.get_dim('data') | ||
| self._num_model_parallel = self._topo.get_dim('model') | ||
| self._num_pipe_parallel = self._topo.get_dim('pipe') | ||
|
|
||
| self._data_parallel_id = self._get_data_parallel_id() | ||
| self._model_parallel_id = self._get_model_parallel_id() | ||
|
|
||
| assert self._check_vaild_topo( | ||
| ), "Here is an unreasonable topogy setting" | ||
|
|
||
| # create comm group for data parallel | ||
| self._dp_group, self._dp_comm_group = self._set_comm_group("data") | ||
| print("data parallel group", self._dp_group) | ||
|
|
||
| # create comm group for model parallel | ||
| self._mp_group, self._mp_comm_group = self._set_comm_group("model") | ||
| print("model parallel group", self._mp_group) | ||
|
|
||
| def _check_vaild_topo(self): | ||
| return self._num_data_parallel * self._num_model_parallel * self._num_pipe_parallel == self.nranks | ||
|
|
||
| def _set_comm_group(self, parallel_method="data"): | ||
| parallel_group = [] | ||
| parallel_comm_group = None | ||
| parallel_groups = self._topo.get_comm_list(parallel_method) | ||
|
|
||
| for group in parallel_groups: | ||
| comm_group = paddle.distributed.new_group(ranks=group) | ||
| if self.global_rank in group: | ||
| parallel_group = group | ||
| parallel_comm_group = comm_group | ||
|
|
||
| assert len(parallel_group) > 0 | ||
| assert parallel_comm_group is not None | ||
|
|
||
| return parallel_group, parallel_comm_group | ||
|
|
||
| def topology(self): | ||
| return self._topo | ||
|
|
||
| def get_global_rank(self): | ||
| return self.global_rank | ||
|
|
||
| # data parallel message: | ||
| def _get_data_parallel_id(self): | ||
| return self._topo.get_coord(self.global_rank).data | ||
|
|
||
| def get_data_parallel_rank(self): | ||
| return self._data_parallel_id | ||
|
|
||
| def get_data_parallel_world_size(self): | ||
| return self._num_data_parallel | ||
|
|
||
| def get_data_parallel_group(self): | ||
| return self._dp_comm_group | ||
|
|
||
| def get_data_parallel_group_src_rank(self): | ||
| return self._dp_comm_group.ranks[0] | ||
|
|
||
| # model parallel message: | ||
| def _get_model_parallel_id(self): | ||
| return self._topo.get_coord(self.global_rank).model | ||
|
|
||
| def get_model_parallel_rank(self): | ||
| return self._model_parallel_id | ||
|
|
||
| def get_model_parallel_world_size(self): | ||
| return self._num_model_parallel | ||
|
|
||
| def get_model_parallel_group(self): | ||
| return self._mp_comm_group | ||
|
|
||
| def get_model_parallel_group_src_rank(self): | ||
| return self._mp_comm_group.ranks[0] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
101 changes: 101 additions & 0 deletions
101
python/paddle/fluid/tests/unittests/hybrid_communicate_group.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import numpy as np | ||
| import os | ||
| import paddle | ||
| from paddle.distributed import fleet | ||
|
|
||
|
|
||
| class TestNewGroupAPI(object): | ||
| def __init__(self): | ||
| paddle.distributed.init_parallel_env() | ||
| topo = fleet.CommunicateTopology(["data", "model", "pipe"], [2, 1, 1]) | ||
| self.hcg = fleet.HybridCommunicateGroup(topo) | ||
|
|
||
| d1 = np.array([1, 2, 3]) | ||
| d2 = np.array([2, 3, 4]) | ||
| self.tensor1 = paddle.to_tensor(d1) | ||
| self.tensor2 = paddle.to_tensor(d2) | ||
|
|
||
| def test_all(self): | ||
| topo = self.hcg.topology() | ||
| global_rank = self.hcg.get_data_parallel_rank() | ||
|
|
||
| dp_rank = self.hcg.get_data_parallel_rank() | ||
| dp_gp = self.hcg.get_data_parallel_group() | ||
| dp_world_size = self.hcg.get_data_parallel_world_size() | ||
| dp_src_rank = self.hcg.get_data_parallel_group_src_rank() | ||
| np.testing.assert_array_equal(dp_world_size, 2) | ||
| np.testing.assert_array_equal(dp_src_rank, 0) | ||
|
|
||
| mp_rank = self.hcg.get_model_parallel_rank() | ||
| mp_gp = self.hcg.get_model_parallel_group() | ||
| mp_world_size = self.hcg.get_model_parallel_world_size() | ||
| mp_src_rank = self.hcg.get_model_parallel_group_src_rank() | ||
| np.testing.assert_array_equal(mp_world_size, 1) | ||
|
|
||
| tmp = np.array([0, 0, 0]) | ||
| result = paddle.to_tensor(tmp) | ||
| paddle.distributed.scatter( | ||
| result, [self.tensor2, self.tensor1], | ||
| src=dp_src_rank, | ||
| group=dp_gp, | ||
| use_calc_stream=True) | ||
| if dp_rank == 0: | ||
| assert np.array_equal(result, self.tensor2) | ||
| elif dp_rank == 1: | ||
| assert np.array_equal(result, self.tensor1) | ||
| print("test scatter api ok") | ||
|
|
||
| paddle.distributed.broadcast( | ||
| result, src=1, group=dp_gp, use_calc_stream=True) | ||
| assert np.array_equal(result, self.tensor1) | ||
| print("test broadcast api ok") | ||
|
|
||
| paddle.distributed.reduce( | ||
| result, dst=dp_src_rank, group=dp_gp, use_calc_stream=True) | ||
| if dp_rank == 0: | ||
| assert np.array_equal(result, | ||
| paddle.add(self.tensor1, self.tensor1)) | ||
| elif dp_rank == 1: | ||
| assert np.array_equal(result, self.tensor1) | ||
| print("test reduce api ok") | ||
|
|
||
| paddle.distributed.all_reduce(result, use_calc_stream=True) | ||
| assert np.array_equal( | ||
| result, | ||
| paddle.add(paddle.add(self.tensor1, self.tensor1), self.tensor1)) | ||
| print("test all_reduce api ok") | ||
|
|
||
| paddle.distributed.wait(result, dp_gp, use_calc_stream=True) | ||
| paddle.distributed.wait(result, dp_gp, use_calc_stream=False) | ||
| print("test wait api ok") | ||
|
|
||
| result = [] | ||
| paddle.distributed.all_gather( | ||
| result, self.tensor1, group=dp_gp, use_calc_stream=True) | ||
| assert np.array_equal(result[0], self.tensor1) | ||
| assert np.array_equal(result[1], self.tensor1) | ||
| print("test all_gather api ok") | ||
|
|
||
| paddle.distributed.barrier(group=dp_gp) | ||
| print("test barrier api ok") | ||
|
|
||
| return | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| gpt = TestNewGroupAPI() | ||
| gpt.test_all() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment to describe what it is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done