Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions paddle/framework/data_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <typeindex>
#include "paddle/framework/framework.pb.h"

namespace paddle {
namespace framework {

inline DataType ToDataType(std::type_index type) {
if (typeid(float).hash_code() == type.hash_code()) {
return DataType::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) {
return DataType::FP64;
} else if (typeid(int).hash_code() == type.hash_code()) {
return DataType::INT32;
} else {
PADDLE_THROW("Not supported");
return static_cast<DataType>(-1);
}
}

} // namespace framework
} // namespace paddle
5 changes: 3 additions & 2 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ template <typename PlaceType, typename KernelType>
class OpKernelRegistrar : public Registrar {
public:
explicit OpKernelRegistrar(const char* op_type) {
OperatorWithKernel::OpKernelKey key;
key.place_ = PlaceType();
using T = typename KernelType::ELEMENT_TYPE;
OperatorWithKernel::OpKernelKey key(ToDataType(std::type_index(typeid(T))),
PlaceType());
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType);
}
};
Expand Down
77 changes: 62 additions & 15 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */

#include "op_info.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
Expand Down Expand Up @@ -407,7 +408,7 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
const Scope& scope_;
};

class OpKernel {
class OpKernelBase {
public:
/**
* ExecutionContext is the only parameter of Kernel Run function.
Expand All @@ -418,33 +419,47 @@ class OpKernel {

virtual void Compute(const ExecutionContext& context) const = 0;

virtual ~OpKernel() {}
virtual ~OpKernelBase() = default;
};

template <typename T>
class OpKernel : public OpKernelBase {
public:
using ELEMENT_TYPE = T;
};

class OperatorWithKernel : public OperatorBase {
public:
struct OpKernelKey {
platform::Place place_;
DataType data_type_;

OpKernelKey() = default;
explicit OpKernelKey(const platform::DeviceContext& dev_ctx) {
place_ = dev_ctx.GetPlace();
}
OpKernelKey(DataType data_type, platform::Place place)
: place_(place), data_type_(data_type) {}

OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx)
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}

bool operator==(const OpKernelKey& o) const {
return platform::places_are_same_class(place_, o.place_);
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_;
}
};

struct OpKernelHash {
std::hash<bool> hash_;
std::hash<int> hash_;
size_t operator()(const OpKernelKey& key) const {
return hash_(platform::is_gpu_place(key.place_));
int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_);
// NOTE: Number of places limit to 16.
int pre_hash = data_type << 4 | (place & 0x0F);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need a pre_hash here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we should hash two private data together. So I combine them manually.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define 4 and 0x0f somewhere.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can have a universal method, HashCombine.
Please refer to Hash64Combine in tensorflow

return hash_(pre_hash);
}
};

using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
OpKernelHash>;

OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
Expand All @@ -458,8 +473,10 @@ class OperatorWithKernel : public OperatorBase {

void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(ExecutionContext(*this, scope, dev_ctx));
ExecutionContext ctx(*this, scope, dev_ctx);
auto& opKernel = AllOpKernels().at(type_).at(
OpKernelKey(IndicateDataType(ctx), dev_ctx));
opKernel->Compute(ctx);
}

static std::unordered_map<std::string /* op_type */, OpKernelMap>&
Expand All @@ -469,13 +486,43 @@ class OperatorWithKernel : public OperatorBase {
}

bool SupportGPU() const override {
OperatorWithKernel::OpKernelKey key;
key.place_ = platform::GPUPlace();
return OperatorWithKernel::AllOpKernels().at(type_).count(key) != 0;
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(),
[](OpKernelMap::const_reference kern_pair) {
return platform::is_gpu_place(kern_pair.first.place_);
});
}

protected:
virtual void InferShape(InferShapeContextBase* ctx) const = 0;

// indicate kernel DataType by input data. Defaultly all input data must be
// same.
virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe InferDataType is a better name.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it is not an inference, i.e., it is not inference the output data type by inputs. It indicates the kernel data type.

auto& scope = ctx.scope();
int data_type = -1;
for (auto& input : this->inputs_) {
for (auto& ipt_name : input.second) {
auto* var = scope.FindVar(ipt_name);
if (var != nullptr) {
const Tensor* t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
}
if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type()));
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
"DataType of Paddle Op must be same.");
data_type = tmp;
}
}
}
}
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<DataType>(data_type);
}
};

} // namespace framework
Expand Down
7 changes: 5 additions & 2 deletions paddle/framework/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,13 @@ class OpWithKernelTest : public OperatorWithKernel {

protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {}
DataType IndicateDataType(const ExecutionContext& ctx) const override {
return DataType::FP32;
}
};

template <typename T1, typename T2>
class CPUKernelTest : public OpKernel {
class CPUKernelTest : public OpKernel<float> {
public:
void Compute(const ExecutionContext& ctx) const {
std::cout << "this is cpu kernel" << std::endl;
Expand All @@ -146,7 +149,7 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
}
};

class CPUKernalMultiInputsTest : public OpKernel {
class CPUKernalMultiInputsTest : public OpKernel<float> {
public:
void Compute(const ExecutionContext& ctx) const {
auto xs = ctx.op().Inputs("xs");
Expand Down
12 changes: 2 additions & 10 deletions paddle/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,10 @@ limitations under the License. */

namespace paddle {

namespace pybind {
namespace details {
template <bool less, size_t i, typename... args>
struct CastToPyBufferImpl;
}
} // namespace pybind

namespace framework {

class Tensor {
public:
template <bool less, size_t i, typename... args>
friend struct pybind::details::CastToPyBufferImpl;

template <typename T, size_t D, int MajorType, typename IndexType>
friend struct EigenTensor;

Expand Down Expand Up @@ -119,6 +109,8 @@ class Tensor {
return holder_->place();
}

std::type_index type() const { return holder_->type(); }

private:
template <typename T>
inline void check_memory_size() const;
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/accuracy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata,
}

template <typename T>
class AccuracyOpCUDAKernel : public framework::OpKernel {
class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/accuracy_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;

template <typename Place, typename T>
class AccuracyKernel : public framework::OpKernel {
class AccuracyKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* inference = ctx.Input<Tensor>("Inference");
Expand Down
20 changes: 10 additions & 10 deletions paddle/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace paddle {
namespace operators {

template <typename Place, typename T, typename Functor>
class ActivationKernel : public framework::OpKernel {
class ActivationKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand All @@ -36,7 +36,7 @@ class ActivationKernel : public framework::OpKernel {
};

template <typename Place, typename T, typename Functor>
class ActivationGradKernel : public framework::OpKernel {
class ActivationGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand Down Expand Up @@ -202,7 +202,7 @@ struct SquareGradFunctor {
};

template <typename Place, typename T, typename AttrType = T>
class BReluKernel : public framework::OpKernel {
class BReluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand All @@ -219,7 +219,7 @@ class BReluKernel : public framework::OpKernel {
};

template <typename Place, typename T, typename AttrType = T>
class BReluGradKernel : public framework::OpKernel {
class BReluGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand All @@ -239,7 +239,7 @@ class BReluGradKernel : public framework::OpKernel {
};

template <typename Place, typename T, typename AttrType = T>
class SoftReluKernel : public framework::OpKernel {
class SoftReluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand All @@ -256,7 +256,7 @@ class SoftReluKernel : public framework::OpKernel {
};

template <typename Place, typename T, typename AttrType = T>
class SoftReluGradKernel : public framework::OpKernel {
class SoftReluGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand All @@ -277,7 +277,7 @@ class SoftReluGradKernel : public framework::OpKernel {
};

template <typename Place, typename T, typename AttrType = T>
class PowKernel : public framework::OpKernel {
class PowKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand All @@ -293,7 +293,7 @@ class PowKernel : public framework::OpKernel {
};

template <typename Place, typename T, typename AttrType = T>
class PowGradKernel : public framework::OpKernel {
class PowGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand All @@ -312,7 +312,7 @@ class PowGradKernel : public framework::OpKernel {
};

template <typename Place, typename T, typename AttrType = T>
class STanhKernel : public framework::OpKernel {
class STanhKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand All @@ -329,7 +329,7 @@ class STanhKernel : public framework::OpKernel {
};

template <typename Place, typename T, typename AttrType = T>
class STanhGradKernel : public framework::OpKernel {
class STanhGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/add_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

template <typename Place, typename T>
class AddKernel : public framework::OpKernel {
class AddKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input0 = context.Input<Tensor>("X");
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/clip_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ClipGradFunctor {
};

template <typename Place, typename T>
class ClipKernel : public framework::OpKernel {
class ClipKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max");
Expand All @@ -73,7 +73,7 @@ class ClipKernel : public framework::OpKernel {
};

template <typename Place, typename T>
class ClipGradKernel : public framework::OpKernel {
class ClipGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max");
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/concat_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace paddle {
namespace operators {

template <typename Place, typename T>
class ConcatKernel : public framework::OpKernel {
class ConcatKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X");
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/cos_sim_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

template <typename Place, typename T>
class CosSimKernel : public framework::OpKernel {
class CosSimKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
// get Tensor
Expand Down Expand Up @@ -67,7 +67,7 @@ class CosSimKernel : public framework::OpKernel {
};

template <typename Place, typename T>
class CosSimGradKernel : public framework::OpKernel {
class CosSimGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
// get Tensor
Expand Down
Loading