|
4 | 4 | using LinearAlgebra |
5 | 5 | using LinearAlgebra.BLAS: libblas, BlasInt, @blasfunc |
6 | 6 |
|
| 7 | +using Compat: get_num_threads, set_num_threads |
| 8 | + |
7 | 9 | """ |
8 | 10 | gemm!() |
9 | 11 |
|
@@ -89,22 +91,28 @@ for (gemm, elt) in gemm_datatype_mappings |
89 | 91 | strB = size(B, 3) == 1 ? 0 : Base.stride(B, 3) |
90 | 92 | strC = Base.stride(C, 3) |
91 | 93 |
|
92 | | - for k in 1:size(A, 3) |
| 94 | + old_threads = get_num_threads() |
| 95 | + set_num_threads(1) |
| 96 | + |
| 97 | + Threads.@threads for k in 1:size(C, 3) |
| 98 | + |
| 99 | + ptrAk = ptrA + (k-1) * strA * sizeof($elt) |
| 100 | + ptrBk = ptrB + (k-1) * strB * sizeof($elt) |
| 101 | + ptrCk = ptrC + (k-1) * strC * sizeof($elt) |
| 102 | + |
93 | 103 | ccall((@blasfunc($(gemm)), libblas), Nothing, |
94 | 104 | (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, |
95 | 105 | Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}, |
96 | 106 | Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, |
97 | 107 | Ref{BlasInt}), |
98 | 108 | transA, transB, m, n, |
99 | | - ka, alpha, ptrA, max(1,Base.stride(A,2)), |
100 | | - ptrB, max(1,Base.stride(B,2)), beta, ptrC, |
| 109 | + ka, alpha, ptrAk, max(1,Base.stride(A,2)), |
| 110 | + ptrBk, max(1,Base.stride(B,2)), beta, ptrCk, |
101 | 111 | max(1,Base.stride(C,2))) |
102 | | - |
103 | | - ptrA += strA * sizeof($elt) |
104 | | - ptrB += strB * sizeof($elt) |
105 | | - ptrC += Base.stride(C, 3) * sizeof($elt) |
106 | 112 | end |
107 | 113 |
|
| 114 | + set_num_threads(old_threads) |
| 115 | + |
108 | 116 | C |
109 | 117 | end |
110 | 118 | end |
|
0 commit comments