Skip to content
88 changes: 73 additions & 15 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,21 +1486,36 @@ def get_mesh(pp_idx=0):
# overlap all_gather with optimizer op
if hasattr(param_and_grad[0], 'last_idx'):
idx = param_and_grad[0].last_idx
shard_size = (
self.param_storage[idx]._numel()
// self._sharding_group.nranks
)
begin = shard_size * max(self._sharding_group.rank, 0)
end = begin + shard_size
slice_buffer = paddle._C_ops.view_slice(
self.param_storage[idx], begin, end
)
task = paddle.distributed.all_gather(
self.param_storage[idx],
slice_buffer,
group=self._sharding_group,
sync_op=False,
)
if param_and_grad[0].last_idx == 0:
shard_size = (
self.param_storage[idx]._numel()
// self._sharding_group.nranks
)
begin = shard_size * max(self._sharding_group.rank, 0)
end = begin + shard_size
slice_buffer = paddle._C_ops.view_slice(
self.param_storage[idx], begin, end
)
task = paddle.distributed.all_gather(
self.param_storage[idx],
slice_buffer,
group=self._sharding_group,
sync_op=False,
)
self.param_storage[idx].is_sync = True
else:
self.param_storage[idx].is_sync = False

def _enable_tensor_fusion(self):
self.enable_tensor_fusion = True

def _enable_sharding_overlap(self, layers):
self.enable_sharding_overlap = True
if not isinstance(layers, paddle.nn.Layer):
raise RuntimeError(
f"`layers` must be `paddle.nn.Layer` but got {type(layers)}"
)
self._layers = layers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1、 后续要用到 self._layers 做参数查找和注册 hook,这里需要对 layers 参数做检查,比如,类型是 paddle.nn.Layer
2、这个函数本身就是 enable_sharding_overlap 为 True 时才会调用吧,是有有必要再传这个参数?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1和2均已做修改,感谢!


def _reduce_scatter_gradients(self, grad_storage):
shard_size = grad_storage._numel() // self._sharding_group.nranks
Expand All @@ -1516,6 +1531,16 @@ def _reduce_scatter_gradients(self, grad_storage):
).wait()

def _async_reduce_scatter(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如线下沟通,还有以下问题:

  1. 函数命名
  2. 增加注释

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已做相应修改,感谢!

if not self._layers:
raise RuntimeError(
"Sharding overlap requires an initialized model. "
"Call `_enable_sharding_overlap()` to set model."
)
param2layer = {}
for layer in self._layers.sublayers():
for p in layer.parameters(include_sublayers=False):
param2layer[id(p)] = layer

for i in range(len(self.fuse_param_view)):
self._reduce_scatter_gradients(self.grad_storage[i])

Expand All @@ -1541,6 +1566,26 @@ def fuse_comm(*_):

return fuse_comm

def fuse_all_gather_hook_func(param_storage, comm_group):
@paddle.autograd.no_grad()
def fuse_comm(*_):
if not param_storage.is_sync:
shard_size = param_storage._numel() // comm_group.nranks
begin = shard_size * max(comm_group.rank, 0)
end = begin + shard_size
slice_buffer = paddle._C_ops.view_slice(
param_storage, begin, end
)
task = paddle.distributed.all_gather(
param_storage,
slice_buffer,
group=comm_group,
sync_op=False,
)
param_storage.is_sync = True

return fuse_comm

param_group_len = (
len(self.fuse_param_view[i]) * self.gradient_accumulation_steps
)
Expand All @@ -1557,6 +1602,18 @@ def fuse_comm(*_):
)
)

if i < len(self.fuse_param_view) - 1:
first_param = next(iter(self.fuse_param_view[i].values()))[
'param'
]
layer = param2layer.get(id(first_param))
layer.register_forward_pre_hook(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 这里每次调用 _find_layer_containing_param 都会遍历所有子layer,建议缓存 param2layer 的关系
  2. 考虑 layer 为 None 的情况

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改为用局部变量 param2layer = {} 缓存,已有 self._layers 为 None 时的报错提醒。

fuse_all_gather_hook_func(
self.param_storage[i + 1],
self._sharding_group,
)
)

def _build_fuse_param_view(
self,
params_and_grads,
Expand All @@ -1579,6 +1636,7 @@ def get_padded_size(param):
param_buffer = paddle.zeros(
shape=[total_buffer_size], dtype=params_and_grads[0][0].dtype
)
param_buffer.is_sync = False
grad_dtype = paddle.float32
grad_buffer = paddle.zeros(shape=[total_buffer_size], dtype=grad_dtype)
grad_buffer.check_in = 0
Expand Down
Loading