@@ -921,57 +921,74 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
921921 platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance ();
922922 auto * dev_ctx = pool.Get (place);
923923
924- // check if op[type] has kernel registered.
925- auto & all_op_kernels = AllOpKernels ();
926- auto kernels_iter = all_op_kernels.find (type_);
927- if (kernels_iter == all_op_kernels.end ()) {
928- PADDLE_THROW (
929- " There are no kernels which are registered in the %s operator." , type_);
930- }
924+ if (!kernel_type_) {
925+ // LOG(INFO) << "1, kernel_type is not set.";
926+ // check if op[type] has kernel registered.
927+ auto & all_op_kernels = AllOpKernels ();
928+ auto kernels_iter = all_op_kernels.find (type_);
929+ if (kernels_iter == all_op_kernels.end ()) {
930+ PADDLE_THROW (
931+ " There are no kernels which are registered in the %s operator." ,
932+ type_);
933+ }
931934
932- OpKernelMap& kernels = kernels_iter->second ;
935+ OpKernelMap& kernels = kernels_iter->second ;
933936
934- auto expected_kernel_key = this ->GetExpectedKernelType (
935- ExecutionContext (*this , scope, *dev_ctx, ctx, nullptr ));
936- VLOG (3 ) << " expected_kernel_key:" << expected_kernel_key;
937+ auto expected_kernel_key = this ->GetExpectedKernelType (
938+ ExecutionContext (*this , scope, *dev_ctx, ctx, nullptr ));
939+ VLOG (3 ) << " expected_kernel_key:" << expected_kernel_key;
937940
938- auto kernel_iter = kernels.find (expected_kernel_key);
941+ auto kernel_iter = kernels.find (expected_kernel_key);
939942#ifdef PADDLE_WITH_MKLDNN
940- // workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
941- if (kernel_iter == kernels.end () &&
942- expected_kernel_key.library_type_ == LibraryType::kMKLDNN ) {
943- VLOG (3 ) << " missing MKLDNN kernel: fallbacking to PLAIN one" ;
944- expected_kernel_key.library_type_ = LibraryType::kPlain ;
945- expected_kernel_key.data_layout_ = DataLayout::kAnyLayout ;
946- kernel_iter = kernels.find (expected_kernel_key);
947- }
943+ // workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
944+ if (kernel_iter == kernels.end () &&
945+ expected_kernel_key.library_type_ == LibraryType::kMKLDNN ) {
946+ VLOG (3 ) << " missing MKLDNN kernel: fallbacking to PLAIN one" ;
947+ expected_kernel_key.library_type_ = LibraryType::kPlain ;
948+ expected_kernel_key.data_layout_ = DataLayout::kAnyLayout ;
949+ kernel_iter = kernels.find (expected_kernel_key);
950+ }
948951#endif
949- if (kernel_iter == kernels.end ()) {
950- PADDLE_THROW (" op %s does not have kernel for %s" , type_,
951- KernelTypeToString (expected_kernel_key));
952+ if (kernel_iter == kernels.end ()) {
953+ PADDLE_THROW (" op %s does not have kernel for %s" , type_,
954+ KernelTypeToString (expected_kernel_key));
955+ }
956+
957+ kernel_type_.reset (new OpKernelType (expected_kernel_key));
958+ kernel_func_.reset (new OpKernelFunc (kernel_iter->second ));
952959 }
953960
961+ // std::shared_ptr<OpKernelType> kernel_type = kernel_type_;
962+ // std::shared_ptr<OpKernelFunc> kernel_func = kernel_func_;
963+
954964 std::vector<KernelConfig>* kernel_configs =
955- GetKernelConfig (expected_kernel_key);
965+ // GetKernelConfig(expected_kernel_key);
966+ GetKernelConfig (*kernel_type_);
956967
957968 // do data transformScope &transfer_scope;
958969 std::vector<std::string> transfered_inplace_vars;
959970 auto * transfer_scope =
960- PrepareData (scope, expected_kernel_key, &transfered_inplace_vars, &ctx);
971+ // PrepareData(scope, expected_kernel_key, &transfered_inplace_vars,
972+ // &ctx);
973+ PrepareData (scope, *kernel_type_, &transfered_inplace_vars, &ctx);
961974
962975 // exec scope is the scope that kernel actually executed on.
963976 const Scope& exec_scope =
964977 (transfer_scope == nullptr ? scope : *transfer_scope);
965978
966- if (!(expected_kernel_key.place_ == dev_ctx->GetPlace ())) {
967- dev_ctx = pool.Get (expected_kernel_key.place_ );
979+ // if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
980+ // dev_ctx = pool.Get(expected_kernel_key.place_);
981+ if (!(kernel_type_->place_ == dev_ctx->GetPlace ())) {
982+ dev_ctx = pool.Get (kernel_type_->place_ );
968983 }
969984
970985 RuntimeInferShapeContext infer_shape_ctx (*this , exec_scope, ctx);
971986 this ->InferShape (&infer_shape_ctx);
972987 // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
973988 // not Scope. Imperative mode only pass inputs and get outputs.
974- kernel_iter->second (
989+ // kernel_iter->second(
990+ // ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs));
991+ (*kernel_func_)(
975992 ExecutionContext (*this , exec_scope, *dev_ctx, ctx, kernel_configs));
976993
977994 if (!transfered_inplace_vars.empty ()) {
0 commit comments