Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 103 additions & 1 deletion paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,9 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, "{kernel_func}", {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}}, {{{extra_args}}}, {{{data_format_tensors}}}, {is_onednn_only}, {dynamic_fallback});
pir::AttributeMap extra_attr_default_value;
{extra_attr_default_value_code}
paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, "{kernel_func}", {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}}, {{{extra_args}}}, extra_attr_default_value, {{{data_format_tensors}}}, {is_onednn_only}, {dynamic_fallback});
return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}");
}}
"""
Expand Down Expand Up @@ -1126,6 +1128,94 @@ def get_mutable_attribute_grad_semantic(op_info, op_info_items):
return mutable_attribute_grad_semantics


def GenOneDnnExtraAttrsDefaultValue(onednn_extra_args):
INTARRAY_STR_TEMPLATE = """ pir::Attribute attr_{attr_name} = {op_attribute_type}::get(pir::IrContext::Instance(), phi::IntArray({attr}));
"""
SCALAR_STR_TEMPLATE = """ pir::Attribute attr_{attr_name} = paddle::dialect::TransToIrAttribute({attr}, pir::IrContext::Instance());
"""
STR_TEMPLATE = """ pir::Attribute attr_{attr_name} = {op_attribute_type}::get(pir::IrContext::Instance(), {attr});
"""
ARRAY_ATTRIBUTE_TEMPLATE = """ std::vector<pir::Attribute> vec_{attr_name};
std::vector<{cpp_type}> vec_values = {attr_valuse};
for (size_t i = 0; i < static_cast<size_t>(vec_values.size()); i++) {{
{create_attribute}
vec_{attr_name}.push_back(attr_{attr_name});
}}
pir::Attribute attr_{attr_name} = pir::ArrayAttribute::get(pir::IrContext::Instance(), vec_{attr_name});
"""
attr_str = ""
array_attr_type = "pir::ArrayAttribute<"
for idx in range(len(onednn_extra_args)):
assert (
onednn_extra_args[idx]['typename'] in attr_types_map
), f"{onednn_extra_args[idx]['typename']} : Attr type error."
extra_arg_type = attr_types_map[onednn_extra_args[idx]['typename']][0]

if array_attr_type in extra_arg_type:
inner_attribute_type = extra_arg_type[len(array_attr_type) : -1]
if inner_attribute_type == "paddle::dialect::IntArrayAttribute":
attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format(
attr_name=onednn_extra_args[idx]['name'],
cpp_type=onednn_extra_args[idx]['typename'].replace(
'[]', ''
),
attr_valuse=onednn_extra_args[idx]['default_value'],
create_attribute=INTARRAY_STR_TEMPLATE.format(
attr_name=onednn_extra_args[idx]['name'],
op_attribute_type=inner_attribute_type,
attr="vec_values[i]",
),
)
elif inner_attribute_type == "paddle::dialect::ScalarAttribute":
attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format(
attr_name=onednn_extra_args[idx]['name'],
cpp_type=onednn_extra_args[idx]['typename'].replace(
'[]', ''
),
attr_valuse=onednn_extra_args[idx]['default_value'],
create_attribute=SCALAR_STR_TEMPLATE.format(
attr_name=onednn_extra_args[idx]['name'],
attr="vec_values[i]",
),
)
else:
attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format(
attr_name=onednn_extra_args[idx]['name'],
cpp_type=onednn_extra_args[idx]['typename'].replace(
'[]', ''
),
attr_valuse=onednn_extra_args[idx]['default_value'],
create_attribute=STR_TEMPLATE.format(
attr_name=onednn_extra_args[idx]['name'],
op_attribute_type=inner_attribute_type,
attr="vec_values[i]",
),
)
elif extra_arg_type == "paddle::dialect::IntArrayAttribute":
attr_str += INTARRAY_STR_TEMPLATE.format(
attr_name=onednn_extra_args[idx]['name'],
op_attribute_type=extra_arg_type,
attr=onednn_extra_args[idx]['name'],
)
elif extra_arg_type == "paddle::dialect::ScalarAttribute":
attr_str += SCALAR_STR_TEMPLATE.format(
attr_name=onednn_extra_args[idx]['name'],
attr=onednn_extra_args[idx]['name'],
)
else:
attr_str += STR_TEMPLATE.format(
attr_name=onednn_extra_args[idx]['name'],
op_attribute_type=extra_arg_type,
attr=onednn_extra_args[idx]['default_value'],
)

attr_str += """extra_attr_default_value["{attr_name}"] = attr_{attr_name};\n""".format(
attr_name=onednn_extra_args[idx]['name']
)

return attr_str


def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
# (3) CodeGen: Traverse op_info_items and generate
ops_name_list = [] # all op class name store in this list
Expand Down Expand Up @@ -1676,12 +1766,24 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
data_format_tensors = (
'"' + '", "'.join(data_format_tensors) + '"'
)
if (
op_info.onednn_extra_args is not None
and len(op_info.onednn_extra_args) > 0
):
extra_attr_default_value_code_str = (
GenOneDnnExtraAttrsDefaultValue(
op_info.onednn_extra_args
)
)
else:
extra_attr_default_value_code_str = ""

op_info_func_str = OP_INFO_ONEDNN_TEMPLATE.format(
op_name=op_class_name,
inputs=inputs_info_str,
attributes=attribute_info_str,
outputs=outputs_info_str,
extra_attr_default_value_code=extra_attr_default_value_code_str,
infer_meta_func=infer_meta_func_str,
infer_meta_param=infer_meta_param_str,
kernel_func=kernel_func_str,
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
extra_args : bool fuse_with_relu=false
data_format_tensors : x, out_grad

# - op : bilinear_interp
- op : bilinear_interp
data_format_tensors : x

# - op : cast

Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/type_storage.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/core/operation_utils.h"

namespace paddle {
namespace dialect {
Expand Down Expand Up @@ -94,6 +95,7 @@ struct OpRunTimeInfo {
std::vector<std::pair<std::string, std::string>> inplace;
std::vector<std::pair<std::string, std::string>> view;
std::vector<std::string> extra_args;
pir::AttributeMap extra_args_default_value;
std::vector<std::string> data_format_tensors;
bool is_onednn_only;
bool dynamic_fallback;
Expand All @@ -107,6 +109,7 @@ struct OpRunTimeInfo {
const std::vector<std::pair<std::string, std::string>>& inplace,
const std::vector<std::pair<std::string, std::string>>& view,
const std::vector<std::string>& extra_args = {},
const pir::AttributeMap& extra_args_default_value = {{}},
const std::vector<std::string>& data_format_tensors = {},
bool is_onednn_only = false,
bool dynamic_fallback = false)
Expand All @@ -119,6 +122,7 @@ struct OpRunTimeInfo {
inplace(inplace),
view(view),
extra_args(extra_args),
extra_args_default_value(extra_args_default_value),
data_format_tensors(data_format_tensors),
is_onednn_only(is_onednn_only),
dynamic_fallback(dynamic_fallback) {}
Expand Down
85 changes: 83 additions & 2 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include "paddle/pir/dialect/control_flow/ir/cf_op.h"

#ifdef PADDLE_WITH_DNNL
#include "build/paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h"
#include "paddle/fluid/pir/dialect/operator/trait/onednn.h"
Expand Down Expand Up @@ -855,6 +856,45 @@ bool SupportsMKLDNN(const std::string& kernel_name,
}
}
}

bool SupportsCPUBF16(const std::string& kernel_name) {
auto phi_kernels =
phi::KernelFactory::Instance().SelectKernelMap(kernel_name);
auto has_phi_kernel =
std::any_of(phi_kernels.begin(),
phi_kernels.end(),
[](phi::KernelKeyMap::const_reference kern_pair) {
return kern_pair.first.backend() == phi::Backend::CPU &&
kern_pair.first.dtype() == phi::DataType::BFLOAT16;
});
if (has_phi_kernel) {
return true;
} else {
auto op_kernel_iter =
paddle::framework::OperatorWithKernel::AllOpKernels().find(
phi::TransToFluidOpName(kernel_name));
if (op_kernel_iter ==
paddle::framework::OperatorWithKernel::AllOpKernels().end()) {
return false;
} else {
auto& op_kernels = op_kernel_iter->second;
return std::any_of(
op_kernels.begin(),
op_kernels.end(),
[](std::unordered_map<
paddle::framework::OpKernelType,
std::function<void(const paddle::framework::ExecutionContext&)>,
paddle::framework::OpKernelType::Hash>::const_reference
kern_pair) {
return platform::is_cpu_place(kern_pair.first.place_) &&
kern_pair.first.place_ == platform::CPUPlace() &&
kern_pair.first.data_type_ ==
paddle::framework::proto::VarType::Type::
VarType_Type_BF16;
});
}
}
}
#endif

phi::KernelKey GetKernelKey(
Expand Down Expand Up @@ -1795,8 +1835,8 @@ std::vector<pir::Type> BuildOutputs(
std::vector<pir::Type> op_output_types;
pir::AttributeMap attribute_map = op_item->attributes();

auto phi_kernel = phi::KernelFactory::Instance().SelectKernelWithGPUDNN(
kernel_fn_str, kernel_key);
auto phi_kernel =
phi::KernelFactory::Instance().SelectKernel(kernel_fn_str, kernel_key);
VLOG(6) << "[" << kernel_fn_str
<< "] selected kernel(is_valid: " << phi_kernel.IsValid()
<< " ): " << kernel_key;
Expand Down Expand Up @@ -2437,6 +2477,47 @@ void ProcessBlock(
op_item = op_item_inner;
op_info_parser = GetOpYamlInfoParser(op_item_inner);
}

// Use OneDNN if CPU not support bf16
if (kernel_key.dtype() == phi::DataType::BFLOAT16 &&
kernel_key.backend() == phi::Backend::CPU &&
!op_item->HasTrait<OneDNNTrait>() && !SupportsCPUBF16(kernel_name) &&
SupportsMKLDNN(kernel_name, phi::DataType::BFLOAT16)) {
std::string target_op_name = op_item->name();
target_op_name.replace(0, 5, "onednn_op");
auto op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (op_info) {
std::vector<pir::Type> op_item_inner_output_types;
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
op_item_inner_output_types.push_back(op_item->result_type(i));
}
}
auto attributes = op_item->attributes();
auto yaml_interface =
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>();
OpRunTimeInfo runtime_info =
std::get<3>(yaml_interface->get_op_info_(target_op_name));
for (auto& attr : runtime_info.extra_args_default_value) {
attributes[attr.first] = attr.second;
}
pir::Operation* op_item_inner =
pir::Operation::Create(op_item->operands_source(),
attributes,
op_item_inner_output_types,
op_info);
op_item->ReplaceAllUsesWith(op_item_inner->results());
for (auto iter = block->begin(); iter != block->end(); ++iter) {
if (*iter == *op_item) {
block->Assign(iter, op_item_inner);
break;
}
}
op_item = op_item_inner;
op_info_parser = GetOpYamlInfoParser(op_item_inner);
kernel_key.set_backend(phi::Backend::ONEDNN);
}
}
#endif
// build input
auto new_vec_inputs = BuildInputs(op_item,
Expand Down
2 changes: 1 addition & 1 deletion test/mkldnn/test_bilinear_interp_v2_mkldnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def setUp(self):
self.outputs = {'Out': output_np}

def test_check_output(self):
self.check_output(check_dygraph=False)
self.check_output(check_dygraph=False, check_pir_onednn=True)


class TestBilinearInterpOpOneDNNNHWC(TestBilinearInterpOneDNNOp):
Expand Down