Skip to content
Merged
49 changes: 49 additions & 0 deletions paddle/fluid/framework/transfer_scope_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,61 @@
namespace paddle {
namespace framework {

#ifdef PADDLE_WITH_MKLDNN
using transfer_data_cache_map = std::unordered_map<size_t, Scope*>;
using transfer_scope_cache_map = std::unordered_set<Scope*>;
static std::unordered_map<size_t, transfer_data_cache_map*>
static_transfer_data_caches;
static std::unordered_map<size_t, transfer_scope_cache_map*>
static_transfer_scope_caches;
#endif

std::unordered_map<size_t, Scope*>& global_transfer_data_cache() {
#ifdef PADDLE_WITH_MKLDNN
size_t sid = platform::get_cur_mkldnn_session_id();

// if there is specific mkldnn tid setting from user.
if (sid != platform::kMKLDNNSessionID_Default) {
sid = std::hash<std::thread::id>()(std::this_thread::get_id());

static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_until_finish_this_job(acquire_barrier);

auto map_it = static_transfer_data_caches.find(sid);
if (map_it == static_transfer_data_caches.end()) {
auto* x = new transfer_data_cache_map;
static_transfer_data_caches[sid] = x;
return *x;
} else {
return *static_transfer_data_caches[sid];
}
}
#endif
thread_local auto* x = new std::unordered_map<size_t, Scope*>;
return *x;
}

std::unordered_set<Scope*>& global_transfer_scope_cache() {
#ifdef PADDLE_WITH_MKLDNN
size_t sid = platform::get_cur_mkldnn_session_id();

// if there is specific mkldnn session id setting from user.
if (sid != platform::kMKLDNNSessionID_Default) {
sid = std::hash<std::thread::id>()(std::this_thread::get_id());

static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_until_finish_this_job(acquire_barrier);

auto map_it = static_transfer_scope_caches.find(sid);
if (map_it == static_transfer_scope_caches.end()) {
auto* x = new transfer_scope_cache_map;
static_transfer_scope_caches[sid] = x;
return *x;
} else {
return *static_transfer_scope_caches[sid];
}
}
#endif
thread_local auto* x = new std::unordered_set<Scope*>;
return *x;
}
Expand Down
19 changes: 16 additions & 3 deletions paddle/fluid/inference/tests/api/analyzer_bert_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ TEST(Analyzer_bert, compare_determine) {
inputs);
}

TEST(Analyzer_bert, transfer_scope_cache) {
void verify_transfer_scope_cache(bool is_static = false) {
AnalysisConfig config;
SetConfig(&config);

Expand All @@ -251,6 +251,7 @@ TEST(Analyzer_bert, transfer_scope_cache) {
threads.emplace_back([&, i]() {
std::getline(fin, line);
ParseLine(line, &input);
if (is_static) platform::set_cur_mkldnn_session_id(1);
predictor->Run(input, &output, FLAGS_batch_size);
global_transfer_scope_cache.insert(
&paddle::framework::global_transfer_scope_cache());
Expand All @@ -264,9 +265,21 @@ TEST(Analyzer_bert, transfer_scope_cache) {
// Since paddle::framework::global_transfer_scope_cache() and
// paddle::framework::global_transfer_data_cache() are thread_local,
// their pointer should be different among different thread id.
PADDLE_ENFORCE(global_transfer_scope_cache.size(), threads_num);
PADDLE_ENFORCE(global_transfer_data_cache.size(), threads_num);
if (is_static) {
PADDLE_ENFORCE(global_transfer_scope_cache.size(), 1);
PADDLE_ENFORCE(global_transfer_data_cache.size(), 1);
} else {
PADDLE_ENFORCE(global_transfer_scope_cache.size(), threads_num);
PADDLE_ENFORCE(global_transfer_data_cache.size(), threads_num);
}
}

TEST(Analyzer_bert, threadlocal_transfer_scope_cache) {
verify_transfer_scope_cache(true);
}

TEST(Analyzer_bert, static_transfer_scope_cache) {
verify_transfer_scope_cache();
}
} // namespace inference
} // namespace paddle