Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,9 @@ OneDNNPhiKernelInstruction::OneDNNPhiKernelInstruction(
.AsVector();

for (auto& attr : data_format_tensors_attr) {
auto pair = kernel_context_.InputRangeAt(value_exec_info_->GetIdByName(
attr.dyn_cast<pir::StrAttribute>().AsString()));
auto pair =
kernel_context_.InputRangeAt(yaml_info_parser.InputName2Id().at(
attr.dyn_cast<pir::StrAttribute>().AsString()));
for (int i = pair.first; i < pair.second; ++i) {
data_format_tensors_.insert(i);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class OneDNNPhiKernelInstruction : public InstructionBase {

const std::string& Name() const override { return phi_op_name_; }

private:
protected:
paddle::dialect::InferMetaInterface::Concept* infer_meta_interface_{
nullptr}; // not owned

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,73 @@ OneDNNMixedPhiKernelInstruction::OneDNNMixedPhiKernelInstruction(
const platform::Place& place,
pir::Operation* op,
const ValueExecutionInfo* value_exec_info)
: OneDNNPhiKernelInstruction(id, place, op, value_exec_info) {}
: OneDNNPhiKernelInstruction(id, place, op, value_exec_info) {
auto op_attributes = op->attributes();
kernel_name_ =
op_attributes.at("kernel_name").dyn_cast<pir::StrAttribute>().AsString();
kernel_key_ = op_attributes.at("kernel_key")
.dyn_cast<paddle::dialect::KernelAttribute>()
.data();
}

void OneDNNMixedPhiKernelInstruction::Run() {
// Step1. Mixed Dynamic Choose Kernel
// todo if (input_tensor.layout() != phi::DataLayout::ONEDNN)
if (!has_choose_kernel_) {
has_choose_kernel_ = true;
use_onednn_kernel_ =
phi_kernel_->check_if_onednn_kernel_support_(&kernel_context_);
if (!use_onednn_kernel_) {
auto kernel_result =
phi::KernelFactory::Instance().SelectKernelOrThrowError(kernel_name_,
kernel_key_);
delete phi_kernel_;
phi_kernel_ = new phi::Kernel(kernel_result.kernel);
}
}

// Step2. Run Kernel
if (use_onednn_kernel_) {
OneDNNPhiKernelInstruction::Run();
} else {
// TransLayout first
auto inputs = kernel_context_.InputsBetween<phi::DenseTensor>(
size_t(0), kernel_context_.InputsSize());

for (size_t i = 0; i < inputs.size(); ++i) {
auto input = inputs[i];
if (input->layout() == phi::DataLayout::ONEDNN) {
DataLayout tmp_layout =
phi::OneDNNContext::tls().get_cur_paddle_data_layout();

// NOTE(zhiqiu): to handle the special case in ApplyDataTransform() in
// data_transfer.cc
if (!input->IsInitialized() && tmp_layout == DataLayout::NHWC) {
auto transed_tensor = const_cast<phi::DenseTensor*>(input);
transed_tensor->set_layout(tmp_layout);
phi::funcs::MatchShapeToLayout(
transed_tensor, phi::DataLayout::ONEDNN, tmp_layout);
} else {
phi::DenseTensor transed_tensor;
transed_tensor.set_meta(input->meta());
phi::funcs::TransDataLayoutFromOneDNN(phi::DataLayout::ONEDNN,
tmp_layout,
*input,
&transed_tensor,
phi::CPUPlace());
*(const_cast<phi::DenseTensor*>(input)) = transed_tensor;
}
}
}

OneDNNPhiKernelInstruction::Run();
VLOG(6) << "Begin run op " << phi_op_name_ << " infer meta.";
if (infer_meta_interface_) {
infer_meta_interface_->infer_meta_(&(infer_meta_context_));
}
VLOG(6) << "End run op " << phi_op_name_ << " infer meta.";
VLOG(6) << "Begin run op " << phi_op_name_ << " kernel.";
(*(phi_kernel_))(&(kernel_context_));
VLOG(6) << "End run op " << phi_op_name_ << " kernel.";
}
}

} // namespace framework
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ class OneDNNMixedPhiKernelInstruction : public OneDNNPhiKernelInstruction {
const ValueExecutionInfo* value_exec_info);

void Run() override;

private:
std::string kernel_name_;
phi::KernelKey kernel_key_;
bool has_choose_kernel_{false};
bool use_onednn_kernel_{true};
};

} // namespace framework
Expand Down
26 changes: 22 additions & 4 deletions paddle/fluid/framework/new_executor/pir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,21 @@ void PirInterpreter::UpdateNcclOpNum() {
VLOG(4) << "Update nccl op num, nccl op num is: " << nccl_op_num;
}

void PirInterpreter::UpdateOneDNNOpNum() {
int64_t onednn_op_num = 0;
#ifdef PADDLE_WITH_DNNL
for (auto& ins : vec_instruction_base_) {
if (dynamic_cast<OneDNNPhiKernelInstruction*>(ins.get()) != nullptr ||
dynamic_cast<OneDNNLegacyKernelInstruction*>(ins.get()) != nullptr ||
dynamic_cast<OneDNNMixedPhiKernelInstruction*>(ins.get()) != nullptr) {
onednn_op_num = onednn_op_num + 1;
}
}
#endif
onednn_op_num_ = onednn_op_num;
VLOG(4) << "Update onednn op num, onednn op num is: " << onednn_op_num;
}

// Note(zhangbo):
// When there is a KQueueSync type OP in the model, breadth traversal is better
// than depth traversal. For example: OP(O) ->(direct_run)-> OP(A)
Expand Down Expand Up @@ -1305,7 +1320,7 @@ paddle::framework::FetchList PirInterpreter::Run(

// Run
if (FLAGS_enable_pir_in_executor_trace_run || nccl_op_num_ > 1 ||
execution_config_.used_for_inference ||
onednn_op_num_ || execution_config_.used_for_inference ||
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
(sync_op_num_ == 0))) {
LOG_FIRST_N(INFO, 1) << "pir interpreter is running by trace mode ...";
Expand All @@ -1326,7 +1341,7 @@ paddle::framework::FetchList PirInterpreter::Run(
}
#endif
if (FLAGS_enable_pir_in_executor_trace_run || nccl_op_num_ > 1 ||
execution_config_.used_for_inference ||
onednn_op_num_ || execution_config_.used_for_inference ||
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
(sync_op_num_ == 0))) {
TraceRunImpl();
Expand Down Expand Up @@ -1395,7 +1410,7 @@ FetchList PirInterpreter::Run(const std::vector<std::string>& feed_names,

// Run
if (FLAGS_enable_pir_in_executor_trace_run || nccl_op_num_ > 1 ||
execution_config_.used_for_inference ||
onednn_op_num_ || execution_config_.used_for_inference ||
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
(sync_op_num_ == 0))) {
LOG_FIRST_N(INFO, 1) << "pir interpreter is running by trace mode ...";
Expand All @@ -1416,7 +1431,7 @@ FetchList PirInterpreter::Run(const std::vector<std::string>& feed_names,
}
#endif
if (FLAGS_enable_pir_in_executor_trace_run || nccl_op_num_ > 1 ||
execution_config_.used_for_inference ||
onednn_op_num_ || execution_config_.used_for_inference ||
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
(sync_op_num_ == 0))) {
TraceRunImpl();
Expand Down Expand Up @@ -1804,6 +1819,9 @@ void PirInterpreter::PreAnalysis() {

UpdateNcclOpNum();
VLOG(4) << "Done UpdateNcclOpNum";

UpdateOneDNNOpNum();
VLOG(4) << "Done UpdateOneDNNOpNum";
}

::pir::Value PirInterpreter::GetValueByName(const std::string& var_name) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/new_executor/pir_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class PirInterpreter : public InterpreterBaseImpl {
// build graph
void UpdateSyncOpNum();
void UpdateNcclOpNum();
void UpdateOneDNNOpNum();
void AnalyseExecuteOrderForTrace(
std::map<size_t, std::set<size_t>> op_downstream_map,
InstructionSchedulingPriorityLess compare);
Expand Down Expand Up @@ -196,6 +197,7 @@ class PirInterpreter : public InterpreterBaseImpl {
// used for Trace
int64_t sync_op_num_{-1};
int64_t nccl_op_num_{-1};
int64_t onednn_op_num_{-1};
std::vector<size_t> trace_execute_order_;

std::vector<HookFunc> output_hookfuncs_;
Expand Down
18 changes: 18 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,24 @@ bool AnalysisPredictor::PrepareExecutor() {
gpu_pm.EnableIRPrinting();
}
gpu_pm.Run(pir_program_.get());
} else {
::pir::PassManager cpu_pm(::pir::IrContext::Instance(), 2);

auto constant_folding_pass = ::pir::CreateConstantFoldingPass();
constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place_);
constant_folding_pass->SetNotOwned(pir::kParamScopeAttr, sub_scope_);

cpu_pm.AddPass(std::move(constant_folding_pass));
cpu_pm.AddPass(::pir::CreateDeadCodeEliminationPass());
cpu_pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass());
//----------------------------------------------------------------------------------------------//
if (!config_.glog_info_disabled()) {
cpu_pm.EnablePrintStatistics();
}
if (config_.ir_debug_) {
cpu_pm.EnableIRPrinting();
}
cpu_pm.Run(pir_program_.get());
}

pir_program_ = std::move(
Expand Down
12 changes: 5 additions & 7 deletions paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,17 @@
extra_args : bool is_test=false
data_format_tensors : x, out, mid_out, out_grad

- op : pad3d
extra_args :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有extra_args仍需要指定这个key吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个底层能力上是可以不指定的。但是我写这里的时候,想了一下,还是想放个空的。大多数op都有extra_args,写一个空的extra_args更能让看代码的人一眼就知道他的extra_args是空的。

data_format_tensors : x
dynamic_fallback : True

# - op : matmul
# extra_args : str mkldnn_data_type="float32"
# layout_transform :
# arg_name: cur_paddle_data_layout
# tensors: x, y

# - op : pad3d
# extra_args :
# layout_transform :
# arg_name: data_format
# tensors: x
# dynamic_fallback : True

# - op : batch_norm
# extra_args : bool fuse_with_relu=false
# layout_transform :
Expand Down
28 changes: 27 additions & 1 deletion paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1345,7 +1345,7 @@ void HandleForSpecialOp(
}
}

if (op_item->isa<::pir::YieldOp>() || op_item->isa<::pir::ShadowOutputOp>()) {
if (op_item->isa<::pir::YieldOp>()) {
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
Expand All @@ -1360,6 +1360,32 @@ void HandleForSpecialOp(
}
}

if (op_item->isa<::pir::ShadowOutputOp>()) {
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
auto new_in = GetNewInput(
cur_in, *map_value_pair, static_cast<int>(i), op_item->name());
// layout transfer(only for onednn)
#ifdef PADDLE_WITH_DNNL
auto new_in_type = new_in.type();
if (new_in_type.isa<AllocatedDenseTensorType>()) {
if (new_in_type.dyn_cast<AllocatedDenseTensorType>().data_layout() ==
phi::DataLayout::ONEDNN) {
new_in = AddOneDNN2PaddleLayoutTransferOp(
new_in, phi::DataLayout::ANY, block);
}
}
#endif
Comment on lines +1374 to +1383
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不太合理吧,大家编译的时候带上WITH_MKLDNN,但是不跑mkldnn模式,这里会影响其他模式逻辑?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看错了,有这个判断不会的if (new_in_type.dyn_cast().data_layout() == phi::DataLayout::ONEDNN)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在不跑mkldnn模式的时候,new_in_type.dyn_cast().data_layout() == phi::DataLayout::ONEDNN 为 false。所以不会影响其它模式。但位置在当前pd_op_to_kernel_pass.cc的代码框架下,只能放在这里。

vec_inputs.push_back(new_in);
}
}
}

if (op_item->isa<::pir::SetParameterOp>()) {
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/core/kernel_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ class Kernel {
}

GetKernelTypeForVarFn get_kerneltype_forvar_fn_{nullptr};
std::function<bool(const KernelContext* ctx)> check_if_onednn_kernel_support_{
nullptr};

private:
KernelFn fn_{nullptr};
Expand Down
5 changes: 2 additions & 3 deletions paddle/phi/kernels/cpu/onednn_to_paddle_layout_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,13 @@ void OneDNN2PaddleLayout(const Context& dev_ctx,
}

DataLayout tmp_layout = static_cast<DataLayout>(dst_layout);
if (static_cast<DataLayout>(dst_layout) == DataLayout::ANY) {
tmp_layout = phi::OneDNNContext::tls().get_cur_paddle_data_layout();
}

if (tmp_layout == DataLayout::ANY) {
tmp_layout = phi::OneDNNContext::tls().get_cur_paddle_data_layout();
}

VLOG(4) << "src_layout: " << src_layout << ", tmp_layout: " << tmp_layout;

// NOTE(zhiqiu): to handle the special case in ApplyDataTransform() in
// data_transfer.cc
if (!x.IsInitialized() && src_layout == DataLayout::ONEDNN &&
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/kernels/onednn/pad3d_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ KernelKey Pad3dGetKernelTypeForVar(const GetKernelTypeForVarContext* ctx) {
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}

bool Pad3dCheckIfOneDNNSupport(const KernelContext* ctx) {
// only constant mode and non-blocked layouts are supported for oneDNN
if (ctx->AttrAt<std::string>(1) == "constant" &&
ctx->InputAt<phi::DenseTensor>(0).mem_desc().get_inner_nblks() == 0) {
return true;
}
return false;
}

template <typename T, typename Context>
void Pad3dKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand All @@ -58,4 +67,5 @@ PD_REGISTER_KERNEL(pad3d,
phi::dtype::bfloat16,
float) {
kernel->get_kerneltype_forvar_fn_ = phi::Pad3dGetKernelTypeForVar;
kernel->check_if_onednn_kernel_support_ = phi::Pad3dCheckIfOneDNNSupport;
}
20 changes: 20 additions & 0 deletions test/ir/inference/auto_scan_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,26 @@ def inference_config_str(self, config) -> str:
return str(dic)


class PirMkldnnAutoScanTest(MkldnnAutoScanTest):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def run_test_config(
self, model, params, prog_config, pred_config, feed_data
) -> Dict[str, np.ndarray]:
"""
Test a single case.
"""
paddle.set_flags({'FLAGS_enable_pir_in_executor': True})
pred_config.switch_ir_optim(False)
pred_config.enable_new_executor()
result = super().run_test_config(
model, params, prog_config, pred_config, feed_data
)
paddle.set_flags({'FLAGS_enable_pir_in_executor': False})
return result


class PassAutoScanTest(AutoScanTest):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions test/ir/inference/program_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def _cast(self) -> None:

def create_fake_model(program_config):
'''Create a Paddle model(in memory) according to the given config.'''
paddle.set_flags({'FLAGS_enable_pir_in_executor': False})
program_config = copy.deepcopy(program_config)
program_config._cast()
paddle.enable_static()
Expand Down
Loading