Skip to content

Commit d1fe04b

Browse files
committed
supports mp and pp hybrid, test=allcase
1 parent b5d8f43 commit d1fe04b

1 file changed

Lines changed: 6 additions & 0 deletions

File tree

python/paddle/distributed/fleet/base/fleet_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,12 @@ def init(self, role_maker=None, is_collective=False, strategy=None):
275275
sharding_configs = self._user_defined_strategy.sharding_configs
276276
mp_degree = int(sharding_configs['mp_degree'])
277277

278+
use_tensor_parallel = self._user_defined_strategy.tensor_parallel
279+
if use_tensor_parallel:
280+
tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
281+
mp_degree = int(tensor_parallel_configs[
282+
'tensor_parallel_degree'])
283+
278284
if mp_degree > 1:
279285
assert global_world_size % mp_degree == 0
280286
# NOTE(wangxi): mp_ring_id sync with sharding_optimizer.py _build_groups

0 commit comments

Comments
 (0)