Skip to content
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
638d847
sink decomp 0
cyber-pioneer Nov 27, 2023
3f71ea9
add log
cyber-pioneer Nov 27, 2023
e8df014
sink decomp 1
cyber-pioneer Nov 27, 2023
d6220ef
add cmake file
cyber-pioneer Nov 28, 2023
f02a4e7
support whole framework
cyber-pioneer Nov 28, 2023
08503d9
move call_decomp_rule
cyber-pioneer Nov 28, 2023
866ea16
fix code
cyber-pioneer Nov 28, 2023
f8b7b53
remove op
cyber-pioneer Nov 28, 2023
65f175f
fix runtime bug
cyber-pioneer Nov 28, 2023
a6b170c
support prim flag
cyber-pioneer Nov 29, 2023
372d5d2
fix checkout output
cyber-pioneer Nov 29, 2023
38ee0fa
support recover tar_vars
cyber-pioneer Dec 1, 2023
dfe7e7d
add blacklist
cyber-pioneer Dec 1, 2023
3b3c7ae
add blacklist and whitelist
cyber-pioneer Dec 1, 2023
4276f06
replace origin decomp
cyber-pioneer Dec 6, 2023
0b5fba4
remove const
cyber-pioneer Dec 6, 2023
10ac3a3
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into sink
cyber-pioneer Dec 7, 2023
82d0db2
fix blacklist
cyber-pioneer Dec 8, 2023
96fd50e
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into sink
cyber-pioneer Dec 8, 2023
3c56a7f
add check dynamic shape
cyber-pioneer Dec 8, 2023
fa72dac
dynamic flag come into effect
cyber-pioneer Dec 8, 2023
cc2765c
test case change flag
cyber-pioneer Dec 10, 2023
4b8abc4
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into sink
cyber-pioneer Dec 10, 2023
7ccbf2c
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into sink
cyber-pioneer Dec 13, 2023
b0fb895
add decomp sink guard
cyber-pioneer Dec 13, 2023
1ea5cff
polish code
cyber-pioneer Dec 13, 2023
7b9ba74
polish code
cyber-pioneer Dec 13, 2023
2755091
polish code
cyber-pioneer Dec 13, 2023
5446331
fix bug
cyber-pioneer Dec 13, 2023
9794b93
add decomp output check
cyber-pioneer Dec 13, 2023
e2c2e96
add error log
cyber-pioneer Dec 13, 2023
d18f57c
add log
cyber-pioneer Dec 13, 2023
95b8d4c
polish code
cyber-pioneer Dec 13, 2023
5aee5e4
remove test case
cyber-pioneer Dec 13, 2023
07a2834
polish code
cyber-pioneer Dec 14, 2023
8298074
add blacklist test case
cyber-pioneer Dec 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/fluid/pir/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ set(op_dialect_vjp_srcs
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_vjp.cc
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/op_dialect.cc
${op_decomp_source_file}
${op_vjp_source_file})
${op_vjp_source_file}
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/base/decomp_trans.cc)
set(op_dialect_vjp_deps primitive_vjp_experimental op_dialect)

cc_library(
Expand Down
315 changes: 315 additions & 0 deletions paddle/fluid/primitive/base/decomp_trans.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/primitive/base/decomp_trans.h"
#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/prim/utils/utils.h"
#include "paddle/pir/core/builtin_dialect.h"
#include "paddle/pir/core/program.h"

PHI_DECLARE_bool(prim_skip_dynamic);

using paddle::dialect::DenseTensorType;
using paddle::dialect::SelectedRowsType;

namespace paddle {

using Program = pir::Program;

static bool find_value(const std::vector<int64_t>& vec, int64_t value) {
if (std::find(vec.begin(), vec.end(), value) != vec.end()) {
return true;
} else {
return false;
}
}

static const phi::DDim& GetValueDims(pir::Value value) {
if (value.type().isa<DenseTensorType>()) {
return value.type().dyn_cast<DenseTensorType>().dims();
} else if (value.type().isa<SelectedRowsType>()) {
return value.type().dyn_cast<SelectedRowsType>().dims();
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"[Prim] Currently, we can only get shape for dense "
"tensor."));
}
}

static phi::DataType GetValueDtype(pir::Value value) {
if (value.type().isa<DenseTensorType>()) {
return paddle::dialect::TransToPhiDataType(
value.type().dyn_cast<DenseTensorType>().dtype());
} else if (value.type().isa<SelectedRowsType>()) {
return paddle::dialect::TransToPhiDataType(
value.type().dyn_cast<SelectedRowsType>().dtype());
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only get phi::DataType from DenseTensorType and "
"SelectedRowsType."));
}
}

static bool check_dynamic_shape(const pir::OpOperand& item,
const pir::Operation& op) {
auto dims = GetValueDims(item.source());
std::vector<int64_t> shape = common::vectorize<int64_t>(dims);
if (find_value(shape, -1)) {
LOG(WARNING)
<< "[Prim] Decomp op does not support dynamic shape -1, but got "
"shape ["
<< dims << "] in inputs of op " << op.name();
return true;
} else {
return false;
}
}

bool has_decomp_rule(const pir::Operation& op) {
pir::IrContext* ctx = pir::IrContext::Instance();
pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op.name());
auto decomp_interface_impl =
op_info.GetInterfaceImpl<paddle::dialect::DecompInterface>();
if (decomp_interface_impl == nullptr) return false;
return true;
}

bool DecompProgram::check_decomp_dynamic_shape(pir::Operation* op) {
for (auto item : op->operands()) {
auto value = item.source();
// check if initialized in case of optional input.
if (value.impl() && value.type().storage()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (value.impl() && value.type().storage()) {
if (IsEmptyValue(value)) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

pir::Operation* prev_op = value.dyn_cast<pir::OpResult>().owner();
if (prev_op->name() == "builtin.combine") {
for (pir::OpOperand& sub_item : prev_op->operands()) {
if (check_dynamic_shape(sub_item, *op)) {
return true;
}
}
} else {
if (check_dynamic_shape(item, *op)) {
return true;
}
}
}
}
return false;
}

void DecompProgram::check_decomp_outputs(
const std::string& op_name,
const std::vector<pir::OpResult>& orig_outs,
const std::vector<pir::OpResult>& decomp_outs) {
for (size_t i = 0; i < orig_outs.size(); i++) {
auto orig_dtype = GetValueDtype(orig_outs[i]);
auto decomp_dtype = GetValueDtype(decomp_outs[i]);

PADDLE_ENFORCE(
orig_dtype == decomp_dtype,
paddle::platform::errors::PreconditionNotMet(
"[Prim] For op %s, its origin output dtype %s is not equal to "
"decomp output dtype %s ",
op_name,
orig_dtype,
decomp_dtype));

auto orig_dim = GetValueDims(orig_outs[i]);
auto decomp_dim = GetValueDims(decomp_outs[i]);
std::vector<int64_t> shape = common::vectorize<int64_t>(orig_dim);
if (find_value(common::vectorize<int64_t>(orig_dim), -1)) {
LOG(WARNING)
<< "[Prim] Decomp op does not support dynamic shape -1, but got "
"shape ["
<< orig_dim << "] in output of origin op " << op_name;
}
if (find_value(common::vectorize<int64_t>(decomp_dim), -1)) {
LOG(WARNING)
<< "[Prim] Decomp op does not support dynamic shape -1, but got "
"shape ["
<< decomp_dim << "] in output of decomp op " << op_name;
}

PADDLE_ENFORCE(
orig_dim == decomp_dim,
paddle::platform::errors::PreconditionNotMet(
"[Prim] For op %s, its origin output shape [%s] is not equal to "
"decomp output shape [%s] ",
op_name,
orig_dim,
decomp_dim));
}
return;
}

std::vector<pir::OpResult> DecompProgram::format_decomp_res(
const std::string& op_name,
const std::vector<pir::OpResult>& orig_outs,
const std::vector<std::vector<pir::OpResult>>& decomp_outs) {
PADDLE_ENFORCE_EQ(
orig_outs.size(),
decomp_outs.size(),
paddle::platform::errors::PreconditionNotMet(
"[Prim] For op %s, its origin output num %d is not equal to "
"decomp output num %d ",
op_name,
orig_outs.size(),
decomp_outs.size()));
std::vector<pir::OpResult> new_decomp_outs(orig_outs.size());
for (size_t i = 0; i < orig_outs.size(); i++) {
if (orig_outs[i]) {
PADDLE_ENFORCE_EQ(
decomp_outs[i].size(),
1,
paddle::platform::errors::PreconditionNotMet(
"[Prim] For op %s, each element of decomp output num must "
"be 1, but num of index %d is %d ",
op_name,
i,
decomp_outs[i].size()));
new_decomp_outs[i] = decomp_outs[i][0];
}
}
return new_decomp_outs;
}

std::vector<pir::OpResult> DecompProgram::construct_dst_vars(
const std::string& op_name,
const std::vector<pir::OpResult>& orig_outs,
const std::vector<pir::OpResult>& decomp_outs,
std::unordered_map<pir::OpResult, int> orig_vars_dict) {
std::vector<pir::OpResult> tar_vars(src_vars_.size());
PADDLE_ENFORCE_EQ(
orig_outs.size(),
decomp_outs.size(),
paddle::platform::errors::PreconditionNotMet(
"[Prim] For op %s, its origin output num %d is not equal to "
"decomp output num %d ",
op_name,
orig_outs.size(),
decomp_outs.size()));
for (size_t i = 0; i < orig_outs.size(); i++) {
if (orig_vars_dict.find(orig_outs[i]) != orig_vars_dict.end()) {
tar_vars[orig_vars_dict[orig_outs[i]]] = decomp_outs[i];
}
}
return tar_vars;
}

bool DecompProgram::enable_decomp_by_filter(const std::string& op_name) {
bool flag = true;

if (whitelist_.size() > 0) {
if (whitelist_.find(op_name) == whitelist_.end()) {
flag = false;
}
}
if (blacklist_.size() > 0) {
if (blacklist_.find(op_name) != blacklist_.end()) {
flag = false;
}
}
return flag;
}

std::vector<std::vector<pir::OpResult>> call_decomp_rule(pir::Operation* op) {
paddle::dialect::DecompInterface decomp_interface =
op->dyn_cast<paddle::dialect::DecompInterface>();
PADDLE_ENFORCE(decomp_interface,
phi::errors::InvalidArgument(
"[Prim] The decomp function is not registered in %s op ",
op->name()));
std::vector<std::vector<pir::OpResult>> decomp_res =
decomp_interface.Decomp(op);
return decomp_res;
}

DecompProgram::DecompProgram(pir::Program* program,
const std::vector<pir::OpResult>& src_vars,
const std::set<std::string>& blacklist,
const std::set<std::string>& whitelist)
: program_(program),
src_vars_(src_vars),
blacklist_(blacklist),
whitelist_(whitelist) {}

std::vector<pir::OpResult> DecompProgram::decomp_program() {
std::ostringstream orig_prog_stream;
std::unordered_map<pir::OpResult, int> orig_vars_dict;
for (size_t i = 0; i < src_vars_.size(); i++) {
orig_vars_dict[src_vars_[i]] = static_cast<int>(i);
}
program_->Print(orig_prog_stream);
VLOG(4) << "[Prim] Origin program bofore decomp :\n"
<< orig_prog_stream.str();
if (!paddle::prim::PrimCommonUtils::IsFwdPrimEnabled()) {
return src_vars_;
}
std::vector<pir::OpResult> tar_vars(src_vars_.size());
pir::Block* block = program_->block();
std::vector<pir::Operation*> ops_list;
for (auto& op : *block) {
ops_list.push_back(&op);
}
for (size_t i = 0; i < ops_list.size(); i++) {
auto op = ops_list[i];
bool enable_prim =
has_decomp_rule(*op) && enable_decomp_by_filter(op->name());
if (enable_prim && FLAGS_prim_skip_dynamic &&
check_decomp_dynamic_shape(op)) {
enable_prim = false;
}
if (enable_prim) {
VLOG(4) << "[Prim] decomp op name " << op->name();
check_decomp_dynamic_shape(op);
auto& builder = *(paddle::dialect::ApiBuilder::Instance().GetBuilder());
builder.set_insertion_point(op);
std::vector<std::vector<pir::OpResult>> decomp_res = call_decomp_rule(op);
std::vector<pir::OpResult> orig_outs = op->results();
std::vector<pir::OpResult> standard_decomp_res =
format_decomp_res(op->name(), orig_outs, decomp_res);
check_decomp_outputs(op->name(), orig_outs, standard_decomp_res);
tar_vars = construct_dst_vars(
op->name(), orig_outs, standard_decomp_res, orig_vars_dict);

op->ReplaceAllUsesWith(standard_decomp_res);
bool remove_op = true;
for (auto& item : op->results()) {
if (item.HasOneUse()) {
remove_op = false;
break;
}
}
if (remove_op) {
auto op_iter = std::find(block->begin(), block->end(), *op);
block->erase(op_iter);
}
}
}
for (size_t i = 0; i < tar_vars.size(); i++) {
if (!tar_vars[i]) {
tar_vars[i] = src_vars_[i];
}
}
auto& builder = *(paddle::dialect::ApiBuilder::Instance().GetBuilder());
builder.SetInsertionPointToBlockEnd(block);
std::ostringstream decomp_prog_stream;
program_->Print(decomp_prog_stream);
VLOG(4) << "[Prim] New program after decomp :\n" << decomp_prog_stream.str();
return tar_vars;
}

} // namespace paddle
61 changes: 61 additions & 0 deletions paddle/fluid/primitive/base/decomp_trans.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>

#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/pir/dialect/operator/interface/decomp.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/pir/core/block.h"
#include "paddle/pir/core/program.h"

namespace paddle {

class DecompProgram {
public:
DecompProgram(pir::Program* program,
const std::vector<pir::OpResult>& src_vars,
const std::set<std::string>& blacklist,
const std::set<std::string>& whitelist);

std::vector<pir::OpResult> decomp_program();
bool check_decomp_dynamic_shape(pir::Operation* op);
void check_decomp_outputs(const std::string& op_name,
const std::vector<pir::OpResult>& orig_outs,
const std::vector<pir::OpResult>& decomp_outs);
std::vector<pir::OpResult> format_decomp_res(
const std::string& op_name,
const std::vector<pir::OpResult>& orig_outs,
const std::vector<std::vector<pir::OpResult>>& decomp_outs);
std::vector<pir::OpResult> construct_dst_vars(
const std::string& op_name,
const std::vector<pir::OpResult>& orig_outs,
const std::vector<pir::OpResult>& decomp_outs,
std::unordered_map<pir::OpResult, int> orig_vars_dict);
bool enable_decomp_by_filter(const std::string& op_name);

private:
pir::Program* program_;
std::vector<pir::OpResult> src_vars_;
std::set<std::string> blacklist_;
std::set<std::string> whitelist_;
};

bool has_decomp_rule(const pir::Operation& op);

std::vector<std::vector<pir::OpResult>> call_decomp_rule(pir::Operation* op);

} // namespace paddle
Loading