Skip to content

Commit 9c44a25

Browse files
committed
[rocSPARSE] Update the interface for sparse products
1 parent e61e088 commit 9c44a25

File tree

1 file changed

+46
-42
lines changed

1 file changed

+46
-42
lines changed

src/sparse/interfaces.jl

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -22,48 +22,52 @@ function mm_wrapper(
2222
mm!(transa, transb, alpha, A, B, beta, C, 'O')
2323
end
2424

25-
tag_wrappers = (
26-
(identity, identity),
27-
(T -> :(HermOrSym{T, <:$T}), A -> :(parent($A))))
28-
29-
op_wrappers = (
30-
(identity, T -> 'N', identity),
31-
(T -> :(Transpose{<:T, <:$T}), T -> 'T', A -> :(parent($A))),
32-
(T -> :(Adjoint{<:T, <:$T}), T -> T <: Real ? 'T' : 'C', A -> :(parent($A))))
33-
34-
for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
35-
TypeA = wrapa(taga(:(ROCSparseMatrix{T})))
36-
37-
@eval begin
38-
function LinearAlgebra.mul!(
39-
C::ROCVector{T}, A::$TypeA, B::DenseROCVector{T},
40-
alpha::Number, beta::Number,
41-
) where T <: Union{Float16, ComplexF16, BlasFloat}
42-
mv_wrapper($transa(T), alpha, $(untaga(unwrapa(:A))), B, beta, C)
43-
end
44-
45-
function LinearAlgebra.mul!(
46-
C::ROCVector{Complex{T}}, A::$TypeA, B::DenseROCVector{Complex{T}},
47-
alpha::Number, beta::Number,
48-
) where T <: Union{Float16, BlasFloat}
49-
mv_wrapper($transa(T), alpha, $(untaga(unwrapa(:A))), B, beta, C)
50-
end
51-
end
52-
53-
for (tagb, untagb) in tag_wrappers, (wrapb, transb, unwrapb) in op_wrappers
54-
TypeB = wrapb(tagb(:(DenseROCMatrix{T})))
55-
56-
@eval begin
57-
function LinearAlgebra.mul!(
58-
C::ROCMatrix{T}, A::$TypeA, B::$TypeB,
59-
alpha::Number, beta::Number,
60-
) where T <: Union{Float16, ComplexF16, BlasFloat}
61-
mm_wrapper(
62-
$transa(T), $transb(T), alpha,
63-
$(untaga(unwrapa(:A))), $(untagb(unwrapb(:B))), beta, C)
64-
end
65-
end
66-
end
25+
# legacy methods with final MulAddMul argument
26+
LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::DenseROCVector{T}, _add::MulAddMul) where T <: BlasFloat =
27+
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
28+
LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::ROCSparseVector{T}, _add::MulAddMul) where T <: BlasFloat =
29+
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
30+
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::ROCSparseMatrix{T}, B::DenseROCMatrix{T}, _add::MulAddMul) where T <: BlasFloat =
31+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
32+
33+
function LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::DenseROCVector{T}, alpha::Number, beta::Number) where T <: BlasFloat
34+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
35+
mv_wrapper(tA, alpha, A, B, beta, C)
36+
end
37+
38+
function LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::ROCSparseVector{T}, alpha::Number, beta::Number) where T <: BlasFloat
39+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
40+
mv_wrapper(tA, alpha, A, ROCVector{T}(B), beta, C)
41+
end
42+
43+
function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::ROCSparseMatrix{T}, B::DenseROCMatrix{T}, alpha::Number, beta::Number) where T <: BlasFloat
44+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
45+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
46+
mm_wrapper(tA, tB, alpha, A, B, beta, C)
47+
end
48+
49+
# legacy methods with final MulAddMul argument
50+
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSC{T}, _add::MulAddMul) where T <: BlasFloat =
51+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
52+
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSR{T}, _add::MulAddMul) where T <: BlasFloat =
53+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
54+
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCOO{T}, _add::MulAddMul) where T <: BlasFloat =
55+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
56+
57+
function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSC{T}, alpha::Number, beta::Number) where T <: BlasFloat
58+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
59+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
60+
mm!(tA, tB, alpha, A, B, beta, C, 'O')
61+
end
62+
function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSR{T}, alpha::Number, beta::Number) where T <: BlasFloat
63+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
64+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
65+
mm!(tA, tB, alpha, A, B, beta, C, 'O')
66+
end
67+
function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCOO{T}, alpha::Number, beta::Number) where T <: BlasFloat
68+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
69+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
70+
mm!(tA, tB, alpha, A, B, beta, C, 'O')
6771
end
6872

6973
Base.:(+)(A::ROCSparseMatrixCSR, B::ROCSparseMatrixCSR) = geam(one(eltype(A)), A, one(eltype(A)), B, 'O')

0 commit comments

Comments
 (0)