Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class OneDNNPhiKernelInstruction : public InstructionBase {

const std::string& Name() const override { return phi_op_name_; }

private:
protected:
paddle::dialect::InferMetaInterface::Concept* infer_meta_interface_{
nullptr}; // not owned

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,73 @@ OneDNNMixedPhiKernelInstruction::OneDNNMixedPhiKernelInstruction(
const platform::Place& place,
pir::Operation* op,
const ValueExecutionInfo* value_exec_info)
: OneDNNPhiKernelInstruction(id, place, op, value_exec_info) {}
: OneDNNPhiKernelInstruction(id, place, op, value_exec_info) {
auto op_attributes = op->attributes();
kernel_name_ =
op_attributes.at("kernel_name").dyn_cast<pir::StrAttribute>().AsString();
kernel_key_ = op_attributes.at("kernel_key")
.dyn_cast<paddle::dialect::KernelAttribute>()
.data();
}

void OneDNNMixedPhiKernelInstruction::Run() {
// Step1. Mixed Dynamic Choose Kernel
// todo if (input_tensor.layout() != phi::DataLayout::ONEDNN)
if (!has_choose_kernel_) {
has_choose_kernel_ = true;
use_onednn_kernel_ =
phi_kernel_->check_if_onednn_kernel_support_(&kernel_context_);
if (!use_onednn_kernel_) {
auto kernel_result =
phi::KernelFactory::Instance().SelectKernelOrThrowError(kernel_name_,
kernel_key_);
delete phi_kernel_;
phi_kernel_ = new phi::Kernel(kernel_result.kernel);
}
}

// Step2. Run Kernel
if (use_onednn_kernel_) {
OneDNNPhiKernelInstruction::Run();
} else {
// TransLayout first
auto inputs = kernel_context_.InputsBetween<phi::DenseTensor>(
size_t(0), kernel_context_.InputsSize());

for (size_t i = 0; i < inputs.size(); ++i) {
auto input = inputs[i];
if (input->layout() == phi::DataLayout::ONEDNN) {
DataLayout tmp_layout =
phi::OneDNNContext::tls().get_cur_paddle_data_layout();

// NOTE(zhiqiu): to handle the special case in ApplyDataTransform() in
// data_transfer.cc
if (!input->IsInitialized() && tmp_layout == DataLayout::NHWC) {
auto transed_tensor = const_cast<phi::DenseTensor*>(input);
transed_tensor->set_layout(tmp_layout);
phi::funcs::MatchShapeToLayout(
transed_tensor, phi::DataLayout::ONEDNN, tmp_layout);
} else {
phi::DenseTensor transed_tensor;
transed_tensor.set_meta(input->meta());
phi::funcs::TransDataLayoutFromOneDNN(phi::DataLayout::ONEDNN,
tmp_layout,
*input,
&transed_tensor,
phi::CPUPlace());
*(const_cast<phi::DenseTensor*>(input)) = transed_tensor;
}
}
}

OneDNNPhiKernelInstruction::Run();
VLOG(6) << "Begin run op " << phi_op_name_ << " infer meta.";
if (infer_meta_interface_) {
infer_meta_interface_->infer_meta_(&(infer_meta_context_));
}
VLOG(6) << "End run op " << phi_op_name_ << " infer meta.";
VLOG(6) << "Begin run op " << phi_op_name_ << " kernel.";
(*(phi_kernel_))(&(kernel_context_));
VLOG(6) << "End run op " << phi_op_name_ << " kernel.";
}
}

} // namespace framework
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ class OneDNNMixedPhiKernelInstruction : public OneDNNPhiKernelInstruction {
const ValueExecutionInfo* value_exec_info);

void Run() override;

private:
std::string kernel_name_;
phi::KernelKey kernel_key_;
bool has_choose_kernel_{false};
bool use_onednn_kernel_{true};
};

} // namespace framework
Expand Down
12 changes: 5 additions & 7 deletions paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,17 @@
extra_args : bool is_test=false
data_format_tensors : x, out, mid_out, out_grad

- op : pad3d
extra_args :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有extra_args仍需要指定这个key吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个底层能力上是可以不指定的。但是我写这里的时候,想了一下,还是想放个空的。大多数op都有extra_args,写一个空的extra_args更能让看代码的人一眼就知道他的extra_args是空的。

data_format_tensors : x
dynamic_fallback : True

# - op : matmul
# extra_args : str mkldnn_data_type="float32"
# layout_transform :
# arg_name: cur_paddle_data_layout
# tensors: x, y

# - op : pad3d
# extra_args :
# layout_transform :
# arg_name: data_format
# tensors: x
# dynamic_fallback : True

# - op : batch_norm
# extra_args : bool fuse_with_relu=false
# layout_transform :
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/core/kernel_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ class Kernel {
}

GetKernelTypeForVarFn get_kerneltype_forvar_fn_{nullptr};
std::function<bool(const KernelContext* ctx)> check_if_onednn_kernel_support_{
nullptr};

private:
KernelFn fn_{nullptr};
Expand Down
3 changes: 0 additions & 3 deletions paddle/phi/kernels/cpu/onednn_to_paddle_layout_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ void OneDNN2PaddleLayout(const Context& dev_ctx,
}

DataLayout tmp_layout = static_cast<DataLayout>(dst_layout);
if (static_cast<DataLayout>(dst_layout) == DataLayout::ANY) {
tmp_layout = phi::OneDNNContext::tls().get_cur_paddle_data_layout();
}

if (tmp_layout == DataLayout::ANY) {
tmp_layout = phi::OneDNNContext::tls().get_cur_paddle_data_layout();
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/kernels/onednn/pad3d_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ KernelKey Pad3dGetKernelTypeForVar(const GetKernelTypeForVarContext* ctx) {
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}

bool Pad3dCheckIfOneDNNSupport(const KernelContext* ctx) {
// only constant mode and non-blocked layouts are supported for oneDNN
if (ctx->AttrAt<std::string>(1) == "constant" &&
ctx->InputAt<phi::DenseTensor>(0).mem_desc().get_inner_nblks() == 0) {
return true;
}
return false;
}

template <typename T, typename Context>
void Pad3dKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand All @@ -58,4 +67,5 @@ PD_REGISTER_KERNEL(pad3d,
phi::dtype::bfloat16,
float) {
kernel->get_kerneltype_forvar_fn_ = phi::Pad3dGetKernelTypeForVar;
kernel->check_if_onednn_kernel_support_ = phi::Pad3dCheckIfOneDNNSupport;
}
20 changes: 20 additions & 0 deletions test/ir/inference/test_mkldnn_pad3d_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from auto_scan_test import MkldnnAutoScanTest
from hypothesis import given
from program_config import OpConfig, ProgramConfig, TensorConfig
from utils import compare_legacy_with_pt


class TestOneDNNPad3DOp(MkldnnAutoScanTest):
Expand Down Expand Up @@ -77,6 +78,25 @@ def sample_predictor_configs(self, program_config):
def test(self, *args, **kwargs):
self.run_test(quant=False, *args, **kwargs)

@given(
data_format=st.sampled_from(['NCDHW', 'NDHWC']),
use_paddings_tensor=st.sampled_from([True, False]),
in_shape=st.sampled_from(
[[2, 3, 4, 5, 6], [1, 4, 1, 3, 2], [4, 3, 2, 1, 1], [1, 1, 1, 1, 1]]
),
paddings=st.sampled_from(
[
[0, 0, 0, 0, 0, 0],
[1, 2, 0, 1, 2, 1],
[2, 5, 11, 3, 4, 3],
[0, 5, 0, 1, 0, 2],
]
),
)
@compare_legacy_with_pt
def test_pir(self, *args, **kwargs):
self.run_test(quant=False, *args, **kwargs)


if __name__ == "__main__":
unittest.main()