From 85626933ef8b4bdfae00eb782fba652efd13ddcd Mon Sep 17 00:00:00 2001 From: irestonelib Date: Mon, 23 Aug 2021 20:33:46 +0800 Subject: [PATCH 1/3] fix bmm bug --- paddle/fluid/operators/bmm_op.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddle/fluid/operators/bmm_op.h b/paddle/fluid/operators/bmm_op.h index 49104d4f08d288..b285fd5a840736 100644 --- a/paddle/fluid/operators/bmm_op.h +++ b/paddle/fluid/operators/bmm_op.h @@ -64,6 +64,10 @@ class BmmKernel : public framework::OpKernel { Tensor *out = context.Output("Out"); out->mutable_data(context.GetPlace()); + if(x.dims()[0]==0 || y.dims()[0]==0){ + return; + } + auto blas = math::GetBlas(context); auto mat_dim_a = math::CreateMatrixDescriptor(x.dims(), 0, false); From 56bd869534f87410d057a579e3664fa2490cab6f Mon Sep 17 00:00:00 2001 From: irestonelib Date: Tue, 24 Aug 2021 10:21:36 +0800 Subject: [PATCH 2/3] bmm style --- paddle/fluid/operators/bmm_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/bmm_op.h b/paddle/fluid/operators/bmm_op.h index b285fd5a840736..22edb6ad1d4666 100644 --- a/paddle/fluid/operators/bmm_op.h +++ b/paddle/fluid/operators/bmm_op.h @@ -64,7 +64,7 @@ class BmmKernel : public framework::OpKernel { Tensor *out = context.Output("Out"); out->mutable_data(context.GetPlace()); - if(x.dims()[0]==0 || y.dims()[0]==0){ + if (x.dims()[0] == 0 || y.dims()[0] == 0) { return; } From ce19e55505bf6212eee5793042928e30b67b388b Mon Sep 17 00:00:00 2001 From: irestonelib Date: Tue, 24 Aug 2021 13:08:52 +0800 Subject: [PATCH 3/3] fix bmm --- paddle/fluid/operators/bmm_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/bmm_op.h b/paddle/fluid/operators/bmm_op.h index 22edb6ad1d4666..15cd6de91365e0 100644 --- a/paddle/fluid/operators/bmm_op.h +++ b/paddle/fluid/operators/bmm_op.h @@ -64,7 +64,7 @@ class BmmKernel : public framework::OpKernel { Tensor *out = context.Output("Out"); out->mutable_data(context.GetPlace()); - if (x.dims()[0] == 0 || y.dims()[0] == 0) { + if (x.numel() == 0 || y.numel() == 0) { return; }