diff --git a/src/onehot.jl b/src/onehot.jl index c9f0b145a0..5bdf53c5f4 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -9,6 +9,8 @@ OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, T} OneHotArray(indices::AbstractArray{T, N}, L::Integer) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices) _indices(x::OneHotArray) = x.indices +_indices(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) = + reshape(parent(x).indices, x.dims[2:end]) const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T} const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I} @@ -16,6 +18,15 @@ const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I} OneHotVector(idx, L) = OneHotArray(idx, L) OneHotMatrix(indices, L) = OneHotArray(indices, L) +# use this type so reshaped arrays hit fast paths +# e.g. argmax +const OneHotLike{T, L, N, var"N+1", I} = + Union{OneHotArray{T, L, N, var"N+1", I}, + Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}} + +_isonehot(x::OneHotArray) = true +_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L) + Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...) _onehotindex(x, i) = (x == i) @@ -28,34 +39,30 @@ Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.i Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...] -_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N} -_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N} +_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N} +_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N} -function Base.cat(xs::OneHotArray{<:Any, L}...; dims::Int) where L - if isone(dims) - return throw(ArgumentError("Cannot concat OneHotArray along first dimension. Use collect to convert to Bool array first.")) +function Base.cat(xs::OneHotLike{<:Any, L}...; dims::Int) where L + if isone(dims) || any(x -> !_isonehot(x), xs) + return cat(map(x -> convert(_onehot_bool_type(x), x), xs)...; dims = dims) else return OneHotArray(cat(_indices.(xs)...; dims = dims - 1), L) end end -Base.hcat(xs::OneHotArray...) = cat(xs...; dims = 2) -Base.vcat(xs::OneHotArray...) = cat(xs...; dims = 1) - -Base.reshape(x::OneHotArray{<:Any, L}, dims::Dims) where L = - (first(dims) == L) ? OneHotArray(reshape(x.indices, dims[2:end]...), L) : - throw(ArgumentError("Cannot reshape OneHotArray if first(dims) != size(x, 1)")) -Base._reshape(x::OneHotArray, dims::Tuple{Vararg{Int}}) = reshape(x, dims) +Base.hcat(xs::OneHotLike...) = cat(xs...; dims = 2) +Base.vcat(xs::OneHotLike...) = cat(xs...; dims = 1) batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L) -Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, x.indices), L) +Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L) Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CUDA.CuArrayStyle{N}() -Base.argmax(x::OneHotArray; dims = Colon()) = - (dims == 1) ? reshape(CartesianIndex.(x.indices, CartesianIndices(x.indices)), 1, size(x.indices)...) : - argmax(convert(_onehot_bool_type(x), x); dims = dims) +Base.argmax(x::OneHotLike; dims = Colon()) = + (_isonehot(x) && dims == 1) ? + reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) : + invoke(argmax, Tuple{AbstractArray}, x; dims = dims) """ onehot(l, labels[, unk]) @@ -135,11 +142,18 @@ function onecold(y::AbstractArray, labels = 1:size(y, 1)) end _fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1) -_fast_argmax(x::OneHotArray) = x.indices +function _fast_argmax(x::OneHotLike) + if _isonehot(x) + return _indices(x) + else + return _fast_argmax(convert(_onehot_bool_type(x), x)) + end +end @nograd OneHotArray, onecold, onehot, onehotbatch -function Base.:(*)(A::AbstractMatrix, B::OneHotArray{<:Any, L}) where L +function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L + _isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B) size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L")) return A[:, onecold(B)] end diff --git a/test/onehot.jl b/test/onehot.jl index 9461bc816d..ce30534ec9 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -35,7 +35,7 @@ end end @testset "OneHotArray" begin - using Flux: OneHotArray, OneHotVector, OneHotMatrix + using Flux: OneHotArray, OneHotVector, OneHotMatrix, OneHotLike ov = OneHotVector(rand(1:10), 10) om = OneHotMatrix(rand(1:10, 5), 10) @@ -74,27 +74,39 @@ end @testset "Concatenating" begin # vector cat @test hcat(ov, ov) == OneHotMatrix(vcat(ov.indices, ov.indices), 10) - @test_throws ArgumentError vcat(ov, ov) + @test vcat(ov, ov) == vcat(collect(ov), collect(ov)) @test cat(ov, ov; dims = 3) == OneHotArray(cat(ov.indices, ov.indices; dims = 2), 10) # matrix cat @test hcat(om, om) == OneHotMatrix(vcat(om.indices, om.indices), 10) - @test_throws ArgumentError vcat(om, om) + @test vcat(om, om) == vcat(collect(om), collect(om)) @test cat(om, om; dims = 3) == OneHotArray(cat(om.indices, om.indices; dims = 2), 10) # array cat @test cat(oa, oa; dims = 3) == OneHotArray(cat(oa.indices, oa.indices; dims = 2), 10) - @test_throws ArgumentError cat(oa, oa; dims = 1) + @test cat(oa, oa; dims = 1) == cat(collect(oa), collect(oa); dims = 1) end @testset "Base.reshape" begin # reshape test - @test reshape(oa, 10, 25) isa OneHotArray - @test reshape(oa, 10, :) isa OneHotArray - @test reshape(oa, :, 25) isa OneHotArray - @test_throws ArgumentError reshape(oa, 50, :) - @test_throws ArgumentError reshape(oa, 5, 10, 5) - @test reshape(oa, (10, 25)) isa OneHotArray + @test reshape(oa, 10, 25) isa OneHotLike + @test reshape(oa, 10, :) isa OneHotLike + @test reshape(oa, :, 25) isa OneHotLike + @test reshape(oa, 50, :) isa OneHotLike + @test reshape(oa, 5, 10, 5) isa OneHotLike + @test reshape(oa, (10, 25)) isa OneHotLike + + @testset "w/ cat" begin + r = reshape(oa, 10, :) + @test hcat(r, r) isa OneHotArray + @test vcat(r, r) isa Array{Bool} + end + + @testset "w/ argmax" begin + r = reshape(oa, 10, :) + @test argmax(r) == argmax(OneHotMatrix(reshape(oa.indices, :), 10)) + @test Flux._fast_argmax(r) == collect(reshape(oa.indices, :)) + end end @testset "Base.argmax" begin