diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_instruction.cc b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_instruction.cc index 8803ddfe56defb..fab561df7a6e99 100644 --- a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_instruction.cc @@ -361,6 +361,26 @@ OneDNNPhiKernelInstruction::OneDNNPhiKernelInstruction( phi::MetaConfig new_config = infer_meta_context_.GetMetaConfig(); new_config.is_run_mkldnn_kernel = true; infer_meta_context_.SetMetaConfig(new_config); + + // Step5: Handle skip_transform_inputs + if (op_attributes.count("skip_transform_inputs")) { + std::vector skip_transform_inputs = + op->attributes() + .at("skip_transform_inputs") + .dyn_cast() + .AsVector(); + + for (auto& input : skip_transform_inputs) { + auto input_name = input.dyn_cast().AsString(); + auto pair = kernel_context_.InputRangeAt( + yaml_info_parser.InputName2Id().at(input_name)); + VLOG(6) << "skip_transform_input = " << input_name; + for (int i = pair.first; i < pair.second; ++i) { + skip_format_tensors_.insert(i); + VLOG(6) << input_name << " index = " << i; + } + } + } } OneDNNPhiKernelInstruction::~OneDNNPhiKernelInstruction() { @@ -381,6 +401,9 @@ void OneDNNPhiKernelInstruction::Run() { if (!input->initialized()) { continue; } + if (skip_format_tensors_.count(i)) { + continue; + } VLOG(6) << "input[" << i << "].layout() = " << input->layout(); if (input->layout() != phi::DataLayout::ONEDNN) { phi::DataLayout from_layout = input->layout(); diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_instruction.h b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_instruction.h index 9cca848549f2b3..cae045044ed3c5 100644 --- a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_instruction.h @@ -69,6 +69,7 @@ class OneDNNPhiKernelInstruction : public InstructionBase { const ValueExecutionInfo* value_exec_info_; // not owned std::set data_format_tensors_{}; + std::set skip_format_tensors_{}; phi::DataLayout input_layout_{phi::DataLayout::kAnyLayout}; std::map extra_attr_{}; std::map ctx_attr_{}; diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_instruction.cc b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_instruction.cc index 446c0c364c4419..0d14d59bcd35bd 100644 --- a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_instruction.cc @@ -217,7 +217,7 @@ OneDNNLegacyKernelInstruction::OneDNNLegacyKernelInstruction( .at("data_format_tensors") .dyn_cast() .AsVector(); - std::vector layout_transform_inputs; + auto& op_normalizer = paddle::translator::OpNameNormalizer::instance(); std::string fluid_op_name = yaml_info_parser.GetOriginOpName(); for (auto& attr : data_format_tensors_attr) { @@ -231,6 +231,24 @@ OneDNNLegacyKernelInstruction::OneDNNLegacyKernelInstruction( phi::MetaConfig new_config = infer_meta_context_.GetMetaConfig(); new_config.is_run_mkldnn_kernel = true; infer_meta_context_.SetMetaConfig(new_config); + + // Step4: Handle skip_transform_inputs + if (op_attributes.count("skip_transform_inputs")) { + std::vector skip_transform_inputs = + op->attributes() + .at("skip_transform_inputs") + .dyn_cast() + .AsVector(); + + auto& op_normalizer = paddle::translator::OpNameNormalizer::instance(); + std::string fluid_op_name = yaml_info_parser.GetOriginOpName(); + + for (auto& input : skip_transform_inputs) { + auto input_name = input.dyn_cast().AsString(); + skip_format_tensors_.insert( + op_normalizer.GetLegacyArgName(fluid_op_name, input_name)); + } + } } OneDNNLegacyKernelInstruction::~OneDNNLegacyKernelInstruction() { @@ -247,6 +265,9 @@ void OneDNNLegacyKernelInstruction::Run() { // Step1. TransLayout auto inputs = kernel_context_->InNameList(); for (auto& input_name : inputs) { + if (skip_format_tensors_.count(*input_name)) { + continue; + } auto input_vars = kernel_context_->MultiInputVar(*input_name); for (auto& var : input_vars) { if (var->IsType()) { diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_instruction.h b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_instruction.h index 86192e737750d3..ccd255598b8465 100644 --- a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_instruction.h @@ -68,6 +68,7 @@ class OneDNNLegacyKernelInstruction : public InstructionBase { const ValueExecutionInfo* value_exec_info_; // not owned std::set data_format_tensors_{}; + std::set skip_format_tensors_{}; phi::DataLayout input_layout_{phi::DataLayout::kAnyLayout}; }; diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 92583472e10025..e5696c36b3a2a2 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -273,7 +273,7 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ std::vector outputs = {{ {outputs} }}; pir::AttributeMap extra_attr_default_value; {extra_attr_default_value_code} - paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, "{kernel_func}", {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}}, {{{extra_args}}}, extra_attr_default_value, {{{data_format_tensors}}}, {is_onednn_only}, {dynamic_fallback}); + paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, "{kernel_func}", {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}}, {{{extra_args}}}, {{{skip_transform_inputs}}}, extra_attr_default_value, {{{data_format_tensors}}}, {is_onednn_only}, {dynamic_fallback}); return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}"); }} """ @@ -1788,6 +1788,22 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): ) else: extra_attr_default_value_code_str = "" + skip_transform_inputs = "" + if op_info.data_transform_map is not None: + if "skip_transform" in op_info.data_transform_map: + skip_transform = op_info.data_transform_map[ + "skip_transform" + ] + if skip_transform is not None: + skip_transform_input_names = [] + for input in skip_transform: + skip_transform_input_names.append(input) + + skip_transform_inputs = ( + '"' + + '", "'.join(skip_transform_input_names) + + '"' + ) op_info_func_str = OP_INFO_ONEDNN_TEMPLATE.format( op_name=op_class_name, @@ -1805,6 +1821,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): view=view_str, origin_op_name=op_info.op_yaml_item['name'], extra_args=extra_args, + skip_transform_inputs=skip_transform_inputs, data_format_tensors=data_format_tensors, is_onednn_only="true" if op_info.is_onednn_only diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml index d9d8fe69990249..10e521e9023aa1 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml @@ -224,8 +224,8 @@ # - op : sgd_dense_param_sparse_grad -# - op : shape -# extra_args : str mkldnn_data_type="float32" +- op : shape + extra_args : str mkldnn_data_type="float32" - op : shuffle_channel diff --git a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h index aacf9a69861528..86370dd0cc6c19 100644 --- a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h +++ b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h @@ -95,6 +95,7 @@ struct OpRunTimeInfo { std::vector> inplace; std::vector> view; std::vector extra_args; + std::vector skip_transform_inputs; pir::AttributeMap extra_args_default_value; std::vector data_format_tensors; bool is_onednn_only; @@ -109,6 +110,7 @@ struct OpRunTimeInfo { const std::vector>& inplace, const std::vector>& view, const std::vector& extra_args = {}, + const std::vector& skip_transform_inputs = {}, const pir::AttributeMap& extra_args_default_value = {{}}, const std::vector& data_format_tensors = {}, bool is_onednn_only = false, @@ -122,6 +124,7 @@ struct OpRunTimeInfo { inplace(inplace), view(view), extra_args(extra_args), + skip_transform_inputs(skip_transform_inputs), extra_args_default_value(extra_args_default_value), data_format_tensors(data_format_tensors), is_onednn_only(is_onednn_only), diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 688c461a72091d..9c2c21f8c9d218 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -2327,6 +2327,13 @@ pir::Operation* BuildKernelOp( op_attribute.emplace( "extra_args", pir::ArrayAttribute::get(pir::IrContext::Instance(), extra_args)); + std::vector skip_transform_inputs; + for (auto& arg : op_info_parser->OpRuntimeInfo().skip_transform_inputs) { + skip_transform_inputs.push_back(pir::StrAttribute::get(ctx, arg)); + } + op_attribute.emplace("skip_transform_inputs", + pir::ArrayAttribute::get(pir::IrContext::Instance(), + skip_transform_inputs)); std::vector data_format_tensors; for (auto& input : op_info_parser->OpRuntimeInfo().data_format_tensors) { data_format_tensors.push_back(pir::StrAttribute::get(ctx, input)); diff --git a/test/mkldnn/test_shape_mkldnn_op.py b/test/mkldnn/test_shape_mkldnn_op.py index ed9b81f8e35995..4ae0e02b98f99e 100644 --- a/test/mkldnn/test_shape_mkldnn_op.py +++ b/test/mkldnn/test_shape_mkldnn_op.py @@ -35,7 +35,7 @@ def config(self): self.dtype = np.float32 def test_check_output(self): - self.check_output_with_place(core.CPUPlace()) + self.check_output_with_place(core.CPUPlace(), check_pir_onednn=True) class TestShape0DFP32OneDNNOp(TestShape3DFP32OneDNNOp):