Skip to content

Commit 24c85bf

Browse files
committed
fix concat when axis < 0; test=develop
1 parent e3a4de0 commit 24c85bf

File tree

5 files changed

+26
-15
lines changed

5 files changed

+26
-15
lines changed

lite/kernels/arm/concat_compute.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ void ConcatCompute::Run() {
6868
axis = axis_tensor_data[0];
6969
}
7070
if (axis < 0) {
71-
axis += inputs[0]->dims().size();
71+
axis += static_cast<int>(inputs[0]->dims().size());
7272
}
7373

7474
lite_api::PrecisionType type = PRECISION(kUnk);

lite/kernels/x86/concat_compute.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,17 @@ class ConcatCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
4444
return;
4545
}
4646

47-
int64_t axis = static_cast<int64_t>(param.axis);
47+
int axis = param.axis;
4848
auto* axis_tensor = param.axis_tensor;
4949
if (axis_tensor != nullptr) {
5050
auto* axis_tensor_data = axis_tensor->template data<int>();
51-
axis = static_cast<int64_t>(axis_tensor_data[0]);
51+
axis = axis_tensor_data[0];
5252
}
53-
5453
const auto& x_dims = param.x[0]->dims();
54+
if (axis < 0) {
55+
axis += static_cast<int>(x_dims.size());
56+
}
57+
5558
auto* out = param.output;
5659
T* output_data = param.output->template mutable_data<T>();
5760

lite/kernels/xpu/concat_compute.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ void ConcatCompute<InType>::Run() {
2929

3030
auto ins = param.x;
3131
auto out = param.output;
32-
int64_t axis = param.axis;
32+
int64_t axis = param.axis < 0
33+
? param.axis + static_cast<int>(ins[0]->dims().size())
34+
: param.axis;
3335

3436
std::vector<const float*> x_list;
3537
std::vector<std::vector<int>> xdims_list;

lite/operators/concat_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ bool ConcatOpLite::InferShapeImpl() const {
3939
axis = axis_tensor_val[0];
4040
}
4141
if (axis < 0) {
42-
axis += inputs[0]->dims().size();
42+
axis += static_cast<int>(inputs[0]->dims().size());
4343
}
4444

4545
auto out_dims = inputs[0]->dims();

lite/tests/kernels/concat_compute_test.cc

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,23 +74,24 @@ class ConcateComputeTester : public arena::TestCase {
7474
x_vct.push_back(scope->FindTensor(name));
7575
}
7676

77+
int axis = axis_ < 0 ? axis_ + static_cast<int>(x_dims_.size()) : axis_;
7778
auto* out = scope->NewTensor(out_);
78-
DDim output_dims = infer_shape(x_vct, axis_);
79+
DDim output_dims = infer_shape(x_vct, axis);
7980
out->Resize(output_dims);
8081
auto* output_data = out->mutable_data<float>();
8182

8283
int num = x_vct.size();
8384
int rows = 1;
8485
auto dim_0 = x_vct[0]->dims();
85-
for (int i = 0; i < axis_; ++i) {
86+
for (int i = 0; i < axis; ++i) {
8687
rows *= dim_0[i];
8788
}
8889
int out_rows = rows, out_cols = 0;
8990

9091
std::vector<int> input_cols(x_vct.size());
9192
for (int i = 0; i < num; ++i) {
9293
int input_i_numel = x_vct[i]->dims().size() == 0 ? 0 : 1;
93-
for (int didx = 0; didx < x_vct[i]->dims().size(); ++didx) {
94+
for (size_t didx = 0; didx < x_vct[i]->dims().size(); ++didx) {
9495
input_i_numel *= x_vct[i]->dims()[didx];
9596
}
9697
int t_cols = input_i_numel / rows;
@@ -142,12 +143,20 @@ class ConcateComputeTester : public arena::TestCase {
142143
TEST(Concat, precision) {
143144
Place place;
144145
float abs_error = 2e-5;
145-
#if defined(LITE_WITH_NPU)
146+
std::vector<int> axes{-1, 1, 2};
147+
std::vector<bool> use_axis_tensor{false, true};
148+
#if defined(LITE_WITH_XPU) && !defined(LITE_WITH_XTCL)
149+
place = TARGET(kXPU);
150+
use_axis_tensor = std::vector<bool>{false};
151+
#elif defined(LITE_WITH_NPU)
146152
place = TARGET(kNPU);
147153
abs_error = 1e-2; // use fp16 in npu
154+
axes = std::vector<int>{1, 2};
155+
use_axis_tensor = std::vector<bool>{false};
148156
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
149157
place = TARGET(kHuaweiAscendNPU);
150158
abs_error = 1e-2; // precision_mode default is force_fp16
159+
axes = std::vector<int>{1, 2};
151160
#elif defined(LITE_WITH_ARM)
152161
place = TARGET(kARM);
153162
#elif defined(LITE_WITH_X86)
@@ -156,11 +165,8 @@ TEST(Concat, precision) {
156165
return;
157166
#endif
158167

159-
for (int axis : {1, 2}) {
160-
for (bool is_use_axis_tensor : {false, true}) {
161-
#ifdef LITE_WITH_NPU
162-
if (is_use_axis_tensor) continue;
163-
#endif
168+
for (int axis : axes) {
169+
for (bool is_use_axis_tensor : use_axis_tensor) {
164170
std::unique_ptr<arena::TestCase> tester(
165171
new ConcateComputeTester(place, "def", axis, is_use_axis_tensor));
166172
arena::Arena arena(std::move(tester), place, abs_error);

0 commit comments

Comments
 (0)