Skip to content

Commit 0a21315

Browse files
committed
update_uc_merge
1 parent d406e56 commit 0a21315

1 file changed

Lines changed: 51 additions & 10 deletions

File tree

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1735,6 +1735,37 @@ def filter_params(model_to_save, state_dict, is_optimizer=False):
17351735
return filter_tensor_list
17361736

17371737

1738+
def merge_large_tensor_parallel(tensor, tp_group, tp_action, dst_rank, is_dst):
1739+
num_rows = tensor.shape[0]
1740+
num_splits = 4
1741+
parts = np.array_split(np.arange(num_rows), num_splits)
1742+
splits = [len(part) for part in parts]
1743+
split_parts = np.insert(np.cumsum(splits), 0, 0)
1744+
split_tensors = []
1745+
for i in range(num_splits):
1746+
if get_env_device() == "xpu":
1747+
ret = distributed_allgather(tensor[split_parts[i] : split_parts[i + 1], :], group=tp_group, offload=False)
1748+
else:
1749+
ret = distributed_gather(
1750+
tensor[split_parts[i] : split_parts[i + 1], :], dst=dst_rank, group=tp_group, offload=False
1751+
)
1752+
# Copy to CPUPlace temporarily, may lower speed.
1753+
if ret is not None:
1754+
ret = [t.cpu() for t in ret]
1755+
split_tensors.append(ret)
1756+
concat_tensors = []
1757+
if is_dst:
1758+
for i in range(tp_group.nranks):
1759+
tmp = []
1760+
for j in range(num_splits):
1761+
tmp.append(split_tensors[j][i])
1762+
concat_tensors.append(paddle.concat(tmp))
1763+
tensor = tp_action(concat_tensors)
1764+
else:
1765+
tensor = None
1766+
return tensor
1767+
1768+
17381769
def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
17391770
hcg = fleet.get_hybrid_communicate_group()
17401771
tp_group = hcg.get_model_parallel_group()
@@ -1757,12 +1788,17 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
17571788
key = filter_keys[i]
17581789
tensor = state_dict[key]
17591790
if key in tp_actions:
1760-
if get_env_device() == "xpu":
1761-
ret = distributed_allgather(tensor, group=tp_group, offload=False)
1791+
# Get tensor size
1792+
tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks
1793+
if tensor_bytes >= 5368709120: # temporarily set 5GB as threshold
1794+
tensor = merge_large_tensor_parallel(tensor, tp_group, tp_actions[key], j, is_dst)
17621795
else:
1763-
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
1764-
action = tp_actions.pop(key)
1765-
tensor = action(ret) if is_dst else None
1796+
if get_env_device() == "xpu":
1797+
ret = distributed_allgather(tensor, group=tp_group, offload=False)
1798+
else:
1799+
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
1800+
action = tp_actions.pop(key)
1801+
tensor = action(ret) if is_dst else None
17661802
else:
17671803
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None
17681804

@@ -1797,12 +1833,17 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
17971833
if tensor.numel().item() == 1:
17981834
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None # Need broadcast when loaded
17991835
else:
1800-
if get_env_device() == "xpu":
1801-
ret = distributed_allgather(tensor, group=tp_group, offload=False)
1836+
# Get tensor size
1837+
tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks
1838+
if tensor_bytes >= 5368709120: # temporarily set 5GB as threshold
1839+
tensor = merge_large_tensor_parallel(tensor, tp_group, tp_actions[model_key], j, is_dst)
18021840
else:
1803-
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
1804-
action = tp_actions[model_key]
1805-
tensor = action(ret) if is_dst else None
1841+
if get_env_device() == "xpu":
1842+
ret = distributed_allgather(tensor, group=tp_group, offload=False)
1843+
else:
1844+
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
1845+
action = tp_actions[model_key]
1846+
tensor = action(ret) if is_dst else None
18061847
else:
18071848
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None
18081849

0 commit comments

Comments
 (0)