diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 1560bcbde8d08d..8e56406583385c 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -238,6 +238,8 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ 'add_n_with_kernel', 'split_grad', 'expand', + 'increment', + 'increment_', } attr_types_map = { diff --git a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py index 7d3476b1a65204..2a68a1ad430675 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py @@ -88,7 +88,7 @@ }""" OP_VJP_DEFINE_TEMPLATE = """ -std::vector> {op_class_name}::Vjp(pir::Operation* op, const std::vector>& inputs_, const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients){{ +std::vector> {op_class_name}::Vjp(pir::Operation* op, const std::vector>& inputs_, const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients){{ {check_param} VLOG(6) << "Prepare inputs of {op_grad_name}"; {backward_input_code} @@ -302,5 +302,5 @@ def gen_exclusive_interface_str(op_info, op_info_items): " static void InferMeta( phi::InferMetaContext *infer_meta );" ) if op_info.op_phi_name[0] not in vjp_interface_black_list: - exclusive_interface_str += "\n static std::vector> Vjp(pir::Operation* op, const std::vector>& inputs_, const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients);" + exclusive_interface_str += "\n static std::vector> Vjp(pir::Operation* op, const std::vector>& inputs_, const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients);" return exclusive_interface_str diff --git a/paddle/fluid/pir/dialect/operator/interface/vjp.cc b/paddle/fluid/pir/dialect/operator/interface/vjp.cc index 5a509b3dfc99e2..ea7854670449ea 100644 --- a/paddle/fluid/pir/dialect/operator/interface/vjp.cc +++ b/paddle/fluid/pir/dialect/operator/interface/vjp.cc @@ -14,24 +14,6 @@ #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" -namespace paddle::dialect { -std::vector> VjpInterface::Vjp( - pir::Operation* op, - const std::vector>& inputs, - const std::vector>& outputs, - const std::vector>& out_grads, - const std::vector>& stop_gradients) { - std::vector> out_grads_value; - for (const auto& grad : out_grads) { - std::vector grad_value; - for (auto op_result : grad) { - grad_value.emplace_back(op_result); - } - out_grads_value.emplace_back(std::move(grad_value)); - } - return impl_->vjp_(op, inputs, outputs, out_grads_value, stop_gradients); -} - -} // namespace paddle::dialect +namespace paddle::dialect {} // namespace paddle::dialect IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::VjpInterface) diff --git a/paddle/fluid/pir/dialect/operator/interface/vjp.h b/paddle/fluid/pir/dialect/operator/interface/vjp.h index 44d1731359beb5..5246a2867665e4 100644 --- a/paddle/fluid/pir/dialect/operator/interface/vjp.h +++ b/paddle/fluid/pir/dialect/operator/interface/vjp.h @@ -23,14 +23,14 @@ class VjpInterface : public pir::OpInterfaceBase { explicit Concept(std::vector> (*vjp)( pir::Operation* op, const std::vector>& inputs, - const std::vector>& outputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients)) : vjp_(vjp) {} std::vector> (*vjp_)( pir::Operation* op, const std::vector>& inputs, - const std::vector>& outputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients); }; @@ -40,7 +40,7 @@ class VjpInterface : public pir::OpInterfaceBase { static std::vector> Vjp( pir::Operation* op, const std::vector>& inputs, - const std::vector>& outputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { return ConcreteOp::Vjp(op, inputs, outputs, out_grads, stop_gradients); @@ -56,19 +56,12 @@ class VjpInterface : public pir::OpInterfaceBase { std::vector> Vjp( pir::Operation* op, const std::vector>& inputs, - const std::vector>& outputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { return impl_->vjp_(op, inputs, outputs, out_grads, stop_gradients); } - std::vector> Vjp( - pir::Operation* op, - const std::vector>& inputs, - const std::vector>& outputs, - const std::vector>& out_grads, - const std::vector>& stop_gradients); - private: Concept* impl_; }; diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 28eeca552835eb..a898965f1f7025 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -246,7 +246,7 @@ void IfOp::VerifyRegion() { std::vector> IfOp::Vjp( pir::Operation *op, const std::vector> &inputs_, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients) { PADDLE_ENFORCE_EQ( @@ -346,7 +346,7 @@ void WhileOp::Print(pir::IrPrinter &printer) { std::vector> WhileOp::Vjp( pir::Operation *op, const std::vector> &inputs, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients) { auto fwd_op = WhileOp::dyn_cast(op); @@ -417,7 +417,7 @@ std::vector> WhileOp::Vjp( std::vector> TuplePushOpVjpInterfaceModel::Vjp( pir::Operation *op, const std::vector> &inputs, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients) { PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h index 4451bdfc84645d..f3b9b30c2e4002 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h @@ -52,7 +52,7 @@ class IfOp : public pir::Op { static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); }; @@ -87,7 +87,7 @@ class WhileOp : public pir::Op { static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); }; @@ -96,7 +96,7 @@ struct TuplePushOpVjpInterfaceModel : public VjpInterface::Concept { static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 2160e56442d465..dad8c36e2f358a 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -20,7 +20,8 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op, paddle::dialect::ArrayLengthOp, paddle::dialect::ArrayReadOp, paddle::dialect::ArrayWrite_Op, paddle::dialect::SliceArrayOp, paddle::dialect::SliceArrayDenseOp, paddle::dialect::AssignArray_Op, - paddle::dialect::ArrayToTensorOp, paddle::dialect::SelectInputOp + paddle::dialect::ArrayToTensorOp, paddle::dialect::SelectInputOp, + paddle::dialect::IncrementOp, paddle::dialect::Increment_Op #else #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" @@ -2270,6 +2271,349 @@ void SelectInputOp::VerifySig() { VLOG(4) << "End Verifying for: AssignArray_Op."; } +const char *IncrementOp::attributes_name[1] = {"value"}; + +OpInfoTuple IncrementOp::GetOpInfo() { + std::vector inputs = { + paddle::dialect::OpInputInfo( + "x", "paddle::dialect::DenseTensorType", false, false, false, false)}; + std::vector attributes = { + paddle::dialect::OpAttributeInfo("value", "pir::FloatAttribute", "")}; + std::vector outputs = { + paddle::dialect::OpOutputInfo( + "out", "paddle::dialect::DenseTensorType", false, false)}; + paddle::dialect::OpRunTimeInfo run_time_info = + paddle::dialect::OpRunTimeInfo("IncrementInferMeta", + {"x", "value"}, + "increment", + {"x", "value"}, + {}, + {}, + {}, + {}); + return std::make_tuple( + inputs, attributes, outputs, run_time_info, "increment"); +} + +void IncrementOp::Build(pir::Builder &builder, + pir::OperationArgument &argument, + pir::Value x_, + float value) { + VLOG(4) << "Start build IncrementOp"; + + VLOG(4) << "Builder construction inputs"; + std::vector argument_inputs = {x_}; + argument.AddInputs(argument_inputs); + + VLOG(4) << "Builder construction attributes"; + pir::Attribute attr_value = + pir::FloatAttribute::get(pir::IrContext::Instance(), value); + argument.AddAttribute("value", attr_value); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorType x = + x_.type().dyn_cast(); + (void)x; + + VLOG(4) << "Builder construction dense_x"; + paddle::dialect::IrTensor ir_tensor_x( + paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset()); + VLOG(4) << "Builder construction meta_x"; + paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x); + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::IncrementInferMeta(meta_x, value, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); + ::pir::PassStopGradientsDefaultly(argument); +} + +void IncrementOp::Build(pir::Builder &builder, + pir::OperationArgument &argument, + pir::Value x_, + pir::AttributeMap attributes) { + VLOG(4) << "Start build IncrementOp"; + + IR_ENFORCE(attributes.find("value") != attributes.end(), + "'value' Attribute is expected for IncrementOp. "); + float value = attributes.at("value").dyn_cast().data(); + + VLOG(4) << "Builder construction inputs"; + std::vector argument_inputs = {x_}; + argument.AddInputs(argument_inputs); + + VLOG(4) << "Builder construction attributes"; + pir::Attribute attr_value = + pir::FloatAttribute::get(pir::IrContext::Instance(), value); + argument.AddAttribute("value", attr_value); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorType x = + x_.type().dyn_cast(); + (void)x; + + VLOG(4) << "Builder construction dense_x"; + paddle::dialect::IrTensor ir_tensor_x( + paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset()); + VLOG(4) << "Builder construction meta_x"; + paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x); + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::IncrementInferMeta(meta_x, value, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); + ::pir::PassStopGradientsDefaultly(argument); +} + +void IncrementOp::VerifySig() { + VLOG(4) << "Start Verifying inputs, outputs and attributes for: IncrementOp."; + VLOG(4) << "Verifying inputs:"; + { + auto input_size = num_operands(); + IR_ENFORCE(input_size == 1u, + "The size %d of inputs must be equal to 1.", + input_size); + IR_ENFORCE((*this) + ->operand_source(0) + .type() + .isa(), + "Type validation failed for the 0th input, got %s.", + (*this)->operand_source(0).type()); + } + VLOG(4) << "Verifying attributes:"; + { + auto &attributes = this->attributes(); + IR_ENFORCE(attributes.count("value") > 0, "value does not exist."); + IR_ENFORCE(attributes.at("value").isa(), + "Type of attribute: value is not pir::FloatAttribute."); + } + VLOG(4) << "Verifying outputs:"; + { + auto output_size = num_results(); + IR_ENFORCE(output_size == 1u, + "The size %d of outputs must be equal to 1.", + output_size); + IR_ENFORCE( + (*this)->result(0).type().isa(), + "Type validation failed for the 0th output."); + } + VLOG(4) << "End Verifying for: IncrementOp."; +} + +void IncrementOp::InferMeta(phi::InferMetaContext *infer_meta) { + auto fn = PD_INFER_META(phi::IncrementInferMeta); + fn(infer_meta); +} + +phi::DataType IncrementOp::GetKernelTypeForVar( + const std::string &var_name, + const phi::DataType &tensor_dtype, + const phi::DataType &expected_kernel_dtype) { + VLOG(4) << "Get KernelType for Var of op: IncrementOp"; + + return expected_kernel_dtype; +} + +const char *Increment_Op::attributes_name[1] = {"value"}; + +OpInfoTuple Increment_Op::GetOpInfo() { + std::vector inputs = { + paddle::dialect::OpInputInfo( + "x", "paddle::dialect::DenseTensorType", false, false, false, false)}; + std::vector attributes = { + paddle::dialect::OpAttributeInfo("value", "pir::FloatAttribute", "")}; + std::vector outputs = { + paddle::dialect::OpOutputInfo( + "out", "paddle::dialect::DenseTensorType", false, false)}; + paddle::dialect::OpRunTimeInfo run_time_info = + paddle::dialect::OpRunTimeInfo("IncrementInferMeta", + {"x", "value"}, + "increment", + {"x", "value"}, + {}, + {}, + {{"out", "x"}}, + {}); + return std::make_tuple( + inputs, attributes, outputs, run_time_info, "increment"); +} + +void Increment_Op::Build(pir::Builder &builder, + pir::OperationArgument &argument, + pir::Value x_, + float value) { + VLOG(4) << "Start build Increment_Op"; + + VLOG(4) << "Builder construction inputs"; + std::vector argument_inputs = {x_}; + argument.AddInputs(argument_inputs); + + VLOG(4) << "Builder construction attributes"; + pir::Attribute attr_value = + pir::FloatAttribute::get(pir::IrContext::Instance(), value); + argument.AddAttribute("value", attr_value); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorType x = + x_.type().dyn_cast(); + (void)x; + + VLOG(4) << "Builder construction dense_x"; + paddle::dialect::IrTensor ir_tensor_x( + paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset()); + VLOG(4) << "Builder construction meta_x"; + paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x); + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::IncrementInferMeta(meta_x, value, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); + ::pir::PassStopGradientsDefaultly(argument); +} + +void Increment_Op::Build(pir::Builder &builder, + pir::OperationArgument &argument, + pir::Value x_, + pir::AttributeMap attributes) { + VLOG(4) << "Start build Increment_Op"; + + IR_ENFORCE(attributes.find("value") != attributes.end(), + "'value' Attribute is expected for Increment_Op. "); + float value = attributes.at("value").dyn_cast().data(); + + VLOG(4) << "Builder construction inputs"; + std::vector argument_inputs = {x_}; + argument.AddInputs(argument_inputs); + + VLOG(4) << "Builder construction attributes"; + pir::Attribute attr_value = + pir::FloatAttribute::get(pir::IrContext::Instance(), value); + argument.AddAttribute("value", attr_value); + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorType x = + x_.type().dyn_cast(); + (void)x; + + VLOG(4) << "Builder construction dense_x"; + paddle::dialect::IrTensor ir_tensor_x( + paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset()); + VLOG(4) << "Builder construction meta_x"; + paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x); + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::IncrementInferMeta(meta_x, value, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); + ::pir::PassStopGradientsDefaultly(argument); +} + +void Increment_Op::VerifySig() { + VLOG(4) + << "Start Verifying inputs, outputs and attributes for: Increment_Op."; + VLOG(4) << "Verifying inputs:"; + { + auto input_size = num_operands(); + IR_ENFORCE(input_size == 1u, + "The size %d of inputs must be equal to 1.", + input_size); + IR_ENFORCE((*this) + ->operand_source(0) + .type() + .isa(), + "Type validation failed for the 0th input, got %s.", + (*this)->operand_source(0).type()); + } + VLOG(4) << "Verifying attributes:"; + { + auto &attributes = this->attributes(); + IR_ENFORCE(attributes.count("value") > 0, "value does not exist."); + IR_ENFORCE(attributes.at("value").isa(), + "Type of attribute: value is not pir::FloatAttribute."); + } + VLOG(4) << "Verifying outputs:"; + { + auto output_size = num_results(); + IR_ENFORCE(output_size == 1u, + "The size %d of outputs must be equal to 1.", + output_size); + IR_ENFORCE( + (*this)->result(0).type().isa(), + "Type validation failed for the 0th output."); + } + VLOG(4) << "End Verifying for: Increment_Op."; +} + +void Increment_Op::InferMeta(phi::InferMetaContext *infer_meta) { + auto fn = PD_INFER_META(phi::IncrementInferMeta); + fn(infer_meta); +} + +phi::DataType Increment_Op::GetKernelTypeForVar( + const std::string &var_name, + const phi::DataType &tensor_dtype, + const phi::DataType &expected_kernel_dtype) { + VLOG(4) << "Get KernelType for Var of op: Increment_Op"; + + return expected_kernel_dtype; +} + } // namespace dialect } // namespace paddle @@ -2289,4 +2633,6 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AssignArray_Op) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayToTensorOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op) #endif diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 460356039d84ab..1f367b4319d8c9 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -55,7 +55,7 @@ class AddNOp : public pir::Op> Vjp( pir::Operation *op, const std::vector> &inputs_, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); static std::vector> Decomp(pir::Operation *op); @@ -396,7 +396,7 @@ class ExpandOp : public pir::Op> Vjp( pir::Operation *op, const std::vector> &inputs_, - const std::vector> &outputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); }; @@ -412,6 +412,89 @@ class SelectInputOp : public pir::Op { pir::OpResult out() { return result(0); } }; +class IncrementOp + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.increment"; } + static const char *attributes_name[1]; + static constexpr uint32_t attributes_num = 1; + static OpInfoTuple GetOpInfo(); + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value x_, // NOLINT + float value = 1.0); + + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value x_, // NOLINT + pir::AttributeMap attributes); + + void VerifySig(); + + static phi::DataType GetKernelTypeForVar( + const std::string &var_name, + const phi::DataType &tensor_dtype, + const phi::DataType &expected_kernel_dtype); + + pir::Value x() { return operand_source(0); } + pir::OpResult out() { return result(0); } + + static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector> Vjp( + pir::Operation *op, + const std::vector> &inputs_, + const std::vector> &outputs, + const std::vector> &out_grads, + const std::vector> &stop_gradients); +}; + +class Increment_Op + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.increment_"; } + static const char *attributes_name[1]; + static constexpr uint32_t attributes_num = 1; + static OpInfoTuple GetOpInfo(); + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value x_, // NOLINT + float value = 1.0); + + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value x_, // NOLINT + pir::AttributeMap attributes); + + void VerifySig(); + + static phi::DataType GetKernelTypeForVar( + const std::string &var_name, + const phi::DataType &tensor_dtype, + const phi::DataType &expected_kernel_dtype); + + pir::Value x() { return operand_source(0); } + pir::OpResult out() { return result(0); } + + static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector> Vjp( + pir::Operation *op, + const std::vector> &inputs_, + const std::vector> &outputs, + const std::vector> &out_grads, + const std::vector> &stop_gradients); +}; + } // namespace dialect } // namespace paddle @@ -431,3 +514,5 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AssignArray_Op) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayToTensorOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc index a37cbd681d185d..f35ab01117d2a3 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/fluid/primitive/type/lazy_tensor.h" @@ -30,7 +31,7 @@ using IntArray = paddle::experimental::IntArray; std::vector> AddNOp::Vjp( pir::Operation* op, const std::vector>& inputs_, - const std::vector>& outputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { VLOG(6) << "Prepare inputs of add_n_grad"; @@ -82,7 +83,7 @@ std::vector> AddNOp::Vjp( std::vector> ExpandOp::Vjp( pir::Operation* op, const std::vector>& inputs_, - const std::vector>& outputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { PADDLE_ENFORCE_EQ(inputs_.size(), @@ -127,5 +128,69 @@ std::vector> ExpandOp::Vjp( return res; } +std::vector> IncrementOp::Vjp( + pir::Operation* op, + const std::vector>& inputs_, + const std::vector>& outputs, + const std::vector>& out_grads, + const std::vector>& stop_gradients) { + PADDLE_ENFORCE_EQ( + inputs_.size(), + 1, + platform::errors::InvalidArgument( + "Increment op's inputs size should be 2, but now is %d.", + inputs_.size())); + PADDLE_ENFORCE_EQ( + outputs.size(), + 1, + platform::errors::InvalidArgument( + "Increment op's outputs size should be 1, but now is %d.", + outputs.size())); + + VLOG(6) << "Vjp prepare Prepare attributes of increment_grad"; + + float value = op->attribute("value").dyn_cast().data(); + + VLOG(6) << "Vjp prepare call increment's vjp inteface"; + + pir::OpResult tensor_res = paddle::dialect::increment(inputs_[0][0], -value); + + std::vector> res{{tensor_res}}; + + return res; +} + +std::vector> Increment_Op::Vjp( + pir::Operation* op, + const std::vector>& inputs_, + const std::vector>& outputs, + const std::vector>& out_grads, + const std::vector>& stop_gradients) { + PADDLE_ENFORCE_EQ( + inputs_.size(), + 1, + platform::errors::InvalidArgument( + "Increment_ op's inputs size should be 2, but now is %d.", + inputs_.size())); + PADDLE_ENFORCE_EQ( + outputs.size(), + 1, + platform::errors::InvalidArgument( + "Increment_ op's outputs size should be 1, but now is %d.", + outputs.size())); + + VLOG(6) << "Vjp prepare Prepare attributes of increment__grad"; + + float value = op->attribute("value").dyn_cast().data(); + + VLOG(6) << "Vjp prepare call increment_'s vjp inteface"; + + pir::OpResult tensor_res = paddle::dialect::increment_(inputs_[0][0], -value); + + std::vector> res{{tensor_res}}; + + return res; +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 2221fba3754e0b..efeeb4855205e2 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -702,8 +702,8 @@ void BindVjp(pybind11::module *m) { "call_vjp", [](pir::Operation &fwd_op, const std::vector> &inputs, - const std::vector> &outputs, - const std::vector> &out_grads, + const std::vector> &outputs, + const std::vector> &out_grads, const std::vector> &stop_gradients) { py::list res; paddle::dialect::VjpInterface vjp_interface = diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 0961cc18a398f7..f804b3cdc91716 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -265,7 +265,6 @@ def update_no_grad_set_after_prune( from inputs to outputs add value not in the path to no_grad_set, from outputs to inputs add value not in the path to no_grad_set, ''' - inputs_set = ValueSet(inputs) for input in inputs: if not input.use_empty(): @@ -328,7 +327,6 @@ def inverse_sort_op(ops): while queue: op = queue.popleft() sorted_list.append(op) - for x in get_real_op_inputs(op): x_op = x.get_defining_op() pending_count[x_op] -= 1 @@ -339,6 +337,28 @@ def inverse_sort_op(ops): raise ValueError( "inverse_sort_op wrong, sorted_list size is not equal to origin_list size" ) + change_list = [] + for op in reversed(sorted_list): + if op.name() == 'pd_op.increment_': + idx_1 = sorted_list.index(op) + idx_2 = sorted_list.index(op) + + for op_in in reversed(sorted_list[: sorted_list.index(op)]): + if ( + some_in_set( + op.operands_source(), + ValueSet(get_real_op_inputs(op_in)), + ) + and op_in.name() != "cf.tuple_push" + ): + idx_2 = sorted_list.index(op_in) + if idx_1 != idx_2: + change_list.append((idx_1, idx_2)) + for idx_1, idx_2 in change_list: + sorted_list[idx_1], sorted_list[idx_2] = ( + sorted_list[idx_2], + sorted_list[idx_1], + ) return sorted_list @@ -353,6 +373,7 @@ def append_backward_ops( no_grad_set, backward_ops, state, + bwd_block_argument_to_value_map, ): ''' add grad_op in order of topological inverse sort @@ -394,6 +415,14 @@ def append_backward_ops( else continue to next op. ''' + def return_value_to_copyvalue_map( + value, control_flow_value_to_copyvalue_map + ): + output = value + while output in control_flow_value_to_copyvalue_map: + output = control_flow_value_to_copyvalue_map[output] + return output + def append_add_n(value): # value is input of more than one fwd_op, # so more than one bwd_op create input_grad, @@ -416,11 +445,11 @@ def make_output_with_output_grad(op): outputs = [] output_grads = [] for i, value in enumerate(op.results()): - new_value = ( - [control_flow_value_to_copyvalue_map[value]] - if value in control_flow_value_to_copyvalue_map - else [value] - ) + new_value = [ + return_value_to_copyvalue_map( + value, control_flow_value_to_copyvalue_map + ) + ] while value in state.inside_value_to_outside_value_map: value = state.inside_value_to_outside_value_map[value] @@ -465,7 +494,12 @@ def make_output_with_output_grad(op): zero_flag[i] = True outputs.append(new_value) - output_grads.append(state.value_to_valuegrad[value][0]) + grad_value = state.value_to_valuegrad[value][0] + output_grads.append( + bwd_block_argument_to_value_map[grad_value[0]] + if grad_value[0] in bwd_block_argument_to_value_map + else grad_value + ) if op.name() == "pd_op.while": for i, input in enumerate(get_real_op_inputs(op)): @@ -482,7 +516,13 @@ def make_output_with_output_grad(op): or state.value_to_valuegrad[input] == [] ): append_full_like(0.0, input, input, state, backward_ops) - output_grads.append(state.value_to_valuegrad[input][0]) + + grad_value = state.value_to_valuegrad[input][0] + output_grads.append( + bwd_block_argument_to_value_map[grad_value[0]] + if grad_value[0] in bwd_block_argument_to_value_map + else grad_value + ) return zero_flag, outputs, output_grads def get_grad_semantic_info(op): @@ -491,6 +531,8 @@ def get_grad_semantic_info(op): "pd_op.if", "pd_op.while", "cf.tuple_push", + "pd_op.increment_", + "pd_op.increment", ]: grad_semantic_info = [ True for _ in range(len(get_real_op_inputs(op))) @@ -513,18 +555,18 @@ def make_input_with_input_stopgradient(op): tmp_input = [] for tmp in input.get_defining_op().operands_source(): tmp_input.append( - control_flow_value_to_copyvalue_map[tmp] - if tmp in control_flow_value_to_copyvalue_map - else tmp + return_value_to_copyvalue_map( + tmp, control_flow_value_to_copyvalue_map + ) ) inputs.append(tmp_input) else: - tmp_input = ( - [control_flow_value_to_copyvalue_map[input]] - if input in control_flow_value_to_copyvalue_map - else [input] - ) + tmp_input = [ + return_value_to_copyvalue_map( + input, control_flow_value_to_copyvalue_map + ) + ] inputs.append(tmp_input) continue @@ -541,11 +583,11 @@ def make_input_with_input_stopgradient(op): [info[0] for info in combine_stop_gradient] ) else: - tmp_input = ( - [control_flow_value_to_copyvalue_map[input]] - if input in control_flow_value_to_copyvalue_map - else [input] - ) + tmp_input = [ + return_value_to_copyvalue_map( + input, control_flow_value_to_copyvalue_map + ) + ] inputs.append(tmp_input) if input in no_grad_set or input.stop_gradient is True: @@ -573,8 +615,8 @@ def update_input_grad_map(op, input_grads, origin_inputs): ) else: input_grad = input_grads[i] - if input in block_argument_to_value_map: - input = block_argument_to_value_map[input] + if input in fwd_block_argument_to_value_map: + input = fwd_block_argument_to_value_map[input] if isinstance(input_grad, list): state.value_to_valuegrad[input].append(input_grad) @@ -618,12 +660,12 @@ def append_yield( 0.0, value, value, state, backward_ops ) - if base_op.name() == "pd_op.while": - input_grad = paddle.add( - output_grad, state.value_to_valuegrad[value][0][0] - ) - else: - input_grad = state.value_to_valuegrad[value][0][0] + # if base_op.name() == "pd_op.while": + # input_grad = paddle.add( + # output_grad, state.value_to_valuegrad[value][0][0] + # ) + # else: + input_grad = state.value_to_valuegrad[value][0][0] inputs_grad.append(input_grad) @@ -633,13 +675,15 @@ def argument_to_value(while_op): assert len(while_op.as_while_op().block_arguments()) + 1 == len( while_op.operands_source() ), "while op's block_arguments size + 1 should same to whiel op's operands_source" - map = ValueDict() + arg_to_value_map = ValueDict() + value_to_arg_map = ValueDict() for arg, value in zip( while_op.as_while_op().block_arguments(), while_op.operands_source()[1:], ): - map[arg] = value - return map + arg_to_value_map[arg] = value + value_to_arg_map[value] = [arg] + return arg_to_value_map, value_to_arg_map # there are four patterns: # [builtin.combine , op1] (op1's one input is vectorType, outputs are not vectorType) @@ -651,7 +695,9 @@ def argument_to_value(while_op): # tuple_push value to pop value control_flow_value_to_copyvalue_map = ValueDict() control_flow_copyvalue_to_value_map = ValueDict() - block_argument_to_value_map = ValueDict() + # fwd_whileop's blockargument to fwd_whileop's input value + fwd_block_argument_to_value_map = ValueDict() + # bwd_whileop's input value to bwd_whileop's blockargument if ( len(effective_forward_ops) > 1 @@ -662,7 +708,7 @@ def argument_to_value(while_op): # while op yield [cond, loop_vars], # but outputs only has loop_vars. inside_outputs = yield_op.operands_source()[1:] - block_argument_to_value_map = argument_to_value(base_op) + fwd_block_argument_to_value_map, _ = argument_to_value(base_op) else: inside_outputs = yield_op.operands_source() @@ -695,7 +741,7 @@ def argument_to_value(while_op): input_grad_stopgradients, ) = make_input_with_input_stopgradient(op) - if op.name() == "cf.tuple_push": + if op.name() in ["cf.tuple_push", "pd_op.increment_"]: with dynamic_shape_prim_vjp_guard(op, inputs): copy_out = paddle.framework.core.call_vjp( op, @@ -704,9 +750,20 @@ def argument_to_value(while_op): output_grads, input_grad_stopgradients, ) + pop_op = bwd_block.ops[-1] bwd_ops = [pop_op] - for output, copy_output in zip(inputs[1:], copy_out[1:]): + tmp_inputs = ( + inputs + if op.name() == "pd_op.increment_" + else inputs[1:] + ) + tmp_copy_out = ( + copy_out + if op.name() == "pd_op.increment_" + else copy_out[1:] + ) + for output, copy_output in zip(tmp_inputs, tmp_copy_out): control_flow_value_to_copyvalue_map[ output[0] ] = copy_output[0] @@ -719,7 +776,7 @@ def argument_to_value(while_op): if len(output_grads) == 0 or all(zero_flag): continue - if op.name() == "pd_op.if" or op.name() == "pd_op.while": + if op.name() in ["pd_op.if", "pd_op.while"]: origin_inputs = get_used_external_value(op) for sub_block in op.blocks(): build_pipe_for_block(sub_block) @@ -737,6 +794,16 @@ def argument_to_value(while_op): for sub_fwd_block, sub_bwd_block in zip( op.blocks(), grad_op.blocks() ): + # update grad_op structure + if grad_op.name() == "pd_op.while": + ( + _, + sub_bwd_block_argument_to_value_map, + ) = argument_to_value(grad_op) + else: + sub_bwd_block_argument_to_value_map = ( + ValueDict() + ) sub_state = state.copy(sub_fwd_block) sub_backward_ops = [] append_backward_ops( @@ -749,6 +816,7 @@ def argument_to_value(while_op): no_grad_set, sub_backward_ops, sub_state, + sub_bwd_block_argument_to_value_map, ) # update input_grad map update_input_grad_map(op, input_grads, origin_inputs) @@ -908,17 +976,18 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): no_grad_set, backward_ops, state, + ValueDict(), ) # now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue) outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set( outputs_fwd_set, inputs_fwd_set, no_grad_set, state ) + _, remove_ops = prune_ops( backward_ops, inputs_set, outputs_set, no_gradvar_set ) state.turn_map() - for bwd_op in inverse_sort_op(remove_ops): if bwd_op.result(0) in ValueSet(grad_outputs): continue diff --git a/test/cpp/prim/test_vjp.cc b/test/cpp/prim/test_vjp.cc index 2c3ba7073e6065..f9393bf6b9f548 100644 --- a/test/cpp/prim/test_vjp.cc +++ b/test/cpp/prim/test_vjp.cc @@ -70,7 +70,7 @@ TEST(VJP, TanhBackwardTest) { std::vector> stop_gradients{{false}}; std::vector> inputs{{op1.out()}}; - std::vector> outputs{{op2.out()}}; + std::vector> outputs{{op2.out()}}; std::vector> out_grads{{op3.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.tanh"); @@ -125,7 +125,7 @@ TEST(VJP, Tanh_BackwardTest) { std::vector> stop_gradients{{false}}; std::vector> inputs{{op1.out()}}; - std::vector> outputs{{op2.out()}}; + std::vector> outputs{{op2.out()}}; std::vector> out_grads{{op3.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.tanh_"); @@ -181,7 +181,7 @@ TEST(VJP, MeanBackwardTest) { std::vector> stop_gradients{{false}}; std::vector> inputs{{op1.out()}}; - std::vector> outputs{{op2.out()}}; + std::vector> outputs{{op2.out()}}; std::vector> out_grads{{op3.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.mean"); @@ -241,7 +241,7 @@ TEST(VJP, ConcatBackwardTest) { std::vector> stop_gradients{{false, false}}; std::vector> inputs{{op1.out(), op1.out()}, {op3.axis()}}; - std::vector> outputs{{op3.out()}}; + std::vector> outputs{{op3.out()}}; std::vector> out_grads{{op4.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.concat"); auto concat_vjp_interface_impl = @@ -308,7 +308,7 @@ TEST(VJP, AddBackwardTest) { std::vector> stop_gradients{{false}, {false}}; std::vector> inputs{{op1.out()}, {op2.out()}}; - std::vector> outputs{{op3.out()}}; + std::vector> outputs{{op3.out()}}; std::vector> out_grads{{op4.out()}}; pir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd_op.add"); @@ -373,7 +373,7 @@ TEST(VJP, Add_BackwardTest) { std::vector> stop_gradients{{false}, {false}}; std::vector> inputs{{op1.out()}, {op2.out()}}; - std::vector> outputs{{op3.out()}}; + std::vector> outputs{{op3.out()}}; std::vector> out_grads{{op4.out()}}; pir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd_op.add_"); @@ -441,7 +441,12 @@ TEST(VJP, SplitBackwardTest) { std::vector> stop_gradients{{false}}; std::vector> inputs{ {op2.x()}, {op2.sections()}, {op2.axis()}}; - std::vector> outputs{{op3.outputs()}}; + std::vector> outputs(1); + std::vector res; + for (uint32_t i = 0; i < op3.outputs().size(); i++) { + res.push_back(op3.outputs()[i]); + } + outputs[0] = res; std::vector> out_grads{{op3.result(0), op4.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.split"); diff --git a/test/ir/pir/test_while_api.py b/test/ir/pir/test_while_api.py index c067b4d174f14d..9165cec5ac077e 100644 --- a/test/ir/pir/test_while_api.py +++ b/test/ir/pir/test_while_api.py @@ -186,15 +186,6 @@ def test_add_n_program(self): .body() .ops[-2] .name(), - "pd_op.add", - ) - self.assertEqual( - main_program.global_block() - .ops[-1] - .as_while_op() - .body() - .ops[-4] - .name(), "cf.has_elements", ) diff --git a/test/legacy_test/test_while_loop_op.py b/test/legacy_test/test_while_loop_op.py index 0b73041559c0fd..4feddf5a7c2df6 100644 --- a/test/legacy_test/test_while_loop_op.py +++ b/test/legacy_test/test_while_loop_op.py @@ -22,7 +22,7 @@ from paddle import base from paddle.base import core from paddle.base.backward import append_backward -from paddle.base.framework import Program, program_guard +from paddle.base.framework import program_guard from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -116,8 +116,8 @@ def body(i, ten, test_dict, test_list, test_list_dict): i = paddle.increment(i) return [i, ten, test_dict, test_list, test_list_dict] - main_program = Program() - startup_program = Program() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() with program_guard(main_program, startup_program): i = paddle.zeros(shape=[1], dtype='int64') ten = paddle.tensor.fill_constant( @@ -255,6 +255,7 @@ def internal_body(j, init, sums): class TestApiWhileLoop_Backward(unittest.TestCase): # TODO(zhangbo): Support while grad exe for pir + # @test_with_pir_api def test_while_loop_backward(self): def cond(i, x): return paddle.less_than(i, eleven) @@ -264,11 +265,12 @@ def body(i, x): i = paddle.increment(i) return [i, x] - main_program = Program() - startup_program = Program() - with base.program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): i = paddle.static.data(name='i', shape=[1], dtype='float32') i.stop_gradient = False + i.persistable = True eleven = paddle.tensor.fill_constant( shape=[1], dtype='float32', value=11 ) @@ -277,10 +279,11 @@ def body(i, x): ) x = paddle.static.data(name='x', shape=[1], dtype='float32') x.stop_gradient = False + x.persistable = True out = paddle.static.nn.while_loop(cond, body, [i, x]) mean = paddle.mean(out[1]) - append_backward(mean) + grad_list = append_backward(mean) place = ( base.CUDAPlace(0) @@ -294,17 +297,31 @@ def body(i, x): data = np.asarray([100]).astype('float32') i_grad = np.asarray([110]).astype('float32') - res = exe.run( - main_program, - feed={'i': feed_i, 'x': feed_x}, - fetch_list=[mean.name, i.grad_name], - ) + if paddle.framework.in_pir_mode(): + for p, g in grad_list: + if p == i: + di = g + res = exe.run( + main_program, + feed={'i': feed_i, 'x': feed_x}, + fetch_list=[mean, di], + ) + else: + res = exe.run( + main_program, + feed={'i': feed_i, 'x': feed_x}, + fetch_list=[mean.name, i.grad_name, x.grad_name], + ) np.testing.assert_allclose(np.asarray(res[0]), data, rtol=1e-05) np.testing.assert_allclose(np.asarray(res[1]), i_grad, rtol=1e-05) # TODO(zhangbo): Support while grad exe for pir + # @test_with_pir_api def test_while_loop_backward2(self): - def cond(i, x): + def cond1(i, x): + return i < 2 + + def cond2(i, x): return i < 3 def body(i, x): @@ -312,17 +329,19 @@ def body(i, x): i = i + 1 return [i, x] - main_program = Program() - startup_program = Program() - with base.program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): i = paddle.static.data(name='i', shape=[1], dtype='float32') i.stop_gradient = False + i.persistable = True x = paddle.static.data(name='x', shape=[1], dtype='float32') x.stop_gradient = False + x.persistable = True - out = paddle.static.nn.while_loop(cond, body, [i, x]) + out = paddle.static.nn.while_loop(cond1, body, [i, x]) mean = paddle.mean(out[1]) - append_backward(mean) + grad_list = append_backward(mean) place = ( base.CUDAPlace(0) @@ -334,21 +353,37 @@ def body(i, x): feed_i = np.ones(1).astype('float32') feed_x = np.ones(1).astype('float32') data = np.asarray([2]).astype('float32') + ans = np.asarray([1]).astype('float32') + x1_grad = np.asarray([1]).astype('float32') i_grad = np.asarray([3]).astype('float32') x_grad = np.asarray([2]).astype('float32') - res = exe.run( - main_program, - feed={'i': feed_i, 'x': feed_x}, - fetch_list=[mean.name, i.grad_name, x.grad_name], - ) - np.testing.assert_allclose(np.asarray(res[0]), data, rtol=1e-05) - np.testing.assert_allclose(np.asarray(res[1]), i_grad, rtol=1e-05) - np.testing.assert_allclose(np.asarray(res[2]), x_grad, rtol=1e-05) + if paddle.framework.in_pir_mode(): + for p, g in grad_list: + if p == i: + di = g + if p == x: + dx = g + res = exe.run( + main_program, + feed={'i': feed_i, 'x': feed_x}, + fetch_list=[out[1], di, dx], + ) + else: + res = exe.run( + main_program, + feed={'i': feed_i, 'x': feed_x}, + fetch_list=[out[1].name, i.grad_name, x.grad_name], + ) + + np.testing.assert_allclose(np.asarray(res[0]), ans, rtol=1e-05) + np.testing.assert_allclose(np.asarray(res[1]), ans, rtol=1e-05) + np.testing.assert_allclose(np.asarray(res[2]), ans, rtol=1e-05) class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase): # TODO(zhangbo): Support while grad exe for pir + def test_nested_net_with_backward_and_lodtensor(self): def external_cond(i, j, x, mem_array): return paddle.less_than(i, array_len) @@ -377,9 +412,9 @@ def internal_body(j, x, mem_array): ) return [i, j, x, mem_array] - main_program = Program() - startup_program = Program() - with base.program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): d0 = paddle.static.data(name='d0', shape=[10], dtype='float32') d1 = paddle.static.data(name='d1', shape=[10], dtype='float32') d2 = paddle.static.data(name='d2', shape=[10], dtype='float32') @@ -461,9 +496,9 @@ def fn_add_one(): default=fn_add_one, ) - main_program = Program() - startup_program = Program() - with base.program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): i = paddle.tensor.fill_constant(shape=[1], dtype='int64', value=1) ten = paddle.tensor.fill_constant( shape=[1], dtype='int64', value=10 @@ -532,8 +567,8 @@ def body_returns_with_mutable_list(i, test_list): ) return paddle.increment(i), test_list - main_program = Program() - startup_program = Program() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() with program_guard(main_program, startup_program): data = paddle.tensor.fill_constant( shape=[1], dtype='int64', value=1 @@ -666,8 +701,8 @@ def body(z, i): i += 1 return z, i - main_program = Program() - startup_program = Program() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() with program_guard(main_program, startup_program): x = paddle.static.data(name='x', shape=[-1, 5], dtype='int32') z = paddle.tensor.fill_constant([], 'int32', 0)