@@ -386,21 +386,18 @@ def transform(t, device, dtype, blocking):
386386 device = t .place
387387 if dtype is None :
388388 dtype = t .dtype
389+ if type (dtype ) is str :
390+ dtype = framework .convert_np_dtype_to_dtype_ (dtype )
389391
390392 # 1. gpu place need to determine whether the memory is sufficient for allocation.
391393 if t .place .is_gpu_place ():
392- gpu_memory_available = core .gpu_memory_available ()
393- # for gpu, minimum memory allocation unit is 256 bytes.
394- if type (dtype ) is str :
395- size_dtype = core .size_of_dtype (
396- framework .convert_np_dtype_to_dtype_ (dtype ))
397- else :
398- size_dtype = core .size_of_dtype (dtype )
394+ size_dtype = core .size_of_dtype (dtype )
399395 # Note(weilong wu): Paddle GPU minimum memory allocation unit is 256 bytes,
400396 # waiting_alloc_memory will compute the memory space occupied by 't'.
401397 # Coefficient 1.2 is used to avoid OOM that may occur in this critical state when the memory is just enough.
402398 waiting_alloc_memory = (
403399 (t ._numel () * size_dtype ) / 256 + 1 ) * 256 * 1.2
400+ gpu_memory_available = core .gpu_memory_available ()
404401 if gpu_memory_available < waiting_alloc_memory :
405402 # Copy Tensor to cpu
406403 t_used = t ._copy_to (paddle .CPUPlace (), blocking )
@@ -414,12 +411,17 @@ def transform(t, device, dtype, blocking):
414411
415412 # 2. cast Tensor to dtype
416413 if dtype is not None and dtype != t_used .dtype :
417- t_casted = t_used .cast (dtype = dtype )
414+ with paddle .fluid .framework ._dygraph_place_guard (
415+ place = t_used .place ):
416+ t_casted = t_used .cast (dtype = dtype )
418417 else :
419418 t_casted = t_used
420419
421420 # 3. Copy casted Tensor(in CPU or GPU) to device
422- new_t = t_casted ._copy_to (device , blocking )
421+ if device is not None and not t_casted .place ._equals (device ):
422+ new_t = t_casted ._copy_to (device , blocking )
423+ else :
424+ new_t = t_casted
423425
424426 # 4. Share Tensor to origin Tensor
425427 dst_tensor = t .value ().get_tensor ()
0 commit comments