-
Notifications
You must be signed in to change notification settings - Fork 260
Closed
Labels
enhancementNew feature or requestNew feature or request
Description
The behviour of views and related functionality on CUDA arrays behaves as expected, but causes problems on Flux OneHotArrays. Specifically, the indices of the views aren't stored in a CUDA array which causes scalar indexing issues and an isbits error. The following example gives some insight into the issue, although it doesn't show the errors that eventually crop up:
using Flux, CUDA
data = rand(1:10, 10, 4)
x = data |> gpu
xhot = Flux.onehotbatch(data, 1:10) |> gpu
typeof(view(x, :, [1,3]))
#=
SubArray{Int64, 2, CuArray{Int64, 2, CUDA.Mem.DeviceBuffer},
Tuple{Base.Slice{Base.OneTo{Int64}}, CuArray{Int64, 1, CUDA.Mem.DeviceBuffer}}, false}
=#
typeof(view(xhot, :, :, [1,3]))
#=
SubArray{Bool, 3, Flux.OneHotArray{UInt32, 10, 2, 3, CuArray{UInt32, 2, CUDA.Mem.DeviceBuffer}},
Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Vector{Int64}}, false}
=#This may be too niche of a problem, but it seems possible this could be an issue in other places where the structure isn't a CUDA array itself, but uses CUDA arrays. It would be nice if this could be automatically detected, and then view would put all indices on CUDA arrays.
The following code gets around the issue, but it doesn't seem ideal:
xv = view(x, :, [1,3])
typeof(Base.unsafe_view(xhot, Base.Slice(Base.OneTo(10)), Base.Slice(Base.OneTo(10)), xv.indices[ndims(x)]))Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request