diff --git a/paddle/infrt/naive/CMakeLists.txt b/paddle/infrt/naive/CMakeLists.txt index edb7b8a9121c8d..c90c6e7ba7b88e 100644 --- a/paddle/infrt/naive/CMakeLists.txt +++ b/paddle/infrt/naive/CMakeLists.txt @@ -4,4 +4,5 @@ cc_library(infrt_naive SRCS meta_tensor.cc infershaped/infershaped_kernel_launchers.cc ) -cc_test_tiny(test_infrt_infershape_launchers SRCS infershaped/infershape_launchers_test.cc DEPS infrt) +cc_test_tiny(test_infrt_infershape_launchers SRCS +infershaped/infershape_launchers_test.cc DEPS infrt) diff --git a/paddle/infrt/naive/infershaped/elementwise_add.h b/paddle/infrt/naive/infershaped/elementwise_add.h index c79929822b9a3b..ee044e38da03dd 100644 --- a/paddle/infrt/naive/infershaped/elementwise_add.h +++ b/paddle/infrt/naive/infershaped/elementwise_add.h @@ -17,6 +17,7 @@ #include "paddle/infrt/host_context/kernel_utils.h" #include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h" +#include "paddle/infrt/naive/infershaped/infershaped_utils.h" // This file contains a example of the infershape ElementwiseAdd kernel. // Some of the following code should be generated from PTEN by script. @@ -32,17 +33,19 @@ static void ElementwiseAddInferShape(const MetaTensor& a, *c->mutable_shape() = a.shape(); } -static void ElementwiseAdd(const tensor::DenseHostTensor& a, +static void ElementwiseAdd(tensor::DenseHostTensor* /*Context*/, + const tensor::DenseHostTensor& a, const tensor::DenseHostTensor& b, tensor::DenseHostTensor* c) {} -// TODO(zhiqiang) This class should be generated by a script offline. -class ElementwiseAddLauncher : public InferShapedKernelLauncher { +template +class KernelLauncher : public InferShapedKernelLauncher { public: - static const uint16_t input_tensor_indices[2]; - static const uint16_t num_input_tensors{2}; + static const uint16_t num_input_tensors{InferShapeHelper::count}; static const bool turn_on_infer_shape_cache{true}; - void Invoke(host_context::KernelFrame* frame) override { // Build the infershape KernelFrame if needed. // TODO(Superjomn) add unlikely here. @@ -50,21 +53,16 @@ class ElementwiseAddLauncher : public InferShapedKernelLauncher { CreateKernelFrameForInferShape(frame); } if (turn_on_infer_shape_cache) { - if (IsShapeChanged(input_tensor_indices, num_input_tensors)) { - INFRT_KERNEL(ElementwiseAddInferShape) - (&infershape_kernel_frame_builder); - BuildInferShapeCache(input_tensor_indices, num_input_tensors); + if (!turn_on_infer_shape_cache || IsShapeChanged(num_input_tensors)) { + ::infrt::host_context::KernelImpl::Invoke( + &infershape_kernel_frame_builder); + BuildInferShapeCache(num_input_tensors); } - } else { - INFRT_KERNEL(ElementwiseAddInferShape)(&infershape_kernel_frame_builder); - BuildInferShapeCache(input_tensor_indices, num_input_tensors); } - INFRT_KERNEL(ElementwiseAdd)(frame); + ::infrt::host_context::KernelImpl::Invoke(frame); } }; -const uint16_t ElementwiseAddLauncher::input_tensor_indices[2] = {0, 1}; - } // namespace naive } // namespace infrt diff --git a/paddle/infrt/naive/infershaped/infershape_launchers_test.cc b/paddle/infrt/naive/infershaped/infershape_launchers_test.cc index 317323d7c5f519..ba6fdbdd5783f5 100644 --- a/paddle/infrt/naive/infershaped/infershape_launchers_test.cc +++ b/paddle/infrt/naive/infershaped/infershape_launchers_test.cc @@ -17,11 +17,24 @@ #include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h" #include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h" #include "paddle/infrt/naive/infershaped/infershaped_registry.h" +#include "paddle/infrt/naive/infershaped/infershaped_utils.h" #include "paddle/infrt/tensor/dense_host_tensor.h" namespace infrt { namespace naive { +namespace { +static void ElementwiseAddTest(const tensor::DenseHostTensor& a, + const tensor::DenseHostTensor& b, + tensor::DenseHostTensor* c); +} + +TEST(utils, registry) { + constexpr uint8_t count = + InferShapeHelper::count; + CHECK_EQ(count, 2U); +} + TEST(ElementwiseAdd, registry) { InferShapedKernelRegistry registry; RegisterInferShapeLaunchers(®istry); @@ -35,6 +48,7 @@ TEST(ElementwiseAdd, registry) { tensor::DenseHostTensor c({2, 8}, GetDType()); host_context::KernelFrameBuilder kernel_frame_builder; + kernel_frame_builder.AddArgument(new host_context::Value(0)); kernel_frame_builder.AddArgument(new host_context::Value(std::move(a))); kernel_frame_builder.AddArgument(new host_context::Value(std::move(b))); kernel_frame_builder.SetResults({new host_context::Value(std::move(c))}); diff --git a/paddle/infrt/naive/infershaped/infershaped_kernel_launcher.cc b/paddle/infrt/naive/infershaped/infershaped_kernel_launcher.cc index 9ef9d9f2b7ba20..6a2c4a51ecdb29 100644 --- a/paddle/infrt/naive/infershaped/infershaped_kernel_launcher.cc +++ b/paddle/infrt/naive/infershaped/infershaped_kernel_launcher.cc @@ -20,7 +20,7 @@ namespace naive { void InferShapedKernelLauncher::CreateKernelFrameForInferShape( host_context::KernelFrame* frame) { for (host_context::Value* value : - frame->GetValues(0, frame->GetNumElements())) { + frame->GetValues(1, frame->GetNumElements() - 1)) { // TODO(Superjomn) To extend this. if (value->is_type()) { values.emplace_back(MetaTensor{&value->get()}); @@ -32,27 +32,24 @@ void InferShapedKernelLauncher::CreateKernelFrameForInferShape( } void InferShapedKernelLauncher::BuildInferShapeCache( - const uint16_t* input_indices, const uint16_t num_inputs) { + const uint16_t num_inputs) { tensor_shape_cache.resize(num_inputs); for (uint16_t i = 0; i < num_inputs; i++) { tensor_shape_cache[i] = - infershape_kernel_frame_builder.GetArgAt(input_indices[i]) - ->get() - .shape(); + infershape_kernel_frame_builder.GetArgAt(i)->get().shape(); } } bool InferShapedKernelLauncher::IsShapeChanged( - const uint16_t* input_indices, const uint16_t num_inputs) const { + const uint16_t num_inputs) const { if (tensor_shape_cache.empty() && !infershape_kernel_frame_builder.IsEmpty()) return true; bool changed = false; for (uint16_t i = 0; i < num_inputs && !changed; i++) { - changed = changed || (tensor_shape_cache[i] != - infershape_kernel_frame_builder - .GetArgAt(input_indices[i]) - .shape()); + changed = changed || + (tensor_shape_cache[i] != + infershape_kernel_frame_builder.GetArgAt(i).shape()); } return changed; } diff --git a/paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h b/paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h index 14c4beaf937d0a..890a779ed24032 100644 --- a/paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h +++ b/paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h @@ -39,12 +39,10 @@ struct InferShapedKernelLauncher { //! Build or update the infer-shape cache using the latest shape from //! InferShapeFrame. - void BuildInferShapeCache(const uint16_t* input_indices, - const uint16_t num_inputs); + void BuildInferShapeCache(const uint16_t num_inputs); //! Compare the latest shape with the shape cache. - bool IsShapeChanged(const uint16_t* input_indices, - const uint16_t num_inputs) const; + bool IsShapeChanged(const uint16_t num_inputs) const; // values to hold the TensorMeta. llvm::SmallVector values; diff --git a/paddle/infrt/naive/infershaped/infershaped_kernel_launchers.cc b/paddle/infrt/naive/infershaped/infershaped_kernel_launchers.cc index 928a43da3e2191..e570b3521b795a 100644 --- a/paddle/infrt/naive/infershaped/infershaped_kernel_launchers.cc +++ b/paddle/infrt/naive/infershaped/infershaped_kernel_launchers.cc @@ -13,12 +13,18 @@ // limitations under the License. #include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h" - #include "paddle/infrt/naive/infershaped/elementwise_add.h" #include "paddle/infrt/naive/infershaped/infershaped_registry.h" + namespace infrt { namespace naive { +using ElementwiseAddLauncher = + KernelLauncher; + void RegisterInferShapeLaunchers(InferShapedKernelRegistry* registry) { registry->AddKernel("elementwise_add", INFERSHAPED_KERNEL_CREATOR(ElementwiseAddLauncher)); diff --git a/paddle/infrt/naive/infershaped/infershaped_utils.h b/paddle/infrt/naive/infershaped/infershaped_utils.h new file mode 100644 index 00000000000000..8155d87231a8f9 --- /dev/null +++ b/paddle/infrt/naive/infershaped/infershaped_utils.h @@ -0,0 +1,77 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/infrt/tensor/dense_host_tensor.h" + +namespace infrt { +namespace naive { +namespace infershaped { + +using KeyType = const tensor::DenseHostTensor&; +using CountType = uint8_t; + +constexpr CountType value(std::true_type) { return 1; } + +constexpr CountType value(std::false_type) { return 0; } + +template +constexpr CountType value() { + return value(std::integral_constant::value>{}); +} + +template +constexpr CountType count(CountType num) { + return num; +} + +template +constexpr CountType count() { + return 0; +} + +template <> +constexpr CountType count(CountType num) { + return num + 1; +} + +template <> +constexpr CountType count() { + return 1; +} + +template +constexpr CountType count(CountType num) { + return count(num + value()); +} + +template +constexpr CountType count() { + return count(value()); +} + +} // namespace infershaped + +template +struct InferShapeHelper; + +template +struct InferShapeHelper { + static constexpr int count = infershaped::count(); +}; + +} // namespace naive +} // namespace infrt