-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[oneDNN] Disable caching of Reorder operation #35664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
be98a0f
24d5473
bb0522f
8bfa1f3
072121c
a330dfb
d7e3bc3
5fdbd9a
13439fb
0ab02ff
0c7640a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1071,138 +1071,73 @@ class ActivationMKLDNNHandler | |
| } | ||
| }; | ||
|
|
||
| class ReorderMKLDNNHandler : public MKLDNNHandler { | ||
| class ReorderMKLDNNHandler { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would encourage you to extract this class to separate file. It's not a template so it would make a binary size smaller and compilation faster. |
||
| public: | ||
| ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT | ||
| framework::proto::VarType::Type vtype, | ||
| mkldnn::memory::data_type dtype, | ||
| const platform::MKLDNNDeviceContext& dev_ctx, | ||
| mkldnn::engine engine, const std::string& base_key) | ||
| : platform::MKLDNNHandler(dev_ctx, engine, base_key), | ||
| dims_(dims), | ||
| mkldnn::memory::data_type dtype, mkldnn::engine engine) | ||
| : dims_(dims), | ||
| vtype_(vtype), | ||
| vtype_dst_(vtype), | ||
| dtype_(dtype), | ||
| dtype_dst_(dtype) {} | ||
| dtype_dst_(dtype), | ||
| engine_(engine) {} | ||
|
|
||
| ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT | ||
| framework::proto::VarType::Type vtype, | ||
| mkldnn::memory::data_type dtype, | ||
| framework::proto::VarType::Type vtype_dst, | ||
| mkldnn::memory::data_type dtype_dst, | ||
| const platform::MKLDNNDeviceContext& dev_ctx, | ||
| mkldnn::engine engine, const std::string& base_key) | ||
| : platform::MKLDNNHandler(dev_ctx, engine, base_key), | ||
| dims_(dims), | ||
| mkldnn::engine engine) | ||
| : dims_(dims), | ||
| vtype_(vtype), | ||
| vtype_dst_(vtype_dst), | ||
| dtype_(dtype), | ||
| dtype_dst_(dtype_dst) {} | ||
| dtype_dst_(dtype_dst), | ||
| engine_(engine) {} | ||
|
|
||
| std::shared_ptr<mkldnn::memory> AcquireSrcMemory( | ||
| const MKLDNNMemoryFormat& fmt, void* ptr) { | ||
| return this->AcquireMemory(dims_, dtype_, fmt, ptr, "@user_src_mem_p"); | ||
| auto md = mkldnn::memory::desc(dims_, dtype_, fmt); | ||
| return std::make_shared<mkldnn::memory>(md, engine_, ptr); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we're no longer caching those |
||
| } | ||
|
|
||
| std::shared_ptr<mkldnn::memory> AcquireSubmemory( | ||
| const std::vector<int64_t>& dims, const std::vector<int64_t>& offset, | ||
| const std::shared_ptr<mkldnn::memory>& mem_p, int submemory_number = 0) { | ||
| std::string local_key = key_; | ||
| local_key.append("@submem") | ||
| .append(std::to_string(submemory_number)) | ||
| .append("_p"); | ||
|
|
||
| auto sub_mem_p = | ||
| std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); | ||
| if (sub_mem_p == nullptr) { | ||
| auto sub_md = mem_p->get_desc().submemory_desc(dims, {offset}); | ||
| sub_mem_p = std::make_shared<mkldnn::memory>(sub_md, engine_, | ||
| mem_p->get_data_handle()); | ||
| dev_ctx_.SetBlob(local_key, sub_mem_p); | ||
| } else { | ||
| sub_mem_p->set_data_handle(mem_p->get_data_handle()); | ||
| } | ||
| const std::shared_ptr<mkldnn::memory>& mem_p) { | ||
| auto sub_md = mem_p->get_desc().submemory_desc(dims, {offset}); | ||
| auto sub_mem_p = std::make_shared<mkldnn::memory>(sub_md, engine_, | ||
| mem_p->get_data_handle()); | ||
| return sub_mem_p; | ||
| } | ||
|
|
||
| std::shared_ptr<mkldnn::memory> AcquireDstMemory( | ||
| framework::Tensor* output, const MKLDNNMemoryFormat& fmt, | ||
| platform::Place place) { | ||
| auto local_key = key_ + "@user_dst_mem_p"; | ||
| auto mem_p = | ||
| std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); | ||
| if (mem_p == nullptr) { | ||
| auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_dst_, fmt); | ||
| auto dst_data = | ||
| output->mutable_data(place, vtype_dst_, dst_md.get_size()); | ||
|
|
||
| mem_p = std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data); | ||
| dev_ctx_.SetBlob(local_key, mem_p); | ||
| } else { | ||
| // Even if memory object exists , we may be using it for diffrent tensor | ||
| auto dst_data = | ||
| output->mutable_data(place, vtype_dst_, mem_p->get_desc().get_size()); | ||
| mem_p->set_data_handle(dst_data); | ||
| } | ||
| return mem_p; | ||
| auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_dst_, fmt); | ||
| auto dst_data = output->mutable_data(place, vtype_dst_, dst_md.get_size()); | ||
| return std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data); | ||
| } | ||
|
|
||
| std::shared_ptr<mkldnn::memory> AcquireDstMemory( | ||
| framework::Tensor* output, const std::vector<int64_t>& dims, | ||
| const int memory_number, const MKLDNNMemoryFormat& fmt, | ||
| platform::Place place) { | ||
| auto local_key = | ||
| key_ + "@user_dst_mem" + std::to_string(memory_number) + "_p"; | ||
| auto mem_p = | ||
| std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); | ||
| if (mem_p == nullptr) { | ||
| auto dst_md = platform::MKLDNNMemDesc(dims, dtype_dst_, fmt); | ||
| auto dst_data = | ||
| output->mutable_data(place, vtype_dst_, dst_md.get_size()); | ||
|
|
||
| mem_p = std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data); | ||
| dev_ctx_.SetBlob(local_key, mem_p); | ||
| } else { | ||
| // Even if memory object exists , we may be using it for diffrent tensor | ||
| auto dst_data = | ||
| output->mutable_data(place, vtype_dst_, mem_p->get_desc().get_size()); | ||
| mem_p->set_data_handle(dst_data); | ||
| } | ||
| return mem_p; | ||
| } | ||
|
|
||
| std::shared_ptr<mkldnn::reorder> AcquireReorder( | ||
| std::shared_ptr<mkldnn::memory> dst_memory_p, | ||
| std::shared_ptr<mkldnn::memory> src_memory_p, int reorder_number) { | ||
| auto prim_key = key_ + "@reorder" + std::to_string(reorder_number) + "_p"; | ||
| auto reorder_p = | ||
| std::static_pointer_cast<mkldnn::reorder>(dev_ctx_.GetBlob(prim_key)); | ||
| if (reorder_p == nullptr) { | ||
| reorder_p = | ||
| std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p)); | ||
| dev_ctx_.SetBlob(prim_key, reorder_p); | ||
| } | ||
| return reorder_p; | ||
| const MKLDNNMemoryFormat& fmt, platform::Place place) { | ||
| auto dst_md = platform::MKLDNNMemDesc(dims, dtype_dst_, fmt); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the above overload of this function you use this class member |
||
| auto dst_data = output->mutable_data(place, vtype_dst_, dst_md.get_size()); | ||
| return std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data); | ||
| } | ||
|
|
||
| std::shared_ptr<mkldnn::reorder> AcquireReorder( | ||
| std::shared_ptr<mkldnn::memory> dst_memory_p, | ||
| std::shared_ptr<mkldnn::memory> src_memory_p) { | ||
| auto prim_key = key_ + "@reorder_p"; | ||
| auto reorder_p = | ||
| std::static_pointer_cast<mkldnn::reorder>(dev_ctx_.GetBlob(prim_key)); | ||
| if (reorder_p == nullptr) { | ||
| reorder_p = | ||
| std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p)); | ||
| dev_ctx_.SetBlob(prim_key, reorder_p); | ||
| } | ||
| return reorder_p; | ||
| return std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p)); | ||
| } | ||
|
|
||
| private: | ||
| std::vector<int64_t> dims_; | ||
| framework::proto::VarType::Type vtype_, vtype_dst_; | ||
| mkldnn::memory::data_type dtype_, dtype_dst_; | ||
|
Comment on lines
1138
to
1139
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the benefit of storing both |
||
| mkldnn::engine engine_; | ||
| }; | ||
|
|
||
| template <typename T> | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, revert it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad when resolving conflicts. Good catch! Thanks