diff --git a/docs/src/utils.md b/docs/src/utils.md index c46b646ce..84596357f 100644 --- a/docs/src/utils.md +++ b/docs/src/utils.md @@ -42,8 +42,9 @@ gs = gs1 .+ gs2 @test gs[w] ≈ gs1[w] + gs2[w] @test gs[b] ≈ gs1[b] + gs2[b] -# gradients and dictionaries interact nicely -gs .+= Dict(p => randn(size(p)) for p in keys(gs)) +# gradients and IdDict interact nicely +# note that an IdDict must be used for gradient algebra on the GPU +gs .+= IdDict(p => randn(size(p)) for p in keys(gs)) # clip gradients map(x -> clamp.(x, -0.1, 0.1), gs) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 7d72f9537..3c62d6586 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -80,6 +80,9 @@ end Base.copy(ps::Params) = union!(Params(), ps) Base.union(ps::Params, itrs...) = union!(copy(ps), itrs...) +Base.issetequal(ps1::Params, ps2::Params) = issetequal(ps1.params, ps2.params) +Base.issetequal(ps1::Params, x::Base.AbstractSet) = issetequal(ps1.params, x) +Base.issetequal(x::Base.AbstractSet, ps1::Params) = issetequal(x, ps1.params) function Base.intersect!(ps::Params, itrs...) for itr in itrs diff --git a/test/cuda.jl b/test/cuda.jl index 0766ff986..a54402999 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -1,4 +1,7 @@ using CUDA +using Zygote: Grads +using Random: randn! +CUDA.allowscalar(false) # Test GPU movement inside the call to `gradient` @testset "GPU movement" begin @@ -21,7 +24,6 @@ end g_gpu = gradient(x -> w(x), a_gpu)[1] @test g_gpu isa CuArray @test g_gpu |> collect ≈ g - end @testset "jacobian" begin @@ -37,3 +39,44 @@ end @test j2[v1] isa CuArray @test j2[v1] ≈ cu(res2) end + +@testset "gradient algebra" begin + w, b = rand(2) |> cu, rand(2) |> cu + x1, x2 = rand(2) |> cu, rand(2) |> cu + + gs1 = gradient(() -> sum(w .* x1), Params([w])) + gs2 = gradient(() -> sum(w .* x2), Params([w])) + + @test .- gs1 isa Grads + @test gs1 .- gs2 isa Grads + @test .+ gs1 isa Grads + @test gs1 .+ gs2 isa Grads + @test 2 .* gs1 isa Grads + @test (2 .* gs1)[w] ≈ 2 * gs1[w] + @test gs1 .* 2 isa Grads + @test gs1 ./ 2 isa Grads + @test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w] + + gs12 = gs1 .+ gs2 + gs1 .+= gs2 + @test gs12[w] ≈ gs1[w] + + gs3 = gradient(() -> sum(w .* x1), Params([w, b])) # grad nothing with respect to b + gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b])) + + @test .- gs3 isa Grads + @test gs3 .- gs4 isa Grads + @test .+ gs3 isa Grads + @test gs3 .+ gs4 isa Grads + @test 2 .* gs3 isa Grads + @test gs3 .* 2 isa Grads + @test gs3 ./ 2 isa Grads + @test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w] + @test (gs3 .+ gs4)[b] ≈ gs4[b] + + @test gs3 .+ IdDict(w => similar(w), b => similar(b)) isa Grads + gs3 .+= IdDict(p => randn!(similar(p)) for p in keys(gs3)) + @test gs3 isa Grads + + @test_throws ArgumentError gs1 .+ gs4 +end diff --git a/test/interface.jl b/test/interface.jl index 087da74f3..0ffb933f6 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -116,11 +116,11 @@ end @test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w] @test (gs3 .+ gs4)[b] ≈ gs4[b] - @test gs3 .+ Dict(w => similar(w), b => similar(b)) isa Grads - gs3 .+= Dict(p => randn(size(p)) for p in keys(gs3)) + @test gs3 .+ IdDict(w => similar(w), b => similar(b)) isa Grads + gs3 .+= IdDict(p => randn(size(p)) for p in keys(gs3)) @test gs3 isa Grads - @test_throws ArgumentError gs1 .+ gs4 + @test_throws ArgumentError gs1 .+ gs4 end @testset "map and broadcast" begin