Skip to content
This repository was archived by the owner on Aug 11, 2020. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mshadow/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ extern "C" {
#include <mkl_cblas.h>
#include <mkl_vsl.h>
#include <mkl_vsl_functions.h>
#include <mkl_version.h>
#endif

#if MSHADOW_USE_CUDA
Expand Down
75 changes: 75 additions & 0 deletions mshadow/dot_engine-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef MSHADOW_DOT_ENGINE_INL_H_
#define MSHADOW_DOT_ENGINE_INL_H_

#include <vector>
#include "./base.h"
#include "./extension/implicit_gemm.h"

Expand Down Expand Up @@ -291,11 +292,48 @@ struct BLASEngine<cpu, float> {
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc, int batch_count,
float **workspace) {
#if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
std::vector<int> p_m(batch_count, m);
std::vector<int> p_n(batch_count, n);
std::vector<int> p_k(batch_count, k);
std::vector<int> p_lda(batch_count, lda);
std::vector<int> p_ldb(batch_count, ldb);
std::vector<int> p_ldc(batch_count, ldc);
std::vector<float> p_alpha(batch_count, alpha);
std::vector<float> p_beta(batch_count, beta);
std::vector<const float*> pp_A;
std::vector<const float*> pp_B;
std::vector<float*> pp_C;

CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);

std::vector<int> p_group_sizeb(batch_count, batch_count);
std::vector<CBLAS_TRANSPOSE> p_transa(batch_count, cblas_a_trans);
std::vector<CBLAS_TRANSPOSE> p_transb(batch_count, cblas_b_trans);

auto m_k = m * k;
auto k_n = k * n;
auto m_n = m * n;

for (int i = 0; i < batch_count; i++) {
pp_A.push_back(A + i * m_k);
pp_B.push_back(B + i * k_n);
pp_C.push_back(C + i * m_n);
}

cblas_sgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(),
p_m.data(), p_n.data(), p_k.data(),
p_alpha.data(), pp_A.data(), p_lda.data(), pp_B.data(),
p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(),
1, p_group_sizeb.data());
#else
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
#endif
}
inline static void gemv(Stream<cpu> *stream,
bool trans, int m, int n,
Expand Down Expand Up @@ -361,11 +399,48 @@ struct BLASEngine<cpu, double> {
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc, int batch_count,
double **workspace) {
#if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
std::vector<int> p_m(batch_count, m);
std::vector<int> p_n(batch_count, n);
std::vector<int> p_k(batch_count, k);
std::vector<int> p_lda(batch_count, lda);
std::vector<int> p_ldb(batch_count, ldb);
std::vector<int> p_ldc(batch_count, ldc);
std::vector<double> p_alpha(batch_count, alpha);
std::vector<double> p_beta(batch_count, beta);
std::vector<const double*> pp_A;
std::vector<const double*> pp_B;
std::vector<double*> pp_C;

CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);

std::vector<int> p_group_sizeb(batch_count, batch_count);
std::vector<CBLAS_TRANSPOSE> p_transa(batch_count, cblas_a_trans);
std::vector<CBLAS_TRANSPOSE> p_transb(batch_count, cblas_b_trans);

auto m_k = m * k;
auto k_n = k * n;
auto m_n = m * n;

for (int i = 0; i < batch_count; i++) {
pp_A.push_back(A + i * m_k);
pp_B.push_back(B + i * k_n);
pp_C.push_back(C + i * m_n);
}

cblas_dgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(),
p_m.data(), p_n.data(), p_k.data(),
p_alpha.data(), pp_A.data(), p_lda.data(), pp_B.data(),
p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(),
1, p_group_sizeb.data());
#else
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
#endif
}
inline static void gemv(Stream<cpu> *stream,
bool trans, int m, int n, double alpha,
Expand Down