diff --git a/paddle/infrt/dialect/init_infrt_dialects.cc b/paddle/infrt/dialect/init_infrt_dialects.cc index 090f1aea289109..8c858008473c96 100644 --- a/paddle/infrt/dialect/init_infrt_dialects.cc +++ b/paddle/infrt/dialect/init_infrt_dialects.cc @@ -21,6 +21,7 @@ #include "paddle/infrt/dialect/infrt/infrt_dialect.h" #include "paddle/infrt/dialect/infrt_base.h" #include "paddle/infrt/dialect/pd_ops.h" +#include "paddle/infrt/dialect/pten/infrt_pten_kernel.h" #include "paddle/infrt/dialect/pten/infrt_pten_tensor.h" #include "paddle/infrt/dialect/pten/pten_base.h" #include "paddle/infrt/dialect/tensor_shape.h" @@ -34,6 +35,7 @@ void registerCinnDialects(mlir::DialectRegistry ®istry) { // NOLINT mlir::pd::PaddleDialect, #ifdef INFRT_WITH_PTEN pten::PTENDenseTensorDialect, + pten::PTENKernelDialect, pten::PTENDialect #endif >(); diff --git a/paddle/infrt/dialect/pten/CMakeLists.txt b/paddle/infrt/dialect/pten/CMakeLists.txt index b4ed5cdc1d82fd..61d6fa2852c8df 100644 --- a/paddle/infrt/dialect/pten/CMakeLists.txt +++ b/paddle/infrt/dialect/pten/CMakeLists.txt @@ -10,4 +10,6 @@ add_mlir_dialect(infrt_pten_kernel pten_kernel) gather_srcs(infrt_src SRCS pten_base.cc infrt_pten_tensor.cc - infrt_pten_tensor.cc) + infrt_pten_tensor.cc + infrt_pten_kernel.cc + ) diff --git a/paddle/infrt/dialect/pten/infrt_pten_kernel.cc b/paddle/infrt/dialect/pten/infrt_pten_kernel.cc new file mode 100644 index 00000000000000..d0dfea5f4f6491 --- /dev/null +++ b/paddle/infrt/dialect/pten/infrt_pten_kernel.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/pten/infrt_pten_kernel.h" + +#include + +#include "paddle/infrt/dialect/dense_tensor.h" +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/pten/infrt_pten_kernel.h.inc" +#include "paddle/infrt/dialect/pten/infrt_pten_kernelDialect.cpp.inc" + +namespace infrt { +namespace pten { + +void PTENKernelDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "paddle/infrt/dialect/pten/infrt_pten_kernel.cpp.inc" + >(); +} + +} // namespace pten +} // namespace infrt + +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/pten/infrt_pten_kernel.cpp.inc" // NOLINT diff --git a/paddle/infrt/dialect/pten/infrt_pten_kernel.h b/paddle/infrt/dialect/pten/infrt_pten_kernel.h new file mode 100644 index 00000000000000..3f44043b0fb3d3 --- /dev/null +++ b/paddle/infrt/dialect/pten/infrt_pten_kernel.h @@ -0,0 +1,39 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/infrt/dialect/pten/infrt_pten_kernelTypes.h.inc" +#include "paddle/infrt/dialect/pten/pten_base.h" + +#include "paddle/infrt/dialect/pten/infrt_pten_kernelDialect.h.inc" + +namespace infrt { +namespace pten {} // namespace pten +} // namespace infrt diff --git a/paddle/infrt/dialect/pten/infrt_pten_kernel.td b/paddle/infrt/dialect/pten/infrt_pten_kernel.td index a3a1609d9918ae..a35c0d5b78b6a1 100644 --- a/paddle/infrt/dialect/pten/infrt_pten_kernel.td +++ b/paddle/infrt/dialect/pten/infrt_pten_kernel.td @@ -1,7 +1,11 @@ -#ifndef PTEN_KERNEL +#ifdef PTEN_KERNEL +#else #define PTEN_KERNEL -include "paddle/infrt/dialect/pten/infrt_pten_tensor.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "paddle/infrt/dialect/infrt_base.td" +include "paddle/infrt/dialect/pten/infrt_pten_base.td" def PTEN_KernelDialect : Dialect { let name = "pten_kernel"; @@ -17,10 +21,4 @@ def PTEN_KernelDialect : Dialect { class PDT_Kernel traits = []> : Op { } -def FakeKernelOp : PDT_Kernel<"pten.matmul.host.fp32"> { - let arguments = (ins CPU_Context:$dev_ctx, TensorType:$x, TensorType:$y, BoolAttr:$transpose_x, BoolAttr:$transpose_y); - let results = (outs TensorType:$output); -} - #endif - diff --git a/paddle/infrt/dialect/pten/infrt_pten_tensor.td b/paddle/infrt/dialect/pten/infrt_pten_tensor.td index 528f0f919680d6..eff584cc44a354 100644 --- a/paddle/infrt/dialect/pten/infrt_pten_tensor.td +++ b/paddle/infrt/dialect/pten/infrt_pten_tensor.td @@ -53,4 +53,9 @@ def PDT_FillDenseTensorOp_f32 : FillDenseTensorOp; def PDT_CreateAllocatorOp_cpu : CreateCPUAllocatorOp; def PDT_CreateContextOp_cpu : CreateCPUContextOp; +def FakeKernelOp : PDT_Op<"fake_pten_kernel"> { + let arguments = (ins CPU_Context:$dev_ctx, TensorType:$x, TensorType:$y, BoolAttr:$transpose_x); + let results = (outs TensorType:$output); +} + #endif diff --git a/paddle/infrt/host_context/CMakeLists.txt b/paddle/infrt/host_context/CMakeLists.txt index 11304742ecd413..ec85318c74ad5b 100644 --- a/paddle/infrt/host_context/CMakeLists.txt +++ b/paddle/infrt/host_context/CMakeLists.txt @@ -16,6 +16,7 @@ gather_srcs(infrt_src SRCS cc_test_tiny(test_infrt_host_context_value SRCS value_test.cc DEPS infrt ${MLIR_IR_LIBS}) cc_test_tiny(test_infrt_kernel_utils SRCS kernel_utils_test.cc DEPS infrt ${MLIR_IR_LIBS}) +cc_test_tiny(test_infrt_kernel_frame SRCS kernel_frame_test.cc DEPS infrt ${MLIR_IR_LIBS}) cc_test_tiny(test_infrt_kernel_registry SRCS kernel_registry_test.cc DEPS infrt ${MLIR_IR_LIBS}) cc_test_tiny(test_infrt_op_executable SRCS op_executable_test.cc DEPS infrt ${MLIR_IR_LIBS}) cc_test_tiny(test_infrt_core_runtime SRCS core_runtime_test.cc DEPS infrt ${MLIR_IR_LIBS}) diff --git a/paddle/infrt/host_context/kernel_frame.cc b/paddle/infrt/host_context/kernel_frame.cc index 1acb35e898308a..24be2bbfdddac5 100644 --- a/paddle/infrt/host_context/kernel_frame.cc +++ b/paddle/infrt/host_context/kernel_frame.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/infrt/host_context/kernel_frame.h" +#include #include @@ -25,5 +26,35 @@ std::ostream& operator<<(std::ostream& os, const KernelFrame& frame) { return os; } +#ifndef NDEBUG +std::string KernelFrame::DumpArgTypes() const { + std::stringstream ss; + for (auto* value : GetValues(0, GetNumElements())) { + if (value->is_type()) { + ss << "bool,"; + } else if (value->is_type()) { + ss << "DenseHostTensor,"; + } else if (value->is_type()) { + ss << "float,"; + } else if (value->is_type()) { + ss << "int,"; + } else if (value->is_type()) { + ss << "pten::DenseTensor,"; + } else if (value->is_type()) { + ss << "pten::MetaTensor,"; + } else if (value->is_type<::pten::CPUContext>()) { + ss << "pten::CPUContext,"; + } else if (value->is_type()) { + ss << "none,"; + } else if (value->is_type()) { + ss << "CpuPtenContext,"; + } else { + ss << "unk,"; + } + } + return ss.str(); +} +#endif + } // namespace host_context } // namespace infrt diff --git a/paddle/infrt/host_context/kernel_frame.h b/paddle/infrt/host_context/kernel_frame.h index 35527872e624f7..7cef05c9c26d5a 100644 --- a/paddle/infrt/host_context/kernel_frame.h +++ b/paddle/infrt/host_context/kernel_frame.h @@ -31,20 +31,24 @@ namespace host_context { class KernelFrame { public: int GetNumArgs() const { return num_arguments_; } - int GetNumResults() const { return num_results_ == -1 ? 0 : num_results_; } - int GetNumAttributes() const { - return value_or_attrs_.size() - num_arguments_ - - (num_results_ == -1 ? 0 : num_results_); + int GetNumResults() const { + return value_or_attrs_.size() - num_arguments_ - GetNumAttributes(); } + int GetNumAttributes() const { return num_attrs_ == -1 ? 0 : num_attrs_; } //! Get something at a specific position \p index. The element might be an //! argument, an attribute or a result. template T& GetElementAt(int index) { - CHECK_LT(index, GetNumArgs() + GetNumAttributes() + GetNumResults()); + CHECK_LT(static_cast(index), GetNumElements()); return value_or_attrs_[index]->template get_or_default(); } + Value* GetElementAt(int index) { + CHECK_LT(static_cast(index), GetNumElements()); + return value_or_attrs_[index]; + } + // Get number of elements, either input, attributes or results. size_t GetNumElements() const { return value_or_attrs_.size(); } @@ -70,18 +74,21 @@ class KernelFrame { } Value* GetAttributeAt(int idx) { - CHECK_NE(num_results_, -1) - << "Must call SetNumResults before GetAttributeAt"; - CHECK_LT(idx, - static_cast(value_or_attrs_.size() - num_arguments_ - - num_results_)); - return value_or_attrs_[num_arguments_ + num_results_ + idx]; + // CHECK_NE(num_results_, -1) + //<< "Must call SetNumResults before GetAttributeAt"; + CHECK_LT(idx, GetNumAttributes()); + return value_or_attrs_[num_arguments_ + idx]; } void AddAttribute(Value* v) { - CHECK_NE(num_results_, -1) - << "Must call SetNumResults before calling AddAttribute"; + CHECK_EQ(num_results_, -1) + << "Must call SetNumResults after calling AddAttribute"; value_or_attrs_.emplace_back(v); + if (num_attrs_ == -1) num_attrs_ = 0; + num_attrs_++; + + CHECK_EQ(value_or_attrs_.size(), + static_cast(num_arguments_ + num_attrs_)); } template @@ -96,35 +103,41 @@ class KernelFrame { template void SetResultAt(int index, T&& value) { - CHECK_LT(index, num_results_) << "Invalid result index"; - CHECK(value_or_attrs_[num_arguments_ + index]); - value_or_attrs_[num_arguments_ + index]->set(std::move(value)); + CHECK_LT(index, GetNumResults()) << "Invalid result index"; + CHECK(value_or_attrs_[num_arguments_ + GetNumAttributes() + index]); + value_or_attrs_[num_arguments_ + GetNumAttributes() + index]->set( + std::move(value)); } llvm::ArrayRef GetResults() const { - return GetValues(num_arguments_, num_results_); + return GetValues(num_arguments_ + GetNumAttributes(), num_results_); } llvm::MutableArrayRef GetResults() { - return GetMutableValues(num_arguments_, num_results_); + return GetMutableValues(num_arguments_ + GetNumAttributes(), num_results_); } llvm::ArrayRef GetValues(size_t from, size_t length) const { - CHECK_LE(static_cast(from + length), num_arguments_ + num_results_); + CHECK_LE(from + length, GetNumElements()); if (length == 0) return {}; return llvm::makeArrayRef(&value_or_attrs_[from], length); } llvm::MutableArrayRef GetMutableValues(size_t from, size_t length) { - CHECK_LE(static_cast(from + length), num_arguments_ + num_results_); + CHECK_LE(from + length, GetNumElements()); if (length == 0) return {}; return llvm::makeMutableArrayRef(&value_or_attrs_[from], length); } +#ifndef NDEBUG + std::string DumpArgTypes() const; +#endif + bool IsEmpty() const { return value_or_attrs_.empty(); } protected: int num_arguments_{}; + int num_attrs_{-1}; int num_results_{-1}; llvm::SmallVector value_or_attrs_; @@ -136,15 +149,15 @@ class KernelFrameBuilder : public KernelFrame { public: void AddArgument(Value* value) { CHECK(value); - CHECK_EQ(num_results_, -1) - << "Should call AddArgument before calling SetNumResults"; + CHECK_EQ(num_attrs_, -1) + << "Should call AddArgument before calling SetAttributes"; value_or_attrs_.push_back(value); ++num_arguments_; } void SetResults(llvm::ArrayRef values) { - CHECK_EQ(num_arguments_, static_cast(value_or_attrs_.size())); - CHECK_EQ(num_results_, -1); + CHECK_EQ(num_arguments_ + GetNumAttributes(), + static_cast(value_or_attrs_.size())); for (Value* x : values) { value_or_attrs_.push_back(x); } @@ -152,9 +165,8 @@ class KernelFrameBuilder : public KernelFrame { } void SetNumResults(size_t n) { - CHECK_EQ(num_arguments_, static_cast(value_or_attrs_.size())); - CHECK_EQ(num_results_, -1); - num_results_ = n; + CHECK_EQ(num_arguments_ + GetNumAttributes(), + static_cast(value_or_attrs_.size())); for (size_t i = 0; i < n; i++) { value_or_attrs_.emplace_back(new Value); } @@ -162,18 +174,20 @@ class KernelFrameBuilder : public KernelFrame { void SetResultAt(int result_id, Value* value) { CHECK_EQ(static_cast(value_or_attrs_.size()), - num_arguments_ + num_results_) + num_arguments_ + GetNumAttributes() + num_results_) << "Call SetNumResults first"; - CHECK_LT(result_id + num_arguments_, + CHECK_LT(result_id + num_arguments_ + GetNumAttributes(), static_cast(value_or_attrs_.size())); CHECK(value); - value_or_attrs_[num_arguments_ + result_id]->set(value); + value_or_attrs_[num_arguments_ + GetNumAttributes() + result_id]->set( + value); } void Reset() { value_or_attrs_.clear(); num_arguments_ = 0; num_results_ = -1; + num_attrs_ = -1; } }; diff --git a/paddle/infrt/host_context/kernel_frame_test.cc b/paddle/infrt/host_context/kernel_frame_test.cc new file mode 100644 index 00000000000000..d67ee6a9af730c --- /dev/null +++ b/paddle/infrt/host_context/kernel_frame_test.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/infrt/host_context/kernel_frame.h" + +namespace infrt { +namespace host_context { +/* +TEST(KernelRegistry, basic) { + KernelFrameBuilder kernel_frame; + + Value arg_0(std::string{"arg_0"}); + Value arg_1(std::string{"arg_1"}); + Value arg_2(std::string{"arg_2"}); + Value res_0(std::string{"res_0"}); + Value res_1(std::string{"res_1"}); + Value attr_0(std::string{"attr_0"}); + + kernel_frame.AddArgument(&arg_0); + kernel_frame.AddArgument(&arg_1); + kernel_frame.AddArgument(&arg_2); + kernel_frame.SetResults({&res_0, &res_1}); + kernel_frame.AddAttribute(&attr_0); + + CHECK_EQ(kernel_frame.GetNumArgs(), 3); + CHECK_EQ(kernel_frame.GetNumResults(), 2); + CHECK_EQ(kernel_frame.GetNumAttributes(), 1); + CHECK_EQ(kernel_frame.GetNumElements(), 6UL); + + CHECK_EQ(kernel_frame.GetArgAt(2), "arg_2"); + CHECK_EQ(kernel_frame.GetAttributeAt(0)->get(), "attr_0"); + CHECK_EQ(kernel_frame.GetResults()[1]->get(), "res_1"); +} +*/ + +TEST(KernelRegistry, basic) { + KernelFrameBuilder kernel_frame; + + Value arg_0(std::string{"arg_0"}); + Value arg_1(std::string{"arg_1"}); + Value arg_2(std::string{"arg_2"}); + Value attr_0(std::string{"attr_0"}); + Value res_0(std::string{"res_0"}); + Value res_1(std::string{"res_1"}); + + kernel_frame.AddArgument(&arg_0); + kernel_frame.AddArgument(&arg_1); + kernel_frame.AddArgument(&arg_2); + kernel_frame.AddAttribute(&attr_0); + kernel_frame.SetResults({&res_0, &res_1}); + + CHECK_EQ(kernel_frame.GetNumArgs(), 3); + CHECK_EQ(kernel_frame.GetNumResults(), 2); + CHECK_EQ(kernel_frame.GetNumAttributes(), 1); + CHECK_EQ(kernel_frame.GetNumElements(), 6UL); + + CHECK_EQ(kernel_frame.GetArgAt(2), "arg_2"); + CHECK_EQ(kernel_frame.GetAttributeAt(0)->get(), "attr_0"); + CHECK_EQ(kernel_frame.GetResults()[1]->get(), "res_1"); +} + +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/kernel_utils.h b/paddle/infrt/host_context/kernel_utils.h index 31d411006d2378..7973325ea9fa65 100644 --- a/paddle/infrt/host_context/kernel_utils.h +++ b/paddle/infrt/host_context/kernel_utils.h @@ -265,7 +265,7 @@ struct KernelImpl { static_assert(const_idx == 0, "Arguments and results should appear before attributes."); - auto* value = frame->GetArgAt(in_idx); + auto* value = frame->GetElementAt(in_idx); auto&& arg = value->get(); KernelCallHelper< diff --git a/paddle/infrt/host_context/mlir_to_runtime_translate.cc b/paddle/infrt/host_context/mlir_to_runtime_translate.cc index 3dbc7a702be38d..dc792a8f9a50e4 100644 --- a/paddle/infrt/host_context/mlir_to_runtime_translate.cc +++ b/paddle/infrt/host_context/mlir_to_runtime_translate.cc @@ -150,6 +150,17 @@ boost::optional MlirToRuntimeTranslator::EmitAttribute( return boost::none; } +template <> +boost::optional MlirToRuntimeTranslator::EmitAttribute( + const mlir::Attribute& attr) { + if (!attr.isa()) return boost::none; + if (attr.isa()) { + auto val = attr.cast(); + return val.getValue(); + } + return boost::none; +} + template <> boost::optional MlirToRuntimeTranslator::EmitAttribute( const mlir::Attribute& attr) { @@ -187,6 +198,7 @@ boost::optional MlirToRuntimeTranslator::EmitAttribute( return res; \ } +PROCESS_ARRAY_INT(bool, 1); PROCESS_ARRAY_INT(int16_t, 16); PROCESS_ARRAY_INT(int32_t, 32); PROCESS_ARRAY_INT(int64_t, 64); @@ -262,25 +274,6 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { << GetValue(operand) << " vs " << arg_value; } - // process results - llvm::SmallVector res_values; - for (int i = 0, e = op->getNumResults(); i < e; i++) { - auto res = op->getResult(i); - res_values.push_back(AddValue(res)); - - VLOG(3) << "* op mlir res: " << DumpToString(res) << " " << GetValue(res); - } - impl_->cur_op->SetResults(res_values); - -#ifdef INFRT_DEBUG - { - VLOG(3) << "check result"; - for (int i = 0; i < impl_->cur_op->frame().GetNumResults(); i++) { - VLOG(3) << "+ res value: " << impl_->cur_op->frame().GetResults()[i]; - } - } -#endif - // process attributes auto attrs = op->getAttrs(); @@ -296,6 +289,8 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { impl_->cur_op->AppendAttribute(new Value(*v)); } else if (auto v = EmitAttribute(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(std::move(*v))); + } else if (auto v = EmitAttribute(attr.getValue())) { + impl_->cur_op->AppendAttribute(new Value(*v)); } else if (auto v = EmitAttribute>(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(std::move(*v))); } else if (auto v = EmitAttribute>(attr.getValue())) { @@ -311,6 +306,25 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { } } + // process results + llvm::SmallVector res_values; + for (int i = 0, e = op->getNumResults(); i < e; i++) { + auto res = op->getResult(i); + res_values.push_back(AddValue(res)); + + VLOG(3) << "* op mlir res: " << DumpToString(res) << " " << GetValue(res); + } + impl_->cur_op->SetResults(res_values); + +#ifdef INFRT_DEBUG + { + VLOG(3) << "check result"; + for (int i = 0; i < impl_->cur_op->frame().GetNumResults(); i++) { + VLOG(3) << "+ res value: " << impl_->cur_op->frame().GetResults()[i]; + } + } +#endif + // process regions, we treat regions as attribute. auto num_regions = op->getNumRegions(); if (num_regions > 0) { diff --git a/paddle/infrt/host_context/op_executable.cc b/paddle/infrt/host_context/op_executable.cc index cf40d7315c6a58..59a73e71083286 100644 --- a/paddle/infrt/host_context/op_executable.cc +++ b/paddle/infrt/host_context/op_executable.cc @@ -133,7 +133,8 @@ void OpExecutable::Execute() { VLOG(3) << "execute " << name() << " --- frame args: " << impl_->frame.GetNumArgs() << " results " << impl_->frame.GetNumResults() << " attributes " - << impl_->frame.GetNumAttributes(); + << impl_->frame.GetNumAttributes() << "\n" + << frame().DumpArgTypes(); for (int i = 0; i < impl_->frame.GetNumArgs(); i++) { VLOG(3) << "function arg: " << impl_->frame.GetArgAt(i); } diff --git a/paddle/infrt/host_context/value.h b/paddle/infrt/host_context/value.h index f0478583f7cfd5..015577b4094f70 100644 --- a/paddle/infrt/host_context/value.h +++ b/paddle/infrt/host_context/value.h @@ -47,8 +47,11 @@ namespace host_context { struct MlirFunctionExecutable; +struct None {}; + using ValueVariantType = - Variant T& get() { - CHECK(data.template is()); + CHECK(data.template is()) << "typeid: " << data.index() + << " != " << ValueVariantType::IndexOf; return data.get(); } diff --git a/paddle/infrt/kernel/pten/infershaped/infershaped_kernel_launcher.cc b/paddle/infrt/kernel/pten/infershaped/infershaped_kernel_launcher.cc index 4d91cda04152f1..e1f5abb3e0750f 100644 --- a/paddle/infrt/kernel/pten/infershaped/infershaped_kernel_launcher.cc +++ b/paddle/infrt/kernel/pten/infershaped/infershaped_kernel_launcher.cc @@ -27,9 +27,6 @@ void InferShapedKernelLauncher::CreateKernelFrameForInferShape( values.emplace_back( ::pten::MetaTensor{&value->get<::pten::DenseTensor>()}); infershape_kernel_frame_builder.AddArgument(values.back().get()); - } else if (value->is_type()) { - values.emplace_back(pten::MetaTensor{&value->get()}); - infershape_kernel_frame_builder.AddArgument(values.back().get()); } else { infershape_kernel_frame_builder.AddArgument(value); } diff --git a/paddle/infrt/kernel/pten/infershaped/pten_kernel_launcher.h b/paddle/infrt/kernel/pten/infershaped/pten_kernel_launcher.h index 9a3e978e966b07..ab01e2d6227c57 100644 --- a/paddle/infrt/kernel/pten/infershaped/pten_kernel_launcher.h +++ b/paddle/infrt/kernel/pten/infershaped/pten_kernel_launcher.h @@ -15,6 +15,9 @@ #include +#include + +#include "paddle/infrt/backends/host/pten_context.h" #include "paddle/infrt/host_context/kernel_utils.h" #include "paddle/infrt/kernel/pten/infershaped/infershaped_kernel_launcher.h" #include "paddle/infrt/kernel/pten/infershaped/infershaped_utils.h" @@ -22,6 +25,19 @@ namespace infrt { namespace kernel { +static void FakePtenInferShape(const ::pten::MetaTensor& a, + const ::pten::MetaTensor& b, + bool arg_0, + ::pten::MetaTensor* c) {} + +static void FakePtenKernel(const backends::CpuPtenContext& /*Context*/, + const ::pten::DenseTensor& a, + const ::pten::DenseTensor& b, + bool arg_0, + ::pten::DenseTensor* c) { + std::cout << "@FakePtenKernel@" << std::endl; +} + template ::count}; static const bool turn_on_infer_shape_cache{true}; void Invoke(host_context::KernelFrame* frame) override { +#ifndef NDEBUG + LOG(INFO) << "Kernel.frame: " << frame->DumpArgTypes(); +#endif // Build the infershape KernelFrame if needed. // TODO(Superjomn) add unlikely here. if (infershape_kernel_frame_builder.IsEmpty()) { CreateKernelFrameForInferShape(frame); +#ifndef NDEBUG + LOG(INFO) << "infershape.frame: " + << infershape_kernel_frame_builder.DumpArgTypes(); +#endif } + if (turn_on_infer_shape_cache) { if (!turn_on_infer_shape_cache || IsShapeChanged(num_input_tensors)) { + // INFRT_KERNEL(infershape)::Invoke(&infershape_kernel_frame_builder); ::infrt::host_context::KernelImpl::Invoke( &infershape_kernel_frame_builder); BuildInferShapeCache(num_input_tensors); } } ::infrt::host_context::KernelImpl::Invoke(frame); + // INFRT_KERNEL(kernel)::Invoke(frame); } }; diff --git a/paddle/infrt/kernel/pten/registry.cc b/paddle/infrt/kernel/pten/registry.cc index 1b410b5fc40083..43067e186fe184 100644 --- a/paddle/infrt/kernel/pten/registry.cc +++ b/paddle/infrt/kernel/pten/registry.cc @@ -43,17 +43,15 @@ void RegisterPtenKernels(host_context::KernelRegistry* registry) { registry->AddKernel("pten_dt.fill_dense_tensor.f32", INFRT_KERNEL(infrt::kernel::pten::FillDenseTensorF32)); registry->AddKernel( - "pten.matmul.host.fp32", - std::bind(&kernel::KernelLauncherFunc< - decltype(&::pten::MatmulKernel), - &::pten::MatmulKernel, - decltype(&::pten::MatmulInferMeta), - &::pten::MatmulInferMeta>, - kernel::KernelLauncher< - decltype(&::pten::MatmulKernel), - &::pten::MatmulKernel, - decltype(&::pten::MatmulInferMeta), - &::pten::MatmulInferMeta>(), + "pten_dt.fake_pten_kernel", + std::bind(&KernelLauncherFunc, + KernelLauncher(), std::placeholders::_1)); } diff --git a/paddle/infrt/support/variant.h b/paddle/infrt/support/variant.h index 2f415b21c80109..218960735d442d 100644 --- a/paddle/infrt/support/variant.h +++ b/paddle/infrt/support/variant.h @@ -138,10 +138,10 @@ class Variant { IndexT index() { return index_; } - private: template static constexpr size_t IndexOf = TupleIndexOf::value; + private: static constexpr size_t kStorageSize = std::max({sizeof(Ts)...}); static constexpr size_t kAlignment = std::max({alignof(Ts)...}); diff --git a/paddle/infrt/tests/dialect/pten/dense_tensor.mlir b/paddle/infrt/tests/dialect/pten/dense_tensor.mlir index 88f5b289fd9f84..c089b5f3c2be47 100644 --- a/paddle/infrt/tests/dialect/pten/dense_tensor.mlir +++ b/paddle/infrt/tests/dialect/pten/dense_tensor.mlir @@ -1,11 +1,12 @@ -// RUN: infrtopt %s | FileCheck %s +// RUN: infrtexec %s | FileCheck %s -// CHECK-LABEL: @basic_tensor -func @basic_tensor() { - %a = "pten_dt.create_allocator.cpu" (): () -> !pten.CPU_allocator - %b = "pten_dt.create_context.cpu" (): () -> !pten.CPU_context - %c = "pten_dt.create_dense_tensor.cpu.f32.nchw" (%a) {dims=[1:i64], lod=[1:i64]}: (!pten.CPU_allocator) -> (!infrt.tensor) - // "pten_dt.fill_dense_tensor.f32" (%c) {value=[1.0:f32]} : (!infrt.tensor) -> () +// CHECK-LABEL: @fake_pten_kernel_execute +func @fake_pten_kernel_execute() { + %allocator = "pten_dt.create_allocator.cpu" (): () -> !pten.CPU_allocator + %ctx = "pten_dt.create_context.cpu" (): () -> !pten.CPU_context + %t = "pten_dt.create_dense_tensor.cpu.f32.nchw" (%allocator) {dims=[1:i64], lod=[1:i64]}: (!pten.CPU_allocator) -> (!infrt.tensor) + // CHECK: @FakePtenKernel@ + %d = "pten_dt.fake_pten_kernel" (%ctx, %t, %t) {transpose_x=false} : (!pten.CPU_context, !infrt.tensor, !infrt.tensor) -> (!infrt.tensor) infrt.return }