Skip to content

Commit 7e659a0

Browse files
authored
Merge pull request #6932 from dzhwinter/fix/kernelkey
"remove hash combine"
2 parents 37e9626 + a521ace commit 7e659a0

File tree

2 files changed

+13
-23
lines changed

2 files changed

+13
-23
lines changed

paddle/framework/op_kernel_type.h

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,33 +22,23 @@ limitations under the License. */
2222
namespace paddle {
2323
namespace framework {
2424

25-
/*
26-
Refer to https://stackoverflow.com/questions/35985960/
27-
c-why-is-boosthash-combine-the-best-way-to-combine-hash-values
28-
*/
29-
template <class T>
30-
inline void HashCombine(const T& v, std::size_t* seed) {
31-
std::hash<T> hasher;
32-
*seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
33-
}
34-
3525
struct OpKernelType {
3626
struct Hash {
3727
size_t operator()(const OpKernelType& key) const {
38-
int place = key.place_.which();
39-
int data_type = static_cast<int>(key.data_type_);
40-
int data_layout = static_cast<int>(key.data_layout_);
41-
int library_type = static_cast<int>(key.library_type_);
42-
43-
size_t seed = 0;
44-
HashCombine(place, &seed);
45-
HashCombine(data_type, &seed);
46-
HashCombine(data_layout, &seed);
47-
HashCombine(library_type, &seed);
48-
return seed;
28+
int place = key.place_.which() + (1 << LEFT_SHIFT);
29+
int data_type =
30+
static_cast<int>(key.data_type_) + (1 << (LEFT_SHIFT + 1));
31+
int data_layout =
32+
static_cast<int>(key.data_layout_) + (1 << (LEFT_SHIFT + 2));
33+
int library_type =
34+
static_cast<int>(key.library_type_) + (1 << (LEFT_SHIFT + 3));
35+
std::hash<int> hasher;
36+
return hasher(place + data_type + data_layout + library_type);
4937
}
5038
};
5139

40+
// place, data_type, library_type kinds less than 2^8
41+
constexpr static int LEFT_SHIFT = 8;
5242
proto::DataType data_type_;
5343
DataLayout data_layout_;
5444
platform::Place place_;

paddle/platform/device_context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ class DeviceContextPool {
137137

138138
private:
139139
static DeviceContextPool* pool;
140+
constexpr static int LEFT_SHIFT = 8;
140141
struct Hash {
141142
std::hash<int> hash_;
142143
size_t operator()(const platform::Place& place) const {
143-
int pre_hash = place.which()
144-
<< (sizeof(int) * 8 - NUM_PLACE_TYPE_LIMIT_IN_BIT);
144+
int pre_hash = place.which() + (1 << LEFT_SHIFT);
145145
if (platform::is_gpu_place(place)) {
146146
pre_hash += boost::get<platform::GPUPlace>(place).GetDeviceId();
147147
}

0 commit comments

Comments
 (0)