@@ -1735,6 +1735,37 @@ def filter_params(model_to_save, state_dict, is_optimizer=False):
17351735 return filter_tensor_list
17361736
17371737
1738+ def merge_large_tensor_parallel (tensor , tp_group , tp_action , dst_rank , is_dst ):
1739+ num_rows = tensor .shape [0 ]
1740+ num_splits = 4
1741+ parts = np .array_split (np .arange (num_rows ), num_splits )
1742+ splits = [len (part ) for part in parts ]
1743+ split_parts = np .insert (np .cumsum (splits ), 0 , 0 )
1744+ split_tensors = []
1745+ for i in range (num_splits ):
1746+ if get_env_device () == "xpu" :
1747+ ret = distributed_allgather (tensor [split_parts [i ] : split_parts [i + 1 ], :], group = tp_group , offload = False )
1748+ else :
1749+ ret = distributed_gather (
1750+ tensor [split_parts [i ] : split_parts [i + 1 ], :], dst = dst_rank , group = tp_group , offload = False
1751+ )
1752+ # Copy to CPUPlace temporarily, may lower speed.
1753+ if ret is not None :
1754+ ret = [t .cpu () for t in ret ]
1755+ split_tensors .append (ret )
1756+ concat_tensors = []
1757+ if is_dst :
1758+ for i in range (tp_group .nranks ):
1759+ tmp = []
1760+ for j in range (num_splits ):
1761+ tmp .append (split_tensors [j ][i ])
1762+ concat_tensors .append (paddle .concat (tmp ))
1763+ tensor = tp_action (concat_tensors )
1764+ else :
1765+ tensor = None
1766+ return tensor
1767+
1768+
17381769def merge_tensor_parallel_with_shard (state_dict , tp_actions , all_filter_keys ):
17391770 hcg = fleet .get_hybrid_communicate_group ()
17401771 tp_group = hcg .get_model_parallel_group ()
@@ -1757,12 +1788,17 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
17571788 key = filter_keys [i ]
17581789 tensor = state_dict [key ]
17591790 if key in tp_actions :
1760- if get_env_device () == "xpu" :
1761- ret = distributed_allgather (tensor , group = tp_group , offload = False )
1791+ # Get tensor size
1792+ tensor_bytes = tensor .numel ().item () * dtype_byte_size (tensor .dtype ) * tp_group .nranks
1793+ if tensor_bytes >= 5368709120 : # temporarily set 5GB as threshold
1794+ tensor = merge_large_tensor_parallel (tensor , tp_group , tp_actions [key ], j , is_dst )
17621795 else :
1763- ret = distributed_gather (tensor , dst = j , group = tp_group , offload = False )
1764- action = tp_actions .pop (key )
1765- tensor = action (ret ) if is_dst else None
1796+ if get_env_device () == "xpu" :
1797+ ret = distributed_allgather (tensor , group = tp_group , offload = False )
1798+ else :
1799+ ret = distributed_gather (tensor , dst = j , group = tp_group , offload = False )
1800+ action = tp_actions .pop (key )
1801+ tensor = action (ret ) if is_dst else None
17661802 else :
17671803 tensor = tensor ._copy_to (DEST_PLACE , False ) if is_dst else None
17681804
@@ -1797,12 +1833,17 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
17971833 if tensor .numel ().item () == 1 :
17981834 tensor = tensor ._copy_to (DEST_PLACE , False ) if is_dst else None # Need broadcast when loaded
17991835 else :
1800- if get_env_device () == "xpu" :
1801- ret = distributed_allgather (tensor , group = tp_group , offload = False )
1836+ # Get tensor size
1837+ tensor_bytes = tensor .numel ().item () * dtype_byte_size (tensor .dtype ) * tp_group .nranks
1838+ if tensor_bytes >= 5368709120 : # temporarily set 5GB as threshold
1839+ tensor = merge_large_tensor_parallel (tensor , tp_group , tp_actions [model_key ], j , is_dst )
18021840 else :
1803- ret = distributed_gather (tensor , dst = j , group = tp_group , offload = False )
1804- action = tp_actions [model_key ]
1805- tensor = action (ret ) if is_dst else None
1841+ if get_env_device () == "xpu" :
1842+ ret = distributed_allgather (tensor , group = tp_group , offload = False )
1843+ else :
1844+ ret = distributed_gather (tensor , dst = j , group = tp_group , offload = False )
1845+ action = tp_actions [model_key ]
1846+ tensor = action (ret ) if is_dst else None
18061847 else :
18071848 tensor = tensor ._copy_to (DEST_PLACE , False ) if is_dst else None
18081849
0 commit comments