Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/fluid/pir/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,9 @@ set(op_dialect_srcs
${PADDLE_SOURCE_DIR}/paddle/fluid/pir/transforms/shape_optimization_pass.cc)

if(WITH_MKLDNN)
set(op_dialect_srcs ${op_dialect_srcs} ${onednn_op_source_file}
${op_onednn_info_file})
set(op_dialect_srcs
${op_dialect_srcs} ${onednn_op_source_file} ${op_onednn_info_file}
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_onednn_op.cc)
endif()

set(op_dialect_deps phi common pir type_info string_helper)
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

{op_header}
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.h"
#endif

namespace paddle {{
namespace drr {{
Expand Down
9 changes: 8 additions & 1 deletion paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
'assign_out_',
}

ONEDNN_MANUAL_OP_LIST = {
'split_grad',
'expand',
}

attr_types_map = {
'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'],
'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'],
Expand Down Expand Up @@ -1345,7 +1350,9 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
if len(op_traits) > 0:
op_traits_str = "," + ",".join(op_traits)

if op_name in PD_MANUAL_OP_LIST:
if dialect_name == "onednn_op" and op_name in ONEDNN_MANUAL_OP_LIST:
continue
elif dialect_name != "onednn_op" and op_name in PD_MANUAL_OP_LIST:
continue
if op_kernel_map is None:
func_list = [None]
Expand Down
346 changes: 346 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,346 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef GET_OP_LIST
#undef GET_OP_LIST
paddle::onednn::dialect::ExpandOp
#else

#include "paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.h"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h"
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.h"
#include "paddle/fluid/pir/dialect/operator/ir/ir_meta_tensor.h"
#include "paddle/fluid/pir/dialect/operator/ir/ir_selected_rows.h"
#include "paddle/fluid/pir/dialect/operator/ir/ir_tensor.h"
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/phi/api/lib/data_type_set.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/fusion.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/pir/include/core/builtin_attribute.h"
#include "paddle/pir/include/core/builtin_op.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/core/ir_context.h"
#include "paddle/pir/include/core/op_base.h"

namespace paddle {
namespace onednn {
namespace dialect {

const char* ExpandOp::attributes_name[1] = {"mkldnn_data_type"};

OpInfoTuple ExpandOp::GetOpInfo() {
std::vector<paddle::dialect::OpInputInfo> inputs = {
paddle::dialect::OpInputInfo(
"x", "paddle::dialect::DenseTensorType", false, false, false, true),
paddle::dialect::OpInputInfo("shape",
"paddle::dialect::IntArrayAttribute",
false,
false,
true,
false)};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {
paddle::dialect::OpAttributeInfo(
"mkldnn_data_type", "pir::StrAttribute", "")};
std::vector<paddle::dialect::OpOutputInfo> outputs = {
paddle::dialect::OpOutputInfo(
"out", "paddle::dialect::DenseTensorType", false, false)};
pir::AttributeMap extra_attr_default_value;
pir::Attribute attr_mkldnn_data_type =
pir::StrAttribute::get(pir::IrContext::Instance(), "float32");
extra_attr_default_value["mkldnn_data_type"] = attr_mkldnn_data_type;

paddle::dialect::OpRunTimeInfo run_time_info =
paddle::dialect::OpRunTimeInfo("ExpandInferMeta",
{"x", "shape"},
"expand",
{"x", "shape"},
{"x"},
{},
{},
{},
{"mkldnn_data_type"},
{},
extra_attr_default_value,
{},
false,
false);
return std::make_tuple(inputs, attributes, outputs, run_time_info, "expand");
}

void ExpandOp::Build(pir::Builder& builder,
pir::OperationArgument& argument,
pir::Value x_,
const std::vector<int64_t>& shape,
const std::string& mkldnn_data_type) {
VLOG(4) << "Start build ExpandOp";

// Generate int_array mutable attribute: shape
paddle::dialect::FullIntArrayOp full_shape_op =
builder.Build<paddle::dialect::FullIntArrayOp>(
shape, phi::DataType::INT64, phi::CPUPlace());
pir::Value shape_ = full_shape_op->result(0);

VLOG(4) << "Builder construction inputs";
std::vector<pir::Value> argument_inputs = {x_, shape_};
argument.AddInputs(argument_inputs);

VLOG(4) << "Builder construction attributes";
pir::AttributeMap argument_attributes = {};
pir::Attribute attr_mkldnn_data_type =
pir::StrAttribute::get(pir::IrContext::Instance(), mkldnn_data_type);
argument.AddAttribute("mkldnn_data_type", attr_mkldnn_data_type);
argument_attributes.insert({"mkldnn_data_type", attr_mkldnn_data_type});

std::vector<pir::Type> argument_outputs =
ExpandOp::InferMeta(argument_inputs, argument_attributes);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void ExpandOp::Build(pir::Builder& builder,
pir::OperationArgument& argument,
pir::Value x_,
pir::AttributeMap attributes) {
VLOG(4) << "Start build ExpandOp";

IR_ENFORCE(attributes.find("shape") != attributes.end(),
"'shape' Attribute is expected for ExpandOp. ");
std::vector<int64_t> shape =
attributes.at("shape")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData();

IR_ENFORCE(attributes.find("mkldnn_data_type") != attributes.end(),
"'mkldnn_data_type' Attribute is expected for ExpandOp. ");
std::string mkldnn_data_type = attributes.at("mkldnn_data_type")
.dyn_cast<pir::StrAttribute>()
.AsString();

// Generate int_array mutable attribute: shape
paddle::dialect::FullIntArrayOp full_shape_op =
builder.Build<paddle::dialect::FullIntArrayOp>(
shape, phi::DataType::INT64, phi::CPUPlace());
pir::Value shape_ = full_shape_op->result(0);

VLOG(4) << "Builder construction inputs";
std::vector<pir::Value> argument_inputs = {x_, shape_};
argument.AddInputs(argument_inputs);

VLOG(4) << "Builder construction attributes";
pir::AttributeMap argument_attributes = {};
pir::Attribute attr_mkldnn_data_type =
pir::StrAttribute::get(pir::IrContext::Instance(), mkldnn_data_type);
argument.AddAttribute("mkldnn_data_type", attr_mkldnn_data_type);
argument_attributes.insert({"mkldnn_data_type", attr_mkldnn_data_type});

std::vector<pir::Type> argument_outputs =
ExpandOp::InferMeta(argument_inputs, argument_attributes);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void ExpandOp::Build(pir::Builder& builder,
pir::OperationArgument& argument,
pir::Value x_,
pir::Value shape_,
const std::string& mkldnn_data_type) {
VLOG(4) << "Start build ExpandOp";

VLOG(4) << "Builder construction inputs";
std::vector<pir::Value> argument_inputs = {x_, shape_};
argument.AddInputs(argument_inputs);

VLOG(4) << "Builder construction attributes";
pir::AttributeMap argument_attributes = {};
pir::Attribute attr_mkldnn_data_type =
pir::StrAttribute::get(pir::IrContext::Instance(), mkldnn_data_type);
argument.AddAttribute("mkldnn_data_type", attr_mkldnn_data_type);
argument_attributes.insert({"mkldnn_data_type", attr_mkldnn_data_type});

std::vector<pir::Type> argument_outputs =
ExpandOp::InferMeta(argument_inputs, argument_attributes);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void ExpandOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: ExpandOp.";
VLOG(4) << "Verifying inputs:";
{
auto input_size = num_operands();
IR_ENFORCE(input_size == 2u,
"The size %d of inputs must be equal to 2.",
input_size);
IR_ENFORCE((*this)
->operand_source(0)
.type()
.isa<paddle::dialect::DenseTensorType>(),
"Type validation failed for the 0th input, got %s.",
(*this)->operand_source(0).type());
if (auto vec_type =
(*this)->operand_source(1).type().dyn_cast<pir::VectorType>()) {
for (size_t i = 0; i < vec_type.size(); ++i) {
IR_ENFORCE(vec_type[i].isa<paddle::dialect::DenseTensorType>(),
"Type validation failed for the 1th input, got %s.",
(*this)->operand_source(1).type());
}
} else {
IR_ENFORCE((*this)
->operand_source(1)
.type()
.isa<paddle::dialect::DenseTensorType>(),
"Type validation failed for the 1th input, got %s.",
(*this)->operand_source(1).type());
}
}
VLOG(4) << "Verifying attributes:";
{
auto& attributes = this->attributes();
IR_ENFORCE(attributes.count("mkldnn_data_type") > 0,
"mkldnn_data_type does not exist.");
IR_ENFORCE(attributes.at("mkldnn_data_type").isa<pir::StrAttribute>(),
"Type of attribute: mkldnn_data_type is not pir::StrAttribute.");
}
VLOG(4) << "Verifying outputs:";
{
auto output_size = num_results();
IR_ENFORCE(output_size == 1u,
"The size %d of outputs must be equal to 1.",
output_size);
IR_ENFORCE(
(*this)->result(0).type().isa<paddle::dialect::DenseTensorType>(),
"Type validation failed for the 0th output.");
}
VLOG(4) << "End Verifying for: ExpandOp.";
}

void ExpandOp::InferMeta(phi::InferMetaContext* infer_meta) {
auto fn = PD_INFER_META(phi::ExpandInferMeta);
fn(infer_meta);
}

std::vector<pir::Type> ExpandOp::InferMeta(
const std::vector<pir::Value>& input_values,
const pir::AttributeMap& attributes) {
IR_ENFORCE(input_values.size() == 2,
"Num of inputs is expected to be 2 but got %d.",
input_values.size());

pir::Value x_ = input_values[0];
pir::Value shape_ = input_values[1];
VLOG(4) << "Builder construction outputs";

paddle::dialect::DenseTensorType x;
if (x_.type().isa<paddle::dialect::DenseTensorType>()) {
x = x_.type().dyn_cast<paddle::dialect::DenseTensorType>();
} else if (x_.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
paddle::dialect::AllocatedDenseTensorType allocated_x =
x_.type().dyn_cast<paddle::dialect::AllocatedDenseTensorType>();
x = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(),
allocated_x.dtype(),
allocated_x.dims(),
allocated_x.data_layout(),
allocated_x.lod(),
allocated_x.offset());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support paddle::dialect::DenseTensorType or "
"paddle::dialect::AllocatedDenseTensorType"));
}

phi::IntArray shape;
if (shape_.defining_op()->isa<paddle::dialect::FullIntArrayOp>()) {
shape = std::move(phi::IntArray(paddle::dialect::GetInt64Vector(
shape_.defining_op()
->dyn_cast<paddle::dialect::FullIntArrayOp>()
.attribute("value"))));
} else if (shape_.type().isa<pir::VectorType>()) {
size_t shape_size = shape_.type().dyn_cast<pir::VectorType>().size();
// In ExpandInferMeta use -2 to represent the element in expand_shape is a
// var.
shape = std::move(phi::IntArray(std::vector<int64_t>(shape_size, -2)));
shape.SetFromTensor(true);
} else if (shape_.type().isa<paddle::dialect::DenseTensorType>()) {
size_t shape_size = common::product(
shape_.type().dyn_cast<paddle::dialect::DenseTensorType>().dims());
// In ExpandInferMeta use -2 to represent the element in expand_shape is a
// var.
shape = std::move(phi::IntArray(std::vector<int64_t>(shape_size, -2)));
shape.SetFromTensor(true);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support VectorType or DenseTensorType"));
}

VLOG(4) << "Builder construction dense_x";
paddle::dialect::IrTensor ir_tensor_x(
paddle::dialect::TransToPhiDataType(x.dtype()),
x.dims(),
x.data_layout(),
x.lod(),
x.offset());
VLOG(4) << "Builder construction meta_x";
paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x);
paddle::dialect::IrTensor dense_out;
paddle::dialect::IrMetaTensor meta_out(&dense_out);

phi::ExpandInferMeta(meta_x, shape, &meta_out);

std::vector<pir::Type> argument_outputs;
pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get(
pir::IrContext::Instance(),
paddle::dialect::TransToIrDataType(dense_out.dtype()),
dense_out.dims(),
dense_out.layout(),
dense_out.lod(),
dense_out.offset());
argument_outputs.push_back(out_dense_tensor_type);

return argument_outputs;
}

phi::DataType ExpandOp::GetKernelTypeForVar(
const std::string& var_name,
const phi::DataType& tensor_dtype,
const phi::DataType& expected_kernel_dtype) {
VLOG(4) << "Get KernelType for Var of op: ExpandOp";

return expected_kernel_dtype;
}

bool ExpandOp::InferSymbolicShape(
pir::ShapeConstraintIRAnalysis* shape_analysis) {
VLOG(4) << "Infer symbolic shape for op: ExpandOp";
return paddle::dialect::ExpandOpInferSymbolicShape(this->operation(),
shape_analysis);
}

} // namespace dialect
} // namespace onednn
} // namespace paddle

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::onednn::dialect::ExpandOp)
#endif
Loading