@@ -74,23 +74,24 @@ class ConcateComputeTester : public arena::TestCase {
7474 x_vct.push_back (scope->FindTensor (name));
7575 }
7676
77+ int axis = axis_ < 0 ? axis_ + static_cast <int >(x_dims_.size ()) : axis_;
7778 auto * out = scope->NewTensor (out_);
78- DDim output_dims = infer_shape (x_vct, axis_ );
79+ DDim output_dims = infer_shape (x_vct, axis );
7980 out->Resize (output_dims);
8081 auto * output_data = out->mutable_data <float >();
8182
8283 int num = x_vct.size ();
8384 int rows = 1 ;
8485 auto dim_0 = x_vct[0 ]->dims ();
85- for (int i = 0 ; i < axis_ ; ++i) {
86+ for (int i = 0 ; i < axis ; ++i) {
8687 rows *= dim_0[i];
8788 }
8889 int out_rows = rows, out_cols = 0 ;
8990
9091 std::vector<int > input_cols (x_vct.size ());
9192 for (int i = 0 ; i < num; ++i) {
9293 int input_i_numel = x_vct[i]->dims ().size () == 0 ? 0 : 1 ;
93- for (int didx = 0 ; didx < x_vct[i]->dims ().size (); ++didx) {
94+ for (size_t didx = 0 ; didx < x_vct[i]->dims ().size (); ++didx) {
9495 input_i_numel *= x_vct[i]->dims ()[didx];
9596 }
9697 int t_cols = input_i_numel / rows;
@@ -142,12 +143,20 @@ class ConcateComputeTester : public arena::TestCase {
142143TEST (Concat, precision) {
143144 Place place;
144145 float abs_error = 2e-5 ;
145- #if defined(LITE_WITH_NPU)
146+ std::vector<int > axes{-1 , 1 , 2 };
147+ std::vector<bool > use_axis_tensor{false , true };
148+ #if defined(LITE_WITH_XPU) && !defined(LITE_WITH_XTCL)
149+ place = TARGET (kXPU );
150+ use_axis_tensor = std::vector<bool >{false };
151+ #elif defined(LITE_WITH_NPU)
146152 place = TARGET (kNPU );
147153 abs_error = 1e-2 ; // use fp16 in npu
154+ axes = std::vector<int >{1 , 2 };
155+ use_axis_tensor = std::vector<bool >{false };
148156#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
149157 place = TARGET (kHuaweiAscendNPU );
150158 abs_error = 1e-2 ; // precision_mode default is force_fp16
159+ axes = std::vector<int >{1 , 2 };
151160#elif defined(LITE_WITH_ARM)
152161 place = TARGET (kARM );
153162#elif defined(LITE_WITH_X86)
@@ -156,11 +165,8 @@ TEST(Concat, precision) {
156165 return ;
157166#endif
158167
159- for (int axis : {1 , 2 }) {
160- for (bool is_use_axis_tensor : {false , true }) {
161- #ifdef LITE_WITH_NPU
162- if (is_use_axis_tensor) continue ;
163- #endif
168+ for (int axis : axes) {
169+ for (bool is_use_axis_tensor : use_axis_tensor) {
164170 std::unique_ptr<arena::TestCase> tester (
165171 new ConcateComputeTester (place, " def" , axis, is_use_axis_tensor));
166172 arena::Arena arena (std::move (tester), place, abs_error);
0 commit comments