Skip to content

Commit 2ca34a7

Browse files
authored
[PIR] Support wrap_type_interface for AlloctedDenseTensorType AllocatedSelectedRowsType and AllocatedDenseTensorArrayType (#62451)
* refine code * fix
1 parent 7bfde24 commit 2ca34a7

10 files changed

Lines changed: 93 additions & 441 deletions

File tree

paddle/fluid/pir/dialect/kernel/ir/kernel_type.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
namespace paddle {
1818
namespace dialect {
1919

20+
pir::Type AllocatedDenseTensorType::prim_type() {
21+
return storage()->dense_tensor_type_;
22+
}
23+
2024
const phi::Place& AllocatedDenseTensorType::place() const {
2125
return storage()->place_;
2226
}
@@ -41,6 +45,10 @@ size_t AllocatedDenseTensorType::offset() const {
4145
return storage()->dense_tensor_type_.offset();
4246
}
4347

48+
pir::Type AllocatedSelectedRowsType::prim_type() {
49+
return storage()->selected_rows_type_;
50+
}
51+
4452
const phi::Place& AllocatedSelectedRowsType::place() const {
4553
return storage()->place_;
4654
}
@@ -65,6 +73,10 @@ size_t AllocatedSelectedRowsType::offset() const {
6573
return storage()->selected_rows_type_.offset();
6674
}
6775

76+
pir::Type AllocatedDenseTensorArrayType::prim_type() {
77+
return storage()->dense_tensor_array_type_;
78+
}
79+
6880
const phi::Place& AllocatedDenseTensorArrayType::place() const {
6981
return storage()->place_;
7082
}

paddle/fluid/pir/dialect/kernel/ir/kernel_type.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ namespace dialect {
2424
class AllocatedDenseTensorType
2525
: public pir::Type::TypeBase<AllocatedDenseTensorType,
2626
pir::Type,
27-
AllocatedDenseTensorTypeStorage> {
27+
AllocatedDenseTensorTypeStorage,
28+
pir::WrapTypeInterface> {
2829
public:
2930
using Base::Base;
3031

@@ -49,6 +50,8 @@ class AllocatedDenseTensorType
4950
ctx, place, dense_tensor_type);
5051
}
5152

53+
pir::Type prim_type();
54+
5255
const phi::Place &place() const;
5356

5457
pir::Type dtype() const;
@@ -65,7 +68,8 @@ class AllocatedDenseTensorType
6568
class AllocatedSelectedRowsType
6669
: public pir::Type::TypeBase<AllocatedSelectedRowsType,
6770
pir::Type,
68-
AllocatedSelectedRowsTypeStorage> {
71+
AllocatedSelectedRowsTypeStorage,
72+
pir::WrapTypeInterface> {
6973
public:
7074
using Base::Base;
7175

@@ -90,6 +94,8 @@ class AllocatedSelectedRowsType
9094
ctx, place, type);
9195
}
9296

97+
pir::Type prim_type();
98+
9399
const phi::Place &place() const;
94100

95101
pir::Type dtype() const;
@@ -106,7 +112,8 @@ class AllocatedSelectedRowsType
106112
class AllocatedDenseTensorArrayType
107113
: public pir::Type::TypeBase<AllocatedDenseTensorArrayType,
108114
pir::Type,
109-
AllocatedDenseTensorArrayTypeStorage> {
115+
AllocatedDenseTensorArrayTypeStorage,
116+
pir::WrapTypeInterface> {
110117
public:
111118
using Base::Base;
112119

@@ -129,6 +136,8 @@ class AllocatedDenseTensorArrayType
129136
ctx, place, type);
130137
}
131138

139+
pir::Type prim_type();
140+
132141
const phi::Place &place() const;
133142

134143
const pir::Type &dtype() const;

paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,6 @@
4444
{type} {name};
4545
if ({name}_.type().isa<{type}>()) {{
4646
{name} = {name}_.type().dyn_cast<{type}>(); (void){name};
47-
}} else if ({name}_.type().isa<{allocated_type}>()) {{
48-
{allocated_type} allocated_{name} = {name}_.type().dyn_cast<{allocated_type}>();
49-
{name} = {type}::get(pir::IrContext::Instance(),
50-
allocated_{name}.dtype(),
51-
allocated_{name}.dims(),
52-
allocated_{name}.data_layout(),
53-
allocated_{name}.lod(),
54-
allocated_{name}.offset());
55-
(void){name};
5647
}} else {{
5748
PADDLE_THROW(phi::errors::Unimplemented("Only support {type} or {allocated_type}"));
5849
}}
@@ -158,20 +149,11 @@ def GenBuildOutputsPart2(
158149
paddle::dialect::IrMetaTensor meta_{name};
159150
paddle::dialect::IrTensor ir_tensor_{name};
160151
161-
162152
if ({name}_.impl() != nullptr) {{
163153
VLOG(4) << "Builder construction dense_{name}";
164154
{type} {name};
165155
if ({name}_.type().isa<{type}>()) {{
166156
{name} = {name}_.type().dyn_cast<{type}>();
167-
}} else if ({name}_.type().isa<{allocated_type}>()) {{
168-
{allocated_type} allocated_{name} = {name}_.type().dyn_cast<{allocated_type}>();
169-
{name} = {type}::get(pir::IrContext::Instance(),
170-
allocated_{name}.dtype(),
171-
allocated_{name}.dims(),
172-
allocated_{name}.data_layout(),
173-
allocated_{name}.lod(),
174-
allocated_{name}.offset());
175157
}} else {{
176158
PADDLE_THROW(phi::errors::Unimplemented("Only support {type} or {allocated_type}"));
177159
}}
@@ -195,13 +177,6 @@ def GenBuildOutputsPart2(
195177
{name}_type.data_layout(),
196178
{name}_type.lod(),
197179
{name}_type.offset()));
198-
}} else if({name}[i].isa<paddle::dialect::AllocatedDenseTensorType>()){{
199-
auto {name}_type = {name}[i].dyn_cast<paddle::dialect::AllocatedDenseTensorType>();
200-
vec_ir_tensor_{name}.push_back(paddle::dialect::IrTensor(paddle::dialect::TransToPhiDataType({name}_type.dtype()),
201-
{name}_type.dims(),
202-
{name}_type.data_layout(),
203-
{name}_type.lod(),
204-
{name}_type.offset()));
205180
}} else {{
206181
PADDLE_THROW(phi::errors::Unimplemented("Only support DenseTensorType or AllocatedDenseTensorType"));
207182
}}
@@ -228,13 +203,6 @@ def GenBuildOutputsPart2(
228203
{name}_type.data_layout(),
229204
{name}_type.lod(),
230205
{name}_type.offset()));
231-
}} else if({name}[i].isa<paddle::dialect::AllocatedDenseTensorType>()){{
232-
auto {name}_type = {name}[i].dyn_cast<paddle::dialect::AllocatedDenseTensorType>();
233-
vec_ir_tensor_{name}.push_back(paddle::dialect::IrTensor(paddle::dialect::TransToPhiDataType({name}_type.dtype()),
234-
{name}_type.dims(),
235-
{name}_type.data_layout(),
236-
{name}_type.lod(),
237-
{name}_type.offset()));
238206
}} else {{
239207
PADDLE_THROW(phi::errors::Unimplemented("Only support DenseTensorType or AllocatedDenseTensorType"));
240208
}}
@@ -273,13 +241,6 @@ def GenBuildOutputsPart2(
273241
{name}_size = 1;
274242
}}
275243
{name} = std::vector<int64_t>({name}_size, -1);
276-
}} else if ({name}_.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {{
277-
common::DDim {name}_dim = {name}_.type().dyn_cast<paddle::dialect::AllocatedDenseTensorType>().dims();
278-
size_t {name}_size = common::product({name}_dim);
279-
if (common::contain_unknown_dim({name}_dim)) {{
280-
{name}_size = 1;
281-
}}
282-
{name} = std::vector<int64_t>({name}_size, -1);
283244
}} else {{
284245
PADDLE_THROW(phi::errors::Unimplemented("Only support VectorType or DenseTensorType or AllocatedDenseTensorType"));
285246
}}\n"""

paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -575,14 +575,6 @@ void WhileOp::VerifySig() {
575575
phi::errors::PreconditionNotMet(
576576
"Type validation failed for the 0th input, it should be a "
577577
"bool DenseTensorType."));
578-
} else if (auto cond_type =
579-
operand_type(0).dyn_cast<AllocatedDenseTensorType>()) {
580-
PADDLE_ENFORCE_EQ(
581-
cond_type.dtype().isa<pir::BoolType>(),
582-
true,
583-
phi::errors::PreconditionNotMet(
584-
"Type validation failed for the 0th input, it should be a "
585-
"bool DenseTensorType."));
586578
} else {
587579
PADDLE_THROW(phi::errors::PreconditionNotMet(
588580
"Currently, the while op cond input only support bool dense_tensor "
@@ -803,8 +795,7 @@ void HasElementsOp::VerifySig() {
803795

804796
// Verify outputs:
805797
IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1.");
806-
IR_ENFORCE((*this)->result_type(0).isa<DenseTensorType>() ||
807-
(*this)->result_type(0).isa<AllocatedDenseTensorType>(),
798+
IR_ENFORCE((*this)->result_type(0).isa<DenseTensorType>(),
808799
"The type of cf.has_elements' output is not correct.");
809800
}
810801

@@ -874,8 +865,7 @@ void AssertOp::VerifySig() {
874865
(*this)->operand(1).type().dyn_cast<pir::VectorType>()) {
875866
for (size_t i = 0; i < vec_type.size(); ++i) {
876867
IR_ENFORCE(vec_type[i].isa<paddle::dialect::DenseTensorType>() ||
877-
vec_type[i].isa<paddle::dialect::SelectedRowsType>() ||
878-
vec_type[i].isa<AllocatedDenseTensorType>(),
868+
vec_type[i].isa<paddle::dialect::SelectedRowsType>(),
879869
"Type validation failed for the 1th input.");
880870
}
881871
} else {
@@ -885,7 +875,6 @@ void AssertOp::VerifySig() {
885875
->operand(1)
886876
.type()
887877
.isa<paddle::dialect::SelectedRowsType>(),
888-
(*this)->operand(1).type().isa<AllocatedDenseTensorType>(),
889878
"Type validation failed for the 1th input.");
890879
}
891880
}

paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.cc

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -255,15 +255,6 @@ std::vector<pir::Type> ExpandOp::InferMeta(
255255
paddle::dialect::DenseTensorType x;
256256
if (x_.type().isa<paddle::dialect::DenseTensorType>()) {
257257
x = x_.type().dyn_cast<paddle::dialect::DenseTensorType>();
258-
} else if (x_.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
259-
paddle::dialect::AllocatedDenseTensorType allocated_x =
260-
x_.type().dyn_cast<paddle::dialect::AllocatedDenseTensorType>();
261-
x = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(),
262-
allocated_x.dtype(),
263-
allocated_x.dims(),
264-
allocated_x.data_layout(),
265-
allocated_x.lod(),
266-
allocated_x.offset());
267258
} else {
268259
PADDLE_THROW(phi::errors::Unimplemented(
269260
"Only support paddle::dialect::DenseTensorType or "

0 commit comments

Comments
 (0)