diff --git a/Project.toml b/Project.toml index 1224bc9..0cd548f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,26 +1,36 @@ name = "PyCallChainRules" uuid = "b12ccfe2-7326-416f-9f4f-cd3183bd9fe8" authors = ["rejuvyesh and contributors"] -version = "0.1.1" +version = "0.2.0" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DLPack = "53c2dc0f-f7d5-43fd-8906-6c0220547083" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" [compat] -ChainRulesCore = "1" -ChainRulesTestUtils = "1.5" +Adapt = "3" +ChainRulesCore = "1.9" +CUDA = "≥ 1.3" +DLPack = "0.1" +FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13" +Functors = "0.2" PyCall = "1" +Requires = "1.3" Zygote = "0.6" julia = "1.6" [extras] -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "SafeTestsets", "Zygote", "Flux", "ChainRulesTestUtils"] +test = ["Test", "SafeTestsets", "Zygote", "Flux", "CUDA"] diff --git a/README.md b/README.md index 83026e8..2e32eb2 100644 --- a/README.md +++ b/README.md @@ -5,18 +5,25 @@ While Julia is great, there are still a lot of existing useful differentiable python code in PyTorch, Jax, etc. Given PyCall.jl is already so great and seamless, one might wonder what it takes to differentiate through those `pycall`s. This library aims for that ideal. +Thanks to [@pabloferz](https://github.cim/pabloferz), this works on both CPU and GPU without any array copies via [DLPack.jl](https://github.com/pabloferz/DLPack.jl). + ## Basic Usage ### PyTorch -**Install Python dependencies**: +#### CPU only + +##### Install Python dependencies + ```julia using PyCall run(`$(PyCall.pyprogramname) -m pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --upgrade`) run(`$(PyCall.pyprgramname) -m pip install "git+https://github.com/pytorch/functorch.git"`) ``` +##### Example + ```julia using PyCallChainRules.Torch: TorchModuleWrapper, torch using Zygote @@ -35,16 +42,61 @@ loss(m, x, y) = sum(m(x) .- target) grad, = Zygote.gradient(m->loss(m, input, target), jlwrap) ``` +#### GPU + +##### Install Python dependencies + +```julia +using PyCall +# For CUDA 11 +run(`$(PyCall.pyprogramname) -m pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html +--upgrade`) +run(`$(PyCall.pyprgramname) -m pip install "git+https://github.com/pytorch/functorch.git"`) +``` + +##### Example + +```julia +using CUDA +using PyCallChainRules.Torch: TorchModuleWrapper, torch +using Zygote + +@assert CUDA.isfunctional() + +indim = 32 +outdim = 16 +torch_module = torch.nn.Linear(indim, outdim).to(device=torch.device("cuda:0")) # Can be anything subclassing torch.nn.Module +jlwrap = TorchModuleWrapper(torch_module) + +batchsize = 64 +input = CUDA.cu(randn(Float32, indim, batchsize)) +output = jlwrap(input) + +target = CUDA.cu(randn(Float32, outdim, batchsize)) +loss(m, x, y) = sum(m(x) .- target) +grad, = Zygote.gradient(m->loss(m, input, target), jlwrap) +``` + + ### Jax -**Install Python dependencies**: +#### CPU only + +##### Install Python dependencies +```julia +using PyCall +run(`$(PyCall.pyprogramname) -m pip install jax\["cpu"\]`) # for cpu version +``` + +#### GPU + +##### Install Python dependencies ```julia using PyCall -run(`$(PyCall.pyprogramname) -m pip install jax\["cpu"\]) +run(`$(PyCall.pyprogramname) -m pip install jax\["cuda"\] -f https://storage.googleapis.com/jax-releases/jax_releases.html`) ``` ## Current Limitations / TODO -- CPU only -- Lots of array copies -- Assumes wrapped python functions are single output only \ No newline at end of file +- Assumes wrapped python functions are single output only +- No keyword argument support \ No newline at end of file diff --git a/src/PyCallChainRules.jl b/src/PyCallChainRules.jl index 15d978f..3fef687 100644 --- a/src/PyCallChainRules.jl +++ b/src/PyCallChainRules.jl @@ -1,8 +1,27 @@ module PyCallChainRules +using DLPack +using Requires + +import FillArrays +import Adapt + +struct PyAdaptor{T} end +Adapt.adapt_storage(to::PyAdaptor{T}, x::AbstractArray) where {T} = convert(Array, x) +Adapt.adapt_storage(to::PyAdaptor{T}, x::StridedArray) where {T} = x +Adapt.adapt_storage(to::PyAdaptor{T}, x::FillArrays.AbstractFill) where {T} = collect(x) + + + # Write your package code here. include("pytorch.jl") include("jax.jl") +function __init__() + @require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" begin + include("cuda.jl") + end +end + end diff --git a/src/cuda.jl b/src/cuda.jl new file mode 100644 index 0000000..7f2c840 --- /dev/null +++ b/src/cuda.jl @@ -0,0 +1,3 @@ +Adapt.adapt_storage(to::PyAdaptor{T}, x::CUDA.AnyCuArray) where {T} = CUDA.CuArray(x) +Adapt.adapt_storage(to::PyAdaptor{T}, x::CUDA.StridedCuArray) where {T} = x +Adapt.adapt_storage(to::PyAdaptor{<:CUDA.CuArray}, x::FillArrays.AbstractFill) = CUDA.cu(collect(x)) diff --git a/src/jax.jl b/src/jax.jl index 4301bab..c150b4f 100644 --- a/src/jax.jl +++ b/src/jax.jl @@ -2,46 +2,51 @@ module Jax using PyCall using ChainRulesCore +using DLPack +using Functors: fmap +using Adapt + +using ..PyCallChainRules: PyAdaptor const inspect = PyNULL() const jax = PyNULL() +const dlpack = PyNULL() const stax = PyNULL() const numpy = PyNULL() const ispysetup = Ref{Bool}(false) -function reversedims(a::AbstractArray{T,N}) where {T<:AbstractFloat,N} - permutedims(a, N:-1:1) -end - -mapover(f, iselement, x) = - iselement(x) ? f(x) : map(e -> mapover(f, iselement, e), x) +pyto_dlpack(x) = @pycall dlpack.to_dlpack(x)::PyObject +pyfrom_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject struct JaxFunctionWrapper jaxfn::PyObject end -function (wrap::JaxFunctionWrapper)(args...) +function (wrap::JaxFunctionWrapper)(args...; kwargs...) # TODO: handle multiple outputs - out = numpy.array(wrap.jaxfn(mapover(x->jax.numpy.asarray(PyReverseDims(x)), x-> x isa Array, args)...)) - return reversedims((out)) + out = (wrap.jaxfn(fmap(x->DLPack.share(x, PyObject, pyfrom_dlpack), args)...)) + return (DLPack.wrap(out, pyto_dlpack)) end -function ChainRulesCore.rrule(wrap::JaxFunctionWrapper, args...) +function ChainRulesCore.rrule(wrap::JaxFunctionWrapper, args...; kwargs...) + T = typeof(first(args)) project = ProjectTo(args) - jax_primal, jax_vjpfun = jax.vjp(wrap.jaxfn, mapover(x->jax.numpy.asarray(PyReverseDims(x)), x-> x isa Array, args)...) + jax_primal, jax_vjpfun = jax.vjp(wrap.jaxfn, fmap(x->DLPack.share(x, PyObject, pyfrom_dlpack), args)...; kwargs...) function JaxFunctionWrapper_pullback(Δ) - tangent_vals = mapover(x->reversedims(numpy.array(x)), x-> x isa PyObject,jax_vjpfun(jax.numpy.array(PyReverseDims(Δ)))) - + cΔ = Adapt.adapt(PyAdaptor{T}, Δ) + dlΔ = DLPack.share(cΔ, PyObject, pyfrom_dlpack) + tangent_vals = fmap(x->(DLPack.wrap(x, pyto_dlpack)), jax_vjpfun(dlΔ)) return (NoTangent(), project(tangent_vals)...) end - return reversedims(numpy.array(jax_primal)), JaxFunctionWrapper_pullback + return (DLPack.wrap(jax_primal, pyto_dlpack)), JaxFunctionWrapper_pullback end function __init__() try copy!(jax, pyimport("jax")) + copy!(dlpack, pyimport("jax.dlpack")) copy!(numpy, pyimport("numpy")) copy!(stax, pyimport("jax.example_libraries.stax")) copy!(inspect, pyimport("inspect")) diff --git a/src/pytorch.jl b/src/pytorch.jl index 82f43e0..95ac709 100644 --- a/src/pytorch.jl +++ b/src/pytorch.jl @@ -1,74 +1,91 @@ module Torch using PyCall + using ChainRulesCore +using DLPack +using Functors: @functor +using Adapt + + +using ..PyCallChainRules: PyAdaptor const inspect = PyNULL() const torch = PyNULL() const functorch = PyNULL() - +const dlpack = PyNULL() const ispysetup = Ref{Bool}(false) -function reversedims(a::AbstractArray{T,N}) where {T<:AbstractFloat,N} - permutedims(a, N:-1:1) -end +pyto_dlpack(x) = @pycall dlpack.to_dlpack(x)::PyObject +pyfrom_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject + struct TorchModuleWrapper torch_stateless_module::PyObject dtype::PyObject - device::PyObject params::Tuple buffers::Tuple end +@functor TorchModuleWrapper (params,) -Base.show(io::IO, f::TorchModuleWrapper) = print(io, f.torch_stateless_module, " ", f.dtype, " ", f.device, "params size=", size.(f.params)) +Base.show(io::IO, f::TorchModuleWrapper) = print(io, f.torch_stateless_module, " ", f.dtype, " ", " params size=", size.(f.params)) Base.length(f::TorchModuleWrapper) = length(f.params) Base.iterate(f::TorchModuleWrapper) = iterate(f.params) Base.iterate(f::TorchModuleWrapper, state) = iterate(f.params, state) -function TorchModuleWrapper(torch_module, device) +function TorchModuleWrapper(torch_module) pybuiltin("isinstance")(torch_module, torch.nn.Module) || error("Not a torch.nn.Module") - torch_module = torch_module.to(device) funmod, params, buffers = functorch.make_functional_with_buffers(torch_module) dtype = params[1].dtype - jlparams = map(x -> x.detach().numpy(), params) - return TorchModuleWrapper(funmod, dtype, device, jlparams, buffers) + jlparams = map(params) do x + DLPack.wrap(x, pyto_dlpack) + end + return TorchModuleWrapper(funmod, dtype, jlparams, buffers) end -function TorchModuleWrapper(torch_module) - device = torch.cuda.is_available() ? torch.device("cuda:0") : torch.device("cpu") - TorchModuleWrapper(torch_module, device) -end -function (wrap::TorchModuleWrapper)(args...) +function (wrap::TorchModuleWrapper)(args...; kwargs...) # TODO: handle multiple outputs - tensor_out = wrap.torch_stateless_module(Tuple(map(x -> torch.as_tensor(x).to(device = wrap.device, dtype = wrap.dtype).requires_grad_(true), wrap.params)), - wrap.buffers, map(x -> torch.as_tensor(PyReverseDims(x)).to(dtype = wrap.dtype, device = wrap.device), args)...) - return reversedims(tensor_out.detach().numpy()) + params = wrap.params + tensor_out = wrap.torch_stateless_module(Tuple(map(x -> DLPack.share(x, PyObject, pyfrom_dlpack).requires_grad_(true), params)), + wrap.buffers, map(x -> DLPack.share(x, PyObject, pyfrom_dlpack), args)...; kwargs...) + res = DLPack.wrap(tensor_out, pyto_dlpack) + return res end -function ChainRulesCore.rrule(wrap::TorchModuleWrapper, args...) - torch_primal, torch_vjpfun = functorch.vjp(wrap.torch_stateless_module, Tuple(map(x -> torch.as_tensor(x).to(device = wrap.device, dtype = wrap.dtype).requires_grad_(true), wrap.params)), - wrap.buffers, map(x -> torch.as_tensor(PyReverseDims(x)).to(dtype = wrap.dtype, device = wrap.device).requires_grad_(true), args)...) +function ChainRulesCore.rrule(wrap::TorchModuleWrapper, args...; kwargs...) + T = typeof(first(args)) + params = wrap.params + torch_primal, torch_vjpfun = functorch.vjp(py"buffer_implicit"(wrap.torch_stateless_module, wrap.buffers), Tuple(map(x -> DLPack.share(x, PyObject, pyfrom_dlpack).requires_grad_(true), params)), + map(x -> DLPack.share(x, PyObject, pyfrom_dlpack).requires_grad_(true), args)...; kwargs...) project = ProjectTo(args) function TorchModuleWrapper_pullback(Δ) - torch_tangent_vals = torch_vjpfun(torch.as_tensor(PyReverseDims(Δ)).to(dtype = wrap.dtype, device = wrap.device)) - jlparams_tangents = map(x -> x.detach().numpy(), torch_tangent_vals[1]) - args_tangents = project(map(x -> reversedims(x.detach().numpy()), torch_tangent_vals[3:end])) - return (Tangent{TorchModuleWrapper}(; torch_stateless_module = NoTangent(), dtype = NoTangent(), device = NoTangent(), params = jlparams_tangents, buffers = NoTangent()), args_tangents...) + torch_tangent_vals = torch_vjpfun(DLPack.share(Adapt.adapt(PyAdaptor{T}, Δ), PyObject, pyfrom_dlpack)) + jlparams_tangents = map(x -> (DLPack.wrap(x, pyto_dlpack)), torch_tangent_vals[1]) + args_tangents = project(map(x -> (DLPack.wrap(x, pyto_dlpack)), torch_tangent_vals[2:end])) + return (Tangent{TorchModuleWrapper}(; torch_stateless_module = NoTangent(), dtype = NoTangent(), params = jlparams_tangents, buffers = NoTangent()), args_tangents...) end - return reversedims(torch_primal.detach().numpy()), TorchModuleWrapper_pullback + res = DLPack.wrap(torch_primal, pyto_dlpack) + return res, TorchModuleWrapper_pullback end function __init__() try copy!(torch, pyimport("torch")) + copy!(dlpack, pyimport("torch.utils.dlpack")) copy!(functorch, pyimport("functorch")) copy!(inspect, pyimport("inspect")) ispysetup[] = true + py""" + def buffer_implicit(fn, buffers): + def newfn(params, inputs): + return fn(params, buffers, inputs) + + return newfn + """ catch err @warn """PyCallChainRules.jl has failed to import torch and functorch from Python. Please make sure these are installed. diff --git a/test/test_jax.jl b/test/test_jax.jl index 8620b18..9c5d936 100644 --- a/test/test_jax.jl +++ b/test/test_jax.jl @@ -1,36 +1,78 @@ -using PyCallChainRules.Jax: JaxFunctionWrapper, jax, numpy, stax, reversedims, ispysetup +using PyCallChainRules.Jax: JaxFunctionWrapper, jax, numpy, stax, pyto_dlpack, pyfrom_dlpack, ispysetup using Test -using ChainRulesTestUtils + using Zygote -using ChainRulesCore: NoTangent +using CUDA using Random +using PyCall +using DLPack #using Flux + + if !ispysetup[] return end +@testset "dlpack" begin + key = jax.random.PRNGKey(0) + for dims in ((10,), (1, 10), (2, 3, 5), (2, 3, 4, 5)) + xto = jax.random.normal(key, dims) + xjl = DLPack.wrap(xto, pyto_dlpack) + @test Tuple(xto.shape) == reverse(size(xjl)) + @test isapprox(sum(numpy.array(xto)), sum(xjl)) + end +end + batchsize = 1 indim = 3 outdim = 2 init_lin, apply_lin = stax.Dense(outdim) _, params = init_lin(jax.random.PRNGKey(0), (-1, indim)) -params_np = map(reversedims ∘ numpy.array, params) +params_np = map(x->((DLPack.wrap(x, pyto_dlpack))), params) linwrap = JaxFunctionWrapper(apply_lin) x = randn(Float32, indim, batchsize) +if CUDA.functional() + params_np = map(cu, params_np) + x = cu(x) +end y = linwrap(params_np, x) @test size(y) == (outdim, batchsize) # CRTU check TODO -test_rrule(linwrap, params_np, x; check_inferred=false, check_thunked_output_tangent=false, rtol=1e-4, atol=1e-4) +#test_rrule(linwrap, params_np, x; check_inferred=false, check_thunked_output_tangent=false, rtol=1e-4, atol=1e-4) # Zygote check +if CUDA.functional() + params_np = map(cu, params_np) + x = cu(x) +end + grad, = Zygote.gradient(p->sum(linwrap(p, x)), params_np) +py""" +import jax +import jax.numpy as jnp +def grad(fn, params, x): + f2 = lambda p, z: jnp.sum(fn(p, z)) + return jax.grad(f2)(params, x) +""" +jaxgrad = map(x->(DLPack.wrap(x, pyto_dlpack)), (py"grad")(apply_lin, params, DLPack.share(x, PyObject, pyfrom_dlpack))) @test length(grad) == length(params_np) @test size(grad[1]) == size(params_np[1]) @test size(grad[2]) == size(params_np[2]) +@test isapprox(Array(grad[1]), Array(jaxgrad[1])) +@test isapprox(Array(grad[2]), Array(jaxgrad[2])) grad, = Zygote.gradient(z->sum(linwrap(params_np, z)), x) -@test size(grad) == size(x) \ No newline at end of file +@test size(grad) == size(x) +py""" +import jax +import jax.numpy as jnp +def gradx(fn, params, x): + f2 = lambda p, z: jnp.sum(fn(p, z)) + return jax.grad(f2, argnums=(1,))(params, x) +""" +jaxgrad = map(x->(DLPack.wrap(x, pyto_dlpack)), (py"gradx")(apply_lin, params, DLPack.share(x, PyObject, pyfrom_dlpack))) +@test isapprox(Array(jaxgrad[1]), Array(grad)) \ No newline at end of file diff --git a/test/test_pytorch.jl b/test/test_pytorch.jl index 30fe04c..6dd5889 100644 --- a/test/test_pytorch.jl +++ b/test/test_pytorch.jl @@ -1,45 +1,106 @@ -using PyCallChainRules.Torch: TorchModuleWrapper, torch, functorch, ispysetup +using PyCallChainRules.Torch: TorchModuleWrapper, torch, functorch, dlpack, pyto_dlpack, pyfrom_dlpack, ispysetup using Test -using ChainRulesTestUtils using Zygote using Flux using ChainRulesCore: NoTangent, AbstractZero import Random using PyCall +using CUDA +using DLPack if !ispysetup[] return end -Random.seed!(42) +if CUDA.functional() + device = torch.device("cuda:0") +else + device = torch.device("cpu") +end + +#CUDA.allowscalar(true) + +function compare_grad_wrt_params(modelwrap, inputs...) + params = map(x -> DLPack.share(x, PyObject, pyfrom_dlpack).to(device = device, dtype = modelwrap.dtype).requires_grad_(true), (modelwrap.params)) + torch_out = modelwrap.torch_stateless_module(params, modelwrap.buffers, map(z->DLPack.share(z, PyObject, pyfrom_dlpack).to(dtype=modelwrap.dtype, device=device), inputs)...).sum() + torchgrad = map(x-> (x.cpu().numpy()), torch.autograd.grad(torch_out, params)) + grad, = Zygote.gradient(m->sum(m(inputs...)), modelwrap) + @test length(torchgrad) == length(grad.params) + for i in 1:length(grad.params) + @test isapprox(sum(torchgrad[i]), sum(grad.params[i])) + end + @test length(grad.params) == length(modelwrap.params) + @test grad.params[1] !== nothing + @test grad.params[2] !== nothing + @test size(grad.params[1]) == size(modelwrap.params[1]) + @test size(grad.params[2]) == size(modelwrap.params[2]) +end -ChainRulesTestUtils.test_approx(::AbstractZero, x::PyObject, msg=""; kwargs...) = @test true -ChainRulesTestUtils.test_approx(::AbstractZero, x::Tuple{}, msg=""; kwargs...) = @test true +function compare_grad_wrt_inputs(modelwrap, x) + params = map(z -> DLPack.share(z, PyObject, pyfrom_dlpack).to(device = device, dtype = modelwrap.dtype).requires_grad_(true), (modelwrap.params)) + xtorch = DLPack.share(copy(x), PyObject, pyfrom_dlpack).to(dtype=modelwrap.dtype, device=device).requires_grad_(true) + torch_out = modelwrap.torch_stateless_module(params, modelwrap.buffers, xtorch).sum() + torchgrad = map(z-> (copy(z.cpu().numpy())), torch.autograd.grad(torch_out, xtorch))[1] + grad, = Zygote.gradient(z->sum(modelwrap(z)), x) + @test size(grad) == size(x) + @test length(torchgrad) == length(grad) + @test isapprox(sum(torchgrad), sum(grad)) +end + +# Random.seed!(42) +# torch.manual_seed(42) -function ChainRulesTestUtils.FiniteDifferences.to_vec(x::TorchModuleWrapper) - params_vec, back = ChainRulesTestUtils.FiniteDifferences.to_vec(x.params) - function TorchModuleWrapper_from_vec(params_vec) - TorchModuleWrapper(x.torch_stateless_module, x.dtype, x.device, back(params_vec), x.buffers) +@testset "dlpack" begin + for dims in ((10,), (1, 10), (2, 3, 5), (2, 3, 4, 5)) + xto = torch.randn(dims..., device=device) + xjl = DLPack.wrap(xto, pyto_dlpack) + @test Tuple(xto.size()) == reverse(size(xjl)) + @test isapprox(sum(xto.cpu().numpy()), sum(xjl)) end - return params_vec, TorchModuleWrapper_from_vec end batchsize = 1 indim = 3 outdim = 2 hiddendim = 4 -lin = torch.nn.Sequential(torch.nn.Linear(indim, hiddendim), torch.nn.ReLU(), torch.nn.Linear(hiddendim, outdim)) -linwrap = TorchModuleWrapper(lin) +@testset "linear" begin + lin = torch.nn.Linear(indim, outdim).to(device=device) + linwrap = TorchModuleWrapper(lin) + if CUDA.functional() + linwrap = fmap(CUDA.cu, linwrap) + end + x = randn(Float32, indim, batchsize) + if CUDA.functional() + x = cu(x) + end + y = linwrap(x) + @test size(y) == (outdim, batchsize) + compare_grad_wrt_params(linwrap, x) + compare_grad_wrt_inputs(linwrap, x) + +end -x = randn(Float32, indim, batchsize) -y = linwrap(x) -@test size(y) == (outdim, batchsize) +@testset "mlp" begin + mlp = torch.nn.Sequential(torch.nn.Linear(indim, hiddendim), torch.nn.ReLU(), torch.nn.Linear(hiddendim, outdim)).to(device=device) + mlpwrap = TorchModuleWrapper(mlp) + if CUDA.functional() + mlpwrap = fmap(CUDA.cu, mlpwrap) + end + x = randn(Float32, indim, batchsize) + if CUDA.functional() + x = cu(x) + end + y = mlpwrap(x) + @test size(y) == (outdim, batchsize) + compare_grad_wrt_params(mlpwrap, x) + compare_grad_wrt_inputs(mlpwrap, x) +end # CRTU check -x = randn(Float32, indim, batchsize) -test_rrule(linwrap, x; check_inferred=false, check_thunked_output_tangent=false, atol=1e-4, rtol=1e-4) +# x = randn(Float32, indim, batchsize) +# test_rrule(linwrap, x; check_inferred=false, check_thunked_output_tangent=false, atol=1e-4, rtol=1e-4) # const CRTU = ChainRulesTestUtils # primals_and_tangents = CRTU.auto_primal_and_tangent((linwrap, x)) # CRTU.primal(primals_and_tangents) @@ -66,30 +127,64 @@ test_rrule(linwrap, x; check_inferred=false, check_thunked_output_tangent=false, # Zygote check -grad, = Zygote.gradient(m->sum(m(x)), linwrap) -@test length(grad.params) == length(linwrap.params) -@test grad.params[1] !== nothing -@test grad.params[2] !== nothing -@test size(grad.params[1]) == size(linwrap.params[1]) -@test size(grad.params[2]) == size(linwrap.params[2]) -grad, = Zygote.gradient(z->sum(linwrap(z)), x) -@test size(grad) == size(x) +# params = map(x -> torch.as_tensor(copy(x)).to(device = linwrap.device, dtype = linwrap.dtype).requires_grad_(true), linwrap.params) +# torch_out = linwrap.torch_stateless_module(params, linwrap.buffers, map(z->torch.as_tensor(PyReverseDims(z)).to(dtype=linwrap.dtype), [x])...).sum() +# torchgrad = map(x-> copy(x.numpy()), torch.autograd.grad(torch_out, params)) +# grad, = Zygote.gradient(m->sum(m(x)), linwrap) +# @test length(torchgrad) == length(grad.params) +# for i in 1:length(grad.params) +# @test isapprox(torchgrad[i], grad.params[i]) +# end +# @test length(grad.params) == length(linwrap.params) +# @test grad.params[1] !== nothing +# @test grad.params[2] !== nothing +# @test size(grad.params[1]) == size(linwrap.params[1]) +# @test size(grad.params[2]) == size(linwrap.params[2]) + +# params = map(x -> torch.as_tensor(copy(x)).to(device = linwrap.device, dtype = linwrap.dtype).requires_grad_(true), linwrap.params) +# xtorch = torch.as_tensor(PyReverseDims(copy(x))).to(dtype=linwrap.dtype).requires_grad_(true) +# torch_out = linwrap.torch_stateless_module(params, linwrap.buffers, xtorch).sum() +# torchgrad = map(x-> ReverseDimsArray(x.numpy()), torch.autograd.grad(torch_out, xtorch))[1] +# grad, = Zygote.gradient(z->sum(linwrap(z)), copy(x)) +# @test size(grad) == size(x) +# @test length(torchgrad) == length(grad) +# @test isapprox(torchgrad, grad) # Flux check -nn = Chain(Dense(4, 3), linwrap) -x2 = randn(Float32, 4, batchsize) -grad, = Zygote.gradient(m->sum(m(x2)), nn) - - -model = torch.nn.Sequential( - torch.nn.Conv2d(1,2,5), - torch.nn.ReLU(), - torch.nn.Conv2d(2,6,5), - torch.nn.ReLU() - ) -modelwrap = TorchModuleWrapper(model) - -input = randn(Float32, 12, 12, 1, batchsize) -output = modelwrap(input) -test_rrule(modelwrap, input; check_inferred=false, check_thunked_output_tangent=false, atol=1e-2, rtol=1e-2) \ No newline at end of file +@testset "flux" begin + lin = torch.nn.Linear(indim, outdim).to(device=device) + linwrap = TorchModuleWrapper(lin) + nn = Chain(Dense(4, 3), linwrap) + if CUDA.functional() + nn = Flux.gpu(nn) + end + x2 = randn(Float32, 4, batchsize) + if CUDA.functional() + x2 = cu(x2) + end + grad, = Zygote.gradient(m->sum(m(x2)), nn) + @test grad !== nothing +end + + +@testset "conv" begin + model = torch.nn.Sequential( + torch.nn.Conv2d(1,2,5), + torch.nn.ReLU(), + torch.nn.Conv2d(2,6,5), + torch.nn.ReLU() + ).to(device=device) + modelwrap = TorchModuleWrapper(model) + if CUDA.functional() + modelwrap = fmap(CUDA.cu, modelwrap) + end + input = randn(Float32, 12, 12, 1, batchsize) + if CUDA.functional() + input = cu(input) + end + output = modelwrap(input) + + compare_grad_wrt_params(modelwrap, input) + compare_grad_wrt_inputs(modelwrap, input) +end \ No newline at end of file diff --git a/test/test_pytorch_hub.jl b/test/test_pytorch_hub.jl index 43303e6..0cd551c 100644 --- a/test/test_pytorch_hub.jl +++ b/test/test_pytorch_hub.jl @@ -1,17 +1,23 @@ -using PyCallChainRules.Torch: TorchModuleWrapper, torch, functorch, ispysetup +using PyCallChainRules.Torch: TorchModuleWrapper, torch, functorch, ispysetup, pyfrom_dlpack using Test -using ChainRulesTestUtils using Zygote using Flux using ChainRulesCore: NoTangent, AbstractZero import Random using PyCall +using Functors: fmap +using DLPack +using CUDA if !ispysetup[] return end - +if CUDA.functional() + device = torch.device("cuda:0") +else + device = torch.device("cpu") +end py""" import torch def bn2group(module): @@ -41,18 +47,23 @@ def bn2group(module): """ model = torch.hub.load("pytorch/vision", "resnet18", pretrained=true) -model_gn = py"bn2group"(model) +model_gn = py"bn2group"(model).to(device=device) modelwrap = TorchModuleWrapper(model_gn) - +if CUDA.functional() + modelwrap = fmap(CUDA.cu, modelwrap) +end x = randn(Float32, reverse((1, 3, 224, 224))) -y = modelwrap(x) +if CUDA.functional() + x = CUDA.cu(x) +end +#y = modelwrap(x) grad, = Zygote.gradient(m->sum(m(x)), modelwrap) @test length(grad.params) == length(modelwrap.params) -params = map(x -> torch.as_tensor(x).to(device = modelwrap.device, dtype = modelwrap.dtype).requires_grad_(true), modelwrap.params) -torch_out = modelwrap.torch_stateless_module(params, modelwrap.buffers, map(z->torch.as_tensor(PyReverseDims(z)).to(dtype=modelwrap.dtype), [x])...).sum() -torchgrad = map(x-> x.numpy(), torch.autograd.grad(torch_out, params)) +params = map(x -> DLPack.share(x, PyObject, pyfrom_dlpack).to(device = device, dtype = modelwrap.dtype).requires_grad_(true), modelwrap.params) +torch_out = modelwrap.torch_stateless_module(params, modelwrap.buffers, map(z-> DLPack.share(z, PyObject, pyfrom_dlpack).to(dtype=modelwrap.dtype, device=device), [x])...).sum() +torchgrad = map(x-> x.cpu().numpy(), torch.autograd.grad(torch_out, params)) @test length(torchgrad) == length(grad.params) for i in 1:length(grad.params) - @test isapprox(torchgrad[i], grad.params[i], atol=1e-3, rtol=1e-3) + @test isapprox(sum(torchgrad[i]), sum(grad.params[i]), atol=1e-3, rtol=1e-3) end