-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[Prim][PIR] Sink decomp graph #59448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 34 commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
638d847
sink decomp 0
cyber-pioneer 3f71ea9
add log
cyber-pioneer e8df014
sink decomp 1
cyber-pioneer d6220ef
add cmake file
cyber-pioneer f02a4e7
support whole framework
cyber-pioneer 08503d9
move call_decomp_rule
cyber-pioneer 866ea16
fix code
cyber-pioneer f8b7b53
remove op
cyber-pioneer 65f175f
fix runtime bug
cyber-pioneer a6b170c
support prim flag
cyber-pioneer 372d5d2
fix checkout output
cyber-pioneer 38ee0fa
support recover tar_vars
cyber-pioneer dfe7e7d
add blacklist
cyber-pioneer 3b3c7ae
add blacklist and whitelist
cyber-pioneer 4276f06
replace origin decomp
cyber-pioneer 0b5fba4
remove const
cyber-pioneer 10ac3a3
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into sink
cyber-pioneer 82d0db2
fix blacklist
cyber-pioneer 96fd50e
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into sink
cyber-pioneer 3c56a7f
add check dynamic shape
cyber-pioneer fa72dac
dynamic flag come into effect
cyber-pioneer cc2765c
test case change flag
cyber-pioneer 4b8abc4
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into sink
cyber-pioneer 7ccbf2c
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into sink
cyber-pioneer b0fb895
add decomp sink guard
cyber-pioneer 1ea5cff
polish code
cyber-pioneer 7b9ba74
polish code
cyber-pioneer 2755091
polish code
cyber-pioneer 5446331
fix bug
cyber-pioneer 9794b93
add decomp output check
cyber-pioneer e2c2e96
add error log
cyber-pioneer d18f57c
add log
cyber-pioneer 95b8d4c
polish code
cyber-pioneer 5aee5e4
remove test case
cyber-pioneer 07a2834
polish code
cyber-pioneer 8298074
add blacklist test case
cyber-pioneer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,315 @@ | ||
| // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include "paddle/fluid/primitive/base/decomp_trans.h" | ||
| #include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" | ||
| #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" | ||
| #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" | ||
| #include "paddle/fluid/pir/dialect/operator/utils/utils.h" | ||
| #include "paddle/fluid/prim/utils/utils.h" | ||
| #include "paddle/pir/core/builtin_dialect.h" | ||
| #include "paddle/pir/core/program.h" | ||
|
|
||
| PHI_DECLARE_bool(prim_skip_dynamic); | ||
|
|
||
| using paddle::dialect::DenseTensorType; | ||
| using paddle::dialect::SelectedRowsType; | ||
|
|
||
| namespace paddle { | ||
|
|
||
| using Program = pir::Program; | ||
|
|
||
| static bool find_value(const std::vector<int64_t>& vec, int64_t value) { | ||
| if (std::find(vec.begin(), vec.end(), value) != vec.end()) { | ||
| return true; | ||
| } else { | ||
| return false; | ||
| } | ||
| } | ||
|
|
||
| static const phi::DDim& GetValueDims(pir::Value value) { | ||
| if (value.type().isa<DenseTensorType>()) { | ||
| return value.type().dyn_cast<DenseTensorType>().dims(); | ||
| } else if (value.type().isa<SelectedRowsType>()) { | ||
| return value.type().dyn_cast<SelectedRowsType>().dims(); | ||
| } else { | ||
| PADDLE_THROW(phi::errors::InvalidArgument( | ||
| "[Prim] Currently, we can only get shape for dense " | ||
| "tensor.")); | ||
| } | ||
| } | ||
|
|
||
| static phi::DataType GetValueDtype(pir::Value value) { | ||
| if (value.type().isa<DenseTensorType>()) { | ||
| return paddle::dialect::TransToPhiDataType( | ||
| value.type().dyn_cast<DenseTensorType>().dtype()); | ||
| } else if (value.type().isa<SelectedRowsType>()) { | ||
| return paddle::dialect::TransToPhiDataType( | ||
| value.type().dyn_cast<SelectedRowsType>().dtype()); | ||
| } else { | ||
| PADDLE_THROW(phi::errors::InvalidArgument( | ||
| "Currently, we can only get phi::DataType from DenseTensorType and " | ||
| "SelectedRowsType.")); | ||
| } | ||
| } | ||
|
|
||
| static bool check_dynamic_shape(const pir::OpOperand& item, | ||
| const pir::Operation& op) { | ||
| auto dims = GetValueDims(item.source()); | ||
| std::vector<int64_t> shape = common::vectorize<int64_t>(dims); | ||
| if (find_value(shape, -1)) { | ||
| LOG(WARNING) | ||
| << "[Prim] Decomp op does not support dynamic shape -1, but got " | ||
| "shape [" | ||
| << dims << "] in inputs of op " << op.name(); | ||
| return true; | ||
| } else { | ||
| return false; | ||
| } | ||
| } | ||
|
|
||
| bool has_decomp_rule(const pir::Operation& op) { | ||
| pir::IrContext* ctx = pir::IrContext::Instance(); | ||
| pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op.name()); | ||
| auto decomp_interface_impl = | ||
| op_info.GetInterfaceImpl<paddle::dialect::DecompInterface>(); | ||
| if (decomp_interface_impl == nullptr) return false; | ||
| return true; | ||
| } | ||
|
|
||
| bool DecompProgram::check_decomp_dynamic_shape(pir::Operation* op) { | ||
| for (auto item : op->operands()) { | ||
| auto value = item.source(); | ||
| // check if initialized in case of optional input. | ||
| if (value.impl() && value.type().storage()) { | ||
| pir::Operation* prev_op = value.dyn_cast<pir::OpResult>().owner(); | ||
| if (prev_op->name() == "builtin.combine") { | ||
| for (pir::OpOperand& sub_item : prev_op->operands()) { | ||
| if (check_dynamic_shape(sub_item, *op)) { | ||
| return true; | ||
| } | ||
| } | ||
| } else { | ||
| if (check_dynamic_shape(item, *op)) { | ||
| return true; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| return false; | ||
| } | ||
|
|
||
| void DecompProgram::check_decomp_outputs( | ||
| const std::string& op_name, | ||
| const std::vector<pir::OpResult>& orig_outs, | ||
| const std::vector<pir::OpResult>& decomp_outs) { | ||
| for (size_t i = 0; i < orig_outs.size(); i++) { | ||
| auto orig_dtype = GetValueDtype(orig_outs[i]); | ||
| auto decomp_dtype = GetValueDtype(decomp_outs[i]); | ||
|
|
||
| PADDLE_ENFORCE( | ||
| orig_dtype == decomp_dtype, | ||
| paddle::platform::errors::PreconditionNotMet( | ||
| "[Prim] For op %s, its origin output dtype %s is not equal to " | ||
| "decomp output dtype %s ", | ||
| op_name, | ||
| orig_dtype, | ||
| decomp_dtype)); | ||
|
|
||
| auto orig_dim = GetValueDims(orig_outs[i]); | ||
| auto decomp_dim = GetValueDims(decomp_outs[i]); | ||
| std::vector<int64_t> shape = common::vectorize<int64_t>(orig_dim); | ||
| if (find_value(common::vectorize<int64_t>(orig_dim), -1)) { | ||
| LOG(WARNING) | ||
| << "[Prim] Decomp op does not support dynamic shape -1, but got " | ||
| "shape [" | ||
| << orig_dim << "] in output of origin op " << op_name; | ||
| } | ||
| if (find_value(common::vectorize<int64_t>(decomp_dim), -1)) { | ||
| LOG(WARNING) | ||
| << "[Prim] Decomp op does not support dynamic shape -1, but got " | ||
| "shape [" | ||
| << decomp_dim << "] in output of decomp op " << op_name; | ||
| } | ||
|
|
||
| PADDLE_ENFORCE( | ||
| orig_dim == decomp_dim, | ||
| paddle::platform::errors::PreconditionNotMet( | ||
| "[Prim] For op %s, its origin output shape [%s] is not equal to " | ||
| "decomp output shape [%s] ", | ||
| op_name, | ||
| orig_dim, | ||
| decomp_dim)); | ||
| } | ||
| return; | ||
| } | ||
|
|
||
| std::vector<pir::OpResult> DecompProgram::format_decomp_res( | ||
| const std::string& op_name, | ||
| const std::vector<pir::OpResult>& orig_outs, | ||
| const std::vector<std::vector<pir::OpResult>>& decomp_outs) { | ||
| PADDLE_ENFORCE_EQ( | ||
| orig_outs.size(), | ||
| decomp_outs.size(), | ||
| paddle::platform::errors::PreconditionNotMet( | ||
| "[Prim] For op %s, its origin output num %d is not equal to " | ||
| "decomp output num %d ", | ||
| op_name, | ||
| orig_outs.size(), | ||
| decomp_outs.size())); | ||
| std::vector<pir::OpResult> new_decomp_outs(orig_outs.size()); | ||
| for (size_t i = 0; i < orig_outs.size(); i++) { | ||
| if (orig_outs[i]) { | ||
| PADDLE_ENFORCE_EQ( | ||
| decomp_outs[i].size(), | ||
| 1, | ||
| paddle::platform::errors::PreconditionNotMet( | ||
| "[Prim] For op %s, each element of decomp output num must " | ||
| "be 1, but num of index %d is %d ", | ||
| op_name, | ||
| i, | ||
| decomp_outs[i].size())); | ||
| new_decomp_outs[i] = decomp_outs[i][0]; | ||
| } | ||
| } | ||
| return new_decomp_outs; | ||
| } | ||
|
|
||
| std::vector<pir::OpResult> DecompProgram::construct_dst_vars( | ||
| const std::string& op_name, | ||
| const std::vector<pir::OpResult>& orig_outs, | ||
| const std::vector<pir::OpResult>& decomp_outs, | ||
| std::unordered_map<pir::OpResult, int> orig_vars_dict) { | ||
| std::vector<pir::OpResult> tar_vars(src_vars_.size()); | ||
| PADDLE_ENFORCE_EQ( | ||
| orig_outs.size(), | ||
| decomp_outs.size(), | ||
| paddle::platform::errors::PreconditionNotMet( | ||
| "[Prim] For op %s, its origin output num %d is not equal to " | ||
| "decomp output num %d ", | ||
| op_name, | ||
| orig_outs.size(), | ||
| decomp_outs.size())); | ||
| for (size_t i = 0; i < orig_outs.size(); i++) { | ||
| if (orig_vars_dict.find(orig_outs[i]) != orig_vars_dict.end()) { | ||
| tar_vars[orig_vars_dict[orig_outs[i]]] = decomp_outs[i]; | ||
| } | ||
| } | ||
| return tar_vars; | ||
| } | ||
|
|
||
| bool DecompProgram::enable_decomp_by_filter(const std::string& op_name) { | ||
| bool flag = true; | ||
|
|
||
| if (whitelist_.size() > 0) { | ||
| if (whitelist_.find(op_name) == whitelist_.end()) { | ||
| flag = false; | ||
| } | ||
| } | ||
| if (blacklist_.size() > 0) { | ||
| if (blacklist_.find(op_name) != blacklist_.end()) { | ||
| flag = false; | ||
| } | ||
| } | ||
| return flag; | ||
| } | ||
|
|
||
| std::vector<std::vector<pir::OpResult>> call_decomp_rule(pir::Operation* op) { | ||
| paddle::dialect::DecompInterface decomp_interface = | ||
| op->dyn_cast<paddle::dialect::DecompInterface>(); | ||
| PADDLE_ENFORCE(decomp_interface, | ||
| phi::errors::InvalidArgument( | ||
| "[Prim] The decomp function is not registered in %s op ", | ||
| op->name())); | ||
| std::vector<std::vector<pir::OpResult>> decomp_res = | ||
| decomp_interface.Decomp(op); | ||
| return decomp_res; | ||
| } | ||
|
|
||
| DecompProgram::DecompProgram(pir::Program* program, | ||
| const std::vector<pir::OpResult>& src_vars, | ||
| const std::set<std::string>& blacklist, | ||
| const std::set<std::string>& whitelist) | ||
| : program_(program), | ||
| src_vars_(src_vars), | ||
| blacklist_(blacklist), | ||
| whitelist_(whitelist) {} | ||
|
|
||
| std::vector<pir::OpResult> DecompProgram::decomp_program() { | ||
| std::ostringstream orig_prog_stream; | ||
| std::unordered_map<pir::OpResult, int> orig_vars_dict; | ||
| for (size_t i = 0; i < src_vars_.size(); i++) { | ||
| orig_vars_dict[src_vars_[i]] = static_cast<int>(i); | ||
| } | ||
| program_->Print(orig_prog_stream); | ||
| VLOG(4) << "[Prim] Origin program bofore decomp :\n" | ||
| << orig_prog_stream.str(); | ||
| if (!paddle::prim::PrimCommonUtils::IsFwdPrimEnabled()) { | ||
| return src_vars_; | ||
| } | ||
| std::vector<pir::OpResult> tar_vars(src_vars_.size()); | ||
| pir::Block* block = program_->block(); | ||
| std::vector<pir::Operation*> ops_list; | ||
| for (auto& op : *block) { | ||
| ops_list.push_back(&op); | ||
| } | ||
| for (size_t i = 0; i < ops_list.size(); i++) { | ||
| auto op = ops_list[i]; | ||
| bool enable_prim = | ||
| has_decomp_rule(*op) && enable_decomp_by_filter(op->name()); | ||
| if (enable_prim && FLAGS_prim_skip_dynamic && | ||
| check_decomp_dynamic_shape(op)) { | ||
| enable_prim = false; | ||
| } | ||
| if (enable_prim) { | ||
| VLOG(4) << "[Prim] decomp op name " << op->name(); | ||
| check_decomp_dynamic_shape(op); | ||
| auto& builder = *(paddle::dialect::ApiBuilder::Instance().GetBuilder()); | ||
| builder.set_insertion_point(op); | ||
| std::vector<std::vector<pir::OpResult>> decomp_res = call_decomp_rule(op); | ||
| std::vector<pir::OpResult> orig_outs = op->results(); | ||
| std::vector<pir::OpResult> standard_decomp_res = | ||
| format_decomp_res(op->name(), orig_outs, decomp_res); | ||
| check_decomp_outputs(op->name(), orig_outs, standard_decomp_res); | ||
| tar_vars = construct_dst_vars( | ||
| op->name(), orig_outs, standard_decomp_res, orig_vars_dict); | ||
|
|
||
| op->ReplaceAllUsesWith(standard_decomp_res); | ||
| bool remove_op = true; | ||
| for (auto& item : op->results()) { | ||
| if (item.HasOneUse()) { | ||
| remove_op = false; | ||
| break; | ||
| } | ||
| } | ||
| if (remove_op) { | ||
| auto op_iter = std::find(block->begin(), block->end(), *op); | ||
| block->erase(op_iter); | ||
| } | ||
| } | ||
| } | ||
| for (size_t i = 0; i < tar_vars.size(); i++) { | ||
| if (!tar_vars[i]) { | ||
| tar_vars[i] = src_vars_[i]; | ||
| } | ||
| } | ||
| auto& builder = *(paddle::dialect::ApiBuilder::Instance().GetBuilder()); | ||
| builder.SetInsertionPointToBlockEnd(block); | ||
| std::ostringstream decomp_prog_stream; | ||
| program_->Print(decomp_prog_stream); | ||
| VLOG(4) << "[Prim] New program after decomp :\n" << decomp_prog_stream.str(); | ||
| return tar_vars; | ||
| } | ||
|
|
||
| } // namespace paddle | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <memory> | ||
|
|
||
| #include "paddle/fluid/framework/program_desc.h" | ||
| #include "paddle/fluid/pir/dialect/operator/interface/decomp.h" | ||
| #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" | ||
| #include "paddle/pir/core/block.h" | ||
| #include "paddle/pir/core/program.h" | ||
|
|
||
| namespace paddle { | ||
|
|
||
| class DecompProgram { | ||
| public: | ||
| DecompProgram(pir::Program* program, | ||
| const std::vector<pir::OpResult>& src_vars, | ||
| const std::set<std::string>& blacklist, | ||
| const std::set<std::string>& whitelist); | ||
|
|
||
| std::vector<pir::OpResult> decomp_program(); | ||
| bool check_decomp_dynamic_shape(pir::Operation* op); | ||
| void check_decomp_outputs(const std::string& op_name, | ||
| const std::vector<pir::OpResult>& orig_outs, | ||
| const std::vector<pir::OpResult>& decomp_outs); | ||
| std::vector<pir::OpResult> format_decomp_res( | ||
| const std::string& op_name, | ||
| const std::vector<pir::OpResult>& orig_outs, | ||
| const std::vector<std::vector<pir::OpResult>>& decomp_outs); | ||
| std::vector<pir::OpResult> construct_dst_vars( | ||
| const std::string& op_name, | ||
| const std::vector<pir::OpResult>& orig_outs, | ||
| const std::vector<pir::OpResult>& decomp_outs, | ||
| std::unordered_map<pir::OpResult, int> orig_vars_dict); | ||
| bool enable_decomp_by_filter(const std::string& op_name); | ||
|
|
||
| private: | ||
| pir::Program* program_; | ||
| std::vector<pir::OpResult> src_vars_; | ||
| std::set<std::string> blacklist_; | ||
| std::set<std::string> whitelist_; | ||
| }; | ||
|
|
||
| bool has_decomp_rule(const pir::Operation& op); | ||
|
|
||
| std::vector<std::vector<pir::OpResult>> call_decomp_rule(pir::Operation* op); | ||
|
|
||
| } // namespace paddle |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done