add mkldnn shapeblob cache clear strategy#18513
add mkldnn shapeblob cache clear strategy#18513luotao1 merged 4 commits intoPaddlePaddle:developfrom luotao1:shape_blob_clear_cache
Conversation
test=develop
|
@LeoZhao-Intel Please take a review! |
| thread_local std::string cur_input_shape_str = ""; | ||
| // the cache size of different input shapes for MKLDNN. | ||
| // Default 1 means fixed input shape, not dynamic shape. | ||
| thread_local int cur_input_shape_cache_size = 1; |
There was a problem hiding this comment.
how about cur_input_shape_cache_capacity instead of size? size here is a kind of real data size, while it is for max cache size in my understanding.
There was a problem hiding this comment.
Got it. I will change.
| cur_input_shape_str = input_shape_str; | ||
| } | ||
| std::string get_cur_input_shape_str(void) { return cur_input_shape_str; } | ||
| void set_cur_input_shape_cache_size(int input_shape_cache_size) { |
There was a problem hiding this comment.
set_cur_input_shape_cache_capacity?
There was a problem hiding this comment.
Got it. I will change.
|
|
||
| void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); } | ||
|
|
||
| size_t MKLDNNDeviceContext::GetShapeBlobSize(int mkldnn_session_id) const { |
There was a problem hiding this comment.
why need a explicit input parameter mkldnn_session_id? can we just get it from thread_local? seems there is no case for user getting shapeblobsize for other session id with different id in thread_local
There was a problem hiding this comment.
can we just get it from thread_local
Could you give more details on how to get it from thread_local?
why need a explicit input parameter mkldnn_session_id?
if we don't have mkldnn_session_id parameter, line428 pMap->find(mkldnn_session_id) could not find. Or do you mean there is only one session_id in pMap?
There was a problem hiding this comment.
mkldnn_session_id is already stored in cur_mkldnn_session_id, it is a thread_local variable, so if user call it in same thread with predictor.run, it can easily get id from cur_mkldnn_session_id.
There was a problem hiding this comment.
if user call getblobsize in same thread with predictor.run, then code is like:
BlobMap* pMap = p_blobmap_.get();
auto map_it = pMap->find(cur_mkldnn_session_id);
if (map_it == pMap->end()) {
LOG(FATAL) << "MKLDNNDeviceContext don't find mkldnn_session_id : "
<< mkldnn_session_id;
}
return map_it->second->size();
}
If we still want user to call this function in other threads, then this parameter is necessary.
There was a problem hiding this comment.
Got it. I will change it.
| if (key_it == sBlob->end()) { | ||
| // In cache clearing mode, cur_input_shape_cache_size defines max pblob | ||
| // capacity | ||
| if ((tid == kMKLDNNSessionID_CacheClearing) && |
There was a problem hiding this comment.
better rename tid to sid (session id) to align with kMKLDNNSessionID_xxx
There was a problem hiding this comment.
Got it. I will change.
|
|
||
| // default mkldnn session id | ||
| constexpr size_t kMKLDNNSessionID_Default = 0; | ||
| constexpr int kMKLDNNSessionID_Default = 0; |
There was a problem hiding this comment.
why change to int? size_t is used for aligning with std::hash(this_thread::gettid()), otherwise it may need a static_cast
There was a problem hiding this comment.
Got it. I will change.
| (sBlob->size() == static_cast<size_t>(cur_input_shape_cache_size))) { | ||
| VLOG(2) << "tid=" << tid | ||
| << ", remove all head blob of shape: " << sBlob->begin()->first; | ||
| sBlob->erase(sBlob->begin()->first); |
There was a problem hiding this comment.
why here remove sBlob->begin()->first? seems it is wrong, should be erase(sBlob->begin()), sBlob->begin()->first is string.
BTW, it doesn't really mean removing the head or first one because sBlob is std::unordered_map and its index method is not same with vector or queue,.
There was a problem hiding this comment.
map could erase by key, see: https://www.geeksforgeeks.org/map-erase-function-in-c-stl/
it doesn't really mean removing the head or first one because sBlob is
Got it. I will change the LOG. For code, since we can erase any one of the sBlob, it runs successfully.
There was a problem hiding this comment.
you are right, there are 3 erase functions.
test=develop
|
All done, please review again @LeoZhao-Intel |
| // In cache clearing mode, cur_input_shape_cache_capacity defines | ||
| // max pblob capacity | ||
| if ((sid == kMKLDNNSessionID_CacheClearing) && | ||
| (sBlob->size() == |
There was a problem hiding this comment.
do you think">=" is better here ?
There was a problem hiding this comment.
when ==, it will clean. Thus > does not work.
What's the case when > works?
There was a problem hiding this comment.
so far I still not see cases for >, but from code level, it looks safer to use >=.
test=develop
|
LGTM |
|
@jczaja Please take a review! |
|
|
||
| size_t MKLDNNDeviceContext::GetShapeBlobSize() const { | ||
| BlobMap* pMap = p_blobmap_.get(); | ||
| auto map_it = pMap->find(cur_mkldnn_session_id); |
There was a problem hiding this comment.
currently it is only for UT purposes? Perhaps it would be safer to guard whole function with critical section , the same mutex as for SetBlob and GetBlob. In particular if there is plan that to use GetShapeBlobSize with parallel executor. Apart from that LGTM
There was a problem hiding this comment.
Got it. You are right.
|
@LeoZhao-Intel @jianhang-liu The number of |
test=develop
|
@jczaja @LeoZhao-Intel I add the mutex, and |
|
LGTM |
ghost
left a comment
There was a problem hiding this comment.
Please consider whether we really need VLOG(2) considering those log are just detail information of cache usage.
Since those logs are used by MKLDNN only, what's your opinion about it? @LeoZhao-Intel @jczaja |
|
I am fine with VLOG(2), but not sure if there is a clear rule for log level definition. |

This PR
shapeblob.size() == mkldnn_input_shape_cache_size, erase the first one of shapeblob.TestMkldnnCacheClearto ensure this cache clear strategyshapeblob.size()always = 1.shapeblob.size() = mkldnn_input_shape_cache_size.platform::get_cur_input_shape_strfunction, and addsize_t MKLDNNDeviceContext::GetShapeBlobSize(int mkldnn_session_id)to get the ShapeBlob size by mkldnn_session_id.TODO
platform::set_cur_input_shape_str(ss.str());andplatform::set_cur_mkldnn_session_id(platform::kMKLDNNSessionID_CacheClearing);, will be deprecated after enhanceconfig.EnableMKLDNN()interface.platform::set_cur_input_shape_str(ss.str());,platform::set_cur_mkldnn_session_id(platform::kMKLDNNSessionID_CacheClearing);etc, make them become Class function ofMKLDNNDeviceContext.