@@ -62,90 +62,102 @@ void DebugCinnCompiledResult(const CinnCompiledObject& result) {
6262 << " ]" ;
6363}
6464
65- std::vector<std::string> MapPaddleVariablesToCinn (
66- const std::vector<std::string>& paddle_names,
67- const std::unordered_map<std::string, std::string>& paddle2cinn_varmap) {
68- std::vector<std::string> result;
69- result.reserve (result.size ());
65+ void LaunchCinnExecution (const CinnCompiledObject& compiled_obj,
66+ const CinnLaunchContext& context) {
67+ compiled_obj.runtime_program ->Execute (&context.FinalizeArguments ());
68+ }
69+
70+ CinnLaunchContext::CinnLaunchContext (const CinnCompiledObject& compiled_obj)
71+ : paddle2cinn_varmap_(compiled_obj.paddle2cinn_varmap),
72+ cinn_scope_ (compiled_obj.scope) {
73+ auto var_names = cinn_scope_->var_names ();
74+ cinn_variable_names_.reserve (var_names.size ());
7075 std::transform (
71- paddle_names.begin (), paddle_names.end (), std::back_inserter (result),
72- [&paddle2cinn_varmap](const std::string& pd_name) {
73- PADDLE_ENFORCE_GT (paddle2cinn_varmap.count (pd_name), 0 ,
74- platform::errors::NotFound (
75- " Not found the corresponding cinn variable "
76- " of paddle variable(%s) in compilation result." ,
77- pd_name));
78- return paddle2cinn_varmap.at (pd_name);
79- });
80- return result;
76+ var_names.begin (), var_names.end (),
77+ std::inserter (cinn_variable_names_, cinn_variable_names_.end ()),
78+ [](const auto & name_view) { return std::string (name_view.data ()); });
8179}
8280
83- std::vector<CinnTensor> GetCinnTensorsFromCompiledScope (
84- const std::vector<std::string>& cinn_names, const CinnScope& cinn_scope) {
85- std::vector<CinnTensor> result;
86- result.reserve (cinn_names.size ());
87- std::transform (cinn_names.begin (), cinn_names.end (),
88- std::back_inserter (result),
89- [&cinn_scope](const std::string& var_name) {
90- PADDLE_ENFORCE_NOT_NULL (
91- cinn_scope.FindVar (var_name),
92- platform::errors::NotFound (
93- " Variable(%s) not found in cinn scope." , var_name));
94- return cinn_scope.GetTensor (var_name);
95- });
96- return result;
81+ bool CinnLaunchContext::IsVariableUsed (const std::string& paddle_name) {
82+ return paddle2cinn_varmap_.count (paddle_name) > 0 &&
83+ cinn_variable_names_.count (paddle2cinn_varmap_.at (paddle_name)) > 0 ;
84+ }
85+
86+ CinnTensor CinnLaunchContext::GetCinnTensor (const std::string& var_name) {
87+ PADDLE_ENFORCE_GT (cinn_variable_names_.count (var_name), 0 ,
88+ platform::errors::NotFound (
89+ " Variable(%s) not found in cinn scope." , var_name));
90+ return cinn_scope_->GetTensor (var_name);
91+ }
92+
93+ std::vector<std::string> CinnLaunchContext::GetInternalVariableNames () {
94+ std::unordered_set<std::string> all_parameters (cinn_variable_names_);
95+ std::for_each (name2argument_.begin (), name2argument_.end (),
96+ [&all_parameters](const auto & name2arg) {
97+ all_parameters.erase (name2arg.first );
98+ });
99+ return {all_parameters.begin (), all_parameters.end ()};
100+ }
101+
102+ void CinnLaunchContext::MutableTensorData (const std::string& var_name,
103+ const platform::Place& place,
104+ LoDTensor* paddle_tensor,
105+ bool is_internal_var) {
106+ auto cinn_name = var_name;
107+ if (!is_internal_var) {
108+ PADDLE_ENFORCE_EQ (IsVariableUsed (var_name), true ,
109+ platform::errors::InvalidArgument (
110+ " Paddle variable(%s) not used by cinn" , var_name));
111+ cinn_name = paddle2cinn_varmap_.at (var_name);
112+ }
113+
114+ auto cinn_tensor = GetCinnTensor (cinn_name);
115+ // TODO(CtfGo): support mutable corresponding c++ type after CINN ready
116+ paddle_tensor->mutable_data <float >(
117+ framework::make_ddim (cinn_tensor->shape ().data ()), place);
97118}
98119
99- void CheckTensorEquivalent (const std::string& paddle_name,
100- const LoDTensor* paddle_tensor,
101- const CinnTensor& cinn_tensor) {
120+ void CinnLaunchContext:: CheckTensorEquivalent (const std::string& paddle_name,
121+ const LoDTensor& paddle_tensor,
122+ const CinnTensor& cinn_tensor) {
102123 PADDLE_ENFORCE_EQ (
103- paddle_tensor-> IsInitialized (), true ,
124+ paddle_tensor. IsInitialized (), true ,
104125 platform::errors::InvalidArgument (
105- " The tensor in variable(%s) is not initialized." , paddle_name));
126+ " Tensor in variable(%s) is not initialized." , paddle_name));
106127
107128 // check dimension
108129 auto cinn_dims = framework::make_ddim (cinn_tensor->shape ().data ());
109- PADDLE_ENFORCE_EQ (paddle_tensor->dims (), cinn_dims,
110- platform::errors::InvalidArgument (
111- " The tensor dimension in variable(%s) "
112- " is not equivalent, paddle is [%s] "
113- " but cinn is [%s]." ,
114- paddle_name, paddle_tensor->dims (), cinn_dims));
130+ PADDLE_ENFORCE_EQ (paddle_tensor.dims (), cinn_dims,
131+ platform::errors::PreconditionNotMet (
132+ " Tensors' shape in variable(%s) are not equivalent, "
133+ " paddle's shape = [%s], but cinn's shape = [%s]." ,
134+ paddle_name, paddle_tensor.dims (), cinn_dims));
115135
116136 // TODO(CtfGo): check the underlying data type after CINN ready
117137}
118138
119- void TensorMutableDataWithCinnInfo (const platform::Place& place,
120- const CinnTensor& cinn_tensor,
121- LoDTensor* paddle_tensor) {
122- // TODO(CtfGo): support mutable corresponding c++ type after CINN ready
123- paddle_tensor->mutable_data <float >(
124- framework::make_ddim (cinn_tensor->shape ().data ()), place);
125- }
126-
127- std::vector<std::string> SeperateTempVar (
128- const CinnScope& cinn_scope,
129- const std::vector<std::string>& input_cinn_names,
130- const std::vector<std::string>& output_cinn_names) {
131- auto cinn_var_names = cinn_scope.var_names ();
132- std::unordered_set<std::string> all_cinn_names;
133- all_cinn_names.reserve (cinn_var_names.size ());
134- std::transform (
135- cinn_var_names.begin (), cinn_var_names.end (),
136- std::inserter (all_cinn_names, all_cinn_names.end ()),
137- [](const auto & name_view) { return std::string (name_view.data ()); });
139+ void CinnLaunchContext::AssignExternalVariable (const std::string& paddle_name,
140+ LoDTensor* paddle_tensor) {
141+ PADDLE_ENFORCE_EQ (IsVariableUsed (paddle_name), true ,
142+ platform::errors::InvalidArgument (
143+ " Paddle variable(%s) not used by cinn" , paddle_name));
138144
139- auto exclude_fn = [&all_cinn_names](const auto & cinn_name) {
140- all_cinn_names.erase (cinn_name);
141- };
145+ const auto & cinn_name = paddle2cinn_varmap_.at (paddle_name);
146+ CheckTensorEquivalent (paddle_name, *paddle_tensor, GetCinnTensor (cinn_name));
147+ return SetArgument (cinn_name, paddle_tensor);
148+ }
142149
143- std::for_each (input_cinn_names.begin (), input_cinn_names.end (), exclude_fn);
144- std::for_each (output_cinn_names.begin (), output_cinn_names.end (), exclude_fn);
145- return {all_cinn_names.begin (), all_cinn_names.end ()};
150+ void CinnLaunchContext::AssignInternalVariable (const std::string& cinn_name,
151+ LoDTensor* paddle_tensor) {
152+ PADDLE_ENFORCE_GT (cinn_variable_names_.count (cinn_name), 0 ,
153+ platform::errors::InvalidArgument (
154+ " Variable(%s) not found in cinn socpe." , cinn_name));
155+ CheckTensorEquivalent (cinn_name, *paddle_tensor, GetCinnTensor (cinn_name));
156+ return SetArgument (cinn_name, paddle_tensor);
146157}
147158
148- std::unique_ptr<cinn_buffer_t > ShareTensorWithCinnBuffer (LoDTensor* tensor) {
159+ std::unique_ptr<cinn_buffer_t > CinnLaunchContext::ShareTensorWithCinnBuffer (
160+ LoDTensor* tensor) {
149161 // convert paddle dimensions array to cinn format
150162 std::vector<cinn_dimension_t > cinn_dims (tensor->dims ().size ());
151163 for (auto i = 0 ; i < tensor->dims ().size (); ++i) {
@@ -159,17 +171,29 @@ std::unique_ptr<cinn_buffer_t> ShareTensorWithCinnBuffer(LoDTensor* tensor) {
159171 return cinn_buffer;
160172}
161173
162- void CheckArgumentsNotMissed (
163- const CinnScope& cinn_scope,
164- const std::map<std::string, cinn_pod_value_t >& name2argument) {
165- auto cinn_var_names = cinn_scope.var_names ();
166- std::for_each (cinn_var_names.begin (), cinn_var_names.end (),
167- [&name2argument](const auto & name_view) {
168- PADDLE_ENFORCE_GT (
169- name2argument.count (name_view.data ()), 0 ,
170- platform::errors::InvalidArgument (
171- " Parameter(%s) is not assgined." , name_view.data ()));
174+ void CinnLaunchContext::SetArgument (const std::string& cinn_name,
175+ LoDTensor* paddle_tensor) {
176+ auto buffer = ShareTensorWithCinnBuffer (paddle_tensor);
177+ name2argument_.emplace (cinn_name, buffer.get ());
178+ hold_buffers_.emplace_back (std::move (buffer));
179+ VLOG (4 ) << " SetArgument-" << name2argument_.size () << " : "
180+ << " name(" << cinn_name << " ), "
181+ << " type(" << framework::DataTypeToString (paddle_tensor->type ())
182+ << " ), dims(" << paddle_tensor->dims () << " )." ;
183+ }
184+
185+ const std::map<std::string, cinn_pod_value_t >&
186+ CinnLaunchContext::FinalizeArguments () const {
187+ // Check all execution parameters are assigned valued.
188+ std::for_each (cinn_variable_names_.begin (), cinn_variable_names_.end (),
189+ [this ](const auto & var_name) {
190+ PADDLE_ENFORCE_GT (name2argument_.count (var_name), 0 ,
191+ platform::errors::InvalidArgument (
192+ " Variable(%s) is missed for launching "
193+ " compiled program execution" ,
194+ var_name));
172195 });
196+ return name2argument_;
173197}
174198
175199} // namespace details
0 commit comments