Skip to content

Commit f07de67

Browse files
committed
fix concat_grad
1 parent 4b36ee2 commit f07de67

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

paddle/fluid/operators/concat_op_npu.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ class ConcatGradNPUKernel : public framework::OpKernel<T> {
8080
axis = ComputeAxis(static_cast<int64_t>(axis),
8181
static_cast<int64_t>(ins[0]->dims().size()));
8282

83-
std::vector<int> sizes;
8483
int offset = 0;
8584
auto stream =
8685
ctx.template device_context<paddle::platform::NPUDeviceContext>()
@@ -91,7 +90,6 @@ class ConcatGradNPUKernel : public framework::OpKernel<T> {
9190
if (out_var_names[j] != framework::kEmptyVarName &&
9291
outs[j]->numel() != 0UL) {
9392
outs[j]->mutable_data<T>(ctx.GetPlace());
94-
sizes.push_back(outs[j]->dims()[axis]);
9593
std::vector<int> offsets;
9694
std::vector<int> sizes;
9795
for (int dim = 0; dim < ins[j]->dims().size(); ++dim) {
@@ -103,9 +101,8 @@ class ConcatGradNPUKernel : public framework::OpKernel<T> {
103101
sizes.push_back(ins[j]->dims()[dim]);
104102
}
105103
}
106-
auto runner =
107-
NpuOpRunner("SliceD", {*out_grad}, {*outs[j]},
108-
{{"offsets", offset}, {"size", ins[j]->dims()[axis]}});
104+
auto runner = NpuOpRunner("SliceD", {*out_grad}, {*outs[j]},
105+
{{"offsets", offsets}, {"size", sizes}});
109106
runner.Run(stream);
110107
}
111108
if (ins[j]->numel() != 0UL) {

0 commit comments

Comments
 (0)