Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/framework/library_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace framework {
// For more details about the design of LibraryType, Please refer to
// https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md#library

enum LibraryType { kPlain = 0; kMKLDNN = 1; kCUDNN = 2; }
enum LibraryType { kPlain = 0, kMKLDNN = 1, kCUDNN = 2 };

} // namespace
} // framework
82 changes: 82 additions & 0 deletions paddle/framework/op_kernel_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/framework/data_layout.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/library_type.h"
#include "paddle/platform/place.h"

namespace paddle {
namespace framework {

/*
Refer to https://stackoverflow.com/questions/35985960/
c-why-is-boosthash-combine-the-best-way-to-combine-hash-values
*/
template <class T>
inline void hash_combine(std::size_t& seed, const T& v) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Google Style.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

std::hash<T> hasher;
seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}

struct OpKernelType {
struct Hash {
size_t operator()(const OpKernelType& key) const {
int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_);
int data_layout = static_cast<int>(key.data_layout_);
int library_type = static_cast<int>(key.library_type_);

size_t seed = 0;
hash_combine(seed, place);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about combine these four fields as one integer, then make a hash_combine function call?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can not simply combine these four fields together. HashCombine is the function to do such work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我的意思是,四次HashCombine调用开销可以省略。

KernelType种类并不多,不会发生碰撞,hash只是让分布均匀一下。
防止碰撞可以这样

constexpr int SHIFT = 8; // every kind of type less than 2^8
int data_type = static_cast<int>(key.data_type_) + 1<<(SHIFT);
int data_layout = static_cast<int>(key.data_layout_) + 1 <<(SHIFT + 1);
int library_type = static_cast<int>(key.library_type_) + 1 << (SHIFT + 2);
int kernel_id = data_type + data_layout + library_type + place.which();
std::hash<int> hasher;
return hasher(kernel_id);

hash_combine(seed, data_type);
hash_combine(seed, data_layout);
hash_combine(seed, library_type);
return seed;
}
};

proto::DataType data_type_;
DataLayout data_layout_;
platform::Place place_;
LibraryType library_type_;

OpKernelType(proto::DataType data_type, platform::Place place,
DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain)
: data_type_(data_type),
data_layout_(data_layout),
place_(place),
library_type_(library_type) {}

OpKernelType(proto::DataType data_type,
const platform::DeviceContext& dev_ctx,
DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain)
: data_type_(data_type),
data_layout_(data_layout),
place_(dev_ctx.GetPlace()),
library_type_(library_type) {}

bool operator==(const OpKernelType& o) const {
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_ && data_layout_ == o.data_layout_ &&
library_type_ == o.library_type_;
}
};

} // namespace framework
} // namespace paddle
5 changes: 3 additions & 2 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,9 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
}

std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key) {
os << "place[" << kernel_key.place_ << "]:data_type[" << kernel_key.data_type_
<< "]";
os << "data_type[" << kernel_key.data_type_ << "]:data_layout["
<< kernel_key.data_layout_ << "]:place[" << kernel_key.place_
<< "]:library_type[" << kernel_key.library_type_ << "]";
return os;
}

Expand Down
31 changes: 1 addition & 30 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@ limitations under the License. */
#include "glog/logging.h" // For VLOG
#include "paddle/framework/attribute.h"
#include "paddle/framework/block_desc.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_kernel_type.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
#include "paddle/platform/variant.h"
#include "paddle/utils/Error.h"

Expand Down Expand Up @@ -345,34 +344,6 @@ class OpKernel : public OpKernelBase {
using ELEMENT_TYPE = T;
};

struct OpKernelType {
struct Hash {
std::hash<int> hash_;
size_t operator()(const OpKernelType& key) const {
int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_);
int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
(place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1));
return hash_(pre_hash);
}
};

platform::Place place_;
proto::DataType data_type_;

OpKernelType(proto::DataType data_type, platform::Place place)
: place_(place), data_type_(data_type) {}

OpKernelType(proto::DataType data_type,
const platform::DeviceContext& dev_ctx)
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}

bool operator==(const OpKernelType& o) const {
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_;
}
};

class OperatorWithKernel : public OperatorBase {
public:
using OpKernelMap =
Expand Down
10 changes: 0 additions & 10 deletions paddle/platform/place.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,8 @@ struct IsMKLDNNPlace : public boost::static_visitor<bool> {
bool operator()(const GPUPlace &) const { return false; }
};

// Define the max number of Place in bit length. i.e., the max number of places
// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4

typedef boost::variant<CUDNNPlace, GPUPlace, CPUPlace, MKLDNNPlace> Place;

// static check number of place types is less equal than
// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
BOOST_MPL_ASSERT((boost::mpl::less_equal<
Place::types::size,
boost::mpl::long_<1 << NUM_PLACE_TYPE_LIMIT_IN_BIT>>));

void set_place(const Place &);
const Place &get_place();

Expand Down