Skip to content

Commit b647c1f

Browse files
committed
modify code according to the review
1 parent 335bed0 commit b647c1f

File tree

1 file changed

+10
-71
lines changed

1 file changed

+10
-71
lines changed

paddle/fluid/operators/multi_dot_op.cc

Lines changed: 10 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@ limitations under the License. */
2222
#include "paddle/fluid/operators/strided_memcpy.h"
2323
#include "paddle/fluid/operators/utils.h"
2424

25-
#ifdef PADDLE_WITH_MKLDNN
26-
#include "paddle/fluid/platform/mkldnn_helper.h"
27-
#endif
28-
2925
namespace paddle {
3026
namespace operators {
3127
using Tensor = framework::Tensor;
@@ -41,11 +37,11 @@ inline framework::DDim ComputeAndCheckShape(
4137
bool is_vector = false;
4238
framework::DDim out_dim;
4339

44-
if (first_dim.size() > 2) {
45-
PADDLE_THROW(platform::errors::InvalidArgument(
46-
"multi_dot: the first input tensor must be 1D or 2D but got[%d]!",
47-
static_cast<int>(first_dim.size())));
48-
}
40+
PADDLE_ENFORCE_LT(
41+
first_dim.size(), static_cast<size_t>(3),
42+
platform::errors::InvalidArgument(
43+
"multi_dot: the first input tensor must be 1D or 2D but got[%d]!",
44+
static_cast<int>(first_dim.size())));
4945

5046
// If the first tensor is 1D of size n view it as a row vector (1, n)
5147
if (first_dim.size() == 1) {
@@ -54,11 +50,11 @@ inline framework::DDim ComputeAndCheckShape(
5450
}
5551

5652
auto last_dim = inputs_dims[n - 1];
57-
if (last_dim.size() > 2) {
58-
PADDLE_THROW(platform::errors::InvalidArgument(
59-
"the last input tensor of multi_dot op must be 1D or 2D but got[%d]!",
60-
static_cast<int>(last_dim.size())));
61-
}
53+
PADDLE_ENFORCE_LT(
54+
last_dim.size(), static_cast<size_t>(3),
55+
platform::errors::InvalidArgument(
56+
"the last input tensor of multi_dot must be 1D or 2D but got[%d]!",
57+
static_cast<int>(first_dim.size())));
6258

6359
// If the last tensor is 1D of size n view it as a column vector (n, 1)
6460
if (last_dim.size() == 1) {
@@ -226,10 +222,6 @@ class MultiDotOpMaker : public framework::OpProtoAndCheckerMaker {
226222
void Make() override {
227223
AddInput("X", "The input tensors of multi_dot operator.").AsDuplicable();
228224
AddOutput("Out", "The output tensor of multi_dot operator");
229-
AddAttr<bool>(
230-
"use_mkldnn",
231-
"(bool, default false) Indicates if MKL-DNN kernel will be used")
232-
.SetDefault(false);
233225
AddComment(R"DOC(
234226
Compute the dot product of two or more arrays in a single function call, while automatically selecting the fastest evaluation order.
235227
@@ -259,44 +251,6 @@ class MultiDotOp : public framework::OperatorWithKernel {
259251
ctx->SetOutputDim("Out", out_dims);
260252
ctx->ShareLoD("X", "Out");
261253
}
262-
263-
protected:
264-
framework::OpKernelType GetExpectedKernelType(
265-
const framework::ExecutionContext& ctx) const override {
266-
auto inputs = ctx.MultiInput<Tensor>("X");
267-
auto input_data_type = framework::proto::VarType::Type(0);
268-
for (auto* input : inputs) {
269-
if (!input->IsInitialized()) {
270-
PADDLE_THROW(platform::errors::InvalidArgument(
271-
"The inputs of multi_dot OP are Empty!"));
272-
break;
273-
}
274-
}
275-
input_data_type = inputs[0]->type();
276-
277-
#ifdef PADDLE_WITH_MKLDNN
278-
using mkldnn::memory;
279-
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
280-
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
281-
framework::DataLayout::kMKLDNN,
282-
framework::LibraryType::kMKLDNN);
283-
}
284-
#endif
285-
return framework::OpKernelType(input_data_type, ctx.GetPlace());
286-
}
287-
288-
framework::OpKernelType GetKernelTypeForVar(
289-
const std::string& var_name, const framework::Tensor& tensor,
290-
const framework::OpKernelType& expected_kernel_type) const {
291-
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
292-
// only promote inputs’s types when contains complex input
293-
return framework::OpKernelType(tensor.type(), tensor.place(),
294-
tensor.layout());
295-
} else {
296-
return framework::OpKernelType(expected_kernel_type.data_type_,
297-
tensor.place(), tensor.layout());
298-
}
299-
}
300254
};
301255

302256
/**
@@ -379,21 +333,6 @@ class MultiDotOpGrad : public framework::OperatorWithKernel {
379333
ctx->SetOutputsDim(out_x_g_n, ins_dims);
380334
ctx->ShareAllLoD(in_x, out_x_g_n);
381335
}
382-
383-
protected:
384-
framework::OpKernelType GetExpectedKernelType(
385-
const framework::ExecutionContext& ctx) const override {
386-
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
387-
ctx, framework::GradVarName("Out")),
388-
ctx.GetPlace());
389-
}
390-
391-
framework::OpKernelType GetKernelTypeForVar(
392-
const std::string& var_name, const Tensor& tensor,
393-
const framework::OpKernelType& expected_kernel_type) const override {
394-
return framework::OpKernelType(expected_kernel_type.data_type_,
395-
tensor.place(), tensor.layout());
396-
}
397336
};
398337

399338
template <typename DeviceContext, typename T>

0 commit comments

Comments
 (0)