Skip to content

Commit 7ae5373

Browse files
ggggxmwanghuancoder
authored andcommitted
[PHI] Fix flatten and split kernel for big tensor (PaddlePaddle#72634)
1 parent 1bf96d4 commit 7ae5373

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -909,7 +909,7 @@ Tensor flatten_decomp(const Tensor& x, int start_axis, int end_axis) {
909909
return reshape<T>(x, x_dim);
910910
}
911911

912-
int slice_numel = 1;
912+
int64_t slice_numel = 1;
913913
for (int i = start_axis; i <= end_axis; ++i) {
914914
slice_numel *= x_dim[i];
915915
}

paddle/phi/infermeta/unary.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1617,7 +1617,7 @@ void FlattenInferMeta(const MetaTensor& x,
16171617
}
16181618

16191619
int64_t outer = 1;
1620-
std::vector<int32_t> out_shape;
1620+
std::vector<int64_t> out_shape;
16211621
out_shape.reserve(in_dims_size - stop_axis + start_axis + 1);
16221622

16231623
for (int i = 0; i < start_axis; ++i) {
@@ -4380,7 +4380,7 @@ void SplitInferMeta(const MetaTensor& x,
43804380
axis_value == -1) { // NOLINT
43814381
out_dims = std::vector<phi::DDim>(
43824382
sections_data.size(),
4383-
common::make_ddim(std::vector<int>(x.dims().size(), -1)));
4383+
common::make_ddim(std::vector<int64_t>(x.dims().size(), -1)));
43844384
} else {
43854385
out_dims = std::vector<phi::DDim>(sections_data.size(), x.dims());
43864386
}
@@ -4403,7 +4403,7 @@ void SplitInferMeta(const MetaTensor& x,
44034403
const int unknow_dim_val = -1;
44044404
int unknow_dim_idx = -1;
44054405
int num_of_unknow = 0;
4406-
int sum_of_section = 0;
4406+
int64_t sum_of_section = 0;
44074407

44084408
for (int i = 0; i < static_cast<int>(sections_data.size()); ++i) {
44094409
sections_vec.push_back(sections_data[i]);
@@ -4412,7 +4412,7 @@ void SplitInferMeta(const MetaTensor& x,
44124412
num_of_unknow++;
44134413
unknow_dim_idx = i;
44144414
} else {
4415-
sum_of_section += static_cast<int>(sections_data[i]);
4415+
sum_of_section += static_cast<int64_t>(sections_data[i]);
44164416
}
44174417
}
44184418

0 commit comments

Comments
 (0)