Skip to content

Commit 18deacd

Browse files
Michael AbbottMichael Abbott
authored andcommitted
multi-thread loop + single-thread BLAS
1 parent d8c1761 commit 18deacd

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
33
version = "0.7.5"
44

55
[deps]
6+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
67
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
910
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1011
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1112

1213
[compat]
14+
Compat = "3.13"
1315
Requires = "0.5, 1.0"
1416
julia = "1.3"
1517

src/gemm.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
using LinearAlgebra
55
using LinearAlgebra.BLAS: libblas, BlasInt, @blasfunc
66

7+
using Compat: get_num_threads, set_num_threads
8+
79
"""
810
gemm!()
911
@@ -89,22 +91,28 @@ for (gemm, elt) in gemm_datatype_mappings
8991
strB = size(B, 3) == 1 ? 0 : Base.stride(B, 3)
9092
strC = Base.stride(C, 3)
9193

92-
for k in 1:size(A, 3)
94+
old_threads = get_num_threads()
95+
set_num_threads(1)
96+
97+
Threads.@threads for k in 1:size(C, 3)
98+
99+
ptrAk = ptrA + (k-1) * strA * sizeof($elt)
100+
ptrBk = ptrB + (k-1) * strB * sizeof($elt)
101+
ptrCk = ptrC + (k-1) * strC * sizeof($elt)
102+
93103
ccall((@blasfunc($(gemm)), libblas), Nothing,
94104
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
95105
Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt},
96106
Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt},
97107
Ref{BlasInt}),
98108
transA, transB, m, n,
99-
ka, alpha, ptrA, max(1,Base.stride(A,2)),
100-
ptrB, max(1,Base.stride(B,2)), beta, ptrC,
109+
ka, alpha, ptrAk, max(1,Base.stride(A,2)),
110+
ptrBk, max(1,Base.stride(B,2)), beta, ptrCk,
101111
max(1,Base.stride(C,2)))
102-
103-
ptrA += strA * sizeof($elt)
104-
ptrB += strB * sizeof($elt)
105-
ptrC += Base.stride(C, 3) * sizeof($elt)
106112
end
107113

114+
set_num_threads(old_threads)
115+
108116
C
109117
end
110118
end

0 commit comments

Comments
 (0)