diff --git a/paddle/fluid/pir/dialect/CMakeLists.txt b/paddle/fluid/pir/dialect/CMakeLists.txt index e563831e96e61a..3f82b40eeff587 100644 --- a/paddle/fluid/pir/dialect/CMakeLists.txt +++ b/paddle/fluid/pir/dialect/CMakeLists.txt @@ -162,7 +162,8 @@ set(op_dialect_vjp_srcs ${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_vjp.cc ${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/op_dialect.cc ${op_decomp_source_file} - ${op_vjp_source_file}) + ${op_vjp_source_file} + ${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/base/decomp_trans.cc) set(op_dialect_vjp_deps primitive_vjp_experimental op_dialect) cc_library( diff --git a/paddle/fluid/primitive/base/decomp_trans.cc b/paddle/fluid/primitive/base/decomp_trans.cc new file mode 100644 index 00000000000000..877cadb4d9befb --- /dev/null +++ b/paddle/fluid/primitive/base/decomp_trans.cc @@ -0,0 +1,315 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/primitive/base/decomp_trans.h" +#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/fluid/prim/utils/utils.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/core/program.h" + +PHI_DECLARE_bool(prim_skip_dynamic); + +using paddle::dialect::DenseTensorType; +using paddle::dialect::SelectedRowsType; + +namespace paddle { + +using Program = pir::Program; + +static bool find_value(const std::vector& vec, int64_t value) { + if (std::find(vec.begin(), vec.end(), value) != vec.end()) { + return true; + } else { + return false; + } +} + +static const phi::DDim& GetValueDims(pir::Value value) { + if (value.type().isa()) { + return value.type().dyn_cast().dims(); + } else if (value.type().isa()) { + return value.type().dyn_cast().dims(); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "[Prim] Currently, we can only get shape for dense " + "tensor.")); + } +} + +static phi::DataType GetValueDtype(pir::Value value) { + if (value.type().isa()) { + return paddle::dialect::TransToPhiDataType( + value.type().dyn_cast().dtype()); + } else if (value.type().isa()) { + return paddle::dialect::TransToPhiDataType( + value.type().dyn_cast().dtype()); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Currently, we can only get phi::DataType from DenseTensorType and " + "SelectedRowsType.")); + } +} + +static bool check_dynamic_shape(const pir::OpOperand& item, + const pir::Operation& op) { + auto dims = GetValueDims(item.source()); + std::vector shape = common::vectorize(dims); + if (find_value(shape, -1)) { + LOG(WARNING) + << "[Prim] Decomp op does not support dynamic shape -1, but got " + "shape [" + << dims << "] in inputs of op " << op.name(); + return true; + } else { + return false; + } +} + +bool has_decomp_rule(const pir::Operation& op) { + pir::IrContext* ctx = pir::IrContext::Instance(); + pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op.name()); + auto decomp_interface_impl = + op_info.GetInterfaceImpl(); + if (decomp_interface_impl == nullptr) return false; + return true; +} + +bool DecompProgram::check_decomp_dynamic_shape(pir::Operation* op) { + for (auto item : op->operands()) { + auto value = item.source(); + // check if initialized in case of optional input. + if (!paddle::dialect::IsEmptyValue(value)) { + pir::Operation* prev_op = value.dyn_cast().owner(); + if (prev_op->name() == "builtin.combine") { + for (pir::OpOperand& sub_item : prev_op->operands()) { + if (check_dynamic_shape(sub_item, *op)) { + return true; + } + } + } else { + if (check_dynamic_shape(item, *op)) { + return true; + } + } + } + } + return false; +} + +void DecompProgram::check_decomp_outputs( + const std::string& op_name, + const std::vector& orig_outs, + const std::vector& decomp_outs) { + for (size_t i = 0; i < orig_outs.size(); i++) { + auto orig_dtype = GetValueDtype(orig_outs[i]); + auto decomp_dtype = GetValueDtype(decomp_outs[i]); + + PADDLE_ENFORCE( + orig_dtype == decomp_dtype, + paddle::platform::errors::PreconditionNotMet( + "[Prim] For op %s, its origin output dtype %s is not equal to " + "decomp output dtype %s ", + op_name, + orig_dtype, + decomp_dtype)); + + auto orig_dim = GetValueDims(orig_outs[i]); + auto decomp_dim = GetValueDims(decomp_outs[i]); + std::vector shape = common::vectorize(orig_dim); + if (find_value(common::vectorize(orig_dim), -1)) { + LOG(WARNING) + << "[Prim] Decomp op does not support dynamic shape -1, but got " + "shape [" + << orig_dim << "] in output of origin op " << op_name; + } + if (find_value(common::vectorize(decomp_dim), -1)) { + LOG(WARNING) + << "[Prim] Decomp op does not support dynamic shape -1, but got " + "shape [" + << decomp_dim << "] in output of decomp op " << op_name; + } + + PADDLE_ENFORCE( + orig_dim == decomp_dim, + paddle::platform::errors::PreconditionNotMet( + "[Prim] For op %s, its origin output shape [%s] is not equal to " + "decomp output shape [%s] ", + op_name, + orig_dim, + decomp_dim)); + } + return; +} + +std::vector DecompProgram::format_decomp_res( + const std::string& op_name, + const std::vector& orig_outs, + const std::vector>& decomp_outs) { + PADDLE_ENFORCE_EQ( + orig_outs.size(), + decomp_outs.size(), + paddle::platform::errors::PreconditionNotMet( + "[Prim] For op %s, its origin output num %d is not equal to " + "decomp output num %d ", + op_name, + orig_outs.size(), + decomp_outs.size())); + std::vector new_decomp_outs(orig_outs.size()); + for (size_t i = 0; i < orig_outs.size(); i++) { + if (orig_outs[i]) { + PADDLE_ENFORCE_EQ( + decomp_outs[i].size(), + 1, + paddle::platform::errors::PreconditionNotMet( + "[Prim] For op %s, each element of decomp output num must " + "be 1, but num of index %d is %d ", + op_name, + i, + decomp_outs[i].size())); + new_decomp_outs[i] = decomp_outs[i][0]; + } + } + return new_decomp_outs; +} + +std::vector DecompProgram::construct_dst_vars( + const std::string& op_name, + const std::vector& orig_outs, + const std::vector& decomp_outs, + std::unordered_map orig_vars_dict) { + std::vector tar_vars(src_vars_.size()); + PADDLE_ENFORCE_EQ( + orig_outs.size(), + decomp_outs.size(), + paddle::platform::errors::PreconditionNotMet( + "[Prim] For op %s, its origin output num %d is not equal to " + "decomp output num %d ", + op_name, + orig_outs.size(), + decomp_outs.size())); + for (size_t i = 0; i < orig_outs.size(); i++) { + if (orig_vars_dict.find(orig_outs[i]) != orig_vars_dict.end()) { + tar_vars[orig_vars_dict[orig_outs[i]]] = decomp_outs[i]; + } + } + return tar_vars; +} + +bool DecompProgram::enable_decomp_by_filter(const std::string& op_name) { + bool flag = true; + + if (whitelist_.size() > 0) { + if (whitelist_.find(op_name) == whitelist_.end()) { + flag = false; + } + } + if (blacklist_.size() > 0) { + if (blacklist_.find(op_name) != blacklist_.end()) { + flag = false; + } + } + return flag; +} + +std::vector> call_decomp_rule(pir::Operation* op) { + paddle::dialect::DecompInterface decomp_interface = + op->dyn_cast(); + PADDLE_ENFORCE(decomp_interface, + phi::errors::InvalidArgument( + "[Prim] The decomp function is not registered in %s op ", + op->name())); + std::vector> decomp_res = + decomp_interface.Decomp(op); + return decomp_res; +} + +DecompProgram::DecompProgram(pir::Program* program, + const std::vector& src_vars, + const std::set& blacklist, + const std::set& whitelist) + : program_(program), + src_vars_(src_vars), + blacklist_(blacklist), + whitelist_(whitelist) {} + +std::vector DecompProgram::decomp_program() { + std::ostringstream orig_prog_stream; + std::unordered_map orig_vars_dict; + for (size_t i = 0; i < src_vars_.size(); i++) { + orig_vars_dict[src_vars_[i]] = static_cast(i); + } + program_->Print(orig_prog_stream); + VLOG(4) << "[Prim] Origin program bofore decomp :\n" + << orig_prog_stream.str(); + if (!paddle::prim::PrimCommonUtils::IsFwdPrimEnabled()) { + return src_vars_; + } + std::vector tar_vars(src_vars_.size()); + pir::Block* block = program_->block(); + std::vector ops_list; + for (auto& op : *block) { + ops_list.push_back(&op); + } + for (size_t i = 0; i < ops_list.size(); i++) { + auto op = ops_list[i]; + bool enable_prim = + has_decomp_rule(*op) && enable_decomp_by_filter(op->name()); + if (enable_prim && FLAGS_prim_skip_dynamic && + check_decomp_dynamic_shape(op)) { + enable_prim = false; + } + if (enable_prim) { + VLOG(4) << "[Prim] decomp op name " << op->name(); + check_decomp_dynamic_shape(op); + auto& builder = *(paddle::dialect::ApiBuilder::Instance().GetBuilder()); + builder.set_insertion_point(op); + std::vector> decomp_res = call_decomp_rule(op); + std::vector orig_outs = op->results(); + std::vector standard_decomp_res = + format_decomp_res(op->name(), orig_outs, decomp_res); + check_decomp_outputs(op->name(), orig_outs, standard_decomp_res); + tar_vars = construct_dst_vars( + op->name(), orig_outs, standard_decomp_res, orig_vars_dict); + + op->ReplaceAllUsesWith(standard_decomp_res); + bool remove_op = true; + for (auto& item : op->results()) { + if (item.HasOneUse()) { + remove_op = false; + break; + } + } + if (remove_op) { + auto op_iter = std::find(block->begin(), block->end(), *op); + block->erase(op_iter); + } + } + } + for (size_t i = 0; i < tar_vars.size(); i++) { + if (!tar_vars[i]) { + tar_vars[i] = src_vars_[i]; + } + } + auto& builder = *(paddle::dialect::ApiBuilder::Instance().GetBuilder()); + builder.SetInsertionPointToBlockEnd(block); + std::ostringstream decomp_prog_stream; + program_->Print(decomp_prog_stream); + VLOG(4) << "[Prim] New program after decomp :\n" << decomp_prog_stream.str(); + return tar_vars; +} + +} // namespace paddle diff --git a/paddle/fluid/primitive/base/decomp_trans.h b/paddle/fluid/primitive/base/decomp_trans.h new file mode 100644 index 00000000000000..550d8beab80314 --- /dev/null +++ b/paddle/fluid/primitive/base/decomp_trans.h @@ -0,0 +1,61 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/pir/dialect/operator/interface/decomp.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/program.h" + +namespace paddle { + +class DecompProgram { + public: + DecompProgram(pir::Program* program, + const std::vector& src_vars, + const std::set& blacklist, + const std::set& whitelist); + + std::vector decomp_program(); + bool check_decomp_dynamic_shape(pir::Operation* op); + void check_decomp_outputs(const std::string& op_name, + const std::vector& orig_outs, + const std::vector& decomp_outs); + std::vector format_decomp_res( + const std::string& op_name, + const std::vector& orig_outs, + const std::vector>& decomp_outs); + std::vector construct_dst_vars( + const std::string& op_name, + const std::vector& orig_outs, + const std::vector& decomp_outs, + std::unordered_map orig_vars_dict); + bool enable_decomp_by_filter(const std::string& op_name); + + private: + pir::Program* program_; + std::vector src_vars_; + std::set blacklist_; + std::set whitelist_; +}; + +bool has_decomp_rule(const pir::Operation& op); + +std::vector> call_decomp_rule(pir::Operation* op); + +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a215945e0001e6..6d50682b8ab572 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -201,6 +201,7 @@ limitations under the License. */ #include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h" #include "paddle/fluid/prim/utils/eager/eager_tensor_operants.h" #include "paddle/fluid/prim/utils/static/static_tensor_operants.h" +#include "paddle/fluid/primitive/base/decomp_trans.h" #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/include/operants_manager.h" @@ -769,16 +770,30 @@ void BindVjp(pybind11::module *m) { } void BindDecomp(pybind11::module *m) { + m->def("sinking_decomp", + [](pir::Program *program, + std::vector &src_vars, + std::set &blacklist, + std::set &whitelist) { + VLOG(4) << "[Prim] Bind Decomp sinking_decomp begin."; + py::list res; + DecompProgram decomp_object(program, src_vars, blacklist, whitelist); + auto tar_vars = decomp_object.decomp_program(); + for (size_t i = 0; i < tar_vars.size(); ++i) { + if (!tar_vars[i]) { + res.append(nullptr); + } else { + res.append(tar_vars[i]); + } + } + VLOG(4) << "[Prim] Bind Decomp sinking_decomp end."; + return res; + }); + m->def("call_decomp", [](pir::Operation &fwd_op) { py::list res; - paddle::dialect::DecompInterface decomp_interface = - fwd_op.dyn_cast(); - PADDLE_ENFORCE( - decomp_interface, - phi::errors::InvalidArgument( - "The decomp function is not registered in %s op ", fwd_op.name())); std::vector> decomp_res = - decomp_interface.Decomp(&fwd_op); + call_decomp_rule(&fwd_op); for (size_t i = 0; i < decomp_res.size(); ++i) { py::list sub_res; for (size_t j = 0; j < decomp_res[i].size(); ++j) { @@ -794,12 +809,7 @@ void BindDecomp(pybind11::module *m) { }); m->def("has_decomp", [](pir::Operation &fwd_op) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::OpInfo fwd_op_info = ctx->GetRegisteredOpInfo(fwd_op.name()); - auto decomp_interface_impl = - fwd_op_info.GetInterfaceImpl(); - if (decomp_interface_impl == nullptr) return false; - return true; + return paddle::has_decomp_rule(fwd_op); }); } diff --git a/paddle/phi/core/flags.cc b/paddle/phi/core/flags.cc index 242132db6e2654..ea1af5eee4d0b9 100644 --- a/paddle/phi/core/flags.cc +++ b/paddle/phi/core/flags.cc @@ -1452,6 +1452,9 @@ PHI_DEFINE_EXPORTED_int32( "been dropped when you are profiling, try increasing this value."); PHI_DEFINE_EXPORTED_bool(print_ir, false, "Whether print ir debug str."); +PHI_DEFINE_EXPORTED_bool(prim_skip_dynamic, + false, + "Whether to skip decomping op with dynamic shape."); #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ defined(PADDLE_WITH_XPU_BKCL) diff --git a/python/paddle/base/core.py b/python/paddle/base/core.py index b88b873fcfa9dd..fc461109c1e059 100644 --- a/python/paddle/base/core.py +++ b/python/paddle/base/core.py @@ -517,7 +517,17 @@ def _get_batch_norm_none_var(op): # This api is used for development for dynamic shape in prim, and will be removed in future. def _enable_prim_dynamic_shape(): - if os.getenv("FLAGS_prim_skip_dynamic") == "1": + flag = os.getenv("FLAGS_prim_skip_dynamic") + if flag and flag.lower() in ("1", "true"): + return True + else: + return False + + +# This api is used for development for sinking decomp in c++, and will be removed in future. +def _enable_sink_decomp(): + flag = os.getenv("FLAGS_sink_decomp") + if flag and flag.lower() in ("1", "true"): return True else: return False @@ -531,6 +541,10 @@ def _set_prim_forward_blacklist(*args): prim_config["forward_blacklist"].add(item) +def _reset_prim_forward_blacklist(): + prim_config["forward_blacklist"] = set() + + def _set_prim_backward_blacklist(*args): ops = set(args) for item in ops: diff --git a/python/paddle/decomposition/decomp.py b/python/paddle/decomposition/decomp.py index c9b017a111d3e6..fddfe6c72d5622 100644 --- a/python/paddle/decomposition/decomp.py +++ b/python/paddle/decomposition/decomp.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import os import typing import warnings @@ -25,6 +26,7 @@ has_decomp, ) from paddle.base.libpaddle.pir import Block, Operation, Program +from paddle.base.wrapped_decorator import signature_safe_contextmanager from paddle.framework import core from . import register @@ -32,6 +34,19 @@ logger = log_helper.get_logger(__name__, logging.DEBUG) +# For sinking decomp in c++. In future, sinking decomp will be implemented in c++ by default and then this api will be removed. +@signature_safe_contextmanager +def sink_decomp_guard(): + sink_decomp = core._enable_sink_decomp() + try: + if not sink_decomp: + os.environ['FLAGS_sink_decomp'] = 'true' + yield + finally: + if not sink_decomp: + os.environ['FLAGS_sink_decomp'] = 'false' + + def _build_tensor_tuple(xs): if isinstance(xs, pir.Value): return (xs,) @@ -201,6 +216,9 @@ def decompose( Returns: dst_vars (list): A list contains all vars which replace origin ones in src_vars. """ + if core._enable_sink_decomp(): + blacklist = core.prim_config["forward_blacklist"] | blacklist + return core.sinking_decomp(program, src_vars, blacklist, whitelist) if not core._is_fwd_prim_enabled(): return src_vars if not isinstance(program, Program): diff --git a/test/prim/pir_prim/CMakeLists.txt b/test/prim/pir_prim/CMakeLists.txt index e19bbdfd235a78..28c33771155296 100644 --- a/test/prim/pir_prim/CMakeLists.txt +++ b/test/prim/pir_prim/CMakeLists.txt @@ -8,8 +8,14 @@ set(TEST_PRIM_PURE_PIR_CASES test_prim_dynamic) foreach(target ${TEST_PRIM_PURE_PIR_CASES}) - py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1 - FLAGS_enable_pir_api=true) + py_test_modules( + ${target} + MODULES + ${target} + ENVS + GLOG_v=1 + FLAGS_enable_pir_api=true + FLAGS_prim_skip_dynamic=1) endforeach() file( diff --git a/test/prim/pir_prim/test_pir_prim_flags.py b/test/prim/pir_prim/test_pir_prim_flags.py index 0305274011b50c..62240c334c0718 100644 --- a/test/prim/pir_prim/test_pir_prim_flags.py +++ b/test/prim/pir_prim/test_pir_prim_flags.py @@ -19,7 +19,7 @@ import paddle import paddle.nn.functional as F from paddle.base import core -from paddle.decomposition import decompose +from paddle.decomposition import decomp class TestPrimBlacklistFlags(unittest.TestCase): @@ -39,7 +39,7 @@ def not_in_blacklist(self): # Ensure that tanh in original block self.assertTrue('pd_op.gelu' in fwd_ops) - [y] = decompose(main_program, [y]) + [y] = decomp.decompose(main_program, [y]) fwd_ops_new = [op.name() for op in main_program.global_block().ops] # Ensure that tanh is splitted into small ops @@ -67,7 +67,7 @@ def in_blacklist(self): # Ensure that tanh in original block self.assertTrue('pd_op.gelu' in fwd_ops) - _ = decompose(main_program, [y]) + _ = decomp.decompose(main_program, [y]) fwd_ops_new = [op.name() for op in main_program.global_block().ops] # Ensure that tanh is splitted into small ops @@ -84,6 +84,13 @@ def test_prim_forward_blacklist(self): core._set_prim_forward_blacklist("pd_op.gelu") self.in_blacklist() + def test_prim_forward_blacklist_sink(self): + with decomp.sink_decomp_guard(): + core._reset_prim_forward_blacklist() + self.not_in_blacklist() + core._set_prim_forward_blacklist("pd_op.gelu") + self.in_blacklist() + class PrimeNet(paddle.nn.Layer): def __init__(self): diff --git a/test/prim/pir_prim/test_prim_dynamic.py b/test/prim/pir_prim/test_prim_dynamic.py index a7d4d6424eda44..0d697718b169a1 100644 --- a/test/prim/pir_prim/test_prim_dynamic.py +++ b/test/prim/pir_prim/test_prim_dynamic.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import unittest import numpy as np @@ -66,12 +65,18 @@ def base_net(self, flag=None): return outs def test_prim_all_dynamic(self): - os.environ["FLAGS_prim_skip_dynamic"] = "1" res_ref = self.base_net() res = self.base_net("all") for ref, actual in zip(res_ref, res): np.testing.assert_allclose(ref, actual, rtol=1e-6) + def test_prim_all_dynamic_sink(self): + with decomp.sink_decomp_guard(): + res_ref = self.base_net() + res = self.base_net("all") + for ref, actual in zip(res_ref, res): + np.testing.assert_allclose(ref, actual, rtol=1e-6) + if __name__ == "__main__": unittest.main() diff --git a/test/prim/pir_prim/test_prim_program.py b/test/prim/pir_prim/test_prim_program.py index ca754d12e88b1b..75f1b7cdf5cdc8 100644 --- a/test/prim/pir_prim/test_prim_program.py +++ b/test/prim/pir_prim/test_prim_program.py @@ -18,7 +18,7 @@ import paddle from paddle.autograd.ir_backward import grad -from paddle.decomposition import decompose +from paddle.decomposition import decomp from paddle.framework import core paddle.enable_static() @@ -47,7 +47,7 @@ def base_net(self, flag=None): y.stop_gradient = False divide_out = paddle.divide(x, y) sum_out = paddle.mean(divide_out, axis=0) - [new_out] = decompose(main_program, [sum_out]) + [new_out] = decomp.decompose(main_program, [sum_out]) gradients = grad(new_out, (x, y)) exe = paddle.static.Executor() @@ -98,6 +98,27 @@ def test_prim_all(self): for ref, actual in zip(res_ref, res): np.testing.assert_allclose(ref, actual, rtol=1e-6) + def test_prim_forward_sink(self): + res_ref = self.base_net() + with decomp.sink_decomp_guard(): + res = self.base_net("forward") + for ref, actual in zip(res_ref, res): + np.testing.assert_equal(ref, actual) + + def test_prim_backward_sink(self): + res_ref = self.base_net() + with decomp.sink_decomp_guard(): + res = self.base_net("backward") + for ref, actual in zip(res_ref, res): + np.testing.assert_allclose(ref, actual, rtol=1e-6) + + def test_prim_all_sink(self): + res_ref = self.base_net() + with decomp.sink_decomp_guard(): + res = self.base_net("all") + for ref, actual in zip(res_ref, res): + np.testing.assert_allclose(ref, actual, rtol=1e-6) + if __name__ == "__main__": unittest.main() diff --git a/test/prim/pir_prim/test_prim_simpnet.py b/test/prim/pir_prim/test_prim_simpnet.py index 85051a26c350e7..83527a28e0996a 100644 --- a/test/prim/pir_prim/test_prim_simpnet.py +++ b/test/prim/pir_prim/test_prim_simpnet.py @@ -19,7 +19,7 @@ import paddle from paddle import _pir_ops, nn from paddle.autograd.ir_backward import grad -from paddle.decomposition import decompose +from paddle.decomposition import decomp from paddle.framework import core paddle.enable_static() @@ -62,7 +62,7 @@ def base_net(self, flag=None): l2_w = paddle.static.data('l2_w', self.shape_l2_w, dtype='float32') divide_out = paddle.divide(x, y) res = net(divide_out, l1_w, l2_w) - [res2] = decompose(main_program, [res]) + [res2] = decomp.decompose(main_program, [res]) gradients = grad(res2, (x, y)) exe = paddle.static.Executor() outs = exe.run( @@ -90,6 +90,13 @@ def test_prim_all(self): for ref, actual in zip(res_ref, res): np.testing.assert_allclose(ref, actual, rtol=1e-6) + def test_prim_all_sink(self): + res_ref = self.base_net() + with decomp.sink_decomp_guard(): + res = self.base_net("all") + for ref, actual in zip(res_ref, res): + np.testing.assert_allclose(ref, actual, rtol=1e-6) + if __name__ == "__main__": unittest.main()