Skip to content
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ message HybridConfig {
optional int32 dp_degree = 1 [ default = -1 ];
optional int32 mp_degree = 2 [ default = 1 ];
optional int32 pp_degree = 3 [ default = 1 ];
optional int32 sharding_degree = 4 [ default = 1 ];
}

message AMPConfig {
Expand Down
18 changes: 14 additions & 4 deletions python/paddle/distributed/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from . import topology as tp
from .topology import ParallelMode
from ..meta_parallel import TensorParallel, model_parallel_random_seed
from ..meta_parallel import PipelineParallel
from ..meta_parallel import PipelineParallel, ShardingParallel
from ..meta_optimizers import HybridParallelOptimizer
from ..meta_optimizers import HybridParallelGradScaler

Expand Down Expand Up @@ -295,9 +295,11 @@ def _init_hybrid_parallel_env(self):
self.dp_degree = self.hybrid_configs["dp_degree"]
self.mp_degree = self.hybrid_configs["mp_degree"]
self.pp_degree = self.hybrid_configs["pp_degree"]
self.sharding_degree = self.hybrid_configs["sharding_degree"]

assert self.mp_degree >= 0, "mp_degree should be greater or equal to 0"
assert self.pp_degree >= 0, "pp_degree should be greater or equal to 0"
assert self.sharding_degree >= 0, "sharding_degree should be greater or equal to 0"

self.mp_degree = max(self.mp_degree, 1)
self.pp_degree = max(self.pp_degree, 1)
Expand All @@ -309,8 +311,11 @@ def _init_hybrid_parallel_env(self):
self.dp_degree = max(self.dp_degree, 1)

self._topology = tp.CommunicateTopology(
hybrid_group_names=["data", "pipe", "model"],
dims=[self.dp_degree, self.pp_degree, self.mp_degree])
hybrid_group_names=["data", "pipe", "sharding", "model"],
dims=[
self.dp_degree, self.pp_degree, self.sharding_degree,
self.mp_degree
])

self._hcg = tp.HybridCommunicateGroup(self._topology)

Expand Down Expand Up @@ -886,7 +891,11 @@ def forward(self, x):
assert model is not None, "model should not be None"
if self.worker_num() <= 1:
return model
if self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:

if self._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL:
distributed_model = ShardingParallel(
model, self._hcg, strategy=self._user_defined_strategy)
elif self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
distributed_model = paddle.DataParallel(
model,
comm_buffer_size=self._user_defined_strategy.
Expand All @@ -901,6 +910,7 @@ def forward(self, x):
elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
distributed_model = PipelineParallel(
model, self._hcg, strategy=self._user_defined_strategy)

return distributed_model

@dygraph_only
Expand Down
55 changes: 43 additions & 12 deletions python/paddle/distributed/fleet/base/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ class ParallelMode(object):
DATA_PARALLEL = 0
TENSOR_PARALLEL = 1
PIPELINE_PARALLEL = 2
SHARDING_PARALLEL = 3


class CommunicateTopology(object):
def __init__(self,
hybrid_group_names=["data", "pipe", "model"],
dims=[1, 1, 1]):
hybrid_group_names=["data", "pipe", "sharding", "model"],
dims=[1, 1, 1, 1]):
self._parallel_names = hybrid_group_names
self._dims = dims
self.coordinate = collections.namedtuple('Coordinate',
Expand Down Expand Up @@ -122,15 +123,17 @@ def __init__(self, topology):
self._dp_degree = self._topo.get_dim('data')
self._mp_degree = self._topo.get_dim('model')
self._pp_degree = self._topo.get_dim('pipe')
self._sharding_degree = self._topo.get_dim('sharding')

self._data_parallel_id = self._get_data_parallel_id()
self._model_parallel_id = self._get_model_parallel_id()
self._sharding_parallel_id = self._get_sharding_parallel_id()
self.stage_id = self._get_pipe_parallel_id()

assert self._check_vaild_topo(
), "Here is an unreasonable topogy setting. world_size: {}, but" \
"dp_num: {}, mp_num: {}, pp_num: {}".format(self.nranks, self._dp_degree,
self._mp_degree, self._pp_degree)
"mp_num: {}, sharding_num: {}, pp_num: {}, dp_num: {}".format(self.nranks,
self._mp_degree, self._sharding_degree, self._pp_degree, self._dp_degree)

# create comm group for data parallel
self._dp_group, self._dp_comm_group = self._set_comm_group("data")
Expand All @@ -141,6 +144,10 @@ def __init__(self, topology):
# create comm group for pipe parallel
self._pp_group, self._pp_comm_group = self._set_comm_group("pipe")

# create comm group for sharding parallel
self._sharding_group, self._sharding_comm_group = self._set_comm_group(
"sharding")

# create global group for check inf_nan / clip global norm
self._check_group, self._check_comm_group = self._set_check_group(
"data")
Expand All @@ -149,19 +156,26 @@ def __init__(self, topology):
self.is_first_stage = (self.stage_id == 0)
self.is_last_stage = (self.stage_id == (self._pp_degree - 1))

debug_str = "HybridParallelInfo: rank_id: %d, dp_degree: %d, " \
"mp_degree: %d, pp_degree: %d" % (self.global_rank, self._dp_degree,
self._mp_degree,self._pp_degree)
debug_str += ", dp_group: %s, mp_group: %s, pp_group: %s, check/clip group: %s" % (
self._dp_group, self._mp_group, self._pp_group, self._check_group)
debug_str = "HybridParallelInfo: rank_id: %d, mp_degree: %d, " \
"sharding_degree: %d, pp_degree: %d, dp_degree: %d" % (self.global_rank, self._mp_degree,
self._sharding_degree, self._pp_degree, self._dp_degree)
debug_str += ", mp_group: %s, sharding_group: %s, pp_group: %s, dp_group: %s, check/clip group: %s" % (
self._mp_group, self._sharding_group, self._pp_group,
self._dp_group, self._check_group)
logger.info(debug_str)

global _HYBRID_PARALLEL_GROUP
_HYBRID_PARALLEL_GROUP = self

def get_parallel_mode(self):
# there are three modes : DataParallel / TensorParallel / PipelineParallel
if self._mp_degree == 1 and self._pp_degree == 1:
# there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel
# NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and
# adding its parallel logic within that parallelism
# when use sharding alone, it should have its own parallelism for its parallel logic
# TODO modify 3 others parallel to support sharding
if self._mp_degree == 1 and self._pp_degree == 1 and self._dp_degree == 1 and self._sharding_degree > 1:
return ParallelMode.SHARDING_PARALLEL
elif self._mp_degree == 1 and self._pp_degree == 1:
return ParallelMode.DATA_PARALLEL
elif self._mp_degree > 1 and self._pp_degree == 1:
# initialize the seed
Expand All @@ -170,7 +184,7 @@ def get_parallel_mode(self):
return ParallelMode.PIPELINE_PARALLEL

def _check_vaild_topo(self):
return self._dp_degree * self._mp_degree * self._pp_degree == self.nranks
return self._dp_degree * self._mp_degree * self._pp_degree * self._sharding_degree == self.nranks

def _set_comm_group(self, parallel_method="data"):
parallel_group = []
Expand Down Expand Up @@ -255,6 +269,23 @@ def get_pipe_parallel_world_size(self):
def get_pipe_parallel_group(self):
return self._pp_comm_group

# sharding parallel message:
def _get_sharding_parallel_id(self):
return self._topo.get_coord(self.global_rank).sharding

def get_sharding_parallel_rank(self):
return self._sharding_parallel_id

def get_sharding_parallel_world_size(self):
return self._sharding_degree

def get_sharding_parallel_group(self):
return self._sharding_comm_group

def get_sharding_parallel_group_src_rank(self):
# TODO should the src rank related to the shard rank for each parameter ?
return self._sharding_comm_group.ranks[0]

# check parallel group
def get_check_parallel_group(self):
return self._check_comm_group
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

######
from functools import reduce

import paddle
from paddle import framework
from ...utils.log_util import logger


def _is_trainable(param: paddle.Tensor) -> bool:
return not param.stop_gradient


class DygraphShardingOptimizer(object):
"""
A wrapper for Sharding Optimizer in Dygraph.

.. warning: DygraphShardingOptimizer is experimental and subject to change.

.. ZeRO: https://arxiv.org/abs/1910.02054

"""

# TODO (JZ-LIANG)
# TO support following featrues in future:
# 1. fused update parameter sync
# 2. parameters_groups
# 3. dynamic trainable params, which is the case bewteen pretraining and finetuning
# 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm

def __init__(
self,
hcg,
user_defined_strategy,
params,
inner_optimizer_class,
**inner_optimizer_kargs, ):

if not isinstance(params, list):
raise TypeError(
"`parameters` argument given to the DygraphShardingOptimizer should be "
"an iterable of paddle Tensors, but got argument type is `{}`.".
format(type(params)))
self._parameter_list = params
self._reference_is_trainable_params = list(
map(_is_trainable, self._parameter_list))

self._inner_optimizer_class = inner_optimizer_class
self._inner_optimizer_kargs = inner_optimizer_kargs

# sharding parallel information
# TODO better way to get the hcg & user_defined_strategy
self._hcg = hcg
self._user_defined_strategy = user_defined_strategy
self._sharding_world_size = self._hcg.get_sharding_parallel_world_size()
self._sharding_rank = self._hcg.get_sharding_parallel_rank()

# logic partitioning
self._build_sharding_mapping()

# actually create opt ops
self._buid_inner_optimizer()

def clear_grad(self):
"""
should clear grad for all parameters in model
"""
for p in self._parameter_list:
if not p.stop_gradient:
p.clear_gradient()

def _build_sharding_mapping(self):

self._rank2params = self._partition_parameters()
self._param2rank = self._map_param_to_rank()

def _partition_parameters(self):
"""
Partitions parameters among sharding ranks.

Return:
Dict[int, List]
"""
# TODO(JZ-LIANG) support multiple partition methods
# method1: greedy even but unorder
# method2: roughly even with oreder

mapping = {}
for rank_ in range(self._sharding_world_size):
mapping[rank_] = []
sizes = [0] * self._sharding_world_size
for param in self._parameter_list:
rank = sizes.index(min(sizes))
mapping[rank].append(param)
numel = reduce(lambda x, y: x * y, param.shape)
assert numel > 0, "param [{}] should larger than 0, but it is [{}]".format(
param.name, numel)
sizes[rank] += numel

return mapping

def _map_param_to_rank(self):
"""
mapping parameters to the shard which holds it.

Return:
Dict[str, int]
"""
mapping = {}
for rank, params in self._rank2params.items():
for param in params:
mapping[param.name] = rank
return mapping

def _buid_inner_optimizer(self):
# we rely on the inner opt to determine whether a parameter is stop_gradient or not:
# create moment
# update related ops: clip, regular, opt
self._inner_optimizer = self._inner_optimizer_class(
parameters=self._rank2params[self._sharding_rank],
**self._inner_optimizer_kargs)

def _sharding_sync_parameters(self):
"""
sync parameter across sharding group
"""
# TODO speed up this functional

logger.debug("sharding start sync parameters")
with framework.no_grad():
# TODO detach not need (?)
for rank, params in self._rank2params.items():
for param in params:
paddle.distributed.broadcast(
param,
# the collective API need src rank to be the global rank id
# instead of the relative logic rank id within group
src=self._hcg.get_sharding_parallel_group().ranks[rank],
group=self._hcg.get_sharding_parallel_group(),
use_calc_stream=True)

def _update_trainable(self):
"""
allow user to update trainable parameters list during training
"""
raise NotImplementedError

def minimize(self,
loss,
startup_program=None,
parameters=None,
no_grad_set=None):

# NOTE in dygraph mode, the only different between step and minimize is that minimize
# allow user to customize the parameters for updating on each step

input_param_names = set([param.name for param in parameters])
parameters = list(
filter(lambda x: x.name in input_param_names, self._rank2params[
self._sharding_rank]))
result = self._inner_optimizer.minimize(loss, startup_program,
parameters, no_grad_set)

# sync parameters accross sharding ranks
self._sharding_sync_parameters()

return result

def step(self):
# TODO Check whether the model trainable param changed and update state accordingly

# actually updating
self._inner_optimizer.step()

# sync parameters accross sharding ranks
self._sharding_sync_parameters()

# TODO is it a good way to make _grad_clip a property
@property
def _grad_clip(self):
assert self._inner_optimizer is not None, "inner opt of sharding is not initiliazed."
return self._inner_optimizer._grad_clip

def __getattr__(self, item):
return getattr(self._inner_optimizer, item)
Loading