Skip to content

Commit beba53f

Browse files
committed
Merge commit 'refs/pull/59448/head' of github.com:PaddlePaddle/Paddle into develop
2 parents 50afa0b + 4b8abc4 commit beba53f

10 files changed

Lines changed: 398 additions & 22 deletions

File tree

paddle/fluid/pir/dialect/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ set(op_dialect_vjp_srcs
162162
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_vjp.cc
163163
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/op_dialect.cc
164164
${op_decomp_source_file}
165-
${op_vjp_source_file})
165+
${op_vjp_source_file}
166+
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/base/decomp_trans.cc)
166167
set(op_dialect_vjp_deps primitive_vjp_experimental op_dialect)
167168

168169
cc_library(
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// #include <sstream>
16+
// #include <string>
17+
18+
#include "paddle/fluid/primitive/base/decomp_trans.h"
19+
#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h"
20+
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
21+
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
22+
#include "paddle/fluid/prim/utils/utils.h"
23+
#include "paddle/pir/core/builtin_dialect.h"
24+
#include "paddle/pir/core/program.h"
25+
26+
PHI_DECLARE_bool(prim_skip_dynamic);
27+
28+
using paddle::dialect::DenseTensorType;
29+
using paddle::dialect::SelectedRowsType;
30+
31+
namespace paddle {
32+
33+
using Program = pir::Program;
34+
35+
static bool find_value(const std::vector<int64_t>& vec, int64_t value) {
36+
if (std::find(vec.begin(), vec.end(), value) != vec.end()) {
37+
return true;
38+
} else {
39+
return false;
40+
}
41+
}
42+
43+
static const phi::DDim& GetValueDims(pir::Value value) {
44+
if (value.type().isa<DenseTensorType>()) {
45+
return value.type().dyn_cast<DenseTensorType>().dims();
46+
} else if (value.type().isa<SelectedRowsType>()) {
47+
return value.type().dyn_cast<SelectedRowsType>().dims();
48+
} else {
49+
PADDLE_THROW(phi::errors::InvalidArgument(
50+
"Currently, we can only get shape for dense "
51+
"tensor."));
52+
}
53+
}
54+
55+
static bool check_dynamic_shape(const pir::OpOperand& item,
56+
const pir::Operation& op) {
57+
auto dims = GetValueDims(item.source());
58+
std::vector<int64_t> shape = common::vectorize<int64_t>(dims);
59+
if (find_value(shape, -1)) {
60+
LOG(WARNING)
61+
<< "[Prim] Decomp op does not support dynamic shape -1, but got "
62+
"shape ["
63+
<< dims << "] in inputs of op " << op.name();
64+
return true;
65+
} else {
66+
return false;
67+
}
68+
}
69+
70+
bool has_decomp_rule(const pir::Operation& op) {
71+
pir::IrContext* ctx = pir::IrContext::Instance();
72+
pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op.name());
73+
auto decomp_interface_impl =
74+
op_info.GetInterfaceImpl<paddle::dialect::DecompInterface>();
75+
if (decomp_interface_impl == nullptr) return false;
76+
return true;
77+
}
78+
79+
bool DecompProgram::check_decomp_dynamic_shape(pir::Operation* op) {
80+
for (auto item : op->operands()) {
81+
auto value = item.source();
82+
// check if initialized in case of optional input.
83+
if (value.impl() && value.type().storage()) {
84+
pir::Operation* prev_op = value.dyn_cast<pir::OpResult>().owner();
85+
if (prev_op->name() == "builtin.combine") {
86+
for (pir::OpOperand& sub_item : prev_op->operands()) {
87+
if (check_dynamic_shape(sub_item, *op)) {
88+
return true;
89+
}
90+
}
91+
} else {
92+
if (check_dynamic_shape(item, *op)) {
93+
return true;
94+
}
95+
}
96+
// PADDLE_ENFORCE_NOT_NULL(
97+
// prev_op, platform::errors::PreconditionNotMet("prev_op should not
98+
// be null"));
99+
}
100+
}
101+
return false;
102+
}
103+
104+
void DecompProgram::check_decomp_outputs(
105+
const std::string& op_name,
106+
const std::vector<pir::OpResult>& orig_outs,
107+
const std::vector<pir::OpResult>& decomp_outs) {
108+
return;
109+
}
110+
111+
std::vector<pir::OpResult> DecompProgram::format_decomp_res(
112+
const std::string& op_name,
113+
const std::vector<pir::OpResult>& orig_outs,
114+
const std::vector<std::vector<pir::OpResult>>& decomp_outs) {
115+
PADDLE_ENFORCE_EQ(orig_outs.size(),
116+
decomp_outs.size(),
117+
paddle::platform::errors::PreconditionNotMet(
118+
"For op %s, its origin output num %d is not equal to "
119+
"decomp output num %d ",
120+
op_name,
121+
orig_outs.size(),
122+
decomp_outs.size()));
123+
std::vector<pir::OpResult> new_decomp_outs(orig_outs.size());
124+
for (size_t i = 0; i < orig_outs.size(); i++) {
125+
if (orig_outs[i]) {
126+
PADDLE_ENFORCE_EQ(decomp_outs[i].size(),
127+
1,
128+
paddle::platform::errors::PreconditionNotMet(
129+
"For op %s, each element of decomp output num must "
130+
"be 1, but num of index %d is %d ",
131+
op_name,
132+
i,
133+
decomp_outs[i].size()));
134+
new_decomp_outs[i] = decomp_outs[i][0];
135+
}
136+
}
137+
return new_decomp_outs;
138+
}
139+
140+
std::vector<pir::OpResult> DecompProgram::construct_dst_vars(
141+
const std::string& op_name,
142+
const std::vector<pir::OpResult>& orig_outs,
143+
const std::vector<pir::OpResult>& decomp_outs,
144+
std::unordered_map<pir::OpResult, int> orig_vars_dict) {
145+
std::vector<pir::OpResult> tar_vars(src_vars_.size());
146+
PADDLE_ENFORCE_EQ(orig_outs.size(),
147+
decomp_outs.size(),
148+
paddle::platform::errors::PreconditionNotMet(
149+
"For op %s, its origin output num %d is not equal to "
150+
"decomp output num %d ",
151+
op_name,
152+
orig_outs.size(),
153+
decomp_outs.size()));
154+
for (size_t i = 0; i < orig_outs.size(); i++) {
155+
VLOG(4) << "decomp construct idx -------- " << i;
156+
if (orig_vars_dict.find(orig_outs[i]) != orig_vars_dict.end()) {
157+
VLOG(4) << "decomp construct in idx -------- " << i;
158+
tar_vars[orig_vars_dict[orig_outs[i]]] = decomp_outs[i];
159+
}
160+
}
161+
return tar_vars;
162+
}
163+
164+
bool DecompProgram::enable_decomp_by_filter(const std::string& op_name) {
165+
bool flag = true;
166+
167+
if (whitelist_.size() > 0) {
168+
if (whitelist_.find(op_name) == whitelist_.end()) {
169+
flag = false;
170+
}
171+
}
172+
if (blacklist_.size() > 0) {
173+
if (blacklist_.find(op_name) != blacklist_.end()) {
174+
flag = false;
175+
}
176+
}
177+
return flag;
178+
}
179+
180+
std::vector<std::vector<pir::OpResult>> call_decomp_rule(pir::Operation* op) {
181+
paddle::dialect::DecompInterface decomp_interface =
182+
op->dyn_cast<paddle::dialect::DecompInterface>();
183+
PADDLE_ENFORCE(
184+
decomp_interface,
185+
phi::errors::InvalidArgument(
186+
"The decomp function is not registered in %s op ", op->name()));
187+
std::vector<std::vector<pir::OpResult>> decomp_res =
188+
decomp_interface.Decomp(op);
189+
return decomp_res;
190+
}
191+
192+
DecompProgram::DecompProgram(pir::Program* program,
193+
const std::vector<pir::OpResult>& src_vars,
194+
const std::set<std::string>& blacklist,
195+
const std::set<std::string>& whitelist)
196+
: program_(program),
197+
src_vars_(src_vars),
198+
blacklist_(blacklist),
199+
whitelist_(whitelist) {}
200+
201+
std::vector<pir::OpResult> DecompProgram::decomp_program() {
202+
std::ostringstream print_stream;
203+
std::unordered_map<pir::OpResult, int> orig_vars_dict;
204+
for (size_t i = 0; i < src_vars_.size(); i++) {
205+
orig_vars_dict[src_vars_[i]] = static_cast<int>(i);
206+
}
207+
program_->Print(print_stream);
208+
VLOG(4) << "program in sink decomp ------" << print_stream.str();
209+
if (!paddle::prim::PrimCommonUtils::IsFwdPrimEnabled()) {
210+
return src_vars_;
211+
}
212+
std::vector<pir::OpResult> tar_vars(src_vars_.size());
213+
pir::Block* block = program_->block();
214+
std::vector<pir::Operation*> ops_list;
215+
for (auto& op : *block) {
216+
ops_list.push_back(&op);
217+
}
218+
for (size_t i = 0; i < ops_list.size(); i++) {
219+
auto op = ops_list[i];
220+
bool enable_prim =
221+
has_decomp_rule(*op) && enable_decomp_by_filter(op->name());
222+
if (enable_prim && FLAGS_prim_skip_dynamic &&
223+
check_decomp_dynamic_shape(op)) {
224+
enable_prim = false;
225+
}
226+
VLOG(4) << "enable_prim flag ======= " << enable_prim;
227+
if (enable_prim) {
228+
VLOG(4) << "decomp op name ======= " << op->name();
229+
230+
auto& builder = *(paddle::dialect::ApiBuilder::Instance().GetBuilder());
231+
builder.set_insertion_point(op);
232+
std::vector<std::vector<pir::OpResult>> decomp_res = call_decomp_rule(op);
233+
std::vector<pir::OpResult> orig_outs = op->results();
234+
std::vector<pir::OpResult> standard_decomp_res =
235+
format_decomp_res(op->name(), orig_outs, decomp_res);
236+
tar_vars = construct_dst_vars(
237+
op->name(), orig_outs, standard_decomp_res, orig_vars_dict);
238+
239+
VLOG(4) << "decomp out size ======= " << decomp_res.size();
240+
op->ReplaceAllUsesWith(standard_decomp_res);
241+
std::ostringstream print_stream2;
242+
program_->Print(print_stream2);
243+
VLOG(4) << "program out sink decomp ------ index " << i << ". "
244+
<< print_stream2.str();
245+
bool remove_op = true;
246+
for (auto& item : op->results()) {
247+
if (item.HasOneUse()) {
248+
remove_op = false;
249+
break;
250+
}
251+
}
252+
VLOG(4) << "program remove op ----------- " << remove_op << ". "
253+
<< op->name();
254+
if (remove_op) {
255+
auto op_iter = std::find(block->begin(), block->end(), *op);
256+
block->erase(op_iter);
257+
}
258+
}
259+
}
260+
for (size_t i = 0; i < tar_vars.size(); i++) {
261+
if (!tar_vars[i]) {
262+
VLOG(4) << "assign tar_vars =========== " << i;
263+
tar_vars[i] = src_vars_[i];
264+
}
265+
}
266+
auto& builder = *(paddle::dialect::ApiBuilder::Instance().GetBuilder());
267+
builder.SetInsertionPointToEnd(block);
268+
std::ostringstream print_stream3;
269+
program_->Print(print_stream3);
270+
VLOG(4) << "program out final ************ " << print_stream3.str();
271+
return tar_vars;
272+
}
273+
274+
} // namespace paddle
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <memory>
18+
19+
#include "paddle/fluid/framework/program_desc.h"
20+
#include "paddle/fluid/pir/dialect/operator/interface/decomp.h"
21+
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
22+
#include "paddle/pir/core/block.h"
23+
#include "paddle/pir/core/program.h"
24+
25+
namespace paddle {
26+
27+
class DecompProgram {
28+
public:
29+
DecompProgram(pir::Program* program,
30+
const std::vector<pir::OpResult>& src_vars,
31+
const std::set<std::string>& blacklist,
32+
const std::set<std::string>& whitelist);
33+
34+
std::vector<pir::OpResult> decomp_program();
35+
bool check_decomp_dynamic_shape(pir::Operation* op);
36+
void check_decomp_outputs(const std::string& op_name,
37+
const std::vector<pir::OpResult>& orig_outs,
38+
const std::vector<pir::OpResult>& decomp_outs);
39+
std::vector<pir::OpResult> format_decomp_res(
40+
const std::string& op_name,
41+
const std::vector<pir::OpResult>& orig_outs,
42+
const std::vector<std::vector<pir::OpResult>>& decomp_outs);
43+
std::vector<pir::OpResult> construct_dst_vars(
44+
const std::string& op_name,
45+
const std::vector<pir::OpResult>& orig_outs,
46+
const std::vector<pir::OpResult>& decomp_outs,
47+
std::unordered_map<pir::OpResult, int> orig_vars_dict);
48+
bool enable_decomp_by_filter(const std::string& op_name);
49+
50+
private:
51+
pir::Program* program_;
52+
std::vector<pir::OpResult> src_vars_;
53+
std::set<std::string> blacklist_;
54+
std::set<std::string> whitelist_;
55+
};
56+
57+
bool has_decomp_rule(const pir::Operation& op);
58+
59+
std::vector<std::vector<pir::OpResult>> call_decomp_rule(pir::Operation* op);
60+
61+
} // namespace paddle

0 commit comments

Comments
 (0)