Skip to content

Commit d4674da

Browse files
committed
Cache the chosen kernel of operators'.
test=develop
1 parent 31d830d commit d4674da

File tree

2 files changed

+47
-28
lines changed

2 files changed

+47
-28
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -921,57 +921,74 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
921921
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
922922
auto* dev_ctx = pool.Get(place);
923923

924-
// check if op[type] has kernel registered.
925-
auto& all_op_kernels = AllOpKernels();
926-
auto kernels_iter = all_op_kernels.find(type_);
927-
if (kernels_iter == all_op_kernels.end()) {
928-
PADDLE_THROW(
929-
"There are no kernels which are registered in the %s operator.", type_);
930-
}
924+
if (!kernel_type_) {
925+
// LOG(INFO) << "1, kernel_type is not set.";
926+
// check if op[type] has kernel registered.
927+
auto& all_op_kernels = AllOpKernels();
928+
auto kernels_iter = all_op_kernels.find(type_);
929+
if (kernels_iter == all_op_kernels.end()) {
930+
PADDLE_THROW(
931+
"There are no kernels which are registered in the %s operator.",
932+
type_);
933+
}
931934

932-
OpKernelMap& kernels = kernels_iter->second;
935+
OpKernelMap& kernels = kernels_iter->second;
933936

934-
auto expected_kernel_key = this->GetExpectedKernelType(
935-
ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr));
936-
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
937+
auto expected_kernel_key = this->GetExpectedKernelType(
938+
ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr));
939+
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
937940

938-
auto kernel_iter = kernels.find(expected_kernel_key);
941+
auto kernel_iter = kernels.find(expected_kernel_key);
939942
#ifdef PADDLE_WITH_MKLDNN
940-
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
941-
if (kernel_iter == kernels.end() &&
942-
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
943-
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
944-
expected_kernel_key.library_type_ = LibraryType::kPlain;
945-
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
946-
kernel_iter = kernels.find(expected_kernel_key);
947-
}
943+
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
944+
if (kernel_iter == kernels.end() &&
945+
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
946+
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
947+
expected_kernel_key.library_type_ = LibraryType::kPlain;
948+
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
949+
kernel_iter = kernels.find(expected_kernel_key);
950+
}
948951
#endif
949-
if (kernel_iter == kernels.end()) {
950-
PADDLE_THROW("op %s does not have kernel for %s", type_,
951-
KernelTypeToString(expected_kernel_key));
952+
if (kernel_iter == kernels.end()) {
953+
PADDLE_THROW("op %s does not have kernel for %s", type_,
954+
KernelTypeToString(expected_kernel_key));
955+
}
956+
957+
kernel_type_.reset(new OpKernelType(expected_kernel_key));
958+
kernel_func_.reset(new OpKernelFunc(kernel_iter->second));
952959
}
953960

961+
// std::shared_ptr<OpKernelType> kernel_type = kernel_type_;
962+
// std::shared_ptr<OpKernelFunc> kernel_func = kernel_func_;
963+
954964
std::vector<KernelConfig>* kernel_configs =
955-
GetKernelConfig(expected_kernel_key);
965+
// GetKernelConfig(expected_kernel_key);
966+
GetKernelConfig(*kernel_type_);
956967

957968
// do data transformScope &transfer_scope;
958969
std::vector<std::string> transfered_inplace_vars;
959970
auto* transfer_scope =
960-
PrepareData(scope, expected_kernel_key, &transfered_inplace_vars, &ctx);
971+
// PrepareData(scope, expected_kernel_key, &transfered_inplace_vars,
972+
// &ctx);
973+
PrepareData(scope, *kernel_type_, &transfered_inplace_vars, &ctx);
961974

962975
// exec scope is the scope that kernel actually executed on.
963976
const Scope& exec_scope =
964977
(transfer_scope == nullptr ? scope : *transfer_scope);
965978

966-
if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
967-
dev_ctx = pool.Get(expected_kernel_key.place_);
979+
// if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
980+
// dev_ctx = pool.Get(expected_kernel_key.place_);
981+
if (!(kernel_type_->place_ == dev_ctx->GetPlace())) {
982+
dev_ctx = pool.Get(kernel_type_->place_);
968983
}
969984

970985
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx);
971986
this->InferShape(&infer_shape_ctx);
972987
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
973988
// not Scope. Imperative mode only pass inputs and get outputs.
974-
kernel_iter->second(
989+
// kernel_iter->second(
990+
// ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs));
991+
(*kernel_func_)(
975992
ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs));
976993

977994
if (!transfered_inplace_vars.empty()) {

paddle/fluid/framework/operator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,8 @@ class OperatorWithKernel : public OperatorBase {
541541

542542
protected:
543543
mutable OpKernelConfigsMap kernel_configs_map_;
544+
mutable std::shared_ptr<OpKernelType> kernel_type_;
545+
mutable std::shared_ptr<OpKernelFunc> kernel_func_;
544546
};
545547

546548
extern bool OpSupportGPU(const std::string& op_type);

0 commit comments

Comments
 (0)