diff --git a/paddle/fluid/pir/dialect/CMakeLists.txt b/paddle/fluid/pir/dialect/CMakeLists.txt index b95a2b17a3696c..78968aaee3e95c 100644 --- a/paddle/fluid/pir/dialect/CMakeLists.txt +++ b/paddle/fluid/pir/dialect/CMakeLists.txt @@ -314,8 +314,10 @@ cc_library( #Note(risemeup1):compile some *.cc files which depend on primitive_vjp_experimental into op_dialect_vjp.a/lib set(op_decomp_source_file ${PD_DIALECT_SOURCE_DIR}/op_decomp.cc) +# set(op_decomp_vjp_source_file ${PD_DIALECT_SOURCE_DIR}/op_decomp_vjp.cc) set(op_dialect_vjp_srcs ${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_decomp.cc + ${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_decomp_vjp.cc ${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_vjp.cc ${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/op_dialect.cc ${op_decomp_source_file} diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py index 21bfd1bd3a0ed7..70e1171b069255 100644 --- a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -97,6 +97,11 @@ "unsqueeze", ] - # xshape output will no longer used after decomp, but return none to keep output num the same as origin op decomp_ops_contain_unused_output = ["squeeze", "unsqueeze"] + +decomp_vjp_interface_declare_gen_op_list = [ + "add_grad", + "matmul_grad", + "relu_grad", +] diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index a792a920328f18..897cf162622186 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -20,7 +20,10 @@ from distutils.util import strtobool import yaml -from decomp_interface_gen_op_list import decomp_interface_declare_gen_op_list +from decomp_interface_gen_op_list import ( + decomp_interface_declare_gen_op_list, + decomp_vjp_interface_declare_gen_op_list, +) from gen_utils import to_pascal_case from infer_symbolic_shape_gen import gen_infer_symbolic_shape_str from op_all_func_gen import gen_op_all_func @@ -85,6 +88,7 @@ #include #include "paddle/fluid/pir/dialect/operator/interface/decomp.h" +#include "paddle/fluid/pir/dialect/operator/interface/decomp_vjp.h" #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/layout_transformation.h" @@ -1348,6 +1352,8 @@ def AutoCodeGen( exclusive_interface_str_tmp = exclusive_interface_str decomp_interface_str = "paddle::dialect::DecompInterface" decomp_interface_declare_str = "\n static std::vector> Decomp(pir::Operation* op);" + decomp_vjp_interface_str = "paddle::dialect::DecompVjpInterface" + decomp_vjp_interface_declare_str = "\n static std::vector> DecompVjp(pir::Operation* op);" # If op has inplace info, we will generate inplace op and non-inplace op. for op_name in op_info.op_phi_name: @@ -1392,6 +1398,23 @@ def AutoCodeGen( not in exclusive_interface_str ): exclusive_interface_str += decomp_interface_declare_str + elif ( + op_name in decomp_vjp_interface_declare_gen_op_list + and kernel_func_name + in decomp_vjp_interface_declare_gen_op_list + and dialect_name != "onednn_op" + ): + if decomp_vjp_interface_str not in op_interfaces: + op_interfaces = op_interfaces + [ + decomp_vjp_interface_str + ] + if ( + decomp_vjp_interface_declare_str + not in exclusive_interface_str + ): + exclusive_interface_str += ( + decomp_vjp_interface_declare_str + ) else: op_interfaces = op_interfaces_tmp exclusive_interface_str = exclusive_interface_str_tmp diff --git a/paddle/fluid/pir/dialect/operator/interface/decomp_vjp.h b/paddle/fluid/pir/dialect/operator/interface/decomp_vjp.h new file mode 100644 index 00000000000000..bc8c8254df6b5a --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/decomp_vjp.h @@ -0,0 +1,52 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/pir/include/core/op_base.h" + +namespace paddle { +namespace dialect { +class DecompVjpInterface : public pir::OpInterfaceBase { + public: + struct Concept { + explicit Concept( + std::vector> (*decomp)(pir::Operation* op)) + : decomp_(decomp) {} + std::vector> (*decomp_)(pir::Operation* op); + }; + + template + struct Model : public Concept { + static std::vector> DecompVjp(pir::Operation* op) { + return ConcreteOp::DecompVjp(op); + } + Model() : Concept(DecompVjp) {} + }; + + /// Constructor + DecompVjpInterface(pir::Operation* op, Concept* impl) + : pir::OpInterfaceBase(op), impl_(impl) {} + + std::vector> DecompVjp(pir::Operation* op) { + return impl_->decomp_(op); + } + + private: + Concept* impl_; +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DecompVjpInterface) diff --git a/paddle/fluid/pir/dialect/operator/interface/interface.cc b/paddle/fluid/pir/dialect/operator/interface/interface.cc index 3c697a409e5d25..74b7ad583a48d1 100644 --- a/paddle/fluid/pir/dialect/operator/interface/interface.cc +++ b/paddle/fluid/pir/dialect/operator/interface/interface.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/pir/dialect/operator/interface/decomp.h" +#include "paddle/fluid/pir/dialect/operator/interface/decomp_vjp.h" #include "paddle/fluid/pir/dialect/operator/interface/get_kernel_type_for_var.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" @@ -20,4 +21,5 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InferMetaInterface) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OpYamlInfoInterface) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DecompInterface) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DecompVjpInterface) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::GetKernelTypeForVarInterface) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op_decomp_vjp.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op_decomp_vjp.cc new file mode 100644 index 00000000000000..836e0787c92e31 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op_decomp_vjp.cc @@ -0,0 +1,213 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/fluid/primitive/composite/composite.h" +#include "paddle/fluid/primitive/rule/vjp/details.h" +#include "paddle/fluid/primitive/rule/vjp/generated/generated_vjp.h" +#include "paddle/fluid/primitive/type/lazy_tensor.h" +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/pir/include/core/builtin_op.h" +#include "paddle/pir/include/core/op_base.h" + +// TODO(chenzhuo) +// this file will be generated in pd_op_decomp_vjp.cc + +namespace paddle { +namespace dialect { +using IntArray = paddle::experimental::IntArray; + +std::vector> AddGradOp::DecompVjp(pir::Operation* op) { + VLOG(4) << "Decomp call add_grad's decomp interface begin"; + + AddGradOp op_obj = op->dyn_cast(); + (void)op_obj; + + FLAGS_tensor_operants_mode = "static"; + + VLOG(6) << "Decomp Prepare inputs of add_grad"; + + Tensor x(std::make_shared(op_obj.x())); + Tensor y(std::make_shared(op_obj.y())); + Tensor out_grad(std::make_shared(op_obj.out_grad())); + + VLOG(6) << "Decomp prepare attributes of add_grad"; + int axis = op->attribute("axis").dyn_cast().data(); + + VLOG(6) << "Decomp call add_grad's composite rule prepare"; + + std::vector> stop_gradients(op->results().size()); + if (op->HasAttribute(kAttrStopGradients)) { + auto stop_gradients_attr = op->attribute(kAttrStopGradients) + .dyn_cast() + .AsVector(); + stop_gradients[0].push_back( + stop_gradients_attr[0].dyn_cast().data()); + stop_gradients[1].push_back( + stop_gradients_attr[1].dyn_cast().data()); + VLOG(0) << " stop_gradients is set "; + } else { + stop_gradients[0].push_back(false); + stop_gradients[1].push_back(false); + VLOG(0) << " stop_gradients is not set "; + } + + std::vector> tensor_res; + for (auto arg : stop_gradients) { + tensor_res.push_back(std::vector(arg.size())); + } + std::string op_name = "add_grad"; + FLAGS_tensor_operants_mode = "static"; + VLOG(4) << "Call Pir Decomposed backward op add_grad"; + paddle::Tensor* x_grad = !stop_gradients[0][0] ? &tensor_res[0][0] : nullptr; + paddle::Tensor* y_grad = !stop_gradients[1][0] ? &tensor_res[1][0] : nullptr; + paddle::primitive::details::add_grad( + x, y, out_grad, axis, x_grad, y_grad); + std::vector> res(tensor_res.size()); + for (size_t i = 0; i < tensor_res.size(); ++i) { + res[i].resize(tensor_res[i].size()); + for (size_t j = 0; j < tensor_res[i].size(); ++j) { + if (tensor_res[i][j].defined()) { + res[i][j] = std::static_pointer_cast( + tensor_res[i][j].impl()) + ->value(); + } + } + } + return res; +} + +std::vector> ReluGradOp::DecompVjp(pir::Operation* op) { + VLOG(4) << "Decomp call relu_grad's decomp interface begin"; + + ReluGradOp op_obj = op->dyn_cast(); + (void)op_obj; + + FLAGS_tensor_operants_mode = "static"; + + VLOG(6) << "Decomp Prepare inputs of relu_grad"; + + Tensor out(std::make_shared(op_obj.out())); + Tensor out_grad(std::make_shared(op_obj.out_grad())); + + VLOG(6) << "Decomp prepare attributes of relu_grad"; + + VLOG(6) << "Decomp call relu_grad's composite rule prepare"; + + std::vector> stop_gradients(op->results().size()); + if (op->HasAttribute(kAttrStopGradients)) { + auto stop_gradients_attr = op->attribute(kAttrStopGradients) + .dyn_cast() + .AsVector(); + stop_gradients[0].push_back( + stop_gradients_attr[0].dyn_cast().data()); + VLOG(0) << " stop_gradients is set "; + } else { + stop_gradients[0].push_back(false); + VLOG(0) << " stop_gradients is not set "; + } + + std::vector> tensor_res; + for (auto arg : stop_gradients) { + tensor_res.push_back(std::vector(arg.size())); + } + std::string op_name = "relu_grad"; + FLAGS_tensor_operants_mode = "static"; + VLOG(4) << "Call Pir Decomposed backward op relu_grad"; + paddle::Tensor* x_grad = !stop_gradients[0][0] ? &tensor_res[0][0] : nullptr; + paddle::primitive::details::relu_grad( + out, out_grad, x_grad); + std::vector> res(tensor_res.size()); + for (size_t i = 0; i < tensor_res.size(); ++i) { + res[i].resize(tensor_res[i].size()); + for (size_t j = 0; j < tensor_res[i].size(); ++j) { + if (tensor_res[i][j].defined()) { + res[i][j] = std::static_pointer_cast( + tensor_res[i][j].impl()) + ->value(); + } + } + } + return res; +} + +std::vector> MatmulGradOp::DecompVjp( + pir::Operation* op) { + VLOG(4) << "Decomp call matmul_grad's decomp interface begin"; + + MatmulGradOp op_obj = op->dyn_cast(); + (void)op_obj; + + FLAGS_tensor_operants_mode = "static"; + + VLOG(6) << "Decomp Prepare inputs of matmul_grad"; + + Tensor x(std::make_shared(op_obj.x())); + Tensor y(std::make_shared(op_obj.y())); + Tensor out_grad(std::make_shared(op_obj.out_grad())); + + VLOG(6) << "Decomp prepare attributes of matmul_grad"; + bool transpose_x = + op->attribute("transpose_x").dyn_cast().data(); + bool transpose_y = + op->attribute("transpose_y").dyn_cast().data(); + + VLOG(6) << "Decomp call matmul_grad's composite rule prepare"; + + std::vector> stop_gradients(op->results().size()); + if (op->HasAttribute(kAttrStopGradients)) { + auto stop_gradients_attr = op->attribute(kAttrStopGradients) + .dyn_cast() + .AsVector(); + stop_gradients[0].push_back( + stop_gradients_attr[0].dyn_cast().data()); + stop_gradients[1].push_back( + stop_gradients_attr[1].dyn_cast().data()); + VLOG(0) << " stop_gradients is set "; + } else { + stop_gradients[0].push_back(false); + stop_gradients[1].push_back(false); + VLOG(0) << " stop_gradients is not set "; + } + + std::vector> tensor_res; + for (auto arg : stop_gradients) { + tensor_res.push_back(std::vector(arg.size())); + } + std::string op_name = "matmul_grad"; + FLAGS_tensor_operants_mode = "static"; + VLOG(4) << "Call Pir Decomposed backward op matmul_grad"; + paddle::Tensor* x_grad = !stop_gradients[0][0] ? &tensor_res[0][0] : nullptr; + paddle::Tensor* y_grad = !stop_gradients[1][0] ? &tensor_res[1][0] : nullptr; + paddle::primitive::details::matmul_grad( + x, y, out_grad, transpose_x, transpose_y, x_grad, y_grad); + std::vector> res(tensor_res.size()); + for (size_t i = 0; i < tensor_res.size(); ++i) { + res[i].resize(tensor_res[i].size()); + for (size_t j = 0; j < tensor_res[i].size(); ++j) { + if (tensor_res[i][j].defined()) { + res[i][j] = std::static_pointer_cast( + tensor_res[i][j].impl()) + ->value(); + } + } + } + return res; +} + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c04cd803249132..617fecfa4a4f94 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -205,6 +205,7 @@ limitations under the License. */ #include "paddle/fluid/eager/nan_inf_utils.h" #include "paddle/fluid/imperative/layout_autotune.h" #include "paddle/fluid/pir/dialect/operator/interface/decomp.h" +#include "paddle/fluid/pir/dialect/operator/interface/decomp_vjp.h" #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" #include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" #include "paddle/fluid/pir/dialect/operator/ir/manual_pylayer_op.h" @@ -1028,6 +1029,42 @@ void BindDecomp(pybind11::module *m) { }); } +void BindDecompVjp(pybind11::module *m) { + m->def("call_decomp_vjp", [](pir::Operation &vjp_op) { + py::list res; + paddle::dialect::DecompVjpInterface decomp_vjp_interface = + vjp_op.dyn_cast(); + PADDLE_ENFORCE( + decomp_vjp_interface, + phi::errors::InvalidArgument( + "[Prim] The decomp_vjp function is not registered in %s vjp_op ", + vjp_op.name())); + std::vector> decomp_res = + decomp_vjp_interface.DecompVjp(&vjp_op); + + for (size_t i = 0; i < decomp_res.size(); ++i) { + py::list sub_res; + for (size_t j = 0; j < decomp_res[i].size(); ++j) { + if (!decomp_res[i][j]) { + sub_res.append(nullptr); + } else { + sub_res.append(decomp_res[i][j]); + } + } + res.append(sub_res); + } + return res; + }); + + m->def("has_decomp_vjp", [](pir::Operation &vjp_op) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::OpInfo vjp_op_info = ctx->GetRegisteredOpInfo(vjp_op.name()); + auto decomp_vjp_interface_impl = + vjp_op_info.GetInterfaceImpl(); + return decomp_vjp_interface_impl != nullptr; + }); +} + PYBIND11_MODULE(libpaddle, m) { BindImperative(&m); BindEager(&m); @@ -3261,6 +3298,7 @@ All parameter, weight, gradient are variables in Paddle. BindPir(&m); BindVjp(&m); BindDecomp(&m); + BindDecompVjp(&m); #ifdef PADDLE_WITH_DISTRIBUTE BindDistApi(&m); #endif diff --git a/python/paddle/decomposition/decomp.py b/python/paddle/decomposition/decomp.py index ba6dc8e862e2a4..a031642b30448f 100644 --- a/python/paddle/decomposition/decomp.py +++ b/python/paddle/decomposition/decomp.py @@ -21,8 +21,10 @@ from paddle.autograd.backward_utils import ValueDict, ValueSet from paddle.base.core import ( call_decomp, + call_decomp_vjp, decomp_ops_contain_unused_output, has_decomp, + has_decomp_vjp, ) from paddle.base.libpaddle.pir import Block, Operation from paddle.base.wrapped_decorator import signature_safe_contextmanager @@ -844,9 +846,23 @@ def decompose_dist_program(pir_program): ''' Decompose all non-primitive ops into primitive ops in a pir program. It may contain forward ops and backward ops. ''' - # Todo(CZ): Decompose backward ops. + # decomp forward composite ops decompose(pir_program, []) + # decomp backward ops + block = pir_program.global_block() + with paddle.pir.core.program_guard(pir_program): + ops = pir_program.global_block().ops + for op in ops: + bwd_op_name = op.name() + if has_decomp_vjp(op): + pir.set_insertion_point(op) + orig_outs = op.results() + decomp_outs = call_decomp_vjp(op) + new_outs = _analyse_decomp_results(orig_outs, decomp_outs, op) + op.replace_all_uses_with(new_outs) + block.remove_op(op) + def decompose_pir_program(pir_program, param_mapping, grad_var_to_var): ''' diff --git a/test/prim/pir_prim/test_decomp_whole_program.py b/test/prim/pir_prim/test_decomp_whole_program.py new file mode 100644 index 00000000000000..f8c58ef7c24692 --- /dev/null +++ b/test/prim/pir_prim/test_decomp_whole_program.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.autograd.ir_backward import grad +from paddle.decomposition import decomp + +paddle.enable_static() + + +class TestPrimMode(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.shape_x = [32, 32] + self.shape_y = [32, 32] + self.x = np.random.random(self.shape_x).astype("float32") + self.y = np.random.random(self.shape_y).astype("float32") + + def base_net(self, flag=None): + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x = paddle.static.data('x', self.shape_x, dtype='float32') + y = paddle.static.data('y', self.shape_y, dtype='float32') + x.stop_gradient = False + y.stop_gradient = False + x1 = paddle.sin(x) + y1 = paddle.cos(y) + tmp1 = paddle.matmul(x1, y1) + tmp2 = paddle.mean(tmp1) + sum_out = paddle.sin(tmp2) + gradients = grad(sum_out, (x, y)) + if flag == "prim": + with decomp.prim_guard(): + decomp.decompose_dist_program(main_program) + exe = paddle.static.Executor() + [fwd, dx, dy] = exe.run( + feed={'x': self.x, 'y': self.y}, fetch_list=[sum_out, gradients] + ) + + whole_ops = [op.name() for op in main_program.global_block().ops] + if flag == "prim": + assert 'pd_op.matmul_grad' not in whole_ops + else: + assert 'pd_op.matmul_grad' in whole_ops + + return fwd, dx, dy + + def test_prim_all(self): + res_ref = self.base_net() + res = self.base_net("prim") + for ref, actual in zip(res_ref, res): + np.testing.assert_allclose(ref, actual, rtol=1e-6) + + +if __name__ == "__main__": + unittest.main()