Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
702ec12
Use dlpack for array interop
rejuvyesh Jan 25, 2022
22c5967
more tests
rejuvyesh Jan 25, 2022
65d8036
rebase but tests are flaky
rejuvyesh Jan 27, 2022
890c517
refactor to avoid repetition
rejuvyesh Jan 28, 2022
a3b3995
missed import
rejuvyesh Jan 28, 2022
42a1fc5
small change to enable later use for gpu
rejuvyesh Jan 28, 2022
ddf2b7e
update for DLPack.jl refactor
rejuvyesh Jan 31, 2022
7888b40
update jax for new DLPack.jl
rejuvyesh Jan 31, 2022
4d09900
latest DLPack; all issues resolved
rejuvyesh Jan 31, 2022
5e93561
minor cleanup
rejuvyesh Jan 31, 2022
923a8b5
update for upcoming DLPack interface for sharing jlarrays to python
rejuvyesh Feb 7, 2022
8e10397
get jax cuda working
rejuvyesh Feb 9, 2022
3de2ec2
start getting ready for GPUs
rejuvyesh Feb 10, 2022
fd81252
fix for new version
rejuvyesh Feb 15, 2022
afb9f5b
acknowledgement
rejuvyesh Feb 15, 2022
ab108fc
minor cleanup
rejuvyesh Feb 15, 2022
4b29709
use device
rejuvyesh Feb 15, 2022
8dcff72
fix for DLPack's share interface for PyCall
rejuvyesh Feb 18, 2022
58c2b00
update version since we are DLPack based now
rejuvyesh Feb 21, 2022
7158777
add link
rejuvyesh Feb 21, 2022
5eefee9
make CUDA optional
rejuvyesh Feb 21, 2022
3873560
relax tolerance
rejuvyesh Feb 21, 2022
b2ef824
relax CUDA versions
rejuvyesh Feb 21, 2022
4c064a7
update version requirements
rejuvyesh Feb 21, 2022
5ab994d
simply adapt
rejuvyesh Feb 21, 2022
18c56f3
update readme for gpu
rejuvyesh Feb 21, 2022
6493802
update readme
rejuvyesh Feb 21, 2022
fe00e4c
improve jax install instructions
rejuvyesh Feb 21, 2022
c1223df
add basic kwargs support
rejuvyesh Feb 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ authors = ["rejuvyesh <[email protected]> and contributors"]
version = "0.1.1"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DLPack = "53c2dc0f-f7d5-43fd-8906-6c0220547083"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

Expand Down
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

While Julia is great, there are still a lot of existing useful differentiable python code in PyTorch, Jax, etc. Given PyCall.jl is already so great and seamless, one might wonder what it takes to differentiate through those `pycall`s. This library aims for that ideal.

Thanks to @pabloferz, this works on both CPU and GPU without any array copies via [DLPack.jl](https://github.com/pabloferz/DLPack.jl).

## Basic Usage


Expand Down Expand Up @@ -40,11 +42,10 @@ grad, = Zygote.gradient(m->loss(m, input, target), jlwrap)
**Install Python dependencies**:
```julia
using PyCall
run(`$(PyCall.pyprogramname) -m pip install jax\["cpu"\])
run(`$(PyCall.pyprogramname) -m pip install jax\["cpu"\]) # for cpu version
```

## Current Limitations / TODO

- CPU only
- Lots of array copies
- Assumes wrapped python functions are single output only
- Assumes wrapped python functions are single output only
- No keyword argument support
23 changes: 23 additions & 0 deletions src/PyCallChainRules.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,28 @@
module PyCallChainRules

using DLPack
using Functors
using FillArrays
using CUDA



function ReverseDimsArray(a::AbstractArray{T,N}) where {T<:AbstractFloat,N}
PermutedDimsArray(a, N:-1:1)
end

maybecontiguous(x::AbstractArray) = Array(x)
mayebecontiguous(x::StridedArray) = x
function maybecontiguous(x::FillArrays.AbstractFill)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can't be the best way to handle FillArrays?

x = collect(x)
if CUDA.functional()
x = CUDA.cu(x)
end
return x
end
maybecontiguous(x::AnyCuArray) = CuArray(x)
maybecontiguous(x::StridedCuArray) = x

# Write your package code here.
include("pytorch.jl")

Expand Down
30 changes: 18 additions & 12 deletions src/jax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,52 @@ module Jax

using PyCall
using ChainRulesCore
using DLPack
using Functors: fmap

using ..PyCallChainRules: ReverseDimsArray, maybecontiguous

import DLPack
DLPack.share(A::StridedArray, from_dlpack::Function) = DLPack.share(A, PyObject, from_dlpack)

const inspect = PyNULL()
const jax = PyNULL()
const dlpack = PyNULL()
const stax = PyNULL()
const numpy = PyNULL()

const ispysetup = Ref{Bool}(false)

function reversedims(a::AbstractArray{T,N}) where {T<:AbstractFloat,N}
permutedims(a, N:-1:1)
end

mapover(f, iselement, x) =
iselement(x) ? f(x) : map(e -> mapover(f, iselement, e), x)
pyto_dlpack(x) = @pycall dlpack.to_dlpack(x)::PyObject
pyfrom_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject

struct JaxFunctionWrapper
jaxfn::PyObject
end

function (wrap::JaxFunctionWrapper)(args...)
# TODO: handle multiple outputs
out = numpy.array(wrap.jaxfn(mapover(x->jax.numpy.asarray(PyReverseDims(x)), x-> x isa Array, args)...))
return reversedims((out))
out = (wrap.jaxfn(fmap(x->DLPack.share(x, pyfrom_dlpack), args)...))
return (DLPack.wrap(out, pyto_dlpack))
end

function ChainRulesCore.rrule(wrap::JaxFunctionWrapper, args...)
project = ProjectTo(args)
jax_primal, jax_vjpfun = jax.vjp(wrap.jaxfn, mapover(x->jax.numpy.asarray(PyReverseDims(x)), x-> x isa Array, args)...)
jax_primal, jax_vjpfun = jax.vjp(wrap.jaxfn, fmap(x->DLPack.share(x, pyfrom_dlpack), args)...)
function JaxFunctionWrapper_pullback(Δ)
tangent_vals = mapover(x->reversedims(numpy.array(x)), x-> x isa PyObject,jax_vjpfun(jax.numpy.array(PyReverseDims(Δ))))

cΔ = maybecontiguous(Δ)
dlΔ = DLPack.share(cΔ, pyfrom_dlpack)
tangent_vals = fmap(x->(DLPack.wrap(x, pyto_dlpack)), jax_vjpfun(dlΔ))
return (NoTangent(), project(tangent_vals)...)
end
return reversedims(numpy.array(jax_primal)), JaxFunctionWrapper_pullback
return (DLPack.wrap(jax_primal, pyto_dlpack)), JaxFunctionWrapper_pullback
end


function __init__()
try
copy!(jax, pyimport("jax"))
copy!(dlpack, pyimport("jax.dlpack"))
copy!(numpy, pyimport("numpy"))
copy!(stax, pyimport("jax.example_libraries.stax"))
copy!(inspect, pyimport("inspect"))
Expand Down
65 changes: 41 additions & 24 deletions src/pytorch.jl
Original file line number Diff line number Diff line change
@@ -1,74 +1,91 @@
module Torch

using PyCall

using ChainRulesCore
using DLPack
using Functors: @functor

import DLPack
DLPack.share(A::StridedArray, from_dlpack::Function) = DLPack.share(A, PyObject, from_dlpack)

using ..PyCallChainRules: ReverseDimsArray, maybecontiguous

const inspect = PyNULL()
const torch = PyNULL()
const functorch = PyNULL()

const dlpack = PyNULL()
const ispysetup = Ref{Bool}(false)

function reversedims(a::AbstractArray{T,N}) where {T<:AbstractFloat,N}
permutedims(a, N:-1:1)
end
pyto_dlpack(x) = @pycall dlpack.to_dlpack(x)::PyObject
pyfrom_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject


struct TorchModuleWrapper
torch_stateless_module::PyObject
dtype::PyObject
device::PyObject
params::Tuple
buffers::Tuple
end

@functor TorchModuleWrapper (params,)

Base.show(io::IO, f::TorchModuleWrapper) = print(io, f.torch_stateless_module, " ", f.dtype, " ", f.device, "params size=", size.(f.params))
Base.show(io::IO, f::TorchModuleWrapper) = print(io, f.torch_stateless_module, " ", f.dtype, " ", " params size=", size.(f.params))

Base.length(f::TorchModuleWrapper) = length(f.params)
Base.iterate(f::TorchModuleWrapper) = iterate(f.params)
Base.iterate(f::TorchModuleWrapper, state) = iterate(f.params, state)

function TorchModuleWrapper(torch_module, device)
function TorchModuleWrapper(torch_module)
pybuiltin("isinstance")(torch_module, torch.nn.Module) || error("Not a torch.nn.Module")
torch_module = torch_module.to(device)
funmod, params, buffers = functorch.make_functional_with_buffers(torch_module)
dtype = params[1].dtype
jlparams = map(x -> x.detach().numpy(), params)
return TorchModuleWrapper(funmod, dtype, device, jlparams, buffers)
jlparams = map(params) do x
(convert(Array, DLPack.wrap(x.cpu(), pyto_dlpack)))
end
return TorchModuleWrapper(funmod, dtype, jlparams, buffers)
end

function TorchModuleWrapper(torch_module)
device = torch.cuda.is_available() ? torch.device("cuda:0") : torch.device("cpu")
TorchModuleWrapper(torch_module, device)
end

function (wrap::TorchModuleWrapper)(args...)
# TODO: handle multiple outputs
tensor_out = wrap.torch_stateless_module(Tuple(map(x -> torch.as_tensor(x).to(device = wrap.device, dtype = wrap.dtype).requires_grad_(true), wrap.params)),
wrap.buffers, map(x -> torch.as_tensor(PyReverseDims(x)).to(dtype = wrap.dtype, device = wrap.device), args)...)
return reversedims(tensor_out.detach().numpy())
params = wrap.params
tensor_out = wrap.torch_stateless_module(Tuple(map(x -> DLPack.share(x, pyfrom_dlpack).requires_grad_(true), params)),
wrap.buffers, map(x -> DLPack.share(x, pyfrom_dlpack), args)...)
res = (DLPack.wrap(tensor_out, pyto_dlpack))
return res
end

function ChainRulesCore.rrule(wrap::TorchModuleWrapper, args...)
torch_primal, torch_vjpfun = functorch.vjp(wrap.torch_stateless_module, Tuple(map(x -> torch.as_tensor(x).to(device = wrap.device, dtype = wrap.dtype).requires_grad_(true), wrap.params)),
wrap.buffers, map(x -> torch.as_tensor(PyReverseDims(x)).to(dtype = wrap.dtype, device = wrap.device).requires_grad_(true), args)...)
params = wrap.params
torch_primal, torch_vjpfun = functorch.vjp(py"buffer_implicit"(wrap.torch_stateless_module, wrap.buffers), Tuple(map(x -> DLPack.share((x), pyfrom_dlpack).requires_grad_(true), params)),
map(x -> DLPack.share((x), pyfrom_dlpack).requires_grad_(true), args)...)
project = ProjectTo(args)
function TorchModuleWrapper_pullback(Δ)
torch_tangent_vals = torch_vjpfun(torch.as_tensor(PyReverseDims(Δ)).to(dtype = wrap.dtype, device = wrap.device))
jlparams_tangents = map(x -> x.detach().numpy(), torch_tangent_vals[1])
args_tangents = project(map(x -> reversedims(x.detach().numpy()), torch_tangent_vals[3:end]))
return (Tangent{TorchModuleWrapper}(; torch_stateless_module = NoTangent(), dtype = NoTangent(), device = NoTangent(), params = jlparams_tangents, buffers = NoTangent()), args_tangents...)
torch_tangent_vals = torch_vjpfun(DLPack.share((maybecontiguous(Δ)), pyfrom_dlpack))
jlparams_tangents = map(x -> (DLPack.wrap(x, pyto_dlpack)), torch_tangent_vals[1])
args_tangents = project(map(x -> (DLPack.wrap(x, pyto_dlpack)), torch_tangent_vals[2:end]))
return (Tangent{TorchModuleWrapper}(; torch_stateless_module = NoTangent(), dtype = NoTangent(), params = jlparams_tangents, buffers = NoTangent()), args_tangents...)
end
return reversedims(torch_primal.detach().numpy()), TorchModuleWrapper_pullback
res = (DLPack.wrap(torch_primal, pyto_dlpack))
return res, TorchModuleWrapper_pullback
end


function __init__()
try
copy!(torch, pyimport("torch"))
copy!(dlpack, pyimport("torch.utils.dlpack"))
copy!(functorch, pyimport("functorch"))
copy!(inspect, pyimport("inspect"))
ispysetup[] = true
py"""
def buffer_implicit(fn, buffers):
def newfn(params, inputs):
return fn(params, buffers, inputs)

return newfn
"""
catch err
@warn """PyCallChainRules.jl has failed to import torch and functorch from Python.
Please make sure these are installed.
Expand Down
61 changes: 55 additions & 6 deletions test/test_jax.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,85 @@
using PyCallChainRules.Jax: JaxFunctionWrapper, jax, numpy, stax, reversedims, ispysetup
using PyCallChainRules: ReverseDimsArray

using PyCallChainRules.Jax: JaxFunctionWrapper, jax, numpy, stax, pyto_dlpack, pyfrom_dlpack, ispysetup

using Test
using ChainRulesTestUtils

using Zygote
using ChainRulesCore: NoTangent
using CUDA
using Random
using PyCall
using DLPack
#using Flux



if !ispysetup[]
return
end


function reversedims(a::AbstractArray{T,N}) where {T<:AbstractFloat,N}
permutedims(a, N:-1:1)
end

@testset "dlpack" begin
key = jax.random.PRNGKey(0)
for dims in ((10,), (1, 10), (2, 3, 5), (2, 3, 4, 5))
xto = jax.random.normal(key, dims)
xjl = DLPack.wrap(xto, pyto_dlpack)
@test Tuple(xto.shape) == reverse(size(xjl))
@test isapprox(sum(numpy.array(xto)), sum(xjl))
end
end

batchsize = 1
indim = 3
outdim = 2

init_lin, apply_lin = stax.Dense(outdim)
_, params = init_lin(jax.random.PRNGKey(0), (-1, indim))
params_np = map(reversedims ∘ numpy.array, params)
params_np = map(x->((DLPack.wrap(x, pyto_dlpack))), params)
linwrap = JaxFunctionWrapper(apply_lin)
x = randn(Float32, indim, batchsize)
if CUDA.functional()
params_np = map(cu, params_np)
x = cu(x)
end
y = linwrap(params_np, x)
@test size(y) == (outdim, batchsize)

# CRTU check TODO
test_rrule(linwrap, params_np, x; check_inferred=false, check_thunked_output_tangent=false, rtol=1e-4, atol=1e-4)
#test_rrule(linwrap, params_np, x; check_inferred=false, check_thunked_output_tangent=false, rtol=1e-4, atol=1e-4)

# Zygote check
if CUDA.functional()
params_np = map(cu, params_np)
x = cu(x)
end

grad, = Zygote.gradient(p->sum(linwrap(p, x)), params_np)
py"""
import jax
import jax.numpy as jnp
def grad(fn, params, x):
f2 = lambda p, z: jnp.sum(fn(p, z))
return jax.grad(f2)(params, x)
"""
jaxgrad = map(x->(DLPack.wrap(x, pyto_dlpack)), (py"grad")(apply_lin, params, DLPack.share(x, pyfrom_dlpack)))
@test length(grad) == length(params_np)
@test size(grad[1]) == size(params_np[1])
@test size(grad[2]) == size(params_np[2])
@test isapprox(Array(grad[1]), Array(jaxgrad[1]))
@test isapprox(Array(grad[2]), Array(jaxgrad[2]))

grad, = Zygote.gradient(z->sum(linwrap(params_np, z)), x)
@test size(grad) == size(x)
@test size(grad) == size(x)
py"""
import jax
import jax.numpy as jnp
def gradx(fn, params, x):
f2 = lambda p, z: jnp.sum(fn(p, z))
return jax.grad(f2, argnums=(1,))(params, x)
"""
jaxgrad = map(x->(DLPack.wrap(x, pyto_dlpack)), (py"gradx")(apply_lin, params, DLPack.share(x, pyfrom_dlpack)))
@test isapprox(Array(jaxgrad[1]), Array(grad))
Loading