Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/framework/library_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace framework {
// For more details about the design of LibraryType, Please refer to
// https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md#library

enum LibraryType { kPlain = 0; kMKLDNN = 1; kCUDNN = 2; }
enum LibraryType { kPlain = 0, kMKLDNN = 1, kCUDNN = 2 };

} // namespace
} // framework
82 changes: 82 additions & 0 deletions paddle/framework/op_kernel_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/framework/data_layout.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/library_type.h"
#include "paddle/platform/place.h"

namespace paddle {
namespace framework {

/*
Refer to https://stackoverflow.com/questions/35985960/
c-why-is-boosthash-combine-the-best-way-to-combine-hash-values
*/
template <class T>
inline void HashCombine(const T& v, std::size_t* seed) {
std::hash<T> hasher;
*seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
}

struct OpKernelType {
struct Hash {
size_t operator()(const OpKernelType& key) const {
int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_);
int data_layout = static_cast<int>(key.data_layout_);
int library_type = static_cast<int>(key.library_type_);

size_t seed = 0;
HashCombine(place, &seed);
HashCombine(data_type, &seed);
HashCombine(data_layout, &seed);
HashCombine(library_type, &seed);
return seed;
}
};

proto::DataType data_type_;
DataLayout data_layout_;
platform::Place place_;
LibraryType library_type_;

OpKernelType(proto::DataType data_type, platform::Place place,
DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain)
: data_type_(data_type),
data_layout_(data_layout),
place_(place),
library_type_(library_type) {}

OpKernelType(proto::DataType data_type,
const platform::DeviceContext& dev_ctx,
DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain)
: data_type_(data_type),
data_layout_(data_layout),
place_(dev_ctx.GetPlace()),
library_type_(library_type) {}

bool operator==(const OpKernelType& o) const {
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_ && data_layout_ == o.data_layout_ &&
library_type_ == o.library_type_;
}
};

} // namespace framework
} // namespace paddle
5 changes: 3 additions & 2 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,9 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
}

std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key) {
os << "place[" << kernel_key.place_ << "]:data_type[" << kernel_key.data_type_
<< "]";
os << "data_type[" << kernel_key.data_type_ << "]:data_layout["
<< kernel_key.data_layout_ << "]:place[" << kernel_key.place_
<< "]:library_type[" << kernel_key.library_type_ << "]";
return os;
}

Expand Down
31 changes: 1 addition & 30 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@ limitations under the License. */
#include "glog/logging.h" // For VLOG
#include "paddle/framework/attribute.h"
#include "paddle/framework/block_desc.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_kernel_type.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
#include "paddle/platform/variant.h"
#include "paddle/utils/Error.h"

Expand Down Expand Up @@ -345,34 +344,6 @@ class OpKernel : public OpKernelBase {
using ELEMENT_TYPE = T;
};

struct OpKernelType {
struct Hash {
std::hash<int> hash_;
size_t operator()(const OpKernelType& key) const {
int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_);
int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
(place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1));
return hash_(pre_hash);
}
};

platform::Place place_;
proto::DataType data_type_;

OpKernelType(proto::DataType data_type, platform::Place place)
: place_(place), data_type_(data_type) {}

OpKernelType(proto::DataType data_type,
const platform::DeviceContext& dev_ctx)
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}

bool operator==(const OpKernelType& o) const {
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_;
}
};

class OperatorWithKernel : public OperatorBase {
public:
using OpKernelMap =
Expand Down
10 changes: 0 additions & 10 deletions paddle/platform/place.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,8 @@ struct IsMKLDNNPlace : public boost::static_visitor<bool> {
bool operator()(const GPUPlace &) const { return false; }
};

// Define the max number of Place in bit length. i.e., the max number of places
// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4

typedef boost::variant<CUDNNPlace, GPUPlace, CPUPlace, MKLDNNPlace> Place;

// static check number of place types is less equal than
// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
BOOST_MPL_ASSERT((boost::mpl::less_equal<
Place::types::size,
boost::mpl::long_<1 << NUM_PLACE_TYPE_LIMIT_IN_BIT>>));

void set_place(const Place &);
const Place &get_place();

Expand Down