From ad0dedc79bcd38f908c1f19b73760022169aca62 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Thu, 28 Dec 2023 05:36:54 +0000 Subject: [PATCH 1/2] fix --- .../pir/dialect/operator/ir/manual_api.cc | 7 ++ .../pir/dialect/operator/ir/manual_api.h | 2 + .../pir/dialect/operator/ir/manual_op.cc | 68 ++++++++++++++++++- .../fluid/pir/dialect/operator/ir/manual_op.h | 5 ++ .../fluid/pybind/manual_static_op_function.h | 41 +++++++++++ paddle/phi/infermeta/unary.cc | 3 +- 6 files changed, 124 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc index d9c5debe92ee66..33fecafdbb0258 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc @@ -165,5 +165,12 @@ std::tuple array_to_tensor(pir::Value x, return std::make_tuple(array_to_tensor.result(0), array_to_tensor.result(1)); } +pir::OpResult slice_array_dense(pir::Value input, pir::Value starts) { + auto op = ApiBuilder::Instance() + .GetBuilder() + ->Build(input, starts); + return op.result(0); +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.h b/paddle/fluid/pir/dialect/operator/ir/manual_api.h index 680cd5b54ab905..347e10494696c0 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.h @@ -72,5 +72,7 @@ std::tuple array_to_tensor(pir::Value x, int axis, bool use_stack); +pir::OpResult slice_array_dense(pir::Value input, pir::Value starts); + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index d3d8c46111bbb2..97dc594b6bda38 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -1807,7 +1807,7 @@ OpInfoTuple SliceArrayDenseOp::GetOpInfo() { paddle::dialect::OpOutputInfo( "out", "paddle::dialect::DenseTensorType", false, false)}; paddle::dialect::OpRunTimeInfo run_time_info = - paddle::dialect::OpRunTimeInfo("SliceArrayInferMeta", + paddle::dialect::OpRunTimeInfo("SliceArrayDenseInferMeta", {"input", "starts"}, "slice_array_dense", {"input", "starts"}, @@ -1855,6 +1855,72 @@ void SliceArrayDenseOp::VerifySig() { VLOG(4) << "End Verifying for: SliceArrayOp."; } +static void SliceArrayDenseOp::Build( + pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value input, + pir::Value starts) { + VLOG(4) << "Start build SliceArrayDenseOp"; + VLOG(4) << "Builder construction inputs"; + argument.AddInputs({input, starts}); + VLOG(4) << "Builder construction attributes"; + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorArrayType input_type = + input.type().dyn_cast(); + paddle::dialect::IrTensor dense_input( + paddle::dialect::TransToPhiDataType(input_type.dtype()), + {}, + input_type.data_layout(), + {}); + paddle::dialect::IrMetaTensor meta_input(&dense_input); + + phi::IntArray starts_list; + if (starts.dyn_cast() + .owner() + ->isa()) { + starts_list = std::move(phi::IntArray(paddle::dialect::GetInt64Vector( + starts.dyn_cast() + .owner() + ->dyn_cast() + .attribute("value")))); + } else if (starts.type().isa()) { + size_t starts_size = starts.type().dyn_cast().size(); + starts_list = + std::move(phi::IntArray(std::vector(starts_size, -1))); + starts_list.SetFromTensor(true); + } else if (starts.type().isa()) { + common::DDim starts_dim = + starts.type().dyn_cast().dims(); + size_t starts_size = common::product(starts_dim); + if (common::contain_unknown_dim(starts_dim)) { + starts_size = 1; + } + starts_list = + std::move(phi::IntArray(std::vector(starts_size, -1))); + starts_list.SetFromTensor(true); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support VectorType or DenseTensorType")); + } + + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::SliceArrayDenseInferMeta( + meta_input, starts_list, &meta_out, phi::MetaConfig(false, false)); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); +} + void SliceArrayDenseOp::InferMeta(phi::InferMetaContext *infer_meta) { auto fn = PD_INFER_META(phi::SliceArrayDenseInferMeta); fn(infer_meta); diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 4d001206409512..121c95dee169aa 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -334,6 +334,11 @@ class SliceArrayDenseOp static OpInfoTuple GetOpInfo(); void VerifySig(); + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value input, + pir::Value starts); + static phi::DataType GetKernelTypeForVar( const std::string &var_name, const phi::DataType &tensor_dtype, diff --git a/paddle/fluid/pybind/manual_static_op_function.h b/paddle/fluid/pybind/manual_static_op_function.h index 21285163dd64f1..dc09d539f39ffb 100644 --- a/paddle/fluid/pybind/manual_static_op_function.h +++ b/paddle/fluid/pybind/manual_static_op_function.h @@ -274,6 +274,43 @@ static PyObject *static_api_array_to_tensor(PyObject *self, } } +static PyObject *static_api_slice_array_dense(PyObject *self, + PyObject *args, + PyObject *kwargs) { + try { + VLOG(6) << "Add slice_array_dense op into program"; + VLOG(8) << "args count: " << (PyTuple_Size(args) / 2); + + // Get Value from args + PyObject *input_obj = PyTuple_GET_ITEM(args, 0); + auto input = CastPyArg2Value(input_obj, "slice_array_dense", 0); + + PyObject *starts_obj = PyTuple_GET_ITEM(args, 1); + pir::Value starts; + if (PyObject_CheckIROpResult(starts_obj)) { + starts = CastPyArg2Value(starts_obj, "slice_array_dense", 1); + } else if (PyObject_CheckIRVectorOfOpResult(starts_obj)) { + std::vector starts_tmp = + CastPyArg2VectorOfValue(starts_obj, "slice_array_dense", 1); + starts = paddle::dialect::stack(starts_tmp, /*axis*/ 0); + + } else { + std::vector starts_tmp = + CastPyArg2Longs(starts_obj, "slice_array_dense", 1); + starts = paddle::dialect::full_int_array( + starts_tmp, phi::DataType::INT64, phi::CPUPlace()); + } + + // Call ir static api + auto static_api_out = paddle::dialect::slice_array_dense(input, starts); + + return ToPyObject(static_api_out); + } catch (...) { + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + static PyMethodDef ManualOpsAPI[] = { {"set_parameter", (PyCFunction)(void (*)(void))static_api_set_parameter, @@ -303,6 +340,10 @@ static PyMethodDef ManualOpsAPI[] = { (PyCFunction)(void (*)(void))static_api_array_to_tensor, METH_VARARGS | METH_KEYWORDS, "C++ interface function for array_to_tensor."}, + {"slice_array_dense", + (PyCFunction)(void (*)(void))static_api_slice_array_dense, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for slice_array_dense."}, {nullptr, nullptr, 0, nullptr}}; } // namespace pybind diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 90987398057fe9..a75cd4170e2785 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3687,7 +3687,8 @@ void SliceArrayDenseInferMeta(const MetaTensor& input, if (config.is_runtime) { return; } - out->set_dims(input.dims()); + // out->set_dims(input.dims()); + out->set_dtype(input.dtype()); } void SliceRawInferMeta(const MetaTensor& input, From d06422b93fc1a6e6a916d3cc795b9b355192536f Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Thu, 28 Dec 2023 05:56:37 +0000 Subject: [PATCH 2/2] fix --- paddle/fluid/pir/dialect/operator/ir/manual_op.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 97dc594b6bda38..0a60b4c7d7d819 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -1855,11 +1855,10 @@ void SliceArrayDenseOp::VerifySig() { VLOG(4) << "End Verifying for: SliceArrayOp."; } -static void SliceArrayDenseOp::Build( - pir::Builder &builder, // NOLINT - pir::OperationArgument &argument, // NOLINT - pir::Value input, - pir::Value starts) { +void SliceArrayDenseOp::Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value input, + pir::Value starts) { VLOG(4) << "Start build SliceArrayDenseOp"; VLOG(4) << "Builder construction inputs"; argument.AddInputs({input, starts});