Skip to content

Conversation

@p-zubieta
Copy link
Owner

@p-zubieta p-zubieta commented Feb 5, 2022

Fixes #10

TODO

  • Add CUDA support
  • Add PythonCall support
  • Add tests
  • Remove DLArray

@p-zubieta p-zubieta marked this pull request as draft February 5, 2022 01:47
@p-zubieta p-zubieta force-pushed the export branch 2 times, most recently from 608fb29 to 4c4b3e4 Compare February 5, 2022 03:23
@p-zubieta p-zubieta changed the title Add a to_dlpack interface Add a share interface Feb 7, 2022
@rejuvyesh
Copy link
Contributor

signal (11): Segmentation fault
in expression starting at REPL[1]:1
unknown function (ip: 0x7f31c58049d0)
Allocations: 107518953 (Pool: 107475126; Big: 43827); GC: 77

I'm seeing intermittent segfaults like above with this branch, just so you are aware.

@p-zubieta
Copy link
Owner Author

Do you have a reproducible example?

@rejuvyesh
Copy link
Contributor

rejuvyesh commented Feb 7, 2022

It's just intermittent unfortunately. With https://gist.github.com/rejuvyesh/0c0995ac81d8c75efada7797a292f611:

julia> include("test/stresstest_dlpack.jl")
[ Info: Precompiling Zygote [e88e6eb3-aa80-5325-afca-941959d7151f]
┌ Warning: `vendor()` is deprecated, use `BLAS.get_config()` and inspect the output instead
│   caller = npyinitialize() at numpy.jl:67
└ @ PyCall ~/.julia/packages/PyCall/L0fLP/src/numpy.jl:67
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")

signal (11): Segmentation fault
in expression starting at /home/jagupt/.julia/dev/PyCallChainRules/test/stresstest_dlpack.jl:102
unknown function (ip: 0x7f7076f69e30)
Allocations: 62477905 (Pool: 62457694; Big: 20211); GC: 69

It's still quite a big example and I'm trying to reduce it to something smaller. But in case this is useful by itself.

@rejuvyesh
Copy link
Contributor

rejuvyesh commented Feb 7, 2022

Also I haven't been able to reproduce the segfault with jax's dlpack.

EDIT: I have been able to get segfaults on jax when used with CUDA. See https://gist.github.com/rejuvyesh/be230c57faa1bffeffc57f6d4f9a9514

@p-zubieta
Copy link
Owner Author

Can you try again @rejuvyesh?

@rejuvyesh
Copy link
Contributor

Yep, this fixes the issues! Amazing sleuthing!

@p-zubieta
Copy link
Owner Author

To support both PyCall and PythonCall, I need to change the signature of share to something like share(::StridedArray, ::Type{PyObject}, from_dlpack) and share(::StridedArray, ::Type{Py}, from_dlpack) (share(::StridedArray, from_dlpack::PyObject) and share(::StridedArray, from_dlpack::Py) would still work though).

@rejuvyesh
Copy link
Contributor

I think for proper CUDA support (with CUDA.allowscalar(false)) for DLArrays we will need to do something with Adapt.adapt for DLArray operations to use CUDA functions.

@rejuvyesh
Copy link
Contributor

using CUDA
using DLPack
using PyCall

CUDA.allowscalar(false)

dlpack = pyimport("jax.dlpack")
numpy = pyimport("numpy")

pyto_dlpack(x) = @pycall dlpack.to_dlpack(x)::PyObject
pyfrom_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject


jax = pyimport("jax")
key = jax.random.PRNGKey(0)
jax_x = jax.random.normal(key, (2, 3))
jax_sum = jax.numpy.sum(jax_x)
jl_x = DLPack.DLArray(jax_x, pyto_dlpack)
jl_sum = sum(jl_x)

@assert isapprox(jax_sum.item(), jl_sum)

Results into:

julia> include("test/stresstest_dlpack_cuda.jl")
ERROR: LoadError: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] assertscalar(op::String)
    @ GPUArrays ~/.julia/packages/GPUArrays/umZob/src/host/indexing.jl:53
  [3] getindex(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Int64, ::Int64)
    @ GPUArrays ~/.julia/packages/GPUArrays/umZob/src/host/indexing.jl:86
  [4] getindex
    @ ./permuteddimsarray.jl:71 [inlined]
  [5] _getindex
    @ ./abstractarray.jl:1262 [inlined]
  [6] getindex
    @ ./abstractarray.jl:1218 [inlined]
  [7] getindex
    @ ~/.julia/packages/DLPack/OdZvu/src/DLPack.jl:304 [inlined]
  [8] iterate
    @ ./abstractarray.jl:1144 [inlined]
  [9] iterate
    @ ./abstractarray.jl:1142 [inlined]
 [10] _foldl_impl(op::Base.BottomRF{typeof(Base.add_sum)}, init::Base._InitialValue, itr::DLMatrix{Float32, PermutedDimsArray{Float32, 2, (2, 1), (2, 1), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, PyObject})
    @ Base ./reduce.jl:56
 [11] foldl_impl
    @ ./reduce.jl:48 [inlined]
 [12] mapfoldl_impl
    @ ./reduce.jl:44 [inlined]
 [13] #mapfoldl#244
    @ ./reduce.jl:162 [inlined]
 [14] mapfoldl
    @ ./reduce.jl:162 [inlined]
 [15] _mapreduce
    @ ./reduce.jl:423 [inlined]
 [16] _mapreduce_dim
    @ ./reducedim.jl:330 [inlined]
 [17] #mapreduce#725
    @ ./reducedim.jl:322 [inlined]
 [18] mapreduce
    @ ./reducedim.jl:322 [inlined]
 [19] #_sum#735
    @ ./reducedim.jl:894 [inlined]
 [20] _sum
    @ ./reducedim.jl:894 [inlined]
 [21] #_sum#734
    @ ./reducedim.jl:893 [inlined]
 [22] _sum
    @ ./reducedim.jl:893 [inlined]
 [23] #sum#732
    @ ./reducedim.jl:889 [inlined]
 [24] sum(a::DLMatrix{Float32, PermutedDimsArray{Float32, 2, (2, 1), (2, 1), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, PyObject})
    @ Base ./reducedim.jl:889
 [25] top-level scope
    @ ~/.julia/dev/PyCallChainRules/test/stresstest_dlpack_cuda.jl:19
 [26] include(fname::String)
    @ Base.MainInclude ./client.jl:451
 [27] top-level scope
    @ REPL[2]:1
 [28] top-level scope
    @ ~/.julia/packages/CUDA/bki2w/src/initialization.jl:52
in expression starting at /home/jagupt/.julia/dev/PyCallChainRules/test/stresstest_dlpack_cuda.jl:19

@rejuvyesh
Copy link
Contributor

It seems we will need a separate wrapper for CUDA arrays than DLArray to avoid the AbstractArray fallback: JuliaLang/julia#31563

@p-zubieta
Copy link
Owner Author

p-zubieta commented Feb 13, 2022

Ok. I decided to write a wrap method to replace DLArray. This method returns Arrays and CuArrays directly and stores the python objects in a global pool instead, which should avoid any issues with CUDA.allowscalar(false) and not require Adapt at all.

The only thing is that I dropped the PermuteDimsArrays wrappers for now (I could add them again if really necessary). But I'm always doing reshape(a, reverse(size(a))) when the layout of the imported tensors is row major anyway.

So now the interface looks like:

  • DLPack.wrap (similar to from_dlpack in python libraries)
  • DLPack.share (similar to to_dlpack in python libraries)

@p-zubieta p-zubieta changed the title Add a share interface Add share and wrap interfaces Feb 13, 2022
@p-zubieta
Copy link
Owner Author

p-zubieta commented Feb 13, 2022

Once all python libraries support exporting and importing via __dlpack__ attributes we can write to_dlpack and from_dlpack methods that match the semantics from https://data-apis.org/array-api/latest/design_topics/data_interchange.html. But I think DLPack.wrap and DLPack.share should work fine for now.

@rejuvyesh
Copy link
Contributor

Let me know if you need any help with the tests, but this works quite well! We should ask on #gpu to see if we can get access to CUDA buildkite runners for testing.
And again, thank you for all this! The hard part for PyCallChainRules was all here in DLPack :)

@p-zubieta p-zubieta marked this pull request as ready for review February 15, 2022 17:10
@p-zubieta
Copy link
Owner Author

Just a last heads-up, I have removed the DLArray interface altogether, but I don't think that will be an issue given wrap.

@p-zubieta
Copy link
Owner Author

Thank you @rejuvyesh for testing this over rejuvyesh/PyCallChainRules.jl#10! That guided the design of the interface quite well and helped to make the PR more robust.

@p-zubieta p-zubieta merged commit 93dbfc9 into main Feb 15, 2022
@p-zubieta p-zubieta deleted the export branch February 15, 2022 17:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Exposing dlpack interface for Julia Array/CuArray

3 participants