Skip to content

Commit 9146c1e

Browse files
authored
fix compatible with npu. (#8409)
1 parent 2619f17 commit 9146c1e

1 file changed

Lines changed: 8 additions & 5 deletions

File tree

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@
9393
async_save_queue = []
9494

9595

96+
DEST_PLACE = paddle.CPUPlace()
97+
if paddle.device.is_compiled_with_cuda():
98+
DEST_PLACE = paddle.CUDAPinnedPlace()
99+
100+
96101
class UnifiedCheckpointOption(ExplicitEnum):
97102
"""
98103
"- skip_save_model_weight: do not save model weights when the masters weight exist\n"
@@ -1746,7 +1751,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
17461751
action = tp_actions.pop(key)
17471752
tensor = action(ret) if is_dst else None
17481753
else:
1749-
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None
1754+
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None
17501755

17511756
if is_dst:
17521757
state_dict_to_save[key] = tensor
@@ -1777,15 +1782,13 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
17771782
if model_key in tp_actions:
17781783
# for example: beta1, beta2
17791784
if tensor.numel().item() == 1:
1780-
tensor = (
1781-
tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None
1782-
) # Need broadcast when loaded
1785+
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None # Need broadcast when loaded
17831786
else:
17841787
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
17851788
action = tp_actions[model_key]
17861789
tensor = action(ret) if is_dst else None
17871790
else:
1788-
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None
1791+
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None
17891792

17901793
if is_dst:
17911794
state_dict_to_save[filter_keys[i]] = tensor

0 commit comments

Comments
 (0)