@@ -12,11 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License. */
1414
15+ #include < sstream>
16+
1517#include " paddle/fluid/framework/tcmpt_utils.h"
1618
1719#include " paddle/fluid/framework/lod_tensor.h"
1820#include " paddle/fluid/framework/selected_rows.h"
1921#include " paddle/fluid/framework/variable.h"
22+ #include " paddle/fluid/string/string_helper.h"
2023
2124namespace paddle {
2225namespace framework {
@@ -62,7 +65,7 @@ std::shared_ptr<pt::DenseTensor> MakeTensorImpl<pt::DenseTensor>(
6265 proto::VarType::Type type) {
6366 return MakeTensorImpl<pt::DenseTensor, LoDTensor>(
6467 tensor, pt::TransToPtBackend (place), pt::TransToPtDataType (type),
65- pt::TransToPtLayout (tensor.layout ()));
68+ pt::TransToPtDataLayout (tensor.layout ()));
6669}
6770
6871template <>
@@ -71,7 +74,7 @@ std::shared_ptr<pt::DenseTensor> MakeTensorImpl<pt::DenseTensor>(
7174 proto::VarType::Type type) {
7275 return MakeTensorImpl<pt::DenseTensor, Tensor>(
7376 tensor, pt::TransToPtBackend (place), pt::TransToPtDataType (type),
74- pt::TransToPtLayout (tensor.layout ()));
77+ pt::TransToPtDataLayout (tensor.layout ()));
7578}
7679
7780std::shared_ptr<tcmpt::TensorBase> InputVariableToPtTensor (
@@ -150,5 +153,115 @@ std::shared_ptr<tcmpt::TensorBase> OutputVariableToPtTensor(
150153 return nullptr ;
151154}
152155
156+ OpKernelType TransPtKernelKeyToOpKernelType (const pt::KernelKey& kernel_key) {
157+ proto::VarType::Type data_type = pt::TransToProtoVarType (kernel_key.dtype ());
158+ platform::Place place = pt::TransToFluidPlace (kernel_key.backend ());
159+ DataLayout data_layout = pt::TransToFluidDataLayout (kernel_key.layout ());
160+ LibraryType library_type = LibraryType::kPlain ;
161+ if (kernel_key.backend () == pt::Backend::kMKLDNN ) {
162+ library_type = LibraryType::kMKLDNN ;
163+ } else if (kernel_key.backend () == pt::Backend::kCUDNN ) {
164+ library_type = LibraryType::kCUDNN ;
165+ } else {
166+ // do nothing
167+ }
168+ // TODO(chenweihang): the customized_type_value is lost
169+ return OpKernelType (data_type, place, data_layout, library_type);
170+ }
171+
172+ pt::KernelKey TransOpKernelTypeToPtKernelKey (const OpKernelType& kernel_type) {
173+ pt::Backend backend = pt::TransToPtBackend (kernel_type.place_ );
174+ if (kernel_type.library_type_ == LibraryType::kMKLDNN ) {
175+ backend = pt::Backend::kMKLDNN ;
176+ } else if (kernel_type.library_type_ == LibraryType::kCUDNN ) {
177+ backend = pt::Backend::kCUDNN ;
178+ } else {
179+ // do
180+ }
181+ pt::DataLayout layout = pt::TransToPtDataLayout (kernel_type.data_layout_ );
182+ pt::DataType dtype = pt::TransToPtDataType (kernel_type.data_type_ );
183+ return pt::KernelKey (backend, layout, dtype);
184+ }
185+
186+ KernelSignatureMap& KernelSignatureMap::Instance () {
187+ static KernelSignatureMap g_kernel_signature_map;
188+ return g_kernel_signature_map;
189+ }
190+
191+ const paddle::SmallVector<std::string>&
192+ KernelArgsNameMakerByOpProto::GetInputArgsNames () {
193+ for (int i = 0 ; i < op_proto_->inputs_size (); ++i) {
194+ auto & in = op_proto_->inputs ()[i];
195+ auto & in_name = in.name ();
196+ if ((in.has_extra () && in.extra ()) || (in.has_quant () && in.quant ())) {
197+ VLOG (1 ) << " Parse PtKernel input: skip extra & quant input - " << in_name;
198+ continue ;
199+ }
200+ // If contains dispensable input, we should override the
201+ // GetExpectedPtKernelArgs method self
202+ if (in.has_dispensable () && in.dispensable ()) {
203+ VLOG (1 ) << " Parse PtKernel input: skip dispensable input - " << in_name;
204+ continue ;
205+ }
206+ VLOG (1 ) << " Parse PtKernel input: " << in_name;
207+ input_names_.emplace_back (in_name);
208+ }
209+ return input_names_;
210+ }
211+
212+ const paddle::SmallVector<std::string>&
213+ KernelArgsNameMakerByOpProto::GetOutputArgsNames () {
214+ for (int i = 0 ; i < op_proto_->outputs_size (); ++i) {
215+ auto & out = op_proto_->outputs ()[i];
216+ auto & out_name = out.name ();
217+ // TODO(chenweihang): outputs also need skip some cases
218+ VLOG (1 ) << " Parse PtKernel output: " << out_name;
219+ output_names_.emplace_back (out_name);
220+ }
221+ return output_names_;
222+ }
223+
224+ const paddle::SmallVector<std::string>&
225+ KernelArgsNameMakerByOpProto::GetAttrsArgsNames () {
226+ for (int i = 0 ; i < op_proto_->attrs_size (); ++i) {
227+ auto & attr = op_proto_->attrs ()[i];
228+ auto & attr_name = attr.name ();
229+ if (attr_name == " use_mkldnn" || attr_name == " op_role" ||
230+ attr_name == " op_role_var" || attr_name == " op_namescope" ||
231+ attr_name == " op_callstack" || attr_name == " op_device" ) {
232+ VLOG (1 ) << " Parse PtKernel attribute: skip needless attr - " << attr_name;
233+ continue ;
234+ }
235+ if ((attr.has_extra () && attr.extra ()) ||
236+ (attr.has_quant () && attr.quant ())) {
237+ VLOG (1 ) << " Parse PtKernel attribute: skip extra & quant attr - "
238+ << attr_name;
239+ continue ;
240+ }
241+ VLOG (1 ) << " Parse PtKernel attribute: " << attr_name;
242+ attr_names_.emplace_back (attr_name);
243+ }
244+
245+ return attr_names_;
246+ }
247+
248+ KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature () {
249+ return std::make_pair (
250+ op_proto_->type (),
251+ std::make_tuple (GetInputArgsNames (), GetAttrsArgsNames (),
252+ GetOutputArgsNames ()));
253+ }
254+
255+ std::string KernelSignatureToString (const KernelSignature& signature) {
256+ std::stringstream os;
257+ os << " Kernel Signature - name: " << signature.first << " ; inputs: "
258+ << string::join_strings (std::get<0 >(signature.second ), " , " )
259+ << " ; attributes: "
260+ << string::join_strings (std::get<1 >(signature.second ), " , " )
261+ << " ; outputs: "
262+ << string::join_strings (std::get<2 >(signature.second ), " , " );
263+ return os.str ();
264+ }
265+
153266} // namespace framework
154267} // namespace paddle
0 commit comments