Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/utils/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class Env {
// get OneCCL Enabled
bool getOneCCLEnabled() { return oneCCLEnabled; }

// get Primitive Cache M
int getPrimitiveCacheM() { return primitiveCacheM; }

private:
Env() {
// init Verbose
Expand Down Expand Up @@ -105,6 +108,9 @@ class Env {

// init OneCCL Enabled
initOneCCLEnabled();

// init Primitive Cache M
initPrimitiveCacheM();
}

// Verbose
Expand Down Expand Up @@ -260,4 +266,19 @@ class Env {
char *xftOneCCLValue = getenv("XFT_ONECCL");
oneCCLEnabled = xftOneCCLValue != nullptr ? std::atoi(xftOneCCLValue) : false;
}

// XFT_PRIMITIVE_CACHE_M
int primitiveCacheM = 256;
void initPrimitiveCacheM() {
char *xFTPrimitiveCacheMValue = getenv("XFT_PRIMITIVE_CACHE_M");
if (xFTPrimitiveCacheMValue != NULL) {
int value = atoi(xFTPrimitiveCacheMValue);
if (value >= 0)
primitiveCacheM = value;
else
printf("[ERROR] XFT_PRIMITIVE_CACHE_M value need to be greater than or equal to 0.\n");
} else {
primitiveCacheM = 256;
}
}
};
63 changes: 32 additions & 31 deletions src/utils/matmul_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class MMHelper {
}

AMXThresholdM = Env::getInstance().getAMXThresholdM();
primitiveCacheM = Env::getInstance().getPrimitiveCacheM();
cpu_engine = new dnnl::engine(dnnl::engine::kind::cpu, 0);
cpu_stream = new dnnl::stream(*cpu_engine);
}
Expand Down Expand Up @@ -1559,6 +1560,7 @@ class MMHelper {
dnnl::stream *cpu_stream;

int AMXThresholdM;
int primitiveCacheM;

enum matmul_kinds {
Basic = 0,
Expand All @@ -1571,13 +1573,26 @@ class MMHelper {
Resext,
};

template <typename Twei>
std::string create_key(bool transA, int M, int N, int K, int matmul_kind, const Twei *packedB) {
std::string create_key(bool transA, int M, int N, int K, int matmul_kind) {
std::stringstream key;
key << transA << "_" << M << "_" << N << "_" << K << "_" << matmul_kind << "_" << packedB;
key << transA << "_" << M << "_" << N << "_" << K << "_" << matmul_kind;
return key.str();
}

// Cache primitive_desc and matmul
bool cache_matmul_primitive(dnnl::matmul::primitive_desc *matmul_pd, dnnl::matmul *matmul_prim, bool transA, int M,
int N, int K, int matmul_kind) {
// If M < primitiveCacheM or a power of 2, then cache.
if (M <= primitiveCacheM || ((M & (M - 1)) == 0)) {
std::string key = create_key(transA, M, N, K, matmul_kind);
std::tuple<dnnl::matmul::primitive_desc *, dnnl::matmul *> value(matmul_pd, matmul_prim);
matmul_hub[key] = value;
return true;
} else {
return false;
}
}

dnnl::memory::format_tag get_onednn_input_layout(dnnl::memory::data_type dt) {
if (this->kind == dnnl::engine::kind::cpu) {
return dnnl::memory::format_tag::ab;
Expand Down Expand Up @@ -1722,7 +1737,7 @@ class MMHelper {

matmul::primitive_desc *matmul_pd;
matmul *matmul_prim;
std::string key = create_key(transA, M, N, K, postAlg, (const float *)nullptr);
std::string key = create_key(transA, M, N, K, postAlg);
auto it = matmul_hub.find(key);
if (it != matmul_hub.end()) {
matmul_pd = std::get<0>(it->second);
Expand Down Expand Up @@ -1797,9 +1812,7 @@ class MMHelper {
matmul_prim = new matmul(*matmul_pd);

// Cache primitive_desc and matmul
std::string key = create_key(transA, M, N, K, postAlg, (const float *)nullptr);
std::tuple<dnnl::matmul::primitive_desc *, dnnl::matmul *> value(matmul_pd, matmul_prim);
matmul_hub[key] = value;
cache_matmul_primitive(matmul_pd, matmul_prim, transA, M, N, K, postAlg);
}

// Repack and convert input data.
Expand Down Expand Up @@ -1875,7 +1888,7 @@ class MMHelper {

matmul::primitive_desc *matmul_pd;
matmul *matmul_prim;
std::string key = create_key(transA, M, N, K, postAlg, packedB);
std::string key = create_key(transA, M, N, K, postAlg);
auto it = matmul_hub.find(key);
if (it != matmul_hub.end()) {
matmul_pd = std::get<0>(it->second);
Expand Down Expand Up @@ -1937,9 +1950,7 @@ class MMHelper {
}
matmul_prim = new matmul(*matmul_pd);
// Cache primitive_desc and matmul
std::string key = create_key(transA, M, N, K, postAlg, packedB);
std::tuple<dnnl::matmul::primitive_desc *, dnnl::matmul *> value(matmul_pd, matmul_prim);
matmul_hub[key] = value;
cache_matmul_primitive(matmul_pd, matmul_prim, transA, M, N, K, postAlg);
}

// Repack and convert input data.
Expand Down Expand Up @@ -1993,7 +2004,7 @@ class MMHelper {

matmul::primitive_desc *matmul_pd;
matmul *matmul_prim;
std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd, packedB);
std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd);
auto it = matmul_hub.find(key);
if (it != matmul_hub.end()) {
matmul_pd = std::get<0>(it->second);
Expand Down Expand Up @@ -2028,9 +2039,7 @@ class MMHelper {
matmul_prim = new matmul(*matmul_pd);

// Cache primitive_desc and matmul
std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd, packedB);
std::tuple<dnnl::matmul::primitive_desc *, dnnl::matmul *> value(matmul_pd, matmul_prim);
matmul_hub[key] = value;
cache_matmul_primitive(matmul_pd, matmul_prim, transA, M, N, K, matmul_kinds::BiasAdd);
}

// Repack and convert input data.
Expand Down Expand Up @@ -2086,7 +2095,7 @@ class MMHelper {

matmul::primitive_desc *matmul_pd;
matmul *matmul_prim;
std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu, packedB);
std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu);
auto it = matmul_hub.find(key);
if (it != matmul_hub.end()) {
matmul_pd = std::get<0>(it->second);
Expand Down Expand Up @@ -2129,9 +2138,7 @@ class MMHelper {
matmul_prim = new matmul(*matmul_pd);

// Cache primitive_desc and matmul
std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu, packedB);
std::tuple<dnnl::matmul::primitive_desc *, dnnl::matmul *> value(matmul_pd, matmul_prim);
matmul_hub[key] = value;
cache_matmul_primitive(matmul_pd, matmul_prim, transA, M, N, K, matmul_kinds::BiasAdd_Relu);
}

// Repack and convert input data.
Expand Down Expand Up @@ -2187,7 +2194,7 @@ class MMHelper {

matmul::primitive_desc *matmul_pd;
matmul *matmul_prim;
std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul, packedB);
std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul);
auto it = matmul_hub.find(key);
if (it != matmul_hub.end()) {
matmul_pd = std::get<0>(it->second);
Expand Down Expand Up @@ -2232,9 +2239,7 @@ class MMHelper {
matmul_prim = new matmul(*matmul_pd);

// Cache primitive_desc and matmul
std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul, packedB);
std::tuple<dnnl::matmul::primitive_desc *, dnnl::matmul *> value(matmul_pd, matmul_prim);
matmul_hub[key] = value;
cache_matmul_primitive(matmul_pd, matmul_prim, transA, M, N, K, matmul_kinds::Resmul);
}

// Repack and convert input data.
Expand Down Expand Up @@ -2307,7 +2312,7 @@ class MMHelper {

matmul::primitive_desc *matmul_pd;
matmul *matmul_prim;
std::string key = create_key(transA, M, N, K, matmul_kinds::Residential, packedB);
std::string key = create_key(transA, M, N, K, matmul_kinds::Residential);
auto it = matmul_hub.find(key);
if (it != matmul_hub.end()) {
matmul_pd = std::get<0>(it->second);
Expand Down Expand Up @@ -2359,9 +2364,7 @@ class MMHelper {
}

// Cache primitive_desc and matmul
std::string key = create_key(transA, M, N, K, matmul_kinds::Residential, packedB);
std::tuple<dnnl::matmul::primitive_desc *, dnnl::matmul *> value(matmul_pd, matmul_prim);
matmul_hub[key] = value;
cache_matmul_primitive(matmul_pd, matmul_prim, transA, M, N, K, matmul_kinds::Residential);
}

// Repack and convert input data.
Expand Down Expand Up @@ -2425,7 +2428,7 @@ class MMHelper {

matmul::primitive_desc *matmul_pd;
matmul *matmul_prim;
std::string key = create_key(transA, M, N, K, matmul_kinds::Basic, B);
std::string key = create_key(transA, M, N, K, matmul_kinds::Basic);
auto it = matmul_hub.find(key);
if (it != matmul_hub.end()) {
matmul_pd = std::get<0>(it->second);
Expand All @@ -2447,9 +2450,7 @@ class MMHelper {
matmul_prim = new matmul(*matmul_pd);

// Cache primitive_desc and matmul
std::string key = create_key(transA, M, N, K, matmul_kinds::Basic, B);
std::tuple<dnnl::matmul::primitive_desc *, dnnl::matmul *> value(matmul_pd, matmul_prim);
matmul_hub[key] = value;
cache_matmul_primitive(matmul_pd, matmul_prim, transA, M, N, K, matmul_kinds::Basic);
}

auto input_mem = memory(matmul_pd->src_desc(), *engine, const_cast<int8_t *>(A));
Expand Down