@@ -22,10 +22,6 @@ limitations under the License. */
2222#include " paddle/fluid/operators/strided_memcpy.h"
2323#include " paddle/fluid/operators/utils.h"
2424
25- #ifdef PADDLE_WITH_MKLDNN
26- #include " paddle/fluid/platform/mkldnn_helper.h"
27- #endif
28-
2925namespace paddle {
3026namespace operators {
3127using Tensor = framework::Tensor;
@@ -41,11 +37,11 @@ inline framework::DDim ComputeAndCheckShape(
4137 bool is_vector = false ;
4238 framework::DDim out_dim;
4339
44- if (first_dim. size () > 2 ) {
45- PADDLE_THROW ( platform::errors::InvalidArgument (
46- " multi_dot: the first input tensor must be 1D or 2D but got[%d]! " ,
47- static_cast < int >(first_dim. size ())));
48- }
40+ PADDLE_ENFORCE_LT (
41+ first_dim. size (), static_cast < size_t >( 3 ),
42+ platform::errors::InvalidArgument (
43+ " multi_dot: the first input tensor must be 1D or 2D but got[%d]! " ,
44+ static_cast < int >(first_dim. size ())));
4945
5046 // If the first tensor is 1D of size n view it as a row vector (1, n)
5147 if (first_dim.size () == 1 ) {
@@ -54,11 +50,11 @@ inline framework::DDim ComputeAndCheckShape(
5450 }
5551
5652 auto last_dim = inputs_dims[n - 1 ];
57- if (last_dim. size () > 2 ) {
58- PADDLE_THROW ( platform::errors::InvalidArgument (
59- " the last input tensor of multi_dot op must be 1D or 2D but got[%d]! " ,
60- static_cast < int >(last_dim. size ())));
61- }
53+ PADDLE_ENFORCE_LT (
54+ last_dim. size (), static_cast < size_t >( 3 ),
55+ platform::errors::InvalidArgument (
56+ " the last input tensor of multi_dot must be 1D or 2D but got[%d]! " ,
57+ static_cast < int >(first_dim. size ())));
6258
6359 // If the last tensor is 1D of size n view it as a column vector (n, 1)
6460 if (last_dim.size () == 1 ) {
@@ -226,10 +222,6 @@ class MultiDotOpMaker : public framework::OpProtoAndCheckerMaker {
226222 void Make () override {
227223 AddInput (" X" , " The input tensors of multi_dot operator." ).AsDuplicable ();
228224 AddOutput (" Out" , " The output tensor of multi_dot operator" );
229- AddAttr<bool >(
230- " use_mkldnn" ,
231- " (bool, default false) Indicates if MKL-DNN kernel will be used" )
232- .SetDefault (false );
233225 AddComment (R"DOC(
234226Compute the dot product of two or more arrays in a single function call, while automatically selecting the fastest evaluation order.
235227
@@ -259,44 +251,6 @@ class MultiDotOp : public framework::OperatorWithKernel {
259251 ctx->SetOutputDim (" Out" , out_dims);
260252 ctx->ShareLoD (" X" , " Out" );
261253 }
262-
263- protected:
264- framework::OpKernelType GetExpectedKernelType (
265- const framework::ExecutionContext& ctx) const override {
266- auto inputs = ctx.MultiInput <Tensor>(" X" );
267- auto input_data_type = framework::proto::VarType::Type (0 );
268- for (auto * input : inputs) {
269- if (!input->IsInitialized ()) {
270- PADDLE_THROW (platform::errors::InvalidArgument (
271- " The inputs of multi_dot OP are Empty!" ));
272- break ;
273- }
274- }
275- input_data_type = inputs[0 ]->type ();
276-
277- #ifdef PADDLE_WITH_MKLDNN
278- using mkldnn::memory;
279- if (this ->CanMKLDNNBeUsed (ctx, input_data_type)) {
280- return framework::OpKernelType (input_data_type, ctx.GetPlace (),
281- framework::DataLayout::kMKLDNN ,
282- framework::LibraryType::kMKLDNN );
283- }
284- #endif
285- return framework::OpKernelType (input_data_type, ctx.GetPlace ());
286- }
287-
288- framework::OpKernelType GetKernelTypeForVar (
289- const std::string& var_name, const framework::Tensor& tensor,
290- const framework::OpKernelType& expected_kernel_type) const {
291- if (framework::IsComplexType (expected_kernel_type.data_type_ )) {
292- // only promote inputs’s types when contains complex input
293- return framework::OpKernelType (tensor.type (), tensor.place (),
294- tensor.layout ());
295- } else {
296- return framework::OpKernelType (expected_kernel_type.data_type_ ,
297- tensor.place (), tensor.layout ());
298- }
299- }
300254};
301255
302256/* *
@@ -379,21 +333,6 @@ class MultiDotOpGrad : public framework::OperatorWithKernel {
379333 ctx->SetOutputsDim (out_x_g_n, ins_dims);
380334 ctx->ShareAllLoD (in_x, out_x_g_n);
381335 }
382-
383- protected:
384- framework::OpKernelType GetExpectedKernelType (
385- const framework::ExecutionContext& ctx) const override {
386- return framework::OpKernelType (OperatorWithKernel::IndicateVarDataType (
387- ctx, framework::GradVarName (" Out" )),
388- ctx.GetPlace ());
389- }
390-
391- framework::OpKernelType GetKernelTypeForVar (
392- const std::string& var_name, const Tensor& tensor,
393- const framework::OpKernelType& expected_kernel_type) const override {
394- return framework::OpKernelType (expected_kernel_type.data_type_ ,
395- tensor.place (), tensor.layout ());
396- }
397336};
398337
399338template <typename DeviceContext, typename T>
0 commit comments