Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,5 +172,61 @@ TEST(Analyzer_MM_DNN, compare_determine) {
input_slots_all);
}

#ifdef PADDLE_WITH_MKLDNN
void TestMkldnnCacheClear(int mkldnn_input_shape_cache_capacity) {
AnalysisConfig config;
SetConfig(&config);
config.EnableMKLDNN();
// TODO(luotao): explicit following settings will be deprecated after enhance
// config.EnableMKLDNN() interface.
if (mkldnn_input_shape_cache_capacity > 0) {
platform::set_cur_mkldnn_session_id(
platform::kMKLDNNSessionID_CacheClearing);
platform::set_cur_input_shape_cache_capacity(
mkldnn_input_shape_cache_capacity);
}

std::vector<PaddleTensor> input, output;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);

int sample_num = 10;
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);

auto &pool = platform::DeviceContextPool::Instance();
auto *dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext *>(
pool.Get(platform::CPUPlace()));
for (int i = 0; i < sample_num; i++) {
PrepareInputs(&input, &data, FLAGS_batch_size);
if (mkldnn_input_shape_cache_capacity > 0) {
std::stringstream ss;
for (size_t i = 0; i < input.size(); i++) {
for (size_t j = 0; j < input[i].shape.size(); ++j) {
ss << input[i].shape[j] << "-";
}
}
// TODO(luotao): explicit following settings will be deprecated after
// enhance config.EnableMKLDNN() interface.
platform::set_cur_input_shape_str(ss.str());
}
predictor->Run(input, &output, 1);
}
if (mkldnn_input_shape_cache_capacity > 0) {
PADDLE_ENFORCE_EQ(dev_ctx->GetShapeBlobSize(),
mkldnn_input_shape_cache_capacity);
} else {
PADDLE_ENFORCE_EQ(dev_ctx->GetShapeBlobSize(), 1UL);
}
dev_ctx->ResetBlobMap();
}

TEST(Analyzer_MM_DNN, mkldnn_cache_clear) {
// 0 means do not use cache clear strategy.
TestMkldnnCacheClear(0);
// 4 means use cache clear strategy, and the
// mkldnn_input_shape_cache_capacity is 4.
TestMkldnnCacheClear(4);
}
#endif

} // namespace inference
} // namespace paddle
54 changes: 38 additions & 16 deletions paddle/fluid/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,44 +407,67 @@ thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default;
// - For fixed-shape, it's a null string in default.
// - For dynamic-shape, it's user specific.
thread_local std::string cur_input_shape_str = "";
// the cache capacity of different input shapes for MKLDNN.
// Default 1 means fixed input shape, not dynamic shape.
thread_local int cur_input_shape_cache_capacity = 1;
} // namespace

void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; }
size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; }
void set_cur_input_shape_str(std::string input_shape_str) {
cur_input_shape_str = input_shape_str;
}
std::string get_cur_input_shape_str(void) { return cur_input_shape_str; }
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity) {
cur_input_shape_cache_capacity = input_shape_cache_capacity;
}

void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); }

size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
BlobMap* pMap = p_blobmap_.get();
auto map_it = pMap->find(cur_mkldnn_session_id);
Copy link
Contributor

@jczaja jczaja Jul 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently it is only for UT purposes? Perhaps it would be safer to guard whole function with critical section , the same mutex as for SetBlob and GetBlob. In particular if there is plan that to use GetShapeBlobSize with parallel executor. Apart from that LGTM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. You are right.

if (map_it == pMap->end()) {
LOG(FATAL) << "MKLDNNDeviceContext don't find cur_mkldnn_session_id : "
<< cur_mkldnn_session_id;
}
return map_it->second->size();
}

void MKLDNNDeviceContext::SetBlob(const std::string& name,
std::shared_ptr<void> data) const {
BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<ShapeBlob> sBlob = nullptr;
std::shared_ptr<KeyBlob> pBlob = nullptr;

int tid = platform::get_cur_mkldnn_session_id();
int sid = platform::get_cur_mkldnn_session_id();

std::lock_guard<std::mutex> lock(*p_mutex_);

// Find ShapeBlob for current thread
auto map_it = pMap->find(tid);
// Find ShapeBlob for current mkldnn session id.
auto map_it = pMap->find(sid);

if (map_it == pMap->end()) {
// 1st time to set blob in current thread
sBlob = std::shared_ptr<ShapeBlob>(new ShapeBlob());
(*pMap)[tid] = sBlob;
VLOG(2) << "SetBlob: tid=" << tid << ", add new tid\n";
(*pMap)[sid] = sBlob;
VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
} else {
sBlob = map_it->second;
}

// Find KeyBlob for current input shape
std::string cur_input_shape_str = platform::get_cur_input_shape_str();
auto key_it = sBlob->find(cur_input_shape_str);

if (key_it == sBlob->end()) {
// In cache clearing mode, cur_input_shape_cache_capacity defines
// max pblob capacity
if ((sid == kMKLDNNSessionID_CacheClearing) &&
(sBlob->size() ==
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think">=" is better here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when ==, it will clean. Thus > does not work.
What's the case when > works?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so far I still not see cases for >, but from code level, it looks safer to use >=.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.

static_cast<size_t>(cur_input_shape_cache_capacity))) {
VLOG(2) << "sid=" << sid
<< ", remove all blobs of shape: " << sBlob->begin()->first;
sBlob->erase(sBlob->begin()->first);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why here remove sBlob->begin()->first? seems it is wrong, should be erase(sBlob->begin()), sBlob->begin()->first is string.
BTW, it doesn't really mean removing the head or first one because sBlob is std::unordered_map and its index method is not same with vector or queue,.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

map could erase by key, see: https://www.geeksforgeeks.org/map-erase-function-in-c-stl/

it doesn't really mean removing the head or first one because sBlob is

Got it. I will change the LOG. For code, since we can erase any one of the sBlob, it runs successfully.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, there are 3 erase functions.

}
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
(*sBlob)[cur_input_shape_str] = pBlob;
} else {
Expand All @@ -458,7 +481,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
} else {
blob_it->second = data; // set data to existing blob
}
VLOG(2) << "SetBlob: tid=" << tid << ", add blob=" << name << "\n";
VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
// lock will be automatically released when out of scope
return;
}
Expand All @@ -469,23 +492,22 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
std::shared_ptr<ShapeBlob> sBlob = nullptr;
std::shared_ptr<KeyBlob> pBlob = nullptr;

int tid = platform::get_cur_mkldnn_session_id();
int sid = platform::get_cur_mkldnn_session_id();

std::lock_guard<std::mutex> lock(*p_mutex_);

// Find ShapeBlob for current thread firstly
auto map_it = pMap->find(tid);
// Find ShapeBlob for current mkldnn session id firstly
auto map_it = pMap->find(sid);
if (map_it == pMap->end()) {
VLOG(2) << "GetBlob: tid=" << tid << ", miss tid\n";
VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
return nullptr;
}
std::string cur_input_shape_str = platform::get_cur_input_shape_str();
sBlob = map_it->second;

// Find KeyBlob for current input shape secondly
auto sBlob_it = sBlob->find(cur_input_shape_str);
if (sBlob_it == sBlob->end()) {
VLOG(2) << "GetBlob: tid=" << cur_input_shape_str
VLOG(2) << "GetBlob: sid=" << cur_input_shape_str
<< ", miss input_shape_str\n";
return nullptr;
}
Expand All @@ -495,11 +517,11 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
auto key_it = pBlob->find(name);

if (key_it == pBlob->end()) {
VLOG(2) << "GetBlob tid=" << tid << ", miss blob=" << name << "\n";
VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
return nullptr;
}

VLOG(2) << "GetBlob tid=" << tid << ", get blob=" << name << "\n";
VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
// lock will be automatically released when out of scope
return key_it->second;
}
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
void set_cur_mkldnn_session_id(size_t);
size_t get_cur_mkldnn_session_id(void);
void set_cur_input_shape_str(std::string input_shape_str);
std::string get_cur_input_shape_str(void);
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);

class MKLDNNDeviceContext : public CPUDeviceContext {
public:
Expand All @@ -408,6 +408,9 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
// Remove all entries from the blob map
void ResetBlobMap() const;

// Get the ShapeBlob size in cur_mkldnn_session_id.
size_t GetShapeBlobSize() const;

// Set data to blob (i.e. name/data pair). Create blob if not existing
void SetBlob(const std::string& name, std::shared_ptr<void> data) const;

Expand Down