-
Notifications
You must be signed in to change notification settings - Fork 6k
add mkldnn shapeblob cache clear strategy #18513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -407,44 +407,67 @@ thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default; | |
| // - For fixed-shape, it's a null string in default. | ||
| // - For dynamic-shape, it's user specific. | ||
| thread_local std::string cur_input_shape_str = ""; | ||
| // the cache capacity of different input shapes for MKLDNN. | ||
| // Default 1 means fixed input shape, not dynamic shape. | ||
| thread_local int cur_input_shape_cache_capacity = 1; | ||
| } // namespace | ||
|
|
||
| void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; } | ||
| size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; } | ||
| void set_cur_input_shape_str(std::string input_shape_str) { | ||
| cur_input_shape_str = input_shape_str; | ||
| } | ||
| std::string get_cur_input_shape_str(void) { return cur_input_shape_str; } | ||
| void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity) { | ||
| cur_input_shape_cache_capacity = input_shape_cache_capacity; | ||
| } | ||
|
|
||
| void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); } | ||
|
|
||
| size_t MKLDNNDeviceContext::GetShapeBlobSize() const { | ||
| BlobMap* pMap = p_blobmap_.get(); | ||
| auto map_it = pMap->find(cur_mkldnn_session_id); | ||
| if (map_it == pMap->end()) { | ||
| LOG(FATAL) << "MKLDNNDeviceContext don't find cur_mkldnn_session_id : " | ||
| << cur_mkldnn_session_id; | ||
| } | ||
| return map_it->second->size(); | ||
| } | ||
|
|
||
| void MKLDNNDeviceContext::SetBlob(const std::string& name, | ||
| std::shared_ptr<void> data) const { | ||
| BlobMap* pMap = p_blobmap_.get(); | ||
| std::shared_ptr<ShapeBlob> sBlob = nullptr; | ||
| std::shared_ptr<KeyBlob> pBlob = nullptr; | ||
|
|
||
| int tid = platform::get_cur_mkldnn_session_id(); | ||
| int sid = platform::get_cur_mkldnn_session_id(); | ||
|
|
||
| std::lock_guard<std::mutex> lock(*p_mutex_); | ||
|
|
||
| // Find ShapeBlob for current thread | ||
| auto map_it = pMap->find(tid); | ||
| // Find ShapeBlob for current mkldnn session id. | ||
| auto map_it = pMap->find(sid); | ||
|
|
||
| if (map_it == pMap->end()) { | ||
| // 1st time to set blob in current thread | ||
| sBlob = std::shared_ptr<ShapeBlob>(new ShapeBlob()); | ||
| (*pMap)[tid] = sBlob; | ||
| VLOG(2) << "SetBlob: tid=" << tid << ", add new tid\n"; | ||
| (*pMap)[sid] = sBlob; | ||
| VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n"; | ||
| } else { | ||
| sBlob = map_it->second; | ||
| } | ||
|
|
||
| // Find KeyBlob for current input shape | ||
| std::string cur_input_shape_str = platform::get_cur_input_shape_str(); | ||
| auto key_it = sBlob->find(cur_input_shape_str); | ||
|
|
||
| if (key_it == sBlob->end()) { | ||
| // In cache clearing mode, cur_input_shape_cache_capacity defines | ||
| // max pblob capacity | ||
| if ((sid == kMKLDNNSessionID_CacheClearing) && | ||
| (sBlob->size() == | ||
|
||
| static_cast<size_t>(cur_input_shape_cache_capacity))) { | ||
| VLOG(2) << "sid=" << sid | ||
| << ", remove all blobs of shape: " << sBlob->begin()->first; | ||
| sBlob->erase(sBlob->begin()->first); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why here remove sBlob->begin()->first? seems it is wrong, should be erase(sBlob->begin()), sBlob->begin()->first is string.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. map could erase by key, see: https://www.geeksforgeeks.org/map-erase-function-in-c-stl/
Got it. I will change the LOG. For code, since we can erase any one of the sBlob, it runs successfully.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you are right, there are 3 erase functions. |
||
| } | ||
| pBlob = std::shared_ptr<KeyBlob>(new KeyBlob()); | ||
| (*sBlob)[cur_input_shape_str] = pBlob; | ||
| } else { | ||
|
|
@@ -458,7 +481,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, | |
| } else { | ||
| blob_it->second = data; // set data to existing blob | ||
| } | ||
| VLOG(2) << "SetBlob: tid=" << tid << ", add blob=" << name << "\n"; | ||
| VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n"; | ||
| // lock will be automatically released when out of scope | ||
| return; | ||
| } | ||
|
|
@@ -469,23 +492,22 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob( | |
| std::shared_ptr<ShapeBlob> sBlob = nullptr; | ||
| std::shared_ptr<KeyBlob> pBlob = nullptr; | ||
|
|
||
| int tid = platform::get_cur_mkldnn_session_id(); | ||
| int sid = platform::get_cur_mkldnn_session_id(); | ||
|
|
||
| std::lock_guard<std::mutex> lock(*p_mutex_); | ||
|
|
||
| // Find ShapeBlob for current thread firstly | ||
| auto map_it = pMap->find(tid); | ||
| // Find ShapeBlob for current mkldnn session id firstly | ||
| auto map_it = pMap->find(sid); | ||
| if (map_it == pMap->end()) { | ||
| VLOG(2) << "GetBlob: tid=" << tid << ", miss tid\n"; | ||
| VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n"; | ||
| return nullptr; | ||
| } | ||
| std::string cur_input_shape_str = platform::get_cur_input_shape_str(); | ||
| sBlob = map_it->second; | ||
|
|
||
| // Find KeyBlob for current input shape secondly | ||
| auto sBlob_it = sBlob->find(cur_input_shape_str); | ||
| if (sBlob_it == sBlob->end()) { | ||
| VLOG(2) << "GetBlob: tid=" << cur_input_shape_str | ||
| VLOG(2) << "GetBlob: sid=" << cur_input_shape_str | ||
| << ", miss input_shape_str\n"; | ||
| return nullptr; | ||
| } | ||
|
|
@@ -495,11 +517,11 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob( | |
| auto key_it = pBlob->find(name); | ||
|
|
||
| if (key_it == pBlob->end()) { | ||
| VLOG(2) << "GetBlob tid=" << tid << ", miss blob=" << name << "\n"; | ||
| VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n"; | ||
| return nullptr; | ||
| } | ||
|
|
||
| VLOG(2) << "GetBlob tid=" << tid << ", get blob=" << name << "\n"; | ||
| VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n"; | ||
| // lock will be automatically released when out of scope | ||
| return key_it->second; | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently it is only for UT purposes? Perhaps it would be safer to guard whole function with critical section , the same mutex as for SetBlob and GetBlob. In particular if there is plan that to use GetShapeBlobSize with parallel executor. Apart from that LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. You are right.