Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
001d799
optimize backward
xiaoguoguo626807 Dec 8, 2023
05ca298
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 11, 2023
4fd113e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 12, 2023
8f60538
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 13, 2023
8854896
[PIR] add vjp interface for while op
winter-wang Dec 12, 2023
7e177f6
[PIR] fix ci error.
winter-wang Dec 13, 2023
11c8656
modify while stopgradient
xiaoguoguo626807 Dec 14, 2023
d8c3936
merge
xiaoguoguo626807 Dec 14, 2023
da62e16
merge
xiaoguoguo626807 Dec 15, 2023
67ed811
merge
xiaoguoguo626807 Dec 15, 2023
30bba32
modify while grad bug
xiaoguoguo626807 Dec 18, 2023
53f2920
merge
xiaoguoguo626807 Dec 18, 2023
fde161c
modify while grad op
xiaoguoguo626807 Dec 18, 2023
fdc12c7
modify
xiaoguoguo626807 Dec 18, 2023
e3d19b9
increment vp
xiaoguoguo626807 Dec 19, 2023
600d99c
merge
xiaoguoguo626807 Dec 20, 2023
63344b7
while case
xiaoguoguo626807 Dec 20, 2023
59ad2fc
delete print
xiaoguoguo626807 Dec 20, 2023
f4eceb6
delete print
xiaoguoguo626807 Dec 20, 2023
1c9eb96
Update python/paddle/autograd/ir_backward.py
xiaoguoguo626807 Dec 20, 2023
4beaa79
Merge branch 'develop' into while_2
xiaoguoguo626807 Dec 20, 2023
65083df
modify while_loop
xiaoguoguo626807 Dec 21, 2023
f2f4fa0
Merge branch 'while_2' of https://github.com/xiaoguoguo626807/Paddle …
xiaoguoguo626807 Dec 21, 2023
f8e3ac4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 21, 2023
95bc3d7
code_style
xiaoguoguo626807 Dec 21, 2023
37e807c
modofy ci bug
xiaoguoguo626807 Dec 21, 2023
b7a003f
modofy ci bug
xiaoguoguo626807 Dec 21, 2023
048a942
delete print
xiaoguoguo626807 Dec 21, 2023
5d98f91
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Dec 21, 2023
9bd0512
modify ci bug
xiaoguoguo626807 Dec 22, 2023
bfbfa45
modify
xiaoguoguo626807 Dec 22, 2023
53ec0db
Update python/paddle/autograd/ir_backward.py
xiaoguoguo626807 Dec 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
'add_n_with_kernel',
'split_grad',
'expand',
'increment',
'increment_',
}

attr_types_map = {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/dialect/op_generator/op_interface_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
}"""

OP_VJP_DEFINE_TEMPLATE = """
std::vector<std::vector<pir::OpResult>> {op_class_name}::Vjp(pir::Operation* op, const std::vector<std::vector<pir::Value>>& inputs_, const std::vector<std::vector<pir::OpResult>>& outputs, const std::vector<std::vector<pir::Value>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients){{
std::vector<std::vector<pir::OpResult>> {op_class_name}::Vjp(pir::Operation* op, const std::vector<std::vector<pir::Value>>& inputs_, const std::vector<std::vector<pir::Value>>& outputs, const std::vector<std::vector<pir::Value>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients){{
{check_param}
VLOG(6) << "Prepare inputs of {op_grad_name}";
{backward_input_code}
Expand Down Expand Up @@ -302,5 +302,5 @@ def gen_exclusive_interface_str(op_info, op_info_items):
" static void InferMeta( phi::InferMetaContext *infer_meta );"
)
if op_info.op_phi_name[0] not in vjp_interface_black_list:
exclusive_interface_str += "\n static std::vector<std::vector<pir::OpResult>> Vjp(pir::Operation* op, const std::vector<std::vector<pir::Value>>& inputs_, const std::vector<std::vector<pir::OpResult>>& outputs, const std::vector<std::vector<pir::Value>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients);"
exclusive_interface_str += "\n static std::vector<std::vector<pir::OpResult>> Vjp(pir::Operation* op, const std::vector<std::vector<pir::Value>>& inputs_, const std::vector<std::vector<pir::Value>>& outputs, const std::vector<std::vector<pir::Value>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients);"
return exclusive_interface_str
20 changes: 1 addition & 19 deletions paddle/fluid/pir/dialect/operator/interface/vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,6 @@

#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"

namespace paddle::dialect {
std::vector<std::vector<pir::OpResult>> VjpInterface::Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs,
const std::vector<std::vector<pir::OpResult>>& outputs,
const std::vector<std::vector<pir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<pir::Value>> out_grads_value;
for (const auto& grad : out_grads) {
std::vector<pir::Value> grad_value;
for (auto op_result : grad) {
grad_value.emplace_back(op_result);
}
out_grads_value.emplace_back(std::move(grad_value));
}
return impl_->vjp_(op, inputs, outputs, out_grads_value, stop_gradients);
}

} // namespace paddle::dialect
namespace paddle::dialect {} // namespace paddle::dialect

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::VjpInterface)
15 changes: 4 additions & 11 deletions paddle/fluid/pir/dialect/operator/interface/vjp.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ class VjpInterface : public pir::OpInterfaceBase<VjpInterface> {
explicit Concept(std::vector<std::vector<pir::OpResult>> (*vjp)(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs,
const std::vector<std::vector<pir::OpResult>>& outputs,
const std::vector<std::vector<pir::Value>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients))
: vjp_(vjp) {}
std::vector<std::vector<pir::OpResult>> (*vjp_)(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs,
const std::vector<std::vector<pir::OpResult>>& outputs,
const std::vector<std::vector<pir::Value>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients);
};
Expand All @@ -40,7 +40,7 @@ class VjpInterface : public pir::OpInterfaceBase<VjpInterface> {
static std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs,
const std::vector<std::vector<pir::OpResult>>& outputs,
const std::vector<std::vector<pir::Value>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
return ConcreteOp::Vjp(op, inputs, outputs, out_grads, stop_gradients);
Expand All @@ -56,19 +56,12 @@ class VjpInterface : public pir::OpInterfaceBase<VjpInterface> {
std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs,
const std::vector<std::vector<pir::OpResult>>& outputs,
const std::vector<std::vector<pir::Value>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
return impl_->vjp_(op, inputs, outputs, out_grads, stop_gradients);
}

std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs,
const std::vector<std::vector<pir::OpResult>>& outputs,
const std::vector<std::vector<pir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients);

private:
Concept* impl_;
};
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ void IfOp::VerifyRegion() {
std::vector<std::vector<pir::OpResult>> IfOp::Vjp(
pir::Operation *op,
const std::vector<std::vector<pir::Value>> &inputs_,
const std::vector<std::vector<pir::OpResult>> &outputs,
const std::vector<std::vector<pir::Value>> &outputs,
const std::vector<std::vector<pir::Value>> &out_grads,
const std::vector<std::vector<bool>> &stop_gradients) {
PADDLE_ENFORCE_EQ(
Expand Down Expand Up @@ -345,7 +345,7 @@ void WhileOp::Print(pir::IrPrinter &printer) {
std::vector<std::vector<pir::OpResult>> WhileOp::Vjp(
pir::Operation *op,
const std::vector<std::vector<pir::Value>> &inputs,
const std::vector<std::vector<pir::OpResult>> &outputs,
const std::vector<std::vector<pir::Value>> &outputs,
const std::vector<std::vector<pir::Value>> &out_grads,
const std::vector<std::vector<bool>> &stop_gradients) {
auto fwd_op = WhileOp::dyn_cast(op);
Expand Down Expand Up @@ -416,7 +416,7 @@ std::vector<std::vector<pir::OpResult>> WhileOp::Vjp(
std::vector<std::vector<pir::OpResult>> TuplePushOpVjpInterfaceModel::Vjp(
pir::Operation *op,
const std::vector<std::vector<pir::Value>> &inputs,
const std::vector<std::vector<pir::OpResult>> &outputs,
const std::vector<std::vector<pir::Value>> &outputs,
const std::vector<std::vector<pir::Value>> &out_grads,
const std::vector<std::vector<bool>> &stop_gradients) {
PADDLE_ENFORCE_EQ(
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class IfOp : public pir::Op<IfOp, VjpInterface> {
static std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation *op,
const std::vector<std::vector<pir::Value>> &inputs_,
const std::vector<std::vector<pir::OpResult>> &outputs,
const std::vector<std::vector<pir::Value>> &outputs,
const std::vector<std::vector<pir::Value>> &out_grads,
const std::vector<std::vector<bool>> &stop_gradients);
};
Expand Down Expand Up @@ -86,7 +86,7 @@ class WhileOp : public pir::Op<WhileOp, VjpInterface> {
static std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation *op,
const std::vector<std::vector<pir::Value>> &inputs_,
const std::vector<std::vector<pir::OpResult>> &outputs,
const std::vector<std::vector<pir::Value>> &outputs,
const std::vector<std::vector<pir::Value>> &out_grads,
const std::vector<std::vector<bool>> &stop_gradients);
};
Expand All @@ -95,7 +95,7 @@ struct TuplePushOpVjpInterfaceModel : public VjpInterface::Concept {
static std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation *op,
const std::vector<std::vector<pir::Value>> &inputs,
const std::vector<std::vector<pir::OpResult>> &outputs,
const std::vector<std::vector<pir::Value>> &outputs,
const std::vector<std::vector<pir::Value>> &out_grads,
const std::vector<std::vector<bool>> &stop_gradients);

Expand Down
Loading