[Auto-parallel] Fix sharding all_gather overlap in auto_dy#73717
[Auto-parallel] Fix sharding all_gather overlap in auto_dy#73717Xing-lil merged 14 commits intoPaddlePaddle:developfrom
Conversation
| def fuse_all_gather_hook_func(param_storage, comm_group): | ||
| @paddle.autograd.no_grad() | ||
| def fuse_comm(*_): | ||
| shard_size = param_storage._numel() // comm_group.nranks |
There was a problem hiding this comment.
这里,如果 param_storage._numel() 不能被整除,会怎么处理
There was a problem hiding this comment.
在 _build_fuse_param_view 中的 get_padded_size 确保了param_storage._numel() 是 comm_group.nranks 整数倍,故不会出现这种情况。
| task = paddle.distributed.all_gather( | ||
| param_storage, | ||
| slice_buffer, | ||
| group=self._sharding_group, |
There was a problem hiding this comment.
为什么 传了 comm_group 但实际用的 self._sharding_group?
|
|
||
| def _set_sharding_overlap(self, enable_sharding_overlap, layers): | ||
| self.enable_sharding_overlap = enable_sharding_overlap | ||
| self._layers = layers |
There was a problem hiding this comment.
1、 后续要用到 self._layers 做参数查找和注册 hook,这里需要对 layers 参数做检查,比如,类型是 paddle.nn.Layer
2、这个函数本身就是 enable_sharding_overlap 为 True 时才会调用吧,是有有必要再传这个参数?
| 'param' | ||
| ] | ||
| layer = _find_layer_containing_param(first_param) | ||
| layer.register_forward_pre_hook( |
There was a problem hiding this comment.
- 这里每次调用 _find_layer_containing_param 都会遍历所有子layer,建议缓存 param2layer 的关系
- 考虑 layer 为 None 的情况
There was a problem hiding this comment.
已修改为用局部变量 param2layer = {} 缓存,已有 self._layers 为 None 时的报错提醒。
| ) | ||
|
|
||
| def _set_tensor_fusion(self, enable_tensor_fusion): | ||
| self.enable_tensor_fusion = enable_tensor_fusion |
There was a problem hiding this comment.
这个函数本身就是 enable_tensor_fusion 为 True,不需再传参数 enable_tensor_fusion 了。建议:
def _enable_tensor_fusion(self):
self.enable_tensor_fusion = True
| ) | ||
| for layer in self._layers.sublayers(): | ||
| for p in layer.parameters(include_sublayers=False): | ||
| if param.name == p.name: |
There was a problem hiding this comment.
这里只能通过 name 来判断吗?是否参数名会被用户修改?
| @@ -1516,6 +1531,16 @@ def _reduce_scatter_gradients(self, grad_storage): | |||
| ).wait() | |||
|
|
|||
| def _async_reduce_scatter(self): | |||
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #73717 +/- ##
==========================================
Coverage ? 55.81%
==========================================
Files ? 1
Lines ? 43
Branches ? 0
==========================================
Hits ? 24
Misses ? 19
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
PR Category
Auto Parallel
PR Types
Bug fixes
Description
Launching all
all_gatherat once blocks overlap with other sync/comm ops.Fix: Prefetch 1 buffer ahead by hook to enable overlap.
Ref: Same fix in dynamic_hand #73406
Pcard-70448