Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 51 additions & 10 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1735,6 +1735,37 @@ def filter_params(model_to_save, state_dict, is_optimizer=False):
return filter_tensor_list


def merge_large_tensor_parallel(tensor, tp_group, tp_action, dst_rank, is_dst):
num_rows = tensor.shape[0]
num_splits = 4
parts = np.array_split(np.arange(num_rows), num_splits)
splits = [len(part) for part in parts]
split_parts = np.insert(np.cumsum(splits), 0, 0)
split_tensors = []
for i in range(num_splits):
if get_env_device() == "xpu":
ret = distributed_allgather(tensor[split_parts[i] : split_parts[i + 1], :], group=tp_group, offload=False)
else:
ret = distributed_gather(
tensor[split_parts[i] : split_parts[i + 1], :], dst=dst_rank, group=tp_group, offload=False
)
# Copy to CPUPlace temporarily, may lower speed.
if ret is not None:
ret = [t.cpu() for t in ret]
split_tensors.append(ret)
concat_tensors = []
if is_dst:
for i in range(tp_group.nranks):
tmp = []
for j in range(num_splits):
tmp.append(split_tensors[j][i])
concat_tensors.append(paddle.concat(tmp))
tensor = tp_action(concat_tensors)
else:
tensor = None
return tensor


def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
hcg = fleet.get_hybrid_communicate_group()
tp_group = hcg.get_model_parallel_group()
Expand All @@ -1757,12 +1788,17 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
key = filter_keys[i]
tensor = state_dict[key]
if key in tp_actions:
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=tp_group, offload=False)
# Get tensor size
tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks
if tensor_bytes >= 5 * 1024 * 1024 * 1024: # temporarily set 5GB as threshold
tensor = merge_large_tensor_parallel(tensor, tp_group, tp_actions[key], j, is_dst)
else:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
action = tp_actions.pop(key)
tensor = action(ret) if is_dst else None
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=tp_group, offload=False)
else:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
action = tp_actions.pop(key)
tensor = action(ret) if is_dst else None
else:
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None

Expand Down Expand Up @@ -1797,12 +1833,17 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
if tensor.numel().item() == 1:
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None # Need broadcast when loaded
else:
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=tp_group, offload=False)
# Get tensor size
tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks
if tensor_bytes >= 5 * 1024 * 1024 * 1024: # temporarily set 5GB as threshold
tensor = merge_large_tensor_parallel(tensor, tp_group, tp_actions[model_key], j, is_dst)
else:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
action = tp_actions[model_key]
tensor = action(ret) if is_dst else None
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=tp_group, offload=False)
else:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
action = tp_actions[model_key]
tensor = action(ret) if is_dst else None
else:
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None

Expand Down