Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,24 @@ OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, T}
OneHotArray(indices::AbstractArray{T, N}, L::Integer) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices)

_indices(x::OneHotArray) = x.indices
_indices(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) =
reshape(parent(x).indices, x.dims[2:end])

const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T}
const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I}

OneHotVector(idx, L) = OneHotArray(idx, L)
OneHotMatrix(indices, L) = OneHotArray(indices, L)

# use this type so reshaped arrays hit fast paths
# e.g. argmax
const OneHotLike{T, L, N, var"N+1", I} =
Union{OneHotArray{T, L, N, var"N+1", I},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Man this N+1 is tripping me up, I would say we need to remove this soon. Where is it used exactly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we could calculate var"N+1" during runtime?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like it either! It can't be done at runtime since N and var"N+1" are used in the type specification. N is used to specify the type of the index array, and var"N+1" is used to inherit from AbstractArray{Bool, var"N+1"}. Neither is evaluated at runtime.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could change it to another variable. I don't have strong feelings, but a part of me says that at least this naming signals the intent of the type parameter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be fair, I did mean we would have to switch it out during construction, because I don't think it's any better for dispatch to have to do checks on ints than types. To me it suggests that it is a preknown quantity so adding it to the type doesn't win us much.

Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}}

_isonehot(x::OneHotArray) = true
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L)

Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...)

_onehotindex(x, i) = (x == i)
Expand All @@ -28,34 +39,30 @@ Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.i
Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x
Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...]

_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}

function Base.cat(xs::OneHotArray{<:Any, L}...; dims::Int) where L
if isone(dims)
return throw(ArgumentError("Cannot concat OneHotArray along first dimension. Use collect to convert to Bool array first."))
function Base.cat(xs::OneHotLike{<:Any, L}...; dims::Int) where L
if isone(dims) || any(x -> !_isonehot(x), xs)
return cat(map(x -> convert(_onehot_bool_type(x), x), xs)...; dims = dims)
else
return OneHotArray(cat(_indices.(xs)...; dims = dims - 1), L)
end
end

Base.hcat(xs::OneHotArray...) = cat(xs...; dims = 2)
Base.vcat(xs::OneHotArray...) = cat(xs...; dims = 1)

Base.reshape(x::OneHotArray{<:Any, L}, dims::Dims) where L =
(first(dims) == L) ? OneHotArray(reshape(x.indices, dims[2:end]...), L) :
throw(ArgumentError("Cannot reshape OneHotArray if first(dims) != size(x, 1)"))
Base._reshape(x::OneHotArray, dims::Tuple{Vararg{Int}}) = reshape(x, dims)
Base.hcat(xs::OneHotLike...) = cat(xs...; dims = 2)
Base.vcat(xs::OneHotLike...) = cat(xs...; dims = 1)

batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L)

Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, x.indices), L)
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)

Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CUDA.CuArrayStyle{N}()

Base.argmax(x::OneHotArray; dims = Colon()) =
(dims == 1) ? reshape(CartesianIndex.(x.indices, CartesianIndices(x.indices)), 1, size(x.indices)...) :
argmax(convert(_onehot_bool_type(x), x); dims = dims)
Base.argmax(x::OneHotLike; dims = Colon()) =
(_isonehot(x) && dims == 1) ?
reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) :
argmax(convert(_onehot_bool_type(x), x); dims = dims)

"""
onehot(l, labels[, unk])
Expand Down Expand Up @@ -135,11 +142,17 @@ function onecold(y::AbstractArray, labels = 1:size(y, 1))
end

_fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1)
_fast_argmax(x::OneHotArray) = x.indices
function _fast_argmax(x::OneHotLike)
if _isonehot(x)
return _indices(x)
else
return _fast_argmax(convert(_onehot_bool_type(x), x))
end
end

@nograd OneHotArray, onecold, onehot, onehotbatch

function Base.:(*)(A::AbstractMatrix, B::OneHotArray{<:Any, L}) where L
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
return A[:, onecold(B)]
end
32 changes: 22 additions & 10 deletions test/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ end
end

@testset "OneHotArray" begin
using Flux: OneHotArray, OneHotVector, OneHotMatrix
using Flux: OneHotArray, OneHotVector, OneHotMatrix, OneHotLike

ov = OneHotVector(rand(1:10), 10)
om = OneHotMatrix(rand(1:10, 5), 10)
Expand Down Expand Up @@ -74,27 +74,39 @@ end
@testset "Concatenating" begin
# vector cat
@test hcat(ov, ov) == OneHotMatrix(vcat(ov.indices, ov.indices), 10)
@test_throws ArgumentError vcat(ov, ov)
@test vcat(ov, ov) == vcat(collect(ov), collect(ov))
@test cat(ov, ov; dims = 3) == OneHotArray(cat(ov.indices, ov.indices; dims = 2), 10)

# matrix cat
@test hcat(om, om) == OneHotMatrix(vcat(om.indices, om.indices), 10)
@test_throws ArgumentError vcat(om, om)
@test vcat(om, om) == vcat(collect(om), collect(om))
@test cat(om, om; dims = 3) == OneHotArray(cat(om.indices, om.indices; dims = 2), 10)

# array cat
@test cat(oa, oa; dims = 3) == OneHotArray(cat(oa.indices, oa.indices; dims = 2), 10)
@test_throws ArgumentError cat(oa, oa; dims = 1)
@test cat(oa, oa; dims = 1) == cat(collect(oa), collect(oa); dims = 1)
end

@testset "Base.reshape" begin
# reshape test
@test reshape(oa, 10, 25) isa OneHotArray
@test reshape(oa, 10, :) isa OneHotArray
@test reshape(oa, :, 25) isa OneHotArray
@test_throws ArgumentError reshape(oa, 50, :)
@test_throws ArgumentError reshape(oa, 5, 10, 5)
@test reshape(oa, (10, 25)) isa OneHotArray
@test reshape(oa, 10, 25) isa OneHotLike
@test reshape(oa, 10, :) isa OneHotLike
@test reshape(oa, :, 25) isa OneHotLike
@test reshape(oa, 50, :) isa OneHotLike
@test reshape(oa, 5, 10, 5) isa OneHotLike
@test reshape(oa, (10, 25)) isa OneHotLike

@testset "w/ cat" begin
r = reshape(oa, 10, :)
@test hcat(r, r) isa OneHotArray
@test vcat(r, r) isa Array{Bool}
end

@testset "w/ argmax" begin
r = reshape(oa, 10, :)
@test argmax(r) == argmax(OneHotMatrix(reshape(oa.indices, :), 10))
@test Flux._fast_argmax(r) == collect(reshape(oa.indices, :))
end
end

@testset "Base.argmax" begin
Expand Down