diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index 313607d975e60a..830eae1a51d4a2 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -112,22 +112,40 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, const framework::Tensor* out, const framework::Tensor* dout, framework::Tensor* dx, framework::Tensor* dy) { - auto* dx_data = dx->mutable_data(ctx.GetPlace()); - auto* dy_data = dy->mutable_data(ctx.GetPlace()); - auto* dout_data = dout->data(); - if (dx_data == dout_data && dy_data != dout_data) { - VLOG(4) << "Special case when dx_data is the same as dout_data, " - "only need copy dout to dy"; - framework::TensorCopy( - *dout, ctx.GetPlace(), - ctx.template device_context(), dy); - } else if (dx_data != dout_data && dy_data == dout_data) { - VLOG(4) << "Special case when dy_data is the same as dout_data, " - "only need copy dout to dx"; - framework::TensorCopy( - *dout, ctx.GetPlace(), - ctx.template device_context(), dx); - } else if (dx_data != dout_data && dy_data != dout_data) { + bool dims_same = (dx->dims() == dout->dims()) && (dy->dims() == dout->dims()); + bool using_tensorcopy = false; + if (dims_same) { + auto* dx_data = dx->mutable_data(ctx.GetPlace()); + auto* dy_data = dy->mutable_data(ctx.GetPlace()); + auto* dout_data = dout->data(); + bool dx_data_same = (dx_data == dout_data); + bool dy_data_same = (dy_data == dout_data); + if (dx_data_same && !dy_data_same) { + using_tensorcopy = true; + VLOG(4) << "Special case when dx_data is the same as dout_data, " + "only need copy dout to dy"; + framework::TensorCopy( + *dout, ctx.GetPlace(), + ctx.template device_context(), dy); + } else if (dy_data_same && !dx_data_same) { + using_tensorcopy = true; + VLOG(4) << "Special case when dy_data is the same as dout_data, " + "only need copy dout to dx"; + framework::TensorCopy( + *dout, ctx.GetPlace(), + ctx.template device_context(), dx); + } else if (dy_data_same && dx_data_same) { + using_tensorcopy = true; + VLOG(4) << "Special case when dy_data is the same as dout_data, " + "and dx_data is the same as dout_data, do not need " + "any operator"; + } else { + // need copy dout to two tensor: dx and dy + // using SimpleElemwiseAddGradCUDAKernel faster + } + } + + if (!using_tensorcopy) { auto size = x->numel(); int vec_size = max(static_cast(sizeof(float4) / sizeof(T)), 1); dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); @@ -140,10 +158,6 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, ctx.template device_context().stream()>>>( dout->data(), size, vec_size, dx->mutable_data(ctx.GetPlace()), dy->mutable_data(ctx.GetPlace())); - } else { - VLOG(4) << "Special case when dy_data is the same as dout_data, " - "and dx_data is the same as dout_data, do not need " - "any operator"; } }