Skip to content

Views of Flux OneHotArrays #1349

@RS-Coop

Description

@RS-Coop

The behviour of views and related functionality on CUDA arrays behaves as expected, but causes problems on Flux OneHotArrays. Specifically, the indices of the views aren't stored in a CUDA array which causes scalar indexing issues and an isbits error. The following example gives some insight into the issue, although it doesn't show the errors that eventually crop up:

using Flux, CUDA

data = rand(1:10, 10, 4)

x = data |> gpu
xhot = Flux.onehotbatch(data, 1:10) |> gpu

typeof(view(x, :, [1,3]))
#=
SubArray{Int64, 2, CuArray{Int64, 2, CUDA.Mem.DeviceBuffer},
Tuple{Base.Slice{Base.OneTo{Int64}}, CuArray{Int64, 1, CUDA.Mem.DeviceBuffer}}, false}
=#

typeof(view(xhot, :, :, [1,3]))
#=
SubArray{Bool, 3, Flux.OneHotArray{UInt32, 10, 2, 3, CuArray{UInt32, 2, CUDA.Mem.DeviceBuffer}},
Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Vector{Int64}}, false}
=#

This may be too niche of a problem, but it seems possible this could be an issue in other places where the structure isn't a CUDA array itself, but uses CUDA arrays. It would be nice if this could be automatically detected, and then view would put all indices on CUDA arrays.

The following code gets around the issue, but it doesn't seem ideal:

xv = view(x, :, [1,3])
typeof(Base.unsafe_view(xhot, Base.Slice(Base.OneTo(10)), Base.Slice(Base.OneTo(10)), xv.indices[ndims(x)]))

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions