Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,33 @@ OneDNNPhiKernelInstruction::OneDNNPhiKernelInstruction(
.at("extra_args")
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
std::vector<std::string> extra_args;
auto& op_normalizer = paddle::translator::OpNameNormalizer::instance();
std::string fluid_op_name = yaml_info_parser.GetOriginOpName();

for (auto& attr : extra_args_attr) {
auto attr_name = attr.dyn_cast<pir::StrAttribute>().AsString();
extra_attr_[attr_name] = ConvertPirAttribute2RuntimeAttribute(
op_attributes.at(attr_name), attr_name, yaml_info_parser);
auto legacy_attr_name =
op_normalizer.GetLegacyAttrName(fluid_op_name, attr_name);
if (legacy_attr_name != attr_name) {
extra_attr_[legacy_attr_name] = extra_attr_[attr_name];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra_attr_ 和 ctx_attr_ 必须要使用 legacy_attr_name 么?后续就算子退场,这里是不是还要重新适配?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OneDNN底层的Kernel通过Ctx使用的Attr情况如下:

  1. 他既使用extra的Attr,也通过ctx使用PHI kernel签名的Attr。
  2. 他既支持Operators渠道的调用,又支持PHI渠道的调用。因此底层取Attr时,使用的是Operators里Attr的命名规范(首字母大写)。因此我需要把legacy_attr_name给到Ctx。

}
}
auto attr_name_list = yaml_info_parser.AttrParams(true);
for (auto& attr : attr_name_list) {
auto attr_name = attr;
if (!op_attributes.count(attr_name)) {
// In PIR, IntArray attr will be input, but not attr.
continue;
}
ctx_attr_[attr_name] = ConvertPirAttribute2RuntimeAttribute(
op_attributes.at(attr_name), attr_name, yaml_info_parser);
auto legacy_attr_name =
op_normalizer.GetLegacyAttrName(fluid_op_name, attr_name);
if (legacy_attr_name != attr_name) {
ctx_attr_[legacy_attr_name] = ctx_attr_[attr_name];
}
}
}
TensorNameMap(op, *value_exec_info_, yaml_info_parser, inputs_, outputs_);
Expand All @@ -331,6 +353,9 @@ void OneDNNPhiKernelInstruction::Run() {
size_t(0), kernel_context_.InputsSize());
for (size_t i = 0; i < inputs.size(); ++i) {
auto input = inputs[i];
if (input == nullptr) {
continue;
}
if (input->layout() != phi::DataLayout::ONEDNN) {
phi::DataLayout from_layout = input->layout();

Expand Down Expand Up @@ -370,6 +395,9 @@ void OneDNNPhiKernelInstruction::Run() {
for (auto& attr : extra_attr_) {
one_dnn_ctx->SetDnnAttr(attr.first, attr.second);
}
for (auto& attr : ctx_attr_) {
one_dnn_ctx->SetDnnAttr(attr.first, attr.second);
}
one_dnn_ctx->SetInputsName(inputs_);
one_dnn_ctx->SetOutputsName(outputs_);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class OneDNNPhiKernelInstruction : public InstructionBase {
std::set<int> data_format_tensors_{};
phi::DataLayout input_layout_{phi::DataLayout::kAnyLayout};
std::map<std::string, phi::Attribute> extra_attr_{};
std::map<std::string, phi::Attribute> ctx_attr_{};
std::map<std::string, std::vector<std::string>> inputs_{};
std::map<std::string, std::vector<std::string>> outputs_{};
};
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1990,6 +1990,15 @@ def OpGenerator(
if first_file:
op["is_onednn_only"] = True
onednn_only_op_list.append("\"" + op['name'] + "\"")
if op['name'] in ops_onednn_extra_map:
onednn_item = ops_onednn_extra_map[op['name']]
op["is_onednn_only"] = onednn_item["is_onednn_only"]
op["extra_args"] = onednn_item["extra_args"]
op["data_format_tensors"] = onednn_item[
"data_format_tensors"
]
op["dynamic_fallback"] = onednn_item["dynamic_fallback"]
op["attrs"] = op["attrs"] + onednn_item["attrs"]
elif op['name'] in ops_onednn_extra_map:
onednn_item = ops_onednn_extra_map[op['name']]
op["is_onednn_only"] = onednn_item["is_onednn_only"]
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/onednn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@
func : dequantize
data_type : input

- op : fused_conv2d
args : (Tensor input, Tensor filter, Tensor bias, Tensor residual_param, int[] strides={1, 1}, int[] paddings={0, 0}, str padding_algorithm="EXPLICIT", int[] dilations={1, 1}, int groups=1, str data_format="NCHW", str mkldnn_data_type="float32", str fuse_activation="", bool fuse_residual_connection=false, bool force_fp32_output=false)
output : Tensor(output)
infer_meta :
func : FusedConvInferMeta
kernel :
func : fused_conv2d
data_type : input
optional : bias, residual_param

- op : quantize
args : (Tensor input, bool is_negative_input=false, float scale=1.0, float shift=0.0, str output_format="NHWC", bool bfloat16=false)
output : Tensor(output)
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
extra_args : bool is_test=false
data_format_tensors : input, out_grad

- op : fused_conv2d
extra_args : float fuse_alpha = 0.0, float fuse_beta = 0.0, float scale_in=1.0, float scale_out=1.0, float scale_in_eltwise=1.0, float[] scale_weights={1.0f}
data_format_tensors : input

- op : lrn
extra_args : bool is_test=false
data_format_tensors : x
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,12 @@
reserve_space: ReserveSpace

- op : fused_conv2d
inputs :
{input : Input, filter : Filter, bias : Bias, residual_param : ResidualData}
outputs :
{output : Output}
attrs :
{scale_in : Scale_in, scale_out : Scale_out, scale_in_eltwise : Scale_in_eltwise, scale_weights : Scale_weights}
extra :
attrs : [bool use_cudnn = false, float fuse_alpha = 0.0f, float fuse_beta = 0.0f, float Scale_in = 1.0f,
float Scale_out = 1.0f, float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool use_mkldnn = true, str mkldnn_data_type = "float32"]
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@
func : ConvInferMeta
kernel :
func : conv2d
data_type : input
backward : conv2d_grad

- op : conv3d
Expand All @@ -576,6 +577,7 @@
func : Conv3DInferMeta
kernel :
func : conv3d
data_type : input
backward : conv3d_grad

- op : conv3d_transpose
Expand Down Expand Up @@ -713,6 +715,7 @@
func : DepthwiseConvInferMeta
kernel :
func : depthwise_conv2d
data_type : input
backward : depthwise_conv2d_grad

- op : det
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ void FusedConvInferMeta(const MetaTensor& input,
bool fuse_residual_conn,
bool force_fp32_output,
MetaTensor* out,
MetaConfig config);
MetaConfig config = MetaConfig());

void MoeInferMeta(const MetaTensor& x,
const MetaTensor& gate,
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/onednn/conv_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ static dnnl::memory::data_type GetDstType(
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, \
::paddle::DataType::BFLOAT16, \
::phi::dtype::bfloat16, \
__VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \
Expand Down
2 changes: 1 addition & 1 deletion test/legacy_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ set_tests_properties(test_sigmoid_cross_entropy_with_logits_op
PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_optimizer_v2 PROPERTIES TIMEOUT 150)
set_tests_properties(test_partial_sum_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_cond PROPERTIES TIMEOUT 120)
set_tests_properties(test_cond PROPERTIES TIMEOUT 240)
set_tests_properties(test_sgd_op PROPERTIES TIMEOUT 250)
set_tests_properties(test_parallel_executor_seresnext_base_gpu
PROPERTIES TIMEOUT 120)
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def setUpClass(cls):
cls.check_prim = False
cls.check_prim_pir = False
cls._check_cinn = False
cls.check_pir_onednn = False

np.random.seed(123)
random.seed(124)
Expand Down
16 changes: 14 additions & 2 deletions test/legacy_test/test_conv2d_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,10 @@ def test_check_output(self):
place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace()
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output_with_place(
place, atol=1e-5, check_dygraph=(not self.use_mkldnn)
place,
atol=1e-5,
check_dygraph=(not self.use_mkldnn),
check_pir_onednn=self.check_pir_onednn,
)

def test_check_grad(self):
Expand All @@ -515,6 +518,7 @@ def test_check_grad(self):
'Output',
max_relative_error=0.02,
check_dygraph=(not self.use_mkldnn),
check_pir_onednn=self.check_pir_onednn,
)

def test_check_grad_no_filter(self):
Expand All @@ -531,6 +535,7 @@ def test_check_grad_no_filter(self):
max_relative_error=0.02,
no_grad_set={'Filter'},
check_dygraph=(not self.use_mkldnn),
check_pir_onednn=self.check_pir_onednn,
)

def test_check_grad_no_input(self):
Expand All @@ -546,6 +551,7 @@ def test_check_grad_no_input(self):
'Output',
no_grad_set={'Input'},
check_dygraph=(not self.use_mkldnn),
check_pir_onednn=self.check_pir_onednn,
)

def init_test_case(self):
Expand Down Expand Up @@ -824,7 +830,10 @@ def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace()
self.check_output_with_place(
place, atol=1e-5, check_dygraph=(not self.use_mkldnn)
place,
atol=1e-5,
check_dygraph=(not self.use_mkldnn),
check_pir_onednn=self.check_pir_onednn,
)

def test_check_grad(self):
Expand All @@ -838,6 +847,7 @@ def test_check_grad(self):
'Output',
max_relative_error=0.02,
check_dygraph=(not self.use_mkldnn),
check_pir_onednn=self.check_pir_onednn,
)

def test_check_grad_no_filter(self):
Expand All @@ -852,6 +862,7 @@ def test_check_grad_no_filter(self):
max_relative_error=0.02,
no_grad_set={'Filter'},
check_dygraph=(not self.use_mkldnn),
check_pir_onednn=self.check_pir_onednn,
)

def test_check_grad_no_input(self):
Expand All @@ -865,6 +876,7 @@ def test_check_grad_no_input(self):
'Output',
no_grad_set={'Input'},
check_dygraph=(not self.use_mkldnn),
check_pir_onednn=self.check_pir_onednn,
)

def init_test_case(self):
Expand Down
1 change: 0 additions & 1 deletion test/legacy_test/test_elementwise_add_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
class TestElementwiseAddOp(OpTest):
def init_kernel_type(self):
self.use_mkldnn = False
self.check_pir_onednn = False

def setUp(self):
self.op_type = "elementwise_add"
Expand Down
8 changes: 7 additions & 1 deletion test/mkldnn/test_conv2d_bf16_mkldnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def setUp(self):
self.init_data_type()
self.init_force_fp32_output()
self.init_infer_or_train()
self.check_pir_onednn = True

self.conv2d_param = {
'stride': self.stride,
Expand Down Expand Up @@ -117,7 +118,9 @@ def setUp(self):
self.init_additional_attrs()

def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
self.check_output_with_place(
core.CPUPlace(), check_pir_onednn=self.check_pir_onednn
)

def test_check_grad(self):
pass
Expand Down Expand Up @@ -186,6 +189,7 @@ def test_check_grad(self):
"Output",
user_defined_grads=[dx, dweights],
user_defined_grad_outputs=[convert_float_to_uint16(dout)],
check_pir_onednn=self.check_pir_onednn,
)

def test_check_grad_no_filter(self):
Expand All @@ -202,6 +206,7 @@ def test_check_grad_no_filter(self):
{'Filter'},
user_defined_grads=[dx],
user_defined_grad_outputs=[convert_float_to_uint16(dout)],
check_pir_onednn=self.check_pir_onednn,
)

def test_check_grad_no_input(self):
Expand All @@ -218,6 +223,7 @@ def test_check_grad_no_input(self):
{'Input'},
user_defined_grads=[dweights],
user_defined_grad_outputs=[convert_float_to_uint16(dout)],
check_pir_onednn=self.check_pir_onednn,
)


Expand Down
6 changes: 5 additions & 1 deletion test/mkldnn/test_conv2d_int8_mkldnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def setUp(self):
self.init_fuse_activation()
self.init_fuse_residual()
self.init_data_type()
self.check_pir_onednn = True

conv2d_param = {
'stride': self.stride,
Expand Down Expand Up @@ -184,7 +185,10 @@ def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
# the atol for integer tests should be 1
self.check_output_with_place(
core.CPUPlace(), atol=1, check_dygraph=False
core.CPUPlace(),
atol=1,
check_dygraph=False,
check_pir_onednn=self.check_pir_onednn,
)

def test_check_grad(self):
Expand Down
13 changes: 2 additions & 11 deletions test/mkldnn/test_conv2d_mkldnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
import numpy as np
from op_test import OpTest, skip_check_grad_ci
from test_conv2d_op import TestConv2DOp, TestConv2DOp_v2
from utils import compare_legacy_with_pt

from paddle.base import core


def conv2d_bias_naive(out, bias):
Expand Down Expand Up @@ -64,6 +61,7 @@ def setUp(self):
self.input_residual_size = None

TestConv2DOp.setUp(self)
self.check_pir_onednn = True

output = self.outputs['Output']

Expand Down Expand Up @@ -144,6 +142,7 @@ def setUp(self):
self.input_residual_size = None

TestConv2DOp.setUp(self)
self.check_pir_onednn = True

output = self.outputs['Output']

Expand Down Expand Up @@ -195,14 +194,6 @@ def setUp(self):

self.outputs['Output'] = output

@compare_legacy_with_pt
def test_check_output(self):
place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace()
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output_with_place(
place, atol=1e-5, check_dygraph=(not self.use_mkldnn)
)


@skip_check_grad_ci(
reason="Fusion is for inference only, check_grad is not required."
Expand Down