@@ -24,148 +24,6 @@ limitations under the License. */
2424namespace paddle {
2525namespace framework {
2626
27- // TODO(chenweihang, shixiaowei): adapt SelectedRows
28- template <>
29- std::shared_ptr<pten::DenseTensor> MakeTensorImpl<pten::DenseTensor, LoDTensor>(
30- const LoDTensor& tensor, pten::Backend backend,
31- paddle::experimental::DataType dtype,
32- paddle::experimental::DataLayout layout) {
33- auto holder = tensor.Holder ();
34- auto tensor_impl = std::make_shared<pten::DenseTensor>(
35- pten::TensorMeta (tensor.dims (), backend, dtype, layout, tensor.offset ()),
36- pten::TensorStatus ());
37-
38- if (holder != nullptr ) {
39- tensor_impl->ShareAllocation (tensor.Holder ());
40- }
41- return tensor_impl;
42- }
43-
44- template <>
45- std::shared_ptr<pten::DenseTensor> MakeTensorImpl<pten::DenseTensor, Tensor>(
46- const Tensor& tensor, pten::Backend backend,
47- paddle::experimental::DataType dtype,
48- paddle::experimental::DataLayout layout) {
49- auto holder = tensor.Holder ();
50- auto tensor_impl = std::make_shared<pten::DenseTensor>(
51- pten::TensorMeta (tensor.dims (), backend, dtype, layout, tensor.offset ()),
52- pten::TensorStatus ());
53-
54- if (holder != nullptr ) {
55- tensor_impl->ShareAllocation (tensor.Holder ());
56- }
57- return tensor_impl;
58- }
59-
60- template <>
61- std::shared_ptr<pten::DenseTensor> MakeTensorImpl<pten::DenseTensor>(
62- const LoDTensor& tensor, const platform::Place& place,
63- proto::VarType::Type type) {
64- return MakeTensorImpl<pten::DenseTensor, LoDTensor>(
65- tensor, pten::TransToPtenBackend (place), pten::TransToPtenDataType (type),
66- pten::TransToPtenDataLayout (tensor.layout ()));
67- }
68-
69- template <>
70- std::shared_ptr<pten::DenseTensor> MakeTensorImpl<pten::DenseTensor>(
71- const Tensor& tensor, const platform::Place& place,
72- proto::VarType::Type type) {
73- return MakeTensorImpl<pten::DenseTensor, Tensor>(
74- tensor, pten::TransToPtenBackend (place), pten::TransToPtenDataType (type),
75- pten::TransToPtenDataLayout (tensor.layout ()));
76- }
77-
78- template <>
79- void ShareTensorImpl<pten::DenseTensor>(pten::DenseTensor* tensor_impl,
80- LoDTensor* out) {
81- out->ResetHolderWithType (tensor_impl->allocation (),
82- pten::TransToProtoVarType (tensor_impl->data_type ()));
83- }
84-
85- template <>
86- void ShareTensorImpl<pten::DenseTensor>(pten::DenseTensor* tensor_impl,
87- Tensor* out) {
88- out->ResetHolderWithType (tensor_impl->allocation (),
89- pten::TransToProtoVarType (tensor_impl->data_type ()));
90- }
91-
92- std::shared_ptr<pten::TensorBase> InputVariableToPtenTensor (
93- const framework::Variable& variable, const pten::TensorArgDef& arg_def) {
94- auto expected_place = pten::TransToFluidPlace (arg_def.backend );
95-
96- if (variable.template IsType <framework::LoDTensor>()) {
97- const auto & tensor = variable.template Get <framework::LoDTensor>();
98- if (!platform::is_same_place (tensor.place (), expected_place)) {
99- framework::LoDTensor tmp_tensor;
100- framework::TensorCopySync (tensor, expected_place, &tmp_tensor);
101- auto pt_in =
102- framework::MakeTensorImpl<pten::DenseTensor, framework::LoDTensor>(
103- tmp_tensor, arg_def.backend , arg_def.dtype , arg_def.layout );
104- return pt_in;
105- } else {
106- auto pt_in =
107- framework::MakeTensorImpl<pten::DenseTensor, framework::LoDTensor>(
108- tensor, arg_def.backend , arg_def.dtype , arg_def.layout );
109- return pt_in;
110- }
111- } else if (variable.template IsType <framework::SelectedRows>()) {
112- // TODO(chenweihang): now we don't deal with row and height
113- // by xiaowei's advice
114- const auto & tensor = variable.template Get <framework::SelectedRows>();
115- if (!platform::is_same_place (tensor.value ().place (), expected_place)) {
116- framework::Tensor tmp_tensor;
117- TensorCopySync (tensor.value (), expected_place, &tmp_tensor);
118- // TODO(chenweihang): adapt SelectedRows by xiaowei's design
119- auto pt_in =
120- framework::MakeTensorImpl<pten::DenseTensor, framework::Tensor>(
121- tmp_tensor, arg_def.backend , arg_def.dtype , arg_def.layout );
122- return pt_in;
123- } else {
124- auto pt_in =
125- framework::MakeTensorImpl<pten::DenseTensor, framework::Tensor>(
126- tensor.value (), arg_def.backend , arg_def.dtype , arg_def.layout );
127- return pt_in;
128- }
129- } else {
130- PADDLE_THROW (platform::errors::Unimplemented (
131- " Unsupported shared input `%s` type now when call pt kernel." ,
132- framework::ToTypeName (variable.Type ())));
133- }
134- return nullptr ;
135- }
136-
137- std::shared_ptr<pten::TensorBase> OutputVariableToPtenTensor (
138- framework::Variable* variable, const pten::TensorArgDef& arg_def) {
139- // mutable_data before run kernel, to avoid share output form
140- // KernelContext to original tensor
141- if (variable->template IsType <framework::LoDTensor>()) {
142- auto * tensor = variable->template GetMutable <framework::LoDTensor>();
143- tensor->mutable_data (pten::TransToFluidPlace (arg_def.backend ),
144- pten::TransToProtoVarType (arg_def.dtype ));
145- auto pt_out =
146- framework::MakeTensorImpl<pten::DenseTensor, framework::LoDTensor>(
147- *tensor, arg_def.backend , arg_def.dtype , arg_def.layout );
148- return pt_out;
149- } else if (variable->template IsType <framework::SelectedRows>()) {
150- auto * tensor = variable->template GetMutable <framework::SelectedRows>();
151- tensor->mutable_value ()->mutable_data (
152- pten::TransToFluidPlace (arg_def.backend ),
153- pten::TransToProtoVarType (arg_def.dtype ));
154- // TODO(chenweihang): adapt SelectedRows by xiaowei's design,
155- // here the row and height will lost in output!
156- auto pt_out =
157- framework::MakeTensorImpl<pten::DenseTensor, framework::Tensor>(
158- tensor->value (), arg_def.backend , arg_def.dtype , arg_def.layout );
159- return pt_out;
160- } else {
161- PADDLE_THROW (platform::errors::Unimplemented (
162- " Unsupported shared output `%s` type now when call pt kernel." ,
163- framework::ToTypeName (variable->Type ())));
164- }
165-
166- return nullptr ;
167- }
168-
16927OpKernelType TransPtenKernelKeyToOpKernelType (
17028 const pten::KernelKey& kernel_key) {
17129 proto::VarType::Type data_type =
0 commit comments