@@ -80,7 +80,6 @@ class ConcatGradNPUKernel : public framework::OpKernel<T> {
8080 axis = ComputeAxis (static_cast <int64_t >(axis),
8181 static_cast <int64_t >(ins[0 ]->dims ().size ()));
8282
83- std::vector<int > sizes;
8483 int offset = 0 ;
8584 auto stream =
8685 ctx.template device_context <paddle::platform::NPUDeviceContext>()
@@ -91,7 +90,6 @@ class ConcatGradNPUKernel : public framework::OpKernel<T> {
9190 if (out_var_names[j] != framework::kEmptyVarName &&
9291 outs[j]->numel () != 0UL ) {
9392 outs[j]->mutable_data <T>(ctx.GetPlace ());
94- sizes.push_back (outs[j]->dims ()[axis]);
9593 std::vector<int > offsets;
9694 std::vector<int > sizes;
9795 for (int dim = 0 ; dim < ins[j]->dims ().size (); ++dim) {
@@ -103,9 +101,8 @@ class ConcatGradNPUKernel : public framework::OpKernel<T> {
103101 sizes.push_back (ins[j]->dims ()[dim]);
104102 }
105103 }
106- auto runner =
107- NpuOpRunner (" SliceD" , {*out_grad}, {*outs[j]},
108- {{" offsets" , offset}, {" size" , ins[j]->dims ()[axis]}});
104+ auto runner = NpuOpRunner (" SliceD" , {*out_grad}, {*outs[j]},
105+ {{" offsets" , offsets}, {" size" , sizes}});
109106 runner.Run (stream);
110107 }
111108 if (ins[j]->numel () != 0UL ) {
0 commit comments