@@ -15,7 +15,12 @@ limitations under the License. */
1515#include < string>
1616
1717#include " paddle/fluid/framework/op_registry.h"
18+ #include " paddle/fluid/framework/pten_utils.h"
1819
20+ // only can include the headers in paddle/pten/api dirs
21+ #include " paddle/pten/api/lib/utils/tensor_utils.h"
22+ #include " paddle/pten/include/core.h"
23+ #include " paddle/pten/include/manipulation.h"
1924namespace paddle {
2025namespace framework {
2126class InferShapeContext ;
@@ -248,13 +253,6 @@ class ReshapeOp : public framework::OperatorWithKernel {
248253 auto input_data_type =
249254 framework::OperatorWithKernel::IndicateVarDataType (ctx, " X" );
250255
251- // #ifdef PADDLE_WITH_MKLDNN
252- // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
253- // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
254- // framework::DataLayout::kMKLDNN,
255- // framework::LibraryType::kMKLDNN);
256- // }
257- // #endif
258256 return framework::OpKernelType (input_data_type, ctx.GetPlace ());
259257 }
260258
@@ -366,13 +364,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
366364 auto input_data_type =
367365 framework::OperatorWithKernel::IndicateVarDataType (ctx, " X" );
368366
369- // #ifdef PADDLE_WITH_MKLDNN
370- // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
371- // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
372- // framework::DataLayout::kMKLDNN,
373- // framework::LibraryType::kMKLDNN);
374- // }
375- // #endif
376367 return framework::OpKernelType (input_data_type, ctx.GetPlace ());
377368 }
378369};
@@ -382,42 +373,117 @@ class ReshapeKernel {
382373 void operator ()(const framework::ExecutionContext &ctx) const {
383374 auto *out = ctx.Output <framework::LoDTensor>(" Out" );
384375 auto *in = ctx.Input <framework::LoDTensor>(" X" );
385-
386- framework::DDim out_dims = out->dims ();
376+ // framework::DDim out_dims = out->dims();
377+ auto pt_x = paddle::experimental::MakePtenDenseTensor (*in);
378+
379+ // we can't MakePtenDenseTensor by out, because reshape will realloc memory
380+ // and this will throw error(can't realloc shared memory) in current
381+ // DenseTensor
382+ // design. So, codes below create a tmp densetensor for output.
383+ // TODO(YuanRisheng) we can use MakePtenDenseTensor after #36916 merge.
384+ const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
385+ paddle::platform::CPUPlace ());
386+ pten::DenseTensorMeta meta{pten::TransToPtenDataType (in->type ()),
387+ in->dims (),
388+ pten::TransToPtenDataLayout (in->layout ())};
389+ auto pt_out_tmp =
390+ std::make_shared<pten::DenseTensor>(alloc, std::move (meta));
391+ pten::DenseTensor *pt_out = nullptr ;
392+ if (in == out) {
393+ pt_out = pt_x.get ();
394+ } else {
395+ pt_out = pt_out_tmp.get ();
396+ }
387397
388398 auto list_new_shape_tensor =
389399 ctx.MultiInput <framework::Tensor>(" ShapeTensor" );
400+ auto *shape_tensor = ctx.HasInput (" Shape" )
401+ ? ctx.Input <framework::LoDTensor>(" Shape" )
402+ : nullptr ;
390403 if (list_new_shape_tensor.size () > 0 ) {
391404 // have shape tensor
392- auto new_shape = get_new_shape (list_new_shape_tensor);
393- out_dims = ReshapeOp::ValidateShape (new_shape, in->dims ());
405+ std::vector<pten::DenseTensor> pt_vec_shape;
406+ for (auto &tensor : list_new_shape_tensor) {
407+ if (platform::is_gpu_place (tensor->place ()) ||
408+ platform::is_xpu_place (tensor->place ())) {
409+ framework::Tensor temp;
410+ TensorCopySync (*tensor, platform::CPUPlace (), &temp);
411+ pt_vec_shape.push_back (
412+ std::move (*(paddle::experimental::MakePtenDenseTensor (temp))));
413+ } else {
414+ pt_vec_shape.push_back (
415+ std::move (*(paddle::experimental::MakePtenDenseTensor (*tensor))));
416+ }
417+ }
418+ if (platform::is_cpu_place (ctx.GetPlace ())) {
419+ auto &dev_ctx = ctx.device_context <platform::CPUDeviceContext>();
420+ pten::ReshapeFromVectorDT (dev_ctx, *pt_x.get (), pt_vec_shape, pt_out);
421+ }
422+ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
423+ if (platform::is_gpu_place (ctx.GetPlace ())) {
424+ auto &dev_ctx = ctx.device_context <platform::CUDADeviceContext>();
425+ pten::ReshapeFromVectorDT (dev_ctx, *pt_x.get (), pt_vec_shape, pt_out);
426+ }
427+ #endif
428+ #ifdef PADDLE_WITH_XPU
429+ if (platform::is_xpu_place (ctx.GetPlace ())) {
430+ auto &dev_ctx = ctx.device_context <platform::XPUDeviceContext>();
431+ pten::ReshapeFromVectorDT (dev_ctx, *pt_x.get (), pt_vec_shape, pt_out);
432+ }
433+ #endif
434+ } else if (shape_tensor) {
435+ std::unique_ptr<pten::DenseTensor> pt_shape;
436+ if (platform::is_gpu_place (shape_tensor->place ()) ||
437+ platform::is_xpu_place (shape_tensor->place ())) {
438+ framework::Tensor temp;
439+ TensorCopySync (*shape_tensor, platform::CPUPlace (), &temp);
440+ pt_shape = paddle::experimental::MakePtenDenseTensor (temp);
441+ } else {
442+ pt_shape = paddle::experimental::MakePtenDenseTensor (*shape_tensor);
443+ }
394444
445+ if (platform::is_cpu_place (ctx.GetPlace ())) {
446+ auto &dev_ctx = ctx.device_context <platform::CPUDeviceContext>();
447+ pten::ReshapeFromDT (dev_ctx, *pt_x.get (), *pt_shape.get (), pt_out);
448+ }
449+ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
450+ if (platform::is_gpu_place (ctx.GetPlace ())) {
451+ auto &dev_ctx = ctx.device_context <platform::CUDADeviceContext>();
452+ pten::ReshapeFromDT (dev_ctx, *pt_x.get (), *pt_shape.get (), pt_out);
453+ }
454+ #endif
455+ #ifdef PADDLE_WITH_XPU
456+ if (platform::is_xpu_place (ctx.GetPlace ())) {
457+ auto &dev_ctx = ctx.device_context <platform::XPUDeviceContext>();
458+ pten::ReshapeFromDT (dev_ctx, *pt_x.get (), *pt_shape.get (), pt_out);
459+ }
460+ #endif
395461 } else {
396- auto *shape_tensor = ctx.HasInput (" Shape" )
397- ? ctx.Input <framework::LoDTensor>(" Shape" )
398- : nullptr ;
399-
400- if (shape_tensor) {
401- auto *shape_data = shape_tensor->data <int >();
402- framework::Tensor cpu_shape_tensor;
403- if (platform::is_gpu_place (shape_tensor->place ()) ||
404- platform::is_xpu_place (shape_tensor->place ())) {
405- TensorCopySync (*shape_tensor, platform::CPUPlace (),
406- &cpu_shape_tensor);
407- shape_data = cpu_shape_tensor.data <int >();
408- }
409- auto shape =
410- std::vector<int >(shape_data, shape_data + shape_tensor->numel ());
411- out_dims = ReshapeOp::ValidateShape (shape, in->dims ());
462+ auto &shape_vec = ctx.Attr <std::vector<int >>(" shape" );
463+ if (platform::is_cpu_place (ctx.GetPlace ())) {
464+ auto &dev_ctx = ctx.device_context <platform::CPUDeviceContext>();
465+ pten::ReshapeFromVectorVal (dev_ctx, *pt_x.get (), shape_vec, pt_out);
466+ }
467+ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
468+ if (platform::is_gpu_place (ctx.GetPlace ())) {
469+ auto &dev_ctx = ctx.device_context <platform::CUDADeviceContext>();
470+ pten::ReshapeFromVectorVal (dev_ctx, *pt_x.get (), shape_vec, pt_out);
412471 }
472+ #endif
473+ #ifdef PADDLE_WITH_XPU
474+ if (platform::is_xpu_place (ctx.GetPlace ())) {
475+ auto &dev_ctx = ctx.device_context <platform::XPUDeviceContext>();
476+ pten::ReshapeFromVectorVal (dev_ctx, *pt_x.get (), shape_vec, pt_out);
477+ }
478+ #endif
479+ }
480+ // non-inplace need move all result from pt_out to out, inplace need set
481+ // result dims.
482+ if (in != out) {
483+ paddle::experimental::MovesStorage (pt_out, static_cast <Tensor *>(out));
484+ } else {
485+ out->Resize (pt_out->dims ());
413486 }
414-
415- out->Resize (out_dims);
416- out->mutable_data (ctx.GetPlace (), in->type ());
417- framework::TensorCopy (
418- *in, ctx.GetPlace (),
419- ctx.template device_context <platform::DeviceContext>(), out);
420- out->Resize (out_dims);
421487 }
422488};
423489
@@ -479,6 +545,21 @@ class Reshape2Op : public ReshapeOp {
479545
480546 ReshapeOp::InferShape (ctx);
481547 }
548+
549+ framework::KernelSignature GetExpectedPtenKernelArgs (
550+ const framework::ExecutionContext &ctx) const override {
551+ auto multi_inputs = ctx.MultiInput <framework::Tensor>(" ShapeTensor" );
552+ if (multi_inputs.size () > 0 ) {
553+ return framework::KernelSignature (
554+ " reshape2.mulhost.mid" , {" X" , " ShapeTensor" }, {}, {" XShape" , " Out" });
555+ } else if (ctx.HasInput (" Shape" )) {
556+ return framework::KernelSignature (" reshape2.host.mid" , {" X" , " Shape" }, {},
557+ {" XShape" , " Out" });
558+ } else {
559+ return framework::KernelSignature (" reshape2.mid" , {" X" }, {" shape" },
560+ {" XShape" , " Out" });
561+ }
562+ }
482563};
483564
484565class Reshape2OpMaker : public ReshapeOpMaker {
@@ -557,13 +638,6 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
557638 auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType (
558639 ctx, framework::GradVarName (" Out" ));
559640
560- // #ifdef PADDLE_WITH_MKLDNN
561- // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
562- // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
563- // framework::DataLayout::kMKLDNN,
564- // framework::LibraryType::kMKLDNN);
565- // }
566- // #endif
567641 return framework::OpKernelType (input_data_type, ctx.GetPlace ());
568642 }
569643
0 commit comments