Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/framework/details/build_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
continue;
}
} else if (pass->Type() == "onednn_placement_pass") {
pass->Set("mkldnn_enabled_op_types",
pass->Set("onednn_enabled_op_types",
new std::unordered_set<std::string>(onednn_enabled_op_types_));
}
VLOG(1) << "Start Apply Pass " << pass->Type();
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/ir/onednn/onednn_placement_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ inline bool FoundPhiOneDNNKernelWithCorrectDataType(
return false;
}

bool MKLDNNPlacementPass::IsSupport(const Node* op) const {
bool ONEDNNPlacementPass::IsSupport(const Node* op) const {
if (FoundOneDNNKernelWithCorrectDataType(op) ||
FoundPhiOneDNNKernelWithCorrectDataType(op)) {
// For interpolate ops, there's a little difference between Paddle and
Expand All @@ -89,8 +89,8 @@ bool MKLDNNPlacementPass::IsSupport(const Node* op) const {

} // namespace paddle::framework::ir

REGISTER_PASS(onednn_placement_pass, paddle::framework::ir::MKLDNNPlacementPass)
.RequirePassAttr("mkldnn_enabled_op_types");
REGISTER_PASS(onednn_placement_pass, paddle::framework::ir::ONEDNNPlacementPass)
.RequirePassAttr("onednn_enabled_op_types");

REGISTER_PASS_CAPABILITY(onednn_placement_pass)
.AddCombination(
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/ir/onednn/onednn_placement_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ namespace ir {
/*
* Specifies which operators should use MKLDNN.
*/
class MKLDNNPlacementPass : public PlacementPassBase {
class ONEDNNPlacementPass : public PlacementPassBase {
protected:
bool IsSupport(const Node* op) const override;

private:
const std::string GetPlacementName() const override { return "MKLDNN"; }
const std::string GetPlacementName() const override { return "ONEDNN"; }

const std::string GetAttrName() const override { return "use_mkldnn"; }

const std::unordered_set<std::string> GetOpTypesList() const override {
return Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
return Get<std::unordered_set<std::string>>("onednn_enabled_op_types");
}
};

Expand Down
18 changes: 10 additions & 8 deletions paddle/fluid/framework/ir/onednn/onednn_placement_pass_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class PlacementPassTest {

auto pass = PassRegistry::Instance().Get("onednn_placement_pass");

pass->Set("mkldnn_enabled_op_types",
pass->Set("onednn_enabled_op_types",
new std::unordered_set<std::string>(onednn_enabled_op_types));

graph.reset(pass->Apply(graph.release()));
Expand All @@ -143,8 +143,10 @@ class PlacementPassTest {
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->HasAttr("use_mkldnn") &&
PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))) {
if ((op->HasAttr("use_mkldnn") &&
PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))) ||
(op->HasAttr("use_onednn") &&
PADDLE_GET_CONST(bool, op->GetAttr("use_onednn")))) {
++use_onednn_true_count;
}
}
Expand All @@ -156,27 +158,27 @@ class PlacementPassTest {
void PlacementNameTest() {
auto pass = PassRegistry::Instance().Get("onednn_placement_pass");
EXPECT_EQ(static_cast<PlacementPassBase*>(pass.get())->GetPlacementName(),
"MKLDNN");
"ONEDNN");
}
};

TEST(MKLDNNPlacementPass, enable_conv_relu) {
TEST(ONEDNNPlacementPass, enable_conv_relu) {
// 2 conv (1 conv is always true) + 2 relu (1 relu is always true) + 0 pool
PlacementPassTest().MainTest({"conv2d", "relu"}, 4);
}

TEST(MKLDNNPlacementPass, enable_relu_pool) {
TEST(ONEDNNPlacementPass, enable_relu_pool) {
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool
PlacementPassTest().MainTest({"relu", "pool2d"}, 4);
}

TEST(MKLDNNPlacementPass, enable_all) {
TEST(ONEDNNPlacementPass, enable_all) {
// 2 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool +
// 1 concat
PlacementPassTest().MainTest({}, 6);
}

TEST(MKLDNNPlacementPass, placement_name) {
TEST(ONEDNNPlacementPass, placement_name) {
PlacementPassTest().PlacementNameTest();
}

Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/inference/analysis/argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,12 @@ struct Argument {
// whether to mute all logs in inference.
DECL_ARGUMENT_FIELD(disable_logs, DisableLogs, bool);

// Pass a set of op types to enable its mkldnn kernel
DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types,
MKLDNNEnabledOpTypes,
// Pass a set of op types to enable its onednn kernel
DECL_ARGUMENT_FIELD(onednn_enabled_op_types,
ONEDNNEnabledOpTypes,
std::unordered_set<std::string>);
// The cache capacity of different input shapes for mkldnn.
DECL_ARGUMENT_FIELD(mkldnn_cache_capacity, MkldnnCacheCapacity, int);
// The cache capacity of different input shapes for onednn.
DECL_ARGUMENT_FIELD(mkldnn_cache_capacity, OnednnCacheCapacity, int);

#ifdef PADDLE_WITH_DNNL
// A set of op types to enable their quantized kernels
Expand All @@ -219,7 +219,7 @@ struct Argument {
Bfloat16EnabledOpTypes,
std::unordered_set<std::string>);

DECL_ARGUMENT_FIELD(use_onednn_int8, UseMkldnnInt8, bool);
DECL_ARGUMENT_FIELD(use_onednn_int8, UseOnednnInt8, bool);
#endif

// Passed from config.
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/inference/analysis/ir_pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("optim_cache_dir", new std::string(std::move(optim_cache_dir)));
pass_num++;
} else if (pass_name == "onednn_placement_pass") {
pass->Set("mkldnn_enabled_op_types",
pass->Set("onednn_enabled_op_types",
new std::unordered_set<std::string>(
argument->mkldnn_enabled_op_types()));
argument->onednn_enabled_op_types()));
} else if (pass_name == "cudnn_placement_pass") {
pass->Set("cudnn_enabled_op_types",
new std::unordered_set<std::string>());
Expand Down
14 changes: 7 additions & 7 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1031,8 +1031,8 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
}
#endif
#ifdef PADDLE_WITH_DNNL
} else if (config_.mkldnn_enabled()) {
// mkldnn
} else if (config_.onednn_enabled()) {
// onednn
pir::IrContext *ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OneDNNOperatorDialect>();
if (!config_.custom_pass_only_) {
Expand Down Expand Up @@ -2100,9 +2100,9 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetIpuCustomPatterns(config_.ipu_custom_patterns_);
#endif

if (config_.mkldnn_enabled() && !config_.use_gpu()) {
LOG(INFO) << "MKLDNN is enabled";
argument_->SetMKLDNNEnabledOpTypes(config_.onednn_enabled_op_types_);
if (config_.onednn_enabled() && !config_.use_gpu()) {
LOG(INFO) << "ONEDNN is enabled";
argument_->SetONEDNNEnabledOpTypes(config_.onednn_enabled_op_types_);
}

if (config_.cinn_enabled()) {
Expand All @@ -2115,7 +2115,7 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetBfloat16EnabledOpTypes(config_.bfloat16_enabled_op_types_);
}

if (config_.mkldnn_int8_enabled()) {
if (config_.onednn_int8_enabled()) {
LOG(INFO) << "Int8 is enabled";
argument_->SetQuantizeEnabledOpTypes(config_.quantize_enabled_op_types_);
argument_->SetQuantizeExcludedOpIds(config_.quantize_excluded_op_ids_);
Expand Down Expand Up @@ -2296,7 +2296,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
#if defined(_WIN32)
argument_->PartiallyRelease();
#else
if (config_.mkldnn_enabled() ||
if (config_.onednn_enabled() ||
(config_.tensorrt_engine_enabled() &&
config_.tensorrt_precision_mode_ ==
AnalysisConfig::Precision::kInt8)) { // NOLINT
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/capi/pd_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ bool PD_OnednnEnabled(const PD_AnalysisConfig* config) {
config,
common::errors::InvalidArgument(
"The pointer of analysis configuration shouldn't be nullptr"));
return config->config.mkldnn_enabled();
return config->config.onednn_enabled();
}

void PD_SetCpuMathLibraryNumThreads(PD_AnalysisConfig* config,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/jit/engine/interpreter_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void InterpreterEngine::CreateInterpreterCore() {
#ifdef PADDLE_WITH_DNNL
auto onednn_pass =
framework::ir::PassRegistry::Instance().Get("onednn_placement_pass");
onednn_pass->Set("mkldnn_enabled_op_types",
onednn_pass->Set("onednn_enabled_op_types",
new std::unordered_set<std::string>({}));
onednn_pass->Apply(&graph);
#endif
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/operators/generator/get_expected_kernel_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ static bool ReduceOpHasOptimizedOneDNNKernel(
}

// only poolop
bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) {
bool CanONEDNNSupportPool(const framework::ExecutionContext& ctx) {
if (ctx.Attr<bool>("adaptive") == false) return true;
// oneDNN is supporting only unchangeable in size pool window
auto src_tz = common::vectorize(ctx.Input<phi::DenseTensor>("X")->dims());
Expand Down Expand Up @@ -181,7 +181,7 @@ phi::KernelKey GetPoolExpectedKernelType(
auto data_type = op_ptr->OperatorWithKernel::IndicateVarDataType(ctx, "X");

// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_DNNL
op_ptr->SetDnnFallback(!CanMKLDNNSupportPool(ctx));
op_ptr->SetDnnFallback(!CanONEDNNSupportPool(ctx));
// NOTE(jiahongyu) END: Above codes originally enclosed by PADDLE_WITH_DNNL

return phi::KernelKey(data_type, ctx.GetPlace());
Expand All @@ -194,7 +194,7 @@ phi::KernelKey GetPoolDoubleGradExpectedKernelType(
op_ptr->OperatorWithKernel::IndicateVarDataType(ctx, "grad_x@GRAD");

// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_DNNL
op_ptr->SetDnnFallback(!CanMKLDNNSupportPool(ctx));
op_ptr->SetDnnFallback(!CanONEDNNSupportPool(ctx));
// NOTE(jiahongyu) END: Above codes originally enclosed by PADDLE_WITH_DNNL

return phi::KernelKey(data_type, ctx.GetPlace());
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pybind/compiled_program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,15 @@ void BindCompiledProgram(pybind11::module &m) { // NOLINT
const std::unordered_set<std::string> &onednn_enabled_op_types) {
self.onednn_enabled_op_types_ = onednn_enabled_op_types;
})
.def_property(
"onednn_enabled_op_types",
[](const BuildStrategy &self) {
return self.onednn_enabled_op_types_;
},
[](BuildStrategy &self,
const std::unordered_set<std::string> &onednn_enabled_op_types) {
self.onednn_enabled_op_types_ = onednn_enabled_op_types;
})
.def_property(
"allow_cuda_graph_capture",
[](const BuildStrategy &self) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def _optimize_fp32_graph(self, graph):
graph = self._update_activations(graph)
graph = self._remove_ctrl_vars(graph)
graph = self._apply_pass(
graph, 'onednn_placement_pass', ['mkldnn_enabled_op_types'], [set()]
graph, 'onednn_placement_pass', ['onednn_enabled_op_types'], [set()]
)
# remove dropout ops
graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass')
Expand Down