Skip to content

Commit b3e049f

Browse files
authored
infershaped autogen (PR #1), test=develop (PaddlePaddle#39405)
1 parent 1bd7a14 commit b3e049f

7 files changed

Lines changed: 123 additions & 32 deletions

paddle/infrt/naive/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ cc_library(infrt_naive SRCS meta_tensor.cc
44
infershaped/infershaped_kernel_launchers.cc
55
)
66

7-
cc_test_tiny(test_infrt_infershape_launchers SRCS infershaped/infershape_launchers_test.cc DEPS infrt)
7+
cc_test_tiny(test_infrt_infershape_launchers SRCS
8+
infershaped/infershape_launchers_test.cc DEPS infrt)

paddle/infrt/naive/infershaped/elementwise_add.h

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "paddle/infrt/host_context/kernel_utils.h"
1919
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
20+
#include "paddle/infrt/naive/infershaped/infershaped_utils.h"
2021

2122
// This file contains a example of the infershape ElementwiseAdd kernel.
2223
// Some of the following code should be generated from PTEN by script.
@@ -32,39 +33,36 @@ static void ElementwiseAddInferShape(const MetaTensor& a,
3233
*c->mutable_shape() = a.shape();
3334
}
3435

35-
static void ElementwiseAdd(const tensor::DenseHostTensor& a,
36+
static void ElementwiseAdd(tensor::DenseHostTensor* /*Context*/,
37+
const tensor::DenseHostTensor& a,
3638
const tensor::DenseHostTensor& b,
3739
tensor::DenseHostTensor* c) {}
3840

39-
// TODO(zhiqiang) This class should be generated by a script offline.
40-
class ElementwiseAddLauncher : public InferShapedKernelLauncher {
41+
template <typename KernelFunc,
42+
KernelFunc kernel,
43+
typename InferShapedFunc,
44+
InferShapedFunc infershape>
45+
class KernelLauncher : public InferShapedKernelLauncher {
4146
public:
42-
static const uint16_t input_tensor_indices[2];
43-
static const uint16_t num_input_tensors{2};
47+
static const uint16_t num_input_tensors{InferShapeHelper<KernelFunc>::count};
4448
static const bool turn_on_infer_shape_cache{true};
45-
4649
void Invoke(host_context::KernelFrame* frame) override {
4750
// Build the infershape KernelFrame if needed.
4851
// TODO(Superjomn) add unlikely here.
4952
if (infershape_kernel_frame_builder.IsEmpty()) {
5053
CreateKernelFrameForInferShape(frame);
5154
}
5255
if (turn_on_infer_shape_cache) {
53-
if (IsShapeChanged(input_tensor_indices, num_input_tensors)) {
54-
INFRT_KERNEL(ElementwiseAddInferShape)
55-
(&infershape_kernel_frame_builder);
56-
BuildInferShapeCache(input_tensor_indices, num_input_tensors);
56+
if (!turn_on_infer_shape_cache || IsShapeChanged(num_input_tensors)) {
57+
::infrt::host_context::KernelImpl<InferShapedFunc, infershape>::Invoke(
58+
&infershape_kernel_frame_builder);
59+
BuildInferShapeCache(num_input_tensors);
5760
}
58-
} else {
59-
INFRT_KERNEL(ElementwiseAddInferShape)(&infershape_kernel_frame_builder);
60-
BuildInferShapeCache(input_tensor_indices, num_input_tensors);
6161
}
6262

63-
INFRT_KERNEL(ElementwiseAdd)(frame);
63+
::infrt::host_context::KernelImpl<KernelFunc, kernel>::Invoke(frame);
6464
}
6565
};
6666

67-
const uint16_t ElementwiseAddLauncher::input_tensor_indices[2] = {0, 1};
68-
6967
} // namespace naive
7068
} // namespace infrt

paddle/infrt/naive/infershaped/infershape_launchers_test.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,24 @@
1717
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
1818
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h"
1919
#include "paddle/infrt/naive/infershaped/infershaped_registry.h"
20+
#include "paddle/infrt/naive/infershaped/infershaped_utils.h"
2021
#include "paddle/infrt/tensor/dense_host_tensor.h"
2122

2223
namespace infrt {
2324
namespace naive {
2425

26+
namespace {
27+
static void ElementwiseAddTest(const tensor::DenseHostTensor& a,
28+
const tensor::DenseHostTensor& b,
29+
tensor::DenseHostTensor* c);
30+
}
31+
32+
TEST(utils, registry) {
33+
constexpr uint8_t count =
34+
InferShapeHelper<decltype(&ElementwiseAddTest)>::count;
35+
CHECK_EQ(count, 2U);
36+
}
37+
2538
TEST(ElementwiseAdd, registry) {
2639
InferShapedKernelRegistry registry;
2740
RegisterInferShapeLaunchers(&registry);
@@ -35,6 +48,7 @@ TEST(ElementwiseAdd, registry) {
3548
tensor::DenseHostTensor c({2, 8}, GetDType<float>());
3649

3750
host_context::KernelFrameBuilder kernel_frame_builder;
51+
kernel_frame_builder.AddArgument(new host_context::Value(0));
3852
kernel_frame_builder.AddArgument(new host_context::Value(std::move(a)));
3953
kernel_frame_builder.AddArgument(new host_context::Value(std::move(b)));
4054
kernel_frame_builder.SetResults({new host_context::Value(std::move(c))});

paddle/infrt/naive/infershaped/infershaped_kernel_launcher.cc

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace naive {
2020
void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
2121
host_context::KernelFrame* frame) {
2222
for (host_context::Value* value :
23-
frame->GetValues(0, frame->GetNumElements())) {
23+
frame->GetValues(1, frame->GetNumElements() - 1)) {
2424
// TODO(Superjomn) To extend this.
2525
if (value->is_type<tensor::DenseHostTensor>()) {
2626
values.emplace_back(MetaTensor{&value->get<tensor::DenseHostTensor>()});
@@ -32,27 +32,24 @@ void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
3232
}
3333

3434
void InferShapedKernelLauncher::BuildInferShapeCache(
35-
const uint16_t* input_indices, const uint16_t num_inputs) {
35+
const uint16_t num_inputs) {
3636
tensor_shape_cache.resize(num_inputs);
3737
for (uint16_t i = 0; i < num_inputs; i++) {
3838
tensor_shape_cache[i] =
39-
infershape_kernel_frame_builder.GetArgAt(input_indices[i])
40-
->get<MetaTensor>()
41-
.shape();
39+
infershape_kernel_frame_builder.GetArgAt(i)->get<MetaTensor>().shape();
4240
}
4341
}
4442

4543
bool InferShapedKernelLauncher::IsShapeChanged(
46-
const uint16_t* input_indices, const uint16_t num_inputs) const {
44+
const uint16_t num_inputs) const {
4745
if (tensor_shape_cache.empty() && !infershape_kernel_frame_builder.IsEmpty())
4846
return true;
4947

5048
bool changed = false;
5149
for (uint16_t i = 0; i < num_inputs && !changed; i++) {
52-
changed = changed || (tensor_shape_cache[i] !=
53-
infershape_kernel_frame_builder
54-
.GetArgAt<MetaTensor>(input_indices[i])
55-
.shape());
50+
changed = changed ||
51+
(tensor_shape_cache[i] !=
52+
infershape_kernel_frame_builder.GetArgAt<MetaTensor>(i).shape());
5653
}
5754
return changed;
5855
}

paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,10 @@ struct InferShapedKernelLauncher {
3939

4040
//! Build or update the infer-shape cache using the latest shape from
4141
//! InferShapeFrame.
42-
void BuildInferShapeCache(const uint16_t* input_indices,
43-
const uint16_t num_inputs);
42+
void BuildInferShapeCache(const uint16_t num_inputs);
4443

4544
//! Compare the latest shape with the shape cache.
46-
bool IsShapeChanged(const uint16_t* input_indices,
47-
const uint16_t num_inputs) const;
45+
bool IsShapeChanged(const uint16_t num_inputs) const;
4846

4947
// values to hold the TensorMeta.
5048
llvm::SmallVector<host_context::ValueRef, 3> values;

paddle/infrt/naive/infershaped/infershaped_kernel_launchers.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,18 @@
1313
// limitations under the License.
1414

1515
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h"
16-
1716
#include "paddle/infrt/naive/infershaped/elementwise_add.h"
1817
#include "paddle/infrt/naive/infershaped/infershaped_registry.h"
18+
1919
namespace infrt {
2020
namespace naive {
2121

22+
using ElementwiseAddLauncher =
23+
KernelLauncher<decltype(&ElementwiseAdd),
24+
&ElementwiseAdd,
25+
decltype(&ElementwiseAddInferShape),
26+
&ElementwiseAddInferShape>;
27+
2228
void RegisterInferShapeLaunchers(InferShapedKernelRegistry* registry) {
2329
registry->AddKernel("elementwise_add",
2430
INFERSHAPED_KERNEL_CREATOR(ElementwiseAddLauncher));
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <type_traits>
18+
#include "paddle/infrt/tensor/dense_host_tensor.h"
19+
20+
namespace infrt {
21+
namespace naive {
22+
namespace infershaped {
23+
24+
using KeyType = const tensor::DenseHostTensor&;
25+
using CountType = uint8_t;
26+
27+
constexpr CountType value(std::true_type) { return 1; }
28+
29+
constexpr CountType value(std::false_type) { return 0; }
30+
31+
template <typename T>
32+
constexpr CountType value() {
33+
return value(std::integral_constant<bool, std::is_same<T, KeyType>::value>{});
34+
}
35+
36+
template <typename FirstArg>
37+
constexpr CountType count(CountType num) {
38+
return num;
39+
}
40+
41+
template <typename FirstArg>
42+
constexpr CountType count() {
43+
return 0;
44+
}
45+
46+
template <>
47+
constexpr CountType count<KeyType>(CountType num) {
48+
return num + 1;
49+
}
50+
51+
template <>
52+
constexpr CountType count<KeyType>() {
53+
return 1;
54+
}
55+
56+
template <typename FirstArg, typename SecondArg, typename... RestOfArgs>
57+
constexpr CountType count(CountType num) {
58+
return count<SecondArg, RestOfArgs...>(num + value<FirstArg>());
59+
}
60+
61+
template <typename FirstArg, typename SecondArg, typename... RestOfArgs>
62+
constexpr CountType count() {
63+
return count<SecondArg, RestOfArgs...>(value<FirstArg>());
64+
}
65+
66+
} // namespace infershaped
67+
68+
template <typename F>
69+
struct InferShapeHelper;
70+
71+
template <typename Return, typename... Args>
72+
struct InferShapeHelper<Return (*)(Args...)> {
73+
static constexpr int count = infershaped::count<Args...>();
74+
};
75+
76+
} // namespace naive
77+
} // namespace infrt

0 commit comments

Comments
 (0)