Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ArrayLayouts"
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
authors = ["Sheehan Olver <[email protected]>"]
version = "1.4"
version = "1.4.1"

[deps]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand Down
26 changes: 24 additions & 2 deletions src/muladd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,23 @@ function default_blasmul!(α, A::AbstractMatrix, B::AbstractMatrix, β, C::Abstr
C
end

function default_blasmul!(α, A::AbstractVector, B::AbstractMatrix, β, C::AbstractMatrix)
mA, = size(A)
mB, nB = size(B)
1 == mB || throw(DimensionMismatch("Dimensions must match"))
size(C) == (mA, nB) || throw(DimensionMismatch("Dimensions must match"))

lmul!(β, C)

(iszero(mA) || iszero(nB)) && return C

for k in colsupport(A), j in rowsupport(B)
_default_blasmul_loop!(α, A, B, β, C, k, j)
end
C
end


function _default_blasmul!(::IndexLinear, α, A::AbstractMatrix, B::AbstractVector, β, C::AbstractVector)
mA, nA = size(A)
mB = length(B)
Expand Down Expand Up @@ -266,6 +283,11 @@ function materialize!(M::MatMulVecAdd)
default_blasmul!(α, unalias(C,A), unalias(C,B), iszero(β) ? false : β, C)
end

function materialize!(M::VecMulMatAdd)
α, A, B, β, C = M.α, M.A, M.B, M.β, M.C
default_blasmul!(α, unalias(C,A), unalias(C,B), iszero(β) ? false : β, C)
end

@inline _gemv!(tA, α, A, x, β, y) = BLAS.gemv!(tA, α, unalias(y,A), unalias(y,x), β, y)
@inline _gemm!(tA, tB, α, A, B, β, C) = BLAS.gemm!(tA, tB, α, unalias(C,A), unalias(C,B), β, C)

Expand Down Expand Up @@ -424,8 +446,7 @@ function similar(M::MulAdd{<:DualLayout,<:Any,ZerosLayout}, ::Type{T}, (x,y)) wh
trans(similar(trans(M.A), T, y))
end

function similar(M::MulAdd{<:Any,<:DualLayout,ZerosLayout}, ::Type{T}, (x,y)) where T
@assert length(x) == 1
function similar(M::MulAdd{ScalarLayout,<:DualLayout,ZerosLayout}, ::Type{T}, (x,y)) where T
trans = transtype(M.B)
trans(similar(trans(M.B), T, y))
end
Expand All @@ -434,3 +455,4 @@ const ZerosLayouts = Union{ZerosLayout,DualLayout{ZerosLayout}}
copy(M::MulAdd{<:ZerosLayouts, <:ZerosLayouts, <:ZerosLayouts}) = M.C
copy(M::MulAdd{<:ZerosLayouts, <:Any, <:ZerosLayouts}) = M.C
copy(M::MulAdd{<:Any, <:ZerosLayouts, <:ZerosLayouts}) = M.C

4 changes: 4 additions & 0 deletions test/test_muladd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -731,4 +731,8 @@ Random.seed!(0)
Y = randn(rng, 8, 2)
@test mul(Y',X) ≈ Y'X
end

@testset "Vec * Adj" begin
@test ArrayLayouts.mul(1:5, (1:4)') == (1:5) * (1:4)'
end
end