-
Notifications
You must be signed in to change notification settings - Fork 6k
[Auto-parallel] Fix sharding all_gather overlap in auto_dy #73717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
fe3d757
24d448e
02f206f
ea1ec63
a00a2c5
6bb3f44
c06c077
13c3d8f
58ded16
93d2dc6
e8e50da
3c9457f
74a19bf
bcd98bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1486,21 +1486,36 @@ def get_mesh(pp_idx=0): | |
| # overlap all_gather with optimizer op | ||
| if hasattr(param_and_grad[0], 'last_idx'): | ||
| idx = param_and_grad[0].last_idx | ||
| shard_size = ( | ||
| self.param_storage[idx]._numel() | ||
| // self._sharding_group.nranks | ||
| ) | ||
| begin = shard_size * max(self._sharding_group.rank, 0) | ||
| end = begin + shard_size | ||
| slice_buffer = paddle._C_ops.view_slice( | ||
| self.param_storage[idx], begin, end | ||
| ) | ||
| task = paddle.distributed.all_gather( | ||
| self.param_storage[idx], | ||
| slice_buffer, | ||
| group=self._sharding_group, | ||
| sync_op=False, | ||
| ) | ||
| if param_and_grad[0].last_idx == 0: | ||
| shard_size = ( | ||
| self.param_storage[idx]._numel() | ||
| // self._sharding_group.nranks | ||
| ) | ||
| begin = shard_size * max(self._sharding_group.rank, 0) | ||
| end = begin + shard_size | ||
| slice_buffer = paddle._C_ops.view_slice( | ||
| self.param_storage[idx], begin, end | ||
| ) | ||
| task = paddle.distributed.all_gather( | ||
| self.param_storage[idx], | ||
| slice_buffer, | ||
| group=self._sharding_group, | ||
| sync_op=False, | ||
| ) | ||
| self.param_storage[idx].is_sync = True | ||
| else: | ||
| self.param_storage[idx].is_sync = False | ||
|
|
||
| def _enable_tensor_fusion(self): | ||
| self.enable_tensor_fusion = True | ||
|
|
||
| def _enable_sharding_overlap(self, layers): | ||
| self.enable_sharding_overlap = True | ||
| if not isinstance(layers, paddle.nn.Layer): | ||
| raise RuntimeError( | ||
| f"`layers` must be `paddle.nn.Layer` but got {type(layers)}" | ||
| ) | ||
| self._layers = layers | ||
|
|
||
| def _reduce_scatter_gradients(self, grad_storage): | ||
| shard_size = grad_storage._numel() // self._sharding_group.nranks | ||
|
|
@@ -1516,6 +1531,16 @@ def _reduce_scatter_gradients(self, grad_storage): | |
| ).wait() | ||
|
|
||
| def _async_reduce_scatter(self): | ||
|
||
| if not self._layers: | ||
| raise RuntimeError( | ||
| "Sharding overlap requires an initialized model. " | ||
| "Call `_enable_sharding_overlap()` to set model." | ||
| ) | ||
| param2layer = {} | ||
| for layer in self._layers.sublayers(): | ||
| for p in layer.parameters(include_sublayers=False): | ||
| param2layer[id(p)] = layer | ||
|
|
||
| for i in range(len(self.fuse_param_view)): | ||
| self._reduce_scatter_gradients(self.grad_storage[i]) | ||
|
|
||
|
|
@@ -1541,6 +1566,26 @@ def fuse_comm(*_): | |
|
|
||
| return fuse_comm | ||
|
|
||
| def fuse_all_gather_hook_func(param_storage, comm_group): | ||
| @paddle.autograd.no_grad() | ||
| def fuse_comm(*_): | ||
| if not param_storage.is_sync: | ||
| shard_size = param_storage._numel() // comm_group.nranks | ||
| begin = shard_size * max(comm_group.rank, 0) | ||
| end = begin + shard_size | ||
| slice_buffer = paddle._C_ops.view_slice( | ||
| param_storage, begin, end | ||
| ) | ||
| task = paddle.distributed.all_gather( | ||
| param_storage, | ||
| slice_buffer, | ||
| group=comm_group, | ||
| sync_op=False, | ||
| ) | ||
| param_storage.is_sync = True | ||
|
|
||
| return fuse_comm | ||
|
|
||
| param_group_len = ( | ||
| len(self.fuse_param_view[i]) * self.gradient_accumulation_steps | ||
| ) | ||
|
|
@@ -1557,6 +1602,18 @@ def fuse_comm(*_): | |
| ) | ||
| ) | ||
|
|
||
| if i < len(self.fuse_param_view) - 1: | ||
| first_param = next(iter(self.fuse_param_view[i].values()))[ | ||
| 'param' | ||
| ] | ||
| layer = param2layer.get(id(first_param)) | ||
| layer.register_forward_pre_hook( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改为用局部变量 |
||
| fuse_all_gather_hook_func( | ||
| self.param_storage[i + 1], | ||
| self._sharding_group, | ||
| ) | ||
| ) | ||
|
|
||
| def _build_fuse_param_view( | ||
| self, | ||
| params_and_grads, | ||
|
|
@@ -1579,6 +1636,7 @@ def get_padded_size(param): | |
| param_buffer = paddle.zeros( | ||
| shape=[total_buffer_size], dtype=params_and_grads[0][0].dtype | ||
| ) | ||
| param_buffer.is_sync = False | ||
| grad_dtype = paddle.float32 | ||
| grad_buffer = paddle.zeros(shape=[total_buffer_size], dtype=grad_dtype) | ||
| grad_buffer.check_in = 0 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1、 后续要用到 self._layers 做参数查找和注册 hook,这里需要对 layers 参数做检查,比如,类型是
paddle.nn.Layer2、这个函数本身就是
enable_sharding_overlap为 True 时才会调用吧,是有有必要再传这个参数?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1和2均已做修改,感谢!