Skip to content

Commit dba3976

Browse files
【pir】add test_while_api.py gradient case (#59999)
* optimize backward * [PIR] add vjp interface for while op * [PIR] fix ci error. * modify while stopgradient * merge * modify while grad bug * modify while grad op * modify * code style * code style --------- Co-authored-by: winter-wang <[email protected]>
1 parent 413b654 commit dba3976

File tree

7 files changed

+262
-50
lines changed

7 files changed

+262
-50
lines changed

paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ void IfOp::Build(pir::Builder &builder, // NOLINT
4848
argument.output_types.swap(output_types);
4949
argument.AddRegion().emplace_back();
5050
argument.AddRegion().emplace_back();
51+
cond.set_attribute(kStopGradientAttrName, builder.bool_attr(true));
5152
}
5253

5354
void IfOp::Build(pir::Builder &builder, // NOLINT
@@ -289,17 +290,30 @@ void WhileOp::Build(pir::Builder &builder, // NOLINT
289290
argument.AddInput(cond);
290291
argument.AddInputs(inputs);
291292
auto &body = argument.AddRegion().emplace_back();
293+
std::vector<pir::Attribute> outs_stop_gradient;
292294
for (auto val : inputs) {
293295
argument.AddOutput(val.type());
294-
body.AddArgument(val.type());
296+
auto arg = body.AddArgument(val.type());
297+
298+
auto bool_attr = val.attribute<pir::BoolAttribute>(kStopGradientAttrName);
299+
arg.set_attribute(kStopGradientAttrName,
300+
bool_attr ? bool_attr : builder.bool_attr(false));
301+
outs_stop_gradient.push_back(bool_attr ? bool_attr
302+
: builder.bool_attr(false));
295303
}
304+
305+
argument.AddAttribute(
306+
kStopGradientAttrName,
307+
pir::ArrayAttribute::get(builder.ir_context(), outs_stop_gradient));
308+
296309
cond.set_attribute(kStopGradientAttrName, builder.bool_attr(true));
297310
}
298311
pir::Block &WhileOp::body() {
299312
pir::Region &body_region = (*this)->region(0);
300313
if (body_region.empty()) body_region.emplace_back();
301314
return body_region.front();
302315
}
316+
303317
pir::Value WhileOp::cond() { return (*this)->operand_source(0); }
304318

305319
void WhileOp::Print(pir::IrPrinter &printer) {
@@ -367,6 +381,14 @@ std::vector<std::vector<pir::OpResult>> WhileOp::Vjp(
367381
"the outputs size is %d.",
368382
inputs.size(),
369383
outputs.size()));
384+
PADDLE_ENFORCE_EQ(inputs.size(),
385+
out_grads.size() + 1,
386+
phi::errors::InvalidArgument(
387+
"while op's inputs' size should equal to "
388+
"output_grads' size + 1, Now the inputs's size is %d ."
389+
"the output_grads size is %d.",
390+
inputs.size(),
391+
out_grads.size()));
370392
PADDLE_ENFORCE_EQ(stop_gradients[0][0],
371393
true,
372394
phi::errors::InvalidArgument(
@@ -377,27 +399,12 @@ std::vector<std::vector<pir::OpResult>> WhileOp::Vjp(
377399

378400
std::vector<pir::Type> output_types;
379401
std::vector<pir::Value> loop_vars;
380-
size_t index = 0;
381402

382-
for (; index < outputs.size(); ++index) {
403+
for (size_t index = 0; index < out_grads.size(); ++index) {
383404
if (!stop_gradients[index + 1][0]) {
384405
loop_vars.push_back(out_grads[index][0]);
385406
}
386407
}
387-
for (++index; index < inputs.size(); ++index) {
388-
if (!stop_gradients[index][0]) {
389-
auto fwd_type = inputs[index][0].type().dyn_cast<DenseTensorType>();
390-
PADDLE_ENFORCE_NE(
391-
fwd_type,
392-
pir::Type(),
393-
phi::errors::InvalidArgument(
394-
"The forward value type must be dense tensor type."));
395-
auto shape = vectorize(fwd_type.dims());
396-
auto dtype = TransToPhiDataType(fwd_type.dtype());
397-
auto full_op = builder.Build<FullOp>(shape, 0.0, dtype, phi::CPUPlace());
398-
loop_vars.push_back(full_op.out());
399-
}
400-
}
401408
auto while_grad = builder.Build<WhileOp>(cond_val, loop_vars);
402409

403410
std::vector<std::vector<pir::OpResult>> res(inputs.size());
@@ -426,9 +433,7 @@ std::vector<std::vector<pir::OpResult>> TuplePushOpVjpInterfaceModel::Vjp(
426433
res[0].resize(1);
427434
for (size_t i = 1u; i < inputs.size(); ++i) {
428435
res[i].resize(1);
429-
if (!stop_gradients[i][0]) {
430-
res[i][0] = pop_op.result(i - 1);
431-
}
436+
res[i][0] = pop_op.result(i - 1);
432437
}
433438
return res;
434439
}
@@ -439,6 +444,10 @@ void HasElementsOp::Build(pir::Builder &builder, // NOLINT
439444
argument.AddInput(container);
440445
argument.AddOutput(
441446
DenseTensorType::get(builder.ir_context(), builder.bool_type(), {1}));
447+
std::vector<pir::Attribute> outs_stop_gradient{builder.bool_attr(true)};
448+
argument.AddAttribute(
449+
kStopGradientAttrName,
450+
pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient));
442451
}
443452
void HasElementsOp::VerifySig() {
444453
VLOG(4) << "Verifying inputs, outputs ,attributes for: HasElementsOp.";

paddle/fluid/pir/dialect/operator/ir/control_flow_op.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <vector>
1717

1818
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
19+
#include "paddle/pir/core/block.h"
1920
#include "paddle/pir/core/op_base.h"
2021

2122
namespace paddle {
@@ -78,6 +79,7 @@ class WhileOp : public pir::Op<WhileOp, VjpInterface> {
7879
const std::vector<pir::Value> &inputs);
7980
pir::Block &body();
8081
pir::Value cond();
82+
const pir::Block::ArgListType &block_args() { return body().args(); }
8183
void Print(pir::IrPrinter &printer); // NOLINT
8284
void VerifySig() {}
8385
void VerifyRegion() {}

paddle/fluid/pybind/control_flow_api.cc

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
namespace py = pybind11;
3737
using paddle::dialect::ApiBuilder;
38+
using paddle::dialect::HasElementsOp;
3839
using paddle::dialect::IfOp;
3940
using paddle::dialect::WhileOp;
4041
using pir::Block;
@@ -88,7 +89,10 @@ void BindWhileOp(py::module* m) {
8889
WhileOp in python api.
8990
)DOC");
9091
while_op.def("body", &WhileOp::body, return_value_policy::reference)
91-
.def("as_operation", &WhileOp::operation, return_value_policy::reference);
92+
.def("as_operation", &WhileOp::operation, return_value_policy::reference)
93+
.def("block_arguments",
94+
&WhileOp::block_args,
95+
return_value_policy::reference);
9296
}
9397

9498
void GetUsedExternalValueImpl(
@@ -126,6 +130,30 @@ std::vector<Value> GetUsedExternalValue(const Operation& op) {
126130
return used_values;
127131
}
128132

133+
Value BuildHasElementsOp(Operation& fwd_op) { // NOLINT
134+
PADDLE_ENFORCE(fwd_op.isa<WhileOp>(),
135+
phi::errors::PreconditionNotMet(
136+
"param op of BuildHasElementsOp must be while op."));
137+
auto fwdop = fwd_op.dyn_cast<WhileOp>();
138+
TuplePushOp push_op;
139+
for (auto iter = fwdop.body().rbegin(); iter != fwdop.body().rend(); ++iter) {
140+
if (iter->isa<TuplePushOp>()) {
141+
push_op = iter->dyn_cast<TuplePushOp>();
142+
PADDLE_ENFORCE_EQ(push_op.container().use_empty(),
143+
false,
144+
phi::errors::InvalidArgument(
145+
"The last container in foward while op must used "
146+
"after construct while_grad op"));
147+
break;
148+
}
149+
}
150+
auto new_cond = ApiBuilder::Instance()
151+
.GetBuilder()
152+
->Build<HasElementsOp>(push_op.container())
153+
.out();
154+
return new_cond;
155+
}
156+
129157
void BuildPipeForBlock(Block* block) {
130158
PADDLE_ENFORCE_NOT_NULL(
131159
block,
@@ -193,16 +221,17 @@ void PyIfOp::UpdateOutput() {
193221
void BindControlFlowApi(py::module* m) {
194222
m->def("get_used_external_value", GetUsedExternalValue);
195223
m->def("build_pipe_for_block", BuildPipeForBlock);
224+
m->def("cf_has_elements", BuildHasElementsOp);
196225
m->def("cf_yield", [](py::list inputs) {
197226
std::vector<Value> input_values;
198227
for (auto input : inputs) {
199228
input_values.push_back(input.cast<Value>());
200229
}
201230
ApiBuilder::Instance().GetBuilder()->Build<YieldOp>(input_values);
202231
});
203-
204232
BindIfOp(m);
205233
BindWhileOp(m);
206234
}
235+
207236
} // namespace pybind
208237
} // namespace paddle

paddle/fluid/pybind/pir.cc

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,12 +320,16 @@ void BindBlock(py::module *m) {
320320
"front",
321321
[](Block &self) { return &self.front(); },
322322
return_value_policy::reference)
323+
.def_property_readonly(
324+
"parent_op",
325+
[](Block &self) { return self.GetParentOp(); },
326+
return_value_policy::reference)
323327
.def_property_readonly(
324328
"program",
325329
[](Block &self) { return self.GetParentOp()->GetParentProgram(); },
326330
return_value_policy::reference)
327331
.def_property_readonly(
328-
"get_parent",
332+
"parent_block",
329333
[](Block &self) { return self.GetParentOp()->GetParent(); },
330334
return_value_policy::reference)
331335
.def_property_readonly("ops",
@@ -717,6 +721,14 @@ void BindValue(py::module *m) {
717721
kAttrIsPersisable,
718722
BoolAttribute::get(pir::IrContext::Instance(), persistable));
719723
})
724+
.def("all_used_ops",
725+
[](Value &self) -> py::list {
726+
py::list op_list;
727+
for (auto it = self.use_begin(); it != self.use_end(); ++it) {
728+
op_list.append(it.owner());
729+
}
730+
return op_list;
731+
})
720732
.def(
721733
"get_defining_op",
722734
[](Value self) -> pir::Operation * {

0 commit comments

Comments
 (0)