Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions paddle/cinn/hlir/framework/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,45 @@ using cinn::common::Type;

using cinn::hlir::op::ExternalApiRegistry;

// Collect the temporary tensors from a computational graph.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个如果没有额外依赖的话,建议放到cinn/lang/lower.cc中,与lang::GetTempBuffers 其他函数放到一起。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看了下lang的namespace下已经有一个同名函数了,但可以在function name区分下?

std::vector<ir::Buffer> GetTempBuffers(
const std::vector<cinn::ir::Tensor>& tensor_args, Expr body) {
std::unordered_set<std::string> tensor_arg_names;
std::unordered_set<std::string> buffer_arg_names;
for (auto& tensor : tensor_args) {
tensor_arg_names.insert(tensor->name);
if (tensor->buffer.defined()) {
buffer_arg_names.insert(tensor->buffer->name);
}
}
std::map<std::string, ir::Buffer>
name_to_buffer; // used to avoid duplication.

auto all_temp_tensors =
ir::ir_utils::CollectIRNodesWithoutTensor(body, [&](const Expr* x) {
return x->as_tensor() && x->as_tensor()->buffer.defined() &&
((!buffer_arg_names.count(x->as_tensor()->buffer->name) &&
!tensor_arg_names.count(x->as_tensor()->name)) ||
utils::Endswith(x->as_tensor()->buffer->name, "temp_buffer"));
});
for (auto& e : all_temp_tensors) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里相对于旧的逻辑少了CollectIRNodesWithoutTensor,这个是符合预期的吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我理解是符合预期的

auto buffer_name = e.as_tensor()->buffer->name;
if (!name_to_buffer.count(buffer_name)) {
name_to_buffer[buffer_name] = e.as_tensor()->buffer;
} else {
// TODO(phlrain): why update
if (e.as_tensor()->buffer->numel() <
name_to_buffer[buffer_name]->numel()) {
name_to_buffer[buffer_name] = e.as_tensor()->buffer;
}
}
}

std::vector<ir::Buffer> temp_buffers;
for (auto& i : name_to_buffer) temp_buffers.push_back(i.second);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

虽然for的body只有一行,也建议用{}包裹一下

return temp_buffers;
}

OpLowererImpl::OpLowererImpl(
const absl::flat_hash_map<std::string, Type>& type_dict,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
Expand Down Expand Up @@ -300,9 +339,7 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
}
#endif
// 2.Prepare temp buffers
poly::StageMap stages;
auto temp_buffers =
lang::GetTempBuffers(*group_func_arg_tensors, stages, func_body);
auto temp_buffers = GetTempBuffers(*group_func_arg_tensors, func_body);
// 3.Building LoweredFunc
auto func = ir::_LoweredFunc_::Make(group->GetFuncName(),
group_func_args,
Expand Down