Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 111 additions & 11 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,20 +287,30 @@ std::vector<std::vector<pir::OpResult>> IfOp::Vjp(
void WhileOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value cond,
const std::vector<pir::Value> &inputs) {
const std::vector<pir::Value> &inputs,
bool construct_body) {
argument.AddInput(cond);
argument.AddInputs(inputs);
auto &body = argument.AddRegion().emplace_back();
std::vector<pir::Attribute> outs_stop_gradient;
for (auto val : inputs) {
argument.AddOutput(val.type());
auto arg = body.AddArgument(val.type());

auto bool_attr = val.attribute<pir::BoolAttribute>(kStopGradientAttrName);
arg.set_attribute(kStopGradientAttrName,
bool_attr ? bool_attr : builder.bool_attr(false));
outs_stop_gradient.push_back(bool_attr ? bool_attr
: builder.bool_attr(false));
if (construct_body) {
auto &body = argument.AddRegion().emplace_back();
for (auto val : inputs) {
argument.AddOutput(val.type());
auto arg = body.AddArgument(val.type());
auto bool_attr = val.attribute<pir::BoolAttribute>(kStopGradientAttrName);
outs_stop_gradient.push_back(bool_attr ? bool_attr
: builder.bool_attr(false));
arg.set_attribute(kStopGradientAttrName,
bool_attr ? bool_attr : builder.bool_attr(false));
}
} else {
argument.AddRegion(nullptr);
for (auto val : inputs) {
argument.AddOutput(val.type());
auto bool_attr = val.attribute<pir::BoolAttribute>(kStopGradientAttrName);
outs_stop_gradient.push_back(bool_attr ? bool_attr
: builder.bool_attr(false));
}
}

argument.AddAttribute(
Expand Down Expand Up @@ -343,6 +353,96 @@ void WhileOp::Print(pir::IrPrinter &printer) {
os << "\n }";
}

void WhileOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: WhileOp.";
auto input_size = num_operands();
PADDLE_ENFORCE_GE(
input_size,
1u,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be greater or equal to 1.", input_size));

if (auto cond_type = operand_type(0).dyn_cast<pir::DenseTensorType>()) {
PADDLE_ENFORCE_EQ(
cond_type.dtype().isa<pir::BoolType>(),
true,
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input, it should be a "
"bool DenseTensorType."));
} else if (auto cond_type =
operand_type(0).dyn_cast<AllocatedDenseTensorType>()) {
PADDLE_ENFORCE_EQ(
cond_type.dtype().isa<pir::BoolType>(),
true,
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input, it should be a "
"bool DenseTensorType."));
} else {
PADDLE_THROW(phi::errors::PreconditionNotMet(
"Currently, the while op cond input only support bool dense_tensor "
"and bool allocated_dense_tensor."));
}
PADDLE_ENFORCE_EQ((*this)->num_regions(),
1u,
phi::errors::PreconditionNotMet(
"The size %d of regions must be equal to 1.",
(*this)->num_regions()));
auto output_size = num_results();
PADDLE_ENFORCE_EQ(output_size + 1,
input_size,
phi::errors::PreconditionNotMet(
"The result size (%d) not equal to input size(%d) + 1.",
num_results(),
input_size));
for (size_t index = 0; index < output_size; ++index) {
PADDLE_ENFORCE_EQ(
operand_type(index + 1),
result_type(index),
phi::errors::PreconditionNotMet(
"The (%d) result and operand type is not equal.", index));
}
}

void WhileOp::VerifyRegion() {
VLOG(4) << "Start verifying sub regions for: WhileOp.";
PADDLE_ENFORCE_EQ(
(*this)->region(0).size(),
1u,
phi::errors::PreconditionNotMet("The size %d of body_region must be 1.",
(*this)->region(0).size()));
auto &body_block = body();
auto output_size = num_results();
PADDLE_ENFORCE_EQ(
body_block.args_size(),
output_size,
phi::errors::PreconditionNotMet(
"The result size (%d) not equal to block args size(%d) + 1.",
output_size,
body_block.args_size()));

PADDLE_ENFORCE_EQ(
body_block.empty(),
false,
phi::errors::PreconditionNotMet("The body block is empty."));

auto yield_op = body_block.back().dyn_cast<pir::YieldOp>();
auto input_size = num_operands();
PADDLE_ENFORCE_EQ(
yield_op && yield_op.num_operands() == input_size,
true,
phi::errors::PreconditionNotMet(
"The body block yield size not equal to operands size."));
// Todo: fix other bugs and make the following code work.
// for (size_t index = 0; index < input_size; ++index) {
// PADDLE_ENFORCE_EQ(
// operand_type(index),
// yield_op.operand_type(index),
// phi::errors::PreconditionNotMet(
// "The (%d) operand and block yield type is not equal.", index));
// }
VLOG(4) << "Successful end verifying sub regions for: WhileOp.";
}

std::vector<std::vector<pir::OpResult>> WhileOp::Vjp(
pir::Operation *op,
const std::vector<std::vector<pir::Value>> &inputs,
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,14 @@ class WhileOp : public pir::Op<WhileOp, VjpInterface> {
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value cond,
const std::vector<pir::Value> &inputs);
const std::vector<pir::Value> &inputs,
bool construct_body = true);
TEST_API pir::Block &body();
pir::Value cond();
const pir::Block::ArgListType &block_args() { return body().args(); }
void Print(pir::IrPrinter &printer); // NOLINT
void VerifySig() {}
void VerifyRegion() {}
void VerifySig();
void VerifyRegion();
static std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation *op,
const std::vector<std::vector<pir::Value>> &inputs_,
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ OperatorDialect::OperatorDialect(pir::IrContext *ctx)
ctx->GetOrRegisterDialect<::pir::ControlFlowDialect>();
auto info = ctx->GetRegisteredOpInfo(pir::TuplePushOp::name());
info.AttachInterface(std::move(
pir::InterfaceValue::
Get<pir::TuplePushOp, VjpInterface, TuplePushOpVjpInterfaceModel>()));
pir::InterfaceValue::Get<VjpInterface, TuplePushOpVjpInterfaceModel>()));
}

void OperatorDialect::initialize() {
Expand Down
85 changes: 74 additions & 11 deletions paddle/fluid/pybind/control_flow_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ using paddle::dialect::AssertOp;
using paddle::dialect::HasElementsOp;
using paddle::dialect::IfOp;
using paddle::dialect::WhileOp;
using paddle::pybind::PyIfOp;
using paddle::pybind::PyWhileOp;
using pir::Block;
using pir::Builder;
using pir::Operation;
Expand All @@ -50,8 +52,6 @@ using pir::Type;
using pir::Value;
using pir::YieldOp;
using pybind11::return_value_policy;

using paddle::pybind::PyIfOp;
namespace {

void BindIfOp(py::module* m) {
Expand All @@ -78,22 +78,24 @@ void BindIfOp(py::module* m) {
}

void BindWhileOp(py::module* m) {
m->def("build_while_op", [](Value cond, py::list loop_vars) {
m->def("build_while_op", [](Value cond, py::list loop_vars) -> PyWhileOp {
std::vector<Value> loop_values;
for (auto var : loop_vars) {
loop_values.push_back(var.cast<Value>());
}
return ApiBuilder::Instance().GetBuilder()->Build<WhileOp>(cond,
loop_values);
return PyWhileOp(
ApiBuilder::Instance().GetBuilder()->Build<WhileOp>(cond, loop_values));
});
py::class_<WhileOp> while_op(*m, "WhileOp", R"DOC(
py::class_<PyWhileOp> while_op(*m, "WhileOp", R"DOC(
WhileOp in python api.
)DOC");
while_op.def("body", &WhileOp::body, return_value_policy::reference)
.def("as_operation", &WhileOp::operation, return_value_policy::reference)
while_op.def("body", &PyWhileOp::body, return_value_policy::reference)
.def(
"as_operation", &PyWhileOp::operation, return_value_policy::reference)
.def("block_arguments",
&WhileOp::block_args,
return_value_policy::reference);
return_value_policy::reference)
.def("optimize_update", &PyWhileOp::OptimizeUpdate);
}

void BindAssertOp(py::module* m) {
Expand Down Expand Up @@ -227,7 +229,7 @@ PyIfOp::PyIfOp(IfOp if_op) : IfOp(if_op) {

void PyIfOp::UpdateOutput() {
PADDLE_ENFORCE_NOT_NULL(
*this,
operation_,
paddle::platform::errors::InvalidArgument(
"The if_op in PyIfOp used to update output can't be nullptr"));
auto block = parent();
Expand All @@ -241,7 +243,68 @@ void PyIfOp::UpdateOutput() {
cond(), true_region().TakeBack(), false_region().TakeBack());
block->Assign(iter, new_if_op);
IfOp::operator=(new_if_op);
VerifyRegion();
operation_->Verify();
}

PyWhileOp::PyWhileOp(WhileOp while_op) : WhileOp(while_op) {
PADDLE_ENFORCE_NOT_NULL(
operation_,
paddle::platform::errors::InvalidArgument(
"The while_op used to construct PyWhileOp can't be nullptr"));
}

std::vector<Value> PyWhileOp::OptimizeUpdate() {
PADDLE_ENFORCE_NOT_NULL(operation_,
paddle::platform::errors::InvalidArgument(
"The while_op in PyWhileOp used to remove unused "
"loop vars can't be nullptr"));
auto parent_block = parent();
PADDLE_ENFORCE_NOT_NULL(
parent_block,
paddle::platform::errors::InvalidArgument(
"The parent block of while_op which used to remove "
"unused loop vars can't be nullptr"));

operation_->Verify();
auto& body_block = body();
auto yield_op = body_block.back().dyn_cast<YieldOp>();
auto operand_num = operation_->num_operands();
bool no_change = true;
std::vector<size_t> index_vec;
std::vector<Value> res, new_input, new_yield_val{yield_op.operand_source(0)};
for (uint32_t i = 0; i < num_results(); ++i) {
res.push_back(result(i));
}
for (size_t operand_index = 1u, arg_index = 0u; operand_index < operand_num;
++operand_index) {
if (yield_op.operand_source(operand_index) == body_block.arg(arg_index)) {
body_block.arg(arg_index).ReplaceAllUsesWith(
operand_source(operand_index));
body_block.EraseArgument(arg_index);
no_change = false;
res[operand_index - 1u] = operand_source(operand_index);
} else {
new_input.push_back(operand_source(operand_index));
index_vec.push_back(operand_index - 1u);
new_yield_val.push_back(yield_op.operand_source(operand_index));
++arg_index;
}
}
if (no_change) return res;
Block::Iterator iter = **this;
Builder builder(ir_context(), false);
auto new_while_op = builder.Build<WhileOp>(cond(), new_input, false);
new_while_op->region(0).swap(std::move(operation_->region(0)));
parent_block->Assign(iter, new_while_op);
WhileOp::operator=(new_while_op);
body_block.pop_back();
builder.SetInsertionPointToBlockEnd(&body_block);
builder.Build<YieldOp>(new_yield_val);
operation_->Verify();
for (size_t result_index = 0; result_index < num_results(); ++result_index) {
res[index_vec[result_index]] = result(result_index);
}
return res;
}

void BindControlFlowApi(py::module* m) {
Expand Down
16 changes: 16 additions & 0 deletions paddle/fluid/pybind/control_flow_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ class PyIfOp : public dialect::IfOp {
void UpdateOutput();
};

class PyWhileOp : public dialect::WhileOp {
public:
explicit PyWhileOp(dialect::WhileOp while_op);

///
/// \brief Construct a new while_op to replace the original while_op. The
/// input, output, and parameters of the new while_op no longer contain the
/// variables that have not been modified in the loop. The size of the return
/// value is equal to the output size of the original while_op, where the
/// value of the read-only loop variable is the corresponding operand of the
/// original while_op, and the value of the non-read-only loop variable is the
/// corresponding output of the new while_op,
///
std::vector<pir::Value> OptimizeUpdate();
};

void BindControlFlowApi(pybind11::module *m);
} // namespace pybind
} // namespace paddle
10 changes: 2 additions & 8 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -527,14 +527,8 @@ void BindOperation(py::module *m) {
})
.def("as_if_op",
[](Operation &self) { return PyIfOp(self.dyn_cast<IfOp>()); })
.def("as_while_op", [](Operation &self) -> WhileOp {
auto while_op = self.dyn_cast<WhileOp>();
if (!while_op) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Can't cast non-while type Operation to WhileOp."));
}
return while_op;
});
.def("as_while_op",
[](Operation &self) { return PyWhileOp(self.dyn_cast<WhileOp>()); });
py::class_<Operation::BlockContainer> block_container(
*m, "Operation_BlockContainer", R"DOC(
The Operation_BlockContainer only use to walk all blocks in the operation.
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1859,6 +1859,7 @@ void IncrementInferMeta(const MetaTensor& x, float value, MetaTensor* out) {
product(x.dims())));
out->set_dims(x.dims());
out->share_lod(x);
out->set_layout(x.layout());
out->set_dtype(x.dtype());
}

Expand Down
16 changes: 14 additions & 2 deletions paddle/pir/core/block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ void Block::push_back(Operation *op) { insert(ops_.end(), op); }

void Block::push_front(Operation *op) { insert(ops_.begin(), op); }

void Block::pop_back() {
IR_ENFORCE(!ops_.empty(), "can't pop back from empty block.");
ops_.back()->Destroy();
ops_.pop_back();
}

Operation *Block::GetParentOp() const {
return parent_ ? parent_->GetParent() : nullptr;
}
Expand All @@ -50,8 +56,7 @@ Block::Iterator Block::erase(ConstIterator position) {

void Block::clear() {
while (!empty()) {
ops_.back()->Destroy();
ops_.pop_back();
pop_back();
}
}

Expand Down Expand Up @@ -103,6 +108,13 @@ Value Block::AddArgument(Type type) {
return argument;
}

void Block::EraseArgument(uint32_t index) {
auto argument = arg(index);
IR_ENFORCE(argument.use_empty(),
"Erase a block argument that is still in use.");
argument.dyn_cast<BlockArgument>().Destroy();
arguments_.erase(arguments_.begin() + index);
}
bool Block::TopoOrderCheck(const OpListType &op_list) {
std::unordered_set<Value> visited_values;
for (Operation *op : op_list) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/pir/core/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class IR_API Block {

void push_back(Operation *op);
void push_front(Operation *op);
void pop_back();
Iterator insert(ConstIterator iterator, Operation *op);
Iterator erase(ConstIterator position);
void clear();
Expand Down Expand Up @@ -111,6 +112,7 @@ class IR_API Block {
Type arg_type(uint32_t index) const { return arguments_[index].type(); }
void ClearArguments();
Value AddArgument(Type type);
void EraseArgument(uint32_t index);
template <class TypeIter>
void AddArguments(TypeIter first, TypeIter last);
template <class TypeContainer>
Expand Down
Loading