|
93 | 93 | async_save_queue = [] |
94 | 94 |
|
95 | 95 |
|
| 96 | +DEST_PLACE = paddle.CPUPlace() |
| 97 | +if paddle.device.is_compiled_with_cuda(): |
| 98 | + DEST_PLACE = paddle.CUDAPinnedPlace() |
| 99 | + |
| 100 | + |
96 | 101 | class UnifiedCheckpointOption(ExplicitEnum): |
97 | 102 | """ |
98 | 103 | "- skip_save_model_weight: do not save model weights when the masters weight exist\n" |
@@ -1746,7 +1751,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): |
1746 | 1751 | action = tp_actions.pop(key) |
1747 | 1752 | tensor = action(ret) if is_dst else None |
1748 | 1753 | else: |
1749 | | - tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None |
| 1754 | + tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None |
1750 | 1755 |
|
1751 | 1756 | if is_dst: |
1752 | 1757 | state_dict_to_save[key] = tensor |
@@ -1777,15 +1782,13 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys) |
1777 | 1782 | if model_key in tp_actions: |
1778 | 1783 | # for example: beta1, beta2 |
1779 | 1784 | if tensor.numel().item() == 1: |
1780 | | - tensor = ( |
1781 | | - tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None |
1782 | | - ) # Need broadcast when loaded |
| 1785 | + tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None # Need broadcast when loaded |
1783 | 1786 | else: |
1784 | 1787 | ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) |
1785 | 1788 | action = tp_actions[model_key] |
1786 | 1789 | tensor = action(ret) if is_dst else None |
1787 | 1790 | else: |
1788 | | - tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None |
| 1791 | + tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None |
1789 | 1792 |
|
1790 | 1793 | if is_dst: |
1791 | 1794 | state_dict_to_save[filter_keys[i]] = tensor |
|
0 commit comments