Skip to content

Commit 73750d9

Browse files
authored
rm convertToSSA API,test=huawei_ascend_npu test=nvidia_tensorrt test=verisilicon_timvx (#8988) (#9233)
1 parent 8f379b4 commit 73750d9

5 files changed

Lines changed: 192 additions & 23 deletions

File tree

CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,6 @@ lite_option(LITE_WITH_XCODE "when debug in xcode, its ON."
8888
lite_option(LITE_WITH_ARM82_FP16 "when compile with arm v8.2 fp16, it's ON." OFF)
8989
lite_option(LITE_WITH_ARM82_INT8_SDOT "when compile with arm v8.2 int8, it's ON." OFF)
9090
lite_option(LITE_WITH_CODE_META_INFO "include git version in the header file." ON)
91-
# whether convert input model which is not a DAG to SSA graph
92-
lite_option(WITH_CONVERT_TO_SSA "whether convert input model which is not a DAG to SSA graph" ON)
9391

9492
# Thirdparty
9593
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING

cmake/configure.cmake

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,3 @@ if (LITE_WITH_M1)
301301
add_definitions("-DLITE_WITH_M1")
302302
endif(LITE_WITH_M1)
303303

304-
if (WITH_CONVERT_TO_SSA STREQUAL ON)
305-
add_definitions("-DWITH_CONVERT_TO_SSA")
306-
endif(WITH_CONVERT_TO_SSA)

lite/core/optimizer/mir/type_target_cast_pass.cc

Lines changed: 173 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
3939

4040
// record the copied node.
4141
std::map<std::string, Node*> copied_nodes;
42+
// record the origin node.
43+
std::map<std::string, Node*> input_nodes;
4244
std::vector<std::string> skip_ops = {
4345
"while", "conditional_block", "write_back"};
4446

@@ -48,8 +50,14 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
4850
if (!node->IsStmt() || iter != skip_ops.end()) continue;
4951
auto inlinks = node->inlinks;
5052
for (auto* in : inlinks) {
53+
if (!input_nodes.count(in->AsArg().name))
54+
input_nodes[in->AsArg().name] = in;
5155
ComplementInputs(graph.get(), node, in, &copied_nodes);
5256
}
57+
auto outlinks = node->outlinks;
58+
for (auto* out : outlinks) {
59+
ComplementOutputs(graph.get(), node, out, &input_nodes);
60+
}
5361
}
5462
}
5563

@@ -78,17 +86,174 @@ void TypeTargetTransformPass::ComplementInputs(
7886
<< " for kernel " << inst.op()->DebugString() << " "
7987
<< *in->AsArg().type << " -> " << *decl_arg_type;
8088
// Add an IoCopy instruction to make the input compatible with other dist.
81-
AddIoCopyInst(*in->AsArg().type,
82-
*decl_arg_type,
83-
in,
84-
graph,
85-
inst_node,
86-
copied_nodes,
87-
valid_places_);
89+
AddInputIoCopyInst(*in->AsArg().type,
90+
*decl_arg_type,
91+
in,
92+
graph,
93+
inst_node,
94+
copied_nodes,
95+
valid_places_);
96+
}
97+
}
98+
99+
void TypeTargetTransformPass::AddOutputIoCopyInst(
100+
const Type& from,
101+
const Type& to,
102+
Node* out,
103+
SSAGraph* graph,
104+
Node* inst_node,
105+
const std::vector<Place>& valid_places) {
106+
CHECK(!valid_places.empty()) << "valid_place should be set";
107+
// inst -> out node(new_name) -> io_copy_op -> new_var_node(out->AsArg().name)
108+
// So there will be a new Argument node and a new IoCopy Statement Node.
109+
CHECK(out->IsArg());
110+
auto new_name = string_format("%s/target_trans", out->AsArg().name.c_str());
111+
auto* new_var_node = graph->NewArgumentNode(out->AsArg().name);
112+
113+
// Set the place for new var node, the target should be equal to to.target()
114+
// The precision and layout should be equal to from.precision(), from.layout()
115+
bool is_tensor = from.IsTensor();
116+
if (!is_tensor) {
117+
CHECK(from.IsTensorList()) << "only support tensor or tensor_array.";
118+
}
119+
if (is_tensor) {
120+
new_var_node->AsArg().type =
121+
LiteType::GetTensorTy(to.target(), from.precision(), from.layout());
122+
} else {
123+
new_var_node->AsArg().type =
124+
LiteType::GetTensorListTy(to.target(), from.precision(), from.layout());
125+
}
126+
auto* io_copy_inst = graph->NewInstructNode();
127+
std::string io_copy_type = "io_copy";
128+
// create Op and kernels.
129+
auto io_copy_op = LiteOpRegistry::Global().Create(io_copy_type);
130+
CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed";
131+
// CHECK(io_copy_op);
132+
// Create the new var manually.
133+
inst_node->AsStmt().op()->scope()->Var(new_name);
134+
135+
// Create IoCopy Instruction.
136+
cpp::OpDesc op_desc;
137+
op_desc.SetType(io_copy_type);
138+
if (is_tensor) {
139+
op_desc.SetInput("Input", {new_name});
140+
op_desc.SetOutput("Out", {out->AsArg().name});
141+
} else {
142+
op_desc.SetInput("InputArray", {new_name});
143+
op_desc.SetOutput("OutArray", {out->AsArg().name});
144+
}
145+
io_copy_op->Attach(op_desc, inst_node->AsStmt().op()->scope());
146+
auto kernels = io_copy_op->CreateKernels(valid_places);
147+
bool is_found = false;
148+
std::vector<std::unique_ptr<KernelBase>> selected_kernels;
149+
for (auto& kernel : kernels) {
150+
const Type* in_arg_ty = nullptr;
151+
const Type* out_arg_ty = nullptr;
152+
if (is_tensor) {
153+
in_arg_ty = kernel->GetInputDeclType("Input");
154+
out_arg_ty = kernel->GetOutputDeclType("Out");
155+
} else {
156+
in_arg_ty = kernel->GetInputDeclType("InputArray");
157+
out_arg_ty = kernel->GetOutputDeclType("OutArray");
158+
}
159+
160+
VLOG(4) << "------ kernel info -------";
161+
VLOG(4) << "*in_arg_ty(io_copy kernel input):" << *in_arg_ty;
162+
VLOG(4) << "from(last kernel output):" << from;
163+
VLOG(4) << "out_arg_ty(io_copy kernel output):" << *out_arg_ty;
164+
VLOG(4) << "to:" << to << "\n";
165+
166+
if (TypeCompatible(*in_arg_ty, from) &&
167+
TargetCompatibleTo(*out_arg_ty, to)) {
168+
VLOG(4) << "picked";
169+
is_found = true;
170+
}
171+
172+
if (is_found) {
173+
selected_kernels.emplace_back(std::move(kernel));
174+
// we pick the kernel
175+
io_copy_inst->AsStmt(
176+
io_copy_type, std::move(selected_kernels), io_copy_op);
177+
break;
178+
}
179+
VLOG(4) << "not picked";
180+
}
181+
182+
CHECK(is_found) << "Can't find a io_copy kernel for io_copy op: " << from
183+
<< ":" << inst_node->AsStmt().op_info()->Type() << " -> "
184+
<< to << ":" << out->AsArg().name;
185+
// Add new link, inst -> var -> io_copy_op -> new_var_node
186+
DirectedLink(out, io_copy_inst);
187+
DirectedLink(io_copy_inst, new_var_node);
188+
189+
// Update the original instruction OpDesc.
190+
// Update its output var name to the io_copy_output_name
191+
auto* inst_node_op_desc = inst_node->AsStmt().op()->mutable_op_info();
192+
for (auto& op_output : *inst_node_op_desc->mutable_outputs()) {
193+
for (auto& var_name : op_output.second)
194+
if (var_name == out->AsArg().name) var_name = new_name;
195+
}
196+
// Update the input name of Ops whose input var is out var node
197+
for (auto& op : out->outlinks) {
198+
if (!op->IsStmt()) continue;
199+
auto* op_desc = op->AsStmt().op()->mutable_op_info();
200+
for (auto& op_input : *op_desc->mutable_inputs())
201+
for (auto& var_name : op_input.second)
202+
if (var_name == out->AsArg().name) var_name = new_name;
203+
}
204+
// reset opdesc and update kernel information
205+
out->AsArg().name = new_name;
206+
auto original_selected_kernel =
207+
std::move(inst_node->AsStmt().kernels().front());
208+
auto update_op_info = *inst_node->AsStmt().op_info();
209+
inst_node->AsStmt().ResetOp(update_op_info, graph->valid_places());
210+
inst_node->AsStmt().kernels().clear();
211+
inst_node->AsStmt().kernels().emplace_back(
212+
std::move(original_selected_kernel));
213+
214+
for (auto& kernel : inst_node->AsStmt().kernels()) {
215+
VLOG(4) << "kernel info: " << kernel->name();
216+
inst_node->AsStmt().op()->AttachKernel(kernel.get());
217+
}
218+
219+
graph->CheckValid();
220+
}
221+
222+
void TypeTargetTransformPass::ComplementOutputs(
223+
SSAGraph* graph,
224+
Node* inst_node,
225+
Node* out,
226+
std::map<std::string, Node*>* input_nodes) {
227+
// If this output is out of date.
228+
if (inst_node->outlinks.end() ==
229+
std::find(inst_node->outlinks.begin(), inst_node->outlinks.end(), out))
230+
return;
231+
232+
CHECK(inst_node->IsStmt());
233+
auto& inst = inst_node->AsStmt();
234+
VLOG(3) << "found Target tensor: " << out->AsArg().name;
235+
CHECK(out->IsRoleSet());
236+
CHECK(out->IsArg());
237+
CHECK(out->AsArg().type);
238+
if (input_nodes->count(out->AsArg().name)) {
239+
if (!TargetCompatibleTo(
240+
*out->AsArg().type,
241+
*input_nodes->at(out->AsArg().name)->AsArg().type)) {
242+
VLOG(3) << "found Output Target unmatched tensor: " << out->AsArg().name
243+
<< " for kernel " << inst.op()->DebugString() << " "
244+
<< *out->AsArg().type << " -> "
245+
<< *(input_nodes->at(out->AsArg().name))->AsArg().type;
246+
AddOutputIoCopyInst(*out->AsArg().type,
247+
*input_nodes->at(out->AsArg().name)->AsArg().type,
248+
out,
249+
graph,
250+
inst_node,
251+
valid_places_);
252+
}
88253
}
89254
}
90255

91-
void TypeTargetTransformPass::AddIoCopyInst(
256+
void TypeTargetTransformPass::AddInputIoCopyInst(
92257
const Type& from,
93258
const Type& to,
94259
Node* in,

lite/core/optimizer/mir/type_target_cast_pass.h

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,25 @@ class TypeTargetTransformPass : public ProgramPass {
3838
Node* in,
3939
std::map<std::string, Node*>* copied_nodes);
4040

41-
void AddIoCopyInst(const Type& from,
42-
const Type& to,
43-
Node* in,
44-
SSAGraph* graph,
45-
Node* inst_node,
46-
std::map<std::string, Node*>* copied_nodes,
47-
const std::vector<Place>& valid_places);
41+
void ComplementOutputs(SSAGraph* graph,
42+
Node* inst_node,
43+
Node* out,
44+
std::map<std::string, Node*>* input_nodes);
45+
46+
void AddInputIoCopyInst(const Type& from,
47+
const Type& to,
48+
Node* in,
49+
SSAGraph* graph,
50+
Node* inst_node,
51+
std::map<std::string, Node*>* copied_nodes,
52+
const std::vector<Place>& valid_places);
53+
54+
void AddOutputIoCopyInst(const Type& from,
55+
const Type& to,
56+
Node* out,
57+
SSAGraph* graph,
58+
Node* inst_node,
59+
const std::vector<Place>& valid_places);
4860

4961
void SetValidPlaces(const std::vector<Place>& valid_places);
5062

lite/model_parser/model_parser.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,6 @@ void LoadModelPb(const std::string &model_dir,
245245
pb::ProgramDesc pb_prog(&pb_proto_prog);
246246
// Transform to cpp::ProgramDesc
247247
TransformProgramDescAnyToCpp(pb_prog, cpp_prog);
248-
#ifdef WITH_CONVERT_TO_SSA
249-
general::ssa::ConvertToSSA(cpp_prog);
250-
#endif
251248

252249
// Load params data from file.
253250
// NOTE: Only main block be used now.

0 commit comments

Comments
 (0)