Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion paddle/fluid/pybind/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,12 @@ void BindTensor(pybind11::module &m) { // NOLINT
[](DistTensor &self, const DistTensor &src) {
self.unsafe_set_dims(src.dims());
self.unsafe_set_dist_attr(src.dist_attr());
self.unsafe_mutable_value()->ShareDataWith(src.value());
if (!IsCurRankInMesh(self.process_mesh()) &&
!IsCurRankInMesh(src.dist_attr().process_mesh())) {
self.unsafe_mutable_value()->ShareDataNoCheckWith(src.value());
} else {
self.unsafe_mutable_value()->ShareDataWith(src.value());
}
return self;
})
.def("_clear", &DistTensor::clear);
Expand Down
65 changes: 33 additions & 32 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,9 @@ def __init__(self, optimizer, shard_fn=None):
self._shard_fn._shard_parameter(param)

def _set_and_check_sharding_prop_from_param(self):
if len(self._shard_fn._mesh._shape) == 1:
if (self._shard_fn._mesh is not None) and (
len(self._shard_fn._mesh._shape) == 1
):
self._sharding_degree = self._shard_fn._mesh.get_dim_size(0)
self._sharding_mesh_axis = 0
else:
Expand All @@ -684,16 +686,12 @@ def _set_and_check_sharding_prop_from_param(self):
assert isinstance(
placements[self._sharding_mesh_axis], dist.Replicate
), "The placement on sharding_mesh_axis should be Replicate"

# check the sharding degree since it has already been set
if any(
isinstance(placement, dist.Shard)
for placement in placements
):
for idx, placement in enumerate(placements):
if isinstance(placement, dist.Replicate):
assert (
mesh.dim_size(idx) == self._sharding_degree
), "The sharding degree of all parameters must be equal currently."
assert (
mesh.dim_size(self._sharding_mesh_axis)
== self._sharding_degree
), "The sharding degree of all parameters must be equal currently."

assert (
self._sharding_degree is not None
Expand Down Expand Up @@ -889,7 +887,7 @@ class ShardingStage1(_ShardingStageBase):
A builtin shard_fn for shard_optimizer interface, users can pass it to shard_optimizer to implement sharding optimization with stage 1.

Args:
mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
mesh(None|paddle.distributed.ProcessMesh): If mesh is not None, the `ProcessMesh` object describes the Cartesian topology of the used processes for dense type parameters. Note: Currently, only one mesh configuration is supported for all dense parameters. If there is a need for multiple mesh configurations, please configure them yourself in the upper layer networking code.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认值为 None 时的的情况也描述一下吧


Examples:
.. code-block:: python
Expand Down Expand Up @@ -922,7 +920,7 @@ class ShardingStage1(_ShardingStageBase):
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py
"""

def __init__(self, mesh):
def __init__(self, mesh=None):
super().__init__(mesh)

def __call__(self, key, param, accumulator):
Expand Down Expand Up @@ -950,7 +948,7 @@ class ShardingStage2(_ShardingStageBase):
A builtin shard_fn for shard_optimizer interface, users can pass it to shard_optimizer to implement sharding optimization with stage 2.

Args:
mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
mesh(None|paddle.distributed.ProcessMesh): If mesh is not None, the `ProcessMesh` object describes the Cartesian topology of the used processes for dense type parameters. Note: Currently, only one mesh configuration is supported for all dense parameters. If there is a need for multiple mesh configurations, please configure them yourself in the upper layer networking code.

Examples:
.. code-block:: python
Expand Down Expand Up @@ -983,7 +981,7 @@ class ShardingStage2(_ShardingStageBase):
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py
"""

def __init__(self, mesh):
def __init__(self, mesh=None):
super().__init__(mesh)

def __call__(self, key, param, accumulator):
Expand Down Expand Up @@ -1022,21 +1020,21 @@ def _grad_hook(grad):
return grad

def _register_hook_for_param_grad(self, param):
if param.is_dense():
if param.is_dense() and self._mesh is not None:
placements = []
for _ in range(len(self._mesh.shape)):
placements.append(dist.Replicate())
param._to_dist_(placements, self._mesh)

param.register_hook(ShardingStage2._grad_hook)
if param.is_dist():
param.register_hook(ShardingStage2._grad_hook)


class ShardingStage3(_ShardingStageBase):
"""
A builtin shard_fn for shard_optimizer interface, users can pass it to shard_optimizer to implement sharding optimization with stage 3.

Args:
mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
mesh(None|paddle.distributed.ProcessMesh): If mesh is not None, the `ProcessMesh` object describes the Cartesian topology of the used processes for dense type parameters. Note: Currently, only one mesh configuration is supported for all dense parameters. If there is a need for multiple mesh configurations, please configure them yourself in the upper layer networking code.

Examples:
.. code-block:: python
Expand Down Expand Up @@ -1069,30 +1067,33 @@ class ShardingStage3(_ShardingStageBase):
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py
"""

def __init__(self, mesh):
def __init__(self, mesh=None):
super().__init__(mesh)

def _shard_parameter(self, param):
if param.is_dense():
if param.is_dense() and self._mesh is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else: raise error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不指定 mesh 的情况下,不允许 param 是 dense 么?当前测试的 llama2-13B 动静统一组网下就是存在 param 为 dense 的情况。

placements = []
for _ in range(len(self._mesh.shape)):
placements.append(dist.Replicate())
param._to_dist_(placements, self._mesh)

new_placements = get_placement_with_sharding(
param, self._sharding_mesh_axis
)
shard_param = dist.reshard(param, param.process_mesh, new_placements)
# change the holder of param to new shard_param
param.get_tensor()._share_data_with(shard_param.get_tensor())
if param.is_dist():
new_placements = get_placement_with_sharding(
param, self._sharding_mesh_axis
)
shard_param = dist.reshard(
param, param.process_mesh, new_placements
)
# change the holder of param to new shard_param
param.get_tensor()._share_data_with(shard_param.get_tensor())

def _unshard_parameter(self, param):
new_placements = param.placements
if isinstance(new_placements[self._sharding_mesh_axis], dist.Shard):
new_placements[self._sharding_mesh_axis] = dist.Replicate()
if param.is_dist():
new_placements = param.placements
if isinstance(new_placements[self._sharding_mesh_axis], dist.Shard):
new_placements[self._sharding_mesh_axis] = dist.Replicate()

new_param = dist.reshard(param, param.process_mesh, new_placements)
param.get_tensor()._share_data_with(new_param.get_tensor())
new_param = dist.reshard(param, param.process_mesh, new_placements)
param.get_tensor()._share_data_with(new_param.get_tensor())

def __call__(self, key, param, accumulator):
if param.is_dist():
Expand Down