@@ -21,7 +21,6 @@ enum class ReshapeKernelOpName {
2121 reshape,
2222 reshape2,
2323 squeeze,
24- squeeze2,
2524 flatten,
2625 flatten2,
2726};
@@ -106,9 +105,6 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
106105 case ReshapeKernelOpName::squeeze:
107106 InferShapeSqueezeOp (ctx, x_dims, out_dims);
108107 break ;
109- case ReshapeKernelOpName::squeeze2:
110- InferShapeSqueeze2Op (ctx, x_dims, out_dims);
111- break ;
112108 case ReshapeKernelOpName::flatten:
113109 InferShapeFlattenOp (ctx, x_dims, out_dims);
114110 break ;
@@ -172,16 +168,6 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
172168 out_dims = GetOutputShape (axes, x_dims, true );
173169 }
174170
175- void InferShapeSqueeze2Op (const framework::ExecutionContext& ctx,
176- framework::DDim& x_dims, // NOLINT
177- framework::DDim& out_dims) const { // NOLINT
178- auto * out = ctx.Output <phi::DenseTensor>(" Out" );
179- auto * xshape = ctx.Output <phi::DenseTensor>(" XShape" );
180- auto xshape_dims = xshape->dims ();
181- x_dims = phi::slice_ddim (xshape_dims, 1 , xshape_dims.size ());
182- out_dims = out->dims ();
183- }
184-
185171 void InferShapeFlattenOp (const framework::ExecutionContext& ctx,
186172 framework::DDim& x_dims, // NOLINT
187173 framework::DDim& out_dims) const { // NOLINT
@@ -342,19 +328,16 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> {
342328 InferShapeReshapeSqueezeGradOp (ctx, x_dims);
343329 break ;
344330 case ReshapeKernelOpName::reshape2:
345- InferShapeReshape2Squeeze2Flatten2GradOp (ctx, x_dims);
331+ InferShapeReshape2Flatten2GradOp (ctx, x_dims);
346332 break ;
347333 case ReshapeKernelOpName::squeeze:
348334 InferShapeReshapeSqueezeGradOp (ctx, x_dims);
349335 break ;
350- case ReshapeKernelOpName::squeeze2:
351- InferShapeReshape2Squeeze2Flatten2GradOp (ctx, x_dims);
352- break ;
353336 case ReshapeKernelOpName::flatten:
354337 InferShapeFlattenGradOp (ctx, x_dims);
355338 break ;
356339 case ReshapeKernelOpName::flatten2:
357- InferShapeReshape2Squeeze2Flatten2GradOp (ctx, x_dims);
340+ InferShapeReshape2Flatten2GradOp (ctx, x_dims);
358341 break ;
359342 default :
360343 PADDLE_THROW (paddle::platform::errors::OutOfRange (
@@ -369,7 +352,7 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> {
369352 dx_dims = dx->dims ();
370353 }
371354
372- void InferShapeReshape2Squeeze2Flatten2GradOp (
355+ void InferShapeReshape2Flatten2GradOp (
373356 const framework::ExecutionContext& ctx,
374357 framework::DDim& dx_dims) const { // NOLINT
375358 auto xshape_dims = ctx.Input <phi::DenseTensor>(" XShape" )->dims ();
@@ -401,22 +384,6 @@ REGISTER_OP_KERNEL(
401384 ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
402385 ReshapeKernelOpName::squeeze>);
403386
404- REGISTER_OP_KERNEL (
405- squeeze2,
406- MKLDNN,
407- paddle::platform::CPUPlace,
408- ops::ReshapeMKLDNNKernel<float , ReshapeKernelOpName::squeeze2>,
409- ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
410- ReshapeKernelOpName::squeeze2>);
411-
412- REGISTER_OP_KERNEL (
413- squeeze2_grad,
414- MKLDNN,
415- paddle::platform::CPUPlace,
416- ops::ReshapeGradMKLDNNKernel<float , ReshapeKernelOpName::squeeze2>,
417- ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
418- ReshapeKernelOpName::squeeze2>);
419-
420387REGISTER_OP_KERNEL (
421388 reshape,
422389 MKLDNN,
0 commit comments