-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Support double precision #4455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support double precision #4455
Changes from 2 commits
3a5693e
2c05465
f1913d4
ae3dca7
f2feb33
fb6a48c
b9c8637
63469da
87da154
d53b38e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
| #include <typeindex> | ||
| #include "paddle/framework/framework.pb.h" | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
|
|
||
| inline DataType ToDataType(std::type_index type) { | ||
| if (typeid(float).hash_code() == type.hash_code()) { | ||
| return DataType::FP32; | ||
| } else if (typeid(double).hash_code() == type.hash_code()) { | ||
| return DataType::FP64; | ||
| } else if (typeid(int).hash_code() == type.hash_code()) { | ||
| return DataType::INT32; | ||
| } else { | ||
| PADDLE_THROW("Not supported"); | ||
| return static_cast<DataType>(-1); | ||
| } | ||
| } | ||
|
|
||
| } // namespace framework | ||
| } // namespace paddle |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ limitations under the License. */ | |
|
|
||
| #include "op_info.h" | ||
| #include "paddle/framework/attribute.h" | ||
| #include "paddle/framework/data_type.h" | ||
| #include "paddle/framework/framework.pb.h" | ||
| #include "paddle/framework/lod_tensor.h" | ||
| #include "paddle/framework/scope.h" | ||
|
|
@@ -407,7 +408,7 @@ class RuntimeInferShapeContext : public InferShapeContextBase { | |
| const Scope& scope_; | ||
| }; | ||
|
|
||
| class OpKernel { | ||
| class OpKernelBase { | ||
| public: | ||
| /** | ||
| * ExecutionContext is the only parameter of Kernel Run function. | ||
|
|
@@ -418,33 +419,47 @@ class OpKernel { | |
|
|
||
| virtual void Compute(const ExecutionContext& context) const = 0; | ||
|
|
||
| virtual ~OpKernel() {} | ||
| virtual ~OpKernelBase() = default; | ||
| }; | ||
|
|
||
| template <typename T> | ||
| class OpKernel : public OpKernelBase { | ||
| public: | ||
| using ELEMENT_TYPE = T; | ||
| }; | ||
|
|
||
| class OperatorWithKernel : public OperatorBase { | ||
| public: | ||
| struct OpKernelKey { | ||
| platform::Place place_; | ||
| DataType data_type_; | ||
|
|
||
| OpKernelKey() = default; | ||
| explicit OpKernelKey(const platform::DeviceContext& dev_ctx) { | ||
| place_ = dev_ctx.GetPlace(); | ||
| } | ||
| OpKernelKey(DataType data_type, platform::Place place) | ||
| : place_(place), data_type_(data_type) {} | ||
|
|
||
| OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx) | ||
| : place_(dev_ctx.GetPlace()), data_type_(data_type) {} | ||
|
|
||
| bool operator==(const OpKernelKey& o) const { | ||
| return platform::places_are_same_class(place_, o.place_); | ||
| return platform::places_are_same_class(place_, o.place_) && | ||
| data_type_ == o.data_type_; | ||
| } | ||
| }; | ||
|
|
||
| struct OpKernelHash { | ||
| std::hash<bool> hash_; | ||
| std::hash<int> hash_; | ||
| size_t operator()(const OpKernelKey& key) const { | ||
| return hash_(platform::is_gpu_place(key.place_)); | ||
| int place = key.place_.which(); | ||
| int data_type = static_cast<int>(key.data_type_); | ||
| // NOTE: Number of places limit to 16. | ||
| int pre_hash = data_type << 4 | (place & 0x0F); | ||
|
||
| return hash_(pre_hash); | ||
| } | ||
| }; | ||
|
|
||
| using OpKernelMap = | ||
| std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>; | ||
| std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>, | ||
| OpKernelHash>; | ||
|
|
||
| OperatorWithKernel(const std::string& type, const VariableNameMap& inputs, | ||
| const VariableNameMap& outputs, const AttributeMap& attrs) | ||
|
|
@@ -458,8 +473,10 @@ class OperatorWithKernel : public OperatorBase { | |
|
|
||
| void Run(const Scope& scope, | ||
| const platform::DeviceContext& dev_ctx) const final { | ||
| auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); | ||
| opKernel->Compute(ExecutionContext(*this, scope, dev_ctx)); | ||
| ExecutionContext ctx(*this, scope, dev_ctx); | ||
| auto& opKernel = AllOpKernels().at(type_).at( | ||
| OpKernelKey(IndicateDataType(ctx), dev_ctx)); | ||
| opKernel->Compute(ctx); | ||
| } | ||
|
|
||
| static std::unordered_map<std::string /* op_type */, OpKernelMap>& | ||
|
|
@@ -469,13 +486,43 @@ class OperatorWithKernel : public OperatorBase { | |
| } | ||
|
|
||
| bool SupportGPU() const override { | ||
| OperatorWithKernel::OpKernelKey key; | ||
| key.place_ = platform::GPUPlace(); | ||
| return OperatorWithKernel::AllOpKernels().at(type_).count(key) != 0; | ||
| auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); | ||
| return std::any_of(op_kernels.begin(), op_kernels.end(), | ||
| [](OpKernelMap::const_reference kern_pair) { | ||
| return platform::is_gpu_place(kern_pair.first.place_); | ||
| }); | ||
| } | ||
|
|
||
| protected: | ||
| virtual void InferShape(InferShapeContextBase* ctx) const = 0; | ||
|
|
||
| // indicate kernel DataType by input data. Defaultly all input data must be | ||
| // same. | ||
| virtual DataType IndicateDataType(const ExecutionContext& ctx) const { | ||
|
||
| auto& scope = ctx.scope(); | ||
| int data_type = -1; | ||
| for (auto& input : this->inputs_) { | ||
| for (auto& ipt_name : input.second) { | ||
| auto* var = scope.FindVar(ipt_name); | ||
| if (var != nullptr) { | ||
| const Tensor* t = nullptr; | ||
| if (var->IsType<Tensor>()) { | ||
| t = &var->Get<Tensor>(); | ||
| } else if (var->IsType<LoDTensor>()) { | ||
| t = &var->Get<LoDTensor>(); | ||
| } | ||
| if (t != nullptr) { | ||
| int tmp = static_cast<int>(ToDataType(t->type())); | ||
| PADDLE_ENFORCE(tmp == data_type || data_type == -1, | ||
| "DataType of Paddle Op must be same."); | ||
| data_type = tmp; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input"); | ||
| return static_cast<DataType>(data_type); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace framework | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we need a pre_hash here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we should hash two private data together. So I combine them manually.