|
23 | 23 |
|
24 | 24 | namespace phi { |
25 | 25 |
|
26 | | -template <typename T, typename Context> |
27 | | -void ExpandAsKernel(const Context& ctx, |
28 | | - const DenseTensor& x, |
29 | | - const paddle::optional<DenseTensor>& y, |
30 | | - const std::vector<int>& target_shape_t, |
31 | | - DenseTensor* out) { |
32 | | - std::vector<int> target_shape = target_shape_t; |
| 26 | +template <typename T, typename Context, typename ShapeType> |
| 27 | +void ExpandAsKernelImpl(const Context& ctx, |
| 28 | + const DenseTensor& x, |
| 29 | + const paddle::optional<DenseTensor>& y, |
| 30 | + const std::vector<ShapeType>& target_shape_t, |
| 31 | + DenseTensor* out) { |
| 32 | + std::vector<ShapeType> target_shape = target_shape_t; |
33 | 33 |
|
34 | 34 | if (y.get_ptr()) { |
35 | | - target_shape = phi::vectorize<int>(y.get_ptr()->dims()); |
| 35 | + target_shape = phi::vectorize<ShapeType>(y.get_ptr()->dims()); |
36 | 36 | } |
37 | 37 |
|
38 | 38 | int rank = x.dims().size(); |
39 | 39 | int target_rank = static_cast<int>(target_shape.size()); |
40 | | - auto vec_in_dims = common::vectorize<int>(x.dims()); |
| 40 | + auto vec_in_dims = common::vectorize<ShapeType>(x.dims()); |
41 | 41 |
|
42 | 42 | unsigned int diff = target_rank - rank; |
43 | 43 | vec_in_dims.insert(vec_in_dims.begin(), diff, 1); |
@@ -80,6 +80,34 @@ void ExpandAsKernel(const Context& ctx, |
80 | 80 | ExpandKernel<T, Context>(ctx, x, target_shape, out); |
81 | 81 | } |
82 | 82 |
|
| 83 | +static inline std::vector<int> convert_to_int_vec(std::vector<int64_t> a) { |
| 84 | + std::vector<int> ret; |
| 85 | + for (size_t i = 0; i < a.size(); i++) { |
| 86 | + ret.emplace_back(static_cast<int>(a[i])); |
| 87 | + } |
| 88 | + |
| 89 | + return ret; |
| 90 | +} |
| 91 | + |
| 92 | +template <typename T, typename Context> |
| 93 | +void ExpandAsKernel(const Context& ctx, |
| 94 | + const DenseTensor& x, |
| 95 | + const paddle::optional<DenseTensor>& y, |
| 96 | + const std::vector<int64_t>& target_shape_t, |
| 97 | + DenseTensor* out) { |
| 98 | + bool use_int64 = |
| 99 | + std::any_of(target_shape_t.begin(), target_shape_t.end(), [](int64_t v) { |
| 100 | + return v > static_cast<int64_t>(std::numeric_limits<int32_t>::max()); |
| 101 | + }); |
| 102 | + |
| 103 | + if (use_int64) { |
| 104 | + ExpandAsKernelImpl<T, Context, int64_t>(ctx, x, y, target_shape_t, out); |
| 105 | + } else { |
| 106 | + ExpandAsKernelImpl<T, Context, int32_t>( |
| 107 | + ctx, x, y, convert_to_int_vec(target_shape_t), out); |
| 108 | + } |
| 109 | +} |
| 110 | + |
83 | 111 | } // namespace phi |
84 | 112 |
|
85 | 113 | PD_REGISTER_KERNEL(expand_as, |
|
0 commit comments