diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py index 959f9eb49f40ff..b86c5f8df71747 100644 --- a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -20,6 +20,7 @@ import paddle from paddle.framework import ( + _current_expected_place_, base as imperative_base, core, ) @@ -33,6 +34,7 @@ class HOOK_ACTION: alignment = { "gpu": 256, + "npu": 256, } align = { @@ -42,6 +44,28 @@ class HOOK_ACTION: } +__current_device_type__ = None + + +def get_current_device_type(): + global __current_device_type__ + if __current_device_type__ is None: + if paddle.is_compiled_with_cuda(): + device_type = "gpu" + elif paddle.is_compiled_with_xpu(): + device_type = "xpu" + elif paddle.is_compiled_with_custom_device(): + current_device = _current_expected_place_() + device_type = current_device.get_device_type() + else: + device_type = "unknown" + assert ( + device_type in alignment.keys() + ), f"tensor fusion helper now only support {alignment.keys()}, but got device {device_type} instead." + __current_device_type__ = device_type + return __current_device_type__ + + def assign_group_by_size(parameters, group_size=128 * 1024 * 1024): is_sparse_gradient = [False] * len(parameters) @@ -76,8 +100,12 @@ def flatten_dense_tensors( for param in parameters: assert param.trainable, "param must be trainable..." size = np.prod(param.shape) * align[dtype] - remaining = size % alignment["gpu"] - ali = 0 if remaining == 0 else alignment["gpu"] - remaining + remaining = size % alignment[get_current_device_type()] + ali = ( + 0 + if remaining == 0 + else alignment[get_current_device_type()] - remaining + ) align_ = ali // align[dtype] _param2offset[param.name] = _buffer_size _buffer_size += np.prod(param.shape) + align_ @@ -88,7 +116,7 @@ def flatten_dense_tensors( if fuse_param: param_storage = ParamStorage( - size=_buffer_size, dtype=dtype, device="gpu" + size=_buffer_size, dtype=dtype, device=get_current_device_type() ) param_storage.add_rank_params(parameters, _param2align) @@ -97,7 +125,7 @@ def flatten_dense_tensors( grad_storage = GradStorage( size=_buffer_size, dtype=grad_dtype, - device="gpu", + device=get_current_device_type(), destination="0", parm2align=_param2align, ) @@ -261,7 +289,7 @@ def build_reduce_scatter_buffer( def get_padded_size(param): size = np.prod(param.shape) - align_size = alignment["gpu"] // align[dtype] + align_size = alignment[get_current_device_type()] // align[dtype] align_size = align_size * sharding_degree padded_size = ((size + align_size - 1) // align_size) * align_size return padded_size