Skip to content

Commit 937e21a

Browse files
authored
supports mp and dp hybrid (#34377)
1 parent 846be13 commit 937e21a

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

python/paddle/distributed/fleet/base/fleet_base.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,27 @@ def init(self, role_maker=None, is_collective=False, strategy=None):
269269
cg.set_comm_group('global', global_rank, global_world_size,
270270
global_ring_id, global_ranks)
271271

272+
use_tensor_parallel = self._user_defined_strategy.tensor_parallel
273+
use_mp = use_sharding or use_tensor_parallel
274+
272275
# hybrid group
273-
if use_sharding is False: return
276+
if use_mp is False: return
277+
278+
mp_degree_sharding = 1
279+
mp_degree_tensor_parallel = 1
280+
if use_sharding:
281+
sharding_configs = self._user_defined_strategy.sharding_configs
282+
mp_degree_sharding = int(sharding_configs['mp_degree'])
283+
284+
if use_tensor_parallel:
285+
tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
286+
mp_degree_tensor_parallel = int(tensor_parallel_configs[
287+
'tensor_parallel_degree'])
288+
289+
if use_sharding and use_tensor_parallel:
290+
assert mp_degree_sharding == mp_degree_tensor_parallel
274291

275-
sharding_configs = self._user_defined_strategy.sharding_configs
276-
mp_degree = int(sharding_configs['mp_degree'])
292+
mp_degree = mp_degree_sharding if use_sharding else mp_degree_tensor_parallel
277293

278294
if mp_degree > 1:
279295
assert global_world_size % mp_degree == 0

python/paddle/fluid/tests/unittests/test_fleet_static_mp_layers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def setUp(self):
8484
"mp_degree": self.model_parallel_size,
8585
"sharding_degree": 2,
8686
}
87+
strategy.tensor_parallel = True
88+
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 2}
8789
fleet.init(is_collective=True, strategy=strategy)
8890

8991
def get_program(self):

0 commit comments

Comments
 (0)