diff --git a/src/utils/environment.h b/src/utils/environment.h index 2630ff74..ddcb2df4 100644 --- a/src/utils/environment.h +++ b/src/utils/environment.h @@ -68,6 +68,9 @@ class Env { // get OneCCL Enabled bool getOneCCLEnabled() { return oneCCLEnabled; } + // get Primitive Cache M + int getPrimitiveCacheM() { return primitiveCacheM; } + private: Env() { // init Verbose @@ -105,6 +108,9 @@ class Env { // init OneCCL Enabled initOneCCLEnabled(); + + // init Primitive Cache M + initPrimitiveCacheM(); } // Verbose @@ -260,4 +266,19 @@ class Env { char *xftOneCCLValue = getenv("XFT_ONECCL"); oneCCLEnabled = xftOneCCLValue != nullptr ? std::atoi(xftOneCCLValue) : false; } + + // XFT_PRIMITIVE_CACHE_M + int primitiveCacheM = 256; + void initPrimitiveCacheM() { + char *xFTPrimitiveCacheMValue = getenv("XFT_PRIMITIVE_CACHE_M"); + if (xFTPrimitiveCacheMValue != NULL) { + int value = atoi(xFTPrimitiveCacheMValue); + if (value >= 0) + primitiveCacheM = value; + else + printf("[ERROR] XFT_PRIMITIVE_CACHE_M value need to be greater than or equal to 0.\n"); + } else { + primitiveCacheM = 256; + } + } }; \ No newline at end of file diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index a8cc291c..d1d44464 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -54,6 +54,7 @@ class MMHelper { } AMXThresholdM = Env::getInstance().getAMXThresholdM(); + primitiveCacheM = Env::getInstance().getPrimitiveCacheM(); cpu_engine = new dnnl::engine(dnnl::engine::kind::cpu, 0); cpu_stream = new dnnl::stream(*cpu_engine); } @@ -1559,6 +1560,7 @@ class MMHelper { dnnl::stream *cpu_stream; int AMXThresholdM; + int primitiveCacheM; enum matmul_kinds { Basic = 0, @@ -1571,13 +1573,26 @@ class MMHelper { Resext, }; - template - std::string create_key(bool transA, int M, int N, int K, int matmul_kind, const Twei *packedB) { + std::string create_key(bool transA, int M, int N, int K, int matmul_kind) { std::stringstream key; - key << transA << "_" << M << "_" << N << "_" << K << "_" << matmul_kind << "_" << packedB; + key << transA << "_" << M << "_" << N << "_" << K << "_" << matmul_kind; return key.str(); } + // Cache primitive_desc and matmul + bool cache_matmul_primitive(dnnl::matmul::primitive_desc *matmul_pd, dnnl::matmul *matmul_prim, bool transA, int M, + int N, int K, int matmul_kind) { + // If M < primitiveCacheM or a power of 2, then cache. + if (M <= primitiveCacheM || ((M & (M - 1)) == 0)) { + std::string key = create_key(transA, M, N, K, matmul_kind); + std::tuple value(matmul_pd, matmul_prim); + matmul_hub[key] = value; + return true; + } else { + return false; + } + } + dnnl::memory::format_tag get_onednn_input_layout(dnnl::memory::data_type dt) { if (this->kind == dnnl::engine::kind::cpu) { return dnnl::memory::format_tag::ab; @@ -1722,7 +1737,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, postAlg, (const float *)nullptr); + std::string key = create_key(transA, M, N, K, postAlg); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -1797,9 +1812,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, postAlg, (const float *)nullptr); - std::tuple value(matmul_pd, matmul_prim); - matmul_hub[key] = value; + cache_matmul_primitive(matmul_pd, matmul_prim, transA, M, N, K, postAlg); } // Repack and convert input data. @@ -1875,7 +1888,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, postAlg, packedB); + std::string key = create_key(transA, M, N, K, postAlg); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -1937,9 +1950,7 @@ class MMHelper { } matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, postAlg, packedB); - std::tuple value(matmul_pd, matmul_prim); - matmul_hub[key] = value; + cache_matmul_primitive(matmul_pd, matmul_prim, transA, M, N, K, postAlg); } // Repack and convert input data. @@ -1993,7 +2004,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd, packedB); + std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -2028,9 +2039,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd, packedB); - std::tuple value(matmul_pd, matmul_prim); - matmul_hub[key] = value; + cache_matmul_primitive(matmul_pd, matmul_prim, transA, M, N, K, matmul_kinds::BiasAdd); } // Repack and convert input data. @@ -2086,7 +2095,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu, packedB); + std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -2129,9 +2138,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu, packedB); - std::tuple value(matmul_pd, matmul_prim); - matmul_hub[key] = value; + cache_matmul_primitive(matmul_pd, matmul_prim, transA, M, N, K, matmul_kinds::BiasAdd_Relu); } // Repack and convert input data. @@ -2187,7 +2194,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul, packedB); + std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -2232,9 +2239,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul, packedB); - std::tuple value(matmul_pd, matmul_prim); - matmul_hub[key] = value; + cache_matmul_primitive(matmul_pd, matmul_prim, transA, M, N, K, matmul_kinds::Resmul); } // Repack and convert input data. @@ -2307,7 +2312,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::Residential, packedB); + std::string key = create_key(transA, M, N, K, matmul_kinds::Residential); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -2359,9 +2364,7 @@ class MMHelper { } // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::Residential, packedB); - std::tuple value(matmul_pd, matmul_prim); - matmul_hub[key] = value; + cache_matmul_primitive(matmul_pd, matmul_prim, transA, M, N, K, matmul_kinds::Residential); } // Repack and convert input data. @@ -2425,7 +2428,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::Basic, B); + std::string key = create_key(transA, M, N, K, matmul_kinds::Basic); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -2447,9 +2450,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::Basic, B); - std::tuple value(matmul_pd, matmul_prim); - matmul_hub[key] = value; + cache_matmul_primitive(matmul_pd, matmul_prim, transA, M, N, K, matmul_kinds::Basic); } auto input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A));