@@ -341,8 +341,6 @@ void AnalysisPredictor::MkldnnPreSet(
341341 platform::MKLDNNDeviceContext::tls ().set_cur_mkldnn_session_id (
342342 platform::MKLDNNDeviceContextThreadLocals::
343343 kMKLDNNSessionID_CacheClearing );
344- platform::MKLDNNDeviceContext::tls ().set_cur_input_shape_cache_capacity (
345- config_.mkldnn_cache_capacity_ );
346344 // Set current_input_shape for caching dynamic shape.
347345 std::stringstream ss;
348346 for (size_t i = 0 ; i < inputs_shape.size (); ++i) {
@@ -353,6 +351,9 @@ void AnalysisPredictor::MkldnnPreSet(
353351 VLOG (2 ) << " Set input shape=" << ss.str ();
354352 platform::MKLDNNDeviceContext::tls ().set_cur_input_shape_str (ss.str ());
355353 }
354+ platform::MKLDNNDeviceContext::tls ().set_cur_input_shape_cache_capacity (
355+ config_.mkldnn_cache_capacity_ );
356+
356357#endif
357358}
358359
@@ -368,10 +369,9 @@ void AnalysisPredictor::MkldnnPostReset() {
368369 CHECK_LE (shape_blob_size,
369370 static_cast <size_t >(config_.mkldnn_cache_capacity_ ));
370371 }
371- paddle::platform::MKLDNNDeviceContext::tls ().set_cur_mkldnn_session_id (
372- platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default );
373- platform::MKLDNNDeviceContext::tls ().set_cur_input_shape_cache_capacity (0 );
374- platform::MKLDNNDeviceContext::tls ().set_cur_input_shape_str (" " );
372+ // We cannot reset to the default cache settings
373+ // as there maybe CopyToCPU method used and oneDNN
374+ // primitives are used there so cache would grow
375375 }
376376#endif
377377}
0 commit comments