Skip to content

Commit 294dfd2

Browse files
authored
[HybridParallel]Add SharedLayerDesc for PipelineParallel (#33578)
* add pplayer * add sharedlayerdesc
1 parent 07197fb commit 294dfd2

File tree

9 files changed

+358
-7
lines changed

9 files changed

+358
-7
lines changed

python/paddle/distributed/collective.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,9 @@ def new_group(ranks=None, backend=None):
267267

268268
# TODO(shenliang03): This is a temporary solution to solve the problem of
269269
# hang caused by cross-creation of new_group
270-
tmp = fill_constant([0], dtype="int32", value="1")
270+
tmp = paddle.to_tensor(
271+
[1], dtype="int32") if in_dygraph_mode() else fill_constant(
272+
[0], dtype="int32", value="1")
271273
paddle.distributed.all_reduce(tmp, use_calc_stream=True)
272274
paddle.distributed.wait(tmp)
273275
return gp

python/paddle/distributed/fleet/base/topology.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ def get_comm_list(self, axis_name):
107107

108108
return all_result
109109

110+
def get_rank_from_stage(self, global_rank, **kwargs):
111+
coord = self.get_coord(global_rank)
112+
tf = coord._replace(**kwargs)._asdict()
113+
return self.get_rank(**tf)
114+
110115

111116
class HybridCommunicateGroup(object):
112117
def __init__(self, topology):
@@ -254,7 +259,6 @@ def get_pipe_parallel_group(self):
254259
def get_check_parallel_group(self):
255260
return self._check_comm_group
256261

257-
def get_rank_from_stage(self, stage_id):
258-
coord = self._topo.get_coord(self.global_rank)
259-
tf = coord._replace(pipe=stage_id)._asdict()
260-
return self._topo.get_rank(**tf)
262+
def get_rank_from_stage(self, stage_id, **kwargs):
263+
return self._topo.get_rank_from_stage(
264+
self.global_rank, pipe=stage_id, **kwargs)

python/paddle/distributed/fleet/meta_parallel/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .parallel_layers import RowParallelLinear # noqa: F401
1818
from .parallel_layers import ParallelCrossEntropy # noqa: F401
1919
from .parallel_layers import LayerDesc # noqa: F401
20+
from .parallel_layers import SharedLayerDesc # noqa: F401
2021
from .parallel_layers import PipelineLayer # noqa: F401
2122
from .parallel_layers import RNGStatesTracker # noqa: F401
2223
from .parallel_layers import model_parallel_random_seed # noqa: F401

python/paddle/distributed/fleet/meta_parallel/parallel_layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .mp_layers import RowParallelLinear # noqa: F401
1818
from .mp_layers import ParallelCrossEntropy # noqa: F401
1919
from .pp_layers import LayerDesc # noqa: F401
20+
from .pp_layers import SharedLayerDesc # noqa: F401
2021
from .pp_layers import PipelineLayer # noqa: F401
2122
from .random import RNGStatesTracker # noqa: F401
2223
from .random import model_parallel_random_seed # noqa: F401

python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import paddle
1616
from paddle.fluid.dygraph.layers import Layer
1717
from ...utils.log_util import logger, layer_to_str
18+
from functools import partial
1819

1920
__all__ = []
2021

@@ -58,6 +59,20 @@ def __repr__(self):
5859
**self.kwargs)
5960

6061

62+
class SharedLayerDesc(LayerDesc):
63+
def __init__(self,
64+
key,
65+
layer_func,
66+
forward_func=None,
67+
shared_weight_attr='weight',
68+
*inputs,
69+
**kwargs):
70+
super(SharedLayerDesc, self).__init__(layer_func, *inputs, **kwargs)
71+
self.layer_name = key
72+
self.forward_func = forward_func
73+
self.shared_weight_attr = shared_weight_attr
74+
75+
6176
class PipelineLayer(Layer):
6277
def __init__(self,
6378
layers,
@@ -104,11 +119,86 @@ def __init__(self,
104119
self._start_pos = 0
105120
self._end_pos = self._num_layers - 1
106121
self._segment_network(seg_method)
122+
self.shared_layers = paddle.nn.LayerDict()
123+
self.shared_weight_attrs = {}
107124

108125
# construct layer
109126
self.run_function = []
110127
self._build_layer()
111128

129+
self.shared_comm = self._construct_shared_comm()
130+
self._synchronize_shared_weights()
131+
132+
def get_stage_from_index(self, layer_idx):
133+
assert 0 <= layer_idx < self._num_layers, "layer_idx is out of bound"
134+
for stage in range(self._topo.get_dim('pipe')):
135+
if self.segment_parts[stage] <= layer_idx < self.segment_parts[stage
136+
+ 1]:
137+
return stage
138+
139+
def _construct_shared_comm(self):
140+
shared_comm = {}
141+
if self._topo.get_dim("pipe") == 1:
142+
return
143+
144+
layers_desc = self._layers_desc
145+
shared_layer_names = set(
146+
s.layer_name for s in layers_desc if isinstance(s, SharedLayerDesc))
147+
for key in shared_layer_names:
148+
shared_layers = []
149+
for idx, layer in enumerate(layers_desc):
150+
if isinstance(layer,
151+
SharedLayerDesc) and layer.layer_name == key:
152+
shared_layers.append(idx)
153+
154+
shared_stages = set(
155+
self.get_stage_from_index(idx) for idx in shared_layers)
156+
self._dp_degree = self._topo.get_dim('data')
157+
self._mp_degree = self._topo.get_dim('model')
158+
159+
shared_ranks = []
160+
for dp in range(self._dp_degree):
161+
for mp in range(self._mp_degree):
162+
shared_ranks = []
163+
for s in sorted(shared_stages):
164+
shared_ranks.append(
165+
self._topo.get_rank_from_stage(
166+
self.global_rank, pipe=s, data=dp, model=mp))
167+
168+
group = paddle.distributed.new_group(ranks=shared_ranks)
169+
if self.global_rank in shared_ranks:
170+
assert key in self.shared_layers
171+
if key in self.shared_layers:
172+
shared_comm[key] = {
173+
'ranks': shared_ranks,
174+
'group': group,
175+
'weight_attr': self.shared_weight_attrs[key],
176+
'layer': self.shared_layers[key],
177+
}
178+
return shared_comm
179+
180+
def _synchronize_shared_weights(self):
181+
for key, comm in self.shared_comm.items():
182+
with paddle.framework.no_grad():
183+
paddle.distributed.broadcast(
184+
getattr(comm['layer'], comm['weight_attr']),
185+
src=min(comm['ranks']),
186+
group=comm['group'])
187+
188+
def allreduce_shared_weight_gradients(self):
189+
for key, comm in self.shared_comm.items():
190+
param = getattr(self.shared_layers[key], comm['weight_attr'])
191+
# need use trace_op to allreduce weight
192+
with paddle.framework.no_grad():
193+
paddle.fluid.framework._dygraph_tracer().trace_op(
194+
type="c_allreduce_sum",
195+
inputs={'X': param._grad_ivar()},
196+
outputs={'Out': param._grad_ivar()},
197+
attrs={
198+
'ring_id': comm['group'].id,
199+
'use_calc_stream': True
200+
})
201+
112202
def _segment_network(self, seg_method):
113203
logger.info("start segment network..")
114204
seg = SegmentLayers(
@@ -142,6 +232,21 @@ def _build_layer(self):
142232
if isinstance(layer, Layer):
143233
self.run_function.append(layer)
144234
self.add_sublayer(str(layer_index), layer)
235+
elif isinstance(layer, SharedLayerDesc):
236+
if layer.layer_name not in self.shared_layers:
237+
self.shared_layers[layer.layer_name] = layer.build_layer()
238+
self.shared_weight_attrs[
239+
layer.layer_name] = layer.shared_weight_attr
240+
241+
if layer.forward_func is None:
242+
self.run_function.append(self.shared_layers[
243+
layer.layer_name])
244+
245+
else:
246+
self.run_function.append(
247+
partial(layer.forward_func, self.shared_layers[
248+
layer.layer_name]))
249+
145250
elif isinstance(layer, LayerDesc):
146251
model = layer.build_layer()
147252
self.run_function.append(model)

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ def train_batch(self, data, optimizer, lr_scheduler=None):
138138
self._backward(cache_id=backward_steps)
139139
backward_steps += 1
140140

141+
self._layers.allreduce_shared_weight_gradients()
142+
141143
# optimizer
142144
self._step()
143145
self.train_loss = self._reduce_final_loss()

python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ def test_parallel_embedding(self):
270270
np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy())
271271

272272
def test_parallel_cross_entropy(self):
273-
batch_size = 2
274-
seq_length = 1
273+
batch_size = 8
274+
seq_length = 16
275275
class_size_per_card = 2
276276
vocab_size = class_size_per_card * self.model_parallel_size
277277
seed = 1025

0 commit comments

Comments
 (0)