Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/src/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ gs = gs1 .+ gs2
@test gs[w] ≈ gs1[w] + gs2[w]
@test gs[b] ≈ gs1[b] + gs2[b]

# gradients and dictionaries interact nicely
gs .+= Dict(p => randn(size(p)) for p in keys(gs))
# gradients and IdDict interact nicely
gs .+= IdDict(p => randn(size(p)) for p in keys(gs))

# clip gradients
map(x -> clamp.(x, -0.1, 0.1), gs)
Expand Down
3 changes: 3 additions & 0 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ end

Base.copy(ps::Params) = union!(Params(), ps)
Base.union(ps::Params, itrs...) = union!(copy(ps), itrs...)
Base.issetequal(ps1::Params, ps2::Params) = issetequal(ps1.params, ps2.params)
Base.issetequal(ps1::Params, x::Base.AbstractSet) = issetequal(ps1.params, x)
Base.issetequal(x::Base.AbstractSet, ps1::Params) = issetequal(x, ps1.params)

function Base.intersect!(ps::Params, itrs...)
for itr in itrs
Expand Down
45 changes: 44 additions & 1 deletion test/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
using CUDA
using Zygote: Grads
using Random: randn!
CUDA.allowscalar(false)

# Test GPU movement inside the call to `gradient`
@testset "GPU movement" begin
Expand All @@ -21,7 +24,6 @@ end
g_gpu = gradient(x -> w(x), a_gpu)[1]
@test g_gpu isa CuArray
@test g_gpu |> collect ≈ g

end

@testset "jacobian" begin
Expand All @@ -37,3 +39,44 @@ end
@test j2[v1] isa CuArray
@test j2[v1] ≈ cu(res2)
end

@testset "gradient algebra" begin
w, b = rand(2) |> cu, rand(2) |> cu
x1, x2 = rand(2) |> cu, rand(2) |> cu

gs1 = gradient(() -> sum(w .* x1), Params([w]))
gs2 = gradient(() -> sum(w .* x2), Params([w]))

@test .- gs1 isa Grads
@test gs1 .- gs2 isa Grads
@test .+ gs1 isa Grads
@test gs1 .+ gs2 isa Grads
@test 2 .* gs1 isa Grads
@test (2 .* gs1)[w] ≈ 2 * gs1[w]
@test gs1 .* 2 isa Grads
@test gs1 ./ 2 isa Grads
@test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w]

gs12 = gs1 .+ gs2
gs1 .+= gs2
@test gs12[w] ≈ gs1[w]

gs3 = gradient(() -> sum(w .* x1), Params([w, b])) # grad nothing with respect to b
gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b]))

@test .- gs3 isa Grads
@test gs3 .- gs4 isa Grads
@test .+ gs3 isa Grads
@test gs3 .+ gs4 isa Grads
@test 2 .* gs3 isa Grads
@test gs3 .* 2 isa Grads
@test gs3 ./ 2 isa Grads
@test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w]
@test (gs3 .+ gs4)[b] ≈ gs4[b]

@test gs3 .+ IdDict(w => similar(w), b => similar(b)) isa Grads
gs3 .+= IdDict(p => randn!(similar(p)) for p in keys(gs3))
@test gs3 isa Grads

@test_throws ArgumentError gs1 .+ gs4
end
6 changes: 3 additions & 3 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ end
@test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w]
@test (gs3 .+ gs4)[b] ≈ gs4[b]

@test gs3 .+ Dict(w => similar(w), b => similar(b)) isa Grads
gs3 .+= Dict(p => randn(size(p)) for p in keys(gs3))
@test gs3 .+ IdDict(w => similar(w), b => similar(b)) isa Grads
gs3 .+= IdDict(p => randn(size(p)) for p in keys(gs3))
@test gs3 isa Grads

@test_throws ArgumentError gs1 .+ gs4
@test_throws ArgumentError gs1 .+ gs4
end

@testset "map and broadcast" begin
Expand Down