Skip to content

Commit c252b1d

Browse files
authored
Simplify size op impl (PaddlePaddle#45808)
* simplify size op * trans to cuda manuly * fix copy error
1 parent 7d00011 commit c252b1d

5 files changed

Lines changed: 23 additions & 78 deletions

File tree

paddle/phi/kernels/cpu/size_kernel.cc

Lines changed: 0 additions & 32 deletions
This file was deleted.

paddle/phi/kernels/gpu/size_kernel.cu

Lines changed: 0 additions & 31 deletions
This file was deleted.

paddle/phi/kernels/impl/size_kernel_impl.h renamed to paddle/phi/kernels/size_kernel.cc

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,33 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#pragma once
15+
#include "paddle/phi/kernels/size_kernel.h"
1616

17+
#include "paddle/phi/core/kernel_registry.h"
1718
#include "paddle/phi/core/tensor_utils.h"
1819

1920
namespace phi {
2021

21-
template <typename T, typename Context>
22+
template <typename Context>
2223
void SizeKernel(const Context& ctx,
2324
const DenseTensor& input,
2425
DenseTensor* out) {
25-
auto place = ctx.GetPlace();
26-
auto out_data = ctx.template Alloc<int64_t>(out);
27-
auto cpu_place = phi::CPUPlace();
28-
if (place == cpu_place) {
29-
out_data[0] = input.numel();
30-
} else {
31-
DenseTensor cpu_tensor;
32-
cpu_tensor.Resize(out->dims());
33-
auto cpu_data = ctx.template HostAlloc<int64_t>(&cpu_tensor);
34-
cpu_data[0] = input.numel();
35-
phi::Copy(ctx, cpu_tensor, place, false, out);
36-
}
26+
auto* out_data = ctx.template HostAlloc<int64_t>(out);
27+
out_data[0] = input.numel();
3728
}
3829

3930
} // namespace phi
31+
32+
PD_REGISTER_GENERAL_KERNEL(
33+
size, CPU, ALL_LAYOUT, phi::SizeKernel<phi::CPUContext>, ALL_DTYPE) {
34+
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
35+
}
36+
37+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
38+
PD_REGISTER_GENERAL_KERNEL(
39+
size, GPU, ALL_LAYOUT, phi::SizeKernel<phi::GPUContext>, ALL_DTYPE) {
40+
kernel->OutputAt(0)
41+
.SetBackend(phi::Backend::CPU)
42+
.SetDataType(phi::DataType::INT64);
43+
}
44+
#endif

paddle/phi/kernels/size_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
namespace phi {
2020

21-
template <typename T, typename Context>
21+
template <typename Context>
2222
void SizeKernel(const Context& ctx, const DenseTensor& input, DenseTensor* out);
2323

2424
} // namespace phi

python/paddle/distributed/collective.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,9 @@ def all_gather_object(object_list, obj, group=None):
11401140
), "all_gather_object doesn't support static graph mode."
11411141

11421142
tensor, len_of_tensor = _convert_object_to_tensor(obj)
1143+
if paddle.get_device() != "cpu":
1144+
len_of_tensor = len_of_tensor._copy_to(
1145+
paddle.framework._current_expected_place(), False)
11431146

11441147
# gather len_of_tensor from all ranks
11451148
list_len_of_tensor = []

0 commit comments

Comments
 (0)