Skip to content

Commit 91a0acd

Browse files
authored
static support mp_layers (#33700)
1 parent 58e465a commit 91a0acd

File tree

6 files changed

+273
-6
lines changed

6 files changed

+273
-6
lines changed

python/paddle/distributed/collective.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,6 @@ def is_member(self):
9292
return True
9393

9494
def get_group_rank(self, rank):
95-
if self.id == 0:
96-
return rank
9795
if self.is_member() and rank in self.ranks:
9896
return self.ranks.index(rank)
9997
else:
@@ -126,7 +124,8 @@ def _get_group_map():
126124
global _group_map
127125
if not _group_map:
128126
genv = _get_global_env()
129-
_group_map[0] = Group(genv.rank, genv.world_size, 0)
127+
_group_map[0] = Group(genv.rank, genv.world_size,
128+
list(range(genv.world_size)))
130129
return _group_map
131130

132131

@@ -1014,6 +1013,27 @@ def _c_softmax_with_cross_entropy(logits,
10141013
else:
10151014
return loss, softmax
10161015

1016+
attrs = {
1017+
'ring_id': ring_id,
1018+
'rank': rank,
1019+
'nranks': nranks,
1020+
}
1021+
helper = LayerHelper('c_softmax_with_cross_entropy', **locals())
1022+
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
1023+
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
1024+
helper.append_op(
1025+
type='c_softmax_with_cross_entropy',
1026+
inputs={'Logits': logits,
1027+
'Label': label},
1028+
outputs={'Softmax': softmax,
1029+
'Loss': loss},
1030+
attrs=attrs)
1031+
1032+
if return_softmax:
1033+
return loss, softmax
1034+
1035+
return loss
1036+
10171037

10181038
def _linear(x, weight, bias=None, name=None):
10191039
"""

python/paddle/distributed/fleet/base/fleet_base.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,40 @@ def init(self, role_maker=None, is_collective=False, strategy=None):
253253
warnings.warn(
254254
"The dygraph hybrid parallel environment has been initialized."
255255
)
256+
elif self._is_collective:
257+
use_sharding = self._user_defined_strategy.sharding
258+
259+
# global group
260+
global_rank = self.worker_index()
261+
global_world_size = self.worker_num()
262+
# NOTE(wangxi): see sharding_optimizer
263+
global_ring_id = 3 if use_sharding else 0
264+
global_ranks = list(range(global_world_size))
265+
266+
if tp._HYBRID_PARALLEL_GROUP is None: tp._CommunicateGroup()
267+
cg = tp._HYBRID_PARALLEL_GROUP
268+
self._hcg = cg
269+
cg.set_comm_group('global', global_rank, global_world_size,
270+
global_ring_id, global_ranks)
271+
272+
# hybrid group
273+
if use_sharding is False: return
274+
275+
sharding_configs = self._user_defined_strategy.sharding_configs
276+
mp_degree = int(sharding_configs['mp_degree'])
277+
278+
if mp_degree > 1:
279+
assert global_world_size % mp_degree == 0
280+
# NOTE(wangxi): mp_ring_id sync with sharding_optimizer.py _build_groups
281+
mp_ring_id = 0
282+
mp_rank = global_rank % mp_degree
283+
mp_group_id = global_rank // mp_degree
284+
mp_group_ranks = [
285+
idx for idx in global_ranks
286+
if idx // mp_degree == mp_group_id
287+
]
288+
cg.set_comm_group('model', mp_rank, mp_degree, mp_ring_id,
289+
mp_group_ranks)
256290

257291
def _init_hybrid_parallel_env(self):
258292
"""initialize the hybrid environment

python/paddle/distributed/fleet/base/topology.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,31 @@ def get_check_parallel_group(self):
262262
def get_rank_from_stage(self, stage_id, **kwargs):
263263
return self._topo.get_rank_from_stage(
264264
self.global_rank, pipe=stage_id, **kwargs)
265+
266+
267+
class _CommunicateGroup(object):
268+
""" tmp for static """
269+
270+
def __init__(self):
271+
global _HYBRID_PARALLEL_GROUP
272+
_HYBRID_PARALLEL_GROUP = self
273+
self.groups = dict()
274+
275+
def set_comm_group(self, group_name, group_rank, group_size, ring_id,
276+
group_ranks):
277+
group = paddle.distributed.collective.Group(group_rank, group_size,
278+
ring_id, group_ranks)
279+
self.groups[group_name] = group
280+
281+
def get_group(self, group_name):
282+
assert group_name in self.groups
283+
return self.groups[group_name]
284+
285+
def get_model_parallel_group(self):
286+
return self.get_group('model')
287+
288+
def get_model_parallel_world_size(self):
289+
return self.get_group('model').nranks
290+
291+
def get_model_parallel_rank(self):
292+
return self.get_group('model').rank

python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(self,
5656
self._weight_attr = weight_attr
5757
self._name = name
5858

59-
if self.is_mp:
59+
if self.is_mp and paddle.in_dynamic_mode():
6060
with get_rng_state_tracker().rng_state():
6161
self.weight = self.create_parameter(
6262
attr=self._weight_attr,
@@ -121,7 +121,7 @@ def __init__(self,
121121
self._weight_attr = weight_attr
122122
self._dtype = self._helper.get_default_dtype()
123123

124-
if self.is_mp:
124+
if self.is_mp and paddle.in_dynamic_mode():
125125
with get_rng_state_tracker().rng_state():
126126
self.weight = self.create_parameter(
127127
shape=[in_features, self.output_size_per_partition],
@@ -198,7 +198,7 @@ def __init__(self,
198198

199199
self.input_size_per_partition = in_features // self.world_size
200200

201-
if self.is_mp:
201+
if self.is_mp and paddle.in_dynamic_mode():
202202
with get_rng_state_tracker().rng_state():
203203
self.weight = self.create_parameter(
204204
shape=[self.input_size_per_partition, self.out_features],

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_graph_executor)
7070
list(APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer_base)
7171
list(APPEND MIXED_DIST_TEST_OPS test_fleet_distributed_strategy)
7272
list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto)
73+
list(APPEND MIXED_DIST_TEST_OPS test_fleet_static_mp_layers)
7374
foreach(TEST_OP ${MIXED_DIST_TEST_OPS})
7475
list(REMOVE_ITEM TEST_OPS ${TEST_OP})
7576
endforeach()
@@ -525,6 +526,7 @@ if(WITH_DISTRIBUTE)
525526
py_test_modules(test_fleet_private_function MODULES test_fleet_private_function ENVS ${dist_ENVS})
526527
py_test_modules(test_fleet_meta_optimizer_base MODULES test_fleet_meta_optimizer_base ENVS ${dist_ENVS})
527528
py_test_modules(test_fleet_distributed_strategy MODULES test_fleet_distributed_strategy)
529+
py_test_modules(test_fleet_static_mp_layers MODULES test_fleet_static_mp_layers)
528530
#py_test_modules(test_fleet_auto MODULES test_fleet_auto ENVS ${dist_ENVS})
529531
if(NOT WIN32)
530532
py_test_modules(test_fleet_localsgd_meta_optimizer MODULES test_fleet_localsgd_meta_optimizer ENVS ${dist_ENVS})
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import division
16+
from __future__ import print_function
17+
18+
import unittest
19+
20+
import paddle
21+
import numpy as np
22+
import random
23+
import paddle.distributed as dist
24+
import paddle.fluid as fluid
25+
import paddle.distributed.fleet as fleet
26+
from paddle import framework
27+
import os
28+
29+
paddle.enable_static()
30+
31+
32+
class ColumnLinearNet(fluid.dygraph.Layer):
33+
def __init__(self, input_size, output_size):
34+
super(ColumnLinearNet, self).__init__()
35+
self.parallel_linear = fleet.meta_parallel.ColumnParallelLinear(
36+
in_features=input_size,
37+
out_features=output_size,
38+
weight_attr=None,
39+
has_bias=True,
40+
gather_output=True,
41+
name="test_column_linear")
42+
43+
def forward(self, x):
44+
output = self.parallel_linear(x)
45+
return output
46+
47+
48+
class RowLinearNet(fluid.dygraph.Layer):
49+
def __init__(self, input_size, output_size):
50+
super(RowLinearNet, self).__init__()
51+
self.parallel_linear = fleet.meta_parallel.RowParallelLinear(
52+
in_features=input_size,
53+
out_features=output_size,
54+
has_bias=True,
55+
input_is_parallel=False,
56+
name="test_row_linear")
57+
58+
def forward(self, x):
59+
output = self.parallel_linear(x)
60+
return output
61+
62+
63+
class EmbeddingNet(fluid.dygraph.Layer):
64+
def __init__(self, vocab_size, hidden_size):
65+
super(EmbeddingNet, self).__init__()
66+
self.embedding = fleet.meta_parallel.VocabParallelEmbedding(vocab_size,
67+
hidden_size)
68+
69+
def forward(self, x):
70+
output = self.embedding(x)
71+
return output
72+
73+
74+
class TestDistTraning(unittest.TestCase):
75+
def setUp(self):
76+
os.environ["PADDLE_TRAINER_ID"] = "2"
77+
os.environ[
78+
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002,127.0.0.1:36003,127.0.0.1:36004"
79+
80+
strategy = fleet.DistributedStrategy()
81+
self.model_parallel_size = 2
82+
strategy.sharding = True
83+
strategy.sharding_configs = {
84+
"mp_degree": self.model_parallel_size,
85+
"sharding_degree": 2,
86+
}
87+
fleet.init(is_collective=True, strategy=strategy)
88+
89+
def get_program(self):
90+
return paddle.static.Program(), paddle.static.Program()
91+
92+
def test_column_parallel_layer(self):
93+
main_program, startup_program = self.get_program()
94+
with paddle.static.program_guard(main_program, startup_program):
95+
input_size, output_size = 28, 64
96+
model_a = ColumnLinearNet(input_size, output_size)
97+
98+
x = paddle.static.data(name='x', shape=[None, input_size])
99+
y = model_a(x)
100+
101+
#print(main_program)
102+
ops = main_program.global_block().ops
103+
ops = [op.type for op in ops]
104+
self.assertEqual(
105+
ops, ['c_identity', 'matmul', 'elementwise_add', 'c_concat'])
106+
107+
weight = model_a.parallel_linear.weight
108+
bias = model_a.parallel_linear.bias
109+
self.assertEqual(weight.shape, (input_size, output_size //
110+
self.model_parallel_size))
111+
self.assertEqual(bias.shape,
112+
(output_size // self.model_parallel_size, ))
113+
114+
def test_row_parallel_layer(self):
115+
main_program, startup_program = self.get_program()
116+
with paddle.static.program_guard(main_program, startup_program):
117+
input_size, output_size = 28, 64
118+
model_a = RowLinearNet(input_size, output_size)
119+
120+
x = paddle.static.data(name='x', shape=[None, input_size])
121+
y = model_a(x)
122+
123+
#print(main_program)
124+
ops = main_program.global_block().ops
125+
ops = [op.type for op in ops]
126+
self.assertEqual(
127+
ops,
128+
['c_split', 'matmul', 'c_allreduce_sum', 'elementwise_add'])
129+
130+
weight = model_a.parallel_linear.weight
131+
bias = model_a.parallel_linear.bias
132+
self.assertEqual(weight.shape, (
133+
input_size // self.model_parallel_size, output_size))
134+
self.assertEqual(bias.shape, (output_size, ))
135+
136+
def test_parallel_embedding(self):
137+
main_program, startup_program = self.get_program()
138+
with paddle.static.program_guard(main_program, startup_program):
139+
vocab_size, hidden_size = 1000, 512
140+
seq_len = 128
141+
142+
# model_a
143+
model_a = EmbeddingNet(vocab_size, hidden_size)
144+
145+
x = paddle.static.data(
146+
name='x', shape=[None, seq_len], dtype='int64')
147+
y = model_a(x)
148+
149+
#print(main_program)
150+
ops = main_program.global_block().ops
151+
ops = [op.type for op in ops]
152+
self.assertEqual(ops, ['c_embedding', 'c_allreduce_sum'])
153+
154+
weight = model_a.embedding.weight
155+
self.assertEqual(weight.shape, (
156+
vocab_size // self.model_parallel_size, hidden_size))
157+
158+
def test_parallel_cross_entropy(self):
159+
main_program, startup_program = self.get_program()
160+
with paddle.static.program_guard(main_program, startup_program):
161+
batch_size = 8
162+
seq_length = 16
163+
class_size = 1000
164+
class_size_per_card = class_size // self.model_parallel_size
165+
166+
# model_a
167+
model_a = fleet.meta_parallel.ParallelCrossEntropy()
168+
169+
x = paddle.static.data(
170+
name='x', shape=[batch_size, seq_length, class_size_per_card])
171+
label = paddle.static.data(
172+
name='label', shape=[batch_size, seq_length], dtype='int64')
173+
loss_a = model_a(x, label)
174+
175+
#print(main_program)
176+
ops = main_program.global_block().ops
177+
ops = [op.type for op in ops]
178+
self.assertEqual(ops,
179+
['unsqueeze2', 'c_softmax_with_cross_entropy'])
180+
181+
182+
if __name__ == '__main__':
183+
unittest.main()

0 commit comments

Comments
 (0)