Skip to content

Commit 228eb89

Browse files
authored
cinn_launch_op: skip checking input variables must be used (#37119)
Modify serveral implements on CinnLaunchOp: 1. Skip checking input variables must be used 2. Move current helper functions to a CinnlaunchContext
1 parent 6486e24 commit 228eb89

File tree

4 files changed

+298
-266
lines changed

4 files changed

+298
-266
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ endif()
171171
if (WITH_CINN)
172172
op_library(cinn_launch_op SRCS cinn_launch_op.cc cinn_launch_op.cu.cc DEPS transform_desc cinn_compiler cinn ${OP_HEADER_DEPS})
173173
cc_test(cinn_launch_op_test SRCS cinn_launch_op_test.cc DEPS cinn_compiler cinn_launch_op elementwise_add_op)
174+
set_tests_properties(cinn_launch_op_test PROPERTIES ENVIRONMENT OMP_NUM_THREADS=1)
174175
endif()
175176

176177
# FIXME(typhoonzero): operator deps may not needed.

paddle/fluid/operators/cinn_launch_op.cc

Lines changed: 100 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -62,90 +62,102 @@ void DebugCinnCompiledResult(const CinnCompiledObject& result) {
6262
<< "]";
6363
}
6464

65-
std::vector<std::string> MapPaddleVariablesToCinn(
66-
const std::vector<std::string>& paddle_names,
67-
const std::unordered_map<std::string, std::string>& paddle2cinn_varmap) {
68-
std::vector<std::string> result;
69-
result.reserve(result.size());
65+
void LaunchCinnExecution(const CinnCompiledObject& compiled_obj,
66+
const CinnLaunchContext& context) {
67+
compiled_obj.runtime_program->Execute(&context.FinalizeArguments());
68+
}
69+
70+
CinnLaunchContext::CinnLaunchContext(const CinnCompiledObject& compiled_obj)
71+
: paddle2cinn_varmap_(compiled_obj.paddle2cinn_varmap),
72+
cinn_scope_(compiled_obj.scope) {
73+
auto var_names = cinn_scope_->var_names();
74+
cinn_variable_names_.reserve(var_names.size());
7075
std::transform(
71-
paddle_names.begin(), paddle_names.end(), std::back_inserter(result),
72-
[&paddle2cinn_varmap](const std::string& pd_name) {
73-
PADDLE_ENFORCE_GT(paddle2cinn_varmap.count(pd_name), 0,
74-
platform::errors::NotFound(
75-
"Not found the corresponding cinn variable "
76-
"of paddle variable(%s) in compilation result.",
77-
pd_name));
78-
return paddle2cinn_varmap.at(pd_name);
79-
});
80-
return result;
76+
var_names.begin(), var_names.end(),
77+
std::inserter(cinn_variable_names_, cinn_variable_names_.end()),
78+
[](const auto& name_view) { return std::string(name_view.data()); });
8179
}
8280

83-
std::vector<CinnTensor> GetCinnTensorsFromCompiledScope(
84-
const std::vector<std::string>& cinn_names, const CinnScope& cinn_scope) {
85-
std::vector<CinnTensor> result;
86-
result.reserve(cinn_names.size());
87-
std::transform(cinn_names.begin(), cinn_names.end(),
88-
std::back_inserter(result),
89-
[&cinn_scope](const std::string& var_name) {
90-
PADDLE_ENFORCE_NOT_NULL(
91-
cinn_scope.FindVar(var_name),
92-
platform::errors::NotFound(
93-
"Variable(%s) not found in cinn scope.", var_name));
94-
return cinn_scope.GetTensor(var_name);
95-
});
96-
return result;
81+
bool CinnLaunchContext::IsVariableUsed(const std::string& paddle_name) {
82+
return paddle2cinn_varmap_.count(paddle_name) > 0 &&
83+
cinn_variable_names_.count(paddle2cinn_varmap_.at(paddle_name)) > 0;
84+
}
85+
86+
CinnTensor CinnLaunchContext::GetCinnTensor(const std::string& var_name) {
87+
PADDLE_ENFORCE_GT(cinn_variable_names_.count(var_name), 0,
88+
platform::errors::NotFound(
89+
"Variable(%s) not found in cinn scope.", var_name));
90+
return cinn_scope_->GetTensor(var_name);
91+
}
92+
93+
std::vector<std::string> CinnLaunchContext::GetInternalVariableNames() {
94+
std::unordered_set<std::string> all_parameters(cinn_variable_names_);
95+
std::for_each(name2argument_.begin(), name2argument_.end(),
96+
[&all_parameters](const auto& name2arg) {
97+
all_parameters.erase(name2arg.first);
98+
});
99+
return {all_parameters.begin(), all_parameters.end()};
100+
}
101+
102+
void CinnLaunchContext::MutableTensorData(const std::string& var_name,
103+
const platform::Place& place,
104+
LoDTensor* paddle_tensor,
105+
bool is_internal_var) {
106+
auto cinn_name = var_name;
107+
if (!is_internal_var) {
108+
PADDLE_ENFORCE_EQ(IsVariableUsed(var_name), true,
109+
platform::errors::InvalidArgument(
110+
"Paddle variable(%s) not used by cinn", var_name));
111+
cinn_name = paddle2cinn_varmap_.at(var_name);
112+
}
113+
114+
auto cinn_tensor = GetCinnTensor(cinn_name);
115+
// TODO(CtfGo): support mutable corresponding c++ type after CINN ready
116+
paddle_tensor->mutable_data<float>(
117+
framework::make_ddim(cinn_tensor->shape().data()), place);
97118
}
98119

99-
void CheckTensorEquivalent(const std::string& paddle_name,
100-
const LoDTensor* paddle_tensor,
101-
const CinnTensor& cinn_tensor) {
120+
void CinnLaunchContext::CheckTensorEquivalent(const std::string& paddle_name,
121+
const LoDTensor& paddle_tensor,
122+
const CinnTensor& cinn_tensor) {
102123
PADDLE_ENFORCE_EQ(
103-
paddle_tensor->IsInitialized(), true,
124+
paddle_tensor.IsInitialized(), true,
104125
platform::errors::InvalidArgument(
105-
"The tensor in variable(%s) is not initialized.", paddle_name));
126+
"Tensor in variable(%s) is not initialized.", paddle_name));
106127

107128
// check dimension
108129
auto cinn_dims = framework::make_ddim(cinn_tensor->shape().data());
109-
PADDLE_ENFORCE_EQ(paddle_tensor->dims(), cinn_dims,
110-
platform::errors::InvalidArgument(
111-
"The tensor dimension in variable(%s) "
112-
"is not equivalent, paddle is [%s] "
113-
"but cinn is [%s].",
114-
paddle_name, paddle_tensor->dims(), cinn_dims));
130+
PADDLE_ENFORCE_EQ(paddle_tensor.dims(), cinn_dims,
131+
platform::errors::PreconditionNotMet(
132+
"Tensors' shape in variable(%s) are not equivalent, "
133+
"paddle's shape = [%s], but cinn's shape = [%s].",
134+
paddle_name, paddle_tensor.dims(), cinn_dims));
115135

116136
// TODO(CtfGo): check the underlying data type after CINN ready
117137
}
118138

119-
void TensorMutableDataWithCinnInfo(const platform::Place& place,
120-
const CinnTensor& cinn_tensor,
121-
LoDTensor* paddle_tensor) {
122-
// TODO(CtfGo): support mutable corresponding c++ type after CINN ready
123-
paddle_tensor->mutable_data<float>(
124-
framework::make_ddim(cinn_tensor->shape().data()), place);
125-
}
126-
127-
std::vector<std::string> SeperateTempVar(
128-
const CinnScope& cinn_scope,
129-
const std::vector<std::string>& input_cinn_names,
130-
const std::vector<std::string>& output_cinn_names) {
131-
auto cinn_var_names = cinn_scope.var_names();
132-
std::unordered_set<std::string> all_cinn_names;
133-
all_cinn_names.reserve(cinn_var_names.size());
134-
std::transform(
135-
cinn_var_names.begin(), cinn_var_names.end(),
136-
std::inserter(all_cinn_names, all_cinn_names.end()),
137-
[](const auto& name_view) { return std::string(name_view.data()); });
139+
void CinnLaunchContext::AssignExternalVariable(const std::string& paddle_name,
140+
LoDTensor* paddle_tensor) {
141+
PADDLE_ENFORCE_EQ(IsVariableUsed(paddle_name), true,
142+
platform::errors::InvalidArgument(
143+
"Paddle variable(%s) not used by cinn", paddle_name));
138144

139-
auto exclude_fn = [&all_cinn_names](const auto& cinn_name) {
140-
all_cinn_names.erase(cinn_name);
141-
};
145+
const auto& cinn_name = paddle2cinn_varmap_.at(paddle_name);
146+
CheckTensorEquivalent(paddle_name, *paddle_tensor, GetCinnTensor(cinn_name));
147+
return SetArgument(cinn_name, paddle_tensor);
148+
}
142149

143-
std::for_each(input_cinn_names.begin(), input_cinn_names.end(), exclude_fn);
144-
std::for_each(output_cinn_names.begin(), output_cinn_names.end(), exclude_fn);
145-
return {all_cinn_names.begin(), all_cinn_names.end()};
150+
void CinnLaunchContext::AssignInternalVariable(const std::string& cinn_name,
151+
LoDTensor* paddle_tensor) {
152+
PADDLE_ENFORCE_GT(cinn_variable_names_.count(cinn_name), 0,
153+
platform::errors::InvalidArgument(
154+
"Variable(%s) not found in cinn socpe.", cinn_name));
155+
CheckTensorEquivalent(cinn_name, *paddle_tensor, GetCinnTensor(cinn_name));
156+
return SetArgument(cinn_name, paddle_tensor);
146157
}
147158

148-
std::unique_ptr<cinn_buffer_t> ShareTensorWithCinnBuffer(LoDTensor* tensor) {
159+
std::unique_ptr<cinn_buffer_t> CinnLaunchContext::ShareTensorWithCinnBuffer(
160+
LoDTensor* tensor) {
149161
// convert paddle dimensions array to cinn format
150162
std::vector<cinn_dimension_t> cinn_dims(tensor->dims().size());
151163
for (auto i = 0; i < tensor->dims().size(); ++i) {
@@ -159,17 +171,29 @@ std::unique_ptr<cinn_buffer_t> ShareTensorWithCinnBuffer(LoDTensor* tensor) {
159171
return cinn_buffer;
160172
}
161173

162-
void CheckArgumentsNotMissed(
163-
const CinnScope& cinn_scope,
164-
const std::map<std::string, cinn_pod_value_t>& name2argument) {
165-
auto cinn_var_names = cinn_scope.var_names();
166-
std::for_each(cinn_var_names.begin(), cinn_var_names.end(),
167-
[&name2argument](const auto& name_view) {
168-
PADDLE_ENFORCE_GT(
169-
name2argument.count(name_view.data()), 0,
170-
platform::errors::InvalidArgument(
171-
"Parameter(%s) is not assgined.", name_view.data()));
174+
void CinnLaunchContext::SetArgument(const std::string& cinn_name,
175+
LoDTensor* paddle_tensor) {
176+
auto buffer = ShareTensorWithCinnBuffer(paddle_tensor);
177+
name2argument_.emplace(cinn_name, buffer.get());
178+
hold_buffers_.emplace_back(std::move(buffer));
179+
VLOG(4) << "SetArgument-" << name2argument_.size() << ": "
180+
<< "name(" << cinn_name << "), "
181+
<< "type(" << framework::DataTypeToString(paddle_tensor->type())
182+
<< "), dims(" << paddle_tensor->dims() << ").";
183+
}
184+
185+
const std::map<std::string, cinn_pod_value_t>&
186+
CinnLaunchContext::FinalizeArguments() const {
187+
// Check all execution parameters are assigned valued.
188+
std::for_each(cinn_variable_names_.begin(), cinn_variable_names_.end(),
189+
[this](const auto& var_name) {
190+
PADDLE_ENFORCE_GT(name2argument_.count(var_name), 0,
191+
platform::errors::InvalidArgument(
192+
"Variable(%s) is missed for launching "
193+
"compiled program execution",
194+
var_name));
172195
});
196+
return name2argument_;
173197
}
174198

175199
} // namespace details

0 commit comments

Comments
 (0)