Skip to content

Commit fb224ab

Browse files
committed
polish kernel factory and kernel registry
1 parent 76a588e commit fb224ab

File tree

6 files changed

+54
-71
lines changed

6 files changed

+54
-71
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,20 +1080,6 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
10801080
this->InferShape(&infer_shape_ctx);
10811081
}
10821082

1083-
static std::string RuntimeContextDebugString(const RuntimeContext& ctx) {
1084-
std::stringstream ss;
1085-
ss << "RuntimeContext(Inputs: ";
1086-
for (auto& var_pair : ctx.inputs) {
1087-
ss << var_pair.first << ", ";
1088-
}
1089-
ss << "Outputs: ";
1090-
for (auto& var_pair : ctx.outputs) {
1091-
ss << var_pair.first << ", ";
1092-
}
1093-
ss << ")";
1094-
return ss.str();
1095-
}
1096-
10971083
void OperatorWithKernel::RunImpl(const Scope& scope,
10981084
const platform::Place& place) const {
10991085
// To reduce the elapsed time of HasAttr, we use bool variable to record the
@@ -1144,7 +1130,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
11441130
// and RCOM backend, the XPU, NPU and MKLDNN will be supported in the second
11451131
// phase
11461132
if (FLAGS_run_pt_kernel &&
1147-
pten::KernelFactory::Instance().ContainsKernel(type_.c_str())) {
1133+
pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) {
11481134
if (pt_kernel_signature_.get() == nullptr || pt_kernel_.get() == nullptr) {
11491135
ChoosePtenKernel(exe_ctx);
11501136
}
@@ -1651,10 +1637,9 @@ void OperatorWithKernel::ParseInputDataType(
16511637
if (t != nullptr) {
16521638
PADDLE_ENFORCE_EQ(
16531639
t->IsInitialized(), true,
1654-
platform::errors::InvalidArgument(
1655-
"The Tensor in the %s Op's Input Variable %s(%s) is "
1656-
"not initialized.",
1657-
Type(), name, Inputs().at(name).at(i)));
1640+
platform::errors::InvalidArgument("The %s Op's Input Variable `%s` "
1641+
"contains uninitialized Tensor.",
1642+
Type(), name));
16581643
proto::VarType::Type tmp = t->type();
16591644
PADDLE_ENFORCE(tmp == *data_type || *data_type == default_data_type,
16601645
platform::errors::InvalidArgument(
@@ -1789,8 +1774,6 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
17891774

17901775
pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
17911776
const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const {
1792-
VLOG(1) << RuntimeContextDebugString(ctx);
1793-
17941777
// TODO(chenweihang): now only work for very simple case,
17951778
// many cases need to be deal with later:
17961779
// 1. the input and output are not tensor

paddle/fluid/imperative/prepared_operator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
153153
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
154154

155155
if (FLAGS_run_pt_kernel &&
156-
pten::KernelFactory::Instance().ContainsKernel(op.Type().c_str())) {
156+
pten::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) {
157157
auto pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx);
158158

159159
VLOG(1) << framework::KernelSignatureToString(pt_kernel_signature);

paddle/fluid/pybind/op_function_generator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ GenerateOpFunctions() {
557557
// since only OperatorWithKernel can run in dygraph mode.
558558
// if the pten lib contains op kernel, we still generate ops method
559559
if (!all_kernels.count(op_type) &&
560-
!pten::KernelFactory::Instance().ContainsKernel(op_type.c_str())) {
560+
!pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type)) {
561561
continue;
562562
}
563563

paddle/pten/core/kernel_factory.cc

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,24 @@
1919

2020
namespace pten {
2121

22+
uint32_t KernelKey::Hash::operator()(const KernelKey& key) const {
23+
uint32_t hash_value = 0;
24+
// |----31-20------|---19-12---|---11-8----|---7-0---|
25+
// | For extension | DataType | DataLayout | Backend |
26+
hash_value |= static_cast<uint8_t>(key.backend());
27+
hash_value |=
28+
(static_cast<uint8_t>(key.layout()) << KernelKey::kBackendBitLength);
29+
hash_value |=
30+
(static_cast<uint16_t>(key.dtype())
31+
<< (KernelKey::kBackendBitLength + KernelKey::kDataTypeBitLength));
32+
return hash_value;
33+
}
34+
2235
KernelFactory& KernelFactory::Instance() {
2336
static KernelFactory g_op_kernel_factory;
2437
return g_op_kernel_factory;
2538
}
2639

27-
bool KernelFactory::ContainsKernel(const char* kernel_name) const {
28-
auto iter = kernels_.find(KernelName(kernel_name, ""));
29-
return (iter != kernels_.end());
30-
}
31-
3240
Kernel KernelFactory::SelectKernel(const KernelName& kernel_name,
3341
const KernelKey& kernel_key) const {
3442
auto iter = kernels_.find(kernel_name);

paddle/pten/core/kernel_factory.h

Lines changed: 34 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <ostream>
1818
#include <string>
1919
#include <unordered_map>
20+
#include <unordered_set>
2021
#include <utility>
2122

2223
#include "paddle/pten/common/backend.h"
@@ -37,10 +38,10 @@ using DataLayout = paddle::experimental::DataLayout;
3738
/**
3839
* [ Naming considerations ]
3940
*
40-
* The tensor Compute library contains many kernels, and the computation
41+
* The tensor operation library contains many kernels, and the computation
4142
* in each specific scenario is represented by an kernel.
4243
*
43-
* We directly named it `Kernel` instead of `Kernel`, the tensor Compute
44+
* We directly named it `Kernel` instead of `Kernel`, the tensor operation
4445
* library here and fluid are independent, avoiding developers from
4546
* misunderstanding the relationship between the two concepts.
4647
*/
@@ -52,10 +53,7 @@ using KernelFn = void (*)(KernelContext* ctx);
5253
class KernelName final {
5354
public:
5455
KernelName(std::string name, std::string overload_name)
55-
: name_(std::move(name)), overload_name_(std::move(overload_name)) {
56-
hash_value_ = std::hash<std::string>()(name_) ^
57-
(std::hash<std::string>()(overload_name_) << 1);
58-
}
56+
: name_(std::move(name)), overload_name_(std::move(overload_name)) {}
5957

6058
KernelName(const std::string& kernel_name) {
6159
ParseNameAndOverloadNameFromString(kernel_name);
@@ -68,24 +66,26 @@ class KernelName final {
6866

6967
const std::string& name() const { return name_; }
7068
const std::string& overload_name() const { return overload_name_; }
71-
size_t hash_value() const { return hash_value_; }
7269

7370
struct Hash {
7471
size_t operator()(const KernelName& kernel_name) const {
75-
return kernel_name.hash_value();
72+
return std::hash<std::string>()(kernel_name.name()) ^
73+
(std::hash<std::string>()(kernel_name.overload_name()) << 1);
7674
}
7775
};
7876

77+
size_t hash_value() const { return Hash()(*this); }
78+
7979
bool operator<(const KernelName& kernel_name) const {
80-
return hash_value_ < kernel_name.hash_value();
80+
return hash_value() < kernel_name.hash_value();
8181
}
8282

8383
bool operator==(const KernelName& kernel_name) const {
84-
return hash_value_ == kernel_name.hash_value();
84+
return hash_value() == kernel_name.hash_value();
8585
}
8686

8787
bool operator!=(const KernelName& kernel_name) const {
88-
return hash_value_ != kernel_name.hash_value();
88+
return hash_value() != kernel_name.hash_value();
8989
}
9090

9191
private:
@@ -98,57 +98,45 @@ class KernelName final {
9898
name_ = kernel_name.substr(0, pos);
9999
overload_name_ = kernel_name.substr(pos + 1, kernel_name.size());
100100
}
101-
hash_value_ = std::hash<std::string>()(name_) ^
102-
(std::hash<std::string>()(overload_name_) << 1);
103101
}
104102

105-
// The members cannot be modified except by constructing,
106-
// because the hash value need to be re calculated
107-
// TODO(chenweihang): use string_view later?
103+
// TODO(chenweihang): use string_view to improve performance later
108104
std::string name_;
109105
std::string overload_name_;
110-
// Avoid calculating Hash value at runtime
111-
size_t hash_value_;
112106
};
113107

114108
class KernelKey {
115109
public:
116110
KernelKey() = default;
117111

118112
KernelKey(Backend backend, DataLayout layout, DataType dtype)
119-
: backend_(backend), layout_(layout), dtype_(dtype) {
120-
// |----31-20------|---19-12---|---11-8----|---7-0---|
121-
// | For extension | DataType | DataLayout | Backend |
122-
123-
hash_value_ = 0;
124-
hash_value_ |= static_cast<uint8_t>(backend_);
125-
hash_value_ |= (static_cast<uint8_t>(layout_) << kBackendBitLength);
126-
hash_value_ |= (static_cast<uint16_t>(dtype_)
127-
<< (kBackendBitLength + kDataTypeBitLength));
128-
}
113+
: backend_(backend), layout_(layout), dtype_(dtype) {}
129114

130115
Backend backend() const { return backend_; }
131116
DataLayout layout() const { return layout_; }
132117
DataType dtype() const { return dtype_; }
133118

134-
uint32_t hash_value() const { return hash_value_; }
119+
struct Hash {
120+
// Note: Now the number of bits we need does not exceed 32 bits, so there is
121+
// no need to use 64 bits. If needed in the future, it can be expanded,
122+
// but now we don’t over-design.
123+
uint32_t operator()(const KernelKey& key) const;
124+
};
125+
126+
uint32_t hash_value() const { return Hash()(*this); }
135127

136128
bool operator<(const KernelKey& key) const {
137-
return hash_value_ < key.hash_value();
129+
return hash_value() < key.hash_value();
138130
}
139131

140132
bool operator==(const KernelKey& key) const {
141-
return hash_value_ == key.hash_value();
133+
return hash_value() == key.hash_value();
142134
}
143135

144136
bool operator!=(const KernelKey& key) const {
145-
return hash_value_ != key.hash_value();
137+
return hash_value() != key.hash_value();
146138
}
147139

148-
struct Hash {
149-
uint32_t operator()(const KernelKey& key) const { return key.hash_value(); }
150-
};
151-
152140
private:
153141
// In total should be smaller than 32.
154142
constexpr static int kBackendBitLength = 8;
@@ -158,12 +146,6 @@ class KernelKey {
158146
Backend backend_{Backend::UNDEFINED};
159147
DataLayout layout_{DataLayout::UNDEFINED};
160148
DataType dtype_{DataType::UNDEFINED};
161-
162-
// Avoid calculating Hash value at runtime.
163-
// Note: Now the number of bits we need does not exceed 32 bits, so there is
164-
// no need to use 64 bits. If needed in the future, it can be expanded,
165-
// but now we don’t over-design.
166-
uint32_t hash_value_;
167149
};
168150

169151
// TODO(chenweihang): how deal with vector<Param>?
@@ -282,7 +264,13 @@ class KernelFactory {
282264

283265
KernelMap& kernels() { return kernels_; }
284266

285-
bool ContainsKernel(const char* name) const;
267+
void InsertCompatibleOpType(const std::string& op_type) {
268+
compatible_op_types_.insert(op_type);
269+
}
270+
271+
bool HasCompatiblePtenKernel(const std::string& op_type) const {
272+
return compatible_op_types_.count(op_type) > 0;
273+
}
286274

287275
const Kernel& SelectKernelOrThrowError(const KernelName& kernel_name,
288276
const KernelKey& kernel_key) const;
@@ -299,6 +287,9 @@ class KernelFactory {
299287
KernelFactory() = default;
300288

301289
KernelMap kernels_;
290+
// Used to be compatible with the original execution system and
291+
// quickly confirm whether the new kernel can be called
292+
std::unordered_set<std::string> compatible_op_types_;
302293
};
303294

304295
/** operator << overload **/

paddle/pten/core/kernel_registry.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ struct KernelRegistrar {
149149
args_parse_fn(kernel_key, kernel.mutable_args_def());
150150
args_def_fn(&kernel);
151151

152+
KernelFactory::Instance().InsertCompatibleOpType(kernel_name.name());
152153
KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
153154
}
154155
};

0 commit comments

Comments
 (0)