Skip to content

Commit a8d4bd3

Browse files
zhangbo9674Wanglongzhi2001
authored andcommitted
[PIR] add slice_array_dense api (PaddlePaddle#60433)
* fix * fix
1 parent 4e1a2f7 commit a8d4bd3

File tree

6 files changed

+123
-2
lines changed

6 files changed

+123
-2
lines changed

paddle/fluid/pir/dialect/operator/ir/manual_api.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,5 +165,12 @@ std::tuple<pir::OpResult, pir::OpResult> array_to_tensor(pir::Value x,
165165
return std::make_tuple(array_to_tensor.result(0), array_to_tensor.result(1));
166166
}
167167

168+
pir::OpResult slice_array_dense(pir::Value input, pir::Value starts) {
169+
auto op = ApiBuilder::Instance()
170+
.GetBuilder()
171+
->Build<paddle::dialect::SliceArrayDenseOp>(input, starts);
172+
return op.result(0);
173+
}
174+
168175
} // namespace dialect
169176
} // namespace paddle

paddle/fluid/pir/dialect/operator/ir/manual_api.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,5 +72,7 @@ std::tuple<pir::OpResult, pir::OpResult> array_to_tensor(pir::Value x,
7272
int axis,
7373
bool use_stack);
7474

75+
pir::OpResult slice_array_dense(pir::Value input, pir::Value starts);
76+
7577
} // namespace dialect
7678
} // namespace paddle

paddle/fluid/pir/dialect/operator/ir/manual_op.cc

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1807,7 +1807,7 @@ OpInfoTuple SliceArrayDenseOp::GetOpInfo() {
18071807
paddle::dialect::OpOutputInfo(
18081808
"out", "paddle::dialect::DenseTensorType", false, false)};
18091809
paddle::dialect::OpRunTimeInfo run_time_info =
1810-
paddle::dialect::OpRunTimeInfo("SliceArrayInferMeta",
1810+
paddle::dialect::OpRunTimeInfo("SliceArrayDenseInferMeta",
18111811
{"input", "starts"},
18121812
"slice_array_dense",
18131813
{"input", "starts"},
@@ -1855,6 +1855,71 @@ void SliceArrayDenseOp::VerifySig() {
18551855
VLOG(4) << "End Verifying for: SliceArrayOp.";
18561856
}
18571857

1858+
void SliceArrayDenseOp::Build(pir::Builder &builder, // NOLINT
1859+
pir::OperationArgument &argument, // NOLINT
1860+
pir::Value input,
1861+
pir::Value starts) {
1862+
VLOG(4) << "Start build SliceArrayDenseOp";
1863+
VLOG(4) << "Builder construction inputs";
1864+
argument.AddInputs({input, starts});
1865+
VLOG(4) << "Builder construction attributes";
1866+
VLOG(4) << "Builder construction outputs";
1867+
paddle::dialect::DenseTensorArrayType input_type =
1868+
input.type().dyn_cast<paddle::dialect::DenseTensorArrayType>();
1869+
paddle::dialect::IrTensor dense_input(
1870+
paddle::dialect::TransToPhiDataType(input_type.dtype()),
1871+
{},
1872+
input_type.data_layout(),
1873+
{});
1874+
paddle::dialect::IrMetaTensor meta_input(&dense_input);
1875+
1876+
phi::IntArray starts_list;
1877+
if (starts.dyn_cast<pir::OpResult>()
1878+
.owner()
1879+
->isa<paddle::dialect::FullIntArrayOp>()) {
1880+
starts_list = std::move(phi::IntArray(paddle::dialect::GetInt64Vector(
1881+
starts.dyn_cast<pir::OpResult>()
1882+
.owner()
1883+
->dyn_cast<paddle::dialect::FullIntArrayOp>()
1884+
.attribute("value"))));
1885+
} else if (starts.type().isa<pir::VectorType>()) {
1886+
size_t starts_size = starts.type().dyn_cast<pir::VectorType>().size();
1887+
starts_list =
1888+
std::move(phi::IntArray(std::vector<int64_t>(starts_size, -1)));
1889+
starts_list.SetFromTensor(true);
1890+
} else if (starts.type().isa<paddle::dialect::DenseTensorType>()) {
1891+
common::DDim starts_dim =
1892+
starts.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
1893+
size_t starts_size = common::product(starts_dim);
1894+
if (common::contain_unknown_dim(starts_dim)) {
1895+
starts_size = 1;
1896+
}
1897+
starts_list =
1898+
std::move(phi::IntArray(std::vector<int64_t>(starts_size, -1)));
1899+
starts_list.SetFromTensor(true);
1900+
} else {
1901+
PADDLE_THROW(phi::errors::Unimplemented(
1902+
"Only support VectorType or DenseTensorType"));
1903+
}
1904+
1905+
paddle::dialect::IrTensor dense_out;
1906+
paddle::dialect::IrMetaTensor meta_out(&dense_out);
1907+
1908+
phi::SliceArrayDenseInferMeta(
1909+
meta_input, starts_list, &meta_out, phi::MetaConfig(false, false));
1910+
1911+
std::vector<pir::Type> argument_outputs;
1912+
pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get(
1913+
pir::IrContext::Instance(),
1914+
paddle::dialect::TransToIrDataType(dense_out.dtype()),
1915+
dense_out.dims(),
1916+
dense_out.layout(),
1917+
dense_out.lod(),
1918+
dense_out.offset());
1919+
argument_outputs.push_back(out_dense_tensor_type);
1920+
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
1921+
}
1922+
18581923
void SliceArrayDenseOp::InferMeta(phi::InferMetaContext *infer_meta) {
18591924
auto fn = PD_INFER_META(phi::SliceArrayDenseInferMeta);
18601925
fn(infer_meta);

paddle/fluid/pir/dialect/operator/ir/manual_op.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,11 @@ class SliceArrayDenseOp
334334
static OpInfoTuple GetOpInfo();
335335
void VerifySig();
336336

337+
static void Build(pir::Builder &builder, // NOLINT
338+
pir::OperationArgument &argument, // NOLINT
339+
pir::Value input,
340+
pir::Value starts);
341+
337342
static phi::DataType GetKernelTypeForVar(
338343
const std::string &var_name,
339344
const phi::DataType &tensor_dtype,

paddle/fluid/pybind/manual_static_op_function.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,43 @@ static PyObject *static_api_array_to_tensor(PyObject *self,
274274
}
275275
}
276276

277+
static PyObject *static_api_slice_array_dense(PyObject *self,
278+
PyObject *args,
279+
PyObject *kwargs) {
280+
try {
281+
VLOG(6) << "Add slice_array_dense op into program";
282+
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
283+
284+
// Get Value from args
285+
PyObject *input_obj = PyTuple_GET_ITEM(args, 0);
286+
auto input = CastPyArg2Value(input_obj, "slice_array_dense", 0);
287+
288+
PyObject *starts_obj = PyTuple_GET_ITEM(args, 1);
289+
pir::Value starts;
290+
if (PyObject_CheckIROpResult(starts_obj)) {
291+
starts = CastPyArg2Value(starts_obj, "slice_array_dense", 1);
292+
} else if (PyObject_CheckIRVectorOfOpResult(starts_obj)) {
293+
std::vector<pir::Value> starts_tmp =
294+
CastPyArg2VectorOfValue(starts_obj, "slice_array_dense", 1);
295+
starts = paddle::dialect::stack(starts_tmp, /*axis*/ 0);
296+
297+
} else {
298+
std::vector<int64_t> starts_tmp =
299+
CastPyArg2Longs(starts_obj, "slice_array_dense", 1);
300+
starts = paddle::dialect::full_int_array(
301+
starts_tmp, phi::DataType::INT64, phi::CPUPlace());
302+
}
303+
304+
// Call ir static api
305+
auto static_api_out = paddle::dialect::slice_array_dense(input, starts);
306+
307+
return ToPyObject(static_api_out);
308+
} catch (...) {
309+
ThrowExceptionToPython(std::current_exception());
310+
return nullptr;
311+
}
312+
}
313+
277314
static PyMethodDef ManualOpsAPI[] = {
278315
{"set_parameter",
279316
(PyCFunction)(void (*)(void))static_api_set_parameter,
@@ -303,6 +340,10 @@ static PyMethodDef ManualOpsAPI[] = {
303340
(PyCFunction)(void (*)(void))static_api_array_to_tensor,
304341
METH_VARARGS | METH_KEYWORDS,
305342
"C++ interface function for array_to_tensor."},
343+
{"slice_array_dense",
344+
(PyCFunction)(void (*)(void))static_api_slice_array_dense,
345+
METH_VARARGS | METH_KEYWORDS,
346+
"C++ interface function for slice_array_dense."},
306347
{nullptr, nullptr, 0, nullptr}};
307348

308349
} // namespace pybind

paddle/phi/infermeta/unary.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3687,7 +3687,8 @@ void SliceArrayDenseInferMeta(const MetaTensor& input,
36873687
if (config.is_runtime) {
36883688
return;
36893689
}
3690-
out->set_dims(input.dims());
3690+
// out->set_dims(input.dims());
3691+
out->set_dtype(input.dtype());
36913692
}
36923693

36933694
void SliceRawInferMeta(const MetaTensor& input,

0 commit comments

Comments
 (0)