@@ -39,6 +39,8 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
3939
4040 // record the copied node.
4141 std::map<std::string, Node*> copied_nodes;
42+ // record the origin node.
43+ std::map<std::string, Node*> input_nodes;
4244 std::vector<std::string> skip_ops = {
4345 " while" , " conditional_block" , " write_back" };
4446
@@ -48,8 +50,14 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
4850 if (!node->IsStmt () || iter != skip_ops.end ()) continue ;
4951 auto inlinks = node->inlinks ;
5052 for (auto * in : inlinks) {
53+ if (!input_nodes.count (in->AsArg ().name ))
54+ input_nodes[in->AsArg ().name ] = in;
5155 ComplementInputs (graph.get (), node, in, &copied_nodes);
5256 }
57+ auto outlinks = node->outlinks ;
58+ for (auto * out : outlinks) {
59+ ComplementOutputs (graph.get (), node, out, &input_nodes);
60+ }
5361 }
5462}
5563
@@ -78,17 +86,174 @@ void TypeTargetTransformPass::ComplementInputs(
7886 << " for kernel " << inst.op ()->DebugString () << " "
7987 << *in->AsArg ().type << " -> " << *decl_arg_type;
8088 // Add an IoCopy instruction to make the input compatible with other dist.
81- AddIoCopyInst (*in->AsArg ().type ,
82- *decl_arg_type,
83- in,
84- graph,
85- inst_node,
86- copied_nodes,
87- valid_places_);
89+ AddInputIoCopyInst (*in->AsArg ().type ,
90+ *decl_arg_type,
91+ in,
92+ graph,
93+ inst_node,
94+ copied_nodes,
95+ valid_places_);
96+ }
97+ }
98+
99+ void TypeTargetTransformPass::AddOutputIoCopyInst (
100+ const Type& from,
101+ const Type& to,
102+ Node* out,
103+ SSAGraph* graph,
104+ Node* inst_node,
105+ const std::vector<Place>& valid_places) {
106+ CHECK (!valid_places.empty ()) << " valid_place should be set" ;
107+ // inst -> out node(new_name) -> io_copy_op -> new_var_node(out->AsArg().name)
108+ // So there will be a new Argument node and a new IoCopy Statement Node.
109+ CHECK (out->IsArg ());
110+ auto new_name = string_format (" %s/target_trans" , out->AsArg ().name .c_str ());
111+ auto * new_var_node = graph->NewArgumentNode (out->AsArg ().name );
112+
113+ // Set the place for new var node, the target should be equal to to.target()
114+ // The precision and layout should be equal to from.precision(), from.layout()
115+ bool is_tensor = from.IsTensor ();
116+ if (!is_tensor) {
117+ CHECK (from.IsTensorList ()) << " only support tensor or tensor_array." ;
118+ }
119+ if (is_tensor) {
120+ new_var_node->AsArg ().type =
121+ LiteType::GetTensorTy (to.target (), from.precision (), from.layout ());
122+ } else {
123+ new_var_node->AsArg ().type =
124+ LiteType::GetTensorListTy (to.target (), from.precision (), from.layout ());
125+ }
126+ auto * io_copy_inst = graph->NewInstructNode ();
127+ std::string io_copy_type = " io_copy" ;
128+ // create Op and kernels.
129+ auto io_copy_op = LiteOpRegistry::Global ().Create (io_copy_type);
130+ CHECK (io_copy_op) << " create op [" << io_copy_op << " ] failed" ;
131+ // CHECK(io_copy_op);
132+ // Create the new var manually.
133+ inst_node->AsStmt ().op ()->scope ()->Var (new_name);
134+
135+ // Create IoCopy Instruction.
136+ cpp::OpDesc op_desc;
137+ op_desc.SetType (io_copy_type);
138+ if (is_tensor) {
139+ op_desc.SetInput (" Input" , {new_name});
140+ op_desc.SetOutput (" Out" , {out->AsArg ().name });
141+ } else {
142+ op_desc.SetInput (" InputArray" , {new_name});
143+ op_desc.SetOutput (" OutArray" , {out->AsArg ().name });
144+ }
145+ io_copy_op->Attach (op_desc, inst_node->AsStmt ().op ()->scope ());
146+ auto kernels = io_copy_op->CreateKernels (valid_places);
147+ bool is_found = false ;
148+ std::vector<std::unique_ptr<KernelBase>> selected_kernels;
149+ for (auto & kernel : kernels) {
150+ const Type* in_arg_ty = nullptr ;
151+ const Type* out_arg_ty = nullptr ;
152+ if (is_tensor) {
153+ in_arg_ty = kernel->GetInputDeclType (" Input" );
154+ out_arg_ty = kernel->GetOutputDeclType (" Out" );
155+ } else {
156+ in_arg_ty = kernel->GetInputDeclType (" InputArray" );
157+ out_arg_ty = kernel->GetOutputDeclType (" OutArray" );
158+ }
159+
160+ VLOG (4 ) << " ------ kernel info -------" ;
161+ VLOG (4 ) << " *in_arg_ty(io_copy kernel input):" << *in_arg_ty;
162+ VLOG (4 ) << " from(last kernel output):" << from;
163+ VLOG (4 ) << " out_arg_ty(io_copy kernel output):" << *out_arg_ty;
164+ VLOG (4 ) << " to:" << to << " \n " ;
165+
166+ if (TypeCompatible (*in_arg_ty, from) &&
167+ TargetCompatibleTo (*out_arg_ty, to)) {
168+ VLOG (4 ) << " picked" ;
169+ is_found = true ;
170+ }
171+
172+ if (is_found) {
173+ selected_kernels.emplace_back (std::move (kernel));
174+ // we pick the kernel
175+ io_copy_inst->AsStmt (
176+ io_copy_type, std::move (selected_kernels), io_copy_op);
177+ break ;
178+ }
179+ VLOG (4 ) << " not picked" ;
180+ }
181+
182+ CHECK (is_found) << " Can't find a io_copy kernel for io_copy op: " << from
183+ << " :" << inst_node->AsStmt ().op_info ()->Type () << " -> "
184+ << to << " :" << out->AsArg ().name ;
185+ // Add new link, inst -> var -> io_copy_op -> new_var_node
186+ DirectedLink (out, io_copy_inst);
187+ DirectedLink (io_copy_inst, new_var_node);
188+
189+ // Update the original instruction OpDesc.
190+ // Update its output var name to the io_copy_output_name
191+ auto * inst_node_op_desc = inst_node->AsStmt ().op ()->mutable_op_info ();
192+ for (auto & op_output : *inst_node_op_desc->mutable_outputs ()) {
193+ for (auto & var_name : op_output.second )
194+ if (var_name == out->AsArg ().name ) var_name = new_name;
195+ }
196+ // Update the input name of Ops whose input var is out var node
197+ for (auto & op : out->outlinks ) {
198+ if (!op->IsStmt ()) continue ;
199+ auto * op_desc = op->AsStmt ().op ()->mutable_op_info ();
200+ for (auto & op_input : *op_desc->mutable_inputs ())
201+ for (auto & var_name : op_input.second )
202+ if (var_name == out->AsArg ().name ) var_name = new_name;
203+ }
204+ // reset opdesc and update kernel information
205+ out->AsArg ().name = new_name;
206+ auto original_selected_kernel =
207+ std::move (inst_node->AsStmt ().kernels ().front ());
208+ auto update_op_info = *inst_node->AsStmt ().op_info ();
209+ inst_node->AsStmt ().ResetOp (update_op_info, graph->valid_places ());
210+ inst_node->AsStmt ().kernels ().clear ();
211+ inst_node->AsStmt ().kernels ().emplace_back (
212+ std::move (original_selected_kernel));
213+
214+ for (auto & kernel : inst_node->AsStmt ().kernels ()) {
215+ VLOG (4 ) << " kernel info: " << kernel->name ();
216+ inst_node->AsStmt ().op ()->AttachKernel (kernel.get ());
217+ }
218+
219+ graph->CheckValid ();
220+ }
221+
222+ void TypeTargetTransformPass::ComplementOutputs (
223+ SSAGraph* graph,
224+ Node* inst_node,
225+ Node* out,
226+ std::map<std::string, Node*>* input_nodes) {
227+ // If this output is out of date.
228+ if (inst_node->outlinks .end () ==
229+ std::find (inst_node->outlinks .begin (), inst_node->outlinks .end (), out))
230+ return ;
231+
232+ CHECK (inst_node->IsStmt ());
233+ auto & inst = inst_node->AsStmt ();
234+ VLOG (3 ) << " found Target tensor: " << out->AsArg ().name ;
235+ CHECK (out->IsRoleSet ());
236+ CHECK (out->IsArg ());
237+ CHECK (out->AsArg ().type );
238+ if (input_nodes->count (out->AsArg ().name )) {
239+ if (!TargetCompatibleTo (
240+ *out->AsArg ().type ,
241+ *input_nodes->at (out->AsArg ().name )->AsArg ().type )) {
242+ VLOG (3 ) << " found Output Target unmatched tensor: " << out->AsArg ().name
243+ << " for kernel " << inst.op ()->DebugString () << " "
244+ << *out->AsArg ().type << " -> "
245+ << *(input_nodes->at (out->AsArg ().name ))->AsArg ().type ;
246+ AddOutputIoCopyInst (*out->AsArg ().type ,
247+ *input_nodes->at (out->AsArg ().name )->AsArg ().type ,
248+ out,
249+ graph,
250+ inst_node,
251+ valid_places_);
252+ }
88253 }
89254}
90255
91- void TypeTargetTransformPass::AddIoCopyInst (
256+ void TypeTargetTransformPass::AddInputIoCopyInst (
92257 const Type& from,
93258 const Type& to,
94259 Node* in,
0 commit comments