Skip to content

Commit 18cd5f2

Browse files
committed
fix expand_as kernel for big tensor
1 parent 7415b76 commit 18cd5f2

File tree

8 files changed

+50
-31
lines changed

8 files changed

+50
-31
lines changed

paddle/phi/infermeta/binary.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1816,7 +1816,7 @@ void CEmbeddingInferMeta(const MetaTensor& weight,
18161816

18171817
void ExpandAsInferMeta(const MetaTensor& x,
18181818
const MetaTensor& y,
1819-
const std::vector<int>& target_shape,
1819+
const std::vector<int64_t>& target_shape,
18201820
MetaTensor* out) {
18211821
#define MAX_RANK_SUPPORTED 8
18221822
auto x_dims = x.dims();

paddle/phi/infermeta/binary.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ void CEmbeddingInferMeta(const MetaTensor& weight,
348348

349349
void ExpandAsInferMeta(const MetaTensor& x,
350350
const MetaTensor& y,
351-
const std::vector<int>& target_shape,
351+
const std::vector<int64_t>& target_shape,
352352
MetaTensor* out);
353353

354354
void FakeDequantizeMaxAbsInferMeta(const MetaTensor& x,

paddle/phi/kernels/expand_as_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ template <typename T, typename Context>
2222
void ExpandAsKernel(const Context& ctx,
2323
const DenseTensor& x,
2424
const paddle::optional<DenseTensor>& y,
25-
const std::vector<int>& target_shape,
25+
const std::vector<int64_t>& target_shape,
2626
DenseTensor* out);
2727

2828
} // namespace phi

paddle/phi/kernels/gpu/expand_as_kernel.cu

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,21 @@
2323

2424
namespace phi {
2525

26-
template <typename T, typename Context>
27-
void ExpandAsKernel(const Context& ctx,
28-
const DenseTensor& x,
29-
const paddle::optional<DenseTensor>& y,
30-
const std::vector<int>& target_shape_t,
31-
DenseTensor* out) {
32-
std::vector<int> target_shape = target_shape_t;
26+
template <typename T, typename Context, typename ShapeType>
27+
void ExpandAsKernelImpl(const Context& ctx,
28+
const DenseTensor& x,
29+
const paddle::optional<DenseTensor>& y,
30+
const std::vector<ShapeType>& target_shape_t,
31+
DenseTensor* out) {
32+
std::vector<ShapeType> target_shape = target_shape_t;
3333

3434
if (y.get_ptr()) {
35-
target_shape = phi::vectorize<int>(y.get_ptr()->dims());
35+
target_shape = phi::vectorize<ShapeType>(y.get_ptr()->dims());
3636
}
3737

3838
int rank = x.dims().size();
3939
int target_rank = static_cast<int>(target_shape.size());
40-
auto vec_in_dims = common::vectorize<int>(x.dims());
40+
auto vec_in_dims = common::vectorize<ShapeType>(x.dims());
4141

4242
unsigned int diff = target_rank - rank;
4343
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
@@ -80,6 +80,34 @@ void ExpandAsKernel(const Context& ctx,
8080
ExpandKernel<T, Context>(ctx, x, target_shape, out);
8181
}
8282

83+
static inline std::vector<int> convert_to_int_vec(std::vector<int64_t> a) {
84+
std::vector<int> ret;
85+
for (size_t i = 0; i < a.size(); i++) {
86+
ret.emplace_back(static_cast<int>(a[i]));
87+
}
88+
89+
return ret;
90+
}
91+
92+
template <typename T, typename Context>
93+
void ExpandAsKernel(const Context& ctx,
94+
const DenseTensor& x,
95+
const paddle::optional<DenseTensor>& y,
96+
const std::vector<int64_t>& target_shape_t,
97+
DenseTensor* out) {
98+
bool use_int64 =
99+
std::any_of(target_shape_t.begin(), target_shape_t.end(), [](int64_t v) {
100+
return v > static_cast<int64_t>(std::numeric_limits<int32_t>::max());
101+
});
102+
103+
if (use_int64) {
104+
ExpandAsKernelImpl<T, Context, int64_t>(ctx, x, y, target_shape_t, out);
105+
} else {
106+
ExpandAsKernelImpl<T, Context, int32_t>(
107+
ctx, x, y, convert_to_int_vec(target_shape_t), out);
108+
}
109+
}
110+
83111
} // namespace phi
84112

85113
PD_REGISTER_KERNEL(expand_as,

paddle/phi/kernels/impl/expand_as_kernel_impl.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ namespace phi {
2727
template <typename Context, typename T, int Rank>
2828
void ExpandAs(const Context& context,
2929
const DenseTensor& x,
30-
const std::vector<int>& target_shape,
30+
const std::vector<int64_t>& target_shape,
3131
DenseTensor* out) {
3232
auto in_dims = x.dims();
3333
auto vec_in_dims = common::vectorize<int>(in_dims);
3434
auto diff = target_shape.size() - vec_in_dims.size();
3535
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
36-
std::vector<int> repeat_times(vec_in_dims.size());
36+
std::vector<int64_t> repeat_times(vec_in_dims.size());
3737
if (Rank == 0) {
3838
phi::Copy<Context>(context, x, context.GetPlace(), false, out);
3939
return;
@@ -98,7 +98,7 @@ template <typename T, typename Context>
9898
void ExpandAsKernel(const Context& ctx,
9999
const DenseTensor& x,
100100
const paddle::optional<DenseTensor>& y,
101-
const std::vector<int>& target_shape,
101+
const std::vector<int64_t>& target_shape,
102102
DenseTensor* out) {
103103
auto rank = x.dims().size();
104104
auto target_rank = target_shape.size();
@@ -124,12 +124,12 @@ void ExpandAsKernel(const Context& ctx,
124124
target_rank,
125125
MAX_RANK_SUPPORTED));
126126

127-
std::vector<int> real_target_shape = target_shape;
127+
std::vector<int64_t> real_target_shape = target_shape;
128128
for (size_t i = 0; i < target_shape.size(); ++i) {
129129
if (target_shape[i] == -1) {
130130
if (y) {
131131
if (y->IsInitialized()) {
132-
real_target_shape = common::vectorize<int>(y->dims());
132+
real_target_shape = common::vectorize<int64_t>(y->dims());
133133
}
134134
}
135135
break;

paddle/phi/kernels/impl/solve_kernel_impl.h

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,6 @@ static std::vector<int64_t> get_broadcast_batch_portion(
7676
return batchPortion;
7777
}
7878

79-
static inline std::vector<int> convert_to_int_vec(std::vector<int64_t> a) {
80-
std::vector<int> ret;
81-
for (size_t i = 0; i < a.size(); i++) {
82-
ret.emplace_back(static_cast<int>(a[i]));
83-
}
84-
85-
return ret;
86-
}
87-
8879
// broadcast the batch dimensions of tensor x and tensor y.
8980
static inline std::tuple<std::vector<int64_t>, std::vector<int64_t>>
9081
get_broadcast_dims(const Tensor& x, const Tensor& y) {
@@ -150,11 +141,11 @@ static void linalg_solve(const Context& dev_ctx,
150141
Tensor tmp_x_bc;
151142

152143
phi::ExpandAsKernel<T, Context>(
153-
dev_ctx, tmp_x, nullptr, convert_to_int_vec(x_broadcast_dims), &tmp_x_bc);
144+
dev_ctx, tmp_x, nullptr, x_broadcast_dims, &tmp_x_bc);
154145

155146
Tensor tmp_y_bc;
156147
phi::ExpandAsKernel<T, Context>(
157-
dev_ctx, tmp_y, nullptr, convert_to_int_vec(y_broadcast_dims), &tmp_y_bc);
148+
dev_ctx, tmp_y, nullptr, y_broadcast_dims, &tmp_y_bc);
158149

159150
auto x_dim = x.dims();
160151
auto y_dim = y.dims();

paddle/phi/ops/yaml/backward.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -986,8 +986,8 @@
986986
composite : exp_grad(out, out_grad, x_grad)
987987

988988
- backward_op : expand_as_grad
989-
forward : expand_as (Tensor x, Tensor y, int[] target_shape = {}) -> Tensor(out)
990-
args : (Tensor x, Tensor out_grad, int[] target_shape)
989+
forward : expand_as (Tensor x, Tensor y, int64_t[] target_shape = {}) -> Tensor(out)
990+
args : (Tensor x, Tensor out_grad, int64_t[] target_shape)
991991
output : Tensor(x_grad)
992992
infer_meta :
993993
func : UnchangedInferMeta

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1754,7 +1754,7 @@
17541754
backward : expand_grad
17551755

17561756
- op : expand_as
1757-
args : (Tensor x, Tensor y, int[] target_shape = {})
1757+
args : (Tensor x, Tensor y, int64_t[] target_shape = {})
17581758
output : Tensor(out)
17591759
infer_meta :
17601760
func : ExpandAsInferMeta

0 commit comments

Comments
 (0)