Skip to content

Commit 184768e

Browse files
authored
Merge pull request #4455 from reyoung/feature/make_paddle_support_double
Support double precision
2 parents 5d6d2bc + d53b38e commit 184768e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+341
-142
lines changed

paddle/framework/data_type.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <typeindex>
17+
#include "paddle/framework/framework.pb.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
22+
inline DataType ToDataType(std::type_index type) {
23+
if (typeid(float).hash_code() == type.hash_code()) {
24+
return DataType::FP32;
25+
} else if (typeid(double).hash_code() == type.hash_code()) {
26+
return DataType::FP64;
27+
} else if (typeid(int).hash_code() == type.hash_code()) {
28+
return DataType::INT32;
29+
} else {
30+
PADDLE_THROW("Not supported");
31+
return static_cast<DataType>(-1);
32+
}
33+
}
34+
35+
} // namespace framework
36+
} // namespace paddle

paddle/framework/op_registry.h

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,39 @@ class OpRegistrar : public Registrar {
100100
}
101101
};
102102

103-
template <typename PlaceType, typename KernelType>
103+
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
104+
struct OpKernelRegistrarFunctor;
105+
106+
template <typename PlaceType, size_t I, typename... KernelTypes>
107+
struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
108+
using KERNEL_TYPE =
109+
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
110+
111+
void operator()(const char* op_type) const {
112+
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
113+
OperatorWithKernel::OpKernelKey key(ToDataType(std::type_index(typeid(T))),
114+
PlaceType());
115+
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
116+
117+
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
118+
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
119+
func;
120+
func(op_type);
121+
}
122+
};
123+
124+
template <typename PlaceType, size_t I, typename... KernelType>
125+
struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
126+
void operator()(const char* op_type) const {}
127+
};
128+
129+
// User can register many kernel in one place. The data type could be different.
130+
template <typename PlaceType, typename... KernelType>
104131
class OpKernelRegistrar : public Registrar {
105132
public:
106133
explicit OpKernelRegistrar(const char* op_type) {
107-
OperatorWithKernel::OpKernelKey key;
108-
key.place_ = PlaceType();
109-
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType);
134+
OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func;
135+
func(op_type);
110136
}
111137
};
112138

paddle/framework/operator.h

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License. */
2222

2323
#include "op_info.h"
2424
#include "paddle/framework/attribute.h"
25+
#include "paddle/framework/data_type.h"
2526
#include "paddle/framework/framework.pb.h"
2627
#include "paddle/framework/lod_tensor.h"
2728
#include "paddle/framework/scope.h"
@@ -403,7 +404,7 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
403404
const Scope& scope_;
404405
};
405406

406-
class OpKernel {
407+
class OpKernelBase {
407408
public:
408409
/**
409410
* ExecutionContext is the only parameter of Kernel Run function.
@@ -414,33 +415,47 @@ class OpKernel {
414415

415416
virtual void Compute(const ExecutionContext& context) const = 0;
416417

417-
virtual ~OpKernel() {}
418+
virtual ~OpKernelBase() = default;
419+
};
420+
421+
template <typename T>
422+
class OpKernel : public OpKernelBase {
423+
public:
424+
using ELEMENT_TYPE = T;
418425
};
419426

420427
class OperatorWithKernel : public OperatorBase {
421428
public:
422429
struct OpKernelKey {
423430
platform::Place place_;
431+
DataType data_type_;
424432

425-
OpKernelKey() = default;
426-
explicit OpKernelKey(const platform::DeviceContext& dev_ctx) {
427-
place_ = dev_ctx.GetPlace();
428-
}
433+
OpKernelKey(DataType data_type, platform::Place place)
434+
: place_(place), data_type_(data_type) {}
435+
436+
OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx)
437+
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}
429438

430439
bool operator==(const OpKernelKey& o) const {
431-
return platform::places_are_same_class(place_, o.place_);
440+
return platform::places_are_same_class(place_, o.place_) &&
441+
data_type_ == o.data_type_;
432442
}
433443
};
434444

435445
struct OpKernelHash {
436-
std::hash<bool> hash_;
446+
std::hash<int> hash_;
437447
size_t operator()(const OpKernelKey& key) const {
438-
return hash_(platform::is_gpu_place(key.place_));
448+
int place = key.place_.which();
449+
int data_type = static_cast<int>(key.data_type_);
450+
int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
451+
(place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1));
452+
return hash_(pre_hash);
439453
}
440454
};
441455

442456
using OpKernelMap =
443-
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
457+
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
458+
OpKernelHash>;
444459

445460
OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
446461
const VariableNameMap& outputs, const AttributeMap& attrs)
@@ -451,8 +466,10 @@ class OperatorWithKernel : public OperatorBase {
451466
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
452467
this->InferShape(&infer_shape_ctx);
453468

454-
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
455-
opKernel->Compute(ExecutionContext(*this, scope, dev_ctx));
469+
ExecutionContext ctx(*this, scope, dev_ctx);
470+
auto& opKernel = AllOpKernels().at(type_).at(
471+
OpKernelKey(IndicateDataType(ctx), dev_ctx));
472+
opKernel->Compute(ctx);
456473
}
457474

458475
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
@@ -462,13 +479,43 @@ class OperatorWithKernel : public OperatorBase {
462479
}
463480

464481
bool SupportGPU() const override {
465-
OperatorWithKernel::OpKernelKey key;
466-
key.place_ = platform::GPUPlace();
467-
return OperatorWithKernel::AllOpKernels().at(type_).count(key) != 0;
482+
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
483+
return std::any_of(op_kernels.begin(), op_kernels.end(),
484+
[](OpKernelMap::const_reference kern_pair) {
485+
return platform::is_gpu_place(kern_pair.first.place_);
486+
});
468487
}
469488

470489
protected:
471490
virtual void InferShape(InferShapeContextBase* ctx) const = 0;
491+
492+
// indicate kernel DataType by input data. Defaultly all input data must be
493+
// same.
494+
virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
495+
auto& scope = ctx.scope();
496+
int data_type = -1;
497+
for (auto& input : this->inputs_) {
498+
for (auto& ipt_name : input.second) {
499+
auto* var = scope.FindVar(ipt_name);
500+
if (var != nullptr) {
501+
const Tensor* t = nullptr;
502+
if (var->IsType<Tensor>()) {
503+
t = &var->Get<Tensor>();
504+
} else if (var->IsType<LoDTensor>()) {
505+
t = &var->Get<LoDTensor>();
506+
}
507+
if (t != nullptr) {
508+
int tmp = static_cast<int>(ToDataType(t->type()));
509+
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
510+
"DataType of Paddle Op must be same.");
511+
data_type = tmp;
512+
}
513+
}
514+
}
515+
}
516+
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
517+
return static_cast<DataType>(data_type);
518+
}
472519
};
473520

474521
} // namespace framework

paddle/framework/operator_test.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,13 @@ class OpWithKernelTest : public OperatorWithKernel {
114114

115115
protected:
116116
void InferShape(framework::InferShapeContextBase* ctx) const override {}
117+
DataType IndicateDataType(const ExecutionContext& ctx) const override {
118+
return DataType::FP32;
119+
}
117120
};
118121

119122
template <typename T1, typename T2>
120-
class CPUKernelTest : public OpKernel {
123+
class CPUKernelTest : public OpKernel<float> {
121124
public:
122125
void Compute(const ExecutionContext& ctx) const {
123126
std::cout << "this is cpu kernel" << std::endl;
@@ -144,7 +147,7 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
144147
}
145148
};
146149

147-
class CPUKernalMultiInputsTest : public OpKernel {
150+
class CPUKernalMultiInputsTest : public OpKernel<float> {
148151
public:
149152
void Compute(const ExecutionContext& ctx) const {
150153
auto xs = ctx.op().Inputs("xs");

paddle/framework/tensor.h

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,10 @@ limitations under the License. */
2929

3030
namespace paddle {
3131

32-
namespace pybind {
33-
namespace details {
34-
template <bool less, size_t i, typename... args>
35-
struct CastToPyBufferImpl;
36-
}
37-
} // namespace pybind
38-
3932
namespace framework {
4033

4134
class Tensor {
4235
public:
43-
template <bool less, size_t i, typename... args>
44-
friend struct pybind::details::CastToPyBufferImpl;
45-
4636
template <typename T, size_t D, int MajorType, typename IndexType>
4737
friend struct EigenTensor;
4838

@@ -119,6 +109,8 @@ class Tensor {
119109
return holder_->place();
120110
}
121111

112+
std::type_index type() const { return holder_->type(); }
113+
122114
private:
123115
template <typename T>
124116
inline void check_memory_size() const;

paddle/operators/accuracy_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata,
4747
}
4848

4949
template <typename T>
50-
class AccuracyOpCUDAKernel : public framework::OpKernel {
50+
class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
5151
public:
5252
void Compute(const framework::ExecutionContext& ctx) const override {
5353
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),

paddle/operators/accuracy_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
3535
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
3636

3737
template <typename Place, typename T>
38-
class AccuracyKernel : public framework::OpKernel {
38+
class AccuracyKernel : public framework::OpKernel<T> {
3939
public:
4040
void Compute(const framework::ExecutionContext& ctx) const override {
4141
auto* inference = ctx.Input<Tensor>("Inference");

paddle/operators/activation_op.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace paddle {
2020
namespace operators {
2121

2222
template <typename Place, typename T, typename Functor>
23-
class ActivationKernel : public framework::OpKernel {
23+
class ActivationKernel : public framework::OpKernel<T> {
2424
public:
2525
void Compute(const framework::ExecutionContext& context) const override {
2626
auto* X = context.Input<framework::Tensor>("X");
@@ -36,7 +36,7 @@ class ActivationKernel : public framework::OpKernel {
3636
};
3737

3838
template <typename Place, typename T, typename Functor>
39-
class ActivationGradKernel : public framework::OpKernel {
39+
class ActivationGradKernel : public framework::OpKernel<T> {
4040
public:
4141
void Compute(const framework::ExecutionContext& context) const override {
4242
auto* X = context.Input<framework::Tensor>("X");
@@ -202,7 +202,7 @@ struct SquareGradFunctor {
202202
};
203203

204204
template <typename Place, typename T, typename AttrType = T>
205-
class BReluKernel : public framework::OpKernel {
205+
class BReluKernel : public framework::OpKernel<T> {
206206
public:
207207
void Compute(const framework::ExecutionContext& context) const override {
208208
auto* X = context.Input<framework::Tensor>("X");
@@ -219,7 +219,7 @@ class BReluKernel : public framework::OpKernel {
219219
};
220220

221221
template <typename Place, typename T, typename AttrType = T>
222-
class BReluGradKernel : public framework::OpKernel {
222+
class BReluGradKernel : public framework::OpKernel<T> {
223223
public:
224224
void Compute(const framework::ExecutionContext& context) const override {
225225
auto* X = context.Input<framework::Tensor>("X");
@@ -239,7 +239,7 @@ class BReluGradKernel : public framework::OpKernel {
239239
};
240240

241241
template <typename Place, typename T, typename AttrType = T>
242-
class SoftReluKernel : public framework::OpKernel {
242+
class SoftReluKernel : public framework::OpKernel<T> {
243243
public:
244244
void Compute(const framework::ExecutionContext& context) const override {
245245
auto* X = context.Input<framework::Tensor>("X");
@@ -256,7 +256,7 @@ class SoftReluKernel : public framework::OpKernel {
256256
};
257257

258258
template <typename Place, typename T, typename AttrType = T>
259-
class SoftReluGradKernel : public framework::OpKernel {
259+
class SoftReluGradKernel : public framework::OpKernel<T> {
260260
public:
261261
void Compute(const framework::ExecutionContext& context) const override {
262262
auto* X = context.Input<framework::Tensor>("X");
@@ -277,7 +277,7 @@ class SoftReluGradKernel : public framework::OpKernel {
277277
};
278278

279279
template <typename Place, typename T, typename AttrType = T>
280-
class PowKernel : public framework::OpKernel {
280+
class PowKernel : public framework::OpKernel<T> {
281281
public:
282282
void Compute(const framework::ExecutionContext& context) const override {
283283
auto* X = context.Input<framework::Tensor>("X");
@@ -293,7 +293,7 @@ class PowKernel : public framework::OpKernel {
293293
};
294294

295295
template <typename Place, typename T, typename AttrType = T>
296-
class PowGradKernel : public framework::OpKernel {
296+
class PowGradKernel : public framework::OpKernel<T> {
297297
public:
298298
void Compute(const framework::ExecutionContext& context) const override {
299299
auto* X = context.Input<framework::Tensor>("X");
@@ -312,7 +312,7 @@ class PowGradKernel : public framework::OpKernel {
312312
};
313313

314314
template <typename Place, typename T, typename AttrType = T>
315-
class STanhKernel : public framework::OpKernel {
315+
class STanhKernel : public framework::OpKernel<T> {
316316
public:
317317
void Compute(const framework::ExecutionContext& context) const override {
318318
auto* X = context.Input<framework::Tensor>("X");
@@ -329,7 +329,7 @@ class STanhKernel : public framework::OpKernel {
329329
};
330330

331331
template <typename Place, typename T, typename AttrType = T>
332-
class STanhGradKernel : public framework::OpKernel {
332+
class STanhGradKernel : public framework::OpKernel<T> {
333333
public:
334334
void Compute(const framework::ExecutionContext& context) const override {
335335
auto* X = context.Input<framework::Tensor>("X");

paddle/operators/add_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
2525
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
2626

2727
template <typename Place, typename T>
28-
class AddKernel : public framework::OpKernel {
28+
class AddKernel : public framework::OpKernel<T> {
2929
public:
3030
void Compute(const framework::ExecutionContext& context) const override {
3131
auto* input0 = context.Input<Tensor>("X");

0 commit comments

Comments
 (0)