1717#include < ostream>
1818#include < string>
1919#include < unordered_map>
20+ #include < unordered_set>
2021#include < utility>
2122
2223#include " paddle/pten/common/backend.h"
@@ -37,10 +38,10 @@ using DataLayout = paddle::experimental::DataLayout;
3738/* *
3839 * [ Naming considerations ]
3940 *
40- * The tensor Compute library contains many kernels, and the computation
41+ * The tensor operation library contains many kernels, and the computation
4142 * in each specific scenario is represented by an kernel.
4243 *
43- * We directly named it `Kernel` instead of `Kernel`, the tensor Compute
44+ * We directly named it `Kernel` instead of `Kernel`, the tensor operation
4445 * library here and fluid are independent, avoiding developers from
4546 * misunderstanding the relationship between the two concepts.
4647 */
@@ -52,10 +53,7 @@ using KernelFn = void (*)(KernelContext* ctx);
5253class KernelName final {
5354 public:
5455 KernelName (std::string name, std::string overload_name)
55- : name_(std::move(name)), overload_name_(std::move(overload_name)) {
56- hash_value_ = std::hash<std::string>()(name_) ^
57- (std::hash<std::string>()(overload_name_) << 1 );
58- }
56+ : name_(std::move(name)), overload_name_(std::move(overload_name)) {}
5957
6058 KernelName (const std::string& kernel_name) {
6159 ParseNameAndOverloadNameFromString (kernel_name);
@@ -68,24 +66,26 @@ class KernelName final {
6866
6967 const std::string& name () const { return name_; }
7068 const std::string& overload_name () const { return overload_name_; }
71- size_t hash_value () const { return hash_value_; }
7269
7370 struct Hash {
7471 size_t operator ()(const KernelName& kernel_name) const {
75- return kernel_name.hash_value ();
72+ return std::hash<std::string>()(kernel_name.name ()) ^
73+ (std::hash<std::string>()(kernel_name.overload_name ()) << 1 );
7674 }
7775 };
7876
77+ size_t hash_value () const { return Hash ()(*this ); }
78+
7979 bool operator <(const KernelName& kernel_name) const {
80- return hash_value_ < kernel_name.hash_value ();
80+ return hash_value () < kernel_name.hash_value ();
8181 }
8282
8383 bool operator ==(const KernelName& kernel_name) const {
84- return hash_value_ == kernel_name.hash_value ();
84+ return hash_value () == kernel_name.hash_value ();
8585 }
8686
8787 bool operator !=(const KernelName& kernel_name) const {
88- return hash_value_ != kernel_name.hash_value ();
88+ return hash_value () != kernel_name.hash_value ();
8989 }
9090
9191 private:
@@ -98,57 +98,45 @@ class KernelName final {
9898 name_ = kernel_name.substr (0 , pos);
9999 overload_name_ = kernel_name.substr (pos + 1 , kernel_name.size ());
100100 }
101- hash_value_ = std::hash<std::string>()(name_) ^
102- (std::hash<std::string>()(overload_name_) << 1 );
103101 }
104102
105- // The members cannot be modified except by constructing,
106- // because the hash value need to be re calculated
107- // TODO(chenweihang): use string_view later?
103+ // TODO(chenweihang): use string_view to improve performance later
108104 std::string name_;
109105 std::string overload_name_;
110- // Avoid calculating Hash value at runtime
111- size_t hash_value_;
112106};
113107
114108class KernelKey {
115109 public:
116110 KernelKey () = default ;
117111
118112 KernelKey (Backend backend, DataLayout layout, DataType dtype)
119- : backend_(backend), layout_(layout), dtype_(dtype) {
120- // |----31-20------|---19-12---|---11-8----|---7-0---|
121- // | For extension | DataType | DataLayout | Backend |
122-
123- hash_value_ = 0 ;
124- hash_value_ |= static_cast <uint8_t >(backend_);
125- hash_value_ |= (static_cast <uint8_t >(layout_) << kBackendBitLength );
126- hash_value_ |= (static_cast <uint16_t >(dtype_)
127- << (kBackendBitLength + kDataTypeBitLength ));
128- }
113+ : backend_(backend), layout_(layout), dtype_(dtype) {}
129114
130115 Backend backend () const { return backend_; }
131116 DataLayout layout () const { return layout_; }
132117 DataType dtype () const { return dtype_; }
133118
134- uint32_t hash_value () const { return hash_value_; }
119+ struct Hash {
120+ // Note: Now the number of bits we need does not exceed 32 bits, so there is
121+ // no need to use 64 bits. If needed in the future, it can be expanded,
122+ // but now we don’t over-design.
123+ uint32_t operator ()(const KernelKey& key) const ;
124+ };
125+
126+ uint32_t hash_value () const { return Hash ()(*this ); }
135127
136128 bool operator <(const KernelKey& key) const {
137- return hash_value_ < key.hash_value ();
129+ return hash_value () < key.hash_value ();
138130 }
139131
140132 bool operator ==(const KernelKey& key) const {
141- return hash_value_ == key.hash_value ();
133+ return hash_value () == key.hash_value ();
142134 }
143135
144136 bool operator !=(const KernelKey& key) const {
145- return hash_value_ != key.hash_value ();
137+ return hash_value () != key.hash_value ();
146138 }
147139
148- struct Hash {
149- uint32_t operator ()(const KernelKey& key) const { return key.hash_value (); }
150- };
151-
152140 private:
153141 // In total should be smaller than 32.
154142 constexpr static int kBackendBitLength = 8 ;
@@ -158,12 +146,6 @@ class KernelKey {
158146 Backend backend_{Backend::UNDEFINED};
159147 DataLayout layout_{DataLayout::UNDEFINED};
160148 DataType dtype_{DataType::UNDEFINED};
161-
162- // Avoid calculating Hash value at runtime.
163- // Note: Now the number of bits we need does not exceed 32 bits, so there is
164- // no need to use 64 bits. If needed in the future, it can be expanded,
165- // but now we don’t over-design.
166- uint32_t hash_value_;
167149};
168150
169151// TODO(chenweihang): how deal with vector<Param>?
@@ -282,7 +264,13 @@ class KernelFactory {
282264
283265 KernelMap& kernels () { return kernels_; }
284266
285- bool ContainsKernel (const char * name) const ;
267+ void InsertCompatibleOpType (const std::string& op_type) {
268+ compatible_op_types_.insert (op_type);
269+ }
270+
271+ bool HasCompatiblePtenKernel (const std::string& op_type) const {
272+ return compatible_op_types_.count (op_type) > 0 ;
273+ }
286274
287275 const Kernel& SelectKernelOrThrowError (const KernelName& kernel_name,
288276 const KernelKey& kernel_key) const ;
@@ -299,6 +287,9 @@ class KernelFactory {
299287 KernelFactory () = default ;
300288
301289 KernelMap kernels_;
290+ // Used to be compatible with the original execution system and
291+ // quickly confirm whether the new kernel can be called
292+ std::unordered_set<std::string> compatible_op_types_;
302293};
303294
304295/* * operator << overload **/
0 commit comments