@@ -22,6 +22,7 @@ limitations under the License. */
2222
2323#include " op_info.h"
2424#include " paddle/framework/attribute.h"
25+ #include " paddle/framework/data_type.h"
2526#include " paddle/framework/framework.pb.h"
2627#include " paddle/framework/lod_tensor.h"
2728#include " paddle/framework/scope.h"
@@ -403,7 +404,7 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
403404 const Scope& scope_;
404405};
405406
406- class OpKernel {
407+ class OpKernelBase {
407408 public:
408409 /* *
409410 * ExecutionContext is the only parameter of Kernel Run function.
@@ -414,33 +415,47 @@ class OpKernel {
414415
415416 virtual void Compute (const ExecutionContext& context) const = 0;
416417
417- virtual ~OpKernel () {}
418+ virtual ~OpKernelBase () = default ;
419+ };
420+
421+ template <typename T>
422+ class OpKernel : public OpKernelBase {
423+ public:
424+ using ELEMENT_TYPE = T;
418425};
419426
420427class OperatorWithKernel : public OperatorBase {
421428 public:
422429 struct OpKernelKey {
423430 platform::Place place_;
431+ DataType data_type_;
424432
425- OpKernelKey () = default ;
426- explicit OpKernelKey (const platform::DeviceContext& dev_ctx) {
427- place_ = dev_ctx.GetPlace ();
428- }
433+ OpKernelKey (DataType data_type, platform::Place place)
434+ : place_(place), data_type_(data_type) {}
435+
436+ OpKernelKey (DataType data_type, const platform::DeviceContext& dev_ctx)
437+ : place_(dev_ctx.GetPlace()), data_type_(data_type) {}
429438
430439 bool operator ==(const OpKernelKey& o) const {
431- return platform::places_are_same_class (place_, o.place_ );
440+ return platform::places_are_same_class (place_, o.place_ ) &&
441+ data_type_ == o.data_type_ ;
432442 }
433443 };
434444
435445 struct OpKernelHash {
436- std::hash<bool > hash_;
446+ std::hash<int > hash_;
437447 size_t operator ()(const OpKernelKey& key) const {
438- return hash_ (platform::is_gpu_place (key.place_ ));
448+ int place = key.place_ .which ();
449+ int data_type = static_cast <int >(key.data_type_ );
450+ int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
451+ (place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1 ));
452+ return hash_ (pre_hash);
439453 }
440454 };
441455
442456 using OpKernelMap =
443- std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
457+ std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
458+ OpKernelHash>;
444459
445460 OperatorWithKernel (const std::string& type, const VariableNameMap& inputs,
446461 const VariableNameMap& outputs, const AttributeMap& attrs)
@@ -451,8 +466,10 @@ class OperatorWithKernel : public OperatorBase {
451466 RuntimeInferShapeContext infer_shape_ctx (*this , scope);
452467 this ->InferShape (&infer_shape_ctx);
453468
454- auto & opKernel = AllOpKernels ().at (type_).at (OpKernelKey (dev_ctx));
455- opKernel->Compute (ExecutionContext (*this , scope, dev_ctx));
469+ ExecutionContext ctx (*this , scope, dev_ctx);
470+ auto & opKernel = AllOpKernels ().at (type_).at (
471+ OpKernelKey (IndicateDataType (ctx), dev_ctx));
472+ opKernel->Compute (ctx);
456473 }
457474
458475 static std::unordered_map<std::string /* op_type */ , OpKernelMap>&
@@ -462,13 +479,43 @@ class OperatorWithKernel : public OperatorBase {
462479 }
463480
464481 bool SupportGPU () const override {
465- OperatorWithKernel::OpKernelKey key;
466- key.place_ = platform::GPUPlace ();
467- return OperatorWithKernel::AllOpKernels ().at (type_).count (key) != 0 ;
482+ auto & op_kernels = OperatorWithKernel::AllOpKernels ().at (type_);
483+ return std::any_of (op_kernels.begin (), op_kernels.end (),
484+ [](OpKernelMap::const_reference kern_pair) {
485+ return platform::is_gpu_place (kern_pair.first .place_ );
486+ });
468487 }
469488
470489 protected:
471490 virtual void InferShape (InferShapeContextBase* ctx) const = 0;
491+
492+ // indicate kernel DataType by input data. Defaultly all input data must be
493+ // same.
494+ virtual DataType IndicateDataType (const ExecutionContext& ctx) const {
495+ auto & scope = ctx.scope ();
496+ int data_type = -1 ;
497+ for (auto & input : this ->inputs_ ) {
498+ for (auto & ipt_name : input.second ) {
499+ auto * var = scope.FindVar (ipt_name);
500+ if (var != nullptr ) {
501+ const Tensor* t = nullptr ;
502+ if (var->IsType <Tensor>()) {
503+ t = &var->Get <Tensor>();
504+ } else if (var->IsType <LoDTensor>()) {
505+ t = &var->Get <LoDTensor>();
506+ }
507+ if (t != nullptr ) {
508+ int tmp = static_cast <int >(ToDataType (t->type ()));
509+ PADDLE_ENFORCE (tmp == data_type || data_type == -1 ,
510+ " DataType of Paddle Op must be same." );
511+ data_type = tmp;
512+ }
513+ }
514+ }
515+ }
516+ PADDLE_ENFORCE (data_type != -1 , " DataType should be indicated by input" );
517+ return static_cast <DataType>(data_type);
518+ }
472519};
473520
474521} // namespace framework
0 commit comments