@@ -1807,7 +1807,7 @@ OpInfoTuple SliceArrayDenseOp::GetOpInfo() {
18071807 paddle::dialect::OpOutputInfo (
18081808 " out" , " paddle::dialect::DenseTensorType" , false , false )};
18091809 paddle::dialect::OpRunTimeInfo run_time_info =
1810- paddle::dialect::OpRunTimeInfo (" SliceArrayInferMeta " ,
1810+ paddle::dialect::OpRunTimeInfo (" SliceArrayDenseInferMeta " ,
18111811 {" input" , " starts" },
18121812 " slice_array_dense" ,
18131813 {" input" , " starts" },
@@ -1855,6 +1855,71 @@ void SliceArrayDenseOp::VerifySig() {
18551855 VLOG (4 ) << " End Verifying for: SliceArrayOp." ;
18561856}
18571857
1858+ void SliceArrayDenseOp::Build (pir::Builder &builder, // NOLINT
1859+ pir::OperationArgument &argument, // NOLINT
1860+ pir::Value input,
1861+ pir::Value starts) {
1862+ VLOG (4 ) << " Start build SliceArrayDenseOp" ;
1863+ VLOG (4 ) << " Builder construction inputs" ;
1864+ argument.AddInputs ({input, starts});
1865+ VLOG (4 ) << " Builder construction attributes" ;
1866+ VLOG (4 ) << " Builder construction outputs" ;
1867+ paddle::dialect::DenseTensorArrayType input_type =
1868+ input.type ().dyn_cast <paddle::dialect::DenseTensorArrayType>();
1869+ paddle::dialect::IrTensor dense_input (
1870+ paddle::dialect::TransToPhiDataType (input_type.dtype ()),
1871+ {},
1872+ input_type.data_layout (),
1873+ {});
1874+ paddle::dialect::IrMetaTensor meta_input (&dense_input);
1875+
1876+ phi::IntArray starts_list;
1877+ if (starts.dyn_cast <pir::OpResult>()
1878+ .owner ()
1879+ ->isa <paddle::dialect::FullIntArrayOp>()) {
1880+ starts_list = std::move (phi::IntArray (paddle::dialect::GetInt64Vector (
1881+ starts.dyn_cast <pir::OpResult>()
1882+ .owner ()
1883+ ->dyn_cast <paddle::dialect::FullIntArrayOp>()
1884+ .attribute (" value" ))));
1885+ } else if (starts.type ().isa <pir::VectorType>()) {
1886+ size_t starts_size = starts.type ().dyn_cast <pir::VectorType>().size ();
1887+ starts_list =
1888+ std::move (phi::IntArray (std::vector<int64_t >(starts_size, -1 )));
1889+ starts_list.SetFromTensor (true );
1890+ } else if (starts.type ().isa <paddle::dialect::DenseTensorType>()) {
1891+ common::DDim starts_dim =
1892+ starts.type ().dyn_cast <paddle::dialect::DenseTensorType>().dims ();
1893+ size_t starts_size = common::product (starts_dim);
1894+ if (common::contain_unknown_dim (starts_dim)) {
1895+ starts_size = 1 ;
1896+ }
1897+ starts_list =
1898+ std::move (phi::IntArray (std::vector<int64_t >(starts_size, -1 )));
1899+ starts_list.SetFromTensor (true );
1900+ } else {
1901+ PADDLE_THROW (phi::errors::Unimplemented (
1902+ " Only support VectorType or DenseTensorType" ));
1903+ }
1904+
1905+ paddle::dialect::IrTensor dense_out;
1906+ paddle::dialect::IrMetaTensor meta_out (&dense_out);
1907+
1908+ phi::SliceArrayDenseInferMeta (
1909+ meta_input, starts_list, &meta_out, phi::MetaConfig (false , false ));
1910+
1911+ std::vector<pir::Type> argument_outputs;
1912+ pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get (
1913+ pir::IrContext::Instance (),
1914+ paddle::dialect::TransToIrDataType (dense_out.dtype ()),
1915+ dense_out.dims (),
1916+ dense_out.layout (),
1917+ dense_out.lod (),
1918+ dense_out.offset ());
1919+ argument_outputs.push_back (out_dense_tensor_type);
1920+ argument.AddOutputs (argument_outputs.begin (), argument_outputs.end ());
1921+ }
1922+
18581923void SliceArrayDenseOp::InferMeta (phi::InferMetaContext *infer_meta) {
18591924 auto fn = PD_INFER_META (phi::SliceArrayDenseInferMeta);
18601925 fn (infer_meta);
0 commit comments