@@ -31,6 +31,7 @@ namespace backends {
3131#define KERNEL_ARGS " kernel_args"
3232#define KERNEL_ARGS_NUM " kernel_args_num"
3333#define KERNEL_STREAM " kernel_stream"
34+ #define TENSOR_SHAPE_ARGS " tensor_shape_args"
3435
3536/* *
3637 * Split a CINN Module into two separate modules, one cantains the host
@@ -150,7 +151,8 @@ struct CollectBucketStrategyHostFunctionVisitor
150151 : CollectHostFunctionVisitor(module_name),
151152 kernel_args_(KERNEL_ARGS, type_of<void *>()),
152153 kernel_args_num_(KERNEL_ARGS_NUM, type_of<int >()),
153- kernel_stream_(KERNEL_STREAM, type_of<void *>()) {}
154+ kernel_stream_(KERNEL_STREAM, type_of<void *>()),
155+ tensor_shape_args_(TENSOR_SHAPE_ARGS, type_of<int32_t **>()) {}
154156
155157 std::tuple<ir::Module, ir::Module> operator ()(Expr* expr) {
156158 ir::IRMutator<>::Visit (expr, expr);
@@ -181,6 +183,25 @@ struct CollectBucketStrategyHostFunctionVisitor
181183 {});
182184 host_module_builder.AddFunctionWithoutOptim (
183185 host_func.as_lowered_func_ref ());
186+
187+ // Parse LoweredFunc to infer output tensor's shape
188+ std::vector<ir::Expr> infer_shape_func_body_stmts (arg_defs_);
189+ infer_shape_func_body_stmts.insert (
190+ infer_shape_func_body_stmts.end (),
191+ op->infer_shape_func .as_lowered_func ()->body );
192+
193+ std::vector<ir::Argument> infer_shape_arguments = {
194+ ir::Argument (kernel_args_, ir::Argument::IO::kOutput ),
195+ ir::Argument (kernel_args_num_, ir::Argument::IO::kInput ),
196+ ir::Argument (tensor_shape_args_, ir::Argument::IO::kOutput )};
197+
198+ ir::Expr host_infer_shape_func =
199+ ir::_LoweredFunc_::Make (op->infer_shape_func .as_lowered_func ()->name ,
200+ infer_shape_arguments,
201+ ir::Block::Make (infer_shape_func_body_stmts),
202+ {});
203+ host_module_builder.AddFunctionWithoutOptim (
204+ host_infer_shape_func.as_lowered_func_ref ());
184205 }
185206
186207 void ProcessLoweredFunc (ir::Expr func, ir::Expr predicate);
@@ -199,6 +220,7 @@ struct CollectBucketStrategyHostFunctionVisitor
199220 ir::Var kernel_args_;
200221 ir::Var kernel_args_num_;
201222 ir::Var kernel_stream_;
223+ ir::Var tensor_shape_args_;
202224};
203225
204226} // namespace detail
0 commit comments