diff --git a/paddle/fluid/pir/dialect/op_generator/api_gen.py b/paddle/fluid/pir/dialect/op_generator/api_gen.py index 39324fe9b1a99f..719ff6957162dc 100644 --- a/paddle/fluid/pir/dialect/op_generator/api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/api_gen.py @@ -27,6 +27,7 @@ PD_MANUAL_API_LIST = { 'embedding_grad', + 'assign', } H_FILE_TEMPLATE = """ @@ -241,7 +242,7 @@ def _parse_yaml(self, op_yaml_files, op_compat_yaml_file): def _need_skip(self, op_info, op_name): return ( op_info.infer_meta_func is None and op_name not in PD_MANUAL_OP_LIST - ) or op_name in PD_MANUAL_API_LIST + ) def _is_optional_input(self, op_info, input_name): name_list = op_info.input_name_list @@ -377,7 +378,10 @@ def _gen_h_file(self, op_info_items, namespaces, h_file_path): for op_name in op_info.op_phi_name: # NOTE:When infer_meta_func is None, the Build() function generated in pd_op # is wrong, so temporarily skip the automatic generation of these APIs - if self._need_skip(op_info, op_name): + if ( + self._need_skip(op_info, op_name) + or op_name in PD_MANUAL_API_LIST + ): continue declare_str += self._gen_one_declare( op_info, op_name, False, False @@ -828,7 +832,10 @@ def _gen_cpp_file(self, op_info_items, namespaces, cpp_file_path): for op_name in op_info.op_phi_name: # NOTE:When infer_meta_func is None, the Build() function generated in pd_op # is wrong, so temporarily skip the automatic generation of these APIs - if self._need_skip(op_info, op_name): + if ( + self._need_skip(op_info, op_name) + or op_name in PD_MANUAL_API_LIST + ): continue impl_str += self._gen_one_impl(op_info, op_name, False, False) if len(op_info.mutable_attribute_name_list) > 0: diff --git a/paddle/fluid/pir/dialect/op_generator/python_c_gen.py b/paddle/fluid/pir/dialect/op_generator/python_c_gen.py index 5db36e3fb06e76..d1284f0c9866dc 100644 --- a/paddle/fluid/pir/dialect/op_generator/python_c_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/python_c_gen.py @@ -441,6 +441,7 @@ def _gen_one_impl(self, op_info, op_name): def _need_skip(self, op_info, op_name): return ( super()._need_skip(op_info, op_name) + or op_name.endswith(('_grad', '_grad_', 'xpu')) or op_name in MANUAL_STATIC_OP_FUNCTION_LIST ) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc index 138cba40dadd7e..4d7c86f1eb8d5d 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc @@ -200,5 +200,25 @@ pir::OpResult slice_array_dense(pir::Value input, pir::Value starts) { return op.result(0); } +pir::OpResult assign(const pir::Value& x) { + CheckValueDataType(x, "x", "assign"); + if (x.type().isa()) { + paddle::dialect::AssignOp assign_op = + ApiBuilder::Instance().GetBuilder()->Build( + x); + return assign_op.result(0); + } else if (x.type().isa()) { + paddle::dialect::AssignArrayOp assign_array_op = + ApiBuilder::Instance() + .GetBuilder() + ->Build(x); + return assign_array_op.result(0); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Currently, assign only supports DenseTensorType and " + "DenseTensorArrayType.")); + } +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.h b/paddle/fluid/pir/dialect/operator/ir/manual_api.h index a476290dcbc3d7..044ed5097a32f2 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.h @@ -83,5 +83,7 @@ pir::OpResult add_n_array(const std::vector& inputs); pir::OpResult slice_array_dense(pir::Value input, pir::Value starts); +pir::OpResult assign(const pir::Value& x); + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 5c7387dc22e6c0..f17695c035845f 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -21,10 +21,11 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op, paddle::dialect::CreateArrayLikeOp, paddle::dialect::ArrayLengthOp, paddle::dialect::ArrayReadOp, paddle::dialect::ArrayWrite_Op, paddle::dialect::SliceArrayOp, paddle::dialect::SliceArrayDenseOp, - paddle::dialect::AssignArray_Op, paddle::dialect::ArrayToTensorOp, - paddle::dialect::TensorToArrayOp, paddle::dialect::SelectInputOp, - paddle::dialect::IncrementOp, paddle::dialect::Increment_Op, - paddle::dialect::ShapeBroadcastOp, paddle::dialect::MemcpyD2hMultiIoOp + paddle::dialect::AssignArrayOp, paddle::dialect::AssignArray_Op, + paddle::dialect::ArrayToTensorOp, paddle::dialect::TensorToArrayOp, + paddle::dialect::SelectInputOp, paddle::dialect::IncrementOp, + paddle::dialect::Increment_Op, paddle::dialect::ShapeBroadcastOp, + paddle::dialect::MemcpyD2hMultiIoOp #else #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" @@ -3687,6 +3688,170 @@ phi::DataType SliceArrayDenseOp::GetKernelTypeForVar( return expected_kernel_dtype; } +OpInfoTuple AssignArrayOp::GetOpInfo() { + std::vector inputs = { + paddle::dialect::OpInputInfo("x", + "paddle::dialect::DenseTensorArrayType", + false, + false, + false, + true)}; + std::vector attributes = {}; + std::vector outputs = { + paddle::dialect::OpOutputInfo( + "out", "paddle::dialect::DenseTensorArrayType", false, false)}; + paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo( + "UnchangedArrayInferMeta", {"x"}, "assign_array", {"x"}, {}, {}, {}, {}); + return std::make_tuple( + inputs, attributes, outputs, run_time_info, "assign_array"); +} + +void AssignArrayOp::Build(pir::Builder &builder, + pir::OperationArgument &argument, + pir::Value x_) { + VLOG(4) << "Start build AssignArrayOp"; + + VLOG(4) << "Builder construction inputs"; + std::vector argument_inputs = {x_}; + argument.AddInputs(argument_inputs); + + VLOG(4) << "Builder construction attributes"; + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorArrayType x_type; + if (x_.type().isa()) { + x_type = x_.type().dyn_cast(); + } else if (x_.type().isa()) { + paddle::dialect::AllocatedDenseTensorArrayType allocated_input = + x_.type().dyn_cast(); + x_type = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + allocated_input.dtype(), + allocated_input.data_layout()); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorArrayType or " + "paddle::dialect::AllocatedDenseTensorArrayType")); + } + + VLOG(4) << "Builder construction dense_tensor_array_x"; + paddle::dialect::IrTensor ir_tensor_x( + paddle::dialect::TransToPhiDataType(x_type.dtype()), + {}, + x_type.data_layout(), + {}); + VLOG(4) << "Builder construction meta_x"; + paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x); + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::UnchangedArrayInferMeta(meta_x, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_array_type = + paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + TransToIrDataType(dense_out.dtype()), + dense_out.layout()); + argument_outputs.push_back(out_dense_tensor_array_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); + ::pir::PassStopGradientsDefaultly(argument); +} + +void AssignArrayOp::VerifySig() { + VLOG(4) + << "Start Verifying inputs, outputs and attributes for: AssignArrayOp."; + VLOG(4) << "Verifying inputs:"; + { + auto input_size = num_operands(); + IR_ENFORCE(input_size == 1u, + "The size %d of inputs must be equal to 1.", + input_size); + IR_ENFORCE((*this) + ->operand_source(0) + .type() + .isa(), + "Type validation failed for the 0th input, got %s.", + (*this)->operand_source(0).type()); + } + VLOG(4) << "Verifying attributes:"; + { + // Attributes num is 0, not need to check attributes type. + } + VLOG(4) << "Verifying outputs:"; + { + auto output_size = num_results(); + IR_ENFORCE(output_size == 1u, + "The size %d of outputs must be equal to 1.", + output_size); + IR_ENFORCE( + (*this)->result(0).type().isa(), + "Type validation failed for the 0th output."); + } + VLOG(4) << "End Verifying for: AssignArrayOp."; +} + +void AssignArrayOp::InferMeta(phi::InferMetaContext *infer_meta) { + auto fn = PD_INFER_META(phi::UnchangedArrayInferMeta); + fn(infer_meta); +} + +phi::DataType AssignArrayOp::GetKernelTypeForVar( + const std::string &var_name, + const phi::DataType &tensor_dtype, + const phi::DataType &expected_kernel_dtype) { + VLOG(4) << "Get KernelType for Var of op: AssignArrayOp"; + + return expected_kernel_dtype; +} + +std::vector AssignArrayOp::InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes) { + VLOG(4) << "Start infermeta AssignArrayOp"; + IR_ENFORCE(input_values.size() == 1, + "Num of inputs is expected to be 1 but got %d.", + input_values.size()); + pir::Value x_ = input_values[0]; + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorArrayType x_type; + if (x_.type().isa()) { + x_type = x_.type().dyn_cast(); + } else if (x_.type().isa()) { + paddle::dialect::AllocatedDenseTensorArrayType allocated_input = + x_.type().dyn_cast(); + x_type = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + allocated_input.dtype(), + allocated_input.data_layout()); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorArrayType or " + "paddle::dialect::AllocatedDenseTensorArrayType")); + } + paddle::dialect::IrTensor dense_input( + paddle::dialect::TransToPhiDataType(x_type.dtype()), + {}, + x_type.data_layout(), + {}); + paddle::dialect::IrMetaTensor meta_input(&dense_input); + + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::UnchangedArrayInferMeta(meta_input, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_array_type = + paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + TransToIrDataType(dense_out.dtype()), + dense_out.layout()); + argument_outputs.push_back(out_dense_tensor_array_type); + return argument_outputs; +} + OpInfoTuple AssignArray_Op::GetOpInfo() { std::vector inputs = { paddle::dialect::OpInputInfo("x", @@ -4953,7 +5118,6 @@ void MemcpyD2hMultiIoOp::VerifySig() { IR_ENFORCE(input_size == 1u, "The size %d of inputs must be equal to 1.", input_size); - IR_ENFORCE((*this) ->operand_source(0) .type() @@ -5065,6 +5229,7 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayReadOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayWrite_Op) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SliceArrayOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SliceArrayDenseOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AssignArrayOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AssignArray_Op) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayToTensorOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::TensorToArrayOp) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 60766e45842cdb..d385ce92cacfe2 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -468,6 +468,37 @@ class SliceArrayDenseOp const pir::AttributeMap &attributes); }; +class AssignArrayOp + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.assign_array"; } + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + static OpInfoTuple GetOpInfo(); + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value x_); + + void VerifySig(); + + static phi::DataType GetKernelTypeForVar( + const std::string &var_name, + const phi::DataType &tensor_dtype, + const phi::DataType &expected_kernel_dtype); + + pir::Value x() { return operand_source(0); } + pir::OpResult out() { return result(0); } + + static void InferMeta(phi::InferMetaContext *infer_meta); + static std::vector InferMeta( + const std::vector &input_values, + const pir::AttributeMap &attributes); +}; + class AssignArray_Op : public pir::Op 0 and any( - isinstance(x, (Variable, paddle.pir.OpResult)) for x in input + isinstance(x, (Variable, paddle.pir.Value)) for x in input ): # We only deal with the case where the list is nested one level, convert all scalars into variables, and then use stack to process. It is necessary to ensure the consistency of types. if not all( x.shape == (1,) for x in input if isinstance( - x, (Variable, core.eager.Tensor, paddle.pir.OpResult) + x, (Variable, core.eager.Tensor, paddle.pir.Value) ) ): raise TypeError( @@ -2392,7 +2392,7 @@ def assign(x, output=None): def convert_scalar(x): if not isinstance( - x, (Variable, core.eager.Tensor, paddle.pir.OpResult) + x, (Variable, core.eager.Tensor, paddle.pir.Value) ): return assign(x) return x diff --git a/test/legacy_test/test_assign_op.py b/test/legacy_test/test_assign_op.py index 8b2bdac20499a5..3bff65836286af 100644 --- a/test/legacy_test/test_assign_op.py +++ b/test/legacy_test/test_assign_op.py @@ -24,6 +24,7 @@ from paddle import base from paddle.base import Program, core, program_guard from paddle.base.backward import append_backward +from paddle.framework import in_pir_mode from paddle.pir_utils import test_with_pir_api @@ -113,12 +114,13 @@ def test_backward(self): paddle.disable_static() -class TestAssignOpWithLoDTensorArray(unittest.TestCase): - def test_assign_LoDTensorArray(self): +class TestAssignOpWithTensorArray(unittest.TestCase): + @test_with_pir_api + def test_assign_tensor_array(self): paddle.enable_static() - main_program = Program() - startup_program = Program() - with program_guard(main_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): x = paddle.static.data(name='x', shape=[100, 10], dtype='float32') x.stop_gradient = False y = paddle.tensor.fill_constant( @@ -127,25 +129,41 @@ def test_assign_LoDTensorArray(self): z = paddle.add(x=x, y=y) i = paddle.tensor.fill_constant(shape=[1], dtype='int64', value=0) init_array = paddle.tensor.array_write(x=z, i=i) + # TODO(xiaoguoguo626807): Remove this stop_gradient=False. + init_array.stop_gradient = False array = paddle.assign(init_array) sums = paddle.tensor.array_read(array=init_array, i=i) mean = paddle.mean(sums) append_backward(mean) place = ( - base.CUDAPlace(0) - if core.is_compiled_with_cuda() - else base.CPUPlace() + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() ) - exe = base.Executor(place) + exe = paddle.static.Executor(place) feed_x = np.random.random(size=(100, 10)).astype('float32') ones = np.ones((100, 10)).astype('float32') feed_add = feed_x + ones - res = exe.run( - main_program, - feed={'x': feed_x}, - fetch_list=[sums.name, x.grad_name], - ) + if in_pir_mode(): + x_grad = None + for op in main_program.global_block().ops: + if "grad" not in op.name(): + continue + if op.operands()[0].source().is_same(x): + x_grad = op.results()[0] + assert x_grad is not None, "Can not find x_grad in main_program" + res = exe.run( + main_program, + feed={'x': feed_x}, + fetch_list=[sums, x_grad], + ) + else: + res = exe.run( + main_program, + feed={'x': feed_x}, + fetch_list=[sums.name, x.grad_name], + ) np.testing.assert_allclose(res[0], feed_add, rtol=1e-05) np.testing.assert_allclose(res[1], ones / 1000.0, rtol=1e-05) paddle.disable_static() @@ -166,44 +184,8 @@ def test_errors(self): paddle.disable_static() -class TestAssignOApi(unittest.TestCase): - def test_assign_LoDTensorArray(self): - paddle.enable_static() - main_program = Program() - startup_program = Program() - with program_guard(main_program): - x = paddle.static.data(name='x', shape=[100, 10], dtype='float32') - x.stop_gradient = False - y = paddle.tensor.fill_constant( - shape=[100, 10], dtype='float32', value=1 - ) - z = paddle.add(x=x, y=y) - i = paddle.tensor.fill_constant(shape=[1], dtype='int64', value=0) - init_array = paddle.tensor.array_write(x=z, i=i) - array = paddle.assign(init_array) - sums = paddle.tensor.array_read(array=init_array, i=i) - mean = paddle.mean(sums) - append_backward(mean) - - place = ( - base.CUDAPlace(0) - if core.is_compiled_with_cuda() - else base.CPUPlace() - ) - exe = base.Executor(place) - feed_x = np.random.random(size=(100, 10)).astype('float32') - ones = np.ones((100, 10)).astype('float32') - feed_add = feed_x + ones - res = exe.run( - main_program, - feed={'x': feed_x}, - fetch_list=[sums.name, x.grad_name], - ) - np.testing.assert_allclose(res[0], feed_add, rtol=1e-05) - np.testing.assert_allclose(res[1], ones / 1000.0, rtol=1e-05) - paddle.disable_static() - - def test_assign_NumpyArray(self): +class TestAssignOpApi(unittest.TestCase): + def test_assign_numpy_array(self): for dtype in [np.bool_, np.float32, np.int32, np.int64]: with base.dygraph.guard(): array = np.random.random(size=(100, 10)).astype(dtype) @@ -259,7 +241,7 @@ def test_clone(self): @unittest.skipIf( not paddle.is_compiled_with_cuda(), "FP16 test runs only on GPU" ) -class TestAssignOApiFP16(unittest.TestCase): +class TestAssignOpApiFP16(unittest.TestCase): def test_assign_fp16(self): x = np.random.uniform(0, 10, [3, 3]).astype(np.float16) x = paddle.to_tensor(x)